class Rblearn::CrossValidation::KFold

Public Class Methods

new(n, n_folds, shuffle) click to toggle source

TODO: make indices and n_folds private

# File lib/rblearn/CrossValidation.rb, line 19
def initialize(n, n_folds, shuffle)
  indices = (0...n).to_a
  indices.shuffle! if shuffle
  @indices = indices
  @n_folds = n_folds
end

Public Instance Methods

create() click to toggle source
# File lib/rblearn/CrossValidation.rb, line 26
def create
  groups_nfolds = @indices.each_slice((@indices.size.to_f / @n_folds).ceil).to_a
  groups = []

  @n_folds.times do |k|
    validation_set = []
    test_set = []

    @n_folds.times do |j|
      test_set += groups_nfolds[j] if k == j
      validation_set += groups_nfolds[j] unless k == j
    end
    groups << [validation_set, test_set]
  end

  return groups
end