class TorchVision::Models::VGG

Constants

CFGS
MODEL_URLS

Public Class Methods

make_layers(cfg, batch_norm) click to toggle source
# File lib/torchvision/models/vgg.rb, line 73
def self.make_layers(cfg, batch_norm)
  layers = []
  in_channels = 3
  cfg.each do |v|
    if v == "M"
      layers += [Torch::NN::MaxPool2d.new(2, stride: 2)]
    else
      conv2d = Torch::NN::Conv2d.new(in_channels, v, 3, padding: 1)
      if batch_norm
        layers += [conv2d, Torch::NN::BatchNorm2d.new(v), Torch::NN::ReLU.new(inplace: true)]
      else
        layers += [conv2d, Torch::NN::ReLU.new(inplace: true)]
      end
      in_channels = v
    end
  end
  Torch::NN::Sequential.new(*layers)
end
make_model(arch, cfg, batch_norm, pretrained: false, **kwargs) click to toggle source
# File lib/torchvision/models/vgg.rb, line 62
def self.make_model(arch, cfg, batch_norm, pretrained: false, **kwargs)
  kwargs[:init_weights] = false if pretrained
  model = VGG.new(make_layers(CFGS[cfg], batch_norm), **kwargs)
  if pretrained
    url = MODEL_URLS[arch]
    state_dict = Torch::Hub.load_state_dict_from_url(url)
    model.load_state_dict(state_dict)
  end
  model
end
new(features, num_classes: 1000, init_weights: true) click to toggle source
Calls superclass method
# File lib/torchvision/models/vgg.rb, line 15
def initialize(features, num_classes: 1000, init_weights: true)
  super()
  @features = features
  @avgpool = Torch::NN::AdaptiveAvgPool2d.new([7, 7])
  @classifier = Torch::NN::Sequential.new(
    Torch::NN::Linear.new(512 * 7 * 7, 4096),
    Torch::NN::ReLU.new(inplace: true),
    Torch::NN::Dropout.new,
    Torch::NN::Linear.new(4096, 4096),
    Torch::NN::ReLU.new(inplace: true),
    Torch::NN::Dropout.new,
    Torch::NN::Linear.new(4096, num_classes)
  )
  _initialize_weights if init_weights
end

Public Instance Methods

_initialize_weights() click to toggle source
# File lib/torchvision/models/vgg.rb, line 39
def _initialize_weights
  modules.each do |m|
    case m
    when Torch::NN::Conv2d
      Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
      Torch::NN::Init.constant!(m.bias, 0) if m.bias
    when Torch::NN::BatchNorm2d
      Torch::NN::Init.constant!(m.weight, 1)
      Torch::NN::Init.constant!(m.bias, 0)
    when Torch::NN::Linear
      Torch::NN::Init.normal!(m.weight, mean: 0, std: 0.01)
      Torch::NN::Init.constant!(m.bias, 0)
    end
  end
end
forward(x) click to toggle source
# File lib/torchvision/models/vgg.rb, line 31
def forward(x)
  x = @features.call(x)
  x = @avgpool.call(x)
  x = Torch.flatten(x, 1)
  x = @classifier.call(x)
  x
end