class Tensorflow::Graph::Session

Attributes

graph[RW]
options[RW]

Public Class Methods

finalize(pointer) click to toggle source
# File lib/tensorflow/graph/session.rb, line 31
def self.finalize(pointer)
  proc do
    FFI.TF_DeleteSession(pointer)
  end
end
new(graph, options) click to toggle source
# File lib/tensorflow/graph/session.rb, line 37
def initialize(graph, options)
  @graph = graph
  Status.check do |status|
    @pointer = FFI.TF_NewSession(graph, options, status)
  end
end
run(graph) { |session| ... } click to toggle source
# File lib/tensorflow/graph/session.rb, line 24
def self.run(graph)
  session = self.new(graph, SessionOptions.new)
  result = yield session
  session.close
  result
end

Public Instance Methods

close() click to toggle source
# File lib/tensorflow/graph/session.rb, line 135
def close
  Status.check do |status|
    FFI.TF_CloseSession(self, status)
  end
end
run(operations, feed_dict={}) click to toggle source
# File lib/tensorflow/graph/session.rb, line 48
def run(operations, feed_dict={})
  operations = Array(operations).flatten.compact

  key_outputs = feed_dict.keys.map(&:outputs).flatten
  keys_ptr = FFI::Output.array_to_ptr(key_outputs.map(&:output))

  values = self.values_to_tensors(feed_dict)
  values_ptr = ::FFI::MemoryPointer.new(:pointer, values.length)
  values_ptr.write_array_of_pointer(values)

  # Gather up all the outputs for each operation
  outputs = operations.map do |operation|
    case operation
      when Operation, Variable
        operation.outputs
      when OperationOutput
        operation
      else
        raise(Error::UnimplementedError, "Unsupported operation type: #{operation}")
    end
  end.flatten

  outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
  results_ptr = ::FFI::MemoryPointer.new(:pointer, outputs.length)

  # Gather up all the targets
  targets = operations.map do |operation|
    case operation
      when Operation, Variable
        operation
      when OperationOutput
        operation.operation
      else
        raise("Unsupported target: #{operation}")
    end
  end
  targets_ptr = ::FFI::MemoryPointer.new(:pointer, targets.length)
  targets_ptr.write_array_of_pointer(targets)

  run_options = nil
  metadata = nil

  Status.check do |status|
    FFI.TF_SessionRun(self, run_options,
                      # Inputs
                      keys_ptr, values_ptr, feed_dict.keys.length,
                      # Outputs
                      outputs_ptr, results_ptr, outputs.length,
                      # Targets
                      targets_ptr, operations.length,
                      metadata,
                      status)
  end

  results = results_ptr.read_array_of_pointer(outputs.length).map.with_index do |pointer, i|
    output = outputs[i]
    Tensor.from_pointer(pointer).value
  end

  # For each operation we want to return a single result
  start = 0
  result = operations.reduce(Array.new) do |array, operation|
    length = case operation
               when Operation, Variable
                 operation.outputs.length
               when OperationOutput
                 1
               else
                 raise(Error::UnimplementedError, "Unsupported operation type: #{operation}")
             end

    if length == 0
      array << nil
    else
      array.concat(results[start, length])
      start += length
    end
    array
  end

  if operations.length == 1 && results.length == 1
    result.first
  else
    result
  end
end
to_ptr() click to toggle source
# File lib/tensorflow/graph/session.rb, line 44
def to_ptr
  @pointer
end
values_to_tensors(values) click to toggle source
# File lib/tensorflow/graph/session.rb, line 141
def values_to_tensors(values)
  values.map do |key, value|
    case value
      when Tensor
        value
      else
        # The value dtype needs to match the key dtype
        raise(Error::UnknownError, "Cannot determine dtype: #{key}") if key.num_outputs != 1
        dtype = key.output_types.first
        Tensor.new(value, dtype: dtype)
      end
  end
end