class Tensorflow::Variable
Attributes
dtype[R]
handle[R]
name[R]
Public Class Methods
new(initial_value = nil, dtype: nil, shape: nil, shared_name: nil, name: 'Variable', trainable: true)
click to toggle source
# File lib/tensorflow/variable.rb, line 7 def initialize(initial_value = nil, dtype: nil, shape: nil, shared_name: nil, name: 'Variable', trainable: true) initial_value = case initial_value when NilClass @dtype = dtype shape = [] initial_value when Graph::Operation @dtype = dtype || initial_value.dtype shape = shape || initial_value.output_shapes.first initial_value when Tensor @dtype = initial_value.dtype shape = shape || initial_value.shape initial_value else tensor = Tensor.from_value(initial_value, dtype: dtype) @dtype = tensor.dtype shape = tensor.shape tensor end name = name&.to_s shared_name = shared_name&.to_s unique_name = ExecutionContext.current.unique_name(name || shared_name) shared_name ||= unique_name @name = unique_name collections = [Graph::GraphKeys::GLOBAL_VARIABLES] if trainable collections << Graph::GraphKeys::TRAINABLE_VARIABLES end ExecutionContext.current.add_to_collections(collections, self) @handle = RawOps.var_handle_op(dtype: @dtype, shape: shape, shared_name: shared_name, name: unique_name) self.value = initial_value if initial_value end
Public Instance Methods
assign_add(value, dtype: nil)
click to toggle source
# File lib/tensorflow/variable.rb, line 101 def assign_add(value, dtype: nil) @value_handle = nil tensor = Tensor.from_value(value, dtype: dtype) tensor = Tensorflow.cast(tensor, self.dtype) RawOps.assign_add_variable_op(self.handle, value, dtype: tensor.dtype) end
assign_sub(value)
click to toggle source
# File lib/tensorflow/variable.rb, line 108 def assign_sub(value) @value_handle = nil tensor = Tensor.from_value(value, dtype: dtype) tensor = Tensorflow.cast(tensor, self.dtype) RawOps.assign_sub_variable_op(self.handle, value, dtype: tensor.dtype) end
consumers()
click to toggle source
These methods match the operation api to enable gradients and sessions
# File lib/tensorflow/variable.rb, line 71 def consumers self.handle.consumers end
initialized?()
click to toggle source
# File lib/tensorflow/variable.rb, line 66 def initialized? RawOps.var_is_initialized_op(self.handle) end
initializer()
click to toggle source
# File lib/tensorflow/variable.rb, line 62 def initializer @initializer end
inspect()
click to toggle source
# File lib/tensorflow/variable.rb, line 119 def inspect inspection = [] inspection << ["name: #{self.handle.name}"] if self.handle.respond_to?(:name) inspection << ["shape: #{self.value_handle.shape}"] inspection << ["dtype: #{self.value_handle.dtype}"] "#<#{self.class} #{inspection.join(", ")}>" end
outputs()
click to toggle source
This enables executing variables to get the values in a session
# File lib/tensorflow/variable.rb, line 76 def outputs [Graph::OperationOutput.from_index(self.value_handle, 0)] end
rank()
click to toggle source
# File lib/tensorflow/variable.rb, line 93 def rank self.shape.size end
reshape(shape)
click to toggle source
# File lib/tensorflow/variable.rb, line 97 def reshape(shape) RawOps.reshape(self, shape) end
shape()
click to toggle source
# File lib/tensorflow/variable.rb, line 84 def shape self.value_handle.shape end
tensor()
click to toggle source
# File lib/tensorflow/variable.rb, line 88 def tensor raise(Error::UnavailableError, "Only supported in eager execution mode") if Tensorflow.execution_mode == Tensorflow::GRAPH_MODE self.value_handle.tensor end
to_ptr()
click to toggle source
# File lib/tensorflow/variable.rb, line 80 def to_ptr self.handle.to_ptr end
to_s()
click to toggle source
# File lib/tensorflow/variable.rb, line 115 def to_s inspect end
value()
click to toggle source
# File lib/tensorflow/variable.rb, line 49 def value case value_handle when Eager::TensorHandle value_handle.value when Graph::Operation value_handle end end
value=(value)
click to toggle source
# File lib/tensorflow/variable.rb, line 58 def value=(value) @initializer = RawOps.assign_variable_op(self.handle, value, dtype: @dtype) end
value_handle()
click to toggle source
# File lib/tensorflow/variable.rb, line 45 def value_handle @value_handle ||= RawOps.read_variable_op(self.handle, dtype: @dtype) end