[tilingInterface] Update the tile sizes to i64 attr type (#17761)
Update the tile sizes to contain i64Attrs instead of arith.constant.
Somehow it's giving dynamic shapes in tensor.extract_slice since the
arith.constant op isn't folded or seen as a constant.
To fix Issue: https://github.com/iree-org/iree/issues/17441
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
index 2e76e9f..a8611dc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
@@ -65,6 +65,19 @@
return {distributeLB, distributeStep};
}
+// Helper function to change arith.constant to i64 attribute.
+static void changeArithCstToI64Attr(OpBuilder &b,
+ MutableArrayRef<OpFoldResult> constants) {
+ for (OpFoldResult &val : constants) {
+ if (auto dyn_cast = llvm::dyn_cast_if_present<Value>(val)) {
+ APInt intVal;
+ if (matchPattern(dyn_cast, m_ConstantInt(&intVal))) {
+ val = b.getI64IntegerAttr(intVal.getSExtValue());
+ }
+ }
+ }
+}
+
//===----------------------------------------------------------------------===//
// TileDispatchUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
@@ -166,6 +179,13 @@
loops.push_back(loop);
builder.setInsertionPoint(loop.getBody()->getTerminator());
}
+
+ // Update the sizes if it contains arith.index with i64 attrs.
+ // TODO: tensor.extract_slice is unable to determine the
+ // result type if arith.constant is present. This is a workaround
+ // to ensure that the result type is determined.
+ changeArithCstToI64Attr(builder, sizes);
+
return loops;
}
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir
index ee4c519..fe2ddfd 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir
@@ -40,7 +40,6 @@
// CHECK: %[[DIM_X:.+]] = gpu.block_dim x
// CHECK: scf.for %[[IV_X:.+]] = %[[TID_X]] to %{{.+}} step %[[DIM_X]]
// CHECK: %[[DEST:.+]] = memref.subview %[[WG_OUTPUT]][0, 0, %[[IV_X]]]
-// CHECK: %[[CAST:.+]] = memref.cast %[[DEST]]
// CHECK: iree_linalg_ext.sort
// CHECK-SAME: dimension(1)
-// CHECK-SAME: outs(%[[CAST]]
+// CHECK-SAME: outs(%[[DEST]]