class TorchVision::Datasets::MNIST
Public Class Methods
new(root, train: true, download: false, transform: nil, target_transform: nil)
click to toggle source
Calls superclass method
# File lib/torchvision/datasets/mnist.rb, line 5 def initialize(root, train: true, download: false, transform: nil, target_transform: nil) super(root, transform: transform, target_transform: target_transform) @train = train self.download if download if !check_exists raise Error, "Dataset not found. You can use download: true to download it" end data_file = @train ? training_file : test_file @data, @targets = Torch.load(File.join(processed_folder, data_file)) end
Public Instance Methods
[](index)
click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 23 def [](index) img, target = @data[index], @targets[index].item img = Utils.image_from_array(img) img = @transform.call(img) if @transform target = @target_transform.call(target) if @target_transform [img, target] end
check_exists()
click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 43 def check_exists File.exist?(File.join(processed_folder, training_file)) && File.exist?(File.join(processed_folder, test_file)) end
download()
click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 48 def download return if check_exists FileUtils.mkdir_p(raw_folder) FileUtils.mkdir_p(processed_folder) resources.each do |resource| filename = resource[:url].split("/").last download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256]) end puts "Processing..." training_set = [ unpack_mnist("train-images-idx3-ubyte", 16, [60000, 28, 28]), unpack_mnist("train-labels-idx1-ubyte", 8, [60000]) ] test_set = [ unpack_mnist("t10k-images-idx3-ubyte", 16, [10000, 28, 28]), unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000]) ] Torch.save(training_set, File.join(processed_folder, training_file)) Torch.save(test_set, File.join(processed_folder, test_file)) puts "Done!" end
processed_folder()
click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 39 def processed_folder File.join(@root, self.class.name.split("::").last, "processed") end
raw_folder()
click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 35 def raw_folder File.join(@root, self.class.name.split("::").last, "raw") end
size()
click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 19 def size @data.size(0) end
Private Instance Methods
resources()
click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 78 def resources [ { url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609" }, { url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c" }, { url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6" }, { url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6" } ] end
test_file()
click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 103 def test_file "test.pt" end
training_file()
click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 99 def training_file "training.pt" end
unpack_mnist(path, offset, shape)
click to toggle source
# File lib/torchvision/datasets/mnist.rb, line 107 def unpack_mnist(path, offset, shape) path = File.join(raw_folder, "#{path}.gz") File.open(path, "rb") do |f| gz = Zlib::GzipReader.new(f) gz.read(offset) Torch.tensor(Numo::UInt8.from_string(gz.read, shape)) end end