Sandbox integrate (#8581)
* Bring LinalgTransform dialect from the sandbox to iree-dialects.
Temporarily name the dialect "iree_linalg_transform" instead of "linalg_transform" to avoid name conflicts during transition and thus ease it.
* LinalgTransform python bindings
Temporarily name the dialect "iree_linalg_transform" instead of "linalg_transform" to avoid name conflicts during transition and thus ease it.
* [NFC] Add the MLIR clang-format and format iree-dialects
* Update to sandbox 77ca66e88d130b195b2eac169f17b95305a98577.
* Move Dialect tests to a location consistent with core MLIR
* Update sandbox to 3738d5792a3da6f03628c4375183cb39e3a82d51
* Format
* Drop spurious dependency
* clang-format
* Build fixes
* Move include/Transforms -> include/iree-dialects/Transforms
* Disable pytype on _iree_linalg_transforms_ops_ext.py
* clang-format
* More BUILD fixes
* Fix unit test
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index af9ae07..09a2de1 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1625,7 +1625,7 @@
if (auto endPerformOp = llvm::dyn_cast<EndPerformConcurrentlyOp>(op)) {
continue;
}
- llvm_unreachable("Unexpected operation in perform_concurrently");
+ assert(false, "Unexpected operation in perform_concurrently");
}
return ret;
}
diff --git a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
index 6a03048..cdaf0b4 100644
--- a/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
+++ b/integrations/tensorflow/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
@@ -168,7 +168,7 @@
PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
LogicalResult bufferize(Operation *op, RewriterBase &b,
const BufferizationState &state) const {
- llvm_unreachable("op does not have any tensor OpOperands / OpResults");
+ assert(false, "op does not have any tensor OpOperands / OpResults");
return failure();
}
};
diff --git a/llvm-external-projects/iree-dialects/.clang-format b/llvm-external-projects/iree-dialects/.clang-format
new file mode 100644
index 0000000..a74fda4
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/.clang-format
@@ -0,0 +1,2 @@
+BasedOnStyle: LLVM
+AlwaysBreakTemplateDeclarations: Yes
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index 350a209..3f47bd6 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -44,6 +44,7 @@
"include/iree-dialects/Dialect/Input/*.td",
"include/iree-dialects/Dialect/LinalgExt/IR/*.td",
"include/iree-dialects/Dialect/LinalgExt/Passes/*.td",
+ "include/iree-dialects/Dialect/LinalgTransform/*.td",
"include/iree-dialects/Dialect/PyDM/IR/*.td",
"python/iree/compiler/dialects/*.td",
]) + [
@@ -53,6 +54,7 @@
deps = [
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
+ "@llvm-project//mlir:PDLDialectTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
],
)
@@ -134,6 +136,23 @@
# IREELinalgExt Dialect
################################################################################
+cc_library(
+ name = "IREEDialectsTransforms",
+ srcs = glob([
+ "lib/Transforms/*.cpp",
+ ]),
+ hdrs = glob([
+ "include/iree-dialects/Transforms/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Rewrite",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
gentbl_cc_library(
name = "IREELinalgExtIncGen",
strip_include_prefix = "include",
@@ -229,7 +248,7 @@
)
gentbl_cc_library(
- name = "IREELinalgExtTransformsIncGen",
+ name = "IREELinalgExtPassIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
@@ -265,8 +284,8 @@
deps = [
":IREELinalgExtIncGen",
":IREELinalgExtInterfacesIncGen",
+ ":IREELinalgExtPassIncGen",
":IREELinalgExtTiledOpInterfacesIncGen",
- ":IREELinalgExtTransformsIncGen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:ArithmeticUtils",
@@ -298,7 +317,7 @@
deps = [
":IREEInputDialect",
":IREELinalgExtDialect",
- ":IREELinalgExtTransformsIncGen",
+ ":IREELinalgExtPassIncGen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:ArithmeticDialect",
@@ -320,6 +339,44 @@
],
)
+cc_library(
+ name = "IREELinalgExtTransforms",
+ srcs = glob([
+ "lib/Dialect/LinalgExt/Transforms/*.cpp",
+ ]),
+ hdrs = glob([
+ "include/iree-dialects/Dialect/LinalgExt/Transforms/*.h",
+ ]) + ["include/iree-dialects/Dialect/LinalgExt/LinalgExtBufferization.h"],
+ deps = [
+ ":IREEInputDialect",
+ ":IREELinalgExtDialect",
+ ":IREELinalgExtPassIncGen",
+ ":IREELinalgExtPasses",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:Async",
+ "@llvm-project//mlir:BufferizationDialect",
+ "@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:LinalgStructuredOpsIncGen",
+ "@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:MathDialect",
+ "@llvm-project//mlir:MemRefDialect",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TensorUtils",
+ "@llvm-project//mlir:TilingInterface",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
################################################################################
# IREEPyDM Dialect
################################################################################
@@ -494,6 +551,169 @@
)
################################################################################
+# IREELinalgTransform Dialect
+################################################################################
+
+gentbl_cc_library(
+ name = "IREELinalgTransformIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ [
+ "-dialect=iree_linalg_transform",
+ "-gen-dialect-decls",
+ ],
+ "include/iree-dialects/Dialect/LinalgTransform/LinalgTransformDialect.h.inc",
+ ),
+ (
+ [
+ "-dialect=iree_linalg_transform",
+ "-gen-dialect-defs",
+ ],
+ "include/iree-dialects/Dialect/LinalgTransform/LinalgTransformDialect.cpp.inc",
+ ),
+ (
+ ["-gen-op-decls"],
+ "include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td",
+ deps = [
+ ":TdFiles",
+ "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "IREELinalgTransformInterfacesIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "include/iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "include/iree-dialects/Dialect/LinalgTransform/TransformOpInterface.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/iree-dialects/Dialect/LinalgTransform/TransformOpInterface.td",
+ deps = [
+ ":TdFiles",
+ ],
+)
+
+cc_library(
+ name = "IREELinalgTransformDialect",
+ srcs = glob([
+ "lib/Dialect/LinalgTransform/IR/*.cpp",
+ "lib/Dialect/LinalgTransform/IR/*.h",
+ ]),
+ hdrs = glob([
+ "include/iree-dialects/Dialect/LinalgTransform/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":IREEDialectsTransforms",
+ ":IREELinalgExtDialect",
+ ":IREELinalgExtTransforms",
+ ":IREELinalgTransformIncGen",
+ ":IREELinalgTransformInterfacesIncGen",
+ "@llvm-project//llvm:Support",
+ # Dialects
+ "@llvm-project//mlir:Async",
+ "@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:BufferizationDialect",
+ "@llvm-project//mlir:BufferizationTransforms",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:PDLDialect",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:TensorDialect",
+ # IR
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Rewrite",
+ # Interfaces
+ "@llvm-project//mlir:ControlFlowInterfaces",
+ "@llvm-project//mlir:LinalgInterfaces",
+ # Transforms
+ "@llvm-project//mlir:AsyncTransforms",
+ "@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:AffineToStandard",
+ "@llvm-project//mlir:ModuleBufferization",
+ "@llvm-project//mlir:SCFTransforms",
+ "@llvm-project//mlir:Transforms",
+ "@llvm-project//mlir:ReconcileUnrealizedCasts",
+ # Utils
+ "@llvm-project//mlir:ArithmeticUtils",
+ "@llvm-project//mlir:DialectUtils",
+ # Conversions
+ "@llvm-project//mlir:AsyncToLLVM",
+ "@llvm-project//mlir:FuncToLLVM",
+ "@llvm-project//mlir:LinalgToLLVM",
+ "@llvm-project//mlir:LinalgToStandard",
+ "@llvm-project//mlir:MathToLLVM",
+ "@llvm-project//mlir:MemRefToLLVM",
+ "@llvm-project//mlir:SCFToControlFlow",
+ "@llvm-project//mlir:VectorToLLVM",
+ ],
+)
+
+cc_library(
+ name = "IREELinalgTransformDialectTransforms",
+ srcs = glob([
+ "lib/Dialect/LinalgTransform/Transforms/*.cpp",
+ ]),
+ hdrs = [
+ "include/iree-dialects/Dialect/LinalgTransform/TrackingCSE.h",
+ "include/iree-dialects/Dialect/LinalgTransform/TrackingListener.h",
+ "include/iree-dialects/Dialect/LinalgTransform/TrackingRewriteDriver.h",
+ ],
+ deps = [
+ ":IREEDialectsTransforms",
+ ":IREELinalgTransformDialect",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:AffineUtils",
+ "@llvm-project//mlir:ArithmeticTransforms",
+ "@llvm-project//mlir:BufferizationDialect",
+ "@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:MathDialect",
+ "@llvm-project//mlir:MemRefDialect",
+ "@llvm-project//mlir:ModuleBufferization",
+ "@llvm-project//mlir:PDLDialect",
+ "@llvm-project//mlir:PDLInterpDialect",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Rewrite",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:SCFTransforms",
+ "@llvm-project//mlir:SCFUtils",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TensorTransforms",
+ "@llvm-project//mlir:TensorUtils",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
+ "@llvm-project//mlir:VectorOps",
+ "@llvm-project//mlir:VectorToLLVM",
+ "@llvm-project//mlir:VectorTransforms",
+ ],
+)
+
+################################################################################
# CAPI
################################################################################
@@ -505,6 +725,7 @@
deps = [
":IREEInputDialect",
":IREELinalgExtDialect",
+ ":IREELinalgTransformDialect",
":IREEPyDMDialect",
":IREEPyDMTransforms",
"@llvm-project//mlir:CAPIIR",
@@ -514,6 +735,27 @@
)
################################################################################
+# Test lib
+################################################################################
+
+cc_library(
+ name = "IREEDialectsTest",
+ srcs = glob([
+ "test/lib/**/*.cpp",
+ ]),
+ deps = [
+ ":IREEDialectsTransforms",
+ ":IREELinalgExtDialect",
+ ":IREELinalgTransformDialect",
+ ":IREELinalgTransformDialectTransforms",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Rewrite",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
+################################################################################
# Tools
################################################################################
@@ -524,9 +766,12 @@
],
tags = ["hostonly"],
deps = [
+ ":IREEDialectsTest",
":IREEInputDialect",
":IREELinalgExtDialect",
":IREELinalgExtPasses",
+ ":IREELinalgTransformDialect",
+ ":IREELinalgTransformDialectTransforms",
":IREEPyDMDialect",
":IREEPyDMTransforms",
"@llvm-project//llvm:Support",
@@ -537,6 +782,8 @@
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MlirOptLib",
+ "@llvm-project//mlir:PDLDialect",
+ "@llvm-project//mlir:PDLInterpDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
index eb6276b..163405c 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
@@ -27,16 +27,22 @@
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IREELinalgExt, iree_linalg_ext);
+//===--------------------------------------------------------------------===//
+// LinalgTransform
+//===--------------------------------------------------------------------===//
+
+MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LinalgTransform, iree_linalg_transform);
+
//===----------------------------------------------------------------------===//
// IREEPyDMDialect
//===----------------------------------------------------------------------===//
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IREEPyDM, iree_pydm);
-#define DEFINE_C_API_STRUCT(name, storage) \
- struct name { \
- storage *ptr; \
- }; \
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
typedef struct name name
DEFINE_C_API_STRUCT(IREEPyDMSourceBundle, void);
@@ -55,13 +61,13 @@
ireePyDMSourceBundleCreateFile(MlirStringRef filePath);
/// Destroys a created source bundle.
-MLIR_CAPI_EXPORTED void ireePyDMSourceBundleDestroy(
- IREEPyDMSourceBundle bundle);
+MLIR_CAPI_EXPORTED void
+ireePyDMSourceBundleDestroy(IREEPyDMSourceBundle bundle);
MLIR_CAPI_EXPORTED bool mlirTypeIsAIREEPyDMPrimitiveType(MlirType type);
-#define IREEPYDM_DECLARE_NULLARY_TYPE(Name) \
- MLIR_CAPI_EXPORTED bool mlirTypeIsAIREEPyDM##Name(MlirType type); \
+#define IREEPYDM_DECLARE_NULLARY_TYPE(Name) \
+ MLIR_CAPI_EXPORTED bool mlirTypeIsAIREEPyDM##Name(MlirType type); \
MLIR_CAPI_EXPORTED MlirType mlirIREEPyDM##Name##TypeGet(MlirContext ctx);
IREEPYDM_DECLARE_NULLARY_TYPE(Bool)
@@ -96,25 +102,27 @@
MLIR_CAPI_EXPORTED IREEPyDMLoweringOptions ireePyDMLoweringOptionsCreate();
/// Sets the RTL link source bundle to the lowering options.
-MLIR_CAPI_EXPORTED void ireePyDMLoweringOptionsLinkRtl(
- IREEPyDMLoweringOptions options, IREEPyDMSourceBundle source);
+MLIR_CAPI_EXPORTED void
+ireePyDMLoweringOptionsLinkRtl(IREEPyDMLoweringOptions options,
+ IREEPyDMSourceBundle source);
/// Destroys a created lowering options struct.
-MLIR_CAPI_EXPORTED void ireePyDMLoweringOptionsDestroy(
- IREEPyDMLoweringOptions options);
+MLIR_CAPI_EXPORTED void
+ireePyDMLoweringOptionsDestroy(IREEPyDMLoweringOptions options);
/// Builds a pass pipeline which should be run immediately post import to
/// perform non-local structural transformations not suitable at the AST level
/// and do local type inference.
-MLIR_CAPI_EXPORTED void mlirIREEPyDMBuildPostImportPassPipeline(
- MlirOpPassManager passManager);
+MLIR_CAPI_EXPORTED void
+mlirIREEPyDMBuildPostImportPassPipeline(MlirOpPassManager passManager);
/// Builds a pass pipeline which lowers the iree_pydm dialect to IREE.
-MLIR_CAPI_EXPORTED void mlirIREEPyDMBuildLowerToIREEPassPipeline(
- MlirOpPassManager passManager, IREEPyDMLoweringOptions options);
+MLIR_CAPI_EXPORTED void
+mlirIREEPyDMBuildLowerToIREEPassPipeline(MlirOpPassManager passManager,
+ IREEPyDMLoweringOptions options);
#ifdef __cplusplus
}
#endif
-#endif // IREE_DIALECTS_C_DIALECTS_H
+#endif // IREE_DIALECTS_C_DIALECTS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Utils.h b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Utils.h
index 8d0252e..6132ee9 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Utils.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Utils.h
@@ -23,4 +23,4 @@
}
#endif
-#endif // IREE_DIALECTS_C_UTILS_H
+#endif // IREE_DIALECTS_C_UTILS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
index 620c526..504c744 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(Input)
add_subdirectory(LinalgExt)
+add_subdirectory(LinalgTransform)
add_subdirectory(PyDM)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.h
index 4d6dbb7..18234ef 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.h
@@ -16,4 +16,4 @@
#define GET_TYPEDEF_CLASSES
#include "iree-dialects/Dialect/Input/InputTypes.h.inc"
-#endif // IREE_DIALECTS_DIALECT_INPUT_DIALECT_H
+#endif // IREE_DIALECTS_DIALECT_INPUT_DIALECT_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.h
index abf07a3..152dd69 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.h
@@ -19,4 +19,4 @@
#define GET_OP_CLASSES
#include "iree-dialects/Dialect/Input/InputOps.h.inc"
-#endif // IREE_DIALECTS_DIALECT_INPUT_OPS_H
+#endif // IREE_DIALECTS_DIALECT_INPUT_OPS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h
index 8b15ec9..a9fe5d9 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h
@@ -11,7 +11,7 @@
#include "mlir/IR/OpDefinition.h"
// clang-format off: must be included after all LLVM/MLIR headers
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h.inc" // IWYU pragma: keep
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h.inc" // IWYU pragma: keep
// clang-format on
-#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTDIALECT_H_
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTDIALECT_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
index 7bec2f6..3dd0bec 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
@@ -28,14 +28,14 @@
LogicalResult verifyLinalgExtOpInterface(Operation *op);
}
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h.inc" // IWYU pragma: export
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h.inc" // IWYU pragma: export
/// Include the generated interface declarations.
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.h.inc" // IWYU pragma: export
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.h.inc" // IWYU pragma: export
-} // namespace LinalgExt
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h
index 8e9fed2..e1abad2 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h
@@ -29,12 +29,12 @@
/// `dim`. If the shape is constant, returns the shape as an `IntegerAttr`.
OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim);
-} // namespace LinalgExt
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
#define GET_OP_CLASSES
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h.inc" // IWYU pragma: export
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h.inc" // IWYU pragma: export
-#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h
index 17d8faf..6ee12a9 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h
@@ -25,9 +25,9 @@
/// Registers external models implemented for the `TiledOpInterface`.
void registerTiledOpInterfaceExternalModels(DialectRegistry ®istry);
-} // namespace LinalgExt
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_TILEDOPINTERFACE_H_
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_TILEDOPINTERFACE_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/LinalgExtBufferization.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/LinalgExtBufferization.h
index c1b60b6..ddef7a5 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/LinalgExtBufferization.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/LinalgExtBufferization.h
@@ -1,10 +1,8 @@
-//===-- LinalgExtBufferization.h - Linalg Extension bufferization ---------===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_BUFFERIZATION_H_
#define IREE_DIALECTS_DIALECT_LINALGEXT_BUFFERIZATION_H_
@@ -19,9 +17,9 @@
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
-} // namespace LinalgExt
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_LINALGEXT_BUFFERIZATION_H_
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_BUFFERIZATION_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h
index e5c044d..3f3fe9b 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h
@@ -1,3 +1,9 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_
#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_
@@ -9,11 +15,11 @@
namespace LinalgExt {
#define GEN_PASS_CLASSES
-#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h.inc" // IWYU pragma: keep
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h.inc" // IWYU pragma: keep
-} // namespace LinalgExt
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
index fb857f3..febec87 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Passes.h
@@ -24,9 +24,9 @@
void registerPasses();
-} // namespace LinalgExt
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASSES_H_
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASSES_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Transforms.h
index 6fa1f51..b005ebf 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Transforms.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Passes/Transforms.h
@@ -40,15 +40,14 @@
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter(),
PatternBenefit benefit = 1)
- : OpInterfaceRewritePattern(context, benefit),
- filter(filter),
+ : OpInterfaceRewritePattern(context, benefit), filter(filter),
options(options) {}
LogicalResult matchAndRewriteBase(TiledOpInterface tilableOp,
PatternRewriter &rewriter,
TiledOp &result) const;
- private:
+private:
/// LinalgTransformMarker handles special attribute manipulations.
linalg::LinalgTransformationFilter filter;
/// Options to control tiling;
@@ -73,7 +72,8 @@
return failure();
}
// Check for do-nothing case.
- if (!tiledOp.op) return failure();
+ if (!tiledOp.op)
+ return failure();
if (tiledOp.op != tilableOp) {
if (tiledOp.results.empty()) {
rewriter.eraseOp(tilableOp);
@@ -85,9 +85,9 @@
}
};
-} // namespace LinalgExt
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index 3099515..a5644d4 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -1,10 +1,8 @@
-//===- Transforms.h - LinalgExt transformations as patterns -----*- C++ -*-===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
@@ -27,15 +25,15 @@
LinalgExtTilingPattern(MLIRContext *context, linalg::LinalgTilingOptions opt)
: OpInterfaceRewritePattern<TilingInterface>(context), options(opt) {}
- FailureOr<Operation *> returningMatchAndRewrite(
- TilingInterface op, PatternRewriter &rewriter) const;
+ FailureOr<Operation *>
+ returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
LogicalResult matchAndRewrite(TilingInterface op,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(op, rewriter);
}
- private:
+private:
linalg::LinalgTilingOptions options;
};
@@ -43,8 +41,8 @@
struct TileOpToSCFRewriter : public OpRewritePattern<TileOp> {
using OpRewritePattern::OpRewritePattern;
- FailureOr<scf::ForOp> returningMatchAndRewrite(
- TileOp tileOp, PatternRewriter &rewriter) const;
+ FailureOr<scf::ForOp>
+ returningMatchAndRewrite(TileOp tileOp, PatternRewriter &rewriter) const;
LogicalResult matchAndRewrite(TileOp tileOp,
PatternRewriter &rewriter) const override {
@@ -56,8 +54,8 @@
struct TileOpToInParallelRewriter : public OpRewritePattern<TileOp> {
using OpRewritePattern::OpRewritePattern;
- FailureOr<InParallelOp> returningMatchAndRewrite(
- TileOp tileOp, PatternRewriter &rewriter) const;
+ FailureOr<InParallelOp>
+ returningMatchAndRewrite(TileOp tileOp, PatternRewriter &rewriter) const;
LogicalResult matchAndRewrite(TileOp tileOp,
PatternRewriter &rewriter) const override {
@@ -69,8 +67,9 @@
struct InParallelOpToAsyncRewriter : public OpRewritePattern<InParallelOp> {
using OpRewritePattern::OpRewritePattern;
- FailureOr<Operation *> returningMatchAndRewrite(
- InParallelOp inParallelOp, PatternRewriter &rewriter) const;
+ FailureOr<Operation *>
+ returningMatchAndRewrite(InParallelOp inParallelOp,
+ PatternRewriter &rewriter) const;
LogicalResult matchAndRewrite(InParallelOp inParallelOp,
PatternRewriter &rewriter) const override {
@@ -82,8 +81,9 @@
struct InParallelOpToScfForRewriter : public OpRewritePattern<InParallelOp> {
using OpRewritePattern::OpRewritePattern;
- FailureOr<scf::ForOp> returningMatchAndRewrite(
- InParallelOp inParallelOp, PatternRewriter &rewriter) const;
+ FailureOr<scf::ForOp>
+ returningMatchAndRewrite(InParallelOp inParallelOp,
+ PatternRewriter &rewriter) const;
LogicalResult matchAndRewrite(InParallelOp inParallelOp,
PatternRewriter &rewriter) const override {
@@ -91,9 +91,9 @@
}
};
-} // namespace LinalgExt
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Utils.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Utils.h
index 534e794..e06eb17 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Utils.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Utils.h
@@ -1,10 +1,8 @@
-//===- Utils.h - Utils for simplifying writing transformations -*- C++ -*-===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===---------------------------------------------------------------------===//
#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_UTILS_H_
#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_UTILS_H_
@@ -111,14 +109,14 @@
vals);
}
- private:
+private:
OpBuilder &b;
Location loc;
};
-} // namespace LinalgExt
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_UTILS_H_
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_UTILS_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt
new file mode 100644
index 0000000..ad1594c
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/CMakeLists.txt
@@ -0,0 +1,20 @@
+function(_add_interface)
+ set(LLVM_TARGET_DEFINITIONS TransformOpInterface.td)
+ mlir_tablegen(TransformOpInterface.h.inc -gen-op-interface-decls)
+ mlir_tablegen(TransformOpInterface.cpp.inc -gen-op-interface-defs)
+ add_public_tablegen_target(IREELinalgTransformOpInterface)
+ add_dependencies(IREELinalgExtIncGen IREELinalgTransformOpInterface)
+endfunction()
+
+function(_add_dialect)
+ set(LLVM_TARGET_DEFINITIONS LinalgTransformOps.td)
+ mlir_tablegen(LinalgTransformOps.h.inc -gen-op-decls)
+ mlir_tablegen(LinalgTransformOps.cpp.inc -gen-op-defs)
+ mlir_tablegen(LinalgTransformDialect.h.inc -gen-dialect-decls -dialect=iree_linalg_transform)
+ mlir_tablegen(LinalgTransformDialect.cpp.inc -gen-dialect-defs -dialect=iree_linalg_transform)
+ add_public_tablegen_target(IREELinalgTransformIncGen)
+ add_dependencies(mlir-headers IREELinalgTransformIncGen)
+endfunction()
+
+_add_interface()
+_add_dialect()
\ No newline at end of file
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h
new file mode 100644
index 0000000..6eda492
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h
@@ -0,0 +1,30 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef MLIR_DIALECT_LINALG_IR_LINALGTRANSFORMOPS_H
+#define MLIR_DIALECT_LINALG_IR_LINALGTRANSFORMOPS_H
+
+#include "TrackingListener.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgTransform/TrackingListener.h"
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace scf {
+class ForOp;
+} // namespace scf
+} // namespace mlir
+
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h.inc"
+
+#endif // MLIR_DIALECT_LINALG_IR_LINALGTRANSFORMOPS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td
new file mode 100644
index 0000000..04d5156
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td
@@ -0,0 +1,452 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef LINALG_TRANSFORM_OPS
+#define LINALG_TRANSFORM_OPS
+
+include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.td"
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+
+def Linalg_Transform_Dialect : Dialect {
+ let name = "iree_linalg_transform";
+ let cppNamespace = "::mlir::linalg::transform";
+ let dependentDialects = [
+ "linalg::LinalgDialect",
+ ];
+}
+
+// Operations with this trait must provide the following methods:
+// - `Value target()` - returns the operation handle (value of !pdl.operation
+// type) targeted by this transformation, if available;
+// - `Optional<SymbolRefAttr> matcher()` - returns the name of the PDL matcher
+// that selects the ops targeted by this transformation, if provided.
+class Linalg_Transform_Operation<string name, list<Trait> props = []>
+ : Op<Linalg_Transform_Dialect, name, props> {
+ let cppNamespace = "::mlir::linalg::transform";
+}
+
+class Transform_Op<string name, list<Trait> props = []>
+ : Linalg_Transform_Operation<name, !listconcat(props, [
+ DeclareOpInterfaceMethods<TransformOpInterface, ["apply"]>])>;
+
+//===----------------------------------------------------------------------===//
+
+def ScopeOp : Linalg_Transform_Operation<"util.scope",
+ [IsolatedFromAbove, DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
+ let description = [{An operation to restrict transformation scopes.}];
+
+ let regions = (region AnyRegion:$body);
+ let arguments = (ins Variadic<AnyType>:$ins);
+ let results = (outs Variadic<AnyType>:$outs);
+ let assemblyFormat = [{ `(` operands `)` attr-dict-with-keyword $body
+ `:` functional-type(operands, results) }];
+}
+
+def ForwardOp : Linalg_Transform_Operation<"util.forward",
+ [Terminator, HasParent<"ScopeOp">]> {
+ let description = [{Terminator for a scope operation, indicating the results
+ that should be forwarded out of the scope.}];
+
+ let arguments = (ins Variadic<AnyType>:$ins);
+ let assemblyFormat = "operands attr-dict `:` type(operands)";
+}
+
+//===----------------------------------------------------------------------===//
+
+def SequenceOp : Linalg_Transform_Operation<"sequence",
+ [NoTerminator, OpAsmOpInterface]> {
+ let description = [{Contains a sequence of transformation ops to apply.
+
+ The transformations indicated by the sequence are applied in order of their
+ appearance. Each value produced by a transformation within the sequence
+ corresponds to an operation or a group of operations in the IR being
+ transformed. Therefore, each value may be used at most once by another
+ transformation operation as the transformation is likely to replace the
+ transformed operation with another operation or a group thereof. In such
+ cases, the tranfsormation operation is expected to produce a new value to
+ denote the newly produced operations that can be transformed further.
+ }];
+
+ let regions = (region SizedRegion<1>:$body);
+ let assemblyFormat = "attr-dict-with-keyword regions";
+
+ let extraClassDeclaration = [{
+ static StringRef getDefaultDialect() { return "iree_linalg_transform"; }
+ }];
+
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+
+def MatchOp : Transform_Op<"match"> {
+ let description = [{ Find and return an op that matches the specific PDL
+ pattern. When executed inside a sequence, it returns all matching ops. }];
+
+ let arguments = (ins SymbolRefAttr:$targetMatcher);
+ let results = (outs PDL_Operation:$target);
+
+ let assemblyFormat = "$targetMatcher attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+
+def TileOp : Linalg_Transform_Operation<"tile",
+ [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+ let description = [{Indicates that ops of a specific kind in the given
+ function should be tiled with the options provided as attributes.}];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$sizes,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$peel,
+ DefaultValuedAttr<BoolAttr, "false">:$scalarize_dyn_dims);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+}
+
+def FuseOp : Linalg_Transform_Operation<"fuse",
+ [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+ let description = [{Tiles the operations pointed to by the target handle and
+ fuses their producers greedily using the options provided as attributes.}];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange
+ );
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+}
+
+def GeneralizeOp : Linalg_Transform_Operation<"generalize",
+ [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+ let description = [{Generalizes the operations pointed to
+ by the target handle.}];
+
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+}
+
+def InterchangeOp : Linalg_Transform_Operation<"interchange",
+ [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+ let description = [{Interchanges the iterators of the operations pointed to
+ by the target handle using the iterator interchange attribute.}];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$iterator_interchange);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+}
+
+def PadOp : Linalg_Transform_Operation<"pad",
+ [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+ let description = [{Pads the operations pointed to by the target handle
+ using the options provides as operation attributes.}];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$pack_paddings,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$hoist_paddings,
+ DefaultValuedAttr<
+ TypedArrayAttrBase<I64ArrayAttr,
+ "array of arrays of i64">,
+ "{}">:$transpose_paddings);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+}
+
+def BufferizeOp : Transform_Op<"bufferize"> {
+ let description = [{Indicates that the entire module should be bufferized.}];
+ let assemblyFormat = "attr-dict";
+}
+
+def DecomposeOp : Transform_Op<"decompose"> {
+ let description = [{Indicates that ops in the entire module should be
+ decomposed into lower-level components.}];
+ let assemblyFormat = "attr-dict";
+}
+
+def VectorizeOp : Transform_Op<"vectorize"> {
+ let description = [{Indiactes that vectorization should be performed. If a
+ target handle is provided, only vectorizes the operations pointed to by the
+ handle. Otherwise vectorizes the entire module. Vectorization options are
+ provided as operation attributes.}];
+
+ let arguments = (ins Optional<PDL_Operation>:$target,
+ DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding
+ );
+ let results = (outs Optional<PDL_Operation>:$transformed);
+
+ let hasCustomAssemblyFormat = 1;
+}
+
+def LowerVectorsOp : Transform_Op<"lower_vectors"> {
+ let description = [{Indicates that the vector operations in the entire
+ module should be lowered to simpler primitives (multiple stages of lowering
+ be executed at once).}];
+
+ let arguments =
+ (ins DefaultValuedAttr<I64ArrayAttr, "{0, 1, 2, 3, 4, 5, 6}">:$stages,
+ DefaultValuedAttr<StrAttr, "\"outerproduct\"">:$contraction_lowering,
+ DefaultValuedAttr<StrAttr, "\"innerparallel\"">:$multireduction_lowering,
+ DefaultValuedAttr<StrAttr, "\"linalg-copy\"">:$split_transfers,
+ DefaultValuedAttr<BoolAttr, "true">:$unroll_vector_transfers,
+ DefaultValuedAttr<StrAttr, "\"eltwise\"">:$transpose_lowering,
+ DefaultValuedAttr<BoolAttr, "false">:$transpose_avx2_lowering
+ );
+
+ let assemblyFormat = "attr-dict";
+}
+
+def LowerToLLVMOp : Transform_Op<"lower_to_llvm"> {
+ let description = [{Indicates that the entire module should be converted
+ to the LLVM dialect. This is expected to be the last transformation in
+ a sequence.}];
+
+ let arguments =
+ (ins DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_index_optimizations,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_arm_neon,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_arm_sve,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_amx,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_x86vector,
+ DefaultValuedAttr<BoolAttr, "false">:$enable_async);
+
+ let assemblyFormat = "attr-dict";
+}
+
+def GetParentLoopOp : Linalg_Transform_Operation<"get_parent_loop", [
+ TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+ let description = [{Obtains a handle to a parent loop of the given
+ operation.}];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<Confined<I64Attr, [IntPositive]>,
+ "1">:$num_loops);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(Operation *source);
+ }];
+}
+
+def UnrollLoopOp : Linalg_Transform_Operation<"unroll_loop", [
+ TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+ let description = [{Unrolls the given loop with the given unroll factor.}];
+
+ let arguments = (ins PDL_Operation:$target,
+ Confined<I64Attr, [IntPositive]>:$factor);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ LogicalResult applyToOne(::mlir::scf::ForOp loop);
+ }];
+}
+
+def PipelineLoopOp : Linalg_Transform_Operation<"pipeline_loop", [
+ TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64Attr, "1">:$iteration_interval,
+ DefaultValuedAttr<I64Attr, "10">:$read_latency);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
+ }];
+}
+
+def OutlineLoopOp : Linalg_Transform_Operation<"outline_loop", [
+ DeclareOpInterfaceMethods<TransformOpInterface, ["apply"]>]> {
+ let arguments = (ins PDL_Operation:$target,
+ StrAttr:$func_name);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::FuncOp> applyToOne(::mlir::scf::ForOp loop);
+ }];
+}
+
+def PrintOp : Transform_Op<"print"> {
+ let arguments = (ins StrAttr:$name);
+ let description = [{Prints the module.}];
+ let assemblyFormat = "attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// LinalgExt specific transforms
+//===----------------------------------------------------------------------===//
+
+def TileToLinalgExtTileOp : Linalg_Transform_Operation<"tile_to_iree_linalg_ext_tile_op",
+ [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+ let description = [{Tile a linalg op with linalg_ext.tile op along a single
+ dimension.}];
+
+ let summary = [{
+ 0 should be used as a tile size to skip tiling a particular dimension.
+ This is a commonly used convention in Linalg.
+ For the purpose of this transformation, a single non-zero positive tile
+ size is allowed.
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$sizes);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<Operation *> applyToOne(
+ ::mlir::TilingInterface target);
+ }];
+}
+
+def RewriteLinalgExtTileToScfForOp :
+ Linalg_Transform_Operation<"rewrite_iree_linalg_ext_tile_to_scf_for",
+ [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+
+ let description = [{Rewrite linalg_ext.tile op to scf.for.}];
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
+ ::mlir::iree_compiler::IREE::LinalgExt::TileOp target);
+ }];
+}
+
+def RewriteLinalgExtTileToInParallelOp :
+ Linalg_Transform_Operation<"rewrite_iree_linalg_ext_tile_to_in_parallel",
+ [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+
+ let description = [{Rewrite linalg_ext.tile op to linalg_ext.in_parallel.}];
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::iree_compiler::IREE::LinalgExt::InParallelOp> applyToOne(
+ ::mlir::iree_compiler::IREE::LinalgExt::TileOp target);
+ }];
+}
+
+def RewriteLinalgExtInParallelToAsyncOp :
+ Linalg_Transform_Operation<"rewrite_iree_linalg_ext_in_parallel_to_async",
+ [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+
+ let description = [{Rewrite linalg_ext.in_parallel op to the async dialect.}];
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::Operation *> applyToOne(
+ ::mlir::iree_compiler::IREE::LinalgExt::InParallelOp target);
+ }];
+}
+
+def RewriteLinalgExtInParallelToScfForOp :
+ Linalg_Transform_Operation<"rewrite_iree_linalg_ext_in_parallel_to_scf_for",
+ [TransformOpInterface, TargetableSingleOperandTransformOpTrait]> {
+
+ let description = [{Rewrite linalg_ext.in_parallel to a sequential scf.for.
+
+ Warning: when the linalg_ext.in_parallel terminator operates on tensors,
+ this is a potentially dangerous transformation under the current semantics.
+ In order for this transformation to be semantics-preserving, 2 conditions need
+ to come together that are currently not checked and the subject of future analyses:
+ 1. The terminator of linalg_ext.in_parallel may compose the output tensor in
+ potentially undefined ways: if the linalg_ext.parallel_insert_slice regions overlap, they
+ may occur in any order and the result is unspecified. A future overlap/intersection
+ analysis will be needed to guard against this case.
+ 2. Even when condition 1. has well-defined behavior, semantics altering behavior may
+ still be introduced to simplify inplace bufferization. In the current implementation,
+ all linalg_ext.parallel_insert_slice dest operands are optimistically turned into scf.for
+ iter_args. This is optimistic because any use of a tensor inside linalg_ext.in_parallel
+ is guaranteed to see the value before the start of the op; whereas the same use may
+ see a partially updated sequential result in the scf.for op.
+ An extra analysis is needed to ensure that a particular read of a result tensor would
+ see the initial value and not a partial update. This is guaranteed by construction when
+ the linalg_ext.in_parallel op is produced by lowering a linalg_ext.tile operation but
+ is not something that is generally enforced in the IR.
+ For the moment we perform the replacement of the use with the scf.for iter_arg to be
+ able to connect pieces inplace with bufferization. In the future an analysis will be
+ needed to ensure correctness of this lowering to sequential scf.for + iter_args.
+ }];
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
+ ::mlir::iree_compiler::IREE::LinalgExt::InParallelOp target);
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+
+def ExpertOp : Linalg_Transform_Operation<"expert"> {
+ let description = [{A "transformation expert" that can be lowered to a
+ sequence of transformations. The details of the lowering depend on the name
+ and are expressed declaratively.}];
+
+ let arguments = (ins PDL_Operation:$target,
+ StrAttr:$expertName);
+ let results = (outs PDL_Operation:$transformed);
+
+ let assemblyFormat = "`apply` $expertName `to` $target attr-dict";
+}
+
+#endif // LINALG_TRANSFORM_OPS
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/Passes.h
new file mode 100644
index 0000000..f2db6da
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/Passes.h
@@ -0,0 +1,25 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <memory>
+
+namespace mlir {
+namespace linalg {
+namespace transform {
+
+void registerLinalgTransformInterpreterPass();
+void registerLinalgTransformExpertExpansionPass();
+void registerDropSchedulePass();
+
+} // namespace transform
+} // namespace linalg
+} // namespace mlir
+
+namespace mlir {
+class Pass;
+std::unique_ptr<Pass> createLinalgTransformInterpreterPass();
+std::unique_ptr<Pass> createDropSchedulePass();
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/ScopedTransform.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/ScopedTransform.h
new file mode 100644
index 0000000..8bb66b3
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/ScopedTransform.h
@@ -0,0 +1,31 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_SCOPEDTRANSFORM_H
+#define IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_SCOPEDTRANSFORM_H
+
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
+
+namespace mlir {
+namespace linalg {
+namespace transform {
+ScopeOp wrapInScope(Operation *op);
+FailureOr<SmallVector<Operation *>> unwrapScope(ScopeOp scope);
+
+template <typename TransformT>
+auto scoped(Operation *target, TransformT &&transform) {
+ auto scope = wrapInScope(target);
+ Operation &op = *scope.body().front().begin();
+ auto result = transform(scope, &op);
+ if (failed(unwrapScope(scope)) || failed(result))
+ return decltype(result)(failure());
+ return result;
+}
+} // namespace transform
+} // namespace linalg
+} // namespace mlir
+
+#endif // IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_SCOPEDTRANSFORM_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h
new file mode 100644
index 0000000..880a39f
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h
@@ -0,0 +1,20 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+class MLIRContext;
+
+/// The only purpose of this class is to enable creation of PatternRewriter
+/// instances as the base class doesn't have a public constructor.
+class SimplePatternRewriter : public PatternRewriter {
+public:
+ explicit SimplePatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}
+};
+
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingCSE.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingCSE.h
new file mode 100644
index 0000000..3374c5b
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingCSE.h
@@ -0,0 +1,22 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_TRACKINGCSE_H
+#define MLIR_DIALECT_LINALG_TRANSFORMS_TRACKINGCSE_H
+
+namespace mlir {
+class DominanceInfo;
+struct LogicalResult;
+class Operation;
+struct RewriteListener;
+
+LogicalResult
+eliminateCommonSubexpressionsWithTrackedOps(Operation *root,
+ RewriteListener &listener,
+ DominanceInfo *domInfo = nullptr);
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TRANSFORMS_TRACKINGCSE_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingListener.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingListener.h
new file mode 100644
index 0000000..bc9d795
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingListener.h
@@ -0,0 +1,79 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_TRANSFORMS_TRACKINGLISTENER_H
+#define IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_TRANSFORMS_TRACKINGLISTENER_H
+
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h"
+#include "iree-dialects/Transforms/Listener.h"
+
+namespace mlir {
+namespace linalg {
+/// A tracking listener using to perform CSE and canonicalization passes while
+/// tracking certain linalg operation handles live in a linalg transform
+/// interpreter.
+class TrackingListener : public RewriteListener,
+ public transform::TransformState::Extension {
+public:
+ TrackingListener(transform::TransformState &state);
+ TrackingListener(TrackingListener &&other)
+ : transform::TransformState::Extension(
+ std::forward<transform::TransformState::Extension>(other)),
+ trackedOperationKeys(std::move(other.trackedOperationKeys)),
+ hadErrors(other.hadErrors) {
+#ifndef NDEBUG
+ errorStateChecked = other.errorStateChecked;
+ other.errorStateChecked = true;
+#endif
+ }
+ ~TrackingListener() {
+#ifndef NDEBUG
+ assert(errorStateChecked && "must check listener error state");
+#endif // NDEBUG
+ }
+
+ /// When a tracked linalg operation is replaced, try to find a single linalg
+ /// op responsible for the replacement values and substitute the handle of the
+ /// replaced op for this op.
+ void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
+
+ /// When a tracked operation is removed (due to CSE or canonicalization), then
+ /// any further transformations on the op are redundant. Remove it from the
+ /// tracked operation list.
+ void notifyOperationRemoved(Operation *op) override;
+
+ void notifySetPayload(Value handle,
+ ArrayRef<Operation *> operations) override;
+ void notifyRemovePayload(Value handle,
+ ArrayRef<Operation *> operations) override;
+
+ /// Emits an error pointing at the given operation. Use this instead of
+ /// directly emitting an error on the operation to set the listener into the
+ /// error state and thus communicate with its user.
+ InFlightDiagnostic emitError(Operation *op, const llvm::Twine &message = {});
+
+ /// Converts the current error state into LogicalResult and clears it.
+ LogicalResult checkErrorState() {
+ LogicalResult result = failure(hadErrors);
+#ifndef NDEBUG
+ errorStateChecked = true;
+#endif // NDEBUG
+ return result;
+ }
+
+private:
+ /// A map from a tracked operation (LinalgOp cannot be used as a key) to its
+ /// key in the map.
+ DenseMap<Operation *, Value> trackedOperationKeys;
+ bool hadErrors = false;
+#ifndef NDEBUG
+ bool errorStateChecked = false;
+#endif // NDEBUG
+};
+} // namespace linalg
+} // namespace mlir
+
+#endif // IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_TRANSFORMS_TRACKINGLISTENER_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingRewriteDriver.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingRewriteDriver.h
new file mode 100644
index 0000000..25657a9
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TrackingRewriteDriver.h
@@ -0,0 +1,28 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_TRACKINGREWRITEDRIVER_H
+#define MLIR_DIALECT_LINALG_TRANSFORMS_TRACKINGREWRITEDRIVER_H
+
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpMapping.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+struct RewriteListener;
+
+/// Apply the given list of transformations to the regions of the
+/// isolated-from-above operation `root` greedily until convergence. Update
+/// Linalg operations in values of `trackedOperations` if they are replaced by
+/// other Linalg operations during the rewriting process. Tracked operations
+/// must be replaced with Linalg operations and must not be erased in the
+/// patterns.
+LogicalResult applyPatternsTrackAndFoldGreedily(
+ Operation *root, RewriteListener &listener,
+ const FrozenRewritePatternSet &patterns,
+ GreedyRewriteConfig config = GreedyRewriteConfig());
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TRANSFORMS_TRACKINGREWRITEDRIVER_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h
new file mode 100644
index 0000000..4093736
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h
@@ -0,0 +1,348 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef MLIR_DIALECT_LINALG_TRANSFORM_TRANSFORM_OP_INTERFACE_H
+#define MLIR_DIALECT_LINALG_TRANSFORM_TRANSFORM_OP_INTERFACE_H
+
+#include <mlir/IR/OpDefinition.h>
+
+#include <type_traits>
+
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpMapping.h"
+#include "iree-dialects/Transforms/Functional.h"
+
+namespace mlir {
+namespace linalg {
+namespace transform {
+
+class TransformOpInterface;
+
+/// The state maintained across applications of various ops implementing the
+/// TransformOpInterface. The operations implementing this interface and the
+/// surrounding structure are referred to as transform IR. The operations to
+/// which transformations apply are referred to as payload IR. The state thus
+/// contains the mapping between values defined transform IR ops and payload IR
+/// ops. It assumes that each value in the transform IR can be used at most once
+/// (since transformations are likely to change the payload IR ops the value
+/// corresponds to). Checks that transform IR values correspond to disjoint sets
+/// of payload IR ops throughout the transformation.
+class TransformState {
+public:
+ /// Creates a state for the transformation rooted at the given op.
+ explicit TransformState(Operation *root);
+
+ /// Returns the op at which the transformation state is rooted. This is
+ /// typically helpful for transformations that apply globally.
+ Operation *getTopLevel() const;
+
+ /// Returns the list of ops that the given transform IR value corresponds to.
+ /// This is helpful for transformations that apply to a particular handle.
+ ArrayRef<Operation *> getPayloadOps(Value value) const;
+
+ /// Applies the transformation specified by the given transform op and updates
+ /// the state accordingly.
+ LogicalResult applyTransform(TransformOpInterface transform);
+
+ /// The extension mechanism for TransformState. Extensions are expected to
+ /// derive this class and may use its methods to access the state. Extensions
+ /// are identified by their type and a state can only have one extension of
+ /// a particular type.
+ class Extension {
+ friend class TransformState;
+
+ public:
+ // Out-of-line implementation to ensure vtable and metadata are emitted in
+ // a single .o file.
+ virtual ~Extension();
+
+ protected:
+ Extension(TransformState &state) : state(state) {}
+
+ /// Read-only access to the mapping between transform IR values and payload
+ /// IR operations contained in the state.
+ const TransformOpMapping &getMapping() const { return state.operations; }
+
+ /// Notifies the extension that payload IR operations were associated with
+ /// the given transform IR handle. Concrete extensions that are willing to
+ /// be notified should override this method.
+ virtual void notifySetPayload(Value handle,
+ ArrayRef<Operation *> operations) {}
+ /// Notifies the extension that the association between a transform IR
+ /// handle and a list of payload IR operations is about to be removed.
+ /// Concrete extensions that are willing to be notified should override this
+ /// method.
+ virtual void notifyRemovePayload(Value handle,
+ ArrayRef<Operation *> operations) {}
+
+ /// Notifies the extension that the ops associated with the transform IR
+ /// handle changed. Concrete extensions that are willing to be notified
+ /// should override this method.
+ virtual void notifyUpdatePayload(Value handle, ArrayRef<Operation *> oldOps,
+ ArrayRef<Operation *> newOps) {}
+
+ /// Sets the payload IR ops associated with the given transform IR value.
+ /// Fails if this would result in multiple transform IR values with uses
+ /// corresponding to the same payload IR ops. This extension will NOT
+ /// be notified about this event.
+ LogicalResult setPayloadOps(Value handle,
+ ArrayRef<Operation *> operations) {
+ propagatingSetPayload = true;
+ LogicalResult result = state.setPayloadOps(handle, operations);
+ propagatingSetPayload = false;
+ return result;
+ }
+
+ /// Forgets the payload IR ops associated with the given transform IR value.
+ /// This extension will NOT be notified about this event.
+ void removePayloadOps(Value handle) {
+ propagatingRemovePayload = true;
+ state.removePayloadOps(handle);
+ propagatingRemovePayload = false;
+ }
+
+ /// Updates the payload IR ops associated with the given transform IR value.
+ /// The callback function is called once per associated operation and is
+ /// expected to return the modified operation or nullptr. In the latter
+ /// case, the corresponding operation is no longer associated with the
+ /// transform IR value. This extension will NOT be notified about it.
+ void updatePayloadOps(Value handle,
+ function_ref<Operation *(Operation *)> callback) {
+ propagatingUpdatePayload = true;
+ state.updatePayloadOps(handle, callback);
+ propagatingUpdatePayload = false;
+ }
+
+ private:
+ /// Flags indicating whether a notifiable event originates at this
+ /// extension. If set, this extension is not notified about the event.
+ bool propagatingSetPayload = false;
+ bool propagatingRemovePayload = false;
+ bool propagatingUpdatePayload = false;
+
+ /// Sends notifications to about an event to the current extension. Expected
+ /// to be called by the TransformState only.
+ void sendNotifySetPayload(Value handle, ArrayRef<Operation *> operations) {
+ if (!propagatingSetPayload)
+ notifySetPayload(handle, operations);
+ }
+ void sendNotifyRemovePayload(Value handle,
+ ArrayRef<Operation *> operations) {
+ if (!propagatingRemovePayload)
+ notifyRemovePayload(handle, operations);
+ }
+ void sendNotifyUpdatePayload(Value handle, ArrayRef<Operation *> oldOps,
+ ArrayRef<Operation *> newOps) {
+ if (!propagatingUpdatePayload)
+ notifyUpdatePayload(handle, oldOps, newOps);
+ }
+
+ /// Back-reference to the state this is extending.
+ TransformState &state;
+ };
+
+ /// Adds a new extension of the type specifeid as template parameter,
+ /// constructing it with the arguments provided. The extension is owned by the
+ /// TransformState. It is expected that the state does not already have an
+ /// extension of the same type. Extension constructors are expected to take
+ /// a reference to TransformState as first argument, automatically supplied
+ /// by this call.
+ template <typename Ty, typename... Args>
+ Ty &addExtension(Args &&... args) {
+ static_assert(
+ std::is_base_of<Extension, Ty>::value,
+ "only an class derived from TransformState::Extension is allowed here");
+ auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
+ auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
+ assert(result.second && "extension already added");
+ return *static_cast<Ty *>(result.first->second.get());
+ }
+
+ /// Returns the extension of the specified type.
+ template <typename Ty>
+ Ty &getExtension() {
+ static_assert(
+ std::is_base_of<Extension, Ty>::value,
+ "only an class derived from TransformState::Extension is allowed here");
+ auto iter = extensions.find(TypeID::get<Ty>());
+ assert(iter != extensions.end() && "extension not found");
+ return *static_cast<Ty *>(iter->second.get());
+ }
+
+ /// Removes the extension of the specified type.
+ template <typename Ty>
+ void removeExtension() {
+ static_assert(
+ std::is_base_of<Extension, Ty>::value,
+ "only an class derived from TransformState::Extension is allowed here");
+ extensions.erase(TypeID::get<Ty>());
+ }
+
+private:
+ /// Identifier for storing top-level value in the `operations` mapping.
+ constexpr const static Value kTopLevelValue = Value();
+
+ /// Sets the payload IR ops associated with the given transform IR value.
+ /// Fails if this would result in multiple transform IR values with uses
+ /// corresponding to the same payload IR ops.
+ LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
+
+ /// Forgets the payload IR ops associated with the given transform IR value.
+ void removePayloadOps(Value value);
+
+ /// Updates the payload IR ops associated with the given transform IR value.
+ /// The callback function is called once per associated operation and is
+ /// expected to return the modified operation or nullptr. In the latter case,
+ /// the corresponding operation is no longer associated with the transform IR
+ /// value.
+ void updatePayloadOps(Value value,
+ function_ref<Operation *(Operation *)> callback);
+
+ /// The mapping between payload IR values and transform IR ops.
+ TransformOpMapping operations;
+
+ /// Extensions attached to the TransformState, identified by the TypeID of
+ /// their type. Only one extension of any given type is allowed.
+ DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
+};
+
+/// Local mapping between values defined by a specific op implementing the
+/// TransformOpInterface and the payload IR ops they correspond to.
+class TransformResults {
+ friend class TransformState;
+
+public:
+ /// Indicates that the result of the transform IR op at the given position
+ /// corresponds to the given list of payload IR ops. Each result must be set
+ /// by the transformation exactly once.
+ void set(OpResult value, ArrayRef<Operation *> ops);
+
+private:
+ /// Creates an instance of TransformResults that expects mappings for
+ /// `numSegments` values.
+ explicit TransformResults(unsigned numSegments);
+
+ /// Gets the list of operations associated with the result at the given
+ /// position.
+ ArrayRef<Operation *> get(unsigned position) const;
+
+ /// Storage for pointers to payload IR ops that are associated with results of
+ /// a transform IR op. `segments` contains as many entries as the transform IR
+ /// op has results. Each entry is a reference to a contiguous segment in
+ /// the `operations` list that contains the pointers to operations. This
+ /// allows for operations to be stored contiguously without nested vectors and
+ /// for different segments to be set in any order.
+ SmallVector<ArrayRef<Operation *>, 2> segments;
+ SmallVector<Operation *> operations;
+};
+
+namespace detail {
+/// Appends `result` to the vector assuming it corresponds to the success state
+/// in `FailureOr<convertible-to-Operation*>`. If `result` is just a
+/// `LogicalResult`, does nothing.
+template <typename Ty>
+std::enable_if_t<std::is_same<Ty, LogicalResult>::value>
+appendTransformResultToVector(Ty result,
+ SmallVectorImpl<Operation *> &results) {}
+
+template <typename Ty>
+std::enable_if_t<!std::is_same<Ty, LogicalResult>::value>
+appendTransformResultToVector(Ty result,
+ SmallVectorImpl<Operation *> &results) {
+ static_assert(
+ std::is_convertible<typename Ty::value_type, Operation *>::value,
+ "Expected transform function to return operations");
+ results.push_back(*result);
+}
+} // namespace detail
+
+/// Applies a one-to-one transform to each of the given targets. Puts the
+/// results of transforms, if any, in `results` in the same order. Fails if any
+/// of the application fails. Individual transforms must be callable with
+/// one of the following signatures:
+/// - FailureOr<convertible-to-Operation*>(OpTy)
+/// - LogicalResult(OpTy)
+/// where OpTy is either
+/// - Operation *, in which case the transform is always applied;
+/// - a concrete Op class, in which case a check is performed whether
+/// `targets` contains operations of the same class and a failure is reported
+/// if it does not.
+template <typename FnTy>
+LogicalResult applyTransformToEach(ArrayRef<Operation *> targets,
+ SmallVectorImpl<Operation *> &results,
+ FnTy transform) {
+ using TransformOpType =
+ typename llvm::function_traits<FnTy>::template arg_t<0>;
+ static_assert(std::is_convertible<TransformOpType, Operation *>::value,
+ "Expected transform function to take an operation");
+ for (Operation *target : targets) {
+ auto specificOp =
+ functional::detail::IsaOr<TransformOpType>::dyn_cast(target);
+ if (!specificOp)
+ return failure();
+
+ auto result = transform(specificOp);
+ if (failed(result))
+ return failure();
+
+ detail::appendTransformResultToVector(result, results);
+ }
+ return success();
+}
+
+/// Trait implementing the TransformOpInterface for operations applying a
+/// transformation to a single operation handle and producing a single operation
+/// handle. The op must implement a method with one of the following signatures:
+/// - FailureOr<convertible-to-Operation*> applyToOne(OpTy)
+/// - LogicalResult applyToOne(OpTy)
+/// to perform a transformation that is applied in turn to all payload IR
+/// operations that correspond to the handle of the transform IR operation.
+/// In the functions above, OpTy is either Operation * or a concrete payload IR
+/// Op class that the transformation is applied to (NOT the class of the
+/// transform IR op).
+template <typename OpTy>
+class TargetableSingleOperandOpTrait
+ : public OpTrait::TraitBase<OpTy, TargetableSingleOperandOpTrait> {
+public:
+ /// Applies the transformation to each op from the only target and sets the
+ /// only result to correspond to the list of individual results.
+ LogicalResult apply(TransformResults &transformResults,
+ TransformState &state) {
+ using TransformOpType = typename llvm::function_traits<decltype(
+ &OpTy::applyToOne)>::template arg_t<0>;
+ ArrayRef<Operation *> targets =
+ state.getPayloadOps(this->getOperation()->getOperand(0));
+ SmallVector<Operation *> results;
+ if (failed(applyTransformToEach(
+ targets, results, [&](TransformOpType specificOp) {
+ return static_cast<OpTy *>(this)->applyToOne(specificOp);
+ })))
+ return failure();
+ if (OpTy::template hasTrait<OpTrait::OneResult>()) {
+ transformResults.set(
+ this->getOperation()->getResult(0).template cast<OpResult>(),
+ results);
+ }
+ return success();
+ }
+
+ /// Verifies that the op satisfies the requirements for this trait.
+ static LogicalResult verifyTrait(Operation *) {
+ static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
+ "expected single-operand op");
+ static_assert(OpTy::template hasTrait<OpTrait::OneResult>() ||
+ OpTy::template hasTrait<OpTrait::ZeroResult>(),
+ "expected zero- or single-result op");
+ return success();
+ }
+};
+
+} // namespace transform
+} // namespace linalg
+} // namespace mlir
+
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_LINALG_TRANSFORM_TRANSFORM_OP_INTERFACE_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpInterface.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpInterface.td
new file mode 100644
index 0000000..2e24d04
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpInterface.td
@@ -0,0 +1,49 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef TRANSFORM_OP_INTERFACE
+#define TRANSFORM_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def TransformOpInterface : OpInterface<"TransformOpInterface"> {
+ let description = [{This interface is to be implemented by operations that
+ identify transformations to be performed on other operations. The former
+ are referred to as transform IR operations. The latter are referred to as
+ payload IR operations. Such transform IR operations provide a fine-grain
+ control mechanism over how transformations are applied by using and defining
+ transform IR values, referred to as handles, that correspond to sets of
+ operations in the payload IR. Transformations are applied starting from
+ the operations identified by handles, but may affect other operations as
+ well. Further restrictions may be imposed by flows that rely on transform
+ IR operations to control transformations.
+ }];
+
+ let cppNamespace = "::mlir::linalg::transform";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{Applies the transformation represented by the current
+ operation. This accepts as arguments the object that must be populated
+ with results of the current transformation and a transformation state
+ object that can be used for queries, e.g., to obtain the list of
+ operations on which the transformation represented by the current op is
+ targeted.}],
+ /*returnType=*/"LogicalResult",
+ /*name=*/"apply",
+ /*arguments=*/(ins
+ "::mlir::linalg::transform::TransformResults &":$transformResults,
+ "::mlir::linalg::transform::TransformState &":$state)
+ >,
+ ];
+}
+
+def TargetableSingleOperandTransformOpTrait
+ : NativeOpTrait<"TargetableSingleOperandOpTrait"> {
+ let cppNamespace = "::mlir::linalg::transform";
+}
+
+#endif // TRANSFORM_OP_INTERFACE
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpMapping.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpMapping.h
new file mode 100644
index 0000000..7158f33
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgTransform/TransformOpMapping.h
@@ -0,0 +1,21 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef MLIR_DIALECT_LINALGTRANSFORMS_TRANSFORMOPMAPPING_H
+#define MLIR_DIALECT_LINALGTRANSFORMS_TRANSFORMOPMAPPING_H
+
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+class Operation;
+class Value;
+
+using TransformOpMapping = DenseMap<Value, SmallVector<Operation *>>;
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALGTRANSFORMS_TRANSFORMOPMAPPING_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/Constants.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/Constants.h
index a8e7655..b662605 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/Constants.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/Constants.h
@@ -144,9 +144,9 @@
FirstCustom = 0x101,
};
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_PYDM_IR_CONSTANTS_H
+#endif // IREE_DIALECTS_DIALECT_PYDM_IR_CONSTANTS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.h
index 0c0a40e..84729d8 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.h
@@ -20,15 +20,15 @@
/// Base class for all unboxed primitive types.
class PrimitiveType : public mlir::Type {
- public:
+public:
using mlir::Type::Type;
static bool classof(Type type);
};
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
// Include generated dialect code (this comment blocks clang-format from
// clobbering order).
@@ -48,9 +48,9 @@
ListType, NoneType, RealType, StrType, TupleType, TypeType>();
}
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_DIALECT_H
+#endif // IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_DIALECT_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h
index 63ef309..61201bb 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h
@@ -19,12 +19,12 @@
enum class BuiltinTypeCode;
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
#include "iree-dialects/Dialect/PyDM/IR/PyDMOpInterfaces.h.inc"
#include "iree-dialects/Dialect/PyDM/IR/PyDMTypeInterfaces.h.inc"
-#endif // IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_INTERFACES_H
+#endif // IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_INTERFACES_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMOps.h
index 18fe67c..905f3a3 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMOps.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMOps.h
@@ -22,4 +22,4 @@
#define GET_OP_CLASSES
#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h.inc"
-#endif // IREE_DIALECTS_IREEPYDM_IR_OPS_H
+#endif // IREE_DIALECTS_IREEPYDM_IR_OPS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/Passes.h
index d29cef4..b5a0fd2 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/Passes.h
@@ -41,20 +41,20 @@
std::unique_ptr<OperationPass<FuncOp>> createLocalPropagateTypesPass();
std::unique_ptr<OperationPass<FuncOp>> createVariablesToSSAPass();
std::unique_ptr<OperationPass<>> createFixateWeakNumericPass();
-std::unique_ptr<OperationPass<ModuleOp>> createLinkIREEPyDMRTLPass(
- Optional<SourceBundle> linkRtlSourceBundle = None);
+std::unique_ptr<OperationPass<ModuleOp>>
+createLinkIREEPyDMRTLPass(Optional<SourceBundle> linkRtlSourceBundle = None);
std::unique_ptr<OperationPass<ModuleOp>> createLowerIREEPyDMToRTLPass();
-void buildPostImportPassPipeline(OpPassManager& passManager);
-void buildLowerToIREEPassPipeline(OpPassManager& passManager,
- const LowerToIREEOptions& options);
+void buildPostImportPassPipeline(OpPassManager &passManager);
+void buildLowerToIREEPassPipeline(OpPassManager &passManager,
+ const LowerToIREEOptions &options);
/// Register all passes and pass pipelines.
void registerPasses();
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h
index c71ee27..9fafff7 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h
@@ -19,7 +19,7 @@
/// An analysis of the external linkage which must be satisfied.
class LinkageAnalysis {
- public:
+public:
LinkageAnalysis(Operation *moduleOp);
/// Whether there are any external functions that need resolution.
@@ -30,13 +30,13 @@
return externFuncOps;
}
- private:
+private:
llvm::SmallVector<Operation *> externFuncOps;
};
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_RTL_LINKAGE_ANALYSIS_H
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_RTL_LINKAGE_ANALYSIS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h
index 5e2ca56..9f0d23c 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h
@@ -23,9 +23,9 @@
TypeConverter &typeConverter,
RewritePatternSet &patterns);
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_LOWERING_PATTERNS_H
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_LOWERING_PATTERNS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h
index 3d3cf46..f8a9be0 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h
@@ -16,7 +16,7 @@
namespace PYDM {
class LoweringTypeConverter : public mlir::TypeConverter {
- public:
+public:
enum class WeakFloatType {
F32,
F64,
@@ -32,15 +32,15 @@
bool isTypeLegal(Type t) const;
bool areTypesLegal(TypeRange types) const;
- private:
+private:
bool boolBits = 32;
int weakIntegerBits = 32;
WeakFloatType weakFloatType = WeakFloatType::F32;
};
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_TYPECONVERTER_H
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_TYPECONVERTER_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Utils/TypeInference.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Utils/TypeInference.h
index aec96be..35a8abb 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Utils/TypeInference.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Utils/TypeInference.h
@@ -7,14 +7,14 @@
#ifndef IREE_DIALECTS_DIALECT_IREEPYDM_UTILS_TYPE_INFERENCE_H
#define IREE_DIALECTS_DIALECT_IREEPYDM_UTILS_TYPE_INFERENCE_H
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Allocator.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Allocator.h"
namespace mlir {
namespace iree_compiler {
@@ -27,7 +27,7 @@
/// generally, duplicating/permuting blocks or regions is preferred over
/// unifying.
class PermutedTypePropagator {
- public:
+public:
PermutedTypePropagator(MLIRContext *context) : context(context) {}
// ---------------------------------------------------------------------------
@@ -80,7 +80,7 @@
TypeRange newArgumentTypes,
BlockPermuteCallback initializeCallback);
- private:
+private:
MLIRContext *context;
llvm::BumpPtrAllocator allocator;
@@ -90,9 +90,9 @@
Block *block);
};
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_IREEPYDM_UTILS_TYPE_INFERENCE_H
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_UTILS_TYPE_INFERENCE_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Functional.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Functional.h
new file mode 100644
index 0000000..44e0539
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Functional.h
@@ -0,0 +1,347 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef MLIR_REWRITE_FUNCTIONAL_H
+#define MLIR_REWRITE_FUNCTIONAL_H
+
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
+
+namespace mlir {
+namespace functional {
+
+/// A "Functional Pattern" is a callable concept that accepts as its first
+/// argument an operation or operation interface and as its second argument a
+/// `RewriterBase` or `PatternRewriter`. Beyond these, it can accept additional
+/// parameters of any type.
+///
+/// A functional pattern is expected to return a type convertible to
+/// `LogicalResult`. If the result is a `FailureOr<T>`, then `T` is forwarded to
+/// subsequent patterns in sequences. Additionally, if `T` is a pair or tuple,
+/// its elements are unpacked and passed as separate arguments to subsequent
+/// patterns.
+template <typename PatternT>
+struct PatternConcept {
+ static constexpr bool verify() {
+ using Traits = llvm::function_traits<std::decay_t<PatternT>>;
+ static_assert(Traits::num_args >= 2,
+ "Patterns must have at least two arguments.");
+
+ using OpT = typename Traits::template arg_t<0>;
+ static_assert(std::is_convertible<OpT, Operation *>::value,
+ "The first argument of a pattern must be Operation * or "
+ "convertible to Operation *");
+
+ using RewriterT = typename Traits::template arg_t<1>;
+ static_assert(std::is_convertible<PatternRewriter &, RewriterT>::value,
+ "The second argument of a pattern must be convertible from "
+ "PatternRewriter & (e.g. RewriterBase &)");
+
+ using ResultT = typename Traits::result_t;
+ static_assert(
+ std::is_convertible<ResultT, LogicalResult>::value,
+ "The result of a pattern must be convertible to LogicalResult");
+
+ return true;
+ }
+};
+
+namespace detail {
+/// A simple pattern rewriter that implements no special logic.
+class SimpleRewriter : public PatternRewriter {
+public:
+ SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
+};
+} // end namespace detail
+
+/// Apply a pattern directly on an operation. This function instantiates a
+/// simple pattern rewriter and calls the pattern directly on the operation with
+/// the given arguments.
+template <typename OpT, typename PatternT, typename... Args,
+ bool = PatternConcept<PatternT>::verify()>
+auto applyAt(OpT op, PatternT &&pattern, Args &&... args) {
+ detail::SimpleRewriter rewriter(op->getContext());
+ rewriter.setInsertionPoint(op);
+ return pattern(op, rewriter, std::forward<Args>(args)...);
+}
+
+/// Given a scope, apply a pattern with the given arguments until the first
+/// successful match and return the result. This function instantiates a simple
+/// pattern rewriter.
+template <typename PatternT, typename... Args,
+ bool = PatternConcept<PatternT>::verify()>
+auto applyOnceIn(Operation *scope, PatternT &&pattern, Args &&... args) {
+ assert(scope->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+ "scope is not isolated from above");
+ using Traits = llvm::function_traits<std::decay_t<PatternT>>;
+ using OpT = typename Traits::template arg_t<0>;
+
+ detail::SimpleRewriter rewriter(scope->getContext());
+ typename Traits::result_t result = failure();
+ scope->walk([pattern = std::forward<PatternT>(pattern), &result, &rewriter,
+ &args...](OpT op) {
+ rewriter.setInsertionPoint(op);
+ result = pattern(op, rewriter, std::forward<Args>(args)...);
+ return failed(result) ? WalkResult::advance() : WalkResult::interrupt();
+ });
+ return result;
+}
+
+namespace detail {
+/// If a pattern returns `FailureOr<Type>`, unpack the nested value of `Type`.
+/// Otherwise, just return the whole value.
+template <typename ReturnType>
+struct UnpackFailureOr {
+ using type = ReturnType;
+ /// Base case. If a pattern does not return `FailureOr`, just forward the
+ /// whole result. Usually, this is a `LogicalResult`.
+ static type unpack(type &&value) { return value; }
+};
+template <typename NestedType>
+struct UnpackFailureOr<FailureOr<NestedType>> {
+ using type = NestedType;
+ /// Specialized case for `FailureOr`. Assumes that the pattern succeeded.
+ /// Return the contained value.
+ static type unpack(FailureOr<type> &&value) { return std::move(*value); }
+};
+} // end namespace detail
+
+/// Given a scope, apply a pattern once per operation in the scope, saving the
+/// result of each match. The result list is empty if the pattern failed to
+/// match at all.
+template <typename PatternT, typename... Args,
+ bool = PatternConcept<PatternT>::verify()>
+auto applyForEachIn(Operation *scope, PatternT &&pattern, Args &&... args) {
+ using Traits = llvm::function_traits<std::decay_t<PatternT>>;
+ using Unpack = detail::UnpackFailureOr<typename Traits::result_t>;
+
+ assert(scope->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+ "scope is not isolated from above");
+
+ detail::SimpleRewriter rewriter(scope->getContext());
+ // A list of all the results.
+ SmallVector<typename Unpack::type> results;
+ scope->walk([pattern = std::forward<PatternT>(pattern), &rewriter, &results,
+ &args...](Operation *op) {
+ rewriter.setInsertionPoint(op);
+ auto result = pattern(op, rewriter, std::forward<Args>(args)...);
+ // If the pattern applied, unpack the result and store it.
+ if (succeeded(result))
+ results.push_back(Unpack::unpack(std::move(result)));
+ });
+ return results;
+}
+
+/// Apply a pattern directly on an operation, for each operation in a list.
+template <typename ListT, typename PatternT, typename... Args,
+ bool = PatternConcept<PatternT>::verify()>
+auto applyForEach(ListT &&list, PatternT &&pattern, Args &&... args) {
+ using Traits = llvm::function_traits<std::decay_t<PatternT>>;
+ using Unpack = detail::UnpackFailureOr<typename Traits::result_t>;
+
+ // A list of all the results.
+ SmallVector<typename Unpack::type> results;
+ for (auto op : list) {
+ auto result = applyAt(op, std::forward<PatternT>(pattern),
+ std::forward<Args>(args)...);
+ // The pattern applied, unpack the result and store it.
+ if (succeeded(result))
+ results.push_back(Unpack::unpack(std::move(result)));
+ }
+ return results;
+}
+
+namespace detail {
+/// Utility struct for handling functional patterns that may operate on generic
+/// `Operation *` or a more specific interface or op type. In the base case,
+/// patterns need to check that the correct type was passed and may need to cast
+/// to that type.
+template <typename OpT>
+struct IsaOr {
+ static bool apply(Operation *op) { return isa<OpT>(op); }
+ static OpT cast(Operation *op) { return ::mlir::cast<OpT>(op); }
+ static OpT dyn_cast(Operation *op) { return ::mlir::dyn_cast<OpT>(op); }
+};
+/// In the special case, nothing needs to be done. Just pass the generic op
+/// directly into the pattern.
+template <>
+struct IsaOr<Operation *> {
+ static bool apply(Operation *op) { return true; }
+ static Operation *cast(Operation *op) { return op; }
+ static Operation *dyn_cast(Operation *op) { return op; }
+};
+
+/// A sequence here is a tuple of unique functions. However, when constructing
+/// the sequence, the result types of subsequent functions are not visible. In
+/// order to generically pass around the entire sequence, it is stored as a list
+/// of opaque pointers.
+using OpaqueSeq = ArrayRef<void *>;
+
+/// Unpack a non-tuple return type into a tuple.
+template <typename ResultT>
+struct UnpackIntoTuple {
+ static auto apply(ResultT &&result) {
+ return std::make_tuple(std::forward<ResultT>(result));
+ }
+};
+/// Unpack a pair into a tuple.
+template <typename FirstT, typename SecondT>
+struct UnpackIntoTuple<std::pair<FirstT, SecondT>> {
+ static auto apply(std::pair<FirstT, SecondT> &&result) {
+ return std::make_tuple(std::move(result.first), std::move(result.second));
+ }
+};
+/// If the result type is already a tuple, just forward it.
+template <typename... ResultTs>
+struct UnpackIntoTuple<std::tuple<ResultTs...>> {
+ static auto apply(std::tuple<ResultTs...> &&result) {
+ return std::forward<std::tuple<ResultTs...>>(result);
+ }
+};
+
+/// Utility function for calling another pattern in the sequence where the
+/// arguments are packed into a tuple. Similar to `std::apply`, except the first
+/// argument is passed as the operation.
+template <typename PatternT, typename OpT, typename Args, size_t... Indices>
+auto callNextPattern(PatternT &&pattern, OpT op, PatternRewriter &rewriter,
+ Args &&args, std::index_sequence<Indices...>) {
+ rewriter.setInsertionPoint(op);
+ return pattern(op, rewriter, std::move(std::get<Indices + 1>(args))...);
+}
+
+/// A pattern sequence is implemented as a tuple of unique functions.
+template <typename... UniqueFunctionTs>
+struct GenericSequence : public std::tuple<UniqueFunctionTs...> {
+ /// Inherit the tuple constructor.
+ using std::tuple<UniqueFunctionTs...>::tuple;
+ /// The number of patterns in the sequence.
+ static constexpr size_t NumPatterns = sizeof...(UniqueFunctionTs);
+
+ /// Get the equivalent tuple type, for use with tuple type utilities.
+ template <typename...>
+ struct GenericToTupleType;
+ template <typename... Ts>
+ struct GenericToTupleType<GenericSequence<Ts...>> {
+ using type = std::tuple<Ts...>;
+ };
+
+ /// Create a new sequence that contains all the patterns of an existing
+ /// sequence but appended with a new pattern. The previous sequence is
+ /// invalidated.
+ template <typename ResultT, typename PrevT, size_t... Indices>
+ static auto moveInto(PrevT &&prev, std::index_sequence<Indices...>) {
+ return GenericSequence<
+ std::tuple_element_t<Indices,
+ typename GenericToTupleType<PrevT>::type>...,
+ llvm::unique_function<ResultT(Operation *, PatternRewriter &,
+ OpaqueSeq)>>(
+ std::move(std::get<Indices>(prev))...,
+ // Populate an empty function and define it later.
+ llvm::unique_function<ResultT(Operation *, PatternRewriter &,
+ OpaqueSeq)>());
+ }
+
+ /// Chain a pattern with another pattern. When calling `seq.then(...)`, the
+ /// results of the previous pattern are passed to the subsequent pattern as
+ /// follows:
+ /// - if the pattern failed, then sequence execution is aborted
+ /// - the first result must be convertible to `Operation *`, and is passed as
+ /// the operation into the next pattern
+ /// - additional results are passed as arguments
+ template <typename PatternT, bool = PatternConcept<PatternT>::verify()>
+ auto then(PatternT &&pattern) {
+ using Traits = llvm::function_traits<std::decay_t<PatternT>>;
+ using OpT = typename Traits::template arg_t<0>;
+ using ResultT = typename Traits::result_t;
+
+ // Get the type of the previous pattern fucntion.
+ using PrevT =
+ std::remove_reference_t<decltype(std::get<NumPatterns - 1>(*this))>;
+ // Copy all the patterns into a new sequence.
+ auto seq = moveInto<ResultT>(std::move(*this),
+ std::make_index_sequence<NumPatterns>());
+ // Get a reference to the n
+ auto &next = std::get<NumPatterns>(seq);
+ // Each pattern in the sequence calls the previous pattern, except the first
+ // pattern.
+ next = [pattern = std::forward<PatternT>(pattern)](
+ Operation *op, PatternRewriter &rewriter,
+ OpaqueSeq opaqueSeq) -> ResultT {
+ // FIXME: this is a hack to get around knowing all the return types.
+ auto prevResult =
+ (*(PrevT *)opaqueSeq[NumPatterns - 1])(op, rewriter, opaqueSeq);
+ if (failed(prevResult))
+ return failure();
+
+ // The previous pattern succeeded. Unpack the results into a tuple to pass
+ // as arguments to the next pattern.
+ auto args = UnpackIntoTuple<std::remove_reference_t<decltype(
+ *prevResult)>>::apply(std::move(*prevResult));
+ // Grab the first result value as the operation.
+ Operation *nextOp = std::get<0>(args);
+ if (!detail::IsaOr<OpT>::apply(nextOp))
+ return failure();
+ // Call the next pattern.
+ return callNextPattern(
+ pattern, detail::IsaOr<OpT>::cast(nextOp), rewriter, std::move(args),
+ std::make_index_sequence<std::tuple_size<decltype(args)>::value -
+ 1>());
+ };
+ return seq;
+ }
+
+ // Convert a generic sequence into an opaque sequence.
+ template <typename SequenceT, size_t... Indices>
+ static auto getAsOpaqueSeq(SequenceT &seq, std::index_sequence<Indices...>) {
+ SmallVector<void *> opaqueSeq = {(void *)&std::get<Indices>(seq)...};
+ return opaqueSeq;
+ }
+
+ // Implement the call operator using tail-recursion.
+ auto operator()(Operation *op, PatternRewriter &rewriter) {
+ return std::get<NumPatterns - 1>(*this)(
+ op, rewriter,
+ getAsOpaqueSeq(*this, std::make_index_sequence<NumPatterns>()));
+ }
+};
+
+/// The starting point for constructing a pattern sequence.
+struct SequenceBuilder {
+ /// The first pattern of the sequence receives arguments directly from the
+ /// caller and does not recurse.
+ template <typename PatternT, typename... Args,
+ bool = PatternConcept<PatternT>::verify()>
+ auto begin(PatternT &&pattern, Args &&... args) {
+ using Traits = llvm::function_traits<std::decay_t<PatternT>>;
+ using OpT = typename Traits::template arg_t<0>;
+ using ResultT = typename Traits::result_t;
+
+ /// Create the function that calls the pattern.
+ llvm::unique_function<ResultT(Operation *, PatternRewriter &, OpaqueSeq)>
+ call = [pattern = std::forward<PatternT>(pattern),
+ &args...](Operation *op, PatternRewriter &rewriter,
+ OpaqueSeq) -> ResultT {
+ if (!detail::IsaOr<OpT>::apply(op))
+ return failure();
+ rewriter.setInsertionPoint(op);
+ return pattern(detail::IsaOr<OpT>::cast(op), rewriter,
+ std::forward<Args>(args)...);
+ };
+
+ /// Insert it into a generic sequence and return.
+ return GenericSequence<llvm::unique_function<ResultT(
+ Operation *, PatternRewriter &, OpaqueSeq)>>(std::move(call));
+ }
+};
+
+} // end namespace detail
+
+struct SequenceBuilder : public detail::SequenceBuilder {};
+
+} // end namespace functional
+} // end namespace mlir
+
+#endif // MLIR_REWRITE_FUNCTIONAL_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Listener.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Listener.h
new file mode 100644
index 0000000..da3fb71
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/Listener.h
@@ -0,0 +1,119 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H
+#define IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// RewriteListener
+//===----------------------------------------------------------------------===//
+
+/// This class represents a listener that can be used to hook on to various
+/// rewrite events in an `OpBuilder` or `PatternRewriter`. The class is notified
+/// by when:
+///
+/// - an operation is removed
+/// - an operation is inserted
+/// - an operation is replaced
+/// - a block is created
+///
+/// Listeners can be used to track IR mutations throughout pattern rewrites.
+struct RewriteListener {
+ virtual ~RewriteListener();
+
+ /// These are the callback methods that subclasses can choose to implement if
+ /// they would like to be notified about certain types of mutations.
+
+ /// Notification handler for when an operation is inserted into the builder.
+ /// op` is the operation that was inserted.
+ virtual void notifyOperationInserted(Operation *op) {}
+
+ /// Notification handler for when a block is created using the builder.
+ /// `block` is the block that was created.
+ virtual void notifyBlockCreated(Block *block) {}
+
+ /// Notification handler for when the specified operation is about to be
+ /// replaced with another set of operations. This is called before the uses of
+ /// the operation have been replaced with the specific values.
+ virtual void notifyOperationReplaced(Operation *op, ValueRange newValues) {}
+
+ /// Notification handler for when an the specified operation is about to be
+ /// deleted. At this point, the operation has zero uses.
+ virtual void notifyOperationRemoved(Operation *op) {}
+
+ /// Notify the listener that a pattern failed to match the given operation,
+ /// and provide a callback to populate a diagnostic with the reason why the
+ /// failure occurred. This method allows for derived listeners to optionally
+ /// hook into the reason why a rewrite failed, and display it to users.
+ virtual void
+ notifyMatchFailure(Operation *op,
+ function_ref<void(Diagnostic &)> reasonCallback) {}
+};
+
+//===----------------------------------------------------------------------===//
+// ListenerList
+//===----------------------------------------------------------------------===//
+
+/// This class contains multiple listeners to which rewrite events can be sent.
+class ListenerList : public RewriteListener {
+public:
+ /// Add a listener to the list.
+ void addListener(RewriteListener *listener) { listeners.push_back(listener); }
+
+ /// Send notification of an operation being inserted to all listeners.
+ void notifyOperationInserted(Operation *op) override;
+ /// Send notification of a block being created to all listeners.
+ void notifyBlockCreated(Block *block) override;
+ /// Send notification that an operation has been replaced to all listeners.
+ void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
+ /// Send notification that an operation is about to be deleted to all
+ /// listeners.
+ void notifyOperationRemoved(Operation *op) override;
+ /// Notify all listeners that a pattern match failed.
+ void
+ notifyMatchFailure(Operation *op,
+ function_ref<void(Diagnostic &)> reasonCallback) override;
+
+private:
+ /// The list of listeners to send events to.
+ SmallVector<RewriteListener *, 1> listeners;
+};
+
+//===----------------------------------------------------------------------===//
+// PatternRewriterListener
+//===----------------------------------------------------------------------===//
+
+/// This class implements a pattern rewriter with a rewrite listener. Rewrite
+/// events are forwarded to the provided rewrite listener.
+class PatternRewriterListener : public PatternRewriter, public ListenerList {
+public:
+ PatternRewriterListener(MLIRContext *context) : PatternRewriter(context) {}
+
+ /// When an operation is about to be replaced, send out an event to all
+ /// attached listeners.
+ void replaceOp(Operation *op, ValueRange newValues) override {
+ notifyOperationReplaced(op, newValues);
+ PatternRewriter::replaceOp(op, newValues);
+ }
+
+ void notifyOperationInserted(Operation *op) override {
+ ListenerList::notifyOperationInserted(op);
+ }
+ void notifyBlockCreated(Block *block) override {
+ ListenerList::notifyBlockCreated(block);
+ }
+ void notifyOperationRemoved(Operation *op) override {
+ ListenerList::notifyOperationRemoved(op);
+ }
+};
+
+} // namespace mlir
+
+#endif // IREE_LLVM_SANDBOX_TRANSFORMS_LISTENER_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h
new file mode 100644
index 0000000..a062f2a
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerCSE.h
@@ -0,0 +1,21 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H
+#define LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H
+
+#include "iree-dialects/Transforms/Listener.h"
+
+namespace mlir {
+class DominanceInfo;
+class Operation;
+
+LogicalResult eliminateCommonSubexpressions(Operation *op,
+ DominanceInfo *domInfo,
+ RewriteListener *listener);
+} // namespace mlir
+
+#endif // LLVM_IREE_SANDBOX_TRANSFORMS_LISTENERCSE_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h
new file mode 100644
index 0000000..b5274c5
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h
@@ -0,0 +1,30 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Transforms/Listener.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+
+namespace mlir {
+struct GreedyRewriteConfig;
+
+/// Applies the specified patterns on `op` alone while also trying to fold it,
+/// by selecting the highest benefits patterns in a greedy manner. Returns
+/// success if no more patterns can be matched. `erased` is set to true if `op`
+/// was folded away or erased as a result of becoming dead. Note: This does not
+/// apply any patterns recursively to the regions of `op`. Accepts a listener
+/// so the caller can be notified of rewrite events.
+LogicalResult applyPatternsAndFoldGreedily(
+ MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
+ const GreedyRewriteConfig &config, RewriteListener *listener);
+inline LogicalResult applyPatternsAndFoldGreedily(
+ Operation *op, const FrozenRewritePatternSet &patterns,
+ const GreedyRewriteConfig &config, RewriteListener *listener) {
+ return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config,
+ listener);
+}
+
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
index 5c0e24d..f9159aa 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
@@ -5,6 +5,7 @@
MLIRIR
IREEInputDialect
IREELinalgExtDialect
+ IREELinalgTransformDialect
IREEPyDMDialect
IREEPyDMPasses
)
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
index 569e530..3aa63ad 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
@@ -8,6 +8,7 @@
#include "iree-dialects/Dialect/Input/InputDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
#include "mlir/CAPI/IR.h"
@@ -36,6 +37,14 @@
IREELinalgExt, iree_linalg_ext,
mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect)
+//===--------------------------------------------------------------------===//
+// IREELinalgTransform
+//===--------------------------------------------------------------------===//
+
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(
+ IREELinalgTransform, iree_linalg_transform,
+ mlir::linalg::transform::LinalgTransformDialect)
+
//===----------------------------------------------------------------------===//
// IREEPyDMDialect
//===----------------------------------------------------------------------===//
@@ -52,12 +61,12 @@
return unwrap(type).isa<PYDM::PrimitiveType>();
}
-#define IREEPYDM_DEFINE_NULLARY_TYPE(Name) \
- bool mlirTypeIsAIREEPyDM##Name(MlirType type) { \
- return unwrap(type).isa<PYDM::Name##Type>(); \
- } \
- MlirType mlirIREEPyDM##Name##TypeGet(MlirContext ctx) { \
- return wrap(PYDM::Name##Type::get(unwrap(ctx))); \
+#define IREEPYDM_DEFINE_NULLARY_TYPE(Name) \
+ bool mlirTypeIsAIREEPyDM##Name(MlirType type) { \
+ return unwrap(type).isa<PYDM::Name##Type>(); \
+ } \
+ MlirType mlirIREEPyDM##Name##TypeGet(MlirContext ctx) { \
+ return wrap(PYDM::Name##Type::get(unwrap(ctx))); \
}
IREEPYDM_DEFINE_NULLARY_TYPE(Bool)
@@ -97,8 +106,8 @@
return wrap(PYDM::ObjectType::get(unwrap(ctx), cppType));
}
-MLIR_CAPI_EXPORTED void mlirIREEPyDMBuildPostImportPassPipeline(
- MlirOpPassManager passManager) {
+MLIR_CAPI_EXPORTED void
+mlirIREEPyDMBuildPostImportPassPipeline(MlirOpPassManager passManager) {
auto *passManagerCpp = unwrap(passManager);
PYDM::buildPostImportPassPipeline(*passManagerCpp);
}
diff --git a/llvm-external-projects/iree-dialects/lib/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/CMakeLists.txt
index 47ce6dd..76b98c3 100644
--- a/llvm-external-projects/iree-dialects/lib/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(CAPI)
add_subdirectory(Dialect)
+add_subdirectory(Transforms)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
index 620c526..504c744 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(Input)
add_subdirectory(LinalgExt)
+add_subdirectory(LinalgTransform)
add_subdirectory(PyDM)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputDialect.cpp
index a12a1b9..62c755b 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputDialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputDialect.cpp
@@ -7,9 +7,9 @@
#include "iree-dialects/Dialect/Input/InputDialect.h"
#include "iree-dialects/Dialect/Input/InputOps.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::iree_compiler::IREE::Input;
@@ -63,7 +63,7 @@
printer << "<" << getTargetType() << ">";
}
-} // namespace Input
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace Input
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
index 4657c12..5d824a0 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
@@ -7,12 +7,12 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using namespace mlir::iree_compiler::IREE::LinalgExt;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
index 5fdaf8a..19e437a 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
@@ -18,8 +18,8 @@
return result;
}
-LogicalResult IREE::LinalgExt::detail::verifyLinalgExtOpInterface(
- Operation *op) {
+LogicalResult
+IREE::LinalgExt::detail::verifyLinalgExtOpInterface(Operation *op) {
LinalgExtOp linalgExtOp = cast<LinalgExtOp>(op);
if (op->getNumResults()) {
if (!linalgExtOp.hasTensorSemantics()) {
@@ -48,4 +48,4 @@
return success();
}
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.cpp.inc" // IWYU pragma: export
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.cpp.inc" // IWYU pragma: export
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index af9ae07..6c06913 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -7,11 +7,6 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallSet.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/SMLoc.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -33,6 +28,11 @@
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/SMLoc.h"
using namespace mlir;
using namespace mlir::iree_compiler::IREE::LinalgExt;
@@ -330,7 +330,8 @@
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
- if (starts[i]) cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
+ if (starts[i])
+ cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
starts[i] = cast;
}
@@ -435,8 +436,8 @@
return loopBounds;
}
-SmallVector<unsigned> SortOp::getPartitionableLoops(
- unsigned maxNumParallelDims) {
+SmallVector<unsigned>
+SortOp::getPartitionableLoops(unsigned maxNumParallelDims) {
auto range = llvm::seq<unsigned>(0, getOperandRank());
SmallVector<unsigned> partitionableLoops(range.begin(), range.end());
partitionableLoops.erase(std::next(partitionableLoops.begin(), dimension()));
@@ -568,7 +569,8 @@
// After tiling, it could be dynamic shape. (Because
// subview/subtensor does not inference the type correctly
// on (1 << x)) cases).
- if (length == ShapedType::kDynamicSize) return success();
+ if (length == ShapedType::kDynamicSize)
+ return success();
if (length & (length - 1)) {
return op->emitOpError("only powers of 2 are handled currently");
}
@@ -763,8 +765,8 @@
return success();
}
-SmallVector<unsigned> FftOp::getPartitionableLoops(
- unsigned maxNumParallelDims) {
+SmallVector<unsigned>
+FftOp::getPartitionableLoops(unsigned maxNumParallelDims) {
auto range = llvm::seq<unsigned>(0, getOperandRank());
SmallVector<unsigned> partitionableLoops(range.begin(), range.end());
// Indices matter for coeff computation.
@@ -842,7 +844,8 @@
}
SmallVector<int64_t> expectedAccumulatorShape;
for (int i = 0; i < inputType.getRank(); i++) {
- if (i != dimension()) expectedAccumulatorShape.push_back(inputShapes[i]);
+ if (i != dimension())
+ expectedAccumulatorShape.push_back(inputShapes[i]);
}
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
[](std::tuple<int64_t, int64_t> s) {
@@ -892,8 +895,8 @@
return iteratorTypes;
}
-SmallVector<unsigned> ScanOp::getPartitionableLoops(
- unsigned maxNumParallelDims) {
+SmallVector<unsigned>
+ScanOp::getPartitionableLoops(unsigned maxNumParallelDims) {
auto range = llvm::seq<unsigned>(0, getOperandRank());
SmallVector<unsigned> partitionableLoops(range.begin(), range.end());
partitionableLoops.erase(std::next(partitionableLoops.begin(), dimension()));
@@ -921,7 +924,8 @@
bool isInclusive = inclusive();
SmallVector<Value> accIndices;
for (int i = 0; i < indices.size(); i++) {
- if (i != scanDim) accIndices.push_back(indices[i]);
+ if (i != scanDim)
+ accIndices.push_back(indices[i]);
}
auto scfIf = b.create<scf::IfOp>(
@@ -943,9 +947,11 @@
indices[scanDim] = ivMinusOne;
scanBlkArgs.push_back(b.create<memref::LoadOp>(loc, output(), indices));
Value i0;
- if (!isInclusive) i0 = b.create<memref::LoadOp>(loc, input(), indices);
+ if (!isInclusive)
+ i0 = b.create<memref::LoadOp>(loc, input(), indices);
indices[scanDim] = iv;
- if (isInclusive) i0 = b.create<memref::LoadOp>(loc, input(), indices);
+ if (isInclusive)
+ i0 = b.create<memref::LoadOp>(loc, input(), indices);
scanBlkArgs.push_back(i0);
});
@@ -1174,14 +1180,14 @@
return tiledRevOp;
}
-#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
- void OP_NAME::getEffects( \
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
- &effects) { \
- SmallVector<Value> inputBuffers = getInputBufferOperands(); \
- SmallVector<Value> outputBuffers = getOutputBufferOperands(); \
- getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \
- outputBuffers); \
+#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
+ void OP_NAME::getEffects( \
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
+ &effects) { \
+ SmallVector<Value> inputBuffers = getInputBufferOperands(); \
+ SmallVector<Value> outputBuffers = getOutputBufferOperands(); \
+ getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \
+ outputBuffers); \
}
DEFINE_OP_GET_EFFECTS(ScatterOp)
@@ -1201,11 +1207,13 @@
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
- if (opOperand->get().isa<BlockArgument>()) return false;
+ if (opOperand->get().isa<BlockArgument>())
+ return false;
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
- if (!hasTensorCastOperand) return failure();
+ if (!hasTensorCastOperand)
+ return failure();
SmallVector<Type, 4> newResultTypes;
newResultTypes.reserve(op->getNumResults());
@@ -1246,7 +1254,7 @@
return success();
}
};
-} // namespace
+} // namespace
//===----------------------------------------------------------------------===//
// TileOp
@@ -1300,7 +1308,8 @@
void TileOp::print(OpAsmPrinter &p) {
p << ' ' << tile_size() << ' ';
- if (tiled_dim() > 0) p << "tiled_dim = " << tiled_dim() << ' ';
+ if (tiled_dim() > 0)
+ p << "tiled_dim = " << tiled_dim() << ' ';
if (!outs().empty()) {
p << "outs(";
llvm::interleaveComma(outs(), p,
@@ -1357,7 +1366,8 @@
result.operands))
return failure();
}
- if (parser.parseArrowTypeList(result.types)) return failure();
+ if (parser.parseArrowTypeList(result.types))
+ return failure();
SmallVector<OpAsmParser::OperandType, 8> regionOperands;
std::unique_ptr<Region> region = std::make_unique<Region>();
@@ -1366,7 +1376,8 @@
return failure();
// Parse the optional attribute list.
- if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
TileOp::ensureTerminator(*region, builder, result.location);
result.addRegion(std::move(region));
@@ -1426,7 +1437,8 @@
if (parser.parseOperand(numThreads) ||
parser.resolveOperand(numThreads, indexType, result.operands))
return failure();
- if (parser.parseArrowTypeList(result.types)) return failure();
+ if (parser.parseArrowTypeList(result.types))
+ return failure();
SmallVector<OpAsmParser::OperandType, 8> regionOperands;
SmallVector<Type, 8> regionTypes;
@@ -1437,7 +1449,8 @@
result.addRegion(std::move(region));
// Parse the optional attribute list.
- if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
return success();
}
@@ -1541,7 +1554,7 @@
/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
class ParallelInsertSliceOpConstantArgumentFolder final
: public OpRewritePattern<ParallelInsertSliceOp> {
- public:
+public:
using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
@@ -1569,7 +1582,7 @@
return success();
}
};
-} // namespace
+} // namespace
void ParallelInsertSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
@@ -1604,14 +1617,16 @@
result.addRegion(std::move(region));
// Parse the optional attribute list.
- if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
return success();
}
SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() {
- return llvm::to_vector(llvm::map_range(
- this->yieldingOps(),
- [](ParallelInsertSliceOp op) { return op.yieldedType(); }));
+ return llvm::to_vector(
+ llvm::map_range(this->yieldingOps(), [](ParallelInsertSliceOp op) {
+ return op.yieldedType();
+ }));
}
SmallVector<ParallelInsertSliceOp> PerformConcurrentlyOp::yieldingOps() {
@@ -1625,7 +1640,7 @@
if (auto endPerformOp = llvm::dyn_cast<EndPerformConcurrentlyOp>(op)) {
continue;
}
- llvm_unreachable("Unexpected operation in perform_concurrently");
+ assert(false && "Unexpected operation in perform_concurrently");
}
return ret;
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/TiledOpInterface.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/TiledOpInterface.cpp
index 06bc712..e509489 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/TiledOpInterface.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/TiledOpInterface.cpp
@@ -6,13 +6,13 @@
#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h"
-#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "iree-tiled-op-interface"
@@ -299,7 +299,7 @@
}
};
-} // namespace
+} // namespace
void IREE::LinalgExt::registerTiledOpInterfaceExternalModels(
DialectRegistry ®istry) {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertToLoops.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertToLoops.cpp
index da62126..10629da 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertToLoops.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertToLoops.cpp
@@ -8,9 +8,6 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
@@ -20,6 +17,9 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
using namespace mlir;
namespace IREE = mlir::iree_compiler::IREE;
@@ -81,7 +81,7 @@
return success();
}
};
-} // namespace
+} // namespace
//===----------------------------------------------------------------------===//
// Pass
@@ -107,7 +107,7 @@
}
}
};
-} // namespace
+} // namespace
std::unique_ptr<OperationPass<FuncOp>>
IREE::LinalgExt::createLinalgExtToLoopsPass() {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/PadContractionToBlockSize.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/PadContractionToBlockSize.cpp
index a2fe9bd..340e6c4 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/PadContractionToBlockSize.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/PadContractionToBlockSize.cpp
@@ -77,7 +77,8 @@
needsPad = true;
}
}
- if (!needsPad) return false;
+ if (!needsPad)
+ return false;
auto resultType = RankedTensorType::get(newStaticDims, type.getElementType());
Value zeroConstant = builder.create<arith::ConstantOp>(
@@ -132,7 +133,7 @@
});
}
};
-} // namespace
+} // namespace
std::unique_ptr<OperationPass<>>
IREE::LinalgExt::createPadContractionToBlockSizePass() {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp
index f038541..70a6526 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp
@@ -20,13 +20,13 @@
namespace detail {
#define GEN_PASS_REGISTRATION
-#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h.inc" // IWYU pragma: export
-} // namespace detail
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h.inc" // IWYU pragma: export
+} // namespace detail
-} // namespace LinalgExt
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
void IREE::LinalgExt::registerPasses() {
IREE::LinalgExt::detail::registerPasses();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
index fd66bff..c95f927 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
@@ -11,7 +11,6 @@
#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Transforms.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,6 +21,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
namespace IREE = mlir::iree_compiler::IREE;
@@ -33,9 +33,9 @@
//===----------------------------------------------------------------------===//
/// Returns failure if the options are unsupported.
-static LogicalResult verifySupportedTilingOptions(
- PatternRewriter &rewriter, Operation *op,
- const linalg::LinalgTilingOptions &options) {
+static LogicalResult
+verifySupportedTilingOptions(PatternRewriter &rewriter, Operation *op,
+ const linalg::LinalgTilingOptions &options) {
if (!options.interchangeVector.empty()) {
return rewriter.notifyMatchFailure(op,
"unsupported interchange during tiling");
@@ -93,12 +93,13 @@
/// distributed loops. It is a stack, and once an entry at the top of the
/// stack is used for distribution it is popped before processing the inner
/// loops.
-static FailureOr<TiledOp> tileInterfaceOpImpl(
- OpBuilder &builder, TiledOpInterface tilableOp, ValueRange outputs,
- MutableArrayRef<OpFoldResult> tileSizes, ArrayRef<StringRef> iteratorTypes,
- ArrayRef<Range> loopBounds, unsigned loopDepth,
- SmallVectorImpl<OpFoldResult> &offsets,
- ArrayRef<linalg::ProcInfo> distributionInfo) {
+static FailureOr<TiledOp>
+tileInterfaceOpImpl(OpBuilder &builder, TiledOpInterface tilableOp,
+ ValueRange outputs, MutableArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<StringRef> iteratorTypes,
+ ArrayRef<Range> loopBounds, unsigned loopDepth,
+ SmallVectorImpl<OpFoldResult> &offsets,
+ ArrayRef<linalg::ProcInfo> distributionInfo) {
Location loc = tilableOp.getLoc();
// If this is the innermost loop, then generated the tiled implementation of
// the op by invoking the TiledOpInterface methods.
@@ -166,7 +167,8 @@
tileInterfaceOpImpl(b, tilableOp, (isBufferTiling ? outputs : args),
tileSizes, iteratorTypes, loopBounds,
loopDepth + 1, offsets, distributionInfo);
- if (failed(innerReturnValue)) return;
+ if (failed(innerReturnValue))
+ return;
b.create<scf::YieldOp>(loc, innerReturnValue->results);
});
if (failed(innerReturnValue)) {
@@ -197,7 +199,8 @@
auto tileSizes = getAsOpFoldResult(tileSizesVals);
tileSizes.resize(iteratorTypes.size(), zeroAttr);
for (auto en : llvm::enumerate(iteratorTypes)) {
- if (en.value() == getParallelIteratorTypeName()) continue;
+ if (en.value() == getParallelIteratorTypeName())
+ continue;
if (!isUntiledLoop(tileSizes[en.index()])) {
return static_cast<LogicalResult>(tilableOp.emitOpError(
"unimplemented tiling of non-parallel loop iterator type"));
@@ -219,8 +222,10 @@
if (options.distribution) {
SmallVector<Range> distributedLoopRange;
for (auto i : llvm::seq<unsigned>(0, tileSizes.size())) {
- if (isUntiledLoop(tileSizes[i])) continue;
- if (iteratorTypes[i] != getParallelIteratorTypeName()) continue;
+ if (isUntiledLoop(tileSizes[i]))
+ continue;
+ if (iteratorTypes[i] != getParallelIteratorTypeName())
+ continue;
distributedLoopRange.push_back(loopBounds[i]);
}
distributionInfo = options.distribution->procInfo(b, tilableOp.getLoc(),
@@ -243,7 +248,8 @@
}
FailureOr<TiledOp> res = tileInterfaceOp(rewriter, tilableOp, options);
- if (failed(res)) return res;
+ if (failed(res))
+ return res;
result = *res;
if (result.op) {
filter.replaceLinalgTransformationFilter(rewriter, result.op);
@@ -267,7 +273,7 @@
}
void runOnOperation() override;
};
-} // namespace
+} // namespace
template <typename OpTy>
static Value buildFlowWorkgroupInfoOp(OpBuilder &b, unsigned dim) {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToAsync.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToAsync.cpp
index 64514bb..61989dc 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToAsync.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToAsync.cpp
@@ -1,17 +1,14 @@
-//===- InParallelToAsync.cpp - Rewrite InParallel as Async ----------------===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
#include <cstdlib>
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
-#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Async/IR/Async.h"
@@ -24,6 +21,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::iree_compiler::IREE::LinalgExt;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToSequentialFor.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToSequentialFor.cpp
index 683629b..f5a4d40 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToSequentialFor.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToSequentialFor.cpp
@@ -1,15 +1,12 @@
-//===- InParallelToSequentialFor.cpp.cpp - Rewrite InParallel as ForOp ---===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===---------------------------------------------------------------------===//
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
-#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -21,6 +18,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::iree_compiler::IREE::LinalgExt;
@@ -32,7 +30,7 @@
op.yieldingOps(), [](ParallelInsertSliceOp op) { return op.dest(); }));
}
-} // namespace
+} // namespace
FailureOr<scf::ForOp> InParallelOpToScfForRewriter::returningMatchAndRewrite(
InParallelOp inParallelOp, PatternRewriter &rewriter) const {
@@ -86,7 +84,8 @@
for (Value toReplace : valuesToYield) {
for (OpOperand &u : toReplace.getUses()) {
Operation *op = u.getOwner();
- if (!forOp->isProperAncestor(op)) continue;
+ if (!forOp->isProperAncestor(op))
+ continue;
opsToReplace.push_back(op);
}
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
index 6a03048..2b8f8ec 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
@@ -1,10 +1,8 @@
-//===-- LinalgExtBufferization.cpp - Linalg Extension bufferization -------===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
#include "iree-dialects/Dialect/LinalgExt/LinalgExtBufferization.h"
@@ -18,6 +16,15 @@
using namespace mlir;
using namespace mlir::iree_compiler::IREE::LinalgExt;
+using bufferization::AnalysisState;
+using bufferization::BufferizableOpInterface;
+using bufferization::BufferizationState;
+using bufferization::BufferRelation;
+using bufferization::getMemRefType;
+using bufferization::replaceOpWithBufferizedValues;
+using bufferization::replaceOpWithNewBufferizedOp;
+using tensor::ExtractSliceOp;
+
/// Return the destinations that an InParallelOp is inserting into. One per
/// ParallelInsertSliceOp.
static SmallVector<OpOperand *> getInsertionDest(InParallelOp inParallelOp) {
@@ -34,15 +41,6 @@
}
namespace mlir {
-
-using bufferization::BufferizableOpInterface;
-using bufferization::BufferizationState;
-using bufferization::BufferRelation;
-using bufferization::getMemRefType;
-using bufferization::replaceOpWithBufferizedValues;
-using bufferization::replaceOpWithNewBufferizedOp;
-using tensor::ExtractSliceOp;
-
namespace iree_compiler {
namespace IREE {
namespace LinalgExt {
@@ -54,15 +52,16 @@
struct InParallelOpInterface
: public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
InParallelOp> {
- SmallVector<OpOperand *> getAliasingOpOperand(
- Operation *op, OpResult opResult, const BufferizationState &state) const {
+ SmallVector<OpOperand *>
+ getAliasingOpOperand(Operation *op, OpResult opResult,
+ const AnalysisState &state) const {
// Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
auto inParallelOp = cast<InParallelOp>(op);
return {getInsertionDest(inParallelOp)[opResult.getResultNumber()]};
}
bool isMemoryWrite(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// This op is a memory write. Stop lookup here to avoid finding false
// conflicts involving this op and one of the ops in the region. This is
// similar to how scf.if ops are analyzed.
@@ -72,12 +71,12 @@
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &b,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
OpBuilder::InsertionGuard g(b);
auto inParallelOp = cast<InParallelOp>(op);
Block *body = &inParallelOp.region().front();
@@ -89,7 +88,7 @@
SmallVector<Value> newResults;
for (OpResult opResult : inParallelOp->getOpResults()) {
SmallVector<OpOperand *> insertDestOperands =
- state.getAliasingOpOperand(opResult);
+ state.getAnalysisState().getAliasingOpOperand(opResult);
assert(insertDestOperands.size() == 1 &&
"expected exactly one aliasing OpOperand");
// Insert copies right before the PerformConcurrentlyOp terminator. They
@@ -153,7 +152,8 @@
b.eraseOp(insertOp);
return WalkResult::advance();
});
- if (walkResult.wasInterrupted()) return failure();
+ if (walkResult.wasInterrupted())
+ return failure();
// Replace the op.
replaceOpWithBufferizedValues(b, op, newResults);
@@ -167,18 +167,19 @@
: public BufferizableOpInterface::ExternalModel<
PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
LogicalResult bufferize(Operation *op, RewriterBase &b,
- const BufferizationState &state) const {
- llvm_unreachable("op does not have any tensor OpOperands / OpResults");
+ BufferizationState &state) const {
+ assert(false && "op does not have any tensor OpOperands / OpResults");
return failure();
}
};
/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
-static bool areEquivalentExtractSliceOps(const BufferizationState &state,
+static bool areEquivalentExtractSliceOps(const AnalysisState &state,
ExtractSliceOp st,
ParallelInsertSliceOp sti) {
- if (!st || !sti) return false;
+ if (!st || !sti)
+ return false;
if (st != sti &&
!state.areEquivalentBufferizedValues(st.source(), sti.dest()))
return false;
@@ -189,12 +190,12 @@
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
-static bool hasMatchingExtractSliceOp(const BufferizationState &state,
- Value value,
+static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
ParallelInsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
- if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) return true;
+ if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
+ return true;
return false;
};
@@ -206,10 +207,10 @@
struct ParallelInsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<
ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
- SmallVector<OpResult> getAliasingOpResult(
- Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
- if (&opOperand != &op->getOpOperand(1) /*dest*/) return {};
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ if (&opOperand != &op->getOpOperand(1) /*dest*/)
+ return {};
// ParallelInsertSliceOp itself has no results. Tensors are returned via
// the parent op.
@@ -223,7 +224,8 @@
unsigned int opIdx = 0;
for (ParallelInsertSliceOp insertOp :
block->getOps<ParallelInsertSliceOp>()) {
- if (insertOp.getOperation() == op) break;
+ if (insertOp.getOperation() == op)
+ break;
++opIdx;
}
assert(opIdx < inParallelOp->getNumResults() &&
@@ -233,22 +235,22 @@
}
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return &opOperand == &op->getOpOperand(1) /*dest*/;
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &b,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
// Will be bufferized as part of InParallelOp.
return failure();
}
@@ -257,7 +259,7 @@
// the code.
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
@@ -332,10 +334,10 @@
return false;
}
};
-} // namespace LinalgExt
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
void mlir::iree_compiler::IREE::LinalgExt::
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToInParallel.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToInParallel.cpp
index 83ece71..96810d0 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToInParallel.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToInParallel.cpp
@@ -1,16 +1,12 @@
-//===- TileToInParallel.cpp.cpp - Rewrite TileOp as InParallel -----------===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,14 +18,16 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::iree_compiler::IREE::LinalgExt;
-FailureOr<iree_compiler::IREE::LinalgExt::InParallelOp> mlir::iree_compiler::
- IREE::LinalgExt::TileOpToInParallelRewriter::returningMatchAndRewrite(
- iree_compiler::IREE::LinalgExt::TileOp tileOp,
- PatternRewriter &rewriter) const {
+FailureOr<iree_compiler::IREE::LinalgExt::InParallelOp>
+mlir::iree_compiler::IREE::LinalgExt::TileOpToInParallelRewriter::
+ returningMatchAndRewrite(iree_compiler::IREE::LinalgExt::TileOp tileOp,
+ PatternRewriter &rewriter) const {
// TODO: verifier.
assert(tileOp.getNumResults() > 0 &&
tileOp.outs().size() == tileOp.getNumResults());
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToSequentialFor.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToSequentialFor.cpp
index 657eedd..e6451cc 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToSequentialFor.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToSequentialFor.cpp
@@ -1,15 +1,12 @@
-//===- LowerToSCF.cpp.cpp - Lower to SCF ---------------------------------===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===---------------------------------------------------------------------===//
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
-#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -21,14 +18,15 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::iree_compiler::IREE::LinalgExt;
-FailureOr<scf::ForOp> mlir::iree_compiler::IREE::LinalgExt::
- TileOpToSCFRewriter::returningMatchAndRewrite(
- iree_compiler::IREE::LinalgExt::TileOp tileOp,
- PatternRewriter &rewriter) const {
+FailureOr<scf::ForOp>
+mlir::iree_compiler::IREE::LinalgExt::TileOpToSCFRewriter::
+ returningMatchAndRewrite(iree_compiler::IREE::LinalgExt::TileOp tileOp,
+ PatternRewriter &rewriter) const {
// TODO: verifier.
assert(tileOp.getNumResults() > 0 &&
tileOp.outs().size() == tileOp.getNumResults());
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
index 0e55970..bd75540 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -1,10 +1,8 @@
-//===- Tiling.cpp - Tiling using TilingInterface --------------------------===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
@@ -142,16 +140,17 @@
struct OpTilingPattern : public OpInterfaceRewritePattern<TilingInterface> {
OpTilingPattern(MLIRContext *context, linalg::LinalgTilingOptions opt,
linalg::LinalgTransformationFilter filt)
- : OpInterfaceRewritePattern<TilingInterface>(context),
- options(opt),
+ : OpInterfaceRewritePattern<TilingInterface>(context), options(opt),
filter(filt) {}
LogicalResult matchAndRewrite(TilingInterface op,
PatternRewriter &rewriter) const override {
- if (failed(filter.checkAndNotify(rewriter, op))) return failure();
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
/// Currently only handle single result operations.
- if (op->getNumResults() != 1) return failure();
+ if (op->getNumResults() != 1)
+ return failure();
Location loc = op->getLoc();
// Get rank and tile sizes.
@@ -178,7 +177,7 @@
return success();
}
- private:
+private:
linalg::LinalgTilingOptions options;
linalg::LinalgTransformationFilter filter;
};
@@ -190,14 +189,14 @@
SliceOpTiledOpSwapPattern(MLIRContext *context,
linalg::LinalgTilingOptions opt,
linalg::LinalgTransformationFilter filt)
- : OpRewritePattern<tensor::ExtractSliceOp>(context),
- options(opt),
+ : OpRewritePattern<tensor::ExtractSliceOp>(context), options(opt),
filter(filt) {}
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto sourceOp = sliceOp.source().getDefiningOp<TilingInterface>();
- if (!sourceOp || !filter.hasReplacementFilter(sourceOp)) return failure();
+ if (!sourceOp || !filter.hasReplacementFilter(sourceOp))
+ return failure();
SmallVector<Operation *> tiledOps = sourceOp.getTiledImplementation(
rewriter, sourceOp.getDestinationOperands(rewriter),
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
@@ -208,9 +207,9 @@
return success();
}
- private:
+private:
linalg::LinalgTilingOptions options;
linalg::LinalgTransformationFilter filter;
};
-} // namespace
+} // namespace
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingExternalModels.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingExternalModels.cpp
index 7174daa..174d4ff 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingExternalModels.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingExternalModels.cpp
@@ -1,19 +1,17 @@
-//===- TilingExternalModels.cpp - External models for TilingInterface -----===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "linalg-ext-tiling"
@@ -22,7 +20,8 @@
using namespace mlir::iree_compiler::IREE::LinalgExt;
static Value getAsValue(OpBuilder &b, Location loc, OpFoldResult ofr) {
- if (auto v = ofr.dyn_cast<Value>()) return v;
+ if (auto v = ofr.dyn_cast<Value>())
+ return v;
return b.create<arith::ConstantIndexOp>(
loc, ofr.get<Attribute>().cast<IntegerAttr>().getInt());
}
@@ -30,16 +29,15 @@
ArrayRef<OpFoldResult> ofrs) {
SmallVector<Value> vals;
vals.reserve(ofrs.size());
- for (auto ofr : ofrs) vals.push_back(getAsValue(b, loc, ofr));
+ for (auto ofr : ofrs)
+ vals.push_back(getAsValue(b, loc, ofr));
return vals;
}
-static SmallVector<Value, 4> makeTiledInputShapes(OpBuilder &b, Location loc,
- LinalgOp linalgOp,
- ArrayRef<Value> valuesToTile,
- ArrayRef<Value> ivsRef,
- ArrayRef<Value> tileSizesRef,
- ArrayRef<Value> sizeBounds) {
+static SmallVector<Value, 4>
+makeTiledInputShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
+ ArrayRef<Value> valuesToTile, ArrayRef<Value> ivsRef,
+ ArrayRef<Value> tileSizesRef, ArrayRef<Value> sizeBounds) {
assert(static_cast<int64_t>(valuesToTile.size()) == linalgOp.getNumInputs() &&
"expected one value to tile for every operand");
@@ -96,10 +94,11 @@
return linalgOp.createLoopRanges(b, op->getLoc());
}
- SmallVector<Operation *> getTiledImplementation(
- Operation *op, OpBuilder &b, ValueRange tiledDest,
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
- bool tileDestOperands) const {
+ SmallVector<Operation *>
+ getTiledImplementation(Operation *op, OpBuilder &b, ValueRange tiledDest,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ bool tileDestOperands) const {
LinalgOp linalgOp = cast<LinalgOp>(op);
if (op->getNumResults() != 1) {
// TODO: Need a failure message here, but `notifyMatchFailure` is only a
@@ -109,11 +108,13 @@
Location loc = op->getLoc();
AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
- if (!shapeSizesToLoopsMap) return {};
+ if (!shapeSizesToLoopsMap)
+ return {};
OpOperand *outOperand = linalgOp.getOutputOperand(0);
AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand);
- if (!indexingMap.isProjectedPermutation()) return {};
+ if (!indexingMap.isProjectedPermutation())
+ return {};
SmallVector<Value> offsetsVals = getAsValues(b, loc, offsets);
SmallVector<Value> sizeVals = getAsValues(b, loc, sizes);
@@ -153,7 +154,7 @@
return {linalgOp.clone(b, loc, tiledDest.getTypes(), tiledOperands)};
}
};
-} // namespace
+} // namespace
template <typename OpType>
void registerOne(DialectRegistry ®istry) {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingToTileOp.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingToTileOp.cpp
index ba8cc4d..ba14818 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingToTileOp.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingToTileOp.cpp
@@ -1,10 +1,8 @@
-//===- TilingToTileOp.cpp - Tiling using to TileOp TilingInterface --------===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
@@ -63,9 +61,10 @@
return TilingResult{tileOp, tiledOp};
}
-FailureOr<Operation *> mlir::iree_compiler::IREE::LinalgExt::
- LinalgExtTilingPattern::returningMatchAndRewrite(
- TilingInterface op, PatternRewriter &rewriter) const {
+FailureOr<Operation *>
+mlir::iree_compiler::IREE::LinalgExt::LinalgExtTilingPattern::
+ returningMatchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const {
/// Currently only handle single result operations.
if (op->getNumResults() != 1)
return rewriter.notifyMatchFailure(op, "Not a single result");
@@ -78,7 +77,8 @@
int64_t dim = -1;
for (auto en : llvm::enumerate(tileSizes)) {
Optional<int64_t> maybeTileSize = getConstantIntValue(en.value());
- if (maybeTileSize && *maybeTileSize == 0) continue;
+ if (maybeTileSize && *maybeTileSize == 0)
+ continue;
if (maybeTileSize && *maybeTileSize < 0)
return rewriter.notifyMatchFailure(op, "Negative tile size");
if (dim >= 0)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Utils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Utils.cpp
index 9b250b8..da58824 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Utils.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Utils.cpp
@@ -1,10 +1,8 @@
-//===- Utils.cpp - LinalgExt transform utils ------------------------------===//
+// Copyright 2021 The IREE Authors
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
@@ -36,7 +34,8 @@
offsets = SmallVector<Value>(leadingOffsets.begin(), leadingOffsets.end());
sizes = SmallVector<Value>(leadingSizes.begin(), leadingSizes.end());
strides = SmallVector<Value>(leadingStrides.begin(), leadingStrides.end());
- if (leadingRank >= tensorRank) return;
+ if (leadingRank >= tensorRank)
+ return;
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
for (int64_t i = leadingRank, e = tensorRank; i < e; ++i) {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt
new file mode 100644
index 0000000..9f57627
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt
new file mode 100644
index 0000000..7b20d53
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/CMakeLists.txt
@@ -0,0 +1,40 @@
+add_mlir_library(IREELinalgTransformDialect
+ LinalgTransformOps.cpp
+ PDL.cpp
+ ScopedTransform.cpp
+ TransformOpInterface.cpp
+ TrackingListener.cpp
+ TrackingRewriteDriver.cpp
+
+ DEPENDS
+ mlir-headers
+
+ LINK_LIBS PUBLIC
+ IREEDialectsTransforms
+ MLIRIR
+
+ # Dialects
+ IREELinalgExtDialect
+ IREELinalgExtTransforms
+
+ MLIRAsync
+ MLIRControlFlowInterfaces
+ MLIRLinalg
+ MLIRPDL
+ MLIRRewrite
+
+ # Transforms
+ MLIRAsyncTransforms
+ MLIRLinalgTransforms
+ MLIRAffineToStandard
+ MLIRTransforms
+ MLIRReconcileUnrealizedCasts
+
+ # Conversions
+ MLIRAsyncToLLVM
+ MLIRMemRefToLLVM
+ MLIRMathToLLVM
+ MLIRVectorToLLVM
+ MLIRLinalgToLLVM
+ MLIRSCFToControlFlow
+)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/FunctionHelpers.h b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/FunctionHelpers.h
new file mode 100644
index 0000000..bbc8ae3
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/FunctionHelpers.h
@@ -0,0 +1,49 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
+#include <utility>
+
+namespace mlir {
+namespace linalg {
+
+// Pure C++ functional patterns requires some type plumbing.
+namespace detail {
+template <typename OpT>
+struct ConvertOrForward {
+ static OpT to(LinalgOp op) { return cast<OpT>(op.getOperation()); }
+ static LinalgOp from(OpT op) { return cast<LinalgOp>(op.getOperation()); }
+};
+template <>
+struct ConvertOrForward<LinalgOp> {
+ static LinalgOp to(LinalgOp op) { return op; }
+ static LinalgOp from(LinalgOp op) { return op; }
+};
+} // namespace detail
+
+/// Wrap a call to a Linalg pattern where the input is a `LinalgOp` and the
+/// output is a `LinalgOp`.
+template <typename FunctionalLinalgPattern, typename... Args>
+auto callLinalgPattern(Args &&... args) {
+ FunctionalLinalgPattern pattern(std::forward<Args>(args)...);
+ using Traits = llvm::function_traits<decltype(
+ &FunctionalLinalgPattern::returningMatchAndRewrite)>;
+ using OpT = typename Traits::template arg_t<0>;
+ return
+ [pattern = std::move(pattern)](
+ LinalgOp linalgOp, PatternRewriter &rewriter) -> FailureOr<LinalgOp> {
+ OpT op = detail::ConvertOrForward<OpT>::to(linalgOp);
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return detail::ConvertOrForward<decltype(*result)>::from(*result);
+ };
+}
+
+} // namespace linalg
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp
new file mode 100644
index 0000000..e28dcd3
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/LinalgTransformOps.cpp
@@ -0,0 +1,889 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
+
+#include <algorithm>
+
+#include "FunctionHelpers.h"
+#include "PDL.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h"
+#include "iree-dialects/Dialect/LinalgTransform/TrackingListener.h"
+#include "iree-dialects/Dialect/LinalgTransform/TrackingRewriteDriver.h"
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h"
+#include "iree-dialects/Transforms/Listener.h"
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
+#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
+#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-transform-dialect"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::iree_compiler::IREE;
+
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformDialect.cpp.inc"
+
+void transform::LinalgTransformDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.cpp.inc"
+ >();
+}
+
+//===----------------------------------------------------------------------===//
+// Functional Rewrite Helpers
+//===----------------------------------------------------------------------===//
+
+using FunctionalLinalgTransform =
+ std::function<FailureOr<LinalgOp>(LinalgOp, PatternRewriter &)>;
+
+/// Extracts a vector of int64_t from an array attribute. Asserts if the
+/// attribute contains values other than integers.
+static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
+ SmallVector<int64_t> result;
+ result.reserve(attr.size());
+ for (APInt value : attr.getAsValueRange<IntegerAttr>())
+ result.push_back(value.getSExtValue());
+ return result;
+}
+
+/// Extracts a vector of unsigned from an array attribute. Asserts if the
+/// attribute contains values other than intergers. May truncate.
+static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
+ SmallVector<unsigned> result;
+ result.reserve(attr.size());
+ for (APInt value : attr.getAsValueRange<IntegerAttr>())
+ result.push_back(value.getZExtValue());
+ return result;
+}
+
+//===---------------------------------------------------------------------===//
+// ScopeOp
+//===---------------------------------------------------------------------===//
+
+void transform::ScopeOp::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (index)
+ regions.emplace_back(getResults());
+ else
+ regions.emplace_back(&body());
+}
+
+//===---------------------------------------------------------------------===//
+// SequenceOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult transform::SequenceOp::verify() {
+ WalkResult result = this->walk([](Operation *child) {
+ for (OpResult result : child->getResults()) {
+ if (llvm::hasNItemsOrLess(result.getUses(), 1))
+ continue;
+ InFlightDiagnostic diag = child->emitError()
+ << "result #" << result.getResultNumber()
+ << " has more than one use";
+ for (OpOperand &use : result.getUses()) {
+ diag.attachNote(use.getOwner()->getLoc())
+ << "used here as operand #" << use.getOperandNumber();
+ }
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return failure(result.wasInterrupted());
+}
+
+//===---------------------------------------------------------------------===//
+// MatchOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult transform::MatchOp::apply(TransformResults &results,
+ TransformState &state) {
+ Operation *topLevelOp = state.getTopLevel();
+ FailureOr<SmallVector<Operation *>> ops = findMatchingOps(*this, topLevelOp);
+ if (failed(ops))
+ return failure();
+ LLVM_DEBUG(DBGS() << "matched " << ops->size() << " ops\n");
+ results.set(getResult().cast<OpResult>(), *ops);
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// TileOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform::TileOp::applyToOne(LinalgOp target) {
+ LinalgTilingOptions tilingOptions;
+ SmallVector<int64_t> tileSizes = extractI64Array(sizes());
+ // "scalarize_dyn_dims" actually sets the same lambda as the tile sizes and
+ // asserts that it is not already set.
+ if (!tileSizes.empty() || !scalarize_dyn_dims())
+ tilingOptions.setTileSizes(tileSizes);
+ tilingOptions.setInterchange(extractUIntArray(interchange()));
+ tilingOptions.setPeeledLoops(extractI64Array(peel()));
+ if (scalarize_dyn_dims())
+ tilingOptions.scalarizeDynamicDims();
+
+ LinalgTilingPattern pattern(getContext(), tilingOptions);
+ auto functionalTile = [&](LinalgOp op,
+ PatternRewriter &rewriter) -> FailureOr<LinalgOp> {
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return result->op;
+ };
+ return functional::applyAt(target, functionalTile);
+}
+
+LogicalResult transform::TileOp::verify() {
+ if (!sizes().empty() && scalarize_dyn_dims()) {
+ return emitOpError() << sizesAttrName() << " and "
+ << scalarize_dyn_dimsAttrName()
+ << " attributes are mutually exclusive";
+ }
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// FuseOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform::FuseOp::applyToOne(LinalgOp target) {
+ LinalgTilingAndFusionOptions fusionOptions;
+ fusionOptions.tileSizes = extractI64Array(tile_sizes());
+ fusionOptions.tileInterchange = extractI64Array(tile_interchange());
+
+ LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
+ auto functionalFuse = [&](LinalgOp op,
+ PatternRewriter &rewriter) -> FailureOr<LinalgOp> {
+ auto tileLoopNest = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(tileLoopNest))
+ return failure();
+ return tileLoopNest->getRootOp();
+ };
+ return functional::applyAt(target, functionalFuse);
+}
+
+LogicalResult transform::FuseOp::verify() {
+ SmallVector<int64_t> permutation = extractI64Array(tile_interchange());
+ auto sequence = llvm::seq<int64_t>(0, permutation.size());
+ if (!std::is_permutation(sequence.begin(), sequence.end(),
+ permutation.begin(), permutation.end())) {
+ return emitOpError() << "expects interchange to be a permutation, found "
+ << tile_interchange();
+ }
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// GeneralizeOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
+ // Exit early if no transformation is needed.
+ if (isa<GenericOp>(target))
+ return target;
+ return functional::applyAt(
+ target, callLinalgPattern<LinalgGeneralizationPattern>(getContext()));
+}
+
+//===---------------------------------------------------------------------===//
+// InterchangeOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
+ SmallVector<unsigned> interchangeVector =
+ extractUIntArray(iterator_interchange());
+ // Exit early if no transformation is needed.
+ if (interchangeVector.empty())
+ return target;
+ return functional::applyAt(target,
+ callLinalgPattern<GenericOpInterchangePattern>(
+ getContext(), interchangeVector));
+}
+
+LogicalResult transform::InterchangeOp::verify() {
+ SmallVector<unsigned> permutation = extractUIntArray(iterator_interchange());
+ auto sequence = llvm::seq<unsigned>(0, permutation.size());
+ if (!std::is_permutation(sequence.begin(), sequence.end(),
+ permutation.begin(), permutation.end())) {
+ return emitOpError()
+ << "expects iterator_interchange to be a permutation, found "
+ << iterator_interchange();
+ }
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// PadOp
+//===---------------------------------------------------------------------===//
+
+/// Returns the neutral value for a Linalg operation that produces the given
+/// operand, construct using the provided builder. Currently assumes the
+/// reduction in the Linalg operation is an addition and, therefore, the neutral
+/// value is zero.
+static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
+ auto t = getElementTypeOrSelf(op.get().getType());
+ return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t,
+ b.getZeroAttr(t));
+}
+
+FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
+ // Copy the stack allocated options since the lambdas have a longer lifetime.
+ SmallVector<int64_t> packPaddings = extractI64Array(this->pack_paddings());
+ auto packFunc = [=](OpOperand &opOperand) {
+ return opOperand.getOperandNumber() < packPaddings.size()
+ ? packPaddings[opOperand.getOperandNumber()] != 0
+ : false;
+ };
+ SmallVector<int64_t> hoistPaddings = extractI64Array(this->hoist_paddings());
+ auto hoistingFunc = [=](OpOperand &opOperand) {
+ return opOperand.getOperandNumber() < hoistPaddings.size()
+ ? hoistPaddings[opOperand.getOperandNumber()]
+ : 0;
+ };
+ ArrayAttr transposePaddings = this->transpose_paddings().cast<ArrayAttr>();
+ auto transposeFunc = [=](OpOperand &opOperand) {
+ if (opOperand.getOperandNumber() >= transposePaddings.size())
+ return SmallVector<int64_t>();
+ return extractI64Array(
+ transposePaddings[opOperand.getOperandNumber()].cast<ArrayAttr>());
+ };
+ LinalgPaddingOptions paddingOptions;
+ paddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
+ paddingOptions.setPaddingNoFoldComputationFunction(packFunc);
+ paddingOptions.setPaddingHoistComputationFunction(hoistingFunc);
+ paddingOptions.setPaddingTransposeComputationFunction(transposeFunc);
+
+ return functional::applyAt(target, callLinalgPattern<LinalgPaddingPattern>(
+ getContext(), paddingOptions));
+}
+
+LogicalResult transform::PadOp::verify() {
+ SmallVector<int64_t> packPaddings = extractI64Array(pack_paddings());
+ if (any_of(packPaddings, [](int64_t packPadding) {
+ return packPadding != 0 && packPadding != 1;
+ })) {
+ return emitOpError()
+ << "expects pack_paddings to contain booleans (0/1), found "
+ << pack_paddings();
+ }
+ SmallVector<int64_t> hoistPaddings = extractI64Array(hoist_paddings());
+ if (any_of(hoistPaddings,
+ [](int64_t hoistPadding) { return hoistPadding < 0; })) {
+ return emitOpError()
+ << "expects hoist_paddings to contain positive integers, found "
+ << hoist_paddings();
+ }
+ ArrayAttr transposes = transpose_paddings();
+ for (Attribute attr : transposes) {
+ SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
+ auto sequence = llvm::seq<int64_t>(0, transpose.size());
+ if (!std::is_permutation(sequence.begin(), sequence.end(),
+ transpose.begin(), transpose.end())) {
+ return emitOpError()
+ << "expects transpose_paddings to be a permutation, found "
+ << attr;
+ }
+ }
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// DecomposeOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult
+transform::DecomposeOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ RewritePatternSet patterns(getContext());
+ // TODO: make this targetable.
+ populateDecomposeConvolutionPatterns(patterns, LinalgTransformationFilter());
+ if (failed(applyPatternsAndFoldGreedily(state.getTopLevel(),
+ std::move(patterns))))
+ return failure();
+
+ // TODO: make this chainable, it isn't in the original codegenstrategy.
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// VectorizeOp
+//===---------------------------------------------------------------------===//
+
+static void configureVectorizationPatterns(transform::VectorizeOp vectorizeOp,
+ RewritePatternSet &patterns) {
+ MLIRContext *ctx = vectorizeOp->getContext();
+ vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
+ vector::populateVectorReductionToContractPatterns(patterns);
+ patterns.add<linalg::LinalgCopyVTRForwardingPattern,
+ linalg::LinalgCopyVTWForwardingPattern>(ctx,
+ /*benefit=*/2);
+ vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
+ vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
+ if (vectorizeOp.vectorize_padding())
+ linalg::populatePadOpVectorizationPatterns(patterns);
+}
+
+/// Applies the transformation specified by the given vectorize operation to the
+/// given target operation AND some related operations.Populates `results` with
+/// transformation operations for further transformations if the pattern applied
+/// successfully (currently, the main "contraction" op after vectorization).
+static FailureOr<LinalgOp>
+executeTargetedVectorizeOp(LinalgOp target,
+ linalg::transform::VectorizeOp vectorizeOp) {
+ // TODO: this is copy-pasta from LinalgStrategyVectorizePass, it shouldn't be.
+ MLIRContext *ctx = target->getContext();
+ RewritePatternSet patterns(ctx);
+ configureVectorizationPatterns(vectorizeOp, patterns);
+ LinalgVectorizationPattern pattern(vectorizeOp.getContext());
+ auto functionalVectorize = [&](LinalgOp op, PatternRewriter &rewriter) {
+ return pattern.matchAndRewrite(op, rewriter);
+ };
+
+ /// Apply the transformations in a scope.
+ return transform::scoped(
+ target,
+ [&](transform::ScopeOp scope, Operation *op) -> FailureOr<LinalgOp> {
+ if (failed(functional::applyAt(op, functionalVectorize)) ||
+ failed(applyPatternsAndFoldGreedily(scope, std::move(patterns))))
+ return failure();
+ // FIXME: Vectorization doesn't return anything.
+ return LinalgOp();
+ });
+
+ // TODO: vectorization may fail because the op is not vectorizable, unclear
+ // what to do here. We should probably report it somehow, but we may also
+ // want to go on and keep the original for continuation. Should we have
+ // some notion of transformation optionality vs. mandatory (like lowering)?
+ // How to find ops that were not replaced?
+}
+
+LogicalResult
+transform::VectorizeOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ if (target()) {
+ SmallVector<Operation *> resultVector;
+ LogicalResult res = applyTransformToEach(
+ state.getPayloadOps(target()), resultVector, [&](LinalgOp target) {
+ return executeTargetedVectorizeOp(target, *this);
+ });
+
+ if (failed(res))
+ return failure();
+
+ results.set(getResult(0).cast<OpResult>(), resultVector);
+ return success();
+ }
+
+ MLIRContext *ctx = getContext();
+ RewritePatternSet patterns(ctx);
+ patterns.add<LinalgVectorizationPattern>(ctx);
+ configureVectorizationPatterns(*this, patterns);
+ auto &listener = state.getExtension<TrackingListener>();
+ LogicalResult applicationResult = applyPatternsTrackAndFoldGreedily(
+ state.getTopLevel(), listener, std::move(patterns));
+ LogicalResult listenerResult = listener.checkErrorState();
+ return failure(failed(applicationResult) || failed(listenerResult));
+}
+
+ParseResult transform::VectorizeOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ auto operationType = pdl::OperationType::get(parser.getContext());
+ OpAsmParser::OperandType target;
+ OptionalParseResult parseResult = parser.parseOptionalOperand(target);
+ if (parseResult.hasValue()) {
+ if (parseResult.getValue().failed() ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.resolveOperand(target, operationType, result.operands) ||
+ parser.addTypeToList(operationType, result.types)) {
+ return failure();
+ }
+ } else {
+ if (parser.parseOptionalAttrDict(result.attributes)) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+void transform::VectorizeOp::print(OpAsmPrinter &printer) {
+ if (target())
+ printer << " " << target() << " ";
+ printer.printOptionalAttrDict(getOperation()->getAttrs());
+}
+
+//===---------------------------------------------------------------------===//
+// LowerVectorsOp
+//===---------------------------------------------------------------------===//
+
+/// Returns true of the numbered vector lowering stage is included into the list
+/// of stages specified on the given lowerVectors operation.
+static bool stageIncluded(int stage, transform::LowerVectorsOp lowerVectorsOp) {
+ for (auto s : lowerVectorsOp.stages().getAsValueRange<IntegerAttr>()) {
+ if (s.getSExtValue() == stage)
+ return true;
+ }
+ return false;
+}
+
+// Applies the transformation specified by the given lower vectors operation
+/// to the given function.
+LogicalResult
+transform::LowerVectorsOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ MLIRContext *ctx = getContext();
+ RewritePatternSet patterns(ctx);
+
+ vector::VectorTransposeLowering vectorTransposeLowering =
+ llvm::StringSwitch<vector::VectorTransposeLowering>(transpose_lowering())
+ .Case("eltwise", vector::VectorTransposeLowering::EltWise)
+ .Case("flat_transpose", vector::VectorTransposeLowering::Flat)
+ .Case("shuffle", vector::VectorTransposeLowering::Shuffle)
+ .Default(vector::VectorTransposeLowering::EltWise);
+ vector::VectorMultiReductionLowering vectorMultiReductionLowering =
+ llvm::StringSwitch<vector::VectorMultiReductionLowering>(
+ multireduction_lowering())
+ .Case("innerreduction",
+ vector::VectorMultiReductionLowering::InnerReduction)
+ .Default(vector::VectorMultiReductionLowering::InnerParallel);
+ vector::VectorContractLowering vectorContractLowering =
+ llvm::StringSwitch<vector::VectorContractLowering>(contraction_lowering())
+ .Case("matrixintrinsics", vector::VectorContractLowering::Matmul)
+ .Case("dot", vector::VectorContractLowering::Dot)
+ .Case("outerproduct", vector::VectorContractLowering::OuterProduct)
+ .Default(vector::VectorContractLowering::OuterProduct);
+ // TODO: fix the annoying name mismatch (vector-transfers vs vector-transfer).
+ vector::VectorTransferSplit vectorTransferSplit =
+ llvm::StringSwitch<vector::VectorTransferSplit>(split_transfers())
+ .Case("none", vector::VectorTransferSplit::None)
+ .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy)
+ .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer)
+ .Default(vector::VectorTransferSplit::None);
+
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering)
+ .setVectorMultiReductionLowering(vectorMultiReductionLowering)
+ .setVectorTransposeLowering(vectorTransposeLowering)
+ .setVectorTransferSplit(vectorTransferSplit);
+
+ VectorTransferToSCFOptions vectorTransferToSCFOptions =
+ VectorTransferToSCFOptions()
+ .enableFullUnroll(unroll_vector_transfers())
+ .enableLowerPermutationMaps();
+
+ int maxTransferRank = 1;
+
+ auto avx2LoweringOptions =
+ x86vector::avx2::LoweringOptions().setTransposeOptions(
+ x86vector::avx2::TransposeLoweringOptions()
+ .lower4x8xf32(transpose_avx2_lowering())
+ .lower8x8xf32(transpose_avx2_lowering()));
+
+ // TODO: this is copy-pasta from LinalgStrategyLowerVectorsPass, shouldn't be.
+ vector::populateVectorToVectorCanonicalizationPatterns(patterns);
+ if (stageIncluded(1, *this)) {
+ patterns.add<mlir::vector::ContractionOpToOuterProductOpLowering,
+ mlir::vector::ContractionOpToMatmulOpLowering,
+ mlir::vector::ContractionOpLowering>(vectorTransformOptions,
+ ctx);
+ vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
+ }
+ if (stageIncluded(2, *this)) {
+ vector::populateVectorMultiReductionLoweringPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
+ }
+ if (stageIncluded(3, *this)) {
+ patterns.add<vector::VectorTransferFullPartialRewriter>(
+ ctx, vectorTransformOptions);
+ }
+ if (stageIncluded(4, *this)) {
+ vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank);
+ }
+ if (stageIncluded(5, *this)) {
+ populateVectorToSCFConversionPatterns(
+ patterns, vectorTransferToSCFOptions.setTargetRank(maxTransferRank));
+ }
+ if (stageIncluded(6, *this)) {
+ vector::populateVectorShapeCastLoweringPatterns(patterns);
+ }
+ if (stageIncluded(7, (*this))) {
+ vector::populateVectorTransposeLoweringPatterns(patterns,
+ vectorTransformOptions);
+ if (transpose_avx2_lowering())
+ x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
+ patterns, avx2LoweringOptions, /*benefit=*/10);
+ }
+
+ // TODO: these transformations are currently not targeted at concrete ops.
+ // LinalgTransformationFilter filter = makeTransformationFilter(target);
+ if (failed(applyPatternsAndFoldGreedily(state.getTopLevel(),
+ std::move(patterns))))
+ return failure();
+
+ // TODO: make composable...
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// BufferizeOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult transform::BufferizeOp::apply(transform::TransformResults &result,
+ transform::TransformState &state) {
+ PassManager pm(getContext());
+
+ bufferization::OneShotBufferizationOptions options;
+ options.memCpyFn = [](OpBuilder &builder, Location loc, Value from,
+ Value to) {
+ return success(linalg::makeMemRefCopyOp(builder, loc, from, to));
+ };
+ pm.addPass(createLinalgComprehensiveModuleBufferizePass(options));
+ if (failed(pm.run(state.getTopLevel())))
+ return failure();
+
+ // Perform buffer-level hoistings.
+ state.getTopLevel()->walk(
+ [&](FuncOp funcOp) { hoistRedundantVectorTransfers(funcOp); });
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// LowerToLLVMOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult
+transform::LowerToLLVMOp::apply(transform::TransformResults &result,
+ transform::TransformState &state) {
+ // TODO: it is feasible to scope lowering at arbitrary level and introduce
+ // unrealized casts, but there needs to be the final module-wise cleanup in
+ // the end. Keep module-level for now.
+ PassManager pm(getContext());
+
+ pm.addNestedPass<FuncOp>(createConvertVectorToSCFPass());
+ pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
+ if (enable_async()) {
+ pm.addPass(createAsyncToAsyncRuntimePass());
+ pm.addPass(createAsyncRuntimeRefCountingPass());
+ pm.addPass(createAsyncRuntimeRefCountingOptPass());
+ }
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createLowerAffinePass());
+ pm.addPass(createConvertSCFToCFPass());
+ pm.addPass(createConvertLinalgToLLVMPass());
+ pm.addPass(createConvertVectorToLLVMPass(
+ // clang-format off
+ LowerVectorToLLVMOptions()
+ .enableReassociateFPReductions(reassociate_fp_reductions())
+ .enableIndexOptimizations(enable_index_optimizations())
+ .enableArmNeon(enable_arm_neon())
+ .enableArmSVE(enable_arm_sve())
+ .enableAMX(enable_amx())
+ .enableX86Vector(enable_x86vector())));
+ // clang-format on
+ pm.addNestedPass<FuncOp>(createConvertMathToLLVMPass());
+ pm.addPass(createMemRefToLLVMPass());
+ if (enable_async())
+ pm.addPass(createConvertAsyncToLLVMPass());
+ pm.addPass(createConvertFuncToLLVMPass());
+ pm.addPass(createReconcileUnrealizedCastsPass());
+ if (failed(pm.run(state.getTopLevel())))
+ return failure();
+
+ // Make all arguments noalias for now.
+ // FIXME: this is a terrible hack!
+ state.getTopLevel()->walk([](LLVM::LLVMFuncOp funcOp) {
+ for (int64_t i = 0; i < funcOp.getNumArguments(); ++i) {
+ if (!funcOp.getType().getParamType(i).isa<LLVM::LLVMPointerType>())
+ continue;
+ funcOp.setArgAttr(i, "llvm.noalias", UnitAttr::get(funcOp.getContext()));
+ }
+ });
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// GetParentLoopOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<scf::ForOp>
+transform::GetParentLoopOp::applyToOne(Operation *source) {
+ int64_t nLoops = num_loops();
+ for (int64_t i = 0; i < nLoops; ++i) {
+ source = source->getParentOfType<scf::ForOp>();
+ if (!source) {
+ emitError() << "the transformed op is enclosed by " << i << " loops, but "
+ << nLoops << " expected";
+ return failure();
+ }
+ }
+ return cast<scf::ForOp>(source);
+}
+
+//===---------------------------------------------------------------------===//
+// UnrollLoopOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult transform::UnrollLoopOp::applyToOne(scf::ForOp loop) {
+ return loopUnrollByFactor(loop, factor());
+}
+
+//===---------------------------------------------------------------------===//
+// PipelineLoopOp
+//===---------------------------------------------------------------------===//
+
+static void
+loopScheduling(scf::ForOp forOp,
+ std::vector<std::pair<Operation *, unsigned>> &schedule,
+ unsigned iterationInterval, unsigned readLatency) {
+ auto getLatency = [&](Operation *op) {
+ if (isa<vector::TransferReadOp>(op))
+ return readLatency;
+ return unsigned(1);
+ };
+
+ DenseMap<Operation *, unsigned> opCycles;
+ std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
+ for (Operation &op : forOp.getBody()->getOperations()) {
+ if (isa<scf::YieldOp>(op))
+ continue;
+ unsigned earlyCycle = 0;
+ for (Value operand : op.getOperands()) {
+ Operation *def = operand.getDefiningOp();
+ if (!def)
+ continue;
+ earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
+ }
+ opCycles[&op] = earlyCycle;
+ wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
+ }
+ for (auto it : wrappedSchedule) {
+ for (Operation *op : it.second) {
+ unsigned cycle = opCycles[op];
+ schedule.push_back(std::make_pair(op, cycle / iterationInterval));
+ }
+ }
+}
+
+FailureOr<scf::ForOp> transform::PipelineLoopOp::applyToOne(scf::ForOp loop) {
+ // TODO: make the pipelining pattern return the transformed loop.
+ if (!getOperation()->getUses().empty()) {
+ InFlightDiagnostic diag = emitError()
+ << "NYI: cannot target the result of pipelining";
+ diag.attachNote(getOperation()->use_begin()->getOwner()->getLoc())
+ << "use here";
+ return failure();
+ }
+
+ scf::PipeliningOption schedule;
+ schedule.getScheduleFn =
+ [this](scf::ForOp forOp,
+ std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
+ loopScheduling(forOp, schedule, iteration_interval(), read_latency());
+ };
+
+ RewritePatternSet patterns(loop->getContext());
+ scf::populateSCFLoopPipeliningPatterns(patterns, schedule);
+ assert(patterns.getNativePatterns().size() == 1 &&
+ "expected one pipelining pattern");
+ auto functionalPattern = [&patterns](scf::ForOp forOp,
+ PatternRewriter &rewriter) {
+ RewritePattern *pattern = patterns.getNativePatterns().front().get();
+ return pattern->matchAndRewrite(forOp, rewriter);
+ };
+ if (failed(functional::applyAt(loop, std::move(functionalPattern))))
+ return failure();
+
+ return scf::ForOp();
+}
+
+//===---------------------------------------------------------------------===//
+// OutlineLoopOp
+//===---------------------------------------------------------------------===//
+
+static scf::ExecuteRegionOp outlineInExecuteRegion(RewriterBase &b,
+ Operation *op) {
+ if (op->getNumRegions() != 1)
+ return nullptr;
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(op);
+ scf::ExecuteRegionOp executeRegionOp =
+ b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
+ {
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
+ Operation *clonedOp = b.cloneWithoutRegions(*op);
+ Region &clonedRegion = clonedOp->getRegions().front();
+ assert(clonedRegion.empty() && "expected empty region");
+ b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
+ clonedRegion.end());
+ b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
+ }
+ b.replaceOp(op, executeRegionOp.getResults());
+ return executeRegionOp;
+}
+
+static FailureOr<FuncOp> outlineLoop(scf::ForOp loop, StringRef funcName,
+ transform::TransformState &state) {
+ PatternRewriterListener rewriter(loop->getContext());
+ auto &listener = state.getExtension<TrackingListener>();
+ rewriter.addListener(&listener);
+ Location loc = loop.getLoc();
+ scf::ExecuteRegionOp exec = outlineInExecuteRegion(rewriter, loop);
+ assert(exec && "failed to produce execute_region");
+ FailureOr<FuncOp> outlined =
+ outlineSingleBlockRegion(rewriter, loc, exec.getRegion(), funcName);
+ if (failed(listener.checkErrorState()))
+ return failure();
+ return outlined;
+}
+
+LogicalResult
+transform::OutlineLoopOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Operation *> resultVector;
+ auto res =
+ applyTransformToEach(state.getPayloadOps(target()), resultVector,
+ [&](scf::ForOp loop) -> FailureOr<FuncOp> {
+ return outlineLoop(loop, func_name(), state);
+ });
+ if (failed(res))
+ return failure();
+ results.set(getResult().cast<OpResult>(), resultVector);
+ return success();
+}
+
+//===---------------------------------------------------------------------===//
+// PrintOp
+//===---------------------------------------------------------------------===//
+
+LogicalResult transform::PrintOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ llvm::outs() << "[[[ IR printer: " << name() << " ]]]\n";
+ state.getTopLevel()->dump();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LinalgExt specific transforms
+//===----------------------------------------------------------------------===//
+
+FailureOr<Operation *>
+transform::TileToLinalgExtTileOp::applyToOne(TilingInterface target) {
+ LinalgTilingOptions tilingOptions;
+ SmallVector<int64_t> tileSizes = extractI64Array(sizes());
+ if (!tileSizes.empty())
+ tilingOptions.setTileSizes(tileSizes);
+
+ LinalgExt::LinalgExtTilingPattern pattern(this->getContext(), tilingOptions);
+ auto functionalTile =
+ [&](TilingInterface op,
+ PatternRewriter &rewriter) -> FailureOr<Operation *> {
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return result;
+ };
+
+ auto tileSeq = functional::SequenceBuilder().begin(std::move(functionalTile));
+ return functional::applyAt(target, tileSeq);
+}
+
+FailureOr<scf::ForOp> transform::RewriteLinalgExtTileToScfForOp::applyToOne(
+ LinalgExt::TileOp target) {
+ LinalgExt::TileOpToSCFRewriter pattern(this->getContext());
+ auto functionalRewrite =
+ [&](LinalgExt::TileOp op,
+ PatternRewriter &rewriter) -> FailureOr<scf::ForOp> {
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return result;
+ };
+ return functional::applyAt(target, functionalRewrite);
+}
+
+FailureOr<LinalgExt::InParallelOp>
+transform::RewriteLinalgExtTileToInParallelOp::applyToOne(
+ LinalgExt::TileOp target) {
+ LinalgExt::TileOpToInParallelRewriter pattern(this->getContext());
+ auto functionalRewrite =
+ [&](LinalgExt::TileOp op,
+ PatternRewriter &rewriter) -> FailureOr<LinalgExt::InParallelOp> {
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return result;
+ };
+ return functional::applyAt(target, functionalRewrite);
+}
+
+FailureOr<Operation *>
+transform::RewriteLinalgExtInParallelToAsyncOp::applyToOne(
+ LinalgExt::InParallelOp target) {
+ LinalgExt::InParallelOpToAsyncRewriter pattern(this->getContext());
+ auto functionalRewrite =
+ [&](LinalgExt::InParallelOp op,
+ PatternRewriter &rewriter) -> FailureOr<Operation *> {
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return result;
+ };
+ return functional::applyAt(target, functionalRewrite);
+}
+
+FailureOr<scf::ForOp>
+transform::RewriteLinalgExtInParallelToScfForOp::applyToOne(
+ LinalgExt::InParallelOp target) {
+ LinalgExt::InParallelOpToScfForRewriter pattern(this->getContext());
+ auto functionalRewrite =
+ [&](LinalgExt::InParallelOp op,
+ PatternRewriter &rewriter) -> FailureOr<scf::ForOp> {
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return result;
+ };
+ return functional::applyAt(target, functionalRewrite);
+}
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.cpp
new file mode 100644
index 0000000..e8d9477
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.cpp
@@ -0,0 +1,331 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "PDL.h"
+
+#include "iree-dialects/Transforms/Functional.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/PDL/IR/PDLOps.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
+
+namespace mlir {
+namespace linalg {
+
+/// Return ops that match any of the patterns.
+static SmallVector<Operation *>
+getMatchingOps(Operation *parent, const FrozenRewritePatternSet &patterns) {
+ PatternApplicator applicator(patterns);
+ applicator.applyDefaultCostModel();
+
+ // TODO: The C++ functional API needs better interoperability with PDL.
+ return functional::applyForEachIn(
+ parent,
+ [&](Operation *op, PatternRewriter &rewriter) -> FailureOr<Operation *> {
+ if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+ return op;
+ return failure();
+ });
+}
+
+/// Hook for PDL driver to check if an operation (`value`) is directly nested in
+/// a function with the name provided as constant parameter.
+/// TODO: PDL needs user-defined "questions".
+static LogicalResult nestedInFunc(PDLValue value, ArrayAttr constantParams,
+ PatternRewriter &rewriter) {
+ auto *operation = value.cast<Operation *>();
+ auto func = operation->getParentOfType<FuncOp>();
+ assert(constantParams.size() == 1 &&
+ "expected a constant param with function name");
+ auto functionSymbol = constantParams[0].dyn_cast<SymbolRefAttr>();
+ assert(functionSymbol && "expected a function name");
+
+ if (!func)
+ return rewriter.notifyMatchFailure(operation, "not nested in a function");
+ return success(functionSymbol.getLeafReference() == func.getName());
+}
+
+/// PDL rewrite hook that does nothing.
+static void noOpRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
+ PatternRewriter &rewriter, PDLResultList &results) {
+ assert(args.size() == 1 && "expected one argument");
+#ifndef NDEBUG
+ args.front().cast<Operation *>()->setAttr("iree_linalg_transform.matched",
+ rewriter.getUnitAttr());
+#endif
+}
+
+/// Construct a BlockAndValueMapping from `linalgOp` to `genericLinalgModelOp`.
+/// Walk both ops and check whether all subops are the same.
+static LogicalResult haveIdenticalBodiesImpl(LinalgOp linalgOp,
+ LinalgOp genericLinalgModelOp) {
+ BlockAndValueMapping bvm;
+ bvm.map(linalgOp.getBlock()->getArguments(),
+ genericLinalgModelOp.getBlock()->getArguments());
+ SmallVector<Operation *> linalgBodyOps;
+ linalgOp.getBlock()->walk(
+ [&](Operation *op) { linalgBodyOps.push_back(op); });
+
+ unsigned idx = 0;
+ WalkResult res = genericLinalgModelOp.getBlock()->walk([&](Operation *op) {
+ Operation *linalgSubOp = linalgBodyOps[idx++];
+ if (op->getName() != linalgSubOp->getName())
+ return WalkResult::interrupt();
+ if (op->getAttrs() != linalgSubOp->getAttrs())
+ return WalkResult::interrupt();
+ for (auto it : llvm::zip(op->getOperands(), linalgSubOp->getOperands()))
+ if (std::get<0>(it) != bvm.lookupOrNull(std::get<1>(it)))
+ return WalkResult::interrupt();
+ bvm.map(linalgSubOp->getResults(), op->getResults());
+ return WalkResult::advance();
+ });
+
+ return success(!res.wasInterrupted());
+}
+
+/// Dispatch body equivalence check depending on case.
+static LogicalResult haveEquivalentBodies(LinalgOp linalgOp,
+ LinalgOp genericLinalgModelOp,
+ PatternRewriter &rewriter) {
+ if (succeeded(haveIdenticalBodiesImpl(linalgOp, genericLinalgModelOp)))
+ return success();
+ // TODO: haveEquivalentBodiesImpl, see e.g.
+ // https://gist.github.com/nicolasvasilache/39e89e18c46e02335c16db6ec20a07e3
+ return failure();
+}
+
+/// Succeed when `linalgOp` and `linalgModelOp` are deemed equivalent.
+static LogicalResult isEquivalentToOpImpl(LinalgOp linalgOp,
+ LinalgOp linalgModelOp,
+ PatternRewriter &rewriter) {
+ // If basic properties do not match, return failure.
+ if (linalgOp.inputs() != linalgModelOp.inputs() ||
+ linalgOp.outputs() != linalgModelOp.outputs() ||
+ linalgOp.indexing_maps() != linalgModelOp.indexing_maps() ||
+ linalgOp.iterator_types() != linalgModelOp.iterator_types())
+ return failure();
+
+ // Build the block and go perform a body comparison.
+ {
+ // createBlock moves the insertion point, scope it in an RAII block.
+ OpBuilder::InsertionGuard guard(rewriter);
+ Region &r = linalgModelOp->getRegion(0);
+ Block *bodyBlock = rewriter.createBlock(
+ &r, r.end(), linalgOp.getBlock()->getArgumentTypes(),
+ llvm::to_vector<4>(
+ llvm::map_range(linalgOp.getBlock()->getArguments(),
+ [](Value v) { return v.getLoc(); })));
+ ImplicitLocOpBuilder b(linalgModelOp.getLoc(), rewriter);
+ auto regionBuilder = linalgModelOp.getRegionBuilder();
+ llvm::ArrayRef<mlir::NamedAttribute> attrs = {};
+ regionBuilder(b, *bodyBlock, attrs);
+ }
+
+ return haveEquivalentBodies(linalgOp, linalgModelOp, rewriter);
+}
+
+/// Check whether the unique Operation* `op` stored in `value` (assumed) is
+/// equivalent to the unique StringRefAttr operation name passed in
+/// `constantParams`.
+/// Equivalence is achieved when either:
+/// 1. `op` has the name stored in `constantParams`.
+/// 2. `op` and `constantParams` are both linalg ops and their structured
+/// interfaces as well as their bodies are equivalent.
+/// Structured interfaces equivalence is a simple attribute level check.
+/// Body equivalence is more involved and currently limited:
+/// a. the current impl constructs an instance of the op whose name is
+/// specified in `constantParams` and checks for exact body equality.
+/// b. a more advanced version would "subtract" the bodies and fold, cse
+/// and canonicalize to fixed point. If the result is "all zeros",
+/// then the bodies would be equivalent (really isomorphic).
+/// 3. other cases TBD (e.g. vector.generic when available).
+static LogicalResult isEquivalentToOp(PDLValue value, ArrayAttr constantParams,
+ PatternRewriter &rewriter) {
+ auto *maybeOp = value.dyn_cast<Operation *>();
+ if (!maybeOp)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ Operation *op = maybeOp;
+
+ ArrayRef<Attribute> attrs = constantParams.getValue();
+ if (attrs.size() != 1)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ auto modelOpNameAttr = attrs.front().dyn_cast<StringAttr>();
+ if (!modelOpNameAttr)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ auto modelOpName = modelOpNameAttr.strref();
+
+ // 1. If op has name `modelOpName`, the match is trivial.
+ if (op->getName().getStringRef() == modelOpName)
+ return success();
+
+ // 2. Linalg vs Linalg.
+ // Create op from `constantParams`.
+ OperationState modelOpState(op->getLoc(), modelOpName, op->getOperands(),
+ op->getResultTypes(), op->getAttrs());
+ modelOpState.addRegion();
+ Operation *modelOp = rewriter.createOperation(modelOpState);
+ auto g1 = llvm::make_scope_exit([&]() { rewriter.eraseOp(modelOp); });
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
+ LinalgOp linalgModelOp = dyn_cast<LinalgOp>(modelOp);
+ if (linalgOp && linalgModelOp)
+ return isEquivalentToOpImpl(linalgOp, linalgModelOp, rewriter);
+
+ // 3. TBD
+ return failure();
+}
+
+/// Assume that:
+/// 1. `value` is an operands range
+/// 2. `constantParams` contains a DictAttr with `operand_number`, `dim` and
+/// `divisor` IntegerAttr entries.
+/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is a
+/// multiple of `divisor`.
+/// Note: 0 is the convention to express "do not tile", it is considered to
+/// divide everything.
+static LogicalResult isDimMultipleOf(PDLValue value, ArrayAttr constantParams,
+ PatternRewriter &rewriter) {
+ auto maybeOperands = value.dyn_cast<ValueRange>();
+ if (!maybeOperands)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ auto operands = *maybeOperands;
+
+ auto dict = constantParams.begin()->dyn_cast<DictionaryAttr>();
+ if (!dict)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+
+ int64_t dim;
+ auto dimAttr = dict.getAs<IntegerAttr>("dim");
+ if (!dimAttr)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ dim = dimAttr.getInt();
+
+ int64_t divisor;
+ auto divisorAttr = dict.getAs<IntegerAttr>("divisor");
+ if (!divisorAttr)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ divisor = divisorAttr.getInt();
+
+ int64_t operandNumber;
+ auto operandNumberAttr = dict.getAs<IntegerAttr>("operand_number");
+ if (!operandNumberAttr)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ operandNumber = operandNumberAttr.getInt();
+
+ ShapedType shapedType;
+ if (static_cast<int64_t>(operands.size()) > operandNumber)
+ shapedType = operands[operandNumber].getType().dyn_cast<ShapedType>();
+ if (!shapedType || shapedType.getRank() <= dim)
+ return failure();
+ return success(divisor == 0 || (shapedType.getShape()[dim] > 0 &&
+ shapedType.getShape()[dim] % divisor == 0));
+}
+
+/// Assume that:
+/// 1. `value` is an operands range
+/// 2. `constantParams` contains a DictAttr with `operand_number` and `dim`
+/// IntegerAttr entries.
+/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is
+/// dynamic.
+static LogicalResult isDimStatic(PDLValue value, ArrayAttr constantParams,
+ PatternRewriter &rewriter) {
+ auto maybeOperands = value.dyn_cast<ValueRange>();
+ if (!maybeOperands)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ auto operands = *maybeOperands;
+
+ auto dict = constantParams.begin()->dyn_cast<DictionaryAttr>();
+ if (!dict)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+
+ int64_t dim;
+ auto dimAttr = dict.getAs<IntegerAttr>("dim");
+ if (!dimAttr)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ dim = dimAttr.getInt();
+
+ int64_t operandNumber;
+ auto operandNumberAttr = dict.getAs<IntegerAttr>("operand_number");
+ if (!operandNumberAttr)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ operandNumber = operandNumberAttr.getInt();
+
+ ShapedType shapedType;
+ if (static_cast<int64_t>(operands.size()) > operandNumber)
+ shapedType = operands[operandNumber].getType().dyn_cast<ShapedType>();
+ return success(shapedType && !shapedType.isDynamicDim(dim));
+}
+
+/// Assume that:
+/// 1. `value` is an operands range
+/// 2. `constantParams` contains a DictAttr with `operand_number` and `dim`
+/// IntegerAttr entries.
+/// Succeed if `value`[`operand_number`] is a ranked type whose `dim` is
+/// dynamic.
+static LogicalResult isDimDynamic(PDLValue value, ArrayAttr constantParams,
+ PatternRewriter &rewriter) {
+ auto maybeOperands = value.dyn_cast<ValueRange>();
+ if (!maybeOperands)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ auto operands = *maybeOperands;
+
+ auto dict = constantParams.begin()->dyn_cast<DictionaryAttr>();
+ if (!dict)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+
+ int64_t dim;
+ auto dimAttr = dict.getAs<IntegerAttr>("dim");
+ if (!dimAttr)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ dim = dimAttr.getInt();
+
+ int64_t operandNumber;
+ auto operandNumberAttr = dict.getAs<IntegerAttr>("operand_number");
+ if (!operandNumberAttr)
+ return failure(); // TODO: notifyMatchFailure needs an Operation* handle.
+ operandNumber = operandNumberAttr.getInt();
+
+ ShapedType shapedType;
+ if (static_cast<int64_t>(operands.size()) > operandNumber)
+ shapedType = operands[operandNumber].getType().dyn_cast<ShapedType>();
+ return success(shapedType && shapedType.isDynamicDim(dim));
+}
+
+FailureOr<SmallVector<Operation *>> findMatchingOps(transform::MatchOp matchOp,
+ SymbolRefAttr pattern,
+ Operation *containerOp) {
+ auto symbolTableOp = matchOp->getParentWithTrait<OpTrait::SymbolTable>();
+ if (!symbolTableOp)
+ return {symbolTableOp->emitError("no parent op with a SymbolTable")};
+ auto patternOp = dyn_cast_or_null<pdl::PatternOp>(
+ SymbolTable::lookupSymbolIn(symbolTableOp, pattern));
+ if (!patternOp)
+ return {symbolTableOp->emitError("could not find a pattern named: ")
+ << pattern};
+
+ // Clone the pattern operation into the temporary module used by the driver
+ // as it might be referenced multiple times.
+ OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
+ OpBuilder::atBlockBegin(pdlModuleOp->getBody()).clone(*patternOp);
+
+ // Build the PDL module.
+ PDLPatternModule pdlModule(std::move(pdlModuleOp));
+ pdlModule.registerConstraintFunction("nestedInFunc", nestedInFunc);
+ pdlModule.registerConstraintFunction("isDimDynamic", isDimDynamic);
+ pdlModule.registerConstraintFunction("isDimMultipleOf", isDimMultipleOf);
+ pdlModule.registerConstraintFunction("isDimStatic", isDimStatic);
+ pdlModule.registerConstraintFunction("isEquivalentToOp", isEquivalentToOp);
+ pdlModule.registerRewriteFunction("iree_linalg_transform.apply",
+ noOpRewriter);
+
+ RewritePatternSet patterns(std::move(pdlModule));
+ return getMatchingOps(containerOp, std::move(patterns));
+}
+
+} // namespace linalg
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.h b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.h
new file mode 100644
index 0000000..94d126e
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/PDL.h
@@ -0,0 +1,34 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_TRANSFORMS_PDL_H
+#define IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_TRANSFORMS_PDL_H
+
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+namespace linalg {
+
+/// Find all operations in `containerOp` that are matched by the specified PDL
+/// `matchOp`, which is located in the same parent ModuleOp as `matchOp`.
+FailureOr<SmallVector<Operation *>> findMatchingOps(transform::MatchOp matchOp,
+ SymbolRefAttr pattern,
+ Operation *containerOp);
+
+inline FailureOr<SmallVector<Operation *>>
+findMatchingOps(transform::MatchOp matchOp, Operation *containerOp) {
+ return findMatchingOps(matchOp, matchOp.targetMatcher(), containerOp);
+}
+
+} // namespace linalg
+} // namespace mlir
+
+#endif // IREE_LLVM_SANDBOX_DIALECTS_LINALGTRANSFORM_TRANSFORMS_PDL_H
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/ScopedTransform.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/ScopedTransform.cpp
new file mode 100644
index 0000000..b3bfeab
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/ScopedTransform.cpp
@@ -0,0 +1,82 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h"
+
+#include "mlir/Transforms/InliningUtils.h"
+
+using namespace mlir;
+
+namespace {
+struct Rewriter : public PatternRewriter {
+ Rewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}
+};
+} // namespace
+
+linalg::transform::ScopeOp linalg::transform::wrapInScope(Operation *op) {
+ Rewriter rewriter(op->getContext());
+ rewriter.setInsertionPoint(op);
+
+ auto scope = rewriter.create<linalg::transform::ScopeOp>(
+ op->getLoc(), op->getResultTypes(), op->getOperands());
+ Region &body = scope.body();
+ rewriter.setInsertionPointToStart(&body.emplaceBlock());
+ BlockAndValueMapping bv;
+ SmallVector<Location> locs(op->getOperandTypes().size(), op->getLoc());
+ bv.map(op->getOperands(), body.addArguments(op->getOperandTypes(), locs));
+
+ Operation *cloneInScope = rewriter.clone(*op, bv);
+ rewriter.create<ForwardOp>(op->getLoc(), cloneInScope->getResults());
+
+ rewriter.replaceOp(op, scope.getResults());
+ return scope;
+}
+
+namespace {
+/// Instruct the inliner to inline everything. Scopes have no semantic meaning
+/// so moving operations in and out of them, regardless of whether their
+/// dialects have implemented an inliner interface, is valid.
+struct ScopeInliner : public InlinerInterface {
+ using InlinerInterface::InlinerInterface;
+
+ bool isLegalToInline(Operation *call, Operation *callable,
+ bool wouldBeCloned) const override {
+ return true;
+ }
+ bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
+ BlockAndValueMapping &valueMapping) const override {
+ return true;
+ }
+ bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
+ BlockAndValueMapping &valueMapping) const override {
+ return true;
+ }
+
+ /// Don't recursively analyze operations, because they can all be "inlined".
+ bool shouldAnalyzeRecursively(Operation *op) const override { return false; }
+
+ /// Replace uses of the results with the `forward` op's operands.
+ void handleTerminator(Operation *op,
+ ArrayRef<Value> valuesToRepl) const override {
+ assert(isa<linalg::transform::ForwardOp>(op));
+ for (auto value : llvm::zip(op->getOperands(), valuesToRepl))
+ std::get<1>(value).replaceAllUsesWith(std::get<0>(value));
+ }
+};
+} // namespace
+
+FailureOr<SmallVector<Operation *>>
+linalg::transform::unwrapScope(linalg::transform::ScopeOp scope) {
+ ScopeInliner interface(scope->getContext());
+ SmallVector<Operation *> ops;
+ scope.body().walk([&](Operation *op) { ops.push_back(op); });
+ if (failed(inlineRegion(interface, &scope.body(), scope, scope.getOperands(),
+ scope.getResults(), /*inlineLoc=*/{},
+ /*shouldCloneInlinedRegion=*/false)))
+ return failure();
+ Rewriter(scope->getContext()).eraseOp(scope);
+ return ops;
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/TrackingListener.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/TrackingListener.cpp
new file mode 100644
index 0000000..5969b7f
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/TrackingListener.cpp
@@ -0,0 +1,183 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgTransform/TrackingListener.h"
+
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tracking-listener"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+namespace mlir {
+namespace linalg {
+
+/// Find the linalg op that defines all values in range, potentially
+/// transitively through tensor casts.
+static LinalgOp findSingleLinalgOpDefiningAll(ValueRange range) {
+ LinalgOp sourceOp = nullptr;
+ for (Value value : range) {
+ // See through tensor casts.
+ //
+ // TODO: we may need some generalization (interfaces?) of this for other
+ // operations, especially multi-operand ones to understand which of their
+ // operands may be coming from a Linalg op. Or a completely different
+ // mechanism of tracking op replacement at creation, or even different
+ // patterns that identify the "main" result of a transformation.
+ while (auto castOp = value.getDefiningOp<tensor::CastOp>())
+ value = castOp.source();
+
+ if (auto currentSourceOp = value.getDefiningOp<LinalgOp>()) {
+ if (!sourceOp || sourceOp == currentSourceOp) {
+ sourceOp = currentSourceOp;
+ continue;
+ }
+
+ LLVM_DEBUG(
+ DBGS() << "different source linalg ops for replacing one op: \n"
+ << sourceOp << "\n"
+ << currentSourceOp << "\n");
+ }
+ LLVM_DEBUG(DBGS() << "replacing linalg op with unknown non-linalg op:\n"
+ << *value.getDefiningOp() << "\n");
+ return nullptr;
+ }
+ return sourceOp;
+}
+
+/// Find the scf "for" op that defines all values in the range.
+static scf::ForOp findSingleForOpDefiningAll(ValueRange range) {
+ scf::ForOp forOp = nullptr;
+ for (Value value : range) {
+ if (auto currentSourceOp = value.getDefiningOp<scf::ForOp>()) {
+ if (!forOp || forOp == currentSourceOp) {
+ forOp = currentSourceOp;
+ continue;
+ }
+ LLVM_DEBUG(
+ DBGS() << "different source scf.for ops when replacing one op\n");
+ }
+
+ LLVM_DEBUG(
+ DBGS()
+ << "could not find a source scf.for when replacing another scf.for\n");
+ return nullptr;
+ }
+ return forOp;
+}
+
+// Find a single op that defines all values in the range, optionally
+// transitively through other operations in an op-specific way.
+static Operation *findSingleDefiningOp(Operation *replacedOp,
+ ValueRange range) {
+ return llvm::TypeSwitch<Operation *, Operation *>(replacedOp)
+ .Case<LinalgOp>([&](LinalgOp) -> Operation * {
+ return findSingleLinalgOpDefiningAll(range);
+ })
+ .Case<scf::ForOp>([&](scf::ForOp) -> Operation * {
+ return findSingleForOpDefiningAll(range);
+ })
+ .Default([](Operation *) -> Operation * { return nullptr; });
+}
+
+TrackingListener::TrackingListener(transform::TransformState &state)
+ : transform::TransformState::Extension(state) {
+ for (const auto &pair : getMapping())
+ for (Operation *op : pair.second)
+ trackedOperationKeys.try_emplace(op, pair.first);
+}
+
+void TrackingListener::notifyOperationReplaced(Operation *op,
+ ValueRange newValues) {
+ // Don't attempt to track in error state.
+ if (hadErrors)
+ return;
+
+ // Exit early if the op is not tracked.
+ auto keyIt = trackedOperationKeys.find(op);
+ if (keyIt == trackedOperationKeys.end())
+ return;
+ Value key = keyIt->second;
+
+ Operation *replacement = findSingleDefiningOp(op, newValues);
+ if (!replacement) {
+ emitError(op) << "could not find replacement for tracked op";
+ return;
+ }
+
+ LLVM_DEBUG(DBGS() << "replacing tracked " << *op << " with " << *replacement
+ << " for " << key << "\n");
+ updatePayloadOps(key, [op, replacement](Operation *tracked) {
+ return tracked == op ? replacement : tracked;
+ });
+
+ // Update the backwards map. The replacement operation must not be already
+ // associated with another key as that would break the bidirectional mapping
+ // invariant. Note that operations are pointer-like so we must ensure the
+ // absence of accidental reuse of the pointer address with some deleted
+ // operation that stayed in this mapping.
+ trackedOperationKeys.erase(op);
+ bool replaced = trackedOperationKeys.try_emplace(replacement, key).second;
+ if (!replaced) {
+ InFlightDiagnostic diag =
+ emitError(replacement)
+ << "replacement operation is already associated with another key";
+ diag.attachNote(op->getLoc()) << "replacing this operation";
+ diag.attachNote(trackedOperationKeys.lookup(replacement).getLoc())
+ << "old key";
+ diag.attachNote(key.getLoc()) << "new key";
+ return;
+ }
+}
+
+void TrackingListener::notifyOperationRemoved(Operation *op) {
+ // Don't attempt to track in error state.
+ if (hadErrors)
+ return;
+
+ auto keyIt = trackedOperationKeys.find(op);
+ if (keyIt == trackedOperationKeys.end())
+ return;
+ Value key = keyIt->second;
+
+ LLVM_DEBUG(DBGS() << "removing tracked " << *op << " for " << key << "\n");
+
+ // If a tracked operation is CSE'd, then any further transformations are
+ // redundant. Just remove it.
+ trackedOperationKeys.erase(op);
+ updatePayloadOps(key, [op](Operation *tracked) {
+ return tracked != op ? tracked : nullptr;
+ });
+}
+
+void TrackingListener::notifySetPayload(Value handle,
+ ArrayRef<Operation *> operations) {
+ for (Operation *op : operations) {
+ assert(trackedOperationKeys.lookup(op) == Value() &&
+ "payload op already associated with another key");
+ trackedOperationKeys[op] = handle;
+ }
+}
+
+void TrackingListener::notifyRemovePayload(Value handle,
+ ArrayRef<Operation *> operations) {
+ for (Operation *op : operations)
+ trackedOperationKeys.erase(op);
+}
+
+InFlightDiagnostic TrackingListener::emitError(Operation *op,
+ const llvm::Twine &message) {
+ hadErrors = true;
+#ifndef NDEBUG
+ errorStateChecked = false;
+#endif // NDEBUG
+ return op->emitError(message);
+}
+} // namespace linalg
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/TrackingRewriteDriver.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/TrackingRewriteDriver.cpp
new file mode 100644
index 0000000..12fda30
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/TrackingRewriteDriver.cpp
@@ -0,0 +1,17 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgTransform/TrackingRewriteDriver.h"
+
+#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+LogicalResult mlir::applyPatternsTrackAndFoldGreedily(
+ Operation *root, RewriteListener &listener,
+ const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config) {
+ return applyPatternsAndFoldGreedily(root, patterns, config, &listener);
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/TransformOpInterface.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/TransformOpInterface.cpp
new file mode 100644
index 0000000..6ae0291
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/TransformOpInterface.cpp
@@ -0,0 +1,155 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h"
+
+#include "llvm/ADT/SmallPtrSet.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+//===----------------------------------------------------------------------===//
+// TransformState
+//===----------------------------------------------------------------------===//
+
+constexpr const Value transform::TransformState::kTopLevelValue;
+
+transform::TransformState::TransformState(Operation *root) {
+ operations[kTopLevelValue].push_back(root);
+}
+
+Operation *transform::TransformState::getTopLevel() const {
+ return operations.lookup(kTopLevelValue).front();
+}
+
+ArrayRef<Operation *>
+transform::TransformState::getPayloadOps(Value value) const {
+ auto iter = operations.find(value);
+ assert(iter != operations.end() && "unknown handle");
+ return iter->getSecond();
+}
+
+LogicalResult
+transform::TransformState::setPayloadOps(Value value,
+ ArrayRef<Operation *> targets) {
+ assert(value != kTopLevelValue &&
+ "attempting to reset the transformation root");
+
+ if (value.use_empty())
+ return success();
+
+ SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
+ bool inserted = operations.insert({value, std::move(storedTargets)}).second;
+ assert(inserted && "value is already associated with another list");
+ (void)inserted;
+
+ const SmallVector<Operation *> ¤tOperationList =
+ operations.lookup(value);
+ llvm::SmallPtrSet<Operation *, 4> currentOperationSet(
+ currentOperationList.begin(), currentOperationList.end());
+ for (const auto &kvp : operations) {
+ if (kvp.getFirst() == value)
+ continue;
+ for (Operation *trackedOp : kvp.getSecond()) {
+ if (currentOperationSet.contains(trackedOp)) {
+ InFlightDiagnostic diag = trackedOp->emitError()
+ << "operation tracked by two handles";
+ diag.attachNote(value.getLoc()) << "handle";
+ diag.attachNote(kvp.getFirst().getLoc()) << "handle";
+ return diag;
+ }
+ }
+ }
+
+ for (const auto &keyedExtension : extensions)
+ keyedExtension.getSecond()->sendNotifySetPayload(value, targets);
+
+ return success();
+}
+
+void transform::TransformState::removePayloadOps(Value value) {
+ auto it = operations.find(value);
+ if (it == operations.end())
+ return;
+
+ for (const auto &keyedExtension : extensions)
+ keyedExtension.getSecond()->sendNotifyRemovePayload(value, it->getSecond());
+
+ operations.erase(it);
+}
+
+void transform::TransformState::updatePayloadOps(
+ Value value, function_ref<Operation *(Operation *)> callback) {
+ auto it = operations.find(value);
+ assert(it != operations.end() && "unknown handle");
+ SmallVector<Operation *> &association = it->getSecond();
+ SmallVector<Operation *> updated;
+ updated.reserve(association.size());
+
+ for (Operation *op : association)
+ if (Operation *updatedOp = callback(op))
+ updated.push_back(updatedOp);
+
+ for (const auto &keyedExtension : extensions)
+ keyedExtension.getSecond()->sendNotifyUpdatePayload(value, association,
+ updated);
+
+ std::swap(association, updated);
+}
+
+LogicalResult
+transform::TransformState::applyTransform(TransformOpInterface transform) {
+ transform::TransformResults results(transform->getNumResults());
+ if (failed(transform.apply(results, *this)))
+ return failure();
+
+ for (Value target : transform->getOperands())
+ removePayloadOps(target);
+
+ for (auto en : llvm::enumerate(transform->getResults()))
+ if (failed(setPayloadOps(en.value(), results.get(en.index()))))
+ return failure();
+
+ return success();
+}
+
+// Out-of-line definition to ensure vtable and metadata are emitted to a single
+// .o file.
+transform::TransformState::Extension::~Extension() {}
+
+//===----------------------------------------------------------------------===//
+// TransformResults
+//===----------------------------------------------------------------------===//
+
+transform::TransformResults::TransformResults(unsigned numSegments) {
+ segments.resize(numSegments,
+ ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
+}
+
+void transform::TransformResults::set(OpResult value,
+ ArrayRef<Operation *> ops) {
+ unsigned position = value.getResultNumber();
+ assert(position < segments.size() &&
+ "setting results for a non-existent handle");
+ assert(segments[position].data() == nullptr && "results already set");
+ unsigned start = operations.size();
+ llvm::append_range(operations, ops);
+ segments[position] = makeArrayRef(operations).drop_front(start);
+}
+
+ArrayRef<Operation *>
+transform::TransformResults::get(unsigned position) const {
+ assert(position < segments.size() &&
+ "querying results for a non-existent handle");
+ assert(segments[position].data() != nullptr && "querying unset results");
+ return segments[position];
+}
+
+//===----------------------------------------------------------------------===//
+// Generated interface implementation.
+//===----------------------------------------------------------------------===//
+
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..1680b7c
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/CMakeLists.txt
@@ -0,0 +1,25 @@
+add_mlir_library(IREELinalgTransformDialectTransforms
+ ExpertExpansion.cpp
+ TrackingCSE.cpp
+ TransformInterpreter.cpp
+
+ DEPENDS
+ mlir-headers
+
+ LINK_LIBS PUBLIC
+ IREELinalgTransformDialect
+
+ MLIRBufferization
+ MLIRIR
+ MLIRLinalg
+ MLIRLLVMIR
+ MLIRMath
+ MLIRMathToLLVM
+ MLIRMemRef
+ MLIRMemRefToLLVM
+ MLIRPass
+ MLIRTensor
+ MLIRTransforms
+ MLIRVector
+ MLIRVectorToLLVM
+)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/ExpertExpansion.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/ExpertExpansion.cpp
new file mode 100644
index 0000000..829e620
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/ExpertExpansion.cpp
@@ -0,0 +1,118 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
+#include "iree-dialects/Dialect/LinalgTransform/Passes.h"
+#include "iree-dialects/Dialect/LinalgTransform/SimplePatternRewriter.h"
+#include "mlir/Dialect/PDL/IR/PDLOps.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "expert-expansion"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]")
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Expands the linalg::transform::ExpertOp instances in the `module` into lists
+/// of transformations as described by the `expansions` module that contains
+/// PDL.
+static void expandStrategyOps(ModuleOp module, ModuleOp expansions) {
+ mlir::OwningOpRef<mlir::ModuleOp> clonedExpansions(
+ cast<ModuleOp>(expansions->clone()));
+ RewritePatternSet patterns(std::move(clonedExpansions));
+ FrozenRewritePatternSet frozen(std::move(patterns));
+ PatternApplicator applicator(frozen);
+ applicator.applyDefaultCostModel();
+
+ SimplePatternRewriter rewriter(module.getContext());
+ module.walk([&](transform::ExpertOp expertOp) {
+ rewriter.setInsertionPoint(expertOp);
+ if (failed(applicator.matchAndRewrite(expertOp, rewriter))) {
+ LLVM_DEBUG(DBGS() << "failed to rewrite strategy \""
+ << expertOp.expertName() << "\"\n");
+ }
+ });
+}
+
+namespace {
+struct ExpertExpansion : public PassWrapper<ExpertExpansion, Pass> {
+ Pass::Option<std::string> strategyModuleName{
+ *this, "strategy-module-name", llvm::cl::init("strategies"),
+ llvm::cl::desc(
+ "Name of the nested module containing expert strategies.")};
+
+ explicit ExpertExpansion(StringRef name = "strategies")
+ : PassWrapper<ExpertExpansion, Pass>() {
+ strategyModuleName = name.str();
+ }
+
+ ExpertExpansion(const ExpertExpansion &other)
+ : PassWrapper<ExpertExpansion, Pass>(other) {
+ strategyModuleName = other.strategyModuleName.getValue();
+ }
+
+ StringRef getArgument() const final {
+ return "linalg-transform-expert-expansion";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>();
+ }
+
+ StringRef getDescription() const final {
+ return "Expands transformation experts into individual transformations";
+ }
+
+ bool canScheduleOn(RegisteredOperationName opName) const override {
+ return true;
+ }
+
+ void runOnOperation() override {
+ auto module = dyn_cast<ModuleOp>(getOperation());
+ if (!module)
+ return signalPassFailure();
+
+ ModuleOp strategyModule = nullptr;
+ for (auto nestedModule : module.getOps<ModuleOp>()) {
+ Optional<StringRef> name = nestedModule.sym_name();
+ if (!name)
+ continue;
+
+ if (*name == strategyModuleName) {
+ if (!strategyModule) {
+ strategyModule = nestedModule;
+ continue;
+ }
+ InFlightDiagnostic diag = nestedModule->emitError()
+ << "more than one strategy module provided";
+ diag.attachNote(strategyModule->getLoc()) << "previous strategy module";
+ return signalPassFailure();
+ }
+ }
+
+ if (!strategyModule) {
+ module->emitError() << "expected a nested strategy module";
+ return signalPassFailure();
+ }
+
+ expandStrategyOps(module, strategyModule);
+ strategyModule->erase();
+ }
+};
+} // namespace
+
+void mlir::linalg::transform::registerLinalgTransformExpertExpansionPass() {
+ PassRegistration<ExpertExpansion>(
+ []() { return std::make_unique<ExpertExpansion>(); });
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TrackingCSE.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TrackingCSE.cpp
new file mode 100644
index 0000000..f4b3c59
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TrackingCSE.cpp
@@ -0,0 +1,16 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgTransform/TrackingCSE.h"
+
+#include "iree-dialects/Transforms/ListenerCSE.h"
+
+using namespace mlir;
+
+LogicalResult mlir::eliminateCommonSubexpressionsWithTrackedOps(
+ Operation *root, RewriteListener &listener, DominanceInfo *domInfo) {
+ return eliminateCommonSubexpressions(root, domInfo, &listener);
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TransformInterpreter.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TransformInterpreter.cpp
new file mode 100644
index 0000000..286e313
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/Transforms/TransformInterpreter.cpp
@@ -0,0 +1,344 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
+#include "iree-dialects/Dialect/LinalgTransform/Passes.h"
+#include "iree-dialects/Dialect/LinalgTransform/TrackingCSE.h"
+#include "iree-dialects/Dialect/LinalgTransform/TrackingRewriteDriver.h"
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpInterface.h"
+#include "iree-dialects/Dialect/LinalgTransform/TransformOpMapping.h"
+#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
+#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/PDL/IR/PDLOps.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/SourceMgr.h"
+
+#define DEBUG_TYPE "transform-interpreter"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static llvm::cl::opt<std::string> clTransformFileName(
+ "linalg-transform-file-name",
+ llvm::cl::desc("mlir file containing a top-level module that specifies "
+ "the transformations to apply."),
+ llvm::cl::init(""));
+
+//===----------------------------------------------------------------------===//
+// Linalg Interpreter Driver
+//===----------------------------------------------------------------------===//
+
+/// Run enabling transformations (LICM and its variants, single-iteration loop
+/// removal, CSE) on the given function.
+static LogicalResult performEnablerTransformations(
+ FuncOp func, RewriteListener &listener,
+ linalg::LinalgEnablingOptions options = linalg::LinalgEnablingOptions()) {
+ MLIRContext *ctx = func->getContext();
+ RewritePatternSet patterns(ctx);
+ linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
+ scf::populateSCFForLoopCanonicalizationPatterns(patterns);
+ if (failed(applyPatternsTrackAndFoldGreedily(func, listener,
+ std::move(patterns))))
+ return failure();
+
+ // This assumes LICM never removes operations so we don't need tracking.
+ if (options.licm) {
+ WalkResult result =
+ func->walk([](LoopLikeOpInterface loopLike) -> WalkResult {
+ return moveLoopInvariantCode(loopLike);
+ });
+ if (result.wasInterrupted())
+ return failure();
+ }
+
+ func.walk([](Operation *op) {
+ (void)llvm::TypeSwitch<Operation *, LogicalResult>(op)
+ .Case<AffineForOp, scf::ForOp>(
+ [](auto loop) { return promoteIfSingleIteration(loop); })
+ .Default([](Operation *) { return success(); });
+ });
+
+ if (options.hoistRedundantVectorTransfers)
+ hoistRedundantVectorTransfers(func);
+ if (options.hoistRedundantVectorTransfersOnTensor)
+ hoistRedundantVectorTransfersOnTensor(func);
+
+ return eliminateCommonSubexpressionsWithTrackedOps(func, listener);
+}
+
+/// Run enabling transformations on the given `containerOp` while preserving the
+/// operation tracking information.
+static LogicalResult performEnablerTransformations(
+ Operation *containerOp, RewriteListener &listener,
+ linalg::LinalgEnablingOptions options = linalg::LinalgEnablingOptions()) {
+ auto res = containerOp->walk([&](FuncOp func) {
+ if (failed(performEnablerTransformations(func, listener, options)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ return failure(res.wasInterrupted());
+}
+
+static LogicalResult executeTransform(Operation *operation,
+ transform::TransformState &state) {
+ auto iface = dyn_cast<transform::TransformOpInterface>(operation);
+ if (!iface)
+ return operation->emitError() << "unknown transformation operation";
+
+ return state.applyTransform(iface);
+}
+
+/// Perform the transformation specified by the callback and unconditionally
+/// check the error state of the listener. Return failure if either failed.
+static LogicalResult checkedListenerTransform(
+ function_ref<LogicalResult(TrackingListener &)> transform,
+ TrackingListener &listener) {
+ // Make sure we check the listener error state regardless of the transform
+ // result.
+ LogicalResult transformResult = transform(listener);
+ LogicalResult listenerResult = listener.checkErrorState();
+ return failure(failed(transformResult) || failed(listenerResult));
+}
+
+/// Applies the transformations listed in the `sequence` to operations starting
+/// from `target`. The following transformations may be applied to operations
+/// produced by previous transformations as indicated by SSA value flow in the
+/// Linalg Transform dialect.
+static LogicalResult executeSequence(linalg::transform::SequenceOp sequence,
+ Operation *containerOp) {
+ MLIRContext *ctx = containerOp->getContext();
+ RewritePatternSet patternList(ctx);
+ for (Dialect *dialect : ctx->getLoadedDialects())
+ dialect->getCanonicalizationPatterns(patternList);
+ for (RegisteredOperationName op : ctx->getRegisteredOperations())
+ op.getCanonicalizationPatterns(patternList, ctx);
+ FrozenRewritePatternSet patterns(std::move(patternList));
+
+ transform::TransformState state(containerOp);
+ TrackingListener &listener = state.addExtension<TrackingListener>();
+
+ // Run the canonicalizations upfront so we don't match and transform
+ // operations only to drop them later.
+ if (failed(checkedListenerTransform(
+ [&](TrackingListener &listener) {
+ return eliminateCommonSubexpressionsWithTrackedOps(containerOp,
+ listener);
+ },
+ listener))) {
+ LLVM_DEBUG(DBGS() << "failed to perform CSE\n");
+ return failure();
+ }
+ if (failed(checkedListenerTransform(
+ [&](TrackingListener &listener) {
+ return applyPatternsTrackAndFoldGreedily(containerOp, listener,
+ patterns);
+ },
+ listener))) {
+ LLVM_DEBUG(DBGS() << "failed to apply canonicalization patterns\n");
+ return failure();
+ }
+
+ for (Operation &transform : sequence.body().front()) {
+ if (failed(executeTransform(&transform, state))) {
+ std::string str;
+ llvm::raw_string_ostream ss(str);
+ ss << "failed to apply: " << transform << "\nto\n" << containerOp;
+ ss.flush();
+ return transform.emitError() << str;
+ }
+
+ LLVM_DEBUG(DBGS() << "successfully applied transform: " << transform
+ << "\n");
+
+ // Run CSE, enabling transformations and canonicalization. This is similar
+ // to running the respective pass, but (a) keeps tracking the value/op
+ // mapping and (b) avoids constructing the pattern set + pass pipeline on
+ // every step.
+ // TODO: consider better targeting than module-level transformations here:
+ // e.g., the enabler internals can apply to one function only. Furthermore,
+ // we don't need all of enabler transformations after/before all passes.
+ if (failed(checkedListenerTransform(
+ [&](TrackingListener &listener) {
+ return eliminateCommonSubexpressionsWithTrackedOps(containerOp,
+ listener);
+ },
+ listener))) {
+ LLVM_DEBUG(DBGS() << "failed to perform CSE\n");
+ return failure();
+ }
+
+ // TODO: this runs CSE internally, mostly redundant with the above.
+ if (failed(checkedListenerTransform(
+ [&](TrackingListener &listener) {
+ return performEnablerTransformations(containerOp, listener);
+ },
+ listener))) {
+ LLVM_DEBUG(DBGS() << "enabler transformations failed\n");
+ return failure();
+ }
+
+ if (failed(checkedListenerTransform(
+ [&](TrackingListener &listener) {
+ return applyPatternsTrackAndFoldGreedily(containerOp, listener,
+ patterns);
+ },
+ listener))) {
+ LLVM_DEBUG(DBGS() << "failed to apply canonicalization patterns\n");
+ return failure();
+ }
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Linalg Interpreter Pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Pass that executes transformations specified by a module-level
+/// iree_linalg_transform.apply operation on the same module.
+struct InterpreterPass : public PassWrapper<InterpreterPass, Pass> {
+ StringRef getArgument() const final { return "linalg-interp-transforms"; }
+
+ StringRef getDescription() const final {
+ return "Executes transformations specified in Linalg Transform dialect";
+ }
+
+ bool canScheduleOn(RegisteredOperationName opName) const override {
+ return true;
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ // clang-format off
+ registry.insert<arith::ArithmeticDialect,
+ AffineDialect,
+ bufferization::BufferizationDialect,
+ func::FuncDialect,
+ linalg::LinalgDialect,
+ linalg::transform::LinalgTransformDialect,
+ LLVM::LLVMDialect,
+ pdl::PDLDialect,
+ pdl_interp::PDLInterpDialect,
+ scf::SCFDialect,
+ tensor::TensorDialect,
+ vector::VectorDialect
+ // clang-format on
+ >();
+
+ arith::registerBufferizableOpInterfaceExternalModels(registry);
+ linalg::registerBufferizableOpInterfaceExternalModels(registry);
+ scf::registerBufferizableOpInterfaceExternalModels(registry);
+ linalg::comprehensive_bufferize::std_ext::
+ registerModuleBufferizationExternalModels(registry);
+ tensor::registerBufferizableOpInterfaceExternalModels(registry);
+ vector::registerBufferizableOpInterfaceExternalModels(registry);
+ }
+
+ void runTransformModuleOnOperation(ModuleOp module, Operation *op) {
+ if (!module)
+ return signalPassFailure();
+
+ auto result = module->walk([&](linalg::transform::SequenceOp sequenceOp) {
+ if (failed(executeSequence(sequenceOp, op)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ if (result.wasInterrupted())
+ signalPassFailure();
+ }
+
+ void runOnOperation() override {
+ // If no transform file is specified, assume the transforms live in the
+ // same module as the IR. The considered ModuleOp is either `getOperation()`
+ // if it is already a ModuleOp, or the first parent ModuleOp.
+ if (clTransformFileName.empty()) {
+ ModuleOp module = dyn_cast<ModuleOp>(getOperation());
+ if (!module)
+ module = getOperation()->getParentOfType<ModuleOp>();
+ return runTransformModuleOnOperation(module, getOperation());
+ }
+
+ // If a transform file is specified, parse its content into a ModuleOp.
+ std::string errorMessage;
+ auto memoryBuffer = openInputFile(clTransformFileName, &errorMessage);
+ if (!memoryBuffer) {
+ llvm::errs() << errorMessage << "\n";
+ return signalPassFailure();
+ }
+ // Tell sourceMgr about this buffer, the parser will pick it up.
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+ OwningOpRef<ModuleOp> module(
+ parseSourceFile<ModuleOp>(sourceMgr, &getContext()));
+ runTransformModuleOnOperation(module.get(), getOperation());
+ }
+};
+
+struct DropSchedulePass : public PassWrapper<DropSchedulePass, Pass> {
+ StringRef getArgument() const final { return "linalg-drop-schedule"; }
+
+ StringRef getDescription() const final {
+ return "Drop the schedule from the operation";
+ }
+
+ bool canScheduleOn(RegisteredOperationName opName) const override {
+ return true;
+ }
+
+ void runOnOperation() override {
+ getOperation()->walk([&](Operation *nestedOp) {
+ if (isa<linalg::transform::SequenceOp>(nestedOp) ||
+ isa<pdl::PatternOp>(nestedOp))
+ nestedOp->erase();
+ });
+ }
+};
+} // namespace
+
+namespace mlir {
+/// Create a Linalg Transform interpreter pass.
+std::unique_ptr<Pass> createLinalgTransformInterpreterPass() {
+ return std::make_unique<InterpreterPass>();
+}
+/// Create a Linalg pass to drop the schedule from the module.
+std::unique_ptr<Pass> createDropSchedulePass() {
+ return std::make_unique<DropSchedulePass>();
+}
+} // namespace mlir
+
+/// Registration hook for the Linalg Transform interpreter pass.
+void mlir::linalg::transform::registerLinalgTransformInterpreterPass() {
+ PassRegistration<InterpreterPass>();
+}
+
+/// Registration hook for the Linalg drop schedule from module pass.
+void mlir::linalg::transform::registerDropSchedulePass() {
+ PassRegistration<DropSchedulePass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp
index 82381bc..164187e 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp
@@ -8,11 +8,11 @@
#include "iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
@@ -110,11 +110,14 @@
}
// IntegerType
-LogicalResult PYDM::IntegerType::verify(
- function_ref<InFlightDiagnostic()> emitError, Optional<int> bitWidth) {
- if (!bitWidth) return success();
+LogicalResult
+PYDM::IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
+ Optional<int> bitWidth) {
+ if (!bitWidth)
+ return success();
int w = abs(*bitWidth);
- if (w == 0 || w == 8 || w == 16 || w == 32 || w == 64) return success();
+ if (w == 0 || w == 8 || w == 16 || w == 32 || w == 64)
+ return success();
return emitError() << "unsupported python integer bit width: " << w;
}
@@ -124,10 +127,12 @@
return parser.emitError(parser.getCurrentLocation());
};
// Weak
- if (failed(parser.parseOptionalLess())) return get(ctxt);
+ if (failed(parser.parseOptionalLess()))
+ return get(ctxt);
// AP
if (succeeded(parser.parseOptionalStar())) {
- if (failed(parser.parseGreater())) return Type();
+ if (failed(parser.parseGreater()))
+ return Type();
return get(ctxt, None);
}
@@ -140,9 +145,12 @@
}
int width;
- if (failed(parser.parseInteger(width))) return Type();
- if (failed(parser.parseGreater())) return Type();
- if (!isSigned) width = -width;
+ if (failed(parser.parseInteger(width)))
+ return Type();
+ if (failed(parser.parseGreater()))
+ return Type();
+ if (!isSigned)
+ width = -width;
return getChecked(emitError, ctxt, width);
}
@@ -169,32 +177,36 @@
StringRef PYDM::IntegerType::getPythonTypeName() const { return "int"; }
Optional<NumericCategory> PYDM::IntegerType::getNumericCategory() const {
- if (isWeak()) return NumericCategory::WeakInteger;
- if (getBitWidth() == 0) return NumericCategory::APSigned;
- if (isSigned()) return NumericCategory::Signed;
+ if (isWeak())
+ return NumericCategory::WeakInteger;
+ if (getBitWidth() == 0)
+ return NumericCategory::APSigned;
+ if (isSigned())
+ return NumericCategory::Signed;
return NumericCategory::Unsigned;
}
Optional<int> PYDM::IntegerType::getNumericSubTypeCode() const {
- if (isWeak()) return 0;
+ if (isWeak())
+ return 0;
IntegerSubTypeCode stc;
switch (getBitWidth()) {
- case 8:
- stc = IntegerSubTypeCode::Integer8;
- break;
- case 16:
- stc = IntegerSubTypeCode::Integer16;
- break;
- case 32:
- stc = IntegerSubTypeCode::Integer32;
- break;
- case 64:
- stc = IntegerSubTypeCode::Integer64;
- break;
- default: {
- stc = IntegerSubTypeCode::Integer8; // Arbitrarily picked value.
- assert(false && "unsupported numeric bitwidth");
- }
+ case 8:
+ stc = IntegerSubTypeCode::Integer8;
+ break;
+ case 16:
+ stc = IntegerSubTypeCode::Integer16;
+ break;
+ case 32:
+ stc = IntegerSubTypeCode::Integer32;
+ break;
+ case 64:
+ stc = IntegerSubTypeCode::Integer64;
+ break;
+ default: {
+ stc = IntegerSubTypeCode::Integer8; // Arbitrarily picked value.
+ assert(false && "unsupported numeric bitwidth");
+ }
}
return static_cast<int>(stc);
}
@@ -221,15 +233,15 @@
getImpl()->storageClass != CollectionStorageClass::Boxed) {
printer << "<";
switch (getImpl()->storageClass) {
- case CollectionStorageClass::Boxed:
- printer << "boxed";
- break;
- case CollectionStorageClass::Empty:
- printer << "empty";
- break;
- case CollectionStorageClass::Unboxed:
- printer << "unboxed";
- break;
+ case CollectionStorageClass::Boxed:
+ printer << "boxed";
+ break;
+ case CollectionStorageClass::Empty:
+ printer << "empty";
+ break;
+ case CollectionStorageClass::Unboxed:
+ printer << "unboxed";
+ break;
}
if (getImpl()->uniformElementType) {
@@ -247,10 +259,14 @@
Type t;
StringRef storageClassKeyword;
- if (parser.parseKeyword(&storageClassKeyword)) return Type();
- if (parser.parseComma()) return Type();
- if (parser.parseType(t)) return Type();
- if (parser.parseGreater()) return Type();
+ if (parser.parseKeyword(&storageClassKeyword))
+ return Type();
+ if (parser.parseComma())
+ return Type();
+ if (parser.parseType(t))
+ return Type();
+ if (parser.parseGreater())
+ return Type();
CollectionStorageClass storageClass;
if (storageClassKeyword == "boxed")
@@ -274,9 +290,11 @@
}
bool PYDM::ListType::isRefinable() const {
- if (getStorageClass() == CollectionStorageClass::Empty) return false;
+ if (getStorageClass() == CollectionStorageClass::Empty)
+ return false;
- if (!getUniformElementType()) return true;
+ if (!getUniformElementType())
+ return true;
if (auto pyType = getUniformElementType().dyn_cast<PythonTypeInterface>())
return pyType.isRefinable();
@@ -286,16 +304,16 @@
Type PYDM::ListType::getElementStorageType() const {
switch (getStorageClass()) {
- case CollectionStorageClass::Boxed:
- case CollectionStorageClass::Empty:
- return ObjectType::get(getContext());
- case CollectionStorageClass::Unboxed:
- assert(getUniformElementType() &&
- "unboxed list should have uniform element type");
- return getUniformElementType();
- default:
- assert(false && "unsupported storage class");
- return {};
+ case CollectionStorageClass::Boxed:
+ case CollectionStorageClass::Empty:
+ return ObjectType::get(getContext());
+ case CollectionStorageClass::Unboxed:
+ assert(getUniformElementType() &&
+ "unboxed list should have uniform element type");
+ return getUniformElementType();
+ default:
+ assert(false && "unsupported storage class");
+ return {};
}
}
@@ -310,11 +328,14 @@
Type PyObjectType::parse(mlir::AsmParser &parser) {
MLIRContext *ctxt = parser.getContext();
- if (parser.parseOptionalLess()) return get(ctxt, nullptr);
+ if (parser.parseOptionalLess())
+ return get(ctxt, nullptr);
Type t;
- if (parser.parseType(t)) return Type();
- if (parser.parseGreater()) return Type();
+ if (parser.parseType(t))
+ return Type();
+ if (parser.parseGreater())
+ return Type();
if (auto primitiveType = t.dyn_cast<PrimitiveType>())
return get(ctxt, primitiveType);
else {
@@ -330,7 +351,8 @@
StringRef PYDM::ObjectType::getPythonTypeName() const { return "object"; }
bool PYDM::ObjectType::isRefinable() const {
- if (!getPrimitiveType()) return true;
+ if (!getPrimitiveType())
+ return true;
if (auto pyType = getPrimitiveType().dyn_cast<PythonTypeInterface>())
return pyType.isRefinable();
@@ -341,7 +363,8 @@
// RealType
void PyRealType::print(mlir::AsmPrinter &printer) const {
auto ft = getImpl()->floatType;
- if (ft) printer << "<" << ft << ">";
+ if (ft)
+ printer << "<" << ft << ">";
}
Type PyRealType::parse(mlir::AsmParser &parser) {
@@ -351,17 +374,22 @@
return parser.emitError(parser.getCurrentLocation());
};
// Weak
- if (failed(parser.parseOptionalLess())) return get(ctxt);
+ if (failed(parser.parseOptionalLess()))
+ return get(ctxt);
// Explicit
FloatType subType;
- if (failed(parser.parseType(subType))) return Type();
- if (failed(parser.parseGreater())) return Type();
+ if (failed(parser.parseType(subType)))
+ return Type();
+ if (failed(parser.parseGreater()))
+ return Type();
return getChecked(emitError, ctxt, subType);
}
-LogicalResult PYDM::RealType::verify(
- function_ref<InFlightDiagnostic()> emitError, FloatType floatType) {
- if (!floatType) return success();
+LogicalResult
+PYDM::RealType::verify(function_ref<InFlightDiagnostic()> emitError,
+ FloatType floatType) {
+ if (!floatType)
+ return success();
if (!floatType.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>()) {
return emitError() << "unsupported Python floating point type: "
<< floatType;
@@ -377,12 +405,14 @@
StringRef PYDM::RealType::getPythonTypeName() const { return "float"; }
Optional<NumericCategory> PYDM::RealType::getNumericCategory() const {
- if (isWeak()) return NumericCategory::WeakReal;
+ if (isWeak())
+ return NumericCategory::WeakReal;
return NumericCategory::Real;
}
Optional<int> PYDM::RealType::getNumericSubTypeCode() const {
- if (isWeak()) return 0;
+ if (isWeak())
+ return 0;
RealSubTypeCode stc =
TypeSwitch<Type, RealSubTypeCode>(getFloatType())
.Case([](BFloat16Type t) { return RealSubTypeCode::BF16; })
@@ -438,13 +468,15 @@
Type PyUnionType::parse(mlir::AsmParser &parser) {
MLIRContext *ctxt = parser.getContext();
- if (parser.parseOptionalLess()) return get(ctxt, {});
+ if (parser.parseOptionalLess())
+ return get(ctxt, {});
SmallVector<::mlir::Type> alternatives;
do {
Type type;
- if (parser.parseType(type)) return Type();
+ if (parser.parseType(type))
+ return Type();
alternatives.push_back(type);
} while (succeeded(parser.parseOptionalComma()));
@@ -452,9 +484,9 @@
ctxt, alternatives);
}
-LogicalResult PYDM::UnionType::verify(
- llvm::function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<Type> alternatives) {
+LogicalResult
+PYDM::UnionType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<Type> alternatives) {
int lastTypeCode = 0;
for (Type alternative : alternatives) {
if (auto pythonType = alternative.dyn_cast<PYDM::PythonTypeInterface>()) {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp
index 2010688..b281874 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp
@@ -7,14 +7,14 @@
#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
-#include "llvm/ADT/SmallSet.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
using namespace mlir;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
@@ -50,7 +50,8 @@
for (int operandIndex = 0, e = operands.size(); operandIndex < e;
++operandIndex) {
Value &operand = operands[operandIndex];
- if (operandIndices && !operandIndices->contains(operandIndex)) continue;
+ if (operandIndices && !operandIndices->contains(operandIndex))
+ continue;
if (auto objectType = operand.getType().dyn_cast<ObjectType>()) {
Type primitiveType = objectType.getPrimitiveType();
if (primitiveType) {
@@ -73,7 +74,7 @@
Optional<llvm::SmallSet<int, 4>> operandIndices;
};
-} // namespace
+} // namespace
static Value getNumericZeroConstant(Location loc, Type numericType,
OpBuilder &builder) {
@@ -132,7 +133,8 @@
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ApplyBinaryOp op,
PatternRewriter &rewriter) const override {
- if (op.dunder_name() != "mul") return failure();
+ if (op.dunder_name() != "mul")
+ return failure();
Value listOperand;
Value countOperand;
if (isBuiltinSequence(op.left()) && isInteger(op.right())) {
@@ -157,7 +159,7 @@
return operand.getType().isa<PYDM::IntegerType>();
}
};
-} // namespace
+} // namespace
void ApplyBinaryOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
@@ -219,7 +221,7 @@
namespace {
struct FoldAsBoolFromBool : public OpRewritePattern<AsBoolOp> {
- public:
+public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AsBoolOp op,
PatternRewriter &rewriter) const override {
@@ -232,14 +234,16 @@
};
struct FoldAsBoolFromNumeric : public OpRewritePattern<AsBoolOp> {
- public:
+public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AsBoolOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ptType = op.value().getType().dyn_cast<PythonTypeInterface>();
- if (!ptType) return failure();
- if (!ptType.getNumericPromotionOrder()) return failure();
+ if (!ptType)
+ return failure();
+ if (!ptType.getNumericPromotionOrder())
+ return failure();
auto boolType = rewriter.getType<BoolType>();
Value zeroValue =
@@ -254,7 +258,7 @@
}
};
-} // namespace
+} // namespace
void AsBoolOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
@@ -288,7 +292,8 @@
//===----------------------------------------------------------------------===//
OpFoldResult BoolToPredOp::fold(ArrayRef<Attribute> operands) {
- if (!operands[0]) return {};
+ if (!operands[0])
+ return {};
// Since both BoolType and I1 share the attribute form (an IntegerAttr of I1),
// we can just return it.
return operands[0];
@@ -351,7 +356,7 @@
/// or insert specific PromoteNumeric ops.
struct ResolveNumericDynamicBinaryPromote
: public OpRewritePattern<DynamicBinaryPromoteOp> {
- public:
+public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op,
PatternRewriter &rewriter) const override {
@@ -362,7 +367,8 @@
auto rightResultType = op.getResultTypes()[1];
auto leftPt = leftType.dyn_cast<PythonTypeInterface>();
auto rightPt = rightType.dyn_cast<PythonTypeInterface>();
- if (!leftPt || !rightPt) return failure();
+ if (!leftPt || !rightPt)
+ return failure();
Optional<int> leftOrder = leftPt.getNumericPromotionOrder();
Optional<int> rightOrder = rightPt.getNumericPromotionOrder();
@@ -395,7 +401,7 @@
/// numeric type, then the op has no meaning and is elided.
struct ElideNonNumericDynamicBinaryPromote
: public OpRewritePattern<DynamicBinaryPromoteOp> {
- public:
+public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op,
PatternRewriter &rewriter) const override {
@@ -415,14 +421,16 @@
}
static bool isConcreteNonNumericType(Type t) {
- if (t.isa<ObjectType>()) return false;
+ if (t.isa<ObjectType>())
+ return false;
auto pt = t.dyn_cast<PythonTypeInterface>();
- if (!pt || pt.getNumericPromotionOrder()) return false;
+ if (!pt || pt.getNumericPromotionOrder())
+ return false;
return true;
}
};
-} // namespace
+} // namespace
void DynamicBinaryPromoteOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
@@ -456,7 +464,8 @@
parser.resolveOperand(cond, conditionType, result.operands))
return failure();
// Parse optional results type list.
- if (parser.parseOptionalArrowTypeList(result.types)) return failure();
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
// Parse the 'then' region.
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
@@ -471,7 +480,8 @@
}
// Parse the optional attribute list.
- if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
return success();
}
@@ -518,7 +528,8 @@
// Don't consider the else region if it is empty.
Region *elseRegion = &this->elseRegion();
- if (elseRegion->empty()) elseRegion = nullptr;
+ if (elseRegion->empty())
+ elseRegion = nullptr;
// Otherwise, the successor is dependent on the condition.
if (auto condAttr = operands.front().dyn_cast_or_null<BoolAttr>()) {
@@ -529,7 +540,8 @@
// If the condition isn't constant, both regions may be executed.
regions.push_back(RegionSuccessor(&thenRegion()));
// If the else region does not exist, it is not a viable successor.
- if (elseRegion) regions.push_back(RegionSuccessor(elseRegion));
+ if (elseRegion)
+ regions.push_back(RegionSuccessor(elseRegion));
}
}
@@ -567,31 +579,31 @@
LogicalResult MakeListOp::verify() {
auto listType = list().getType().cast<ListType>();
switch (listType.getStorageClass()) {
- case CollectionStorageClass::Boxed:
- for (auto element : elements()) {
- if (!element.getType().isa<ObjectType>()) {
- return emitOpError() << "making a list with boxed storage class "
- "must have object elements. Got: "
- << element.getType();
- }
+ case CollectionStorageClass::Boxed:
+ for (auto element : elements()) {
+ if (!element.getType().isa<ObjectType>()) {
+ return emitOpError() << "making a list with boxed storage class "
+ "must have object elements. Got: "
+ << element.getType();
}
- break;
- case CollectionStorageClass::Unboxed:
- for (auto element : elements()) {
- if (element.getType().isa<ObjectType>()) {
- return emitOpError() << "making a list with unboxed storage class "
- "must not have object elements. Got: "
- << element.getType();
- }
+ }
+ break;
+ case CollectionStorageClass::Unboxed:
+ for (auto element : elements()) {
+ if (element.getType().isa<ObjectType>()) {
+ return emitOpError() << "making a list with unboxed storage class "
+ "must not have object elements. Got: "
+ << element.getType();
}
- break;
- case CollectionStorageClass::Empty:
- if (!elements().empty()) {
- return emitOpError()
- << "making a list with empty storage class must have zero "
- "elements";
- }
- break;
+ }
+ break;
+ case CollectionStorageClass::Empty:
+ if (!elements().empty()) {
+ return emitOpError()
+ << "making a list with empty storage class must have zero "
+ "elements";
+ }
+ break;
}
return success();
}
@@ -617,8 +629,8 @@
// PatternMatchCallOp
//===----------------------------------------------------------------------===//
-LogicalResult PatternMatchCallOp::verifySymbolUses(
- SymbolTableCollection &symbolTable) {
+LogicalResult
+PatternMatchCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto verifySymbols = [&](ArrayAttr symbols) -> LogicalResult {
for (auto symbolAttr : symbols) {
auto symbol = symbolAttr.cast<FlatSymbolRefAttr>();
@@ -634,13 +646,15 @@
if (!genericsAttr)
return emitOpError(
"requires a 'generic_match' array of symbol reference attributes");
- if (failed(verifySymbols(genericsAttr))) return failure();
+ if (failed(verifySymbols(genericsAttr)))
+ return failure();
auto specificsAttr = (*this)->getAttrOfType<ArrayAttr>("specific_match");
if (!specificsAttr)
return emitOpError(
"requires a 'specific_match' array of symbol reference attributes");
- if (failed(verifySymbols(specificsAttr))) return failure();
+ if (failed(verifySymbols(specificsAttr)))
+ return failure();
return success();
}
@@ -650,7 +664,8 @@
//===----------------------------------------------------------------------===//
OpFoldResult PromoteNumericOp::fold(ArrayRef<Attribute> operands) {
- if (!operands[0]) return {};
+ if (!operands[0])
+ return {};
Builder b(getContext());
Attribute fromAttr = operands[0];
@@ -703,7 +718,8 @@
//===----------------------------------------------------------------------===//
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
- if (!operands[0]) return {};
+ if (!operands[0])
+ return {};
BoolAttr boolAttr = operands[0].cast<BoolAttr>();
if (boolAttr.getValue())
@@ -783,8 +799,8 @@
// DynamicCallOp
//===----------------------------------------------------------------------===//
-LogicalResult DynamicCallOp::verifySymbolUses(
- SymbolTableCollection &symbolTable) {
+LogicalResult
+DynamicCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that the callee attribute was specified.
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp
index fc17764..b6e5b51 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp
@@ -98,13 +98,14 @@
}
}
- if (!modified) return ft;
+ if (!modified)
+ return ft;
return FunctionType::get(ft.getContext(), inputs, results);
}
};
-} // namespace
+} // namespace
std::unique_ptr<OperationPass<>> PYDM::createFixateWeakNumericPass() {
return std::make_unique<FixateWeakNumericPass>();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp
index 0836591..e2edfff 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp
@@ -8,9 +8,9 @@
#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
#include "iree-dialects/Dialect/PyDM/Utils/TypeInference.h"
-#include "llvm/Support/Debug.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
using namespace mlir;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
@@ -59,10 +59,13 @@
return signalPassFailure();
}
changed = false;
- if (sinkStaticInfoCasts()) changed = true;
- if (refineResultTypes()) changed = true;
+ if (sinkStaticInfoCasts())
+ changed = true;
+ if (refineResultTypes())
+ changed = true;
permuteRefinedBlocks(propagator);
- if (!changed) break;
+ if (!changed)
+ break;
}
// Now that iteration is complete and we are no longer using the
@@ -142,7 +145,8 @@
Operation *refinableOp = refinable.getOperation();
SmallVector<Type> originalResultTypes(refinableOp->getResultTypes());
LLVM_DEBUG(dbgs() << " refineResultTypes: " << *refinableOp << "\n");
- if (!refinable.refineResultTypes()) return;
+ if (!refinable.refineResultTypes())
+ return;
LLVM_DEBUG(dbgs() << " refineResultTypes changed results: "
<< *refinableOp << "\n");
OpBuilder builder(refinableOp);
@@ -152,7 +156,8 @@
Type origType = std::get<0>(it);
OpResult result = std::get<1>(it);
Type newType = result.getType();
- if (origType == newType) continue;
+ if (origType == newType)
+ continue;
// Insert a static info cast.
// In the future, we could further query the use for refinable
// support and skip creating the op.
@@ -173,7 +178,8 @@
}
auto casted = builder.create<StaticInfoCastOp>(refinableOp->getLoc(),
origType, newResult);
- if (!replaceExcept) replaceExcept = casted;
+ if (!replaceExcept)
+ replaceExcept = casted;
result.replaceAllUsesExcept(casted, replaceExcept);
changed = true;
}
@@ -195,7 +201,8 @@
for (auto *block : blocks) {
auto mismatchedPredecessors =
propagator.findMismatchedBlockPredecessors(block);
- if (mismatchedPredecessors.empty()) continue;
+ if (mismatchedPredecessors.empty())
+ continue;
LLVM_DEBUG(dbgs() << " ++ Processing block " << block << " ("
<< mismatchedPredecessors.size()
<< " mismatched predecessors)\n");
@@ -244,7 +251,7 @@
}
};
-} // namespace
+} // namespace
std::unique_ptr<OperationPass<PYDM::FuncOp>>
PYDM::createLocalPropagateTypesPass() {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp
index 057ffef..7c1c4a4 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp
@@ -7,12 +7,12 @@
#include "../PassDetail.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/Support/Debug.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
using namespace mlir;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
@@ -53,7 +53,8 @@
changed = false;
for (auto &block : getOperation().getBody().getBlocks()) {
auto &info = blockAccessInfos[&block];
- if (canonicalizeBlockVariableAccess(block, info)) changed = true;
+ if (canonicalizeBlockVariableAccess(block, info))
+ changed = true;
hoistLoadsFromBlock(block, info);
// Invalidate internal value map and re-initialize from block arg
@@ -62,7 +63,8 @@
info.variableValueMap = info.blockArgVariableValueMap;
}
- if (!changed) break;
+ if (!changed)
+ break;
}
// We should now have eliminated as many loads as possible, so we can
@@ -167,7 +169,8 @@
// legal form where all allocs are done in the entry block.
void hoistLoadsFromBlock(Block &block, BlockAccessInfo &info) {
// Entry block: nowhere to hoist.
- if (block.isEntryBlock()) return;
+ if (block.isEntryBlock())
+ return;
SmallVector<std::tuple<Location, Value, Type>> loadVarTypes;
// Redirect each live load to a block argument.
@@ -211,7 +214,7 @@
}
};
-} // namespace
+} // namespace
std::unique_ptr<OperationPass<PYDM::FuncOp>> PYDM::createVariablesToSSAPass() {
return std::make_unique<VariablesToSSAPass>();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h
index 9fbfc52..5b23dc4 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h
@@ -24,9 +24,9 @@
#define GEN_PASS_CLASSES
#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h.inc"
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
-#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp
index 820338c..515fc26 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp
@@ -14,15 +14,15 @@
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
using namespace PYDM;
-void PYDM::buildPostImportPassPipeline(OpPassManager& passManager) {
+void PYDM::buildPostImportPassPipeline(OpPassManager &passManager) {
passManager.addNestedPass<PYDM::FuncOp>(createVariablesToSSAPass());
passManager.addNestedPass<PYDM::FuncOp>(createLocalPropagateTypesPass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());
}
-void PYDM::buildLowerToIREEPassPipeline(OpPassManager& passManager,
- const LowerToIREEOptions& options) {
+void PYDM::buildLowerToIREEPassPipeline(OpPassManager &passManager,
+ const LowerToIREEOptions &options) {
// TODO: Needs to be iterative, support optimization passes, etc.
passManager.addPass(createLowerIREEPyDMToRTLPass());
if (options.linkRtlSource) {
@@ -45,22 +45,22 @@
namespace {
#define GEN_PASS_REGISTRATION
#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h.inc"
-} // namespace
-} // namespace PYDM_generated
+} // namespace
+} // namespace PYDM_generated
void PYDM::registerPasses() {
PYDM_generated::registerPasses();
PassPipelineRegistration<> postImportPassPipeline(
"pydm-post-import-pipeline",
"Runs passes to cleanup PyDM immediately post-import",
- [](OpPassManager& passManager) {
+ [](OpPassManager &passManager) {
buildPostImportPassPipeline(passManager);
});
PassPipelineRegistration<> lowerToIREEPipeline(
"pydm-lower-to-iree-pipeline",
"Runs passes to lower PyDM to IREE's input dialects",
- [](OpPassManager& passManager) {
+ [](OpPassManager &passManager) {
LowerToIREEOptions options;
buildLowerToIREEPassPipeline(passManager, options);
});
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp
index 82246a4..7da6c8d 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp
@@ -10,12 +10,12 @@
#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
#include "iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/Support/Debug.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Parser/Parser.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "iree_pydm"
@@ -34,12 +34,12 @@
namespace {
class LinkIREEPyDMRTLPass : public LinkIREEPyDMRTLBase<LinkIREEPyDMRTLPass> {
- public:
+public:
LinkIREEPyDMRTLPass() = default;
LinkIREEPyDMRTLPass(Optional<SourceBundle> linkRtlSourceBundle)
: linkRtlSourceBundle(std::move(linkRtlSourceBundle)) {}
- private:
+private:
LogicalResult initialize(MLIRContext *context) override {
SourceBundle localSource;
if (linkRtlSourceBundle) {
@@ -52,13 +52,15 @@
if (localSource.asmBlob) {
// Parse from inline asm.
auto owningOp = parseSourceString(*localSource.asmBlob, context);
- if (!owningOp) return failure();
+ if (!owningOp)
+ return failure();
rtlModule = std::make_shared<mlir::OwningOpRef<mlir::ModuleOp>>(
std::move(owningOp));
} else if (localSource.asmFilePath) {
// Parse from a file.
auto owningOp = parseSourceFile(*localSource.asmFilePath, context);
- if (!owningOp) return failure();
+ if (!owningOp)
+ return failure();
rtlModule = std::make_shared<mlir::OwningOpRef<mlir::ModuleOp>>(
std::move(owningOp));
} else {
@@ -146,7 +148,8 @@
LLVM_DEBUG(llvm::dbgs() << "+++ Inlining module\n";);
auto result = importModule.getOp()->walk<WalkOrder::PreOrder>(
[&](Operation *importOp) -> WalkResult {
- if (importOp == importModule.getOp()) return WalkResult::advance();
+ if (importOp == importModule.getOp())
+ return WalkResult::advance();
if (auto symbolImportOp = dyn_cast<SymbolOpInterface>(importOp)) {
StringAttr name = symbolImportOp.getNameAttr();
Operation *existing = programSymbolTable.lookup(name);
@@ -172,7 +175,8 @@
}
return WalkResult::skip();
});
- if (result.wasInterrupted()) return failure();
+ if (result.wasInterrupted())
+ return failure();
LLVM_DEBUG(llvm::dbgs() << "--- Inlining complete\n";);
return success();
}
@@ -209,9 +213,9 @@
Optional<SourceBundle> linkRtlSourceBundle;
};
-} // namespace
+} // namespace
-std::unique_ptr<OperationPass<ModuleOp>> PYDM::createLinkIREEPyDMRTLPass(
- Optional<SourceBundle> linkRtlSourceBundle) {
+std::unique_ptr<OperationPass<ModuleOp>>
+PYDM::createLinkIREEPyDMRTLPass(Optional<SourceBundle> linkRtlSourceBundle) {
return std::make_unique<LinkIREEPyDMRTLPass>(std::move(linkRtlSourceBundle));
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp
index aee2bf2..c348715 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp
@@ -15,7 +15,8 @@
LinkageAnalysis::LinkageAnalysis(Operation *moduleOp) {
moduleOp->walk<WalkOrder::PreOrder>([&](PYDM::FuncOp f) {
- if (f.empty()) externFuncOps.push_back(f);
+ if (f.empty())
+ externFuncOps.push_back(f);
// We don't need to descend into functions so just skip them.
return WalkResult::skip();
});
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp
index b8aa35e..8ae2afd 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp
@@ -20,7 +20,7 @@
namespace {
class RtlFunc {
- protected:
+protected:
FunctionType makeRaisingSignature(Builder b, ArrayRef<Type> inputs,
Type output) {
return b.getType<FunctionType>(
@@ -33,7 +33,8 @@
OpBuilder builder(symbolTable.getOp()->getContext());
auto name = builder.getStringAttr(rtlFunc.getRtlName());
auto *existing = symbolTable.lookup(name);
- if (existing) return existing;
+ if (existing)
+ return existing;
// Does not exist - create detached and insert.
FunctionType signature = rtlFunc.getRtlSignature(builder);
@@ -66,7 +67,7 @@
/// pydmrtl$apply_binary_${dunderName} RTL func.
class ApplyBinaryFunc : public RtlFunc {
- public:
+public:
ApplyBinaryFunc(StringRef dunderName) : rtlName("pydmrtl$apply_binary_") {
rtlName.append(dunderName.begin(), dunderName.end());
}
@@ -76,13 +77,13 @@
return makeRaisingSignature(b, {objectType, objectType}, objectType);
}
- private:
+private:
std::string rtlName;
};
/// pydmrtl$apply_compare_${dunderName} RTL func.
class ApplyCompareFunc : public RtlFunc {
- public:
+public:
ApplyCompareFunc(StringRef dunderName) : rtlName("pydmrtl$apply_compare_") {
rtlName.append(dunderName.begin(), dunderName.end());
}
@@ -93,19 +94,19 @@
return makeRaisingSignature(b, {objectType, objectType}, boolType);
}
- private:
+private:
std::string rtlName;
};
template <typename RtlFuncTy, typename OpTy>
class EmitImportCallBase : public OpRewritePattern<OpTy> {
- public:
+public:
EmitImportCallBase(SymbolTable &symbolTable, PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>::OpRewritePattern(
symbolTable.getOp()->getContext(), benefit),
symbolTable(symbolTable) {}
- protected:
+protected:
Value emitImportCall(Location loc, ValueRange inputs, RtlFuncTy rtlFunc,
PatternRewriter &rewriter) const {
auto rtlName = rtlFunc.getRtlName();
@@ -152,7 +153,7 @@
}
}
- private:
+private:
SymbolTable &symbolTable;
};
@@ -244,7 +245,7 @@
}
};
-} // namespace
+} // namespace
std::unique_ptr<OperationPass<ModuleOp>> PYDM::createLowerIREEPyDMToRTLPass() {
return std::make_unique<LowerIREEPyDMToRTLPass>();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp
index cd5e652..8efe7ae 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp
@@ -69,7 +69,7 @@
}
};
-} // namespace
+} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
PYDM::createConvertIREEPyDMToIREEPass() {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp
index 9530171..9d92711 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp
@@ -7,7 +7,6 @@
#include "iree-dialects/Dialect/Input/InputOps.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
@@ -15,6 +14,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/TypeSwitch.h"
using llvm::enumerate;
using namespace mlir;
@@ -38,7 +38,7 @@
UnboundLocalError = -10,
};
-} // namespace
+} // namespace
static Type getVariantListType(Builder &builder) {
return builder.getType<Input::ListType>(
@@ -117,8 +117,9 @@
}
}
-static Optional<arith::CmpIPredicate> convertIntegerComparePredicate(
- StringAttr dunderName, bool isSigned, Builder &builder) {
+static Optional<arith::CmpIPredicate>
+convertIntegerComparePredicate(StringAttr dunderName, bool isSigned,
+ Builder &builder) {
StringRef v = dunderName.getValue();
if (v == "lt") {
return isSigned ? arith::CmpIPredicate::slt : arith::CmpIPredicate::ult;
@@ -137,8 +138,8 @@
return {};
}
-static Optional<arith::CmpFPredicate> convertFpComparePredicate(
- StringAttr dunderName, Builder &builder) {
+static Optional<arith::CmpFPredicate>
+convertFpComparePredicate(StringAttr dunderName, Builder &builder) {
StringRef v = dunderName.getValue();
if (v == "lt") {
return arith::CmpFPredicate::OLT;
@@ -164,9 +165,11 @@
/// Returns nullptr for unsupported cases, not emitting diagnostics.
static Value boxConvertedValue(Location loc, Type pythonType,
Value convertedValue, OpBuilder &builder) {
- if (pythonType.isa<PYDM::ObjectType>()) return convertedValue;
+ if (pythonType.isa<PYDM::ObjectType>())
+ return convertedValue;
auto ptiType = pythonType.dyn_cast<PYDM::PythonTypeInterface>();
- if (!ptiType) return {};
+ if (!ptiType)
+ return {};
auto typeCode = ptiType.getTypeCode();
auto list = createObjectList(loc, builder, static_cast<int>(typeCode),
convertedValue);
@@ -179,9 +182,9 @@
: public OpConversionPattern<PYDM::AllocFreeVarOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::AllocFreeVarOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::AllocFreeVarOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
// TODO: We may want to initialize the list structurally in some way.
// This will fail either way on read from unassigned variable, but we need
// to see what works better for good runtime error messages.
@@ -196,9 +199,9 @@
: public OpConversionPattern<PYDM::ApplyBinaryOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::ApplyBinaryOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::ApplyBinaryOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
Type pyLeftType = srcOp.left().getType();
Type pyRightType = srcOp.right().getType();
Type leftType = adaptor.left().getType();
@@ -276,16 +279,16 @@
: public OpConversionPattern<PYDM::ApplyCompareOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::ApplyCompareOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::ApplyCompareOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
Type leftType = adaptor.left().getType();
Type rightType = adaptor.right().getType();
if (leftType != rightType) {
return rewriter.notifyMatchFailure(srcOp, "not same type operands");
}
if (leftType.isa<mlir::IntegerType>()) {
- bool isSigned = true; // TODO: Unsigned.
+ bool isSigned = true; // TODO: Unsigned.
auto predicate = convertIntegerComparePredicate(adaptor.dunder_nameAttr(),
isSigned, rewriter);
if (!predicate)
@@ -310,9 +313,9 @@
class AssignSubscriptListConversion
: public OpConversionPattern<PYDM::AssignSubscriptOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::AssignSubscriptOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::AssignSubscriptOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto pySequence = srcOp.lhs();
if (!pySequence.getType().isa<PYDM::ListType>())
return rewriter.notifyMatchFailure(srcOp, "not builtin sequence");
@@ -411,14 +414,14 @@
Value boxIfNecessary(Location loc, PYDM::ListType listType, Type origRhsType,
Value rhs, ConversionPatternRewriter &rewriter) const {
switch (listType.getStorageClass()) {
- case CollectionStorageClass::Boxed:
- case CollectionStorageClass::Empty: {
- return boxConvertedValue(loc, origRhsType, rhs, rewriter);
- break;
- }
- case CollectionStorageClass::Unboxed:
- // TODO: Implement.
- return nullptr;
+ case CollectionStorageClass::Boxed:
+ case CollectionStorageClass::Empty: {
+ return boxConvertedValue(loc, origRhsType, rhs, rewriter);
+ break;
+ }
+ case CollectionStorageClass::Unboxed:
+ // TODO: Implement.
+ return nullptr;
}
}
};
@@ -426,9 +429,9 @@
class BoolToPredConversion : public OpConversionPattern<PYDM::BoolToPredOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::BoolToPredOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::BoolToPredOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(srcOp, adaptor.value());
return success();
}
@@ -437,9 +440,9 @@
class BoxOpConversion : public OpConversionPattern<PYDM::BoxOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::BoxOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::BoxOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
Value boxedValue = boxConvertedValue(loc, srcOp.primitive().getType(),
adaptor.primitive(), rewriter);
@@ -454,9 +457,9 @@
class CallOpConversion : public OpConversionPattern<PYDM::CallOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::CallOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::CallOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(getTypeConverter()->convertTypes(srcOp.getResultTypes(),
resultTypes))) {
@@ -472,9 +475,9 @@
class ConstantOpConversion : public OpConversionPattern<PYDM::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::ConstantOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::ConstantOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
Type resultType = typeConverter->convertType(srcOp.getResult().getType());
if (!resultType)
@@ -522,9 +525,9 @@
: public OpConversionPattern<PYDM::DynamicUnpackOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::DynamicUnpackOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::DynamicUnpackOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
// Convert types.
Type excResultType =
@@ -604,9 +607,9 @@
class ElideStaticInfoCast : public OpConversionPattern<PYDM::StaticInfoCastOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::StaticInfoCastOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::StaticInfoCastOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(srcOp, srcOp.value());
return success();
}
@@ -617,9 +620,9 @@
class FailureOpConversion : public OpConversionPattern<PYDM::FailureOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::FailureOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::FailureOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
Type i32 = rewriter.getI32Type();
// '-3' == RuntimeError
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
@@ -631,9 +634,9 @@
class FuncOpConversion : public OpConversionPattern<PYDM::FuncOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::FuncOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::FuncOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
FunctionType srcFuncType = srcOp.getType();
TypeConverter::SignatureConversion signatureConversion(
srcOp.getNumArguments());
@@ -679,9 +682,9 @@
class GetTypeCodeConversion : public OpConversionPattern<PYDM::GetTypeCodeOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::GetTypeCodeOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::GetTypeCodeOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
// Gets the 0'th element of the object list, optionally casting it to the
// converted integer type.
@@ -694,10 +697,10 @@
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value typeCode = rewriter.create<Input::ListGetOp>(loc, i32Type,
adaptor.value(), index0);
- rewriter.replaceOp(
- srcOp,
- castIntegerValue(loc, typeCode, resultType.cast<mlir::IntegerType>(),
- rewriter));
+ rewriter.replaceOp(srcOp,
+ castIntegerValue(loc, typeCode,
+ resultType.cast<mlir::IntegerType>(),
+ rewriter));
return success();
}
};
@@ -705,9 +708,9 @@
class LoadVarOpConversion : public OpConversionPattern<PYDM::LoadVarOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::LoadVarOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::LoadVarOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
auto resultType =
getTypeConverter()->convertType(srcOp.getResult().getType());
@@ -726,9 +729,9 @@
class MakeListOpBoxedConversion : public OpConversionPattern<PYDM::MakeListOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::MakeListOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::MakeListOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
auto listType = srcOp.list().getType().cast<PYDM::ListType>();
if (listType.getStorageClass() != CollectionStorageClass::Boxed ||
@@ -759,9 +762,9 @@
class MakeTupleOpConversion : public OpConversionPattern<PYDM::MakeTupleOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::MakeTupleOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::MakeTupleOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
auto resultType = getTypeConverter()->convertType(srcOp.tuple().getType());
if (!resultType)
@@ -788,9 +791,9 @@
/// Converts the `neg` op on integer operand/result to a corresponding sub.
class NegIntegerOpConversion : public OpConversionPattern<PYDM::NegOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::NegOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::NegOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
Type valueType = adaptor.value().getType();
Type resultType = getTypeConverter()->convertType(srcOp.result().getType());
if (!valueType.isa<mlir::IntegerType>() || valueType != resultType)
@@ -808,9 +811,9 @@
class NoneOpConversion : public OpConversionPattern<PYDM::NoneOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::NoneOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::NoneOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
Type i32 = rewriter.getI32Type();
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
srcOp, i32, rewriter.getIntegerAttr(i32, 0));
@@ -825,9 +828,9 @@
: public OpConversionPattern<PYDM::RaiseOnFailureOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::RaiseOnFailureOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::RaiseOnFailureOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
Value status = adaptor.getOperands()[0];
@@ -866,9 +869,9 @@
class ReturnOpConversion : public OpConversionPattern<PYDM::ReturnOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::ReturnOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::ReturnOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
auto zeroResult =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
@@ -883,12 +886,14 @@
: public OpConversionPattern<PYDM::SequenceCloneOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::SequenceCloneOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::SequenceCloneOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
Type origListType = srcOp.sequence().getType();
- if (!isSupportedList(origListType)) return failure();
- if (origListType != srcOp.getResult().getType()) return failure();
+ if (!isSupportedList(origListType))
+ return failure();
+ if (origListType != srcOp.getResult().getType())
+ return failure();
Type resultType = typeConverter->convertType(srcOp.getResult().getType());
if (!resultType) {
return rewriter.notifyMatchFailure(srcOp, "cannot convert result type");
@@ -1016,9 +1021,9 @@
class StoreVarOpConversion : public OpConversionPattern<PYDM::StoreVarOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::StoreVarOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::StoreVarOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
auto origStoreType =
@@ -1043,9 +1048,9 @@
: public OpConversionPattern<PYDM::SubscriptOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::SubscriptOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::SubscriptOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto pySequence = srcOp.value();
if (!pySequence.getType().isa<PYDM::ListType, PYDM::TupleType>())
return rewriter.notifyMatchFailure(srcOp, "not builtin sequence");
@@ -1149,9 +1154,9 @@
class UnboxOpConversion : public OpConversionPattern<PYDM::UnboxOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- PYDM::UnboxOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(PYDM::UnboxOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
auto list = adaptor.getOperands()[0];
@@ -1234,9 +1239,9 @@
class BuiltinBranchConversion : public OpConversionPattern<mlir::cf::BranchOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- mlir::cf::BranchOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(mlir::cf::BranchOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(srcOp, srcOp.getDest(),
adaptor.getDestOperands());
return success();
@@ -1246,9 +1251,9 @@
class BuiltinCondBranchConversion
: public OpConversionPattern<mlir::cf::CondBranchOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- mlir::cf::CondBranchOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(mlir::cf::CondBranchOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
srcOp, adaptor.getCondition(), srcOp.getTrueDest(),
adaptor.getTrueDestOperands(), srcOp.getFalseDest(),
@@ -1260,9 +1265,9 @@
class BuiltinSelectConversion
: public OpConversionPattern<mlir::arith::SelectOp> {
using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- mlir::arith::SelectOp srcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(mlir::arith::SelectOp srcOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::arith::SelectOp>(
srcOp, adaptor.getCondition(), adaptor.getTrueValue(),
adaptor.getFalseValue());
@@ -1270,7 +1275,7 @@
}
};
-} // namespace
+} // namespace
void PYDM::populatePyDMToIREELoweringPatterns(MLIRContext *context,
TypeConverter &typeConverter,
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp
index a16fe16..c618545 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp
@@ -96,10 +96,10 @@
Type LoweringTypeConverter::getWeakFloatType(Builder b) const {
switch (weakFloatType) {
- case WeakFloatType::F32:
- return b.getF32Type();
- case WeakFloatType::F64:
- return b.getF64Type();
+ case WeakFloatType::F32:
+ return b.getF32Type();
+ case WeakFloatType::F64:
+ return b.getF64Type();
}
}
@@ -110,7 +110,8 @@
bool LoweringTypeConverter::areTypesLegal(TypeRange types) const {
for (Type t : types) {
- if (!isTypeLegal(t)) return false;
+ if (!isTypeLegal(t))
+ return false;
}
return true;
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp
index c7212e2..2acd110 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp
@@ -49,7 +49,8 @@
FunctionType signature) {
for (PermutedBlockInfo *info = parentInfo->permutationHead; info;
info = info->next) {
- if (info->signature == signature) return info->permutedBlock;
+ if (info->signature == signature)
+ return info->permutedBlock;
}
return nullptr;
}
@@ -57,7 +58,8 @@
static bool checkAllBlockArgsMapped(Block *block,
BlockAndValueMapping &mapping) {
for (Value arg : block->getArguments()) {
- if (!mapping.contains(arg)) return false;
+ if (!mapping.contains(arg))
+ return false;
}
return true;
}
@@ -95,7 +97,8 @@
auto branchOp = llvm::cast<BranchOpInterface>(terminator);
unsigned successorIndex = 0;
for (Block *successor : terminator->getSuccessors()) {
- if (successor == block) break;
+ if (successor == block)
+ break;
successorIndex += 1;
}
auto successorOperands = branchOp.getSuccessorOperands(successorIndex);
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..08d39ff
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/CMakeLists.txt
@@ -0,0 +1,18 @@
+
+add_mlir_library(IREEDialectsTransforms
+ Listener.cpp
+ ListenerCSE.cpp
+ ListenerGreedyPatternRewriteDriver.cpp
+
+ LINK_LIBS PRIVATE
+ # TODO: break dialect dependency by implementing the transformation separately
+ # and registering it.
+ MLIRAsync
+ MLIRLinalg
+ MLIRLinalgTransforms
+
+ DEPENDS
+ mlir-headers
+ IREELinalgExtIncGen
+ IREELinalgExtInterfacesIncGen
+)
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/Listener.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/Listener.cpp
new file mode 100644
index 0000000..5120993
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/Listener.cpp
@@ -0,0 +1,48 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Transforms/Listener.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// RewriteListener
+//===----------------------------------------------------------------------===//
+
+RewriteListener::~RewriteListener() = default;
+
+//===----------------------------------------------------------------------===//
+// ListenerList
+//===----------------------------------------------------------------------===//
+
+void ListenerList::notifyOperationInserted(Operation *op) {
+ for (RewriteListener *listener : listeners)
+ listener->notifyOperationInserted(op);
+}
+
+void ListenerList::notifyBlockCreated(Block *block) {
+ for (RewriteListener *listener : listeners)
+ listener->notifyBlockCreated(block);
+}
+
+void ListenerList::notifyOperationReplaced(Operation *op,
+ ValueRange newValues) {
+ for (RewriteListener *listener : listeners)
+ listener->notifyOperationReplaced(op, newValues);
+}
+
+void ListenerList::notifyOperationRemoved(Operation *op) {
+ for (RewriteListener *listener : listeners)
+ listener->notifyOperationRemoved(op);
+}
+
+void ListenerList::notifyMatchFailure(
+ Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
+ for (RewriteListener *listener : listeners)
+ listener->notifyMatchFailure(op, reasonCallback);
+}
+
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/ListenerCSE.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/ListenerCSE.cpp
new file mode 100644
index 0000000..1ea8eb9
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/ListenerCSE.cpp
@@ -0,0 +1,309 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Transforms/ListenerCSE.h"
+
+#include <deque>
+
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/Support/RecyclingAllocator.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// BEGIN copied from mlir/lib/Transforms/CSE.cpp
+//===----------------------------------------------------------------------===//
+namespace {
+struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
+ static unsigned getHashValue(const Operation *opC) {
+ return OperationEquivalence::computeHash(
+ const_cast<Operation *>(opC),
+ /*hashOperands=*/OperationEquivalence::directHashValue,
+ /*hashResults=*/OperationEquivalence::ignoreHashValue,
+ OperationEquivalence::IgnoreLocations);
+ }
+ static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
+ auto *lhs = const_cast<Operation *>(lhsC);
+ auto *rhs = const_cast<Operation *>(rhsC);
+ if (lhs == rhs)
+ return true;
+ if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
+ rhs == getTombstoneKey() || rhs == getEmptyKey())
+ return false;
+ return OperationEquivalence::isEquivalentTo(
+ const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
+ /*mapOperands=*/OperationEquivalence::exactValueMatch,
+ /*mapResults=*/OperationEquivalence::ignoreValueEquivalence,
+ OperationEquivalence::IgnoreLocations);
+ }
+};
+} // namespace
+
+namespace {
+/// Simple common sub-expression elimination.
+struct CSE {
+ /// Shared implementation of operation elimination and scoped map definitions.
+ using AllocatorTy = llvm::RecyclingAllocator<
+ llvm::BumpPtrAllocator,
+ llvm::ScopedHashTableVal<Operation *, Operation *>>;
+ using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
+ SimpleOperationInfo, AllocatorTy>;
+
+ //===----------------------------------------------------------------------===//
+ // END copied from mlir/lib/Transforms/CSE.cpp
+ //===----------------------------------------------------------------------===//
+ CSE(DominanceInfo *domInfo, RewriteListener *listener)
+ : domInfo(domInfo), listener(listener) {}
+ //===----------------------------------------------------------------------===//
+ // BEGIN copied from mlir/lib/Transforms/CSE.cpp
+ //===----------------------------------------------------------------------===//
+
+ /// Represents a single entry in the depth first traversal of a CFG.
+ struct CFGStackNode {
+ CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
+ : scope(knownValues), node(node), childIterator(node->begin()),
+ processed(false) {}
+
+ /// Scope for the known values.
+ ScopedMapTy::ScopeTy scope;
+
+ DominanceInfoNode *node;
+ DominanceInfoNode::const_iterator childIterator;
+
+ /// If this node has been fully processed yet or not.
+ bool processed;
+ };
+
+ /// Attempt to eliminate a redundant operation. Returns success if the
+ /// operation was marked for removal, failure otherwise.
+ LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+ bool hasSSADominance);
+ void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
+ void simplifyRegion(ScopedMapTy &knownValues, Region ®ion);
+ /// Return the number of erased operations.
+ unsigned simplify(Operation *rootOp);
+
+private:
+ /// Operations marked as dead and to be erased.
+ std::vector<Operation *> opsToErase;
+
+ /// The dominance info to use.
+ DominanceInfo *domInfo;
+ //===----------------------------------------------------------------------===//
+ // END copied from mlir/lib/Transforms/CSE.cpp
+ //===----------------------------------------------------------------------===//
+ /// An optional listener to notify of replaced or erased operations.
+ RewriteListener *listener;
+ //===----------------------------------------------------------------------===//
+ // BEGIN copied from mlir/lib/Transforms/CSE.cpp
+ //===----------------------------------------------------------------------===//
+};
+
+} // namespace
+
+/// Attempt to eliminate a redundant operation.
+LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+ bool hasSSADominance) {
+ // Don't simplify terminator operations.
+ if (op->hasTrait<OpTrait::IsTerminator>())
+ return failure();
+
+ // If the operation is already trivially dead just add it to the erase list.
+ if (isOpTriviallyDead(op)) {
+ opsToErase.push_back(op);
+ return success();
+ }
+
+ // Don't simplify operations with nested blocks. We don't currently model
+ // equality comparisons correctly among other things. It is also unclear
+ // whether we would want to CSE such operations.
+ if (op->getNumRegions() != 0)
+ return failure();
+
+ // TODO: We currently only eliminate non side-effecting
+ // operations.
+ if (!MemoryEffectOpInterface::hasNoEffect(op))
+ return failure();
+
+ // Look for an existing definition for the operation.
+ if (auto *existing = knownValues.lookup(op)) {
+ // If we find one then replace all uses of the current operation with the
+ // existing one and mark it for deletion. We can only replace an operand in
+ // an operation if it has not been visited yet.
+ if (hasSSADominance) {
+ // If the region has SSA dominance, then we are guaranteed to have not
+ // visited any use of the current operation.
+ //===----------------------------------------------------------------------===//
+ // END copied from mlir/lib/Transforms/CSE.cpp
+ //===----------------------------------------------------------------------===//
+ if (listener)
+ listener->notifyOperationReplaced(op, existing->getResults());
+ //===----------------------------------------------------------------------===//
+ // BEGIN copied from mlir/lib/Transforms/CSE.cpp
+ //===----------------------------------------------------------------------===//
+ op->replaceAllUsesWith(existing);
+ opsToErase.push_back(op);
+ } else {
+ // When the region does not have SSA dominance, we need to check if we
+ // have visited a use before replacing any use.
+ for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
+ std::get<0>(it).replaceUsesWithIf(
+ std::get<1>(it), [&](OpOperand &operand) {
+ return !knownValues.count(operand.getOwner());
+ });
+ }
+
+ // There may be some remaining uses of the operation.
+ if (op->use_empty())
+ opsToErase.push_back(op);
+ }
+
+ // If the existing operation has an unknown location and the current
+ // operation doesn't, then set the existing op's location to that of the
+ // current op.
+ if (existing->getLoc().isa<UnknownLoc>() &&
+ !op->getLoc().isa<UnknownLoc>()) {
+ existing->setLoc(op->getLoc());
+ }
+
+ return success();
+ }
+
+ // Otherwise, we add this operation to the known values map.
+ knownValues.insert(op, op);
+ return failure();
+}
+
+void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
+ bool hasSSADominance) {
+ for (auto &op : *bb) {
+ // If the operation is simplified, we don't process any held regions.
+ if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
+ continue;
+
+ // Most operations don't have regions, so fast path that case.
+ if (op.getNumRegions() == 0)
+ continue;
+
+ // If this operation is isolated above, we can't process nested regions with
+ // the given 'knownValues' map. This would cause the insertion of implicit
+ // captures in explicit capture only regions.
+ if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
+ ScopedMapTy nestedKnownValues;
+ for (auto ®ion : op.getRegions())
+ simplifyRegion(nestedKnownValues, region);
+ continue;
+ }
+
+ // Otherwise, process nested regions normally.
+ for (auto ®ion : op.getRegions())
+ simplifyRegion(knownValues, region);
+ }
+}
+
+void CSE::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) {
+ // If the region is empty there is nothing to do.
+ if (region.empty())
+ return;
+
+ bool hasSSADominance = domInfo->hasSSADominance(®ion);
+
+ // If the region only contains one block, then simplify it directly.
+ if (region.hasOneBlock()) {
+ ScopedMapTy::ScopeTy scope(knownValues);
+ simplifyBlock(knownValues, ®ion.front(), hasSSADominance);
+ return;
+ }
+
+ // If the region does not have dominanceInfo, then skip it.
+ // TODO: Regions without SSA dominance should define a different
+ // traversal order which is appropriate and can be used here.
+ if (!hasSSADominance)
+ return;
+
+ // Note, deque is being used here because there was significant performance
+ // gains over vector when the container becomes very large due to the
+ // specific access patterns. If/when these performance issues are no
+ // longer a problem we can change this to vector. For more information see
+ // the llvm mailing list discussion on this:
+ // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
+ std::deque<std::unique_ptr<CFGStackNode>> stack;
+
+ // Process the nodes of the dom tree for this region.
+ stack.emplace_back(std::make_unique<CFGStackNode>(
+ knownValues, domInfo->getRootNode(®ion)));
+
+ while (!stack.empty()) {
+ auto ¤tNode = stack.back();
+
+ // Check to see if we need to process this node.
+ if (!currentNode->processed) {
+ currentNode->processed = true;
+ simplifyBlock(knownValues, currentNode->node->getBlock(),
+ hasSSADominance);
+ }
+
+ // Otherwise, check to see if we need to process a child node.
+ if (currentNode->childIterator != currentNode->node->end()) {
+ auto *childNode = *(currentNode->childIterator++);
+ stack.emplace_back(
+ std::make_unique<CFGStackNode>(knownValues, childNode));
+ } else {
+ // Finally, if the node and all of its children have been processed
+ // then we delete the node.
+ stack.pop_back();
+ }
+ }
+}
+
+unsigned CSE::simplify(Operation *rootOp) {
+ /// A scoped hash table of defining operations within a region.
+ ScopedMapTy knownValues;
+
+ for (auto ®ion : rootOp->getRegions())
+ simplifyRegion(knownValues, region);
+
+ /// Erase any operations that were marked as dead during simplification.
+ for (auto *op : opsToErase) {
+ //===----------------------------------------------------------------------===//
+ // END copied from mlir/lib/Transforms/CSE.cpp
+ //===----------------------------------------------------------------------===//
+ if (listener)
+ listener->notifyOperationRemoved(op);
+ //===----------------------------------------------------------------------===//
+ // BEGIN copied from mlir/lib/Transforms/CSE.cpp
+ //===----------------------------------------------------------------------===//
+ op->erase();
+ }
+
+ return opsToErase.size();
+}
+
+//===----------------------------------------------------------------------===//
+// END copied from mlir/lib/Transforms/CSE.cpp
+//===----------------------------------------------------------------------===//
+
+/// Run CSE on the provided operation
+LogicalResult mlir::eliminateCommonSubexpressions(Operation *op,
+ DominanceInfo *domInfo,
+ RewriteListener *listener) {
+ assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+ "can only do CSE on isolated-from-above ops");
+
+ Optional<DominanceInfo> defaultDomInfo;
+ if (domInfo == nullptr) {
+ defaultDomInfo.emplace(op);
+ domInfo = &*defaultDomInfo;
+ }
+
+ CSE cse(domInfo, listener);
+ cse.simplify(op);
+ return success();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp
new file mode 100644
index 0000000..f3a2a96
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/ListenerGreedyPatternRewriteDriver.cpp
@@ -0,0 +1,436 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
+
+#include "iree-dialects/Transforms/Listener.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ScopedPrinter.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "listener-greedy-rewriter"
+
+//===----------------------------------------------------------------------===//
+// GreedyPatternRewriteDriver
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
+/// applies the locally optimal patterns in a roughly "bottom up" way.
+class GreedyPatternRewriteDriver : public RewriteListener {
+public:
+ //===----------------------------------------------------------------------===//
+ // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+ //===----------------------------------------------------------------------===//
+ explicit GreedyPatternRewriteDriver(
+ MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ const GreedyRewriteConfig &config,
+ //===----------------------------------------------------------------------===//
+ // END copied from
+ // mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+ //===----------------------------------------------------------------------===//
+ RewriteListener *listener);
+ //===----------------------------------------------------------------------===//
+ // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+ //===----------------------------------------------------------------------===//
+
+ /// Simplify the operations within the given regions.
+ bool simplify(MutableArrayRef<Region> regions);
+
+ /// Add the given operation to the worklist.
+ void addToWorklist(Operation *op);
+
+ /// Pop the next operation from the worklist.
+ Operation *popFromWorklist();
+
+ /// If the specified operation is in the worklist, remove it.
+ void removeFromWorklist(Operation *op);
+
+protected:
+ // Implement the hook for inserting operations, and make sure that newly
+ // inserted ops are added to the worklist for processing.
+ void notifyOperationInserted(Operation *op) override;
+
+ // Look over the provided operands for any defining operations that should
+ // be re-added to the worklist. This function should be called when an
+ // operation is modified or removed, as it may trigger further
+ // simplifications.
+ template <typename Operands>
+ void addToWorklist(Operands &&operands);
+
+ // If an operation is about to be removed, make sure it is not in our
+ // worklist anymore because we'd get dangling references to it.
+ void notifyOperationRemoved(Operation *op) override;
+
+ //===----------------------------------------------------------------------===//
+ // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+ //===----------------------------------------------------------------------===//
+ // When the root of a pattern is about to be replaced, it can trigger
+ // simplifications to its users - make sure to add them to the worklist
+ // before the root is changed.
+ void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
+ //===----------------------------------------------------------------------===//
+ // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+ //===----------------------------------------------------------------------===//
+
+ /// PatternRewriter hook for notifying match failure reasons.
+ void
+ notifyMatchFailure(Operation *op,
+ function_ref<void(Diagnostic &)> reasonCallback) override;
+
+ /// The low-level pattern applicator.
+ PatternApplicator matcher;
+
+ /// The worklist for this transformation keeps track of the operations that
+ /// need to be revisited, plus their index in the worklist. This allows us to
+ /// efficiently remove operations from the worklist when they are erased, even
+ /// if they aren't the root of a pattern.
+ std::vector<Operation *> worklist;
+ DenseMap<Operation *, unsigned> worklistMap;
+
+ /// Non-pattern based folder for operations.
+ OperationFolder folder;
+
+private:
+ /// Configuration information for how to simplify.
+ GreedyRewriteConfig config;
+
+ //===----------------------------------------------------------------------===//
+ // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+ //===----------------------------------------------------------------------===//
+ /// The pattern rewriter to use.
+ PatternRewriterListener rewriter;
+ //===----------------------------------------------------------------------===//
+ // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+ //===----------------------------------------------------------------------===//
+
+#ifndef NDEBUG
+ /// A logger used to emit information during the application process.
+ llvm::ScopedPrinter logger{llvm::dbgs()};
+#endif
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+//===----------------------------------------------------------------------===//
+GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
+ MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ const GreedyRewriteConfig &config, RewriteListener *listener)
+ : matcher(patterns), folder(ctx), config(config), rewriter(ctx) {
+ // Add self as a listener and the user-provided listener.
+ rewriter.addListener(this);
+ if (listener)
+ rewriter.addListener(listener);
+ //===----------------------------------------------------------------------===//
+ // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+ //===----------------------------------------------------------------------===//
+
+ worklist.reserve(64);
+
+ // Apply a simple cost model based solely on pattern benefit.
+ matcher.applyDefaultCostModel();
+}
+
+bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
+#ifndef NDEBUG
+ const char *logLineComment =
+ "//===-------------------------------------------===//\n";
+
+ /// A utility function to log a process result for the given reason.
+ auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
+ logger.unindent();
+ logger.startLine() << "} -> " << result;
+ if (!msg.isTriviallyEmpty())
+ logger.getOStream() << " : " << msg;
+ logger.getOStream() << "\n";
+ };
+ auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
+ logResult(result, msg);
+ logger.startLine() << logLineComment;
+ };
+#endif
+
+ bool changed = false;
+ unsigned iteration = 0;
+ do {
+ worklist.clear();
+ worklistMap.clear();
+
+ if (!config.useTopDownTraversal) {
+ // Add operations to the worklist in postorder.
+ for (auto ®ion : regions)
+ region.walk([this](Operation *op) { addToWorklist(op); });
+ } else {
+ // Add all nested operations to the worklist in preorder.
+ for (auto ®ion : regions)
+ region.walk<WalkOrder::PreOrder>(
+ [this](Operation *op) { worklist.push_back(op); });
+
+ // Reverse the list so our pop-back loop processes them in-order.
+ std::reverse(worklist.begin(), worklist.end());
+ // Remember the reverse index.
+ for (size_t i = 0, e = worklist.size(); i != e; ++i)
+ worklistMap[worklist[i]] = i;
+ }
+
+ // These are scratch vectors used in the folding loop below.
+ SmallVector<Value, 8> originalOperands, resultValues;
+
+ changed = false;
+ while (!worklist.empty()) {
+ auto *op = popFromWorklist();
+
+ // Nulls get added to the worklist when operations are removed, ignore
+ // them.
+ if (op == nullptr)
+ continue;
+
+ LLVM_DEBUG({
+ logger.getOStream() << "\n";
+ logger.startLine() << logLineComment;
+ logger.startLine() << "Processing operation : '" << op->getName()
+ << "'(" << op << ") {\n";
+ logger.indent();
+
+ // If the operation has no regions, just print it here.
+ if (op->getNumRegions() == 0) {
+ op->print(
+ logger.startLine(),
+ OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
+ logger.getOStream() << "\n\n";
+ }
+ });
+
+ // If the operation is trivially dead - remove it.
+ if (isOpTriviallyDead(op)) {
+ rewriter.notifyOperationRemoved(op);
+ op->erase();
+ changed = true;
+
+ LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
+ continue;
+ }
+
+ // Collects all the operands and result uses of the given `op` into work
+ // list. Also remove `op` and nested ops from worklist.
+ originalOperands.assign(op->operand_begin(), op->operand_end());
+ auto preReplaceAction = [&](Operation *op) {
+ // Add the operands to the worklist for visitation.
+ addToWorklist(originalOperands);
+
+ // Add all the users of the result to the worklist so we make sure
+ // to revisit them.
+ for (auto result : op->getResults())
+ for (auto *userOp : result.getUsers())
+ addToWorklist(userOp);
+
+ rewriter.notifyOperationRemoved(op);
+ };
+
+ // Add the given operation to the worklist.
+ auto collectOps = [this](Operation *op) { addToWorklist(op); };
+
+ // Try to fold this op.
+ bool inPlaceUpdate;
+ if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
+ &inPlaceUpdate)))) {
+ LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
+
+ changed = true;
+ if (!inPlaceUpdate)
+ continue;
+ }
+
+ // Try to match one of the patterns. The rewriter is automatically
+ // notified of any necessary changes, so there is nothing else to do
+ // here.
+#ifndef NDEBUG
+ auto canApply = [&](const Pattern &pattern) {
+ LLVM_DEBUG({
+ logger.getOStream() << "\n";
+ logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
+ << op->getName() << " -> (";
+ llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
+ logger.getOStream() << ")' {\n";
+ logger.indent();
+ });
+ return true;
+ };
+ auto onFailure = [&](const Pattern &pattern) {
+ LLVM_DEBUG(logResult("failure", "pattern failed to match"));
+ };
+ auto onSuccess = [&](const Pattern &pattern) {
+ LLVM_DEBUG(logResult("success", "pattern applied successfully"));
+ return success();
+ };
+
+ LogicalResult matchResult =
+ matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
+ if (succeeded(matchResult))
+ LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
+ else
+ LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
+#else
+ LogicalResult matchResult = matcher.matchAndRewrite(op, rewriter);
+#endif
+ changed |= succeeded(matchResult);
+ }
+
+ // After applying patterns, make sure that the CFG of each of the regions
+ // is kept up to date.
+ if (config.enableRegionSimplification)
+ changed |= succeeded(simplifyRegions(rewriter, regions));
+ } while (changed &&
+ (++iteration < config.maxIterations ||
+ config.maxIterations == GreedyRewriteConfig::kNoIterationLimit));
+
+ // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
+ return !changed;
+}
+
+void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
+ // Check to see if the worklist already contains this op.
+ if (worklistMap.count(op))
+ return;
+
+ worklistMap[op] = worklist.size();
+ worklist.push_back(op);
+}
+
+Operation *GreedyPatternRewriteDriver::popFromWorklist() {
+ auto *op = worklist.back();
+ worklist.pop_back();
+
+ // This operation is no longer in the worklist, keep worklistMap up to date.
+ if (op)
+ worklistMap.erase(op);
+ return op;
+}
+
+void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
+ auto it = worklistMap.find(op);
+ if (it != worklistMap.end()) {
+ assert(worklist[it->second] == op && "malformed worklist data structure");
+ worklist[it->second] = nullptr;
+ worklistMap.erase(it);
+ }
+}
+
+void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
+ LLVM_DEBUG({
+ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
+ << ")\n";
+ });
+ addToWorklist(op);
+}
+
+template <typename Operands>
+void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) {
+ for (Value operand : operands) {
+ // If the use count of this operand is now < 2, we re-add the defining
+ // operation to the worklist.
+ // TODO: This is based on the fact that zero use operations
+ // may be deleted, and that single use values often have more
+ // canonicalization opportunities.
+ if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
+ continue;
+ if (auto *defOp = operand.getDefiningOp())
+ addToWorklist(defOp);
+ }
+}
+
+void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
+ LLVM_DEBUG({
+ logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
+ << ")\n";
+ });
+ addToWorklist(op->getOperands());
+ op->walk([this](Operation *operation) {
+ removeFromWorklist(operation);
+ folder.notifyRemoval(operation);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+//===----------------------------------------------------------------------===//
+void GreedyPatternRewriteDriver::notifyOperationReplaced(Operation *op,
+ ValueRange newValues) {
+ LLVM_DEBUG({
+ logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
+ << ")\n";
+ });
+ for (auto result : op->getResults())
+ for (auto *user : result.getUsers())
+ addToWorklist(user);
+}
+//===----------------------------------------------------------------------===//
+// BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+//===----------------------------------------------------------------------===//
+
+void GreedyPatternRewriteDriver::notifyMatchFailure(
+ Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
+ LLVM_DEBUG({
+ Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
+ reasonCallback(diag);
+ logger.startLine() << "** Failure : " << diag.str() << "\n";
+ });
+}
+
+/// Rewrite the regions of the specified operation, which must be isolated from
+/// above, by repeatedly applying the highest benefit patterns in a greedy
+/// work-list driven manner. Return success if no more patterns can be matched
+/// in the result operation regions. Note: This does not apply patterns to the
+/// top-level operation itself.
+///
+//===----------------------------------------------------------------------===//
+// END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+//===----------------------------------------------------------------------===//
+LogicalResult mlir::applyPatternsAndFoldGreedily(
+ MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
+ const GreedyRewriteConfig &config, RewriteListener *listener) {
+ //===----------------------------------------------------------------------===//
+ // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+ //===----------------------------------------------------------------------===//
+ if (regions.empty())
+ return success();
+
+ // The top-level operation must be known to be isolated from above to
+ // prevent performing canonicalizations on operations defined at or above
+ // the region containing 'op'.
+ auto regionIsIsolated = [](Region ®ion) {
+ return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
+ };
+ (void)regionIsIsolated;
+ assert(llvm::all_of(regions, regionIsIsolated) &&
+ "patterns can only be applied to operations IsolatedFromAbove");
+
+ // Start the pattern driver.
+ //===----------------------------------------------------------------------===//
+ // END copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+ //===----------------------------------------------------------------------===//
+ GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config,
+ listener);
+ //===----------------------------------------------------------------------===//
+ // BEGIN copied from mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+ //===----------------------------------------------------------------------===//
+ bool converged = driver.simplify(regions);
+ LLVM_DEBUG(if (!converged) {
+ llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
+ << config.maxIterations << " times\n";
+ });
+ return success(converged);
+}
diff --git a/llvm-external-projects/iree-dialects/python/CMakeLists.txt b/llvm-external-projects/iree-dialects/python/CMakeLists.txt
index 724982b..3aa20dd 100644
--- a/llvm-external-projects/iree-dialects/python/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/python/CMakeLists.txt
@@ -34,6 +34,15 @@
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT IREEDialectsPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler"
+ TD_FILE dialects/LinalgTransformBinding.td
+ SOURCES dialects/iree_linalg_transform.py
+ dialects/_iree_linalg_transform_ops_ext.py
+ DIALECT_NAME iree_linalg_transform
+ )
+
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT IREEDialectsPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler"
TD_FILE dialects/IreePyDmBinding.td
SOURCES
dialects/_iree_pydm_ops_ext.py
@@ -70,8 +79,9 @@
# build burden by ~5x. Make it stop.
MLIRPythonSources.Core
MLIRPythonSources.Dialects.builtin
- MLIRPythonSources.Dialects.func
MLIRPythonSources.Dialects.cf
+ MLIRPythonSources.Dialects.func
+ MLIRPythonSources.Dialects.pdl
MLIRPythonSources.Passes
IREEDialectsPythonSources
IREEDialectsPythonExtensions
diff --git a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
index 3647c47..4f85b1b 100644
--- a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
+++ b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
@@ -26,7 +26,8 @@
}
PyIREEPyDMSourceBundle(const PyIREEPyDMSourceBundle &) = delete;
~PyIREEPyDMSourceBundle() {
- if (wrapped.ptr) ireePyDMSourceBundleDestroy(wrapped);
+ if (wrapped.ptr)
+ ireePyDMSourceBundleDestroy(wrapped);
}
IREEPyDMSourceBundle wrapped;
};
@@ -39,12 +40,13 @@
}
PyIREEPyDMLoweringOptions(const PyIREEPyDMLoweringOptions &) = delete;
~PyIREEPyDMLoweringOptions() {
- if (wrapped.ptr) ireePyDMLoweringOptionsDestroy(wrapped);
+ if (wrapped.ptr)
+ ireePyDMLoweringOptionsDestroy(wrapped);
}
IREEPyDMLoweringOptions wrapped;
};
-} // namespace
+} // namespace
PYBIND11_MODULE(_ireeDialects, m) {
m.doc() = "iree-dialects main python extension";
@@ -105,6 +107,22 @@
py::arg("context") = py::none(), py::arg("load") = true);
//===--------------------------------------------------------------------===//
+ // LinalgTransform
+ //===--------------------------------------------------------------------===//
+ auto iree_linalg_transform_m = m.def_submodule("iree_linalg_transform");
+ iree_linalg_transform_m.def(
+ "register_dialect",
+ [](MlirContext context, bool load) {
+ MlirDialectHandle handle =
+ mlirGetDialectHandle__iree_linalg_transform__();
+ mlirDialectHandleRegisterDialect(handle, context);
+ if (load) {
+ mlirDialectHandleLoadDialect(handle, context);
+ }
+ },
+ py::arg("context") = py::none(), py::arg("load") = true);
+
+ //===--------------------------------------------------------------------===//
// IREEPyDMDialect
//===--------------------------------------------------------------------===//
auto iree_pydm_m = m.def_submodule("iree_pydm");
@@ -171,14 +189,14 @@
},
py::arg("pass_manager"));
-#define DEFINE_IREEPYDM_NULLARY_TYPE(Name) \
- mlir_type_subclass(iree_pydm_m, #Name "Type", mlirTypeIsAIREEPyDM##Name, \
- typeClass) \
- .def_classmethod( \
- "get", \
- [](py::object cls, MlirContext context) { \
- return cls(mlirIREEPyDM##Name##TypeGet(context)); \
- }, \
+#define DEFINE_IREEPYDM_NULLARY_TYPE(Name) \
+ mlir_type_subclass(iree_pydm_m, #Name "Type", mlirTypeIsAIREEPyDM##Name, \
+ typeClass) \
+ .def_classmethod( \
+ "get", \
+ [](py::object cls, MlirContext context) { \
+ return cls(mlirIREEPyDM##Name##TypeGet(context)); \
+ }, \
py::arg("cls"), py::arg("context") = py::none());
DEFINE_IREEPYDM_NULLARY_TYPE(Bool)
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/LinalgTransformBinding.td b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/LinalgTransformBinding.td
new file mode 100644
index 0000000..e908e79
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/LinalgTransformBinding.td
@@ -0,0 +1,13 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef PYTHON_BINDINGS_IREE_LINALGTRANSFORM_BINDING
+#define PYTHON_BINDINGS_IREE_LINALGTRANSFORM_BINDING
+
+include "mlir/Bindings/Python/Attributes.td"
+include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.td"
+
+#endif // PYTHON_BINDINGS_IREE_LINALGTRANSFORM_BINDING
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_linalg_transform_ops_ext.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_linalg_transform_ops_ext.py
new file mode 100644
index 0000000..45baa1d
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/_iree_linalg_transform_ops_ext.py
@@ -0,0 +1,405 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Disable PyType, it does not seem to like the specialization pattern used in
+# MLIR.
+# pytype: skip-file
+
+try:
+ from .. import ir
+ from ..dialects import pdl
+ from typing import Optional, Sequence, Union
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+BoolArg = Optional[Union[bool, ir.BoolAttr]]
+IntArg = Optional[Union[int, ir.IntegerAttr]]
+IntListArg = Optional[Union[Sequence[int], ir.ArrayAttr]]
+IntListListArg = Optional[Union[Sequence[Union[Sequence[int], ir.ArrayAttr]],
+ ir.ArrayAttr]]
+StringArg = Optional[Union[str, ir.StringAttr]]
+
+
+def _defaulted_ensure(f):
+
+ def inner(value, default=None):
+ assert value is not None or default is not None
+ return f(default if value is None else value)
+
+ return inner
+
+
+@_defaulted_ensure
+def _ensure_array_attr(value: IntListArg):
+ i64 = ir.IntegerType.get_signless(64)
+ if isinstance(value, Sequence):
+ return ir.ArrayAttr.get([ir.IntegerAttr.get(i64, i) for i in value])
+ return value
+
+
+@_defaulted_ensure
+def _ensure_array_of_array_attr(value: IntListListArg):
+ if isinstance(value, Sequence):
+ return ir.ArrayAttr.get([_ensure_array_attr(inner) for inner in value])
+ return value
+
+
+@_defaulted_ensure
+def _ensure_int_attr(value: IntArg):
+ if isinstance(value, int):
+ return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value)
+ return value
+
+
+@_defaulted_ensure
+def _ensure_bool_attr(value: BoolArg):
+ if isinstance(value, bool):
+ return ir.BoolAttr.get(value)
+ return value
+
+
+@_defaulted_ensure
+def _ensure_string_attr(value: StringArg):
+ if isinstance(value, str):
+ return ir.StringAttr.get(value)
+ return value
+
+
+class MatchOp:
+ """Specialization for the MatchOp class."""
+
+ def __init__(self, target: Union[str, ir.FlatSymbolRefAttr]):
+ if isinstance(target, str):
+ target = ir.FlatSymbolRefAttr.get(target)
+
+ operation_type = pdl.OperationType.get()
+ super().__init__(operation_type, target)
+
+
+class LowerVectorsOp:
+ """Specialization for the LowerVectorsOp class."""
+
+ def __init__(self,
+ *,
+ stages: IntListArg = None,
+ contraction_lowering: StringArg = None,
+ multireduction_lowering: StringArg = None,
+ split_transfers: StringArg = None,
+ unroll_vector_transfers: BoolArg = None,
+ transpose_lowering: StringArg = None,
+ transpose_avx2_lowering: BoolArg = None,
+ loc=None,
+ ip=None):
+ stages = _ensure_array_attr(stages, [0, 1, 2, 3, 4, 5, 6])
+ contraction_lowering = _ensure_string_attr(contraction_lowering,
+ "outerproduct")
+ multireduction_lowering = _ensure_string_attr(multireduction_lowering,
+ "innerparallel")
+ split_transfers = _ensure_string_attr(split_transfers, "linalg-copy")
+ unroll_vector_transfers = _ensure_bool_attr(unroll_vector_transfers, True)
+ transpose_lowering = _ensure_string_attr(transpose_lowering, "eltwise")
+ transpose_avx2_lowering = _ensure_bool_attr(transpose_avx2_lowering, False)
+
+ super().__init__(stages,
+ contraction_lowering,
+ multireduction_lowering,
+ split_transfers,
+ unroll_vector_transfers,
+ transpose_lowering,
+ transpose_avx2_lowering,
+ loc=loc,
+ ip=ip)
+
+
+class LowerToLLVMOp:
+ """Specialization for the LowerToLLVMOp class."""
+
+ def __init__(self,
+ *,
+ reassociate_fp_reductions: BoolArg = None,
+ enable_index_optimizations: BoolArg = None,
+ enable_arm_neon: BoolArg = None,
+ enable_arm_sve: BoolArg = None,
+ enable_amx: BoolArg = None,
+ enable_x86vector: BoolArg = None,
+ enable_async: BoolArg = None,
+ loc=None,
+ ip=None):
+ super().__init__(_ensure_bool_attr(reassociate_fp_reductions, False),
+ _ensure_bool_attr(enable_index_optimizations, False),
+ _ensure_bool_attr(enable_arm_neon, False),
+ _ensure_bool_attr(enable_arm_sve, False),
+ _ensure_bool_attr(enable_amx, False),
+ _ensure_bool_attr(enable_x86vector, False),
+ _ensure_bool_attr(enable_async, False),
+ loc=loc,
+ ip=ip)
+
+
+class FuseOp:
+ """Specialization for the FuseOp class."""
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ tile_sizes: IntListArg = None,
+ tile_interchange: IntListArg = None,
+ loc=None,
+ ip=None):
+ tile_sizes = _ensure_array_attr(tile_sizes, [])
+ tile_interchange = _ensure_array_attr(tile_interchange, [])
+ operation_type = pdl.OperationType.get()
+
+ super().__init__(operation_type,
+ target,
+ tile_sizes,
+ tile_interchange,
+ loc=loc,
+ ip=ip)
+
+
+class TileOp:
+ """Specialization for the TileOp class."""
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ sizes: IntListArg = None,
+ interchange: IntListArg = None,
+ peel: IntListArg = None,
+ scalarize_dyn_dims: BoolArg = None,
+ loc=None,
+ ip=None):
+ sizes = _ensure_array_attr(sizes, [])
+ interchange = _ensure_array_attr(interchange, [])
+ peel = _ensure_array_attr(peel, [])
+ scalarize_dyn_dims = _ensure_bool_attr(scalarize_dyn_dims, False)
+ operation_type = pdl.OperationType.get()
+
+ super().__init__(operation_type,
+ target,
+ sizes,
+ interchange,
+ peel,
+ scalarize_dyn_dims,
+ loc=loc,
+ ip=ip)
+
+
+class PadOp:
+ """Specialization for the PadOp class."""
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ pack_paddings: IntListArg = None,
+ hoist_paddings: IntListArg = None,
+ transpose_paddings: IntListListArg = None,
+ loc=None,
+ ip=None):
+ pack_paddings = _ensure_array_attr(pack_paddings, [])
+ hoist_paddings = _ensure_array_attr(hoist_paddings, [])
+ transpose_paddings = _ensure_array_of_array_attr(transpose_paddings, [])
+ operation_type = pdl.OperationType.get()
+
+ super().__init__(operation_type,
+ target,
+ pack_paddings,
+ hoist_paddings,
+ transpose_paddings,
+ loc=loc,
+ ip=ip)
+
+
+class GeneralizeOp:
+ """Specialization for the GeneralizeOp class."""
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ loc=None,
+ ip=None):
+ operation_type = pdl.OperationType.get()
+
+ super().__init__(operation_type, target, loc=loc, ip=ip)
+
+
+class InterchangeOp:
+ """Specialization for the InterchangeOp class."""
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ iterator_interchange: IntListArg = None,
+ loc=None,
+ ip=None):
+ iterator_interchange = _ensure_array_attr(iterator_interchange, [])
+ operation_type = pdl.OperationType.get()
+
+ super().__init__(operation_type,
+ target,
+ iterator_interchange,
+ loc=loc,
+ ip=ip)
+
+
+class VectorizeOp:
+
+ def __init__(self,
+ target: Optional[Union[ir.Value, ir.Operation,
+ ir.OpView]] = None,
+ *,
+ vectorize_padding: BoolArg = None,
+ loc=None,
+ ip=None):
+ operation_type = pdl.OperationType.get()
+
+ super().__init__(operation_type if target is not None else None,
+ target,
+ _ensure_bool_attr(vectorize_padding, False),
+ loc=loc,
+ ip=ip)
+
+
+class GetParentLoopOp:
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ num_loops: IntArg = None,
+ loc=None,
+ ip=None):
+ operation_type = pdl.OperationType.get()
+ num_loops = _ensure_int_attr(num_loops, 1)
+ super().__init__(operation_type, target, num_loops, loc=loc, ip=ip)
+
+
+class UnrollLoopOp:
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ factor: Union[int, ir.IntegerAttr],
+ loc=None,
+ ip=None):
+ # Factor must not be None, do not provide the default value here.
+ factor = _ensure_int_attr(factor)
+ super().__init__(target, factor, loc=loc, ip=ip)
+
+
+class PipelineLoopOp:
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ iteration_interval: IntArg,
+ read_latency: IntArg,
+ loc=None,
+ ip=None):
+ iteration_interval = _ensure_int_attr(iteration_interval, 1)
+ read_latency = _ensure_int_attr(read_latency, 10)
+ operation_type = pdl.OperationType.get()
+ super().__init__(operation_type,
+ target,
+ iteration_interval,
+ read_latency,
+ loc=loc,
+ ip=ip)
+
+
+class OutlineLoopOp:
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ func_name: StringArg,
+ loc=None,
+ ip=None):
+ # Function name must not be None, do not provide the default value.
+ func_name = _ensure_string_attr(func_name)
+ operation_type = pdl.OperationType.get()
+ super().__init__(operation_type, target, func_name, loc=loc, ip=ip)
+
+
+class SequenceOp:
+
+ def __init__(self, *, loc=None, ip=None):
+ super().__init__(loc=loc, ip=ip)
+ self.body.blocks.append()
+
+
+class PrintOp:
+
+ def __init__(self, *, name: StringArg, loc=None, ip=None):
+ name = _ensure_string_attr(name)
+ super().__init__(name, loc=loc, ip=ip)
+
+
+##===----------------------------------------------------------------------===##
+## LinalgExt specific transforms
+##===----------------------------------------------------------------------===##
+
+
+class TileToLinalgExtTileOp:
+ """Specialization for the TileToLinalgExtTileOp class."""
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ sizes: IntListArg = None,
+ loc=None,
+ ip=None):
+ sizes = _ensure_array_attr(sizes, [])
+ operation_type = pdl.OperationType.get()
+ super().__init__(operation_type, target, sizes, loc=loc, ip=ip)
+
+
+class RewriteLinalgExtTileToScfForOp:
+ """Specialization for the RewriteLinalgExtTileToScfForOp class."""
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ loc=None,
+ ip=None):
+ operation_type = pdl.OperationType.get()
+ super().__init__(operation_type, target, loc=loc, ip=ip)
+
+
+class RewriteLinalgExtTileToInParallelOp:
+ """Specialization for the RewriteLinalgExtTileToInParallelOp class."""
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ loc=None,
+ ip=None):
+ operation_type = pdl.OperationType.get()
+ super().__init__(operation_type, target, loc=loc, ip=ip)
+
+
+class RewriteLinalgExtInParallelToScfForOp:
+ """Specialization for the RewriteLinalgExtInParallelToScfForOp class."""
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ loc=None,
+ ip=None):
+ operation_type = pdl.OperationType.get()
+ super().__init__(operation_type, target, loc=loc, ip=ip)
+
+
+class RewriteLinalgExtInParallelToAsyncOp:
+ """Specialization for the RewriteLinalgExtInParallelToAsyncOp class."""
+
+ def __init__(self,
+ target: Union[ir.Value, ir.Operation, ir.OpView],
+ *,
+ loc=None,
+ ip=None):
+ operation_type = pdl.OperationType.get()
+ super().__init__(operation_type, target, loc=loc, ip=ip)
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_linalg_transform.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_linalg_transform.py
new file mode 100644
index 0000000..8bb7799
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_linalg_transform.py
@@ -0,0 +1,8 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._iree_linalg_transform_ops_gen import *
+from .._mlir_libs._ireeDialects.iree_linalg_transform import *
diff --git a/llvm-external-projects/iree-dialects/test/CMakeLists.txt b/llvm-external-projects/iree-dialects/test/CMakeLists.txt
index e342e16..09a3e9c 100644
--- a/llvm-external-projects/iree-dialects/test/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/test/CMakeLists.txt
@@ -28,3 +28,5 @@
set_target_properties(check-iree-dialects PROPERTIES FOLDER "Tests")
add_lit_testsuites(IREE_DIALECTS ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${IREE_DIALECTS_TEST_DEPENDS})
+
+add_subdirectory(lib)
\ No newline at end of file
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/canonicalize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/canonicalize.mlir
similarity index 100%
rename from llvm-external-projects/iree-dialects/test/iree_linalgext/canonicalize.mlir
rename to llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/canonicalize.mlir
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/convert_to_loops.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/convert_to_loops.mlir
similarity index 100%
rename from llvm-external-projects/iree-dialects/test/iree_linalgext/convert_to_loops.mlir
rename to llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/convert_to_loops.mlir
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir
similarity index 100%
rename from llvm-external-projects/iree-dialects/test/iree_linalgext/invalid.mlir
rename to llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/pad_contraction_to_block_size.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/pad_contraction_to_block_size.mlir
similarity index 100%
rename from llvm-external-projects/iree-dialects/test/iree_linalgext/pad_contraction_to_block_size.mlir
rename to llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/pad_contraction_to_block_size.mlir
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/pad_tiling.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/pad_tiling.mlir
similarity index 97%
rename from llvm-external-projects/iree-dialects/test/iree_linalgext/pad_tiling.mlir
rename to llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/pad_tiling.mlir
index f71ae8f..d4ad8f0 100644
--- a/llvm-external-projects/iree-dialects/test/iree_linalgext/pad_tiling.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/pad_tiling.mlir
@@ -8,7 +8,7 @@
%0 = tensor.pad %arg0 low[%arg1, %arg2] high[%arg3, %arg4] {
^bb0(%arg6 : index, %arg7 : index):
tensor.yield %arg5 : f32
- } {__internal_linalg_transform__ = "tiling_input"}
+ } {__internal_iree_linalg_transform__ = "tiling_input"}
: tensor<?x?xf32> to tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
similarity index 100%
rename from llvm-external-projects/iree-dialects/test/iree_linalgext/roundtrip.mlir
rename to llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/tiling.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
similarity index 100%
rename from llvm-external-projects/iree-dialects/test/iree_linalgext/tiling.mlir
rename to llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir
new file mode 100644
index 0000000..5ca985e
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/bufferize.mlir
@@ -0,0 +1,34 @@
+// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+
+// CHECK-LABEL: func @matmul_tensors(
+// CHECK-SAME: %[[TA:[0-9a-z]+]]: memref<128x128xf32
+// CHECK-SAME: %[[TB:[0-9a-z]+]]: memref<128x128xf32
+// CHECK-SAME: %[[TC:[0-9a-z]+]]: memref<128x128xf32
+// CHECK-NOT: -> tensor
+func @matmul_tensors(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // CHECK: linalg.matmul ins(%[[TA]], %[[TB]] : memref{{.*}}, memref{{.*}} outs(%[[TC]] : memref{{.*}})
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ // CHECK: return
+ // CHECK-NOT: %{{.*}}
+ return %0 : tensor<128x128xf32>
+// CHECK: }
+}
+
+
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@matmul_tensors](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ bufferize
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/double-tiling.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/double-tiling.mlir
new file mode 100644
index 0000000..74d6cff
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/double-tiling.mlir
@@ -0,0 +1,43 @@
+// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+
+// This test is verifying that a non-trivial 2*tiling+padding+vectorization transformation completes successfully
+
+// CHECK-LABEL: func @matmul_tensors(
+func @matmul_tensors(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // Pack transposed padding of 1st operand.
+ // CHECK: tensor.pad
+ // CHECK: linalg.generic
+
+ // Pack padding of 2nd operand.
+ // CHECK: tensor.pad
+
+ // CHECK: scf.for
+ // CHECK: scf.for
+ // CHECK: scf.for
+ // CHECK: scf.for
+ // CHECK: scf.for
+ // CHECK: linalg.generic
+ // CHECK: vector.contract
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ return %0 : tensor<128x128xf32>
+}
+
+pdl.pattern @pdl_target: benefit(1) {
+ %args = operands
+ %results= types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@matmul_tensors](%0 : !pdl.operation)
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target
+ %1 = tile %0 {interchange = [0, 2, 1], peel = [], scalarize_dyn_dims = false, sizes = [32, 32, 32]}
+ %2 = tile %1 {interchange = [0, 1, 2], peel = [], scalarize_dyn_dims = false, sizes = [4, 4, 1]}
+ %3 = pad %2 {pack_paddings = [1, 1, 1], hoist_paddings = [6, 6, 0], transpose_paddings = [[1, 0], [0, 1]]}
+ %4 = vectorize %3 {vectorize_padding = true}
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir
new file mode 100644
index 0000000..c82252b
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/drop-schedule.mlir
@@ -0,0 +1,26 @@
+// RUN: iree-dialects-opt -linalg-drop-schedule %s | FileCheck %s
+
+func @matmul_tensors(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+// CHECK-NOT: pdl.pattern
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@matmul_tensors](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+// CHECK-NOT: iree_linalg_transform.sequence
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target
+ tile %0 {sizes = [4, 4, 4], pad = false}
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir
new file mode 100644
index 0000000..b5825ee
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/expert.mlir
@@ -0,0 +1,164 @@
+// RUN: iree-dialects-opt -linalg-transform-expert-expansion -split-input-file %s | FileCheck %s --check-prefix=EXPAND
+// RUN: iree-dialects-opt -linalg-transform-expert-expansion -linalg-interp-transforms -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @matmul_tensors
+// CHECK-NOT: linalg
+// CHECK: llvm
+func @matmul_tensors(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ return %0 : tensor<128x128xf32>
+}
+
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@matmul_tensors](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ // This should match the strategy below.
+ // EXPAND-NOT: expert apply
+ // EXPAND: %[[OP:.*]] = match @pdl_target
+ // EXPAND: %[[HANDLE:.*]] = tile %[[OP]] {sizes = [4, 4, 4]}
+ // EXPAND: %[[HANDLE2:.*]] = vectorize %[[HANDLE]] {vectorize_padding = true}
+ // EXPAND: bufferize
+ // EXPAND: lower_vectors {multireduction_lowering = "innerreduce"}
+ // EXPAND: lower_to_llvm
+ %0 = match @pdl_target
+ expert apply "single_tiling" to %0
+ {
+ tile_sizes = [4, 4, 4],
+ vectorize_padding = true,
+ multireduction_lowering = "innerreduce"
+ }
+}
+
+// CHECK-NOT: @strategies
+// EXPAND-NOT: @strategies
+module @strategies {
+ pdl.pattern @single_tiling_matcher : benefit(1) {
+ %tile_sizes = attribute
+ %vectorize_padding = attribute
+ %multireduction_lowering = attribute
+ %name = attribute : "single_tiling"
+ %type = type : !pdl.operation
+ %target = operand : %type
+ %transformed = type
+ %root = operation "iree_linalg_transform.expert"(%target : !pdl.value) {
+ "expertName" = %name,
+ "tile_sizes" = %tile_sizes,
+ "vectorize_padding" = %vectorize_padding,
+ "multireduction_lowering" = %multireduction_lowering
+ } -> (%transformed : !pdl.type)
+
+ rewrite %root {
+ %tile = operation "iree_linalg_transform.tile"(%target : !pdl.value) {
+ "sizes" = %tile_sizes
+ } -> (%transformed : !pdl.type)
+ %handle = result 0 of %tile
+
+ %vectorize = operation "iree_linalg_transform.vectorize"(%handle : !pdl.value) {
+ "vectorize_padding" = %vectorize_padding
+ } -> (%transformed : !pdl.type)
+ %handle2 = result 0 of %vectorize
+
+ %bufferize = operation "iree_linalg_transform.bufferize"
+ %lower_vectors = operation "iree_linalg_transform.lower_vectors" {
+ "multireduction_lowering" = %multireduction_lowering
+ }
+ %lower_to_llvm = operation "iree_linalg_transform.lower_to_llvm"
+
+ replace %root with (%handle2 : !pdl.value)
+ }
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @matmul_tensors2
+// CHECK-NOT: linalg
+// CHECK: llvm
+func @matmul_tensors2(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ return %0 : tensor<128x128xf32>
+}
+
+pdl.pattern @pdl_target2 : benefit(1) {
+ %args = pdl.operands
+ %results = pdl.types
+ %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ pdl.apply_native_constraint "nestedInFunc"[@matmul_tensors2](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ pdl.rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ // This should match the strategy below.
+ // EXPAND-NOT: expert apply
+ // EXPAND: %[[OP:.*]] = match @pdl_target2
+ // EXPAND: %[[HANDLE:.*]] = tile %[[OP]] {sizes = [32, 8, 8]}
+ // EXPAND: %[[HANDLE2:.*]] = tile %[[HANDLE]] {sizes = [4, 4, 4]}
+ // EXPAND: %[[HANDLE3:.*]] = vectorize %[[HANDLE2]] {vectorize_padding = false}
+ // EXPAND: bufferize
+ // EXPAND: lower_vectors {multireduction_lowering = "innerparallel"}
+ // EXPAND: lower_to_llvm
+ %0 = match @pdl_target2
+ %1 = tile %0 {sizes = [32, 8, 8]}
+ expert apply "single_tiling" to %1
+ {
+ tile_sizes = [4, 4, 4],
+ vectorize_padding = false,
+ multireduction_lowering = "innerparallel"
+ }
+}
+
+module @strategies {
+ pdl.pattern @single_tiling_operand : benefit(1) {
+ %tile_sizes = attribute
+ %vectorize_padding = attribute
+ %multireduction_lowering = attribute
+ %name = attribute : "single_tiling"
+ %type = type : !pdl.operation
+ %target = operand : %type
+ %transformed = type
+ %root = operation "iree_linalg_transform.expert"(%target : !pdl.value) {
+ "expertName" = %name,
+ "tile_sizes" = %tile_sizes,
+ "vectorize_padding" = %vectorize_padding,
+ "multireduction_lowering" = %multireduction_lowering
+ } -> (%transformed : !pdl.type)
+
+ rewrite %root {
+ %tile = operation "iree_linalg_transform.tile"(%target : !pdl.value) {
+ "sizes" = %tile_sizes
+ } -> (%transformed : !pdl.type)
+ %handle = result 0 of %tile
+
+ %vectorize = operation "iree_linalg_transform.vectorize"(%handle : !pdl.value) {
+ "vectorize_padding" = %vectorize_padding
+ } -> (%transformed : !pdl.type)
+ %handle2 = result 0 of %vectorize
+
+ %bufferize = operation "iree_linalg_transform.bufferize"
+ %lower_vectors = operation "iree_linalg_transform.lower_vectors" {
+ "multireduction_lowering" = %multireduction_lowering
+ }
+ %lower_to_llvm = operation "iree_linalg_transform.lower_to_llvm"
+
+ replace %root with (%handle2 : !pdl.value)
+ }
+ }
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
new file mode 100644
index 0000000..f0ecf7c
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/failure.mlir
@@ -0,0 +1,176 @@
+// RUN: iree-dialects-opt -linalg-interp-transforms -split-input-file -verify-diagnostics -allow-unregistered-dialect %s
+
+// This cannot be vectorized because of dynamic tensor shapes. We expect the
+// pass fail and report an error at the vectorization operation below.
+func public @non_vectorizable(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%arg0: tensor<?xf32>) outs(%arg1: tensor<?xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32):
+ %1 = arith.mulf %arg2, %arg2 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+pdl.pattern @target_pattern : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @target_pattern
+ // expected-error@below {{failed to apply}}
+ vectorize %0
+}
+
+// -----
+
+func public @no_loop(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%arg0: tensor<?xf32>) outs(%arg1: tensor<?xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32):
+ %1 = arith.mulf %arg2, %arg2 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+pdl.pattern @target_pattern : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @target_pattern
+ // expected-error@below {{the transformed op is enclosed by 0 loops, but 1 expected}}
+ // expected-error@below {{failed to apply}}
+ get_parent_loop %0
+}
+
+// -----
+
+func private @prevent_dce()
+
+pdl.pattern @something : benefit(1) {
+ %0 = operands
+ %2 = operation "scf.for"(%0 : !pdl.range<value>)
+ rewrite %2 with "iree_linalg_transform.apply"
+}
+
+func public @loop(%lb: index, %ub: index, %step: index) {
+ scf.for %i = %lb to %ub step %step {
+ call @prevent_dce() : () -> ()
+ }
+ return
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @something
+ // expected-error@below {{NYI: cannot target the result of pipelining}}
+ // expected-error@below {{failed to apply}}
+ %1 = pipeline_loop %0
+ // expected-note@below {{use here}}
+ get_parent_loop %1
+}
+
+// -----
+
+func public @no_outlining() {
+ "some.operation"() ({}, {}) : () -> ()
+ return
+}
+
+pdl.pattern @some_operation : benefit(1) {
+ %0 = operation "some.operation"
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @some_operation
+ // Make sure we don't crash on wrong operation type.
+ // expected-error@below {{failed to apply}}
+ outline_loop %0 {func_name = "outlined"}
+}
+
+// -----
+
+func @no_replacement(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+ %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // expected-error @below {{could not find replacement for tracked op}}
+ %0 = linalg.matmul {test.attrA}
+ ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@no_replacement](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target
+ // expected-error @below {{failed to apply}}
+ vectorize
+ tile %0
+}
+
+// -----
+
+func @repeated_match(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+ %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // expected-error @below {{operation tracked by two handles}}
+ %0 = linalg.matmul {test.attrA}
+ ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+pdl.pattern @pdl_target1 : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@repeated_match](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+// An exact copy of the above, but with a different name.
+pdl.pattern @pdl_target2 : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@repeated_match](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ // expected-note @below {{handle}}
+ %0 = match @pdl_target1
+ // expected-error @below {{failed to apply}}
+ // expected-note @below {{handle}}
+ %1 = match @pdl_target2
+
+ // Add references to handles produced by match so that they are not DCE'd.
+ tile %0
+ tile %1
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse.mlir
new file mode 100644
index 0000000..6a78eb3
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/fuse.mlir
@@ -0,0 +1,31 @@
+// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+
+
+// CHECK-LABEL: func @fuse_unary
+func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+ // CHECK: scf.for
+ // CHECK: scf.for
+ // CHECK: linalg.elemwise_unary
+ // CHECK: linalg.elemwise_binary
+ %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@fuse_unary](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target
+ %1 = fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/generalize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/generalize.mlir
new file mode 100644
index 0000000..ea12b9a
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/generalize.mlir
@@ -0,0 +1,27 @@
+// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+
+
+// CHECK-LABEL: func @generalize_unary
+func @generalize_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+ // CHECK-NOT: linalg.elemwise_unary
+ // CHECK: linalg.generic
+ %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@generalize_unary](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target
+ generalize %0
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/interchange.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/interchange.mlir
new file mode 100644
index 0000000..e988133
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/interchange.mlir
@@ -0,0 +1,34 @@
+// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+
+// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: func @interchange_generic
+func @interchange_generic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+ // CHECK: linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32):
+ %1 = math.exp %arg2 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.generic"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@interchange_generic](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target
+ interchange %0 {iterator_interchange = [1, 0]}
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/invalid.mlir
new file mode 100644
index 0000000..d9c7e28
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/invalid.mlir
@@ -0,0 +1,59 @@
+// RUN: iree-dialects-opt %s -split-input-file -verify-diagnostics
+
+iree_linalg_transform.sequence {
+ %0 = match @match
+ // expected-error@below {{result #0 has more than one use}}
+ %1 = tile %0
+ // expected-note@below {{used here as operand #0}}
+ tile %1
+ // expected-note@below {{used here as operand #0}}
+ vectorize %1
+}
+
+// -----
+
+iree_linalg_transform.sequence {
+ %0 = match @match
+ // expected-error@below {{"sizes" and "scalarize_dyn_dims" attributes are mutually exclusive}}
+ tile %0 {sizes = [1,2,3], scalarize_dyn_dims = true}
+}
+
+// -----
+
+iree_linalg_transform.sequence {
+ %0 = match @match
+ // expected-error@below {{expects iterator_interchange to be a permutation, found [1, 1]}}
+ interchange %0 {iterator_interchange = [1, 1]}
+}
+
+// -----
+
+iree_linalg_transform.sequence {
+ %0 = match @match
+ // expected-error@below {{expects interchange to be a permutation, found [1, 1]}}
+ fuse %0 {tile_sizes=[0, 1], tile_interchange = [1, 1]}
+}
+
+// -----
+
+iree_linalg_transform.sequence {
+ %0 = match @match
+ // expected-error@below {{expects pack_paddings to contain booleans (0/1), found [1, 7]}}
+ pad %0 {pack_paddings=[1, 7]}
+}
+
+// -----
+
+iree_linalg_transform.sequence {
+ %0 = match @match
+ // expected-error@below {{expects hoist_paddings to contain positive integers, found [1, -7]}}
+ pad %0 {hoist_paddings=[1, -7]}
+}
+
+// -----
+
+iree_linalg_transform.sequence {
+ %0 = match @match
+ // expected-error@below {{expects transpose_paddings to be a permutation, found [1, 1]}}
+ pad %0 {transpose_paddings=[[1, 1]]}
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/pad.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/pad.mlir
new file mode 100644
index 0000000..d6d627b
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/pad.mlir
@@ -0,0 +1,42 @@
+// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+
+
+// CHECK-LABEL: func @pad_unary
+func @pad_unary(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<24x12xf32>) -> tensor<24x12xf32> {
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c4 = arith.constant 4 : index
+
+ // CHECK: scf.for
+ // CHECK: tensor.pad
+ // CHECK: linalg.generic
+ // CHECK: scf.for
+ %0 = scf.for %arg3 = %c0 to %c12 step %c4 iter_args(%arg2 = %arg1) -> (tensor<24x12xf32>) {
+ %1 = tensor.extract_slice %arg0[0, %arg3] [24, 4] [1, 1] : tensor<24x12xf32> to tensor<24x4xf32>
+ %2 = tensor.extract_slice %arg2[0, %arg3] [24, 4] [1, 1] : tensor<24x12xf32> to tensor<24x4xf32>
+
+ // CHECK: linalg.generic
+ // CHECK: tensor.pad
+ // CHECK: linalg.elemwise_unary
+ %3 = linalg.elemwise_unary ins(%1 : tensor<24x4xf32>)
+ outs(%2: tensor<24x4xf32>) -> tensor<24x4xf32>
+ %4 = tensor.insert_slice %3 into %arg2[0, %arg3] [24, 4] [1, 1] : tensor<24x4xf32> into tensor<24x12xf32>
+ scf.yield %4 : tensor<24x12xf32>
+ }
+ return %0 : tensor<24x12xf32>
+}
+
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@pad_unary](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target
+ %1 = pad %0 {pack_paddings=[1, 1], hoist_paddings=[1, 0], transpose_paddings=[[1, 0], [0, 1]]}
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
new file mode 100644
index 0000000..7ff0112
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/roundtrip.mlir
@@ -0,0 +1,33 @@
+// RUN: iree-dialects-opt %s | FileCheck %s
+
+// CHECK: iree_linalg_transform.sequence
+iree_linalg_transform.sequence {
+ // CHECK: %[[OPS:.*]] = match @{{.*}}
+ %0 = match @match1
+ // CHECK: %[[TILED:.*]] = tile %[[OPS]] {
+ // CHECK-DAG: sizes = [4, 4, 4]
+ // CHECK: }
+ %1 = tile %0 {sizes = [4, 4, 4]}
+ // CHECK: %[[TILED2:.*]] = tile %[[TILED]]
+ %2 = tile %1 {sizes = [2, 2, 2]}
+ // CHECK: %[[PADDED:.*]] = pad %[[TILED2]] {pack_paddings = [1, 1, 0]}
+ %3 = pad %2 {pack_paddings = [1, 1, 0]}
+ // CHECK: decompose
+ decompose
+ // CHECK: %{{.*}} = vectorize %[[PADDED]] {vectorize_padding = true}
+ %4 = vectorize %3 {vectorize_padding = true}
+ // CHECK: %[[OPS2:.*]] = match @{{.*}}
+ %5 = match @match2
+ // CHECK: %{{.*}} = vectorize %[[OPS2]]
+ vectorize %5
+ // CHECK-NOT: %
+ // CHECK: vectorize
+ // CHECK-NOT: %
+ vectorize
+ // CHECK: bufferize
+ bufferize
+ // CHECK: lower_vectors {multireduction_lowering = "innerreduce"}
+ lower_vectors { multireduction_lowering = "innerreduce"}
+ // CHECK: lower_to_llvm
+ lower_to_llvm
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/scoped.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/scoped.mlir
new file mode 100644
index 0000000..6964ef1
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/scoped.mlir
@@ -0,0 +1,30 @@
+// RUN: iree-dialects-opt -test-wrap-scope='opname=arith.addi' %s | FileCheck %s --check-prefix WRAP
+// RUN: iree-dialects-opt -test-unwrap-scope %s | FileCheck %s --check-prefix UNWRAP
+
+// WRAP-LABEL: @test_wrap
+// WRAP-SAME: (%[[ARG0:.*]]: i32) -> i32
+func @test_wrap(%arg0: i32) -> i32 {
+ // WRAP: %[[V:.*]] = iree_linalg_transform.util.scope(%[[ARG0]], %[[ARG0]]) {
+ // WRAP-NEXT: ^[[B:.*]](%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
+ // WRAP-NEXT: %[[ADD:.*]] = arith.addi %[[ARG2]], %[[ARG2]]
+ // WRAP-NEXT: iree_linalg_transform.util.forward %[[ADD]]
+ // WRAP-NEXT: } : (i32, i32) -> i32
+ %0 = arith.addi %arg0, %arg0 : i32
+ // WRAP: return %[[V]]
+ return %0 : i32
+}
+
+// UNWRAP-LABEL: @test_unwrap
+// UNWRAP-SAME: (%[[ARG0:.*]]: i32) -> (i32, i32)
+func @test_unwrap(%arg0: i32) -> (i32, i32) {
+ // UNWRAP: %[[V0:.*]] = arith.addi %[[ARG0]], %[[ARG0]]
+ // UNWRAP-NEXT: %[[V1:.*]] = arith.addi %[[V0]], %[[ARG0]]
+ %0:2 = iree_linalg_transform.util.scope(%arg0) {
+ ^bb0(%arg1: i32):
+ %1 = arith.addi %arg1, %arg1 : i32
+ %2 = arith.addi %1, %arg1 : i32
+ iree_linalg_transform.util.forward %1, %2 : i32, i32
+ } : (i32) -> (i32, i32)
+ // UNWRAP-NEXT: return %[[V0]], %[[V1]]
+ return %0#0, %0#1 : i32, i32
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/selective-targeting.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/selective-targeting.mlir
new file mode 100644
index 0000000..fdcd2f9
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/selective-targeting.mlir
@@ -0,0 +1,134 @@
+// RUN: iree-dialects-opt %s -linalg-interp-transforms -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @matmul_tensors(
+func @matmul_tensors(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>,
+ %arg3: tensor<128x128xf32>, %arg4: tensor<128x128xf32>, %arg5: tensor<128x128xf32>,
+ %arg6: tensor<128x128xf32> {linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // This operation is marked for tiling only.
+ // CHECK-COUNT-3: scf.for
+ // CHECK-COUNT-3: tensor.extract_slice
+ // CHECK: linalg.matmul
+ // CHECK-SAME: -> tensor<4x4xf32>
+ %0 = linalg.matmul { test.attrA}
+ ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ // This operation is marked for tiling and vectorization.
+ // Note that the loop-invariant read is hoisted out of the innermost loop.
+ // CHECK: scf.for
+ // CHECK: scf.for
+ // CHECK: vector.transfer_read
+ // CHECK: scf.for
+ // CHECK: vector.transfer_read
+ // CHECK: vector.transfer_read
+ // CHECK: vector.contract
+ // CHECK-NOT: linalg.matmul
+ // CHECK: vector.transfer_write
+ %1 = linalg.matmul { test.attrA, test.attrC}
+ ins(%arg3, %arg4: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg5: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ // This operation is marked for vectorization only.
+ // CHECK-NOT: scf.for
+ // CHECK-COUNT-3: vector.transfer_read
+ // CHECK: vector.contract
+ // CHECK-SAME: into vector<128x128xf32>
+ // CHECK: vector.transfer_write
+ %2 = linalg.matmul { test.attrC}
+ ins(%0, %1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg6: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ return %2 : tensor<128x128xf32>
+}
+
+// Match matmul operations inside @matmul_tensors with test.attrA set.
+pdl.pattern @pdl_target_attrA : benefit(1) {
+ %args = operands
+ %results = types
+ %attr = attribute
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@matmul_tensors](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+// Match matmul operations inside @matmul_tensors with test.attrC set.
+pdl.pattern @pdl_target_attrC : benefit(1) {
+ %args = operands
+ %results = types
+ %attr = attribute
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrC" = %attr}-> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@matmul_tensors](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target_attrA
+ tile %0 {sizes = [4, 4, 4]}
+ %1 = match @pdl_target_attrC
+ vectorize %1
+}
+
+// -----
+
+// CHECK-LABEL: @vectorize_one
+func @vectorize_one(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>,
+ %arg3: tensor<128x128xf32> {linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // CHECK: vector.contract
+ %0 = linalg.matmul {test.attrA}
+ ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ // CHECK: linalg.matmul
+ %1 = linalg.matmul ins(%arg0, %0: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg3: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ return %1 : tensor<128x128xf32>
+}
+
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %attr = attribute
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@vectorize_one](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target
+ vectorize %0
+}
+
+
+// -----
+
+// CHECK-LABEL: @vectorize_all
+func @vectorize_all(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>,
+ %arg3: tensor<128x128xf32> {linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // CHECK: vector.contract
+ %0 = linalg.matmul {test.attrA}
+ ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ // CHECK: vector.contract
+ %1 = linalg.matmul ins(%arg0, %0: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg3: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ return %1 : tensor<128x128xf32>
+}
+
+iree_linalg_transform.sequence {
+ vectorize
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
new file mode 100644
index 0000000..adffa86
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
@@ -0,0 +1,33 @@
+// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+
+// CHECK-LABEL: func @matmul_tensors
+// CHECK-NOT: linalg
+// CHECK: llvm
+func @matmul_tensors(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ return %0 : tensor<128x128xf32>
+}
+
+
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@matmul_tensors](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target
+ %1 = tile %0 {sizes = [4, 4, 4]}
+ %2 = vectorize %1 {vectorize_padding = true}
+ bufferize
+ lower_vectors { multireduction_lowering = "innerreduce"}
+ lower_to_llvm
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-interchange.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-interchange.mlir
new file mode 100644
index 0000000..88286aa
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile-interchange.mlir
@@ -0,0 +1,72 @@
+// RUN: iree-dialects-opt -linalg-interp-transforms -split-input-file %s | FileCheck %s
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// Check that vectorization applies after interchange+tiling.
+
+// CHECK-LABEL: @matmul_021
+// CHECK-NOT: linalg.generic
+// CHECK: vector.contract
+func public @matmul_021(%arg0: tensor<39x154xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<154x5xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<39x5xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<39x5xf32> attributes {passthrough = ["noinline", ["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} {
+ %0 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<39x154xf32>, tensor<154x5xf32>) outs(%arg2 : tensor<39x5xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %1 = arith.mulf %arg3, %arg4 : f32
+ %2 = arith.addf %arg5, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<39x5xf32>
+ return %0 : tensor<39x5xf32>
+}
+
+pdl.pattern @target_pattern : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc" [@matmul_021](%2 : !pdl.operation)
+ rewrite %2 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @target_pattern
+ %1 = tile %0 {interchange = [0, 2, 1], sizes = [3, 5, 14]}
+ %2 = tile %1 {sizes = [3, 5, 2]}
+ %3 = vectorize %2 {vectorize_padding = true}
+}
+
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// Check that vectorization applies after interchange+tiling.
+
+// CHECK-LABEL: @matmul_210
+// CHECK-NOT: linalg.generic
+// CHECK: vector.contract
+func public @matmul_210(%arg0: tensor<39x154xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<154x5xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg2: tensor<39x5xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<39x5xf32> attributes {passthrough = ["noinline", ["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} {
+ %0 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<39x154xf32>, tensor<154x5xf32>) outs(%arg2 : tensor<39x5xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %1 = arith.mulf %arg3, %arg4 : f32
+ %2 = arith.addf %arg5, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<39x5xf32>
+ return %0 : tensor<39x5xf32>
+}
+
+pdl.pattern @target_pattern : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc" [@matmul_210](%2 : !pdl.operation)
+ rewrite %2 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @target_pattern
+ %1 = tile %0 {interchange = [2, 1, 0], sizes = [3, 5, 14]}
+ %2 = tile %1 {sizes = [3, 5, 2]}
+ %3 = vectorize %2 {vectorize_padding = true}
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile.mlir
new file mode 100644
index 0000000..ba94d44
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/tile.mlir
@@ -0,0 +1,44 @@
+// RUN: iree-dialects-opt -linalg-interp-transforms %s | FileCheck %s
+
+// CHECK-LABEL: func @matmul_tensors(
+// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME: -> tensor<128x128xf32> {
+func @matmul_tensors(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) {
+// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) {
+// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) {
+// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
+// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
+// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
+// CHECK: %[[sTD:.*]] = linalg.matmul {{.*}} ins(%[[sTA]], %[[sTB]] : tensor<4x4xf32>, tensor<4x4xf32>)
+// CHECK-SAME: outs(%[[sTC]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<4x4xf32> into tensor<128x128xf32>
+// CHECK: scf.yield %[[TD]] : tensor<128x128xf32>
+// CHECK: scf.yield %[[TD2]] : tensor<128x128xf32>
+// CHECK: scf.yield %[[TD1]] : tensor<128x128xf32>
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+// CHECK: return %[[TD0]] : tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@matmul_tensors](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target
+ tile %0 {sizes = [4, 4, 4]}
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize-transforms.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize-transforms.mlir
new file mode 100644
index 0000000..60864ee
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize-transforms.mlir
@@ -0,0 +1,16 @@
+// This test only checks the content of the file parses.
+// RUN: iree-dialects-opt %s
+
+pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ apply_native_constraint "nestedInFunc"[@matmul_tensors](%0 : !pdl.operation)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "iree_linalg_transform.apply"
+}
+
+iree_linalg_transform.sequence {
+ %0 = match @pdl_target
+ vectorize %0 {vectorize_padding = true}
+}
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize.mlir
new file mode 100644
index 0000000..303ff83
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/vectorize.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-dialects-opt -linalg-interp-transforms -linalg-transform-file-name=%p/vectorize-transforms.mlir %s | FileCheck %s
+
+// CHECK-LABEL: func @matmul_tensors(
+// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME: -> tensor<128x128xf32> {
+func @matmul_tensors(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32> { linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // CHECK: %[[VA:.*]] = vector.transfer_read %[[TA]]
+ // CHECK: %[[VB:.*]] = vector.transfer_read %[[TB]]
+ // CHECK: %[[VC:.*]] = vector.transfer_read %[[TC]]
+ // CHECK: %[[VCU:.*]] = vector.contract {{.*}} %[[VA]], %[[VB]], %[[VC]]
+ // CHECK: vector.transfer_write %[[VCU]], %[[TC]]
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ return %0 : tensor<128x128xf32>
+}
diff --git a/llvm-external-projects/iree-dialects/test/Transforms/test-listener-canonicalize.mlir b/llvm-external-projects/iree-dialects/test/Transforms/test-listener-canonicalize.mlir
new file mode 100644
index 0000000..e1be42c
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Transforms/test-listener-canonicalize.mlir
@@ -0,0 +1,102 @@
+// RUN: iree-dialects-opt %s -allow-unregistered-dialect -test-listener-canonicalize --split-input-file | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Everything below copied from mlir/test/Dialect/Standard/canonicalize.mlir
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @select_same_val
+// CHECK: return %arg1
+func @select_same_val(%arg0: i1, %arg1: i64) -> i64 {
+ %0 = arith.select %arg0, %arg1, %arg1 : i64
+ return %0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_cmp_eq_select
+// CHECK: return %arg1
+func @select_cmp_eq_select(%arg0: i64, %arg1: i64) -> i64 {
+ %0 = arith.cmpi eq, %arg0, %arg1 : i64
+ %1 = arith.select %0, %arg0, %arg1 : i64
+ return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_cmp_ne_select
+// CHECK: return %arg0
+func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 {
+ %0 = arith.cmpi ne, %arg0, %arg1 : i64
+ %1 = arith.select %0, %arg0, %arg1 : i64
+ return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_extui
+// CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64
+// CHECK: return %[[res]]
+func @select_extui(%arg0: i1) -> i64 {
+ %c0_i64 = arith.constant 0 : i64
+ %c1_i64 = arith.constant 1 : i64
+ %res = arith.select %arg0, %c1_i64, %c0_i64 : i64
+ return %res : i64
+}
+
+// CHECK-LABEL: @select_extui2
+// CHECK-DAG: %true = arith.constant true
+// CHECK-DAG: %[[xor:.+]] = arith.xori %arg0, %true : i1
+// CHECK-DAG: %[[res:.+]] = arith.extui %[[xor]] : i1 to i64
+// CHECK: return %[[res]]
+func @select_extui2(%arg0: i1) -> i64 {
+ %c0_i64 = arith.constant 0 : i64
+ %c1_i64 = arith.constant 1 : i64
+ %res = arith.select %arg0, %c0_i64, %c1_i64 : i64
+ return %res : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_extui_i1
+// CHECK-NEXT: return %arg0
+func @select_extui_i1(%arg0: i1) -> i1 {
+ %c0_i1 = arith.constant false
+ %c1_i1 = arith.constant true
+ %res = arith.select %arg0, %c1_i1, %c0_i1 : i1
+ return %res : i1
+}
+
+// -----
+
+// CHECK-LABEL: @branchCondProp
+// CHECK: %[[trueval:.+]] = arith.constant true
+// CHECK: %[[falseval:.+]] = arith.constant false
+// CHECK: "test.consumer1"(%[[trueval]]) : (i1) -> ()
+// CHECK: "test.consumer2"(%[[falseval]]) : (i1) -> ()
+func @branchCondProp(%arg0: i1) {
+ cf.cond_br %arg0, ^trueB, ^falseB
+
+^trueB:
+ "test.consumer1"(%arg0) : (i1) -> ()
+ cf.br ^exit
+
+^falseB:
+ "test.consumer2"(%arg0) : (i1) -> ()
+ cf.br ^exit
+
+^exit:
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @selToNot
+// CHECK: %[[trueval:.+]] = arith.constant true
+// CHECK: %{{.+}} = arith.xori %arg0, %[[trueval]] : i1
+func @selToNot(%arg0: i1) -> i1 {
+ %true = arith.constant true
+ %false = arith.constant false
+ %res = arith.select %arg0, %false, %true : i1
+ return %res : i1
+}
+
diff --git a/llvm-external-projects/iree-dialects/test/Transforms/test-listener-cse.mlir b/llvm-external-projects/iree-dialects/test/Transforms/test-listener-cse.mlir
new file mode 100644
index 0000000..13d2994
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Transforms/test-listener-cse.mlir
@@ -0,0 +1,250 @@
+// RUN: iree-dialects-opt %s -allow-unregistered-dialect -test-listener-cse | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Everything below copied from mlir/test/Transforms/cse.mlir
+//===----------------------------------------------------------------------===//
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
+#map0 = affine_map<(d0) -> (d0 mod 2)>
+
+// CHECK-LABEL: @simple_constant
+func @simple_constant() -> (i32, i32) {
+ // CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
+ %0 = arith.constant 1 : i32
+
+ // CHECK-NEXT: return %c1_i32, %c1_i32 : i32, i32
+ %1 = arith.constant 1 : i32
+ return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @basic
+func @basic() -> (index, index) {
+ // CHECK: %c0 = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 0 : index
+
+ // CHECK-NEXT: %0 = affine.apply #[[$MAP]](%c0)
+ %0 = affine.apply #map0(%c0)
+ %1 = affine.apply #map0(%c1)
+
+ // CHECK-NEXT: return %0, %0 : index, index
+ return %0, %1 : index, index
+}
+
+// CHECK-LABEL: @many
+func @many(f32, f32) -> (f32) {
+^bb0(%a : f32, %b : f32):
+ // CHECK-NEXT: %0 = arith.addf %arg0, %arg1 : f32
+ %c = arith.addf %a, %b : f32
+ %d = arith.addf %a, %b : f32
+ %e = arith.addf %a, %b : f32
+ %f = arith.addf %a, %b : f32
+
+ // CHECK-NEXT: %1 = arith.addf %0, %0 : f32
+ %g = arith.addf %c, %d : f32
+ %h = arith.addf %e, %f : f32
+ %i = arith.addf %c, %e : f32
+
+ // CHECK-NEXT: %2 = arith.addf %1, %1 : f32
+ %j = arith.addf %g, %h : f32
+ %k = arith.addf %h, %i : f32
+
+ // CHECK-NEXT: %3 = arith.addf %2, %2 : f32
+ %l = arith.addf %j, %k : f32
+
+ // CHECK-NEXT: return %3 : f32
+ return %l : f32
+}
+
+/// Check that operations are not eliminated if they have different operands.
+// CHECK-LABEL: @different_ops
+func @different_ops() -> (i32, i32) {
+ // CHECK: %c0_i32 = arith.constant 0 : i32
+ // CHECK: %c1_i32 = arith.constant 1 : i32
+ %0 = arith.constant 0 : i32
+ %1 = arith.constant 1 : i32
+
+ // CHECK-NEXT: return %c0_i32, %c1_i32 : i32, i32
+ return %0, %1 : i32, i32
+}
+
+/// Check that operations are not eliminated if they have different result
+/// types.
+// CHECK-LABEL: @different_results
+func @different_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4x?xf32>) {
+ // CHECK: %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
+ // CHECK-NEXT: %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
+ %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
+ %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
+
+ // CHECK-NEXT: return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
+}
+
+/// Check that operations are not eliminated if they have different attributes.
+// CHECK-LABEL: @different_attributes
+func @different_attributes(index, index) -> (i1, i1, i1) {
+^bb0(%a : index, %b : index):
+ // CHECK: %0 = arith.cmpi slt, %arg0, %arg1 : index
+ %0 = arith.cmpi slt, %a, %b : index
+
+ // CHECK-NEXT: %1 = arith.cmpi ne, %arg0, %arg1 : index
+ /// Predicate 1 means inequality comparison.
+ %1 = arith.cmpi ne, %a, %b : index
+ %2 = "arith.cmpi"(%a, %b) {predicate = 1} : (index, index) -> i1
+
+ // CHECK-NEXT: return %0, %1, %1 : i1, i1, i1
+ return %0, %1, %2 : i1, i1, i1
+}
+
+/// Check that operations with side effects are not eliminated.
+// CHECK-LABEL: @side_effect
+func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
+ // CHECK: %0 = memref.alloc() : memref<2x1xf32>
+ %0 = memref.alloc() : memref<2x1xf32>
+
+ // CHECK-NEXT: %1 = memref.alloc() : memref<2x1xf32>
+ %1 = memref.alloc() : memref<2x1xf32>
+
+ // CHECK-NEXT: return %0, %1 : memref<2x1xf32>, memref<2x1xf32>
+ return %0, %1 : memref<2x1xf32>, memref<2x1xf32>
+}
+
+/// Check that operation definitions are properly propagated down the dominance
+/// tree.
+// CHECK-LABEL: @down_propagate_for
+func @down_propagate_for() {
+ // CHECK: %c1_i32 = arith.constant 1 : i32
+ %0 = arith.constant 1 : i32
+
+ // CHECK-NEXT: affine.for {{.*}} = 0 to 4 {
+ affine.for %i = 0 to 4 {
+ // CHECK-NEXT: "foo"(%c1_i32, %c1_i32) : (i32, i32) -> ()
+ %1 = arith.constant 1 : i32
+ "foo"(%0, %1) : (i32, i32) -> ()
+ }
+ return
+}
+
+// CHECK-LABEL: @down_propagate
+func @down_propagate() -> i32 {
+ // CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
+ %0 = arith.constant 1 : i32
+
+ // CHECK-NEXT: %true = arith.constant true
+ %cond = arith.constant true
+
+ // CHECK-NEXT: cf.cond_br %true, ^bb1, ^bb2(%c1_i32 : i32)
+ cf.cond_br %cond, ^bb1, ^bb2(%0 : i32)
+
+^bb1: // CHECK: ^bb1:
+ // CHECK-NEXT: cf.br ^bb2(%c1_i32 : i32)
+ %1 = arith.constant 1 : i32
+ cf.br ^bb2(%1 : i32)
+
+^bb2(%arg : i32):
+ return %arg : i32
+}
+
+/// Check that operation definitions are NOT propagated up the dominance tree.
+// CHECK-LABEL: @up_propagate_for
+func @up_propagate_for() -> i32 {
+ // CHECK: affine.for {{.*}} = 0 to 4 {
+ affine.for %i = 0 to 4 {
+ // CHECK-NEXT: %c1_i32_0 = arith.constant 1 : i32
+ // CHECK-NEXT: "foo"(%c1_i32_0) : (i32) -> ()
+ %0 = arith.constant 1 : i32
+ "foo"(%0) : (i32) -> ()
+ }
+
+ // CHECK: %c1_i32 = arith.constant 1 : i32
+ // CHECK-NEXT: return %c1_i32 : i32
+ %1 = arith.constant 1 : i32
+ return %1 : i32
+}
+
+// CHECK-LABEL: func @up_propagate
+func @up_propagate() -> i32 {
+ // CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
+ %0 = arith.constant 0 : i32
+
+ // CHECK-NEXT: %true = arith.constant true
+ %cond = arith.constant true
+
+ // CHECK-NEXT: cf.cond_br %true, ^bb1, ^bb2(%c0_i32 : i32)
+ cf.cond_br %cond, ^bb1, ^bb2(%0 : i32)
+
+^bb1: // CHECK: ^bb1:
+ // CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
+ %1 = arith.constant 1 : i32
+
+ // CHECK-NEXT: cf.br ^bb2(%c1_i32 : i32)
+ cf.br ^bb2(%1 : i32)
+
+^bb2(%arg : i32): // CHECK: ^bb2
+ // CHECK-NEXT: %c1_i32_0 = arith.constant 1 : i32
+ %2 = arith.constant 1 : i32
+
+ // CHECK-NEXT: %1 = arith.addi %0, %c1_i32_0 : i32
+ %add = arith.addi %arg, %2 : i32
+
+ // CHECK-NEXT: return %1 : i32
+ return %add : i32
+}
+
+/// The same test as above except that we are testing on a cfg embedded within
+/// an operation region.
+// CHECK-LABEL: func @up_propagate_region
+func @up_propagate_region() -> i32 {
+ // CHECK-NEXT: %0 = "foo.region"
+ %0 = "foo.region"() ({
+ // CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
+ // CHECK-NEXT: %true = arith.constant true
+ // CHECK-NEXT: cf.cond_br
+
+ %1 = arith.constant 0 : i32
+ %true = arith.constant true
+ cf.cond_br %true, ^bb1, ^bb2(%1 : i32)
+
+ ^bb1: // CHECK: ^bb1:
+ // CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
+ // CHECK-NEXT: br
+
+ %c1_i32 = arith.constant 1 : i32
+ cf.br ^bb2(%c1_i32 : i32)
+
+ ^bb2(%arg : i32): // CHECK: ^bb2(%1: i32):
+ // CHECK-NEXT: %c1_i32_0 = arith.constant 1 : i32
+ // CHECK-NEXT: %2 = arith.addi %1, %c1_i32_0 : i32
+ // CHECK-NEXT: "foo.yield"(%2) : (i32) -> ()
+
+ %c1_i32_0 = arith.constant 1 : i32
+ %2 = arith.addi %arg, %c1_i32_0 : i32
+ "foo.yield" (%2) : (i32) -> ()
+ }) : () -> (i32)
+ return %0 : i32
+}
+
+/// This test checks that nested regions that are isolated from above are
+/// properly handled.
+// CHECK-LABEL: @nested_isolated
+func @nested_isolated() -> i32 {
+ // CHECK-NEXT: arith.constant 1
+ %0 = arith.constant 1 : i32
+
+ // CHECK-NEXT: @nested_func
+ builtin.func @nested_func() {
+ // CHECK-NEXT: arith.constant 1
+ %foo = arith.constant 1 : i32
+ "foo.yield"(%foo) : (i32) -> ()
+ }
+
+ // CHECK: "foo.region"
+ "foo.region"() ({
+ // CHECK-NEXT: arith.constant 1
+ %foo = arith.constant 1 : i32
+ "foo.yield"(%foo) : (i32) -> ()
+ }) : () -> ()
+
+ return %0 : i32
+}
diff --git a/llvm-external-projects/iree-dialects/test/Transforms/test-with-listener.mlir b/llvm-external-projects/iree-dialects/test/Transforms/test-with-listener.mlir
new file mode 100644
index 0000000..b765ebc
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/Transforms/test-with-listener.mlir
@@ -0,0 +1,20 @@
+// RUN: iree-dialects-opt -test-listener-canonicalize='listener=1' %s | FileCheck %s --check-prefix CANON
+// RUN: iree-dialects-opt -test-listener-cse='listener=1' %s | FileCheck %s --check-prefix CSE
+
+func @test_canonicalize(%arg0: i32) -> (i32, i32) {
+ // CANON: REPLACED arith.addi
+ // CANON: REMOVED arith.addi
+ %c5 = arith.constant -5 : i32
+ %0 = arith.addi %c5, %arg0 : i32
+ %1 = arith.addi %c5, %0 : i32
+ return %0, %1 : i32, i32
+}
+
+func @test_cse(%arg0: i32) -> (i32, i32) {
+ // CSE: REPLACED arith.addi
+ // CSE: REMOVED arith.addi
+ %c5 = arith.constant -5 : i32
+ %0 = arith.addi %c5, %arg0 : i32
+ %1 = arith.addi %c5, %arg0 : i32
+ return %0, %1 : i32, i32
+}
diff --git a/llvm-external-projects/iree-dialects/test/lib/CMakeLists.txt b/llvm-external-projects/iree-dialects/test/lib/CMakeLists.txt
new file mode 100644
index 0000000..557daa8
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/lib/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(Dialect)
+add_subdirectory(Transforms)
diff --git a/llvm-external-projects/iree-dialects/test/lib/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/test/lib/Dialect/CMakeLists.txt
new file mode 100644
index 0000000..1da2860
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/lib/Dialect/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(LinalgTransform)
diff --git a/llvm-external-projects/iree-dialects/test/lib/Dialect/LinalgTransform/CMakeLists.txt b/llvm-external-projects/iree-dialects/test/lib/Dialect/LinalgTransform/CMakeLists.txt
new file mode 100644
index 0000000..30392ca
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/lib/Dialect/LinalgTransform/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_library(IREELinalgTransformTestPasses
+ TestScopedTransform.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ DEPENDS
+ mlir-headers
+
+ LINK_LIBS PUBLIC
+ IREELinalgTransformDialectTransforms
+ MLIRPass
+ )
diff --git a/llvm-external-projects/iree-dialects/test/lib/Dialect/LinalgTransform/TestScopedTransform.cpp b/llvm-external-projects/iree-dialects/test/lib/Dialect/LinalgTransform/TestScopedTransform.cpp
new file mode 100644
index 0000000..d157a44
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/lib/Dialect/LinalgTransform/TestScopedTransform.cpp
@@ -0,0 +1,61 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgTransform/ScopedTransform.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct TestWrapScopePass : public PassWrapper<TestWrapScopePass, Pass> {
+ TestWrapScopePass() = default;
+ TestWrapScopePass(const TestWrapScopePass &other) : PassWrapper(other) {}
+
+ StringRef getArgument() const final { return "test-wrap-scope"; }
+ StringRef getDescription() const final { return "Test wrap scope pass."; }
+ bool canScheduleOn(RegisteredOperationName opName) const override {
+ return true;
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::transform::LinalgTransformDialect>();
+ }
+
+ void runOnOperation() override {
+ getOperation()->walk([&](Operation *op) {
+ if (op->getName().getStringRef() != opToWrap)
+ return;
+ linalg::transform::wrapInScope(op);
+ });
+ }
+
+ Pass::Option<std::string> opToWrap{*this, "opname",
+ llvm::cl::desc("Op to wrap")};
+};
+
+struct TestUnwrapScopePass : public PassWrapper<TestUnwrapScopePass, Pass> {
+ StringRef getArgument() const final { return "test-unwrap-scope"; }
+ StringRef getDescription() const final { return "Test unwrap scope pass."; }
+ bool canScheduleOn(RegisteredOperationName opName) const override {
+ return true;
+ }
+
+ void runOnOperation() override {
+ getOperation()->walk(
+ [](linalg::transform::ScopeOp scope) { (void)unwrapScope(scope); });
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test_ext {
+void registerTestLinalgTransformWrapScope() {
+ PassRegistration<TestWrapScopePass>();
+ PassRegistration<TestUnwrapScopePass>();
+}
+} // namespace test_ext
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/test/lib/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/test/lib/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..7d39dbd
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/lib/Transforms/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_library(IREETransformsTestPasses
+ TestListenerPasses.cpp
+
+ DEPENDS
+ mlir-headers
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ IREELinalgTransformDialect
+ MLIRPass
+ )
diff --git a/llvm-external-projects/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp b/llvm-external-projects/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp
new file mode 100644
index 0000000..4927824
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/lib/Transforms/TestListenerPasses.cpp
@@ -0,0 +1,99 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Transforms/Listener.h"
+#include "iree-dialects/Transforms/ListenerCSE.h"
+#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+/// The test listener prints stuff to `stdout` so that it can be checked by lit
+/// tests.
+struct TestListener : public RewriteListener {
+ void notifyOperationReplaced(Operation *op, ValueRange newValues) override {
+ llvm::outs() << "REPLACED " << op->getName() << "\n";
+ }
+ void notifyOperationRemoved(Operation *op) override {
+ llvm::outs() << "REMOVED " << op->getName() << "\n";
+ }
+};
+
+struct TestListenerCanonicalizePass
+ : public PassWrapper<TestListenerCanonicalizePass, Pass> {
+ TestListenerCanonicalizePass() = default;
+ TestListenerCanonicalizePass(const TestListenerCanonicalizePass &other)
+ : PassWrapper(other) {}
+
+ StringRef getArgument() const final { return "test-listener-canonicalize"; }
+ StringRef getDescription() const final { return "Test canonicalize pass."; }
+ bool canScheduleOn(RegisteredOperationName opName) const override {
+ return true;
+ }
+
+ void runOnOperation() override {
+ TestListener listener;
+ RewriteListener *listenerToUse = nullptr;
+ if (withListener)
+ listenerToUse = &listener;
+
+ RewritePatternSet patterns(&getContext());
+ for (Dialect *dialect : getContext().getLoadedDialects())
+ dialect->getCanonicalizationPatterns(patterns);
+ for (RegisteredOperationName op : getContext().getRegisteredOperations())
+ op.getCanonicalizationPatterns(patterns, &getContext());
+
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ GreedyRewriteConfig(),
+ listenerToUse)))
+ signalPassFailure();
+ }
+
+ Pass::Option<bool> withListener{
+ *this, "listener", llvm::cl::desc("Whether to run with a test listener"),
+ llvm::cl::init(false)};
+};
+
+struct TestListenerCSEPass : public PassWrapper<TestListenerCSEPass, Pass> {
+ TestListenerCSEPass() = default;
+ TestListenerCSEPass(const TestListenerCSEPass &other) : PassWrapper(other) {}
+
+ StringRef getArgument() const final { return "test-listener-cse"; }
+ StringRef getDescription() const final { return "Test CSE pass."; }
+ bool canScheduleOn(RegisteredOperationName opName) const override {
+ return true;
+ }
+
+ void runOnOperation() override {
+ TestListener listener;
+ RewriteListener *listenerToUse = nullptr;
+ if (withListener)
+ listenerToUse = &listener;
+
+ if (failed(eliminateCommonSubexpressions(getOperation(),
+ /*domInfo=*/nullptr,
+ listenerToUse)))
+ signalPassFailure();
+ }
+
+ Pass::Option<bool> withListener{
+ *this, "listener", llvm::cl::desc("Whether to run with a test listener"),
+ llvm::cl::init(false)};
+};
+
+} // namespace
+
+namespace mlir {
+namespace test_ext {
+void registerTestListenerPasses() {
+ PassRegistration<TestListenerCanonicalizePass>();
+ PassRegistration<TestListenerCSEPass>();
+}
+} // namespace test_ext
+} // namespace mlir
diff --git a/llvm-external-projects/iree-dialects/test/python/smoketest.py b/llvm-external-projects/iree-dialects/test/python/smoketest.py
index 6804fec..d8b1f97 100644
--- a/llvm-external-projects/iree-dialects/test/python/smoketest.py
+++ b/llvm-external-projects/iree-dialects/test/python/smoketest.py
@@ -3,11 +3,13 @@
import iree.compiler.ir
from iree.compiler.dialects import iree_input as iree_d
from iree.compiler.dialects import iree_linalg_ext
+from iree.compiler.dialects import iree_linalg_transform
from iree.compiler.dialects import iree_pydm as pydm_d
with iree.compiler.ir.Context() as ctx:
iree_d.register_dialect()
iree_linalg_ext.register_dialect()
+ iree_linalg_transform.register_dialect()
pydm_d.register_dialect()
# iree_pydm types.
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
index 6aecef3..60e1afd 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
@@ -1,21 +1,29 @@
set(LIBS
- MLIRArithmetic
- MLIRControlFlow
- MLIRDialect
- MLIRLinalg
- MLIRMemRef
- MLIROptLib
- MLIRSCF
- MLIRSCFTransforms
- MLIRFunc
- MLIRTensor
- MLIRTransforms
+ # Local dialects.
IREEInputDialect
IREELinalgExtDialect
IREELinalgExtPasses
IREELinalgExtTransforms
+ IREELinalgTransformDialect
+ IREELinalgTransformDialectTransforms
+ IREELinalgTransformTestPasses
+ IREETransformsTestPasses
IREEPyDMDialect
IREEPyDMPasses
+ # Core dialects.
+ MLIRArithmetic
+ MLIRControlFlow
+ MLIRDialect
+ MLIRFunc
+ MLIRLinalg
+ MLIRMemRef
+ MLIROptLib
+ MLIRPDL
+ MLIRPDLInterp
+ MLIRSCF
+ MLIRSCFTransforms
+ MLIRTensor
+ MLIRTransforms
)
add_llvm_tool(iree-dialects-opt
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
index d1e844b..e18941d 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
@@ -8,6 +8,8 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
+#include "iree-dialects/Dialect/LinalgTransform/Passes.h"
#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
@@ -15,6 +17,8 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -26,32 +30,58 @@
using namespace mlir;
namespace IREE = mlir::iree_compiler::IREE;
+namespace mlir {
+namespace test_ext {
+/// Test passes, do not deserve an include.
+void registerTestLinalgTransformWrapScope();
+void registerTestListenerPasses();
+} // namespace test_ext
+} // namespace mlir
+
int main(int argc, char **argv) {
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
- registerTransformsPasses();
- registerSCFPasses();
-
- // Local dialects.
- mlir::iree_compiler::IREE::PYDM::registerPasses();
- mlir::iree_compiler::IREE::LinalgExt::registerPasses();
-
DialectRegistry registry;
registry.insert<
+ // clang-format off
// Local dialects
mlir::iree_compiler::IREE::Input::IREEInputDialect,
mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect,
mlir::iree_compiler::IREE::PYDM::IREEPyDMDialect,
+ mlir::linalg::transform::LinalgTransformDialect,
// Upstream dialects
- mlir::arith::ArithmeticDialect, mlir::cf::ControlFlowDialect,
- mlir::linalg::LinalgDialect, mlir::memref::MemRefDialect,
- mlir::func::FuncDialect, mlir::scf::SCFDialect,
- mlir::tensor::TensorDialect>();
+ mlir::arith::ArithmeticDialect,
+ mlir::AffineDialect,
+ mlir::cf::ControlFlowDialect,
+ mlir::func::FuncDialect,
+ mlir::linalg::LinalgDialect,
+ mlir::memref::MemRefDialect,
+ mlir::pdl::PDLDialect,
+ mlir::pdl_interp::PDLInterpDialect,
+ mlir::scf::SCFDialect,
+ mlir::tensor::TensorDialect
+ // clang-format on
+ >();
+ // Core dialect passes.
+ registerTransformsPasses();
+ registerSCFPasses();
+ // Local dialect passes.
+ mlir::iree_compiler::IREE::PYDM::registerPasses();
+ mlir::iree_compiler::IREE::LinalgExt::registerPasses();
+ mlir::linalg::transform::registerLinalgTransformInterpreterPass();
+ mlir::linalg::transform::registerLinalgTransformExpertExpansionPass();
+ mlir::linalg::transform::registerDropSchedulePass();
+ // Local test passes.
+ mlir::test_ext::registerTestLinalgTransformWrapScope();
+ mlir::test_ext::registerTestListenerPasses();
+
+ // External models.
IREE::LinalgExt::registerTiledOpInterfaceExternalModels(registry);
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
- /*preloadDialectsInContext=*/false));
+ // Note: without preloading, 3 tests fail atm.
+ /*preloadDialectsInContext=*/true));
}