class NN::BatchNorm

Attributes

d_beta[R]
d_gamma[R]

Public Class Methods

new(nn, index) click to toggle source
# File lib/nn.rb, line 412
def initialize(nn, index)
  @nn = nn
  @index = index
end

Public Instance Methods

backward(dout) click to toggle source
# File lib/nn.rb, line 426
def backward(dout)
  @d_beta = dout.sum(0)
  @d_gamma = (@xn * dout).sum(0)
  dxn = @nn.gammas[@index] * dout
  dxc = dxn / @std
  dstd = -((dxn * @xc) / (@std ** 2)).sum(0)
  dvar = 0.5 * dstd / @std
  dxc += (2.0 / @nn.batch_size) * @xc * dvar
  dmean = dxc.sum(0)
  dxc - dmean / @nn.batch_size
end
forward(x) click to toggle source
# File lib/nn.rb, line 417
def forward(x)
  @mean = x.mean(0)
  @xc = x - @mean
  @var = (@xc ** 2).mean(0)
  @std = NMath.sqrt(@var + 1e-7)
  @xn = @xc / @std
  @nn.gammas[@index] * @xn + @nn.betas[@index]
end