[Codegen] Check for workgroup level tile sizes in workgroup tiling (#18538)
TileDispatchUsingForall relies on lowering configurations having
workgroup level tile sizes, so this PR adds the additional check that
the tilableOp has workgroup level tile sizes. It also adds verification
that there is only one op with a workgroup tiling level.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
index cce23a8..03d2780 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp
@@ -53,11 +53,18 @@
static FailureOr<TilingInfo>
getTiledAndDistributionInfo(RewriterBase &rewriter,
ArrayRef<Operation *> computeOps) {
+ // It is expected that at most one compute op has a workgroup tiling level.
Operation *tilableOp = nullptr;
for (Operation *op : llvm::reverse(computeOps)) {
if (getLoweringConfig(op)) {
+ if (!getLoweringConfig(op).hasWorkgroupTilingLevel()) {
+ continue;
+ }
+ if (tilableOp) {
+ return op->emitOpError("expected only one op with a workgroup tiling"
+ "level.");
+ }
tilableOp = op;
- break;
}
}
if (!tilableOp) {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
index 3ba545e..6c80f50 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
@@ -323,6 +323,10 @@
return !getTileSizeVals(level).empty();
}
+bool LoweringConfigAttr::hasWorkgroupTilingLevel() const {
+ return !getWorkgroupTileSizes().empty();
+}
+
LogicalResult
LoweringConfigAttr::verify(function_ref<InFlightDiagnostic()> emitError,
LoweringConfigTilingLevelsAttr levels,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index 48c3917..ab4648b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -296,6 +296,7 @@
"getStaticTilingLevelSizes",
"getTilingLevelSizes",
"hasTilingLevel",
+ "hasWorkgroupTilingLevel",
]>
]> {
let mnemonic = "lowering_config";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
index ee67fab..2f8dffb 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
@@ -62,6 +62,19 @@
>,
InterfaceMethod<
/*desc=*/[{
+ Returns true if the lowering config specifies tile sizes for the
+ workgroup tiling level.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasWorkgroupTilingLevel",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return false;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
Returns the tile sizes for the specified tiling level. The
interpretation of |level| is attribute and backend dependent. The
|target| is the operation this lowering configuration annotates.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
index 7b7afd1..0a73886 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
@@ -1348,6 +1348,10 @@
.empty();
}
+bool LoweringConfigAttr::hasWorkgroupTilingLevel() const {
+ return !getWorkgroupTileSizes().empty();
+}
+
constexpr StringLiteral kMmaKindName = "mma_kind";
IREE::GPU::MmaInterfaceAttr LoweringConfigAttr::getMmaKind() const {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
index 1e25185..e91fe46 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td
@@ -39,6 +39,7 @@
"getStaticTilingLevelSizes",
"getTilingLevelSizes",
"hasTilingLevel",
+ "hasWorkgroupTilingLevel",
]>
]> {
let mnemonic = "lowering_config";