module Rblearn::CrossValidation

Public Class Methods

train_test_split(x, y, test_size=0.33) click to toggle source

x, y: Narray object We slice a matrix by x[Array<Integer>, true]

# File lib/rblearn/CrossValidation.rb, line 6
def self.train_test_split(x, y, test_size=0.33)
  doc_size = x.shape[0]
  random_indices = (0...doc_size).to_a.shuffle
  endpoint = (doc_size * test_size).to_i
  train_indices = random_indices[endpoint..-1]
  test_indices = random_indices[0...endpoint]

  return [x[train_indices, true], y[train_indices, true], x[test_indices, true], y[test_indices, true]]
end