class Tensorflow::Keras::Metrics::Mean

Public Class Methods

new(name: nil, dtype: :float) click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 5
def initialize(name: nil, dtype: :float)
  @dtype = dtype
  @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype)
  @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype)
end

Public Instance Methods

call(*args) click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 11
def call(*args)
  update_state(*args)
end
reset_states() click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 25
def reset_states
end
result() click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 21
def result
  RawOps.div_no_nan(@total, Tensorflow.cast(@count, :float))
end
update_state(values) click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 15
def update_state(values)
  input = Tensorflow.cast(input, destination_dtype: @dtype)
  @total.assign_add(Math.reduce_sum(input))
  @count.assign_add(Tensorflow.cast(RawOps.size(input), @dtype))
end