class Rblearn::CrossValidation::KFold
Public Class Methods
new(n, n_folds, shuffle)
click to toggle source
TODO: make indices and n_folds private
# File lib/rblearn/CrossValidation.rb, line 19 def initialize(n, n_folds, shuffle) indices = (0...n).to_a indices.shuffle! if shuffle @indices = indices @n_folds = n_folds end
Public Instance Methods
create()
click to toggle source
# File lib/rblearn/CrossValidation.rb, line 26 def create groups_nfolds = @indices.each_slice((@indices.size.to_f / @n_folds).ceil).to_a groups = [] @n_folds.times do |k| validation_set = [] test_set = [] @n_folds.times do |j| test_set += groups_nfolds[j] if k == j validation_set += groups_nfolds[j] unless k == j end groups << [validation_set, test_set] end return groups end