module TorchText::Datasets::TextClassification

Constants

DATASETS
FILENAMES
LABELS
PATHS
URLS

Public Class Methods

ag_news(*args, **kwargs) click to toggle source
# File lib/torchtext/datasets/text_classification.rb, line 15
def ag_news(*args, **kwargs)
  setup_datasets("AG_NEWS", *args, **kwargs)
end

Private Class Methods

_create_data_from_iterator(vocab, iterator, include_unk) click to toggle source
# File lib/torchtext/datasets/text_classification.rb, line 70
def _create_data_from_iterator(vocab, iterator, include_unk)
  data = []
  labels = []
  iterator.each do |cls, tokens|
    if include_unk
      tokens = Torch.tensor(tokens.map { |token| vocab[token] })
    else
      token_ids = tokens.map { |token| vocab[token] }.select { |x| x != Vocab::UNK }
      tokens = Torch.tensor(token_ids)
    end
    data << [cls, tokens]
    labels << cls
  end
  [data, Set.new(labels)]
end
_csv_iterator(data_path, ngrams, yield_cls: false) { |to_i| ... } click to toggle source
# File lib/torchtext/datasets/text_classification.rb, line 55
def _csv_iterator(data_path, ngrams, yield_cls: false)
  return enum_for(:_csv_iterator, data_path, ngrams, yield_cls: yield_cls) unless block_given?

  tokenizer = Data.tokenizer("basic_english")
  CSV.foreach(data_path) do |row|
    tokens = row[1..-1].join(" ")
    tokens = tokenizer.call(tokens)
    if yield_cls
      yield row[0].to_i - 1, Data::Utils.ngrams_iterator(tokens, ngrams)
    else
      yield Data::Utils.ngrams_iterator(tokens, ngrams)
    end
  end
end
download_from_url(url, root:, filename:) click to toggle source

extra filename parameter

# File lib/torchtext/datasets/text_classification.rb, line 87
def download_from_url(url, root:, filename:)
  path = File.join(root, filename)
  return path if File.exist?(path)

  FileUtils.mkdir_p(root)

  puts "Downloading #{url}..."
  download_url_to_file(url, path)
end
download_url_to_file(url, dst) click to toggle source

follows redirects

# File lib/torchtext/datasets/text_classification.rb, line 98
def download_url_to_file(url, dst)
  uri = URI(url)
  tmp = nil
  location = nil

  Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
    request = Net::HTTP::Get.new(uri)

    http.request(request) do |response|
      case response
      when Net::HTTPRedirection
        location = response["location"]
      when Net::HTTPSuccess
        tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
        File.open(tmp, "wb") do |f|
          response.read_body do |chunk|
            f.write(chunk)
          end
        end
      else
        raise Error, "Bad response"
      end
    end
  end

  if location
    download_url_to_file(location, dst)
  else
    FileUtils.mv(tmp, dst)
    dst
  end
end
extract_archive(from_path, to_path: nil, overwrite: nil) click to toggle source

extract_tar_gz doesn't list files, so just return to_path

# File lib/torchtext/datasets/text_classification.rb, line 132
def extract_archive(from_path, to_path: nil, overwrite: nil)
  to_path ||= File.dirname(from_path)

  if from_path.end_with?(".tar.gz") || from_path.end_with?(".tgz")
    File.open(from_path, "rb") do |io|
      Gem::Package.new("").extract_tar_gz(io, to_path)
    end
    return to_path
  end

  raise "We currently only support tar.gz and tgz archives"
end
setup_datasets(dataset_name, root: ".data", ngrams: 1, vocab: nil, include_unk: false) click to toggle source
# File lib/torchtext/datasets/text_classification.rb, line 21
def setup_datasets(dataset_name, root: ".data", ngrams: 1, vocab: nil, include_unk: false)
  dataset_tar = download_from_url(URLS[dataset_name], root: root, filename: FILENAMES[dataset_name])
  to_path = extract_archive(dataset_tar)
  extracted_files = Dir["#{to_path}/#{PATHS[dataset_name]}/*"]

  train_csv_path = nil
  test_csv_path = nil
  extracted_files.each do |fname|
    if fname.end_with?("train.csv")
      train_csv_path = fname
    elsif fname.end_with?("test.csv")
      test_csv_path = fname
    end
  end

  if vocab.nil?
    vocab = Vocab.build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))
  else
    unless vocab.is_a?(Vocab)
      raise ArgumentError, "Passed vocabulary is not of type Vocab"
    end
  end
  train_data, train_labels = _create_data_from_iterator(vocab, _csv_iterator(train_csv_path, ngrams, yield_cls: true), include_unk)
  test_data, test_labels = _create_data_from_iterator(vocab, _csv_iterator(test_csv_path, ngrams, yield_cls: true), include_unk)
  if (train_labels ^ test_labels).length > 0
    raise ArgumentError, "Training and test labels don't match"
  end

  [
    TextClassificationDataset.new(vocab, train_data, train_labels),
    TextClassificationDataset.new(vocab, test_data, test_labels)
  ]
end