class NN::Dropout

Public Class Methods

new(nn) click to toggle source
# File lib/nn.rb, line 384
def initialize(nn)
  @nn = nn
  @mask = nil
end

Public Instance Methods

backward(dout) click to toggle source
# File lib/nn.rb, line 399
def backward(dout)
  dout[@mask] = 0 if @nn.training
  dout
end
forward(x) click to toggle source
# File lib/nn.rb, line 389
def forward(x)
  if @nn.training
    @mask = SFloat.ones(*x.shape).rand < @nn.dropout_ratio
    x[@mask] = 0
  else
    x *= (1 - @nn.dropout_ratio)
  end
  x
end