[mlir][gpu] Pack and unpack to enable f16 and int8 warp reduce. (#11349)
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp
index 1cf1680..b60b354 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp
@@ -20,6 +20,8 @@
#define DEBUG_TYPE "iree-codegen-reduction-distribution"
+static constexpr unsigned kShuffleBitWidth = 32;
+
namespace mlir {
namespace iree_compiler {
@@ -40,28 +42,90 @@
return builder.create<memref::AllocOp>(loc, memrefType);
}
+/// Packs scalar element to it's vector equivalent.
+/// (i.e f16 -> vector<1xf16> and f32 -> vector<1xf32>)
+static Value promoteElementToVector(Location loc, OpBuilder &builder,
+ Value input) {
+ VectorType vectorTypeBroadcast = VectorType::get({1}, input.getType());
+ Value vectorInput =
+ builder.create<vector::BroadcastOp>(loc, vectorTypeBroadcast, input);
+ return vectorInput;
+}
+
+/// Packs vector of lower precision into a single 32-bit width element.
+/// (i.e <2xf16> -> i32 and <4xi8> -> i32)
+static Value packVectorToSupportedWidth(Location loc, OpBuilder &builder,
+ Value input) {
+ LLVM_DEBUG({
+ auto vecType = input.getType().cast<VectorType>();
+ Type elementType = vecType.getElementType();
+ assert(vecType.getDimSize(0) * elementType.getIntOrFloatBitWidth() ==
+ kShuffleBitWidth &&
+ "vecSize * vecBitWidth needs to packable into 32-bitwidth.");
+ assert(elementType.isIntOrFloat() &&
+ "Only int and float packing is supported.");
+ });
+ VectorType packed32Type = VectorType::get({1}, builder.getI32Type());
+ Value packedInputVec =
+ builder.create<vector::BitCastOp>(loc, packed32Type, input);
+ Value packedInput = builder.create<vector::ExtractOp>(loc, packedInputVec, 0);
+ return packedInput;
+}
+
+/// Unpack single scalar element into a target vector type.
+/// (i.e i32 -> vector<4xi8> or f32 -> vector<2xf16>)
+static Value unpackToVector(Location loc, OpBuilder &builder, Value packedInput,
+ VectorType targetVecType) {
+ LLVM_DEBUG({
+ Type packedType = packedInput.getType();
+ assert(packedType.isIntOrFloat() && "Only ints and floats are unpackable.");
+ Type elementType = targetVecType.getElementType();
+ assert(targetVecType.getDimSize(0) * elementType.getIntOrFloatBitWidth() ==
+ packedType.getIntOrFloatBitWidth() &&
+ "packed width needs to be unpackable to vecSize * vecBitWidth.");
+ });
+ Value packedVector = promoteElementToVector(loc, builder, packedInput);
+ Value unpackedVector =
+ builder.create<vector::BitCastOp>(loc, targetVecType, packedVector);
+ return unpackedVector;
+}
+
/// Emit warp reduction code sequence for a given input.
static Value warpReduction(Location loc, OpBuilder &builder, Value input,
vector::CombiningKind kind, uint32_t warpSize,
uint32_t numLaneToReduce) {
+ VectorType unpackedType = input.getType().dyn_cast<VectorType>();
Value laneVal = input;
assert(llvm::isPowerOf2_32(numLaneToReduce));
// Parallel reduction using butterfly shuffles.
for (uint64_t i = 1; i < numLaneToReduce; i <<= 1) {
+ Value shuffleInput = laneVal;
+ if (unpackedType) {
+ shuffleInput = packVectorToSupportedWidth(loc, builder, laneVal);
+ }
Value shuffled = builder
- .create<gpu::ShuffleOp>(loc, laneVal, i,
+ .create<gpu::ShuffleOp>(loc, shuffleInput, i,
/*width=*/warpSize,
/*mode=*/gpu::ShuffleMode::XOR)
.getShuffleResult();
+ if (unpackedType) {
+ shuffled = unpackToVector(loc, builder, shuffled, unpackedType);
+ }
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
}
// Broadcast the result to all the lanes.
if (warpSize != numLaneToReduce) {
+ if (unpackedType) {
+ laneVal = packVectorToSupportedWidth(loc, builder, laneVal);
+ }
laneVal = builder
.create<gpu::ShuffleOp>(loc, laneVal, 0,
/*width=*/warpSize,
/*mode=*/gpu::ShuffleMode::IDX)
.getShuffleResult();
+ if (unpackedType) {
+ laneVal = unpackToVector(loc, builder, laneVal, unpackedType);
+ }
}
return laneVal;
}
@@ -105,6 +169,78 @@
return Attribute();
}
+/// Compute the value on a single thread to get per lane reduction value.
+/// If bit-width is not supported on shuffle operations, and a lower precision,
+/// we represent them as a vector S.T we can pack them into a single 32-bit
+/// width for shuffles.
+static Value reduceToSupportedWidth(Location loc, OpBuilder &builder,
+ Value input, vector::CombiningKind kind) {
+ auto vecType = input.getType().cast<VectorType>();
+ Type elementType = vecType.getElementType();
+ int64_t vecSize = vecType.getDimSize(0);
+ unsigned bitWidth = elementType.getIntOrFloatBitWidth();
+ // Simply reduce if it's already 32 bits.
+ if (bitWidth == kShuffleBitWidth) {
+ return builder.create<vector::ReductionOp>(loc, kind, input);
+ }
+ assert(kShuffleBitWidth % bitWidth == 0 &&
+ "Bitwidth needs to be able to be packed into shuffle-bitwidth.");
+ int64_t unrollCount = kShuffleBitWidth / bitWidth;
+ // Original size needs to be divisble by or less than unroll count to
+ // determine slice size.
+ assert(vecSize % unrollCount == 0 || vecSize < unrollCount);
+ unsigned sliceSize = vecSize / unrollCount;
+ VectorType unrolledLaneValType = VectorType::get({unrollCount}, elementType);
+ Value perLaneReduction = builder.create<arith::ConstantOp>(
+ loc, builder.getZeroAttr(unrolledLaneValType));
+ if (vecSize % unrollCount == 0) {
+ // Unroll reductions s.t we can pack into a supported 32-bitWidth format.
+ for (int64_t i = 0; i < unrollCount; i++) {
+ Value laneValSlice = builder.create<vector::ExtractStridedSliceOp>(
+ loc, input,
+ /*offsets=*/ArrayRef<int64_t>{sliceSize * i},
+ /*sizes=*/ArrayRef<int64_t>{sliceSize},
+ /*strides=*/ArrayRef<int64_t>{1});
+ Value reductionSlice =
+ builder.create<vector::ReductionOp>(loc, kind, laneValSlice);
+ SmallVector<int64_t> perLaneUnrollId = {i};
+ perLaneReduction = builder.create<vector::InsertOp>(
+ loc, reductionSlice, perLaneReduction, perLaneUnrollId);
+ }
+ } else {
+ // In cases where vecSize < unrollCount, we would pad the vector
+ // with identity elements until it's total bit size is 32.
+ Attribute identityAttr =
+ getCombiningKindIdentity(builder, kind, elementType);
+ identityAttr = DenseElementsAttr::get(unrolledLaneValType, identityAttr);
+ Value identity = builder.create<arith::ConstantOp>(loc, identityAttr,
+ unrolledLaneValType);
+ perLaneReduction = builder.create<vector::InsertStridedSliceOp>(
+ loc, input, identity, /*offsets=*/ArrayRef<int64_t>{0},
+ /*strides=*/ArrayRef<int64_t>{1});
+ }
+ return perLaneReduction;
+}
+
+/// Emit identity variable.
+static Value getCombiningIdentityValue(Location loc, OpBuilder &builder,
+ vector::CombiningKind kind,
+ Type identityType) {
+ auto vectorType = identityType.dyn_cast<VectorType>();
+ Type elementType = identityType;
+ if (vectorType) {
+ elementType = vectorType.getElementType();
+ }
+ Attribute identityAttr = getCombiningKindIdentity(builder, kind, elementType);
+ if (vectorType) {
+ identityAttr = DenseElementsAttr::get(vectorType, identityAttr);
+ }
+ assert(identityAttr && "Unknown identity value for the reduction");
+ Value identity =
+ builder.create<arith::ConstantOp>(loc, identityAttr, identityType);
+ return identity;
+}
+
/// Emit reduction across a group for a given input.
static Value groupReduction(Location loc, OpBuilder &builder, Value input,
vector::CombiningKind kind, uint32_t size,
@@ -113,7 +249,7 @@
size % warpSize == 0 &&
"Group reduction only support for sizes aligned on warp size for now.");
// First reduce on a single thread to get per lane reduction value.
- Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+ Value laneVal = reduceToSupportedWidth(loc, builder, input, kind);
laneVal = warpReduction(loc, builder, laneVal, kind, warpSize, warpSize);
// if we have more than one warp, reduce across warps.
if (size > warpSize) {
@@ -148,15 +284,17 @@
Value useIdentityElement = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, laneId, cstNumWarp);
numWarp = llvm::PowerOf2Ceil(numWarp);
- Attribute identityAttr =
- getCombiningKindIdentity(builder, kind, loadVal.getType());
- assert(identityAttr && "Unknown identity value for the reduction");
- Value identity = builder.create<arith::ConstantOp>(loc, identityAttr);
+ Value identity =
+ getCombiningIdentityValue(loc, builder, kind, loadVal.getType());
loadVal = builder.create<arith::SelectOp>(loc, useIdentityElement,
identity, loadVal);
}
laneVal = warpReduction(loc, builder, loadVal, kind, warpSize, numWarp);
}
+ // Handles cases for sub-32bit precision where output is still in vector form.
+ if (laneVal.getType().isa<VectorType>()) {
+ laneVal = builder.create<vector::ReductionOp>(loc, kind, laneVal);
+ }
return laneVal;
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir b/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir
index 79f0f17..3bb8176 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir
@@ -85,6 +85,283 @@
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
+hal.executable private @simple_half_reduce {
+ hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
+ hal.executable.export @simple_half_reduce layout(#pipeline_layout) attributes {
+ workgroup_size = [32 : index, 1 : index, 1 : index]
+ }
+ builtin.module {
+ func.func @simple_half_reduce() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.000000e+00> : vector<1xf16>
+ %cst_0 = arith.constant 0.000000e+00 : f16
+ %cst_1 = arith.constant dense<3.840000e+02> : vector<1xf16>
+ %c32 = arith.constant 32 : index
+ %c384 = arith.constant 384 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<128x384xf16>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128xf16>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %2 = gpu.thread_id x
+ %3 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 + s0 floordiv 32)>()[%2, %workgroup_id_x]
+ %4 = vector.transfer_read %0[%3, %c0], %cst_0 {in_bounds = [true]} : memref<128x384xf16>, vector<384xf16>
+ %5 = vector.broadcast %4 : vector<384xf16> to vector<1x384xf16>
+ %6 = vector.multi_reduction <add>, %5, %cst [1] : vector<1x384xf16> to vector<1xf16>
+ %7 = arith.divf %6, %cst_1 : vector<1xf16>
+ vector.transfer_write %7, %1[%3] {in_bounds = [true]} : vector<1xf16>, memref<128xf16>
+ return
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: func.func @simple_half_reduce() {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0.000000e+00> : vector<2xf16>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[TID:.*]] = gpu.thread_id x
+// CHECK: %[[READ_X:.*]] = affine.apply #{{.*}}()[%[[TID]], %{{.*}}]
+// CHECK: %[[READ_Y:.*]] = affine.apply #{{.*}}()[%[[TID]]]
+// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[READ_X]], %[[READ_Y]]], %[[C0F]] {in_bounds = [true]} : memref<128x384xf16>, vector<12xf16>
+// CHECK: %[[V2:.*]] = vector.extract_strided_slice %[[V1]] {offsets = [0], sizes = [6], strides = [1]} : vector<12xf16> to vector<6xf16>
+// CHECK: %[[V3:.*]] = vector.reduction <add>, %[[V2]] : vector<6xf16> into f16
+// CHECK: %[[V4:.*]] = vector.insert %[[V3]], %[[IDENTITY]] [0] : f16 into vector<2xf16>
+// CHECK: %[[V5:.*]] = vector.extract_strided_slice %[[V1]] {offsets = [6], sizes = [6], strides = [1]} : vector<12xf16> to vector<6xf16>
+// CHECK: %[[V6:.*]] = vector.reduction <add>, %[[V5]] : vector<6xf16> into f16
+// CHECK: %[[V7:.*]] = vector.insert %[[V6]], %[[V4]] [1] : f16 into vector<2xf16>
+// CHECK: %[[CAST0:.*]] = vector.bitcast %[[V7]] : vector<2xf16> to vector<1xi32>
+// CHECK: %[[PACK0:.*]] = vector.extract %[[CAST0]][0] : vector<1xi32>
+// CHECK: %[[S0:.*]], %{{.*}} = gpu.shuffle xor %[[PACK0]], %[[C1]], %[[C32]] : i32
+// CHECK: %[[BROADCAST0:.*]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK0:.*]] = vector.bitcast %[[BROADCAST0]] : vector<1xi32> to vector<2xf16>
+// CHECK: %[[S1:.*]] = arith.addf %[[V7]], %[[UNPACK0]] : vector<2xf16>
+// CHECK: %[[CAST1:.*]] = vector.bitcast %[[S1]] : vector<2xf16> to vector<1xi32>
+// CHECK: %[[PACK1:.*]] = vector.extract %[[CAST1]][0] : vector<1xi32>
+// CHECK: %[[S2:.*]], %{{.*}} = gpu.shuffle xor %[[PACK1]], %[[C2]], %[[C32]] : i32
+// CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[S2]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK1:.*]] = vector.bitcast %[[BROADCAST1]] : vector<1xi32> to vector<2xf16>
+// CHECK: %[[S3:.*]] = arith.addf %[[S1]], %[[UNPACK1]] : vector<2xf16>
+// CHECK: %[[CAST2:.*]] = vector.bitcast %[[S3]] : vector<2xf16> to vector<1xi32>
+// CHECK: %[[PACK2:.*]] = vector.extract %[[CAST2]][0] : vector<1xi32>
+// CHECK: %[[S4:.*]], %{{.*}} = gpu.shuffle xor %[[PACK2]], %[[C4]], %[[C32]] : i32
+// CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[S4]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK2:.*]] = vector.bitcast %[[BROADCAST2]] : vector<1xi32> to vector<2xf16>
+// CHECK: %[[S5:.*]] = arith.addf %[[S3]], %[[UNPACK2]] : vector<2xf16>
+// CHECK: %[[CAST3:.*]] = vector.bitcast %[[S5]] : vector<2xf16> to vector<1xi32>
+// CHECK: %[[PACK3:.*]] = vector.extract %[[CAST3]][0] : vector<1xi32>
+// CHECK: %[[S6:.*]], %{{.*}} = gpu.shuffle xor %[[PACK3]], %[[C8]], %[[C32]] : i32
+// CHECK: %[[BROADCAST3:.*]] = vector.broadcast %[[S6]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK3:.*]] = vector.bitcast %[[BROADCAST3]] : vector<1xi32> to vector<2xf16>
+// CHECK: %[[S7:.*]] = arith.addf %[[S5]], %[[UNPACK3]] : vector<2xf16>
+// CHECK: %[[CAST4:.*]] = vector.bitcast %[[S7]] : vector<2xf16> to vector<1xi32>
+// CHECK: %[[PACK4:.*]] = vector.extract %[[CAST4]][0] : vector<1xi32>
+// CHECK: %[[S8:.*]], %{{.*}} = gpu.shuffle xor %[[PACK4]], %[[C16]], %[[C32]] : i32
+// CHECK: %[[BROADCAST4:.*]] = vector.broadcast %[[S8]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK4:.*]] = vector.bitcast %[[BROADCAST4]] : vector<1xi32> to vector<2xf16>
+// CHECK: %[[S9:.*]] = arith.addf %[[S7]], %[[UNPACK4]] : vector<2xf16>
+// CHECK: %[[S10:.*]] = vector.reduction <add>, %[[S9]] : vector<2xf16> into f16
+// CHECK: %[[S11:.*]] = arith.addf %[[S10]], %[[C0F]] : f16
+// CHECK: %[[B:.*]] = vector.broadcast %[[S11]] : f16 to vector<1xf16>
+// CHECK: %[[DIV:.*]] = arith.divf %[[B]], %{{.*}} : vector<1xf16>
+// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[TID]], %[[C0]] : index
+// CHECK: scf.if %[[CMP]] {
+// CHECK: vector.transfer_write %[[DIV]], {{.*}} : vector<1xf16>, memref<128xf16>
+// CHECK: }
+// CHECK: return
+
+// -----
+
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+hal.executable private @simple_half_reduce_small {
+ hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
+ hal.executable.export @simple_half_reduce_small layout(#pipeline_layout) attributes {
+ workgroup_size = [32 : index, 1 : index, 1 : index]
+ }
+ builtin.module {
+ func.func @simple_half_reduce_small() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.000000e+00> : vector<1xf16>
+ %cst_0 = arith.constant 0.000000e+00 : f16
+ %cst_1 = arith.constant dense<3.840000e+02> : vector<1xf16>
+ %c32 = arith.constant 32 : index
+ %c384 = arith.constant 384 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<128x32xf16>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128xf16>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %2 = gpu.thread_id x
+ %3 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 + s0 floordiv 32)>()[%2, %workgroup_id_x]
+ %4 = vector.transfer_read %0[%3, %c0], %cst_0 {in_bounds = [true]} : memref<128x32xf16>, vector<32xf16>
+ %5 = vector.broadcast %4 : vector<32xf16> to vector<1x32xf16>
+ %6 = vector.multi_reduction <add>, %5, %cst [1] : vector<1x32xf16> to vector<1xf16>
+ %7 = arith.divf %6, %cst_1 : vector<1xf16>
+ vector.transfer_write %7, %1[%3] {in_bounds = [true]} : vector<1xf16>, memref<128xf16>
+ return
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: func.func @simple_half_reduce_small() {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0.000000e+00> : vector<2xf16>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-DAG: %[[C1F:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[TID:.*]] = gpu.thread_id x
+// CHECK: %[[ID:.*]] = affine.apply
+// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[ID]], %{{.*}}], %{{.*}} {in_bounds = [true]} : memref<128x32xf16>, vector<1xf16>
+// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[V1]], %[[IDENTITY]] {offsets = [0], strides = [1]} : vector<1xf16> into vector<2xf16>
+// CHECK: %[[CAST0:.*]] = vector.bitcast %[[V2]] : vector<2xf16> to vector<1xi32>
+// CHECK: %[[PACK0:.*]] = vector.extract %[[CAST0]][0] : vector<1xi32>
+// CHECK: %[[S0:.*]], %{{.*}} = gpu.shuffle xor %[[PACK0]], %[[C1]], %[[C32]] : i32
+// CHECK: %[[BROADCAST0:.*]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK0:.*]] = vector.bitcast %[[BROADCAST0]] : vector<1xi32> to vector<2xf16>
+// CHECK: %[[S1:.*]] = arith.addf %[[V2]], %[[UNPACK0]] : vector<2xf16>
+// CHECK: %[[CAST1:.*]] = vector.bitcast %[[S1]] : vector<2xf16> to vector<1xi32>
+// CHECK: %[[PACK1:.*]] = vector.extract %[[CAST1]][0] : vector<1xi32>
+// CHECK: %[[S2:.*]], %{{.*}} = gpu.shuffle xor %[[PACK1]], %[[C2]], %[[C32]] : i32
+// CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[S2]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK1:.*]] = vector.bitcast %[[BROADCAST1]] : vector<1xi32> to vector<2xf16>
+// CHECK: %[[S3:.*]] = arith.addf %[[S1]], %[[UNPACK1]] : vector<2xf16>
+// CHECK: %[[CAST2:.*]] = vector.bitcast %[[S3]] : vector<2xf16> to vector<1xi32>
+// CHECK: %[[PACK2:.*]] = vector.extract %[[CAST2]][0] : vector<1xi32>
+// CHECK: %[[S4:.*]], %{{.*}} = gpu.shuffle xor %[[PACK2]], %[[C4]], %[[C32]] : i32
+// CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[S4]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK2:.*]] = vector.bitcast %[[BROADCAST2]] : vector<1xi32> to vector<2xf16>
+// CHECK: %[[S5:.*]] = arith.addf %[[S3]], %[[UNPACK2]] : vector<2xf16>
+// CHECK: %[[CAST3:.*]] = vector.bitcast %[[S5]] : vector<2xf16> to vector<1xi32>
+// CHECK: %[[PACK3:.*]] = vector.extract %[[CAST3]][0] : vector<1xi32>
+// CHECK: %[[S6:.*]], %{{.*}} = gpu.shuffle xor %[[PACK3]], %[[C8]], %[[C32]] : i32
+// CHECK: %[[BROADCAST3:.*]] = vector.broadcast %[[S6]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK3:.*]] = vector.bitcast %[[BROADCAST3]] : vector<1xi32> to vector<2xf16>
+// CHECK: %[[S7:.*]] = arith.addf %[[S5]], %[[UNPACK3]] : vector<2xf16>
+// CHECK: %[[CAST4:.*]] = vector.bitcast %[[S7]] : vector<2xf16> to vector<1xi32>
+// CHECK: %[[PACK4:.*]] = vector.extract %[[CAST4]][0] : vector<1xi32>
+// CHECK: %[[S8:.*]], %{{.*}} = gpu.shuffle xor %[[PACK4]], %[[C16]], %[[C32]] : i32
+// CHECK: %[[BROADCAST4:.*]] = vector.broadcast %[[S8]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK4:.*]] = vector.bitcast %[[BROADCAST4]] : vector<1xi32> to vector<2xf16>
+// CHECK: %[[S9:.*]] = arith.addf %[[S7]], %[[UNPACK4]] : vector<2xf16>
+// CHECK: %[[S10:.*]] = vector.reduction <add>, %[[S9]] : vector<2xf16> into f16
+// CHECK: %[[S11:.*]] = arith.addf %[[S10]], %[[C1F]] : f16
+// CHECK: %[[B:.*]] = vector.broadcast %[[S11]] : f16 to vector<1xf16>
+// CHECK: %[[DIV:.*]] = arith.divf %[[B]], %{{.*}} : vector<1xf16>
+// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[TID]], %[[C0]] : index
+// CHECK: scf.if %[[CMP]] {
+// CHECK: vector.transfer_write %[[DIV]], {{.*}} : vector<1xf16>, memref<128xf16>
+// CHECK: }
+// CHECK: return
+
+// -----
+
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+hal.executable private @simple_quarter_reduce {
+ hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
+ hal.executable.export @simple_quarter_reduce layout(#pipeline_layout) attributes {
+ workgroup_size = [32 : index, 1 : index, 1 : index]
+ }
+ builtin.module {
+ func.func @simple_quarter_reduce() {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0> : vector<1xi8>
+ %cst_0 = arith.constant 0 : i8
+ %c32 = arith.constant 32 : index
+ %c384 = arith.constant 384 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<128x384xi8>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128xi8>
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %2 = gpu.thread_id x
+ %3 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 + s0 floordiv 32)>()[%2, %workgroup_id_x]
+ %4 = vector.transfer_read %0[%3, %c0], %cst_0 {in_bounds = [true]} : memref<128x384xi8>, vector<32xi8>
+ %5 = vector.broadcast %4 : vector<32xi8> to vector<1x32xi8>
+ %6 = vector.multi_reduction <add>, %5, %cst [1] : vector<1x32xi8> to vector<1xi8>
+ vector.transfer_write %6, %1[%3] {in_bounds = [true]} : vector<1xi8>, memref<128xi8>
+ return
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: func.func @simple_quarter_reduce() {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0> : vector<4xi8>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-DAG: %[[TID:.*]] = gpu.thread_id x
+// CHECK: %[[ID:.*]] = affine.apply
+// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[ID]], %{{.*}}], %{{.*}} {in_bounds = [true]} : memref<128x384xi8>, vector<1xi8>
+// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[V1]], %[[IDENTITY]] {offsets = [0], strides = [1]} : vector<1xi8> into vector<4xi8>
+// CHECK: %[[CAST0:.*]] = vector.bitcast %[[V2]] : vector<4xi8> to vector<1xi32>
+// CHECK: %[[PACK0:.*]] = vector.extract %[[CAST0]][0] : vector<1xi32>
+// CHECK: %[[S0:.*]], %{{.*}} = gpu.shuffle xor %[[PACK0]], %[[C1]], %[[C32]] : i32
+// CHECK: %[[BROADCAST0:.*]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK0:.*]] = vector.bitcast %[[BROADCAST0]] : vector<1xi32> to vector<4xi8>
+// CHECK: %[[S1:.*]] = arith.addi %[[V2]], %[[UNPACK0]] : vector<4xi8>
+// CHECK: %[[CAST1:.*]] = vector.bitcast %[[S1]] : vector<4xi8> to vector<1xi32>
+// CHECK: %[[PACK1:.*]] = vector.extract %[[CAST1]][0] : vector<1xi32>
+// CHECK: %[[S2:.*]], %{{.*}} = gpu.shuffle xor %[[PACK1]], %[[C2]], %[[C32]] : i32
+// CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[S2]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK1:.*]] = vector.bitcast %[[BROADCAST1]] : vector<1xi32> to vector<4xi8>
+// CHECK: %[[S3:.*]] = arith.addi %[[S1]], %[[UNPACK1]] : vector<4xi8>
+// CHECK: %[[CAST2:.*]] = vector.bitcast %[[S3]] : vector<4xi8> to vector<1xi32>
+// CHECK: %[[PACK2:.*]] = vector.extract %[[CAST2]][0] : vector<1xi32>
+// CHECK: %[[S4:.*]], %{{.*}} = gpu.shuffle xor %[[PACK2]], %[[C4]], %[[C32]] : i32
+// CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[S4]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK2:.*]] = vector.bitcast %[[BROADCAST2]] : vector<1xi32> to vector<4xi8>
+// CHECK: %[[S5:.*]] = arith.addi %[[S3]], %[[UNPACK2]] : vector<4xi8>
+// CHECK: %[[CAST3:.*]] = vector.bitcast %[[S5]] : vector<4xi8> to vector<1xi32>
+// CHECK: %[[PACK3:.*]] = vector.extract %[[CAST3]][0] : vector<1xi32>
+// CHECK: %[[S6:.*]], %{{.*}} = gpu.shuffle xor %[[PACK3]], %[[C8]], %[[C32]] : i32
+// CHECK: %[[BROADCAST3:.*]] = vector.broadcast %[[S6]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK3:.*]] = vector.bitcast %[[BROADCAST3]] : vector<1xi32> to vector<4xi8>
+// CHECK: %[[S7:.*]] = arith.addi %[[S5]], %[[UNPACK3]] : vector<4xi8>
+// CHECK: %[[CAST4:.*]] = vector.bitcast %[[S7]] : vector<4xi8> to vector<1xi32>
+// CHECK: %[[PACK4:.*]] = vector.extract %[[CAST4]][0] : vector<1xi32>
+// CHECK: %[[S8:.*]], %{{.*}} = gpu.shuffle xor %[[PACK4]], %[[C16]], %[[C32]] : i32
+// CHECK: %[[BROADCAST4:.*]] = vector.broadcast %[[S8]] : i32 to vector<1xi32>
+// CHECK: %[[UNPACK4:.*]] = vector.bitcast %[[BROADCAST4]] : vector<1xi32> to vector<4xi8>
+// CHECK: %[[S9:.*]] = arith.addi %[[S7]], %[[UNPACK4]] : vector<4xi8>
+// CHECK: %[[S10:.*]] = vector.reduction <add>, %[[S9]] : vector<4xi8> into i8
+// CHECK: %[[B:.*]] = vector.broadcast %[[S10]] : i8 to vector<1xi8>
+// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[TID]], %[[C0]] : index
+// CHECK: scf.if %[[CMP]] {
+// CHECK: vector.transfer_write %[[B]], {{.*}} : vector<1xi8>, memref<128xi8>
+// CHECK: }
+// CHECK: return
+
+// -----
+
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
hal.executable private @simple_reduce_multi_warp {
hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
hal.executable.export @simple_reduce_multi_warp layout(#pipeline_layout) attributes {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index adab5f5..f1e4bd7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -553,8 +553,8 @@
.getElementType();
if (!elementType.isIntOrFloat()) return failure();
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
- // Reduction distribution only supports 32-bit types now.
- if (bitWidth != 32) return failure();
+ // Reduction distribution only supports 8/16/32 bit types now.
+ if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8) return failure();
const unsigned largestLoadSizeInBits = 128;
unsigned vectorSize = largestLoadSizeInBits / bitWidth;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index fed8c1b..dab068a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -969,8 +969,8 @@
op.getOutputs()[0].getType().cast<ShapedType>().getElementType();
if (!elementType.isIntOrFloat()) return failure();
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
- // Reduction distribution only supports 32-bit types now.
- if (bitWidth != 32) return failure();
+ // Reduction distribution only supports 8/16/32 bit types now.
+ if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8) return failure();
// Let each thread handle `vectorSize` elements.
unsigned vectorSize = kMaxVectorNumBits / bitWidth;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
index 0c7e6a1..e48bd68 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
@@ -101,11 +101,11 @@
}
}
-// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 256], [1, 4], [0, 0, 4]{{\]}}>
-// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseVectorize>
+// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 4096]{{\]}}>
+// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce>
// CHECK: hal.executable.export public @subgroup_reduce_f16
// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index]
+// CHECK-SAME: workgroup_size = [512 : index, 1 : index, 1 : index]
// CHECK: func.func @subgroup_reduce_f16()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[CONFIG]]
diff --git a/tests/e2e/regression/large_reduction.mlir b/tests/e2e/regression/large_reduction.mlir
index 15e6fe4..cdc24b6 100644
--- a/tests/e2e/regression/large_reduction.mlir
+++ b/tests/e2e/regression/large_reduction.mlir
@@ -49,3 +49,37 @@
check.expect_almost_eq_const(%result, dense<40.96> : tensor<2xf32>) : tensor<2xf32>
return
}
+
+func.func @half_reduction_aligned() {
+ %in = util.unfoldable_constant dense<0.001> : tensor<2x4096xf16>
+ %cst = arith.constant 0.0 : f16
+ %init = tensor.empty() : tensor<2xf16>
+ %fill = linalg.fill ins(%cst : f16) outs(%init : tensor<2xf16>) -> tensor<2xf16>
+ %result = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%in : tensor<2x4096xf16>) outs(%fill : tensor<2xf16>) {
+ ^bb0(%arg3: f16, %arg4: f16): // no predecessors
+ %2 = arith.addf %arg3, %arg4 : f16
+ linalg.yield %2 : f16
+ } -> tensor<2xf16>
+ check.expect_almost_eq_const(%result, dense<4.096> : tensor<2xf16>) : tensor<2xf16>
+ return
+}
+
+func.func @quarter_reduction_aligned_smaller() {
+ %in = util.unfoldable_constant dense<1> : tensor<128x128xi8>
+ %cst = arith.constant 0 : i8
+ %init = tensor.empty() : tensor<128xi8>
+ %fill = linalg.fill ins(%cst : i8) outs(%init : tensor<128xi8>) -> tensor<128xi8>
+ %result = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%in : tensor<128x128xi8>) outs(%fill : tensor<128xi8>) {
+ ^bb0(%arg3: i8, %arg4: i8): // no predecessors
+ %2 = arith.addi %arg3, %arg4 : i8
+ linalg.yield %2 : i8
+ } -> tensor<128xi8>
+ check.expect_eq_const(%result, dense<128> : tensor<128xi8>) : tensor<128xi8>
+ return
+}
\ No newline at end of file