Convert Einsum to use Matmul/BMM (#6653)
diff --git a/integrations/tensorflow/build_tools/overlay/mlir-hlo/BUILD.bazel b/integrations/tensorflow/build_tools/overlay/mlir-hlo/BUILD.bazel
index afeede3..687d9fc 100644
--- a/integrations/tensorflow/build_tools/overlay/mlir-hlo/BUILD.bazel
+++ b/integrations/tensorflow/build_tools/overlay/mlir-hlo/BUILD.bazel
@@ -24,6 +24,7 @@
for name in [
"chlo_legalize_to_hlo",
"legalize_control_flow",
+ "legalize_einsum_to_dot_general",
"legalize_gather_to_torch_index_select",
"legalize_to_linalg",
"lhlo",
diff --git a/iree/compiler/InputConversion/MHLO/BUILD b/iree/compiler/InputConversion/MHLO/BUILD
index aaf1475..396ece1 100644
--- a/iree/compiler/InputConversion/MHLO/BUILD
+++ b/iree/compiler/InputConversion/MHLO/BUILD
@@ -88,6 +88,7 @@
"@llvm-project//mlir:Transforms",
"@mlir-hlo//:chlo_legalize_to_hlo",
"@mlir-hlo//:hlo",
+ "@mlir-hlo//:legalize_einsum_to_dot_general",
"@mlir-hlo//:legalize_gather_to_torch_index_select",
"@mlir-hlo//:legalize_to_linalg",
"@mlir-hlo//:map_lmhlo_to_scalar_op",
diff --git a/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp b/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
index 5e3655d..72a209d 100644
--- a/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
+++ b/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
@@ -842,6 +842,8 @@
}
OwningRewritePatternList patterns(&getContext());
+ // TODO: Remove once we have a general contraction to matmul pass.
+ mhlo::PopulateEinsumToDotGeneralPatterns(context, &patterns);
mhlo::PopulateUnfuseBatchNormPatterns(context, &patterns);
mhlo::PopulateComplexLoweringPatterns(context, &patterns);
mhlo::PopulateGatherToTorchIndexSelectPatterns(context, &patterns);