class Informers::FeatureExtraction
Public Class Methods
new(model_path)
click to toggle source
# File lib/informers/feature_extraction.rb, line 18 def initialize(model_path) tokenizer_path = File.expand_path("../../vendor/bert_base_cased_tok.bin", __dir__) @tokenizer = BlingFire.load_model(tokenizer_path) @model = OnnxRuntime::Model.new(model_path) end
Public Instance Methods
predict(texts)
click to toggle source
# File lib/informers/feature_extraction.rb, line 24 def predict(texts) singular = !texts.is_a?(Array) texts = [texts] if singular # tokenize input_ids = texts.map do |text| tokens = @tokenizer.text_to_ids(text, nil, 100) # unk token tokens.unshift(101) # cls token tokens << 102 # sep token tokens end max_tokens = input_ids.map(&:size).max attention_mask = [] input_ids.each do |ids| zeros = [0] * (max_tokens - ids.size) mask = ([1] * ids.size) + zeros attention_mask << mask ids.concat(zeros) end # infer input = { input_ids: input_ids, attention_mask: attention_mask } output = @model.predict(input) scores = output["output_0"] singular ? scores.first : scores end