class Tensorflow::Graph::Graph

Attributes

control_inputs[R]

Public Class Methods

default() click to toggle source
# File lib/tensorflow/graph/graph.rb, line 9
def self.default
  @default ||= Graph.new
end
finalize(pointer) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 17
def self.finalize(pointer)
  proc do
    FFI::TF_DeleteGraph(pointer)
  end
end
new() click to toggle source
# File lib/tensorflow/graph/graph.rb, line 23
def initialize
  @collections = Hash.new
  @name_scope = NameScope.new
  @pointer = FFI.TF_NewGraph()
  @control_inputs = Array.new
  ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
end
reset_default() click to toggle source
# File lib/tensorflow/graph/graph.rb, line 13
def self.reset_default
  @default = Graph.new
end

Public Instance Methods

add_function(function, gradient=nil) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 170
def add_function(function, gradient=nil)
  Status.check do |status|
    FFI.TF_GraphCopyFunction(self, function, gradient, status)
  end
end
add_to_collection(name, value) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 39
def add_to_collection(name, value)
  values = @collections[name] ||= Array.new
  values << value
end
add_to_collections(names, value) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 44
def add_to_collections(names, value)
  names.each do |name|
    self.add_to_collection(name, value)
  end
end
as_default() { |self| ... } click to toggle source
# File lib/tensorflow/graph/graph.rb, line 58
def as_default
  raise(Error::InvalidArgumentError, "Must provide block") unless block_given?
  ExecutionContext.push(self)
  begin
    yield self
  ensure
    ExecutionContext.pop
  end
end
as_graph_def() click to toggle source
# File lib/tensorflow/graph/graph.rb, line 217
def as_graph_def
  buffer_ptr = FFI.TF_NewBuffer
  Status.check do |status|
    FFI.TF_GraphToGraphDef(self, buffer_ptr, status)
  end

  buffer = FFI::Buffer.new(buffer_ptr)
  string = buffer[:data].read_string(buffer[:length])
  GraphDef.decode(string)
ensure
  FFI.TF_DeleteBuffer(buffer)
end
backward(operation) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 101
def backward(operation)
  def backward_internal(set, operation)
    operation.inputs.each do |input|
      set << input.operation
      backward_internal(set, input.operation)
    end
    set
  end
  result = Set.new([operation])
  backward_internal(result, operation)
end
backward_internal(set, operation) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 102
def backward_internal(set, operation)
  operation.inputs.each do |input|
    set << input.operation
    backward_internal(set, input.operation)
  end
  set
end
clear_collection(name) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 54
def clear_collection(name)
  @collections[name] = Array.new
end
collections() click to toggle source
# File lib/tensorflow/graph/graph.rb, line 35
def collections
  @collections.keys
end
control_dependencies(control_inputs) { |self| ... } click to toggle source
# File lib/tensorflow/graph/graph.rb, line 68
def control_dependencies(control_inputs)
  @control_inputs = Array(control_inputs)
  begin
    yield self
  ensure
    @control_inputs = []
  end
end
create_operation(op_type, inputs=[], attrs={}) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 129
def create_operation(op_type, inputs=[], attrs={})
  op_desc = OperationDescription.new(self, op_type, inputs, attrs)
  op_desc.save
end
execute(operations, feed_dict={}) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 134
def execute(operations, feed_dict={})
  session = Session.new(self, SessionOptions.new)
  result = session.run(operations, feed_dict)
  session.close
  result
end
forward(operation) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 89
def forward(operation)
  def forward_internal(set, operation)
    operation.consumers.each do |consumer|
      set << consumer.operation
      forward_internal(set, consumer.operation)
    end
    set
  end
  result = Set.new([operation])
  forward_internal(result, operation)
end
forward_internal(set, operation) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 90
def forward_internal(set, operation)
  operation.consumers.each do |consumer|
    set << consumer.operation
    forward_internal(set, consumer.operation)
  end
  set
end
get_collection_ref(name, scope=nil) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 50
def get_collection_ref(name, scope=nil)
  @collections[name]
end
import(graph_def, options=nil) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 230
def import(graph_def, options=nil)
  options ||= GraphDefOptions.new

  data = if graph_def.is_a?(GraphDef)
           GraphDef.encode(graph_def)
         else
           graph_def
         end

  ptr = ::FFI::MemoryPointer.new(:char, data.bytesize)
  ptr.put_bytes(0, data)

  buffer = FFI::Buffer.new
  buffer[:data] = ptr
  buffer[:length] = data.bytesize

  Status.check do |status|
    FFI.TF_GraphImportGraphDef(self, buffer, options, status)
  end
end
op_def(op_type) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 77
def op_def(op_type)
  buffer_ptr = FFI.TF_NewBuffer
  Status.check do |status|
    FFI.TF_GraphGetOpDef(self, op_type, buffer_ptr, status)
  end
  buffer = FFI::Buffer.new(buffer_ptr)
  string = buffer[:data].read_string(buffer[:length])
  OpDef.decode(string)
ensure
  FFI.TF_DeleteBuffer(buffer)
end
operation(name) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 124
def operation(name)
  ptr = FFI.TF_GraphOperationByName(self, name)
  ptr.null? ? nil : Operation.new(self, ptr)
end
operations() { |operation| ... } click to toggle source
# File lib/tensorflow/graph/graph.rb, line 113
def operations
  return enum_for(:operations) unless block_given?

  # Get a pointer to a size_t set to 0
  position_ptr = ::FFI::MemoryPointer.new(:size_t, 1, true)
  while (ptr = FFI.TF_GraphNextOperation(self, position_ptr))
    break if ptr.null?
    yield Operation.new(self, ptr)
  end
end
output_shapes(operation) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 141
def output_shapes(operation)
  operation.outputs.map do |output|
    num_dims = Status.check do |status|
      FFI.TF_GraphGetTensorNumDims(self, output, status)
    end

    if num_dims == -1
      []
    else
      dims_ptr = ::FFI::MemoryPointer.new(:int64, num_dims)
      Status.check do |status|
        FFI.TF_GraphGetTensorShape(self, output, dims_ptr, num_dims, status)
      end
      dims_ptr.read_array_of_int64(num_dims)
    end
  end
end
tensor_set_shape(operation, shape) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 159
def tensor_set_shape(operation, shape)
  ptr = ::FFI::MemoryPointer.new(:int64, shape.length)
  ptr.write_array_of_int64(shape)
  output = FFI::Output.new
  output[:oper] = operation
  output[:index] = 0
  Status.check do |status|
    FFI.TF_GraphSetTensorShape(self, output, ptr, shape.length, status)
  end
end
to_function(name, operators, input_operations, output_operations, output_names=nil) click to toggle source
# File lib/tensorflow/graph/graph.rb, line 176
def to_function(name, operators, input_operations, output_operations, output_names=nil)
  inputs = input_operations ? input_operations.map(&:outputs).flatten : []
  inputs_ptr = FFI::Output.array_to_ptr(inputs.map(&:output))

  outputs = output_operations ? output_operations.map(&:outputs).flatten : []
  outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))

  # Check output names size
  if output_names && output_names.length != outputs.length
    raise(ArgumentError, "output_names length must equal outputs length or be nil")
  end

  # Convert to pointers - keep reference to pointers so they are not GC'ed until the end of the method
  output_names_ptr = if output_names
                       output_names_ptrs = output_names.map do |output_name|
                         ::FFI::MemoryPointer.from_string(output_name)
                       end
                       output_names_ptr = ::FFI::MemoryPointer.new(:pointer, output_names_ptrs.length, true)
                       output_names_ptr.write_array_of_pointer(output_names_ptrs)
                       output_names_ptr
                     else
                       nil
                     end

  append_hash_to_fn_name = 0
  options = nil
  description = nil

  func = Status.check do |status|
    FFI.TF_GraphToFunction(self, name, append_hash_to_fn_name,
                           operators ? operators.length : -1, operators,
                           inputs ? inputs.length : 0, inputs_ptr,
                           outputs ? outputs.length: 0, outputs_ptr,
                           output_names_ptr,
                           options, description, status)
  end
  output_types = output_operations.map(&:output_types).flatten(1)
  output_shapes = output_operations.map(&:output_shapes).flatten(1)
  Function.new(func, output_types, output_shapes)
end
to_ptr() click to toggle source
# File lib/tensorflow/graph/graph.rb, line 31
def to_ptr
  @pointer
end