| // Single-dimension reduction using a nested loop. |
| // This does not use subgroup arithmetic or prefetching. |
| |
| #version 450 |
| |
| // TODO(benvanik): tile. |
| layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; |
| |
| layout(std430, binding = 0) buffer readonly arg0_binding { |
| float input_0[]; |
| }; |
| layout(std430, binding = 1) buffer readonly arg1_binding { |
| float initial_value_0[]; |
| }; |
| layout(std430, binding = 2) buffer writeonly ret0_binding { |
| float output_0[]; |
| }; |
| |
| // 0 = add |
| // 1 = max |
| // 2 = min |
| layout(constant_id = 100) const uint kOperationId = 0; |
| |
| // This math allows us to handle all single-dimension reductions regardless of |
| // input/output ranks as if they were effectively 2-/3-D values. |
| // |
| // You can always see the output Array as a 2D array of shape [A, B] with |
| // strides [B*C, C]. You can reshape to linearize all the dimensions to the |
| // left of the reduced dimension and all the dimension to the right of it. |
| // Similarly the input array can always be recast into a 3D array of shape |
| // [A, R, B] with stride [C*R*B, B*C, C]. |
| // |
| // So the output index is [i, j] and the input index is [i, r, j]: |
| // inputLinearIndex = |
| // floor(outputLinearIndex / (B * C)) * C * R * B + |
| // r * B * C + |
| // ((ouputLinearIndex / C) % B) * C; |
| layout(constant_id = 101) const uint kA = 1; |
| layout(constant_id = 102) const uint kB = 1; |
| layout(constant_id = 103) const uint kC = 1; |
| layout(constant_id = 104) const uint kR = 1; |
| |
| uint GetLinearizedOutputIndex() { |
| uint workGroupIndex = uint(dot(gl_WorkGroupID, |
| uvec3(1, gl_NumWorkGroups.x, gl_NumWorkGroups.x * gl_NumWorkGroups.y))); |
| return (workGroupIndex * gl_WorkGroupSize.x * gl_WorkGroupSize.y * |
| gl_WorkGroupSize.z) + gl_LocalInvocationIndex; |
| } |
| |
| void main() { |
| uint output_index = GetLinearizedOutputIndex(); |
| uint input_index = (((output_index / kC) % kB) * kC) * kC * kR * kB + (output_index % kB); |
| float value = initial_value_0[0]; |
| for (uint i = 0; i < kR; ++i, input_index += kC) { |
| const float next_value = input_0[input_index]; |
| if (kOperationId == 0) { |
| value = value + next_value; |
| } else if (kOperationId == 1) { |
| value = max(value, next_value); |
| } else if (kOperationId == 2) { |
| value = min(value, next_value); |
| } |
| } |
| output_0[output_index] = value; |
| } |