module TensorFlow::Utils

Constants

NUMO_TYPE_MAP

Public Class Methods

check_status(status) click to toggle source
# File lib/tensorflow/utils.rb, line 17
def check_status(status)
  if FFI.TF_GetCode(status) != 0
    raise Error, FFI.TF_Message(status)
  end
end
default_context() click to toggle source
# File lib/tensorflow/utils.rb, line 23
def default_context
  @default_context ||= Context.new
end
execute(op_name, inputs = [], **attrs) click to toggle source
# File lib/tensorflow/utils.rb, line 27
def execute(op_name, inputs = [], **attrs)
  context = default_context
  status = FFI.TF_NewStatus # TODO reuse status between ops?
  op = FFI.TFE_NewOp(context, op_name, status)
  check_status status

  attrs.each do |attr_name, attr_value|
    next if attr_value.nil?

    attr_name = attr_name.to_s

    is_list = ::FFI::MemoryPointer.new(:int)
    type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status)
    check_status status

    if is_list.read_int == 1
      num_values = attr_value.size

      case FFI::AttrType[type]
      when :int
        values = ::FFI::MemoryPointer.new(:int64, num_values)
        values.write_array_of_int64(attr_value)
        FFI.TFE_OpSetAttrIntList(op, attr_name, values, num_values)
      when :float
        values = ::FFI::MemoryPointer.new(:float, num_values)
        values.write_array_of_float(attr_value)
        FFI.TFE_OpSetAttrFloatList(op, attr_name, values, num_values)
      when :shape
        dims_ptrs =
          attr_value.map do |shape|
            ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
            ptr.write_array_of_int64(shape)
          end
        dims = ::FFI::MemoryPointer.new(:pointer, num_values)
        dims.write_array_of_pointer(dims_ptrs)

        num_dims = ::FFI::MemoryPointer.new(:int, num_values)
        num_dims.write_array_of_int(attr_value.map(&:size))

        FFI.TFE_OpSetAttrShapeList(op, attr_name, dims, num_dims, num_values, status)
      when :type
        values = ::FFI::MemoryPointer.new(:int, num_values)
        types =
          attr_value.map do |v|
            if v.is_a?(Symbol)
              FFI::DataType[v]
            else
              v
            end
          end
        values.write_array_of_int(types)
        FFI.TFE_OpSetAttrTypeList(op, attr_name, values, num_values)
      else
        raise "Unknown list type: #{FFI::AttrType[type]}"
      end
    else
      case FFI::AttrType[type]
      when :string
        FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize)
      when :int
        FFI.TFE_OpSetAttrInt(op, attr_name, attr_value)
      when :float
        FFI.TFE_OpSetAttrFloat(op, attr_name, attr_value)
      when :bool
        FFI.TFE_OpSetAttrBool(op, attr_name, attr_value ? 1 : 0)
      when :type
        attr_value = FFI::DataType[attr_value] if attr_value.is_a?(Symbol)
        FFI.TFE_OpSetAttrType(op, attr_name, attr_value)
      when :shape
        ptr = ::FFI::MemoryPointer.new(:int64, attr_value.size)
        ptr.write_array_of_int64(attr_value)
        FFI.TFE_OpSetAttrShape(op, attr_name, ptr, attr_value.size, status)
        check_status status
      # when :tensor
      # when :placeholder
      # when :func
      else
        raise "Unknown type: #{FFI::AttrType[type]}"
      end
    end
  end

  inputs.each_with_index do |input, i|
    # TODO handle this better
    if op_name == "TensorSliceDataset" && i == 0
      input_ptr = ::FFI::MemoryPointer.new(:pointer, input.size)
      input_ptr.write_array_of_pointer(input)
      FFI.TFE_OpAddInputList(op, input_ptr, input.size, status)
    else
      raise "Missing argument" if input.nil?

      input = TensorFlow.convert_to_tensor(input) unless input.respond_to?(:to_ptr)
      FFI.TFE_OpAddInput(op, input, status)
    end
    check_status status
  end

  # TODO decide how many retvals to allocate
  retvals = ::FFI::MemoryPointer.new(:pointer, 2)
  num_retvals = ::FFI::MemoryPointer.new(:int)
  num_retvals.write_int(retvals.size)
  FFI.TFE_Execute(op, retvals, num_retvals, status)
  check_status status

  n = num_retvals.read_int
  if n > 0
    retvals =
      retvals.read_array_of_pointer(n).map do |handle|
        Tensor.new(pointer: handle)
      end

    # TODO handle case where n = 1 and still want an array for retvals
    n == 1 ? retvals.first : retvals
  end
ensure
  FFI.TF_DeleteStatus(status) if status
  FFI.TFE_DeleteOp(op) if op
end
infer_type(value) click to toggle source
# File lib/tensorflow/utils.rb, line 146
def infer_type(value)
  if value.is_a?(Numo::NArray)
    type = NUMO_TYPE_MAP.find { |k, v| value.is_a?(v) }
    if type
      type.first
    else
      raise Error, "Unable to infer data type"
    end
  elsif value.empty?
    raise Error, "Unable to infer data type"
  elsif value.all? { |v| v.is_a?(String) }
    :string
  elsif value.all? { |v| v.is_a?(TrueClass) || v.is_a?(FalseClass) }
    :bool
  elsif value.all? { |v| v.is_a?(Integer) }
    if value.all? { |v| v >= -2147483648 && v <= 2147483647 }
      :int32
    else
      :int64
    end
  elsif value.all? { |v| v.is_a?(Complex) }
    :complex128
  elsif value.all? { |v| v.is_a?(Numeric) }
    :float
  else
    raise Error, "Unable to infer data type"
  end
end
to_tensor_array(values) click to toggle source
# File lib/tensorflow/utils.rb, line 175
def to_tensor_array(values)
  values.map do |v|
    if v.is_a?(Tensor)
      v
    else
      TensorFlow.convert_to_tensor(v)
    end
  end
end