class TorchVision::Datasets::MNIST

Public Class Methods

new(root, train: true, download: false, transform: nil, target_transform: nil) click to toggle source

yann.lecun.com/exdb/mnist/

Calls superclass method
# File lib/torchvision/datasets/mnist.rb, line 5
def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
  super(root, transform: transform, target_transform: target_transform)
  @train = train

  self.download if download

  if !check_exists
    raise Error, "Dataset not found. You can use download: true to download it"
  end

  data_file = @train ? training_file : test_file
  @data, @targets = Torch.load(File.join(processed_folder, data_file))
end

Public Instance Methods

[](index) click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 23
def [](index)
  img, target = @data[index], @targets[index].item

  img = Utils.image_from_array(img)

  img = @transform.call(img) if @transform

  target = @target_transform.call(target) if @target_transform

  [img, target]
end
check_exists() click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 43
def check_exists
  File.exist?(File.join(processed_folder, training_file)) &&
    File.exist?(File.join(processed_folder, test_file))
end
download() click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 48
def download
  return if check_exists

  FileUtils.mkdir_p(raw_folder)
  FileUtils.mkdir_p(processed_folder)

  resources.each do |resource|
    filename = resource[:url].split("/").last
    download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
  end

  puts "Processing..."

  training_set = [
    unpack_mnist("train-images-idx3-ubyte", 16, [60000, 28, 28]),
    unpack_mnist("train-labels-idx1-ubyte", 8, [60000])
  ]
  test_set = [
    unpack_mnist("t10k-images-idx3-ubyte", 16, [10000, 28, 28]),
    unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
  ]

  Torch.save(training_set, File.join(processed_folder, training_file))
  Torch.save(test_set, File.join(processed_folder, test_file))

  puts "Done!"
end
processed_folder() click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 39
def processed_folder
  File.join(@root, self.class.name.split("::").last, "processed")
end
raw_folder() click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 35
def raw_folder
  File.join(@root, self.class.name.split("::").last, "raw")
end
size() click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 19
def size
  @data.size(0)
end

Private Instance Methods

resources() click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 78
def resources
  [
    {
      url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
      sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
    },
    {
      url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
      sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
    },
    {
      url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
      sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
    },
    {
      url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
      sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
    }
  ]
end
test_file() click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 103
def test_file
  "test.pt"
end
training_file() click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 99
def training_file
  "training.pt"
end
unpack_mnist(path, offset, shape) click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 107
def unpack_mnist(path, offset, shape)
  path = File.join(raw_folder, "#{path}.gz")
  File.open(path, "rb") do |f|
    gz = Zlib::GzipReader.new(f)
    gz.read(offset)
    Torch.tensor(Numo::UInt8.from_string(gz.read, shape))
  end
end