class MultiArmedBandit::Softmax
Attributes
counts[RW]
n_arms[RW]
probs[RW]
temperature[RW]
values[RW]
Public Class Methods
new(temperature, n_arms)
click to toggle source
Initialize an object
# File lib/multi_armed_bandit/softmax.rb, line 10 def initialize(temperature, n_arms) @n_arms = n_arms @temperature = temperature reset() end
Public Instance Methods
bulk_update(new_counts, new_rewards)
click to toggle source
Update in a lump. new_counts is a list of each arm’s trial number and new_rewards means a list of rewards. both each num in new_counts and new_rewards should be accumulated numbers
# File lib/multi_armed_bandit/softmax.rb, line 26 def bulk_update(new_counts, new_rewards) # update the numbers of each arm's trial @counts = new_counts # update expectations of each arm new_values = [] @counts.zip( new_rewards ).each do |n, reward| new_values << reward / n.to_f end @values = new_values # calcurate probabilities z = @values.collect{|i| Math.exp(i/@temperature)}.reduce(:+) @probs = @values.collect{|i| Math.exp(i/@temperature)/z} return probs end
reset()
click to toggle source
Reset instance variables
# File lib/multi_armed_bandit/softmax.rb, line 17 def reset() @counts = Array.new(@n_arms, 0) @values = Array.new(@n_arms, 0.0) @probs = Array.new(@n_arms, 0.0) end
select_arm()
click to toggle source
# File lib/multi_armed_bandit/softmax.rb, line 57 def select_arm z = @values.collect{|i| Math.exp(i/@temperature)}.reduce(:+) @probs = @values.collect{|i| Map.exp(i/@temperature)/z} return categorical_draw(@probs) end
update(chosen_arm, reward)
click to toggle source
# File lib/multi_armed_bandit/softmax.rb, line 46 def update(chosen_arm, reward) @counts[chosen_arm] = @counts[chosen_arm] + 1 n = @counts[chosen_arm] value = @values[chosen_arm] new_value = ((n - 1) / n.to_f) * value + (1 / n.to_f) * reward @values[chosen_arm] = new_value return end
Private Instance Methods
categorical_draw(probs)
click to toggle source
# File lib/multi_armed_bandit/softmax.rb, line 64 def categorical_draw(probs) z = rand() cum_prob = 0.0 probs.size().times do |i| prob = probs[i] cum_prob += prob if cum_prob > z return i end end return probs.size() - 1 end