[LLVMGPU] Combine parallel and reduction padding in LLVMGPUPadAndVectorDistribute (#18771)
Since https://github.com/iree-org/iree/pull/18748 tensor.pad can be
fused in with tiling. This patch combines the parallel and reduction
padding passes into a single pass that pads at once, and the pads are
later fused during tiling.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
index dbcc5b1..2421494 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUPromoteMatmulToFitMMA.cpp
@@ -27,25 +27,18 @@
public:
using impl::LLVMGPUPromoteMatmulToFitMMAPassBase<
LLVMGPUPromoteMatmulToFitMMAPass>::LLVMGPUPromoteMatmulToFitMMAPassBase;
- explicit LLVMGPUPromoteMatmulToFitMMAPass(
- const LLVMGPUMatmulPadOption &option) {
- this->targetDimensions.setValue(option);
- }
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<tensor::TensorDialect, linalg::LinalgDialect>();
}
void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op,
- ArrayRef<int64_t> paddingDims,
- ArrayRef<int64_t> padToMultipleOf, bool noFold) const {
- assert(paddingDims.size() == padToMultipleOf.size() &&
- "invalid pad multiples for padding dimensions");
-
+ ArrayRef<int64_t> padToMultipleOf) const {
LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n");
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);
- SmallVector<bool> nofoldFlags(op.getNumDpsInputs(), noFold);
+ SmallVector<int64_t> paddingDims =
+ llvm::to_vector(llvm::seq<int64_t>(padToMultipleOf.size()));
SmallVector<Attribute> paddingValueAttributes;
for (auto &operand : op->getOpOperands()) {
@@ -58,7 +51,6 @@
.setPaddingDimensions(paddingDims)
.setPaddingValues(paddingValueAttributes)
.setPadToMultipleOf(padToMultipleOf)
- .setNofoldFlags(nofoldFlags)
.setCopyBackOp(linalg::LinalgPaddingOptions::CopyBackOp::None);
FailureOr<linalg::LinalgOp> result =
@@ -72,26 +64,6 @@
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();
- // Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so
- // we can kick canonicalization patterns to fold outer tensor.pad ops away.
- bool noFold = false;
- utils::IteratorType targetIterType = utils::IteratorType::parallel;
- switch (targetDimensions) {
- case LLVMGPUMatmulPadOption::ParallelDims:
- LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n");
- targetIterType = utils::IteratorType::parallel;
- noFold = false;
- break;
- case LLVMGPUMatmulPadOption::ReductionDims:
- LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n");
- targetIterType = utils::IteratorType::reduction;
- noFold = true;
- break;
- default: // Unreachable.
- assert(false);
- break;
- };
-
SmallVector<linalg::LinalgOp> candidates;
funcOp->walk([&](linalg::LinalgOp op) {
if (linalg::isaContractionOpInterface(op)) {
@@ -101,46 +73,27 @@
IRRewriter rewriter(ctx);
for (linalg::LinalgOp op : candidates) {
- SmallVector<int64_t> padMultiples(op.getNumLoops(), 1);
auto config = dyn_cast_or_null<IREE::GPU::LoweringConfigAttr>(
getLoweringConfig(op));
- if (config) {
- switch (targetDimensions) {
- case LLVMGPUMatmulPadOption::ParallelDims:
- padMultiples = config.getStaticTilingLevelSizes(
- static_cast<unsigned>(IREE::GPU::TilingLevel::Workgroup), op);
- break;
- case LLVMGPUMatmulPadOption::ReductionDims:
- padMultiples = config.getStaticTilingLevelSizes(
- static_cast<unsigned>(IREE::GPU::TilingLevel::Reduction), op);
- break;
- default:
- assert(false && "Unexpected target dimensions");
- break;
- }
+ if (!config) {
+ continue;
}
- // Populate padding dimensions.
- SmallVector<int64_t> paddingDimensions;
- for (auto [idx, iter] : llvm::enumerate(op.getIteratorTypesArray())) {
- if (iter == targetIterType) {
- paddingDimensions.push_back(idx);
- }
- }
+ SmallVector<int64_t> wgTiles = config.getStaticTilingLevelSizes(
+ static_cast<unsigned>(IREE::GPU::TilingLevel::Workgroup), op);
+ SmallVector<int64_t> redTiles = config.getStaticTilingLevelSizes(
+ static_cast<unsigned>(IREE::GPU::TilingLevel::Reduction), op);
- // Populate tile sizes. We pad to multiples of workgroup/reduction
- // tile sizes based on the selected target tiling dimensions.
- // This pass is ran after the select target tiling is done to pad
- // all dimensions to the select tile sizes.
- SmallVector<int64_t> padToMultipleOf;
- for (int64_t dim : paddingDimensions) {
- if (padMultiples[dim] != 0) {
- padToMultipleOf.push_back(padMultiples[dim]);
- }
+ // Populate padding dimensions to maximum of possible tile sizes.
+ SmallVector<int64_t> padToMultipleOf(op.getNumLoops(), 1);
+ for (auto [wgTile, redTile, padMultiple] :
+ llvm::zip_equal(wgTiles, redTiles, padToMultipleOf)) {
+ padMultiple = std::max({wgTile, redTile, padMultiple});
}
+ SmallVector<int64_t> paddingDimensions =
+ llvm::to_vector(llvm::seq<int64_t>(op.getNumLoops()));
- padWithZeroValue(rewriter, op, paddingDimensions, padToMultipleOf,
- noFold);
+ padWithZeroValue(rewriter, op, padToMultipleOf);
}
{
@@ -156,58 +109,8 @@
return signalPassFailure();
}
}
-
- // XXX(hanchung): This is needed for pad op fusion, which will remove
- // outer pad ops. I.e., it mainly wants to remove first pad op in the
- // pad->extract_slice->pad chain, while the canonicalization pattern can
- // only recognize slice->pad->slice->pad.
- {
- SmallVector<tensor::PadOp> padOps;
- funcOp.walk([&](tensor::PadOp op) { padOps.push_back(op); });
- for (auto op : padOps) {
- auto srcExtractSliceOp =
- op.getSource().getDefiningOp<tensor::ExtractSliceOp>();
- if (!srcExtractSliceOp) {
- continue;
- }
- auto producerPadOp =
- srcExtractSliceOp.getSource().getDefiningOp<tensor::PadOp>();
- if (!producerPadOp) {
- continue;
- }
- auto src = producerPadOp.getSource()
- .getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
- if (!src) {
- continue;
- }
-
- rewriter.setInsertionPointAfter(src);
- SmallVector<OpFoldResult> sizes =
- tensor::getMixedSizes(rewriter, op.getLoc(), src);
- SmallVector<OpFoldResult> offsets(sizes.size(),
- rewriter.getIndexAttr(0));
- SmallVector<OpFoldResult> strides(sizes.size(),
- rewriter.getIndexAttr(1));
- auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
- op.getLoc(), src.getResult(), offsets, sizes, strides);
- rewriter.startOpModification(op);
- producerPadOp.getSourceMutable().assign(extractSliceOp.getResult());
- rewriter.finalizeOpModification(op);
- }
-
- RewritePatternSet patterns(ctx);
- tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
- return signalPassFailure();
- }
- }
}
};
} // namespace
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option) {
- return std::make_unique<LLVMGPUPromoteMatmulToFitMMAPass>(option);
-}
-
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 76b1af3..51fcc6b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -858,25 +858,20 @@
funcPassManager.addPass(createCSEPass());
if (usePadToModelSharedMemcpy) {
- LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ParallelDims;
- funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option));
+ funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass());
}
// Tile to reduction loops.
{
GPUApplyTilingLevelPassOptions options;
options.tilingLevel = IREE::GPU::TilingLevel::Reduction;
+ options.allowZeroSlices = true;
funcPassManager.addPass(createGPUApplyTilingLevelPass(options));
funcPassManager.addPass(affine::createLoopCoalescingPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
}
- if (usePadToModelSharedMemcpy) {
- LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ReductionDims;
- funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option));
- }
-
funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index c118177..d932564 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -103,10 +103,6 @@
// Wrappers that not use tablegen options.
//------------------------------------------------------------------------------
-enum class LLVMGPUMatmulPadOption { ParallelDims, ReductionDims };
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option);
-
enum class GPUTensorCoreType {
WMMA = 0,
MMA_SYNC = 1,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index ef51a6a..815a82f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -105,19 +105,6 @@
def LLVMGPUPromoteMatmulToFitMMAPass :
InterfacePass<"iree-llvmgpu-promote-matmul-to-fit-mma", "mlir::FunctionOpInterface"> {
let summary = "Pass to promote contraction ops to fit mma shapes";
- let options = [
- Option<"targetDimensions", "target-dimensions", "mlir::iree_compiler::LLVMGPUMatmulPadOption",
- /*default=*/"mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims",
- "Select the strategy to control how multi_reduction is lowered.",
- [{::llvm::cl::values(
- clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims,
- "parallel",
- "Pad all the parallel dims for contraction ops."),
- clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ReductionDims,
- "reduction",
- "Pad all the reduction dims for contraction ops.")
- )}]>
- ];
}
def LLVMGPUSelectLoweringStrategyPass :
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
index 610e114..d21faf8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx940.mlir
@@ -511,7 +511,7 @@
// CHECK: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
// CHECK: vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]]
// CHECK: vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]]
-// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1265 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
+// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1280 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
// CHECK-DAG: %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
// CHECK-DAG: %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
// CHECK: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]
@@ -581,9 +581,11 @@
// CHECK-SAME: memref<196x16x24xf32
// CHECK-SAME: vector<1x1x1xf32>
// RHS
+// The dynamic dimension should be removed after:
+// https://github.com/llvm/llvm-project/pull/112236
// CHECK: vector.transfer_read
-// CHECK-SAME: in_bounds = [true, true, false]
-// CHECK-SAME: memref<1x8x24xf32
+// CHECK-SAME: in_bounds = [true, false, false]
+// CHECK-SAME: memref<1x?x24xf32
// CHECK-SAME: vector<1x1x2xf32>
// CHECK: scf.yield
// OUTPUT
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir
index bda4836..21bc2fc 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/promote_matmul_to_fit_mma.mlir
@@ -1,5 +1,4 @@
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=parallel}))" %s | FileCheck %s --check-prefixes=ALL,PARALLEL
-// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma{target-dimensions=reduction}))" %s | FileCheck %s --check-prefixes=ALL,REDUCTION
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-llvmgpu-promote-matmul-to-fit-mma))" %s | FileCheck %s
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -34,114 +33,20 @@
flow.dispatch.tensor.store %11, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
return
}
-// ALL-LABEL: func.func @batch_matmul_f16
-// ALL: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
-// ALL: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
-// ALL: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
-// ALL-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]]
-// ALL-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]]
-// PARALLEL: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]]
-// PARALLEL: } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16>
-// PARALLEL: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]]
-// PARALLEL: } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16>
-// PARALLEL: %[[FILL:.+]] = linalg.fill
-// PARALLEL: %[[GEMM:.+]] = linalg.batch_matmul
-// PARALLEL-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
-// PARALLEL-SAME: outs(%[[FILL]]
+// CHECK-LABEL: func.func @batch_matmul_f16
+// CHECK: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
+// CHECK: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
+// CHECK: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]]
+// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]]
+// CHECK: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]]
+// CHECK: } : tensor<1x?x1281xf16> to tensor<1x64x1296xf16>
+// CHECK: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]]
+// CHECK: } : tensor<1x1281x?xf16> to tensor<1x1296x128xf16>
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: %[[GEMM:.+]] = linalg.batch_matmul
+// CHECK-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
+// CHECK-SAME: outs(%[[FILL]]
-// The reduction dim is not tiled in the test case, so it pads it to the
-// matmul intrinsic k.
-// REDUCTION-DAG: %[[FILL_DEST:.+]] = flow.dispatch.tensor.load %[[OUT_HANDLE]]
-// REDUCTION: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[FILL_DEST]]
-// REDUCTION: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS]]
-// REDUCTION: } : tensor<1x?x1281xf16> to tensor<1x?x1296xf16>
-// REDUCTION: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS]]
-// REDUCTION: } : tensor<1x1281x?xf16> to tensor<1x1296x?xf16>
-// REDUCTION: %[[GEMM:.+]] = linalg.batch_matmul
-// REDUCTION-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
-// REDUCTION-SAME: outs(%[[FILL]]
-
-// ALL: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[GEMM]]
-// ALL: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]]
-
-// -----
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>
-]>
-#map = affine_map<()[s0] -> (s0 * 64)>
-#map1 = affine_map<()[s0] -> (s0 * 128)>
-#map2 = affine_map<()[s0] -> (s0 * -64 + 968, 64)>
-#map3 = affine_map<()[s0] -> (s0 * -128 + 1281, 128)>
-#map4 = affine_map<()[s0] -> (-s0 + 64)>
-#map5 = affine_map<()[s0] -> (-s0 + 128)>
-#map6 = affine_map<(d0) -> (-d0 + 1281, 64)>
-func.func @batch_matmul_pad_reduction_after_tiling() {
- %c64 = arith.constant 64 : index
- %c1281 = arith.constant 1281 : index
- %c2 = arith.constant 2 : index
- %c1 = arith.constant 1 : index
- %cst = arith.constant 0.000000e+00 : f16
- %c0 = arith.constant 0 : index
- %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
- %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
- %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
- %workgroup_id_z = hal.interface.workgroup.id[2] : index
- %workgroup_id_y = hal.interface.workgroup.id[1] : index
- %3 = affine.apply #map()[%workgroup_id_y]
- %workgroup_id_x = hal.interface.workgroup.id[0] : index
- %4 = affine.apply #map1()[%workgroup_id_x]
- %5 = affine.min #map2()[%workgroup_id_y]
- %6 = affine.min #map3()[%workgroup_id_x]
- %7 = flow.dispatch.tensor.load %0, offsets = [%workgroup_id_z, %3, 0], sizes = [1, %5, 1281], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>> -> tensor<1x?x1281xf16>
- %dim = tensor.dim %7, %c1 : tensor<1x?x1281xf16>
- %8 = flow.dispatch.tensor.load %1, offsets = [%workgroup_id_z, 0, %4], sizes = [1, 1281, %6], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>> -> tensor<1x1281x?xf16>
- %dim_0 = tensor.dim %8, %c2 : tensor<1x1281x?xf16>
- %9 = affine.apply #map4()[%5]
- %padded = tensor.pad %7 low[0, 0, 0] high[0, %9, 0] {
- ^bb0(%arg0: index, %arg1: index, %arg2: index):
- tensor.yield %cst : f16
- } : tensor<1x?x1281xf16> to tensor<1x64x1281xf16>
- %10 = affine.apply #map5()[%6]
- %padded_2 = tensor.pad %8 low[0, 0, 0] high[0, 0, %10] {
- ^bb0(%arg0: index, %arg1: index, %arg2: index):
- tensor.yield %cst : f16
- } : tensor<1x1281x?xf16> to tensor<1x1281x128xf16>
- %11 = tensor.empty() : tensor<1x64x128xf16>
- %12 = linalg.fill ins(%cst : f16) outs(%11 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16>
- %13 = scf.for %arg0 = %c0 to %c1281 step %c64 iter_args(%arg1 = %12) -> (tensor<1x64x128xf16>) {
- %14 = affine.min #map6(%arg0)
- %extracted_slice_4 = tensor.extract_slice %padded[0, 0, %arg0] [1, 64, %14] [1, 1, 1] : tensor<1x64x1281xf16> to tensor<1x64x?xf16>
- %extracted_slice_5 = tensor.extract_slice %padded_2[0, %arg0, 0] [1, %14, 128] [1, 1, 1] : tensor<1x1281x128xf16> to tensor<1x?x128xf16>
- %15 = linalg.batch_matmul ins(%extracted_slice_4, %extracted_slice_5 : tensor<1x64x?xf16>, tensor<1x?x128xf16>) outs(%arg1 : tensor<1x64x128xf16>) -> tensor<1x64x128xf16>
- scf.yield %15 : tensor<1x64x128xf16>
- }
- %extracted_slice_3 = tensor.extract_slice %13[0, 0, 0] [1, %5, %6] [1, 1, 1] : tensor<1x64x128xf16> to tensor<1x?x?xf16>
- flow.dispatch.tensor.store %extracted_slice_3, %2, offsets = [%workgroup_id_z, %3, %4], sizes = [1, %5, %6], strides = [1, 1, 1] : tensor<1x?x?xf16> -> !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
- return
-}
-// The padding on parallel dims is a nop because they are already padded. Skip
-// the check for the testcase.
-// ALL-LABEL: func.func @batch_matmul_pad_reduction_after_tiling
-// ALL: %[[LHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x968x1281xf16>>
-// ALL: %[[RHS_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64x1281x1281xf16>>
-// ALL: %[[OUT_HANDLE:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64x968x1281xf16>>
-// ALL-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_HANDLE]]
-// ALL-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_HANDLE]]
-// REDUCTION: %[[INIT:.+]] = tensor.empty() : tensor<1x64x128xf16>
-// REDUCTION: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[INIT]]
-// REDUCTION: %[[RES:.+]] = scf.for {{.+}} iter_args(%[[ITER:.+]] = %[[FILL]])
-// REDUCTION: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]
-// REDUCTION: %[[PADDED_LHS:.+]] = tensor.pad %[[LHS_SLICE]]
-// REDUCTION: } : tensor<1x?x?xf16> to tensor<1x64x64xf16>
-// REDUCTION: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]
-// REDUCTION: %[[PADDED_RHS:.+]] = tensor.pad %[[RHS_SLICE]]
-// REDUCTION: } : tensor<1x?x?xf16> to tensor<1x64x128xf16>
-// REDUCTION: %[[GEMM:.+]] = linalg.batch_matmul
-// REDUCTION-SAME: ins(%[[PADDED_LHS]], %[[PADDED_RHS]]
-// REDUCTION-SAME: outs(%[[ITER]]
-// REDUCTION: scf.yield %[[GEMM]]
-// REDUCTION: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[RES]]
-// REDUCTION: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]]
+// CHECK: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[GEMM]]
+// CHECK: flow.dispatch.tensor.store %[[OUT_SLICE]], %[[OUT_HANDLE]]