class Torchrb::Torch
Attributes
error_rate[RW]
network_loaded[RW]
network_timestamp[RW]
Public Class Methods
new(options={})
click to toggle source
Calls superclass method
Torchrb::Lua::new
# File lib/torchrb/torch.rb, line 7 def initialize options={} super @network_loaded = false @error_rate = Float::NAN load_network options[:network_storage_path] unless network_loaded rescue nil end
Public Instance Methods
cudify()
click to toggle source
# File lib/torchrb/torch.rb, line 102 def cudify eval <<-EOF, __FILE__, __LINE__ -- print(sys.COLORS.red .. '==> using CUDA GPU #' .. cutorch.getDevice() .. sys.COLORS.black) train_set.input = train_set.input:cuda() train_set.label = train_set.label:cuda() test_set.input = test_set.input:cuda() test_set.label = test_set.label:cuda() validation_set.input = validation_set.input:cuda() validation_set.label = validation_set.label:cuda() criterion = nn.ClassNLLCriterion():cuda() net = cudnn.convert(net:cuda(), cudnn) EOF end
iteration_callback=(callback)
click to toggle source
# File lib/torchrb/torch.rb, line 14 def iteration_callback= callback state.function "iteration_callback" do |trainer, iteration, currentError| progress = iteration / state['number_of_iterations'] self.error_rate = currentError/100.0 callback.call progress: progress, error_rate: error_rate end end
load_network(network_storage_path)
click to toggle source
# File lib/torchrb/torch.rb, line 51 def load_network network_storage_path raise "Neuronal net not trained yet. Call 'Torch#update_training_data'." unless File.exist?(network_storage_path) metadata = eval(<<-EOF, __FILE__, __LINE__).to_ruby net = torch.load('#{network_storage_path}') metadata = torch.load('#{network_storage_path}.meta') classes = metadata[1] timestamp = metadata[3] return metadata[2] EOF self.error_rate = metadata self.network_timestamp = @state['timestamp'] puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] loaded from #{network_storage_path} @ #{network_timestamp}" if debug self.network_loaded = true end
predict(sample, network_storage_path=nil)
click to toggle source
# File lib/torchrb/torch.rb, line 36 def predict sample, network_storage_path=nil load_network network_storage_path unless network_loaded classes = eval <<-EOF, __FILE__, __LINE__ #{sample.to_tensor("sample_data").strip} local prediction = #{enable_cuda ? "net:forward(sample_data:cuda()):float()" : "net:forward(sample_data)"} prediction = prediction:exp() confidences = prediction:totable() return classes EOF puts "predicted #{@state['confidences'].to_h} based on network @ #{network_timestamp}" if debug classes = classes.to_h @state['confidences'].to_h.map { |k, v| {classes[k] => v} }.reduce({}, :merge) end
print_results()
click to toggle source
# File lib/torchrb/torch.rb, line 74 def print_results result = eval <<-EOF, __FILE__, __LINE__ class_performance = torch.LongTensor(#classes):fill(0):totable() test_set_size = test_set:size() for i=1,test_set_size do local groundtruth = test_set.label[i] local prediction = net:forward(test_set.input[i]) local confidences, indices = torch.sort(prediction, true) -- true means sort in descending order class_performance[groundtruth] = class_performance[groundtruth] + 1 end local result = {} for i=1,#classes do local confidence = 100*class_performance[i]/test_set_size table.insert(result, { classes[i], confidence } ) end return result EOF result = result.to_ruby.map(&:to_ruby) if defined?(DEBUG) puts "#" * 80 puts "Results: #{result.to_h}" puts "#" * 80 end end
store_network(network_storage_path)
click to toggle source
# File lib/torchrb/torch.rb, line 66 def store_network network_storage_path eval <<-EOF, __FILE__, __LINE__ torch.save('#{network_storage_path}', net) torch.save('#{network_storage_path}.meta', {classes, #{error_rate}, '#{network_timestamp}}'} ) EOF puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] stored in #{network_storage_path} @ #{network_timestamp}" if debug end
train()
click to toggle source
# File lib/torchrb/torch.rb, line 22 def train eval <<-EOF, __FILE__, __LINE__ local oldprint = print print = function(...) end trainer:train(train_set) print = oldprint EOF self.network_loaded = true self.network_timestamp = Time.now end