class TorchVision::Models::ResNet
Constants
- MODEL_URLS
Public Class Methods
make_model(arch, block, layers, pretrained: false, **kwargs)
click to toggle source
# File lib/torchvision/models/resnet.rb, line 118 def self.make_model(arch, block, layers, pretrained: false, **kwargs) model = ResNet.new(block, layers, **kwargs) if pretrained url = MODEL_URLS[arch] state_dict = Torch::Hub.load_state_dict_from_url(url) model.load_state_dict(state_dict) end model end
new(block, layers, num_classes=1000, zero_init_residual: false, groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
click to toggle source
Calls superclass method
# File lib/torchvision/models/resnet.rb, line 16 def initialize(block, layers, num_classes=1000, zero_init_residual: false, groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil) super() norm_layer ||= Torch::NN::BatchNorm2d @norm_layer = norm_layer @inplanes = 64 @dilation = 1 if replace_stride_with_dilation.nil? # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [false, false, false] end if replace_stride_with_dilation.length != 3 raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}" end @groups = groups @base_width = width_per_group @conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false) @bn1 = norm_layer.new(@inplanes) @relu = Torch::NN::ReLU.new(inplace: true) @maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1) @layer1 = _make_layer(block, 64, layers[0]) @layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0]) @layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1]) @layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2]) @avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1]) @fc = Torch::NN::Linear.new(512 * block.expansion, num_classes) modules.each do |m| case m when Torch::NN::Conv2d Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu") when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm Torch::NN::Init.constant!(m.weight, 1) Torch::NN::Init.constant!(m.bias, 0) end end # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual modules.each do |m| case m when Bottleneck Torch::NN::Init.constant!(m.bn3.weight, 0) when BasicBlock Torch::NN::Init.constant!(m.bn2.weight, 0) end end end end
Public Instance Methods
_forward_impl(x)
click to toggle source
# File lib/torchvision/models/resnet.rb, line 96 def _forward_impl(x) x = @conv1.call(x) x = @bn1.call(x) x = @relu.call(x) x = @maxpool.call(x) x = @layer1.call(x) x = @layer2.call(x) x = @layer3.call(x) x = @layer4.call(x) x = @avgpool.call(x) x = Torch.flatten(x, 1) x = @fc.call(x) x end
_make_layer(block, planes, blocks, stride: 1, dilate: false)
click to toggle source
# File lib/torchvision/models/resnet.rb, line 71 def _make_layer(block, planes, blocks, stride: 1, dilate: false) norm_layer = @norm_layer downsample = nil previous_dilation = @dilation if dilate @dilation *= stride stride = 1 end if stride != 1 || @inplanes != planes * block.expansion downsample = Torch::NN::Sequential.new( Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false), norm_layer.new(planes * block.expansion) ) end layers = [] layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer) @inplanes = planes * block.expansion (blocks - 1).times do layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer) end Torch::NN::Sequential.new(*layers) end
forward(x)
click to toggle source
# File lib/torchvision/models/resnet.rb, line 114 def forward(x) _forward_impl(x) end