% c_dtype = dtype_to_c_type(dtype) % out_c_dtype = dtype_to_c_type(out_dtype) % o_multipliers = o_shape.dup.drop(1).reverse.inject() { |a, s| a << s * a.last }.reverse % i_multipliers = shape.dup.drop(1).reverse.inject() { |a, s| a << s * a.last }.reverse % out_ops = o_multipliers.map.with_index { |m, index| “id_#{index} * #{m}”}.join(' + ') % axis = axis % in_axis_multipliers = i_multipliers.select.with_index { |m, index| axis == index } % in_axis_ops = in_axis_multipliers.map.with_index { |m| “i * #{m}”}.join(' + ') % in_output_multipliers = i_multipliers.reject.with_index { |m, index| axis == index } % in_output_ops = in_output_multipliers.map.with_index { |m, index| “id_#{index} * #{m}”}.join(' + ') __kernel void arg_axis_<%= dtype %>(__global const <%= c_dtype %> *value, __global <%= out_c_dtype %> *output) {
// Get the index of the current element to be processed
<% o_multipliers.size.times.each_with_index do |s, index| %>
const int id_<%= index %> = get_global_id(<%= index %>);
<% end %>
<%= c_dtype %> min_or_max_value = <%= f == :argmax ? min_value_for(dtype) : max_value_for(dtype) %>; int min_or_max_index = 0;
for (int i = 0; i < <%= shape %>; i++) {
int index = <%= in_axis_ops %>; <% unless in_output_ops.empty? %> index += <%= in_output_ops %>; <% end %> <%= case(f) when :argmax "if (value[index] > min_or_max_value) {" when :argmin "if (value[index] < min_or_max_value) {" else raise "unkown redunction func #{f}" end %> min_or_max_index = i; min_or_max_value = value[index]; }
}
output[<%= out_ops %>] = (<%= out_c_dtype %>)min_or_max_index;
}