class Mnist::Loader

Constants

IMAGE_FILE_MAGIC
LABEL_FILE_MAGIC

Attributes

filename[R]

Public Class Methods

new(filename) click to toggle source
# File lib/mnist.rb, line 16
def initialize(filename)
  @filename = filename
end

Public Instance Methods

load_images() click to toggle source
# File lib/mnist.rb, line 22
def load_images
  check_magic(IMAGE_FILE_MAGIC)
  total_count = read_total_count
  nrows, ncols = read_image_size
  images = total_count.times.map do
    read_image(nrows, ncols)
  end
  [nrows, ncols, images]
end
load_labels() click to toggle source
# File lib/mnist.rb, line 32
def load_labels
  check_magic(LABEL_FILE_MAGIC)
  total_count = read_total_count
  read_labels(total_count)
end

Private Instance Methods

check_magic(expected_magic) click to toggle source
# File lib/mnist.rb, line 40
def check_magic(expected_magic)
  actual_magic = read_magic
  unless actual_magic == expected_magic
    raise InvalidMagic, "Expected #{expected_magic}, but #{actual_magic} is given"
  end
end
input() click to toggle source
# File lib/mnist.rb, line 73
def input
  @input ||= Zlib::GzipReader.open(filename)
end
read_image(nrows, ncols) click to toggle source
# File lib/mnist.rb, line 69
def read_image(nrows, ncols)
  input.read(nrows * ncols)
end
read_image_size() click to toggle source
# File lib/mnist.rb, line 63
def read_image_size
  read_uint32(2)
end
read_labels(n=1)
Alias for: read_uint8
read_magic() click to toggle source
# File lib/mnist.rb, line 55
def read_magic
  read_uint32.first
end
read_total_count() click to toggle source
# File lib/mnist.rb, line 59
def read_total_count
  read_uint32.first
end
read_uint32(n=1) click to toggle source
# File lib/mnist.rb, line 51
def read_uint32(n=1)
  input.read(4 * n).unpack('N*')
end
read_uint8(n=1) click to toggle source
# File lib/mnist.rb, line 47
def read_uint8(n=1)
  input.read(n).unpack('C*')
end
Also aliased as: read_labels