[Integrate] Drop LLVM revert of "Remove matmul_transpose variants" (#21344)
Issue: https://github.com/iree-org/iree/issues/21349
---------
Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
Signed-off-by: Abhishek Varma <abhvarma@amd.com>
Co-authored-by: Yu-Zhewen <zhewenyu@amd.com>
Co-authored-by: Abhishek Varma <abhvarma@amd.com>
Co-authored-by: Ean Garvey <ean.garvey@amd.com>
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_pdl_patterns.mlir b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_pdl_patterns.mlir
index 0b3473c..149b833 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_pdl_patterns.mlir
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_pdl_patterns.mlir
@@ -1,16 +1,28 @@
// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-rocm-apply-builtin-pdl-patterns{targets=gfx942 enable-specialization=true}))' \
-// RUN: --mlir-print-local-scope --split-input-file %s | FileCheck %s
+// RUN: --split-input-file %s | FileCheck %s
func.func @transpose_matmul_f16(%lhs : tensor<10x20xf16>, %rhs : tensor<40x20xf16>,
%outs : tensor<10x40xf32>) -> tensor<10x40xf32> {
- %matmul = linalg.matmul_transpose_b ins(%lhs, %rhs : tensor<10x20xf16>, tensor<40x20xf16>)
- outs(%outs : tensor<10x40xf32>) -> tensor<10x40xf32>
+ %matmul = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<10x20xf16>, tensor<40x20xf16>)
+ outs(%outs : tensor<10x40xf32>) -> tensor<10x40xf32>
return %matmul : tensor<10x40xf32>
}
func.func @transpose_matmul_f8E4M3FNUZ(%lhs : tensor<10x20xf8E4M3FNUZ>, %rhs : tensor<40x20xf8E4M3FNUZ>,
%outs : tensor<10x40xf32>) -> tensor<10x40xf32> {
- %matmul = linalg.matmul_transpose_b ins(%lhs, %rhs : tensor<10x20xf8E4M3FNUZ>, tensor<40x20xf8E4M3FNUZ>)
- outs(%outs : tensor<10x40xf32>) -> tensor<10x40xf32>
+ %matmul = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<10x20xf8E4M3FNUZ>, tensor<40x20xf8E4M3FNUZ>)
+ outs(%outs : tensor<10x40xf32>) -> tensor<10x40xf32>
return %matmul : tensor<10x40xf32>
}
func.func @normal_matmul(%lhs : tensor<10x20xf16>, %rhs : tensor<20x40xf16>,
@@ -21,17 +33,27 @@
}
func.func @transpose_matmul_f32(%lhs : tensor<10x20xf32>, %rhs : tensor<40x20xf32>,
%outs : tensor<10x40xf32>) -> tensor<10x40xf32> {
- %matmul = linalg.matmul_transpose_b ins(%lhs, %rhs : tensor<10x20xf32>, tensor<40x20xf32>)
- outs(%outs : tensor<10x40xf32>) -> tensor<10x40xf32>
+ %matmul = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%lhs, %rhs : tensor<10x20xf32>, tensor<40x20xf32>)
+ outs(%outs : tensor<10x40xf32>) -> tensor<10x40xf32>
return %matmul : tensor<10x40xf32>
}
+// CHECK-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
// CHECK-LABEL: func @transpose_matmul_f16
-// CHECK: linalg.matmul_transpose_b
+// CHECK: linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// CHECK-SAME: iree_codegen.specialization_ranges
// CHECK-SAME: [<umin = 2048, udiv = 256>, <umin = 2048, udiv = 256>, <udiv = 64>]
// CHECK-SAME: [<umin = 1024, udiv = 128>, <umin = 1024, udiv = 128>, <udiv = 64>]
// CHECK-LABEL: func @transpose_matmul_f8E4M3FNUZ
-// CHECK: linalg.matmul_transpose_b
+// CHECK: linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// CHECK-SAME: iree_codegen.specialization_ranges
// CHECK-SAME: [<umin = 2048, udiv = 256>, <umin = 2048, udiv = 256>, <udiv = 128>]
// CHECK-SAME: [<umin = 1024, udiv = 128>, <umin = 1024, udiv = 128>, <udiv = 128>]
diff --git a/compiler/plugins/target/ROCM/builtins/tuning/test/spec_gfx942.mlir b/compiler/plugins/target/ROCM/builtins/tuning/test/spec_gfx942.mlir
index 9fc26e1..e9b0298 100644
--- a/compiler/plugins/target/ROCM/builtins/tuning/test/spec_gfx942.mlir
+++ b/compiler/plugins/target/ROCM/builtins/tuning/test/spec_gfx942.mlir
@@ -11,7 +11,11 @@
// RUN: --verify-diagnostics %s | FileCheck %s --check-prefix=MI300X
// Check that the default configuration for mmt_2048x1280x5120_f16_f16_f32
-// applies to the `linalg.matmul_transpose_b` below.
+// applies to the `matmul_transpose_b`, i.e., a `linalg.matmul` with the
+// following indexing maps:
+// affine_map<(d0, d1, d2) -> (d0, d2)>
+// affine_map<(d0, d1, d2) -> (d1, d2)>
+// affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func.func @mmt_2048x1280x5120_f16_f16_f32
// CHECK: linalg.generic
@@ -46,7 +50,12 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1280, 5120], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1280x5120xf16>> -> tensor<1280x5120xf16>
%5 = tensor.empty() : tensor<2048x1280xf32>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>
- %7 = linalg.matmul_transpose_b
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
ins(%3, %4 : tensor<2048x5120xf16>, tensor<1280x5120xf16>)
outs(%6 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 1280], strides = [1, 1] : tensor<2048x1280xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2048x1280xf32>>
diff --git a/compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir b/compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir
index 4fdd4d1..d1cf575 100644
--- a/compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir
+++ b/compiler/plugins/target/ROCM/test/lowering_strategy_from_tuning_spec.mlir
@@ -43,7 +43,12 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<10240x1280xf16>> -> tensor<10240x1280xf16>
%5 = tensor.empty() : tensor<2048x10240xf32>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
- %7 = linalg.matmul_transpose_b
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
ins(%3, %4 : tensor<2048x1280xf16>, tensor<10240x1280xf16>)
outs(%6 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp
index e9578c3..d53f5a2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUGeneralizeNamedOps.cpp
@@ -55,7 +55,7 @@
funcOp.walk([&](linalg::LinalgOp linalgOp) {
if (isa<linalg::BatchMatmulTransposeBOp, linalg::MatmulTransposeBOp,
linalg::VecmatOp, linalg::MatvecOp, linalg::TransposeOp>(
- linalgOp))
+ linalgOp.getOperation()))
namedOpCandidates.push_back(linalgOp);
});
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
index 63b7c98..0bb5dc8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir
@@ -82,7 +82,15 @@
%extracted_slice_1 = tensor.extract_slice %7[0, %arg0] [64, 4] [1, 1] : tensor<64x1280xf16> to tensor<64x4xf16>
%extracted_slice_2 = tensor.extract_slice %10[0, %arg0] [64, 4] [1, 1] : tensor<64x1280xf16> to tensor<64x4xf16>
%13 = linalg.copy {lowering_config = #iree_gpu.lowering_config<{thread = [1, 1]}>} ins(%extracted_slice_1 : tensor<64x4xf16>) outs(%extracted_slice_2 : tensor<64x4xf16>) -> tensor<64x4xf16>
- %14 = linalg.matmul_transpose_b {lowering_config = #iree_gpu.lowering_config<{thread = [4, 4]}>} ins(%12, %13 : tensor<64x4xf16>, tensor<64x4xf16>) outs(%arg1 : tensor<64x64xf32>) -> tensor<64x64xf32>
+ %14 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #iree_gpu.lowering_config<{thread = [4, 4]}>}
+ ins(%12, %13 : tensor<64x4xf16>, tensor<64x4xf16>)
+ outs(%arg1 : tensor<64x64xf32>) -> tensor<64x64xf32>
scf.yield %14 : tensor<64x64xf32>
}
return %11 : tensor<64x64xf32>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_generalize_named_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_generalize_named_ops.mlir
index e9fd29c..67b2531 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_generalize_named_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_generalize_named_ops.mlir
@@ -4,7 +4,14 @@
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<1x32000xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x32000xf32>) -> tensor<1x32000xf32>
- %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<1x4096xf32>, tensor<32000x4096xf32>) outs(%1 : tensor<1x32000xf32>) -> tensor<1x32000xf32>
+ %2 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : tensor<1x4096xf32>, tensor<32000x4096xf32>)
+ outs(%1 : tensor<1x32000xf32>) -> tensor<1x32000xf32>
return %2 : tensor<1x32000xf32>
}
@@ -76,7 +83,14 @@
%f0 = arith.constant 0.0 : f16
%empty = tensor.empty(%dim) : tensor<32x1x?xf16>
%fill = linalg.fill ins(%f0 : f16) outs(%empty : tensor<32x1x?xf16>) -> tensor<32x1x?xf16>
- %2 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : tensor<32x1x128xf16>, tensor<32x?x128xf16>) outs(%fill : tensor<32x1x?xf16>) -> tensor<32x1x?xf16>
+ %2 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : tensor<32x1x128xf16>, tensor<32x?x128xf16>)
+ outs(%fill : tensor<32x1x?xf16>) -> tensor<32x1x?xf16>
return %2 : tensor<32x1x?xf16>
}
@@ -99,9 +113,15 @@
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<512x512xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x512xf32>) -> tensor<512x512xf32>
- %2 = linalg.matmul_transpose_b {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16, 16, 16]]>}
- ins(%arg0, %arg1 : tensor<512x128xf16>, tensor<512x128xf16>)
- outs(%1 : tensor<512x512xf32>) -> tensor<512x512xf32>
+ %2 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16, 16, 16]]>}
+ ins(%arg0, %arg1 : tensor<512x128xf16>, tensor<512x128xf16>)
+ outs(%1 : tensor<512x512xf32>) -> tensor<512x512xf32>
return %2 : tensor<512x512xf32>
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
index 02cdadc..140902a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
@@ -201,7 +201,12 @@
%init = tensor.empty(%0) : tensor<?x2048xf32>
%init2 = tensor.empty(%0) : tensor<?x2048xf16>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x2048xf32>) -> tensor<?x2048xf32>
- %1 = linalg.matmul_transpose_b
+ %1 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
ins(%lhs, %rhs : tensor<?x4096xf16>, tensor<2048x4096xf16>)
outs(%fill : tensor<?x2048xf32>) -> tensor<?x2048xf32>
%2 = linalg.generic {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir
index 8d78602..83e8c3e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir
@@ -99,12 +99,20 @@
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<512x512xi32>>
%empty = tensor.empty() : tensor<512x512xi32>
%fill = linalg.fill ins(%c0_i32 : i32) outs(%empty : tensor<512x512xi32>) -> tensor<512x512xi32>
- %5 = linalg.matmul_transpose_b
+ %5 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
ins(%1, %2 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%fill : tensor<512x512xi32>) -> tensor<512x512xi32>
iree_tensor_ext.dispatch.tensor.store %5, %3, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : tensor<512x512xi32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<512x512xi32>>
return
}
+// CHECK-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func.func @nonacc_gemm
-// CHECK: linalg.matmul_transpose_b
+// CHECK: linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// CHECK-NOT: linalg.generic
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_x86_64.mlir b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_x86_64.mlir
index 90931a8..49f3218 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_x86_64.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_x86_64.mlir
@@ -1668,7 +1668,14 @@
%6 = iree_encoding.set_encoding %arg0 : tensor<256x128xf32> -> tensor<256x128xf32, #encoding_lhs>
%10 = iree_encoding.set_encoding %arg1 : tensor<256x512xf32> -> tensor<256x512xf32, #encoding_rhs>
%14 = iree_encoding.set_encoding %arg2 : tensor<128x512xf32> -> tensor<128x512xf32, #encoding_result>
- %15 = linalg.matmul_transpose_a ins(%6, %10 : tensor<256x128xf32, #encoding_lhs>, tensor<256x512xf32, #encoding_rhs>) outs(%14 : tensor<128x512xf32, #encoding_result>) -> tensor<128x512xf32, #encoding_result>
+ %15 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%6, %10 : tensor<256x128xf32, #encoding_lhs>, tensor<256x512xf32, #encoding_rhs>)
+ outs(%14 : tensor<128x512xf32, #encoding_result>) -> tensor<128x512xf32, #encoding_result>
%16 = iree_encoding.unset_encoding %15 : tensor<128x512xf32, #encoding_result> -> tensor<128x512xf32>
return %16 : tensor<128x512xf32>
}
@@ -1707,7 +1714,14 @@
%6 = iree_encoding.set_encoding %arg0 : tensor<128x256xf32> -> tensor<128x256xf32, #encoding_lhs>
%10 = iree_encoding.set_encoding %arg1 : tensor<512x256xf32> -> tensor<512x256xf32, #encoding_rhs>
%14 = iree_encoding.set_encoding %arg2 : tensor<128x512xf32> -> tensor<128x512xf32, #encoding_result>
- %15 = linalg.matmul_transpose_b ins(%6, %10 : tensor<128x256xf32, #encoding_lhs>, tensor<512x256xf32, #encoding_rhs>) outs(%14 : tensor<128x512xf32, #encoding_result>) -> tensor<128x512xf32, #encoding_result>
+ %15 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%6, %10 : tensor<128x256xf32, #encoding_lhs>, tensor<512x256xf32, #encoding_rhs>)
+ outs(%14 : tensor<128x512xf32, #encoding_result>) -> tensor<128x512xf32, #encoding_result>
%16 = iree_encoding.unset_encoding %15 : tensor<128x512xf32, #encoding_result> -> tensor<128x512xf32>
return %16 : tensor<128x512xf32>
}
@@ -1746,7 +1760,14 @@
%7 = iree_encoding.set_encoding %arg0 : tensor<2x256x128xf32> -> tensor<2x256x128xf32, #encoding_lhs>
%12 = iree_encoding.set_encoding %arg1 : tensor<2x256x512xf32> -> tensor<2x256x512xf32, #encoding_rhs>
%17 = iree_encoding.set_encoding %arg2 : tensor<2x128x512xf32> -> tensor<2x128x512xf32, #encoding_result>
- %18 = linalg.batch_matmul_transpose_a ins(%7, %12 : tensor<2x256x128xf32, #encoding_lhs>, tensor<2x256x512xf32, #encoding_rhs>) outs(%17 : tensor<2x128x512xf32, #encoding_result>) -> tensor<2x128x512xf32, #encoding_result>
+ %18 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%7, %12 : tensor<2x256x128xf32, #encoding_lhs>, tensor<2x256x512xf32, #encoding_rhs>)
+ outs(%17 : tensor<2x128x512xf32, #encoding_result>) -> tensor<2x128x512xf32, #encoding_result>
%19 = iree_encoding.unset_encoding %18 : tensor<2x128x512xf32, #encoding_result> -> tensor<2x128x512xf32>
return %19 : tensor<2x128x512xf32>
}
@@ -1785,7 +1806,14 @@
%7 = iree_encoding.set_encoding %arg0 : tensor<2x128x256xf32> -> tensor<2x128x256xf32, #encoding_lhs>
%12 = iree_encoding.set_encoding %arg1 : tensor<2x512x256xf32> -> tensor<2x512x256xf32, #encoding_rhs>
%17 = iree_encoding.set_encoding %arg2 : tensor<2x128x512xf32> -> tensor<2x128x512xf32, #encoding_result>
- %18 = linalg.batch_matmul_transpose_b ins(%7, %12 : tensor<2x128x256xf32, #encoding_lhs>, tensor<2x512x256xf32, #encoding_rhs>) outs(%17 : tensor<2x128x512xf32, #encoding_result>) -> tensor<2x128x512xf32, #encoding_result>
+ %18 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%7, %12 : tensor<2x128x256xf32, #encoding_lhs>, tensor<2x512x256xf32, #encoding_rhs>)
+ outs(%17 : tensor<2x128x512xf32, #encoding_result>) -> tensor<2x128x512xf32, #encoding_result>
%19 = iree_encoding.unset_encoding %18 : tensor<2x128x512xf32, #encoding_result> -> tensor<2x128x512xf32>
return %19 : tensor<2x128x512xf32>
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/specialize_exports.mlir b/compiler/src/iree/compiler/Codegen/Common/test/specialize_exports.mlir
index 1019853..9c879fe 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/specialize_exports.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/specialize_exports.mlir
@@ -28,8 +28,13 @@
%8 = iree_tensor_ext.dispatch.tensor.load %3, offsets = [0, 0], sizes = [1024, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x4096xf16>> -> tensor<1024x4096xf16>
%9 = tensor.empty(%4) : tensor<?x1024xf32>
%10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<?x1024xf32>) -> tensor<?x1024xf32>
- %11 = linalg.matmul_transpose_b {
- iree_codegen.specialization_ranges = #util<int.assumption.multi_array[
+ %11 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {iree_codegen.specialization_ranges = #util<int.assumption.multi_array[
[<umin = 128, umax = 4096, udiv = 128>, <umin = 128, umax = 4096, udiv = 128>, <umin = 64, udiv = 64>],
[<umin = 4096, udiv = 256>, <umin = 4096, udiv = 256>, <udiv = 64>]]>}
ins(%7, %8 : tensor<?x4096xf16>, tensor<1024x4096xf16>) outs(%10 : tensor<?x1024xf32>) -> tensor<?x1024xf32>
@@ -97,8 +102,13 @@
%8 = iree_tensor_ext.dispatch.tensor.load %3, offsets = [0, 0], sizes = [1024, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x4096xf16>> -> tensor<1024x4096xf16>
%9 = tensor.empty(%4) : tensor<?x1024xf32>
%10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<?x1024xf32>) -> tensor<?x1024xf32>
- %11 = linalg.matmul_transpose_b {
- iree_codegen.specialization_ranges = #util<int.assumption.multi_array[
+ %11 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {iree_codegen.specialization_ranges = #util<int.assumption.multi_array[
[<umin = 128, umax = 4096, udiv = 128>, <umin = 128, umax = 4096, udiv = 128>, <umin = 64, udiv = 64>],
[<umin = 0, udiv = 512>, <umin = 0, udiv = 512>, <udiv = 64>], [<udiv = 16>, <udiv = 16>, <udiv = 64>],
[<umin = 0, udiv = 512>, <umin = 0, udiv = 512>, <udiv = 64>]]>}
@@ -165,8 +175,13 @@
%11 = iree_tensor_ext.dispatch.tensor.load %8, offsets = [0, 0], sizes = [%5, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x4096xf16>>{%5} -> tensor<?x4096xf16>
%12 = tensor.empty(%6, %5) : tensor<?x?xf32>
%13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %14 = linalg.matmul_transpose_b {
- iree_codegen.specialization_ranges = #util<int.assumption.multi_array[
+ %14 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {iree_codegen.specialization_ranges = #util<int.assumption.multi_array[
[<umin = 128, umax = 4096, udiv = 128>, <umin = 128, umax = 4096, udiv = 128>, <umin = 64, udiv = 64>],
[<umin = 4096, udiv = 256>, <umin = 4096, udiv = 256>, <udiv = 64>]]>}
ins(%10, %11 : tensor<?x4096xf16>, tensor<?x4096xf16>) outs(%13 : tensor<?x?xf32>) -> tensor<?x?xf32>
@@ -265,8 +280,13 @@
%8 = iree_tensor_ext.dispatch.tensor.load %3, offsets = [0, 0], sizes = [1024, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x4096xf16>> -> tensor<1024x4096xf16>
%9 = tensor.empty(%4) : tensor<?x1024xf32>
%10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<?x1024xf32>) -> tensor<?x1024xf32>
- %11 = linalg.matmul_transpose_b {
- iree_codegen.specialization_ranges = #util<int.assumption.multi_array[
+ %11 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {iree_codegen.specialization_ranges = #util<int.assumption.multi_array[
[<umin = 4096, udiv = 256>, <udiv = 256>, <udiv = 64>]]>}
ins(%7, %8 : tensor<?x4096xf16>, tensor<1024x4096xf16>) outs(%10 : tensor<?x1024xf32>) -> tensor<?x1024xf32>
iree_tensor_ext.dispatch.tensor.store %11, %6, offsets = [0, 0], sizes = [%4, 1024], strides = [1, 1] : tensor<?x1024xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x1024xf32>>{%4}
@@ -327,8 +347,13 @@
%11 = iree_tensor_ext.dispatch.tensor.load %8, offsets = [0, 0], sizes = [%5, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x4096xf16>>{%5} -> tensor<?x4096xf16>
%12 = tensor.empty(%6, %5) : tensor<?x?xf32>
%13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %14 = linalg.matmul_transpose_b {
- iree_codegen.specialization_ranges = #util<int.assumption.multi_array[
+ %14 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {iree_codegen.specialization_ranges = #util<int.assumption.multi_array[
[<umin = 128, umax = 4096, udiv = 128>, <umin = 128, umax = 4096, udiv = 128>, <umin = 64, udiv = 64>],
[<umin = 4096, udiv = 256>, <umin = 4096, udiv = 256>, <udiv = 64>]]>}
ins(%10, %11 : tensor<?x4096xf16>, tensor<?x4096xf16>) outs(%13 : tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index da54abd..d2c6c76 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -19,6 +19,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVectorExtras.h"
@@ -1595,7 +1596,9 @@
SmallVectorImpl<int64_t> &sizes,
SmallVectorImpl<bool> &scalableSizeFlags) {
// Double-check the operation is one that is supported for lowering to ArmSME.
- if (!llvm::isa<linalg::MatmulOp, linalg::MatmulTransposeAOp>(op))
+ Operation *rawOp = op.getOperation();
+ if (!(IREE::LinalgExt::isPureMatmul(rawOp) ||
+ isa<linalg::MatmulTransposeAOp>(rawOp)))
return;
auto elementType = nonWideningLinalgElementType(op);
@@ -1655,7 +1658,8 @@
hasAnyVFeature(targetAttr.getConfiguration())) {
// Use default tile size for matmul_transpose_b &
// batch_matmul_transpose_b to avoid performance drop.
- if (!isa<linalg::MatmulTransposeBOp, linalg::BatchMatmulTransposeBOp>(op)) {
+ if (!isa<linalg::MatmulTransposeBOp, linalg::BatchMatmulTransposeBOp>(
+ op.getOperation())) {
// Try to maximize the vector register utilization rate for matmul.
getMatmulRISCVVectorSizes(entryPointFn, op, vectorSize, matmulTileSizes,
matmulScalableFlags);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
index a8e5695..27a7f3a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
@@ -30,7 +30,13 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<10240x1280xf16>> -> tensor<10240x1280xf16>
%5 = tensor.empty() : tensor<2048x10240xf32>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
- %7 = linalg.matmul_transpose_b {lowering_config = #config}
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #config}
ins(%3, %4 : tensor<2048x1280xf16>, tensor<10240x1280xf16>)
outs(%6 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
@@ -102,7 +108,13 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<10240x1280xf16>> -> tensor<10240x1280xf16>
%5 = tensor.empty() : tensor<2048x10240xf32>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
- %7 = linalg.matmul_transpose_b {lowering_config = #config}
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #config}
ins(%3, %4 : tensor<2048x1280xf16>, tensor<10240x1280xf16>)
outs(%6 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
@@ -174,7 +186,13 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<10240x1280xf16>> -> tensor<10240x1280xf16>
%5 = tensor.empty() : tensor<2048x10240xf32>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
- %7 = linalg.matmul_transpose_b {lowering_config = #config}
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #config}
ins(%3, %4 : tensor<2048x1280xf16>, tensor<10240x1280xf16>)
outs(%6 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
@@ -250,7 +268,13 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<10240x1280x!eltype>> -> tensor<10240x1280x!eltype>
%5 = tensor.empty() : tensor<2048x10240x!aeltype>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240x!aeltype>) -> tensor<2048x10240x!aeltype>
- %7 = linalg.matmul_transpose_b {lowering_config = #config}
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #config}
ins(%3, %4 : tensor<2048x1280x!eltype>, tensor<10240x1280x!eltype>)
outs(%6 : tensor<2048x10240x!aeltype>) -> tensor<2048x10240x!aeltype>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240x!aeltype> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2048x10240x!aeltype>>
@@ -305,7 +329,13 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<10240x1280x!eltype>> -> tensor<10240x1280x!eltype>
%5 = tensor.empty() : tensor<2048x10240x!aeltype>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240x!aeltype>) -> tensor<2048x10240x!aeltype>
- %7 = linalg.matmul_transpose_b {lowering_config = #config}
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #config}
ins(%3, %4 : tensor<2048x1280x!eltype>, tensor<10240x1280x!eltype>)
outs(%6 : tensor<2048x10240x!aeltype>) -> tensor<2048x10240x!aeltype>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240x!aeltype> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2048x10240x!aeltype>>
@@ -360,7 +390,13 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<10240x1280x!eltype>> -> tensor<10240x1280x!eltype>
%5 = tensor.empty() : tensor<2048x10240x!aeltype>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240x!aeltype>) -> tensor<2048x10240x!aeltype>
- %7 = linalg.matmul_transpose_b {lowering_config = #config}
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #config}
ins(%3, %4 : tensor<2048x1280x!eltype>, tensor<10240x1280x!eltype>)
outs(%6 : tensor<2048x10240x!aeltype>) -> tensor<2048x10240x!aeltype>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240x!aeltype> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2048x10240x!aeltype>>
@@ -415,7 +451,13 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<10240x1280x!eltype>> -> tensor<10240x1280x!eltype>
%5 = tensor.empty() : tensor<2048x10240x!aeltype>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240x!aeltype>) -> tensor<2048x10240x!aeltype>
- %7 = linalg.matmul_transpose_b {lowering_config = #config}
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #config}
ins(%3, %4 : tensor<2048x1280x!eltype>, tensor<10240x1280x!eltype>)
outs(%6 : tensor<2048x10240x!aeltype>) -> tensor<2048x10240x!aeltype>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240x!aeltype> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2048x10240x!aeltype>>
@@ -1066,7 +1108,13 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [10240, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<10240x1280xf16>> -> tensor<10240x1280xf16>
%5 = tensor.empty() : tensor<2048x10240xf32>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
- %7 = linalg.matmul_transpose_b {lowering_config = #config}
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #config}
ins(%3, %4 : tensor<2048x1280xf16>, tensor<10240x1280xf16>)
outs(%6 : tensor<2048x10240xf32>) -> tensor<2048x10240xf32>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2048, 10240], strides = [1, 1] : tensor<2048x10240xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2048x10240xf32>>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir
index 56cf46e..69c8496 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx942.mlir
@@ -373,7 +373,14 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
%5 = tensor.empty() : tensor<256x256xi32>
%6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32>
- %7 = linalg.matmul_transpose_b {lowering_config = #config} ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32>
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #config}
+ ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xi32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<256x256xi32>>
return
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx950.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx950.mlir
index ca93cf0..cb46a69 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx950.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_gfx950.mlir
@@ -373,7 +373,13 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<256x256xi8>
%5 = tensor.empty() : tensor<256x256xi32>
%6 = linalg.fill ins(%cst : i32) outs(%5 : tensor<256x256xi32>) -> tensor<256x256xi32>
- %7 = linalg.matmul_transpose_b {lowering_config = #config} ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32>
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ {lowering_config = #config} ins(%3, %4 : tensor<256x256xi8>, tensor<256x256xi8>) outs(%6 : tensor<256x256xi32>) -> tensor<256x256xi32>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [256, 256], strides = [1, 1] : tensor<256x256xi32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<256x256xi32>>
return
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_warp_reduction.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_warp_reduction.mlir
index 1a000ae..6e82966 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_warp_reduction.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_warp_reduction.mlir
@@ -63,7 +63,13 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1280, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1280x1280xf32>> -> tensor<1280x1280xf32>
%5 = tensor.empty() : tensor<1x1280xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
- %7 = linalg.matmul_transpose_b ins(%3, %4 : tensor<1x1280xf32>, tensor<1280x1280xf32>) outs(%6 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%3, %4 : tensor<1x1280xf32>, tensor<1280x1280xf32>) outs(%6 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1, 1280], strides = [1, 1] : tensor<1x1280xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1x1280xf32>>
return
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir
index 1f3770d..0919ace 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir
@@ -335,16 +335,27 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32000, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32000x4096xf16>> -> tensor<32000x4096xf16>
%5 = tensor.empty() : tensor<2x32000xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<2x32000xf16>) -> tensor<2x32000xf16>
- %7 = linalg.matmul_transpose_b ins(%3, %4 : tensor<2x4096xf16>, tensor<32000x4096xf16>) outs(%6 : tensor<2x32000xf16>) -> tensor<2x32000xf16>
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%3, %4 : tensor<2x4096xf16>, tensor<32000x4096xf16>)
+ outs(%6 : tensor<2x32000xf16>) -> tensor<2x32000xf16>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [2, 32000], strides = [1, 1] : tensor<2x32000xf16> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x32000xf16>>
return
}
+// CHECK-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 512]{{\]}}>
// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64>
// CHECK: func.func @skinny_mmt()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
-// CHECK: linalg.matmul_transpose_b
+// CHECK: linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// CHECK-SAME: lowering_config = #[[$CONFIG]]
// -----
@@ -366,7 +377,14 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32000, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32000x4096xf16>> -> tensor<32000x4096xf16>
%5 = tensor.empty() : tensor<32000x2xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<32000x2xf16>) -> tensor<32000x2xf16>
- %7 = linalg.matmul_transpose_b ins(%4, %3 : tensor<32000x4096xf16>, tensor<2x4096xf16>) outs(%6 : tensor<32000x2xf16>) -> tensor<32000x2xf16>
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%4, %3 : tensor<32000x4096xf16>, tensor<2x4096xf16>)
+ outs(%6 : tensor<32000x2xf16>) -> tensor<32000x2xf16>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [32000, 2], strides = [1, 1] : tensor<32000x2xf16> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32000x2xf16>>
return
}
@@ -375,7 +393,7 @@
// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64>
// CHECK: func.func @skinny_mmt()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
-// CHECK: linalg.matmul_transpose_b
+// CHECK: linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// CHECK-SAME: lowering_config = #[[$CONFIG]]
// -----
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_pipeline_generalize_named_ops.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_pipeline_generalize_named_ops.mlir
index c5498e8..986a67e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_pipeline_generalize_named_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_pipeline_generalize_named_ops.mlir
@@ -4,11 +4,12 @@
// RUN: iree-opt --pass-pipeline="builtin.module(iree-codegen-rocdl-configuration-pipeline)" --iree-gpu-test-target=gfx942 \
// RUN: --split-input-file %s | FileCheck %s
-// Make sure that the GPU configuration pipelines generalize named ops, e.g., linalg.matmul_transpose_b to linalg.generic.
+// Make sure that the GPU configuration pipelines generalize named ops,
+// e.g., matmul_transpose_b (linalg.matmul indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]) to linalg.generic.
// CHECK: linalg.fill
// CHECK-NEXT: linalg.generic
-// CHECK-NOT: linalg.matmul_transpose_b
+// CHECK-NOT: linalg.matmul indexing_maps
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
@@ -27,7 +28,13 @@
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1280, 1280], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1280x1280xf32>> -> tensor<1280x1280xf32>
%5 = tensor.empty() : tensor<1x1280xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
- %7 = linalg.matmul_transpose_b ins(%3, %4 : tensor<1x1280xf32>, tensor<1280x1280xf32>) outs(%6 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
+ %7 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%3, %4 : tensor<1x1280xf32>, tensor<1280x1280xf32>) outs(%6 : tensor<1x1280xf32>) -> tensor<1x1280xf32>
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1, 1280], strides = [1, 1] : tensor<1x1280xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1x1280xf32>>
return
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
index 24df1a3..797726b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
@@ -948,4 +948,16 @@
return !foundNonScalar;
}
+bool isPureMatmul(Operation *op) {
+ auto matmulOp = dyn_cast_or_null<linalg::MatmulOp>(op);
+ return matmulOp &&
+ linalg::MatmulOp::isDefaultIndexingMaps(matmulOp.getIndexingMaps());
+}
+
+bool isPureBatchMatmul(Operation *op) {
+ auto batchMatmulOp = dyn_cast_or_null<linalg::BatchMatmulOp>(op);
+ return batchMatmulOp && linalg::BatchMatmulOp::isDefaultIndexingMaps(
+ batchMatmulOp.getIndexingMaps());
+}
+
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
index 6541be9..92d1590 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
@@ -236,5 +236,12 @@
/// either as inputs or as implicit captures.
bool hasOnlyScalarInputs(linalg::GenericOp op);
+/// Returns true if the operation is a pure MatmulOp (no transpose/broadcast).
+bool isPureMatmul(Operation *op);
+
+/// Returns true if the operation is a pure BatchMatmulOp (no
+/// transpose/broadcast).
+bool isPureBatchMatmul(Operation *op);
+
} // namespace mlir::iree_compiler::IREE::LinalgExt
#endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
index f1e1f84..031482c 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
@@ -603,18 +603,27 @@
%12 = arith.extui %in : i8 to i32
linalg.yield %12 : i32
} -> tensor<?x?x?xi32>
- %op = linalg.batch_matmul_transpose_b
+ %op = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
ins(%dequant, %rhs : tensor<?x?x?xi32>, tensor<?x?x?xi32>)
outs(%init : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
util.return %op : tensor<?x?x?xi32>
}
+// CHECK-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-LABEL: func public @broadcasting_dequant_op(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xi8>
// CHECK-NOT: flow.dispatch.region
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK: %[[RETURN:.+]] = flow.dispatch.region
-// CHECK: %[[MATMUL:.+]] = linalg.batch_matmul_transpose_b
+// CHECK: %[[MATMUL:.+]] = linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// CHECK-SAME: ins(%[[GENERIC]],
// CHECK: flow.return %[[MATMUL]]
// CHECK: return %[[RETURN]]
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir
index 10ee965..7c3f006 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fuse_horizontal_contractions.mlir
@@ -142,7 +142,12 @@
^bb0(%in: i8, %out: i8):
linalg.yield %in : i8
} -> tensor<2x640x640xi8>
- %5 = linalg.batch_matmul_transpose_b
+ %5 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
ins(%arg0, %rhs0 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>)
outs(%2 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
%6 = linalg.generic {
@@ -182,7 +187,12 @@
^bb0(%in: i8, %out: i8):
linalg.yield %in : i8
} -> tensor<2x640x640xi8>
- %8 = linalg.batch_matmul_transpose_b
+ %8 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
ins(%arg0, %rhs1 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>)
outs(%2 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
%9 = linalg.generic {
@@ -234,8 +244,20 @@
%0 = tensor.empty() : tensor<2x4096x640xf16>
%1 = tensor.empty() : tensor<2x4096x640xi32>
%2 = linalg.fill ins(%c0_i32 : i32) outs(%1 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
- %3 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%2 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
- %4 = linalg.batch_matmul_transpose_b ins(%arg0, %arg4 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%2 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
+ %3 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%2 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
+ %4 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg4 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%2 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %3 : tensor<2x4096x640xi32>, tensor<2x4096x640xi32>) outs(%0 : tensor<2x4096x640xf16>) {
^bb0(%in: i32, %in_0: i32, %out: f16):
%6 = arith.sitofp %in : i32 to f32
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/pipeline_tests.mlir b/compiler/src/iree/compiler/DispatchCreation/test/pipeline_tests.mlir
index 2695044..02f64cc 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/pipeline_tests.mlir
@@ -330,13 +330,24 @@
%result = iree_linalg_ext.gather dimension_map = [1, 0]
ins(%source, %indices : tensor<20x20x100xi32>, tensor<100x2xi32>)
outs(%empty: tensor<100x100xi32>) -> tensor<100x100xi32>
- %mm = linalg.matmul_transpose_b ins(%result, %arg2 : tensor<100x100xi32>, tensor<100x100xi32>) outs(%arg3 : tensor<100x100xi32>) -> tensor<100x100xi32>
+ %mm = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%result, %arg2 : tensor<100x100xi32>, tensor<100x100xi32>)
+ outs(%arg3 : tensor<100x100xi32>) -> tensor<100x100xi32>
util.return %mm : tensor<100x100xi32>
}
// CHECK-LABEL: util.func public @gather_matmul
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.workgroups
// CHECK: %[[GATHER:.+]] = iree_linalg_ext.gather
-// CHECK: %[[MATMUL:.+]] = linalg.matmul_transpose_b
+// CHECK: %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME: indexing_maps = [
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d2)>,
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d1, d2)>,
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>]
// CHECK-SAME: ins(%[[GATHER]]
// CHECK: iree_tensor_ext.dispatch.tensor.store %[[MATMUL]]
// CHECK: util.return %[[DISPATCH]]
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir b/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
index 3a777cd..5161ab1 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/set_encoding.mlir
@@ -811,7 +811,13 @@
util.func public @matmul_transpose_a_f32f32f32(%arg0 : tensor<250x100xf32>, %arg1 : tensor<250x500xf32>,
%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32> {
- %0 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<250x100xf32>, tensor<250x500xf32>)
+ %0 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : tensor<250x100xf32>, tensor<250x500xf32>)
outs(%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32>
util.return %0 : tensor<100x500xf32>
}
@@ -832,7 +838,8 @@
// CHECK-SAME: tensor<250x500xf32, #[[RHS_ENCODING]]>
// CHECK: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
// CHECK-SAME: tensor<100x500xf32, #[[OUT_ENCODING]]>
-// CHECK: %[[MATMUL:.+]] = linalg.matmul_transpose_a
+// CHECK: %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
// CHECK: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[MATMUL]]
@@ -842,7 +849,13 @@
util.func public @matmul_transpose_b_f32f32f32(%arg0 : tensor<100x250xf32>, %arg1 : tensor<500x250xf32>,
%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32> {
- %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<100x250xf32>, tensor<500x250xf32>)
+ %0 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : tensor<100x250xf32>, tensor<500x250xf32>)
outs(%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32>
util.return %0 : tensor<100x500xf32>
}
@@ -862,7 +875,8 @@
// CHECK-SAME: tensor<500x250xf32, #[[RHS_ENCODING]]>
// CHECK: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
// CHECK-SAME: tensor<100x500xf32, #[[OUT_ENCODING]]>
-// CHECK: %[[MATMUL:.+]] = linalg.matmul_transpose_b
+// CHECK: %[[MATMUL:.+]] = linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
// CHECK: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[MATMUL]]
@@ -872,7 +886,13 @@
util.func public @batch_matmul_transpose_a_f32f32f32(%arg0 : tensor<2x250x100xf32>, %arg1 : tensor<2x250x500xf32>,
%arg2 : tensor<2x100x500xf32>) -> tensor<2x100x500xf32> {
- %0 = linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : tensor<2x250x100xf32>, tensor<2x250x500xf32>)
+ %0 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : tensor<2x250x100xf32>, tensor<2x250x500xf32>)
outs(%arg2 : tensor<2x100x500xf32>) -> tensor<2x100x500xf32>
util.return %0 : tensor<2x100x500xf32>
}
@@ -892,7 +912,8 @@
// CHECK-SAME: tensor<2x250x500xf32, #[[RHS_ENCODING]]>
// CHECK: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
// CHECK-SAME: tensor<2x100x500xf32, #[[OUT_ENCODING]]>
-// CHECK: %[[BATCH_MATMUL:.+]] = linalg.batch_matmul_transpose_a
+// CHECK: %[[BATCH_MATMUL:.+]] = linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
// CHECK: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[BATCH_MATMUL]]
@@ -902,7 +923,13 @@
util.func public @batch_matmul_transpose_b_f32f32f32(%arg0 : tensor<2x100x250xf32>, %arg1 : tensor<2x500x250xf32>,
%arg2 : tensor<2x100x500xf32>) -> tensor<2x100x500xf32> {
- %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : tensor<2x100x250xf32>, tensor<2x500x250xf32>)
+ %0 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : tensor<2x100x250xf32>, tensor<2x500x250xf32>)
outs(%arg2 : tensor<2x100x500xf32>) -> tensor<2x100x500xf32>
util.return %0 : tensor<2x100x500xf32>
}
@@ -922,7 +949,8 @@
// CHECK-SAME: tensor<2x500x250xf32, #[[RHS_ENCODING]]>
// CHECK: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
// CHECK-SAME: tensor<2x100x500xf32, #[[OUT_ENCODING]]>
-// CHECK: %[[BATCH_MATMUL:.+]] = linalg.batch_matmul_transpose_b
+// CHECK: %[[BATCH_MATMUL:.+]] = linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
// CHECK: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[BATCH_MATMUL]]
@@ -1092,7 +1120,13 @@
linalg.yield %14 : i32
} -> tensor<?x?x?xi32>
%12 = linalg.fill ins(%c0_i32_0 : i32) outs(%9 : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
- %13 = linalg.batch_matmul_transpose_b ins(%11, %6 : tensor<?x?x?xi32>, tensor<?x?x?xi32>) outs(%12 : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
+ %13 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%11, %6 : tensor<?x?x?xi32>, tensor<?x?x?xi32>) outs(%12 : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
flow.return %13 : tensor<?x?x?xi32>
}
util.return %7 : tensor<?x?x?xi32>
@@ -1121,7 +1155,8 @@
// CHECK-SAME: -> tensor<?x?x?xi32, #[[RHS_ENCODING]]>
// CHECK: %[[INIT:.+]] = tensor.empty({{.+}}) : tensor<?x?x?xi32, #[[OUT_ENCODING]]>
// CHECK: %[[FILL:.+]] = linalg.fill ins({{.+}}) outs(%[[INIT]]
-// CHECK: %[[GEMM:.+]] = linalg.batch_matmul_transpose_b
+// CHECK: %[[GEMM:.+]] = linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]]
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
// CHECK-SAME: outs(%[[FILL]]
// CHECK: %[[UNSET:.+]] = iree_encoding.unset_encoding %[[GEMM]]{{.+}} -> tensor<?x?x?xi32>{%[[ARG1_D0]], %[[ARG0_D0]], %[[ARG1_D1]]}
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/sink_reshapes.mlir b/compiler/src/iree/compiler/DispatchCreation/test/sink_reshapes.mlir
index f91ddac..3299462 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/sink_reshapes.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/sink_reshapes.mlir
@@ -17,7 +17,13 @@
%0 = tensor.empty(%m, %n) : tensor<?x?xf32>
%m_by_2 = arith.divsi %m, %c2 : index
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>)
+ %2 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>)
outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = tensor.expand_shape %2 [[0, 1], [2]] output_shape [2, %m, %n]: tensor<?x?xf32> into tensor<2x?x?xf32>
%4 = tensor.empty(%m_by_2, %n) : tensor<2x?x?xf16>
@@ -59,8 +65,11 @@
func.return %8 : tensor<2x?x?xf16>
}
+// CHECK-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @do_not_sink_across_already_fusable_ops
-// CHECK: %[[GEMM:.+]] = linalg.matmul_transpose_b
+// CHECK: %[[GEMM:.+]] = linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GEMM]],
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[GENERIC1]]
@@ -134,7 +143,13 @@
// -> consumer are fusable.
func.func @better_producer_estimate(%lhs : tensor<2x4096x640xi32>, %rhs : tensor<2x640x640xi32>,
%fill0 : tensor<2x4096x640xi32>, %fill1 : tensor<2x4096xi32>) -> tensor<2x4096x640x1xf16> {
- %bmm = linalg.batch_matmul_transpose_b ins(%lhs, %rhs : tensor<2x4096x640xi32>, tensor<2x640x640xi32>)
+ %bmm = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%lhs, %rhs : tensor<2x4096x640xi32>, tensor<2x640x640xi32>)
outs(%fill0 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
%reduction = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
@@ -162,8 +177,11 @@
} -> tensor<2x4096x640x1xf16>
return %quant : tensor<2x4096x640x1xf16>
}
+// CHECK-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-LABEL: func @better_producer_estimate(
-// CHECK: %[[BMM:.+]] = linalg.batch_matmul_transpose_b
+// CHECK: %[[BMM:.+]] = linalg.batch_matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// CHECK: %[[REDUCTION:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK: %[[GENERIC:.+]] = linalg.generic
diff --git a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
index 9193db6..df83d07 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/DemoteContractionInputsToBF16.cpp
@@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/Support/Debug.h"
@@ -78,7 +79,8 @@
})
->getResults()[0]);
}
- auto namedOp = cast<std::remove_pointer_t<decltype(typePtr)>>(linalgOp);
+ auto namedOp = cast<std::remove_pointer_t<decltype(typePtr)>>(
+ linalgOp.getOperation());
rewriter.replaceOpWithNewOp<std::remove_pointer_t<decltype(typePtr)>>(
linalgOp, demotedInputs, linalgOp.getDpsInits(),
linalg::getPrunedAttributeList(namedOp));
@@ -90,37 +92,38 @@
bool demoteConv = (demoteOption == DemotionOption::All) ||
(demoteOption == DemotionOption::Conv);
- if (demoteMatmul && isa<linalg::MatmulOp>(linalgOp)) {
+ Operation *op = linalgOp.getOperation();
+ if (demoteMatmul && IREE::LinalgExt::isPureMatmul(op)) {
replaceOpInputs(static_cast<linalg::MatmulOp *>(nullptr));
- } else if (demoteMatmul && isa<linalg::MatvecOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::MatvecOp>(op)) {
replaceOpInputs(static_cast<linalg::MatvecOp *>(nullptr));
- } else if (demoteMatmul && isa<linalg::VecmatOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::VecmatOp>(op)) {
replaceOpInputs(static_cast<linalg::VecmatOp *>(nullptr));
- } else if (demoteMatmul && isa<linalg::BatchMatmulOp>(linalgOp)) {
+ } else if (demoteMatmul && IREE::LinalgExt::isPureBatchMatmul(op)) {
replaceOpInputs(static_cast<linalg::BatchMatmulOp *>(nullptr));
- } else if (demoteMatmul && isa<linalg::BatchMatvecOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::BatchMatvecOp>(op)) {
replaceOpInputs(static_cast<linalg::BatchMatvecOp *>(nullptr));
- } else if (demoteMatmul && isa<linalg::BatchVecmatOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::BatchVecmatOp>(op)) {
replaceOpInputs(static_cast<linalg::BatchVecmatOp *>(nullptr));
- } else if (demoteMatmul && isa<linalg::MatmulTransposeAOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::MatmulTransposeAOp>(op)) {
replaceOpInputs(static_cast<linalg::MatmulTransposeAOp *>(nullptr));
- } else if (demoteMatmul && isa<linalg::MatmulTransposeBOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::MatmulTransposeBOp>(op)) {
replaceOpInputs(static_cast<linalg::MatmulTransposeBOp *>(nullptr));
- } else if (demoteMatmul && isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::BatchMatmulTransposeAOp>(op)) {
replaceOpInputs(static_cast<linalg::BatchMatmulTransposeAOp *>(nullptr));
- } else if (demoteMatmul && isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
+ } else if (demoteMatmul && isa<linalg::BatchMatmulTransposeBOp>(op)) {
replaceOpInputs(static_cast<linalg::BatchMatmulTransposeBOp *>(nullptr));
- } else if (demoteConv && isa<linalg::Conv2DOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DOp>(op)) {
replaceOpInputs(static_cast<linalg::Conv2DOp *>(nullptr));
- } else if (demoteConv && isa<linalg::Conv2DNchwFchwOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DNchwFchwOp>(op)) {
replaceOpInputs(static_cast<linalg::Conv2DNchwFchwOp *>(nullptr));
- } else if (demoteConv && isa<linalg::Conv2DNhwcHwcfOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DNhwcHwcfOp>(op)) {
replaceOpInputs(static_cast<linalg::Conv2DNhwcHwcfOp *>(nullptr));
- } else if (demoteConv && isa<linalg::Conv2DNhwcFhwcOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DNhwcFhwcOp>(op)) {
replaceOpInputs(static_cast<linalg::Conv2DNhwcFhwcOp *>(nullptr));
- } else if (demoteConv && isa<linalg::Conv2DNgchwFgchwOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DNgchwFgchwOp>(op)) {
replaceOpInputs(static_cast<linalg::Conv2DNgchwFgchwOp *>(nullptr));
- } else if (demoteConv && isa<linalg::Conv2DNgchwGfchwOp>(linalgOp)) {
+ } else if (demoteConv && isa<linalg::Conv2DNgchwGfchwOp>(op)) {
replaceOpInputs(static_cast<linalg::Conv2DNgchwGfchwOp *>(nullptr));
} else {
return failure();
diff --git a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
index 7e370d5..da79610 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/PropagateLinalgTranspose.cpp
@@ -14,6 +14,7 @@
#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "llvm/Support/Debug.h"
@@ -892,7 +893,7 @@
namespace {
-template <typename OpTy, typename ReplTy, int64_t inputIdx>
+template <typename OpTy, int64_t inputIdx>
class NamedOpConversion : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -920,8 +921,32 @@
SmallVector<NamedAttribute> attrs = getPrunedAttributeList(namedOp);
SmallVector<Value> newInputs = namedOp.getInputs();
newInputs[inputIdx] = transpose.getInput();
- rewriter.replaceOpWithNewOp<ReplTy>(namedOp, newInputs,
- namedOp.getDpsInits(), attrs);
+
+ auto replaceOp = [&](auto *typePtr) {
+ rewriter.replaceOpWithNewOp<std::remove_pointer_t<decltype(typePtr)>>(
+ namedOp, newInputs, namedOp.getDpsInits(), attrs);
+ };
+
+ Operation *op = namedOp.getOperation();
+ if (isa<linalg::MatmulTransposeAOp>(op) && inputIdx == 0) {
+ replaceOp(static_cast<linalg::MatmulOp *>(nullptr));
+ } else if (isa<linalg::MatmulTransposeBOp>(op) && inputIdx == 1) {
+ replaceOp(static_cast<linalg::MatmulOp *>(nullptr));
+ } else if (IREE::LinalgExt::isPureMatmul(op) && inputIdx == 0) {
+ replaceOp(static_cast<linalg::MatmulTransposeAOp *>(nullptr));
+ } else if (IREE::LinalgExt::isPureMatmul(op) && inputIdx == 1) {
+ replaceOp(static_cast<linalg::MatmulTransposeBOp *>(nullptr));
+ } else if (isa<linalg::BatchMatmulTransposeAOp>(op) && inputIdx == 0) {
+ replaceOp(static_cast<linalg::BatchMatmulOp *>(nullptr));
+ } else if (isa<linalg::BatchMatmulTransposeBOp>(op) && inputIdx == 1) {
+ replaceOp(static_cast<linalg::BatchMatmulOp *>(nullptr));
+ } else if (IREE::LinalgExt::isPureBatchMatmul(op) && inputIdx == 0) {
+ replaceOp(static_cast<linalg::BatchMatmulTransposeAOp *>(nullptr));
+ } else if (IREE::LinalgExt::isPureBatchMatmul(op) && inputIdx == 1) {
+ replaceOp(static_cast<linalg::BatchMatmulTransposeBOp *>(nullptr));
+ } else {
+ return failure();
+ }
return success();
}
@@ -957,46 +982,18 @@
static void populateNamedOpSinkingPatterns(MLIRContext *context,
RewritePatternSet &sinkingPatterns) {
- sinkingPatterns
- .insert<NamedOpConversion</*OpType=*/linalg::MatmulOp,
- /*ReplacementType=*/linalg::MatmulTransposeBOp,
- /*inputIdx=*/1>>(context,
- SmallVector<int64_t>{1, 0});
- sinkingPatterns
- .insert<NamedOpConversion</*OpType=*/linalg::MatmulOp,
- /*ReplacementType=*/linalg::MatmulTransposeAOp,
- /*inputIdx=*/0>>(context,
- SmallVector<int64_t>{1, 0});
- sinkingPatterns
- .insert<NamedOpConversion</*OpType=*/linalg::MatmulTransposeBOp,
- /*ReplacementType=*/linalg::MatmulOp,
- /*inputIdx=*/1>>(context,
- SmallVector<int64_t>{1, 0});
- sinkingPatterns
- .insert<NamedOpConversion</*OpType=*/linalg::MatmulTransposeAOp,
- /*ReplacementType=*/linalg::MatmulOp,
- /*inputIdx=*/0>>(context,
- SmallVector<int64_t>{1, 0});
- sinkingPatterns.insert<
- NamedOpConversion</*OpType=*/linalg::BatchMatmulOp,
- /*ReplacementType=*/linalg::BatchMatmulTransposeBOp,
- /*inputIdx=*/1>>(context,
- SmallVector<int64_t>{0, 2, 1});
- sinkingPatterns.insert<
- NamedOpConversion</*OpType=*/linalg::BatchMatmulOp,
- /*ReplacementType=*/linalg::BatchMatmulTransposeAOp,
- /*inputIdx=*/0>>(context,
- SmallVector<int64_t>{0, 2, 1});
- sinkingPatterns
- .insert<NamedOpConversion</*OpType=*/linalg::BatchMatmulTransposeBOp,
- /*ReplacementType=*/linalg::BatchMatmulOp,
- /*inputIdx=*/1>>(context,
- SmallVector<int64_t>{0, 2, 1});
- sinkingPatterns
- .insert<NamedOpConversion</*OpType=*/linalg::BatchMatmulTransposeAOp,
- /*ReplacementType=*/linalg::BatchMatmulOp,
- /*inputIdx=*/0>>(context,
- SmallVector<int64_t>{0, 2, 1});
+ sinkingPatterns.insert<NamedOpConversion</*OpType=*/linalg::MatmulOp,
+ /*inputIdx=*/1>>(
+ context, SmallVector<int64_t>{1, 0});
+ sinkingPatterns.insert<NamedOpConversion</*OpType=*/linalg::MatmulOp,
+ /*inputIdx=*/0>>(
+ context, SmallVector<int64_t>{1, 0});
+ sinkingPatterns.insert<NamedOpConversion</*OpType=*/linalg::BatchMatmulOp,
+ /*inputIdx=*/1>>(
+ context, SmallVector<int64_t>{0, 2, 1});
+ sinkingPatterns.insert<NamedOpConversion</*OpType=*/linalg::BatchMatmulOp,
+ /*inputIdx=*/0>>(
+ context, SmallVector<int64_t>{0, 2, 1});
}
static void
diff --git a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp
index 763cc93..a17e9e6 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/RaiseSpecialOps.cpp
@@ -9,6 +9,7 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/GlobalOptimization/Passes.h"
#include "iree/compiler/GlobalOptimization/Utils.h"
#include "llvm/ADT/STLExtras.h"
@@ -448,7 +449,7 @@
// Method to match a linalg.matmul(a, linalg.transpose(b)). Returns `b` on
// success.
static std::optional<Value> matchATransposeBMatmul(linalg::LinalgOp matmulOp) {
- if (!isa<linalg::MatmulOp>(matmulOp.getOperation())) {
+ if (!IREE::LinalgExt::isPureMatmul(matmulOp)) {
return std::nullopt;
}
auto rhs = matmulOp.getDpsInputOperand(1);
@@ -463,7 +464,7 @@
// success.
static std::optional<Value>
matchATransposeBBatchMatmul(linalg::LinalgOp bmmOp) {
- if (!isa<linalg::BatchMatmulOp>(bmmOp.getOperation())) {
+ if (!IREE::LinalgExt::isPureBatchMatmul(bmmOp)) {
return std::nullopt;
}
auto rhs = bmmOp.getDpsInputOperand(1);
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
index 387c8a0..d218508 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/demote_contraction_inputs_to_bf16.mlir
@@ -133,10 +133,19 @@
util.func public @batch_matmul_transpose_a_f32f32f32(%arg0 : tensor<4x250x100xf32>, %arg1 : tensor<4x250x500xf32>,
%arg2 : tensor<4x100x500xf32>) -> tensor<4x100x500xf32> {
- %0 = linalg.batch_matmul_transpose_a ins(%arg0, %arg1 : tensor<4x250x100xf32>, tensor<4x250x500xf32>)
+ %0 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : tensor<4x250x100xf32>, tensor<4x250x500xf32>)
outs(%arg2 : tensor<4x100x500xf32>) -> tensor<4x100x500xf32>
util.return %0 : tensor<4x100x500xf32>
}
+// MATMUL-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// MATMUL-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// MATMUL-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// MATMUL: @batch_matmul_transpose_a_f32f32f32
// MATMUL-SAME: %[[ARG0:.+]]: tensor<4x250x100xf32>
@@ -148,7 +157,8 @@
// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
// MATMUL-SAME: ins(%[[ARG1]] : tensor<4x250x500xf32>)
// MATMUL: arith.truncf {{.*}} : f32 to bf16
-// MATMUL: linalg.batch_matmul_transpose_a
+// MATMUL: linalg.batch_matmul
+// MATMUL-SAME: indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x250x100xbf16>, tensor<4x250x500xbf16>)
// MATMUL-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
@@ -156,10 +166,19 @@
util.func public @batch_matmul_transpose_b_f32f32f32(%arg0 : tensor<4x100x250xf32>, %arg1 : tensor<4x500x250xf32>,
%arg2 : tensor<4x100x500xf32>) -> tensor<4x100x500xf32> {
- %0 = linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : tensor<4x100x250xf32>, tensor<4x500x250xf32>)
+ %0 = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : tensor<4x100x250xf32>, tensor<4x500x250xf32>)
outs(%arg2 : tensor<4x100x500xf32>) -> tensor<4x100x500xf32>
util.return %0 : tensor<4x100x500xf32>
}
+// MATMUL-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// MATMUL-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// MATMUL-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// MATMUL: @batch_matmul_transpose_b_f32f32f32
// MATMUL-SAME: %[[ARG0:.+]]: tensor<4x100x250xf32>
@@ -171,7 +190,8 @@
// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
// MATMUL-SAME: ins(%[[ARG1]] : tensor<4x500x250xf32>)
// MATMUL: arith.truncf {{.*}} : f32 to bf16
-// MATMUL: linalg.batch_matmul_transpose_b
+// MATMUL: linalg.batch_matmul
+// MATMUL-SAME: indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<4x100x250xbf16>, tensor<4x500x250xbf16>)
// MATMUL-SAME: outs(%[[ARG2]] : tensor<4x100x500xf32>)
@@ -179,10 +199,19 @@
util.func public @matmul_transpose_a_f32f32f32(%arg0 : tensor<250x100xf32>, %arg1 : tensor<250x500xf32>,
%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32> {
- %0 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<250x100xf32>, tensor<250x500xf32>)
+ %0 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : tensor<250x100xf32>, tensor<250x500xf32>)
outs(%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32>
util.return %0 : tensor<100x500xf32>
}
+// MATMUL-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// MATMUL-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// MATMUL-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// MATMUL: @matmul_transpose_a_f32f32f32
// MATMUL-SAME: %[[ARG0:.+]]: tensor<250x100xf32>
@@ -194,7 +223,8 @@
// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
// MATMUL-SAME: ins(%[[ARG1]] : tensor<250x500xf32>)
// MATMUL: arith.truncf {{.*}} : f32 to bf16
-// MATMUL: linalg.matmul_transpose_a
+// MATMUL: linalg.matmul
+// MATMUL-SAME: indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<250x100xbf16>, tensor<250x500xbf16>)
// MATMUL-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
@@ -202,10 +232,19 @@
util.func public @matmul_transpose_b_f32f32f32(%arg0 : tensor<100x250xf32>, %arg1 : tensor<500x250xf32>,
%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32> {
- %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<100x250xf32>, tensor<500x250xf32>)
+ %0 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : tensor<100x250xf32>, tensor<500x250xf32>)
outs(%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32>
util.return %0 : tensor<100x500xf32>
}
+// MATMUL-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// MATMUL-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// MATMUL-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// MATMUL: @matmul_transpose_b_f32f32f32
// MATMUL-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
@@ -217,7 +256,8 @@
// MATMUL: %[[DEMOTED1:.+]] = linalg.generic
// MATMUL-SAME: ins(%[[ARG1]] : tensor<500x250xf32>)
// MATMUL: arith.truncf {{.*}} : f32 to bf16
-// MATMUL: linalg.matmul_transpose_b
+// MATMUL: linalg.matmul
+// MATMUL-SAME: indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// MATMUL-SAME: ins(%[[DEMOTED0]], %[[DEMOTED1]] : tensor<100x250xbf16>, tensor<500x250xbf16>)
// MATMUL-SAME: outs(%[[ARG2]] : tensor<100x500xf32>)
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
index ba373f0..ae22d08 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/propagate_linalg_transpose.mlir
@@ -123,9 +123,16 @@
outs(%empty : tensor<16x16xf32>) -> tensor<16x16xf32>
util.return %second_mm : tensor<16x16xf32>
}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-LABEL: util.func public @propagate_to_matmul_ops
-// CHECK: linalg.matmul_transpose_b
-// CHECK: %[[SECOND_MM:.+]] = linalg.matmul_transpose_a
+// CHECK: linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK: %[[SECOND_MM:.+]] = linalg.matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]], #[[$MAP2]]]
// CHECK: util.return %[[SECOND_MM]]
// -----
@@ -136,13 +143,25 @@
%empty = tensor.empty(): tensor<16x16xf32>
%transpose_b = linalg.transpose ins(%rhs : tensor<16x16xf32>)
outs(%empty : tensor<16x16xf32>) permutation = [1, 0]
- %first_mm = linalg.matmul_transpose_b ins(%lhs, %transpose_b : tensor<16x16xf32>, tensor<16x16xf32>)
- outs(%empty : tensor<16x16xf32>) -> tensor<16x16xf32>
+ %first_mm = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%lhs, %transpose_b : tensor<16x16xf32>, tensor<16x16xf32>)
+ outs(%empty : tensor<16x16xf32>) -> tensor<16x16xf32>
%transpose_a = linalg.transpose ins(%second_lhs : tensor<16x16xf32>)
outs(%empty : tensor<16x16xf32>) permutation = [1, 0]
- %second_mm = linalg.matmul_transpose_a ins(%transpose_a, %first_mm : tensor<16x16xf32>, tensor<16x16xf32>)
- outs(%empty : tensor<16x16xf32>) -> tensor<16x16xf32>
+ %second_mm = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%transpose_a, %first_mm : tensor<16x16xf32>, tensor<16x16xf32>)
+ outs(%empty : tensor<16x16xf32>) -> tensor<16x16xf32>
util.return %second_mm : tensor<16x16xf32>
}
// CHECK-LABEL: util.func public @propagate_to_transposed_matmul_ops
@@ -167,9 +186,16 @@
outs(%empty : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
util.return %second_bmm : tensor<2x16x16xf32>
}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
+// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK-LABEL: util.func public @propagate_to_bmm_ops
-// CHECK: linalg.batch_matmul_transpose_b
-// CHECK: %[[SECOND_MM:.+]] = linalg.batch_matmul_transpose_a
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK: %[[SECOND_MM:.+]] = linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]], #[[$MAP2]]]
// CHECK: util.return %[[SECOND_MM]]
// -----
@@ -180,13 +206,25 @@
%empty = tensor.empty(): tensor<2x16x16xf32>
%transpose_b = linalg.transpose ins(%rhs : tensor<2x16x16xf32>)
outs(%empty : tensor<2x16x16xf32>) permutation = [0, 2, 1]
- %first_bmm = linalg.batch_matmul_transpose_b ins(%lhs, %transpose_b : tensor<2x16x16xf32>, tensor<2x16x16xf32>)
- outs(%empty : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+ %first_bmm = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%lhs, %transpose_b : tensor<2x16x16xf32>, tensor<2x16x16xf32>)
+ outs(%empty : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
%transpose_a = linalg.transpose ins(%second_lhs : tensor<2x16x16xf32>)
outs(%empty : tensor<2x16x16xf32>) permutation = [0, 2, 1]
- %second_bmm = linalg.batch_matmul_transpose_a ins(%transpose_a, %first_bmm : tensor<2x16x16xf32>, tensor<2x16x16xf32>)
- outs(%empty : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+ %second_bmm = linalg.batch_matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%transpose_a, %first_bmm : tensor<2x16x16xf32>, tensor<2x16x16xf32>)
+ outs(%empty : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
util.return %second_bmm : tensor<2x16x16xf32>
}
// CHECK-LABEL: util.func public @propagate_to_transposed_bmm_ops
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir
index 60cdd31..9313cf0 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/raise_special_ops.mlir
@@ -188,7 +188,11 @@
// CHECK-LABEL: util.func public @aTransposeBMatmul
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<40x20xf32>
-// CHECK: %[[RESULT:.+]] = linalg.matmul_transpose_b
+// CHECK: %[[RESULT:.+]] = linalg.matmul
+// CHECK-SAME: indexing_maps = [
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d2)>,
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d1, d2)>,
+// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK: util.return %[[RESULT]]
@@ -214,7 +218,11 @@
// CHECK-LABEL: util.func public @aTransposeBBatchMatmul
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x10x20xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<5x40x20xf32>
-// CHECK: %[[RESULT:.+]] = linalg.batch_matmul_transpose_b
+// CHECK: %[[RESULT:.+]] = linalg.batch_matmul
+// CHECK-SAME: indexing_maps = [
+// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+// CHECK-SAME: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK: util.return %[[RESULT]]
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp
index a533339..788716a 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp
@@ -30,9 +30,7 @@
if (!IREE::Flow::isNonNullAndOutsideDispatch(linalgOp)) {
return;
}
- if (isa_and_nonnull<linalg::MatmulOp, linalg::MatmulTransposeBOp,
- linalg::BatchMatmulOp,
- linalg::BatchMatmulTransposeBOp>(linalgOp)) {
+ if (isa_and_nonnull<linalg::MatmulOp, linalg::BatchMatmulOp>(linalgOp)) {
namedOpCandidates.push_back(linalgOp);
}
});
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/conv_to_channels_last.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/conv_to_channels_last.mlir
index 813952d..30b086d 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/conv_to_channels_last.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/conv_to_channels_last.mlir
@@ -125,18 +125,29 @@
%zero = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<2048x1280xf32>
%filled = linalg.fill ins(%zero : f32) outs(%empty : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>
- %res = linalg.matmul_transpose_b
+ %res = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
ins(%arg0, %arg1 : tensor<2048x1280xf16>, tensor<1280x1280xf16>)
outs(%filled : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>
util.return %res : tensor<2048x1280xf32>
}
+// CHECK-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: @mmt_no_transpose
// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul_transpose_b
+// CHECK: linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
+// TILE16-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// TILE16-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// TILE16-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// TILE16-LABEL: @mmt_no_transpose
// TILE16-NOT: linalg.generic
-// TILE16: linalg.matmul_transpose_b
+// TILE16: linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
// -----
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
index efa4a30..6f21af7 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
@@ -236,7 +236,13 @@
// CHECK-SAME: %[[ARG2:.+]]: tensor<10x4xf16>)
func.func @skip_skinny_n_mmtb(%arg0 : tensor<10x20xf16>, %arg1 : tensor<4x20xf16>, %arg2 : tensor<10x4xf16>) -> tensor<10x4xf16>
attributes {hal.device.targets = [#hal.device.target<"rocm", [#rocm_executable_target]>]} {
- %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<10x20xf16>, tensor<4x20xf16>)
+ %0 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : tensor<10x20xf16>, tensor<4x20xf16>)
outs(%arg2 : tensor<10x4xf16>) -> tensor<10x4xf16>
return %0 : tensor<10x4xf16>
}
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir
index ece0283..bfec8dd 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir
@@ -43,7 +43,13 @@
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf16>,
// CHECK-SAME: %[[ARG2:.+]]: tensor<10x?xf16>)
func.func @mmtb_dynamic_k_n(%arg0 : tensor<10x?xf16>, %arg1 : tensor<?x?xf16>, %arg2 : tensor<10x?xf16>) -> tensor<10x?xf16> {
- %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<10x?xf16>, tensor<?x?xf16>)
+ %0 = linalg.matmul
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ]
+ ins(%arg0, %arg1 : tensor<10x?xf16>, tensor<?x?xf16>)
outs(%arg2 : tensor<10x?xf16>) -> tensor<10x?xf16>
return %0 : tensor<10x?xf16>
}
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir
index 49728ce..e61088b 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir
@@ -217,13 +217,17 @@
// -----
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module attributes {transform.with_named_sequence} {
-
// CHECK: func.func @matmul_repeated_operand
func.func @matmul_repeated_operand(%input: tensor<32x64xi8>, %dest: tensor<32x32xi32>) -> tensor<32x32xi32> {
- // CHECK-NEXT: linalg.matmul_transpose_b
+ // CHECK-NEXT: linalg.matmul
// CHECK-SAME: match_status = "matched"
- %res = linalg.matmul_transpose_b {match_status = "unmatched"}
+ %res = linalg.matmul
+ indexing_maps = [#map0, #map1, #map2]
+ {match_status = "unmatched"}
ins(%input, %input : tensor<32x64xi8>, tensor<32x64xi8>)
outs(%dest : tensor<32x32xi32>) -> tensor<32x32xi32>
return %res : tensor<32x32xi32>
@@ -231,9 +235,11 @@
// CHECK: func.func @matmul_non_repeated_operand
func.func @matmul_non_repeated_operand(%input0: tensor<32x64xi8>, %input1: tensor<32x64xi8>, %dest: tensor<32x32xi32>) -> tensor<32x32xi32> {
- // CHECK-NEXT: linalg.matmul_transpose_b
+ // CHECK-NEXT: linalg.matmul
// CHECK-SAME: match_status = "unmatched"
- %res = linalg.matmul_transpose_b {match_status = "unmatched"}
+ %res = linalg.matmul
+ indexing_maps = [#map0, #map1, #map2]
+ {match_status = "unmatched"}
ins(%input0, %input1 : tensor<32x64xi8>, tensor<32x64xi8>)
outs(%dest : tensor<32x32xi32>) -> tensor<32x32xi32>
return %res : tensor<32x32xi32>
@@ -242,7 +248,9 @@
transform.named_sequence @match_matmul_repeated_operand(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
%inputs, %outputs = transform.iree.match.cast_compatible_dag_from_root %arg0 {
^bb0(%arg1: tensor<32x64xi8>, %arg2: tensor<32x32xi32>):
- %1 = linalg.matmul_transpose_b {match_status = "unmatched"}
+ %1 = linalg.matmul
+ indexing_maps = [#map0, #map1, #map2]
+ {match_status = "unmatched"}
ins(%arg1, %arg1 : tensor<32x64xi8>, tensor<32x64xi8>)
outs(%arg2 : tensor<32x32xi32>) -> tensor<32x32xi32>
} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/transpose_matmul.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/transpose_matmul.mlir
index 8f62db5..cc3c68b 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/transpose_matmul.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/transpose_matmul.mlir
@@ -1,9 +1,16 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-preprocessing-transpose-matmul-pass{input=lhs}))" %s | FileCheck %s --check-prefixes=CHECK,LHS
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-preprocessing-transpose-matmul-pass{input=rhs}))" %s | FileCheck %s --check-prefixes=CHECK,RHS
+// LHS-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// LHS-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// LHS-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// RHS-DAG: #[[$MA:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// RHS-DAG: #[[$MB:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// RHS-DAG: #[[$MC:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
// CHECK-LABEL: @matmul
-// LHS: linalg.matmul_transpose_a
-// RHS: linalg.matmul_transpose_b
+// LHS: linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
+// RHS: linalg.matmul indexing_maps = [#[[$MA]], #[[$MB]], #[[$MC]]]
func.func @matmul(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
%cst = arith.constant 0.0 : f32
%init = tensor.empty() : tensor<16x16xf32>
diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py
index 95d990a..253c97b 100755
--- a/tests/e2e/matmul/generate_e2e_matmul_tests.py
+++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py
@@ -719,7 +719,7 @@
acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>"
if transpose_rhs:
- op_name = "linalg.matmul_transpose_b"
+ op_name = "linalg.matmul indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]"
else:
op_name = "linalg.matmul"
diff --git a/tests/external/iree-test-suites/sharktank_models/quality_tests/sdxl/punet_int8_fp16_rocm.json b/tests/external/iree-test-suites/sharktank_models/quality_tests/sdxl/punet_int8_fp16_rocm.json
index 412dcc6..feab91d 100644
--- a/tests/external/iree-test-suites/sharktank_models/quality_tests/sdxl/punet_int8_fp16_rocm.json
+++ b/tests/external/iree-test-suites/sharktank_models/quality_tests/sdxl/punet_int8_fp16_rocm.json
@@ -32,7 +32,7 @@
}
],
"real_weights": "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/punet_weights.irpa",
- "mlir": "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/11-8-2024/punet_fp16.mlir",
+ "mlir": "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/08042025/mlir/stable_diffusion_xl_base_1_0_punet_bs1_64_1024x1024_i8.mlir",
"device": "hip",
"compiler_flags": [
"--iree-hal-target-device=hip",
diff --git a/tests/external/iree-test-suites/sharktank_models/quality_tests/sdxl/punet_int8_fp8_rocm.json b/tests/external/iree-test-suites/sharktank_models/quality_tests/sdxl/punet_int8_fp8_rocm.json
index 0e9d3f6..ba2d9f3 100644
--- a/tests/external/iree-test-suites/sharktank_models/quality_tests/sdxl/punet_int8_fp8_rocm.json
+++ b/tests/external/iree-test-suites/sharktank_models/quality_tests/sdxl/punet_int8_fp8_rocm.json
@@ -32,7 +32,7 @@
}
],
"real_weights": "https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/punet_fp8_weights.irpa",
- "mlir": "https://sharkpublic.blob.core.windows.net/sharkpublic/stan/sdxl-punet/11-26-2024/punet_fp8.mlir",
+ "mlir": "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/08122025_transpose_manual/mlir/punet_fp8.mlir",
"device": "hip",
"compiler_flags": [
"--iree-hal-target-device=hip",
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 7caf2d7..9a14b1d 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 7caf2d7443737d3516cab8ac67083b18db0bf41b
+Subproject commit 9a14b1d254a43dc0d4445c3ffa3d393bca007ba3