[Encoding] Add resolver swizzle verification (#22867)


Signed-off-by: Jorn Tuyls <jorn.tuyls@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp
index 11c6ca8..662295c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp
@@ -7,9 +7,41 @@
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/DialectImplementation.h"
 
 // clang-format off
 #define GET_TYPEDEF_CLASSES
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.cpp.inc" // IWYU pragma: export
 // clang-format on
+
+namespace mlir::iree_compiler::IREE::Codegen {
+
+int64_t TileSwizzle::getExpandedSize() const {
+  int64_t totalExpandedDims = 0;
+  for (const ExpandShapeDimVectorType &expandDims : expandShape) {
+    totalExpandedDims += static_cast<int64_t>(expandDims.size());
+  }
+  return totalExpandedDims;
+}
+
+LogicalResult
+TileSwizzle::verify(function_ref<InFlightDiagnostic()> emitError) const {
+  int64_t totalExpandedDims = getExpandedSize();
+
+  // The permutation size must match the total expanded dimensions.
+  if (static_cast<int64_t>(permutation.size()) != totalExpandedDims) {
+    return emitError() << "swizzle permutation size (" << permutation.size()
+                       << ") does not match total expanded dimensions ("
+                       << totalExpandedDims << ")";
+  }
+
+  // Check that permutation is valid.
+  if (!isPermutationVector(permutation)) {
+    return emitError() << "swizzle permutation is not a valid permutation";
+  }
+
+  return success();
+}
+
+} // namespace mlir::iree_compiler::IREE::Codegen
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
index a5a8e81..037db02 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h
@@ -11,6 +11,7 @@
 
 #include "llvm/ADT/SmallVector.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/Support/LLVM.h"
 
 // clang-format off
@@ -48,7 +49,7 @@
       // "unrolled" across subgroups. Such dimensions are cross-subgroup, so in
       // particular they are cross-thread.
       CrossThread,
-      // This dimensions is across intrinsics, as in, actual instructions in the
+      // This dimension is across intrinsics, as in, actual instructions in the
       // generated code. In other words, it is an actual unrolling factor,
       // resulting in this many more instructions being generated and executed
       // on each thread/subgroup.
@@ -82,8 +83,8 @@
         : kind(kind), size(size), distributionSize(distributionSize) {}
   };
 
-  using ExpandShapeDimVectorType = llvm::SmallVector<Dim, 4>;
-  using ExpandShapeType = llvm::SmallVector<ExpandShapeDimVectorType>;
+  using ExpandShapeDimVectorType = SmallVector<Dim, 4>;
+  using ExpandShapeType = SmallVector<ExpandShapeDimVectorType>;
 
   // This vector-of-vectors contains all the information needed to generate
   // a `tensor.expand_shape` creating additional internal dimensions into the
@@ -96,7 +97,15 @@
   // to generate a `linalg.transpose` changing the layout of the tile. For
   // example, permutation[0] dictates which of the expanded dimensions becomes
   // the leading dimension of the layout.
-  llvm::SmallVector<int64_t> permutation;
+  SmallVector<int64_t> permutation;
+
+  // Returns the total number of expanded dimensions.
+  int64_t getExpandedSize() const;
+
+  // Verifies consistency of the tile swizzle:
+  // - The permutation size must match the total number of expanded dimensions.
+  // - The permutation indices must be valid (within bounds and unique).
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError) const;
 };
 
 using ScalableTileFlags = SmallVector<bool>;
diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Utils.h b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Utils.h
index 6cf64d1..a2a3424 100644
--- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Utils.h
+++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/Utils.h
@@ -109,6 +109,41 @@
       return failure();
     }
 
+    // Verify swizzle if present.
+    if (info->swizzle) {
+      const IREE::Codegen::TileSwizzle &swizzle = *info->swizzle;
+
+      // Verify internal consistency of the swizzle (permutation size and
+      // validity).
+      if (failed(swizzle.verify(emitError))) {
+        return failure();
+      }
+
+      // The expand shape should have the same number of entries as inner tile
+      // dimensions.
+      if (swizzle.expandShape.size() != info->innerTileSizes.size()) {
+        return emitError() << "swizzle expandShape size ("
+                           << swizzle.expandShape.size()
+                           << ") does not match innerTileSizes size ("
+                           << info->innerTileSizes.size() << ")";
+      }
+
+      // For each inner dimension, the product of expanded sizes should match
+      // the inner tile size.
+      for (auto [idx, expandDims] : llvm::enumerate(swizzle.expandShape)) {
+        int64_t product = 1;
+        for (const Codegen::TileSwizzle::Dim &dim : expandDims) {
+          product *= dim.size;
+        }
+        if (product != info->innerTileSizes[idx]) {
+          return emitError()
+                 << "swizzle expandShape[" << idx << "] product (" << product
+                 << ") does not match innerTileSizes[" << idx << "] ("
+                 << info->innerTileSizes[idx] << ")";
+        }
+      }
+    }
+
     return success();
   }
 };
diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/invalid.mlir b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/invalid.mlir
index c6766cc..ccb67d0 100644
--- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/invalid.mlir
@@ -61,6 +61,70 @@
 
 // -----
 
+// Test that swizzle with permutation size mismatch is rejected.
+// The expandShape produces 2 total dimensions, but permutation has 3 elements.
+
+// expected-error @+2 {{swizzle permutation size (3) does not match total expanded dimensions (2)}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]]], permutation = [0, 1, 2]}}}>
+func.func @invalid_swizzle_permutation_size_mismatch(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+  return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
+// Test that swizzle with permutation index out of bounds is rejected.
+// The expandShape produces 2 dimensions (indices 0-1), but permutation contains index 2.
+
+// expected-error @+2 {{swizzle permutation is not a valid permutation}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]]], permutation = [2, 0]}}}>
+func.func @invalid_swizzle_permutation_out_of_bounds(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+  return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
+// Test that swizzle with duplicate permutation indices is rejected.
+
+// expected-error @+2 {{swizzle permutation is not a valid permutation}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]]], permutation = [0, 0]}}}>
+func.func @invalid_swizzle_permutation_duplicate(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+  return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
+// Test that swizzle expandShape size mismatch with innerTileSizes is rejected.
+// innerTileSizes has 2 entries but expandShape has 3.
+
+// expected-error @+2 {{swizzle expandShape size (3) does not match innerTileSizes size (2)}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]], [["Internal", 8]]], permutation = [0, 1, 2]}}}>
+func.func @invalid_swizzle_expand_shape_size_mismatch(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+  return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
+// Test that swizzle expandShape product mismatch with innerTileSizes is rejected.
+// innerTileSizes[0] is 16, but expandShape[0] product is 4*8 = 32.
+
+// expected-error @+2 {{swizzle expandShape[0] product (32) does not match innerTileSizes[0] (16)}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 4], ["Internal", 8]], [["Internal", 16]]], permutation = [0, 1, 2]}}}>
+func.func @invalid_swizzle_expand_shape_product_mismatch(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+  return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
+// Test that swizzle with negative permutation index is rejected.
+
+// expected-error @+2 {{swizzle permutation is not a valid permutation}}
+#encoding_swizzle = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]]], permutation = [-1, 0]}}}>
+func.func @invalid_swizzle_permutation_negative(%arg0: tensor<32x64xf32, #encoding_swizzle>) -> tensor<32x64xf32, #encoding_swizzle> {
+  return %arg0 : tensor<32x64xf32, #encoding_swizzle>
+}
+
+// -----
+
 // CPU encoding resolver invalid test.
 // Note: Most encoding verifier tests are done on gpu_encoding_resolver above.
 // This test verifies that cpu_encoding_resolver also rejects invalid encodings.
diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/roundtrip.mlir b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/roundtrip.mlir
index c64bfa6..69f3e43 100644
--- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/test/roundtrip.mlir
@@ -1,14 +1,14 @@
 // RUN: iree-opt --split-input-file %s | FileCheck %s
 
 // Test that the GPU encoding resolver with valid encoding_info roundtrips correctly.
-// The verifier checks that innerDimsPos indices are valid for the tensor rank.
+// The verifier checks that innerDimsPos, outerDimsPerm and swizzle are valid.
 
-#encoding = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1]}}>
+#encoding = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["Internal", 16]], [["Internal", 16]]], permutation = [0, 1]}}}>
 func.func @valid_2d_encoding(%arg0: tensor<32x64xf32, #encoding>) -> tensor<32x64xf32, #encoding> {
   return %arg0 : tensor<32x64xf32, #encoding>
 }
 // CHECK-LABEL: func.func @valid_2d_encoding
-// CHECK-SAME:    tensor<32x64xf32, #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1]}}>>
+// CHECK-SAME:    tensor<32x64xf32, #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = {{\[}}{{\[}}{{\[}}"Internal", 16{{\]}}{{\]}}, {{\[}}{{\[}}"Internal", 16{{\]}}{{\]}}{{\]}}, permutation = [0, 1]}}}>
 
 // -----
 
@@ -60,6 +60,16 @@
 
 // -----
 
+// Test encoding with valid swizzle - multi-dimensional expand shape with CrossThread and CrossIntrinsic kinds.
+#encoding_swizzle_multi = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [128, 16], outerDimsPerm = [0, 1], swizzle = {expandShape = [[["CrossThread", 2], ["CrossIntrinsic", 4], ["CrossThread", 16]], [["CrossIntrinsic", 4], ["CrossThread", 4]]], permutation = [0, 1, 4, 2, 3]}}}>
+func.func @valid_swizzle_multi_expand(%arg0: tensor<256x128xf32, #encoding_swizzle_multi>) -> tensor<256x128xf32, #encoding_swizzle_multi> {
+  return %arg0 : tensor<256x128xf32, #encoding_swizzle_multi>
+}
+// CHECK-LABEL: func.func @valid_swizzle_multi_expand
+// CHECK-SAME:    swizzle = {expandShape = {{\[}}{{\[}}{{\[}}"CrossThread", 2], ["CrossIntrinsic", 4], ["CrossThread", 16{{\]}}{{\]}}, {{\[}}{{\[}}"CrossIntrinsic", 4], ["CrossThread", 4{{\]}}{{\]}}{{\]}}, permutation = [0, 1, 4, 2, 3]}
+
+// -----
+
 // Test with non-trivial outerDimsPerm (transpose).
 #encoding = #iree_gpu.gpu_encoding_resolver<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [16, 16], outerDimsPerm = [1, 0]}}>
 func.func @valid_transposed_outer_dims(%arg0: tensor<32x64xf32, #encoding>) -> tensor<32x64xf32, #encoding> {