class MemoryViewTestHelper::NDArray

Constants

INTEGER_TYPES
SIZEOF_DTYPE

Public Class Methods

new(p1, p2) click to toggle source
static VALUE
ndarray_initialize(VALUE obj, VALUE shape_ary, VALUE dtype_name)
{
  int i;

  Check_Type(shape_ary, T_ARRAY);

  const ssize_t ndim = (ssize_t)RARRAY_LEN(shape_ary);
  for (i = 0; i < ndim; ++i) {
    VALUE si = RARRAY_AREF(shape_ary, i);
    Check_Type(si, T_FIXNUM);
  }

  ssize_t *shape = ALLOC_N(ssize_t, ndim);
  for (i = 0; i < ndim; ++i) {
    VALUE si = RARRAY_AREF(shape_ary, i);
    shape[i] = NUM2SSIZET(si);
  }

  ndarray_dtype_t dtype = ndarray_obj_to_dtype_t(dtype_name);

  ssize_t *strides = ALLOC_N(ssize_t, ndim);
  ndarray_init_row_major_strides(dtype, ndim, shape, strides);

  ndarray_t *nar;
  TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);

  ssize_t byte_size = strides[0] * shape[0];
  nar->data = ALLOC_N(uint8_t, byte_size);
  nar->byte_size = byte_size;
  nar->dtype = dtype;
  nar->ndim = ndim;
  nar->shape = shape;
  nar->strides = strides;

  return Qnil;
}
try_convert(obj, dtype: nil, order: :row_major) click to toggle source
# File lib/memory-view-test-helper.rb, line 7
def self.try_convert(obj, dtype: nil, order: :row_major)
  begin
    ary = obj.to_ary
  rescue TypeError
    raise ArgumentError, "the argument must be converted to an Array by to_ary (#{obj.class} given)"
  end

  dtype, shape, cache = detect_dtype_and_shape(ary, dtype)
  nar = new(shape, dtype)
  assign_cache(nar, cache)
  return nar
end

Private Class Methods

assign_cache(nar, cache) click to toggle source
# File lib/memory-view-test-helper.rb, line 20
                     def self.assign_cache(nar, cache)
  if nar.ndim == 1
    src = cache[0][:ary]
    src.each_with_index do |x, i|
      nar[i] = x
    end
  else
    assign_cache_recursive(nar, [], cache, 0)
  end
end
assign_cache_recursive(nar, idx, cache, k) click to toggle source
# File lib/memory-view-test-helper.rb, line 31
                     def self.assign_cache_recursive(nar, idx, cache, k)
  if cache[k][:dim]+1 != nar.ndim
    cache[k][:ary].each_index do |i|
      k = assign_cache_recursive(nar, [*idx, i], cache, k+1)
    end
  else
    cache[k][:ary].each_with_index do |x, i|
      nar[*idx, i] = x
    end
  end
  k
end
detect_dtype(obj) click to toggle source
# File lib/memory-view-test-helper.rb, line 87
                     def self.detect_dtype(obj)
  case obj
  when Integer
    :int64
  when Float, Rational
    :float64
  when ->(x) { x.is_a?(Complex) && x.imag == 0 }
    detect_dtype(x.real)
  when Enumerable, proc { obj.respond_to?(:to_ary) }
    nil
  else
    raise TypeError, "#{obj.class} is unsupported"
  end
end
detect_dtype_and_shape(ary, dtype) click to toggle source
# File lib/memory-view-test-helper.rb, line 44
                     def self.detect_dtype_and_shape(ary, dtype)
  current_dim = ary.length
  shape = []
  cache = []
  _, dtype, shape, cache = detect_dtype_and_shape_recursive(ary, 0, nil, dtype, shape, cache)
  return dtype, shape, cache
end
detect_dtype_and_shape_recursive(obj, dim, max_dim, fixed_dtype, out_shape, conversion_cache) click to toggle source
# File lib/memory-view-test-helper.rb, line 52
                     def self.detect_dtype_and_shape_recursive(obj, dim, max_dim, fixed_dtype, out_shape, conversion_cache)
  dtype = detect_dtype(obj)
  unless dtype.nil?
    # obj is scalar
    # TODO handle scalar object
    if max_dim.nil?
      max_dim = dim # update max_dim
    elsif dim != max_dim
      dim_failed = [dim, max_dim].min
      raise ArgumentError, "inhomogeneous array detected at the the #{dim_failed}#{ordinal(dim_failed)} dimension"
    end
    return max_dim, dtype, out_shape, conversion_cache
  end

  # obj is array-like
  ary = Array(obj)
  conversion_cache << {obj:obj, ary:ary, dim:dim}

  dim_size = ary.length
  if out_shape.length <= dim
    # update_shape
    out_shape[dim] = dim_size
  elsif out_shape[dim] != dim_size
    raise ArgumentError, "size mismatch at the #{dim}#{ordinal(dim)} dimension (#{dim_size} for #{out_shape[dim]})"
  end

  # recursive detection
  ary.each do |sub|
    max_dim, dtype_sub, = detect_dtype_and_shape_recursive(sub, dim + 1, max_dim, fixed_dtype, out_shape, conversion_cache)
    dtype = promote_dtype(dtype, dtype_sub) unless fixed_dtype
  end

  return max_dim, (fixed_dtype || dtype), out_shape, conversion_cache
end
ordinal(n) click to toggle source
# File lib/memory-view-test-helper.rb, line 145
                     def self.ordinal(n)
  case n % 10
  when 1
    n != 11 ? "st" : "th"
  when 2
    n != 12 ? "nd" : "th"
  when 3
    n != 13 ? "rd" : "th"
  else
    "th"
  end
end
promote_dtype(dtype_a, dtype_b) click to toggle source
# File lib/memory-view-test-helper.rb, line 113
                     def self.promote_dtype(dtype_a, dtype_b)
  # TODO: use sizeof
  if dtype_a == dtype_b
    dtype_a
  elsif dtype_a.nil? || dtype_b.nil?
    dtype_a || dtype_b
  else
    sizeof_a = SIZEOF_DTYPE[dtype_a]
    sizeof_b = SIZEOF_DTYPE[dtype_b]

    if INTEGER_TYPES.include?(dtype_a) && INTEGER_TYPES.include?(dtype_b)
      # both are integer
      if sizeof_a > sizeof_b
        dtype_a
      elsif sizeof_b > sizeof_a
        dtype_b
      else
        raise TypeError, "auto promotion between signed and unsigned is not supported"
      end
    elsif INTEGER_TYPES.include?(dtype_a)
      # b is float
      dtype_b
    elsif INTEGER_TYPES.include?(dtype_b)
      # a is float
      dtype_a
    else
      # both are float
      sizeof_a > sizeof_b ? dtype_a : dtype_b
    end
  end
end

Public Instance Methods

==(p1) click to toggle source
static VALUE
ndarray_eq(VALUE obj, VALUE other)
{
  if (obj == other)
    return Qtrue;
  else if (!rb_typeddata_is_kind_of(other, &ndarray_data_type)) {
    return Qfalse;
  }

  ndarray_t *nar1, *nar2;
  TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar1);
  TypedData_Get_Struct(other, ndarray_t, &ndarray_data_type, nar2);

  const ssize_t ndim = nar1->ndim;
  if (ndim != nar2->ndim)
    return Qfalse;

  if (ndim == 1) {
    const ssize_t n = nar1->shape[0];
    if (n != nar2->shape[0])
      return Qfalse;

    ssize_t i;
    for (i = 0; i < n; ++i) {
      VALUE v1 = ndarray_1d_aref(nar1, i);
      VALUE v2 = ndarray_1d_aref(nar2, i);
      if (!rb_equal(v1, v2))
        return Qfalse;
    }

    return Qtrue;
  }
  else {
    return ndarray_md_eq(nar1, nar2);
  }
}
[](*args) click to toggle source
static VALUE
ndarray_aref(int argc, VALUE *argv, VALUE obj)
{
  ndarray_t *nar;
  TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);

  if (nar->ndim != argc) {
    rb_raise(rb_eIndexError, "index dimension mismatched (%d for %"PRIdSIZE")", argc, nar->ndim);
  }

  const ssize_t ndim = nar->ndim;
  if (ndim == 1) {
    const ssize_t i = NUM2SSIZET(argv[0]);
    return ndarray_1d_aref(nar, i);
  }
  else {
    ssize_t inline_indices_buf[MAX_INLINE_DIM] = { 0, };
    ssize_t *indices = inline_indices_buf;

    VALUE heap_indices_buf = 0;
    if (ndim > MAX_INLINE_DIM) {
      indices = RB_ALLOCV_N(ssize_t, heap_indices_buf, ndim);
    }

    ssize_t i;
    for (i = 0; i < ndim; ++i) {
      indices[i] = NUM2SSIZET(argv[i]);
    }

    VALUE res = ndarray_md_aref(nar, indices);
    RB_ALLOCV_END(heap_indices_buf);
    return res;
  }
}
[]=(*args) click to toggle source
static VALUE
ndarray_aset(int argc, VALUE *argv, VALUE obj)
{
  ndarray_t *nar;
  TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);

  rb_check_frozen(obj);

  if (nar->ndim != argc - 1) {
    rb_raise(rb_eIndexError, "index dimension mismatched (%d for %"PRIdSIZE")", argc - 1, nar->ndim);
  }

  const VALUE val = argv[argc-1];
  const int item_size = SIZEOF_DTYPE(nar->dtype);

  const ssize_t ndim = nar->ndim;
  if (ndim == 1) {
    /* special case for 1-D array */
    ssize_t i = NUM2SSIZET(argv[0]);
    uint8_t *p = ((uint8_t *)nar->data) + i * item_size;
    return ndarray_set_value(p, nar->dtype, val);
  }
  else {
    ssize_t inline_indices_buf[MAX_INLINE_DIM] = { 0, };
    ssize_t *indices = inline_indices_buf;

    VALUE heap_indices_buf = 0;
    if (ndim > MAX_INLINE_DIM) {
      indices = RB_ALLOCV_N(ssize_t, heap_indices_buf, ndim);
    }

    ssize_t i;
    for (i = 0; i < ndim; ++i) {
      indices[i] = NUM2SSIZET(argv[i]);
    }

    VALUE res = ndarray_md_aset(nar, indices, val);
    RB_ALLOCV_END(heap_indices_buf);
    return res;
  }
}
byte_size() click to toggle source
static VALUE
ndarray_get_byte_size(VALUE obj)
{
  ndarray_t *nar;
  TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);

  return SSIZET2NUM(nar->byte_size);
}
dtype() click to toggle source
static VALUE
ndarray_get_dtype(VALUE obj)
{
  ndarray_t *nar;
  TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);

  if (ndarray_dtype_none < nar->dtype && nar->dtype < NDARRAY_NUM_DTYPES) {
    return ID2SYM(DTYPE_ID(nar->dtype));
  }
  return Qnil;
}
ndim() click to toggle source
static VALUE
ndarray_get_ndim(VALUE obj)
{
  ndarray_t *nar;
  TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);

  return SSIZET2NUM(nar->ndim);
}
reshape(new_shape, order: :row_major) click to toggle source
# File lib/memory-view-test-helper.rb, line 158
def reshape(new_shape, order: :row_major)
  reshape_impl(new_shape.to_ary, order.to_sym)
end
shape() click to toggle source
static VALUE
ndarray_get_shape(VALUE obj)
{
  ndarray_t *nar;
  TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);

  VALUE ary = rb_ary_new_capa(nar->ndim);
  int i;
  for (i = 0; i < nar->ndim; ++i) {
    rb_ary_push(ary, SSIZET2NUM(nar->shape[i]));
  }

  return ary;
}
strides() click to toggle source
static VALUE
ndarray_get_strides(VALUE obj)
{
  ndarray_t *nar;
  TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);

  if (nar->strides == NULL) {
    return rb_ary_new_capa(0);
  }

  VALUE ary = rb_ary_new_capa(nar->ndim);
  int i;
  for (i = 0; i < nar->ndim; ++i) {
    rb_ary_push(ary, SSIZET2NUM(nar->strides[i]));
  }

  return ary;
}

Private Instance Methods

reshape_impl(p1, p2) click to toggle source
static VALUE
ndarray_reshape_impl(VALUE base, VALUE new_shape_v, VALUE order)
{
  enum {
    nothing,
    zero_or_negative_size_in_shape,
    incompatible_new_shape,
  } failure_reason = nothing;

  ndarray_t *nar_base;
  TypedData_Get_Struct(base, ndarray_t, &ndarray_data_type, nar_base);

  Check_Type(new_shape_v, T_ARRAY);
  check_order(order);

  if (order == sym_auto) {
    rb_raise(rb_eNotImpError, ":auto order is not implemented");
  }
  else if (order == sym_column_major) {
    rb_raise(rb_eNotImpError, ":column_major order is not implemented");
  }

  const ssize_t new_ndim = RARRAY_LEN(new_shape_v);

  /* preparing the buffer for new_shape */

  ssize_t inline_new_shape_buf[MAX_INLINE_DIM] = { 0, };
  ssize_t *new_shape = inline_new_shape_buf;

  if (new_ndim > MAX_INLINE_DIM) {
    new_shape = ALLOC_N(ssize_t, new_ndim);
  }

  /* extracting new_shape */

  ssize_t byte_size = SIZEOF_DTYPE(nar_base->dtype);
  ssize_t i;
  for (i = 0; i < new_ndim; ++i) {
    ssize_t dim_size = NUM2SSIZET(RARRAY_AREF(new_shape_v, i));
    if (dim_size <= 0) {
      failure_reason = zero_or_negative_size_in_shape;
      goto finish;
    }
    new_shape[i] = dim_size;
    byte_size *= dim_size;
  }

  if (byte_size != nar_base->byte_size) {
    failure_reason = incompatible_new_shape;
    goto finish;
  }

  /* preparing view array */

  VALUE view = ndarray_s_allocate(CLASS_OF(base));

  ndarray_t *nar;
  TypedData_Get_Struct(view, ndarray_t, &ndarray_data_type, nar);

  nar->data = nar_base->data;
  nar->byte_size = nar_base->byte_size;
  nar->dtype = nar_base->dtype;
  nar->base = base;
  nar->ndim = new_ndim;

  if (new_shape == inline_new_shape_buf) {
    nar->shape = ALLOC_N(ssize_t, new_ndim);
    MEMCPY(nar->shape, new_shape, ssize_t, new_ndim);
  }
  else {
    nar->shape = new_shape;
  }

  nar->strides = ALLOC_N(ssize_t, new_ndim);

  if (order == sym_row_major) {
    ndarray_init_row_major_strides(nar->dtype, new_ndim, nar->shape, nar->strides);
  }

finish:
  if (failure_reason != nothing) {
    if (new_shape && new_shape != inline_new_shape_buf) {
      xfree(new_shape);
    }
  }

  switch (failure_reason) {
    case zero_or_negative_size_in_shape:
      rb_raise(rb_eArgError, "zero or negative size is given in new_shape");

    case incompatible_new_shape:
      rb_raise(rb_eArgError,
               "new_shape is incompatible with the base shape (%"PRIsVALUE" for %"PRIsVALUE")",
               new_shape_v, ndarray_get_shape(base));

    default:
      break;
  }

  return view;
}