class Informers::FeatureExtraction

Public Class Methods

new(model_path) click to toggle source
# File lib/informers/feature_extraction.rb, line 18
def initialize(model_path)
  tokenizer_path = File.expand_path("../../vendor/bert_base_cased_tok.bin", __dir__)
  @tokenizer = BlingFire.load_model(tokenizer_path)
  @model = OnnxRuntime::Model.new(model_path)
end

Public Instance Methods

predict(texts) click to toggle source
# File lib/informers/feature_extraction.rb, line 24
def predict(texts)
  singular = !texts.is_a?(Array)
  texts = [texts] if singular

  # tokenize
  input_ids =
    texts.map do |text|
      tokens = @tokenizer.text_to_ids(text, nil, 100) # unk token
      tokens.unshift(101) # cls token
      tokens << 102 # sep token
      tokens
    end

  max_tokens = input_ids.map(&:size).max
  attention_mask = []
  input_ids.each do |ids|
    zeros = [0] * (max_tokens - ids.size)

    mask = ([1] * ids.size) + zeros
    attention_mask << mask

    ids.concat(zeros)
  end

  # infer
  input = {
    input_ids: input_ids,
    attention_mask: attention_mask
  }
  output = @model.predict(input)
  scores = output["output_0"]

  singular ? scores.first : scores
end