class Lurn::Neighbors::KNNBase

Attributes

k[RW]
predictors[RW]
targets[RW]

Public Class Methods

new(k) click to toggle source
# File lib/lurn/neighbors/knn_base.rb, line 7
def initialize(k)
  @k = k
end

Public Instance Methods

fit(predictors, targets) click to toggle source

Trains the KNN regression model to predict the target variable based on the predictors. For KNN Regression all computation is deferred until the time of prediction so in this case the data is just stored.

@param predictors [Array-like] An array of arrays containing the predictor data @param targets [Array-like] An array with the value you want to predict

# File lib/lurn/neighbors/knn_base.rb, line 18
def fit(predictors, targets)
  @predictors = predictors.map { |pred| Vector.elements(pred) }
  @targets = targets

  nil
end
nearest_neighbors(vector) click to toggle source

Returns the predictors and target value for the k nearest neighbors for the vector parameter

@param vector [Array-like] An array of the same length and type as the predictors used to train the model @return [Array, Array]

Returns two values. The first is an array of the predictors for the k nearest neighbors. The second is an
array of the corresponding target values for the k nearest neighbors.
# File lib/lurn/neighbors/knn_base.rb, line 31
def nearest_neighbors(vector)
  vector = Vector.elements(vector)

  distances = @predictors.map.with_index do |p, index|
    { index: index, distance: euclidian_distance(p, vector), value: targets[index] }
  end

  distances.sort! { |x,y| x[:distance] <=> y[:distance] }

  neighboring_predictors = distances.first(@k).map { |neighbor| @predictors[neighbor[:index]] }
  neighboring_targets = distances.first(@k).map { |neighbor| @targets[neighbor[:index]] }

  return neighboring_predictors, neighboring_targets
end

Private Instance Methods

euclidian_distance(vector1, vector2) click to toggle source
# File lib/lurn/neighbors/knn_base.rb, line 48
def euclidian_distance(vector1, vector2)
  Math.sqrt((vector1 - vector2).map { |v| (v.abs)**2 }.inject(:+))
end