module Torchrb::NN::TrainerDefault

Public Instance Methods

define_trainer(options) click to toggle source
# File lib/torchrb/nn/trainer_default.rb, line 3
  def define_trainer options
    torch.eval <<-EOF, __FILE__, __LINE__
      number_of_iterations = #{options.fetch(:iterations){50}} -- Must be set for the callback to work

      criterion = nn.ClassNLLCriterion()

      trainer = nn.StochasticGradient(net, criterion)
      trainer.verbose = false
      trainer.learningRate = #{options.fetch(:learning_rate){0.001}}
      trainer.maxIteration = number_of_iterations
      trainer.hookIteration = iteration_callback
    EOF
  end