class Tensorflow::Eager::Operation
Attributes
context[R]
guessed_dtype[R]
op_def[R]
status[R]
Public Class Methods
new(context, op_type, inputs, attrs)
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 6 def initialize(context, op_type, inputs, attrs) @context = context @op_def = case op_type when Graph::Function op_type.function_def.signature else Tensorflow.op_def(op_type) end raise(Error::InvalidArgumentError, "Invalid op type: #{op_type}") unless @op_def @status = Status.new @pointer = FFI.TFE_NewOp(context, self.op_def.name, self.status) name = attrs.delete(:name) || op_type inputs = Array(inputs) @guessed_dtype = figure_dtype(attrs, inputs) setup_inputs(inputs, attrs) setup_attrs(attrs) end
Public Instance Methods
add_input(value)
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 199 def add_input(value) # Check to see if the operation has multiple outputs, and if it does, we need to pack them together # to fit into one input if value.is_a?(Array) && value.length > 1 packed = Tensorflow.pack(value) FFI.TFE_OpAddInput(self, packed, self.status) else FFI.TFE_OpAddInput(self, value, self.status) end self.status.check end
add_input_list(values)
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 211 def add_input_list(values) input_ptr = ::FFI::MemoryPointer.new(:pointer, values.length) input_ptr.write_array_of_pointer(values) FFI.TFE_OpAddInputList(self, input_ptr, values.length, self.status) self.status.check end
add_list_attr(type, attr_name, attr_value)
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 73 def add_list_attr(type, attr_name, attr_value) num_values = attr_value.size case type when :int values = ::FFI::MemoryPointer.new(:int64, num_values) values.write_array_of_int64(attr_value) FFI.TFE_OpSetAttrIntList(self, attr_name, values, num_values) when :float values = ::FFI::MemoryPointer.new(:float, num_values) values.write_array_of_float(attr_value) FFI.TFE_OpSetAttrFloatList(self, attr_name, values, num_values) when :shape dims_pointer = ::FFI::MemoryPointer.new(:pointer, num_values) num_dims_pointer = ::FFI::MemoryPointer.new(:int32, num_values) attr_value.each_with_index do |shape, i| dim_pointer = ::FFI::MemoryPointer.new(:int64, shape.length) dim_pointer.write_array_of_int64(shape) dims_pointer.put_pointer(i * ::FFI.type_size(:pointer), dim_pointer) num_dims_pointer.put_int32(i * ::FFI.type_size(:int32), shape.length) end FFI.TFE_OpSetAttrShapeList(self, attr_name, dims_pointer, num_dims_pointer, num_values, self.status) self.status.check when :type values = ::FFI::MemoryPointer.new(:int, num_values) types = attr_value.map do |v| if v.is_a?(Symbol) FFI::DataType[v] else v end end values.write_array_of_int(types) FFI.TFE_OpSetAttrTypeList(self, attr_name, values, num_values) else raise "Unknown list type: #{type}" end end
add_scalar_attr(type, attr_name, attr_value)
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 113 def add_scalar_attr(type, attr_name, attr_value) case type when :string FFI.TFE_OpSetAttrString(self, attr_name, attr_value, attr_value.bytesize) when :int FFI.TFE_OpSetAttrInt(self, attr_name, attr_value) when :float FFI.TFE_OpSetAttrFloat(self, attr_name, attr_value) when :bool FFI.TFE_OpSetAttrBool(self, attr_name, attr_value ? 1 : 0) when :type attr_value = FFI::DataType[attr_value] if attr_value.is_a?(Symbol) FFI.TFE_OpSetAttrType(self, attr_name, attr_value) when :shape ptr = ::FFI::MemoryPointer.new(:int64, attr_value.size) ptr.write_array_of_int64(attr_value) FFI.TFE_OpSetAttrShape(self, attr_name, ptr, attr_value.size, self.status) when :tensor attr_value = TensorHandle.from_value(self.context, attr_value) FFI.TFE_OpSetAttrTensor(self, attr_name, attr_value.tensor, self.status) # when :placeholder when :func case attr_value when Graph::Function FFI.TFE_OpSetAttrFunctionName(self, attr_name, attr_value.name, attr_value.name.length) when String FFI.TFE_OpSetAttrFunctionName(self, attr_name, attr_value, attr_value.length) else self.status.set(:tf_invalid_argument, "Invalid function attribute for attribute: #{attr_name}") end else self.status.set(:tf_unknown, "Unsupported attribute type: #{type}") end self.status.check end
check_input(arg_def, input, dtype)
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 155 def check_input(arg_def, input, dtype) case input when Variable arg_def.type == :DT_RESOURCE ? input.handle : input.value_handle else TensorHandle.from_value(self.context, input, dtype: dtype) end end
dtype()
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 31 def dtype list_ptr = ::FFI::MemoryPointer.new(:int) FFI.TFE_OpGetAttrType(self, 'dtype', list_ptr, self.status) end
figure_dtype(attrs, inputs)
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 36 def figure_dtype(attrs, inputs) attr_def = self.op_def.attr.detect do |attr_def| attr_def.type == 'type' end result = attr_def ? attrs[attr_def.name.to_sym] : nil unless result inputs.each do |input| case input when Operation return input.output_types.first when Variable return input.dtype end end end result end
setup_attrs(attrs)
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 55 def setup_attrs(attrs) attrs.each do |attr_name, attr_value| next unless attr_value attr_name = attr_name.to_s list_ptr = ::FFI::MemoryPointer.new(:int) type = FFI.TFE_OpGetAttrType(self, attr_name, list_ptr, self.status) self.status.check is_list = Boolean(list_ptr.read_int) if is_list add_list_attr(type, attr_name, attr_value) else add_scalar_attr(type, attr_name, attr_value) end end end
setup_input(index, value, attrs)
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 164 def setup_input(index, value, attrs) if value.nil? self.status.set(:tf_invalid_argument, "Argument is unset. Index: #{index}") self.status.check end arg_def = self.op_def.input_arg[index] dtype = attrs[arg_def.type_attr.to_sym] # Value can be an operation with multiple outputs. For example calling PACK with an input operation of SPLIT checked_value = if (!arg_def.number_attr.empty? || !arg_def.type_list_attr.empty?) && value.is_a?(Array) value.map do |sub_value| self.check_input(arg_def, sub_value, dtype) end else self.check_input(arg_def, value, dtype) end if !arg_def.type_list_attr.empty? # This input is a heterogeneous list self.add_input_list(checked_value) elsif !arg_def.number_attr.empty? && !arg_def.type_attr.empty? # This input is a homogeneous list self.add_input_list(checked_value) elsif !arg_def.number_attr.empty? # This is a list but we have to set it up one input at a time checked_value.each do |sub_checked_value| self.add_input(sub_checked_value) end else # This input is a single item self.add_input(checked_value) end end
setup_inputs(inputs, attrs)
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 149 def setup_inputs(inputs, attrs) inputs.each_with_index do |input, index| setup_input(index, input, attrs) end end
to_ptr()
click to toggle source
# File lib/tensorflow/eager/operation.rb, line 27 def to_ptr @pointer end