class Tensorflow::Graph::Session
Attributes
graph[RW]
options[RW]
Public Class Methods
finalize(pointer)
click to toggle source
# File lib/tensorflow/graph/session.rb, line 31 def self.finalize(pointer) proc do FFI.TF_DeleteSession(pointer) end end
new(graph, options)
click to toggle source
# File lib/tensorflow/graph/session.rb, line 37 def initialize(graph, options) @graph = graph Status.check do |status| @pointer = FFI.TF_NewSession(graph, options, status) end end
run(graph) { |session| ... }
click to toggle source
# File lib/tensorflow/graph/session.rb, line 24 def self.run(graph) session = self.new(graph, SessionOptions.new) result = yield session session.close result end
Public Instance Methods
close()
click to toggle source
# File lib/tensorflow/graph/session.rb, line 135 def close Status.check do |status| FFI.TF_CloseSession(self, status) end end
run(operations, feed_dict={})
click to toggle source
# File lib/tensorflow/graph/session.rb, line 48 def run(operations, feed_dict={}) operations = Array(operations).flatten.compact key_outputs = feed_dict.keys.map(&:outputs).flatten keys_ptr = FFI::Output.array_to_ptr(key_outputs.map(&:output)) values = self.values_to_tensors(feed_dict) values_ptr = ::FFI::MemoryPointer.new(:pointer, values.length) values_ptr.write_array_of_pointer(values) # Gather up all the outputs for each operation outputs = operations.map do |operation| case operation when Operation, Variable operation.outputs when OperationOutput operation else raise(Error::UnimplementedError, "Unsupported operation type: #{operation}") end end.flatten outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output)) results_ptr = ::FFI::MemoryPointer.new(:pointer, outputs.length) # Gather up all the targets targets = operations.map do |operation| case operation when Operation, Variable operation when OperationOutput operation.operation else raise("Unsupported target: #{operation}") end end targets_ptr = ::FFI::MemoryPointer.new(:pointer, targets.length) targets_ptr.write_array_of_pointer(targets) run_options = nil metadata = nil Status.check do |status| FFI.TF_SessionRun(self, run_options, # Inputs keys_ptr, values_ptr, feed_dict.keys.length, # Outputs outputs_ptr, results_ptr, outputs.length, # Targets targets_ptr, operations.length, metadata, status) end results = results_ptr.read_array_of_pointer(outputs.length).map.with_index do |pointer, i| output = outputs[i] Tensor.from_pointer(pointer).value end # For each operation we want to return a single result start = 0 result = operations.reduce(Array.new) do |array, operation| length = case operation when Operation, Variable operation.outputs.length when OperationOutput 1 else raise(Error::UnimplementedError, "Unsupported operation type: #{operation}") end if length == 0 array << nil else array.concat(results[start, length]) start += length end array end if operations.length == 1 && results.length == 1 result.first else result end end
to_ptr()
click to toggle source
# File lib/tensorflow/graph/session.rb, line 44 def to_ptr @pointer end
values_to_tensors(values)
click to toggle source
# File lib/tensorflow/graph/session.rb, line 141 def values_to_tensors(values) values.map do |key, value| case value when Tensor value else # The value dtype needs to match the key dtype raise(Error::UnknownError, "Cannot determine dtype: #{key}") if key.num_outputs != 1 dtype = key.output_types.first Tensor.new(value, dtype: dtype) end end end