Remove transform.structured.lower_vectors
This op has been upstreamed as transform.lower_vectors.
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index b9b7628..276494a 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -689,5 +689,6 @@
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",
+ "@llvm-project//mlir:VectorTransformOps",
],
)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
index c3d7f00..a0197c5 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
@@ -70,28 +70,6 @@
let cppNamespace = "mlir::transform_ext";
}
-def LowerVectorsOp : Op<Transform_Dialect, "lower_vectors",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
- let description = [{Indicates that the vector operations in the entire
- module should be lowered to simpler primitives (multiple stages of lowering
- be executed at once).}];
-
- let arguments =
- (ins DefaultValuedAttr<I64ArrayAttr, "{0, 1, 2, 3, 4, 5, 6}">:$stages,
- DefaultValuedAttr<StrAttr, "\"outerproduct\"">:$contraction_lowering,
- DefaultValuedAttr<StrAttr, "\"innerparallel\"">:$multireduction_lowering,
- DefaultValuedAttr<StrAttr, "\"linalg-copy\"">:$split_transfers,
- DefaultValuedAttr<BoolAttr, "true">:$unroll_vector_transfers,
- DefaultValuedAttr<StrAttr, "\"eltwise\"">:$transpose_lowering,
- DefaultValuedAttr<BoolAttr, "false">:$transpose_avx2_lowering
- );
-
- let assemblyFormat = "attr-dict";
- let cppNamespace = "mlir::transform_ext";
-}
-
def LowerToLLVMOp : Op<Transform_Dialect, "lower_to_llvm",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
index 82b3063..f7392d1 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -1006,119 +1006,6 @@
}
//===---------------------------------------------------------------------===//
-// LowerVectorsOp
-//===---------------------------------------------------------------------===//
-
-/// Returns true of the numbered vector lowering stage is included into the list
-/// of stages specified on the given lowerVectors operation.
-static bool stageIncluded(int stage,
- transform_ext::LowerVectorsOp lowerVectorsOp) {
- for (auto s : lowerVectorsOp.getStages().getAsValueRange<IntegerAttr>()) {
- if (s.getSExtValue() == stage)
- return true;
- }
- return false;
-}
-
-// Applies the transformation specified by the given lower vectors operation
-/// to the given function.
-DiagnosedSilenceableFailure
-transform_ext::LowerVectorsOp::apply(mlir::transform::TransformResults &results,
- mlir::transform::TransformState &state) {
- MLIRContext *ctx = getContext();
- RewritePatternSet patterns(ctx);
-
- vector::VectorTransposeLowering vectorTransposeLowering =
- llvm::StringSwitch<vector::VectorTransposeLowering>(
- getTransposeLowering())
- .Case("eltwise", vector::VectorTransposeLowering::EltWise)
- .Case("flat_transpose", vector::VectorTransposeLowering::Flat)
- .Case("shuffle", vector::VectorTransposeLowering::Shuffle)
- .Default(vector::VectorTransposeLowering::EltWise);
- vector::VectorMultiReductionLowering vectorMultiReductionLowering =
- llvm::StringSwitch<vector::VectorMultiReductionLowering>(
- getMultireductionLowering())
- .Case("innerreduction",
- vector::VectorMultiReductionLowering::InnerReduction)
- .Default(vector::VectorMultiReductionLowering::InnerParallel);
- vector::VectorContractLowering vectorContractLowering =
- llvm::StringSwitch<vector::VectorContractLowering>(
- getContractionLowering())
- .Case("matrixintrinsics", vector::VectorContractLowering::Matmul)
- .Case("dot", vector::VectorContractLowering::Dot)
- .Case("outerproduct", vector::VectorContractLowering::OuterProduct)
- .Default(vector::VectorContractLowering::OuterProduct);
- // TODO: fix the annoying name mismatch (vector-transfers vs vector-transfer).
- vector::VectorTransferSplit vectorTransferSplit =
- llvm::StringSwitch<vector::VectorTransferSplit>(getSplitTransfers())
- .Case("none", vector::VectorTransferSplit::None)
- .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy)
- .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
- .Default(vector::VectorTransferSplit::None);
-
- vector::VectorTransformsOptions vectorTransformOptions;
- vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering)
- .setVectorMultiReductionLowering(vectorMultiReductionLowering)
- .setVectorTransposeLowering(vectorTransposeLowering)
- .setVectorTransferSplit(vectorTransferSplit);
-
- VectorTransferToSCFOptions vectorTransferToSCFOptions =
- VectorTransferToSCFOptions().enableFullUnroll(getUnrollVectorTransfers());
-
- int maxTransferRank = 1;
-
- auto avx2LoweringOptions =
- x86vector::avx2::LoweringOptions().setTransposeOptions(
- x86vector::avx2::TransposeLoweringOptions()
- .lower4x8xf32(getTransposeAvx2Lowering())
- .lower8x8xf32(getTransposeAvx2Lowering()));
-
- // TODO: this is copy-pasta from LinalgStrategyLowerVectorsPass, shouldn't be.
- vector::populateVectorToVectorCanonicalizationPatterns(patterns);
- if (stageIncluded(1, *this)) {
- patterns.add<mlir::vector::ContractionOpToOuterProductOpLowering,
- mlir::vector::ContractionOpToMatmulOpLowering,
- mlir::vector::ContractionOpLowering>(vectorTransformOptions,
- ctx);
- vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
- }
- if (stageIncluded(2, *this)) {
- vector::populateVectorMultiReductionLoweringPatterns(
- patterns, vectorTransformOptions.vectorMultiReductionLowering);
- }
- if (stageIncluded(3, *this)) {
- patterns.add<vector::VectorTransferFullPartialRewriter>(
- ctx, vectorTransformOptions);
- }
- if (stageIncluded(4, *this)) {
- vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank);
- }
- if (stageIncluded(5, *this)) {
- populateVectorToSCFConversionPatterns(
- patterns, vectorTransferToSCFOptions.setTargetRank(maxTransferRank));
- }
- if (stageIncluded(6, *this)) {
- vector::populateVectorShapeCastLoweringPatterns(patterns);
- }
- if (stageIncluded(7, (*this))) {
- vector::populateVectorTransposeLoweringPatterns(patterns,
- vectorTransformOptions);
- if (getTransposeAvx2Lowering())
- x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
- patterns, avx2LoweringOptions, /*benefit=*/10);
- }
-
- // TODO: these transformations are currently not targeted at concrete ops.
- // LinalgTransformationFilter filter = makeTransformationFilter(target);
- if (failed(applyPatternsAndFoldGreedily(state.getTopLevel(),
- std::move(patterns))))
- return DiagnosedSilenceableFailure::definiteFailure();
-
- // TODO: make composable...
- return DiagnosedSilenceableFailure::success();
-}
-
-//===---------------------------------------------------------------------===//
// MatchCallbackOp
//===---------------------------------------------------------------------===//
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
index 518e04b..3097c74 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
@@ -19,8 +19,10 @@
transform.structured.vectorize %5
// CHECK: bufferize
bufferize
- // CHECK: lower_vectors {multireduction_lowering = "innerreduce"}
- lower_vectors { multireduction_lowering = "innerreduce"}
+ // CHECK: %[[FUNC:.*]] = transform.structured.match ops{["func.func"]} in %arg0
+ // CHECK: lower_vectors %[[FUNC]] {multireduction_lowering = "innerreduce"}
+ %6 = transform.structured.match ops{["func.func"]} in %arg0
+ transform.vector.lower_vectors %6 { multireduction_lowering = "innerreduce"}
// CHECK: lower_to_llvm
lower_to_llvm
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
index d0df60f..c832533 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
@@ -20,6 +20,7 @@
%2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation
transform.structured.vectorize %2 { vectorize_padding }
bufferize
- lower_vectors { multireduction_lowering = "innerreduce"}
+ %3 = transform.structured.match ops{["func.func"]} in %module_op
+ transform.vector.lower_vectors %3 { multireduction_lowering = "innerreduce"}
lower_to_llvm
}
diff --git a/llvm-external-projects/iree-dialects/test/python/dialects/iree_structured_transform.py b/llvm-external-projects/iree-dialects/test/python/dialects/iree_structured_transform.py
deleted file mode 100644
index 57e76a8..0000000
--- a/llvm-external-projects/iree-dialects/test/python/dialects/iree_structured_transform.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# RUN: %PYTHON %s | FileCheck %s
-
-import iree.compiler.ir as ir
-import iree.compiler.dialects.transform.iree_structured as iree_structured_transform
-import iree.compiler._mlir_libs._ireeDialects.transform
-
-
-def constructAndPrintInModule(f):
- print("\nTEST:", f.__name__)
- with ir.Context() as ctx, ir.Location.unknown():
- iree.compiler._mlir_libs._ireeDialects.transform.register_dialect(ctx)
- module = ir.Module.create()
- with ir.InsertionPoint(module.body):
- f()
- print(module)
- return f
-
-
-# CHECK-LABEL: TEST: testLowerVectorsOp
-# CHECK: transform.lower_vectors {contraction_lowering = "outerproduct", multireduction_lowering = "innerparallel", split_transfers = "linalg-copy", stages = [1], transpose_avx2_lowering = false, transpose_lowering = "shuffle", unroll_vector_transfers = true}
-@constructAndPrintInModule
-def testLowerVectorsOp():
- op = iree_structured_transform.LowerVectorsOp(
- contraction_lowering="outerproduct",
- multireduction_lowering="innerparallel",
- split_transfers="linalg-copy",
- stages=[1],
- transpose_avx2_lowering=False,
- transpose_lowering="shuffle",
- unroll_vector_transfers=True)
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
index 3741c34..9b87ed5 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
@@ -28,6 +28,7 @@
MLIRSCFTransforms
MLIRTensorDialect
MLIRTransforms
+ MLIRVectorTransformOps
)
add_llvm_tool(iree-dialects-opt
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
index a3021ce..d4d8826 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
@@ -28,6 +28,7 @@
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
@@ -67,7 +68,8 @@
mlir::pdl_interp::PDLInterpDialect,
mlir::scf::SCFDialect,
mlir::tensor::TensorDialect,
- mlir::transform::TransformDialect
+ mlir::transform::TransformDialect,
+ mlir::vector::VectorDialect
// clang-format on
>();
@@ -91,6 +93,7 @@
transform_ext::StructuredTransformOpsExtension>();
mlir::linalg::registerTransformDialectExtension(registry);
mlir::scf::registerTransformDialectExtension(registry);
+ mlir::vector::registerTransformDialectExtension(registry);
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,