class TorchVision::Models::VGG
Constants
- CFGS
- MODEL_URLS
Public Class Methods
make_layers(cfg, batch_norm)
click to toggle source
# File lib/torchvision/models/vgg.rb, line 73 def self.make_layers(cfg, batch_norm) layers = [] in_channels = 3 cfg.each do |v| if v == "M" layers += [Torch::NN::MaxPool2d.new(2, stride: 2)] else conv2d = Torch::NN::Conv2d.new(in_channels, v, 3, padding: 1) if batch_norm layers += [conv2d, Torch::NN::BatchNorm2d.new(v), Torch::NN::ReLU.new(inplace: true)] else layers += [conv2d, Torch::NN::ReLU.new(inplace: true)] end in_channels = v end end Torch::NN::Sequential.new(*layers) end
make_model(arch, cfg, batch_norm, pretrained: false, **kwargs)
click to toggle source
# File lib/torchvision/models/vgg.rb, line 62 def self.make_model(arch, cfg, batch_norm, pretrained: false, **kwargs) kwargs[:init_weights] = false if pretrained model = VGG.new(make_layers(CFGS[cfg], batch_norm), **kwargs) if pretrained url = MODEL_URLS[arch] state_dict = Torch::Hub.load_state_dict_from_url(url) model.load_state_dict(state_dict) end model end
new(features, num_classes: 1000, init_weights: true)
click to toggle source
Calls superclass method
# File lib/torchvision/models/vgg.rb, line 15 def initialize(features, num_classes: 1000, init_weights: true) super() @features = features @avgpool = Torch::NN::AdaptiveAvgPool2d.new([7, 7]) @classifier = Torch::NN::Sequential.new( Torch::NN::Linear.new(512 * 7 * 7, 4096), Torch::NN::ReLU.new(inplace: true), Torch::NN::Dropout.new, Torch::NN::Linear.new(4096, 4096), Torch::NN::ReLU.new(inplace: true), Torch::NN::Dropout.new, Torch::NN::Linear.new(4096, num_classes) ) _initialize_weights if init_weights end
Public Instance Methods
_initialize_weights()
click to toggle source
# File lib/torchvision/models/vgg.rb, line 39 def _initialize_weights modules.each do |m| case m when Torch::NN::Conv2d Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu") Torch::NN::Init.constant!(m.bias, 0) if m.bias when Torch::NN::BatchNorm2d Torch::NN::Init.constant!(m.weight, 1) Torch::NN::Init.constant!(m.bias, 0) when Torch::NN::Linear Torch::NN::Init.normal!(m.weight, mean: 0, std: 0.01) Torch::NN::Init.constant!(m.bias, 0) end end end
forward(x)
click to toggle source
# File lib/torchvision/models/vgg.rb, line 31 def forward(x) x = @features.call(x) x = @avgpool.call(x) x = Torch.flatten(x, 1) x = @classifier.call(x) x end