class TorchText::NN::MultiheadAttentionContainer

Public Class Methods

new(nhead, in_proj_container, attention_layer, out_proj, batch_first: false) click to toggle source
Calls superclass method
# File lib/torchtext/nn/multihead_attention_container.rb, line 4
def initialize(nhead, in_proj_container, attention_layer, out_proj, batch_first: false)
  super()
  @nhead = nhead
  @in_proj_container = in_proj_container
  @attention_layer = attention_layer
  @out_proj = out_proj
  @batch_first = batch_first
end

Public Instance Methods

forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil) click to toggle source
# File lib/torchtext/nn/multihead_attention_container.rb, line 13
def forward(query, key, value, attn_mask: nil, bias_k: nil, bias_v: nil)
  if @batch_first
    query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
  end

  tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1)
  q, k, v = @in_proj_container.call(query, key, value)
  unless q.size(-1) % @nhead == 0
    raise "query's embed_dim must be divisible by the number of heads"
  end
  head_dim = q.size(-1).div(@nhead)
  q = q.reshape(tgt_len, bsz * @nhead, head_dim)

  unless k.size(-1) % @nhead == 0
    raise "key's embed_dim must be divisible by the number of heads"
  end
  head_dim = k.size(-1).div(@nhead)
  k = k.reshape(src_len, bsz * @nhead, head_dim)

  unless v.size(-1) % @nhead == 0
    raise "value's embed_dim must be divisible by the number of heads"
  end
  head_dim = v.size(-1).div(@nhead)
  v = v.reshape(src_len, bsz * @nhead, head_dim)

  attn_output, attn_output_weights = @attention_layer.call(q, k, v, attn_mask: attn_mask, bias_k: bias_k, bias_v: bias_v)
  attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
  attn_output = @out_proj.call(attn_output)

  if @batch_first
    attn_output = attn_output.transpose(-3, -2)
  end

  [attn_output, attn_output_weights]
end