class Tensorflow::Datasets::Images::Mnist
Constants
- BASE_URL
Public Instance Methods
dataset(images_file, labels_file)
click to toggle source
# File lib/datasets/images/mnist.rb, line 32 def dataset(images_file, labels_file) download_manager = Datasets::DownloadManager.new urls = ["#{BASE_URL}/#{images_file}", "#{BASE_URL}/#{labels_file}"] resources = download_manager.download(urls) images = Data::FixedLengthRecordDataset.new(resources.first.path, 28 * 28, header_bytes: 16, compression_type: 'GZIP').map_func(self.decode_image) labels = Data::FixedLengthRecordDataset.new(resources.last.path, 1, header_bytes: 8, compression_type: 'GZIP').map_func(self.decode_label) Data::ZipDataset.new(images, labels) end
decode_image(image)
click to toggle source
# File lib/datasets/images/mnist.rb, line 16 def decode_image(image) image = IO.decode_raw(image, Tf.uint8) image = Tf.cast(image, Tf.float32) image = Tf.reshape(image, [784]) # Normalize from [0, 255] to [0.0, 1.0] image / 255.0 end
decode_label(label)
click to toggle source
# File lib/datasets/images/mnist.rb, line 25 def decode_label(label) # tf.string -> [Tf.uint8] label = Tf::IO.decode_raw(label, Tf.uint8) label = Tf.reshape(label, []) # label is a scalar Tf.cast(label, Tf.int32) end
test()
click to toggle source
# File lib/datasets/images/mnist.rb, line 48 def test dataset('t10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz') end
train()
click to toggle source
# File lib/datasets/images/mnist.rb, line 44 def train dataset('train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz') end