class TensorFlow::Variable

Attributes

name[R]

Public Class Methods

new(initial_value = nil, dtype: nil, shape: nil, name: nil) click to toggle source
# File lib/tensorflow/variable.rb, line 5
def initialize(initial_value = nil, dtype: nil, shape: nil, name: nil)
  @dtype = dtype || Utils.infer_type(Array(initial_value).flatten)
  @shape = shape
  @name = name
  @pointer = RawOps.var_handle_op(dtype: type_enum, shape: [], shared_name: Utils.default_context.shared_name)
  assign(initial_value) if initial_value
end

Public Instance Methods

+(other) click to toggle source
# File lib/tensorflow/variable.rb, line 35
def +(other)
  v = Variable.new(read_value.value, dtype: @dtype)
  v.assign_add(other).read_value
end
-(other) click to toggle source
# File lib/tensorflow/variable.rb, line 40
def -(other)
  v = Variable.new(read_value.value, dtype: @dtype)
  v.assign_sub(other).read_value
end
assign(value) click to toggle source
# File lib/tensorflow/variable.rb, line 13
def assign(value)
  value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
  RawOps.assign_variable_op(resource: @pointer, value: value)
  self
end
assign_add(value) click to toggle source
# File lib/tensorflow/variable.rb, line 19
def assign_add(value)
  value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
  RawOps.assign_add_variable_op(resource: @pointer, value: value)
  self
end
assign_sub(value) click to toggle source
# File lib/tensorflow/variable.rb, line 25
def assign_sub(value)
  value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
  RawOps.assign_sub_variable_op(resource: @pointer, value: value)
  self
end
inspect() click to toggle source
# File lib/tensorflow/variable.rb, line 53
def inspect
  value = read_value
  inspection = %w(numo shape dtype).map { |v| "#{v}: #{value.send(v).inspect}"}
  inspection.unshift("name: #{name}") if name
  "#<#{self.class} #{inspection.join(", ")}>"
end
read_value() click to toggle source
# File lib/tensorflow/variable.rb, line 31
def read_value
  RawOps.read_variable_op(resource: @pointer, dtype: type_enum)
end
shape() click to toggle source
# File lib/tensorflow/variable.rb, line 49
def shape
  read_value.shape
end
to_ptr() click to toggle source
# File lib/tensorflow/variable.rb, line 60
def to_ptr
  read_value.to_ptr
end
to_s() click to toggle source
# File lib/tensorflow/variable.rb, line 45
def to_s
  inspect
end

Private Instance Methods

type_enum() click to toggle source
# File lib/tensorflow/variable.rb, line 66
def type_enum
  FFI::DataType[@dtype.to_sym] if @dtype
end