class Tensorflow::Graph::FunctionDef

Constants

Signature

Attributes

ruby_method[R]
signatures[R]

Public Class Methods

new(ruby_method, input_signatures = []) click to toggle source
# File lib/tensorflow/graph/function_def.rb, line 8
def initialize(ruby_method, input_signatures = [])
  @ruby_method = ruby_method
  self.process_signatures(ruby_method, input_signatures)
  self.wrap_ruby_method
end

Public Instance Methods

aliased_name() click to toggle source
# File lib/tensorflow/graph/function_def.rb, line 24
def aliased_name
  "#{self.ruby_method.original_name}_original"
end
build_function(object) click to toggle source
# File lib/tensorflow/graph/function_def.rb, line 46
def build_function(object)
  Graph::new.as_default do |graph|
    placeholders = self.ruby_method.parameters.map.with_index do |param, index|
      signature = self.signatures[index]
      Tensorflow.placeholder(signature.dtype, name: param.last, shape: signature.shape)
    end

    # Call the original ruby_method to build the graph
    bound_method = self.ruby_method.bind(object)
    result = bound_method.call(*placeholders)

    graph.to_function(self.ruby_method.original_name.to_s, nil, placeholders, Array(result))
  end
end
process_signatures(ruby_method, input_signatures) click to toggle source
# File lib/tensorflow/graph/function_def.rb, line 14
def process_signatures(ruby_method, input_signatures)
  if input_signatures.length != ruby_method.parameters.length
    raise(Error::InvalidArgumentError, "Must specify input signature for each method parameter")
  end

  @signatures = input_signatures.map do |dtype, shape|
    Signature.new(dtype, shape)
  end
end
wrap_ruby_method() click to toggle source
# File lib/tensorflow/graph/function_def.rb, line 28
def wrap_ruby_method
  new_name = self.aliased_name
  original_name = self.ruby_method.original_name
  self.ruby_method.owner.instance_eval do
    alias_method(new_name, original_name)
  end

  this = self
  original_name = ruby_method.original_name
  self.ruby_method.owner.instance_eval do
    define_method(original_name) do |*args|
      function = this.build_function(self)
      ExecutionContext.current.add_function(function)
      function
    end
  end
end