[mlir][gpu] Pack and unpack to enable f16 and int8 warp reduce. (#11349)
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp index 1cf1680..b60b354 100644 --- a/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp
@@ -20,6 +20,8 @@ #define DEBUG_TYPE "iree-codegen-reduction-distribution" +static constexpr unsigned kShuffleBitWidth = 32; + namespace mlir { namespace iree_compiler { @@ -40,28 +42,90 @@ return builder.create<memref::AllocOp>(loc, memrefType); } +/// Packs scalar element to it's vector equivalent. +/// (i.e f16 -> vector<1xf16> and f32 -> vector<1xf32>) +static Value promoteElementToVector(Location loc, OpBuilder &builder, + Value input) { + VectorType vectorTypeBroadcast = VectorType::get({1}, input.getType()); + Value vectorInput = + builder.create<vector::BroadcastOp>(loc, vectorTypeBroadcast, input); + return vectorInput; +} + +/// Packs vector of lower precision into a single 32-bit width element. +/// (i.e <2xf16> -> i32 and <4xi8> -> i32) +static Value packVectorToSupportedWidth(Location loc, OpBuilder &builder, + Value input) { + LLVM_DEBUG({ + auto vecType = input.getType().cast<VectorType>(); + Type elementType = vecType.getElementType(); + assert(vecType.getDimSize(0) * elementType.getIntOrFloatBitWidth() == + kShuffleBitWidth && + "vecSize * vecBitWidth needs to packable into 32-bitwidth."); + assert(elementType.isIntOrFloat() && + "Only int and float packing is supported."); + }); + VectorType packed32Type = VectorType::get({1}, builder.getI32Type()); + Value packedInputVec = + builder.create<vector::BitCastOp>(loc, packed32Type, input); + Value packedInput = builder.create<vector::ExtractOp>(loc, packedInputVec, 0); + return packedInput; +} + +/// Unpack single scalar element into a target vector type. +/// (i.e i32 -> vector<4xi8> or f32 -> vector<2xf16>) +static Value unpackToVector(Location loc, OpBuilder &builder, Value packedInput, + VectorType targetVecType) { + LLVM_DEBUG({ + Type packedType = packedInput.getType(); + assert(packedType.isIntOrFloat() && "Only ints and floats are unpackable."); + Type elementType = targetVecType.getElementType(); + assert(targetVecType.getDimSize(0) * elementType.getIntOrFloatBitWidth() == + packedType.getIntOrFloatBitWidth() && + "packed width needs to be unpackable to vecSize * vecBitWidth."); + }); + Value packedVector = promoteElementToVector(loc, builder, packedInput); + Value unpackedVector = + builder.create<vector::BitCastOp>(loc, targetVecType, packedVector); + return unpackedVector; +} + /// Emit warp reduction code sequence for a given input. static Value warpReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t warpSize, uint32_t numLaneToReduce) { + VectorType unpackedType = input.getType().dyn_cast<VectorType>(); Value laneVal = input; assert(llvm::isPowerOf2_32(numLaneToReduce)); // Parallel reduction using butterfly shuffles. for (uint64_t i = 1; i < numLaneToReduce; i <<= 1) { + Value shuffleInput = laneVal; + if (unpackedType) { + shuffleInput = packVectorToSupportedWidth(loc, builder, laneVal); + } Value shuffled = builder - .create<gpu::ShuffleOp>(loc, laneVal, i, + .create<gpu::ShuffleOp>(loc, shuffleInput, i, /*width=*/warpSize, /*mode=*/gpu::ShuffleMode::XOR) .getShuffleResult(); + if (unpackedType) { + shuffled = unpackToVector(loc, builder, shuffled, unpackedType); + } laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); } // Broadcast the result to all the lanes. if (warpSize != numLaneToReduce) { + if (unpackedType) { + laneVal = packVectorToSupportedWidth(loc, builder, laneVal); + } laneVal = builder .create<gpu::ShuffleOp>(loc, laneVal, 0, /*width=*/warpSize, /*mode=*/gpu::ShuffleMode::IDX) .getShuffleResult(); + if (unpackedType) { + laneVal = unpackToVector(loc, builder, laneVal, unpackedType); + } } return laneVal; } @@ -105,6 +169,78 @@ return Attribute(); } +/// Compute the value on a single thread to get per lane reduction value. +/// If bit-width is not supported on shuffle operations, and a lower precision, +/// we represent them as a vector S.T we can pack them into a single 32-bit +/// width for shuffles. +static Value reduceToSupportedWidth(Location loc, OpBuilder &builder, + Value input, vector::CombiningKind kind) { + auto vecType = input.getType().cast<VectorType>(); + Type elementType = vecType.getElementType(); + int64_t vecSize = vecType.getDimSize(0); + unsigned bitWidth = elementType.getIntOrFloatBitWidth(); + // Simply reduce if it's already 32 bits. + if (bitWidth == kShuffleBitWidth) { + return builder.create<vector::ReductionOp>(loc, kind, input); + } + assert(kShuffleBitWidth % bitWidth == 0 && + "Bitwidth needs to be able to be packed into shuffle-bitwidth."); + int64_t unrollCount = kShuffleBitWidth / bitWidth; + // Original size needs to be divisble by or less than unroll count to + // determine slice size. + assert(vecSize % unrollCount == 0 || vecSize < unrollCount); + unsigned sliceSize = vecSize / unrollCount; + VectorType unrolledLaneValType = VectorType::get({unrollCount}, elementType); + Value perLaneReduction = builder.create<arith::ConstantOp>( + loc, builder.getZeroAttr(unrolledLaneValType)); + if (vecSize % unrollCount == 0) { + // Unroll reductions s.t we can pack into a supported 32-bitWidth format. + for (int64_t i = 0; i < unrollCount; i++) { + Value laneValSlice = builder.create<vector::ExtractStridedSliceOp>( + loc, input, + /*offsets=*/ArrayRef<int64_t>{sliceSize * i}, + /*sizes=*/ArrayRef<int64_t>{sliceSize}, + /*strides=*/ArrayRef<int64_t>{1}); + Value reductionSlice = + builder.create<vector::ReductionOp>(loc, kind, laneValSlice); + SmallVector<int64_t> perLaneUnrollId = {i}; + perLaneReduction = builder.create<vector::InsertOp>( + loc, reductionSlice, perLaneReduction, perLaneUnrollId); + } + } else { + // In cases where vecSize < unrollCount, we would pad the vector + // with identity elements until it's total bit size is 32. + Attribute identityAttr = + getCombiningKindIdentity(builder, kind, elementType); + identityAttr = DenseElementsAttr::get(unrolledLaneValType, identityAttr); + Value identity = builder.create<arith::ConstantOp>(loc, identityAttr, + unrolledLaneValType); + perLaneReduction = builder.create<vector::InsertStridedSliceOp>( + loc, input, identity, /*offsets=*/ArrayRef<int64_t>{0}, + /*strides=*/ArrayRef<int64_t>{1}); + } + return perLaneReduction; +} + +/// Emit identity variable. +static Value getCombiningIdentityValue(Location loc, OpBuilder &builder, + vector::CombiningKind kind, + Type identityType) { + auto vectorType = identityType.dyn_cast<VectorType>(); + Type elementType = identityType; + if (vectorType) { + elementType = vectorType.getElementType(); + } + Attribute identityAttr = getCombiningKindIdentity(builder, kind, elementType); + if (vectorType) { + identityAttr = DenseElementsAttr::get(vectorType, identityAttr); + } + assert(identityAttr && "Unknown identity value for the reduction"); + Value identity = + builder.create<arith::ConstantOp>(loc, identityAttr, identityType); + return identity; +} + /// Emit reduction across a group for a given input. static Value groupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size, @@ -113,7 +249,7 @@ size % warpSize == 0 && "Group reduction only support for sizes aligned on warp size for now."); // First reduce on a single thread to get per lane reduction value. - Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input); + Value laneVal = reduceToSupportedWidth(loc, builder, input, kind); laneVal = warpReduction(loc, builder, laneVal, kind, warpSize, warpSize); // if we have more than one warp, reduce across warps. if (size > warpSize) { @@ -148,15 +284,17 @@ Value useIdentityElement = builder.create<arith::CmpIOp>( loc, arith::CmpIPredicate::sge, laneId, cstNumWarp); numWarp = llvm::PowerOf2Ceil(numWarp); - Attribute identityAttr = - getCombiningKindIdentity(builder, kind, loadVal.getType()); - assert(identityAttr && "Unknown identity value for the reduction"); - Value identity = builder.create<arith::ConstantOp>(loc, identityAttr); + Value identity = + getCombiningIdentityValue(loc, builder, kind, loadVal.getType()); loadVal = builder.create<arith::SelectOp>(loc, useIdentityElement, identity, loadVal); } laneVal = warpReduction(loc, builder, loadVal, kind, warpSize, numWarp); } + // Handles cases for sub-32bit precision where output is still in vector form. + if (laneVal.getType().isa<VectorType>()) { + laneVal = builder.create<vector::ReductionOp>(loc, kind, laneVal); + } return laneVal; }
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir b/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir index 79f0f17..3bb8176 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir
@@ -85,6 +85,283 @@ #hal.descriptor_set.binding<1, storage_buffer> ]> ]> +hal.executable private @simple_half_reduce { + hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb { + hal.executable.export @simple_half_reduce layout(#pipeline_layout) attributes { + workgroup_size = [32 : index, 1 : index, 1 : index] + } + builtin.module { + func.func @simple_half_reduce() { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : vector<1xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant dense<3.840000e+02> : vector<1xf16> + %c32 = arith.constant 32 : index + %c384 = arith.constant 384 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<128x384xf16> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128xf16> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %2 = gpu.thread_id x + %3 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 + s0 floordiv 32)>()[%2, %workgroup_id_x] + %4 = vector.transfer_read %0[%3, %c0], %cst_0 {in_bounds = [true]} : memref<128x384xf16>, vector<384xf16> + %5 = vector.broadcast %4 : vector<384xf16> to vector<1x384xf16> + %6 = vector.multi_reduction <add>, %5, %cst [1] : vector<1x384xf16> to vector<1xf16> + %7 = arith.divf %6, %cst_1 : vector<1xf16> + vector.transfer_write %7, %1[%3] {in_bounds = [true]} : vector<1xf16>, memref<128xf16> + return + } + } + } +} + +// CHECK-LABEL: func.func @simple_half_reduce() { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0.000000e+00> : vector<2xf16> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32 +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32 +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32 +// CHECK-DAG: %[[C0F:.*]] = arith.constant 0.000000e+00 : f16 +// CHECK-DAG: %[[TID:.*]] = gpu.thread_id x +// CHECK: %[[READ_X:.*]] = affine.apply #{{.*}}()[%[[TID]], %{{.*}}] +// CHECK: %[[READ_Y:.*]] = affine.apply #{{.*}}()[%[[TID]]] +// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[READ_X]], %[[READ_Y]]], %[[C0F]] {in_bounds = [true]} : memref<128x384xf16>, vector<12xf16> +// CHECK: %[[V2:.*]] = vector.extract_strided_slice %[[V1]] {offsets = [0], sizes = [6], strides = [1]} : vector<12xf16> to vector<6xf16> +// CHECK: %[[V3:.*]] = vector.reduction <add>, %[[V2]] : vector<6xf16> into f16 +// CHECK: %[[V4:.*]] = vector.insert %[[V3]], %[[IDENTITY]] [0] : f16 into vector<2xf16> +// CHECK: %[[V5:.*]] = vector.extract_strided_slice %[[V1]] {offsets = [6], sizes = [6], strides = [1]} : vector<12xf16> to vector<6xf16> +// CHECK: %[[V6:.*]] = vector.reduction <add>, %[[V5]] : vector<6xf16> into f16 +// CHECK: %[[V7:.*]] = vector.insert %[[V6]], %[[V4]] [1] : f16 into vector<2xf16> +// CHECK: %[[CAST0:.*]] = vector.bitcast %[[V7]] : vector<2xf16> to vector<1xi32> +// CHECK: %[[PACK0:.*]] = vector.extract %[[CAST0]][0] : vector<1xi32> +// CHECK: %[[S0:.*]], %{{.*}} = gpu.shuffle xor %[[PACK0]], %[[C1]], %[[C32]] : i32 +// CHECK: %[[BROADCAST0:.*]] = vector.broadcast %[[S0]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK0:.*]] = vector.bitcast %[[BROADCAST0]] : vector<1xi32> to vector<2xf16> +// CHECK: %[[S1:.*]] = arith.addf %[[V7]], %[[UNPACK0]] : vector<2xf16> +// CHECK: %[[CAST1:.*]] = vector.bitcast %[[S1]] : vector<2xf16> to vector<1xi32> +// CHECK: %[[PACK1:.*]] = vector.extract %[[CAST1]][0] : vector<1xi32> +// CHECK: %[[S2:.*]], %{{.*}} = gpu.shuffle xor %[[PACK1]], %[[C2]], %[[C32]] : i32 +// CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[S2]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK1:.*]] = vector.bitcast %[[BROADCAST1]] : vector<1xi32> to vector<2xf16> +// CHECK: %[[S3:.*]] = arith.addf %[[S1]], %[[UNPACK1]] : vector<2xf16> +// CHECK: %[[CAST2:.*]] = vector.bitcast %[[S3]] : vector<2xf16> to vector<1xi32> +// CHECK: %[[PACK2:.*]] = vector.extract %[[CAST2]][0] : vector<1xi32> +// CHECK: %[[S4:.*]], %{{.*}} = gpu.shuffle xor %[[PACK2]], %[[C4]], %[[C32]] : i32 +// CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[S4]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK2:.*]] = vector.bitcast %[[BROADCAST2]] : vector<1xi32> to vector<2xf16> +// CHECK: %[[S5:.*]] = arith.addf %[[S3]], %[[UNPACK2]] : vector<2xf16> +// CHECK: %[[CAST3:.*]] = vector.bitcast %[[S5]] : vector<2xf16> to vector<1xi32> +// CHECK: %[[PACK3:.*]] = vector.extract %[[CAST3]][0] : vector<1xi32> +// CHECK: %[[S6:.*]], %{{.*}} = gpu.shuffle xor %[[PACK3]], %[[C8]], %[[C32]] : i32 +// CHECK: %[[BROADCAST3:.*]] = vector.broadcast %[[S6]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK3:.*]] = vector.bitcast %[[BROADCAST3]] : vector<1xi32> to vector<2xf16> +// CHECK: %[[S7:.*]] = arith.addf %[[S5]], %[[UNPACK3]] : vector<2xf16> +// CHECK: %[[CAST4:.*]] = vector.bitcast %[[S7]] : vector<2xf16> to vector<1xi32> +// CHECK: %[[PACK4:.*]] = vector.extract %[[CAST4]][0] : vector<1xi32> +// CHECK: %[[S8:.*]], %{{.*}} = gpu.shuffle xor %[[PACK4]], %[[C16]], %[[C32]] : i32 +// CHECK: %[[BROADCAST4:.*]] = vector.broadcast %[[S8]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK4:.*]] = vector.bitcast %[[BROADCAST4]] : vector<1xi32> to vector<2xf16> +// CHECK: %[[S9:.*]] = arith.addf %[[S7]], %[[UNPACK4]] : vector<2xf16> +// CHECK: %[[S10:.*]] = vector.reduction <add>, %[[S9]] : vector<2xf16> into f16 +// CHECK: %[[S11:.*]] = arith.addf %[[S10]], %[[C0F]] : f16 +// CHECK: %[[B:.*]] = vector.broadcast %[[S11]] : f16 to vector<1xf16> +// CHECK: %[[DIV:.*]] = arith.divf %[[B]], %{{.*}} : vector<1xf16> +// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[TID]], %[[C0]] : index +// CHECK: scf.if %[[CMP]] { +// CHECK: vector.transfer_write %[[DIV]], {{.*}} : vector<1xf16>, memref<128xf16> +// CHECK: } +// CHECK: return + +// ----- + +#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb"> +#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [ + #hal.descriptor_set.layout<0, bindings = [ + #hal.descriptor_set.binding<0, storage_buffer>, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> +hal.executable private @simple_half_reduce_small { + hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb { + hal.executable.export @simple_half_reduce_small layout(#pipeline_layout) attributes { + workgroup_size = [32 : index, 1 : index, 1 : index] + } + builtin.module { + func.func @simple_half_reduce_small() { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : vector<1xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant dense<3.840000e+02> : vector<1xf16> + %c32 = arith.constant 32 : index + %c384 = arith.constant 384 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<128x32xf16> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128xf16> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %2 = gpu.thread_id x + %3 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 + s0 floordiv 32)>()[%2, %workgroup_id_x] + %4 = vector.transfer_read %0[%3, %c0], %cst_0 {in_bounds = [true]} : memref<128x32xf16>, vector<32xf16> + %5 = vector.broadcast %4 : vector<32xf16> to vector<1x32xf16> + %6 = vector.multi_reduction <add>, %5, %cst [1] : vector<1x32xf16> to vector<1xf16> + %7 = arith.divf %6, %cst_1 : vector<1xf16> + vector.transfer_write %7, %1[%3] {in_bounds = [true]} : vector<1xf16>, memref<128xf16> + return + } + } + } +} + +// CHECK-LABEL: func.func @simple_half_reduce_small() { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0.000000e+00> : vector<2xf16> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32 +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32 +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32 +// CHECK-DAG: %[[C1F:.*]] = arith.constant 0.000000e+00 : f16 +// CHECK-DAG: %[[TID:.*]] = gpu.thread_id x +// CHECK: %[[ID:.*]] = affine.apply +// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[ID]], %{{.*}}], %{{.*}} {in_bounds = [true]} : memref<128x32xf16>, vector<1xf16> +// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[V1]], %[[IDENTITY]] {offsets = [0], strides = [1]} : vector<1xf16> into vector<2xf16> +// CHECK: %[[CAST0:.*]] = vector.bitcast %[[V2]] : vector<2xf16> to vector<1xi32> +// CHECK: %[[PACK0:.*]] = vector.extract %[[CAST0]][0] : vector<1xi32> +// CHECK: %[[S0:.*]], %{{.*}} = gpu.shuffle xor %[[PACK0]], %[[C1]], %[[C32]] : i32 +// CHECK: %[[BROADCAST0:.*]] = vector.broadcast %[[S0]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK0:.*]] = vector.bitcast %[[BROADCAST0]] : vector<1xi32> to vector<2xf16> +// CHECK: %[[S1:.*]] = arith.addf %[[V2]], %[[UNPACK0]] : vector<2xf16> +// CHECK: %[[CAST1:.*]] = vector.bitcast %[[S1]] : vector<2xf16> to vector<1xi32> +// CHECK: %[[PACK1:.*]] = vector.extract %[[CAST1]][0] : vector<1xi32> +// CHECK: %[[S2:.*]], %{{.*}} = gpu.shuffle xor %[[PACK1]], %[[C2]], %[[C32]] : i32 +// CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[S2]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK1:.*]] = vector.bitcast %[[BROADCAST1]] : vector<1xi32> to vector<2xf16> +// CHECK: %[[S3:.*]] = arith.addf %[[S1]], %[[UNPACK1]] : vector<2xf16> +// CHECK: %[[CAST2:.*]] = vector.bitcast %[[S3]] : vector<2xf16> to vector<1xi32> +// CHECK: %[[PACK2:.*]] = vector.extract %[[CAST2]][0] : vector<1xi32> +// CHECK: %[[S4:.*]], %{{.*}} = gpu.shuffle xor %[[PACK2]], %[[C4]], %[[C32]] : i32 +// CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[S4]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK2:.*]] = vector.bitcast %[[BROADCAST2]] : vector<1xi32> to vector<2xf16> +// CHECK: %[[S5:.*]] = arith.addf %[[S3]], %[[UNPACK2]] : vector<2xf16> +// CHECK: %[[CAST3:.*]] = vector.bitcast %[[S5]] : vector<2xf16> to vector<1xi32> +// CHECK: %[[PACK3:.*]] = vector.extract %[[CAST3]][0] : vector<1xi32> +// CHECK: %[[S6:.*]], %{{.*}} = gpu.shuffle xor %[[PACK3]], %[[C8]], %[[C32]] : i32 +// CHECK: %[[BROADCAST3:.*]] = vector.broadcast %[[S6]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK3:.*]] = vector.bitcast %[[BROADCAST3]] : vector<1xi32> to vector<2xf16> +// CHECK: %[[S7:.*]] = arith.addf %[[S5]], %[[UNPACK3]] : vector<2xf16> +// CHECK: %[[CAST4:.*]] = vector.bitcast %[[S7]] : vector<2xf16> to vector<1xi32> +// CHECK: %[[PACK4:.*]] = vector.extract %[[CAST4]][0] : vector<1xi32> +// CHECK: %[[S8:.*]], %{{.*}} = gpu.shuffle xor %[[PACK4]], %[[C16]], %[[C32]] : i32 +// CHECK: %[[BROADCAST4:.*]] = vector.broadcast %[[S8]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK4:.*]] = vector.bitcast %[[BROADCAST4]] : vector<1xi32> to vector<2xf16> +// CHECK: %[[S9:.*]] = arith.addf %[[S7]], %[[UNPACK4]] : vector<2xf16> +// CHECK: %[[S10:.*]] = vector.reduction <add>, %[[S9]] : vector<2xf16> into f16 +// CHECK: %[[S11:.*]] = arith.addf %[[S10]], %[[C1F]] : f16 +// CHECK: %[[B:.*]] = vector.broadcast %[[S11]] : f16 to vector<1xf16> +// CHECK: %[[DIV:.*]] = arith.divf %[[B]], %{{.*}} : vector<1xf16> +// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[TID]], %[[C0]] : index +// CHECK: scf.if %[[CMP]] { +// CHECK: vector.transfer_write %[[DIV]], {{.*}} : vector<1xf16>, memref<128xf16> +// CHECK: } +// CHECK: return + +// ----- + +#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb"> +#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [ + #hal.descriptor_set.layout<0, bindings = [ + #hal.descriptor_set.binding<0, storage_buffer>, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> +hal.executable private @simple_quarter_reduce { + hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb { + hal.executable.export @simple_quarter_reduce layout(#pipeline_layout) attributes { + workgroup_size = [32 : index, 1 : index, 1 : index] + } + builtin.module { + func.func @simple_quarter_reduce() { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : vector<1xi8> + %cst_0 = arith.constant 0 : i8 + %c32 = arith.constant 32 : index + %c384 = arith.constant 384 : index + %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<128x384xi8> + %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128xi8> + %workgroup_id_x = hal.interface.workgroup.id[0] : index + %2 = gpu.thread_id x + %3 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 + s0 floordiv 32)>()[%2, %workgroup_id_x] + %4 = vector.transfer_read %0[%3, %c0], %cst_0 {in_bounds = [true]} : memref<128x384xi8>, vector<32xi8> + %5 = vector.broadcast %4 : vector<32xi8> to vector<1x32xi8> + %6 = vector.multi_reduction <add>, %5, %cst [1] : vector<1x32xi8> to vector<1xi8> + vector.transfer_write %6, %1[%3] {in_bounds = [true]} : vector<1xi8>, memref<128xi8> + return + } + } + } +} + +// CHECK-LABEL: func.func @simple_quarter_reduce() { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0> : vector<4xi8> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32 +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32 +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : i32 +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : i32 +// CHECK-DAG: %[[TID:.*]] = gpu.thread_id x +// CHECK: %[[ID:.*]] = affine.apply +// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[ID]], %{{.*}}], %{{.*}} {in_bounds = [true]} : memref<128x384xi8>, vector<1xi8> +// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[V1]], %[[IDENTITY]] {offsets = [0], strides = [1]} : vector<1xi8> into vector<4xi8> +// CHECK: %[[CAST0:.*]] = vector.bitcast %[[V2]] : vector<4xi8> to vector<1xi32> +// CHECK: %[[PACK0:.*]] = vector.extract %[[CAST0]][0] : vector<1xi32> +// CHECK: %[[S0:.*]], %{{.*}} = gpu.shuffle xor %[[PACK0]], %[[C1]], %[[C32]] : i32 +// CHECK: %[[BROADCAST0:.*]] = vector.broadcast %[[S0]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK0:.*]] = vector.bitcast %[[BROADCAST0]] : vector<1xi32> to vector<4xi8> +// CHECK: %[[S1:.*]] = arith.addi %[[V2]], %[[UNPACK0]] : vector<4xi8> +// CHECK: %[[CAST1:.*]] = vector.bitcast %[[S1]] : vector<4xi8> to vector<1xi32> +// CHECK: %[[PACK1:.*]] = vector.extract %[[CAST1]][0] : vector<1xi32> +// CHECK: %[[S2:.*]], %{{.*}} = gpu.shuffle xor %[[PACK1]], %[[C2]], %[[C32]] : i32 +// CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[S2]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK1:.*]] = vector.bitcast %[[BROADCAST1]] : vector<1xi32> to vector<4xi8> +// CHECK: %[[S3:.*]] = arith.addi %[[S1]], %[[UNPACK1]] : vector<4xi8> +// CHECK: %[[CAST2:.*]] = vector.bitcast %[[S3]] : vector<4xi8> to vector<1xi32> +// CHECK: %[[PACK2:.*]] = vector.extract %[[CAST2]][0] : vector<1xi32> +// CHECK: %[[S4:.*]], %{{.*}} = gpu.shuffle xor %[[PACK2]], %[[C4]], %[[C32]] : i32 +// CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[S4]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK2:.*]] = vector.bitcast %[[BROADCAST2]] : vector<1xi32> to vector<4xi8> +// CHECK: %[[S5:.*]] = arith.addi %[[S3]], %[[UNPACK2]] : vector<4xi8> +// CHECK: %[[CAST3:.*]] = vector.bitcast %[[S5]] : vector<4xi8> to vector<1xi32> +// CHECK: %[[PACK3:.*]] = vector.extract %[[CAST3]][0] : vector<1xi32> +// CHECK: %[[S6:.*]], %{{.*}} = gpu.shuffle xor %[[PACK3]], %[[C8]], %[[C32]] : i32 +// CHECK: %[[BROADCAST3:.*]] = vector.broadcast %[[S6]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK3:.*]] = vector.bitcast %[[BROADCAST3]] : vector<1xi32> to vector<4xi8> +// CHECK: %[[S7:.*]] = arith.addi %[[S5]], %[[UNPACK3]] : vector<4xi8> +// CHECK: %[[CAST4:.*]] = vector.bitcast %[[S7]] : vector<4xi8> to vector<1xi32> +// CHECK: %[[PACK4:.*]] = vector.extract %[[CAST4]][0] : vector<1xi32> +// CHECK: %[[S8:.*]], %{{.*}} = gpu.shuffle xor %[[PACK4]], %[[C16]], %[[C32]] : i32 +// CHECK: %[[BROADCAST4:.*]] = vector.broadcast %[[S8]] : i32 to vector<1xi32> +// CHECK: %[[UNPACK4:.*]] = vector.bitcast %[[BROADCAST4]] : vector<1xi32> to vector<4xi8> +// CHECK: %[[S9:.*]] = arith.addi %[[S7]], %[[UNPACK4]] : vector<4xi8> +// CHECK: %[[S10:.*]] = vector.reduction <add>, %[[S9]] : vector<4xi8> into i8 +// CHECK: %[[B:.*]] = vector.broadcast %[[S10]] : i8 to vector<1xi8> +// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[TID]], %[[C0]] : index +// CHECK: scf.if %[[CMP]] { +// CHECK: vector.transfer_write %[[B]], {{.*}} : vector<1xi8>, memref<128xi8> +// CHECK: } +// CHECK: return + +// ----- + +#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb"> +#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [ + #hal.descriptor_set.layout<0, bindings = [ + #hal.descriptor_set.binding<0, storage_buffer>, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> hal.executable private @simple_reduce_multi_warp { hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb { hal.executable.export @simple_reduce_multi_warp layout(#pipeline_layout) attributes {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index adab5f5..f1e4bd7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -553,8 +553,8 @@ .getElementType(); if (!elementType.isIntOrFloat()) return failure(); unsigned bitWidth = elementType.getIntOrFloatBitWidth(); - // Reduction distribution only supports 32-bit types now. - if (bitWidth != 32) return failure(); + // Reduction distribution only supports 8/16/32 bit types now. + if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8) return failure(); const unsigned largestLoadSizeInBits = 128; unsigned vectorSize = largestLoadSizeInBits / bitWidth;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index fed8c1b..dab068a 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -969,8 +969,8 @@ op.getOutputs()[0].getType().cast<ShapedType>().getElementType(); if (!elementType.isIntOrFloat()) return failure(); unsigned bitWidth = elementType.getIntOrFloatBitWidth(); - // Reduction distribution only supports 32-bit types now. - if (bitWidth != 32) return failure(); + // Reduction distribution only supports 8/16/32 bit types now. + if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8) return failure(); // Let each thread handle `vectorSize` elements. unsigned vectorSize = kMaxVectorNumBits / bitWidth;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir index 0c7e6a1..e48bd68 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
@@ -101,11 +101,11 @@ } } -// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 256], [1, 4], [0, 0, 4]{{\]}}> -// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseVectorize> +// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 4096]{{\]}}> +// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce> // CHECK: hal.executable.export public @subgroup_reduce_f16 // CHECK-SAME: translation_info = #[[TRANSLATION]] -// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index] +// CHECK-SAME: workgroup_size = [512 : index, 1 : index, 1 : index] // CHECK: func.func @subgroup_reduce_f16() // CHECK: linalg.generic // CHECK-SAME: lowering_config = #[[CONFIG]]
diff --git a/tests/e2e/regression/large_reduction.mlir b/tests/e2e/regression/large_reduction.mlir index 15e6fe4..cdc24b6 100644 --- a/tests/e2e/regression/large_reduction.mlir +++ b/tests/e2e/regression/large_reduction.mlir
@@ -49,3 +49,37 @@ check.expect_almost_eq_const(%result, dense<40.96> : tensor<2xf32>) : tensor<2xf32> return } + +func.func @half_reduction_aligned() { + %in = util.unfoldable_constant dense<0.001> : tensor<2x4096xf16> + %cst = arith.constant 0.0 : f16 + %init = tensor.empty() : tensor<2xf16> + %fill = linalg.fill ins(%cst : f16) outs(%init : tensor<2xf16>) -> tensor<2xf16> + %result = linalg.generic {indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%in : tensor<2x4096xf16>) outs(%fill : tensor<2xf16>) { + ^bb0(%arg3: f16, %arg4: f16): // no predecessors + %2 = arith.addf %arg3, %arg4 : f16 + linalg.yield %2 : f16 + } -> tensor<2xf16> + check.expect_almost_eq_const(%result, dense<4.096> : tensor<2xf16>) : tensor<2xf16> + return +} + +func.func @quarter_reduction_aligned_smaller() { + %in = util.unfoldable_constant dense<1> : tensor<128x128xi8> + %cst = arith.constant 0 : i8 + %init = tensor.empty() : tensor<128xi8> + %fill = linalg.fill ins(%cst : i8) outs(%init : tensor<128xi8>) -> tensor<128xi8> + %result = linalg.generic {indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%in : tensor<128x128xi8>) outs(%fill : tensor<128xi8>) { + ^bb0(%arg3: i8, %arg4: i8): // no predecessors + %2 = arith.addi %arg3, %arg4 : i8 + linalg.yield %2 : i8 + } -> tensor<128xi8> + check.expect_eq_const(%result, dense<128> : tensor<128xi8>) : tensor<128xi8> + return +} \ No newline at end of file