blob: b6176d45136ae688150741892eb0b0cbf5c43f50 [file] [log] [blame]
// Single-dimension reduction using a nested loop.
// This does not use subgroup arithmetic or prefetching.
#version 450
// TODO(benvanik): tile.
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(std430, binding = 0) buffer readonly arg0_binding {
float input_0[];
};
layout(std430, binding = 1) buffer readonly arg1_binding {
float initial_value_0[];
};
layout(std430, binding = 2) buffer writeonly ret0_binding {
float output_0[];
};
// 0 = add
// 1 = max
// 2 = min
layout(constant_id = 100) const uint kOperationId = 0;
// This math allows us to handle all single-dimension reductions regardless of
// input/output ranks as if they were effectively 2-/3-D values.
//
// You can always see the output Array as a 2D array of shape [A, B] with
// strides [B*C, C]. You can reshape to linearize all the dimensions to the
// left of the reduced dimension and all the dimension to the right of it.
// Similarly the input array can always be recast into a 3D array of shape
// [A, R, B] with stride [C*R*B, B*C, C].
//
// So the output index is [i, j] and the input index is [i, r, j]:
// inputLinearIndex =
// floor(outputLinearIndex / (B * C)) * C * R * B +
// r * B * C +
// ((ouputLinearIndex / C) % B) * C;
layout(constant_id = 101) const uint kA = 1;
layout(constant_id = 102) const uint kB = 1;
layout(constant_id = 103) const uint kC = 1;
layout(constant_id = 104) const uint kR = 1;
uint GetLinearizedOutputIndex() {
uint workGroupIndex = uint(dot(gl_WorkGroupID,
uvec3(1, gl_NumWorkGroups.x, gl_NumWorkGroups.x * gl_NumWorkGroups.y)));
return (workGroupIndex * gl_WorkGroupSize.x * gl_WorkGroupSize.y *
gl_WorkGroupSize.z) + gl_LocalInvocationIndex;
}
void main() {
uint output_index = GetLinearizedOutputIndex();
uint input_index = (((output_index / kC) % kB) * kC) * kC * kR * kB + (output_index % kB);
float value = initial_value_0[0];
for (uint i = 0; i < kR; ++i, input_index += kC) {
const float next_value = input_0[input_index];
if (kOperationId == 0) {
value = value + next_value;
} else if (kOperationId == 1) {
value = max(value, next_value);
} else if (kOperationId == 2) {
value = min(value, next_value);
}
}
output_0[output_index] = value;
}