module TorchText::Data::Metrics
Public Class Methods
bleu_score(candidate_corpus, references_corpus, max_n: 4, weights: [0.25] * 4)
click to toggle source
# File lib/torchtext/data/metrics.rb, line 5 def bleu_score(candidate_corpus, references_corpus, max_n: 4, weights: [0.25] * 4) unless max_n == weights.length raise "Length of the \"weights\" list has be equal to max_n" end unless candidate_corpus.length == references_corpus.length raise "The length of candidate and reference corpus should be the same" end clipped_counts = Torch.zeros(max_n) total_counts = Torch.zeros(max_n) weights = Torch.tensor(weights) candidate_len = 0.0 refs_len = 0.0 candidate_corpus.zip(references_corpus) do |candidate, refs| candidate_len += candidate.length # Get the length of the reference that's closest in length to the candidate refs_len_list = refs.map { |ref| ref.length.to_f } refs_len += refs_len_list.min_by { |x| (candidate.length - x).abs } reference_counters = compute_ngram_counter(refs[0], max_n) refs[1..-1].each do |ref| reference_counters = reference_counters.merge(compute_ngram_counter(ref, max_n)) { |_, v1, v2| v1 > v2 ? v1 : v2 } end candidate_counter = compute_ngram_counter(candidate, max_n) shared_keys = candidate_counter.keys & reference_counters.keys clipped_counter = candidate_counter.slice(*shared_keys).merge(reference_counters.slice(*shared_keys)) { |_, v1, v2| v1 < v2 ? v1 : v2 } clipped_counter.each_key do |ngram| clipped_counts[ngram.length - 1] += clipped_counter[ngram] end candidate_counter.each_key do |ngram| total_counts[ngram.length - 1] += candidate_counter[ngram] end end if clipped_counts.to_a.min == 0 0.0 else pn = clipped_counts / total_counts log_pn = weights * Torch.log(pn) score = Torch.exp(log_pn.sum) bp = Math.exp([1 - refs_len / candidate_len, 0].min) bp * score.item end end
Private Class Methods
compute_ngram_counter(tokens, max_n)
click to toggle source
# File lib/torchtext/data/metrics.rb, line 61 def compute_ngram_counter(tokens, max_n) raise "Failed assert" unless max_n > 0 Hash[TorchText::Data::Utils.ngrams_iterator(tokens, max_n).map { |x| x.split(" ") }.group_by { |v| v }.map { |k, v| [k, v.size] }] end