class TensorFlow::Keras::Metrics::Mean
Public Class Methods
new(name: nil, dtype: :float)
click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 5 def initialize(name: nil, dtype: :float) @dtype = dtype @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype) @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype) end
Public Instance Methods
call(*args)
click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 11 def call(*args) update_state(*args) end
reset_states()
click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 26 def reset_states end
result()
click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 22 def result RawOps.div_no_nan(x: @total, y: TensorFlow.cast(@count, :float)) end
update_state(values)
click to toggle source
# File lib/tensorflow/keras/metrics/mean.rb, line 15 def update_state(values) input = TensorFlow.convert_to_tensor(values) input = TensorFlow.cast(input, @dtype) @total.assign_add(Math.reduce_sum(input)) @count.assign_add(TensorFlow.cast(RawOps.size(input: input), @dtype)) end