class TorchVision::Models::ResNet

Constants

MODEL_URLS

Public Class Methods

make_model(arch, block, layers, pretrained: false, **kwargs) click to toggle source
# File lib/torchvision/models/resnet.rb, line 118
def self.make_model(arch, block, layers, pretrained: false, **kwargs)
  model = ResNet.new(block, layers, **kwargs)
  if pretrained
    url = MODEL_URLS[arch]
    state_dict = Torch::Hub.load_state_dict_from_url(url)
    model.load_state_dict(state_dict)
  end
  model
end
new(block, layers, num_classes=1000, zero_init_residual: false, groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil) click to toggle source
Calls superclass method
# File lib/torchvision/models/resnet.rb, line 16
def initialize(block, layers, num_classes=1000, zero_init_residual: false,
  groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)

  super()
  norm_layer ||= Torch::NN::BatchNorm2d
  @norm_layer = norm_layer

  @inplanes = 64
  @dilation = 1
  if replace_stride_with_dilation.nil?
    # each element in the tuple indicates if we should replace
    # the 2x2 stride with a dilated convolution instead
    replace_stride_with_dilation = [false, false, false]
  end
  if replace_stride_with_dilation.length != 3
    raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}"
  end
  @groups = groups
  @base_width = width_per_group
  @conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false)
  @bn1 = norm_layer.new(@inplanes)
  @relu = Torch::NN::ReLU.new(inplace: true)
  @maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1)
  @layer1 = _make_layer(block, 64, layers[0])
  @layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0])
  @layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1])
  @layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2])
  @avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1])
  @fc = Torch::NN::Linear.new(512 * block.expansion, num_classes)

  modules.each do |m|
    case m
    when Torch::NN::Conv2d
      Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
    when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm
      Torch::NN::Init.constant!(m.weight, 1)
      Torch::NN::Init.constant!(m.bias, 0)
    end
  end

  # Zero-initialize the last BN in each residual branch,
  # so that the residual branch starts with zeros, and each residual block behaves like an identity.
  # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
  if zero_init_residual
    modules.each do |m|
      case m
      when Bottleneck
        Torch::NN::Init.constant!(m.bn3.weight, 0)
      when BasicBlock
        Torch::NN::Init.constant!(m.bn2.weight, 0)
      end
    end
  end
end

Public Instance Methods

_forward_impl(x) click to toggle source
# File lib/torchvision/models/resnet.rb, line 96
def _forward_impl(x)
  x = @conv1.call(x)
  x = @bn1.call(x)
  x = @relu.call(x)
  x = @maxpool.call(x)

  x = @layer1.call(x)
  x = @layer2.call(x)
  x = @layer3.call(x)
  x = @layer4.call(x)

  x = @avgpool.call(x)
  x = Torch.flatten(x, 1)
  x = @fc.call(x)

  x
end
_make_layer(block, planes, blocks, stride: 1, dilate: false) click to toggle source
# File lib/torchvision/models/resnet.rb, line 71
def _make_layer(block, planes, blocks, stride: 1, dilate: false)
  norm_layer = @norm_layer
  downsample = nil
  previous_dilation = @dilation
  if dilate
    @dilation *= stride
    stride = 1
  end
  if stride != 1 || @inplanes != planes * block.expansion
    downsample = Torch::NN::Sequential.new(
      Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false),
      norm_layer.new(planes * block.expansion)
    )
  end

  layers = []
  layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer)
  @inplanes = planes * block.expansion
  (blocks - 1).times do
    layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer)
  end

  Torch::NN::Sequential.new(*layers)
end
forward(x) click to toggle source
# File lib/torchvision/models/resnet.rb, line 114
def forward(x)
  _forward_impl(x)
end