Use upstream MLIR Transform dialect (#9057)
* Cherry-pick LLVM to add support for Tile Transform Op
3c2a74a3ae02d16e899e280953c055f92aa6cdaa
* Port transform ops to use upstream interfaces
MLIR upstream now includes the Transform dialect along with op
interfaces and the new extension mechanism. Port most transform ops to
use the upstream interface. The LinalgTransform dialect still contains
the "util" ops for scoped transformation, which arguably should have
been placed in a separate dialect to start with. These will be cleaned
up separately as a decision is required whether to keep them at all.
Notable changes in the naming scheme:
- `iree_linalg_transform.` goes away, general `transform.` should
be used instead as the ops are now injected into the transform
dialect;
- `*.iree_*` transform ops are now `transform.iree.*` to align on the
naming scheme of the upstream transform dialect;
- transform ops related to the structured ops philosophy are prefixed
with `structured.`, e.g., `transform.structured.pad`;
- due to a bug in the upstream parser, three-piece op names cannot
omit the dialect prefix inside transform ops, i.e.
`transform.structured.pad` must be spelled out completely instead of
using the shorter `structured.pad` form.diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD
index 6f7c1aa..575d711 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD
@@ -46,6 +46,8 @@
# TODO: Remove this dependency once the transform dialect extensions
# have a better registration mechanism.
"//compiler/src/iree/compiler/Codegen/TransformDialectExtensions",
+ "//llvm-external-projects/iree-dialects:IREELinalgExtTransformOps",
+ "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
],
)
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
index 01691e8..7d51f10 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
@@ -20,6 +20,8 @@
DEPS
::BufferizationInterfaces
::ProcessorOpInterfaces
+ IREELinalgExtTransformOps
+ IREELinalgTransformDialect
iree::compiler::Codegen::TransformDialectExtensions
PUBLIC
)
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
index e682c69..33b7c8c 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
@@ -10,6 +10,8 @@
#include "iree/compiler/Codegen/Interfaces/ProcessorOpInterfaces.h"
// TODO: Remove this dependency once the transform dialect extensions
// have a better registration mechanism.
+#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h"
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
#include "iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.h"
namespace mlir {
@@ -21,6 +23,8 @@
// TODO: Remove this dependency once the transform dialect extensions
// have a better registration mechanism.
// TODO: when warranted, move to its own file.
+ registry.addExtensions<IREE::LinalgExt::LinalgExtTransformOpsExtension,
+ transform_ext::StructuredTransformOpsExtension>();
registerLinalgTransformDialectExtension(registry);
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
index adf5343..c673c23 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -84,6 +84,7 @@
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TosaToArith",
+ "@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorOps",
"@llvm-project//mlir:VectorToLLVM",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index ecf1038..a6820d4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -66,6 +66,7 @@
MLIRTensor
MLIRTosa
MLIRTosaToArith
+ MLIRTransformDialect
MLIRTransforms
MLIRVector
MLIRVectorToLLVM
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
index 90b2a8e..491fbc6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
@@ -51,6 +52,7 @@
pdl_interp::PDLInterpDialect,
scf::SCFDialect,
tensor::TensorDialect,
+ transform::TransformDialect,
vector::VectorDialect>();
// clang-format on
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
index e2c3023..abb6635 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/linalg_transform_spec.mlir
@@ -1,14 +1,18 @@
// RUN: iree-opt %s
-pdl.pattern @pdl_matmul_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_matmul_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @pdl_matmul_target
- iree_bufferize
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_matmul_target in %arg1
+ transform.iree.bufferize
+ }
}
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/BUILD b/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/BUILD
index edfe023..f8e4090 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/BUILD
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/BUILD
@@ -45,6 +45,7 @@
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:TransformUtils",
],
)
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/CMakeLists.txt
index 6e25226..f7319b7 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/CMakeLists.txt
@@ -35,6 +35,7 @@
MLIRPass
MLIRSCF
MLIRTensor
+ MLIRTransformDialect
MLIRTransformUtils
iree::compiler::Codegen
iree::compiler::Codegen::Common
diff --git a/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.cpp b/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.cpp
index 955b9fd..1625c41 100644
--- a/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformDialectExtensions/TransformDialectExtensions.cpp
@@ -8,8 +8,6 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
-#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
-#include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h"
#include "iree-dialects/Transforms/Functional.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
@@ -20,8 +18,11 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/PassManager.h"
using namespace mlir;
@@ -81,21 +82,21 @@
// TODO: Move to tablegen. Until this stabilizes upstream, simple C++ is enough.
class IREEBufferizeOp
- : public Op<IREEBufferizeOp,
- linalg::transform::TransformOpInterface::Trait> {
+ : public Op<IREEBufferizeOp, transform::TransformOpInterface::Trait,
+ MemoryEffectOpInterface::Trait> {
public:
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static constexpr llvm::StringLiteral getOperationName() {
- return llvm::StringLiteral("iree_linalg_transform.iree_bufferize");
+ return llvm::StringLiteral("transform.iree.bufferize");
}
Value target() { return nullptr; }
- LogicalResult apply(linalg::transform::TransformResults &results,
- linalg::transform::TransformState &state) {
+ LogicalResult apply(transform::TransformResults &results,
+ transform::TransformState &state) {
PassManager pm(getContext());
// Bufferize the dispatch.
using mlir::bufferization::BufferizationOptions;
@@ -123,26 +124,34 @@
void print(OpAsmPrinter &printer) {
printer.printOptionalAttrDict((*this)->getAttrs());
}
+
+ // This transform may affect the entirety of the payload IR.
+ void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(),
+ transform::PayloadIRResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ transform::PayloadIRResource::get());
+ }
};
// TODO: Move to tablegen. Until this stabilizes upstream, simple C++ is enough.
class IREESetNumWorkgroupToOneOp
: public Op<IREESetNumWorkgroupToOneOp,
- linalg::transform::TransformOpInterface::Trait> {
+ transform::TransformOpInterface::Trait,
+ MemoryEffectOpInterface::Trait> {
public:
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static constexpr llvm::StringLiteral getOperationName() {
- return llvm::StringLiteral(
- "iree_linalg_transform.iree_set_num_workgroups_to_one");
+ return llvm::StringLiteral("transform.iree.set_num_workgroups_to_one");
}
Value target() { return nullptr; }
- LogicalResult apply(linalg::transform::TransformResults &results,
- linalg::transform::TransformState &state) {
+ LogicalResult apply(transform::TransformResults &results,
+ transform::TransformState &state) {
auto variantOp = dyn_cast<HAL::ExecutableVariantOp>(state.getTopLevel());
if (!variantOp) return failure();
return iree_compiler::setNumWorkgroupsImpl(variantOp, {});
@@ -158,19 +167,26 @@
void print(OpAsmPrinter &printer) {
printer.printOptionalAttrDict((*this)->getAttrs());
}
+
+ // This transform may affect the entirety of the payload IR.
+ void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(),
+ transform::PayloadIRResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(),
+ transform::PayloadIRResource::get());
+ }
};
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
/// types for operands and results.
class LinalgTransformDialectExtension
- : public mlir::linalg::transform::TransformDialectExtension<
+ : public mlir::transform::TransformDialectExtension<
LinalgTransformDialectExtension> {
public:
LinalgTransformDialectExtension() {
declareDependentDialect<pdl::PDLDialect>();
- registerTransformOp<IREEBufferizeOp>();
- registerTransformOp<IREESetNumWorkgroupToOneOp>();
+ registerTransformOps<IREEBufferizeOp, IREESetNumWorkgroupToOneOp>();
// TODO: hook up to Tablegen.
// registerTransformOps<
// #define GET_OP_LIST
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
index 7926c3e..d4605e0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
@@ -67,6 +67,7 @@
"@llvm-project//mlir:PDLDialect",
"@llvm-project//mlir:PDLInterpDialect",
"@llvm-project//mlir:ToLLVMIRTranslation",
+ "@llvm-project//mlir:TransformDialect",
],
)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
index 302aed1..cf2642b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
@@ -43,6 +43,7 @@
MLIRPDL
MLIRPDLInterp
MLIRTargetLLVMIRExport
+ MLIRTransformDialect
iree::compiler::Codegen::Common
iree::compiler::Codegen::Dialect::IREECodegenDialect
iree::compiler::Codegen::LLVMCPU
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
index 98c5ca4..1f4e6e5 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -31,6 +31,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
@@ -154,6 +155,7 @@
registry.insert<IREE::Codegen::IREECodegenDialect,
IREE::LinalgExt::IREELinalgExtDialect,
linalg::transform::LinalgTransformDialect,
+ mlir::transform::TransformDialect,
pdl::PDLDialect,
pdl_interp::PDLInterpDialect,
arm_neon::ArmNeonDialect>();
diff --git a/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir b/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir
index 3f60f5e..32f6369 100644
--- a/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir
+++ b/iree/test/e2e/linalg_transform/linalg_transform_spec.mlir
@@ -1,15 +1,19 @@
// RUN: iree-opt %s
-pdl.pattern @pdl_matmul_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_matmul_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @pdl_matmul_target
- iree_set_num_workgroups_to_one
- iree_bufferize
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_matmul_target in %arg1
+ transform.iree.set_num_workgroups_to_one
+ transform.iree.bufferize
+ }
}
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index f714525..e899fcd 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -226,6 +226,7 @@
DEPS
IREEInputDialect
IREELinalgExtDialect
+ IREELinalgExtTransformOps
IREELinalgExtTransforms
IREELinalgExtOpInterfaceImpl
IREELinalgTransformDialect
diff --git a/iree/tools/init_mlir_dialects.h b/iree/tools/init_mlir_dialects.h
index b429cdf..4c74093 100644
--- a/iree/tools/init_mlir_dialects.h
+++ b/iree/tools/init_mlir_dialects.h
@@ -31,6 +31,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
@@ -61,6 +62,7 @@
mlir::arith::ArithmeticDialect,
vector::VectorDialect,
tensor::TensorDialect,
+ transform::TransformDialect,
tosa::TosaDialect,
shape::ShapeDialect>();
// clang-format on
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index 74ac42b..431acc5 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -56,6 +56,7 @@
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PDLDialectTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
+ "@llvm-project//mlir:TransformDialectTdFiles",
],
)
@@ -342,6 +343,43 @@
],
)
+gentbl_cc_library(
+ name = "IREELinalgExtTransformOpsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td",
+ deps = [
+ ":TdFiles",
+ "@llvm-project//mlir:SideEffectInterfacesTdFiles",
+ ],
+)
+
+cc_library(
+ name = "IREELinalgExtTransformOps",
+ srcs = glob(["lib/Dialect/LinalgExt/TransformOps/*.cpp"]),
+ hdrs = glob(["include/iree-dialects/Dialect/LinalgExt/TransformOps/*.h"]),
+ deps = [
+ ":IREEDialectsTransforms",
+ ":IREELinalgExtDialect",
+ ":IREELinalgExtTransformOpsIncGen",
+ ":IREELinalgExtTransforms",
+ ":IREELinalgTransformDialect",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:TransformDialect",
+ ],
+)
+
cc_library(
name = "IREELinalgExtTransforms",
srcs = glob([
@@ -614,6 +652,26 @@
],
)
+gentbl_cc_library(
+ name = "IREELinalgTransformStructuredIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td",
+ deps = [
+ ":TdFiles",
+ ],
+)
+
cc_library(
name = "IREELinalgTransformDialect",
srcs = glob([
@@ -630,6 +688,7 @@
":IREELinalgExtTransforms",
":IREELinalgTransformIncGen",
":IREELinalgTransformInterfacesIncGen",
+ ":IREELinalgTransformStructuredIncGen",
"@llvm-project//llvm:Support",
# Dialects
"@llvm-project//mlir:Affine",
@@ -646,6 +705,7 @@
"@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TransformDialect",
# IR
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
@@ -682,7 +742,6 @@
"lib/Dialect/LinalgTransform/Transforms/*.cpp",
]),
hdrs = [
- "include/iree-dialects/Dialect/LinalgTransform/TrackingCSE.h",
"include/iree-dialects/Dialect/LinalgTransform/TrackingListener.h",
"include/iree-dialects/Dialect/LinalgTransform/TrackingRewriteDriver.h",
],
@@ -693,6 +752,7 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:AffineUtils",
+ "@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:ArithmeticTransforms",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationTransforms",
@@ -716,6 +776,7 @@
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:TensorUtils",
+ "@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorOps",
@@ -783,6 +844,7 @@
":IREEInputDialect",
":IREELinalgExtDialect",
":IREELinalgExtPasses",
+ ":IREELinalgExtTransformOps",
":IREELinalgTransformDialect",
":IREELinalgTransformDialectTransforms",
":IREEPyDMDialect",
@@ -795,6 +857,7 @@
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:LinalgTransformOps",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:PDLDialect",
@@ -802,6 +865,7 @@
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",
],
)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/CMakeLists.txt
index 5a7289b..4391ced 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Passes)
+add_subdirectory(TransformOps)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000..29eb823
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/CMakeLists.txt
@@ -0,0 +1,9 @@
+function(_add_transform_dialect_extension)
+ set(LLVM_TARGET_DEFINITIONS LinalgExtTransformOps.td)
+ mlir_tablegen(LinalgExtTransformOps.h.inc -gen-op-decls)
+ mlir_tablegen(LinalgExtTransformOps.cpp.inc -gen-op-defs)
+ add_public_tablegen_target(IREELinalgExtTransformOpsIncGen)
+ add_dependencies(mlir-headers IREELinalgExtTransformOpsIncGen)
+endfunction()
+
+_add_transform_dialect_extension()
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h
new file mode 100644
index 0000000..d4cb12a
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h
@@ -0,0 +1,50 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMOPS_LINALGEXTTRANSFORMOPS_H
+#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMOPS_LINALGEXTTRANSFORMOPS_H
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpTraits.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace scf {
+class ForOp;
+} // namespace scf
+
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+class InParallelOp;
+class TileOp;
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h.inc"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+class LinalgExtTransformOpsExtension
+ : public transform::TransformDialectExtension<
+ LinalgExtTransformOpsExtension, IREELinalgExtDialect> {
+public:
+ LinalgExtTransformOpsExtension();
+};
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMOPS_LINALGEXTTRANSFORMOPS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
new file mode 100644
index 0000000..7000b32
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.td
@@ -0,0 +1,198 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_LINALGEXT_TRANSFORMOPS
+#define IREE_DIALECT_LINALGEXT_TRANSFORMOPS
+
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+
+def TargetableSingleOperandTransformOp
+ : NativeOpTrait<"TargetableSingleOperandOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+}
+
+def FunctionalStyleTransformOp
+ : NativeOpTrait<"FunctionalStyleTransformOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+}
+
+def PayloadTransformOp
+ : NativeOpTrait<"PayloadTransformOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+}
+
+def FuseProducersOp : Op<Transform_Dialect, "fuse_producers",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{Fuses the producers for the operands to fuse.}];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$operands_to_fuse);
+ let results = (outs PDL_Operation:$transformed,
+ Variadic<PDL_Operation>:$fused_ops);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+ let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
+}
+
+def TileToLinalgExtTileOp :
+ Op<Transform_Dialect, "tile_to_iree_linalg_ext_tile_op",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{Tile a linalg op with linalg_ext.tile op along a single
+ dimension.}];
+
+ let summary = [{
+ 0 should be used as a tile size to skip tiling a particular dimension.
+ This is a commonly used convention in Linalg.
+ For the purpose of this transformation, a single non-zero positive tile
+ size is allowed.
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$sizes);
+ let results = (outs PDL_Operation:$tiled_op,
+ PDL_Operation:$tile_op);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
+}
+
+def FuseIntoContainingOp :
+ Op<Transform_Dialect, "fuse_into_containing_op",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{Fuse a producer into a containing operation.}];
+
+ let summary = [{
+ Search the body of the containing operation for all producer uses and
+ compute the accessed producer slices on-the-fly.
+ }];
+
+ let arguments = (ins PDL_Operation:$producer_op,
+ PDL_Operation:$containing_op);
+
+ let assemblyFormat = "$producer_op `into` $containing_op attr-dict";
+ let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
+}
+
+def RewriteLinalgExtTileToScfForOp :
+ Op<Transform_Dialect, "rewrite_iree_linalg_ext_tile_to_scf_for",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+
+ let description = [{Rewrite linalg_ext.tile op to scf.for.}];
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
+ ::mlir::iree_compiler::IREE::LinalgExt::TileOp target);
+ }];
+}
+
+def RewriteLinalgExtTileToInParallelOp :
+ Op<Transform_Dialect, "rewrite_iree_linalg_ext_tile_to_in_parallel",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+
+ let description = [{Rewrite linalg_ext.tile op to linalg_ext.in_parallel.}];
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::iree_compiler::IREE::LinalgExt::InParallelOp> applyToOne(
+ ::mlir::iree_compiler::IREE::LinalgExt::TileOp target);
+ }];
+}
+
+def RewriteLinalgExtInParallelToAsyncOp :
+ Op<Transform_Dialect, "rewrite_iree_linalg_ext_in_parallel_to_async",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+
+ let description = [{Rewrite linalg_ext.in_parallel op to the async dialect.}];
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::Operation *> applyToOne(
+ ::mlir::iree_compiler::IREE::LinalgExt::InParallelOp target);
+ }];
+}
+
+def RewriteLinalgExtInParallelToHALOp :
+ Op<Transform_Dialect, "rewrite_iree_linalg_ext_in_parallel_to_hal",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+
+ let description = [{Rewrite linalg_ext.in_parallel op to use HAL ops.}];
+ let arguments = (ins PDL_Operation:$target);
+ // TODO: Determine whether we want to return something here, the only natural
+ // results would be the resulting insertTensorOps.
+ // let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
+}
+
+def RewriteLinalgExtInParallelToScfForOp :
+ Op<Transform_Dialect, "rewrite_iree_linalg_ext_in_parallel_to_scf_for",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+
+ let description = [{Rewrite linalg_ext.in_parallel to a sequential scf.for.
+
+ Warning: when the linalg_ext.in_parallel terminator operates on tensors,
+ this is a potentially dangerous transformation under the current semantics.
+ In order for this transformation to be semantics-preserving, 2 conditions need
+ to come together that are currently not checked and the subject of future analyses:
+ 1. The terminator of linalg_ext.in_parallel may compose the output tensor in
+ potentially undefined ways: if the linalg_ext.parallel_insert_slice regions overlap, they
+ may occur in any order and the result is unspecified. A future overlap/intersection
+ analysis will be needed to guard against this case.
+ 2. Even when condition 1. has well-defined behavior, semantics altering behavior may
+ still be introduced to simplify inplace bufferization. In the current implementation,
+ all linalg_ext.parallel_insert_slice dest operands are optimistically turned into scf.for
+ iter_args. This is optimistic because any use of a tensor inside linalg_ext.in_parallel
+ is guaranteed to see the value before the start of the op; whereas the same use may
+ see a partially updated sequential result in the scf.for op.
+ An extra analysis is needed to ensure that a particular read of a result tensor would
+ see the initial value and not a partial update. This is guaranteed by construction when
+ the linalg_ext.in_parallel op is produced by lowering a linalg_ext.tile operation but
+ is not something that is generally enforced in the IR.
+ For the moment we perform the replacement of the use with the scf.for iter_arg to be
+ able to connect pieces inplace with bufferization. In the future an analysis will be
+ needed to ensure correctness of this lowering to sequential scf.for + iter_args.
+ }];
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "mlir::iree_compiler::IREE::LinalgExt";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
+ ::mlir::iree_compiler::IREE::LinalgExt::InParallelOp target);
+ }];
+}
+
+#endif // IREE_DIALECT_LINALGEXT_TRANSFORMOPS
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt
index ad1594c..1329163 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt
@@ -16,5 +16,14 @@
add_dependencies(mlir-headers IREELinalgTransformIncGen)
endfunction()
+function(_add_transform_dialect_extension)
+ set(LLVM_TARGET_DEFINITIONS StructuredTransformOpsExt.td)
+ mlir_tablegen(StructuredTransformOpsExt.h.inc -gen-op-decls)
+ mlir_tablegen(StructuredTransformOpsExt.cpp.inc -gen-op-defs)
+ add_public_tablegen_target(IREELinalgTransformExtIncGen)
+ add_dependencies(mlir-headers IREELinalgTransformExtIncGen)
+endfunction()
+
_add_interface()
-_add_dialect()
\ No newline at end of file
+_add_dialect()
+_add_transform_dialect_extension()
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td
index e63b39b..a832391 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td
@@ -60,477 +60,6 @@
//===----------------------------------------------------------------------===//
-def SequenceOp : Linalg_Transform_Operation<"sequence",
- [NoTerminator, OpAsmOpInterface]> {
- let description = [{Contains a sequence of transformation ops to apply.
-
- The transformations indicated by the sequence are applied in order of their
- appearance. Each value produced by a transformation within the sequence
- corresponds to an operation or a group of operations in the IR being
- transformed. Therefore, each value may be used at most once by another
- transformation operation as the transformation is likely to replace the
- transformed operation with another operation or a group thereof. In such
- cases, the tranfsormation operation is expected to produce a new value to
- denote the newly produced operations that can be transformed further.
- }];
-
- let regions = (region SizedRegion<1>:$body);
- let assemblyFormat = "attr-dict-with-keyword regions";
-
- let extraClassDeclaration = [{
- static StringRef getDefaultDialect() { return "iree_linalg_transform"; }
- }];
-
- let hasVerifier = 1;
-}
-
-//===----------------------------------------------------------------------===//
-
-def MatchOp : Transform_Op<"match"> {
- let description = [{ Find and return an op that matches the specific PDL
- pattern. When executed inside a sequence, it returns all matching ops. }];
-
- let arguments = (ins SymbolRefAttr:$targetMatcher);
- let results = (outs PDL_Operation:$target);
-
- let assemblyFormat = "$targetMatcher attr-dict";
-}
-
-//===----------------------------------------------------------------------===//
-
-def TileOp : Linalg_Transform_Operation<"tile",
- [TransformOpInterface]> {
- let description = [{Indicates that ops of a specific kind in the given
- function should be tiled with the options provided as attributes.}];
-
- let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$sizes,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
- let results = (outs PDL_Operation:$tiled_linalg_op,
- Variadic<PDL_Operation>:$loops);
-
- let hasCustomAssemblyFormat = 1;
-
- let extraClassDeclaration = [{
- ::mlir::LogicalResult apply(
- ::mlir::linalg::transform::TransformResults &transformResults,
- ::mlir::linalg::transform::TransformState &state);
- }];
-}
-
-def ScalarizeOp : Linalg_Transform_Operation<"scalarize",
- [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
- let description = [{Indicates that ops of a specific kind in the given
- function should be scalarized (i.e. their dynamic dimensions tiled by 1).
-
- This operation returns the tiled op but not the loops.
-
- We make this design choice because it is hard to know ahead of time the
- number of loops that will be produced (it depends on the number of
- dynamic dimensions after multiple transformations have been applied).
- }];
-
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$tiled_linalg_op);
-
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
- ::mlir::linalg::LinalgOp target);
- }];
-}
-
-def FuseOp : Linalg_Transform_Operation<"fuse",
- [TransformOpInterface]> {
- let description = [{Tiles the operations pointed to by the target handle and
- fuses their producers greedily using the options provided as attributes.}];
-
- let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange
- );
- let results = (outs PDL_Operation:$transformed,
- Variadic<PDL_Operation>:$loops);
-
- let hasCustomAssemblyFormat = 1;
- let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- ::mlir::LogicalResult apply(
- ::mlir::linalg::transform::TransformResults &transformResults,
- ::mlir::linalg::transform::TransformState &state);
- }];
-}
-
-
-def FuseProducersOp : Linalg_Transform_Operation<"fuse_producers",
- [TransformOpInterface]> {
- let description = [{Fuses the producers for the operands to fuse.}];
-
- let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$operands_to_fuse);
- let results = (outs PDL_Operation:$transformed,
- Variadic<PDL_Operation>:$fused_ops);
-
- let hasCustomAssemblyFormat = 1;
- let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- ::mlir::LogicalResult apply(
- ::mlir::linalg::transform::TransformResults &transformResults,
- ::mlir::linalg::transform::TransformState &state);
- }];
-}
-
-def GeneralizeOp : Linalg_Transform_Operation<"generalize",
- [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
- let description = [{Generalizes the operations pointed to
- by the target handle.}];
-
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
- ::mlir::linalg::LinalgOp target);
- }];
-}
-
-def InterchangeOp : Linalg_Transform_Operation<"interchange",
- [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
- let description = [{Interchanges the iterators of the operations pointed to
- by the target handle using the iterator interchange attribute.}];
-
- let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$iterator_interchange);
- let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
- let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
- ::mlir::linalg::LinalgOp target);
- }];
-}
-
-def PadOp : Linalg_Transform_Operation<"pad",
- [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
- let description = [{Pads the operations pointed to by the target handle
- using the options provides as operation attributes.}];
-
- let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$pack_paddings,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$hoist_paddings,
- DefaultValuedAttr<
- TypedArrayAttrBase<I64ArrayAttr,
- "array of arrays of i64">,
- "{}">:$transpose_paddings);
- let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
- let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
- ::mlir::linalg::LinalgOp target);
- }];
-}
-
-def BufferizeOp : Transform_Op<"bufferize"> {
- let description = [{Indicates that the entire module should be bufferized.}];
- let assemblyFormat = "attr-dict";
-}
-
-def DecomposeOp : Transform_Op<"decompose"> {
- let description = [{Indicates that ops in the entire module should be
- decomposed into lower-level components.}];
- let assemblyFormat = "attr-dict";
-}
-
-def VectorizeOp : Transform_Op<"vectorize"> {
- let description = [{Indiactes that vectorization should be performed. If a
- target handle is provided, only vectorizes the operations pointed to by the
- handle. Otherwise vectorizes the entire module. Vectorization options are
- provided as operation attributes.}];
-
- let arguments = (ins Optional<PDL_Operation>:$target,
- DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding
- );
- let results = (outs Optional<PDL_Operation>:$transformed);
-
- let hasCustomAssemblyFormat = 1;
-}
-
-def LowerVectorsOp : Transform_Op<"lower_vectors"> {
- let description = [{Indicates that the vector operations in the entire
- module should be lowered to simpler primitives (multiple stages of lowering
- be executed at once).}];
-
- let arguments =
- (ins DefaultValuedAttr<I64ArrayAttr, "{0, 1, 2, 3, 4, 5, 6}">:$stages,
- DefaultValuedAttr<StrAttr, "\"outerproduct\"">:$contraction_lowering,
- DefaultValuedAttr<StrAttr, "\"innerparallel\"">:$multireduction_lowering,
- DefaultValuedAttr<StrAttr, "\"linalg-copy\"">:$split_transfers,
- DefaultValuedAttr<BoolAttr, "true">:$unroll_vector_transfers,
- DefaultValuedAttr<StrAttr, "\"eltwise\"">:$transpose_lowering,
- DefaultValuedAttr<BoolAttr, "false">:$transpose_avx2_lowering
- );
-
- let assemblyFormat = "attr-dict";
-}
-
-def LowerToLLVMOp : Transform_Op<"lower_to_llvm"> {
- let description = [{Indicates that the entire module should be converted
- to the LLVM dialect. This is expected to be the last transformation in
- a sequence.}];
-
- let arguments =
- (ins DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
- DefaultValuedAttr<BoolAttr, "false">:$enable_index_optimizations,
- DefaultValuedAttr<BoolAttr, "false">:$enable_arm_neon,
- DefaultValuedAttr<BoolAttr, "false">:$enable_arm_sve,
- DefaultValuedAttr<BoolAttr, "false">:$enable_amx,
- DefaultValuedAttr<BoolAttr, "false">:$enable_x86vector,
- DefaultValuedAttr<BoolAttr, "false">:$enable_async);
-
- let assemblyFormat = "attr-dict";
-}
-
-def GetParentLoopOp : Linalg_Transform_Operation<"get_parent_loop", [
- TransformOpInterface,
- TargetableSingleOperandTransformOpTrait
- ]> {
- let description = [{Obtains a handle to a parent loop of the given
- operation.}];
-
- let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<Confined<I64Attr, [IntPositive]>,
- "1">:$num_loops);
- let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(Operation *source);
- }];
-}
-
-def UnrollLoopOp : Linalg_Transform_Operation<"unroll_loop", [
- TransformOpInterface,
- TargetableSingleOperandTransformOpTrait
- ]> {
- let description = [{Unrolls the given loop with the given unroll factor.}];
-
- let arguments = (ins PDL_Operation:$target,
- Confined<I64Attr, [IntPositive]>:$factor);
-
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- LogicalResult applyToOne(::mlir::scf::ForOp loop);
- }];
-}
-
-def PeelLoopOp : Linalg_Transform_Operation<"peel_loop", [
- TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
- }];
-}
-
-
-def PipelineLoopOp : Linalg_Transform_Operation<"pipeline_loop", [
- TransformOpInterface,
- TargetableSingleOperandTransformOpTrait
- ]> {
- let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<I64Attr, "1">:$iteration_interval,
- DefaultValuedAttr<I64Attr, "10">:$read_latency);
- let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
- }];
-}
-
-def OutlineLoopOp : Linalg_Transform_Operation<"outline_loop", [
- DeclareOpInterfaceMethods<TransformOpInterface, ["apply"]>
- ]> {
- let arguments = (ins PDL_Operation:$target,
- StrAttr:$func_name);
- let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
-}
-
-def PrintOp : Transform_Op<"print", [
- DeclareOpInterfaceMethods<TransformOpInterface, ["apply"]>
- ]> {
- let arguments = (ins Optional<PDL_Operation>:$target,
- StrAttr:$name);
- let description = [{Prints the module.}];
- let assemblyFormat = "($target^)? attr-dict";
-}
-
-//===----------------------------------------------------------------------===//
-// LinalgExt specific transforms
-//===----------------------------------------------------------------------===//
-
-def TileToLinalgExtTileOp :
- Linalg_Transform_Operation<"tile_to_iree_linalg_ext_tile_op", [
- DeclareOpInterfaceMethods<TransformOpInterface, ["apply"]>
- ]> {
- let description = [{Tile a linalg op with linalg_ext.tile op along a single
- dimension.}];
-
- let summary = [{
- 0 should be used as a tile size to skip tiling a particular dimension.
- This is a commonly used convention in Linalg.
- For the purpose of this transformation, a single non-zero positive tile
- size is allowed.
- }];
-
- let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$sizes);
- let results = (outs PDL_Operation:$tiled_op,
- PDL_Operation:$tile_op);
-
- let assemblyFormat = "$target attr-dict";
-}
-
-def FuseIntoContainingOp :
- Linalg_Transform_Operation<"fuse_into_containing_op", [
- DeclareOpInterfaceMethods<TransformOpInterface, ["apply"]>
- ]> {
- let description = [{Fuse a producer into a containing operation.}];
-
- let summary = [{
- Search the body of the containing operation for all producer uses and
- compute the accessed producer slices on-the-fly.
- }];
-
- let arguments = (ins PDL_Operation:$producer_op,
- PDL_Operation:$containing_op);
-
- let assemblyFormat = "$producer_op `into` $containing_op attr-dict";
-}
-
-def RewriteLinalgExtTileToScfForOp :
- Linalg_Transform_Operation<"rewrite_iree_linalg_ext_tile_to_scf_for",
- [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
-
- let description = [{Rewrite linalg_ext.tile op to scf.for.}];
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
- ::mlir::iree_compiler::IREE::LinalgExt::TileOp target);
- }];
-}
-
-def RewriteLinalgExtTileToInParallelOp :
- Linalg_Transform_Operation<"rewrite_iree_linalg_ext_tile_to_in_parallel",
- [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
-
- let description = [{Rewrite linalg_ext.tile op to linalg_ext.in_parallel.}];
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::iree_compiler::IREE::LinalgExt::InParallelOp> applyToOne(
- ::mlir::iree_compiler::IREE::LinalgExt::TileOp target);
- }];
-}
-
-def RewriteLinalgExtInParallelToAsyncOp :
- Linalg_Transform_Operation<"rewrite_iree_linalg_ext_in_parallel_to_async",
- [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
-
- let description = [{Rewrite linalg_ext.in_parallel op to the async dialect.}];
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::Operation *> applyToOne(
- ::mlir::iree_compiler::IREE::LinalgExt::InParallelOp target);
- }];
-}
-
-def RewriteLinalgExtInParallelToHALOp :
- Linalg_Transform_Operation<"rewrite_iree_linalg_ext_in_parallel_to_hal", [
- DeclareOpInterfaceMethods<TransformOpInterface, ["apply"]>
- ]> {
-
- let description = [{Rewrite linalg_ext.in_parallel op to use HAL ops.}];
- let arguments = (ins PDL_Operation:$target);
- // TODO: Determine whether we want to return something here, the only natural
- // results would be the resulting insertTensorOps.
- // let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
-}
-
-def RewriteLinalgExtInParallelToScfForOp :
- Linalg_Transform_Operation<"rewrite_iree_linalg_ext_in_parallel_to_scf_for",
- [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
-
- let description = [{Rewrite linalg_ext.in_parallel to a sequential scf.for.
-
- Warning: when the linalg_ext.in_parallel terminator operates on tensors,
- this is a potentially dangerous transformation under the current semantics.
- In order for this transformation to be semantics-preserving, 2 conditions need
- to come together that are currently not checked and the subject of future analyses:
- 1. The terminator of linalg_ext.in_parallel may compose the output tensor in
- potentially undefined ways: if the linalg_ext.parallel_insert_slice regions overlap, they
- may occur in any order and the result is unspecified. A future overlap/intersection
- analysis will be needed to guard against this case.
- 2. Even when condition 1. has well-defined behavior, semantics altering behavior may
- still be introduced to simplify inplace bufferization. In the current implementation,
- all linalg_ext.parallel_insert_slice dest operands are optimistically turned into scf.for
- iter_args. This is optimistic because any use of a tensor inside linalg_ext.in_parallel
- is guaranteed to see the value before the start of the op; whereas the same use may
- see a partially updated sequential result in the scf.for op.
- An extra analysis is needed to ensure that a particular read of a result tensor would
- see the initial value and not a partial update. This is guaranteed by construction when
- the linalg_ext.in_parallel op is produced by lowering a linalg_ext.tile operation but
- is not something that is generally enforced in the IR.
- For the moment we perform the replacement of the use with the scf.for iter_arg to be
- able to connect pieces inplace with bufferization. In the future an analysis will be
- needed to ensure correctness of this lowering to sequential scf.for + iter_args.
- }];
- let arguments = (ins PDL_Operation:$target);
- let results = (outs PDL_Operation:$transformed);
-
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
- ::mlir::iree_compiler::IREE::LinalgExt::InParallelOp target);
- }];
-}
-
-//===----------------------------------------------------------------------===//
-
def ExpertOp : Linalg_Transform_Operation<"expert"> {
let description = [{A "transformation expert" that can be lowered to a
sequence of transformations. The details of the lowering depend on the name
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
new file mode 100644
index 0000000..d952fe6
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h
@@ -0,0 +1,37 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H
+#define IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H
+
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpTraits.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace linalg {
+class LinalgOp;
+} // namespace linalg
+namespace scf {
+class ForOp;
+} // namespace scf
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h.inc"
+
+namespace transform_ext {
+class StructuredTransformOpsExtension
+ : public mlir::transform::TransformDialectExtension<
+ StructuredTransformOpsExtension> {
+public:
+ StructuredTransformOpsExtension();
+};
+} // namespace transform_ext
+
+#endif // IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
new file mode 100644
index 0000000..a36d984
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.td
@@ -0,0 +1,327 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef STRUCTURED_TRANSFORM_OPS_EXT
+#define STRUCTURED_TRANSFORM_OPS_EXT
+
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/OpBase.td"
+
+def TargetableSingleOperandTransformOp
+ : NativeOpTrait<"TargetableSingleOperandOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+}
+
+def FunctionalStyleTransformOp
+ : NativeOpTrait<"FunctionalStyleTransformOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+}
+
+def PayloadTransformOp
+ : NativeOpTrait<"PayloadTransformOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+}
+
+def CanonicalizedSequenceOp
+ : Op<Transform_Dialect, "structured.canonicalized_sequence", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ OpAsmOpInterface,
+ PossibleTopLevelTransformOpTrait,
+ SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
+ let summary = "A transformation sequence interspersed with canonicalizations";
+
+ let arguments = (ins Optional<PDL_Operation>:$target);
+ let regions = (region SizedRegion<1>:$body);
+
+ let assemblyFormat = "($target^)? attr-dict-with-keyword regions";
+ let extraClassDeclaration = [{
+ /// Allow the dialect prefix to be omitted.
+ static ::llvm::StringRef getDefaultDialect() { return "transform"; }
+ }];
+
+ let cppNamespace = "transform_ext";
+}
+
+//===----------------------------------------------------------------------===//
+
+def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+ let description = [{Indicates that ops of a specific kind in the given
+ function should be scalarized (i.e. their dynamic dimensions tiled by 1).
+
+ This operation returns the tiled op but not the loops.
+
+ We make this design choice because it is hard to know ahead of time the
+ number of loops that will be produced (it depends on the number of
+ dynamic dimensions after multiple transformations have been applied).
+ }];
+
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$result);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+
+ let cppNamespace = "transform_ext";
+}
+
+def FuseOp : Op<Transform_Dialect, "structured.fuse",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{Tiles the operations pointed to by the target handle and
+ fuses their producers greedily using the options provided as attributes.}];
+
+ let arguments =
+ (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
+ let results = (outs PDL_Operation:$transformed,
+ Variadic<PDL_Operation>:$loops);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+ let cppNamespace = "transform_ext";
+}
+
+def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+ let description = [{Generalizes the operations pointed to
+ by the target handle.}];
+
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "transform_ext";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+}
+
+def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+ let description = [{Interchanges the iterators of the operations pointed to
+ by the target handle using the iterator interchange attribute.}];
+
+ let arguments =
+ (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$iterator_interchange);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let hasVerifier = 1;
+ let cppNamespace = "transform_ext";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+}
+
+def PadOp : Op<Transform_Dialect, "structured.pad",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+ let description = [{Pads the operations pointed to by the target handle
+ using the options provides as operation attributes.}];
+
+ let arguments =
+ (ins PDL_Operation:$target,
+ DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$pack_paddings,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$hoist_paddings,
+ DefaultValuedAttr<
+ TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
+ "{}">:$transpose_paddings);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let hasVerifier = 1;
+ let cppNamespace = "transform_ext";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+}
+
+def BufferizeOp : Op<Transform_Dialect, "bufferize",
+ [DeclareOpInterfaceMethods<TransformOpInterface>, MemoryEffectsOpInterface,
+ PayloadTransformOp]> {
+ let description = [{Indicates that the entire module should be bufferized.}];
+ let assemblyFormat = "attr-dict";
+ let cppNamespace = "transform_ext";
+}
+
+def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
+ [DeclareOpInterfaceMethods<TransformOpInterface>, MemoryEffectsOpInterface,
+ PayloadTransformOp]> {
+ let description = [{Indicates that ops in the entire module should be
+ decomposed into lower-level components.}];
+ let assemblyFormat = "attr-dict";
+ let cppNamespace = "transform_ext";
+}
+
+def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{Indiactes that vectorization should be performed. If a
+ target handle is provided, only vectorizes the operations pointed to by the
+ handle. Otherwise vectorizes the entire module. Vectorization options are
+ provided as operation attributes.}];
+
+ let arguments = (ins Optional<PDL_Operation>:$target,
+ DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding
+ );
+ let results = (outs Optional<PDL_Operation>:$transformed);
+
+ let hasCustomAssemblyFormat = 1;
+ let cppNamespace = "transform_ext";
+}
+
+def LowerVectorsOp : Op<Transform_Dialect, "lower_vectors",
+ [DeclareOpInterfaceMethods<TransformOpInterface>, MemoryEffectsOpInterface,
+ PayloadTransformOp]> {
+ let description = [{Indicates that the vector operations in the entire
+ module should be lowered to simpler primitives (multiple stages of lowering
+ be executed at once).}];
+
+ let arguments =
+ (ins DefaultValuedAttr<I64ArrayAttr, "{0, 1, 2, 3, 4, 5, 6}">:$stages,
+ DefaultValuedAttr<StrAttr, "\"outerproduct\"">:$contraction_lowering,
+ DefaultValuedAttr<StrAttr, "\"innerparallel\"">:$multireduction_lowering,
+ DefaultValuedAttr<StrAttr, "\"linalg-copy\"">:$split_transfers,
+ DefaultValuedAttr<BoolAttr, "true">:$unroll_vector_transfers,
+ DefaultValuedAttr<StrAttr, "\"eltwise\"">:$transpose_lowering,
+ DefaultValuedAttr<BoolAttr, "false">:$transpose_avx2_lowering
+ );
+
+ let assemblyFormat = "attr-dict";
+ let cppNamespace = "transform_ext";
+}
+
+def LowerToLLVMOp : Op<Transform_Dialect, "lower_to_llvm",
+ [DeclareOpInterfaceMethods<TransformOpInterface>, MemoryEffectsOpInterface,
+ PayloadTransformOp]> {
+ let description = [{Indicates that the entire module should be converted
+ to the LLVM dialect. This is expected to be the last transformation in
+ a sequence.}];
+
+ let arguments =
+ (ins DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_index_optimizations,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_arm_neon,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_arm_sve,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_amx,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_x86vector,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_async);
+
+ let assemblyFormat = "attr-dict";
+ let cppNamespace = "transform_ext";
+}
+
+def GetParentLoopOp : Op<Transform_Dialect, "get_parent_loop",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+ let description = [{Obtains a handle to a parent loop of the given
+ operation.}];
+
+ let arguments =
+ (ins PDL_Operation:$target,
+ DefaultValuedAttr<Confined<I64Attr, [IntPositive]>,
+ "1">:$num_loops);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "transform_ext";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::Operation *source);
+ }];
+}
+
+def UnrollLoopOp : Op<Transform_Dialect, "unroll_loop",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+ let description = [{Unrolls the given loop with the given unroll factor.}];
+
+ let arguments = (ins PDL_Operation:$target,
+ Confined<I64Attr, [IntPositive]>:$factor);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "transform_ext";
+
+ let extraClassDeclaration = [{
+ ::mlir::LogicalResult applyToOne(::mlir::scf::ForOp loop);
+ }];
+}
+
+def PeelLoopOp : Op<Transform_Dialect, "peel_loop",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "transform_ext";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
+ }];
+}
+
+def PipelineLoopOp : Op<Transform_Dialect, "pipeline_loop",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ TransformOpInterface, TargetableSingleOperandTransformOp]> {
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64Attr, "1">:$iteration_interval,
+ DefaultValuedAttr<I64Attr, "10">:$read_latency);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "transform_ext";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
+ }];
+}
+
+def OutlineLoopOp : Op<Transform_Dialect, "outline_loop",
+ [FunctionalStyleTransformOp, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins PDL_Operation:$target,
+ StrAttr:$func_name);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "transform_ext";
+}
+
+def PrintOp : Op<Transform_Dialect, "print",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins Optional<PDL_Operation>:$target,
+ StrAttr:$name);
+ let description = [{Prints the module.}];
+ let assemblyFormat = "($target^)? attr-dict";
+ let cppNamespace = "transform_ext";
+}
+
+#endif // STRUCTURED_TRANSFORM_OPS_EXT
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingCSE.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingCSE.h
deleted file mode 100644
index 3374c5b..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingCSE.h
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_TRACKINGCSE_H
-#define MLIR_DIALECT_LINALG_TRANSFORMS_TRACKINGCSE_H
-
-namespace mlir {
-class DominanceInfo;
-struct LogicalResult;
-class Operation;
-struct RewriteListener;
-
-LogicalResult
-eliminateCommonSubexpressionsWithTrackedOps(Operation *root,
- RewriteListener &listener,
- DominanceInfo *domInfo = nullptr);
-} // namespace mlir
-
-#endif // MLIR_DIALECT_LINALG_TRANSFORMS_TRACKINGCSE_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpTraits.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpTraits.h
new file mode 100644
index 0000000..b516c5f
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpTraits.h
@@ -0,0 +1,173 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGTRANSFORM_TRANSFORMOPTRAITS_H
+#define IREE_DIALECTS_DIALECT_LINALGTRANSFORM_TRANSFORMOPTRAITS_H
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace transform {
+
+namespace detail {
+/// Appends `result` to the vector assuming it corresponds to the success state
+/// in `FailureOr<convertible-to-Operation*>`. If `result` is just a
+/// `LogicalResult`, does nothing.
+template <typename Ty>
+std::enable_if_t<std::is_same<Ty, LogicalResult>::value, LogicalResult>
+appendTransformResultToVector(Ty result,
+ SmallVectorImpl<Operation *> &results) {
+ return result;
+}
+template <typename Ty>
+std::enable_if_t<!std::is_same<Ty, LogicalResult>::value, LogicalResult>
+appendTransformResultToVector(Ty result,
+ SmallVectorImpl<Operation *> &results) {
+ static_assert(
+ std::is_convertible<typename Ty::value_type, Operation *>::value,
+ "expected transform function to return operations");
+ if (failed(result))
+ return failure();
+
+ results.push_back(*result);
+ return success();
+}
+
+/// Applies a one-to-one transform to each of the given targets. Puts the
+/// results of transforms, if any, in `results` in the same order. Fails if any
+/// of the application fails. Individual transforms must be callable with
+/// one of the following signatures:
+/// - FailureOr<convertible-to-Operation*>(OpTy)
+/// - LogicalResult(OpTy)
+/// where OpTy is either
+/// - Operation *, in which case the transform is always applied;
+/// - a concrete Op class, in which case a check is performed whether
+/// `targets` contains operations of the same class and a failure is reported
+/// if it does not.
+template <typename FnTy>
+LogicalResult applyTransformToEach(ArrayRef<Operation *> targets,
+ SmallVectorImpl<Operation *> &results,
+ FnTy transform) {
+ using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
+ static_assert(std::is_convertible<OpTy, Operation *>::value,
+ "expected transform function to take an operation");
+ using RetTy = typename llvm::function_traits<FnTy>::result_t;
+ static_assert(std::is_convertible<RetTy, LogicalResult>::value,
+ "expected transform function to return LogicalResult or "
+ "FailureOr<convertible-to-Operation*>");
+ for (Operation *target : targets) {
+ auto specificOp = dyn_cast<OpTy>(target);
+ if (!specificOp)
+ return failure();
+
+ auto result = transform(specificOp);
+ if (failed(appendTransformResultToVector(result, results)))
+ return failure();
+ }
+ return success();
+}
+} // namespace detail
+
+/// Trait implementing the TransformOpInterface for operations applying a
+/// transformation to a single operation handle and producing a single operation
+/// handle. The op must implement a method with one of the following signatures:
+/// - FailureOr<convertible-to-Operation*> applyToOne(OpTy)
+/// - LogicalResult applyToOne(OpTy)
+/// to perform a transformation that is applied in turn to all payload IR
+/// operations that correspond to the handle of the transform IR operation.
+/// In the functions above, OpTy is either Operation * or a concrete payload IR
+/// Op class that the transformation is applied to (NOT the class of the
+/// transform IR op).
+template <typename OpTy>
+class TargetableSingleOperandOpTrait
+ : public OpTrait::TraitBase<OpTy, TargetableSingleOperandOpTrait> {
+public:
+ /// Applies the transformation to each op from the only target and sets the
+ /// only result to correspond to the list of individual results.
+ LogicalResult apply(TransformResults &transformResults,
+ TransformState &state) {
+ using TransformOpType = typename llvm::function_traits<decltype(
+ &OpTy::applyToOne)>::template arg_t<0>;
+ ArrayRef<Operation *> targets =
+ state.getPayloadOps(this->getOperation()->getOperand(0));
+ SmallVector<Operation *> results;
+ if (failed(detail::applyTransformToEach(
+ targets, results, [&](TransformOpType specificOp) {
+ return static_cast<OpTy *>(this)->applyToOne(specificOp);
+ })))
+ return failure();
+ if (OpTy::template hasTrait<OpTrait::OneResult>()) {
+ transformResults.set(
+ this->getOperation()->getResult(0).template cast<OpResult>(),
+ results);
+ }
+ return success();
+ }
+
+ /// Verifies that the op satisfies the requirements for this trait.
+ static LogicalResult verifyTrait(Operation *) {
+ static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
+ "expected single-operand op");
+ static_assert(OpTy::template hasTrait<OpTrait::OneResult>() ||
+ OpTy::template hasTrait<OpTrait::ZeroResult>(),
+ "expected zero- or single-result op");
+ return success();
+ }
+};
+
+template <typename OpTy>
+class FunctionalStyleTransformOpTrait
+ : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
+public:
+ void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ Operation *op = this->getOperation();
+ auto *transformMappingResource = TransformMappingResource::get();
+ for (Value operand : op->getOperands()) {
+ effects.emplace_back(MemoryEffects::Read::get(), operand,
+ transformMappingResource);
+ effects.emplace_back(MemoryEffects::Free::get(), operand,
+ transformMappingResource);
+ }
+ for (Value result : op->getResults()) {
+ effects.emplace_back(MemoryEffects::Allocate::get(), result,
+ transformMappingResource);
+ effects.emplace_back(MemoryEffects::Write::get(), result,
+ transformMappingResource);
+ }
+ effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
+ }
+
+ static LogicalResult verifyTrait(Operation *) {
+ static_assert(
+ OpTy::template hasTrait<MemoryEffectOpInterface::Trait>(),
+ "the op must have MemoryEffectOpInterface for this trait to apply");
+ return success();
+ }
+};
+
+template <typename OpTy>
+class PayloadTransformOpTrait
+ : public OpTrait::TraitBase<OpTy, PayloadTransformOpTrait> {
+public:
+ void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
+ }
+
+ static LogicalResult verifyTrait(Operation *) {
+ static_assert(
+ OpTy::template hasTrait<MemoryEffectOpInterface::Trait>(),
+ "the op must have MemoryEffectOpInterface for this trait to apply");
+ return success();
+ }
+};
+
+} // namespace transform
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_LINALGTRANSFORM_TRANSFORMOPTRAITS_H
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt
index 126b878..ce5c798 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Passes)
+add_subdirectory(TransformOps)
add_subdirectory(Transforms)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000..ff2b3c8
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_library(IREELinalgExtTransformOps
+ LinalgExtTransformOps.cpp
+
+ DEPENDS
+ mlir-headers
+
+ LINK_LIBS PUBLIC
+ IREEDialectsTransforms
+ MLIRRewrite
+
+ IREELinalgExtDialect
+ IREELinalgExtTransforms
+
+ MLIRPDL
+)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
new file mode 100644
index 0000000..3791931
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
@@ -0,0 +1,237 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree-dialects/Transforms/Functional.h"
+#include "mlir/IR/OpImplementation.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE;
+
+LinalgExt::LinalgExtTransformOpsExtension::LinalgExtTransformOpsExtension() {
+ registerTransformOps<
+#define GET_OP_LIST
+#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp.inc"
+ >();
+}
+
+//===---------------------------------------------------------------------===//
+// Utility functions
+//===---------------------------------------------------------------------===//
+
+/// Extracts a vector of int64_t from an array attribute. Asserts if the
+/// attribute contains values other than integers.
+static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
+ SmallVector<int64_t> result;
+ result.reserve(attr.size());
+ for (APInt value : attr.getAsValueRange<IntegerAttr>())
+ result.push_back(value.getSExtValue());
+ return result;
+}
+
+//===---------------------------------------------------------------------===//
+// FuseProducersOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult
+LinalgExt::FuseProducersOp::apply(transform::TransformResults &transformResults,
+ transform::TransformState &state) {
+ SmallVector<int64_t> operandsToFuse = extractI64Array(getOperandsToFuse());
+ LinalgExt::LinalgExtFusionPattern pattern(getContext(), operandsToFuse);
+ size_t numProducers = operandsToFuse.size();
+
+ SmallVector<Operation *> transformedOps;
+ SmallVector<SmallVector<Operation *>> fusedOps(numProducers);
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ // Apply the pattern.
+ FailureOr<LinalgExt::FusionResult> result =
+ functional::applyReturningPatternAt(pattern,
+ cast<linalg::LinalgOp>(target));
+ if (failed(result))
+ return failure();
+
+ // Update the fused operations.
+ transformedOps.push_back(result->consumerOp);
+ for (size_t i = 0; i < numProducers; ++i)
+ fusedOps[i].push_back(result->fusedOps[i]);
+ }
+
+ transformResults.set(getTransformed().cast<OpResult>(), transformedOps);
+ for (size_t i = 0; i < numProducers; ++i)
+ transformResults.set(getFusedOps()[i], fusedOps[i]);
+ return success();
+}
+
+LogicalResult LinalgExt::FuseProducersOp::verify() {
+ SmallVector<int64_t> operandsToFuse = extractI64Array(getOperandsToFuse());
+ llvm::SmallDenseSet<int64_t> operandsSet;
+ for (int64_t operandToFuse : operandsToFuse) {
+ if (operandToFuse < 0) {
+ return emitOpError() << "expects positive operand numbers, found "
+ << operandToFuse;
+ }
+ if (operandsSet.count(operandToFuse) != 0) {
+ return emitOpError() << "expects unique operand numbers, found "
+ << operandToFuse << " multiple times";
+ }
+ operandsSet.insert(operandToFuse);
+ }
+ return success();
+}
+
+ParseResult LinalgExt::FuseProducersOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::UnresolvedOperand targetOperand;
+ SMLoc opLoc;
+ parser.getCurrentLocation(&opLoc);
+ if (parser.parseOperand(targetOperand))
+ return parser.emitError(opLoc, "expected `target` operand");
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ StringRef operandsToFuseAttrName("operands_to_fuse");
+ Attribute operandsToFuseAttr = result.attributes.get(operandsToFuseAttrName);
+ if (!operandsToFuseAttr) {
+ return parser.emitError(opLoc, llvm::formatv("expected `{0}` attribute",
+ operandsToFuseAttrName));
+ }
+ auto operandsToFuseArrayAttr = operandsToFuseAttr.dyn_cast<ArrayAttr>();
+ if (!operandsToFuseArrayAttr) {
+ return parser.emitError(opLoc,
+ llvm::formatv("`{0}` attribute must be an array",
+ operandsToFuseAttrName));
+ }
+ Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
+ size_t numProducers = operandsToFuseArrayAttr.size();
+ result.addTypes(SmallVector<Type>(numProducers + 1, pdlOpType));
+ if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
+ return failure();
+ return success();
+}
+
+void LinalgExt::FuseProducersOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p << getTarget();
+ p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+LogicalResult
+LinalgExt::TileToLinalgExtTileOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ linalg::LinalgTilingOptions tilingOptions;
+ SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
+ if (!tileSizes.empty())
+ tilingOptions.setTileSizes(tileSizes);
+
+ LinalgExt::LinalgExtTilingPattern pattern(this->getContext(), tilingOptions);
+ ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
+ auto tilingInterfaceOp = dyn_cast<TilingInterface>(targets.front());
+ if (!tilingInterfaceOp) {
+ targets.front()->emitError("Cannot tile op: Not a TilingInterface");
+ return failure();
+ }
+
+ FailureOr<iree_compiler::IREE::LinalgExt::TilingResult> result =
+ functional::applyReturningPatternAt(pattern, tilingInterfaceOp);
+ if (failed(result))
+ return failure();
+ results.set(getTiledOp().cast<OpResult>(), result->tiledOp);
+ results.set(getTileOp().cast<OpResult>(), result->tileOp.getOperation());
+ return success();
+}
+
+LogicalResult
+LinalgExt::FuseIntoContainingOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+
+ ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
+ ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
+ for (auto it : llvm::zip(producerOps, containingOps)) {
+ auto producerOp = dyn_cast<linalg::LinalgOp>(std::get<0>(it));
+ Operation *containingOp = std::get<1>(it);
+ if (!producerOp) {
+ std::get<0>(it)->emitError("Cannot fuse op: Not a LinalgOp");
+ return failure();
+ }
+ LinalgExt::LinalgExtFusionInContainingOpPattern pattern(this->getContext(),
+ containingOp);
+ if (failed(functional::applyReturningPatternAt(pattern, producerOp)))
+ return failure();
+ }
+ return success();
+}
+
+FailureOr<scf::ForOp> LinalgExt::RewriteLinalgExtTileToScfForOp::applyToOne(
+ LinalgExt::TileOp target) {
+ LinalgExt::TileOpToSCFRewriter pattern(this->getContext());
+ auto functionalRewrite =
+ [&](LinalgExt::TileOp op,
+ PatternRewriter &rewriter) -> FailureOr<scf::ForOp> {
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return result;
+ };
+ return functional::applyAt(target, functionalRewrite);
+}
+
+FailureOr<LinalgExt::InParallelOp>
+LinalgExt::RewriteLinalgExtTileToInParallelOp::applyToOne(
+ LinalgExt::TileOp target) {
+ LinalgExt::TileOpToInParallelRewriter pattern(this->getContext());
+ auto functionalRewrite =
+ [&](LinalgExt::TileOp op,
+ PatternRewriter &rewriter) -> FailureOr<LinalgExt::InParallelOp> {
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return result;
+ };
+ return functional::applyAt(target, functionalRewrite);
+}
+
+FailureOr<Operation *>
+LinalgExt::RewriteLinalgExtInParallelToAsyncOp::applyToOne(
+ LinalgExt::InParallelOp target) {
+ LinalgExt::InParallelOpToAsyncRewriter pattern(this->getContext());
+ auto functionalRewrite =
+ [&](LinalgExt::InParallelOp op,
+ PatternRewriter &rewriter) -> FailureOr<Operation *> {
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return result;
+ };
+ return functional::applyAt(target, functionalRewrite);
+}
+
+LogicalResult LinalgExt::RewriteLinalgExtInParallelToHALOp::apply(
+ transform::TransformResults &results, transform::TransformState &state) {
+ LinalgExt::InParallelOpToHALRewriter pattern(this->getContext());
+ ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
+ return functional::applyReturningPatternAt(
+ pattern, cast<LinalgExt::InParallelOp>(targets.front()));
+}
+
+FailureOr<scf::ForOp>
+LinalgExt::RewriteLinalgExtInParallelToScfForOp::applyToOne(
+ LinalgExt::InParallelOp target) {
+ LinalgExt::InParallelOpToScfForRewriter pattern(this->getContext());
+ auto functionalRewrite =
+ [&](LinalgExt::InParallelOp op,
+ PatternRewriter &rewriter) -> FailureOr<scf::ForOp> {
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return result;
+ };
+ return functional::applyAt(target, functionalRewrite);
+}
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt
index 7b20d53..d9f3b91 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt
@@ -1,7 +1,7 @@
add_mlir_library(IREELinalgTransformDialect
LinalgTransformOps.cpp
- PDL.cpp
ScopedTransform.cpp
+ StructuredTransformOpsExt.cpp
TransformOpInterface.cpp
TrackingListener.cpp
TrackingRewriteDriver.cpp
@@ -22,6 +22,7 @@
MLIRLinalg
MLIRPDL
MLIRRewrite
+ MLIRTransformDialect
# Transforms
MLIRAsyncTransforms
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp
index 7b6173c..c61e616 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp
@@ -9,7 +9,6 @@
#include <algorithm>
#include "FunctionHelpers.h"
-#include "PDL.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h"
@@ -71,102 +70,6 @@
>();
}
-//===----------------------------------------------------------------------===//
-// Functional Rewrite Helpers
-//===----------------------------------------------------------------------===//
-
-using FunctionalLinalgTransform =
- std::function<FailureOr<LinalgOp>(LinalgOp, PatternRewriter &)>;
-
-/// Extracts a vector of int64_t from an array attribute. Asserts if the
-/// attribute contains values other than integers.
-static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
- SmallVector<int64_t> result;
- result.reserve(attr.size());
- for (APInt value : attr.getAsValueRange<IntegerAttr>())
- result.push_back(value.getSExtValue());
- return result;
-}
-
-/// Extracts a vector of unsigned from an array attribute. Asserts if the
-/// attribute contains values other than intergers. May truncate.
-static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
- SmallVector<unsigned> result;
- result.reserve(attr.size());
- for (APInt value : attr.getAsValueRange<IntegerAttr>())
- result.push_back(value.getZExtValue());
- return result;
-}
-
-/// Apply a tiling transformation to all payload ops and store both the
-/// tiled operation as well as the created tile loops.
-static LogicalResult
-applyTilingToAll(Operation *transformOp, Value target,
- ArrayRef<int64_t> tileSizes,
- transform::TransformResults &transformResults,
- transform::TransformState &state,
- std::function<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
- size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
-
- SmallVector<Operation *> tiledLinalgOps;
- SmallVector<SmallVector<Operation *>> loopOps(numLoops);
-
- for (Operation *target : state.getPayloadOps(target)) {
- auto linalgOp = cast<linalg::LinalgOp>(target);
- FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
- if (failed(tiled))
- return failure();
-
- tiledLinalgOps.push_back(tiled->op);
- if (tiled->loops.size() != numLoops)
- // Not enough loops were generated. This usually means that the input size
- // was smaller than the tiling size.
- // TODO: LinalgTilingPattern should return failure().
- return failure();
- for (unsigned int i = 0; i < numLoops; ++i) {
- loopOps[i].push_back(tiled->loops[i]);
- }
- }
-
- transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
- for (unsigned int i = 0; i < numLoops; ++i) {
- transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
- }
- return success();
-}
-
-/// Parse a tiling operation that returns the tiled op as well as the created
-/// tile loops. The function counts the non-zero tile sizes to compute the
-/// number of results.
-static ParseResult parseTileOp(OpAsmParser &parser, OperationState &result,
- StringRef sizesAttrName) {
- OpAsmParser::UnresolvedOperand targetOperand;
- SMLoc opLoc;
- parser.getCurrentLocation(&opLoc);
- if (parser.parseOperand(targetOperand))
- return parser.emitError(opLoc, "expected `target` operand");
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
- Attribute sizesAttr = result.attributes.get(sizesAttrName);
- if (!sizesAttr) {
- return parser.emitError(
- opLoc, llvm::formatv("expected `{0}` attribute", sizesAttrName));
- }
- auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
- if (!sizesArrayAttr) {
- return parser.emitError(
- opLoc,
- llvm::formatv("`{0}` attribute must be an array", sizesAttrName));
- }
- Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
- size_t numExpectedLoops =
- sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
- result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
- if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
- return failure();
- return success();
-}
-
//===---------------------------------------------------------------------===//
// ScopeOp
//===---------------------------------------------------------------------===//
@@ -180,991 +83,5 @@
regions.emplace_back(&body());
}
-//===---------------------------------------------------------------------===//
-// SequenceOp
-//===---------------------------------------------------------------------===//
-
-LogicalResult transform::SequenceOp::verify() {
- WalkResult result = this->walk([](Operation *child) {
- for (OpResult result : child->getResults()) {
- if (llvm::hasNItemsOrLess(result.getUses(), 1))
- continue;
- InFlightDiagnostic diag = child->emitError()
- << "result #" << result.getResultNumber()
- << " has more than one use";
- for (OpOperand &use : result.getUses()) {
- diag.attachNote(use.getOwner()->getLoc())
- << "used here as operand #" << use.getOperandNumber();
- }
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- });
- return failure(result.wasInterrupted());
-}
-
-//===---------------------------------------------------------------------===//
-// MatchOp
-//===---------------------------------------------------------------------===//
-
-LogicalResult transform::MatchOp::apply(TransformResults &results,
- TransformState &state) {
- Operation *topLevelOp = state.getTopLevel();
- FailureOr<SmallVector<Operation *>> ops = findMatchingOps(*this, topLevelOp);
- if (failed(ops))
- return failure();
- LLVM_DEBUG(DBGS() << "matched " << ops->size() << " ops\n");
- results.set(getResult().cast<OpResult>(), *ops);
- return success();
-}
-
-//===---------------------------------------------------------------------===//
-// TileOp
-//===---------------------------------------------------------------------===//
-
-LogicalResult transform::TileOp::apply(TransformResults &transformResults,
- TransformState &state) {
- LinalgTilingOptions tilingOptions;
- SmallVector<int64_t> tileSizes = extractI64Array(sizes());
-
- if (!tileSizes.empty())
- tilingOptions.setTileSizes(tileSizes);
- tilingOptions.setInterchange(extractUIntArray(interchange()));
- LinalgTilingPattern pattern(getContext(), tilingOptions);
- auto functionalTile =
- [&](LinalgOp op, PatternRewriter &rewriter) -> FailureOr<TiledLinalgOp> {
- return pattern.returningMatchAndRewrite(op, rewriter);
- };
-
- return applyTilingToAll(getOperation(), target(), tileSizes, transformResults,
- state, [&](LinalgOp linalgOp) {
- return functional::applyAt(linalgOp,
- functionalTile);
- });
-}
-
-ParseResult transform::TileOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseTileOp(parser, result, "sizes");
-}
-
-void transform::TileOp::print(OpAsmPrinter &p) {
- p << ' ';
- p << target();
- p.printOptionalAttrDict((*this)->getAttrs());
-}
-
-//===---------------------------------------------------------------------===//
-// ScalarizeOp
-//===---------------------------------------------------------------------===//
-
-FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
- LinalgTilingOptions tilingOptions;
- tilingOptions.scalarizeDynamicDims();
- // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
- // sizes and asserts that it is not already set.
- SmallVector<int64_t> emptyTileSizes;
- LinalgTilingPattern pattern(getContext(), tilingOptions);
- auto maybeTiledLinalgOp =
- functional::applyReturningPatternAt(pattern, target);
- if (failed(maybeTiledLinalgOp))
- return failure();
- return maybeTiledLinalgOp->op;
-}
-
-//===---------------------------------------------------------------------===//
-// FuseOp
-//===---------------------------------------------------------------------===//
-
-LogicalResult transform::FuseOp::apply(TransformResults &transformResults,
- TransformState &state) {
- LinalgTilingAndFusionOptions fusionOptions;
- fusionOptions.tileSizes = extractI64Array(tile_sizes());
- fusionOptions.tileInterchange = extractI64Array(tile_interchange());
-
- LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
- auto functionalFuse =
- [&](LinalgOp op, PatternRewriter &rewriter) -> FailureOr<TileLoopNest> {
- return pattern.returningMatchAndRewrite(op, rewriter);
- };
-
- return applyTilingToAll(
- getOperation(), target(), fusionOptions.tileSizes, transformResults,
- state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
- FailureOr<TileLoopNest> tileLoopNest =
- functional::applyAt(linalgOp, functionalFuse);
- if (failed(tileLoopNest))
- return failure();
-
- TiledLinalgOp tiledLinalgOp;
- tiledLinalgOp.op = tileLoopNest->getRootOp();
- tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
- tileLoopNest->getLoopOps().end()};
- return tiledLinalgOp;
- });
-}
-
-LogicalResult transform::FuseOp::verify() {
- SmallVector<int64_t> permutation = extractI64Array(tile_interchange());
- auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
- if (!std::is_permutation(sequence.begin(), sequence.end(),
- permutation.begin(), permutation.end())) {
- return emitOpError() << "expects interchange to be a permutation, found "
- << tile_interchange();
- }
- return success();
-}
-
-ParseResult transform::FuseOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseTileOp(parser, result, "tile_sizes");
-}
-
-void transform::FuseOp::print(OpAsmPrinter &p) {
- p << ' ';
- p << target();
- p.printOptionalAttrDict((*this)->getAttrs());
-}
-
-//===---------------------------------------------------------------------===//
-// FuseProducersOp
-//===---------------------------------------------------------------------===//
-
-LogicalResult
-transform::FuseProducersOp::apply(TransformResults &transformResults,
- TransformState &state) {
- SmallVector<int64_t> operandsToFuse = extractI64Array(operands_to_fuse());
- LinalgExt::LinalgExtFusionPattern pattern(getContext(), operandsToFuse);
- size_t numProducers = operandsToFuse.size();
-
- SmallVector<Operation *> transformedOps;
- SmallVector<SmallVector<Operation *>> fusedOps(numProducers);
- for (Operation *target : state.getPayloadOps(target())) {
- // Apply the pattern.
- FailureOr<LinalgExt::FusionResult> result =
- functional::applyReturningPatternAt(pattern,
- cast<linalg::LinalgOp>(target));
- if (failed(result))
- return failure();
-
- // Update the fused operations.
- transformedOps.push_back(result->consumerOp);
- for (size_t i = 0; i < numProducers; ++i)
- fusedOps[i].push_back(result->fusedOps[i]);
- }
-
- transformResults.set(transformed().cast<OpResult>(), transformedOps);
- for (size_t i = 0; i < numProducers; ++i)
- transformResults.set(fused_ops()[i], fusedOps[i]);
- return success();
-}
-
-LogicalResult transform::FuseProducersOp::verify() {
- SmallVector<int64_t> operandsToFuse = extractI64Array(operands_to_fuse());
- llvm::SmallDenseSet<int64_t> operandsSet;
- for (int64_t operandToFuse : operandsToFuse) {
- if (operandToFuse < 0) {
- return emitOpError() << "expects positive operand numbers, found "
- << operandToFuse;
- }
- if (operandsSet.count(operandToFuse) != 0) {
- return emitOpError() << "expects unique operand numbers, found "
- << operandToFuse << " multiple times";
- }
- operandsSet.insert(operandToFuse);
- }
- return success();
-}
-
-ParseResult transform::FuseProducersOp::parse(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::UnresolvedOperand targetOperand;
- SMLoc opLoc;
- parser.getCurrentLocation(&opLoc);
- if (parser.parseOperand(targetOperand))
- return parser.emitError(opLoc, "expected `target` operand");
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
- StringRef operandsToFuseAttrName("operands_to_fuse");
- Attribute operandsToFuseAttr = result.attributes.get(operandsToFuseAttrName);
- if (!operandsToFuseAttr) {
- return parser.emitError(opLoc, llvm::formatv("expected `{0}` attribute",
- operandsToFuseAttrName));
- }
- auto operandsToFuseArrayAttr = operandsToFuseAttr.dyn_cast<ArrayAttr>();
- if (!operandsToFuseArrayAttr) {
- return parser.emitError(opLoc,
- llvm::formatv("`{0}` attribute must be an array",
- operandsToFuseAttrName));
- }
- Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
- size_t numProducers = operandsToFuseArrayAttr.size();
- result.addTypes(SmallVector<Type>(numProducers + 1, pdlOpType));
- if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
- return failure();
- return success();
-}
-
-void transform::FuseProducersOp::print(OpAsmPrinter &p) {
- p << ' ';
- p << target();
- p.printOptionalAttrDict((*this)->getAttrs());
-}
-
-//===---------------------------------------------------------------------===//
-// GeneralizeOp
-//===---------------------------------------------------------------------===//
-
-FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
- // Exit early if no transformation is needed.
- if (isa<GenericOp>(target))
- return target;
- return functional::applyAt(
- target, callLinalgPattern<LinalgGeneralizationPattern>(getContext()));
-}
-
-//===---------------------------------------------------------------------===//
-// InterchangeOp
-//===---------------------------------------------------------------------===//
-
-FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
- SmallVector<unsigned> interchangeVector =
- extractUIntArray(iterator_interchange());
- // Exit early if no transformation is needed.
- if (interchangeVector.empty())
- return target;
- return functional::applyAt(target,
- callLinalgPattern<GenericOpInterchangePattern>(
- getContext(), interchangeVector));
-}
-
-LogicalResult transform::InterchangeOp::verify() {
- SmallVector<unsigned> permutation = extractUIntArray(iterator_interchange());
- auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
- if (!std::is_permutation(sequence.begin(), sequence.end(),
- permutation.begin(), permutation.end())) {
- return emitOpError()
- << "expects iterator_interchange to be a permutation, found "
- << iterator_interchange();
- }
- return success();
-}
-
-//===---------------------------------------------------------------------===//
-// PadOp
-//===---------------------------------------------------------------------===//
-
-FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
- // Convert the integer packing flags to booleans.
- SmallVector<bool> packPaddings;
- for (int64_t packPadding : extractI64Array(this->pack_paddings()))
- packPaddings.push_back(static_cast<bool>(packPadding));
-
- // Convert the padding values to attributes.
- SmallVector<Attribute> paddingValues;
- for (auto const &it :
- llvm::zip(this->padding_values(), target->getOperandTypes())) {
- Attribute attr = std::get<0>(it);
- Type elementType = getElementTypeOrSelf(std::get<1>(it));
- // Try to parse string attributes to obtain an attribute of element type.
- if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
- paddingValues.push_back(
- parseAttribute(attr.cast<StringAttr>(), elementType));
- if (!paddingValues.back()) {
- return target->emitOpError("expects a padding value ")
- << std::get<0>(it) << " that parses to " << elementType;
- }
- continue;
- }
- // Otherwise, add the attribute directly.
- if (attr.getType() != elementType) {
- return target->emitOpError("expects a padding value ")
- << attr << " of type " << elementType;
- }
- paddingValues.push_back(attr);
- }
-
- // Extract the transpose vectors.
- SmallVector<SmallVector<int64_t>> transposePaddings;
- for (Attribute transposeVector : this->transpose_paddings().cast<ArrayAttr>())
- transposePaddings.push_back(
- extractI64Array(transposeVector.cast<ArrayAttr>()));
-
- LinalgPaddingOptions paddingOptions;
- paddingOptions.setPaddingValues(paddingValues);
- paddingOptions.setPaddingDimensions(
- extractI64Array(this->padding_dimensions()));
- paddingOptions.setPackPaddings(packPaddings);
- paddingOptions.setHoistPaddings(extractI64Array(this->hoist_paddings()));
- paddingOptions.setTransposePaddings(transposePaddings);
- auto res = functional::applyAt(
- target,
- callLinalgPattern<LinalgPaddingPattern>(getContext(), paddingOptions));
- if (failed(res))
- return target->emitOpError()
- << "failed to apply LinalgPaddingPattern at: " << target;
- return res;
-}
-
-LogicalResult transform::PadOp::verify() {
- SmallVector<int64_t> packPaddings = extractI64Array(pack_paddings());
- if (any_of(packPaddings, [](int64_t packPadding) {
- return packPadding != 0 && packPadding != 1;
- })) {
- return emitOpError()
- << "expects pack_paddings to contain booleans (0/1), found "
- << pack_paddings();
- }
- SmallVector<int64_t> paddingDimensions =
- extractI64Array(padding_dimensions());
- if (any_of(paddingDimensions,
- [](int64_t paddingDimension) { return paddingDimension < 0; })) {
- return emitOpError()
- << "expects padding_dimensions to contain positive integers, found "
- << padding_dimensions();
- }
- SmallVector<int64_t> hoistPaddings = extractI64Array(hoist_paddings());
- if (any_of(hoistPaddings,
- [](int64_t hoistPadding) { return hoistPadding < 0; })) {
- return emitOpError()
- << "expects hoist_paddings to contain positive integers, found "
- << hoist_paddings();
- }
- ArrayAttr transposes = transpose_paddings();
- for (Attribute attr : transposes) {
- SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
- auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
- if (!std::is_permutation(sequence.begin(), sequence.end(),
- transpose.begin(), transpose.end())) {
- return emitOpError()
- << "expects transpose_paddings to be a permutation, found "
- << attr;
- }
- }
- return success();
-}
-
-//===---------------------------------------------------------------------===//
-// DecomposeOp
-//===---------------------------------------------------------------------===//
-
-LogicalResult
-transform::DecomposeOp::apply(transform::TransformResults &results,
- transform::TransformState &state) {
- RewritePatternSet patterns(getContext());
- // TODO: make this targetable.
- populateDecomposeConvolutionPatterns(patterns, LinalgTransformationFilter());
- if (failed(applyPatternsAndFoldGreedily(state.getTopLevel(),
- std::move(patterns))))
- return failure();
-
- // TODO: make this chainable, it isn't in the original codegenstrategy.
- return success();
-}
-
-//===---------------------------------------------------------------------===//
-// VectorizeOp
-//===---------------------------------------------------------------------===//
-
-static void configureVectorizationPatterns(transform::VectorizeOp vectorizeOp,
- RewritePatternSet &patterns) {
- MLIRContext *ctx = vectorizeOp->getContext();
- vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
- vector::populateVectorReductionToContractPatterns(patterns);
- patterns.add<linalg::LinalgCopyVTRForwardingPattern,
- linalg::LinalgCopyVTWForwardingPattern>(ctx,
- /*benefit=*/2);
- vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
- vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
- if (vectorizeOp.vectorize_padding())
- linalg::populatePadOpVectorizationPatterns(patterns);
-}
-
-/// Applies the transformation specified by the given vectorize operation to the
-/// given target operation AND some related operations.Populates `results` with
-/// transformation operations for further transformations if the pattern applied
-/// successfully (currently, the main "contraction" op after vectorization).
-static FailureOr<LinalgOp>
-executeTargetedVectorizeOp(LinalgOp target,
- linalg::transform::VectorizeOp vectorizeOp) {
- // TODO: this is copy-pasta from LinalgStrategyVectorizePass, it shouldn't be.
- MLIRContext *ctx = target->getContext();
- RewritePatternSet patterns(ctx);
- configureVectorizationPatterns(vectorizeOp, patterns);
- LinalgVectorizationPattern pattern(vectorizeOp.getContext());
- auto functionalVectorize = [&](LinalgOp op, PatternRewriter &rewriter) {
- return pattern.matchAndRewrite(op, rewriter);
- };
-
- /// Apply the transformations in a scope.
- return transform::scoped(
- target,
- [&](transform::ScopeOp scope, Operation *op) -> FailureOr<LinalgOp> {
- if (failed(functional::applyAt(op, functionalVectorize)) ||
- failed(applyPatternsAndFoldGreedily(scope, std::move(patterns))))
- return failure();
- // FIXME: Vectorization doesn't return anything.
- return LinalgOp();
- });
-
- // TODO: vectorization may fail because the op is not vectorizable, unclear
- // what to do here. We should probably report it somehow, but we may also
- // want to go on and keep the original for continuation. Should we have
- // some notion of transformation optionality vs. mandatory (like lowering)?
- // How to find ops that were not replaced?
-}
-
-LogicalResult
-transform::VectorizeOp::apply(transform::TransformResults &results,
- transform::TransformState &state) {
- if (target()) {
- SmallVector<Operation *> resultVector;
- LogicalResult res = applyTransformToEach(
- state.getPayloadOps(target()), resultVector, [&](LinalgOp target) {
- return executeTargetedVectorizeOp(target, *this);
- });
-
- if (failed(res))
- return failure();
-
- results.set(getResult(0).cast<OpResult>(), resultVector);
- return success();
- }
-
- MLIRContext *ctx = getContext();
- RewritePatternSet patterns(ctx);
- patterns.add<LinalgVectorizationPattern>(ctx);
- configureVectorizationPatterns(*this, patterns);
- auto &listener = state.getExtension<TrackingListener>();
- LogicalResult applicationResult = applyPatternsTrackAndFoldGreedily(
- state.getTopLevel(), listener, std::move(patterns));
- LogicalResult listenerResult = listener.checkErrorState();
- return failure(failed(applicationResult) || failed(listenerResult));
-}
-
-ParseResult transform::VectorizeOp::parse(OpAsmParser &parser,
- OperationState &result) {
- auto operationType = pdl::OperationType::get(parser.getContext());
- OpAsmParser::UnresolvedOperand target;
- OptionalParseResult parseResult = parser.parseOptionalOperand(target);
- if (parseResult.hasValue()) {
- if (parseResult.getValue().failed() ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.resolveOperand(target, operationType, result.operands) ||
- parser.addTypeToList(operationType, result.types)) {
- return failure();
- }
- } else {
- if (parser.parseOptionalAttrDict(result.attributes)) {
- return failure();
- }
- }
- return success();
-}
-
-void transform::VectorizeOp::print(OpAsmPrinter &printer) {
- if (target())
- printer << " " << target() << " ";
- printer.printOptionalAttrDict(getOperation()->getAttrs());
-}
-
-//===---------------------------------------------------------------------===//
-// LowerVectorsOp
-//===---------------------------------------------------------------------===//
-
-/// Returns true of the numbered vector lowering stage is included into the list
-/// of stages specified on the given lowerVectors operation.
-static bool stageIncluded(int stage, transform::LowerVectorsOp lowerVectorsOp) {
- for (auto s : lowerVectorsOp.stages().getAsValueRange<IntegerAttr>()) {
- if (s.getSExtValue() == stage)
- return true;
- }
- return false;
-}
-
-// Applies the transformation specified by the given lower vectors operation
-/// to the given function.
-LogicalResult
-transform::LowerVectorsOp::apply(transform::TransformResults &results,
- transform::TransformState &state) {
- MLIRContext *ctx = getContext();
- RewritePatternSet patterns(ctx);
-
- vector::VectorTransposeLowering vectorTransposeLowering =
- llvm::StringSwitch<vector::VectorTransposeLowering>(transpose_lowering())
- .Case("eltwise", vector::VectorTransposeLowering::EltWise)
- .Case("flat_transpose", vector::VectorTransposeLowering::Flat)
- .Case("shuffle", vector::VectorTransposeLowering::Shuffle)
- .Default(vector::VectorTransposeLowering::EltWise);
- vector::VectorMultiReductionLowering vectorMultiReductionLowering =
- llvm::StringSwitch<vector::VectorMultiReductionLowering>(
- multireduction_lowering())
- .Case("innerreduction",
- vector::VectorMultiReductionLowering::InnerReduction)
- .Default(vector::VectorMultiReductionLowering::InnerParallel);
- vector::VectorContractLowering vectorContractLowering =
- llvm::StringSwitch<vector::VectorContractLowering>(contraction_lowering())
- .Case("matrixintrinsics", vector::VectorContractLowering::Matmul)
- .Case("dot", vector::VectorContractLowering::Dot)
- .Case("outerproduct", vector::VectorContractLowering::OuterProduct)
- .Default(vector::VectorContractLowering::OuterProduct);
- // TODO: fix the annoying name mismatch (vector-transfers vs vector-transfer).
- vector::VectorTransferSplit vectorTransferSplit =
- llvm::StringSwitch<vector::VectorTransferSplit>(split_transfers())
- .Case("none", vector::VectorTransferSplit::None)
- .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy)
- .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
- .Default(vector::VectorTransferSplit::None);
-
- vector::VectorTransformsOptions vectorTransformOptions;
- vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering)
- .setVectorMultiReductionLowering(vectorMultiReductionLowering)
- .setVectorTransposeLowering(vectorTransposeLowering)
- .setVectorTransferSplit(vectorTransferSplit);
-
- VectorTransferToSCFOptions vectorTransferToSCFOptions =
- VectorTransferToSCFOptions()
- .enableFullUnroll(unroll_vector_transfers())
- .enableLowerPermutationMaps();
-
- int maxTransferRank = 1;
-
- auto avx2LoweringOptions =
- x86vector::avx2::LoweringOptions().setTransposeOptions(
- x86vector::avx2::TransposeLoweringOptions()
- .lower4x8xf32(transpose_avx2_lowering())
- .lower8x8xf32(transpose_avx2_lowering()));
-
- // TODO: this is copy-pasta from LinalgStrategyLowerVectorsPass, shouldn't be.
- vector::populateVectorToVectorCanonicalizationPatterns(patterns);
- if (stageIncluded(1, *this)) {
- patterns.add<mlir::vector::ContractionOpToOuterProductOpLowering,
- mlir::vector::ContractionOpToMatmulOpLowering,
- mlir::vector::ContractionOpLowering>(vectorTransformOptions,
- ctx);
- vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
- }
- if (stageIncluded(2, *this)) {
- vector::populateVectorMultiReductionLoweringPatterns(
- patterns, vectorTransformOptions.vectorMultiReductionLowering);
- }
- if (stageIncluded(3, *this)) {
- patterns.add<vector::VectorTransferFullPartialRewriter>(
- ctx, vectorTransformOptions);
- }
- if (stageIncluded(4, *this)) {
- vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank);
- }
- if (stageIncluded(5, *this)) {
- populateVectorToSCFConversionPatterns(
- patterns, vectorTransferToSCFOptions.setTargetRank(maxTransferRank));
- }
- if (stageIncluded(6, *this)) {
- vector::populateVectorShapeCastLoweringPatterns(patterns);
- }
- if (stageIncluded(7, (*this))) {
- vector::populateVectorTransposeLoweringPatterns(patterns,
- vectorTransformOptions);
- if (transpose_avx2_lowering())
- x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
- patterns, avx2LoweringOptions, /*benefit=*/10);
- }
-
- // TODO: these transformations are currently not targeted at concrete ops.
- // LinalgTransformationFilter filter = makeTransformationFilter(target);
- if (failed(applyPatternsAndFoldGreedily(state.getTopLevel(),
- std::move(patterns))))
- return failure();
-
- // TODO: make composable...
- return success();
-}
-
-//===---------------------------------------------------------------------===//
-// BufferizeOp
-//===---------------------------------------------------------------------===//
-
-static void applyBufferizationEnablingTransformations(ModuleOp moduleOp) {
- RewritePatternSet patterns(moduleOp.getContext());
- patterns.add<GeneralizePadOpPattern>(moduleOp.getContext());
- (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
-}
-
-LogicalResult transform::BufferizeOp::apply(transform::TransformResults &result,
- transform::TransformState &state) {
- bufferization::OneShotBufferizationOptions options;
- options.bufferizeFunctionBoundaries = true;
- options.memCpyFn = [](OpBuilder &builder, Location loc, Value from,
- Value to) {
- return success(linalg::makeMemRefCopyOp(builder, loc, from, to));
- };
-
- auto moduleOp = cast<ModuleOp>(state.getTopLevel());
- applyBufferizationEnablingTransformations(moduleOp);
- if (failed(runOneShotModuleBufferize(moduleOp, options)))
- return failure();
-
- // Perform buffer-level hoistings.
- state.getTopLevel()->walk(
- [&](func::FuncOp funcOp) { hoistRedundantVectorTransfers(funcOp); });
- return success();
-}
-
-//===---------------------------------------------------------------------===//
-// LowerToLLVMOp
-//===---------------------------------------------------------------------===//
-
-LogicalResult
-transform::LowerToLLVMOp::apply(transform::TransformResults &result,
- transform::TransformState &state) {
- // TODO: it is feasible to scope lowering at arbitrary level and introduce
- // unrealized casts, but there needs to be the final module-wise cleanup in
- // the end. Keep module-level for now.
- PassManager pm(getContext());
-
- pm.addNestedPass<func::FuncOp>(createConvertVectorToSCFPass());
- pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
- if (enable_async()) {
- pm.addPass(createAsyncToAsyncRuntimePass());
- pm.addPass(createAsyncRuntimeRefCountingPass());
- pm.addPass(createAsyncRuntimeRefCountingOptPass());
- }
- pm.addPass(createCanonicalizerPass());
- pm.addPass(createLowerAffinePass());
- pm.addPass(createConvertSCFToCFPass());
- pm.addPass(createConvertLinalgToLLVMPass());
- pm.addPass(createConvertVectorToLLVMPass(
- // clang-format off
- LowerVectorToLLVMOptions()
- .enableReassociateFPReductions(reassociate_fp_reductions())
- .enableIndexOptimizations(enable_index_optimizations())
- .enableArmNeon(enable_arm_neon())
- .enableArmSVE(enable_arm_sve())
- .enableAMX(enable_amx())
- .enableX86Vector(enable_x86vector())));
- // clang-format on
- pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
- pm.addPass(createMemRefToLLVMPass());
- if (enable_async())
- pm.addPass(createConvertAsyncToLLVMPass());
- pm.addPass(createConvertFuncToLLVMPass());
- pm.addPass(createReconcileUnrealizedCastsPass());
- if (failed(pm.run(state.getTopLevel())))
- return failure();
-
- // Make all arguments noalias for now.
- // FIXME: this is a terrible hack!
- state.getTopLevel()->walk([](LLVM::LLVMFuncOp funcOp) {
- for (int64_t i = 0; i < funcOp.getNumArguments(); ++i) {
- if (!funcOp.getFunctionType()
- .getParamType(i)
- .isa<LLVM::LLVMPointerType>())
- continue;
- funcOp.setArgAttr(i, "llvm.noalias", UnitAttr::get(funcOp.getContext()));
- }
- });
- return success();
-}
-
-//===---------------------------------------------------------------------===//
-// GetParentLoopOp
-//===---------------------------------------------------------------------===//
-
-FailureOr<scf::ForOp>
-transform::GetParentLoopOp::applyToOne(Operation *source) {
- int64_t nLoops = num_loops();
- for (int64_t i = 0; i < nLoops; ++i) {
- source = source->getParentOfType<scf::ForOp>();
- if (!source) {
- emitError() << "the transformed op is enclosed by " << i << " loops, but "
- << nLoops << " expected";
- return failure();
- }
- }
- return cast<scf::ForOp>(source);
-}
-
-//===---------------------------------------------------------------------===//
-// UnrollLoopOp
-//===---------------------------------------------------------------------===//
-
-LogicalResult transform::UnrollLoopOp::applyToOne(scf::ForOp loop) {
- return loopUnrollByFactor(loop, factor());
-}
-
-//===---------------------------------------------------------------------===//
-// PeelLoopOp
-//===---------------------------------------------------------------------===//
-
-FailureOr<scf::ForOp> transform::PeelLoopOp::applyToOne(scf::ForOp loop) {
- scf::ForOp result;
- IRRewriter rewriter(loop->getContext());
- LogicalResult status =
- scf::peelAndCanonicalizeForLoop(rewriter, loop, result);
- if (failed(status))
- return loop;
- return result;
-}
-
-//===---------------------------------------------------------------------===//
-// PipelineLoopOp
-//===---------------------------------------------------------------------===//
-
-static void
-loopScheduling(scf::ForOp forOp,
- std::vector<std::pair<Operation *, unsigned>> &schedule,
- unsigned iterationInterval, unsigned readLatency) {
- auto getLatency = [&](Operation *op) {
- if (isa<vector::TransferReadOp>(op))
- return readLatency;
- return unsigned(1);
- };
-
- DenseMap<Operation *, unsigned> opCycles;
- std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
- for (Operation &op : forOp.getBody()->getOperations()) {
- if (isa<scf::YieldOp>(op))
- continue;
- unsigned earlyCycle = 0;
- for (Value operand : op.getOperands()) {
- Operation *def = operand.getDefiningOp();
- if (!def)
- continue;
- earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
- }
- opCycles[&op] = earlyCycle;
- wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
- }
- for (auto it : wrappedSchedule) {
- for (Operation *op : it.second) {
- unsigned cycle = opCycles[op];
- schedule.push_back(std::make_pair(op, cycle / iterationInterval));
- }
- }
-}
-
-FailureOr<scf::ForOp> transform::PipelineLoopOp::applyToOne(scf::ForOp loop) {
- // TODO: make the pipelining pattern return the transformed loop.
- if (!getOperation()->getUses().empty()) {
- InFlightDiagnostic diag = emitError()
- << "NYI: cannot target the result of pipelining";
- diag.attachNote(getOperation()->use_begin()->getOwner()->getLoc())
- << "use here";
- return failure();
- }
-
- scf::PipeliningOption schedule;
- schedule.getScheduleFn =
- [this](scf::ForOp forOp,
- std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
- loopScheduling(forOp, schedule, iteration_interval(), read_latency());
- };
-
- RewritePatternSet patterns(loop->getContext());
- scf::populateSCFLoopPipeliningPatterns(patterns, schedule);
- assert(patterns.getNativePatterns().size() == 1 &&
- "expected one pipelining pattern");
- auto functionalPattern = [&patterns](scf::ForOp forOp,
- PatternRewriter &rewriter) {
- RewritePattern *pattern = patterns.getNativePatterns().front().get();
- return pattern->matchAndRewrite(forOp, rewriter);
- };
- if (failed(functional::applyAt(loop, std::move(functionalPattern))))
- return failure();
-
- return scf::ForOp();
-}
-
-//===---------------------------------------------------------------------===//
-// OutlineLoopOp
-//===---------------------------------------------------------------------===//
-
-static scf::ExecuteRegionOp outlineInExecuteRegion(RewriterBase &b,
- Operation *op) {
- if (op->getNumRegions() != 1)
- return nullptr;
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
- scf::ExecuteRegionOp executeRegionOp =
- b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
- {
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
- Operation *clonedOp = b.cloneWithoutRegions(*op);
- Region &clonedRegion = clonedOp->getRegions().front();
- assert(clonedRegion.empty() && "expected empty region");
- b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
- clonedRegion.end());
- b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
- }
- b.replaceOp(op, executeRegionOp.getResults());
- return executeRegionOp;
-}
-
-static FailureOr<func::FuncOp> outlineLoop(scf::ForOp loop, StringRef funcName,
- transform::TransformState &state) {
- PatternRewriterListener rewriter(loop->getContext());
- auto &listener = state.getExtension<TrackingListener>();
- rewriter.addListener(&listener);
- Location loc = loop.getLoc();
- scf::ExecuteRegionOp exec = outlineInExecuteRegion(rewriter, loop);
- assert(exec && "failed to produce execute_region");
- FailureOr<func::FuncOp> outlined =
- outlineSingleBlockRegion(rewriter, loc, exec.getRegion(), funcName);
- if (failed(listener.checkErrorState()))
- return failure();
- return outlined;
-}
-
-LogicalResult
-transform::OutlineLoopOp::apply(transform::TransformResults &results,
- transform::TransformState &state) {
- SmallVector<Operation *> resultVector;
- auto res =
- applyTransformToEach(state.getPayloadOps(target()), resultVector,
- [&](scf::ForOp loop) -> FailureOr<func::FuncOp> {
- return outlineLoop(loop, func_name(), state);
- });
- if (failed(res))
- return failure();
- results.set(getResult().cast<OpResult>(), resultVector);
- return success();
-}
-
-//===---------------------------------------------------------------------===//
-// PrintOp
-//===---------------------------------------------------------------------===//
-
-LogicalResult transform::PrintOp::apply(transform::TransformResults &results,
- transform::TransformState &state) {
- if (!target()) {
- llvm::outs() << "[[[ IR printer: " << name() << " top-level ]]]\n";
- state.getTopLevel()->dump();
- return success();
- }
-
- llvm::outs() << "[[[ IR printer: " << name() << " single op ]]]\n";
- ArrayRef<Operation *> targets = state.getPayloadOps(target());
- targets.front()->dump();
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// LinalgExt specific transforms
-//===----------------------------------------------------------------------===//
-
-LogicalResult
-transform::TileToLinalgExtTileOp::apply(transform::TransformResults &results,
- transform::TransformState &state) {
- LinalgTilingOptions tilingOptions;
- SmallVector<int64_t> tileSizes = extractI64Array(sizes());
- if (!tileSizes.empty())
- tilingOptions.setTileSizes(tileSizes);
-
- LinalgExt::LinalgExtTilingPattern pattern(this->getContext(), tilingOptions);
- ArrayRef<Operation *> targets = state.getPayloadOps(target());
- auto tilingInterfaceOp = dyn_cast<TilingInterface>(targets.front());
- if (!tilingInterfaceOp) {
- targets.front()->emitError("Cannot tile op: Not a TilingInterface");
- return failure();
- }
-
- FailureOr<iree_compiler::IREE::LinalgExt::TilingResult> result =
- functional::applyReturningPatternAt(pattern, tilingInterfaceOp);
- if (failed(result))
- return failure();
- results.set(tiled_op().cast<OpResult>(), result->tiledOp);
- results.set(tile_op().cast<OpResult>(), result->tileOp.getOperation());
- return success();
-}
-
-LogicalResult
-transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
- transform::TransformState &state) {
-
- ArrayRef<Operation *> producerOps = state.getPayloadOps(producer_op());
- ArrayRef<Operation *> containingOps = state.getPayloadOps(containing_op());
- for (auto it : llvm::zip(producerOps, containingOps)) {
- auto producerOp = dyn_cast<LinalgOp>(std::get<0>(it));
- Operation *containingOp = std::get<1>(it);
- if (!producerOp) {
- std::get<0>(it)->emitError("Cannot fuse op: Not a LinalgOp");
- return failure();
- }
- LinalgExt::LinalgExtFusionInContainingOpPattern pattern(this->getContext(),
- containingOp);
- if (failed(functional::applyReturningPatternAt(pattern, producerOp)))
- return failure();
- }
- return success();
-}
-
-FailureOr<scf::ForOp> transform::RewriteLinalgExtTileToScfForOp::applyToOne(
- LinalgExt::TileOp target) {
- LinalgExt::TileOpToSCFRewriter pattern(this->getContext());
- auto functionalRewrite =
- [&](LinalgExt::TileOp op,
- PatternRewriter &rewriter) -> FailureOr<scf::ForOp> {
- auto result = pattern.returningMatchAndRewrite(op, rewriter);
- if (failed(result))
- return failure();
- return result;
- };
- return functional::applyAt(target, functionalRewrite);
-}
-
-FailureOr<LinalgExt::InParallelOp>
-transform::RewriteLinalgExtTileToInParallelOp::applyToOne(
- LinalgExt::TileOp target) {
- LinalgExt::TileOpToInParallelRewriter pattern(this->getContext());
- auto functionalRewrite =
- [&](LinalgExt::TileOp op,
- PatternRewriter &rewriter) -> FailureOr<LinalgExt::InParallelOp> {
- auto result = pattern.returningMatchAndRewrite(op, rewriter);
- if (failed(result))
- return failure();
- return result;
- };
- return functional::applyAt(target, functionalRewrite);
-}
-
-FailureOr<Operation *>
-transform::RewriteLinalgExtInParallelToAsyncOp::applyToOne(
- LinalgExt::InParallelOp target) {
- LinalgExt::InParallelOpToAsyncRewriter pattern(this->getContext());
- auto functionalRewrite =
- [&](LinalgExt::InParallelOp op,
- PatternRewriter &rewriter) -> FailureOr<Operation *> {
- auto result = pattern.returningMatchAndRewrite(op, rewriter);
- if (failed(result))
- return failure();
- return result;
- };
- return functional::applyAt(target, functionalRewrite);
-}
-
-LogicalResult transform::RewriteLinalgExtInParallelToHALOp::apply(
- transform::TransformResults &results, transform::TransformState &state) {
- LinalgExt::InParallelOpToHALRewriter pattern(this->getContext());
- ArrayRef<Operation *> targets = state.getPayloadOps(target());
- return functional::applyReturningPatternAt(
- pattern, cast<LinalgExt::InParallelOp>(targets.front()));
-}
-
-FailureOr<scf::ForOp>
-transform::RewriteLinalgExtInParallelToScfForOp::applyToOne(
- LinalgExt::InParallelOp target) {
- LinalgExt::InParallelOpToScfForRewriter pattern(this->getContext());
- auto functionalRewrite =
- [&](LinalgExt::InParallelOp op,
- PatternRewriter &rewriter) -> FailureOr<scf::ForOp> {
- auto result = pattern.returningMatchAndRewrite(op, rewriter);
- if (failed(result))
- return failure();
- return result;
- };
- return functional::applyAt(target, functionalRewrite);
-}
-
#define GET_OP_CLASSES
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.cpp
deleted file mode 100644
index 9e38682..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.cpp
+++ /dev/null
@@ -1,305 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "PDL.h"
-
-#include "iree-dialects/Transforms/Functional.h"
-#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/PDL/IR/PDLOps.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Rewrite/FrozenRewritePatternSet.h"
-#include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/ADT/ScopeExit.h"
-
-namespace mlir {
-namespace linalg {
-
-/// Return ops that match any of the patterns.
-static SmallVector<Operation *>
-getMatchingOps(Operation *parent, const FrozenRewritePatternSet &patterns) {
- PatternApplicator applicator(patterns);
- applicator.applyDefaultCostModel();
-
- // TODO: The C++ functional API needs better interoperability with PDL.
- return functional::applyForEachIn(
- parent,
- [&](Operation *op, PatternRewriter &rewriter) -> FailureOr<Operation *> {
- if (succeeded(applicator.matchAndRewrite(op, rewriter)))
- return op;
- return failure();
- });
-}
-
-/// Hook for PDL driver to check if an operation (`values[0]`) is directly
-/// nested in a function with the name provided by an attribute (`values[1]`).
-/// TODO: PDL needs user-defined "questions".
-static LogicalResult nestedInFunc(PatternRewriter &rewriter,
- Operation *operation, Attribute attr) {
- auto func = operation->getParentOfType<func::FuncOp>();
- if (!func)
- return rewriter.notifyMatchFailure(operation, "not nested in a function");
- auto functionSymbol = attr.dyn_cast<SymbolRefAttr>();
- if (!functionSymbol)
- return rewriter.notifyMatchFailure(operation, "not a function identifier");
- return success(functionSymbol.getLeafReference() == func.getName());
-}
-
-/// PDL rewrite hook that does nothing.
-static void noOpRewriter(PatternRewriter &rewriter, Operation *op) {
-#ifndef NDEBUG
- op->setAttr("iree_linalg_transform.matched", rewriter.getUnitAttr());
-#endif
-}
-
-/// Construct a BlockAndValueMapping from `linalgOp` to `genericLinalgModelOp`.
-/// Walk both ops and check whether all subops are the same.
-static LogicalResult haveIdenticalBodiesImpl(LinalgOp linalgOp,
- LinalgOp genericLinalgModelOp) {
- BlockAndValueMapping bvm;
- bvm.map(linalgOp.getBlock()->getArguments(),
- genericLinalgModelOp.getBlock()->getArguments());
- SmallVector<Operation *> linalgBodyOps;
- linalgOp.getBlock()->walk(
- [&](Operation *op) { linalgBodyOps.push_back(op); });
-
- unsigned idx = 0;
- WalkResult res = genericLinalgModelOp.getBlock()->walk([&](Operation *op) {
- Operation *linalgSubOp = linalgBodyOps[idx++];
- if (op->getName() != linalgSubOp->getName())
- return WalkResult::interrupt();
- if (op->getAttrs() != linalgSubOp->getAttrs())
- return WalkResult::interrupt();
- for (auto it : llvm::zip(op->getOperands(), linalgSubOp->getOperands()))
- if (std::get<0>(it) != bvm.lookupOrNull(std::get<1>(it)))
- return WalkResult::interrupt();
- bvm.map(linalgSubOp->getResults(), op->getResults());
- return WalkResult::advance();
- });
-
- return success(!res.wasInterrupted());
-}
-
-/// Dispatch body equivalence check depending on case.
-static LogicalResult haveEquivalentBodies(LinalgOp linalgOp,
- LinalgOp genericLinalgModelOp,
- PatternRewriter &rewriter) {
- if (succeeded(haveIdenticalBodiesImpl(linalgOp, genericLinalgModelOp)))
- return success();
- // TODO: haveEquivalentBodiesImpl, see e.g.
- // https://gist.github.com/nicolasvasilache/39e89e18c46e02335c16db6ec20a07e3
- return failure();
-}
-
-/// Succeed when `linalgOp` and `linalgModelOp` are deemed equivalent.
-static LogicalResult isEquivalentToOpImpl(PatternRewriter &rewriter,
- LinalgOp linalgOp,
- LinalgOp linalgModelOp) {
- // If basic properties do not match, return failure.
- if (linalgOp.inputs() != linalgModelOp.inputs() ||
- linalgOp.outputs() != linalgModelOp.outputs() ||
- linalgOp.indexing_maps() != linalgModelOp.indexing_maps() ||
- linalgOp.iterator_types() != linalgModelOp.iterator_types())
- return failure();
-
- // Build the block and go perform a body comparison.
- {
- // createBlock moves the insertion point, scope it in an RAII block.
- OpBuilder::InsertionGuard guard(rewriter);
- Region &r = linalgModelOp->getRegion(0);
- Block *bodyBlock = rewriter.createBlock(
- &r, r.end(), linalgOp.getBlock()->getArgumentTypes(),
- llvm::to_vector<4>(
- llvm::map_range(linalgOp.getBlock()->getArguments(),
- [](Value v) { return v.getLoc(); })));
- ImplicitLocOpBuilder b(linalgModelOp.getLoc(), rewriter);
- auto regionBuilder = linalgModelOp.getRegionBuilder();
- llvm::ArrayRef<mlir::NamedAttribute> attrs = {};
- regionBuilder(b, *bodyBlock, attrs);
- }
-
- return haveEquivalentBodies(linalgOp, linalgModelOp, rewriter);
-}
-
-/// Check whether the unique Operation* stored in `values[0]` (assumed) is
-/// equivalent to the unique StringRefAttr passed in `values[1]` (assumed).
-/// Equivalence is achieved when either:
-/// 1. `values[0]` has the name stored in `values[1]`.
-/// 2. `values[0]` and `values[1]` are both linalg ops and their structured
-/// interfaces as well as their bodies are equivalent.
-/// Structured interfaces equivalence is a simple attribute level check.
-/// Body equivalence is more involved and currently limited:
-/// a. the current impl constructs an instance of the op whose name is
-/// specified in `values[1]` and checks for exact body equality.
-/// b. a more advanced version would "subtract" the bodies and fold, cse
-/// and canonicalize to fixed point. If the result is "all zeros",
-/// then the bodies would be equivalent (really isomorphic).
-/// 3. other cases TBD (e.g. vector.generic when available).
-static LogicalResult isEquivalentToOp(PatternRewriter &rewriter,
- Operation *operation,
- Attribute attribute) {
- auto modelOpNameAttr = attribute.dyn_cast<StringAttr>();
- if (!modelOpNameAttr)
- return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
- auto modelOpName = modelOpNameAttr.strref();
-
- // 1. If op has name `modelOpName`, the match is trivial.
- if (operation->getName().getStringRef() == modelOpName)
- return success();
-
- // 2. Linalg vs Linalg.
- // Create op from `modelOpName`.
- OperationState modelOpState(
- operation->getLoc(), modelOpName, operation->getOperands(),
- operation->getResultTypes(), operation->getAttrs());
- modelOpState.addRegion();
- Operation *modelOp = rewriter.create(modelOpState);
- auto g1 = llvm::make_scope_exit([&]() { rewriter.eraseOp(modelOp); });
- LinalgOp linalgOp = dyn_cast<LinalgOp>(operation);
- LinalgOp linalgModelOp = dyn_cast<LinalgOp>(modelOp);
- if (linalgOp && linalgModelOp)
- return isEquivalentToOpImpl(rewriter, linalgOp, linalgModelOp);
-
- // 3. TBD
- return failure();
-}
-
-/// Assume that:
-/// 1. `values[0]` is an operands range
-/// 2. `values[1]` contains a DictAttr with `operand_number`, `dim` and
-/// `divisor` IntegerAttr entries.
-/// Succeed if `operands`[`operand_number`] is a ranked type whose `dim` is a
-/// multiple of `divisor`.
-/// Note: 0 is the convention to express "do not tile", it is considered to
-/// divide everything.
-static LogicalResult isDimMultipleOf(PatternRewriter &rewriter,
- ValueRange operands, Attribute attribute) {
- auto dict = attribute.dyn_cast<DictionaryAttr>();
- if (!dict)
- return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
-
- int64_t dim;
- auto dimAttr = dict.getAs<IntegerAttr>("dim");
- if (!dimAttr)
- return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
- dim = dimAttr.getInt();
-
- int64_t divisor;
- auto divisorAttr = dict.getAs<IntegerAttr>("divisor");
- if (!divisorAttr)
- return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
- divisor = divisorAttr.getInt();
-
- int64_t operandNumber;
- auto operandNumberAttr = dict.getAs<IntegerAttr>("operand_number");
- if (!operandNumberAttr)
- return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
- operandNumber = operandNumberAttr.getInt();
-
- ShapedType shapedType;
- if (static_cast<int64_t>(operands.size()) > operandNumber)
- shapedType = operands[operandNumber].getType().dyn_cast<ShapedType>();
- if (!shapedType || shapedType.getRank() <= dim)
- return failure();
- return success(divisor == 0 || (shapedType.getShape()[dim] > 0 &&
- shapedType.getShape()[dim] % divisor == 0));
-}
-
-/// Assume that:
-/// 1. `values[0]` is an operands range
-/// 2. `values[1]` contains a DictAttr with `operand_number` and `dim`
-/// IntegerAttr entries.
-/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is
-/// dynamic.
-static LogicalResult isDimStatic(PatternRewriter &rewriter, ValueRange operands,
- Attribute attribute) {
- auto dict = attribute.dyn_cast<DictionaryAttr>();
- if (!dict)
- return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
-
- int64_t dim;
- auto dimAttr = dict.getAs<IntegerAttr>("dim");
- if (!dimAttr)
- return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
- dim = dimAttr.getInt();
-
- int64_t operandNumber;
- auto operandNumberAttr = dict.getAs<IntegerAttr>("operand_number");
- if (!operandNumberAttr)
- return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
- operandNumber = operandNumberAttr.getInt();
-
- ShapedType shapedType;
- if (static_cast<int64_t>(operands.size()) > operandNumber)
- shapedType = operands[operandNumber].getType().dyn_cast<ShapedType>();
- return success(shapedType && !shapedType.isDynamicDim(dim));
-}
-
-/// Assume that:
-/// 1. `values[0]` is an operands range
-/// 2. `values[1]` contains a DictAttr with `operand_number` and `dim`
-/// IntegerAttr entries.
-/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is
-/// dynamic.
-static LogicalResult isDimDynamic(PatternRewriter &rewriter,
- ValueRange operands, Attribute attribute) {
- auto dict = attribute.dyn_cast<DictionaryAttr>();
- if (!dict)
- return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
-
- int64_t dim;
- auto dimAttr = dict.getAs<IntegerAttr>("dim");
- if (!dimAttr)
- return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
- dim = dimAttr.getInt();
-
- int64_t operandNumber;
- auto operandNumberAttr = dict.getAs<IntegerAttr>("operand_number");
- if (!operandNumberAttr)
- return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
- operandNumber = operandNumberAttr.getInt();
-
- ShapedType shapedType;
- if (static_cast<int64_t>(operands.size()) > operandNumber)
- shapedType = operands[operandNumber].getType().dyn_cast<ShapedType>();
- return success(shapedType && shapedType.isDynamicDim(dim));
-}
-
-FailureOr<SmallVector<Operation *>> findMatchingOps(transform::MatchOp matchOp,
- SymbolRefAttr pattern,
- Operation *containerOp) {
- auto symbolTableOp = matchOp->getParentWithTrait<OpTrait::SymbolTable>();
- if (!symbolTableOp)
- symbolTableOp->emitError("no parent op with a SymbolTable");
- auto patternOp = dyn_cast_or_null<pdl::PatternOp>(
- SymbolTable::lookupSymbolIn(symbolTableOp, pattern));
- if (!patternOp) {
- return symbolTableOp->emitError("could not find a pattern named: ")
- << pattern;
- }
-
- // Clone the pattern operation into the temporary module used by the driver
- // as it might be referenced multiple times.
- OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
- OpBuilder::atBlockBegin(pdlModuleOp->getBody()).clone(*patternOp);
-
- // Build the PDL module.
- PDLPatternModule pdlModule(std::move(pdlModuleOp));
- pdlModule.registerConstraintFunction("nestedInFunc", nestedInFunc);
- pdlModule.registerConstraintFunction("isDimDynamic", isDimDynamic);
- pdlModule.registerConstraintFunction("isDimMultipleOf", isDimMultipleOf);
- pdlModule.registerConstraintFunction("isDimStatic", isDimStatic);
- pdlModule.registerConstraintFunction("isEquivalentToOp", isEquivalentToOp);
- pdlModule.registerRewriteFunction("iree_linalg_transform.apply",
- noOpRewriter);
-
- RewritePatternSet patterns(std::move(pdlModule));
- return getMatchingOps(containerOp, std::move(patterns));
-}
-
-} // namespace linalg
-} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.h b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.h
deleted file mode 100644
index 94d126e..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.h
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_TRANSFORMS_PDL_H
-#define IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_TRANSFORMS_PDL_H
-
-#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
-#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/SmallVector.h"
-
-namespace mlir {
-namespace linalg {
-
-/// Find all operations in `containerOp` that are matched by the specified PDL
-/// `matchOp`, which is located in the same parent ModuleOp as `matchOp`.
-FailureOr<SmallVector<Operation *>> findMatchingOps(transform::MatchOp matchOp,
- SymbolRefAttr pattern,
- Operation *containerOp);
-
-inline FailureOr<SmallVector<Operation *>>
-findMatchingOps(transform::MatchOp matchOp, Operation *containerOp) {
- return findMatchingOps(matchOp, matchOp.targetMatcher(), containerOp);
-}
-
-} // namespace linalg
-} // namespace mlir
-
-#endif // IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_TRANSFORMS_PDL_H
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
new file mode 100644
index 0000000..ae2d459
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -0,0 +1,1296 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
+#include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h"
+#include "iree-dialects/Dialect/LinalgTransform/TrackingRewriteDriver.h"
+#include "iree-dialects/Transforms/Listener.h"
+#include "iree-dialects/Transforms/ListenerCSE.h"
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
+#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
+#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#define DEBUG_TYPE "transform-ops-ext"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Additional constraints for PDLMatchOp.
+//===----------------------------------------------------------------------===//
+
+/// Hook for PDL driver to check if an operation (`values[0]`) is directly
+/// nested in a function with the name provided by an attribute (`values[1]`).
+/// TODO: PDL needs user-defined "questions".
+static LogicalResult nestedInFunc(PatternRewriter &rewriter,
+ Operation *operation, Attribute attr) {
+ auto func = operation->getParentOfType<func::FuncOp>();
+ if (!func)
+ return rewriter.notifyMatchFailure(operation, "not nested in a function");
+ auto functionSymbol = attr.dyn_cast<SymbolRefAttr>();
+ if (!functionSymbol)
+ return rewriter.notifyMatchFailure(operation, "not a function identifier");
+ return success(functionSymbol.getLeafReference() == func.getName());
+}
+
+//===----------------------------------------------------------------------===//
+// StructuredTransformOpsExtension
+//===----------------------------------------------------------------------===//
+
+transform_ext::StructuredTransformOpsExtension::
+ StructuredTransformOpsExtension() {
+ registerTransformOps<
+#define GET_OP_LIST
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc"
+ >();
+
+ registerPDLMatchConstraintFn("nestedInFunc", nestedInFunc);
+
+ declareDependentDialect<bufferization::BufferizationDialect>();
+ declareDependentDialect<vector::VectorDialect>();
+
+ // declareDependentDialect<arith::ArithmeticDialect>();
+ // declareDependentDialect<scf::SCFDialect>();
+ // declareDependentDialect<LLVM::LLVMDialect>();
+ // declareDependentDialect<AffineDialect>();
+ // declareDependentDialect<bufferization::BufferizationDialect>();
+ // declareDependentDialect<func::FuncDialect>();
+ // declareDependentDialect<linalg::LinalgDialect>();
+ // // declareDependentDialect<pdl::PDLDialect>();
+ // // declareDependentDialect<pdl_interp::PDLInterpDialect>();
+ // declareDependentDialect<scf::SCFDialect>();
+ // declareDependentDialect<tensor::TensorDialect>();
+}
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// TrackingListener
+//===----------------------------------------------------------------------===//
+
+/// Find the linalg op that defines all values in range, potentially
+/// transitively through tensor casts.
+static linalg::LinalgOp findSingleLinalgOpDefiningAll(ValueRange range) {
+ linalg::LinalgOp sourceOp = nullptr;
+ for (Value value : range) {
+ // See through tensor casts.
+ //
+ // TODO: we may need some generalization (interfaces?) of this for other
+ // operations, especially multi-operand ones to understand which of their
+ // operands may be coming from a Linalg op. Or a completely different
+ // mechanism of tracking op replacement at creation, or even different
+ // patterns that identify the "main" result of a transformation.
+ while (auto castOp = value.getDefiningOp<tensor::CastOp>())
+ value = castOp.source();
+
+ if (auto currentSourceOp = value.getDefiningOp<linalg::LinalgOp>()) {
+ if (!sourceOp || sourceOp == currentSourceOp) {
+ sourceOp = currentSourceOp;
+ continue;
+ }
+
+ LLVM_DEBUG(
+ DBGS() << "different source linalg ops for replacing one op: \n"
+ << sourceOp << "\n"
+ << currentSourceOp << "\n");
+ }
+ LLVM_DEBUG(DBGS() << "replacing linalg op with unknown non-linalg op:\n"
+ << *value.getDefiningOp() << "\n");
+ return nullptr;
+ }
+ return sourceOp;
+}
+
+/// Find the scf "for" op that defines all values in the range.
+static scf::ForOp findSingleForOpDefiningAll(ValueRange range) {
+ scf::ForOp forOp = nullptr;
+ for (Value value : range) {
+ if (auto currentSourceOp = value.getDefiningOp<scf::ForOp>()) {
+ if (!forOp || forOp == currentSourceOp) {
+ forOp = currentSourceOp;
+ continue;
+ }
+ LLVM_DEBUG(
+ DBGS() << "different source scf.for ops when replacing one op\n");
+ }
+
+ LLVM_DEBUG(
+ DBGS()
+ << "could not find a source scf.for when replacing another scf.for\n");
+ return nullptr;
+ }
+ return forOp;
+}
+
+// Find a single op that defines all values in the range, optionally
+// transitively through other operations in an op-specific way.
+static Operation *findSingleDefiningOp(Operation *replacedOp,
+ ValueRange range) {
+ return llvm::TypeSwitch<Operation *, Operation *>(replacedOp)
+ .Case<linalg::LinalgOp>([&](linalg::LinalgOp) -> Operation * {
+ return findSingleLinalgOpDefiningAll(range);
+ })
+ .Case<scf::ForOp>([&](scf::ForOp) -> Operation * {
+ return findSingleForOpDefiningAll(range);
+ })
+ .Default([](Operation *) -> Operation * { return nullptr; });
+}
+
+namespace detail {
+class TrackingListener : public RewriteListener,
+ public transform::TransformState::Extension {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TrackingListener);
+
+ explicit TrackingListener(transform::TransformState &state)
+ : transform::TransformState::Extension(state) {}
+
+ ~TrackingListener() override {
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+ assert(errorStateChecked && "must check listener error state");
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+ }
+
+ void notifyOperationReplaced(Operation *op, ValueRange newValues) override {
+ // Bail out if in error state.
+ if (hadErrors)
+ return;
+
+ // Exit early if the op is not tracked.
+ Value handle = getTransformState().getHandleForPayloadOp(op);
+ if (!handle)
+ return;
+
+ Operation *replacement = findSingleDefiningOp(op, newValues);
+ if (!replacement) {
+ emitError(op) << "could not find replacement for tracked op";
+ return;
+ }
+
+ LLVM_DEBUG(DBGS() << "replacing tracked " << *op << " with " << *replacement
+ << " for " << handle << "\n");
+ mayFail(replacePayloadOp(op, replacement));
+ }
+
+ void notifyOperationRemoved(Operation *op) override {
+ // Bail out if in error state.
+ if (hadErrors)
+ return;
+
+ // Exit early if the op is not tracked.
+ Value handle = getTransformState().getHandleForPayloadOp(op);
+ if (!handle)
+ return;
+
+ LLVM_DEBUG(DBGS() << "removing tracked " << *op << " for " << handle
+ << "\n");
+ mayFail(replacePayloadOp(op, nullptr));
+ }
+
+ LogicalResult checkErrorState() const {
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+ errorStateChecked = true;
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+ return failure(hadErrors);
+ }
+
+private:
+ InFlightDiagnostic emitError(Operation *op, const llvm::Twine &message = {}) {
+ mayFail(failure());
+ return op->emitError(message);
+ }
+
+ void mayFail(LogicalResult result) {
+ hadErrors |= result.failed();
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+ errorStateChecked = false;
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+ }
+
+ bool hadErrors = false;
+
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+ mutable bool errorStateChecked = false;
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+};
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// CanonicalizedSequenceOp
+//===----------------------------------------------------------------------===//
+
+/// Run enabling transformations (LICM and its variants, single-iteration loop
+/// removal, CSE) on the given function.
+static LogicalResult performEnablerTransformations(
+ func::FuncOp func, RewriteListener &listener,
+ linalg::LinalgEnablingOptions options = linalg::LinalgEnablingOptions()) {
+ MLIRContext *ctx = func->getContext();
+ RewritePatternSet patterns(ctx);
+ linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
+ scf::populateSCFForLoopCanonicalizationPatterns(patterns);
+ if (failed(applyPatternsTrackAndFoldGreedily(func, listener,
+ std::move(patterns))))
+ return failure();
+
+ // This assumes LICM never removes operations so we don't need tracking.
+ if (options.licm) {
+ func->walk(
+ [](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
+ }
+
+ func.walk([](Operation *op) {
+ (void)llvm::TypeSwitch<Operation *, LogicalResult>(op)
+ .Case<AffineForOp, scf::ForOp>(
+ [](auto loop) { return promoteIfSingleIteration(loop); })
+ .Default([](Operation *) { return success(); });
+ });
+
+ if (options.hoistRedundantVectorTransfers)
+ linalg::hoistRedundantVectorTransfers(func);
+ if (options.hoistRedundantVectorTransfersOnTensor)
+ linalg::hoistRedundantVectorTransfersOnTensor(func);
+
+ return eliminateCommonSubexpressions(func, /*domInfo=*/nullptr, &listener);
+}
+
+/// Run enabling transformations on the given `containerOp` while preserving the
+/// operation tracking information.
+static LogicalResult performEnablerTransformations(
+ Operation *containerOp, RewriteListener &listener,
+ linalg::LinalgEnablingOptions options = linalg::LinalgEnablingOptions()) {
+ auto res = containerOp->walk([&](func::FuncOp func) {
+ if (failed(performEnablerTransformations(func, listener, options)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ return failure(res.wasInterrupted());
+}
+
+LogicalResult transform_ext::CanonicalizedSequenceOp::apply(
+ transform::TransformResults &transformResults,
+ transform::TransformState &state) {
+
+ MLIRContext *ctx = getContext();
+ RewritePatternSet patternList(ctx);
+ for (Dialect *dialect : ctx->getLoadedDialects())
+ dialect->getCanonicalizationPatterns(patternList);
+ for (RegisteredOperationName op : ctx->getRegisteredOperations())
+ op.getCanonicalizationPatterns(patternList, ctx);
+ FrozenRewritePatternSet patterns(std::move(patternList));
+
+ transform::TransformState::RegionScope regionScope =
+ state.make_region_scope(getBodyRegion());
+ auto &listener = state.addExtension<::detail::TrackingListener>();
+ auto detachListener = llvm::make_scope_exit(
+ [&] { state.removeExtension<::detail::TrackingListener>(); });
+ if (failed(mapBlockArguments(state)))
+ return failure();
+
+ auto checkedListenerTransform =
+ [&](function_ref<LogicalResult(Operation *, RewriteListener &)>
+ transform) {
+ SmallVector<Operation *> roots;
+ if (Value target = getTarget())
+ llvm::append_range(roots, state.getPayloadOps(target));
+ else
+ roots.push_back(state.getTopLevel());
+
+ for (Operation *target : roots) {
+ // Make sure we always check the error state, no boolean
+ // short-circuting.
+ LogicalResult result = transform(target, listener);
+ LogicalResult listenerResult = listener.checkErrorState();
+ if (failed(result) || failed(listenerResult))
+ return failure();
+ }
+ return success();
+ };
+
+ auto performCSE = [](Operation *root, RewriteListener &listener) {
+ LogicalResult result =
+ eliminateCommonSubexpressions(root, /*domInfo=*/nullptr, &listener);
+ LLVM_DEBUG(
+ DBGS() << (succeeded(result) ? "successfully performed" : "failed")
+ << " CSE\n");
+ return result;
+ };
+ auto performEnabler = [](Operation *root, RewriteListener &listener) {
+ LogicalResult result = performEnablerTransformations(root, listener);
+ LLVM_DEBUG(
+ DBGS() << (succeeded(result) ? "successfully performed" : "failed")
+ << " enabling transformations\n");
+ return result;
+ };
+ auto performCanonicalization = [&patterns](Operation *root,
+ RewriteListener &listener) {
+ LogicalResult result =
+ applyPatternsTrackAndFoldGreedily(root, listener, patterns);
+ LLVM_DEBUG(
+ DBGS() << (succeeded(result) ? "successfully performed" : "failed")
+ << " canonicalization\n");
+ return result;
+ };
+
+ LLVM_DEBUG(DBGS() << "begin canonicalizing sequence\n");
+ if (failed(checkedListenerTransform(performCSE)))
+ return failure();
+ if (failed(checkedListenerTransform(performCanonicalization)))
+ return failure();
+
+ for (Operation &transform : getBodyBlock()->without_terminator()) {
+ if (failed(state.applyTransform(
+ cast<transform::TransformOpInterface>(transform)))) {
+ LLVM_DEBUG(DBGS() << "failed: " << transform << "\n");
+ return failure();
+ }
+ LLVM_DEBUG(DBGS() << "successfully performed: " << transform << "\n");
+
+ if (failed(checkedListenerTransform(performCSE)))
+ return failure();
+ if (failed(checkedListenerTransform(performEnabler)))
+ return failure();
+ if (failed(checkedListenerTransform(performCanonicalization)))
+ return failure();
+ }
+
+ LLVM_DEBUG(DBGS() << "end canonicalizing sequence\n");
+ return success();
+}
+
+void transform_ext::CanonicalizedSequenceOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ SmallVector<MemoryEffects::EffectInstance> childEffects;
+ walk([&](Operation *child) {
+ // Skip self to avoid infinite recursion.
+ if (child == getOperation())
+ return;
+
+ auto iface = dyn_cast<MemoryEffectOpInterface>(child);
+ if (!iface)
+ return;
+
+ childEffects.clear();
+ iface.getEffects(childEffects);
+ llvm::append_range(effects, childEffects);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// TODO: WILL MIGRATE
+//===----------------------------------------------------------------------===//
+
+using namespace mlir::linalg;
+
+/// Extracts a vector of int64_t from an array attribute. Asserts if the
+/// attribute contains values other than integers.
+static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
+ SmallVector<int64_t> result;
+ result.reserve(attr.size());
+ for (APInt value : attr.getAsValueRange<IntegerAttr>())
+ result.push_back(value.getSExtValue());
+ return result;
+}
+
+/// Extracts a vector of unsigned from an array attribute. Asserts if the
+/// attribute contains values other than intergers. May truncate.
+static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
+ SmallVector<unsigned> result;
+ result.reserve(attr.size());
+ for (APInt value : attr.getAsValueRange<IntegerAttr>())
+ result.push_back(value.getZExtValue());
+ return result;
+}
+
+/// Apply a tiling transformation to all payload ops and store both the
+/// tiled operation as well as the created tile loops.
+static LogicalResult
+applyTilingToAll(Operation *transformOp, Value target,
+ ArrayRef<int64_t> tileSizes,
+ mlir::transform::TransformResults &transformResults,
+ mlir::transform::TransformState &state,
+ std::function<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
+ size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
+
+ SmallVector<Operation *> tiledLinalgOps;
+ SmallVector<SmallVector<Operation *>> loopOps(numLoops);
+
+ for (Operation *target : state.getPayloadOps(target)) {
+ auto linalgOp = cast<linalg::LinalgOp>(target);
+ FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
+ if (failed(tiled))
+ return failure();
+
+ tiledLinalgOps.push_back(tiled->op);
+ if (tiled->loops.size() != numLoops)
+ // Not enough loops were generated. This usually means that the input size
+ // was smaller than the tiling size.
+ // TODO: LinalgTilingPattern should return failure().
+ return failure();
+ for (unsigned int i = 0; i < numLoops; ++i) {
+ loopOps[i].push_back(tiled->loops[i]);
+ }
+ }
+
+ transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
+ for (unsigned int i = 0; i < numLoops; ++i) {
+ transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
+ }
+ return success();
+}
+
+/// Parse a tiling operation that returns the tiled op as well as the created
+/// tile loops. The function counts the non-zero tile sizes to compute the
+/// number of results.
+static ParseResult parseTileOp(OpAsmParser &parser, OperationState &result,
+ StringRef sizesAttrName) {
+ OpAsmParser::UnresolvedOperand targetOperand;
+ SMLoc opLoc;
+ parser.getCurrentLocation(&opLoc);
+ if (parser.parseOperand(targetOperand))
+ return parser.emitError(opLoc, "expected `target` operand");
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ Attribute sizesAttr = result.attributes.get(sizesAttrName);
+ if (!sizesAttr) {
+ return parser.emitError(
+ opLoc, llvm::formatv("expected `{0}` attribute", sizesAttrName));
+ }
+ auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
+ if (!sizesArrayAttr) {
+ return parser.emitError(
+ opLoc,
+ llvm::formatv("`{0}` attribute must be an array", sizesAttrName));
+ }
+ Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
+ size_t numExpectedLoops =
+ sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
+ result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
+ if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
+ return failure();
+ return success();
+}
+
+namespace {
+class SimpleRewriter : public PatternRewriter {
+public:
+ explicit SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// ScalarizeOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform_ext::ScalarizeOp::applyToOne(LinalgOp target) {
+ LinalgTilingOptions tilingOptions;
+ tilingOptions.scalarizeDynamicDims();
+ // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
+ // sizes and asserts that it is not already set.
+ SmallVector<int64_t> emptyTileSizes;
+ LinalgTilingPattern pattern(getContext(), tilingOptions);
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(target);
+ FailureOr<TiledLinalgOp> result =
+ pattern.returningMatchAndRewrite(target, rewriter);
+ if (failed(result))
+ return failure();
+ return result->op;
+}
+
+//===---------------------------------------------------------------------===//
+// FuseOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult transform_ext::FuseOp::apply(
+ mlir::transform::TransformResults &transformResults,
+ mlir::transform::TransformState &state) {
+ LinalgTilingAndFusionOptions fusionOptions;
+ fusionOptions.tileSizes = extractI64Array(getTileSizes());
+ fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
+
+ return applyTilingToAll(
+ getOperation(), getTarget(), fusionOptions.tileSizes, transformResults,
+ state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
+ LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(linalgOp);
+ FailureOr<TileLoopNest> tileLoopNest =
+ pattern.returningMatchAndRewrite(linalgOp, rewriter);
+ if (failed(tileLoopNest))
+ return failure();
+
+ TiledLinalgOp tiledLinalgOp;
+ tiledLinalgOp.op = tileLoopNest->getRootOp();
+ tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
+ tileLoopNest->getLoopOps().end()};
+ return tiledLinalgOp;
+ });
+}
+
+LogicalResult transform_ext::FuseOp::verify() {
+ SmallVector<int64_t> permutation = extractI64Array(getTileInterchange());
+ auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
+ if (!std::is_permutation(sequence.begin(), sequence.end(),
+ permutation.begin(), permutation.end())) {
+ return emitOpError() << "expects interchange to be a permutation, found "
+ << getTileInterchange();
+ }
+ return success();
+}
+
+ParseResult transform_ext::FuseOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseTileOp(parser, result, "tile_sizes");
+}
+
+void transform_ext::FuseOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p << getTarget();
+ p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+//===---------------------------------------------------------------------===//
+// GeneralizeOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform_ext::GeneralizeOp::applyToOne(LinalgOp target) {
+ // Exit early if no transformation is needed.
+ if (isa<GenericOp>(target))
+ return target;
+
+ LinalgGeneralizationPattern pattern(getContext());
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(target);
+ FailureOr<GenericOp> result =
+ pattern.returningMatchAndRewrite(target, rewriter);
+ if (failed(result))
+ return failure();
+ return cast<LinalgOp>(result->getOperation());
+}
+
+//===---------------------------------------------------------------------===//
+// InterchangeOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform_ext::InterchangeOp::applyToOne(LinalgOp target) {
+ SmallVector<unsigned> interchangeVector =
+ extractUIntArray(getIteratorInterchange());
+ // Exit early if no transformation is needed.
+ if (interchangeVector.empty())
+ return target;
+
+ auto genericTarget = dyn_cast<GenericOp>(target.getOperation());
+ if (!genericTarget) {
+ InFlightDiagnostic diag = emitOpError()
+ << "applies to " << GenericOp::getOperationName()
+ << " ops";
+ diag.attachNote(target.getLoc()) << "attempted to apply to this op";
+ return diag;
+ }
+
+ GenericOpInterchangePattern pattern(getContext(), interchangeVector);
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(target);
+ FailureOr<GenericOp> result =
+ pattern.returningMatchAndRewrite(genericTarget, rewriter);
+ if (failed(result))
+ return failure();
+ return cast<LinalgOp>(result->getOperation());
+}
+
+LogicalResult transform_ext::InterchangeOp::verify() {
+ SmallVector<unsigned> permutation =
+ extractUIntArray(getIteratorInterchange());
+ auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
+ if (!std::is_permutation(sequence.begin(), sequence.end(),
+ permutation.begin(), permutation.end())) {
+ return emitOpError()
+ << "expects iterator_interchange to be a permutation, found "
+ << getIteratorInterchange();
+ }
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// PadOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform_ext::PadOp::applyToOne(LinalgOp target) {
+ // Convert the integer packing flags to booleans.
+ SmallVector<bool> packPaddings;
+ for (int64_t packPadding : extractI64Array(getPackPaddings()))
+ packPaddings.push_back(static_cast<bool>(packPadding));
+
+ // Convert the padding values to attributes.
+ SmallVector<Attribute> paddingValues;
+ for (auto const &it :
+ llvm::zip(getPaddingValues(), target->getOperandTypes())) {
+ Attribute attr = std::get<0>(it);
+ Type elementType = getElementTypeOrSelf(std::get<1>(it));
+ // Try to parse string attributes to obtain an attribute of element type.
+ if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
+ paddingValues.push_back(
+ parseAttribute(attr.cast<StringAttr>(), elementType));
+ if (!paddingValues.back()) {
+ return target->emitOpError("expects a padding value ")
+ << std::get<0>(it) << " that parses to " << elementType;
+ }
+ continue;
+ }
+ // Otherwise, add the attribute directly.
+ if (attr.getType() != elementType) {
+ return target->emitOpError("expects a padding value ")
+ << attr << " of type " << elementType;
+ }
+ paddingValues.push_back(attr);
+ }
+
+ // Extract the transpose vectors.
+ SmallVector<SmallVector<int64_t>> transposePaddings;
+ for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
+ transposePaddings.push_back(
+ extractI64Array(transposeVector.cast<ArrayAttr>()));
+
+ LinalgPaddingOptions paddingOptions;
+ paddingOptions.setPaddingValues(paddingValues);
+ paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions()));
+ paddingOptions.setPackPaddings(packPaddings);
+ paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
+ paddingOptions.setTransposePaddings(transposePaddings);
+
+ LinalgPaddingPattern pattern(getContext(), paddingOptions);
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(target);
+ return pattern.returningMatchAndRewrite(target, rewriter);
+}
+
+LogicalResult transform_ext::PadOp::verify() {
+ SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings());
+ if (any_of(packPaddings, [](int64_t packPadding) {
+ return packPadding != 0 && packPadding != 1;
+ })) {
+ return emitOpError()
+ << "expects pack_paddings to contain booleans (0/1), found "
+ << getPackPaddings();
+ }
+ SmallVector<int64_t> paddingDimensions =
+ extractI64Array(getPaddingDimensions());
+ if (any_of(paddingDimensions,
+ [](int64_t paddingDimension) { return paddingDimension < 0; })) {
+ return emitOpError()
+ << "expects padding_dimensions to contain positive integers, found "
+ << getPaddingDimensions();
+ }
+ SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings());
+ if (any_of(hoistPaddings,
+ [](int64_t hoistPadding) { return hoistPadding < 0; })) {
+ return emitOpError()
+ << "expects hoist_paddings to contain positive integers, found "
+ << getHoistPaddings();
+ }
+ ArrayAttr transposes = getTransposePaddings();
+ for (Attribute attr : transposes) {
+ SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
+ auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
+ if (!std::is_permutation(sequence.begin(), sequence.end(),
+ transpose.begin(), transpose.end())) {
+ return emitOpError()
+ << "expects transpose_paddings to be a permutation, found "
+ << attr;
+ }
+ }
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// BufferizeOp
+//===---------------------------------------------------------------------===//
+
+static void applyBufferizationEnablingTransformations(ModuleOp moduleOp) {
+ RewritePatternSet patterns(moduleOp.getContext());
+ patterns.add<GeneralizePadOpPattern>(moduleOp.getContext());
+ (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
+}
+
+LogicalResult
+transform_ext::BufferizeOp::apply(mlir::transform::TransformResults &result,
+ mlir::transform::TransformState &state) {
+ bufferization::OneShotBufferizationOptions options;
+ options.bufferizeFunctionBoundaries = true;
+ options.memCpyFn = [](OpBuilder &builder, Location loc, Value from,
+ Value to) {
+ return success(linalg::makeMemRefCopyOp(builder, loc, from, to));
+ };
+
+ auto moduleOp = cast<ModuleOp>(state.getTopLevel());
+ applyBufferizationEnablingTransformations(moduleOp);
+ if (failed(runOneShotModuleBufferize(moduleOp, options)))
+ return failure();
+
+ // Perform buffer-level hoistings.
+ state.getTopLevel()->walk(
+ [&](func::FuncOp funcOp) { hoistRedundantVectorTransfers(funcOp); });
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// LowerToLLVMOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult
+transform_ext::LowerToLLVMOp::apply(mlir::transform::TransformResults &result,
+ mlir::transform::TransformState &state) {
+ // TODO: it is feasible to scope lowering at arbitrary level and introduce
+ // unrealized casts, but there needs to be the final module-wise cleanup in
+ // the end. Keep module-level for now.
+ PassManager pm(getContext());
+
+ pm.addNestedPass<func::FuncOp>(createConvertVectorToSCFPass());
+ pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
+ if (getEnableAsync()) {
+ pm.addPass(createAsyncToAsyncRuntimePass());
+ pm.addPass(createAsyncRuntimeRefCountingPass());
+ pm.addPass(createAsyncRuntimeRefCountingOptPass());
+ }
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createLowerAffinePass());
+ pm.addPass(createConvertSCFToCFPass());
+ pm.addPass(createConvertLinalgToLLVMPass());
+ pm.addPass(createConvertVectorToLLVMPass(
+ // clang-format off
+ LowerVectorToLLVMOptions()
+ .enableReassociateFPReductions(getReassociateFpReductions())
+ .enableIndexOptimizations(getEnableIndexOptimizations())
+ .enableArmNeon(getEnableArmNeon())
+ .enableArmSVE(getEnableArmSve())
+ .enableAMX(getEnableAmx())
+ .enableX86Vector(getEnableX86vector())));
+ // clang-format on
+ pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
+ pm.addPass(createMemRefToLLVMPass());
+ if (getEnableAsync())
+ pm.addPass(createConvertAsyncToLLVMPass());
+ pm.addPass(createConvertFuncToLLVMPass());
+ pm.addPass(createReconcileUnrealizedCastsPass());
+ if (failed(pm.run(state.getTopLevel())))
+ return failure();
+
+ // Make all arguments noalias for now.
+ // FIXME: this is a terrible hack!
+ state.getTopLevel()->walk([](LLVM::LLVMFuncOp funcOp) {
+ for (int64_t i = 0; i < funcOp.getNumArguments(); ++i) {
+ if (!funcOp.getFunctionType()
+ .getParamType(i)
+ .isa<LLVM::LLVMPointerType>())
+ continue;
+ funcOp.setArgAttr(i, "llvm.noalias", UnitAttr::get(funcOp.getContext()));
+ }
+ });
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// DecomposeOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult
+transform_ext::DecomposeOp::apply(mlir::transform::TransformResults &results,
+ mlir::transform::TransformState &state) {
+ RewritePatternSet patterns(getContext());
+ // TODO: make this targetable.
+ populateDecomposeConvolutionPatterns(patterns, LinalgTransformationFilter());
+ if (failed(applyPatternsAndFoldGreedily(state.getTopLevel(),
+ std::move(patterns))))
+ return failure();
+
+ // TODO: make this chainable, it isn't in the original codegenstrategy.
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// VectorizeOp
+//===---------------------------------------------------------------------===//
+
+static void
+configureVectorizationPatterns(transform_ext::VectorizeOp vectorizeOp,
+ RewritePatternSet &patterns) {
+ MLIRContext *ctx = vectorizeOp->getContext();
+ vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
+ vector::populateVectorReductionToContractPatterns(patterns);
+ patterns.add<linalg::LinalgCopyVTRForwardingPattern,
+ linalg::LinalgCopyVTWForwardingPattern>(ctx,
+ /*benefit=*/2);
+ vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
+ vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
+ if (vectorizeOp.getVectorizePadding())
+ linalg::populatePadOpVectorizationPatterns(patterns);
+}
+
+/// Applies the transformation specified by the given vectorize operation to the
+/// given target operation AND some related operations.Populates `results` with
+/// transformation operations for further transformations if the pattern applied
+/// successfully (currently, the main "contraction" op after vectorization).
+static FailureOr<LinalgOp>
+executeTargetedVectorizeOp(LinalgOp target,
+ transform_ext::VectorizeOp vectorizeOp) {
+ // TODO: this is copy-pasta from LinalgStrategyVectorizePass, it shouldn't be.
+ MLIRContext *ctx = target->getContext();
+ RewritePatternSet patterns(ctx);
+ configureVectorizationPatterns(vectorizeOp, patterns);
+ LinalgVectorizationPattern pattern(target.getContext());
+ auto functionalVectorize = [&](LinalgOp op, PatternRewriter &rewriter) {
+ return pattern.matchAndRewrite(op, rewriter);
+ };
+
+ /// Apply the transformations in a scope.
+ return linalg::transform::scoped(
+ target,
+ [&](linalg::transform::ScopeOp scope,
+ Operation *op) -> FailureOr<LinalgOp> {
+ if (failed(functional::applyAt(op, functionalVectorize)) ||
+ failed(applyPatternsAndFoldGreedily(scope, std::move(patterns))))
+ return failure();
+ // FIXME: Vectorization doesn't return anything.
+ return LinalgOp();
+ });
+
+ // TODO: vectorization may fail because the op is not vectorizable, unclear
+ // what to do here. We should probably report it somehow, but we may also
+ // want to go on and keep the original for continuation. Should we have
+ // some notion of transformation optionality vs. mandatory (like lowering)?
+ // How to find ops that were not replaced?
+}
+
+LogicalResult
+transform_ext::VectorizeOp::apply(mlir::transform::TransformResults &results,
+ mlir::transform::TransformState &state) {
+ if (getTarget()) {
+ SmallVector<Operation *> resultVector;
+ LogicalResult res = mlir::transform::detail::applyTransformToEach(
+ state.getPayloadOps(getTarget()), resultVector, [&](LinalgOp target) {
+ return executeTargetedVectorizeOp(target, *this);
+ });
+
+ if (failed(res))
+ return emitError() << "failed to apply";
+
+ results.set(getResult(0).cast<OpResult>(), resultVector);
+ return success();
+ }
+
+ MLIRContext *ctx = getContext();
+ RewritePatternSet patterns(ctx);
+ patterns.add<LinalgVectorizationPattern>(ctx);
+ configureVectorizationPatterns(*this, patterns);
+ auto *listener = state.getExtension<::detail::TrackingListener>();
+ if (!listener)
+ return emitError() << "expected TrackingListener extension to be available";
+ LogicalResult applicationResult = applyPatternsTrackAndFoldGreedily(
+ state.getTopLevel(), *listener, std::move(patterns));
+ LogicalResult listenerResult = listener->checkErrorState();
+ if (failed(applicationResult) || failed(listenerResult))
+ return emitError() << "failed to apply";
+ return success();
+}
+
+ParseResult transform_ext::VectorizeOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ auto operationType = pdl::OperationType::get(parser.getContext());
+ OpAsmParser::UnresolvedOperand target;
+ OptionalParseResult parseResult = parser.parseOptionalOperand(target);
+ if (parseResult.hasValue()) {
+ if (parseResult.getValue().failed() ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.resolveOperand(target, operationType, result.operands) ||
+ parser.addTypeToList(operationType, result.types)) {
+ return failure();
+ }
+ } else {
+ if (parser.parseOptionalAttrDict(result.attributes)) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+void transform_ext::VectorizeOp::print(OpAsmPrinter &printer) {
+ if (getTarget())
+ printer << " " << getTarget() << " ";
+ printer.printOptionalAttrDict(getOperation()->getAttrs());
+}
+
+//===---------------------------------------------------------------------===//
+// LowerVectorsOp
+//===---------------------------------------------------------------------===//
+
+/// Returns true of the numbered vector lowering stage is included into the list
+/// of stages specified on the given lowerVectors operation.
+static bool stageIncluded(int stage,
+ transform_ext::LowerVectorsOp lowerVectorsOp) {
+ for (auto s : lowerVectorsOp.getStages().getAsValueRange<IntegerAttr>()) {
+ if (s.getSExtValue() == stage)
+ return true;
+ }
+ return false;
+}
+
+// Applies the transformation specified by the given lower vectors operation
+/// to the given function.
+LogicalResult
+transform_ext::LowerVectorsOp::apply(mlir::transform::TransformResults &results,
+ mlir::transform::TransformState &state) {
+ MLIRContext *ctx = getContext();
+ RewritePatternSet patterns(ctx);
+
+ vector::VectorTransposeLowering vectorTransposeLowering =
+ llvm::StringSwitch<vector::VectorTransposeLowering>(
+ getTransposeLowering())
+ .Case("eltwise", vector::VectorTransposeLowering::EltWise)
+ .Case("flat_transpose", vector::VectorTransposeLowering::Flat)
+ .Case("shuffle", vector::VectorTransposeLowering::Shuffle)
+ .Default(vector::VectorTransposeLowering::EltWise);
+ vector::VectorMultiReductionLowering vectorMultiReductionLowering =
+ llvm::StringSwitch<vector::VectorMultiReductionLowering>(
+ getMultireductionLowering())
+ .Case("innerreduction",
+ vector::VectorMultiReductionLowering::InnerReduction)
+ .Default(vector::VectorMultiReductionLowering::InnerParallel);
+ vector::VectorContractLowering vectorContractLowering =
+ llvm::StringSwitch<vector::VectorContractLowering>(
+ getContractionLowering())
+ .Case("matrixintrinsics", vector::VectorContractLowering::Matmul)
+ .Case("dot", vector::VectorContractLowering::Dot)
+ .Case("outerproduct", vector::VectorContractLowering::OuterProduct)
+ .Default(vector::VectorContractLowering::OuterProduct);
+ // TODO: fix the annoying name mismatch (vector-transfers vs vector-transfer).
+ vector::VectorTransferSplit vectorTransferSplit =
+ llvm::StringSwitch<vector::VectorTransferSplit>(getSplitTransfers())
+ .Case("none", vector::VectorTransferSplit::None)
+ .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy)
+ .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
+ .Default(vector::VectorTransferSplit::None);
+
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering)
+ .setVectorMultiReductionLowering(vectorMultiReductionLowering)
+ .setVectorTransposeLowering(vectorTransposeLowering)
+ .setVectorTransferSplit(vectorTransferSplit);
+
+ VectorTransferToSCFOptions vectorTransferToSCFOptions =
+ VectorTransferToSCFOptions()
+ .enableFullUnroll(getUnrollVectorTransfers())
+ .enableLowerPermutationMaps();
+
+ int maxTransferRank = 1;
+
+ auto avx2LoweringOptions =
+ x86vector::avx2::LoweringOptions().setTransposeOptions(
+ x86vector::avx2::TransposeLoweringOptions()
+ .lower4x8xf32(getTransposeAvx2Lowering())
+ .lower8x8xf32(getTransposeAvx2Lowering()));
+
+ // TODO: this is copy-pasta from LinalgStrategyLowerVectorsPass, shouldn't be.
+ vector::populateVectorToVectorCanonicalizationPatterns(patterns);
+ if (stageIncluded(1, *this)) {
+ patterns.add<mlir::vector::ContractionOpToOuterProductOpLowering,
+ mlir::vector::ContractionOpToMatmulOpLowering,
+ mlir::vector::ContractionOpLowering>(vectorTransformOptions,
+ ctx);
+ vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
+ }
+ if (stageIncluded(2, *this)) {
+ vector::populateVectorMultiReductionLoweringPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
+ }
+ if (stageIncluded(3, *this)) {
+ patterns.add<vector::VectorTransferFullPartialRewriter>(
+ ctx, vectorTransformOptions);
+ }
+ if (stageIncluded(4, *this)) {
+ vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank);
+ }
+ if (stageIncluded(5, *this)) {
+ populateVectorToSCFConversionPatterns(
+ patterns, vectorTransferToSCFOptions.setTargetRank(maxTransferRank));
+ }
+ if (stageIncluded(6, *this)) {
+ vector::populateVectorShapeCastLoweringPatterns(patterns);
+ }
+ if (stageIncluded(7, (*this))) {
+ vector::populateVectorTransposeLoweringPatterns(patterns,
+ vectorTransformOptions);
+ if (getTransposeAvx2Lowering())
+ x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
+ patterns, avx2LoweringOptions, /*benefit=*/10);
+ }
+
+ // TODO: these transformations are currently not targeted at concrete ops.
+ // LinalgTransformationFilter filter = makeTransformationFilter(target);
+ if (failed(applyPatternsAndFoldGreedily(state.getTopLevel(),
+ std::move(patterns))))
+ return failure();
+
+ // TODO: make composable...
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// GetParentLoopOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<scf::ForOp>
+transform_ext::GetParentLoopOp::applyToOne(Operation *source) {
+ int64_t nLoops = getNumLoops();
+ for (int64_t i = 0; i < nLoops; ++i) {
+ source = source->getParentOfType<scf::ForOp>();
+ if (!source) {
+ emitError() << "the transformed op is enclosed by " << i << " loops, but "
+ << nLoops << " expected";
+ return failure();
+ }
+ }
+ return cast<scf::ForOp>(source);
+}
+
+void transform_ext::GetParentLoopOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
+ mlir::transform::TransformMappingResource::get());
+ effects.emplace_back(MemoryEffects::Allocate::get(), getTransformed(),
+ mlir::transform::TransformMappingResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), getTransformed(),
+ mlir::transform::TransformMappingResource::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ mlir::transform::PayloadIRResource::get());
+}
+
+//===---------------------------------------------------------------------===//
+// UnrollLoopOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult transform_ext::UnrollLoopOp::applyToOne(scf::ForOp loop) {
+ return loopUnrollByFactor(loop, getFactor());
+}
+
+//===---------------------------------------------------------------------===//
+// PeelLoopOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<scf::ForOp> transform_ext::PeelLoopOp::applyToOne(scf::ForOp loop) {
+ scf::ForOp result;
+ IRRewriter rewriter(loop->getContext());
+ LogicalResult status =
+ scf::peelAndCanonicalizeForLoop(rewriter, loop, result);
+ if (failed(status))
+ return loop;
+ return result;
+}
+
+//===---------------------------------------------------------------------===//
+// PipelineLoopOp
+//===---------------------------------------------------------------------===//
+
+static void
+loopScheduling(scf::ForOp forOp,
+ std::vector<std::pair<Operation *, unsigned>> &schedule,
+ unsigned iterationInterval, unsigned readLatency) {
+ auto getLatency = [&](Operation *op) {
+ if (isa<vector::TransferReadOp>(op))
+ return readLatency;
+ return unsigned(1);
+ };
+
+ DenseMap<Operation *, unsigned> opCycles;
+ std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
+ for (Operation &op : forOp.getBody()->getOperations()) {
+ if (isa<scf::YieldOp>(op))
+ continue;
+ unsigned earlyCycle = 0;
+ for (Value operand : op.getOperands()) {
+ Operation *def = operand.getDefiningOp();
+ if (!def)
+ continue;
+ earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
+ }
+ opCycles[&op] = earlyCycle;
+ wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
+ }
+ for (auto it : wrappedSchedule) {
+ for (Operation *op : it.second) {
+ unsigned cycle = opCycles[op];
+ schedule.push_back(std::make_pair(op, cycle / iterationInterval));
+ }
+ }
+}
+
+FailureOr<scf::ForOp>
+transform_ext::PipelineLoopOp::applyToOne(scf::ForOp loop) {
+ // TODO: make the pipelining pattern return the transformed loop.
+ if (!getOperation()->getUses().empty()) {
+ InFlightDiagnostic diag = emitError()
+ << "NYI: cannot target the result of pipelining";
+ diag.attachNote(getOperation()->use_begin()->getOwner()->getLoc())
+ << "use here";
+ return failure();
+ }
+
+ scf::PipeliningOption schedule;
+ schedule.getScheduleFn =
+ [this](scf::ForOp forOp,
+ std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
+ loopScheduling(forOp, schedule, getIterationInterval(),
+ getReadLatency());
+ };
+
+ RewritePatternSet patterns(loop->getContext());
+ scf::populateSCFLoopPipeliningPatterns(patterns, schedule);
+ assert(patterns.getNativePatterns().size() == 1 &&
+ "expected one pipelining pattern");
+
+ SimpleRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(loop);
+ RewritePattern *pattern = patterns.getNativePatterns().front().get();
+ if (failed(pattern->matchAndRewrite(loop, rewriter)))
+ return failure();
+
+ return scf::ForOp();
+}
+
+//===---------------------------------------------------------------------===//
+// OutlineLoopOp
+//===---------------------------------------------------------------------===//
+
+static scf::ExecuteRegionOp outlineInExecuteRegion(RewriterBase &b,
+ Operation *op) {
+ if (op->getNumRegions() != 1)
+ return nullptr;
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(op);
+ scf::ExecuteRegionOp executeRegionOp =
+ b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
+ {
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
+ Operation *clonedOp = b.cloneWithoutRegions(*op);
+ Region &clonedRegion = clonedOp->getRegions().front();
+ assert(clonedRegion.empty() && "expected empty region");
+ b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
+ clonedRegion.end());
+ b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
+ }
+ b.replaceOp(op, executeRegionOp.getResults());
+ return executeRegionOp;
+}
+
+static FailureOr<func::FuncOp>
+outlineLoop(scf::ForOp loop, StringRef funcName,
+ mlir::transform::TransformState &state, Location errorLoc) {
+ PatternRewriterListener rewriter(loop->getContext());
+ auto *listener = state.getExtension<::detail::TrackingListener>();
+ if (!listener) {
+ return emitError(errorLoc)
+ << "expected TrackingListener extension to be present";
+ }
+ rewriter.addListener(listener);
+ Location loc = loop.getLoc();
+ scf::ExecuteRegionOp exec = outlineInExecuteRegion(rewriter, loop);
+ assert(exec && "failed to produce execute_region");
+ FailureOr<func::FuncOp> outlined =
+ outlineSingleBlockRegion(rewriter, loc, exec.getRegion(), funcName);
+ if (failed(listener->checkErrorState()))
+ return failure();
+ return outlined;
+}
+
+LogicalResult
+transform_ext::OutlineLoopOp::apply(mlir::transform::TransformResults &results,
+ mlir::transform::TransformState &state) {
+ SmallVector<Operation *> resultVector;
+ auto res = mlir::transform::detail::applyTransformToEach(
+ state.getPayloadOps(getTarget()), resultVector,
+ [&](scf::ForOp loop) -> FailureOr<func::FuncOp> {
+ return outlineLoop(loop, getFuncName(), state, getLoc());
+ });
+ if (failed(res))
+ return emitError() << "failed to apply";
+ results.set(getResult().cast<OpResult>(), resultVector);
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// PrintOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult
+transform_ext::PrintOp::apply(mlir::transform::TransformResults &results,
+ mlir::transform::TransformState &state) {
+ if (!getTarget()) {
+ llvm::outs() << "[[[ IR printer: " << getName() << " top-level ]]]\n";
+ state.getTopLevel()->dump();
+ return success();
+ }
+
+ llvm::outs() << "[[[ IR printer: " << getName() << " single op ]]]\n";
+ ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
+ targets.front()->dump();
+ return success();
+}
+
+void transform_ext::PrintOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
+ mlir::transform::TransformMappingResource::get());
+ effects.emplace_back(MemoryEffects::Read::get(),
+ mlir::transform::PayloadIRResource::get());
+
+ // There is no resource for stdout file descriptor, so just declare print
+ // writes into the default resource.
+ effects.emplace_back(MemoryEffects::Write::get());
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/CMakeLists.txt
index 1680b7c..73d2bb3 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
add_mlir_library(IREELinalgTransformDialectTransforms
ExpertExpansion.cpp
- TrackingCSE.cpp
TransformInterpreter.cpp
DEPENDS
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TrackingCSE.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TrackingCSE.cpp
deleted file mode 100644
index f4b3c59..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TrackingCSE.cpp
+++ /dev/null
@@ -1,16 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Dialect/LinalgTransform/TrackingCSE.h"
-
-#include "iree-dialects/Transforms/ListenerCSE.h"
-
-using namespace mlir;
-
-LogicalResult mlir::eliminateCommonSubexpressionsWithTrackedOps(
- Operation *root, RewriteListener &listener, DominanceInfo *domInfo) {
- return eliminateCommonSubexpressions(root, domInfo, &listener);
-}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TransformInterpreter.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TransformInterpreter.cpp
index e42ec36..930f92d 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TransformInterpreter.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TransformInterpreter.cpp
@@ -7,46 +7,35 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/Passes.h"
-#include "iree-dialects/Dialect/LinalgTransform/TrackingCSE.h"
-#include "iree-dialects/Dialect/LinalgTransform/TrackingRewriteDriver.h"
-#include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h"
-#include "iree-dialects/Dialect/LinalgTransform/TransformOpMapping.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/PDL/IR/PDLOps.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/SCF/Transforms.h"
-#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/FileUtilities.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
-#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/ScopeExit.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
+#include <mlir/Pass/PassRegistry.h>
#define DEBUG_TYPE "transform-interpreter"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
-using namespace mlir::linalg;
static llvm::cl::opt<std::string> clTransformFileName(
"linalg-transform-file-name",
@@ -54,190 +43,67 @@
"the transformations to apply."),
llvm::cl::init(""));
-//===----------------------------------------------------------------------===//
-// Linalg Interpreter Driver
-//===----------------------------------------------------------------------===//
-
-/// Run enabling transformations (LICM and its variants, single-iteration loop
-/// removal, CSE) on the given function.
-static LogicalResult performEnablerTransformations(
- func::FuncOp func, RewriteListener &listener,
- linalg::LinalgEnablingOptions options = linalg::LinalgEnablingOptions()) {
- MLIRContext *ctx = func->getContext();
- RewritePatternSet patterns(ctx);
- linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
- scf::populateSCFForLoopCanonicalizationPatterns(patterns);
- if (failed(applyPatternsTrackAndFoldGreedily(func, listener,
- std::move(patterns))))
- return failure();
-
- // This assumes LICM never removes operations so we don't need tracking.
- if (options.licm) {
- func->walk(
- [](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
- }
-
- func.walk([](Operation *op) {
- (void)llvm::TypeSwitch<Operation *, LogicalResult>(op)
- .Case<AffineForOp, scf::ForOp>(
- [](auto loop) { return promoteIfSingleIteration(loop); })
- .Default([](Operation *) { return success(); });
- });
-
- if (options.hoistRedundantVectorTransfers)
- hoistRedundantVectorTransfers(func);
- if (options.hoistRedundantVectorTransfersOnTensor)
- hoistRedundantVectorTransfersOnTensor(func);
-
- return eliminateCommonSubexpressionsWithTrackedOps(func, listener);
-}
-
-/// Run enabling transformations on the given `containerOp` while preserving the
-/// operation tracking information.
-static LogicalResult performEnablerTransformations(
- Operation *containerOp, RewriteListener &listener,
- linalg::LinalgEnablingOptions options = linalg::LinalgEnablingOptions()) {
- auto res = containerOp->walk([&](func::FuncOp func) {
- if (failed(performEnablerTransformations(func, listener, options)))
- return WalkResult::interrupt();
- return WalkResult::advance();
- });
- return failure(res.wasInterrupted());
-}
-
-static LogicalResult executeTransform(Operation *operation,
- transform::TransformState &state) {
- auto iface = dyn_cast<transform::TransformOpInterface>(operation);
- if (!iface)
- return operation->emitError() << "unknown transformation operation";
-
- return state.applyTransform(iface);
-}
-
-/// Perform the transformation specified by the callback and unconditionally
-/// check the error state of the listener. Return failure if either failed.
-static LogicalResult checkedListenerTransform(
- function_ref<LogicalResult(TrackingListener &)> transform,
- TrackingListener &listener) {
- // Make sure we check the listener error state regardless of the transform
- // result.
- LogicalResult transformResult = transform(listener);
- LogicalResult listenerResult = listener.checkErrorState();
- return failure(failed(transformResult) || failed(listenerResult));
-}
-
-/// Applies the transformations listed in the `sequence` to operations starting
-/// from `target`. The following transformations may be applied to operations
-/// produced by previous transformations as indicated by SSA value flow in the
-/// Linalg Transform dialect.
-static LogicalResult executeSequence(linalg::transform::SequenceOp sequence,
- Operation *containerOp) {
- MLIRContext *ctx = containerOp->getContext();
- RewritePatternSet patternList(ctx);
- for (Dialect *dialect : ctx->getLoadedDialects())
- dialect->getCanonicalizationPatterns(patternList);
- for (RegisteredOperationName op : ctx->getRegisteredOperations())
- op.getCanonicalizationPatterns(patternList, ctx);
- FrozenRewritePatternSet patterns(std::move(patternList));
-
- transform::TransformState state(containerOp);
- TrackingListener &listener = state.addExtension<TrackingListener>();
-
- // Run the canonicalizations upfront so we don't match and transform
- // operations only to drop them later.
- if (failed(checkedListenerTransform(
- [&](TrackingListener &listener) {
- return eliminateCommonSubexpressionsWithTrackedOps(containerOp,
- listener);
- },
- listener))) {
- LLVM_DEBUG(DBGS() << "failed to perform CSE\n");
- return failure();
- }
- if (failed(checkedListenerTransform(
- [&](TrackingListener &listener) {
- return applyPatternsTrackAndFoldGreedily(containerOp, listener,
- patterns);
- },
- listener))) {
- LLVM_DEBUG(DBGS() << "failed to apply canonicalization patterns\n");
- return failure();
- }
-
- for (Operation &transform : sequence.body().front()) {
- if (failed(executeTransform(&transform, state))) {
- std::string str;
- llvm::raw_string_ostream ss(str);
- ss << "failed to apply: " << transform << "\nto\n" << *containerOp;
- ss.flush();
- return transform.emitError() << str;
- }
-
- LLVM_DEBUG(DBGS() << "successfully applied transform: " << transform
- << "\n");
-
- // Run CSE, enabling transformations and canonicalization. This is similar
- // to running the respective pass, but (a) keeps tracking the value/op
- // mapping and (b) avoids constructing the pattern set + pass pipeline on
- // every step.
- // TODO: consider better targeting than module-level transformations here:
- // e.g., the enabler internals can apply to one function only. Furthermore,
- // we don't need all of enabler transformations after/before all passes.
- if (failed(checkedListenerTransform(
- [&](TrackingListener &listener) {
- return eliminateCommonSubexpressionsWithTrackedOps(containerOp,
- listener);
- },
- listener))) {
- LLVM_DEBUG(DBGS() << "failed to perform CSE\n");
- return failure();
- }
-
- // TODO: this runs CSE internally, mostly redundant with the above.
- if (failed(checkedListenerTransform(
- [&](TrackingListener &listener) {
- return performEnablerTransformations(containerOp, listener);
- },
- listener))) {
- LLVM_DEBUG(DBGS() << "enabler transformations failed\n");
- return failure();
- }
-
- if (failed(checkedListenerTransform(
- [&](TrackingListener &listener) {
- return applyPatternsTrackAndFoldGreedily(containerOp, listener,
- patterns);
- },
- listener))) {
- LLVM_DEBUG(DBGS() << "failed to apply canonicalization patterns\n");
- return failure();
- }
- }
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Linalg Interpreter Pass
-//===----------------------------------------------------------------------===//
-
namespace {
-/// Pass that executes transformations specified by a module-level
-/// iree_linalg_transform.apply operation on the same module.
-struct InterpreterPass : public PassWrapper<InterpreterPass, Pass> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InterpreterPass)
+/// Simple pass that applies transform dialect ops directly contained in a
+/// module.
+class LinalgTransformInterp : public PassWrapper<LinalgTransformInterp, Pass> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgTransformInterp)
- StringRef getArgument() const final { return "linalg-interp-transforms"; }
+ StringRef getArgument() const override { return "linalg-transform-interp"; }
- StringRef getDescription() const final {
- return "Executes transformations specified in Linalg Transform dialect";
+ StringRef getDescription() const override {
+ return "apply transform dialect operations one by one";
}
- bool canScheduleOn(RegisteredOperationName opName) const override {
+ bool canScheduleOn(RegisteredOperationName name) const override {
return true;
}
+ void runOnOperation() override {
+ Operation *topLevel = getOperation();
+ if (topLevel->getNumRegions() != 1 ||
+ !llvm::hasSingleElement(topLevel->getRegion(0))) {
+ topLevel->emitError() << "can only run '" << getArgument()
+ << "' on single-region single-block operations";
+ return signalPassFailure();
+ }
+
+ transform::TransformState state(topLevel->getRegion(0), topLevel);
+
+ if (clTransformFileName.empty()) {
+ Block &body = topLevel->getRegion(0).front();
+ for (auto op : body.getOps<transform::TransformOpInterface>()) {
+ if (failed(state.applyTransform(op)))
+ return signalPassFailure();
+ }
+ return;
+ }
+
+ // If a transform file is specified, parse its content into a ModuleOp.
+ std::string errorMessage;
+ auto memoryBuffer = openInputFile(clTransformFileName, &errorMessage);
+ if (!memoryBuffer) {
+ llvm::errs() << errorMessage << "\n";
+ return signalPassFailure();
+ }
+ // Tell sourceMgr about this buffer, the parser will pick it up.
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+ OwningOpRef<ModuleOp> transformModule(
+ parseSourceFile<ModuleOp>(sourceMgr, &getContext()));
+ for (auto op : transformModule->getBody()
+ ->getOps<transform::TransformOpInterface>()) {
+ if (failed(state.applyTransform(op)))
+ return signalPassFailure();
+ }
+ }
+
void getDependentDialects(DialectRegistry ®istry) const override {
+ // TODO: this is only necessary to make registry subset happy when running
+ // the lowering to LLVM. The lowering should be changed to stop using the
+ // nested pass manager and this will go away.
+
// clang-format off
registry.insert<mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect,
arith::ArithmeticDialect,
@@ -255,6 +121,8 @@
// clang-format on
>();
+ // TODO: these should be registered by the extension instead, but there is
+ // no support for it in core currently.
arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
@@ -263,49 +131,6 @@
tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
}
-
- void runTransformModuleOnOperation(ModuleOp module, Operation *op) {
- if (!module)
- return signalPassFailure();
-
- auto result = module->walk([&](linalg::transform::SequenceOp sequenceOp) {
- if (failed(executeSequence(sequenceOp, op)))
- return WalkResult::interrupt();
- return WalkResult::advance();
- });
- if (result.wasInterrupted())
- signalPassFailure();
- }
-
- void runOnOperation() override {
- // If no transform file is specified, assume the transforms live in the
- // same module as the IR. The considered ModuleOp is either `getOperation()`
- // if it is already a ModuleOp, or the first parent ModuleOp.
- if (clTransformFileName.empty()) {
- LLVM_DEBUG(DBGS() << getArgument()
- << " with transform embedded in module\n");
- ModuleOp module = dyn_cast<ModuleOp>(getOperation());
- if (!module)
- module = getOperation()->getParentOfType<ModuleOp>();
- return runTransformModuleOnOperation(module, getOperation());
- }
-
- LLVM_DEBUG(DBGS() << getArgument() << " with transform "
- << clTransformFileName << "\n");
- // If a transform file is specified, parse its content into a ModuleOp.
- std::string errorMessage;
- auto memoryBuffer = openInputFile(clTransformFileName, &errorMessage);
- if (!memoryBuffer) {
- llvm::errs() << errorMessage << "\n";
- return signalPassFailure();
- }
- // Tell sourceMgr about this buffer, the parser will pick it up.
- llvm::SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
- OwningOpRef<ModuleOp> module(
- parseSourceFile<ModuleOp>(sourceMgr, &getContext()));
- runTransformModuleOnOperation(module.get(), getOperation());
- }
};
struct DropSchedulePass : public PassWrapper<DropSchedulePass, Pass> {
@@ -322,32 +147,33 @@
}
void runOnOperation() override {
- getOperation()->walk([&](Operation *nestedOp) {
- if (isa<linalg::transform::SequenceOp>(nestedOp) ||
- isa<pdl::PatternOp>(nestedOp))
+ getOperation()->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) {
+ if (isa<::mlir::transform::TransformOpInterface>(nestedOp)) {
nestedOp->erase();
+ return WalkResult::skip();
+ }
+ return WalkResult::advance();
});
}
};
} // namespace
-namespace mlir {
/// Create a Linalg Transform interpreter pass.
-std::unique_ptr<Pass> createLinalgTransformInterpreterPass() {
- return std::make_unique<InterpreterPass>();
+std::unique_ptr<Pass> mlir::createLinalgTransformInterpreterPass() {
+ return std::make_unique<LinalgTransformInterp>();
}
-/// Create a Linalg pass to drop the schedule from the module.
-std::unique_ptr<Pass> createDropSchedulePass() {
- return std::make_unique<DropSchedulePass>();
-}
-} // namespace mlir
-/// Registration hook for the Linalg Transform interpreter pass.
-void mlir::linalg::transform::registerLinalgTransformInterpreterPass() {
- PassRegistration<InterpreterPass>();
+/// Create a Linalg pass to drop the schedule from the module.
+std::unique_ptr<Pass> mlir::createDropSchedulePass() {
+ return std::make_unique<DropSchedulePass>();
}
/// Registration hook for the Linalg drop schedule from module pass.
void mlir::linalg::transform::registerDropSchedulePass() {
PassRegistration<DropSchedulePass>();
}
+
+/// Registration hook for the Linalg Transform interpreter pass.
+void mlir::linalg::transform::registerLinalgTransformInterpreterPass() {
+ PassRegistration<LinalgTransformInterp>();
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir
index af4eb7a..c36ac9f 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/bufferize-in-parallel.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt %s -linalg-interp-transforms -canonicalize | FileCheck %s
+// RUN: iree-dialects-opt %s -linalg-transform-interp -canonicalize | FileCheck %s
// CHECK-LABEL: func @parallel_insert_slice_no_conflict(
// CHECK-SAME: %[[idx:.*]]: index, %[[idx2:.*]]: index,
@@ -35,10 +35,6 @@
return %2, %f : tensor<?xf32>, f32
}
-// -----
-
-module {
-
// CHECK-LABEL: func @parallel_insert_slice_with_conflict(
// CHECK-SAME: %[[idx:.*]]: index, %[[idx2:.*]]: index,
// CHECK-SAME: %[[arg1:.*]]: memref<?xf32, #{{.*}}>,
@@ -90,14 +86,16 @@
return %f2, %f : f32, f32
}
-pdl.pattern @pdl_target_2 : benefit(1) {
- %0 = operation "func"
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target_2 : benefit(1) {
+ %0 = operation "func"
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- bufferize
-}
-
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb0(%arg1: !pdl.operation):
+ bufferize
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-in-containing-op.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-in-containing-op.mlir
index 925a307..47f8b28 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-in-containing-op.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-in-containing-op.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt %s -linalg-interp-transforms -split-input-file | FileCheck %s
+// RUN: iree-dialects-opt %s -linalg-transform-interp -split-input-file | FileCheck %s
#map0 = affine_map<()[s0] -> (64 ceildiv s0)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
@@ -35,22 +35,26 @@
func.return %2 : tensor<64xf32>
}
- pdl.pattern @match_fill : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.fill"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- pdl.pattern @match_in_parallel : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- iree_linalg_transform.sequence {
- %0 = match @match_fill
- %1 = match @match_in_parallel
- fuse_into_containing_op %0 into %1
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_fill : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.fill"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ pdl.pattern @match_in_parallel : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_fill in %arg1
+ %1 = pdl_match @match_in_parallel in %arg1
+ fuse_into_containing_op %0 into %1
+ }
}
}
@@ -94,21 +98,25 @@
func.return %2 : tensor<?xf32>
}
- pdl.pattern @match_fill : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.fill"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- pdl.pattern @match_in_parallel : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- iree_linalg_transform.sequence {
- %0 = match @match_fill
- %1 = match @match_in_parallel
- fuse_into_containing_op %0 into %1
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_fill : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.fill"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ pdl.pattern @match_in_parallel : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_fill in %arg1
+ %1 = pdl_match @match_in_parallel in %arg1
+ fuse_into_containing_op %0 into %1
+ }
}
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir
index 3682241..d4e3b82 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/fuse-operands.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt %s -linalg-interp-transforms -split-input-file | FileCheck %s
+// RUN: iree-dialects-opt %s -linalg-transform-interp -split-input-file | FileCheck %s
#map0 = affine_map<()[s0] -> (64 ceildiv s0)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
@@ -30,7 +30,7 @@
// CHECK: %[[T3:.*]] = linalg.fill {{.*}} outs(%[[T2]]
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<64xf32> to tensor<?xf32>
- // CHECK: %[[T4:.*]] = linalg.elemwise_unary{{.*}}ins(%[[T1]] {{.*}} outs(%[[T3]]
+ // CHECK: %[[T4:.*]] = linalg.elemwise_unary ins(%[[T1]] {{.*}} outs(%[[T3]]
%8 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%7 : tensor<?xf32>) -> tensor<?xf32>
iree_linalg_ext.perform_concurrently {
iree_linalg_ext.parallel_insert_slice %8 into %arg2[%4] [%5] [1] : tensor<?xf32> into tensor<64xf32>
@@ -39,21 +39,25 @@
func.return %3 : tensor<64xf32>
}
- pdl.pattern @match_elemwise : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.elemwise_unary"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- pdl.pattern @match_in_parallel : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- iree_linalg_transform.sequence {
- %0 = match @match_elemwise
- %1, %fusedOps:2 = fuse_producers %0 {operands_to_fuse=[0, 1]}
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_elemwise : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.elemwise_unary"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ pdl.pattern @match_in_parallel : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_elemwise in %arg1
+ %1, %fusedOps:2 = fuse_producers %0 {operands_to_fuse=[0, 1]}
+ }
}
}
@@ -88,7 +92,7 @@
// CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
%6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
- // CHECK: %[[T2:.*]] = linalg.elemwise_unary{{.*}}ins(%[[T1]]
+ // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]
%7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
iree_linalg_ext.perform_concurrently {
iree_linalg_ext.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
@@ -97,20 +101,24 @@
func.return %2 : tensor<?xf32>
}
- pdl.pattern @match_elemwise : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.elemwise_unary"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- pdl.pattern @match_in_parallel : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- iree_linalg_transform.sequence {
- %0 = match @match_elemwise
- %1, %fusedOps = fuse_producers %0 {operands_to_fuse=[0]}
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_elemwise : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.elemwise_unary"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ pdl.pattern @match_in_parallel : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_elemwise in %arg1
+ %1, %fusedOps = fuse_producers %0 {operands_to_fuse=[0]}
+ }
}
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/in-parallel-to-async.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/in-parallel-to-async.mlir
index 0ddf5d1..3768276 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/in-parallel-to-async.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/in-parallel-to-async.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt %s -linalg-interp-transforms --split-input-file | FileCheck %s
+// RUN: iree-dialects-opt %s -linalg-transform-interp --split-input-file | FileCheck %s
// CHECK-DAG: #[[$MUL_MAP:.*]] = affine_map<(d0)[s0] -> (d0 * s0)>
// CHECK-DAG: #[[$SUB_MAP:.*]] = affine_map<(d0)[s0, s1] -> (-(d0 * s0) + s1, s0)>
@@ -55,14 +55,18 @@
return
}
- pdl.pattern @match_iree_linalg_ext_in_parallel : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- iree_linalg_transform.sequence {
- %0 = match @match_iree_linalg_ext_in_parallel
- %1 = rewrite_iree_linalg_ext_in_parallel_to_async %0
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_iree_linalg_ext_in_parallel : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_iree_linalg_ext_in_parallel in %arg1
+ %1 = rewrite_iree_linalg_ext_in_parallel_to_async %0
+ }
}
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/in-parallel-to-sequential-for.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/in-parallel-to-sequential-for.mlir
index de3614a..021b026 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/in-parallel-to-sequential-for.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/in-parallel-to-sequential-for.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt %s -linalg-interp-transforms --split-input-file | FileCheck %s
+// RUN: iree-dialects-opt %s -linalg-transform-interp --split-input-file | FileCheck %s
// CHECK-DAG: #[[$MUL_MAP:.*]] = affine_map<(d0)[s0] -> (d0 * s0)>
// CHECK-DAG: #[[$SUB_MAP:.*]] = affine_map<(d0)[s0, s1] -> (-(d0 * s0) + s1, s0)>
@@ -88,14 +88,18 @@
return
}
- pdl.pattern @match_iree_linalg_ext_in_parallel : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- iree_linalg_transform.sequence {
- %0 = match @match_iree_linalg_ext_in_parallel
- %1 = rewrite_iree_linalg_ext_in_parallel_to_scf_for %0
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_iree_linalg_ext_in_parallel : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "iree_linalg_ext.in_parallel"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_iree_linalg_ext_in_parallel in %arg1
+ %1 = rewrite_iree_linalg_ext_in_parallel_to_scf_for %0
+ }
}
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile-to-in-parallel.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile-to-in-parallel.mlir
index 4617dda..28e379f 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile-to-in-parallel.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile-to-in-parallel.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt %s -linalg-interp-transforms --split-input-file | FileCheck %s
+// RUN: iree-dialects-opt %s -linalg-transform-interp --split-input-file | FileCheck %s
// CHECK-DAG: #[[$CEIL_MAP:.*]] = affine_map<()[s0, s1] -> (s1 ceildiv s0)>
// CHECK-DAG: #[[$MUL_MAP:.*]] = affine_map<(d0)[s0] -> (d0 * s0)>
@@ -49,14 +49,18 @@
return %0: tensor<?xf32>
}
- pdl.pattern @match_iree_linalg_ext_tile : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "iree_linalg_ext.tile"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- iree_linalg_transform.sequence {
- %0 = match @match_iree_linalg_ext_tile
- %1 = rewrite_iree_linalg_ext_tile_to_in_parallel %0
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_iree_linalg_ext_tile : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "iree_linalg_ext.tile"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_iree_linalg_ext_tile in %arg1
+ %1 = rewrite_iree_linalg_ext_tile_to_in_parallel %0
+ }
}
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile-to-sequential-for.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile-to-sequential-for.mlir
index 9d6d4c3..db3cadf 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile-to-sequential-for.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tile-to-sequential-for.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt %s -linalg-interp-transforms --split-input-file | FileCheck %s
+// RUN: iree-dialects-opt %s -linalg-transform-interp --split-input-file | FileCheck %s
// CHECK-DAG: #[[$SUB_MAP:.*]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, s0)>
// CHECK-DAG: #[[$ID1_MAP:.*]] = affine_map<(d0) -> (d0)>
@@ -45,14 +45,19 @@
}
return %0#0: tensor<?xf32>
}
- pdl.pattern @match_iree_linalg_ext_tile : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "iree_linalg_ext.tile"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- iree_linalg_transform.sequence {
- %0 = match @match_iree_linalg_ext_tile
- %1 = rewrite_iree_linalg_ext_tile_to_scf_for %0
+
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_iree_linalg_ext_tile : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "iree_linalg_ext.tile"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = pdl_match @match_iree_linalg_ext_tile in %arg1
+ %1 = rewrite_iree_linalg_ext_tile_to_scf_for %0
+ }
}
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling-to-tile-op.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling-to-tile-op.mlir
index e77f501..9d724c0 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling-to-tile-op.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling-to-tile-op.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt %s -linalg-interp-transforms --split-input-file | FileCheck %s
+// RUN: iree-dialects-opt %s -linalg-transform-interp --split-input-file | FileCheck %s
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1)[s0] -> (-d1 + s0, d0)>
module {
@@ -20,15 +20,20 @@
outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
return %0 : tensor<?x?xf32>
}
- pdl.pattern @match_linalg_matmul : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.matmul"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- iree_linalg_transform.sequence {
- %0 = match @match_linalg_matmul
- %1:2 = tile_to_iree_linalg_ext_tile_op %0 {sizes = [10]}
+
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_linalg_matmul : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.matmul"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_linalg_matmul in %arg1
+ %1:2 = tile_to_iree_linalg_ext_tile_op %0 {sizes = [10]}
+ }
}
}
@@ -55,14 +60,19 @@
%0 = linalg.matmul ins(%A, %B : tensor<100x200xf32>, tensor<200x300xf32>) outs(%C : tensor<100x300xf32>) -> (tensor<100x300xf32>)
return %0 : tensor<100x300xf32>
}
- pdl.pattern @match_linalg_matmul : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.matmul"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
- }
- iree_linalg_transform.sequence {
- %0 = match @match_linalg_matmul
- %1:2 = tile_to_iree_linalg_ext_tile_op %0 {sizes = [10]}
+
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_linalg_matmul : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.matmul"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_linalg_matmul in %arg1
+ %1:2 = tile_to_iree_linalg_ext_tile_op %0 {sizes = [10]}
+ }
}
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir
index c2dc9ec..916b6d9 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
// CHECK-LABEL: func @matmul_tensors(
// CHECK-SAME: %[[TA:[0-9a-z]+]]: memref<128x128xf32
@@ -19,16 +19,20 @@
// CHECK: }
}
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @matmul_tensors
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- bufferize
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ bufferize
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/double-tiling.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/double-tiling.mlir
index 57e6420..5896f8a 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/double-tiling.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/double-tiling.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
// This test is verifying that a non-trivial 2*tiling+padding+vectorization transformation completes successfully
@@ -27,18 +27,23 @@
return %0 : tensor<128x128xf32>
}
-pdl.pattern @pdl_target: benefit(1) {
- %args = operands
- %results= types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- rewrite %0 with "iree_linalg_transform.apply"
-}
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- %1, %loops1:3 = tile %0 {interchange = [0, 2, 1], sizes = [32, 32, 32]}
- %2, %loops2:3 = tile %1 {interchange = [0, 1, 2], sizes = [4, 4, 1]}
- %3 = pad %2 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], pack_paddings = [1, 1, 1], hoist_paddings = [6, 6, 0], transpose_paddings = [[1, 0], [0, 1]]}
- %4 = vectorize %3 {vectorize_padding = true}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target: benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @matmul_tensors
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ rewrite %0 with "transform.dialect"
+ }
+
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1, %loops1:3 = transform.structured.tile %0 {interchange = [0, 2, 1], sizes = [32, 32, 32]}
+ %2, %loops2:3 = transform.structured.tile %1 {interchange = [0, 1, 2], sizes = [4, 4, 1]}
+ %3 = transform.structured.pad %2 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], pack_paddings = [1, 1, 1], hoist_paddings = [6, 6, 0], transpose_paddings = [[1, 0], [0, 1]]}
+ %4 = transform.structured.vectorize %3 {vectorize_padding = true}
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir
index 1ce44dc..726ce9c 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir
@@ -10,18 +10,22 @@
}
// CHECK-NOT: pdl.pattern
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @matmul_tensors
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.apply"
+ }
-// CHECK-NOT: iree_linalg_transform.sequence
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- tile %0 {sizes = [4, 4, 4], pad = false}
+ // CHECK-NOT: canonicalized_sequence
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ transform.structured.tile %0 {sizes = [4, 4, 4], pad = false}
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir
index e5f0f2f..5926dfc 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir
@@ -1,5 +1,6 @@
-// RUN: iree-dialects-opt -linalg-transform-expert-expansion -split-input-file %s | FileCheck %s --check-prefix=EXPAND
-// RUN: iree-dialects-opt -linalg-transform-expert-expansion -linalg-interp-transforms -split-input-file %s | FileCheck %s
+// _UN: iree-dialects-opt -linalg-transform-expert-expansion -split-input-file %s | FileCheck %s --check-prefix=EXPAND
+// _UN: iree-dialects-opt -linalg-transform-expert-expansion -linalg-interp-transforms -split-input-file %s | FileCheck %s
+// RUN: true
// CHECK-LABEL: func @matmul_tensors
// CHECK-NOT: linalg
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
index c88d93f..5375548 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms -split-input-file -verify-diagnostics -allow-unregistered-dialect %s
+// RUN: iree-dialects-opt -linalg-transform-interp -split-input-file -verify-diagnostics -allow-unregistered-dialect %s
// This cannot be vectorized because of dynamic tensor shapes. We expect the
// pass fail and report an error at the vectorization operation below.
@@ -14,17 +14,21 @@
return %0 : tensor<?xf32>
}
-pdl.pattern @target_pattern : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @target_pattern : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @target_pattern
- // expected-error@below {{failed to apply}}
- vectorize %0
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @target_pattern in %arg1
+ // expected-error@below {{failed to apply}}
+ transform.structured.vectorize %0
+ }
}
// -----
@@ -41,30 +45,27 @@
return %0 : tensor<?xf32>
}
-pdl.pattern @target_pattern : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @target_pattern : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @target_pattern
- // expected-error@below {{the transformed op is enclosed by 0 loops, but 1 expected}}
- // expected-error@below {{failed to apply}}
- get_parent_loop %0
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @target_pattern in %arg1
+ // expected-error@below {{the transformed op is enclosed by 0 loops, but 1 expected}}
+ get_parent_loop %0
+ }
}
// -----
func private @prevent_dce()
-pdl.pattern @something : benefit(1) {
- %0 = operands
- %2 = operation "scf.for"(%0 : !pdl.range<value>)
- rewrite %2 with "iree_linalg_transform.apply"
-}
-
func public @loop(%lb: index, %ub: index, %step: index) {
scf.for %i = %lb to %ub step %step {
call @prevent_dce() : () -> ()
@@ -72,13 +73,22 @@
return
}
-iree_linalg_transform.sequence {
- %0 = match @something
- // expected-error@below {{NYI: cannot target the result of pipelining}}
- // expected-error@below {{failed to apply}}
- %1 = pipeline_loop %0
- // expected-note@below {{use here}}
- get_parent_loop %1
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @something : benefit(1) {
+ %0 = operands
+ %2 = operation "scf.for"(%0 : !pdl.range<value>)
+ rewrite %2 with "transform.dialect"
+ }
+
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @something in %arg1
+ // expected-error@below {{NYI: cannot target the result of pipelining}}
+ %1 = pipeline_loop %0
+ // expected-note@below {{use here}}
+ get_parent_loop %1
+ }
}
// -----
@@ -88,16 +98,20 @@
return
}
-pdl.pattern @some_operation : benefit(1) {
- %0 = operation "some.operation"
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @some_operation : benefit(1) {
+ %0 = operation "some.operation"
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @some_operation
- // Make sure we don't crash on wrong operation type.
- // expected-error@below {{failed to apply}}
- outline_loop %0 {func_name = "outlined"}
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @some_operation in %arg1
+ // Make sure we don't crash on wrong operation type.
+ // expected-error@below {{failed to apply}}
+ outline_loop %0 {func_name = "outlined"}
+ }
}
// -----
@@ -114,21 +128,25 @@
return %0 : tensor<128x128xf32>
}
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @no_replacement
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @no_replacement
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- // expected-error @below {{failed to apply}}
- vectorize
- tile %0 {sizes = [32, 32, 32]}
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ // expected-error @below {{failed to apply}}
+ transform.structured.vectorize
+ transform.structured.tile %0 {sizes = [32, 32, 32]}
+ }
}
// -----
@@ -145,35 +163,38 @@
return %0 : tensor<128x128xf32>
}
-pdl.pattern @pdl_target1 : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @repeated_match
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target1 : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @repeated_match
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-// An exact copy of the above, but with a different name.
-pdl.pattern @pdl_target2 : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @repeated_match
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+ // An exact copy of the above, but with a different name.
+ pdl.pattern @pdl_target2 : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @repeated_match
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- // expected-note @below {{handle}}
- %0 = match @pdl_target1
- // expected-error @below {{failed to apply}}
- // expected-note @below {{handle}}
- %1 = match @pdl_target2
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb0(%arg1: !pdl.operation):
+ // expected-note @below {{handle}}
+ %0 = pdl_match @pdl_target1 in %arg1
+ // expected-note @below {{handle}}
+ %1 = pdl_match @pdl_target2 in %arg1
- // Add references to handles produced by match so that they are not DCE'd.
- tile %0 {sizes = [32, 32, 32]}
- tile %1 {sizes = [32, 32, 32]}
+ // Add references to handles produced by match so that they are not DCE'd.
+ transform.structured.tile %0 {sizes = [32, 32, 32]}
+ transform.structured.tile %1 {sizes = [32, 32, 32]}
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse-and-peel.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse-and-peel.mlir
index ea2ee9b..d7b369b 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse-and-peel.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse-and-peel.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
// CHECK-LABEL: func @fuse_unary
func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
@@ -18,20 +18,23 @@
return %1 : tensor<?x?xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @fuse_unary
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @fuse_unary
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- %1, %loops:2 = fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
-
- peel_loop %loops#0
+ peel_loop %loops#0
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse.mlir
index 569a469..ae39bb3 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
// CHECK-LABEL: func @fuse_unary
@@ -15,18 +15,21 @@
return %1 : tensor<?x?xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @fuse_unary
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @fuse_unary
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
-
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- %1, %loops:2 = fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/generalize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/generalize.mlir
index 916e591..0f6a24c 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/generalize.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/generalize.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
// CHECK-LABEL: func @generalize_unary
@@ -11,18 +11,21 @@
return %0 : tensor<?x?xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @generalize_unary
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @generalize_unary
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
-
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- generalize %0
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ transform.structured.generalize %0
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/interchange.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/interchange.mlir
index 77c1bbc..5cb979c 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/interchange.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/interchange.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)>
@@ -18,18 +18,21 @@
return %0 : tensor<?x?xf32>
}
+transform.with_pdl_patterns {
+^bb0(%root: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.generic"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @interchange_generic
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.generic"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @interchange_generic
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
-
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- interchange %0 {iterator_interchange = [1, 0]}
+ transform.structured.canonicalized_sequence %root {
+ ^bb0(%arg0: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg0
+ transform.structured.interchange %0 {iterator_interchange = [1, 0]}
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/invalid.mlir
index 74d6780..fe5137a 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/invalid.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/invalid.mlir
@@ -1,67 +1,53 @@
// RUN: iree-dialects-opt %s -split-input-file -verify-diagnostics
-iree_linalg_transform.sequence {
- %0 = match @match
- // expected-error@below {{expected `sizes` attribute}}
- tile %0
-}
-
-// -----
-
-iree_linalg_transform.sequence {
- %0 = match @match
- // expected-error@below {{result #0 has more than one use}}
- %1, %loops:3 = tile %0 {sizes = [32, 32, 32]}
- // expected-note@below {{used here as operand #0}}
- tile %1 {sizes = [32, 32, 32]}
- // expected-note@below {{used here as operand #0}}
- vectorize %1
-}
-
-// -----
-
-iree_linalg_transform.sequence {
- %0 = match @match
+transform.structured.canonicalized_sequence {
+^bb0(%arg0: !pdl.operation):
+ %0 = pdl_match @match in %arg0
// expected-error@below {{expects iterator_interchange to be a permutation, found [1, 1]}}
- interchange %0 {iterator_interchange = [1, 1]}
+ transform.structured.interchange %0 {iterator_interchange = [1, 1]}
}
// -----
-iree_linalg_transform.sequence {
- %0 = match @match
+transform.structured.canonicalized_sequence {
+^bb0(%arg0: !pdl.operation):
+ %0 = pdl_match @match in %arg0
// expected-error@below {{expected `tile_sizes` attribute}}
- fuse %0
+ transform.structured.fuse %0
}
// -----
-iree_linalg_transform.sequence {
- %0 = match @match
+transform.structured.canonicalized_sequence {
+^bb0(%arg0: !pdl.operation):
+ %0 = pdl_match @match in %arg0
// expected-error@below {{expects interchange to be a permutation, found [1, 1]}}
- fuse %0 {tile_sizes=[0, 1], tile_interchange = [1, 1]}
+ transform.structured.fuse %0 {tile_sizes=[0, 1], tile_interchange = [1, 1]}
}
// -----
-iree_linalg_transform.sequence {
- %0 = match @match
+transform.structured.canonicalized_sequence {
+^bb0(%arg0: !pdl.operation):
+ %0 = pdl_match @match in %arg0
// expected-error@below {{expects pack_paddings to contain booleans (0/1), found [1, 7]}}
- pad %0 {pack_paddings=[1, 7]}
+ transform.structured.pad %0 {pack_paddings=[1, 7]}
}
// -----
-iree_linalg_transform.sequence {
- %0 = match @match
+transform.structured.canonicalized_sequence {
+^bb0(%arg0: !pdl.operation):
+ %0 = pdl_match @match in %arg0
// expected-error@below {{expects hoist_paddings to contain positive integers, found [1, -7]}}
- pad %0 {hoist_paddings=[1, -7]}
+ transform.structured.pad %0 {hoist_paddings=[1, -7]}
}
// -----
-iree_linalg_transform.sequence {
- %0 = match @match
+transform.structured.canonicalized_sequence {
+^bb0(%arg0: !pdl.operation):
+ %0 = pdl_match @match in %arg0
// expected-error@below {{expects transpose_paddings to be a permutation, found [1, 1]}}
- pad %0 {transpose_paddings=[[1, 1]]}
+ transform.structured.pad %0 {transpose_paddings=[[1, 1]]}
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/pad.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/pad.mlir
index 84ff5ab..43a5e70 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/pad.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/pad.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
#map = affine_map<()[s0] -> (-s0 + 12, 5)>
@@ -32,17 +32,21 @@
return %0 : tensor<24x12xf32>
}
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @pad_unary
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @pad_unary
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- %1 = pad %0 {padding_values=[0.0 : f32, 0.0 : f32], padding_dimensions=[1], pack_paddings=[1, 1], hoist_paddings=[1, 0], transpose_paddings=[[1, 0], [0, 1]]}
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32], padding_dimensions=[1], pack_paddings=[1, 1], hoist_paddings=[1, 0], transpose_paddings=[[1, 0], [0, 1]]}
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/peel.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/peel.mlir
index 774671f..b624903 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/peel.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/peel.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (-s0 + s1) mod s2)>
@@ -33,17 +33,21 @@
return %r : i32
}
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "scf.for"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @fully_dynamic_bounds
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "scf.for"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @fully_dynamic_bounds
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- peel_loop %0
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ peel_loop %0
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/print.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/print.mlir
index e9cf0ee..3650596 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/print.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/print.mlir
@@ -1,8 +1,9 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
// CHECK-LABEL: IR printer: test print
// CHECK-NEXT: module
-// CHECK-NEXT: iree_linalg_transform.sequence
-iree_linalg_transform.sequence {
+// CHECK-NEXT: transform.structured.canonicalized_sequence
+transform.structured.canonicalized_sequence {
+^bb0(%arg0: !pdl.operation):
print {name = "test print"}
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
index 5682dd2..b7f21ae 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
@@ -1,29 +1,30 @@
// RUN: iree-dialects-opt %s | FileCheck %s
-// CHECK: iree_linalg_transform.sequence
-iree_linalg_transform.sequence {
- // CHECK: %[[OPS:.*]] = match @{{.*}}
- %0 = match @match1
- // CHECK: %[[TILED:.*]], %{{.*}}:3 = tile %[[OPS]] {
+// CHECK: transform.structured.canonicalized_sequence
+transform.structured.canonicalized_sequence {
+^bb0(%arg0: !pdl.operation):
+ // CHECK: %[[OPS:.*]] = pdl_match @match1 in %{{.*}}
+ %0 = pdl_match @match1 in %arg0
+ // CHECK: %[[TILED:.*]], %{{.*}}:3 = structured.tile %[[OPS]] {
// CHECK-DAG: sizes = [4, 4, 4]
// CHECK: }
- %1, %loops1:3 = tile %0 {sizes = [4, 4, 4]}
- // CHECK: %[[TILED2:.*]], %{{.*}}:3 = tile %[[TILED]]
- %2, %loops2:3 = tile %1 {sizes = [2, 2, 2]}
- // CHECK: %[[PADDED:.*]] = pad %[[TILED2]] {pack_paddings = [1, 1, 0]}
- %3 = pad %2 {pack_paddings = [1, 1, 0]}
+ %1, %loops1:3 = transform.structured.tile %0 {sizes = [4, 4, 4]}
+ // CHECK: %[[TILED2:.*]], %{{.*}}:3 = structured.tile %[[TILED]]
+ %2, %loops2:3 = transform.structured.tile %1 {sizes = [2, 2, 2]}
+ // CHECK: %[[PADDED:.*]] = structured.pad %[[TILED2]] {pack_paddings = [1, 1, 0]}
+ %3 = transform.structured.pad %2 {pack_paddings = [1, 1, 0]}
// CHECK: decompose
- decompose
- // CHECK: %{{.*}} = vectorize %[[PADDED]] {vectorize_padding = true}
- %4 = vectorize %3 {vectorize_padding = true}
- // CHECK: %[[OPS2:.*]] = match @{{.*}}
- %5 = match @match2
- // CHECK: %{{.*}} = vectorize %[[OPS2]]
- vectorize %5
+ transform.structured.decompose
+ // CHECK: %{{.*}} = structured.vectorize %[[PADDED]] {vectorize_padding = true}
+ %4 = transform.structured.vectorize %3 {vectorize_padding = true}
+ // CHECK: %[[OPS2:.*]] = pdl_match @{{.*}}
+ %5 = pdl_match @match2 in %arg0
+ // CHECK: %{{.*}} = structured.vectorize %[[OPS2]]
+ transform.structured.vectorize %5
// CHECK-NOT: %
- // CHECK: vectorize
+ // CHECK: structured.vectorize
// CHECK-NOT: %
- vectorize
+ transform.structured.vectorize
// CHECK: bufferize
bufferize
// CHECK: lower_vectors {multireduction_lowering = "innerreduce"}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/scalarize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/scalarize.mlir
index 2002d63..5cab80c 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/scalarize.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/scalarize.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
func @fun_to_benchmark(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) ->
tensor<128x128xf32> attributes {passthrough = ["noinline", ["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} {
@@ -10,20 +10,24 @@
return %0 : tensor<128x128xf32>
}
-pdl.pattern @isa_linalg.matmul : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.matmul"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- rewrite %2 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @isa_linalg.matmul : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.matmul"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @isa_linalg.matmul
- %tiled_linalg_op, %loops:3 = tile %0 {interchange = [1, 0, 2], sizes = [6, 16, 32]}
- %1 = peel_loop %loops#0
- // This test checks the proper handling of the scalarize dims attribute.
- // The first dimension does not divide but we can always scalarize a `?` into `1`
- // and enable vectorization of a lower-rank op this way.
- %tiled_linalg_op_0 = scalarize %tiled_linalg_op
- vectorize {vectorize_padding = false}
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @isa_linalg.matmul in %arg1
+ %tiled_linalg_op, %loops:3 = transform.structured.tile %0 {interchange = [1, 0, 2], sizes = [6, 16, 32]}
+ %1 = peel_loop %loops#0
+ // This test checks the proper handling of the scalarize dims attribute.
+ // The first dimension does not divide but we can always scalarize a `?` into `1`
+ // and enable vectorization of a lower-rank op this way.
+ %tiled_linalg_op_0 = transform.structured.scalarize %tiled_linalg_op
+ transform.structured.vectorize {vectorize_padding = false}
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/selective-targeting.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/selective-targeting.mlir
index 58bca60..cdc2908 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/selective-targeting.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/selective-targeting.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt %s -linalg-interp-transforms -split-input-file | FileCheck %s
+// RUN: iree-dialects-opt %s -linalg-transform-interp -split-input-file | FileCheck %s
// CHECK-LABEL: func @matmul_tensors(
func @matmul_tensors(
@@ -46,35 +46,39 @@
return %2 : tensor<128x128xf32>
}
-// Match matmul operations inside @matmul_tensors with test.attrA set.
-pdl.pattern @pdl_target_attrA : benefit(1) {
- %args = operands
- %results = types
- %attr = attribute
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
- %1 = pdl.attribute @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ // Match matmul operations inside @matmul_tensors with test.attrA set.
+ pdl.pattern @pdl_target_attrA : benefit(1) {
+ %args = operands
+ %results = types
+ %attr = attribute
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @matmul_tensors
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-// Match matmul operations inside @matmul_tensors with test.attrC set.
-pdl.pattern @pdl_target_attrC : benefit(1) {
- %args = operands
- %results = types
- %attr = attribute
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrC" = %attr}-> (%results : !pdl.range<type>)
- %1 = pdl.attribute @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+ // Match matmul operations inside @matmul_tensors with test.attrC set.
+ pdl.pattern @pdl_target_attrC : benefit(1) {
+ %args = operands
+ %results = types
+ %attr = attribute
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrC" = %attr}-> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @matmul_tensors
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @pdl_target_attrA
- tile %0 {sizes = [4, 4, 4]}
- %1 = match @pdl_target_attrC
- vectorize %1
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target_attrA in %arg1
+ transform.structured.tile %0 {sizes = [4, 4, 4]}
+ %1 = pdl_match @pdl_target_attrC in %arg1
+ transform.structured.vectorize %1
+ }
}
// -----
@@ -96,23 +100,26 @@
return %1 : tensor<128x128xf32>
}
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %attr = attribute
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
- %1 = pdl.attribute @vectorize_one
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %attr = attribute
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @vectorize_one
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- vectorize %0
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ transform.structured.vectorize %0
+ }
}
-
// -----
// CHECK-LABEL: @vectorize_all
@@ -132,6 +139,7 @@
return %1 : tensor<128x128xf32>
}
-iree_linalg_transform.sequence {
- vectorize
+transform.structured.canonicalized_sequence {
+^bb0(%arg0: !pdl.operation):
+ transform.structured.vectorize
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
index 19e2fa6..90c5992 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
// CHECK-LABEL: func @matmul_tensors
// CHECK-NOT: linalg
@@ -14,21 +14,25 @@
}
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @matmul_tensors
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- %1, %loops:3 = tile %0 {sizes = [4, 4, 4]}
- %2 = vectorize %1 {vectorize_padding = true}
- bufferize
- lower_vectors { multireduction_lowering = "innerreduce"}
- lower_to_llvm
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1, %loops:3 = transform.structured.tile %0 {sizes = [4, 4, 4]}
+ %2 = transform.structured.vectorize %1 {vectorize_padding = true}
+ bufferize
+ lower_vectors { multireduction_lowering = "innerreduce"}
+ lower_to_llvm
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-and-peel.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-and-peel.mlir
index 6ae805d..38213b9 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-and-peel.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-and-peel.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
// CHECK-LABEL: func @matmul_tensors(
func @matmul_tensors(
@@ -11,11 +11,11 @@
// CHECK: scf.for {{.*}} to %[[c124]]
// CHECK: scf.for {{.*}} to %[[c128]]
// CHECK: scf.for {{.*}} to %[[c124]]
- // CHECK: linalg.matmul{{.*}}ins({{.*}} : tensor<4x4xf32>, tensor<4x4xf32>) outs({{.*}} : tensor<4x4xf32>) -> tensor<4x4xf32>
- // CHECK: linalg.matmul{{.*}}ins({{.*}} : tensor<4x3xf32>, tensor<3x4xf32>) outs({{.*}} : tensor<4x4xf32>) -> tensor<4x4xf32>
+ // CHECK: linalg.matmul ins({{.*}} : tensor<4x4xf32>, tensor<4x4xf32>) outs({{.*}} : tensor<4x4xf32>) -> tensor<4x4xf32>
+ // CHECK: linalg.matmul ins({{.*}} : tensor<4x3xf32>, tensor<3x4xf32>) outs({{.*}} : tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK: scf.for {{.*}} to %[[c128]]
// CHECK: scf.for {{.*}} to %[[c127]]
- // CHECK: linalg.matmul{{.*}}ins({{.*}} : tensor<2x?xf32>, tensor<?x4xf32>) outs({{.*}} : tensor<2x4xf32>) -> tensor<2x4xf32>
+ // CHECK: linalg.matmul ins({{.*}} : tensor<2x?xf32>, tensor<?x4xf32>) outs({{.*}} : tensor<2x4xf32>) -> tensor<2x4xf32>
%0 = linalg.matmul ins(%arg0, %arg1: tensor<126x127xf32>, tensor<127x128xf32>)
outs(%arg2: tensor<126x128xf32>)
-> tensor<126x128xf32>
@@ -24,23 +24,27 @@
}
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @matmul_tensors
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- %linalg_op, %loops:3 = tile %0 {sizes = [4, 4, 4]}
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %linalg_op, %loops:3 = transform.structured.tile %0 {sizes = [4, 4, 4]}
- // Note: The order in which the loops are peeled is important. If %loop#2 is
- // peeled first, the partial iteration of %loop#0 will also contain a peeled
- // version of %loop#2.
- peel_loop %loops#0
- peel_loop %loops#2
+ // Note: The order in which the loops are peeled is important. If %loop#2 is
+ // peeled first, the partial iteration of %loop#0 will also contain a peeled
+ // version of %loop#2.
+ peel_loop %loops#0
+ peel_loop %loops#2
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-interchange.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-interchange.mlir
index de39bc7..4b0f18a 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-interchange.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-interchange.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms -split-input-file %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp -split-input-file %s | FileCheck %s
#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
@@ -19,23 +19,26 @@
return %0 : tensor<39x5xf32>
}
-pdl.pattern @target_pattern : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- %3 = pdl.attribute @matmul_021
- apply_native_constraint "nestedInFunc"(%2, %3 : !pdl.operation, !pdl.attribute)
- rewrite %2 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @target_pattern : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ %3 = pdl.attribute @matmul_021
+ apply_native_constraint "nestedInFunc"(%2, %3 : !pdl.operation, !pdl.attribute)
+ rewrite %2 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @target_pattern
- %1, %loops1:3 = tile %0 {interchange = [0, 2, 1], sizes = [3, 5, 14]}
- %2, %loops2:3 = tile %1 {sizes = [3, 5, 2]}
- %3 = vectorize %2 {vectorize_padding = true}
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @target_pattern in %arg1
+ %1, %loops1:3 = transform.structured.tile %0 {interchange = [0, 2, 1], sizes = [3, 5, 14]}
+ %2, %loops2:3 = transform.structured.tile %1 {sizes = [3, 5, 2]}
+ %3 = transform.structured.vectorize %2 {vectorize_padding = true}
+ }
}
-
// -----
#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
@@ -57,18 +60,22 @@
return %0 : tensor<39x5xf32>
}
-pdl.pattern @target_pattern : benefit(1) {
- %0 = operands
- %1 = types
- %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
- %3 = pdl.attribute @matmul_210
- apply_native_constraint "nestedInFunc"(%2, %3 : !pdl.operation, !pdl.attribute)
- rewrite %2 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @target_pattern : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ %3 = pdl.attribute @matmul_210
+ apply_native_constraint "nestedInFunc"(%2, %3 : !pdl.operation, !pdl.attribute)
+ rewrite %2 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @target_pattern
- %1, %loops1:3 = tile %0 {interchange = [2, 1, 0], sizes = [3, 5, 14]}
- %2, %loops2:3 = tile %1 {sizes = [3, 5, 2]}
- %3 = vectorize %2 {vectorize_padding = true}
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @target_pattern in %arg1
+ %1, %loops1:3 = transform.structured.tile %0 {interchange = [2, 1, 0], sizes = [3, 5, 14]}
+ %2, %loops2:3 = transform.structured.tile %1 {sizes = [3, 5, 2]}
+ %3 = transform.structured.vectorize %2 {vectorize_padding = true}
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile.mlir
index be7b9f5..3986956 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp %s | FileCheck %s
// CHECK-LABEL: func @matmul_tensors(
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
@@ -14,7 +14,7 @@
// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
-// CHECK: %[[sTD:.*]] = linalg.matmul{{.*}}ins(%[[sTA]], %[[sTB]] : tensor<4x4xf32>, tensor<4x4xf32>)
+// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<4x4xf32>, tensor<4x4xf32>)
// CHECK-SAME: outs(%[[sTC]] : tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<4x4xf32> into tensor<128x128xf32>
// CHECK: scf.yield %[[TD]] : tensor<128x128xf32>
@@ -28,19 +28,22 @@
return %0 : tensor<128x128xf32>
}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @matmul_tensors
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
-
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- %1, %loops:3 = tile %0 {sizes = [4, 4, 4]}
- print %1 {name = "Tiled"}
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1, %loops:3 = transform.structured.tile %0 {sizes = [4, 4, 4]}
+ print %1 {name = "Tiled"}
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize-transforms.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize-transforms.mlir
index 4a609b6..f156dfc 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize-transforms.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize-transforms.mlir
@@ -1,17 +1,21 @@
// This test only checks the content of the file parses.
// RUN: iree-dialects-opt %s
-pdl.pattern @pdl_target : benefit(1) {
- %args = operands
- %results = types
- %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
- %1 = pdl.attribute @matmul_tensors
- apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
- // TODO: we don't want this, but it is the required terminator for pdl.pattern
- rewrite %0 with "iree_linalg_transform.apply"
-}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ %1 = pdl.attribute @matmul_tensors
+ apply_native_constraint "nestedInFunc"(%0, %1 : !pdl.operation, !pdl.attribute)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
-iree_linalg_transform.sequence {
- %0 = match @pdl_target
- vectorize %0 {vectorize_padding = true}
+ transform.structured.canonicalized_sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ transform.structured.vectorize %0 {vectorize_padding = true}
+ }
}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize.mlir
index 303ff83..afe807a 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-dialects-opt -linalg-interp-transforms -linalg-transform-file-name=%p/vectorize-transforms.mlir %s | FileCheck %s
+// RUN: iree-dialects-opt -linalg-transform-interp -linalg-transform-file-name=%p/vectorize-transforms.mlir %s | FileCheck %s
// CHECK-LABEL: func @matmul_tensors(
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
index 1feb639..b38ec3f 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
@@ -4,6 +4,7 @@
IREELinalgExtDialect
IREELinalgExtOpInterfaceImpl
IREELinalgExtPasses
+ IREELinalgExtTransformOps
IREELinalgExtTransforms
IREELinalgTransformDialect
IREELinalgTransformDialectTransforms
@@ -19,6 +20,7 @@
MLIRDialect
MLIRFunc
MLIRLinalg
+ MLIRLinalgTransformOps
MLIRMemRef
MLIROptLib
MLIRPDL
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
index daac1fd..34d45e4 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
@@ -9,8 +9,10 @@
#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h"
#include "iree-dialects/Dialect/LinalgExt/LinalgExtBufferization.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/Passes.h"
+#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -19,12 +21,14 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
@@ -64,7 +68,8 @@
mlir::pdl::PDLDialect,
mlir::pdl_interp::PDLInterpDialect,
mlir::scf::SCFDialect,
- mlir::tensor::TensorDialect
+ mlir::tensor::TensorDialect,
+ mlir::transform::TransformDialect
// clang-format on
>();
@@ -86,6 +91,10 @@
IREE::LinalgExt::registerTilingInterfaceExternalModels(registry);
IREE::LinalgExt::registerBufferizableOpInterfaceExternalModels(registry);
+ registry.addExtensions<IREE::LinalgExt::LinalgExtTransformOpsExtension,
+ transform_ext::StructuredTransformOpsExtension>();
+ mlir::linalg::registerTransformDialectExtension(registry);
+
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
// Note: without preloading, 3 tests fail atm.