[Encoding] Add resolver swizzle verification (#22867)
Signed-off-by: Jorn Tuyls <jorn.tuyls@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp
index 11c6ca8..662295c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp
@@ -7,9 +7,41 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/DialectImplementation.h"
// clang-format off
#define GET_TYPEDEF_CLASSES
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp.inc" // IWYU pragma: export
// clang-format on
+
+namespace mlir::iree_compiler::IREE::Codegen {
+
+int64_t TileSwizzle::getExpandedSize() const {
+ int64_t totalExpandedDims = 0;
+ for (const ExpandShapeDimVectorType &expandDims : expandShape) {
+ totalExpandedDims += static_cast<int64_t>(expandDims.size());
+ }
+ return totalExpandedDims;
+}
+
+LogicalResult
+TileSwizzle::verify(function_ref<InFlightDiagnostic()> emitError) const {
+ int64_t totalExpandedDims = getExpandedSize();
+
+ // The permutation size must match the total expanded dimensions.
+ if (static_cast<int64_t>(permutation.size()) != totalExpandedDims) {
+ return emitError() << "swizzle permutation size (" << permutation.size()
+ << ") does not match total expanded dimensions ("
+ << totalExpandedDims << ")";
+ }
+
+ // Check that permutation is valid.
+ if (!isPermutationVector(permutation)) {
+ return emitError() << "swizzle permutation is not a valid permutation";
+ }
+
+ return success();
+}
+
+} // namespace mlir::iree_compiler::IREE::Codegen
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
index a5a8e81..037db02 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
@@ -11,6 +11,7 @@
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h"
// clang-format off
@@ -48,7 +49,7 @@
// "unrolled" across subgroups. Such dimensions are cross-subgroup, so in
// particular they are cross-thread.
CrossThread,
- // This dimensions is across intrinsics, as in, actual instructions in the
+ // This dimension is across intrinsics, as in, actual instructions in the
// generated code. In other words, it is an actual unrolling factor,
// resulting in this many more instructions being generated and executed
// on each thread/subgroup.
@@ -82,8 +83,8 @@
: kind(kind), size(size), distributionSize(distributionSize) {}
};
- using ExpandShapeDimVectorType = llvm::SmallVector<Dim, 4>;
- using ExpandShapeType = llvm::SmallVector<ExpandShapeDimVectorType>;
+ using ExpandShapeDimVectorType = SmallVector<Dim, 4>;
+ using ExpandShapeType = SmallVector<ExpandShapeDimVectorType>;
// This vector-of-vectors contains all the information needed to generate
// a `tensor.expand_shape` creating additional internal dimensions into the
@@ -96,7 +97,15 @@
// to generate a `linalg.transpose` changing the layout of the tile. For
// example, permutation[0] dictates which of the expanded dimensions becomes
// the leading dimension of the layout.
- llvm::SmallVector<int64_t> permutation;
+ SmallVector<int64_t> permutation;
+
+ // Returns the total number of expanded dimensions.
+ int64_t getExpandedSize() const;
+
+ // Verifies consistency of the tile swizzle:
+ // - The permutation size must match the total number of expanded dimensions.
+ // - The permutation indices must be valid (within bounds and unique).
+ LogicalResult verify(function_ref<InFlightDiagnostic()> emitError) const;
};
using ScalableTileFlags = SmallVector<bool>;
diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Utils.h b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Utils.h
index 6cf64d1..a2a3424 100644
--- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Utils.h
@@ -109,6 +109,41 @@
return failure();
}
+ // Verify swizzle if present.
+ if (info->swizzle) {
+ const IREE::Codegen::TileSwizzle &swizzle = *info->swizzle;
+
+ // Verify internal consistency of the swizzle (permutation size and
+ // validity).
+ if (failed(swizzle.verify(emitError))) {
+ return failure();
+ }
+
+ // The expand shape should have the same number of entries as inner tile
+ // dimensions.
+ if (swizzle.expandShape.size() != info->innerTileSizes.size()) {
+ return emitError() << "swizzle expandShape size ("
+ << swizzle.expandShape.size()
+ << ") does not match innerTileSizes size ("
+ << info->innerTileSizes.size() << ")";
+ }
+
+ // For each inner dimension, the product of expanded sizes should match
+ // the inner tile size.
+ for (auto [idx, expandDims] : llvm::enumerate(swizzle.expandShape)) {
+ int64_t product = 1;
+ for (const Codegen::TileSwizzle::Dim &dim : expandDims) {
+ product *= dim.size;
+ }
+ if (product != info->innerTileSizes[idx]) {
+ return emitError()
+ << "swizzle expandShape[" << idx << "] product (" << product
+ << ") does not match innerTileSizes[" << idx << "] ("
+ << info->innerTileSizes[idx] << ")";
+ }
+ }
+ }
+
return success();
}
};
diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/invalid.mlir b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/invalid.mlir
index c6766cc..ccb67d0 100644
--- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/invalid.mlir
@@ -61,6 +61,70 @@
// -----
+// Test that swizzle with permutation size mismatch is rejected.
+// The expandShape produces 2 total dimensions, but permutation has 3 elements.
+
+// expected-error @+2 {{swizzle permutation size (3) does not match total expanded dimensions (2)}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]]], permutation = [0, 1, 2]}}}>
+func.func @invalid_swizzle_permutation_size_mismatch(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+ return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
+// Test that swizzle with permutation index out of bounds is rejected.
+// The expandShape produces 2 dimensions (indices 0-1), but permutation contains index 2.
+
+// expected-error @+2 {{swizzle permutation is not a valid permutation}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]]], permutation = [2, 0]}}}>
+func.func @invalid_swizzle_permutation_out_of_bounds(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+ return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
+// Test that swizzle with duplicate permutation indices is rejected.
+
+// expected-error @+2 {{swizzle permutation is not a valid permutation}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]]], permutation = [0, 0]}}}>
+func.func @invalid_swizzle_permutation_duplicate(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+ return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
+// Test that swizzle expandShape size mismatch with innerTileSizes is rejected.
+// innerTileSizes has 2 entries but expandShape has 3.
+
+// expected-error @+2 {{swizzle expandShape size (3) does not match innerTileSizes size (2)}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]], [["Internal", 8]]], permutation = [0, 1, 2]}}}>
+func.func @invalid_swizzle_expand_shape_size_mismatch(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+ return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
+// Test that swizzle expandShape product mismatch with innerTileSizes is rejected.
+// innerTileSizes[0] is 16, but expandShape[0] product is 4*8 = 32.
+
+// expected-error @+2 {{swizzle expandShape[0] product (32) does not match innerTileSizes[0] (16)}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 4], ["Internal", 8]], [["Internal", 16]]], permutation = [0, 1, 2]}}}>
+func.func @invalid_swizzle_expand_shape_product_mismatch(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+ return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
+// Test that swizzle with negative permutation index is rejected.
+
+// expected-error @+2 {{swizzle permutation is not a valid permutation}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]]], permutation = [-1, 0]}}}>
+func.func @invalid_swizzle_permutation_negative(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+ return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
// CPU encoding resolver invalid test.
// Note: Most encoding verifier tests are done on gpu_encoding_resolver above.
// This test verifies that cpu_encoding_resolver also rejects invalid encodings.
diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/roundtrip.mlir b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/roundtrip.mlir
index c64bfa6..69f3e43 100644
--- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/roundtrip.mlir
@@ -1,14 +1,14 @@
// RUN: iree-opt --split-input-file %s | FileCheck %s
// Test that the GPU encoding resolver with valid encoding_info roundtrips correctly.
-// The verifier checks that innerDimsPos indices are valid for the tensor rank.
+// The verifier checks that innerDimsPos, outerDimsPerm and swizzle are valid.
-#encoding = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1]}}>
+#encoding = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]]], permutation = [0, 1]}}}>
func.func @valid_2d_encoding(%arg0: tensor<32x64xf32, #encoding>) -> tensor<32x64xf32, #encoding> {
return %arg0 : tensor<32x64xf32, #encoding>
}
// CHECK-LABEL: func.func @valid_2d_encoding
-// CHECK-SAME: tensor<32x64xf32, #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1]}}>>
+// CHECK-SAME: tensor<32x64xf32, #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = {{\[}}{{\[}}{{\[}}"Internal", 16{{\]}}{{\]}}, {{\[}}{{\[}}"Internal", 16{{\]}}{{\]}}{{\]}}, permutation = [0, 1]}}}>
// -----
@@ -60,6 +60,16 @@
// -----
+// Test encoding with valid swizzle - multi-dimensional expand shape with CrossThread and CrossIntrinsic kinds.
+#encoding_swizzle_multi = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [128, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["CrossThread", 2], ["CrossIntrinsic", 4], ["CrossThread", 16]], [["CrossIntrinsic", 4], ["CrossThread", 4]]], permutation = [0, 1, 4, 2, 3]}}}>
+func.func @valid_swizzle_multi_expand(%arg0: tensor<256x128xf32, #encoding_swizzle_multi>) -> tensor<256x128xf32, #encoding_swizzle_multi> {
+ return %arg0 : tensor<256x128xf32, #encoding_swizzle_multi>
+}
+// CHECK-LABEL: func.func @valid_swizzle_multi_expand
+// CHECK-SAME: swizzle = {expandShape = {{\[}}{{\[}}{{\[}}"CrossThread", 2], ["CrossIntrinsic", 4], ["CrossThread", 16{{\]}}{{\]}}, {{\[}}{{\[}}"CrossIntrinsic", 4], ["CrossThread", 4{{\]}}{{\]}}{{\]}}, permutation = [0, 1, 4, 2, 3]}
+
+// -----
+
// Test with non-trivial outerDimsPerm (transpose).
#encoding = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [1, 0]}}>
func.func @valid_transposed_outer_dims(%arg0: tensor<32x64xf32, #encoding>) -> tensor<32x64xf32, #encoding> {