class TensorFlow::Keras::Metrics::Mean

Public Class Methods

new(name: nil, dtype: :float) click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 5
def initialize(name: nil, dtype: :float)
  @dtype = dtype
  @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype)
  @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype)
end

Public Instance Methods

call(*args) click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 11
def call(*args)
  update_state(*args)
end
reset_states() click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 26
def reset_states
end
result() click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 22
def result
  RawOps.div_no_nan(x: @total, y: TensorFlow.cast(@count, :float))
end
update_state(values) click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 15
def update_state(values)
  input = TensorFlow.convert_to_tensor(values)
  input = TensorFlow.cast(input, @dtype)
  @total.assign_add(Math.reduce_sum(input))
  @count.assign_add(TensorFlow.cast(RawOps.size(input: input), @dtype))
end