class TorchVision::Models::AlexNet
Public Class Methods
new(num_classes: 1000)
click to toggle source
Calls superclass method
# File lib/torchvision/models/alexnet.rb, line 4 def initialize(num_classes: 1000) super() @features = Torch::NN::Sequential.new( Torch::NN::Conv2d.new(3, 64, 11, stride: 4, padding: 2), Torch::NN::ReLU.new(inplace: true), Torch::NN::MaxPool2d.new(3, stride: 2), Torch::NN::Conv2d.new(64, 192, 5, padding: 2), Torch::NN::ReLU.new(inplace: true), Torch::NN::MaxPool2d.new(3, stride: 2), Torch::NN::Conv2d.new(192, 384, 3, padding: 1), Torch::NN::ReLU.new(inplace: true), Torch::NN::Conv2d.new(384, 256, 3, padding: 1), Torch::NN::ReLU.new(inplace: true), Torch::NN::Conv2d.new(256, 256, 3, padding: 1), Torch::NN::ReLU.new(inplace: true), Torch::NN::MaxPool2d.new(3, stride: 2), ) @avgpool = Torch::NN::AdaptiveAvgPool2d.new([6, 6]) @classifier = Torch::NN::Sequential.new( Torch::NN::Dropout.new, Torch::NN::Linear.new(256 * 6 * 6, 4096), Torch::NN::ReLU.new(inplace: true), Torch::NN::Dropout.new, Torch::NN::Linear.new(4096, 4096), Torch::NN::ReLU.new(inplace: true), Torch::NN::Linear.new(4096, num_classes) ) end
Public Instance Methods
forward(x)
click to toggle source
# File lib/torchvision/models/alexnet.rb, line 33 def forward(x) x = @features.call(x) x = @avgpool.call(x) x = Torch.flatten(x, 1) x = @classifier.call(x) x end