[NFC][GPU] Move SPIRVTile to `Codegen/Common/GPU` (#17142)
This PR moves the SPIRVTile pass to codegen common so it can be used in
LLVMGPU. The current GPUTensorTile pass needs some revisions, so for the
time being this PR introduces GPUTile, which will give some extra needed
functionality.
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 4db9dda..5979b5d 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.h"
#include "iree/compiler/PluginAPI/Client.h"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index ebbf7b3..eb6149e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -66,6 +66,7 @@
"GPUTensorAlloc.cpp",
"GPUTensorTile.cpp",
"GPUTensorTileToSerialLoops.cpp",
+ "GPUTile.cpp",
"GPUTileReduction.cpp",
"GPUVectorAlloc.cpp",
"GPUVectorDistribution.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index ebb98a1..ab3be99 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -64,6 +64,7 @@
"GPUTensorAlloc.cpp"
"GPUTensorTile.cpp"
"GPUTensorTileToSerialLoops.cpp"
+ "GPUTile.cpp"
"GPUTileReduction.cpp"
"GPUVectorAlloc.cpp"
"GPUVectorDistribution.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTile.cpp
similarity index 94%
rename from compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
rename to compiler/src/iree/compiler/Codegen/Common/GPU/GPUTile.cpp
index a92bdd2..fe5c54a 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTile.cpp
@@ -4,17 +4,16 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//===- SPIRVTile.cpp ------------------------------------------------------===//
+//===- GPUTile.cpp ------------------------------------------------------===//
//
// This pass tiles and Linalg ops with tensor semantics to invocations.
//
//===----------------------------------------------------------------------===//
+#include "iree/compiler/Codegen/Common/GPU/PassDetail.h"
+#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
-#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
-#include "iree/compiler/Codegen/SPIRV/Passes.h"
-#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
@@ -32,7 +31,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "iree-spirv-tile"
+#define DEBUG_TYPE "iree-codegen-gpu-tile"
namespace mlir::iree_compiler {
@@ -134,7 +133,7 @@
// after bufferization. So add attributes to the tiled loop nest to
// indicate that they should be distributed to invocations.
ArrayRef<LoopLikeOpInterface> loops = tileAndFuseResult.value().loops;
- const char *attrName = getSPIRVDistributeAttrName();
+ const char *attrName = getGPUDistributeAttrName();
// We can have more than 3 dimensions being tiled (e.g., for convolutions with
// non-1 batch). But only the innermost 3 dimensions are distributed.
for (auto [dim, loop] : zip(llvm::seq(0, 3), llvm::reverse(loops))) {
@@ -255,10 +254,10 @@
namespace {
-class SPIRVTilePass final : public SPIRVTileBase<SPIRVTilePass> {
+class GPUTilePass final : public GPUTileBase<GPUTilePass> {
public:
- SPIRVTilePass() = default;
- SPIRVTilePass(const SPIRVTilePass &pass) = default;
+ GPUTilePass() = default;
+ GPUTilePass(const GPUTilePass &pass) = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
@@ -292,7 +291,7 @@
concretizePadShape(funcOp);
- auto reductionTileComputeFn = getSPIRVScfTileSizeComputeFn(funcOp, 2);
+ auto reductionTileComputeFn = getGPUScfTileSizeComputeFn(funcOp, 2);
if (failed(reductionTileComputeFn) ||
failed(tileReduction(funcOp, reductionTileComputeFn.value()))) {
return signalPassFailure();
@@ -331,9 +330,8 @@
};
} // namespace
-std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createSPIRVTilePass() {
- return std::make_unique<SPIRVTilePass>();
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>> createGPUTilePass() {
+ return std::make_unique<GPUTilePass>();
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
index eabe4b4..7a7325f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
@@ -133,6 +133,9 @@
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGPUTensorTileToSerialLoops();
+/// Pass to tile Linalg ops with tensor semantics to invocations.
+std::unique_ptr<InterfacePass<FunctionOpInterface>> createGPUTilePass();
+
/// Tile reductions and generate serial loops around reductions.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createGPUTileReductionPass();
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index 3bffe60..bc681630 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -108,6 +108,11 @@
let constructor = "mlir::iree_compiler::createGPUTensorTileToSerialLoops()";
}
+def GPUTile : InterfacePass<"iree-codegen-gpu-tile", "mlir::FunctionOpInterface"> {
+ let summary = "Tile Linalg ops with tensor semantics to invocations";
+ let constructor = "mlir::iree_compiler::createGPUTilePass()";
+}
+
def GPUTileReduction :
InterfacePass<"iree-codegen-gpu-tile-reduction", "mlir::FunctionOpInterface"> {
let summary = "Pass to tile linalg reduction dimensions.";
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index ef7467e..9441191 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -31,6 +31,7 @@
"gpu_reorder_workgroups.mlir",
"gpu_tensor_alloc.mlir",
"gpu_tensor_tile.mlir",
+ "gpu_tile.mlir",
"gpu_tile_reduction.mlir",
"gpu_vector_alloc.mlir",
"gpu_vector_distribution.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index 1023626..ff9cbae 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -27,6 +27,7 @@
"gpu_reorder_workgroups_static.mlir"
"gpu_tensor_alloc.mlir"
"gpu_tensor_tile.mlir"
+ "gpu_tile.mlir"
"gpu_tile_reduction.mlir"
"gpu_vector_alloc.mlir"
"gpu_vector_distribution.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_linalg_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tile.mlir
similarity index 98%
rename from compiler/src/iree/compiler/Codegen/SPIRV/test/tile_linalg_ops.mlir
rename to compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tile.mlir
index 59d7826..097c1d0 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_linalg_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_tile.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file --pass-pipeline="builtin.module(func.func(iree-spirv-tile))" %s | FileCheck %s
+// RUN: iree-opt -split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-tile))" %s | FileCheck %s
func.func @innermost_reduction() {
%c1 = arith.constant 1 : index
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h
index de5695e..294619d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h
@@ -16,6 +16,7 @@
#include "iree/compiler/Codegen/Common/PassUtils.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel
index 03001c9..0aeb9cc 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel
@@ -74,7 +74,6 @@
":LoweringConfigGen",
":UKernelOpsGen",
"//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface",
- "//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt
index 05b830b..6fcfef2 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt
@@ -55,7 +55,6 @@
MLIRTransformDialectTransforms
MLIRViewLikeInterface
iree::compiler::Codegen::Interfaces::UKernelOpInterface
- iree::compiler::Codegen::Utils
iree::compiler::Dialect::HAL::IR
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
index e519c78..8a89f94 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
@@ -10,7 +10,6 @@
#ifndef IREE_COMPILER_CODEGEN_DIALECT_LOWERINGCONFIG_H_
#define IREE_COMPILER_CODEGEN_DIALECT_LOWERINGCONFIG_H_
-#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp
index 7604542..19cdffd 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.cpp
@@ -7,7 +7,6 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp
index 1947d3d..9d03f0d 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.cpp
@@ -7,7 +7,6 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
-#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ExpandF16OpToF32Pass.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ExpandF16OpToF32Pass.cpp
index 1c8bbf4..2c2eae5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ExpandF16OpToF32Pass.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ExpandF16OpToF32Pass.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp
index 70831e4..e14807e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPeel.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
index 27f8068..bf263c6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULowerExecutableTarget.cpp
@@ -11,6 +11,7 @@
#include "iree/compiler/Codegen/LLVMGPU/KernelConfig.h"
#include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp
index ff6538d..5951e75 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
index 4519b20..689393c 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp
@@ -8,6 +8,7 @@
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLLowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLLowerExecutableTarget.cpp
index c2cba1c..6675547 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLLowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLLowerExecutableTarget.cpp
@@ -13,6 +13,7 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
index 66b44ab..d922d29 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
@@ -69,7 +69,6 @@
"SPIRVMapMemRefStorageClass.cpp",
"SPIRVMaterializeExecutableConditions.cpp",
"SPIRVSelectLoweringStrategy.cpp",
- "SPIRVTile.cpp",
"SPIRVTileAndDistribute.cpp",
"SPIRVTileAndPromote.cpp",
"SPIRVTileAndVectorizeToCooperativeOps.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
index b1040c2..0bb66be 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
@@ -68,7 +68,6 @@
"SPIRVMapMemRefStorageClass.cpp"
"SPIRVMaterializeExecutableConditions.cpp"
"SPIRVSelectLoweringStrategy.cpp"
- "SPIRVTile.cpp"
"SPIRVTileAndDistribute.cpp"
"SPIRVTileAndPromote.cpp"
"SPIRVTileAndVectorizeToCooperativeOps.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index 524ec19..08b358d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -22,6 +22,7 @@
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 4c02cb3..384982b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -313,7 +313,7 @@
// Tile to GPU invocations and vectorize.
funcPassManager.addPass(createGPUCreateFastSlowPathPass());
- funcPassManager.addPass(createSPIRVTilePass());
+ funcPassManager.addPass(createGPUTilePass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
{
@@ -347,7 +347,7 @@
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
- funcPassManager.addPass(createSPIRVTilePass());
+ funcPassManager.addPass(createGPUTilePass());
funcPassManager.addPass(
IREE::LinalgExt::createDecomposeWinogradTransformPass());
funcPassManager.addPass(createCanonicalizerPass());
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
index 9fd67ac..ba4d1e5 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
@@ -151,9 +151,6 @@
createSPIRVTileAndPromotePass(bool promoteCMatrix = false,
bool skipThreadLevel = false);
-/// Pass to tile Linalg ops with tensor semantics to invocations.
-std::unique_ptr<InterfacePass<FunctionOpInterface>> createSPIRVTilePass();
-
/// Pass to tile Linalg ops with buffer semantics suitable for lowering to
/// SPIR-V cooperative ops.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
index 3ac49b7..f8555e2 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.td
@@ -107,11 +107,6 @@
"mlir::iree_compiler::createSPIRVSelectLoweringStrategyPass()";
}
-def SPIRVTile : InterfacePass<"iree-spirv-tile", "mlir::FunctionOpInterface"> {
- let summary = "Tile Linalg ops with tensor semantics to invocations";
- let constructor = "mlir::iree_compiler::createSPIRVTilePass()";
-}
-
def SPIRVTileAndDistribute : InterfacePass<"iree-spirv-tile-and-distribute", "mlir::FunctionOpInterface"> {
let summary = "Tile and distribute Linalg ops with buffer semantics to "
"invocations";
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp
index c79a17e..6e8ec1f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVAnnotateWinogradLoops.cpp
@@ -29,7 +29,7 @@
MLIRContext *context = &getContext();
OpBuilder builder(context);
- const char *attrName = getSPIRVDistributeAttrName();
+ const char *attrName = getGPUDistributeAttrName();
for (auto [index, forOp] : llvm::enumerate(forOps)) {
if (index > kNumGPUDims)
break;
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVDistribute.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVDistribute.cpp
index 39529e8..b337270 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVDistribute.cpp
@@ -6,7 +6,7 @@
//===- SPIRVDistribute.cpp ------------------------------------------------===//
//
-// This pass distributes tiled loop nests with `iree.spirv.distribute_dim`
+// This pass distributes tiled loop nests with `iree.gpu.distribute_dim`
// attributes to invocations.
//
//===----------------------------------------------------------------------===//
@@ -14,6 +14,7 @@
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -34,7 +35,7 @@
PatternRewriter &rewriter) const override {
// Only distribute if we see the marker attribute.
auto numDimAttr =
- forOp->getAttrOfType<IntegerAttr>(getSPIRVDistributeAttrName());
+ forOp->getAttrOfType<IntegerAttr>(getGPUDistributeAttrName());
if (!numDimAttr)
return failure();
@@ -61,7 +62,7 @@
forOp.getLowerBoundMutable().assign(newLb);
forOp.getStepMutable().assign(newStep);
// Remove the attribute to avoid endless recursion.
- forOp->removeAttr(getSPIRVDistributeAttrName());
+ forOp->removeAttr(getGPUDistributeAttrName());
return success();
}
};
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
index ca33d80..34e7a11 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLowerExecutableTargetPass.cpp
@@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/SPIRV/KernelConfig.h"
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp
index e9d2be0..cb72a7c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVSelectLoweringStrategy.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp
index 7c1b7e6..102c442 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Verifiers.cpp
@@ -8,6 +8,7 @@
#include "iree/compiler/Codegen/SPIRV/PassDetail.h"
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/Passes.h"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
index a9a3b39..75bc3c0 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel
@@ -70,7 +70,6 @@
"tile_and_vectorize_matmul.mlir",
"tile_and_vectorize_pooling.mlir",
"tile_and_vectorize_to_cooperative_ops.mlir",
- "tile_linalg_ops.mlir",
"trim_executable_target_env.mlir",
"vectorize_conv.mlir",
"vectorize_elementwise_ops.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
index 4e41326..b2deb6d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
@@ -66,7 +66,6 @@
"tile_and_vectorize_matmul.mlir"
"tile_and_vectorize_pooling.mlir"
"tile_and_vectorize_to_cooperative_ops.mlir"
- "tile_linalg_ops.mlir"
"trim_executable_target_env.mlir"
"vectorize_conv.mlir"
"vectorize_elementwise_ops.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/annotate_winograd_loops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/annotate_winograd_loops.mlir
index d391989..d2c04e8 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/annotate_winograd_loops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/annotate_winograd_loops.mlir
@@ -109,11 +109,11 @@
// CHECK-SAME: %[[ARG4]], %[[ARG6]]] [8, 8, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<8x8xf32> into
// CHECK-SAME: tensor<8x8x1x2x2x32xf32>
// CHECK: scf.yield %[[INSERTED_SLICE_3]] : tensor<8x8x1x2x2x32xf32>
-// CHECK: } {iree.spirv.distribute_dim = 0 : index}
+// CHECK: } {iree.gpu.distribute_dim = 0 : index}
// CHECK: scf.yield %[[D13]] : tensor<8x8x1x2x2x32xf32>
-// CHECK: } {iree.spirv.distribute_dim = 1 : index}
+// CHECK: } {iree.gpu.distribute_dim = 1 : index}
// CHECK: scf.yield %[[D10]] : tensor<8x8x1x2x2x32xf32>
-// CHECK: } {iree.spirv.distribute_dim = 2 : index}
+// CHECK: } {iree.gpu.distribute_dim = 2 : index}
// CHECK: flow.dispatch.tensor.store %[[D7]], %[[D2]], offsets = [0, 0, %[[ARG0]], 0, 0, %[[ARG1]]], sizes =
// CHECK-SAME: [8, 8, 1, 2, 2, 32], strides = [1, 1, 1, 1, 1, 1] : tensor<8x8x1x2x2x32xf32> ->
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<8x8x2x2x2x1280xf32>>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/distribute_to_invocations.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/distribute_to_invocations.mlir
index 87ed573..456b82e 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/distribute_to_invocations.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/distribute_to_invocations.mlir
@@ -9,7 +9,7 @@
%init = tensor.empty() : tensor<2x128xf32>
scf.for %iv = %lb to %ub step %step {
memref.store %zero, %output[%iv] : memref<?xf32>
- } {iree.spirv.distribute_dim = 0 : index}
+ } {iree.gpu.distribute_dim = 0 : index}
return
}
@@ -34,7 +34,7 @@
%init = tensor.empty() : tensor<2x128xf32>
scf.for %iv = %lb to %ub step %step {
memref.store %zero, %output[%iv] : memref<?xf32>
- } {iree.spirv.distribute_dim = 1 : index}
+ } {iree.gpu.distribute_dim = 1 : index}
return
}
@@ -59,7 +59,7 @@
%init = tensor.empty() : tensor<2x128xf32>
scf.for %iv = %lb to %ub step %step {
memref.store %zero, %output[%iv] : memref<?xf32>
- } {iree.spirv.distribute_dim = 2 : index}
+ } {iree.gpu.distribute_dim = 2 : index}
return
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
index 81c9919..9ffdc68 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_batch_matmul.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
+// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-gpu-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
// RUN: %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[1, 8, 64], [1, 8, 4], [0, 0, 0, 4]]>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
index 819e63e..4351407 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-gpu-create-fast-slow-path,iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
+// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-gpu-create-fast-slow-path,iree-codegen-gpu-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
// RUN: %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 4, 4, 16], [0, 2, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
index 16a77a7..7a60449 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_matmul.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse)))))' %s | FileCheck %s
+// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-gpu-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse)))))' %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[8, 64], [8, 4], [0, 0, 4]]>
#translation = #iree_codegen.translation_info<SPIRVBaseVectorize>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir
index 1222e0e..2149860 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_pooling.mlir
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file \
-// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-spirv-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
+// RUN: --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-gpu-tile,canonicalize,cse,iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse)))))' \
// RUN: %s | FileCheck %s
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 2, 2, 8], [0, 1, 1, 4], [0, 0, 0, 0, 1, 1], [0, 1, 0, 0]]>
diff --git a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
index 20c69c2..71a81af 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Utils/BUILD.bazel
@@ -33,6 +33,7 @@
"Utils.h",
],
deps = [
+ "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Interfaces:PartitionableLoopsInterface",
"//compiler/src/iree/compiler/Codegen/Interfaces:ProcessorOpInterfaces",
"//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface",
diff --git a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
index 0c22b45..2d54729 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Utils/CMakeLists.txt
@@ -51,6 +51,7 @@
MLIRTransformUtils
MLIRVectorDialect
MLIRViewLikeInterface
+ iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Interfaces::PartitionableLoopsInterface
iree::compiler::Codegen::Interfaces::ProcessorOpInterfaces
iree::compiler::Codegen::Interfaces::UKernelOpInterface
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index 94c0e9a..f9c55a6 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -150,6 +150,41 @@
}
//===----------------------------------------------------------------------===//
+// GPU tiling and distribution
+//===----------------------------------------------------------------------===//
+
+const char *getGPUDistributeAttrName() { return "iree.gpu.distribute_dim"; }
+
+FailureOr<SmallVector<int64_t>> getGPUTileSize(mlir::FunctionOpInterface funcOp,
+ int tilingLevel) {
+ SmallVector<Operation *> computeOps = getComputeOps(funcOp);
+ auto config = getLoweringConfig(computeOps);
+ if (failed(config)) {
+ return funcOp.emitOpError("failed to get lowering configuration");
+ }
+
+ return config->getTileSizeVals(tilingLevel);
+}
+
+FailureOr<scf::SCFTileSizeComputationFunction>
+getGPUScfTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel) {
+ FailureOr<SmallVector<int64_t>> tileSizes =
+ getGPUTileSize(funcOp, tilingLevel);
+ if (failed(tileSizes))
+ return failure();
+ scf::SCFTileSizeComputationFunction computeFn =
+ [tileSizes](OpBuilder &builder,
+ Operation *op) -> SmallVector<OpFoldResult> {
+ auto tileSizesOfr = getAsIndexOpFoldResult(op->getContext(), *tileSizes);
+ auto zeroAttr = builder.getIndexAttr(0);
+ int numLoops = cast<TilingInterface>(op).getLoopIteratorTypes().size();
+ tileSizesOfr.resize(numLoops, zeroAttr);
+ return tileSizesOfr;
+ };
+ return computeFn;
+}
+
+//===----------------------------------------------------------------------===//
// GPU workgroup memory
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
index f03ca93..4b0a751 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
@@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_
#define IREE_COMPILER_CODEGEN_UTILS_GPUUTILS_H_
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -51,6 +52,23 @@
gpuMmaUnrollOrder(vector::ContractionOp contract);
//===----------------------------------------------------------------------===//
+// GPU tiling and distribution
+//===----------------------------------------------------------------------===//
+
+/// Returns the attribute name carrying information about distribution.
+const char *getGPUDistributeAttrName();
+
+/// Returns the tile sizes at the given `tilingLevel` for compute ops in
+/// `funcOp`.
+FailureOr<SmallVector<int64_t>> getGPUTileSize(mlir::FunctionOpInterface funcOp,
+ int tilingLevel);
+
+/// Returns the functor to compute tile sizes at the given `tilingLevel` for
+/// compute ops in `funcOp`.
+FailureOr<scf::SCFTileSizeComputationFunction>
+getGPUScfTileSizeComputeFn(mlir::FunctionOpInterface funcOp, int tilingLevel);
+
+//===----------------------------------------------------------------------===//
// GPU workgroup memory
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/BUILD.bazel b/compiler/src/iree/compiler/Codegen/VMVX/BUILD.bazel
index 52f5005..17bf919 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/VMVX/BUILD.bazel
@@ -87,6 +87,7 @@
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:TransformUtils",
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt
index 391caba..de03492 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt
@@ -67,6 +67,7 @@
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRPass
+ MLIRSCFDialect
MLIRTensorDialect
MLIRTensorTransforms
MLIRTransformUtils
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp
index 0eddc3a..e487e98 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerExecutableTargetPass.cpp
@@ -7,6 +7,7 @@
#include "iree/compiler/Codegen/Common/PassUtils.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Codegen/VMVX/PassDetail.h"
#include "iree/compiler/Codegen/VMVX/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXSelectLoweringStrategy.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXSelectLoweringStrategy.cpp
index 5b25180..376622e 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXSelectLoweringStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXSelectLoweringStrategy.cpp
@@ -13,6 +13,8 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index c84ec09..1f5eb15 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -23,6 +23,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"