[mlir][gpu] Pack and unpack to enable f16 and int8 warp reduce. (#11349)

diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp
index 1cf1680..b60b354 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp
@@ -20,6 +20,8 @@
 
 #define DEBUG_TYPE "iree-codegen-reduction-distribution"
 
+static constexpr unsigned kShuffleBitWidth = 32;
+
 namespace mlir {
 namespace iree_compiler {
 
@@ -40,28 +42,90 @@
   return builder.create<memref::AllocOp>(loc, memrefType);
 }
 
+/// Packs scalar element to it's vector equivalent.
+/// (i.e f16 -> vector<1xf16> and f32 -> vector<1xf32>)
+static Value promoteElementToVector(Location loc, OpBuilder &builder,
+                                    Value input) {
+  VectorType vectorTypeBroadcast = VectorType::get({1}, input.getType());
+  Value vectorInput =
+      builder.create<vector::BroadcastOp>(loc, vectorTypeBroadcast, input);
+  return vectorInput;
+}
+
+/// Packs vector of lower precision into a single 32-bit width element.
+/// (i.e <2xf16> -> i32 and <4xi8> -> i32)
+static Value packVectorToSupportedWidth(Location loc, OpBuilder &builder,
+                                        Value input) {
+  LLVM_DEBUG({
+    auto vecType = input.getType().cast<VectorType>();
+    Type elementType = vecType.getElementType();
+    assert(vecType.getDimSize(0) * elementType.getIntOrFloatBitWidth() ==
+               kShuffleBitWidth &&
+           "vecSize * vecBitWidth needs to packable into 32-bitwidth.");
+    assert(elementType.isIntOrFloat() &&
+           "Only int and float packing is supported.");
+  });
+  VectorType packed32Type = VectorType::get({1}, builder.getI32Type());
+  Value packedInputVec =
+      builder.create<vector::BitCastOp>(loc, packed32Type, input);
+  Value packedInput = builder.create<vector::ExtractOp>(loc, packedInputVec, 0);
+  return packedInput;
+}
+
+/// Unpack single scalar element into a target vector type.
+/// (i.e i32 -> vector<4xi8> or f32 -> vector<2xf16>)
+static Value unpackToVector(Location loc, OpBuilder &builder, Value packedInput,
+                            VectorType targetVecType) {
+  LLVM_DEBUG({
+    Type packedType = packedInput.getType();
+    assert(packedType.isIntOrFloat() && "Only ints and floats are unpackable.");
+    Type elementType = targetVecType.getElementType();
+    assert(targetVecType.getDimSize(0) * elementType.getIntOrFloatBitWidth() ==
+               packedType.getIntOrFloatBitWidth() &&
+           "packed width needs to be unpackable to vecSize * vecBitWidth.");
+  });
+  Value packedVector = promoteElementToVector(loc, builder, packedInput);
+  Value unpackedVector =
+      builder.create<vector::BitCastOp>(loc, targetVecType, packedVector);
+  return unpackedVector;
+}
+
 /// Emit warp reduction code sequence for a given input.
 static Value warpReduction(Location loc, OpBuilder &builder, Value input,
                            vector::CombiningKind kind, uint32_t warpSize,
                            uint32_t numLaneToReduce) {
+  VectorType unpackedType = input.getType().dyn_cast<VectorType>();
   Value laneVal = input;
   assert(llvm::isPowerOf2_32(numLaneToReduce));
   // Parallel reduction using butterfly shuffles.
   for (uint64_t i = 1; i < numLaneToReduce; i <<= 1) {
+    Value shuffleInput = laneVal;
+    if (unpackedType) {
+      shuffleInput = packVectorToSupportedWidth(loc, builder, laneVal);
+    }
     Value shuffled = builder
-                         .create<gpu::ShuffleOp>(loc, laneVal, i,
+                         .create<gpu::ShuffleOp>(loc, shuffleInput, i,
                                                  /*width=*/warpSize,
                                                  /*mode=*/gpu::ShuffleMode::XOR)
                          .getShuffleResult();
+    if (unpackedType) {
+      shuffled = unpackToVector(loc, builder, shuffled, unpackedType);
+    }
     laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
   }
   // Broadcast the result to all the lanes.
   if (warpSize != numLaneToReduce) {
+    if (unpackedType) {
+      laneVal = packVectorToSupportedWidth(loc, builder, laneVal);
+    }
     laneVal = builder
                   .create<gpu::ShuffleOp>(loc, laneVal, 0,
                                           /*width=*/warpSize,
                                           /*mode=*/gpu::ShuffleMode::IDX)
                   .getShuffleResult();
+    if (unpackedType) {
+      laneVal = unpackToVector(loc, builder, laneVal, unpackedType);
+    }
   }
   return laneVal;
 }
@@ -105,6 +169,78 @@
   return Attribute();
 }
 
+/// Compute the value on a single thread to get per lane reduction value.
+/// If bit-width is not supported on shuffle operations, and a lower precision,
+/// we represent them as a vector S.T we can pack them into a single 32-bit
+/// width for shuffles.
+static Value reduceToSupportedWidth(Location loc, OpBuilder &builder,
+                                    Value input, vector::CombiningKind kind) {
+  auto vecType = input.getType().cast<VectorType>();
+  Type elementType = vecType.getElementType();
+  int64_t vecSize = vecType.getDimSize(0);
+  unsigned bitWidth = elementType.getIntOrFloatBitWidth();
+  // Simply reduce if it's already 32 bits.
+  if (bitWidth == kShuffleBitWidth) {
+    return builder.create<vector::ReductionOp>(loc, kind, input);
+  }
+  assert(kShuffleBitWidth % bitWidth == 0 &&
+         "Bitwidth needs to be able to be packed into shuffle-bitwidth.");
+  int64_t unrollCount = kShuffleBitWidth / bitWidth;
+  // Original size needs to be divisble by or less than unroll count to
+  // determine slice size.
+  assert(vecSize % unrollCount == 0 || vecSize < unrollCount);
+  unsigned sliceSize = vecSize / unrollCount;
+  VectorType unrolledLaneValType = VectorType::get({unrollCount}, elementType);
+  Value perLaneReduction = builder.create<arith::ConstantOp>(
+      loc, builder.getZeroAttr(unrolledLaneValType));
+  if (vecSize % unrollCount == 0) {
+    // Unroll reductions s.t we can pack into a supported 32-bitWidth format.
+    for (int64_t i = 0; i < unrollCount; i++) {
+      Value laneValSlice = builder.create<vector::ExtractStridedSliceOp>(
+          loc, input,
+          /*offsets=*/ArrayRef<int64_t>{sliceSize * i},
+          /*sizes=*/ArrayRef<int64_t>{sliceSize},
+          /*strides=*/ArrayRef<int64_t>{1});
+      Value reductionSlice =
+          builder.create<vector::ReductionOp>(loc, kind, laneValSlice);
+      SmallVector<int64_t> perLaneUnrollId = {i};
+      perLaneReduction = builder.create<vector::InsertOp>(
+          loc, reductionSlice, perLaneReduction, perLaneUnrollId);
+    }
+  } else {
+    // In cases where vecSize < unrollCount, we would pad the vector
+    // with identity elements until it's total bit size is 32.
+    Attribute identityAttr =
+        getCombiningKindIdentity(builder, kind, elementType);
+    identityAttr = DenseElementsAttr::get(unrolledLaneValType, identityAttr);
+    Value identity = builder.create<arith::ConstantOp>(loc, identityAttr,
+                                                       unrolledLaneValType);
+    perLaneReduction = builder.create<vector::InsertStridedSliceOp>(
+        loc, input, identity, /*offsets=*/ArrayRef<int64_t>{0},
+        /*strides=*/ArrayRef<int64_t>{1});
+  }
+  return perLaneReduction;
+}
+
+/// Emit identity variable.
+static Value getCombiningIdentityValue(Location loc, OpBuilder &builder,
+                                       vector::CombiningKind kind,
+                                       Type identityType) {
+  auto vectorType = identityType.dyn_cast<VectorType>();
+  Type elementType = identityType;
+  if (vectorType) {
+    elementType = vectorType.getElementType();
+  }
+  Attribute identityAttr = getCombiningKindIdentity(builder, kind, elementType);
+  if (vectorType) {
+    identityAttr = DenseElementsAttr::get(vectorType, identityAttr);
+  }
+  assert(identityAttr && "Unknown identity value for the reduction");
+  Value identity =
+      builder.create<arith::ConstantOp>(loc, identityAttr, identityType);
+  return identity;
+}
+
 /// Emit reduction across a group for a given input.
 static Value groupReduction(Location loc, OpBuilder &builder, Value input,
                             vector::CombiningKind kind, uint32_t size,
@@ -113,7 +249,7 @@
       size % warpSize == 0 &&
       "Group reduction only support for sizes aligned on warp size for now.");
   // First reduce on a single thread to get per lane reduction value.
-  Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+  Value laneVal = reduceToSupportedWidth(loc, builder, input, kind);
   laneVal = warpReduction(loc, builder, laneVal, kind, warpSize, warpSize);
   // if we have more than one warp, reduce across warps.
   if (size > warpSize) {
@@ -148,15 +284,17 @@
       Value useIdentityElement = builder.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::sge, laneId, cstNumWarp);
       numWarp = llvm::PowerOf2Ceil(numWarp);
-      Attribute identityAttr =
-          getCombiningKindIdentity(builder, kind, loadVal.getType());
-      assert(identityAttr && "Unknown identity value for the reduction");
-      Value identity = builder.create<arith::ConstantOp>(loc, identityAttr);
+      Value identity =
+          getCombiningIdentityValue(loc, builder, kind, loadVal.getType());
       loadVal = builder.create<arith::SelectOp>(loc, useIdentityElement,
                                                 identity, loadVal);
     }
     laneVal = warpReduction(loc, builder, loadVal, kind, warpSize, numWarp);
   }
+  // Handles cases for sub-32bit precision where output is still in vector form.
+  if (laneVal.getType().isa<VectorType>()) {
+    laneVal = builder.create<vector::ReductionOp>(loc, kind, laneVal);
+  }
   return laneVal;
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir b/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir
index 79f0f17..3bb8176 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/warp_reduction.mlir
@@ -85,6 +85,283 @@
     #hal.descriptor_set.binding<1, storage_buffer>
   ]>
 ]>
+hal.executable private @simple_half_reduce  {
+  hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
+    hal.executable.export @simple_half_reduce layout(#pipeline_layout) attributes {
+      workgroup_size = [32 : index, 1 : index, 1 : index]
+    }
+    builtin.module {
+    func.func @simple_half_reduce() {
+      %c0 = arith.constant 0 : index
+      %cst = arith.constant dense<0.000000e+00> : vector<1xf16>
+      %cst_0 = arith.constant 0.000000e+00 : f16
+      %cst_1 = arith.constant dense<3.840000e+02> : vector<1xf16>
+      %c32 = arith.constant 32 : index
+      %c384 = arith.constant 384 : index
+      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<128x384xf16>
+      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128xf16>
+      %workgroup_id_x = hal.interface.workgroup.id[0] : index
+      %2 = gpu.thread_id  x
+      %3 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 + s0 floordiv 32)>()[%2, %workgroup_id_x]
+      %4 = vector.transfer_read %0[%3, %c0], %cst_0 {in_bounds = [true]} : memref<128x384xf16>, vector<384xf16>
+      %5 = vector.broadcast %4 : vector<384xf16> to vector<1x384xf16>
+      %6 = vector.multi_reduction <add>, %5, %cst [1] : vector<1x384xf16> to vector<1xf16>
+      %7 = arith.divf %6, %cst_1 : vector<1xf16>
+      vector.transfer_write %7, %1[%3] {in_bounds = [true]} : vector<1xf16>, memref<128xf16>
+      return
+    }
+    }
+  }
+}
+
+// CHECK-LABEL: func.func @simple_half_reduce() {
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[IDENTITY:.*]] = arith.constant dense<0.000000e+00> : vector<2xf16>
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : i32
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : i32
+//   CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : i32
+//   CHECK-DAG:   %[[C8:.*]] = arith.constant 8 : i32
+//   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : i32
+//   CHECK-DAG:   %[[C32:.*]] = arith.constant 32 : i32
+//   CHECK-DAG:   %[[C0F:.*]] = arith.constant 0.000000e+00 : f16
+//   CHECK-DAG:   %[[TID:.*]] = gpu.thread_id  x
+//       CHECK:   %[[READ_X:.*]] = affine.apply #{{.*}}()[%[[TID]], %{{.*}}]
+//       CHECK:   %[[READ_Y:.*]] = affine.apply #{{.*}}()[%[[TID]]]
+//       CHECK:   %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[READ_X]], %[[READ_Y]]], %[[C0F]] {in_bounds = [true]} : memref<128x384xf16>, vector<12xf16>
+//       CHECK:   %[[V2:.*]] = vector.extract_strided_slice %[[V1]] {offsets = [0], sizes = [6], strides = [1]} : vector<12xf16> to vector<6xf16>
+//       CHECK:   %[[V3:.*]] = vector.reduction <add>, %[[V2]] : vector<6xf16> into f16
+//       CHECK:   %[[V4:.*]] = vector.insert %[[V3]], %[[IDENTITY]] [0] : f16 into vector<2xf16>
+//       CHECK:   %[[V5:.*]] = vector.extract_strided_slice %[[V1]] {offsets = [6], sizes = [6], strides = [1]} : vector<12xf16> to vector<6xf16>
+//       CHECK:   %[[V6:.*]] = vector.reduction <add>, %[[V5]] : vector<6xf16> into f16
+//       CHECK:   %[[V7:.*]] = vector.insert %[[V6]], %[[V4]] [1] : f16 into vector<2xf16>
+//       CHECK:   %[[CAST0:.*]] = vector.bitcast %[[V7]] : vector<2xf16> to vector<1xi32>
+//       CHECK:   %[[PACK0:.*]] = vector.extract %[[CAST0]][0] : vector<1xi32>
+//       CHECK:   %[[S0:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK0]], %[[C1]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST0:.*]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK0:.*]] = vector.bitcast %[[BROADCAST0]] : vector<1xi32> to vector<2xf16>
+//       CHECK:   %[[S1:.*]] = arith.addf %[[V7]], %[[UNPACK0]] : vector<2xf16>
+//       CHECK:   %[[CAST1:.*]] = vector.bitcast %[[S1]] : vector<2xf16> to vector<1xi32>
+//       CHECK:   %[[PACK1:.*]] = vector.extract %[[CAST1]][0] : vector<1xi32>
+//       CHECK:   %[[S2:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK1]], %[[C2]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST1:.*]] = vector.broadcast %[[S2]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK1:.*]] = vector.bitcast %[[BROADCAST1]] : vector<1xi32> to vector<2xf16>
+//       CHECK:   %[[S3:.*]] = arith.addf %[[S1]], %[[UNPACK1]] : vector<2xf16>
+//       CHECK:   %[[CAST2:.*]] = vector.bitcast %[[S3]] : vector<2xf16> to vector<1xi32>
+//       CHECK:   %[[PACK2:.*]] = vector.extract %[[CAST2]][0] : vector<1xi32>
+//       CHECK:   %[[S4:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK2]], %[[C4]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST2:.*]] = vector.broadcast %[[S4]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK2:.*]] = vector.bitcast %[[BROADCAST2]] : vector<1xi32> to vector<2xf16>
+//       CHECK:   %[[S5:.*]] = arith.addf %[[S3]], %[[UNPACK2]] : vector<2xf16>
+//       CHECK:   %[[CAST3:.*]] = vector.bitcast %[[S5]] : vector<2xf16> to vector<1xi32>
+//       CHECK:   %[[PACK3:.*]] = vector.extract %[[CAST3]][0] : vector<1xi32>
+//       CHECK:   %[[S6:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK3]], %[[C8]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST3:.*]] = vector.broadcast %[[S6]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK3:.*]] = vector.bitcast %[[BROADCAST3]] : vector<1xi32> to vector<2xf16>
+//       CHECK:   %[[S7:.*]] = arith.addf %[[S5]], %[[UNPACK3]] : vector<2xf16>
+//       CHECK:   %[[CAST4:.*]] = vector.bitcast %[[S7]] : vector<2xf16> to vector<1xi32>
+//       CHECK:   %[[PACK4:.*]] = vector.extract %[[CAST4]][0] : vector<1xi32>
+//       CHECK:   %[[S8:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK4]], %[[C16]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST4:.*]] = vector.broadcast %[[S8]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK4:.*]] = vector.bitcast %[[BROADCAST4]] : vector<1xi32> to vector<2xf16>
+//       CHECK:   %[[S9:.*]] = arith.addf %[[S7]], %[[UNPACK4]] : vector<2xf16>
+//       CHECK:   %[[S10:.*]] = vector.reduction <add>, %[[S9]] : vector<2xf16> into f16
+//       CHECK:   %[[S11:.*]] = arith.addf %[[S10]], %[[C0F]] : f16
+//       CHECK:   %[[B:.*]] = vector.broadcast %[[S11]] : f16 to vector<1xf16>
+//       CHECK:   %[[DIV:.*]] = arith.divf %[[B]], %{{.*}} : vector<1xf16>
+//       CHECK:   %[[CMP:.*]] = arith.cmpi eq, %[[TID]], %[[C0]] : index
+//       CHECK:   scf.if %[[CMP]] {
+//       CHECK:     vector.transfer_write %[[DIV]], {{.*}} : vector<1xf16>, memref<128xf16>
+//       CHECK:   }
+//       CHECK:   return
+
+// -----
+
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable private @simple_half_reduce_small  {
+  hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
+    hal.executable.export @simple_half_reduce_small layout(#pipeline_layout) attributes {
+      workgroup_size = [32 : index, 1 : index, 1 : index]
+    }
+    builtin.module {
+    func.func @simple_half_reduce_small() {
+      %c0 = arith.constant 0 : index
+      %cst = arith.constant dense<0.000000e+00> : vector<1xf16>
+      %cst_0 = arith.constant 0.000000e+00 : f16
+      %cst_1 = arith.constant dense<3.840000e+02> : vector<1xf16>
+      %c32 = arith.constant 32 : index
+      %c384 = arith.constant 384 : index
+      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<128x32xf16>
+      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128xf16>
+      %workgroup_id_x = hal.interface.workgroup.id[0] : index
+      %2 = gpu.thread_id  x
+      %3 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 + s0 floordiv 32)>()[%2, %workgroup_id_x]
+      %4 = vector.transfer_read %0[%3, %c0], %cst_0 {in_bounds = [true]} : memref<128x32xf16>, vector<32xf16>
+      %5 = vector.broadcast %4 : vector<32xf16> to vector<1x32xf16>
+      %6 = vector.multi_reduction <add>, %5, %cst [1] : vector<1x32xf16> to vector<1xf16>
+      %7 = arith.divf %6, %cst_1 : vector<1xf16>
+      vector.transfer_write %7, %1[%3] {in_bounds = [true]} : vector<1xf16>, memref<128xf16>
+      return
+    }
+    }
+  }
+}
+
+// CHECK-LABEL: func.func @simple_half_reduce_small() {
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[IDENTITY:.*]] = arith.constant dense<0.000000e+00> : vector<2xf16>
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : i32
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : i32
+//   CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : i32
+//   CHECK-DAG:   %[[C8:.*]] = arith.constant 8 : i32
+//   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : i32
+//   CHECK-DAG:   %[[C32:.*]] = arith.constant 32 : i32
+//   CHECK-DAG:   %[[C1F:.*]] = arith.constant 0.000000e+00 : f16
+//   CHECK-DAG:   %[[TID:.*]] = gpu.thread_id  x
+//       CHECK:   %[[ID:.*]] = affine.apply
+//       CHECK:   %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[ID]], %{{.*}}], %{{.*}} {in_bounds = [true]} : memref<128x32xf16>, vector<1xf16>
+//       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[V1]], %[[IDENTITY]] {offsets = [0], strides = [1]} : vector<1xf16> into vector<2xf16>
+//       CHECK:   %[[CAST0:.*]] = vector.bitcast %[[V2]] : vector<2xf16> to vector<1xi32>
+//       CHECK:   %[[PACK0:.*]] = vector.extract %[[CAST0]][0] : vector<1xi32>
+//       CHECK:   %[[S0:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK0]], %[[C1]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST0:.*]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK0:.*]] = vector.bitcast %[[BROADCAST0]] : vector<1xi32> to vector<2xf16>
+//       CHECK:   %[[S1:.*]] = arith.addf %[[V2]], %[[UNPACK0]] : vector<2xf16>
+//       CHECK:   %[[CAST1:.*]] = vector.bitcast %[[S1]] : vector<2xf16> to vector<1xi32>
+//       CHECK:   %[[PACK1:.*]] = vector.extract %[[CAST1]][0] : vector<1xi32>
+//       CHECK:   %[[S2:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK1]], %[[C2]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST1:.*]] = vector.broadcast %[[S2]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK1:.*]] = vector.bitcast %[[BROADCAST1]] : vector<1xi32> to vector<2xf16>
+//       CHECK:   %[[S3:.*]] = arith.addf %[[S1]], %[[UNPACK1]] : vector<2xf16>
+//       CHECK:   %[[CAST2:.*]] = vector.bitcast %[[S3]] : vector<2xf16> to vector<1xi32>
+//       CHECK:   %[[PACK2:.*]] = vector.extract %[[CAST2]][0] : vector<1xi32>
+//       CHECK:   %[[S4:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK2]], %[[C4]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST2:.*]] = vector.broadcast %[[S4]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK2:.*]] = vector.bitcast %[[BROADCAST2]] : vector<1xi32> to vector<2xf16>
+//       CHECK:   %[[S5:.*]] = arith.addf %[[S3]], %[[UNPACK2]] : vector<2xf16>
+//       CHECK:   %[[CAST3:.*]] = vector.bitcast %[[S5]] : vector<2xf16> to vector<1xi32>
+//       CHECK:   %[[PACK3:.*]] = vector.extract %[[CAST3]][0] : vector<1xi32>
+//       CHECK:   %[[S6:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK3]], %[[C8]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST3:.*]] = vector.broadcast %[[S6]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK3:.*]] = vector.bitcast %[[BROADCAST3]] : vector<1xi32> to vector<2xf16>
+//       CHECK:   %[[S7:.*]] = arith.addf %[[S5]], %[[UNPACK3]] : vector<2xf16>
+//       CHECK:   %[[CAST4:.*]] = vector.bitcast %[[S7]] : vector<2xf16> to vector<1xi32>
+//       CHECK:   %[[PACK4:.*]] = vector.extract %[[CAST4]][0] : vector<1xi32>
+//       CHECK:   %[[S8:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK4]], %[[C16]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST4:.*]] = vector.broadcast %[[S8]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK4:.*]] = vector.bitcast %[[BROADCAST4]] : vector<1xi32> to vector<2xf16>
+//       CHECK:   %[[S9:.*]] = arith.addf %[[S7]], %[[UNPACK4]] : vector<2xf16>
+//       CHECK:   %[[S10:.*]] = vector.reduction <add>, %[[S9]] : vector<2xf16> into f16
+//       CHECK:   %[[S11:.*]] = arith.addf %[[S10]], %[[C1F]] : f16
+//       CHECK:   %[[B:.*]] = vector.broadcast %[[S11]] : f16 to vector<1xf16>
+//       CHECK:   %[[DIV:.*]] = arith.divf %[[B]], %{{.*}} : vector<1xf16>
+//       CHECK:   %[[CMP:.*]] = arith.cmpi eq, %[[TID]], %[[C0]] : index
+//       CHECK:   scf.if %[[CMP]] {
+//       CHECK:     vector.transfer_write %[[DIV]], {{.*}} : vector<1xf16>, memref<128xf16>
+//       CHECK:   }
+//       CHECK:   return
+
+// -----
+
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable private @simple_quarter_reduce  {
+  hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
+    hal.executable.export @simple_quarter_reduce layout(#pipeline_layout) attributes {
+      workgroup_size = [32 : index, 1 : index, 1 : index]
+    }
+    builtin.module {
+    func.func @simple_quarter_reduce() {
+      %c0 = arith.constant 0 : index
+      %cst = arith.constant dense<0> : vector<1xi8>
+      %cst_0 = arith.constant 0 : i8
+      %c32 = arith.constant 32 : index
+      %c384 = arith.constant 384 : index
+      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : memref<128x384xi8>
+      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : memref<128xi8>
+      %workgroup_id_x = hal.interface.workgroup.id[0] : index
+      %2 = gpu.thread_id  x
+      %3 = affine.apply affine_map<()[s0, s1] -> (s1 * 2 + s0 floordiv 32)>()[%2, %workgroup_id_x]
+      %4 = vector.transfer_read %0[%3, %c0], %cst_0 {in_bounds = [true]} : memref<128x384xi8>, vector<32xi8>
+      %5 = vector.broadcast %4 : vector<32xi8> to vector<1x32xi8>
+      %6 = vector.multi_reduction <add>, %5, %cst [1] : vector<1x32xi8> to vector<1xi8>
+      vector.transfer_write %6, %1[%3] {in_bounds = [true]} : vector<1xi8>, memref<128xi8>
+      return
+    }
+    }
+  }
+}
+
+// CHECK-LABEL: func.func @simple_quarter_reduce() {
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[IDENTITY:.*]] = arith.constant dense<0> : vector<4xi8>
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : i32
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : i32
+//   CHECK-DAG:   %[[C4:.*]] = arith.constant 4 : i32
+//   CHECK-DAG:   %[[C8:.*]] = arith.constant 8 : i32
+//   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : i32
+//   CHECK-DAG:   %[[C32:.*]] = arith.constant 32 : i32
+//   CHECK-DAG:   %[[TID:.*]] = gpu.thread_id  x
+//       CHECK:   %[[ID:.*]] = affine.apply
+//       CHECK:   %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[ID]], %{{.*}}], %{{.*}} {in_bounds = [true]} : memref<128x384xi8>, vector<1xi8>
+//       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[V1]], %[[IDENTITY]] {offsets = [0], strides = [1]} : vector<1xi8> into vector<4xi8>
+//       CHECK:   %[[CAST0:.*]] = vector.bitcast %[[V2]] : vector<4xi8> to vector<1xi32>
+//       CHECK:   %[[PACK0:.*]] = vector.extract %[[CAST0]][0] : vector<1xi32>
+//       CHECK:   %[[S0:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK0]], %[[C1]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST0:.*]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK0:.*]] = vector.bitcast %[[BROADCAST0]] : vector<1xi32> to vector<4xi8>
+//       CHECK:   %[[S1:.*]] = arith.addi %[[V2]], %[[UNPACK0]] : vector<4xi8>
+//       CHECK:   %[[CAST1:.*]] = vector.bitcast %[[S1]] : vector<4xi8> to vector<1xi32>
+//       CHECK:   %[[PACK1:.*]] = vector.extract %[[CAST1]][0] : vector<1xi32>
+//       CHECK:   %[[S2:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK1]], %[[C2]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST1:.*]] = vector.broadcast %[[S2]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK1:.*]] = vector.bitcast %[[BROADCAST1]] : vector<1xi32> to vector<4xi8>
+//       CHECK:   %[[S3:.*]] = arith.addi %[[S1]], %[[UNPACK1]] : vector<4xi8>
+//       CHECK:   %[[CAST2:.*]] = vector.bitcast %[[S3]] : vector<4xi8> to vector<1xi32>
+//       CHECK:   %[[PACK2:.*]] = vector.extract %[[CAST2]][0] : vector<1xi32>
+//       CHECK:   %[[S4:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK2]], %[[C4]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST2:.*]] = vector.broadcast %[[S4]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK2:.*]] = vector.bitcast %[[BROADCAST2]] : vector<1xi32> to vector<4xi8>
+//       CHECK:   %[[S5:.*]] = arith.addi %[[S3]], %[[UNPACK2]] : vector<4xi8>
+//       CHECK:   %[[CAST3:.*]] = vector.bitcast %[[S5]] : vector<4xi8> to vector<1xi32>
+//       CHECK:   %[[PACK3:.*]] = vector.extract %[[CAST3]][0] : vector<1xi32>
+//       CHECK:   %[[S6:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK3]], %[[C8]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST3:.*]] = vector.broadcast %[[S6]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK3:.*]] = vector.bitcast %[[BROADCAST3]] : vector<1xi32> to vector<4xi8>
+//       CHECK:   %[[S7:.*]] = arith.addi %[[S5]], %[[UNPACK3]] : vector<4xi8>
+//       CHECK:   %[[CAST4:.*]] = vector.bitcast %[[S7]] : vector<4xi8> to vector<1xi32>
+//       CHECK:   %[[PACK4:.*]] = vector.extract %[[CAST4]][0] : vector<1xi32>
+//       CHECK:   %[[S8:.*]], %{{.*}} = gpu.shuffle  xor %[[PACK4]], %[[C16]], %[[C32]] : i32
+//       CHECK:   %[[BROADCAST4:.*]] = vector.broadcast %[[S8]] : i32 to vector<1xi32>
+//       CHECK:   %[[UNPACK4:.*]] = vector.bitcast %[[BROADCAST4]] : vector<1xi32> to vector<4xi8>
+//       CHECK:   %[[S9:.*]] = arith.addi %[[S7]], %[[UNPACK4]] : vector<4xi8>
+//       CHECK:   %[[S10:.*]] = vector.reduction <add>, %[[S9]] : vector<4xi8> into i8
+//       CHECK:   %[[B:.*]] = vector.broadcast %[[S10]] : i8 to vector<1xi8>
+//       CHECK:   %[[CMP:.*]] = arith.cmpi eq, %[[TID]], %[[C0]] : index
+//       CHECK:   scf.if %[[CMP]] {
+//       CHECK:     vector.transfer_write %[[B]], {{.*}} : vector<1xi8>, memref<128xi8>
+//       CHECK:   }
+//       CHECK:   return
+
+// -----
+
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb">
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
 hal.executable private @simple_reduce_multi_warp  {
   hal.executable.variant @cuda, target = #executable_target_cuda_nvptx_fb {
     hal.executable.export @simple_reduce_multi_warp layout(#pipeline_layout) attributes {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index adab5f5..f1e4bd7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -553,8 +553,8 @@
                                .getElementType();
   if (!elementType.isIntOrFloat()) return failure();
   unsigned bitWidth = elementType.getIntOrFloatBitWidth();
-  // Reduction distribution only supports 32-bit types now.
-  if (bitWidth != 32) return failure();
+  // Reduction distribution only supports 8/16/32 bit types now.
+  if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8) return failure();
 
   const unsigned largestLoadSizeInBits = 128;
   unsigned vectorSize = largestLoadSizeInBits / bitWidth;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index fed8c1b..dab068a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -969,8 +969,8 @@
       op.getOutputs()[0].getType().cast<ShapedType>().getElementType();
   if (!elementType.isIntOrFloat()) return failure();
   unsigned bitWidth = elementType.getIntOrFloatBitWidth();
-  // Reduction distribution only supports 32-bit types now.
-  if (bitWidth != 32) return failure();
+  // Reduction distribution only supports 8/16/32 bit types now.
+  if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8) return failure();
 
   // Let each thread handle `vectorSize` elements.
   unsigned vectorSize = kMaxVectorNumBits / bitWidth;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
index 0c7e6a1..e48bd68 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_reduction.mlir
@@ -101,11 +101,11 @@
   }
 }
 
-//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 256], [1, 4], [0, 0, 4]{{\]}}>
-//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVBaseVectorize>
+//  CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 4096]{{\]}}>
+//  CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVSubgroupReduce>
 //      CHECK: hal.executable.export public @subgroup_reduce_f16
 // CHECK-SAME:   translation_info = #[[TRANSLATION]]
-// CHECK-SAME:   workgroup_size = [64 : index, 1 : index, 1 : index]
+// CHECK-SAME:   workgroup_size = [512 : index, 1 : index, 1 : index]
 //      CHECK: func.func @subgroup_reduce_f16()
 //      CHECK:   linalg.generic
 // CHECK-SAME:     lowering_config = #[[CONFIG]]
diff --git a/tests/e2e/regression/large_reduction.mlir b/tests/e2e/regression/large_reduction.mlir
index 15e6fe4..cdc24b6 100644
--- a/tests/e2e/regression/large_reduction.mlir
+++ b/tests/e2e/regression/large_reduction.mlir
@@ -49,3 +49,37 @@
   check.expect_almost_eq_const(%result, dense<40.96> : tensor<2xf32>) : tensor<2xf32>
   return
 }
+
+func.func @half_reduction_aligned() {
+  %in = util.unfoldable_constant dense<0.001> : tensor<2x4096xf16>
+  %cst = arith.constant 0.0 : f16
+  %init = tensor.empty() : tensor<2xf16>
+  %fill = linalg.fill ins(%cst : f16) outs(%init : tensor<2xf16>) -> tensor<2xf16>
+  %result = linalg.generic {indexing_maps = [
+    affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0)>],
+    iterator_types = ["parallel", "reduction"]}
+    ins(%in : tensor<2x4096xf16>) outs(%fill : tensor<2xf16>) {
+    ^bb0(%arg3: f16, %arg4: f16):  // no predecessors
+      %2 = arith.addf %arg3, %arg4 : f16
+      linalg.yield %2 : f16
+    } -> tensor<2xf16>
+  check.expect_almost_eq_const(%result, dense<4.096> : tensor<2xf16>) : tensor<2xf16>
+  return
+}
+
+func.func @quarter_reduction_aligned_smaller() {
+  %in = util.unfoldable_constant dense<1> : tensor<128x128xi8>
+  %cst = arith.constant 0 : i8
+  %init = tensor.empty() : tensor<128xi8>
+  %fill = linalg.fill ins(%cst : i8) outs(%init : tensor<128xi8>) -> tensor<128xi8>
+  %result = linalg.generic {indexing_maps = [
+    affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d0)>],
+    iterator_types = ["parallel", "reduction"]}
+    ins(%in : tensor<128x128xi8>) outs(%fill : tensor<128xi8>) {
+    ^bb0(%arg3: i8, %arg4: i8):  // no predecessors
+      %2 = arith.addi %arg3, %arg4 : i8
+      linalg.yield %2 : i8
+    } -> tensor<128xi8>
+  check.expect_eq_const(%result, dense<128> : tensor<128xi8>) : tensor<128xi8>
+  return
+}
\ No newline at end of file