class TorchVision::Datasets::CIFAR10
Public Class Methods
new(root, train: true, download: false, transform: nil, target_transform: nil)
click to toggle source
www.cs.toronto.edu/~kriz/cifar.html
Calls superclass method
# File lib/torchvision/datasets/cifar10.rb, line 6 def initialize(root, train: true, download: false, transform: nil, target_transform: nil) super(root, transform: transform, target_transform: target_transform) @train = train self.download if download if !_check_integrity raise Error, "Dataset not found or corrupted. You can use download=True to download it" end downloaded_list = @train ? train_list : test_list @data = String.new @targets = String.new downloaded_list.each do |file| file_path = File.join(@root, base_folder, file[:filename]) File.open(file_path, "rb") do |f| while !f.eof? f.read(1) if multiple_labels? @targets << f.read(1) @data << f.read(3072) end end end @targets = @targets.unpack("C*") # TODO switch i to -1 when Numo supports it @data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32) @data = @data.transpose(0, 2, 3, 1) end
Public Instance Methods
[](index)
click to toggle source
# File lib/torchvision/datasets/cifar10.rb, line 42 def [](index) # TODO remove trues when Numo supports it img, target = @data[index, true, true, true], @targets[index] img = Utils.image_from_array(img) img = @transform.call(img) if @transform target = @target_transform.call(target) if @target_transform [img, target] end
_check_integrity()
click to toggle source
# File lib/torchvision/datasets/cifar10.rb, line 55 def _check_integrity root = @root (train_list + test_list).each do |fentry| fpath = File.join(root, base_folder, fentry[:filename]) return false unless check_integrity(fpath, fentry[:sha256]) end true end
download()
click to toggle source
# File lib/torchvision/datasets/cifar10.rb, line 64 def download if _check_integrity puts "Files already downloaded and verified" return end download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256) path = File.join(@root, filename) File.open(path, "rb") do |io| Gem::Package.new("").extract_tar_gz(io, @root) end end
size()
click to toggle source
# File lib/torchvision/datasets/cifar10.rb, line 38 def size @data.shape[0] end
Private Instance Methods
base_folder()
click to toggle source
# File lib/torchvision/datasets/cifar10.rb, line 80 def base_folder "cifar-10-batches-bin" end
filename()
click to toggle source
# File lib/torchvision/datasets/cifar10.rb, line 88 def filename "cifar-10-binary.tar.gz" end
multiple_labels?()
click to toggle source
# File lib/torchvision/datasets/cifar10.rb, line 112 def multiple_labels? false end
test_list()
click to toggle source
# File lib/torchvision/datasets/cifar10.rb, line 106 def test_list [ {filename: "test_batch.bin", sha256: "8e2eb146ae340b09e24670f29cabc6326dba54da8789dab6768acf480273f65b"} ] end
tgz_sha256()
click to toggle source
# File lib/torchvision/datasets/cifar10.rb, line 92 def tgz_sha256 "c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd" end
train_list()
click to toggle source
# File lib/torchvision/datasets/cifar10.rb, line 96 def train_list [ {filename: "data_batch_1.bin", sha256: "cee916563c9f80d84e3cc88e17fdc0941787f1244f00a67874d45b261883ada5"}, {filename: "data_batch_2.bin", sha256: "a591ca11fa1708a91ee40f54b3da4784ccd871ecf2137de63f51ada8b3fa57ed"}, {filename: "data_batch_3.bin", sha256: "bbe8596564c0f86427f876058170b84dac6670ddf06d79402899d93ceea26f67"}, {filename: "data_batch_4.bin", sha256: "014e562d6e23c72197cc727519169a60359f5eccd8945ad5a09d710285ff4e48"}, {filename: "data_batch_5.bin", sha256: "755304fc0b379caeae8c14f0dac912fbc7d6cd469eb67a1029a08a39453a9add"}, ] end
url()
click to toggle source
# File lib/torchvision/datasets/cifar10.rb, line 84 def url "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz" end