class Torchrb::ModelBase
Constants
- REQUIRED_OPTIONS
Public Class Methods
error_rate()
click to toggle source
# File lib/torchrb/model_base.rb, line 33 def error_rate torch.error_rate end
predict(sample)
click to toggle source
# File lib/torchrb/model_base.rb, line 59 def predict sample torch.predict sample, network_storage_path end
progress_callback(progress: nil, message: nil, error_rate: Float::NAN)
click to toggle source
# File lib/torchrb/model_base.rb, line 5 def progress_callback progress: nil, message: nil, error_rate: Float::NAN raise NotImplementedError.new("Implement this method in your Model") end
setup_nn(options={})
click to toggle source
# File lib/torchrb/model_base.rb, line 9 def setup_nn options={} check_options(options) { net: Torchrb::NN::Basic, trainer: Torchrb::NN::TrainerDefault, tensor_type: "DoubleTensor", dimensions: [0], classes: [], dataset_split: [80, 10, 10], normalize: false, enable_cuda: false, auto_store_trained_network: true, network_storage_path: "tmp/cache/torchrb", debug: false, }.merge!(options).each do |option, default| cattr_reader(option) class_variable_set(:"@@#{option}", default) end cattr_reader(:torch) { Torchrb::Torch.new options } @net_options = load_extension(options[:net]) @trainer_options = load_extension(options[:trainer]) end
train()
click to toggle source
# File lib/torchrb/model_base.rb, line 37 def train progress_callback message: 'Loading data' load_model_data torch.iteration_callback= method(:progress_callback) define_nn @net_options define_trainer @trainer_options torch.cudify if enable_cuda progress_callback message: 'Start training' torch.train progress_callback message: 'Done' torch.print_results torch.store_network network_storage_path if auto_store_trained_network after_training if respond_to?(:after_training) torch.error_rate end
Private Class Methods
check_options(options)
click to toggle source
# File lib/torchrb/model_base.rb, line 65 def check_options(options) REQUIRED_OPTIONS.each do |required_option| raise "Option '#{required_option}' is required." unless options.has_key?(required_option) end end
load_dataset(set_name, collection)
click to toggle source
# File lib/torchrb/model_base.rb, line 96 def load_dataset set_name, collection progress_callback progress: @progress, message: "Loading #{set_name.to_s.humanize} with #{collection.size} element(s)." set = Torchrb::DataSet.new set_name, self, collection set.load do @progress += 0.333 / collection.size progress_callback progress: @progress end set.normalize! if normalize && set.is_trainset? end
load_extension(extension)
click to toggle source
# File lib/torchrb/model_base.rb, line 86 def load_extension(extension) if extension.is_a?(Hash) extend extension.keys.first extension.values.inject(&:merge) else extend extension {} end end
load_model_data()
click to toggle source
# File lib/torchrb/model_base.rb, line 71 def load_model_data raise "#{self} needs to implement '#to_tensor(var_name, data)' and '#prediction_class' method." unless respond_to?(:to_tensor, :prediction_class) @progress = 0 start = 0 all_ids = data_model.ids.shuffle [:train_set, :test_set, :validation_set].zip(dataset_split).map do |set, split| next if split.nil? size = all_ids.count * split.to_f / 100.0 offset = start start = start + size collection = data_model.where(id: all_ids.slice(offset, size)) load_dataset set, collection end end