class Tensorflow::Datasets::Images::Mnist

Constants

BASE_URL

Public Instance Methods

dataset(images_file, labels_file) click to toggle source
# File lib/datasets/images/mnist.rb, line 32
def dataset(images_file, labels_file)
  download_manager = Datasets::DownloadManager.new
  urls = ["#{BASE_URL}/#{images_file}",
          "#{BASE_URL}/#{labels_file}"]

  resources = download_manager.download(urls)

  images = Data::FixedLengthRecordDataset.new(resources.first.path, 28 * 28, header_bytes: 16, compression_type: 'GZIP').map_func(self.decode_image)
  labels = Data::FixedLengthRecordDataset.new(resources.last.path, 1, header_bytes: 8, compression_type: 'GZIP').map_func(self.decode_label)
  Data::ZipDataset.new(images, labels)
end
decode_image(image) click to toggle source
# File lib/datasets/images/mnist.rb, line 16
def decode_image(image)
  image = IO.decode_raw(image, Tf.uint8)
  image = Tf.cast(image, Tf.float32)
  image = Tf.reshape(image, [784])
  # Normalize from [0, 255] to [0.0, 1.0]
  image / 255.0
end
decode_label(label) click to toggle source
# File lib/datasets/images/mnist.rb, line 25
def decode_label(label)
  # tf.string -> [Tf.uint8]
  label = Tf::IO.decode_raw(label, Tf.uint8)
  label = Tf.reshape(label, [])  # label is a scalar
  Tf.cast(label, Tf.int32)
end
test() click to toggle source
# File lib/datasets/images/mnist.rb, line 48
def test
  dataset('t10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz')
end
train() click to toggle source
# File lib/datasets/images/mnist.rb, line 44
def train
  dataset('train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz')
end