Moves the linalg_ext dialect to iree_linalg_ext under the iree-dialects project. (#7657)
Mostly NFC:
* Pre-factors iree-dialects into a consistent state wrt namespaces and directory layouts.
* Moves linalg_ext to a new dialect under iree-dialects. Does some adaptation to upstream style along the way.
* Redirects everything that was using it to use the new one.
* Enables tests for iree-dialects in the cmake CI (they were not enabled) and fixes some type prefixing that had drifted. (Will follow up with enabling them in the internal builds to guard against further regression).
Non-NFC:
*When tiling, the old pass was directly generating flow.dispatch.workgroup... ops to query the current workgroup id. We had been planning to add those to the input dialect, so I pulled part of that patch forward.
* Since these now newly lack lowerings, I was expecting to hit some test failures that would guide me to where to adapt, but this appears to be dead code (outside of integration tests?). We'll see what the CI says.
* Can keep coding on this patch to adapt whatever is needed.
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index a112719..f2c7ea6 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -6,7 +6,7 @@
licenses = ["notice"],
)
-exports_files(glob(["include/iree-dialects/Dialect/IREE/*.td"]))
+exports_files(glob(["include/iree-dialects/Dialect/Input/*.td"]))
exports_files(glob(["python/*.cpp"]))
@@ -50,17 +50,21 @@
filegroup(
name = "TdFilegroup",
srcs = glob([
- "include/iree-dialects/Dialect/IREE/*.td",
- "include/iree-dialects/Dialect/IREEPyDM/IR/*.td",
- "include/iree-dialects/Dialect/IREEPyDM/Transforms/*.td",
+ "include/iree-dialects/Dialect/Input/*.td",
+ "include/iree-dialects/Dialect/LinalgExt/IR/*.td",
+ "include/iree-dialects/Dialect/LinalgExt/Transforms/*.td",
+ "include/iree-dialects/Dialect/PyDM/IR/*.td",
+ "include/iree-dialects/Dialect/PyDM/Transforms/*.td",
]),
)
td_library(
name = "TdFiles",
srcs = glob([
- "include/iree-dialects/Dialect/IREE/*.td",
- "include/iree-dialects/Dialect/IREEPyDM/IR/*.td",
+ "include/iree-dialects/Dialect/Input/*.td",
+ "include/iree-dialects/Dialect/LinalgExt/IR/*.td",
+ "include/iree-dialects/Dialect/LinalgExt/Transforms/*.td",
+ "include/iree-dialects/Dialect/PyDM/IR/*.td",
"python/iree/compiler/dialects/*.td",
]) + [
"@llvm-project//mlir:include/mlir/Bindings/Python/Attributes.td",
@@ -73,52 +77,52 @@
)
################################################################################
-# IREE dialect
+# IREEInput dialect
################################################################################
gentbl_cc_library(
- name = "IREEOpsIncGen",
+ name = "IREEInputIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-dialect-decls"],
- "include/iree-dialects/Dialect/IREE/IREEOpsDialect.h.inc",
+ "include/iree-dialects/Dialect/Input/InputDialect.h.inc",
),
(
["-gen-dialect-defs"],
- "include/iree-dialects/Dialect/IREE/IREEOpsDialect.cpp.inc",
+ "include/iree-dialects/Dialect/Input/InputDialect.cpp.inc",
),
(
["-gen-op-decls"],
- "include/iree-dialects/Dialect/IREE/IREEOps.h.inc",
+ "include/iree-dialects/Dialect/Input/InputOps.h.inc",
),
(
["-gen-op-defs"],
- "include/iree-dialects/Dialect/IREE/IREEOps.cpp.inc",
+ "include/iree-dialects/Dialect/Input/InputOps.cpp.inc",
),
(
["-gen-typedef-decls"],
- "include/iree-dialects/Dialect/IREE/IREEOpsTypes.h.inc",
+ "include/iree-dialects/Dialect/Input/InputTypes.h.inc",
),
(
["-gen-typedef-defs"],
- "include/iree-dialects/Dialect/IREE/IREEOpsTypes.cpp.inc",
+ "include/iree-dialects/Dialect/Input/InputTypes.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "include/iree-dialects/Dialect/IREE/IREEOps.td",
+ td_file = "include/iree-dialects/Dialect/Input/InputOps.td",
deps = [":TdFiles"],
)
cc_library(
- name = "IREEDialect",
+ name = "IREEInputDialect",
srcs = glob([
- "lib/Dialect/IREE/*.cpp",
+ "lib/Dialect/Input/*.cpp",
]),
- hdrs = glob(["include/iree-dialects/Dialect/IREE/*.h"]),
+ hdrs = glob(["include/iree-dialects/Dialect/Input/*.h"]),
includes = ["include"],
deps = [
- ":IREEOpsIncGen",
+ ":IREEInputIncGen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
@@ -126,7 +130,7 @@
)
gentbl_filegroup(
- name = "IREEDialectPyGen",
+ name = "IREEInputDialectPyGen",
tbl_outs = [
(
[
@@ -137,13 +141,197 @@
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "python/iree/compiler/dialects/IreeBinding.td",
+ td_file = "python/iree/compiler/dialects/IreeInputBinding.td",
deps = [
":TdFiles",
],
)
################################################################################
+# IREELinalgExt Dialect
+################################################################################
+
+gentbl_cc_library(
+ name = "IREELinalgExtIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ [
+ "-dialect=iree_linalg_ext",
+ "-gen-dialect-decls",
+ ],
+ "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h.inc",
+ ),
+ (
+ [
+ "-dialect=iree_linalg_ext",
+ "-gen-dialect-defs",
+ ],
+ "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.cpp.inc",
+ ),
+ (
+ ["-gen-op-decls"],
+ "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc",
+ ),
+ (
+ ["-gen-typedef-decls"],
+ "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtTypes.h.inc",
+ ),
+ (
+ ["-gen-typedef-defs"],
+ "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtTypes.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td",
+ deps = [
+ ":TdFiles",
+ "@llvm-project//mlir:CallInterfacesTdFiles",
+ "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "IREELinalgExtInterfacesIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.cpp.inc",
+ ),
+ (
+ ["-gen-type-interface-decls"],
+ "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtTypeInterfaces.h.inc",
+ ),
+ (
+ ["-gen-type-interface-defs"],
+ "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtTypeInterfaces.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td",
+ deps = [
+ ":TdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "IREELinalgExtTiledOpInterfacesIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.td",
+ deps = [
+ ":TdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "IREELinalgExtTransformsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-pass-decls"],
+ "include/iree-dialects/Dialect/LinalgExt/Transforms/Passes.h.inc",
+ ),
+ (
+ ["-gen-pass-capi-header"],
+ "include/iree-dialects/Dialect/LinalgExt/Transforms/Passes.capi.h.inc",
+ ),
+ (
+ ["-gen-pass-capi-impl"],
+ "include/iree-dialects/Dialect/LinalgExt/Transforms/Passes.capi.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "include/iree-dialects/Dialect/LinalgExt/Transforms/Passes.td",
+ deps = [
+ ":TdFiles",
+ "@llvm-project//mlir:PassBaseTdFiles",
+ ],
+)
+
+cc_library(
+ name = "IREELinalgExtDialect",
+ srcs = glob([
+ "lib/Dialect/LinalgExt/IR/*.cpp",
+ ]),
+ hdrs = glob([
+ "include/iree-dialects/Dialect/LinalgExt/IR/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":IREELinalgExtIncGen",
+ ":IREELinalgExtInterfacesIncGen",
+ ":IREELinalgExtTiledOpInterfacesIncGen",
+ ":IREELinalgExtTransformsIncGen",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:MathDialect",
+ "@llvm-project//mlir:MemRefDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:SideEffectInterfaces",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:ViewLikeInterface",
+ ],
+)
+
+cc_library(
+ name = "IREELinalgExtTransforms",
+ srcs = glob([
+ "lib/Dialect/LinalgExt/Transforms/*.cpp",
+ ]),
+ hdrs = glob([
+ "include/iree-dialects/Dialect/LinalgExt/Transforms/*.h",
+ ]),
+ deps = [
+ ":IREEInputDialect",
+ ":IREELinalgExtDialect",
+ ":IREELinalgExtTransformsIncGen",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:ArithmeticDialect",
+ "@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:MathDialect",
+ "@llvm-project//mlir:MemRefDialect",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TransformUtils",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
+################################################################################
# IREEPyDM Dialect
################################################################################
@@ -153,31 +341,31 @@
tbl_outs = [
(
["-gen-dialect-decls"],
- "include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.h.inc",
+ "include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.h.inc",
),
(
["-gen-dialect-defs"],
- "include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.cpp.inc",
+ "include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.cpp.inc",
),
(
["-gen-op-decls"],
- "include/iree-dialects/Dialect/IREEPyDM/IR/Ops.h.inc",
+ "include/iree-dialects/Dialect/PyDM/IR/PyDMOps.h.inc",
),
(
["-gen-op-defs"],
- "include/iree-dialects/Dialect/IREEPyDM/IR/Ops.cpp.inc",
+ "include/iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc",
),
(
["-gen-typedef-decls"],
- "include/iree-dialects/Dialect/IREEPyDM/IR/Types.h.inc",
+ "include/iree-dialects/Dialect/PyDM/IR/PyDMTypes.h.inc",
),
(
["-gen-typedef-defs"],
- "include/iree-dialects/Dialect/IREEPyDM/IR/Types.cpp.inc",
+ "include/iree-dialects/Dialect/PyDM/IR/PyDMTypes.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "include/iree-dialects/Dialect/IREEPyDM/IR/Ops.td",
+ td_file = "include/iree-dialects/Dialect/PyDM/IR/PyDMOps.td",
deps = [
":TdFiles",
"@llvm-project//mlir:CallInterfacesTdFiles",
@@ -191,23 +379,23 @@
tbl_outs = [
(
["-gen-op-interface-decls"],
- "include/iree-dialects/Dialect/IREEPyDM/IR/OpInterfaces.h.inc",
+ "include/iree-dialects/Dialect/PyDM/IR/PyDMOpInterfaces.h.inc",
),
(
["-gen-op-interface-defs"],
- "include/iree-dialects/Dialect/IREEPyDM/IR/OpInterfaces.cpp.inc",
+ "include/iree-dialects/Dialect/PyDM/IR/PyDMOpInterfaces.cpp.inc",
),
(
["-gen-type-interface-decls"],
- "include/iree-dialects/Dialect/IREEPyDM/IR/TypeInterfaces.h.inc",
+ "include/iree-dialects/Dialect/PyDM/IR/PyDMTypeInterfaces.h.inc",
),
(
["-gen-type-interface-defs"],
- "include/iree-dialects/Dialect/IREEPyDM/IR/TypeInterfaces.cpp.inc",
+ "include/iree-dialects/Dialect/PyDM/IR/PyDMTypeInterfaces.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.td",
+ td_file = "include/iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.td",
deps = [
":TdFiles",
],
@@ -219,19 +407,19 @@
tbl_outs = [
(
["-gen-pass-decls"],
- "include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h.inc",
+ "include/iree-dialects/Dialect/PyDM/Transforms/Passes.h.inc",
),
(
["-gen-pass-capi-header"],
- "include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.capi.h.inc",
+ "include/iree-dialects/Dialect/PyDM/Transforms/Passes.capi.h.inc",
),
(
["-gen-pass-capi-impl"],
- "include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.capi.cpp.inc",
+ "include/iree-dialects/Dialect/PyDM/Transforms/Passes.capi.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.td",
+ td_file = "include/iree-dialects/Dialect/PyDM/Transforms/Passes.td",
deps = [
":TdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
@@ -261,10 +449,10 @@
cc_library(
name = "IREEPyDMDialect",
srcs = glob([
- "lib/Dialect/IREEPyDM/IR/*.cpp",
+ "lib/Dialect/PyDM/IR/*.cpp",
]),
hdrs = glob([
- "include/iree-dialects/Dialect/IREEPyDM/IR/*.h",
+ "include/iree-dialects/Dialect/PyDM/IR/*.h",
]),
includes = ["include"],
deps = [
@@ -283,20 +471,20 @@
cc_library(
name = "IREEPyDMTransforms",
srcs = glob([
- "lib/Dialect/IREEPyDM/Transforms/*.cpp",
- "lib/Dialect/IREEPyDM/Transforms/Optimize/*.cpp",
- "lib/Dialect/IREEPyDM/Transforms/RTL/*.cpp",
- "lib/Dialect/IREEPyDM/Transforms/ToIREE/*.cpp",
- "lib/Dialect/IREEPyDM/Utils/*.cpp",
+ "lib/Dialect/PyDM/Transforms/*.cpp",
+ "lib/Dialect/PyDM/Transforms/Optimize/*.cpp",
+ "lib/Dialect/PyDM/Transforms/RTL/*.cpp",
+ "lib/Dialect/PyDM/Transforms/ToIREE/*.cpp",
+ "lib/Dialect/PyDM/Utils/*.cpp",
]),
hdrs = glob([
- "include/iree-dialects/Dialect/IREEPyDM/Transforms/*.h",
- "include/iree-dialects/Dialect/IREEPyDM/Transforms/**/*.h",
- "include/iree-dialects/Dialect/IREEPyDM/Utils/*.h",
- "lib/Dialect/IREEPyDM/Transforms/*.h",
+ "include/iree-dialects/Dialect/PyDM/Transforms/*.h",
+ "include/iree-dialects/Dialect/PyDM/Transforms/**/*.h",
+ "include/iree-dialects/Dialect/PyDM/Utils/*.h",
+ "lib/Dialect/PyDM/Transforms/*.h",
]),
deps = [
- ":IREEDialect",
+ ":IREEInputDialect",
":IREEPyDMDialect",
":IREEPyDMTransformsIncGen",
"@llvm-project//llvm:Support",
@@ -323,7 +511,7 @@
hdrs = glob(["include/iree-dialects-c/*.h"]),
includes = ["include"],
deps = [
- ":IREEDialect",
+ ":IREEInputDialect",
":IREEPyDMDialect",
":IREEPyDMTransforms",
"@llvm-project//mlir:CAPIIR",
@@ -343,15 +531,21 @@
],
tags = ["hostonly"],
deps = [
- ":IREEDialect",
+ ":IREEInputDialect",
+ ":IREELinalgExtDialect",
+ ":IREELinalgExtTransforms",
":IREEPyDMDialect",
":IREEPyDMTransforms",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
index a56ef4d..5b5d93d 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_DIALECTS_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_DIALECTS_H
+#ifndef IREE_DIALECTS_C_DIALECTS_H
+#define IREE_DIALECTS_C_DIALECTS_H
#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
@@ -19,7 +19,7 @@
// IREEDialect
//===----------------------------------------------------------------------===//
-MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IREE, iree);
+MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IREEInput, iree_input);
//===----------------------------------------------------------------------===//
// IREEPyDMDialect
@@ -111,4 +111,4 @@
}
#endif
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_DIALECTS_H
+#endif // IREE_DIALECTS_C_DIALECTS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Utils.h b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Utils.h
index 696f6ad..8d0252e 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Utils.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Utils.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_UTILS_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_UTILS_H
+#ifndef IREE_DIALECTS_C_UTILS_H
+#define IREE_DIALECTS_C_UTILS_H
#include "mlir-c/IR.h"
@@ -23,4 +23,4 @@
}
#endif
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_UTILS_H
+#endif // IREE_DIALECTS_C_UTILS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
index 61df04e..620c526 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
@@ -1,2 +1,3 @@
-add_subdirectory(IREE)
-add_subdirectory(IREEPyDM)
+add_subdirectory(Input)
+add_subdirectory(LinalgExt)
+add_subdirectory(PyDM)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/CMakeLists.txt
deleted file mode 100644
index 219c6d9..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/CMakeLists.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-add_mlir_dialect(IREEOps iree)
-add_mlir_doc(IREEDialect IREEDialect IREE/ -gen-dialect-doc)
-add_mlir_doc(IREEOps IREEOps IREE/ -gen-op-doc)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEBase.td
deleted file mode 100644
index 59160c5..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEBase.td
+++ /dev/null
@@ -1,114 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_BASE_TD
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_BASE_TD
-
-include "mlir/IR/OpBase.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-
-def IREE_Dialect : Dialect {
- let name = "iree";
- let summary = "Public ops/type/attributes legal for input to IREE's compiler";
- let description = [{
- IREE's compiler allows as input a number of common dialects. This dialect
- contains structural and unique ops that do not exist elsewhere or that IREE
- has an interest in maintaining as a stable set.
-
- The contents of this dialect often mirror various constructs in IREE's
- internal implementation. The focus here is on simplicity and stability
- over time. Generally, this dialect does not use "advanced" features and
- should be broadly source compatible over a range of LLVM versions. There
- are of course, limits, and source-compatibility is not guaranteed, since
- LLVM/MLIR's API surface is itself unstable.
- }];
- let cppNamespace = "::mlir::iree";
-}
-
-class IREE_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<IREE_Dialect, mnemonic, traits>;
-class IREE_PureOp<string mnemonic, list<OpTrait> traits = []> :
- Op<IREE_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])>;
-class IREE_Type<string name> : TypeDef<IREE_Dialect, name>;
-
-//===----------------------------------------------------------------------===//
-// Predicates
-//===----------------------------------------------------------------------===//
-
-class IREE_AliasedSymbolRefAttr : Attr<CPred<"$_self.isa<FlatSymbolRefAttr>()">,
- "symbol reference attribute"> {
- let storageType = [{ FlatSymbolRefAttr }];
- let returnType = [{ StringRef }];
- let valueType = NoneType;
- let constBuilderCall = "mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
-}
-class IREE_AnyPtrOf<list<Type> types> :
- Type<And<[
- CPred<"$_self.isa<::mlir::iree::PtrType>()">,
- Or<!foreach(type, types,
- SubstLeaves<
- "$_self",
- "$_self.cast<::mlir::iree::PtrType>().getTargetType()",
- type.predicate>)>,
- ]>, !interleave(!foreach(type, types, type.summary), " or ")> {
- string builderCall = "";
-}
-
-def IREE_PrimitiveType : AnyTypeOf<[Index, AnySignlessInteger, AnyFloat]>;
-def IREE_Tensor : TypeAlias<AnyRankedTensor>;
-
-def IREE_AnyList : DialectType<
- IREE_Dialect,
- CPred<"$_self.isa<::mlir::iree::ListType>()">,
- "list"> {
- let description = [{
- A mutable, resizable list of some type.
- }];
-}
-
-class IREE_ListOf<Type type> :
- Type<And<[
- CPred<"$_self.isa<::mlir::iree::ListType>()">,
- SubstLeaves<"$_self",
- "$_self.cast<::mlir::iree::ListType>().getElementType()",
- type.predicate>
- ]>, "list<" # type.summary # ">"> {
- // Set the builder call if the base type has a builder call.
- string builderCall = !if(!empty(type.builderCall),
- "", "::mlir::iree::ListType::get(" # type.builderCall # ")");
-}
-
-def IREE_ElementTypeParameter : TypeParameter<
- "::mlir::Type", "A type suitable as an element type of a container">;
-def IREE_PtrTargetTypeParameter : TypeParameter<
- "::mlir::Type", "A type suitable as a target type of a pointer">;
-
-def IREE_Dim : TypeAlias<Index>;
-def IREE_Dims : Variadic<IREE_Dim>;
-def IREE_Shape : Variadic<IREE_Dim>;
-def IREE_ShapeDynamicDims : Variadic<IREE_Dim>;
-
-def IREE_GlobalRefAttr : IREE_AliasedSymbolRefAttr;
-def IREE_AnyGlobalPtr : IREE_AnyPtrOf<[IREE_Tensor, IREE_PrimitiveType]>;
-
-class IREE_IndexAttrBase<string descr> :
- TypedAttrBase<
- Index, "IntegerAttr",
- And<[
- CPred<"$_self.isa<IntegerAttr>()">,
- CPred<"$_self.cast<IntegerAttr>().getType().isIndex()">,
- ]>,
- descr> {
- let returnType = [{ APInt }];
-}
-def IREE_IndexAttr : IREE_IndexAttrBase<"size_t">;
-
-def IREE_TiedOpStorageAttr :
- TypedArrayAttrBase<IREE_IndexAttr, "64-bit integer array attribute"> {
- let constBuilderCall = "$_builder.getI64ArrayAttr($0)";
-}
-
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_BASE_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.h
deleted file mode 100644
index ec1bca7..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.h
+++ /dev/null
@@ -1,19 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_H
-
-#include "mlir/IR/Dialect.h"
-
-// Include generated dialect code (this comment blocks clang-format from
-// clobbering order).
-#include "iree-dialects/Dialect/IREE/IREEOpsDialect.h.inc"
-
-#define GET_TYPEDEF_CLASSES
-#include "iree-dialects/Dialect/IREE/IREEOpsTypes.h.inc"
-
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.h
deleted file mode 100644
index 9fd719e..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.h
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_H
-
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-
-#define GET_OP_CLASSES
-#include "iree-dialects/Dialect/IREE/IREEOps.h.inc"
-
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.td
deleted file mode 100644
index f937ef7..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEOps.td
+++ /dev/null
@@ -1,525 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_TD
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_TD
-
-include "iree-dialects/Dialect/IREE/IREEDialect.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/IR/SymbolInterfaces.td"
-
-def IREE_NullOp : IREE_PureOp<"null"> {
- let summary = "a null value";
- let description = [{
- Initializes reference and variant types with a null value.
- }];
-
- let results = (outs
- AnyType:$result
- );
-
- let assemblyFormat = [{
- attr-dict `:` type($result)
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// Casts
-//===----------------------------------------------------------------------===//
-
-def IREE_TensorToBufferViewOp : IREE_PureOp<"cast.tensor_to_buffer_view"> {
- let summary = "Casts a tensor to a BufferView, capturing dynamic dims";
- let arguments = (ins
- IREE_Tensor:$source,
- IREE_ShapeDynamicDims:$source_dims
- );
- let results = (outs IREE_BufferViewType:$target);
-
- let assemblyFormat = [{
- $source `:` type($source) (`{` $source_dims^ `}`)? `->` type($target)
- attr-dict-with-keyword
- }];
-}
-
-def IREE_BufferViewToTensorOp : IREE_PureOp<"cast.buffer_view_to_tensor"> {
- let summary = "Casts a BufferView to a tensor, providing dynamic dims";
- let arguments = (ins
- IREE_BufferViewType:$source,
- IREE_ShapeDynamicDims:$target_dims
- );
- let results = (outs IREE_Tensor:$target);
-
- let assemblyFormat = [{
- $source `:` type($source) `->` type($target) (`{` $target_dims^ `}`)?
- attr-dict-with-keyword
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// Global variables
-//===----------------------------------------------------------------------===//
-
-def IREE_GlobalOp : IREE_Op<"global", [
- Symbol,
- ]> {
- let summary = [{stateful global variable declaration}];
- let description = [{
- Declares a global variable that maintains its value across invocations.
- The value is tied to the execution context of the module and different
- contexts will have different global storage.
- }];
-
- let arguments = (ins
- OptionalAttr<StrAttr>:$sym_visibility,
- SymbolNameAttr:$sym_name,
- TypeAttr:$type,
- UnitAttr:$is_mutable,
- OptionalAttr<FlatSymbolRefAttr>:$initializer,
- OptionalAttr<AnyAttr>:$initial_value
- );
-
- let assemblyFormat = [{
- custom<SymbolVisibility>($sym_visibility)
- (`mutable` $is_mutable^)?
- $sym_name
- attr-dict
- (`initializer` `(` $initializer^ `)`):(``)?
- custom<TypeOrAttr>($type, $initial_value)
- }];
-}
-
-def IREE_GlobalAddressOp : IREE_PureOp<"global.address"> {
- let summary = [{returns an address reference to a global}];
- let description = [{
- Returns the address of a global as a typed reference. Can be used with the
- global load and store indirect ops.
- }];
-
- let arguments = (ins
- IREE_GlobalRefAttr:$global
- );
- let results = (outs
- IREE_AnyGlobalPtr:$result
- );
-
- let assemblyFormat = [{
- $global attr-dict `:` type($result)
- }];
-}
-
-def IREE_GlobalLoadOp : IREE_Op<"global.load"> {
- let summary = [{loads a value from a global variable}];
- let description = [{
- Returns a copy of the global value.
- }];
-
- let arguments = (ins
- IREE_GlobalRefAttr:$global
- );
- let results = (outs
- AnyType:$result
- );
-
- let assemblyFormat = [{
- $global attr-dict `:` type($result)
- }];
-}
-
-def IREE_GlobalLoadIndirectOp : IREE_Op<"global.load.indirect"> {
- let summary = [{loads a value from a global variable}];
- let description = [{
- Returns a copy of the global value.
- }];
-
- let arguments = (ins
- IREE_AnyGlobalPtr:$global
- );
- let results = (outs
- AnyType:$result
- );
-
- let assemblyFormat = [{
- $global attr-dict `:` type($global) `->` type($result)
- }];
-}
-
-def IREE_GlobalStoreOp : IREE_Op<"global.store"> {
- let summary = [{stores a value into a global variable}];
- let description = [{
- Stores a copy of the value into a global.
- }];
-
- let arguments = (ins
- AnyType:$value,
- IREE_GlobalRefAttr:$global
- );
-
- let assemblyFormat = [{
- $value `,` $global attr-dict `:` type($value)
- }];
-}
-
-def IREE_GlobalStoreIndirectOp : IREE_Op<"global.store.indirect"> {
- let summary = [{stores a value into a global variable}];
- let description = [{
- Stores a copy of the value into a global.
- }];
-
- let arguments = (ins
- AnyType:$value,
- IREE_AnyGlobalPtr:$global
- );
-
- let assemblyFormat = [{
- $value `,` $global attr-dict `:` type($value) `->` type($global)
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// Buffer Views
-//===----------------------------------------------------------------------===//
-
-def IREE_BufferViewRankOp : IREE_PureOp<"buffer_view.rank"> {
- let summary = [{buffer view rank query}];
- let description = [{
- Returns the rank of the buffer view.
- }];
-
- let arguments = (ins
- IREE_BufferViewType:$buffer_view
- );
- let results = (outs
- IREE_Dim:$result
- );
-
- let assemblyFormat = [{
- $buffer_view attr-dict `:` type($result)
- }];
-}
-
-def IREE_BufferViewDimOp : IREE_PureOp<"buffer_view.dim"> {
- let summary = [{buffer view dimension value query}];
- let description = [{
- Returns the value of the given dimension.
- }];
-
- let arguments = (ins
- IREE_BufferViewType:$buffer_view,
- IndexAttr:$index
- );
- let results = (outs
- IREE_Dim:$result
- );
-
- let assemblyFormat = [{
- $buffer_view `,` $index attr-dict `:` type($result)
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// Mutable Lists
-//===----------------------------------------------------------------------===//
-
-def IREE_ListCreateOp : IREE_PureOp<
- "list.create", [MemoryEffects<[MemAlloc]>]> {
- let summary = [{creates a new empty list}];
- let description = [{
- Creates a new empty list with an optional initial capacity.
- }];
-
- let arguments = (ins
- Optional<Index>:$initial_capacity
- );
- let results = (outs
- IREE_AnyList:$result
- );
-
- let assemblyFormat = "($initial_capacity^)? attr-dict `:` type($result)";
-}
-
-def IREE_ListSizeOp : IREE_Op<"list.size", [MemoryEffects<[MemRead]>]> {
- let summary = [{the size of the list in elements}];
- let description = [{
- Returns the current size of the list in elements.
- }];
-
- let arguments = (ins
- IREE_AnyList:$list
- );
- let results = (outs
- Index:$result
- );
-
- let assemblyFormat = "operands attr-dict `:` type($list)";
-}
-
-def IREE_ListResizeOp : IREE_Op<"list.resize", [MemoryEffects<[MemWrite]>]> {
- let summary = [{resizes the list to a new count in elements}];
- let description = [{
- Resizes the list to contain `new_size` elements. This will either truncate
- the list if the existing size is greater than `new_size` or extend the list
- with the default list value of the element type.
- }];
-
- let arguments = (ins
- IREE_AnyList:$list,
- Index:$new_size
- );
-
- let assemblyFormat = "operands attr-dict `:` type($list)";
-}
-
-def IREE_ListGetOp : IREE_Op<"list.get", [MemoryEffects<[MemRead]>]> {
- let summary = [{element accessor}];
- let description = [{
- Returns the value of the element at the given index. Note that the value
- may be null if the element is null or the type does not match.
- }];
-
- let arguments = (ins
- IREE_AnyList:$list,
- Index:$index
- );
- let results = (outs
- AnyType:$result
- );
-
- let assemblyFormat = "$list `[` $index `]` attr-dict `:` type($list) `->` type($result)";
-}
-
-def IREE_ListSetOp : IREE_Op<"list.set", [MemoryEffects<[MemWrite]>]> {
- let summary = [{element mutator}];
- let description = [{
- Sets the element at the given index to the new value.
- }];
-
- let arguments = (ins
- IREE_AnyList:$list,
- Index:$index,
- AnyType:$value
- );
-
- let assemblyFormat = "$list `[` $index `]` `,` $value attr-dict `:` type($list) `,` type($value)";
-}
-
-//===----------------------------------------------------------------------===//
-// Tensor ops
-//===----------------------------------------------------------------------===//
-
-def IREE_TensorReshapeOp : IREE_PureOp<"tensor.reshape", [
- AllElementTypesMatch<["source", "result"]>,
- AttrSizedOperandSegments,
- ]> {
- let summary = [{reshapes a tensor}];
- let description = [{
- Reshapes a tensor to a new shape without modifying the contents.
- }];
-
- let arguments = (ins
- IREE_Tensor:$source,
- IREE_ShapeDynamicDims:$source_dims,
- IREE_ShapeDynamicDims:$result_dims
- );
- let results = (outs
- IREE_Tensor:$result
- );
-
- let assemblyFormat = [{
- $source `:`
- type($source) (`{` $source_dims^ `}`)? `->`
- type($result) (`{` $result_dims^ `}`)?
- attr-dict-with-keyword
- }];
-}
-
-def IREE_TensorLoadOp : IREE_PureOp<"tensor.load", [
- TypesMatchWith<"value type matches element type of target operand",
- "source", "result",
- "$_self.cast<ShapedType>().getElementType()">,
- AttrSizedOperandSegments,
- ]> {
- let summary = [{loads a value from a tensor element}];
- let description = [{
- Returns the element at the given location from within the tensor.
- }];
-
- let arguments = (ins
- IREE_Tensor:$source,
- IREE_ShapeDynamicDims:$source_dims,
- Variadic<IREE_Dim>:$indices
- );
- let results = (outs
- AnyTypeOf<[IREE_PrimitiveType, AnyVector]>:$result
- );
-
- let assemblyFormat = [{
- $source (`[` $indices^ `]`)? `:`
- type($source) (`{` $source_dims^ `}`)?
- attr-dict-with-keyword
- }];
-
-}
-
-def IREE_TensorStoreOp : IREE_PureOp<"tensor.store", [
- AllTypesMatch<["target", "result"]>,
- TypesMatchWith<"value type matches element type of target operand",
- "target", "value",
- "$_self.cast<ShapedType>().getElementType()">,
- AttrSizedOperandSegments,
- ]> {
- let summary = [{stores a value into a tensor element}];
- let description = [{
- Returns a tensor with the element at the given index set to the given value.
- }];
-
- let arguments = (ins
- AnyTypeOf<[IREE_PrimitiveType, AnyVector]>:$value,
- IREE_Tensor:$target,
- IREE_ShapeDynamicDims:$target_dims,
- Variadic<IREE_Dim>:$indices
- );
- let results = (outs
- IREE_Tensor:$result
- );
-
- let assemblyFormat = [{
- $value `,` $target (`[` $indices^ `]`)? `:`
- type($target) (`{` $target_dims^ `}`)?
- attr-dict-with-keyword
- }];
-}
-
-def IREE_TensorSplatOp : IREE_PureOp<"tensor.splat", [
- TypesMatchWith<"value type matches element type of result",
- "result", "value",
- "$_self.cast<ShapedType>().getElementType()">,
- ]> {
- let summary = [{splats a value into a shaped tensor}];
- let description = [{
- Returns a tensor initialized to the given primitive value.
- }];
-
- let arguments = (ins
- IREE_PrimitiveType:$value,
- IREE_ShapeDynamicDims:$result_dims
- );
- let results = (outs
- IREE_Tensor:$result
- );
-
- let assemblyFormat = [{
- $value `:` type($result) (`{` $result_dims^ `}`)?
- attr-dict-with-keyword
- }];
-}
-
-def IREE_TensorCloneOp : IREE_PureOp<"tensor.clone", [
- AllTypesMatch<["operand", "result"]>,
- ]> {
- let summary = [{performs a full tensor clone operation}];
- let description = [{
- Clones the input tensor into an identical output tensor.
- }];
-
- let arguments = (ins
- IREE_Tensor:$operand,
- IREE_ShapeDynamicDims:$operand_dims
- );
- let results = (outs
- IREE_Tensor:$result
- );
-
- let assemblyFormat = [{
- $operand `:` type($result) (`{` $operand_dims^ `}`)?
- attr-dict-with-keyword
- }];
-}
-
-def IREE_TensorSliceOp : IREE_PureOp<"tensor.slice", [
- AllRanksMatch<["source", "result"]>,
- AllElementTypesMatch<["source", "result"]>,
- AttrSizedOperandSegments,
- ]> {
- let summary = [{slices out a subregion of a tensor}];
- let description = [{
- Clones a subregion of a tensor.
- }];
-
- let arguments = (ins
- IREE_Tensor:$source,
- IREE_ShapeDynamicDims:$source_dims,
- Variadic<IREE_Dim>:$start_indices,
- Variadic<IREE_Dim>:$lengths,
- IREE_ShapeDynamicDims:$result_dims
- );
- let results = (outs
- IREE_Tensor:$result
- );
-
- let assemblyFormat = [{
- $source `[` $start_indices `for` $lengths `]` `:`
- type($source) (`{` $source_dims^ `}`)? `->`
- type($result) (`{` $result_dims^ `}`)?
- attr-dict-with-keyword
- }];
-}
-
-def IREE_TensorUpdateOp : IREE_PureOp<"tensor.update", [
- AllRanksMatch<["update", "target", "result"]>,
- AllTypesMatch<["target", "result"]>,
- AllElementTypesMatch<["update", "target", "result"]>,
- AttrSizedOperandSegments,
- ]> {
- let summary = [{updates a tensor with the contents of another tensor}];
- let description = [{
- Updates the target tensor with the contents of the update tensor at the
- given offset indices.
- }];
-
- let arguments = (ins
- IREE_Tensor:$target,
- IREE_ShapeDynamicDims:$target_dims,
- Variadic<IREE_Dim>:$start_indices,
- IREE_Tensor:$update,
- IREE_ShapeDynamicDims:$update_dims
- );
- let results = (outs
- IREE_Tensor:$result
- );
-
- let assemblyFormat = [{
- $update `,` $target `[` $start_indices `]` `:`
- type($update) (`{` $update_dims^ `}`)? `->`
- type($result) (`{` $target_dims^ `}`)?
- attr-dict-with-keyword
- }];
-
- let builders = [
- OpBuilder<(ins
- "Value":$target,
- "ValueRange":$start_indices,
- "Value":$update)>,
- ];
-}
-
-def IREE_TensorTraceOp : IREE_Op<"tensor.trace", []> {
- let summary = [{trace value(s) operation}];
- let description = [{
- Traces out to a runtime trace sink (console, log file, etc) the given
- tensors and titles them with the given key. The key is informational only
- and useful for titling/marking specific sets of tensors for easier
- searching.
- }];
-
- let arguments = (ins
- StrAttr:$key,
- Variadic<IREE_Tensor>:$operands
- );
-
- let assemblyFormat = "$key attr-dict ($operands^ `:` type($operands))?";
-}
-
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_OPS_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/CMakeLists.txt
deleted file mode 100644
index ef138f8..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/CMakeLists.txt
+++ /dev/null
@@ -1,24 +0,0 @@
-function(_add_interfaces)
- set(LLVM_TARGET_DEFINITIONS Interfaces.td)
- mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
- mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)
- mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
- mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)
- add_public_tablegen_target(MLIRIREEPyDMInterfacesIncGen)
-endfunction()
-
-function(_add_dialect)
- set(LLVM_TARGET_DEFINITIONS Ops.td)
- mlir_tablegen(Ops.h.inc -gen-op-decls)
- mlir_tablegen(Ops.cpp.inc -gen-op-defs)
- mlir_tablegen(Types.h.inc -gen-typedef-decls)
- mlir_tablegen(Types.cpp.inc -gen-typedef-defs)
- mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=iree_pydm)
- mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=iree_pydm)
- add_public_tablegen_target(MLIRIREEPyDMOpsIncGen)
- add_dependencies(MLIRIREEPyDMOpsIncGen MLIRIREEPyDMInterfacesIncGen)
- add_dependencies(mlir-headers MLIRIREEPyDMOpsIncGen)
-endfunction()
-
-_add_dialect()
-_add_interfaces()
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h
deleted file mode 100644
index 4b66798..0000000
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_INTERFACES_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_INTERFACES_H
-
-#include "iree-dialects/Dialect/IREEPyDM/IR/Constants.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/Types.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace PYDM {
-
-enum class BuiltinTypeCode;
-
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#include "iree-dialects/Dialect/IREEPyDM/IR/OpInterfaces.h.inc"
-#include "iree-dialects/Dialect/IREEPyDM/IR/TypeInterfaces.h.inc"
-
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_INTERFACES_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/CMakeLists.txt
new file mode 100644
index 0000000..e43beef
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/CMakeLists.txt
@@ -0,0 +1,13 @@
+function(_add_dialect)
+ set(LLVM_TARGET_DEFINITIONS InputOps.td)
+ mlir_tablegen(InputOps.h.inc -gen-op-decls)
+ mlir_tablegen(InputOps.cpp.inc -gen-op-defs)
+ mlir_tablegen(InputTypes.h.inc -gen-typedef-decls)
+ mlir_tablegen(InputTypes.cpp.inc -gen-typedef-defs)
+ mlir_tablegen(InputDialect.h.inc -gen-dialect-decls -dialect=iree_input)
+ mlir_tablegen(InputDialect.cpp.inc -gen-dialect-defs -dialect=iree_input)
+ add_public_tablegen_target(IREEInputIncGen)
+ add_dependencies(mlir-headers IREEInputIncGen)
+endfunction()
+
+_add_dialect()
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputBase.td
new file mode 100644
index 0000000..6ed641b
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputBase.td
@@ -0,0 +1,114 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_INPUT_BASE_TD
+#define IREE_DIALECTS_DIALECT_INPUT_BASE_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def IREEInput_Dialect : Dialect {
+ let name = "iree_input";
+ let summary = "Public ops/type/attributes legal for input to IREE's compiler";
+ let description = [{
+ IREE's compiler allows as input a number of common dialects. This dialect
+ contains structural and unique ops that do not exist elsewhere or that IREE
+ has an interest in maintaining as a stable set.
+
+ The contents of this dialect often mirror various constructs in IREE's
+ internal implementation. The focus here is on simplicity and stability
+ over time. Generally, this dialect does not use "advanced" features and
+ should be broadly source compatible over a range of LLVM versions. There
+ are of course, limits, and source-compatibility is not guaranteed, since
+ LLVM/MLIR's API surface is itself unstable.
+ }];
+ let cppNamespace = "::mlir::iree_compiler::IREE::Input";
+}
+
+class IREEInput_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<IREEInput_Dialect, mnemonic, traits>;
+class IREEInput_PureOp<string mnemonic, list<OpTrait> traits = []> :
+ Op<IREEInput_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])>;
+class IREEInput_Type<string name> : TypeDef<IREEInput_Dialect, name>;
+
+//===----------------------------------------------------------------------===//
+// Predicates
+//===----------------------------------------------------------------------===//
+
+class IREEInput_AliasedSymbolRefAttr : Attr<CPred<"$_self.isa<FlatSymbolRefAttr>()">,
+ "symbol reference attribute"> {
+ let storageType = [{ FlatSymbolRefAttr }];
+ let returnType = [{ StringRef }];
+ let valueType = NoneType;
+ let constBuilderCall = "mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
+}
+class IREEInput_AnyPtrOf<list<Type> types> :
+ Type<And<[
+ CPred<"$_self.isa<::mlir::iree_compiler::IREE::Input::PtrType>()">,
+ Or<!foreach(type, types,
+ SubstLeaves<
+ "$_self",
+ "$_self.cast<::mlir::iree_compiler::IREE::Input::PtrType>().getTargetType()",
+ type.predicate>)>,
+ ]>, !interleave(!foreach(type, types, type.summary), " or ")> {
+ string builderCall = "";
+}
+
+def IREEInput_PrimitiveType : AnyTypeOf<[Index, AnySignlessInteger, AnyFloat]>;
+def IREEInput_Tensor : TypeAlias<AnyRankedTensor>;
+
+def IREEInput_AnyList : DialectType<
+ IREEInput_Dialect,
+ CPred<"$_self.isa<::mlir::iree_compiler::IREE::Input::ListType>()">,
+ "list"> {
+ let description = [{
+ A mutable, resizable list of some type.
+ }];
+}
+
+class IREEInput_ListOf<Type type> :
+ Type<And<[
+ CPred<"$_self.isa<::mlir::iree_compiler::IREE::Input::ListType>()">,
+ SubstLeaves<"$_self",
+ "$_self.cast<::mlir::iree_compiler::IREE::Input::ListType>().getElementType()",
+ type.predicate>
+ ]>, "list<" # type.summary # ">"> {
+ // Set the builder call if the base type has a builder call.
+ string builderCall = !if(!empty(type.builderCall),
+ "", "::mlir::iree_compiler::IREE::Input::ListType::get(" # type.builderCall # ")");
+}
+
+def IREEInput_ElementTypeParameter : TypeParameter<
+ "::mlir::Type", "A type suitable as an element type of a container">;
+def IREEInput_PtrTargetTypeParameter : TypeParameter<
+ "::mlir::Type", "A type suitable as a target type of a pointer">;
+
+def IREEInput_Dim : TypeAlias<Index>;
+def IREEInput_Dims : Variadic<IREEInput_Dim>;
+def IREEInput_Shape : Variadic<IREEInput_Dim>;
+def IREEInput_ShapeDynamicDims : Variadic<IREEInput_Dim>;
+
+def IREEInput_GlobalRefAttr : IREEInput_AliasedSymbolRefAttr;
+def IREEInput_AnyGlobalPtr : IREEInput_AnyPtrOf<[IREEInput_Tensor, IREEInput_PrimitiveType]>;
+
+class IREEInput_IndexAttrBase<string descr> :
+ TypedAttrBase<
+ Index, "IntegerAttr",
+ And<[
+ CPred<"$_self.isa<IntegerAttr>()">,
+ CPred<"$_self.cast<IntegerAttr>().getType().isIndex()">,
+ ]>,
+ descr> {
+ let returnType = [{ APInt }];
+}
+def IREEInput_IndexAttr : IREEInput_IndexAttrBase<"size_t">;
+
+def IREEInput_TiedOpStorageAttr :
+ TypedArrayAttrBase<IREEInput_IndexAttr, "64-bit integer array attribute"> {
+ let constBuilderCall = "$_builder.getI64ArrayAttr($0)";
+}
+
+#endif // IREE_DIALECTS_DIALECT_INPUT_BASE_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.h
new file mode 100644
index 0000000..4d6dbb7
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.h
@@ -0,0 +1,19 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_INPUT_DIALECT_H
+#define IREE_DIALECTS_DIALECT_INPUT_DIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+// Include generated dialect code (this comment blocks clang-format from
+// clobbering order).
+#include "iree-dialects/Dialect/Input/InputDialect.h.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "iree-dialects/Dialect/Input/InputTypes.h.inc"
+
+#endif // IREE_DIALECTS_DIALECT_INPUT_DIALECT_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.td
similarity index 73%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.td
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.td
index 691359a..cde0652 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREE/IREEDialect.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.td
@@ -4,16 +4,16 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_TD
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_TD
+#ifndef IREE_DIALECTS_DIALECT_INPUT_DIALECT_TD
+#define IREE_DIALECTS_DIALECT_INPUT_DIALECT_TD
-include "iree-dialects/Dialect/IREE/IREEBase.td"
+include "iree-dialects/Dialect/Input/InputBase.td"
//===----------------------------------------------------------------------===//
// Types
//===----------------------------------------------------------------------===//
-def IREE_BufferViewType : IREE_Type<"BufferView"> {
+def IREEInput_BufferViewType : IREEInput_Type<"BufferView"> {
let mnemonic = "buffer_view";
let summary = "View into a buffer, with runtime shape and element type";
@@ -32,16 +32,9 @@
to higher level code to ensure that aliasing rules are enforced at such
boundaries.
}];
- let printer = [{
- $_printer << "buffer_view";
- }];
-
- let parser = [{
- return get($_ctxt);
- }];
}
-def IREE_VariantType : IREE_Type<"Variant"> {
+def IREEInput_VariantType : IREEInput_Type<"Variant"> {
let mnemonic = "variant";
let summary = "Represents any legal or reference type in the IREE runtime";
@@ -50,16 +43,9 @@
The variant type is typically used to parameterize container types that
can contain any legal primitive, reference or null in the IREE type system.
}];
- let printer = [{
- $_printer << "variant";
- }];
-
- let parser = [{
- return get($_ctxt);
- }];
}
-def IREE_ListType : IREE_Type<"List"> {
+def IREEInput_ListType : IREEInput_Type<"List"> {
let mnemonic = "list";
let summary = "A one dimensional list of runtime values";
@@ -73,10 +59,10 @@
by parameterizing them with a VariantType.
}];
- let parameters = (ins IREE_ElementTypeParameter:$elementType);
+ let parameters = (ins IREEInput_ElementTypeParameter:$elementType);
let printer = [{
- $_printer << "list<" << getElementType() << ">";
+ $_printer << "<" << getElementType() << ">";
}];
let parser = [{
@@ -88,14 +74,14 @@
}];
}
-def IREE_PtrType : IREE_Type<"Ptr"> {
+def IREEInput_PtrType : IREEInput_Type<"Ptr"> {
let mnemonic = "ptr";
let summary = "Pointer to a concrete type";
- let parameters = (ins IREE_PtrTargetTypeParameter:$targetType);
+ let parameters = (ins IREEInput_PtrTargetTypeParameter:$targetType);
let printer = [{
- $_printer << "ptr<" << getTargetType() << ">";
+ $_printer << "<" << getTargetType() << ">";
}];
let parser = [{
@@ -107,4 +93,4 @@
}];
}
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREE_IREE_DIALECT_TD
+#endif // IREE_DIALECTS_DIALECT_INPUT_DIALECT_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.h
new file mode 100644
index 0000000..6d7cc49
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.h
@@ -0,0 +1,21 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_INPUT_OPS_H
+#define IREE_DIALECTS_DIALECT_INPUT_OPS_H
+
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/Input/InputOps.h.inc"
+
+#endif // IREE_DIALECTS_DIALECT_INPUT_OPS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td
new file mode 100644
index 0000000..e2bd966
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td
@@ -0,0 +1,631 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_INPUT_OPS_TD
+#define IREE_DIALECTS_DIALECT_INPUT_OPS_TD
+
+include "iree-dialects/Dialect/Input/InputDialect.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/SymbolInterfaces.td"
+
+def IREEInput_NullOp : IREEInput_PureOp<"null"> {
+ let summary = "a null value";
+ let description = [{
+ Initializes reference and variant types with a null value.
+ }];
+
+ let results = (outs
+ AnyType:$result
+ );
+
+ let assemblyFormat = [{
+ attr-dict `:` type($result)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Casts
+//===----------------------------------------------------------------------===//
+
+def IREEInput_TensorToBufferViewOp : IREEInput_PureOp<"cast.tensor_to_buffer_view"> {
+ let summary = "Casts a tensor to a BufferView, capturing dynamic dims";
+ let arguments = (ins
+ IREEInput_Tensor:$source,
+ IREEInput_ShapeDynamicDims:$source_dims
+ );
+ let results = (outs IREEInput_BufferViewType:$target);
+
+ let assemblyFormat = [{
+ $source `:` type($source) (`{` $source_dims^ `}`)? `->` type($target)
+ attr-dict-with-keyword
+ }];
+}
+
+def IREEInput_BufferViewToTensorOp : IREEInput_PureOp<"cast.buffer_view_to_tensor"> {
+ let summary = "Casts a BufferView to a tensor, providing dynamic dims";
+ let arguments = (ins
+ IREEInput_BufferViewType:$source,
+ IREEInput_ShapeDynamicDims:$target_dims
+ );
+ let results = (outs IREEInput_Tensor:$target);
+
+ let assemblyFormat = [{
+ $source `:` type($source) `->` type($target) (`{` $target_dims^ `}`)?
+ attr-dict-with-keyword
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Global variables
+//===----------------------------------------------------------------------===//
+
+def IREEInput_GlobalOp : IREEInput_Op<"global", [
+ Symbol,
+ ]> {
+ let summary = [{stateful global variable declaration}];
+ let description = [{
+ Declares a global variable that maintains its value across invocations.
+ The value is tied to the execution context of the module and different
+ contexts will have different global storage.
+ }];
+
+ let arguments = (ins
+ OptionalAttr<StrAttr>:$sym_visibility,
+ SymbolNameAttr:$sym_name,
+ TypeAttr:$type,
+ UnitAttr:$is_mutable,
+ OptionalAttr<FlatSymbolRefAttr>:$initializer,
+ OptionalAttr<AnyAttr>:$initial_value
+ );
+
+ let assemblyFormat = [{
+ custom<SymbolVisibility>($sym_visibility)
+ (`mutable` $is_mutable^)?
+ $sym_name
+ attr-dict
+ (`initializer` `(` $initializer^ `)`):(``)?
+ custom<TypeOrAttr>($type, $initial_value)
+ }];
+}
+
+def IREEInput_GlobalAddressOp : IREEInput_PureOp<"global.address"> {
+ let summary = [{returns an address reference to a global}];
+ let description = [{
+ Returns the address of a global as a typed reference. Can be used with the
+ global load and store indirect ops.
+ }];
+
+ let arguments = (ins
+ IREEInput_GlobalRefAttr:$global
+ );
+ let results = (outs
+ IREEInput_AnyGlobalPtr:$result
+ );
+
+ let assemblyFormat = [{
+ $global attr-dict `:` type($result)
+ }];
+}
+
+def IREEInput_GlobalLoadOp : IREEInput_Op<"global.load"> {
+ let summary = [{loads a value from a global variable}];
+ let description = [{
+ Returns a copy of the global value.
+ }];
+
+ let arguments = (ins
+ IREEInput_GlobalRefAttr:$global
+ );
+ let results = (outs
+ AnyType:$result
+ );
+
+ let assemblyFormat = [{
+ $global attr-dict `:` type($result)
+ }];
+}
+
+def IREEInput_GlobalLoadIndirectOp : IREEInput_Op<"global.load.indirect"> {
+ let summary = [{loads a value from a global variable}];
+ let description = [{
+ Returns a copy of the global value.
+ }];
+
+ let arguments = (ins
+ IREEInput_AnyGlobalPtr:$global
+ );
+ let results = (outs
+ AnyType:$result
+ );
+
+ let assemblyFormat = [{
+ $global attr-dict `:` type($global) `->` type($result)
+ }];
+}
+
+def IREEInput_GlobalStoreOp : IREEInput_Op<"global.store"> {
+ let summary = [{stores a value into a global variable}];
+ let description = [{
+ Stores a copy of the value into a global.
+ }];
+
+ let arguments = (ins
+ AnyType:$value,
+ IREEInput_GlobalRefAttr:$global
+ );
+
+ let assemblyFormat = [{
+ $value `,` $global attr-dict `:` type($value)
+ }];
+}
+
+def IREEInput_GlobalStoreIndirectOp : IREEInput_Op<"global.store.indirect"> {
+ let summary = [{stores a value into a global variable}];
+ let description = [{
+ Stores a copy of the value into a global.
+ }];
+
+ let arguments = (ins
+ AnyType:$value,
+ IREEInput_AnyGlobalPtr:$global
+ );
+
+ let assemblyFormat = [{
+ $value `,` $global attr-dict `:` type($value) `->` type($global)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Buffer Views
+//===----------------------------------------------------------------------===//
+
+def IREEInput_BufferViewRankOp : IREEInput_PureOp<"buffer_view.rank"> {
+ let summary = [{buffer view rank query}];
+ let description = [{
+ Returns the rank of the buffer view.
+ }];
+
+ let arguments = (ins
+ IREEInput_BufferViewType:$buffer_view
+ );
+ let results = (outs
+ IREEInput_Dim:$result
+ );
+
+ let assemblyFormat = [{
+ $buffer_view attr-dict `:` type($result)
+ }];
+}
+
+def IREEInput_BufferViewDimOp : IREEInput_PureOp<"buffer_view.dim"> {
+ let summary = [{buffer view dimension value query}];
+ let description = [{
+ Returns the value of the given dimension.
+ }];
+
+ let arguments = (ins
+ IREEInput_BufferViewType:$buffer_view,
+ IndexAttr:$index
+ );
+ let results = (outs
+ IREEInput_Dim:$result
+ );
+
+ let assemblyFormat = [{
+ $buffer_view `,` $index attr-dict `:` type($result)
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Mutable Lists
+//===----------------------------------------------------------------------===//
+
+def IREEInput_ListCreateOp : IREEInput_PureOp<
+ "list.create", [MemoryEffects<[MemAlloc]>]> {
+ let summary = [{creates a new empty list}];
+ let description = [{
+ Creates a new empty list with an optional initial capacity.
+ }];
+
+ let arguments = (ins
+ Optional<Index>:$initial_capacity
+ );
+ let results = (outs
+ IREEInput_AnyList:$result
+ );
+
+ let assemblyFormat = "($initial_capacity^)? attr-dict `:` type($result)";
+}
+
+def IREEInput_ListSizeOp : IREEInput_Op<"list.size", [MemoryEffects<[MemRead]>]> {
+ let summary = [{the size of the list in elements}];
+ let description = [{
+ Returns the current size of the list in elements.
+ }];
+
+ let arguments = (ins
+ IREEInput_AnyList:$list
+ );
+ let results = (outs
+ Index:$result
+ );
+
+ let assemblyFormat = "operands attr-dict `:` type($list)";
+}
+
+def IREEInput_ListResizeOp : IREEInput_Op<"list.resize", [MemoryEffects<[MemWrite]>]> {
+ let summary = [{resizes the list to a new count in elements}];
+ let description = [{
+ Resizes the list to contain `new_size` elements. This will either truncate
+ the list if the existing size is greater than `new_size` or extend the list
+ with the default list value of the element type.
+ }];
+
+ let arguments = (ins
+ IREEInput_AnyList:$list,
+ Index:$new_size
+ );
+
+ let assemblyFormat = "operands attr-dict `:` type($list)";
+}
+
+def IREEInput_ListGetOp : IREEInput_Op<"list.get", [MemoryEffects<[MemRead]>]> {
+ let summary = [{element accessor}];
+ let description = [{
+ Returns the value of the element at the given index. Note that the value
+ may be null if the element is null or the type does not match.
+ }];
+
+ let arguments = (ins
+ IREEInput_AnyList:$list,
+ Index:$index
+ );
+ let results = (outs
+ AnyType:$result
+ );
+
+ let assemblyFormat = "$list `[` $index `]` attr-dict `:` type($list) `->` type($result)";
+}
+
+def IREEInput_ListSetOp : IREEInput_Op<"list.set", [MemoryEffects<[MemWrite]>]> {
+ let summary = [{element mutator}];
+ let description = [{
+ Sets the element at the given index to the new value.
+ }];
+
+ let arguments = (ins
+ IREEInput_AnyList:$list,
+ Index:$index,
+ AnyType:$value
+ );
+
+ let assemblyFormat = "$list `[` $index `]` `,` $value attr-dict `:` type($list) `,` type($value)";
+}
+
+//===----------------------------------------------------------------------===//
+// Tensor ops
+//===----------------------------------------------------------------------===//
+
+def IREEInput_TensorReshapeOp : IREEInput_PureOp<"tensor.reshape", [
+ AllElementTypesMatch<["source", "result"]>,
+ AttrSizedOperandSegments,
+ ]> {
+ let summary = [{reshapes a tensor}];
+ let description = [{
+ Reshapes a tensor to a new shape without modifying the contents.
+ }];
+
+ let arguments = (ins
+ IREEInput_Tensor:$source,
+ IREEInput_ShapeDynamicDims:$source_dims,
+ IREEInput_ShapeDynamicDims:$result_dims
+ );
+ let results = (outs
+ IREEInput_Tensor:$result
+ );
+
+ let assemblyFormat = [{
+ $source `:`
+ type($source) (`{` $source_dims^ `}`)? `->`
+ type($result) (`{` $result_dims^ `}`)?
+ attr-dict-with-keyword
+ }];
+}
+
+def IREEInput_TensorLoadOp : IREEInput_PureOp<"tensor.load", [
+ TypesMatchWith<"value type matches element type of target operand",
+ "source", "result",
+ "$_self.cast<ShapedType>().getElementType()">,
+ AttrSizedOperandSegments,
+ ]> {
+ let summary = [{loads a value from a tensor element}];
+ let description = [{
+ Returns the element at the given location from within the tensor.
+ }];
+
+ let arguments = (ins
+ IREEInput_Tensor:$source,
+ IREEInput_ShapeDynamicDims:$source_dims,
+ Variadic<IREEInput_Dim>:$indices
+ );
+ let results = (outs
+ AnyTypeOf<[IREEInput_PrimitiveType, AnyVector]>:$result
+ );
+
+ let assemblyFormat = [{
+ $source (`[` $indices^ `]`)? `:`
+ type($source) (`{` $source_dims^ `}`)?
+ attr-dict-with-keyword
+ }];
+
+}
+
+def IREEInput_TensorStoreOp : IREEInput_PureOp<"tensor.store", [
+ AllTypesMatch<["target", "result"]>,
+ TypesMatchWith<"value type matches element type of target operand",
+ "target", "value",
+ "$_self.cast<ShapedType>().getElementType()">,
+ AttrSizedOperandSegments,
+ ]> {
+ let summary = [{stores a value into a tensor element}];
+ let description = [{
+ Returns a tensor with the element at the given index set to the given value.
+ }];
+
+ let arguments = (ins
+ AnyTypeOf<[IREEInput_PrimitiveType, AnyVector]>:$value,
+ IREEInput_Tensor:$target,
+ IREEInput_ShapeDynamicDims:$target_dims,
+ Variadic<IREEInput_Dim>:$indices
+ );
+ let results = (outs
+ IREEInput_Tensor:$result
+ );
+
+ let assemblyFormat = [{
+ $value `,` $target (`[` $indices^ `]`)? `:`
+ type($target) (`{` $target_dims^ `}`)?
+ attr-dict-with-keyword
+ }];
+}
+
+def IREEInput_TensorSplatOp : IREEInput_PureOp<"tensor.splat", [
+ TypesMatchWith<"value type matches element type of result",
+ "result", "value",
+ "$_self.cast<ShapedType>().getElementType()">,
+ ]> {
+ let summary = [{splats a value into a shaped tensor}];
+ let description = [{
+ Returns a tensor initialized to the given primitive value.
+ }];
+
+ let arguments = (ins
+ IREEInput_PrimitiveType:$value,
+ IREEInput_ShapeDynamicDims:$result_dims
+ );
+ let results = (outs
+ IREEInput_Tensor:$result
+ );
+
+ let assemblyFormat = [{
+ $value `:` type($result) (`{` $result_dims^ `}`)?
+ attr-dict-with-keyword
+ }];
+}
+
+def IREEInput_TensorCloneOp : IREEInput_PureOp<"tensor.clone", [
+ AllTypesMatch<["operand", "result"]>,
+ ]> {
+ let summary = [{performs a full tensor clone operation}];
+ let description = [{
+ Clones the input tensor into an identical output tensor.
+ }];
+
+ let arguments = (ins
+ IREEInput_Tensor:$operand,
+ IREEInput_ShapeDynamicDims:$operand_dims
+ );
+ let results = (outs
+ IREEInput_Tensor:$result
+ );
+
+ let assemblyFormat = [{
+ $operand `:` type($result) (`{` $operand_dims^ `}`)?
+ attr-dict-with-keyword
+ }];
+}
+
+def IREEInput_TensorSliceOp : IREEInput_PureOp<"tensor.slice", [
+ AllRanksMatch<["source", "result"]>,
+ AllElementTypesMatch<["source", "result"]>,
+ AttrSizedOperandSegments,
+ ]> {
+ let summary = [{slices out a subregion of a tensor}];
+ let description = [{
+ Clones a subregion of a tensor.
+ }];
+
+ let arguments = (ins
+ IREEInput_Tensor:$source,
+ IREEInput_ShapeDynamicDims:$source_dims,
+ Variadic<IREEInput_Dim>:$start_indices,
+ Variadic<IREEInput_Dim>:$lengths,
+ IREEInput_ShapeDynamicDims:$result_dims
+ );
+ let results = (outs
+ IREEInput_Tensor:$result
+ );
+
+ let assemblyFormat = [{
+ $source `[` $start_indices `for` $lengths `]` `:`
+ type($source) (`{` $source_dims^ `}`)? `->`
+ type($result) (`{` $result_dims^ `}`)?
+ attr-dict-with-keyword
+ }];
+}
+
+def IREEInput_TensorUpdateOp : IREEInput_PureOp<"tensor.update", [
+ AllRanksMatch<["update", "target", "result"]>,
+ AllTypesMatch<["target", "result"]>,
+ AllElementTypesMatch<["update", "target", "result"]>,
+ AttrSizedOperandSegments,
+ ]> {
+ let summary = [{updates a tensor with the contents of another tensor}];
+ let description = [{
+ Updates the target tensor with the contents of the update tensor at the
+ given offset indices.
+ }];
+
+ let arguments = (ins
+ IREEInput_Tensor:$target,
+ IREEInput_ShapeDynamicDims:$target_dims,
+ Variadic<IREEInput_Dim>:$start_indices,
+ IREEInput_Tensor:$update,
+ IREEInput_ShapeDynamicDims:$update_dims
+ );
+ let results = (outs
+ IREEInput_Tensor:$result
+ );
+
+ let assemblyFormat = [{
+ $update `,` $target `[` $start_indices `]` `:`
+ type($update) (`{` $update_dims^ `}`)? `->`
+ type($result) (`{` $target_dims^ `}`)?
+ attr-dict-with-keyword
+ }];
+
+ let builders = [
+ OpBuilder<(ins
+ "Value":$target,
+ "ValueRange":$start_indices,
+ "Value":$update)>,
+ ];
+}
+
+def IREEInput_TensorTraceOp : IREEInput_Op<"tensor.trace", []> {
+ let summary = [{trace value(s) operation}];
+ let description = [{
+ Traces out to a runtime trace sink (console, log file, etc) the given
+ tensors and titles them with the given key. The key is informational only
+ and useful for titling/marking specific sets of tensors for easier
+ searching.
+ }];
+
+ let arguments = (ins
+ StrAttr:$key,
+ Variadic<IREEInput_Tensor>:$operands
+ );
+
+ let assemblyFormat = "$key attr-dict ($operands^ `:` type($operands))?";
+}
+
+//===----------------------------------------------------------------------===//
+// Workgroup dispatch
+// These ops allow both scheduling and accessing parameters of workgroup
+// dispatches across an arbitrary nd-grid.
+//===----------------------------------------------------------------------===//
+
+// TODO: Define dispatch.workgroup op.
+
+def IREEInput_DispatchWorkgroupRankOp : IREEInput_PureOp<"dispatch.workgroup.rank"> {
+ let summary = [{returns the rank of the workgroup dimensions}];
+ let description = [{
+ The number of workgroup dimensions used during dispatch, bounding the
+ `iree_input.dispatch.workgroup.*` query functions.
+
+ ```mlir
+ %rank = iree_input.dispatch.workgroup.rank : index
+ ```
+ }];
+ let arguments = (ins);
+ let results = (outs IREEInput_Dim:$result);
+ let assemblyFormat = "attr-dict `:` type($result)";
+}
+
+def IREEInput_DispatchWorkgroupIDOp : IREEInput_PureOp<"dispatch.workgroup.id"> {
+ let summary = [{returns the index of the current workgroup in the grid}];
+ let description = [{
+ The global workgroup ID of the current workgroup in the range of
+ `[0, iree_input.dispatch.workgroup.count)` along each dimension.
+
+ Corresponds to the `WorkgroupId` SPIR-V built-in and the `blockIdx` CUDA
+ built-in variable, only in the iree dialect the number of dimensions is not
+ restricted to 3 (XYZ).
+
+ ```mlir
+ %x = iree_input.dispatch.workgroup.id[0] : index
+ %y = iree_input.dispatch.workgroup.id[1] : index
+ ```
+ }];
+ let arguments = (ins IndexAttr:$dimension);
+ let results = (outs IREEInput_Dim:$result);
+ let builders = [
+ OpBuilder<(ins "unsigned":$dim),
+ [{
+ build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
+ }]>,
+ ];
+ let assemblyFormat = "`[` $dimension `]` attr-dict `:` type($result)";
+}
+
+def IREEInput_DispatchWorkgroupCountOp : IREEInput_PureOp<"dispatch.workgroup.count"> {
+ let summary = [{returns the total workgroup count of the grid}];
+ let description = [{
+ The total number of workgroups along each dimension in the dispatch grid.
+
+ Corresponds to the `NumWorkgroups` SPIR-V built-in and the `gridDim` CUDA
+ built-in variable, only in the iree dialect the number of dimensions is not
+ restricted to 3 (XYZ).
+
+ ```mlir
+ %x = iree_input.dispatch.workgroup.count[0] : index
+ %y = iree_input.dispatch.workgroup.count[1] : index
+ ```
+ }];
+ let arguments = (ins IndexAttr:$dimension);
+ let results = (outs IREEInput_Dim:$result);
+ let builders = [
+ OpBuilder<(ins "unsigned":$dim),
+ [{
+ build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
+ }]>,
+ ];
+ let assemblyFormat = "`[` $dimension `]` attr-dict `:` type($result)";
+}
+
+def IREEInput_DispatchWorkgroupSizeOp : IREEInput_PureOp<"dispatch.workgroup.size"> {
+ let summary = [{returns the size of each workgroup in invocations}];
+ let description = [{
+ The number of local invocations within the current workgroup along each
+ dimension. Depending on backend this may map to the SIMT thread count or
+ inner loop nest parameters.
+
+ Workgroup sizes are not determined at the iree dialect level as they are
+ dependent on the target backend determined when lowering into the HAL. It's
+ still possible to use the symbolic workgroup size inside of dispatch
+ executables as a placeholder for the resolved value once in the HAL.
+
+ Corresponds to the `WorkgroupSize` SPIR-V built-in and the `blockDim` CUDA
+ built-in variable, only in the iree dialect the number of dimensions is not
+ restricted to 3 (XYZ).
+
+ ```mlir
+ %x = iree_input.dispatch.workgroup.size[0] : index
+ %y = iree_input.dispatch.workgroup.size[1] : index
+ ```
+ }];
+ let arguments = (ins IndexAttr:$dimension);
+ let results = (outs IREEInput_Dim:$result);
+ let builders = [
+ OpBuilder<(ins "unsigned":$dim),
+ [{
+ build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim));
+ }]>,
+ ];
+ let assemblyFormat = "`[` $dimension `]` attr-dict `:` type($result)";
+}
+
+#endif // IREE_DIALECTS_DIALECT_INPUT_OPS_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/CMakeLists.txt
similarity index 100%
copy from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/CMakeLists.txt
copy to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/CMakeLists.txt
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/CMakeLists.txt
new file mode 100644
index 0000000..b2527bc
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/CMakeLists.txt
@@ -0,0 +1,33 @@
+function(_add_interfaces)
+ set(LLVM_TARGET_DEFINITIONS LinalgExtInterfaces.td)
+ mlir_tablegen(LinalgExtOpInterfaces.h.inc -gen-op-interface-decls)
+ mlir_tablegen(LinalgExtOpInterfaces.cpp.inc -gen-op-interface-defs)
+ mlir_tablegen(LinalgExtTypeInterfaces.h.inc -gen-type-interface-decls)
+ mlir_tablegen(LinalgExtTypeInterfaces.cpp.inc -gen-type-interface-defs)
+ add_public_tablegen_target(IREELinalgExtInterfacesIncGen)
+ add_dependencies(IREELinalgExtIncGen IREELinalgExtInterfacesIncGen)
+endfunction()
+
+function(_add_tiled_op_interface)
+ set(LLVM_TARGET_DEFINITIONS TiledOpInterface.td)
+ mlir_tablegen(TiledOpInterface.h.inc -gen-op-interface-decls)
+ mlir_tablegen(TiledOpInterface.cpp.inc -gen-op-interface-defs)
+ add_public_tablegen_target(IREELinalgExtTiledOpInterfaceIncGen)
+ add_dependencies(IREELinalgExtIncGen IREELinalgExtTiledOpInterfaceIncGen)
+endfunction()
+
+function(_add_dialect)
+ set(LLVM_TARGET_DEFINITIONS LinalgExtOps.td)
+ mlir_tablegen(LinalgExtOps.h.inc -gen-op-decls)
+ mlir_tablegen(LinalgExtOps.cpp.inc -gen-op-defs)
+ mlir_tablegen(LinalgExtTypes.h.inc -gen-typedef-decls)
+ mlir_tablegen(LinalgExtTypes.cpp.inc -gen-typedef-defs)
+ mlir_tablegen(LinalgExtDialect.h.inc -gen-dialect-decls -dialect=iree_linalg_ext)
+ mlir_tablegen(LinalgExtDialect.cpp.inc -gen-dialect-defs -dialect=iree_linalg_ext)
+ add_public_tablegen_target(IREELinalgExtIncGen)
+ add_dependencies(mlir-headers IREELinalgExtIncGen)
+endfunction()
+
+_add_dialect()
+_add_interfaces()
+_add_tiled_op_interface()
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
new file mode 100644
index 0000000..fb15af2
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td
@@ -0,0 +1,37 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_LINALGEXT_BASE
+#define IREE_DIALECT_LINALGEXT_BASE
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Dialect definition
+//===----------------------------------------------------------------------===//
+
+def IREELinalgExt_Dialect : Dialect {
+ let name = "iree_linalg_ext";
+ let cppNamespace = "::mlir::iree_compiler::IREE::LinalgExt";
+ let description = [{
+ The `iree_linalg_ext` dialect is intended to experiment more support for
+ non-structured operations, ie, can not be represented in Linalg operations.
+ }];
+ let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Type definitions
+//===----------------------------------------------------------------------===//
+
+class RankedTensorOrMemRefOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes,
+ Or<[IsMemRefTypePred, And<[IsTensorTypePred, HasRankPred]>]>,
+ "ranked tensor or memref", "::mlir::ShapedType">;
+
+def AnyRankedTensorOrMemRefType : RankedTensorOrMemRefOf<[AnyType]>;
+
+#endif // IREE_DIALECT_LINALGEXT_BASE
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h
new file mode 100644
index 0000000..8b15ec9
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h
@@ -0,0 +1,17 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTDIALECT_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTDIALECT_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+
+// clang-format off: must be included after all LLVM/MLIR headers
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h.inc" // IWYU pragma: keep
+// clang-format on
+
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTDIALECT_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
new file mode 100644
index 0000000..7bec2f6
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
@@ -0,0 +1,41 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
+
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+class LinalgExtOp;
+
+/// OpOperand vector that implicitly converts to a Value vector.
+struct OpOperandVector : public SmallVector<OpOperand *> {
+ operator SmallVector<Value>();
+};
+
+namespace detail {
+LogicalResult verifyLinalgExtOpInterface(Operation *op);
+}
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h.inc" // IWYU pragma: export
+
+/// Include the generated interface declarations.
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.h.inc" // IWYU pragma: export
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
new file mode 100644
index 0000000..638d4ed
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
@@ -0,0 +1,490 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_LINALGEXT_INTERFACES
+#define IREE_DIALECT_LINALGEXT_INTERFACES
+
+include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td"
+
+// The interface is a subset of LinalgStructuredInterface.
+def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
+ let methods = [
+ //===------------------------------------------------------------------===//
+ // Num input/output arguments handling.
+ //===------------------------------------------------------------------===//
+ // `inputs` must be defined by each op that wants to implement the
+ // LinalgStructuredInterface.
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the input shape operands.
+ }],
+ /*retTy=*/"ValueRange",
+ /*methodName=*/"inputs",
+ /*args=*/(ins)
+ >,
+ // These special methods rely on `inputs` and `outputs` being defined by
+ // each op that wants to implement the LinalgStructuredInterface.
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of inputs.
+ }],
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getNumInputs",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.inputs().size();
+ }]
+ >,
+ // `outputs` must be defined by each op that wants to implement the
+ // LinalgStructuredInterface.
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the output shape operands.
+ }],
+ /*retTy=*/"ValueRange",
+ /*methodName=*/"outputs",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of outputs.
+ }],
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getNumOutputs",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.outputs().size();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of inputs and outputs.
+ }],
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getNumInputsAndOutputs",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getNumInputs() + getNumOutputs();
+ }]
+ >,
+ //===------------------------------------------------------------------===//
+ // Input operands handling.
+ //===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the input operands.
+ }],
+ /*retTy=*/"OpOperandVector",
+ /*methodName=*/"getInputOperands",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ int64_t numInputs = getNumInputs();
+ OpOperandVector result;
+ result.reserve(numInputs);
+ llvm::transform(
+ this->getOperation()->getOpOperands().take_front(numInputs),
+ std::back_inserter(result),
+ [](OpOperand &opOperand) { return &opOperand; });
+ return result;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the `i`-th input operand.
+ }],
+ /*retTy=*/"OpOperand*",
+ /*methodName=*/"getInputOperand",
+ /*args=*/(ins "int64_t":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(i >= 0 && i < getNumInputs());
+ return &this->getOperation()->getOpOperand(i);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the subset of input operands that are of buffer type.
+ }],
+ /*retTy=*/"OpOperandVector",
+ /*methodName=*/"getInputBufferOperands",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ OpOperandVector result;
+ result.reserve(getNumInputs());
+ llvm::copy_if(getInputOperands(),
+ std::back_inserter(result),
+ [](OpOperand *opOperand) {
+ return opOperand->get().getType().template isa<MemRefType>();
+ });
+ return result;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the subset of input operands that are of tensor type.
+ }],
+ /*retTy=*/"OpOperandVector",
+ /*methodName=*/"getInputTensorOperands",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ OpOperandVector result;
+ result.reserve(getNumInputs());
+ llvm::copy_if(getInputOperands(),
+ std::back_inserter(result),
+ [](OpOperand *opOperand) {
+ return opOperand->get().getType().template isa<RankedTensorType>();
+ });
+ return result;
+ }]
+ >,
+ //===------------------------------------------------------------------===//
+ // Output operands handling.
+ //===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the output operands.
+ }],
+ /*retTy=*/"OpOperandVector",
+ /*methodName=*/"getOutputOperands",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ int64_t numOutputs = getNumOutputs();
+ OpOperandVector result;
+ result.reserve(numOutputs);
+ llvm::transform(
+ this->getOperation()->getOpOperands()
+ .drop_front(getNumInputs())
+ .take_front(numOutputs),
+ std::back_inserter(result),
+ [](OpOperand &opOperand) { return &opOperand; });
+ return result;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the `i`-th output operand.
+ }],
+ /*retTy=*/"OpOperand*",
+ /*methodName=*/"getOutputOperand",
+ /*args=*/(ins "int64_t":$i),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(i >= 0 && i < getNumOutputs());
+ return &this->getOperation()->getOpOperand(getNumInputs() + i);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the subset of output operands that are of buffer type.
+ }],
+ /*retTy=*/"OpOperandVector",
+ /*methodName=*/"getOutputBufferOperands",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ OpOperandVector result;
+ result.reserve(getNumOutputs());
+ llvm::copy_if(getOutputOperands(),
+ std::back_inserter(result),
+ [](OpOperand *opOperand) {
+ return opOperand->get().getType().template isa<MemRefType>();
+ });
+ return result;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the subset of output operands that are of tensor type.
+ }],
+ /*retTy=*/"OpOperandVector",
+ /*methodName=*/"getOutputTensorOperands",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ OpOperandVector result;
+ result.reserve(getNumOutputs());
+ llvm::copy_if(getOutputOperands(),
+ std::back_inserter(result),
+ [](OpOperand *opOperand) {
+ return opOperand->get().getType().template isa<RankedTensorType>();
+ });
+ return result;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the types of the subset of output operands that are of buffer type.
+ }],
+ /*retTy=*/"SmallVector<MemRefType>",
+ /*methodName=*/"getOutputBufferTypes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<MemRefType> result;
+ result.reserve(getNumOutputs());
+ llvm::transform(getOutputBufferOperands(),
+ std::back_inserter(result),
+ [](OpOperand *opOperands) {
+ return opOperands->get().getType().cast<MemRefType>();
+ });
+ return result;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the types of the subset of output operands that are of tensor type.
+ }],
+ /*retTy=*/"SmallVector<RankedTensorType>",
+ /*methodName=*/"getOutputTensorTypes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<RankedTensorType> result;
+ result.reserve(getNumOutputs());
+ llvm::transform(getOutputTensorOperands(),
+ std::back_inserter(result),
+ [](OpOperand *opOperands) {
+ return opOperands->get().getType().cast<RankedTensorType>();
+ });
+ return result;
+ }]
+ >,
+ //===------------------------------------------------------------------===//
+ // Input and Output arguments handling.
+ //===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the range over input and output operands.
+ }],
+ /*retTy=*/"OpOperandVector",
+ /*methodName=*/"getInputAndOutputOperands",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ int64_t numInputsAndOutputs = getNumInputsAndOutputs();
+ OpOperandVector result;
+ result.reserve(numInputsAndOutputs);
+ llvm::transform(
+ this->getOperation()->getOpOperands()
+ .take_front(numInputsAndOutputs),
+ std::back_inserter(result),
+ [](OpOperand &opOperand) { return &opOperand; });
+ return result;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if the payload uses the value loaded from `opOperand`. This
+ is useful to avoid loading from "write-only" memory that may be
+ uninitialized, as well as properly cloning "read-write" operands.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"payloadUsesValueFromOperand",
+ /*args=*/(ins "OpOperand *":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ unsigned bbArgNumber = opOperand->getOperandNumber();
+ // Safeguard against the named linalg ops that are manually defined and
+ // that only support buffer semantics: we should not be there.
+ // Such ops have an empty regionBuilder and are not constructed with a
+ // region for now. In the future they are slated to disappear.
+ assert(this->getOperation()->getNumRegions() == 1 && "unexpected "
+ "missing region (calling `payloadUsesValueFromOperand` on "
+ "manually defined named Linalg op?)");
+ Block &block = this->getOperation()->getRegion(0).front();
+ // Init tensors have uses.
+ return !block.getArgument(bbArgNumber).use_empty();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if `opOperand` is an input tensor.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isInputTensor",
+ /*args=*/(ins "OpOperand *":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if (!opOperand->get().getType().template isa<RankedTensorType>())
+ return false;
+ if (opOperand->getOperandNumber() < $_op.getNumInputs())
+ return true;
+ return false;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if `opOperand` is an output tensor.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isOutputTensor",
+ /*args=*/(ins "OpOperand *":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if (!opOperand->get().getType().template isa<RankedTensorType>())
+ return false;
+ if (opOperand->getOperandNumber() >= $_op.getNumInputs())
+ return true;
+ return false;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if `opOperand` is an init tensor. This is true when it is
+ an output tensor operand whose value is used in the payload region.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isInitTensor",
+ /*args=*/(ins "OpOperand *":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if (!$_op.isOutputTensor(opOperand))
+ return false;
+ return payloadUsesValueFromOperand(opOperand);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the `opOperand` rank or zero for scalars.
+ }],
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getRank",
+ /*args=*/(ins "OpOperand*":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(opOperand->getOwner() == this->getOperation());
+ if (auto shapedType =
+ opOperand->get().getType().template dyn_cast<ShapedType>())
+ return shapedType.getRank();
+ return 0;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the `opOperand` shape or an empty vector for scalars.
+ }],
+ /*retTy=*/"ArrayRef<int64_t>",
+ /*methodName=*/"getShape",
+ /*args=*/(ins "OpOperand*":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(opOperand->getOwner() == this->getOperation());
+ if (auto shapedType =
+ opOperand->get().getType().template dyn_cast<ShapedType>())
+ return shapedType.getShape();
+ return {};
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if the `opOperand` is a scalar value.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isScalar",
+ /*args=*/(ins "OpOperand*":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(opOperand->getOwner() == this->getOperation());
+ return !opOperand->get().getType().template isa<ShapedType>();
+ }]
+ >,
+ //===------------------------------------------------------------------===//
+ // Other interface methods.
+ //===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Return whether the op has only MemRef input and outputs.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasBufferSemantics",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return this->getOperation()->getNumResults() == 0 &&
+ llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
+ return isScalar(opOperand) ||
+ opOperand->get().getType().template isa<MemRefType>();
+ }) &&
+ llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
+ return opOperand->get().getType().template isa<MemRefType>();
+ });
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return whether the op has only RankedTensor input and outputs.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasTensorSemantics",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return
+ llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
+ return isScalar(opOperand) ||
+ opOperand->get().getType().template isa<RankedTensorType>();
+ }) &&
+ llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
+ return opOperand->get().getType().template isa<RankedTensorType>();
+ });
+ }]
+ >,
+ //===------------------------------------------------------------------===//
+ // Other static interface methods.
+ //===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Clone the current operation with the given location and operands. This
+ is used to abstract away the optional underlying region creation. This
+ does not change the balance between input, output_buffer and
+ init_tensors operands.
+ }],
+ /*retTy=*/"Operation *",
+ /*methodName=*/"clone",
+ (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
+ "ValueRange":$operands),
+ [{
+ BlockAndValueMapping bvm;
+ OperationState state(
+ loc, ConcreteOp::getOperationName(), operands, resultTypes,
+ $_op->getAttrs());
+ for (Region &r : $_op->getRegions())
+ r.cloneInto(state.addRegion(), bvm);
+ return b.createOperation(state);
+ }]
+ >
+ ];
+
+ let extraClassDeclaration = [{
+ //========================================================================//
+ // Helper functions to mutate the `operand_segment_sizes` attribute.
+ // These are useful when cloning and changing operand types.
+ //========================================================================//
+ void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); }
+ void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); }
+
+ private:
+ void setOperandSegmentAt(unsigned idx, unsigned val) {
+ auto attr = (*this)->getAttr("operand_segment_sizes")
+ .cast<DenseIntElementsAttr>();
+ unsigned i = 0;
+ auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
+ [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
+ getOperation()->setAttr("operand_segment_sizes", newAttr);
+ }
+ }];
+
+ let verify = [{ return detail::verifyLinalgExtOpInterface($_op); }];
+}
+
+#endif // IREE_DIALECT_LINALGEXT_INTERFACES
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h
new file mode 100644
index 0000000..8e9fed2
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h
@@ -0,0 +1,40 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
+/// `dim`.
+Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim);
+
+/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
+/// `dim`. If the shape is constant, returns the shape as an `IntegerAttr`.
+OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim);
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h.inc" // IWYU pragma: export
+
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
new file mode 100644
index 0000000..fa75ea3
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -0,0 +1,322 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_LINALGEXT_OPS
+#define IREE_DIALECT_LINALGEXT_OPS
+
+include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td"
+include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
+include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Base class.
+//===----------------------------------------------------------------------===//
+
+class IREELinalgExt_PureOp<string mnemonic, list<OpTrait> traits = []> :
+ Op<IREELinalgExt_Dialect, mnemonic, traits> {
+}
+
+class IREELinalgExt_Op<string mnemonic, list<OpTrait> traits = []> :
+ IREELinalgExt_PureOp<mnemonic, !listconcat(traits,
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ LinalgExtInterface,
+ SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp">
+ ])> {
+ let verifier = [{ return verify$cppClass(*this); }];
+ let printer = [{ return print$cppClass(p, *this); }];
+ let parser = [{ return parse$cppClass(parser, result); }];
+ code extraLinalgExtOpClassDeclaration = [{
+ SmallVector<Value> getDestinationOperands(OpBuilder &b) {
+ SmallVector<Value> dest(outputs().begin(), outputs().end());
+ return dest;
+ }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Non-structured ops
+//===----------------------------------------------------------------------===//
+
+def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
+ [DeclareOpInterfaceMethods<TiledOpInterface,
+ ["getTiledImplementation", "generateScalarImplementation"]>]> {
+ let summary = "Scatter operator";
+ let description = [{
+ Based on XLA operation semantics, takes two `inputs` (`update` and
+ `indices`) and `outputs` value (`original`). The operation updates
+ the value at the slices specified by `indices` by combining the
+ current value with the value in `updates` using the computation
+ specified in `region`. The `region` specifies a binary operation
+ of signature (T, T) -> T, where `T` is the element-type of
+ `updates` (and `original`). The first argument correspond the
+ value to be updated (i.e. from `updates`), and the second the
+ current value (i.e. value from `original`).
+
+ The `indices` is a 2D tensor/memref type. The first dim is the number of
+ updates, and the second dim is index depth. The index depth should always be
+ static.
+
+ The first dim of `updates` and `indices` is identical, since they represent
+ the number of updates.
+
+ The rank of the `original`/`result` is `index_depth + rank(%updates) - 1`.
+ The first `index_depth` indices are derived from `indices` and the shape of
+ update value must match the rest shape of `original`.
+
+ The shapes definition follows tensorflow operations execept that it force
+ batch dims to be 1D. See more information in
+ https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update
+ }];
+ let arguments = (ins
+ Variadic<AnyRankedTensorOrMemRefType>:$inputs,
+ Variadic<AnyRankedTensorOrMemRefType>:$outputs
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$results);
+ let regions = (region AnyRegion:$region);
+ let assemblyFormat = [{
+ attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
+ `outs` `(` $outputs `:` type($outputs) `)`
+ $region (`->` type($results)^)?
+ }];
+ let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+
+ int64_t getIndexDepth() {
+ return getInputOperand(1)
+ ->get()
+ .getType()
+ .cast<ShapedType>()
+ .getShape()
+ .back();
+ }
+
+ Value updates() {
+ return getInputOperand(0)->get();
+ }
+
+ ShapedType getUpdateType() {
+ return updates().getType().cast<ShapedType>();
+ }
+
+ Value indices() {
+ return getInputOperand(1)->get();
+ }
+
+ ShapedType getIndicesType() {
+ return indices().getType().cast<ShapedType>();
+ }
+
+ Value original() {
+ return getOutputOperand(0)->get();
+ }
+
+ ShapedType getOriginalType() {
+ return original().getType().cast<ShapedType>();
+ }
+
+ int64_t getUpdateSliceRank() {
+ return updates().getType().cast<ShapedType>().getRank() - 1;
+ }
+
+ bool isScalarUpdate() {
+ return getUpdateSliceRank() == 0;
+ }
+ }];
+}
+
+def IREELinalgExt_SortOp : IREELinalgExt_Op<"sort",
+ [DeclareOpInterfaceMethods<TiledOpInterface,
+ ["getPartitionableLoops", "generateScalarImplementation",
+ "getTiledImplementation"]>]> {
+ let summary = "Sort operator";
+ let description = [{
+ Based on XLA operation semantics, sorts the given `operands` at the given
+ `dimension` with the given `comparator`.
+
+ See https://www.tensorflow.org/xla/operation_semantics#sort.
+ }];
+
+ // Define arguments and results like linalg.generic op. The attribute has the
+ // same definition as mhlo.sort::dimension. If the rank is greater than 1,
+ // the attribute must be set. If the rank is exacatly 1, the dimension is
+ // optional.
+ let arguments = (ins Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ OptionalAttr<I64Attr>:$dimension
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$results);
+ let regions = (region AnyRegion:$region);
+ let assemblyFormat = [{
+ (`dimension` `(` $dimension^ `)`)?
+ attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
+ `outs` `(` $outputs `:` type($outputs) `)`
+ $region (`->` type($results)^)?
+ }];
+ let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+ Value operand(int index) {
+ return outputs()[index];
+ }
+ ShapedType getOperandType(int index) {
+ return operand(index).getType().cast<ShapedType>();
+ }
+ int64_t getOperandRank() {
+ return getOperandType(0).getRank();
+ }
+ ArrayRef<int64_t> getOperandShape() {
+ return getOperandType(0).getShape();
+ }
+ uint64_t getSortedDimension() {
+ uint64_t sortedDim = 0;
+ if (auto setSortedDim = dimension()) {
+ sortedDim = *setSortedDim;
+ }
+ return sortedDim;
+ }
+ }];
+}
+
+def IREELinalgExt_FftOp : IREELinalgExt_Op<"fft", [
+ DeclareOpInterfaceMethods<TiledOpInterface,
+ [
+ "getPartitionableLoops", "getTiledImplementation",
+ "generateScalarImplementation"
+ ]>,
+ DeclareOpInterfaceMethods<LinalgExtInterface,
+ // FftOp does not have a region, so we have to
+ // overwrite the method.
+ ["payloadUsesValueFromOperand"]>
+]> {
+ let summary = "Fft operator";
+ let description = [{
+ Apply 1D FFT to innermost dim. This is an iterative FFT, not recurrsive.
+ Thus, the bit reversal is assumed applied on the input. The op carries an
+ input -- stage, which indicates the level of reduction loop in the
+ algorithm. It represents the computation body. For more details, see
+ "Data reordering, bit reversal, and in-place algorithms" section in
+ https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm
+
+ The size of innermost dim is expected to be a power of 2.
+
+ It is optional to carry coefficient tensors/buffers as inputs. In this
+ context, they will be the second and third inputs.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$results);
+ let assemblyFormat = [{
+ attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
+ `outs` `(` $outputs `:` type($outputs) `)`
+ (`:` type($results)^)?
+ }];
+ let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+ Value getStage() { return inputs()[0]; }
+ Value getReal() { return outputs()[0]; }
+ Value getImag() { return outputs()[1]; }
+ bool hasCoeff() { return getNumInputs() > 1; }
+ void generateScalarImplWithoutCoeffBuf(
+ OpBuilder & b, Location loc, ArrayRef<Value> operands, Value wholeSize);
+ void generateScalarImplWithCoeffBuf(OpBuilder & b, Location loc,
+ ArrayRef<Value> operands);
+ Value getRealCoeff() {
+ if (!hasCoeff()) return Value();
+ return inputs()[1];
+ }
+ Value getImagCoeff() {
+ if (!hasCoeff()) return Value();
+ return inputs()[2];
+ }
+ ShapedType getOperandType() {
+ return getReal().getType().cast<ShapedType>();
+ }
+ int64_t getOperandRank() {
+ return getOperandType().getRank();
+ }
+ ArrayRef<int64_t> getOperandShape() {
+ return getOperandType().getShape();
+ }
+ int64_t getFftLength() {
+ return getOperandShape().back();
+ }
+ }];
+}
+
+def IREELinalgExt_ReverseOp : IREELinalgExt_Op<"reverse", [
+ DeclareOpInterfaceMethods<
+ TiledOpInterface,
+ ["generateScalarImplementation", "getTiledImplementation"]>,
+ DeclareOpInterfaceMethods<LinalgExtInterface,
+ // ReverseOp does not have a region, so we have to
+ // overwrite the method.
+ ["payloadUsesValueFromOperand"]>]> {
+ let summary = "Reverse operator";
+ let description = [{
+ A temporary solution for lowering reverse ops into IREE, allowing IREE to
+ tile and distribute them.
+ }
+ }];
+
+ let arguments = (ins Variadic<AnyShaped>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ I64ElementsAttr:$dimensions
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$results);
+ let assemblyFormat = [{
+ `dimensions` `(` $dimensions `)`
+ attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
+ (`outs` `(` $outputs^ `:` type($outputs) `)`)?
+ (`:` type($results)^)?
+ }];
+ let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+ Value input() {
+ return getInputOperand(0)->get();
+ }
+ Value output() {
+ return getOutputOperand(0)->get();
+ }
+ ShapedType getOperandType() {
+ return input().getType().cast<ShapedType>();
+ }
+ int64_t getOperandRank() {
+ return getOperandType().getRank();
+ }
+ ArrayRef<int64_t> getOprerandShape() {
+ return getOperandType().getShape();
+ }
+ SmallVector<int64_t> dims() {
+ SmallVector<int64_t> ret;
+ for (const APInt& elem : dimensions()) {
+ ret.push_back(elem.getLimitedValue());
+ }
+ return ret;
+ }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Pure ops
+//===----------------------------------------------------------------------===//
+
+def IREELinalgExt_YieldOp : IREELinalgExt_PureOp<"yield", [NoSideEffect, ReturnLike, Terminator]> {
+ let summary = "LinalgExt yield op";
+ let description = [{
+ `linalg_ext.yield` is a special terminator operation for blocks inside
+ regions in `linalg_ext` ops.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let builders = [
+ OpBuilder<(ins), [{ /* nothing to do */ }]>,
+ ];
+
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+}
+
+#endif // IREE_DIALECT_LINALGEXT_OPS
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h
new file mode 100644
index 0000000..17d8faf
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h
@@ -0,0 +1,33 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_IR_TILEDOPINTERFACE_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_IR_TILEDOPINTERFACE_H_
+
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "mlir/Support/LLVM.h"
+
+/// Include the ODS generated interface header files.
+#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h.inc"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+/// Registers external models implemented for the `TiledOpInterface`.
+void registerTiledOpInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_IR_TILEDOPINTERFACE_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.td
new file mode 100644
index 0000000..3c9ce51
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.td
@@ -0,0 +1,127 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_LINALGEXT_TILEDOPINTERFACE
+#define IREE_DIALECT_LINALGEXT_TILEDOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def TiledOpInterface : OpInterface<"TiledOpInterface"> {
+ let description = [{
+ Interface for allowing operations to expose information needed to
+ tile it (similar to LinalgOp, but without having access to
+ indexing maps)
+ }];
+ let cppNamespace = "::mlir::iree_compiler::IREE::LinalgExt";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the destination operands. For op with `memref`
+ operands, this is the result buffers. For op with `tensor`
+ operands, this is the operands that contain the initial
+ value for the result. These are "tied" to the result
+ buffers. For example, for a `LinalgOp`/`LinalgExt` ops, it
+ is the `outs` parameters. For `tensor.insert_slice`, it is
+ the `dest` parameter.
+ }],
+ /*retType=*/"SmallVector<Value>",
+ /*methodName=*/"getDestinationOperands",
+ /*args=*/(ins "OpBuilder &":$b),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return ValueRange{};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a list of `StringRef`s that describe the number of
+ loops and the iterator types of the operation. The list is
+ expected to use
+ `getParallelIteratorTypeName()`/`getReductionIteratorTypeName()`
+ from MLIR Structured Op Utils.
+ }],
+ /*retType=*/"SmallVector<StringRef>",
+ /*methodName=*/"getLoopIteratorTypes"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a list of ranges that describe the loop bounds and
+ step for the loops of the operation.
+ }],
+ /*retTy=*/"SmallVector<Range>",
+ /*methodName=*/"getLoopBounds",
+ /*args=*/(ins "OpBuilder &":$b)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the loops that are to be distributed given the
+ maximum amount of logical processor dimensions available.
+ }],
+ /*retTy=*/"SmallVector<unsigned>",
+ /*methodName=*/"getPartitionableLoops",
+ /*args=*/(ins "unsigned ":$maxNumParallelDims),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<unsigned> partitionableLoops;
+ auto interfaceOp = cast<TiledOpInterface>($_op.getOperation());
+ for (auto iteratorType :
+ llvm::enumerate(interfaceOp.getLoopIteratorTypes())) {
+ if (iteratorType.value() != getParallelIteratorTypeName()) {
+ break;
+ }
+ partitionableLoops.push_back(iteratorType.index());
+ }
+ if (partitionableLoops.size() > maxNumParallelDims) {
+ partitionableLoops.erase(
+ partitionableLoops.begin(),
+ std::next(partitionableLoops.begin(),
+ partitionableLoops.size() - maxNumParallelDims));
+ }
+ return partitionableLoops;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Generates a tiled version of the operation given the tile
+ size for the loops.
+
+ Returns the tiled operation generated. If the operation has
+ tensor semantics then the result of the tiled values are to
+ be inserted into the `outputs` and return in `results`.
+ }],
+ /*retType=*/"Operation *",
+ /*methodName=*/"getTiledImplementation",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "ValueRange ":$outputs,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVectorImpl<Value> &":$results),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return nullptr;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Generates the loop body implementation. Assume that all the parallel
+ loops and reduction loops are created and the insertion point of the
+ build is set to the innermost of the loop. This method implements the
+ loop body IRs.
+ }],
+ /*retType=*/"LogicalResult",
+ /*methodName=*/"generateScalarImplementation",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "Location ":$loc,
+ "ValueRange ":$ivs),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >
+ ];
+}
+
+#endif // IREE_DIALECT_LINALGEXT_TILEDOPINTERFACES
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CMakeLists.txt
similarity index 75%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/CMakeLists.txt
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CMakeLists.txt
index 6b2f75c..29737fc 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -2,4 +2,4 @@
mlir_tablegen(Passes.h.inc -gen-pass-decls)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl)
-add_public_tablegen_target(MLIRIREEPyDMTransformsPassesIncGen)
+add_public_tablegen_target(IREELinalgExtTransformsPassesIncGen)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/PassDetail.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/PassDetail.h
new file mode 100644
index 0000000..4457f87
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/PassDetail.h
@@ -0,0 +1,19 @@
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+#define GEN_PASS_CLASSES
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h.inc" // IWYU pragma: keep
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASS_DETAIL_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Passes.h
new file mode 100644
index 0000000..c294f92
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Passes.h
@@ -0,0 +1,28 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASSES_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+std::unique_ptr<OperationPass<FuncOp>> createTiledOpInterfaceTilingPass();
+
+std::unique_ptr<OperationPass<FuncOp>> createLinalgExtToLoopsPass();
+
+void registerPasses();
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_PASSES_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Passes.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Passes.td
new file mode 100644
index 0000000..1c132b7
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Passes.td
@@ -0,0 +1,24 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_LINALGEXT_PASSES
+#define IREE_DIALECT_LINALGEXT_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def LinalgExtToLoops :
+ Pass<"iree-linalg-ext-to-loops", "FuncOp"> {
+ let summary = "Convert LinalgExt ops to loops and Linalg ops.";
+ let constructor = "mlir::iree_compiler::IREE::LinalgExt::createLinalgExtToLoopsPass()";
+}
+
+def TiledOpInterfaceTiling :
+ Pass<"iree-linalg-ext-tile", "FuncOp"> {
+ let summary = "Test pass for tiling using TiledOpInterface";
+ let constructor = "mlir::iree_compiler::IREE::LinalgExt::createTiledOpInterfaceTilingPass()";
+}
+
+#endif // IREE_DIALECT_LINALGEXT_PASSES
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
new file mode 100644
index 0000000..6fa1f51
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -0,0 +1,93 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
+#define IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
+
+#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+/// Structure to represent the result of tiling operation.
+struct TiledOp {
+ /// Tiled op.
+ Operation *op;
+ /// Loops generated during tiling.
+ SmallVector<Operation *> loops;
+ /// Values that are replacements for the untiled operations.
+ SmallVector<Value> results;
+};
+
+/// Main entry point for tiling LinalgExtOps using TiledOpInterface.
+FailureOr<TiledOp> tileLinalgExtOp(OpBuilder &b, TiledOpInterface tilableOp,
+ const linalg::LinalgTilingOptions &options);
+
+/// Base rewrite pattern to tile and distribute operations that implement the
+/// `TiledOpInterface`.
+/// Base pattern for tiling TiledOpInterfaceOps.
+struct TiledOpInterfaceBaseTilingPattern
+ : public OpInterfaceRewritePattern<TiledOpInterface> {
+ TiledOpInterfaceBaseTilingPattern(MLIRContext *context,
+ linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern(context, benefit),
+ filter(filter),
+ options(options) {}
+
+ LogicalResult matchAndRewriteBase(TiledOpInterface tilableOp,
+ PatternRewriter &rewriter,
+ TiledOp &result) const;
+
+ private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ linalg::LinalgTransformationFilter filter;
+ /// Options to control tiling;
+ linalg::LinalgTilingOptions options;
+};
+
+struct TiledOpInterfaceTilingPattern
+ : public TiledOpInterfaceBaseTilingPattern {
+ TiledOpInterfaceTilingPattern(MLIRContext *context,
+ linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : TiledOpInterfaceBaseTilingPattern(context, options, filter, benefit) {}
+
+ LogicalResult matchAndRewrite(TiledOpInterface tilableOp,
+ PatternRewriter &rewriter) const override {
+ TiledOp tiledOp;
+ // Check for failure.
+ if (failed(TiledOpInterfaceBaseTilingPattern::matchAndRewriteBase(
+ tilableOp, rewriter, tiledOp))) {
+ return failure();
+ }
+ // Check for do-nothing case.
+ if (!tiledOp.op) return failure();
+ if (tiledOp.op != tilableOp) {
+ if (tiledOp.results.empty()) {
+ rewriter.eraseOp(tilableOp);
+ } else {
+ rewriter.replaceOp(tilableOp, tiledOp.results);
+ }
+ }
+ return success();
+ }
+};
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/CMakeLists.txt
similarity index 100%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/CMakeLists.txt
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/CMakeLists.txt
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/CMakeLists.txt
new file mode 100644
index 0000000..060b427
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/CMakeLists.txt
@@ -0,0 +1,24 @@
+function(_add_interfaces)
+ set(LLVM_TARGET_DEFINITIONS PyDMInterfaces.td)
+ mlir_tablegen(PyDMOpInterfaces.h.inc -gen-op-interface-decls)
+ mlir_tablegen(PyDMOpInterfaces.cpp.inc -gen-op-interface-defs)
+ mlir_tablegen(PyDMTypeInterfaces.h.inc -gen-type-interface-decls)
+ mlir_tablegen(PyDMTypeInterfaces.cpp.inc -gen-type-interface-defs)
+ add_public_tablegen_target(IREEPyDMInterfacesIncGen)
+endfunction()
+
+function(_add_dialect)
+ set(LLVM_TARGET_DEFINITIONS PyDMOps.td)
+ mlir_tablegen(PyDMOps.h.inc -gen-op-decls)
+ mlir_tablegen(PyDMOps.cpp.inc -gen-op-defs)
+ mlir_tablegen(PyDMTypes.h.inc -gen-typedef-decls)
+ mlir_tablegen(PyDMTypes.cpp.inc -gen-typedef-defs)
+ mlir_tablegen(PyDMDialect.h.inc -gen-dialect-decls -dialect=iree_pydm)
+ mlir_tablegen(PyDMDialect.cpp.inc -gen-dialect-defs -dialect=iree_pydm)
+ add_public_tablegen_target(IREEPyDMIncGen)
+ add_dependencies(IREEPyDMIncGen IREEPyDMInterfacesIncGen)
+ add_dependencies(mlir-headers IREEPyDMIncGen)
+endfunction()
+
+_add_dialect()
+_add_interfaces()
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Constants.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/Constants.h
similarity index 95%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Constants.h
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/Constants.h
index 2bcf191..a8e7655 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Constants.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/Constants.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_CONSTANTS_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_CONSTANTS_H
+#ifndef IREE_DIALECTS_DIALECT_PYDM_IR_CONSTANTS_H
+#define IREE_DIALECTS_DIALECT_PYDM_IR_CONSTANTS_H
namespace mlir {
namespace iree_compiler {
@@ -149,4 +149,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_CONSTANTS_H
+#endif // IREE_DIALECTS_DIALECT_PYDM_IR_CONSTANTS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Base.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMBase.td
similarity index 88%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Base.td
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMBase.td
index 04241b7..96ebfbf 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Base.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMBase.td
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_BASE_TD
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_BASE_TD
+#ifndef IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_BASE_TD
+#define IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_BASE_TD
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -44,4 +44,4 @@
}
class IREEPyDM_TypeDef<string name, list<Trait> traits = []> : TypeDef<IREEPyDM_Dialect, name, traits>;
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_BASE_TD
+#endif // IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_BASE_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.h
similarity index 71%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.h
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.h
index 56082aa..0c0a40e 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.h
@@ -4,11 +4,11 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_H
+#ifndef IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_DIALECT_H
+#define IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_DIALECT_H
-#include "iree-dialects/Dialect/IREEPyDM/IR/Constants.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h"
+#include "iree-dialects/Dialect/PyDM/IR/Constants.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Types.h"
@@ -32,10 +32,10 @@
// Include generated dialect code (this comment blocks clang-format from
// clobbering order).
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h.inc"
#define GET_TYPEDEF_CLASSES
-#include "iree-dialects/Dialect/IREEPyDM/IR/Types.h.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMTypes.h.inc"
namespace mlir {
namespace iree_compiler {
@@ -53,4 +53,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_H
+#endif // IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_DIALECT_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.td
similarity index 96%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.td
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.td
index 0125fb9..2a30fdb 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Dialect.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMDialect.td
@@ -4,11 +4,11 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_TD
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_TD
+#ifndef IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_DIALECT_TD
+#define IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_DIALECT_TD
-include "iree-dialects/Dialect/IREEPyDM/IR/Base.td"
-include "iree-dialects/Dialect/IREEPyDM/IR/Interfaces.td"
+include "iree-dialects/Dialect/PyDM/IR/PyDMBase.td"
+include "iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.td"
//===----------------------------------------------------------------------===//
// Variable Types
@@ -152,7 +152,6 @@
let printer = [{
auto w = getImpl()->bitWidth;
- $_printer << "integer";
if (w) {
$_printer << "<";
if (*w == 0) {
@@ -219,7 +218,6 @@
];
let printer = [{
- $_printer << getMnemonic();
if (getImpl()->uniformElementType ||
getImpl()->storageClass != CollectionStorageClass::Boxed) {
$_printer << "<";
@@ -335,7 +333,6 @@
let printer = [{
auto ft = getImpl()->floatType;
- $_printer << "real";
if (ft)
$_printer << "<" << ft << ">";
}];
@@ -429,7 +426,6 @@
];
let printer = [{
- $_printer << getMnemonic();
if (getImpl()->primitiveType)
$_printer << "<" << getImpl()->primitiveType << ">";
}];
@@ -484,7 +480,6 @@
let genVerifyDecl = 1;
let printer = [{
- $_printer << getMnemonic();
llvm::interleaveComma(getAlternatives(), $_printer);
}];
@@ -538,4 +533,4 @@
let constBuilderCall = ?;
}
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_TD
+#endif // IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_DIALECT_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h
new file mode 100644
index 0000000..63ef309
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h
@@ -0,0 +1,30 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_INTERFACES_H
+#define IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_INTERFACES_H
+
+#include "iree-dialects/Dialect/PyDM/IR/Constants.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace PYDM {
+
+enum class BuiltinTypeCode;
+
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOpInterfaces.h.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMTypeInterfaces.h.inc"
+
+#endif // IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_INTERFACES_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.td
similarity index 91%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.td
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.td
index bd553ca..8383a47 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Interfaces.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.td
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_INTERFACES_TD
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_INTERFACES_TD
+#ifndef IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_INTERFACES_TD
+#define IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_INTERFACES_TD
include "mlir/IR/OpBase.td"
@@ -82,4 +82,4 @@
];
}
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_INTERFACES_TD
+#endif // IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_INTERFACES_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Ops.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMOps.h
similarity index 60%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Ops.h
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMOps.h
index a70ae13..18fe67c 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Ops.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMOps.h
@@ -4,11 +4,11 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_IREEPYDM_IR_OPS_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_IREEPYDM_IR_OPS_H
+#ifndef IREE_DIALECTS_IREEPYDM_IR_OPS_H
+#define IREE_DIALECTS_IREEPYDM_IR_OPS_H
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -20,6 +20,6 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#define GET_OP_CLASSES
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h.inc"
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_IREEPYDM_IR_OPS_H
+#endif // IREE_DIALECTS_IREEPYDM_IR_OPS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Ops.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMOps.td
similarity index 98%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Ops.td
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMOps.td
index b09d017..5d5e821 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/Ops.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/IR/PyDMOps.td
@@ -4,10 +4,10 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_OPS_TD
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_OPS_TD
+#ifndef IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_OPS_TD
+#define IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_OPS_TD
-include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.td"
+include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
@@ -795,4 +795,4 @@
}];
}
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_OPS_TD
+#endif // IREE_DIALECTS_DIALECT_PYDM_IR_PYDM_OPS_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/CMakeLists.txt
similarity index 75%
copy from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/CMakeLists.txt
copy to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/CMakeLists.txt
index 6b2f75c..2e5c4cb 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/CMakeLists.txt
@@ -2,4 +2,4 @@
mlir_tablegen(Passes.h.inc -gen-pass-decls)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl)
-add_public_tablegen_target(MLIRIREEPyDMTransformsPassesIncGen)
+add_public_tablegen_target(IREEPyDMTransformsPassesIncGen)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/Passes.h
similarity index 86%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/Passes.h
index 477b394..d29cef4 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/Passes.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
+#ifndef IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
+#define IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
#include <memory>
@@ -57,4 +57,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/Passes.td
similarity index 91%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.td
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/Passes.td
index 2376b9e..24d127e 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/Passes.td
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_CONVERSION_TO_IREE_PASSES_TD
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_CONVERSION_TO_IREE_PASSES_TD
+#ifndef IREE_DIALECTS_DIALECT_IREEPYDM_CONVERSION_TO_IREE_PASSES_TD
+#define IREE_DIALECTS_DIALECT_IREEPYDM_CONVERSION_TO_IREE_PASSES_TD
include "mlir/Pass/PassBase.td"
@@ -71,4 +71,4 @@
let constructor = "mlir::iree_compiler::IREE::PYDM::createConvertIREEPyDMToIREEPass()";
}
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_CONVERSION_TO_IREE_PASSES_TD
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_CONVERSION_TO_IREE_PASSES_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h
similarity index 75%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.h
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h
index d99c830..c71ee27 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_RTL_LINKAGE_ANALYSIS_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_RTL_LINKAGE_ANALYSIS_H
+#ifndef IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_RTL_LINKAGE_ANALYSIS_H
+#define IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_RTL_LINKAGE_ANALYSIS_H
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -39,4 +39,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_RTL_LINKAGE_ANALYSIS_H
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_RTL_LINKAGE_ANALYSIS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/Patterns.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h
similarity index 72%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/Patterns.h
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h
index 3cced96..5e2ca56 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/Patterns.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_PATTERNS_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_PATTERNS_H
+#ifndef IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_PATTERNS_H
+#define IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_PATTERNS_H
namespace mlir {
@@ -28,4 +28,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_LOWERING_PATTERNS_H
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_LOWERING_PATTERNS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h
similarity index 77%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h
index e58a1d5..3d3cf46 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_TYPECONVERTER_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_TYPECONVERTER_H
+#ifndef IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_TYPECONVERTER_H
+#define IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_TYPECONVERTER_H
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -43,4 +43,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_TYPECONVERTER_H
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_TOIREE_TYPECONVERTER_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Utils/TypeInference.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Utils/TypeInference.h
similarity index 92%
rename from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Utils/TypeInference.h
rename to llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Utils/TypeInference.h
index 214175e..cdc628b 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Utils/TypeInference.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/PyDM/Utils/TypeInference.h
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_UTILS_TYPE_INFERENCE_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_UTILS_TYPE_INFERENCE_H
+#ifndef IREE_DIALECTS_DIALECT_IREEPYDM_UTILS_TYPE_INFERENCE_H
+#define IREE_DIALECTS_DIALECT_IREEPYDM_UTILS_TYPE_INFERENCE_H
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
@@ -95,4 +95,4 @@
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_UTILS_TYPE_INFERENCE_H
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_UTILS_TYPE_INFERENCE_H
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
index 75be3ec..fde1221 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
@@ -3,9 +3,9 @@
Utils.cpp
LINK_LIBS PUBLIC
MLIRIR
- IREEDialectsIREEDialect
- IREEDialectsIREEPyDMDialect
- IREEDialectsIREEPyDMPasses
+ IREEInputDialect
+ IREEPyDMDialect
+ IREEPyDMPasses
)
iree_dialects_target_includes(IREEDialectsCAPI)
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
index 003c765..ac169f1 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
@@ -6,9 +6,9 @@
#include "iree-dialects-c/Dialects.h"
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Pass.h"
#include "mlir/CAPI/Registration.h"
@@ -24,7 +24,8 @@
// IREEDialect
//===----------------------------------------------------------------------===//
-MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(IREE, iree, mlir::iree::IREEDialect)
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(
+ IREEInput, iree_input, mlir::iree_compiler::IREE::Input::IREEInputDialect)
//===----------------------------------------------------------------------===//
// IREEPyDMDialect
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
index 61df04e..620c526 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
@@ -1,2 +1,3 @@
-add_subdirectory(IREE)
-add_subdirectory(IREEPyDM)
+add_subdirectory(Input)
+add_subdirectory(LinalgExt)
+add_subdirectory(PyDM)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREE/CMakeLists.txt
deleted file mode 100644
index 0630336..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/CMakeLists.txt
+++ /dev/null
@@ -1,16 +0,0 @@
-add_mlir_library(IREEDialectsIREEDialect
- IREEDialect.cpp
- IREEOps.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${IREE_DIALECTS_SOURCE_DIR}/include
-
- DEPENDS
- MLIRIREEOpsIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRSideEffectInterfaces
-)
-
-iree_dialects_target_includes(IREEDialectsIREEDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEDialect.cpp
deleted file mode 100644
index 6b2ab7f..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEDialect.cpp
+++ /dev/null
@@ -1,43 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-
-#include "iree-dialects/Dialect/IREE/IREEOps.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "mlir/IR/DialectImplementation.h"
-#include "mlir/Support/LLVM.h"
-
-using namespace mlir;
-using namespace mlir::iree;
-
-#include "iree-dialects/Dialect/IREE/IREEOpsDialect.cpp.inc"
-
-#define GET_TYPEDEF_CLASSES
-#include "iree-dialects/Dialect/IREE/IREEOpsTypes.cpp.inc"
-
-void IREEDialect::initialize() {
- addTypes<
-#define GET_TYPEDEF_LIST
-#include "iree-dialects/Dialect/IREE/IREEOpsTypes.cpp.inc"
- >();
- addOperations<
-#define GET_OP_LIST
-#include "iree-dialects/Dialect/IREE/IREEOps.cpp.inc"
- >();
-}
-
-Type IREEDialect::parseType(DialectAsmParser &parser) const {
- StringRef typeTag;
- Type genType;
- if (succeeded(parser.parseKeyword(&typeTag)))
- generatedTypeParser(parser, typeTag, genType);
- return genType;
-}
-
-void IREEDialect::printType(Type type, DialectAsmPrinter &printer) const {
- (void)generatedTypePrinter(type, printer);
-}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt
deleted file mode 100644
index b6293fe..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt
+++ /dev/null
@@ -1,16 +0,0 @@
-add_mlir_library(IREEDialectsIREEPyDMDialect
- Dialect.cpp
- Ops.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${IREE_DIALECTS_SOURCE_DIR}/include
-
- DEPENDS
- MLIRIREEPyDMOpsIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRSideEffectInterfaces
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
deleted file mode 100644
index db6f84f..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
+++ /dev/null
@@ -1,18 +0,0 @@
-add_subdirectory(Optimize)
-add_subdirectory(RTL)
-add_subdirectory(ToIREE)
-
-add_mlir_library(IREEDialectsIREEPyDMPasses
- Passes.cpp
-
- DEPENDS
- MLIRIREEPyDMTransformsPassesIncGen
-
- LINK_LIBS PUBLIC
- IREEDialectsIREEPyDMOptimizePasses
- IREEDialectsIREEPyDMRTLPasses
- IREEDialectsIREEPyDMToIREEPasses
- MLIRTransforms
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt
deleted file mode 100644
index 093f273..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt
+++ /dev/null
@@ -1,16 +0,0 @@
-add_mlir_library(IREEDialectsIREEPyDMOptimizePasses
- FixateWeakNumeric.cpp
- LocalPropagateTypes.cpp
- VariablesToSSA.cpp
-
- DEPENDS
- MLIRIREEPyDMTransformsPassesIncGen
-
- LINK_LIBS PUBLIC
- IREEDialectsIREEPyDMDialect
- IREEDialectsIREEPyDMUtils
- MLIRIR
- MLIRTransformUtils
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMOptimizePasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/PassDetail.h b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/PassDetail.h
deleted file mode 100644
index a95ac11..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/PassDetail.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
-
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-
-namespace iree {
-class IREEDialect;
-}
-
-namespace iree_compiler {
-namespace IREE {
-namespace PYDM {
-
-class FuncOp;
-
-#define GEN_PASS_CLASSES
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h.inc"
-
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/CMakeLists.txt
deleted file mode 100644
index a9139a4..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/CMakeLists.txt
+++ /dev/null
@@ -1,17 +0,0 @@
-add_mlir_library(IREEDialectsIREEPyDMRTLPasses
- LinkageAnalysis.cpp
- LinkRTLPass.cpp
- LowerToRTLPass.cpp
-
- DEPENDS
- MLIRIREEPyDMTransformsPassesIncGen
-
- LINK_LIBS PUBLIC
- IREEDialectsIREEPyDMDialect
- IREEDialectsIREEDialect
- MLIRIR
- MLIRParser
- MLIRTransformUtils
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMRTLPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/CMakeLists.txt
deleted file mode 100644
index 8c332af..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/CMakeLists.txt
+++ /dev/null
@@ -1,18 +0,0 @@
-add_mlir_library(IREEDialectsIREEPyDMToIREEPasses
- ConversionPass.cpp
- LoweringPatterns.cpp
- TypeConverter.cpp
-
- DEPENDS
- MLIRIREEPyDMTransformsPassesIncGen
-
- LINK_LIBS PUBLIC
- IREEDialectsIREEPyDMDialect
- IREEDialectsIREEDialect
- MLIRArithmetic
- MLIRIR
- MLIRStandard
- MLIRTransformUtils
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMToIREEPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/CMakeLists.txt
deleted file mode 100644
index efca191..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/CMakeLists.txt
+++ /dev/null
@@ -1,11 +0,0 @@
-add_mlir_library(IREEDialectsIREEPyDMUtils
- TypeInference.cpp
-
- LINK_LIBS PUBLIC
- IREEDialectsIREEPyDMDialect
- MLIRIR
- MLIRStandard
- MLIRTransformUtils
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMUtils)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/Input/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/Input/CMakeLists.txt
new file mode 100644
index 0000000..b755303
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/Input/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_library(IREEInputDialect
+ InputDialect.cpp
+ InputOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${IREE_DIALECTS_SOURCE_DIR}/include
+
+ DEPENDS
+ IREEInputIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSideEffectInterfaces
+)
+
+iree_dialects_target_includes(IREEInputDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputDialect.cpp
new file mode 100644
index 0000000..ef6a95a
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputDialect.cpp
@@ -0,0 +1,43 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::Input;
+
+#include "iree-dialects/Dialect/Input/InputDialect.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "iree-dialects/Dialect/Input/InputTypes.cpp.inc"
+
+void IREEInputDialect::initialize() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "iree-dialects/Dialect/Input/InputTypes.cpp.inc"
+ >();
+ addOperations<
+#define GET_OP_LIST
+#include "iree-dialects/Dialect/Input/InputOps.cpp.inc"
+ >();
+}
+
+Type IREEInputDialect::parseType(DialectAsmParser &parser) const {
+ StringRef typeTag;
+ Type genType;
+ if (succeeded(parser.parseKeyword(&typeTag)))
+ generatedTypeParser(parser, typeTag, genType);
+ return genType;
+}
+
+void IREEInputDialect::printType(Type type, DialectAsmPrinter &printer) const {
+ (void)generatedTypePrinter(type, printer);
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputOps.cpp
similarity index 92%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/Input/InputOps.cpp
index a723b58..894a74f 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputOps.cpp
@@ -4,16 +4,16 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREE/IREEOps.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
-using namespace mlir::iree;
+using namespace mlir::iree_compiler::IREE::Input;
//===----------------------------------------------------------------------===//
// custom<SymbolVisibility>($sym_visibility)
@@ -91,4 +91,4 @@
}
#define GET_OP_CLASSES
-#include "iree-dialects/Dialect/IREE/IREEOps.cpp.inc"
+#include "iree-dialects/Dialect/Input/InputOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt
similarity index 100%
copy from llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/CMakeLists.txt
copy to llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/CMakeLists.txt
new file mode 100644
index 0000000..6022a57
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/CMakeLists.txt
@@ -0,0 +1,29 @@
+add_mlir_library(IREELinalgExtDialect
+ LinalgExtDialect.cpp
+ LinalgExtInterfaces.cpp
+ LinalgExtOps.cpp
+ TiledOpInterface.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${IREE_DIALECTS_SOURCE_DIR}/include
+
+ DEPENDS
+ IREELinalgExtIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAffine
+ MLIRDialectUtils
+ MLIRIR
+ MLIRLinalg
+ MLIRMath
+ MLIRMemRef
+ MLIRPass
+ MLIRSideEffectInterfaces
+ MLIRSupport
+ MLIRSCF
+ MLIRStandard
+ MLIRTensor
+ MLIRViewLikeInterface
+)
+
+iree_dialects_target_includes(IREELinalgExtDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
new file mode 100644
index 0000000..4657c12
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
@@ -0,0 +1,29 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/SourceMgr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+
+void IREELinalgExtDialect::initialize() {
+ // TODO(hanchung): Add interface to the dialect.
+ // addInterfaces<IREEInlinerInterface>();
+#define GET_OP_LIST
+ addOperations<
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc"
+ >();
+}
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
new file mode 100644
index 0000000..5fdaf8a
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
@@ -0,0 +1,51 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+using namespace IREE::LinalgExt;
+
+OpOperandVector::operator SmallVector<Value>() {
+ SmallVector<Value> result;
+ result.reserve(this->size());
+ llvm::transform(*this, std::back_inserter(result),
+ [](OpOperand *opOperand) { return opOperand->get(); });
+ return result;
+}
+
+LogicalResult IREE::LinalgExt::detail::verifyLinalgExtOpInterface(
+ Operation *op) {
+ LinalgExtOp linalgExtOp = cast<LinalgExtOp>(op);
+ if (op->getNumResults()) {
+ if (!linalgExtOp.hasTensorSemantics()) {
+ return linalgExtOp.emitOpError(
+ "expected inputs and outputs to be RankedTensorType or scalar");
+ }
+
+ if (op->getNumResults() != linalgExtOp.outputs().size()) {
+ return linalgExtOp.emitOpError(
+ "expected number of outputs to be same as the number of results");
+ }
+ for (auto en : llvm::enumerate(op->getResultTypes())) {
+ Type outputType = linalgExtOp.outputs()[en.index()].getType();
+ if (en.value() != outputType) {
+ return linalgExtOp.emitOpError("expected type of `outs` operand #")
+ << en.index() << " " << outputType
+ << " to be same as result type " << en.value();
+ }
+ }
+ } else {
+ if (!linalgExtOp.hasBufferSemantics()) {
+ return linalgExtOp.emitOpError(
+ "expected inputs and outputs to be MemRefType or scalar");
+ }
+ }
+ return success();
+}
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.cpp.inc" // IWYU pragma: export
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
new file mode 100644
index 0000000..ebcb0cd
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -0,0 +1,985 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/SMLoc.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+namespace IREE = mlir::iree_compiler::IREE;
+
+//===----------------------------------------------------------------------===//
+// Utils.
+//===----------------------------------------------------------------------===//
+
+static void getEffectsImpl(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects,
+ ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) {
+ for (Value value : results) {
+ effects.emplace_back(MemoryEffects::Allocate::get(), value,
+ SideEffects::DefaultResource::get());
+ }
+ for (Value value : inputBuffers) {
+ effects.emplace_back(MemoryEffects::Read::get(), value,
+ SideEffects::DefaultResource::get());
+ }
+ for (Value value : outputBuffers) {
+ effects.emplace_back(MemoryEffects::Read::get(), value,
+ SideEffects::DefaultResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), value,
+ SideEffects::DefaultResource::get());
+ }
+}
+
+/// Returns a memref.subview or a tensor.extract_slice based on the type of the
+/// `source`.
+static Value getSlice(OpBuilder &b, Location loc, Value source,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
+ return TypeSwitch<Type, Value>(source.getType())
+ .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+ return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
+ strides);
+ })
+ .Case<MemRefType>([&](MemRefType type) -> Value {
+ return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
+ strides);
+ })
+ .Default([&](Type t) { return nullptr; });
+}
+
+Value IREE::LinalgExt::getDimValue(OpBuilder &builder, Location loc, Value v,
+ int64_t dim) {
+ return TypeSwitch<Type, Value>(v.getType())
+ .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+ return builder.create<tensor::DimOp>(loc, v, dim);
+ })
+ .Case<MemRefType>([&](MemRefType t) -> Value {
+ return builder.create<memref::DimOp>(loc, v, dim);
+ })
+ .Default([&](Type t) { return Value(); });
+}
+
+OpFoldResult IREE::LinalgExt::getDim(OpBuilder &builder, Location loc, Value v,
+ int64_t dim) {
+ auto t = v.getType().cast<ShapedType>();
+ if (t.isDynamicDim(dim)) {
+ return getDimValue(builder, loc, v, dim);
+ }
+ return builder.getI64IntegerAttr(t.getDimSize(dim));
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+static LogicalResult verifyScatterOp(ScatterOp op) {
+ if (op.inputs().size() != 2) {
+ return op.emitOpError("expected two input operands");
+ }
+ if (op.outputs().size() != 1) {
+ return op.emitOpError("expected one output operand");
+ }
+ auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
+ return t1.getShape()[dim] == t2.getShape()[dim];
+ };
+
+ auto indicesType = op.getIndicesType();
+ if (indicesType.getRank() != 2 ||
+ !indicesType.getElementType().isInteger(32)) {
+ return op.emitOpError(
+ "expected indices to be of rank 2 of i32 element type");
+ }
+ auto indexDepth = op.getIndexDepth();
+ if (indexDepth == ShapedType::kDynamicSize) {
+ return op.emitOpError("expected index depth is static");
+ }
+
+ // The first dimension of the indices should match the first dimension of the
+ // output. They indicate to the number of updates.
+ auto updateType = op.getUpdateType();
+ if (updateType.getRank() < 1) {
+ return op.emitOpError("expected update value to be at least rank 1");
+ }
+ if (!checkDimensionsMatch(indicesType, updateType, 0)) {
+ return op.emitOpError(
+ "mismatch in shape of indices and update value at dim#0");
+ }
+ auto originalType = op.getOriginalType();
+ // indexDepth + update dims should match to original dims. The first dim of
+ // update is the number of updates.
+ if (originalType.getRank() != indexDepth + updateType.getRank() - 1) {
+ return op.emitOpError(
+ "mismatch in rank of update value, index depth and original value");
+ }
+ for (auto dim : llvm::seq<unsigned>(indexDepth, originalType.getRank())) {
+ // Offset one because the first dim is the number of updates.
+ if (updateType.getDimSize(1 + dim - indexDepth) !=
+ originalType.getDimSize(dim)) {
+ return op.emitOpError("mismatch in shape of update value dim#")
+ << (1 + dim - indexDepth) << " and original value at dim#" << dim;
+ }
+ }
+ Region ®ion = op.region();
+ Block *body = ®ion.front();
+ if (body->getNumArguments() != 2) {
+ return op.emitOpError("expected region to have two arguments");
+ }
+ Type arg0Type = body->getArgument(0).getType();
+ Type arg1Type = body->getArgument(1).getType();
+ if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
+ return op.emitOpError(
+ "expected region to have scalar argument of integer or float types");
+ }
+ if (arg0Type != updateType.getElementType()) {
+ return op.emitOpError("mismatch in argument 0 of region ")
+ << arg0Type << " and element type of update value "
+ << updateType.getElementType();
+ }
+ if (arg1Type != originalType.getElementType()) {
+ return op.emitOpError("mismatch in argument 1 of region ")
+ << arg1Type << " and element type of original value "
+ << originalType.getElementType();
+ }
+ if (arg0Type != arg1Type) {
+ return op.emitOpError("mismatch in region argument types ")
+ << arg0Type << " and " << arg1Type;
+ }
+ auto yieldOp = cast<IREE::LinalgExt::YieldOp>(body->getTerminator());
+ if (yieldOp->getNumOperands() != 1) {
+ return yieldOp.emitOpError("expected region to yield a single value");
+ }
+ auto yieldedType = yieldOp->getOperand(0).getType();
+ if (yieldedType != arg0Type) {
+ return yieldOp.emitOpError("mismatch in type of yielded value ")
+ << yieldedType << " and argument of the region " << arg0Type;
+ }
+ return success();
+}
+
+SmallVector<StringRef> ScatterOp::getLoopIteratorTypes() {
+ SmallVector<StringRef> iteratorTypes(getUpdateType().getRank(),
+ getParallelIteratorTypeName());
+ return iteratorTypes;
+}
+
+SmallVector<Range> ScatterOp::getLoopBounds(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<Range> ranges;
+ for (auto dim : llvm::seq<int64_t>(0, getUpdateType().getRank())) {
+ Value ub = getDimValue(builder, loc, updates(), dim);
+ ranges.emplace_back(Range{zero, ub, one});
+ }
+ return ranges;
+}
+
+Operation *ScatterOp::getTiledImplementation(OpBuilder &builder,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) {
+ assert(outputs.size() >= 1 && offsets.size() >= 1 && sizes.size() >= 1);
+ Location loc = getLoc();
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ auto oneAttr = builder.getI64IntegerAttr(1);
+
+ // Slice of the updates.
+ auto updateRank = getUpdateType().getRank();
+ SmallVector<OpFoldResult> updateStrides(updateRank, oneAttr);
+ Value tiledUpdate =
+ getSlice(builder, loc, updates(), offsets, sizes, updateStrides);
+ assert(tiledUpdate && "failed to get slice of update");
+
+ // Slice of indices.
+ auto indicesRank = getIndicesType().getRank();
+ SmallVector<OpFoldResult> indicesOffsets(indicesRank, zeroAttr);
+ SmallVector<OpFoldResult> indicesSizes(indicesRank);
+ indicesOffsets[0] = offsets[0];
+ indicesSizes[0] = sizes[0];
+ for (auto dim : llvm::seq<int64_t>(1, indicesRank)) {
+ indicesSizes[dim] = getDim(builder, loc, indices(), dim);
+ }
+ SmallVector<OpFoldResult> indicesStrides(indicesRank, oneAttr);
+ Value tiledIndices = getSlice(builder, loc, indices(), indicesOffsets,
+ indicesSizes, indicesStrides);
+ assert(tiledIndices && "failed to get slice of indices");
+
+ // Slice of the original.
+ auto originalRank = getOriginalType().getRank();
+ SmallVector<OpFoldResult> originalOffsets(originalRank, zeroAttr);
+ SmallVector<OpFoldResult> originalSizes(originalRank);
+ for (auto dim : llvm::seq<int64_t>(0, originalRank - updateRank + 1)) {
+ originalSizes[dim] = getDim(builder, loc, original(), dim);
+ }
+ for (auto dim :
+ llvm::seq<int64_t>(originalRank - updateRank + 1, originalRank)) {
+ originalOffsets[dim] = offsets[dim - (originalRank - updateRank)];
+ originalSizes[dim] = sizes[dim - (originalRank - updateRank)];
+ }
+ SmallVector<OpFoldResult> originalStrides(originalRank, oneAttr);
+ Value tiledOriginal = getSlice(builder, loc, outputs[0], originalOffsets,
+ originalSizes, originalStrides);
+ assert(tiledOriginal && "failed to get slice of original tensor");
+
+ SmallVector<Type> resultTypes;
+ if (getNumResults()) {
+ resultTypes.push_back(tiledOriginal.getType());
+ }
+ Operation *tiledScatterOp =
+ cast<LinalgExtOp>(getOperation())
+ .clone(builder, loc, resultTypes,
+ ValueRange{tiledUpdate, tiledIndices, tiledOriginal});
+ for (auto result : llvm::enumerate(tiledScatterOp->getResults())) {
+ auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+ loc, result.value(), outputs[0], originalOffsets, originalSizes,
+ originalStrides);
+ results.push_back(insertSliceOp.getResult());
+ }
+ return tiledScatterOp;
+}
+
+LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
+ Location loc,
+ ValueRange ivs) {
+ auto indexDepth = getIndexDepth();
+ Value update = b.create<memref::LoadOp>(loc, updates(), ivs);
+ SmallVector<Value> starts;
+ SmallVector<Value> loadIndices;
+ loadIndices.push_back(ivs.front());
+ loadIndices.push_back(Value());
+ for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
+ loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
+ Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
+ starts.push_back(b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx));
+ }
+ starts.append(std::next(ivs.begin()), ivs.end());
+ Value init = b.create<memref::LoadOp>(loc, original(), starts);
+
+ BlockAndValueMapping bvm;
+ Block &block = region().front();
+ bvm.map(block.getArgument(0), update);
+ bvm.map(block.getArgument(1), init);
+ for (auto &blockOp : block.without_terminator()) {
+ b.clone(blockOp, bvm);
+ }
+ // The last op is linalg_ext.yield op. Store the operand to
+ // destination.
+ b.create<memref::StoreOp>(
+ loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)),
+ original(), starts);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SortOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifySortOp(SortOp op) {
+ if (op.getNumInputs()) {
+ return op.emitOpError("does not expect to take any inputs");
+ }
+ if (op.getNumOutputs() == 0) {
+ return op.emitOpError("expected at least one `outs` operand");
+ }
+
+ Block &block = op.region().front();
+ size_t numOutputs = op.getNumOutputs();
+ if (block.getNumArguments() != 2 * numOutputs) {
+ return op.emitOpError("region block should have ")
+ << 2 * numOutputs << " arguments";
+ }
+
+ int64_t rank = op.getOperandRank();
+ ArrayRef<int64_t> shape = op.getOperandShape();
+ if (rank > 1 && !op.dimensionAttr()) {
+ return op.emitOpError("dimension must be specified if rank > 1");
+ }
+ int dimension = 0;
+ if (op.dimensionAttr()) {
+ dimension = op.dimension().getValue();
+ }
+ if (dimension < 0 || dimension >= rank) {
+ return op.emitOpError("dimension must be within (0, ") << rank << "]";
+ }
+
+ for (auto indexedOperand : llvm::enumerate(op.outputs())) {
+ int index = indexedOperand.index();
+ auto operandType = op.getOperandType(index);
+ if (operandType.getRank() != rank) {
+ return op.emitOpError("expected operand ")
+ << index << " to be rank " << rank << ", same as other operands";
+ }
+ if (operandType.getShape() != shape) {
+ return op.emitOpError("expected operand ")
+ << index << " to have same shape as other operands";
+ }
+ Type elemType = operandType.getElementType();
+ for (int i : {2 * index, 2 * index + 1}) {
+ Type argType = block.getArgument(i).getType();
+ if (argType != elemType) {
+ return op.emitOpError("region block argument #")
+ << i << " should be of type " << elemType << " but got "
+ << argType;
+ }
+ }
+ }
+
+ auto yieldOp = cast<YieldOp>(block.getTerminator());
+ if (yieldOp.getNumOperands() != 1) {
+ return op.emitOpError("should yield exactly one operand");
+ }
+ auto ty = yieldOp.getOperand(0).getType().dyn_cast<IntegerType>();
+ if (!ty || ty.getWidth() != 1) {
+ return op.emitOpError("should yield i1 type");
+ }
+
+ return success();
+}
+
+SmallVector<StringRef> SortOp::getLoopIteratorTypes() {
+ // All loops except the dimension to sort along are parallel.
+ SmallVector<StringRef> iteratorTypes(getOperandRank(),
+ getParallelIteratorTypeName());
+ iteratorTypes[getSortedDimension()] = getReductionIteratorTypeName();
+ return iteratorTypes;
+}
+
+SmallVector<Range> SortOp::getLoopBounds(OpBuilder &builder) {
+ int64_t operandRank = getOperandRank();
+ SmallVector<Range> loopBounds(operandRank);
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ Value source = operand(0);
+ for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
+ loopBounds[dim].offset = zero;
+ loopBounds[dim].size = getDimValue(builder, loc, source, dim);
+ loopBounds[dim].stride = one;
+ }
+ return loopBounds;
+}
+
+SmallVector<unsigned> SortOp::getPartitionableLoops(
+ unsigned maxNumParallelDims) {
+ auto range = llvm::seq<unsigned>(0, getOperandRank());
+ SmallVector<unsigned> partitionableLoops(range.begin(), range.end());
+ partitionableLoops.erase(
+ std::next(partitionableLoops.begin(), getSortedDimension()));
+ if (partitionableLoops.size() > maxNumParallelDims) {
+ partitionableLoops.erase(
+ partitionableLoops.begin(),
+ std::next(partitionableLoops.begin(),
+ partitionableLoops.size() - maxNumParallelDims));
+ }
+ return partitionableLoops;
+}
+
+Operation *SortOp::getTiledImplementation(OpBuilder &builder,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) {
+ assert(outputs.size() == this->outputs().size());
+ int64_t rank = getOperandRank();
+ assert(offsets.size() == static_cast<size_t>(rank) &&
+ sizes.size() == static_cast<size_t>(rank));
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ SmallVector<OpFoldResult> strides(rank, oneAttr);
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands(outputs.size());
+ for (auto en : llvm::enumerate(outputs)) {
+ tiledOperands[en.index()] =
+ getSlice(builder, getLoc(), en.value(), offsets, sizes, strides);
+ assert(tiledOperands[en.index()] && "failed to get slice of operand");
+ }
+ SmallVector<Type, 4> resultTypes;
+ if (getNumResults()) {
+ resultTypes = llvm::to_vector<4>(
+ llvm::map_range(tiledOperands, [&](Value v) { return v.getType(); }));
+ }
+ Operation *tiledSortOp = cast<LinalgExtOp>(getOperation())
+ .clone(builder, loc, resultTypes, tiledOperands);
+ for (auto result : llvm::enumerate(tiledSortOp->getResults())) {
+ auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+ loc, result.value(), outputs[result.index()], offsets, sizes, strides);
+ results.push_back(insertSliceOp.getResult());
+ }
+ return tiledSortOp;
+}
+
+LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc,
+ ValueRange ivs) {
+ auto sortDim = getSortedDimension();
+ SmallVector<Value> indices, sortBlkArgs;
+ indices.append(ivs.begin(), ivs.end());
+ // Bubble sort innermost loop.
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ Value ub;
+ if (getOperandType(0).isDynamicDim(sortDim)) {
+ ub = b.create<memref::DimOp>(loc, operand(0), sortDim);
+ } else {
+ ub = b.create<arith::ConstantIndexOp>(
+ loc, getOperandType(0).getDimSize(sortDim));
+ }
+ ub = b.create<arith::SubIOp>(loc, ub, one);
+ auto scfFor = b.create<scf::ForOp>(
+ loc, zero, ub, one, ValueRange{},
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iters) {
+ SmallVector<Value> indices(ivs);
+ Value ivPlusOne = b.create<arith::AddIOp>(loc, iv, one);
+ for (auto output : getOutputOperands()) {
+ indices[sortDim] = iv;
+ sortBlkArgs.push_back(
+ b.create<memref::LoadOp>(loc, output->get(), indices));
+ indices[sortDim] = ivPlusOne;
+ sortBlkArgs.push_back(
+ b.create<memref::LoadOp>(loc, output->get(), indices));
+ }
+ });
+
+ auto &srcBlock = region().front();
+ Region ®ion = scfFor.region();
+ BlockAndValueMapping bvm;
+ {
+ OpBuilder::InsertionGuard guard(b);
+ auto &block = region.front();
+ b.setInsertionPointToEnd(&block);
+ for (auto it : llvm::zip(srcBlock.getArguments(), sortBlkArgs)) {
+ bvm.map(std::get<0>(it), std::get<1>(it));
+ }
+ for (auto &blockOp : srcBlock.without_terminator()) {
+ b.clone(blockOp, bvm);
+ }
+ }
+ Value cond = bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0));
+
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToEnd(®ion.front());
+ b.create<scf::IfOp>(
+ loc, TypeRange{}, cond,
+ [&](OpBuilder &b, Location loc) {
+ // Do not swap the pairs if true.
+ b.create<scf::YieldOp>(loc);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // Swap the pairs if false.
+ SmallVector<Value> indices(ivs.begin(), ivs.end());
+ Value ivPlusOne =
+ b.create<arith::AddIOp>(loc, scfFor.getInductionVar(), one);
+ for (int i = 0, e = getNumOutputs(); i < e; ++i) {
+ Value v1 = sortBlkArgs[i * 2];
+ Value v2 = sortBlkArgs[i * 2 + 1];
+ indices[sortDim] = scfFor.getInductionVar();
+ b.create<memref::StoreOp>(loc, v2, getOutputOperand(i)->get(),
+ indices);
+ indices[sortDim] = ivPlusOne;
+ b.create<memref::StoreOp>(loc, v1, getOutputOperand(i)->get(),
+ indices);
+ }
+ b.create<scf::YieldOp>(loc);
+ });
+ b.create<scf::YieldOp>(loc);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FftOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyFftOp(FftOp op) {
+ auto length = op.getFftLength();
+ // After tiling, it could be dynamic shape. (Because
+ // subview/subtensor does not inference the type correctly
+ // on (1 << x)) cases).
+ if (length == ShapedType::kDynamicSize) return success();
+ if (length & (length - 1)) {
+ return op.emitOpError("only powers of 2 are handled currently");
+ }
+ if (!op.getNumInputs() || !op.isScalar(op.getInputOperand(0))) {
+ return op.emitOpError("expected to carry `stage` input");
+ }
+ if (op.getNumInputs() != 1) {
+ if (op.getNumInputs() != 3 || op.isScalar(op.getInputOperand(1)) ||
+ op.isScalar(op.getInputOperand(2))) {
+ return op.emitOpError("expected to carry real and imag coeff inputs");
+ }
+ }
+ if (op.getNumOutputs() != 2) {
+ return op.emitOpError("expected outputs to be real and imag tensor/memref");
+ }
+ return success();
+}
+
+SmallVector<StringRef> FftOp::getLoopIteratorTypes() {
+ // There are `rank-1` outer loops. The fft itselfs has one loop for each
+ // stage, which handles the merge step -- taking two half size tensors and
+ // merge them into one tensor.
+ SmallVector<StringRef> iteratorTypes(getOperandRank(),
+ getParallelIteratorTypeName());
+ return iteratorTypes;
+}
+
+SmallVector<Range> FftOp::getLoopBounds(OpBuilder &builder) {
+ SmallVector<Range> res;
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ for (auto en : llvm::enumerate(getOperandShape().drop_back())) {
+ Value size;
+ if (en.value() == ShapedType::kDynamicSize) {
+ size = getDimValue(builder, loc, getReal(), en.index());
+ } else {
+ size = builder.create<arith::ConstantIndexOp>(loc, en.value());
+ }
+ res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/one});
+ }
+
+ Value size = getDimValue(builder, loc, getReal(), getOperandRank() - 1);
+ Value stride = builder.create<arith::ShLIOp>(loc, one, getStage());
+ res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/stride});
+ return res;
+}
+
+void FftOp::generateScalarImplWithoutCoeffBuf(OpBuilder &b, Location loc,
+ ArrayRef<Value> operands,
+ Value wholeSize) {
+ auto rank = getOperandRank();
+ SmallVector<AffineMap> maps(operands.size(), b.getMultiDimIdentityMap(rank));
+
+ auto f32Type = b.getF32Type();
+ auto indexToF32 = [](OpBuilder &builder, Location loc, Value v) -> Value {
+ v = builder.create<arith::IndexCastOp>(loc, builder.getI32Type(), v);
+ return builder.create<arith::SIToFPOp>(loc, builder.getF32Type(), v);
+ };
+
+ // We will need exp(-2 * PI * j / m * I), compute "-2 * PI / m" for imag part
+ // first.
+ Value coeff = b.create<arith::ConstantFloatOp>(
+ loc, llvm::APFloat(static_cast<float>(-2 * acos(-1))), f32Type);
+ coeff = b.create<arith::DivFOp>(loc, coeff, indexToF32(b, loc, wholeSize));
+
+ b.create<linalg::GenericOp>(
+ loc, TypeRange{}, ValueRange{}, operands, maps, getLoopIteratorTypes(),
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value lhsReal = args[0];
+ Value lhsImag = args[1];
+ Value rhsReal = args[2];
+ Value rhsImag = args[3];
+
+ // Compute "-2 * PI / m * j"
+ Value w = b.create<arith::MulFOp>(
+ loc, coeff,
+ indexToF32(b, loc, b.create<linalg::IndexOp>(loc, rank - 1)));
+ Value wReal = b.create<math::CosOp>(loc, w);
+ Value wImag = b.create<math::SinOp>(loc, w);
+
+ // t = w * a[k + j + mh];
+ // -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
+ Value xu = b.create<arith::MulFOp>(loc, wReal, rhsReal);
+ Value yv = b.create<arith::MulFOp>(loc, wImag, rhsImag);
+ Value xv = b.create<arith::MulFOp>(loc, wReal, rhsImag);
+ Value yu = b.create<arith::MulFOp>(loc, wImag, rhsReal);
+ Value tReal = b.create<arith::SubFOp>(loc, xu, yv);
+ Value tImag = b.create<arith::AddFOp>(loc, xv, yu);
+
+ // cplx u = a[k + j];
+ // a[k + j] = u + t;
+ // a[k + j + mh] = u - t;
+ Value r1 = b.create<arith::AddFOp>(loc, lhsReal, tReal);
+ Value r2 = b.create<arith::AddFOp>(loc, lhsImag, tImag);
+ Value r3 = b.create<arith::SubFOp>(loc, lhsReal, tReal);
+ Value r4 = b.create<arith::SubFOp>(loc, lhsImag, tImag);
+ b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4});
+ });
+}
+
+void FftOp::generateScalarImplWithCoeffBuf(OpBuilder &b, Location loc,
+ ArrayRef<Value> operands) {
+ auto rank = getOperandRank();
+ SmallVector<AffineMap> maps;
+ // The size of coefficent buffer is epxected to match `2^(stage-1)`, which
+ // equals to the last dim of operands.
+ maps.append(
+ 2, AffineMap::get(rank, 0, b.getAffineDimExpr(rank - 1), b.getContext()));
+ maps.append(operands.size(), b.getMultiDimIdentityMap(rank));
+
+ b.create<linalg::GenericOp>(
+ loc, TypeRange{}, ValueRange{getRealCoeff(), getImagCoeff()}, operands,
+ maps, getLoopIteratorTypes(),
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value wReal = args[0];
+ Value wImag = args[1];
+ Value lhsReal = args[2];
+ Value lhsImag = args[3];
+ Value rhsReal = args[4];
+ Value rhsImag = args[5];
+
+ // t = w * a[k + j + mh];
+ // -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
+ Value xu = b.create<arith::MulFOp>(loc, wReal, rhsReal);
+ Value yv = b.create<arith::MulFOp>(loc, wImag, rhsImag);
+ Value xv = b.create<arith::MulFOp>(loc, wReal, rhsImag);
+ Value yu = b.create<arith::MulFOp>(loc, wImag, rhsReal);
+ Value tReal = b.create<arith::SubFOp>(loc, xu, yv);
+ Value tImag = b.create<arith::AddFOp>(loc, xv, yu);
+
+ // cplx u = a[k + j];
+ // a[k + j] = u + t;
+ // a[k + j + mh] = u - t;
+ Value r1 = b.create<arith::AddFOp>(loc, lhsReal, tReal);
+ Value r2 = b.create<arith::AddFOp>(loc, lhsImag, tImag);
+ Value r3 = b.create<arith::SubFOp>(loc, lhsReal, tReal);
+ Value r4 = b.create<arith::SubFOp>(loc, lhsImag, tImag);
+ b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4});
+ });
+}
+
+// Generates FFT stage scalar implementation. This follows Cooley–Tukey FFT
+// algorithm. The pseudo reference code is:
+// let s <- stage of linalg_ext.fft
+// int m = 1 << s;
+// int mh = m >> 1;
+// for (int k = 0; k < n; k += m) {
+// for (int j = 0; j < mh; ++j) {
+// cplx w = exp(-2 * PI * j / m * I);
+// cplx t = w * a[k + j + mh];
+// cplx u = a[k + j];
+// a[k + j] = u + t;
+// a[k + j + mh] = u - t;
+// }
+// }
+LogicalResult FftOp::generateScalarImplementation(OpBuilder &b, Location loc,
+ ValueRange ivs) {
+ Value real = getReal();
+ Value imag = getImag();
+ Value stage = getStage();
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ Value wholeSize = b.create<arith::ShLIOp>(loc, one, stage);
+ Value halfSize = b.create<arith::ShRSIOp>(loc, wholeSize, one);
+
+ auto rank = getOperandRank();
+ SmallVector<Value> operands;
+ SmallVector<OpFoldResult> lhsIvs(ivs.begin(), ivs.end());
+ SmallVector<OpFoldResult> ones(rank, b.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes(rank, b.getIndexAttr(1));
+ sizes.back() = halfSize;
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, real, lhsIvs, sizes, ones));
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, imag, lhsIvs, sizes, ones));
+
+ SmallVector<OpFoldResult> rhsIvs(ivs.begin(), ivs.end());
+ rhsIvs.back() =
+ b.create<arith::AddIOp>(loc, ivs.back(), halfSize).getResult();
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, real, rhsIvs, sizes, ones));
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, imag, rhsIvs, sizes, ones));
+
+ if (hasCoeff()) {
+ generateScalarImplWithCoeffBuf(b, loc, operands);
+ } else {
+ generateScalarImplWithoutCoeffBuf(b, loc, operands, wholeSize);
+ }
+
+ return success();
+}
+
+bool FftOp::payloadUsesValueFromOperand(OpOperand *) { return false; }
+
+SmallVector<unsigned> FftOp::getPartitionableLoops(
+ unsigned maxNumParallelDims) {
+ auto range = llvm::seq<unsigned>(0, getOperandRank());
+ SmallVector<unsigned> partitionableLoops(range.begin(), range.end());
+ // Indices matter for coeff computation.
+ if (!hasCoeff()) {
+ partitionableLoops.pop_back();
+ }
+ if (partitionableLoops.size() > maxNumParallelDims) {
+ partitionableLoops.erase(
+ partitionableLoops.begin(),
+ std::next(partitionableLoops.begin(),
+ partitionableLoops.size() - maxNumParallelDims));
+ }
+ return partitionableLoops;
+}
+
+Operation *FftOp::getTiledImplementation(OpBuilder &builder, ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) {
+ int64_t rank = getOperandRank();
+ SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands(3);
+ tiledOperands[0] = getStage();
+ tiledOperands[1] = getRealCoeff();
+ tiledOperands[2] = getImagCoeff();
+ SmallVector<Type, 4> resultTypes;
+
+ for (auto out : outputs) {
+ tiledOperands.push_back(
+ getSlice(builder, getLoc(), out, offsets, sizes, strides));
+ if (hasTensorSemantics()) {
+ resultTypes.push_back(tiledOperands.back().getType());
+ }
+ }
+ Operation *tiledFftOp = cast<LinalgExtOp>(getOperation())
+ .clone(builder, loc, resultTypes, tiledOperands);
+ for (auto result : llvm::enumerate(tiledFftOp->getResults())) {
+ auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+ loc, result.value(), outputs[result.index()], offsets, sizes, strides);
+ results.push_back(insertSliceOp.getResult());
+ }
+ return tiledFftOp;
+}
+
+//===----------------------------------------------------------------------===//
+// ReverseOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyReverseOp(ReverseOp op) {
+ if (op.getNumInputs() != 1) {
+ return op.emitOpError("expected exactly one input");
+ }
+ if (op.getNumOutputs() != 1) {
+ return op.emitOpError("expected exactly one output");
+ }
+ auto inputType = op.input().getType().cast<ShapedType>();
+ auto outputType = op.output().getType().cast<ShapedType>();
+ if (inputType.getElementType() != outputType.getElementType()) {
+ return op.emitOpError(
+ "expected input/output element types to be identical");
+ }
+ ArrayRef<int64_t> inputShapes = inputType.getShape();
+ ArrayRef<int64_t> outputShapes = outputType.getShape();
+ if (inputShapes.size() != outputShapes.size()) {
+ return op.emitOpError("expexted input/output to have identical ranks");
+ }
+ if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
+ [](std::tuple<int64_t, int64_t> s) {
+ return std::get<0>(s) != ShapedType::kDynamicSize &&
+ std::get<1>(s) != ShapedType::kDynamicSize &&
+ std::get<0>(s) != std::get<1>(s);
+ })) {
+ return op.emitOpError("incompatible input/output shapes");
+ }
+
+ int64_t rank = op.getOperandRank();
+ llvm::SmallSetVector<int64_t, 4> s;
+ for (auto dim : op.dims()) {
+ if (dim < 0 || dim >= rank) {
+ return op.emitOpError("all the dimensions must be within [0, ")
+ << rank << ")";
+ }
+ if (s.contains(dim)) {
+ return op.emitOpError("expected dimensions numbers are all unique");
+ }
+ s.insert(dim);
+ }
+
+ return success();
+}
+
+bool ReverseOp::payloadUsesValueFromOperand(OpOperand *) { return false; }
+
+SmallVector<StringRef> ReverseOp::getLoopIteratorTypes() {
+ SmallVector<StringRef> iteratorTypes(getOperandRank(),
+ getParallelIteratorTypeName());
+ return iteratorTypes;
+}
+
+SmallVector<Range> ReverseOp::getLoopBounds(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<Range> ranges;
+ for (auto dim : llvm::seq<int64_t>(0, getOperandRank())) {
+ Value ub = getDimValue(builder, loc, input(), dim);
+ ranges.emplace_back(Range{zero, ub, one});
+ }
+ return ranges;
+}
+
+LogicalResult ReverseOp::generateScalarImplementation(OpBuilder &b,
+ Location loc,
+ ValueRange ivs) {
+ SmallVector<Value> mirrorIndices(ivs.begin(), ivs.end());
+ for (auto dim : dims()) {
+ auto size = getDimValue(b, loc, input(), dim);
+ size = b.create<arith::SubIOp>(loc, size,
+ b.create<arith::ConstantIndexOp>(loc, 1));
+ mirrorIndices[dim] = b.create<arith::SubIOp>(loc, size, mirrorIndices[dim]);
+ }
+ Value val = b.create<memref::LoadOp>(loc, input(), ivs);
+ b.create<memref::StoreOp>(loc, val, output(), mirrorIndices);
+ return success();
+}
+
+Operation *ReverseOp::getTiledImplementation(OpBuilder &builder,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) {
+ int64_t rank = getOperandRank();
+ SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands;
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, input(), offsets, sizes, strides));
+
+ AffineExpr sym0, sym1, sym2;
+ bindSymbols(builder.getContext(), sym0, sym1, sym2);
+ AffineMap map =
+ AffineMap::get(/*dimCount=*/0, /*symbolCount=*/3, {sym0 - sym1 - sym2});
+ SmallVector<OpFoldResult> mirrorOffsets(offsets.begin(), offsets.end());
+ for (auto dim : dims()) {
+ Value size = getDimValue(builder, loc, input(), dim);
+ Value offset =
+ getValueOrCreateConstantIndexOp(builder, loc, mirrorOffsets[dim]);
+ Value tileSize = getValueOrCreateConstantIndexOp(builder, loc, sizes[dim]);
+ mirrorOffsets[dim] =
+ builder
+ .create<AffineApplyOp>(loc, map, ValueRange{size, offset, tileSize})
+ .getResult();
+ }
+
+ SmallVector<Type, 4> resultTypes;
+ if (hasTensorSemantics()) {
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, output(), mirrorOffsets, sizes, strides));
+ resultTypes.push_back(tiledOperands[1].getType());
+ } else {
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, output(), mirrorOffsets, sizes, strides));
+ }
+
+ Operation *tiledRevOp = cast<LinalgExtOp>(getOperation())
+ .clone(builder, loc, resultTypes, tiledOperands);
+
+ for (auto result : llvm::enumerate(tiledRevOp->getResults())) {
+ auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+ loc, result.value(), outputs[result.index()], mirrorOffsets, sizes,
+ strides);
+ results.push_back(insertSliceOp.getResult());
+ }
+ return tiledRevOp;
+}
+
+#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
+ void OP_NAME::getEffects( \
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
+ &effects) { \
+ SmallVector<Value> inputBuffers = getInputBufferOperands(); \
+ SmallVector<Value> outputBuffers = getOutputBufferOperands(); \
+ getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \
+ outputBuffers); \
+ }
+
+DEFINE_OP_GET_EFFECTS(ScatterOp)
+DEFINE_OP_GET_EFFECTS(SortOp)
+DEFINE_OP_GET_EFFECTS(FftOp)
+DEFINE_OP_GET_EFFECTS(ReverseOp)
+
+namespace {
+/// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any
+/// changes.
+struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgExtOp> {
+ using OpInterfaceRewritePattern<LinalgExtOp>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(LinalgExtOp op,
+ PatternRewriter &rewriter) const override {
+ // If no operand comes from a tensor::CastOp and can be folded then fail.
+ bool hasTensorCastOperand =
+ llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
+ if (opOperand->get().isa<BlockArgument>()) return false;
+ auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+ return castOp && canFoldIntoConsumerOp(castOp);
+ });
+ if (!hasTensorCastOperand) return failure();
+
+ SmallVector<Type, 4> newResultTypes;
+ newResultTypes.reserve(op->getNumResults());
+ SmallVector<Value, 4> newOperands;
+ newOperands.reserve(op->getNumOperands());
+ // Inputs may fold.
+ for (OpOperand *opOperand : op.getInputOperands()) {
+ auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+ newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
+ ? tensorCastOp.source()
+ : opOperand->get());
+ }
+ // Init tensors may fold, in which case the resultType must also change.
+ for (OpOperand *opOperand : op.getOutputOperands()) {
+ auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+ bool fold = canFoldIntoConsumerOp(tensorCastOp);
+ newOperands.push_back(fold ? tensorCastOp.getOperand()
+ : opOperand->get());
+ newResultTypes.push_back(newOperands.back().getType());
+ }
+ // Clone op.
+ Operation *newOp =
+ op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
+ SmallVector<Value, 4> replacements;
+ replacements.reserve(newOp->getNumResults());
+ for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
+ Value oldResult = std::get<0>(result);
+ Value newResult = std::get<1>(result);
+ if (newResult.getType() != oldResult.getType()) {
+ replacements.push_back(rewriter.create<tensor::CastOp>(
+ op->getLoc(), oldResult.getType(), newResult));
+ } else {
+ replacements.push_back(newResult);
+ }
+ }
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// LinalgExtDialect
+//===----------------------------------------------------------------------===//
+
+void IREELinalgExtDialect::getCanonicalizationPatterns(
+ RewritePatternSet &results) const {
+ results.add<FoldTensorCastOp>(getContext());
+}
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/TiledOpInterface.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/TiledOpInterface.cpp
new file mode 100644
index 0000000..105c394
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/TiledOpInterface.cpp
@@ -0,0 +1,310 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h"
+
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+
+#define DEBUG_TYPE "iree-tiled-op-interface"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+using namespace IREE::LinalgExt;
+
+#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.cpp.inc"
+
+/// Converts an `OpFoldResult` to a `Value` by building a constant op if
+/// if the `OpFoldResult` is an `IntegerAttr`.
+static Value getValue(OpBuilder &builder, Location loc,
+ OpFoldResult valueOrAttr) {
+ if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+ return builder.create<arith::ConstantIndexOp>(
+ loc, attr.cast<IntegerAttr>().getInt());
+ }
+ return valueOrAttr.get<Value>();
+}
+
+//===----------------------------------------------------------------------===//
+// Interface implementations for external operations.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// External model for `tensor.extract_slice`.
+struct ExtractSliceTiledOpInterface
+ : public TiledOpInterface::ExternalModel<ExtractSliceTiledOpInterface,
+ tensor::ExtractSliceOp> {
+ SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+ // No operand of `tensor.extract_slice` serves as a destination operand. So
+ // create an `init_tensor` op of the same size as the result.
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ SmallVector<Value> dest;
+ ReifiedRankedShapedTypeDims returnShape;
+ (void)extractSliceOp.reifyResultShapes(b, returnShape);
+ auto ofrShape = llvm::to_vector<4>(llvm::map_range(
+ returnShape[0], [](Value v) { return getAsOpFoldResult(v); }));
+ Value initTensor = b.create<linalg::InitTensorOp>(
+ op->getLoc(), ofrShape, extractSliceOp.getType().getElementType());
+ return {initTensor};
+ }
+
+ SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ return SmallVector<StringRef>(extractSliceOp.getType().getRank(),
+ getParallelIteratorTypeName());
+ }
+
+ SmallVector<Range> getLoopBounds(Operation *op, OpBuilder &b) const {
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ SmallVector<Value> dest;
+ ReifiedRankedShapedTypeDims returnShape;
+ (void)extractSliceOp.reifyResultShapes(b, returnShape);
+ Location loc = op->getLoc();
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<Range> loopRanges(returnShape[0].size(),
+ Range{zero, nullptr, one});
+ for (auto ub : enumerate(returnShape[0])) {
+ loopRanges[ub.index()].size = ub.value();
+ }
+ return loopRanges;
+ }
+
+ Operation *getTiledImplementation(Operation *op, OpBuilder &b,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) const {
+ auto extractOp = cast<tensor::ExtractSliceOp>(op);
+ // Check that strides are 1. For now abort if they arent
+ Location loc = extractOp.getLoc();
+ auto oneAttr = b.getI64IntegerAttr(1);
+
+ // Compute the offset and sizes for the tiled `tensor.extract_slice`
+ // operation.
+ llvm::SmallDenseSet<unsigned> droppedDims = extractOp.getDroppedDims();
+ unsigned resultDimPos = 0;
+ auto opOffsets = extractOp.getMixedOffsets();
+ auto opSizes = extractOp.getMixedSizes();
+ auto opStrides = extractOp.getMixedStrides();
+ MLIRContext *context = b.getContext();
+ SmallVector<OpFoldResult> newOffset, newSizes, newStrides;
+ for (auto opOffset : enumerate(opOffsets)) {
+ // If the dimension is dropped, use the same offset.
+ if (droppedDims.count(opOffset.index())) {
+ newOffset.push_back(opOffset.value());
+ newSizes.push_back(opSizes[opOffset.index()]);
+ } else {
+ AffineExpr d0, s0, s1;
+ bindDims(context, d0);
+ bindSymbols(context, s0, s1);
+ AffineMap map = AffineMap::get(1, 2, d0 * s0 + s1);
+ SmallVector<Value> operands = {
+ getValue(b, loc, offsets[resultDimPos]),
+ getValue(b, loc, opStrides[opOffset.index()]),
+ getValue(b, loc, opOffset.value())};
+ Value offset = b.create<AffineApplyOp>(loc, map, operands);
+ newOffset.push_back(offset);
+ newSizes.push_back(sizes[resultDimPos]);
+ resultDimPos++;
+ }
+ newStrides.push_back(opStrides[opOffset.index()]);
+ }
+
+ // Generate the tiled `tensor.extract_slice` operation.
+ Type resultType = tensor::ExtractSliceOp::inferRankReducedResultType(
+ extractOp.getType().getRank(), extractOp.getSourceType(), newOffset,
+ newSizes, newStrides);
+ auto tiledExtractOp = b.create<tensor::ExtractSliceOp>(
+ loc, resultType.cast<RankedTensorType>(), extractOp.source(), newOffset,
+ newSizes, newStrides);
+
+ // Insert the tiled extract into the result tensor.
+ SmallVector<OpFoldResult> resultStrides(offsets.size(), oneAttr);
+ auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
+ loc, tiledExtractOp.result(), outputs[0], offsets, sizes,
+ resultStrides);
+ results.push_back(tiledInsertOp.result());
+ return tiledExtractOp;
+ }
+};
+
+struct InsertSliceTiledOpInterface
+ : public TiledOpInterface::ExternalModel<InsertSliceTiledOpInterface,
+ tensor::InsertSliceOp> {
+ SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+ SmallVector<Value> dest;
+ dest.push_back(cast<tensor::InsertSliceOp>(op).dest());
+ return dest;
+ }
+
+ SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+ auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ return SmallVector<StringRef>(insertSliceOp.getSourceType().getRank(),
+ getParallelIteratorTypeName());
+ }
+
+ SmallVector<Range> getLoopBounds(Operation *op, OpBuilder &b) const {
+ auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ Value source = insertSliceOp.source();
+ RankedTensorType sourceType = insertSliceOp.getSourceType();
+ Location loc = op->getLoc();
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<Range> loopBounds(sourceType.getRank(),
+ Range{zero, nullptr, one});
+ for (auto dim :
+ llvm::seq<int64_t>(0, insertSliceOp.getSourceType().getRank())) {
+ loopBounds[dim].size = b.create<tensor::DimOp>(loc, source, dim);
+ }
+ return loopBounds;
+ }
+
+ Operation *getTiledImplementation(Operation *op, OpBuilder &b,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) const {
+ auto insertOp = cast<tensor::InsertSliceOp>(op);
+ // Compute a subtensor of the source based on the offsets.
+ auto opStrides = insertOp.getMixedStrides();
+ if (!llvm::all_of(opStrides, [&](OpFoldResult valueOrAttr) {
+ Optional<int64_t> intVal = getConstantIntValue(valueOrAttr);
+ return intVal && *intVal == 1;
+ })) {
+ op->emitOpError("unable to tile operation with non-unit stride");
+ return nullptr;
+ }
+ MLIRContext *context = b.getContext();
+ Location loc = insertOp.getLoc();
+ auto oneAttr = b.getI64IntegerAttr(1);
+ SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
+ auto extractSliceOp = b.create<tensor::ExtractSliceOp>(
+ loc, insertOp.source(), offsets, sizes, strides);
+
+ // The offsets for the insert is based on the op offsets plus the offsets of
+ // the loops passed in.
+ auto opOffsets = insertOp.getMixedOffsets();
+ auto opSizes = insertOp.getMixedSizes();
+ unsigned offsetIndex = 0;
+ ArrayRef<int64_t> sourceShape = insertOp.getSourceType().getShape();
+ int64_t destRank = insertOp.getType().getRank();
+ SmallVector<OpFoldResult> resultOffsets(destRank);
+ SmallVector<OpFoldResult> resultSizes(destRank);
+ for (auto opOffset : llvm::enumerate(opOffsets)) {
+ // Check for rank-reducing by checking that
+ // 1) The corresponding opSize value is 1
+ // 2) The current rank of the source is not 1.
+ // Then the opOffset is for the rank-reduced dimension. Skip.
+ unsigned opOffsetIndex = opOffset.index();
+ OpFoldResult opOffsetVal = opOffset.value();
+ Optional<int64_t> opSizeVal = getConstantIntValue(opSizes[opOffsetIndex]);
+ if (offsetIndex >= sourceShape.size() ||
+ (opSizeVal && *opSizeVal == 1 && sourceShape[offsetIndex] != 1)) {
+ resultOffsets[opOffsetIndex] = opOffsetVal;
+ resultSizes[opOffsetIndex] = oneAttr;
+ continue;
+ }
+ OpFoldResult offset = offsets[offsetIndex];
+ if (opOffsetVal.is<Attribute>() && offset.is<Attribute>()) {
+ resultOffsets[opOffsetIndex] = b.getI64IntegerAttr(
+ *getConstantIntValue(opOffsetVal) + *getConstantIntValue(offset));
+ } else {
+ AffineExpr d0, s0;
+ bindDims(context, d0);
+ bindSymbols(context, s0);
+ AffineMap map = AffineMap::get(1, 1, d0 + s0);
+ SmallVector<Value> operands = {getValue(b, loc, offset),
+ getValue(b, loc, opOffsetVal)};
+ resultOffsets[opOffsetIndex] =
+ b.create<AffineApplyOp>(loc, map, operands).getResult();
+ }
+ resultSizes[opOffsetIndex] = sizes[offsetIndex];
+ offsetIndex++;
+ }
+ SmallVector<OpFoldResult> resultStrides(destRank, oneAttr);
+ auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
+ loc, extractSliceOp.result(), outputs[0], resultOffsets, resultSizes,
+ resultStrides);
+ results.push_back(tiledInsertOp.result());
+ return extractSliceOp;
+ }
+};
+
+/// Forwards the implementation of `TiledOpInterface` to upstream
+/// `TilingInterface`. Note that this forwarding is only valid when the
+/// iteration space is same as the data space of the result(s). This is due to
+/// the difference in the tiling algorithm being developed around
+/// `TilingInterface` and that used with `TiledOpInterface`. The difference
+/// comes down to the former only needing the tiled operation, and not the value
+/// of the whole tensor.
+template <typename OpTy>
+struct ForwardToTilingInterface
+ : public TiledOpInterface::ExternalModel<ForwardToTilingInterface<OpTy>,
+ OpTy> {
+ SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+ return cast<OpTy>(op).getDestinationOperands(b);
+ }
+
+ SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+ return cast<OpTy>(op).getLoopIteratorTypes();
+ }
+ SmallVector<Range> getLoopBounds(Operation *op, OpBuilder &b) const {
+ return cast<OpTy>(op).getLoopBounds(b);
+ }
+ Operation *getTiledImplementation(Operation *op, OpBuilder &b,
+ ValueRange dest,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) const {
+ Operation *tiledOp =
+ cast<OpTy>(op).getTiledImplementation(b, dest, offsets, sizes);
+ if (!tiledOp) {
+ op->emitOpError("failed to tile operation");
+ return nullptr;
+ }
+ if (tiledOp->getNumResults() != dest.size()) {
+ op->emitOpError(
+ "mismatch in the number of results of the tiled operation and the "
+ "number of results expected");
+ return nullptr;
+ }
+ Location loc = op->getLoc();
+ auto oneAttr = b.getI64IntegerAttr(1);
+ SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
+ for (auto result : llvm::enumerate(tiledOp->getResults())) {
+ // Assume that the shape of the result is same as the loop bounds of the
+ // op. This implies the result can be inserted into the `dest` at
+ // `offsets` and `sizes`. This would be illegal if that is not the
+ // case. This is a point of difference between the `TiledOpInterface` in
+ // IREE and `TilingInterface` in MLIR, since the latter sees fusion and
+ // tiling as the same things. So it returns just the tiled op, and not the
+ // result of the full tensor as the current tiling algorithm expects.
+ auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
+ loc, result.value(), dest[result.index()], offsets, sizes, strides);
+ results.push_back(tiledInsertOp);
+ }
+ return tiledOp;
+ }
+};
+
+} // namespace
+
+void IREE::LinalgExt::registerTiledOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ LLVM_DEBUG(
+ { llvm::dbgs() << "Adding external models of tiled op interface\n"; });
+ registry
+ .addOpInterface<tensor::ExtractSliceOp, ExtractSliceTiledOpInterface>();
+ registry.addOpInterface<tensor::InsertSliceOp, InsertSliceTiledOpInterface>();
+ registry.addOpInterface<linalg::PadTensorOp,
+ ForwardToTilingInterface<linalg::PadTensorOp>>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..66078b9
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_mlir_library(IREELinalgExtPasses
+ ConvertToLoops.cpp
+ Passes.cpp
+ Tiling.cpp
+
+ DEPENDS
+ IREELinalgExtTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEInputDialect
+ IREELinalgExtDialect
+ MLIRAffine
+ MLIRIR
+ MLIRLinalg
+ MLIRLinalgTransforms
+ MLIRMath
+ MLIRMemRef
+ MLIRPass
+ MLIRSCF
+ MLIRStandard
+ MLIRSupport
+ MLIRTensor
+ MLIRTransforms
+)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
new file mode 100644
index 0000000..0c2fcd0
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
@@ -0,0 +1,115 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/PassDetail.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+using namespace IREE::LinalgExt;
+
+/// Recursive method that lowers one dimension of the `TiledOpInterface` to
+/// scalar loops at a time.
+static LogicalResult lowerToLoopsImpl(OpBuilder &builder,
+ TiledOpInterface tilableOp,
+ ArrayRef<Range> loopRanges,
+ unsigned loopDepth,
+ SmallVectorImpl<Value> &ivs) {
+ Location loc = tilableOp.getLoc();
+ if (loopDepth == loopRanges.size()) {
+ return tilableOp.generateScalarImplementation(builder, loc, ivs);
+ }
+ LogicalResult status = success();
+ builder.create<scf::ForOp>(
+ loc, loopRanges[loopDepth].offset, loopRanges[loopDepth].size,
+ loopRanges[loopDepth].stride, ValueRange{},
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ ivs.push_back(iv);
+ status = lowerToLoopsImpl(b, tilableOp, loopRanges, loopDepth + 1, ivs);
+ b.create<scf::YieldOp>(loc);
+ });
+ return status;
+}
+
+/// Main entry point for lowering `TiledOpInterface` op to loops.
+static LogicalResult lowerToLoops(OpBuilder &builder,
+ TiledOpInterface tilableOp) {
+ SmallVector<Range> loopBounds = tilableOp.getLoopBounds(builder);
+ SmallVector<Value> ivs;
+ return lowerToLoopsImpl(builder, tilableOp, loopBounds, 0, ivs);
+}
+
+/// Pattern rewriter hook to lower a `TiledOpInterface` to loops.
+namespace {
+struct TiledOpInterfaceLowerToLoopsPattern : public RewritePattern {
+ TiledOpInterfaceLowerToLoopsPattern(MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto tilableOp = dyn_cast<TiledOpInterface>(op);
+ if (!tilableOp) {
+ return failure();
+ }
+ if (llvm::any_of(tilableOp->getResults(),
+ [&](Value v) { return v.getType().isa<ShapedType>(); })) {
+ return rewriter.notifyMatchFailure(
+ tilableOp, "lower to loops needs to have tensor semantics");
+ }
+ if (failed(lowerToLoops(rewriter, tilableOp))) {
+ return failure();
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct LinalgExtToLoopsPass
+ : public LinalgExtToLoopsBase<LinalgExtToLoopsPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, StandardOpsDialect,
+ mlir::arith::ArithmeticDialect, math::MathDialect,
+ memref::MemRefDialect, scf::SCFDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+
+ OwningRewritePatternList patterns(context);
+ patterns.insert<TiledOpInterfaceLowerToLoopsPattern>(context);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+IREE::LinalgExt::createLinalgExtToLoopsPass() {
+ return std::make_unique<LinalgExtToLoopsPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Passes.cpp
new file mode 100644
index 0000000..c41b9ed
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Passes.cpp
@@ -0,0 +1,33 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+namespace detail {
+#define GEN_PASS_REGISTRATION
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h.inc" // IWYU pragma: export
+} // namespace detail
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+void IREE::LinalgExt::registerPasses() {
+ IREE::LinalgExt::detail::registerPasses();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
new file mode 100644
index 0000000..9df31de
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -0,0 +1,354 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/PassDetail.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+using namespace IREE::LinalgExt;
+
+//===----------------------------------------------------------------------===//
+// Utility methods for tiling a linalg_ext operation that implements a
+// TiledOpInterface
+//===----------------------------------------------------------------------===//
+
+/// Returns failure if the options are unsupported.
+static LogicalResult verifySupportedTilingOptions(
+ PatternRewriter &rewriter, Operation *op,
+ const linalg::LinalgTilingOptions &options) {
+ if (!options.interchangeVector.empty()) {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported interchange during tiling");
+ }
+ if (options.loopType != linalg::LinalgTilingLoopType::Loops) {
+ return rewriter.notifyMatchFailure(op,
+ "only tiling with scf.for is supported");
+ }
+ if (options.distribution) {
+ if (llvm::any_of(options.distribution->distributionMethod,
+ [](linalg::DistributionMethod method) {
+ return method != linalg::DistributionMethod::Cyclic;
+ })) {
+ return rewriter.notifyMatchFailure(op,
+ "only cyclic distibution is allowed");
+ }
+ }
+ return success();
+}
+
+/// Converts an `OpFoldResult` to a `Value` by building a constant op if
+/// if the `OpFoldResult` is an `IntegerAttr`.
+static Value getValue(OpBuilder &builder, Location loc,
+ OpFoldResult valueOrAttr) {
+ if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+ return builder.create<arith::ConstantIndexOp>(
+ loc, attr.cast<IntegerAttr>().getInt());
+ }
+ return valueOrAttr.get<Value>();
+}
+
+/// Returns true if loop is untiled. Only checks if the value is statically
+/// zero. It is assumed that a `Value` defined by a constant op is already
+/// converted to an `IntegerAttr` of that value. So here just return true if
+/// this is an attribute with a zero value.
+static bool isUntiledLoop(OpFoldResult valueOrAttr) {
+ Optional<int64_t> intVal = getConstantIntValue(valueOrAttr);
+ return intVal && *intVal == 0;
+}
+
+/// Generates the tiled loops and the body by invoking the interface methods of
+/// TiledOpInterface.
+/// - `outputs` are the operands to use for outputs of the tiled operation.
+/// - `tileSizes` are tile sizes specified for all loops of the operation. If a
+/// loop is to be untiled it is set to 0.
+/// - `iteratorType` is the type of the loop iterator returned by the
+/// TiledOpInterface.
+/// - `loopBounds` are the bounds of all the loops of the op returned by the
+/// TiledOpInterface.
+/// - `loopDepth` is the current loop depth being processed.
+/// - `offsets` are the `Value`s that represent the position of the tile being
+/// operated on. The offsets are computed as the tiled loops are being
+/// generated.
+/// - `distributionInfo` is the proc_id and nprocs `Value`s to be used for
+/// distributed loops. It is a stack, and once an entry at the top of the
+/// stack is used for distribution it is popped before processing the inner
+/// loops.
+static FailureOr<TiledOp> tileInterfaceOpImpl(
+ OpBuilder &builder, TiledOpInterface tilableOp, ValueRange outputs,
+ MutableArrayRef<OpFoldResult> tileSizes, ArrayRef<StringRef> iteratorTypes,
+ ArrayRef<Range> loopBounds, unsigned loopDepth,
+ SmallVectorImpl<OpFoldResult> &offsets,
+ ArrayRef<linalg::ProcInfo> distributionInfo) {
+ Location loc = tilableOp.getLoc();
+ // If this is the innermost loop, then generated the tiled implementation of
+ // the op by invoking the TiledOpInterface methods.
+ if (loopDepth == tileSizes.size()) {
+ TiledOp ret;
+ ret.op = tilableOp.getTiledImplementation(builder, outputs, offsets,
+ tileSizes, ret.results);
+ if (!ret.op) {
+ return static_cast<LogicalResult>(
+ tilableOp.emitOpError("failed to get tiled implementation"));
+ }
+ return ret;
+ }
+
+ // If tile size at this depth is empty, do nothing.
+ if (isUntiledLoop(tileSizes[loopDepth])) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ offsets.push_back(zeroAttr);
+ assert(matchPattern(loopBounds[loopDepth].offset, m_Zero()) &&
+ "expected loop bounds to have lower bound of zero");
+ tileSizes[loopDepth] = getAsOpFoldResult(loopBounds[loopDepth].size);
+ return tileInterfaceOpImpl(builder, tilableOp, outputs, tileSizes,
+ iteratorTypes, loopBounds, loopDepth + 1,
+ offsets, distributionInfo);
+ }
+
+ // Generate an scf.for for the current loop depth.
+ Value lb = loopBounds[loopDepth].offset;
+ Value ub = loopBounds[loopDepth].size;
+ // TODO(#7073): Put the check back. This is required by tiling linalg_ext.fft
+ // op. We can put the check back after updating linalg_ext.fft semantics.
+ // if (!matchPattern(loopBounds[loopDepth].stride, m_One())) {
+ // return static_cast<LogicalResult>(
+ // tilableOp.emitOpError("expected stride to be 1"));
+ //}
+ Value step = getValue(builder, loc, tileSizes[loopDepth]);
+
+ // Update lb, ub and step for cyclic distribution.
+ if (!distributionInfo.empty() &&
+ iteratorTypes[loopDepth] == getParallelIteratorTypeName()) {
+ linalg::updateBoundsForCyclicDistribution(
+ builder, loc, distributionInfo.front().procId,
+ distributionInfo.front().nprocs, lb, ub, step);
+ distributionInfo = distributionInfo.drop_front();
+ }
+ FailureOr<TiledOp> innerReturnValue;
+ bool isBufferTiling = tilableOp->getNumResults() == 0;
+ ValueRange initValues(isBufferTiling ? ValueRange{} : outputs);
+ auto forOp = builder.create<scf::ForOp>(
+ loc, lb, ub, step, initValues,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ offsets.push_back(iv);
+ auto affineMaps = AffineMap::inferFromExprList({ArrayRef<AffineExpr>{
+ b.getAffineSymbolExpr(0),
+ b.getAffineSymbolExpr(1) - b.getAffineDimExpr(0)}})[0];
+ // Similar to linalg tiling, the tile size is the min(tileSizes, ub -
+ // iv) to account for cases where tile size does not divide (ub - lb)
+ // exactly.
+ Value inBoundsTileSize = b.create<AffineMinOp>(
+ loc, affineMaps,
+ ValueRange{iv, getValue(builder, loc, tileSizes[loopDepth]), ub});
+ tileSizes[loopDepth] = getAsOpFoldResult(inBoundsTileSize);
+ // Recursively proceed to generate the tiled loop for the next level.
+ innerReturnValue =
+ tileInterfaceOpImpl(b, tilableOp, (isBufferTiling ? outputs : args),
+ tileSizes, iteratorTypes, loopBounds,
+ loopDepth + 1, offsets, distributionInfo);
+ if (failed(innerReturnValue)) return;
+ b.create<scf::YieldOp>(loc, innerReturnValue->results);
+ });
+ if (failed(innerReturnValue)) {
+ return innerReturnValue;
+ }
+ innerReturnValue->loops.insert(innerReturnValue->loops.begin(),
+ forOp.getOperation());
+ innerReturnValue->results = forOp.getResults();
+ return innerReturnValue;
+}
+
+FailureOr<TiledOp> tileInterfaceOp(OpBuilder &b, TiledOpInterface tilableOp,
+ const linalg::LinalgTilingOptions &options) {
+ SmallVector<Value> dest = tilableOp.getDestinationOperands(b);
+ if (dest.empty()) {
+ return static_cast<LogicalResult>(tilableOp.emitOpError(
+ "cannot tile operation without destination operands"));
+ }
+
+ SmallVector<StringRef> iteratorTypes = tilableOp.getLoopIteratorTypes();
+ SmallVector<Value, 4> tileSizesVals =
+ options.tileSizeComputationFunction(b, tilableOp);
+ auto zeroAttr = b.getI64IntegerAttr(0);
+
+ // The actual tile sizes used converts `Value` defined as constant 0, to a
+ // zero integer attributes. Currently if the iterator type is not "parallel",
+ // the tile size is forced to zero as well.
+ auto tileSizes = getAsOpFoldResult(tileSizesVals);
+ tileSizes.resize(iteratorTypes.size(), zeroAttr);
+ for (auto en : llvm::enumerate(iteratorTypes)) {
+ if (en.value() == getParallelIteratorTypeName()) continue;
+ if (!isUntiledLoop(tileSizes[en.index()])) {
+ return static_cast<LogicalResult>(tilableOp.emitOpError(
+ "unimplemented tiling of non-parallel loop iterator type"));
+ }
+ }
+
+ // Trivial early exit case of tile sizes being zero for all parallel loops.
+ if (llvm::all_of(tileSizes, isUntiledLoop)) {
+ return TiledOp{tilableOp, {}, {}};
+ }
+
+ SmallVector<Range> loopBounds = tilableOp.getLoopBounds(b);
+ SmallVector<linalg::ProcInfo> distributionInfo;
+ // If the tiled loops are distributed, get the proc_id and nprocs for the
+ // distributed loops. First collect the parallel loops by iterating over the
+ // tileSizes and getting the loops that are distribute, i.e.,
+ // - parallel, i.e. iteratorTypes is "parallel"
+ // - tiled, i.e. tileSize != 0
+ if (options.distribution) {
+ SmallVector<Range> distributedLoopRange;
+ for (auto i : llvm::seq<unsigned>(0, tileSizes.size())) {
+ if (isUntiledLoop(tileSizes[i])) continue;
+ if (iteratorTypes[i] != getParallelIteratorTypeName()) continue;
+ distributedLoopRange.push_back(loopBounds[i]);
+ }
+ distributionInfo = options.distribution->procInfo(b, tilableOp.getLoc(),
+ distributedLoopRange);
+ }
+
+ SmallVector<OpFoldResult> offsets;
+ return tileInterfaceOpImpl(b, tilableOp, dest, tileSizes, iteratorTypes,
+ loopBounds, 0, offsets, distributionInfo);
+}
+
+LogicalResult TiledOpInterfaceBaseTilingPattern::matchAndRewriteBase(
+ TiledOpInterface tilableOp, PatternRewriter &rewriter,
+ TiledOp &result) const {
+ if (failed(filter.checkAndNotify(rewriter, tilableOp))) {
+ return failure();
+ }
+ if (failed(verifySupportedTilingOptions(rewriter, tilableOp, options))) {
+ return failure();
+ }
+
+ FailureOr<TiledOp> res = tileInterfaceOp(rewriter, tilableOp, options);
+ if (failed(res)) return res;
+ result = *res;
+ if (result.op) {
+ filter.replaceLinalgTransformationFilter(rewriter, result.op);
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Test pass for tiling Linalg Ext ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TiledOpInterfaceTilingPass
+ : public TiledOpInterfaceTilingBase<TiledOpInterfaceTilingPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<
+ AffineDialect, IREE::Input::IREEInputDialect, linalg::LinalgDialect,
+ IREE::LinalgExt::IREELinalgExtDialect, memref::MemRefDialect,
+ StandardOpsDialect, mlir::arith::ArithmeticDialect, math::MathDialect,
+ tensor::TensorDialect, scf::SCFDialect>();
+ }
+ void runOnOperation() override;
+};
+} // namespace
+
+template <typename OpTy>
+static Value buildFlowWorkgroupInfoOp(OpBuilder &b, unsigned dim) {
+ return b.template create<OpTy>(b.getInsertionPoint()->getLoc(), dim);
+}
+
+void TiledOpInterfaceTilingPass::runOnOperation() {
+ FuncOp funcOp = getOperation();
+ MLIRContext *context = funcOp.getContext();
+
+ RewritePatternSet patterns(context);
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes({10, 20}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("tiling_input", context),
+ Identifier::get("tiling_output", context)));
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{0}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("no_tiling_input", context),
+ Identifier::get("no_tiling_output", context)));
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes({0, 20}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("outer_reduce_input", context),
+ Identifier::get("outer_reduce_output", context)));
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes({10, 0, 0}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("inner_reduce_input", context),
+ Identifier::get("inner_reduce_output", context)));
+
+ static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
+ [](OpBuilder &builder, Location loc, ArrayRef<Range> parallelLoopRanges) {
+ auto numParallelDims = parallelLoopRanges.size();
+
+ SmallVector<linalg::ProcInfo, 3> procInfo(numParallelDims);
+ for (size_t dim = 0; dim < numParallelDims; ++dim) {
+ procInfo[numParallelDims - dim - 1] = {
+ buildFlowWorkgroupInfoOp<IREE::Input::DispatchWorkgroupIDOp>(
+ builder, dim),
+ buildFlowWorkgroupInfoOp<IREE::Input::DispatchWorkgroupCountOp>(
+ builder, dim)};
+ }
+ return procInfo;
+ },
+ {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
+ linalg::DistributionMethod::Cyclic},
+ DenseMap<StringRef,
+ std::function<linalg::ProcInfo(OpBuilder &, Location)>>()};
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context,
+ linalg::LinalgTilingOptions()
+ .setTileSizes(ArrayRef<int64_t>{10, 0, 30})
+ .setDistributionOptions(workgroupDistributionOptions),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("distribute_input", context),
+ Identifier::get("distribute_output", context)));
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context,
+ linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{32}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("tiling_1d_stage5_fft_input", context),
+ Identifier::get("tiling_1d_stage5_fft_output", context)));
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context,
+ linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{10, 32}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("tiling_2d_stage5_fft_input", context),
+ Identifier::get("tiling_2d_stage5_fft_output", context)));
+
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OperationPass<FuncOp>>
+IREE::LinalgExt::createTiledOpInterfaceTilingPass() {
+ return std::make_unique<TiledOpInterfaceTilingPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/CMakeLists.txt
similarity index 100%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/CMakeLists.txt
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/CMakeLists.txt
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/CMakeLists.txt
new file mode 100644
index 0000000..7ca64d8
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_library(IREEPyDMDialect
+ PyDMDialect.cpp
+ PyDMOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${IREE_DIALECTS_SOURCE_DIR}/include
+
+ DEPENDS
+ IREEPyDMIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSideEffectInterfaces
+)
+
+iree_dialects_target_includes(IREEPyDMDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp
similarity index 94%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp
index af9e341..5db10f7 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp
@@ -4,10 +4,10 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -18,11 +18,11 @@
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
using namespace PYDM;
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.cpp.inc"
-#include "iree-dialects/Dialect/IREEPyDM/IR/OpInterfaces.cpp.inc"
-#include "iree-dialects/Dialect/IREEPyDM/IR/TypeInterfaces.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOpInterfaces.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMTypeInterfaces.cpp.inc"
#define GET_TYPEDEF_CLASSES
-#include "iree-dialects/Dialect/IREEPyDM/IR/Types.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMTypes.cpp.inc"
//------------------------------------------------------------------------------
// Dialect implementation
@@ -38,11 +38,11 @@
void IREEPyDMDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
-#include "iree-dialects/Dialect/IREEPyDM/IR/Types.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc"
>();
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp
similarity index 98%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp
index d79cbd8..7ef7bfc 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp
@@ -4,9 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
@@ -807,4 +807,4 @@
}
#define GET_OP_CLASSES
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..fd9704a
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_subdirectory(Optimize)
+add_subdirectory(RTL)
+add_subdirectory(ToIREE)
+
+add_mlir_library(IREEPyDMPasses
+ Passes.cpp
+
+ DEPENDS
+ IREEPyDMTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEPyDMOptimizePasses
+ IREEPyDMRTLPasses
+ IREEPyDMToIREEPasses
+ MLIRTransforms
+)
+
+iree_dialects_target_includes(IREEPyDMPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/CMakeLists.txt
new file mode 100644
index 0000000..1f7abb1
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_library(IREEPyDMOptimizePasses
+ FixateWeakNumeric.cpp
+ LocalPropagateTypes.cpp
+ VariablesToSSA.cpp
+
+ DEPENDS
+ IREEPyDMTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEPyDMDialect
+ IREEPyDMUtils
+ MLIRIR
+ MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMOptimizePasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp
similarity index 96%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp
index 992c53f..fc17764 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp
@@ -5,8 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
using namespace mlir;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/LocalPropagateTypes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp
similarity index 98%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/LocalPropagateTypes.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp
index 3966f41..4f4a73f 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/LocalPropagateTypes.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp
@@ -5,9 +5,9 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
-#include "iree-dialects/Dialect/IREEPyDM/Utils/TypeInference.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/Utils/TypeInference.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/VariablesToSSA.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp
similarity index 98%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/VariablesToSSA.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp
index 482bf99..0a0a141 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/VariablesToSSA.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp
@@ -5,8 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h
new file mode 100644
index 0000000..9fbfc52
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h
@@ -0,0 +1,32 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
+#define IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace iree {
+class IREEDialect;
+}
+
+namespace iree_compiler {
+namespace IREE {
+namespace PYDM {
+
+class FuncOp;
+
+#define GEN_PASS_CLASSES
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h.inc"
+
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp
similarity index 92%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp
index cbaf558..820338c 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp
@@ -4,9 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
@@ -44,7 +44,7 @@
namespace PYDM_generated {
namespace {
#define GEN_PASS_REGISTRATION
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h.inc"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h.inc"
} // namespace
} // namespace PYDM_generated
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/CMakeLists.txt
new file mode 100644
index 0000000..08392f8
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_library(IREEPyDMRTLPasses
+ LinkageAnalysis.cpp
+ LinkRTLPass.cpp
+ LowerToRTLPass.cpp
+
+ DEPENDS
+ IREEPyDMTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEInputDialect
+ IREEPyDMDialect
+ MLIRIR
+ MLIRParser
+ MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMRTLPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp
similarity index 97%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp
index 315dfa4..f61214a 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp
@@ -7,9 +7,9 @@
#include <memory>
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp
similarity index 83%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp
index dc1a112..aee2bf2 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp
@@ -4,9 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "mlir/IR/SymbolTable.h"
using namespace mlir;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LowerToRTLPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp
similarity index 96%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LowerToRTLPass.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp
index ce3b970..67957a6 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LowerToRTLPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp
@@ -5,10 +5,10 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-#include "iree-dialects/Dialect/IREE/IREEOps.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -221,8 +221,8 @@
struct LowerIREEPyDMToRTLPass
: public LowerIREEPyDMToRTLBase<LowerIREEPyDMToRTLPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<mlir::iree::IREEDialect, BuiltinDialect, StandardOpsDialect>();
+ registry.insert<mlir::iree_compiler::IREE::Input::IREEInputDialect,
+ BuiltinDialect, StandardOpsDialect>();
}
void runOnOperation() override {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/CMakeLists.txt
new file mode 100644
index 0000000..4c8a7b4
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_library(IREEPyDMToIREEPasses
+ ConversionPass.cpp
+ LoweringPatterns.cpp
+ TypeConverter.cpp
+
+ DEPENDS
+ IREEPyDMTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEInputDialect
+ IREEPyDMDialect
+ MLIRArithmetic
+ MLIRIR
+ MLIRStandard
+ MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMToIREEPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp
similarity index 79%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp
index 63e343c..f49d0b1 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp
@@ -5,12 +5,12 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-#include "iree-dialects/Dialect/IREE/IREEOps.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/Patterns.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -25,8 +25,8 @@
struct ConvertIREEPyDMToIREEPass
: public ConvertIREEPyDMToIREEBase<ConvertIREEPyDMToIREEPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<mlir::iree::IREEDialect, BuiltinDialect, StandardOpsDialect,
- math::MathDialect>();
+ registry.insert<mlir::iree_compiler::IREE::Input::IREEInputDialect,
+ BuiltinDialect, StandardOpsDialect, math::MathDialect>();
}
void runOnOperation() override {
@@ -39,7 +39,8 @@
ConversionTarget target(*context);
target.addIllegalDialect<IREEPyDMDialect>();
target.addLegalDialect<BuiltinDialect>();
- target.addLegalDialect<mlir::iree::IREEDialect>();
+ target
+ .addLegalDialect<mlir::iree_compiler::IREE::Input::IREEInputDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addLegalDialect<mlir::math::MathDialect>();
target.addLegalDialect<mlir::StandardOpsDialect>();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp
similarity index 95%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp
index 3f33b13..0d5317d 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp
@@ -4,9 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREE/IREEOps.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/Patterns.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -18,6 +18,7 @@
using llvm::enumerate;
using namespace mlir;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+namespace Input = mlir::iree_compiler::IREE::Input;
using namespace PYDM;
namespace {
@@ -39,15 +40,16 @@
} // namespace
static Type getVariantListType(Builder &builder) {
- return builder.getType<iree::ListType>(builder.getType<iree::VariantType>());
+ return builder.getType<Input::ListType>(
+ builder.getType<Input::VariantType>());
}
static Value getNullValue(Location loc, OpBuilder &builder, Type t) {
return TypeSwitch<Type, Value>(t)
- .Case<iree::ListType>([&](auto t) -> Value {
+ .Case<Input::ListType>([&](auto t) -> Value {
// TODO: If it becomes important to optimize this, come up with a way
// to return an empty list without creating one.
- return builder.create<iree::ListCreateOp>(
+ return builder.create<Input::ListCreateOp>(
loc, getVariantListType(builder), /*capacity=*/nullptr);
})
.Default([&](Type t) -> Value {
@@ -75,8 +77,8 @@
}
static Value createUndefObjectList(Location loc, OpBuilder &builder) {
- return builder.create<iree::ListCreateOp>(loc, getVariantListType(builder),
- /*capacity=*/nullptr);
+ return builder.create<Input::ListCreateOp>(loc, getVariantListType(builder),
+ /*capacity=*/nullptr);
}
void resetObjectList(Location loc, OpBuilder &builder, Value list, int typeCode,
@@ -85,13 +87,13 @@
// to truly reset, we have to resize. Low level optimizations should be able
// to elide this if it turns out to be unnecessary.
auto size = builder.create<arith::ConstantIndexOp>(loc, 2);
- builder.create<iree::ListResizeOp>(loc, list, size);
+ builder.create<Input::ListResizeOp>(loc, list, size);
auto index0 = builder.create<arith::ConstantIndexOp>(loc, 0);
Value typeCodeValue = builder.create<arith::ConstantOp>(
loc, builder.getI32IntegerAttr(typeCode));
- builder.create<iree::ListSetOp>(loc, list, index0, typeCodeValue);
+ builder.create<Input::ListSetOp>(loc, list, index0, typeCodeValue);
auto index1 = builder.create<arith::ConstantIndexOp>(loc, 1);
- builder.create<iree::ListSetOp>(loc, list, index1, data);
+ builder.create<Input::ListSetOp>(loc, list, index1, data);
}
static Value createObjectList(Location loc, OpBuilder &builder, int typeCode,
@@ -334,7 +336,7 @@
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, slice.getType());
Value listSizeIndex =
- rewriter.create<iree::ListSizeOp>(loc, indexType, sequence);
+ rewriter.create<Input::ListSizeOp>(loc, indexType, sequence);
Value listSizeInteger = rewriter.create<arith::IndexCastOp>(
loc, slice.getType(), listSizeIndex);
Block *entryBlock = rewriter.getInsertionBlock();
@@ -387,7 +389,7 @@
{
rewriter.setInsertionPointToEnd(setElementBlock);
Value successResult = getSuccessStatusValue(loc, rewriter);
- rewriter.create<iree::ListSetOp>(
+ rewriter.create<Input::ListSetOp>(
loc, sequence, setElementBlock->getArgument(0), valueToSet);
rewriter.create<mlir::BranchOp>(loc, continuationBlock,
ValueRange{successResult});
@@ -555,7 +557,7 @@
rewriter.setInsertionPointToEnd(entryBlock);
auto arityValue =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(arity));
- Value listSize = rewriter.create<iree::ListSizeOp>(
+ Value listSize = rewriter.create<Input::ListSizeOp>(
loc, rewriter.getIndexType(), adaptor.sequence());
Value arityMatch = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, arityValue, listSize);
@@ -571,7 +573,7 @@
for (auto it : enumerate(slotTypes)) {
Value index = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(it.index()));
- Value slotValue = rewriter.create<iree::ListGetOp>(
+ Value slotValue = rewriter.create<Input::ListGetOp>(
loc, it.value(), adaptor.sequence(), index);
branchArgs.push_back(slotValue);
}
@@ -687,8 +689,8 @@
Type i32Type = rewriter.getIntegerType(32);
Value index0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
- Value typeCode =
- rewriter.create<iree::ListGetOp>(loc, i32Type, adaptor.value(), index0);
+ Value typeCode = rewriter.create<Input::ListGetOp>(loc, i32Type,
+ adaptor.value(), index0);
rewriter.replaceOp(
srcOp,
castIntegerValue(loc, typeCode, resultType.cast<mlir::IntegerType>(),
@@ -712,8 +714,8 @@
auto list = adaptor.getOperands()[0];
auto index1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
- rewriter.replaceOpWithNewOp<iree::ListGetOp>(srcOp, resultType, list,
- index1);
+ rewriter.replaceOpWithNewOp<Input::ListGetOp>(srcOp, resultType, list,
+ index1);
return success();
}
};
@@ -737,13 +739,13 @@
auto size = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(adaptor.elements().size()));
auto list =
- rewriter.create<iree::ListCreateOp>(loc, getVariantListType(rewriter),
- /*capacity=*/size);
- rewriter.create<iree::ListResizeOp>(loc, list, size);
+ rewriter.create<Input::ListCreateOp>(loc, getVariantListType(rewriter),
+ /*capacity=*/size);
+ rewriter.create<Input::ListResizeOp>(loc, list, size);
for (auto it : enumerate(adaptor.elements())) {
auto index = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(it.index()));
- rewriter.create<iree::ListSetOp>(loc, list, index, it.value());
+ rewriter.create<Input::ListSetOp>(loc, list, index, it.value());
}
rewriter.replaceOp(srcOp, ValueRange{list});
@@ -766,13 +768,13 @@
auto size = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(adaptor.slots().size()));
auto list =
- rewriter.create<iree::ListCreateOp>(loc, getVariantListType(rewriter),
- /*capacity=*/size);
- rewriter.create<iree::ListResizeOp>(loc, list, size);
+ rewriter.create<Input::ListCreateOp>(loc, getVariantListType(rewriter),
+ /*capacity=*/size);
+ rewriter.create<Input::ListResizeOp>(loc, list, size);
for (auto it : enumerate(adaptor.slots())) {
auto index = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(it.index()));
- rewriter.create<iree::ListSetOp>(loc, list, index, it.value());
+ rewriter.create<Input::ListSetOp>(loc, list, index, it.value());
}
rewriter.replaceOp(srcOp, ValueRange{list});
@@ -901,7 +903,7 @@
Type indexType = rewriter.getType<IndexType>();
Type listType = listOperand.getType();
Value subListSize =
- rewriter.create<iree::ListSizeOp>(loc, indexType, listOperand);
+ rewriter.create<Input::ListSizeOp>(loc, indexType, listOperand);
Value countInteger = countOperand;
Value countIndex =
rewriter.create<arith::IndexCastOp>(loc, indexType, countOperand);
@@ -916,8 +918,8 @@
Value newListSize =
rewriter.create<arith::MulIOp>(loc, subListSize, clampedCountIndex);
Value newList =
- rewriter.create<iree::ListCreateOp>(loc, listType, clampedCountIndex);
- rewriter.create<iree::ListResizeOp>(loc, newList, newListSize);
+ rewriter.create<Input::ListCreateOp>(loc, listType, clampedCountIndex);
+ rewriter.create<Input::ListResizeOp>(loc, newList, newListSize);
// Split blocks to loop.
// TODO: Use a new list.copy op instead of an inner loop.
@@ -969,9 +971,9 @@
Value newListIt = innerBody->getArgument(0);
Value subListIt = innerBody->getArgument(1);
- Value elementValue = rewriter.create<iree::ListGetOp>(
+ Value elementValue = rewriter.create<Input::ListGetOp>(
loc, listElementType, listOperand, subListIt);
- rewriter.create<iree::ListSetOp>(loc, newList, newListIt, elementValue);
+ rewriter.create<Input::ListSetOp>(loc, newList, newListIt, elementValue);
newListIt = rewriter.create<arith::AddIOp>(loc, newListIt, oneIndex);
subListIt = rewriter.create<arith::AddIOp>(loc, subListIt, oneIndex);
@@ -1057,7 +1059,7 @@
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, slice.getType());
Value listSizeIndex =
- rewriter.create<iree::ListSizeOp>(loc, indexType, adaptor.value());
+ rewriter.create<Input::ListSizeOp>(loc, indexType, adaptor.value());
Value listSizeInteger = rewriter.create<arith::IndexCastOp>(
loc, slice.getType(), listSizeIndex);
@@ -1117,7 +1119,7 @@
{
rewriter.setInsertionPointToEnd(getElementBlock);
Value successResult = getSuccessStatusValue(loc, rewriter);
- Value resultValue = rewriter.create<iree::ListGetOp>(
+ Value resultValue = rewriter.create<Input::ListGetOp>(
loc, resultType, adaptor.value(), getElementBlock->getArgument(0));
rewriter.create<mlir::BranchOp>(loc, continuationBlock,
ValueRange{successResult, resultValue});
@@ -1180,7 +1182,7 @@
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value requiredTypeCodeValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(typeCode));
- Value actualTypeCodeValue = rewriter.create<iree::ListGetOp>(
+ Value actualTypeCodeValue = rewriter.create<Input::ListGetOp>(
loc, rewriter.getI32Type(), list, index0);
Value typeCodeEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, requiredTypeCodeValue,
@@ -1195,7 +1197,7 @@
auto index1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value successResult = getSuccessStatusValue(loc, rewriter);
- Value unboxedValue = rewriter.create<iree::ListGetOp>(
+ Value unboxedValue = rewriter.create<Input::ListGetOp>(
loc, targetUnboxedType, list, index1);
rewriter.create<mlir::BranchOp>(loc, continuationBlock,
ValueRange{successResult, unboxedValue});
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp
similarity index 87%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp
index a7b664f..a16fe16 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp
@@ -4,19 +4,21 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h"
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
using namespace PYDM;
static Type getVariantListType(Builder &builder) {
- return builder.getType<iree::ListType>(builder.getType<iree::VariantType>());
+ return builder.getType<IREE::Input::ListType>(
+ builder.getType<IREE::Input::VariantType>());
}
LoweringTypeConverter::LoweringTypeConverter() {
@@ -81,7 +83,7 @@
addConversion([](mlir::IntegerType t) -> Optional<Type> { return t; });
addConversion([](mlir::FloatType t) -> Optional<Type> { return t; });
addConversion([](mlir::IndexType t) -> Optional<Type> { return t; });
- addConversion([](iree::ListType t) -> Optional<Type> { return t; });
+ addConversion([](IREE::Input::ListType t) -> Optional<Type> { return t; });
}
Type LoweringTypeConverter::getBoolType(Builder b) const {
@@ -103,7 +105,7 @@
bool LoweringTypeConverter::isTypeLegal(Type t) const {
return t.isa<mlir::IntegerType, mlir::FloatType, mlir::IndexType,
- iree::ListType>();
+ IREE::Input::ListType>();
}
bool LoweringTypeConverter::areTypesLegal(TypeRange types) const {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/CMakeLists.txt
new file mode 100644
index 0000000..2417731
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_library(IREEPyDMUtils
+ TypeInference.cpp
+
+ LINK_LIBS PUBLIC
+ IREEPyDMDialect
+ MLIRIR
+ MLIRStandard
+ MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMUtils)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/TypeInference.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp
similarity index 98%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/TypeInference.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp
index d0b8d40..ed09f96 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/TypeInference.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp
@@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/Utils/TypeInference.h"
+#include "iree-dialects/Dialect/PyDM/Utils/TypeInference.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeRange.h"
diff --git a/llvm-external-projects/iree-dialects/python/CMakeLists.txt b/llvm-external-projects/iree-dialects/python/CMakeLists.txt
index ce1b44f..da3a77c 100644
--- a/llvm-external-projects/iree-dialects/python/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/python/CMakeLists.txt
@@ -18,9 +18,9 @@
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT IREEDialectsPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler"
- TD_FILE dialects/IreeBinding.td
- SOURCES dialects/iree.py
- DIALECT_NAME iree
+ TD_FILE dialects/IreeInputBinding.td
+ SOURCES dialects/iree_input.py
+ DIALECT_NAME iree_input
)
declare_mlir_dialect_python_bindings(
diff --git a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
index 965dce7..b3efba8 100644
--- a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
+++ b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
@@ -77,11 +77,11 @@
//===--------------------------------------------------------------------===//
// IREEDialect
//===--------------------------------------------------------------------===//
- auto iree_m = m.def_submodule("iree");
+ auto iree_m = m.def_submodule("iree_input");
iree_m.def(
"register_dialect",
[](MlirContext context, bool load) {
- MlirDialectHandle handle = mlirGetDialectHandle__iree__();
+ MlirDialectHandle handle = mlirGetDialectHandle__iree_input__();
mlirDialectHandleRegisterDialect(handle, context);
if (load) {
mlirDialectHandleLoadDialect(handle, context);
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeBinding.td b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeInputBinding.td
similarity index 87%
rename from llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeBinding.td
rename to llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeInputBinding.td
index 7ce2899..eabf3c4 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeBinding.td
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreeInputBinding.td
@@ -8,6 +8,6 @@
#define PYTHON_BINDINGS_IREE_OPS
include "mlir/Bindings/Python/Attributes.td"
-include "iree-dialects/Dialect/IREE/IREEOps.td"
+include "iree-dialects/Dialect/Input/InputOps.td"
#endif // PYTHON_BINDINGS_IREE_OPS
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreePyDmBinding.td b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreePyDmBinding.td
index ff6371f..688abe6 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreePyDmBinding.td
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/IreePyDmBinding.td
@@ -8,6 +8,6 @@
#define PYTHON_BINDINGS_IREE_PYDM_OPS
include "mlir/Bindings/Python/Attributes.td"
-include "iree-dialects/Dialect/IREEPyDM/IR/Ops.td"
+include "iree-dialects/Dialect/PyDM/IR/PyDMOps.td"
#endif // PYTHON_BINDINGS_IREE_PYDM_OPS
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_input.py
similarity index 71%
rename from llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree.py
rename to llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_input.py
index c70532b..822ae82 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_input.py
@@ -4,5 +4,5 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from ._iree_ops_gen import *
-from .._mlir_libs._ireeDialects.iree import *
+from ._iree_input_ops_gen import *
+from .._mlir_libs._ireeDialects.iree_input import *
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/importer.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/importer.py
index 39030f3..b31de24 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/importer.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/importer.py
@@ -198,7 +198,9 @@
self.fctx = fctx
def visit(self, node):
- self.fctx.update_loc(node)
+ # Some psuedo-nodes (like old 'Index' types do not have location info).
+ if hasattr(node, "lineno"):
+ self.fctx.update_loc(node)
return super().visit(node)
def generic_visit(self, ast_node: ast.AST):
@@ -840,6 +842,13 @@
def visit_Constant(self, ast_node):
self._set_result(self.fctx.ic.emit_constant(ast_node.value))
+ if sys.version_info < (3, 9, 0):
+ # Starting in 3.9, Index nodes are no longer generated (they used to be
+ # a layer of indirection in subscripts). They aren't "real" nodes and
+ # we just pass them through.
+ def visit_Index(self, ast_node):
+ self.visit(ast_node.value)
+
def _get_function_ast(f) -> Tuple[str, ast.AST]:
filename = inspect.getsourcefile(f)
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py
index d8bfd1c..e832938 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py
@@ -93,7 +93,7 @@
del self._ip_stack[-1]
@property
- def ip(self):
+ def ip(self) -> ir.InsertionPoint:
assert self._ip_stack, "InsertionPoint requested but stack is empty"
return self._ip_stack[-1]
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/canonicalize.mlir b/llvm-external-projects/iree-dialects/test/iree_linalgext/canonicalize.mlir
new file mode 100644
index 0000000..acb8344
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/iree_linalgext/canonicalize.mlir
@@ -0,0 +1,21 @@
+// RUN: iree-dialects-opt -canonicalize -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @tensor.cast(
+func @tensor.cast(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
+ %init = linalg.init_tensor [3, 5] : tensor<3x5xi32>
+
+ %casted_arg0 = tensor.cast %arg0 : tensor<3x5xi32> to tensor<?x?xi32>
+ %casted_init = tensor.cast %init : tensor<3x5xi32> to tensor<?x?xi32>
+
+// CHECK: iree_linalg_ext.reverse
+// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<3x5xi32>)
+// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<3x5xi32>)
+ %0 = iree_linalg_ext.reverse
+ dimensions(dense<0> : tensor<1xi64>)
+ ins(%casted_arg0 : tensor<?x?xi32>)
+ outs(%casted_init : tensor<?x?xi32>) : tensor<?x?xi32>
+
+ %1 = tensor.cast %0 : tensor<?x?xi32> to tensor<3x5xi32>
+
+ return %1: tensor<3x5xi32>
+}
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/convert_to_loops.mlir b/llvm-external-projects/iree-dialects/test/iree_linalgext/convert_to_loops.mlir
new file mode 100644
index 0000000..a50a1b2
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/iree_linalgext/convert_to_loops.mlir
@@ -0,0 +1,507 @@
+// RUN: iree-dialects-opt -split-input-file -iree-linalg-ext-to-loops %s | FileCheck --enable-var-scope --dump-input=fail %s
+
+func @sort_1d(%arg0: memref<128xi32>) {
+ iree_linalg_ext.sort dimension(0)
+ outs(%arg0 : memref<128xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %0 = arith.cmpi sgt, %arg2, %arg3 : i32
+ iree_linalg_ext.yield %0 : i1
+ }
+ return
+}
+// CHECK-LABEL: func @sort_1d
+// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C127:.+]] = arith.constant 127 : index
+// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
+// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C127]] step %[[C1]]
+// CHECK: %[[T1:.+]] = arith.addi %[[ARG2]], %[[C1]] : index
+// CHECK: %[[V1:.+]] = memref.load %[[BUF]][%[[ARG2]]]
+// CHECK: %[[V2:.+]] = memref.load %[[BUF]][%[[T1]]]
+// CHECK: %[[COND:.+]] = arith.cmpi sgt, %[[V1]], %[[V2]] : i32
+// CHECK: scf.if %[[COND]] {
+// CHECK: } else {
+// CHECK: %[[T2:.+]] = arith.addi %[[ARG2]], %[[C1]] : index
+// CHECK: memref.store %[[V2]], %[[BUF]][%[[ARG2]]]
+// CHECK: memref.store %[[V1]], %[[BUF]][%[[T2]]]
+// CHECK: }
+
+// -----
+
+func @sort_2d(%arg0: memref<16x32xi32>) {
+ iree_linalg_ext.sort dimension(0)
+ outs(%arg0 : memref<16x32xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %0 = arith.cmpi sgt, %arg2, %arg3 : i32
+ iree_linalg_ext.yield %0 : i1
+ }
+ return
+}
+// CHECK-LABEL: func @sort_2d
+// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index
+// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C16]] step %[[C1]]
+// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C32]] step %[[C1]]
+// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C15]] step %[[C1]]
+// CHECK: %[[T1:.+]] = arith.addi %[[ARG3]], %[[C1]] : index
+// CHECK: %[[V1:.+]] = memref.load %[[BUF]][%[[ARG3]], %[[ARG2]]]
+// CHECK: %[[V2:.+]] = memref.load %[[BUF]][%[[T1]], %[[ARG2]]]
+// CHECK: %[[COND:.+]] = arith.cmpi sgt, %[[V1]], %[[V2]] : i32
+// CHECK: scf.if %[[COND]] {
+// CHECK: } else {
+// CHECK: %[[T2:.+]] = arith.addi %[[ARG3]], %[[C1]] : index
+// CHECK: memref.store %[[V2]], %[[BUF]][%[[ARG3]], %[[ARG2]]]
+// CHECK: memref.store %[[V1]], %[[BUF]][%[[T2]], %[[ARG2]]]
+// CHECK: }
+
+// -----
+
+func @sort_multi(%arg0: memref<128xf32>, %arg1: memref<128xi32>) {
+ iree_linalg_ext.sort
+ outs(%arg0, %arg1 : memref<128xf32>, memref<128xi32>) {
+ ^bb0(%arg2: f32, %arg3: f32, %arg4: i32, %arg5: i32): // no predecessors
+ %0 = arith.cmpf ogt, %arg2, %arg3 : f32
+ iree_linalg_ext.yield %0 : i1
+ }
+ return
+}
+// CHECK-LABEL: func @sort_multi
+// CHECK-SAME: %[[BUF1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[BUF2:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C127:.+]] = arith.constant 127 : index
+// CHECK: scf.for %[[ARG1:.+]] = %[[C0]] to %[[C128]] step %[[C1]]
+// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C127]] step %[[C1]]
+// CHECK: %[[T1:.+]] = arith.addi %[[ARG2]], %[[C1]] : index
+// CHECK: %[[V1:.+]] = memref.load %[[BUF1]][%[[ARG2]]]
+// CHECK: %[[V2:.+]] = memref.load %[[BUF1]][%[[T1]]]
+// CHECK: %[[V3:.+]] = memref.load %[[BUF2]][%[[ARG2]]]
+// CHECK: %[[V4:.+]] = memref.load %[[BUF2]][%[[T1]]]
+// CHECK: %[[COND:.+]] = arith.cmpf ogt, %[[V1]], %[[V2]] : f32
+// CHECK: scf.if %[[COND]] {
+// CHECK: } else {
+// CHECK: %[[T2:.+]] = arith.addi %[[ARG2]], %[[C1]] : index
+// CHECK: memref.store %[[V2]], %[[BUF1]][%[[ARG2]]]
+// CHECK: memref.store %[[V1]], %[[BUF1]][%[[T2]]]
+// CHECK: memref.store %[[V4]], %[[BUF2]][%[[ARG2]]]
+// CHECK: memref.store %[[V3]], %[[BUF2]][%[[T2]]]
+// CHECK: }
+
+// -----
+
+func @scatter_update_scalar_1D(
+ %original: memref<8xi32>, %indices: memref<3x1xi32>,
+ %updates: memref<3xi32>) {
+ iree_linalg_ext.scatter
+ ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
+ outs(%original : memref<8xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ iree_linalg_ext.yield %arg0 : i32
+ }
+ return
+}
+// CHECK-LABEL: func @scatter_update_scalar_1D
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
+// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
+// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
+// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
+
+// -----
+
+func @scatter_add_scalar_2D(
+ %original: memref<4x3xi32>, %indices: memref<3x2xi32>,
+ %updates: memref<3xi32>) {
+ iree_linalg_ext.scatter
+ ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>)
+ outs(%original : memref<4x3xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ %0 = arith.addi %arg1, %arg0 : i32
+ iree_linalg_ext.yield %0 : i32
+ }
+ return
+}
+// CHECK-LABEL: func @scatter_add_scalar_2D
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
+// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x2xi32>
+// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
+// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<3x2xi32>
+// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
+// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<4x3xi32>
+// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
+// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
+
+// -----
+
+func @scatter_update_slice_2D(
+ %original: memref<4x3xi32>, %indices: memref<2x1xi32>,
+ %updates: memref<2x3xi32>) {
+ iree_linalg_ext.scatter
+ ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
+ outs(%original : memref<4x3xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ iree_linalg_ext.yield %arg0 : i32
+ }
+ return
+}
+// CHECK: func @scatter_update_slice_2D
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK: %[[UPDATE:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
+// CHECK: %[[INDEX:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
+// CHECK: %[[LOC:.+]] = arith.index_cast %[[INDEX]] : i32 to index
+// CHECK: memref.store %[[UPDATE]], %[[ORIGINAL]][%[[LOC]], %[[J]]]
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func @scatter_add_scalar_1D(
+ %original: memref<8xi32>, %indices: memref<3x1xi32>,
+ %updates: memref<3xi32>) {
+ iree_linalg_ext.scatter
+ ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>)
+ outs(%original : memref<8xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ %0 = arith.addi %arg1, %arg0 : i32
+ iree_linalg_ext.yield %0 : i32
+ }
+ return
+}
+// CHECK-LABEL: func @scatter_add_scalar_1D
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
+// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
+// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
+// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX]]] : memref<8xi32>
+// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
+// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX]]]
+
+// -----
+
+func @scatter_add_slice_2D(
+ %original: memref<4x3xi32>, %indices: memref<2x1xi32>,
+ %updates: memref<2x3xi32>) {
+ iree_linalg_ext.scatter
+ ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>)
+ outs(%original : memref<4x3xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ %0 = arith.addi %arg1, %arg0 : i32
+ iree_linalg_ext.yield %0 : i32
+ }
+ return
+}
+// CHECK: func @scatter_add_slice_2D
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
+// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
+// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
+// CHECK: %[[ORIGINALVAL:.+]] = memref.load %[[ORIGINAL]][%[[INDEX]], %[[J]]]
+// CHECK: %[[STOREVAL:.+]] = arith.addi %[[ORIGINALVAL]], %[[UPDATEVAL]]
+// CHECK: memref.store %[[STOREVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
+
+// -----
+
+func @scatter_update_scalar_dynamic_1D(
+ %original: memref<?xi32>, %indices: memref<?x1xi32>,
+ %updates: memref<?xi32>) {
+ iree_linalg_ext.scatter
+ ins(%updates, %indices : memref<?xi32>, memref<?x1xi32>)
+ outs(%original : memref<?xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ iree_linalg_ext.yield %arg0 : i32
+ }
+ return
+}
+// CHECK-LABEL: func @scatter_update_scalar_dynamic_1D
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
+// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
+// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x1xi32>
+// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
+// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
+
+// -----
+
+func @scatter_add_scalar_dynamic_2D(
+ %original: memref<?x?xi32>, %indices: memref<?x2xi32>,
+ %updates: memref<?xi32>) {
+ iree_linalg_ext.scatter
+ ins(%updates, %indices : memref<?xi32>, memref<?x2xi32>)
+ outs(%original : memref<?x?xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ %0 = arith.addi %arg1, %arg0 : i32
+ iree_linalg_ext.yield %0 : i32
+ }
+ return
+}
+// CHECK-LABEL: func @scatter_add_scalar_dynamic_2D
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[UB:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?xi32>
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB]] step %[[C1]] {
+// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<?xi32>
+// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<?x2xi32>
+// CHECK: %[[IDX1:.+]] = arith.index_cast %[[T2]] : i32 to index
+// CHECK: %[[T3:.+]] = memref.load %[[INDICES]][%[[I]], %[[C1]]] : memref<?x2xi32>
+// CHECK: %[[IDX2:.+]] = arith.index_cast %[[T3]] : i32 to index
+// CHECK: %[[ORI:.+]] = memref.load %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]] : memref<?x?xi32>
+// CHECK: %[[ADD:.+]] = arith.addi %[[ORI]], %[[T1]] : i32
+// CHECK: memref.store %[[ADD]], %[[ORIGINAL]][%[[IDX1]], %[[IDX2]]]
+
+// -----
+
+func @scatter_update_slice_dynamic_2D(
+ %original: memref<?x?xi32>, %indices: memref<?x1xi32>,
+ %updates: memref<?x?xi32>) {
+ iree_linalg_ext.scatter
+ ins(%updates, %indices : memref<?x?xi32>, memref<?x1xi32>)
+ outs(%original : memref<?x?xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ iree_linalg_ext.yield %arg0 : i32
+ }
+ return
+}
+// CHECK: func @scatter_update_slice_dynamic_2D
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[UB1:.+]] = memref.dim %[[UPDATES]], %[[C0]] : memref<?x?xi32>
+// CHECK-DAG: %[[UB2:.+]] = memref.dim %[[UPDATES]], %[[C1]] : memref<?x?xi32>
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[UB1]] step %[[C1]] {
+// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[UB2]] step %[[C1]] {
+// CHECK: %[[UPDATEVAL:.+]] = memref.load %[[UPDATES]][%[[I]], %[[J]]]
+// CHECK: %[[INDEXVAL:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]]
+// CHECK: %[[INDEX:.+]] = arith.index_cast %[[INDEXVAL]] : i32 to index
+// CHECK: memref.store %[[UPDATEVAL]], %[[ORIGINAL]][%[[INDEX]], %[[J]]]
+
+// -----
+
+func @fft_1D(%real: memref<16xf32>, %imag: memref<16xf32>) {
+ %stage = arith.constant 1 : index
+ iree_linalg_ext.fft
+ ins(%stage: index)
+ outs(%real, %imag: memref<16xf32>, memref<16xf32>)
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: func @fft_1D
+// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
+// CHECK-DAG: %[[SCALE:.+]] = arith.constant -6.28318548 : f32
+// CHECK-DAG: %[[NODE_RNG:.+]] = arith.shli %[[C1]], %[[C1]] : index
+// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[NODE_RNG]]
+// CHECK-DAG: %[[M:.+]] = arith.shli %[[C1]], %[[C1]] : index
+// CHECK-DAG: %[[HM:.+]] = arith.shrsi %[[M]], %[[C1]] : index
+// CHECK: %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[K]]] [%[[HM]]] [1]
+// CHECK: %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[K]]] [%[[HM]]] [1]
+// CHECK: %[[R_OFFSET:.+]] = arith.addi %[[K]], %[[HM]] : index
+// CHECK: %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[R_OFFSET]]] [%[[HM]]] [1]
+// CHECK: %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[R_OFFSET]]] [%[[HM]]] [1]
+// CHECK: %[[M_I32:.+]] = arith.index_cast %[[M]] : index to i32
+// CHECK: %[[M_F32:.+]] = arith.sitofp %[[M_I32]] : i32 to f32
+// CHECK: %[[COEFF:.+]] = arith.divf %[[SCALE]], %[[M_F32]]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP1]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel"]
+// CHECK-SAME: outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]]
+// CHECK: ^bb0(%[[L_REAL:.+]]: f32, %[[L_IMAG:.+]]: f32, %[[R_REAL:.+]]: f32, %[[R_IMAG:.+]]: f32)
+//
+// Compute exp coeff.
+// CHECK: %[[J_IDX:.+]] = linalg.index 0 : index
+// CHECK: %[[J_I32:.+]] = arith.index_cast %[[J_IDX]] : index to i32
+// CHECK: %[[J_F32:.+]] = arith.sitofp %[[J_I32]] : i32 to f32
+// CHECK: %[[EXP_COEF:.+]] = arith.mulf %[[COEFF]], %[[J_F32]] : f32
+// CHECK: %[[W_REAL:.+]] = math.cos %[[EXP_COEF]]
+// CHECK: %[[W_IMAG:.+]] = math.sin %[[EXP_COEF]]
+//
+// Compute "t = w * a[k + j + mh]" by expanding
+// (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
+// CHECK-DAG: %[[XU:.+]] = arith.mulf %[[W_REAL]], %[[R_REAL]]
+// CHECK-DAG: %[[YV:.+]] = arith.mulf %[[W_IMAG]], %[[R_IMAG]]
+// CHECK-DAG: %[[XV:.+]] = arith.mulf %[[W_REAL]], %[[R_IMAG]]
+// CHECK-DAG: %[[YU:.+]] = arith.mulf %[[W_IMAG]], %[[R_REAL]]
+// CHECK: %[[T_REAL:.+]] = arith.subf %[[XU]], %[[YV]]
+// CHECK: %[[T_IMAG:.+]] = arith.addf %[[XV]], %[[YU]]
+//
+// Compute the results.
+// u = a[k + j];
+// a[k + j] = u + t;
+// a[k + j + mh] = u - t;
+// CHECK: %[[RES1:.+]] = arith.addf %[[L_REAL]], %[[T_REAL]]
+// CHECK: %[[RES2:.+]] = arith.addf %[[L_IMAG]], %[[T_IMAG]]
+// CHECK: %[[RES3:.+]] = arith.subf %[[L_REAL]], %[[T_REAL]]
+// CHECK: %[[RES4:.+]] = arith.subf %[[L_IMAG]], %[[T_IMAG]]
+// CHECK: linalg.yield %[[RES1]], %[[RES2]], %[[RES3]], %[[RES4]]
+
+// -----
+
+func @fft_2D(%real: memref<?x16xf32>, %imag: memref<?x16xf32>) {
+ %stage = arith.constant 2 : index
+ iree_linalg_ext.fft
+ ins(%stage: index)
+ outs(%real, %imag: memref<?x16xf32>, memref<?x16xf32>)
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: func @fft_2D
+// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[REAL]], %[[C0]] : memref<?x16xf32>
+// CHECK-DAG: %[[NODE_RNG:.+]] = arith.shli %[[C1]], %[[C2]] : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]]
+// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[NODE_RNG]]
+// CHECK-DAG: %[[M:.+]] = arith.shli %[[C1]], %[[C2]] : index
+// CHECK-DAG: %[[HM:.+]] = arith.shrsi %[[M]], %[[C1]] : index
+// CHECK: %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
+// CHECK: %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
+// CHECK: %[[R_OFFSET:.+]] = arith.addi %[[K]], %[[HM]] : index
+// CHECK: %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
+// CHECK: %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP1]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]]
+//
+// The computation is bascially the same, and they are
+// checked above. Here only checks the different part.
+// CHECK: %{{.+}} = linalg.index 1 : index
+
+// -----
+
+func @fft_2D_coef_buf(%real: memref<?x16xf32>, %imag: memref<?x16xf32>,
+ %coef_real: memref<1xf32>, %coef_imag: memref<1xf32>) {
+ %stage = arith.constant 1 : index
+ iree_linalg_ext.fft
+ ins(%stage, %coef_real, %coef_imag: index, memref<1xf32>, memref<1xf32>)
+ outs(%real, %imag: memref<?x16xf32>, memref<?x16xf32>)
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 16 + s0 + d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: func @fft_2D_coef_buf
+// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IMAG:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[REAL]], %[[C0]] : memref<?x16xf32>
+// CHECK-DAG: %[[NODE_RNG:.+]] = arith.shli %[[C1]], %[[C1]] : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]]
+// CHECK: scf.for %[[K:.+]] = %[[C0]] to %[[C16]] step %[[NODE_RNG]]
+// CHECK-DAG: %[[M:.+]] = arith.shli %[[C1]], %[[C1]] : index
+// CHECK-DAG: %[[HM:.+]] = arith.shrsi %[[M]], %[[C1]] : index
+// CHECK: %[[L_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
+// CHECK: %[[L_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[K]]] [1, %[[HM]]] [1, 1]
+// CHECK: %[[R_OFFSET:.+]] = arith.addi %[[K]], %[[HM]] : index
+// CHECK: %[[R_REAL_SLICE:.+]] = memref.subview %[[REAL]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
+// CHECK: %[[R_IMAG_SLICE:.+]] = memref.subview %[[IMAG]][%[[I]], %[[R_OFFSET]]] [1, %[[HM]]] [1, 1]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP2]], #[[MAP2]], #[[MAP2]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[COEF_REAL]], %[[COEF_IMAG]]
+// CHECK-SAME: outs(%[[L_REAL_SLICE]], %[[L_IMAG_SLICE]], %[[R_REAL_SLICE]], %[[R_IMAG_SLICE]]
+// CHECK: ^bb0(%[[W_REAL:.+]]: f32, %[[W_IMAG:.+]]: f32, %[[L_REAL:.+]]: f32, %[[L_IMAG:.+]]: f32, %[[R_REAL:.+]]: f32, %[[R_IMAG:.+]]: f32)
+// Compute "t = w * a[k + j + mh]" by expanding
+// (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
+// CHECK-DAG: %[[XU:.+]] = arith.mulf %[[W_REAL]], %[[R_REAL]]
+// CHECK-DAG: %[[YV:.+]] = arith.mulf %[[W_IMAG]], %[[R_IMAG]]
+// CHECK-DAG: %[[XV:.+]] = arith.mulf %[[W_REAL]], %[[R_IMAG]]
+// CHECK-DAG: %[[YU:.+]] = arith.mulf %[[W_IMAG]], %[[R_REAL]]
+// CHECK: %[[T_REAL:.+]] = arith.subf %[[XU]], %[[YV]]
+// CHECK: %[[T_IMAG:.+]] = arith.addf %[[XV]], %[[YU]]
+//
+// Compute the results.
+// u = a[k + j];
+// a[k + j] = u + t;
+// a[k + j + mh] = u - t;
+// CHECK: %[[RES1:.+]] = arith.addf %[[L_REAL]], %[[T_REAL]]
+// CHECK: %[[RES2:.+]] = arith.addf %[[L_IMAG]], %[[T_IMAG]]
+// CHECK: %[[RES3:.+]] = arith.subf %[[L_REAL]], %[[T_REAL]]
+// CHECK: %[[RES4:.+]] = arith.subf %[[L_IMAG]], %[[T_IMAG]]
+// CHECK: linalg.yield %[[RES1]], %[[RES2]], %[[RES3]], %[[RES4]]
+
+// -----
+
+func @reverse_dim_0(%arg0: memref<?x?xi32>, %arg1: memref<?x?xi32>) {
+ iree_linalg_ext.reverse
+ dimensions(dense<0> : tensor<1xi64>)
+ ins(%arg0 : memref<?x?xi32>)
+ outs(%arg1 : memref<?x?xi32>)
+ return
+}
+// CHECK-LABEL: func @reverse_dim_0
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %arg0, %c0 : memref<?x?xi32>
+// CHECK-DAG: %[[D1:.+]] = memref.dim %arg0, %c1 : memref<?x?xi32>
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]]
+// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C1]]
+// CHECK: %[[T0:.+]] = memref.dim %[[IN]], %[[C0]]
+// CHECK: %[[T1:.+]] = arith.subi %[[T0]], %[[C1]] : index
+// CHECK: %[[T2:.+]] = arith.subi %[[T1]], %[[I]] : index
+// CHECK: %[[V0:.+]] = memref.load %[[IN]][%[[I]], %[[J]]]
+// CHECK: memref.store %[[V0]], %[[OUT]][%[[T2]], %[[J]]] : memref<?x?xi32>
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/invalid.mlir b/llvm-external-projects/iree-dialects/test/iree_linalgext/invalid.mlir
new file mode 100644
index 0000000..067739f
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/iree_linalgext/invalid.mlir
@@ -0,0 +1,418 @@
+// RUN: iree-dialects-opt -split-input-file -verify-diagnostics %s
+
+func @sort_invalid_dimension(%arg0: tensor<128xi32>) -> tensor<128xi32> {
+ // expected-error @+1 {{dimension must be within (0, 1]}}
+ %0 = iree_linalg_ext.sort dimension(1)
+ outs(%arg0 : tensor<128xi32>) {
+ ^bb0(%arg1: i32, %arg2: i32): // no predecessors
+ %1 = arith.cmpi sgt, %arg1, %arg2 : i32
+ iree_linalg_ext.yield %1 : i1
+ } -> tensor<128xi32>
+ return %0 : tensor<128xi32>
+}
+
+// -----
+
+func @sort_without_dimension(%arg0: tensor<3x4xi32>) -> tensor<3x4xi32> {
+ // expected-error @+1 {{dimension must be specified if rank > 1}}
+ %0 = iree_linalg_ext.sort
+ outs(%arg0 : tensor<3x4xi32>) {
+ ^bb0(%arg1: i32, %arg2: i32): // no predecessors
+ %1 = arith.cmpi sgt, %arg1, %arg2 : i32
+ iree_linalg_ext.yield %1 : i1
+ } -> tensor<3x4xi32>
+ return %0 : tensor<3x4xi32>
+}
+
+// -----
+
+func @sort_mismatch_rank(%arg0: tensor<?x?xi32>, %arg1: tensor<?xf32>)
+ -> (tensor<?x?xi32>, tensor<?xf32>) {
+ // expected-error @+1 {{expected operand 1 to be rank 2, same as other operands}}
+ %0:2 = iree_linalg_ext.sort dimension(0)
+ outs(%arg0, %arg1 : tensor<?x?xi32>, tensor<?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %1 = arith.cmpf ogt, %arg4, %arg5 : f32
+ iree_linalg_ext.yield %1 : i1
+ } -> tensor<?x?xi32>, tensor<?xf32>
+ return %0#0, %0#1 : tensor<?x?xi32>, tensor<?xf32>
+}
+
+// -----
+
+func @sort_mismatch_shape(%arg0: tensor<?xi32>, %arg1: tensor<42xf32>)
+ -> (tensor<?xi32>, tensor<42xf32>) {
+ // expected-error @+1 {{expected operand 1 to have same shape as other operands}}
+ %0:2 = iree_linalg_ext.sort dimension(0)
+ outs(%arg0, %arg1 : tensor<?xi32>, tensor<42xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %1 = arith.cmpf ogt, %arg4, %arg5 : f32
+ iree_linalg_ext.yield %1 : i1
+ } -> tensor<?xi32>, tensor<42xf32>
+ return %0#0, %0#1 : tensor<?xi32>, tensor<42xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : memref<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : tensor<?x?xf32>, %indices : memref<?x1xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, memref<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_extra_outputs(
+ %update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ // expected-error @+1 {{expected number of outputs to be same as the number of results}}
+ %0, %1 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>, tensor<?x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
+ %original : memref<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : memref<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xf32>) -> memref<?x?xf32> {
+ // expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'memref<?x?xf32>'}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> memref<?x?xf32>
+ return %0 : memref<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
+ %original : memref<?x?xf32>) {
+ // expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
+ iree_linalg_ext.scatter
+ ins(%update, %indices : memref<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : memref<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ }
+ return
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : memref<?x?xf32>, %indices : memref<?x1xi32>,
+ %original : tensor<?x?xf32>) {
+ // expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}}
+ iree_linalg_ext.scatter
+ ins(%update, %indices : memref<?x?xf32>, memref<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ }
+ return
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+ %update : tensor<?x?xf32>, %indices : tensor<48x1xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, tensor<48x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+ %update : tensor<64x?xf32>, %indices : tensor<48x1xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<64x?xf32>, tensor<48x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+ %update : tensor<?x?x?xf32>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{mismatch in rank of update value, index depth and original value}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?x?xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+ %update : tensor<?x4xf32>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{mismatch in shape of update value dim#1 and original value at dim#1}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x4xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+ %update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
+ // expected-error @+1 {{expected region to have scalar argument of integer or float types}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xi32>) {
+ ^bb0(%arg1: index, %arg2: index):
+ %1 = arith.addi %arg1, %arg2 : index
+ %2 = arith.index_cast %1 : index to i32
+ iree_linalg_ext.yield %2 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+ %update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
+ // expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xi32>) {
+ ^bb0(%arg1: i64, %arg2: i32):
+ %1 = arith.trunci %arg1 : i64 to i32
+ %2 = arith.addi %1, %arg2 : i32
+ iree_linalg_ext.yield %2 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+ %update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
+ // expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xi32>) {
+ ^bb0(%arg1: i32, %arg2: i64):
+ %1 = arith.trunci %arg2 : i64 to i32
+ %2 = arith.addi %1, %arg1 : i32
+ iree_linalg_ext.yield %2 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+ %update : tensor<?x?xi32>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+ // expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xi64>) {
+ ^bb0(%arg1: i32, %arg2: i64):
+ %1 = arith.extsi %arg1 : i32 to i64
+ %2 = arith.addi %1, %arg2 : i64
+ iree_linalg_ext.yield %2 : i64
+ } -> tensor<?x?xi64>
+ return %0 : tensor<?x?xi64>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+ %update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+ // expected-error @+1 {{expected region to have two arguments}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xi64>) {
+ ^bb0(%arg1: i64, %arg2: i64, %arg3 : i64):
+ %1 = arith.addi %arg1, %arg2 : i64
+ iree_linalg_ext.yield %1 : i64
+ } -> tensor<?x?xi64>
+ return %0 : tensor<?x?xi64>
+}
+
+
+// -----
+
+func @scatter_yield_mismatch(
+ %update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xi64>) {
+ ^bb0(%arg1: i64, %arg2: i64):
+ %1 = arith.addi %arg1, %arg2 : i64
+ %2 = arith.trunci %1 : i64 to i32
+ // expected-error @+1 {{mismatch in type of yielded value 'i32' and argument of the region 'i64'}}
+ iree_linalg_ext.yield %2 : i32
+ } -> tensor<?x?xi64>
+ return %0 : tensor<?x?xi64>
+}
+
+// -----
+
+func @scatter_yield_mismatch(
+ %update : tensor<?x?xi64>, %indices : tensor<?x1xi32>,
+ %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi64>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xi64>) {
+ ^bb0(%arg1: i64, %arg2: i64):
+ %1 = arith.addi %arg1, %arg2 : i64
+ %2 = arith.trunci %1 : i64 to i32
+ // expected-error @+1 {{expected region to yield a single value}}
+ iree_linalg_ext.yield %1, %2 : i64, i32
+ } -> tensor<?x?xi64>
+ return %0 : tensor<?x?xi64>
+}
+
+// -----
+
+func @scatter_index_depth_dynamic(
+ %update : tensor<?x?xi64>, %indices : tensor<?x?xi32>,
+ %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+ // expected-error @+1 {{expected index depth is static}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi64>, tensor<?x?xi32>)
+ outs(%original : tensor<?x?xi64>) {
+ ^bb0(%arg1: i64, %arg2: i64):
+ %1 = arith.addi %arg1, %arg2 : i64
+ %2 = arith.trunci %1 : i64 to i32
+ iree_linalg_ext.yield %1, %2 : i64, i32
+ } -> tensor<?x?xi64>
+ return %0 : tensor<?x?xi64>
+}
+
+// -----
+
+func @scatter_original_rank_mismatch(
+ %update : tensor<?x?xi64>, %indices : tensor<?x2xi32>,
+ %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+ // expected-error @+1 {{mismatch in rank of update value, index depth and original value}}
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi64>, tensor<?x2xi32>)
+ outs(%original : tensor<?x?xi64>) {
+ ^bb0(%arg1: i64, %arg2: i64):
+ %1 = arith.addi %arg1, %arg2 : i64
+ %2 = arith.trunci %1 : i64 to i32
+ iree_linalg_ext.yield %1, %2 : i64, i32
+ } -> tensor<?x?xi64>
+ return %0 : tensor<?x?xi64>
+}
+
+// -----
+
+func @reverse_diff_element_type(%arg0: tensor<3x5xi32>) -> tensor<3x5xf32> {
+ %init = linalg.init_tensor [3, 5] : tensor<3x5xf32>
+ // expected-error @+1 {{expected input/output element types to be identical}}
+ %0 = iree_linalg_ext.reverse
+ dimensions(dense<0> : tensor<1xi64>)
+ ins(%arg0 : tensor<3x5xi32>)
+ outs(%init : tensor<3x5xf32>) : tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+func @reverse_diff_shape(%arg0: tensor<3x5xi32>) -> tensor<3x6xi32> {
+ %init = linalg.init_tensor [3, 6] : tensor<3x6xi32>
+ // expected-error @+1 {{incompatible input/output shapes}}
+ %0 = iree_linalg_ext.reverse
+ dimensions(dense<0> : tensor<1xi64>)
+ ins(%arg0 : tensor<3x5xi32>)
+ outs(%init : tensor<3x6xi32>) : tensor<3x6xi32>
+ return %0 : tensor<3x6xi32>
+}
+
+// -----
+
+func @reverse_dup_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
+ %init = linalg.init_tensor [3, 5] : tensor<3x5xi32>
+ // expected-error @+1 {{expected dimensions numbers are all unique}}
+ %0 = iree_linalg_ext.reverse
+ dimensions(dense<[0, 0]> : tensor<2xi64>)
+ ins(%arg0 : tensor<3x5xi32>)
+ outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
+ return %0 : tensor<3x5xi32>
+}
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/iree_linalgext/roundtrip.mlir
new file mode 100644
index 0000000..1b4bde8
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/iree_linalgext/roundtrip.mlir
@@ -0,0 +1,494 @@
+// RUN: iree-dialects-opt -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @sort_tensor
+// CHECK: iree_linalg_ext.sort
+// CHECK-SAME: outs({{.*}})
+// CHECK: iree_linalg_ext.yield
+func @sort_tensor(%arg0: tensor<128xi32>) -> tensor<128xi32> {
+ %0 = iree_linalg_ext.sort
+ outs(%arg0 : tensor<128xi32>) {
+ ^bb0(%arg1: i32, %arg2: i32): // no predecessors
+ %1 = arith.cmpi sgt, %arg1, %arg2 : i32
+ iree_linalg_ext.yield %1 : i1
+ } -> tensor<128xi32>
+ return %0 : tensor<128xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @sort_memref
+// CHECK: iree_linalg_ext.sort
+// CHECK-SAME: outs({{.*}})
+// CHECK: iree_linalg_ext.yield
+func @sort_memref(%arg0: memref<128xi32>) {
+ iree_linalg_ext.sort dimension(0)
+ outs(%arg0 : memref<128xi32>) {
+ ^bb0(%arg1: i32, %arg2: i32): // no predecessors
+ %0 = arith.cmpi sgt, %arg1, %arg2 : i32
+ iree_linalg_ext.yield %0 : i1
+ }
+ return
+}
+
+// -----
+
+func @sort_multi_result_tensor(
+ %arg0: tensor<?x?xi32>, %arg1: tensor<?x?xf32>)
+ -> (tensor<?x?xi32>, tensor<?x?xf32>) {
+ %0:2 = iree_linalg_ext.sort dimension(0)
+ outs(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %1 = arith.cmpf ogt, %arg4, %arg5 : f32
+ iree_linalg_ext.yield %1 : i1
+ } -> tensor<?x?xi32>, tensor<?x?xf32>
+ return %0#0, %0#1 : tensor<?x?xi32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @sort_multi_result_tensor
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.sort dimension(0)
+// CHECK-SAME: outs(%[[ARG0]], %[[ARG1]]
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @sort_multi_result_memref(
+ %arg0: memref<?x?xi32>, %arg1: memref<?x?xf32>) {
+ iree_linalg_ext.sort dimension(0)
+ outs(%arg0, %arg1 : memref<?x?xi32>, memref<?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %1 = arith.cmpf ogt, %arg4, %arg5 : f32
+ iree_linalg_ext.yield %1 : i1
+ }
+ return
+}
+// CHECK-LABEL: func @sort_multi_result_memref
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xi32>
+// CHECK-SAME: %[[ARG1:.+]]: memref<?x?xf32>
+// CHECK: iree_linalg_ext.sort dimension(0)
+// CHECK-SAME: outs(%[[ARG0]], %[[ARG1]]
+
+// -----
+
+func @scatter_tensor_dynamic(
+ %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
+ %update: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ outs(%original: tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @scatter_tensor_dynamic(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?x1xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @scatter_tensor_static(
+ %original: tensor<128x3xf32>, %indices: tensor<48x1xi32>,
+ %update: tensor<48x3xf32>) -> tensor<128x3xf32> {
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<48x3xf32>, tensor<48x1xi32>)
+ outs(%original: tensor<128x3xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<128x3xf32>
+ return %0 : tensor<128x3xf32>
+}
+// CHECK-LABEL: func @scatter_tensor_static(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<128x3xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<48x1xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<48x3xf32>
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @scatter_tensor_multi_index_depth(
+ %original: tensor<1x128x3xf32>, %indices: tensor<48x2xi32>,
+ %update: tensor<48x3xf32>) -> tensor<1x128x3xf32> {
+ %0 = iree_linalg_ext.scatter
+ ins(%update, %indices : tensor<48x3xf32>, tensor<48x2xi32>)
+ outs(%original: tensor<1x128x3xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<1x128x3xf32>
+ return %0 : tensor<1x128x3xf32>
+}
+// CHECK-LABEL: func @scatter_tensor_multi_index_depth(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<1x128x3xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<48x2xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<48x3xf32>
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @scatter_memref_dynamic(
+ %original: memref<?x?xf32>, %indices: memref<?x1xi32>,
+ %update: memref<?x?xf32>) {
+ iree_linalg_ext.scatter
+ ins(%update, %indices : memref<?x?xf32>, memref<?x1xi32>)
+ outs(%original: memref<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ }
+ return
+}
+// CHECK-LABEL: func @scatter_memref_dynamic(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref<?x1xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK: iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : f32
+// CHECK: return
+
+// -----
+
+func @scatter_memref_static(
+ %original: memref<128x3xf32>, %indices: memref<48x1xi32>,
+ %update: memref<48x3xf32>) {
+ iree_linalg_ext.scatter
+ ins(%update, %indices : memref<48x3xf32>, memref<48x1xi32>)
+ outs(%original: memref<128x3xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ }
+ return
+}
+// CHECK-LABEL: func @scatter_memref_static(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<128x3xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref<48x1xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: memref<48x3xf32>
+// CHECK: iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : f32
+// CHECK: return
+
+// -----
+
+func @scatter_memref_multi_index_depth(
+ %original: memref<1x128x3xf32>, %indices: memref<48x2xi32>,
+ %update: memref<48x3xf32>) {
+ iree_linalg_ext.scatter
+ ins(%update, %indices : memref<48x3xf32>, memref<48x2xi32>)
+ outs(%original: memref<1x128x3xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ }
+ return
+}
+// CHECK-LABEL: func @scatter_memref_multi_index_depth(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<1x128x3xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref<48x2xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: memref<48x3xf32>
+// CHECK: iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : f32
+// CHECK: return
+
+// -----
+
+func @scatter_update_scalar_1D(
+ %original: tensor<8xi32>, %indices: tensor<3x1xi32>,
+ %updates: tensor<3xi32>) -> tensor<8xi32> {
+ %0 = iree_linalg_ext.scatter
+ ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
+ outs(%original : tensor<8xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ iree_linalg_ext.yield %arg0 : i32
+ } -> tensor<8xi32>
+ return %0 : tensor<8xi32>
+}
+// CHECK-LABEL: func @scatter_update_scalar_1D(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : i32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @scatter_update_scalar_2D(
+ %original: tensor<4x3xi32>, %indices: tensor<3x2xi32>,
+ %updates: tensor<3xi32>) -> tensor<4x3xi32> {
+ %0 = iree_linalg_ext.scatter
+ ins(%updates, %indices : tensor<3xi32>, tensor<3x2xi32>)
+ outs(%original : tensor<4x3xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ iree_linalg_ext.yield %arg0 : i32
+ } -> tensor<4x3xi32>
+ return %0 : tensor<4x3xi32>
+}
+// CHECK-LABEL: func @scatter_update_scalar_2D(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : i32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @scatter_update_slice_2D(
+ %original: tensor<4x3xi32>, %indices: tensor<1x1xi32>,
+ %updates: tensor<1x3xi32>) -> tensor<4x3xi32> {
+ %0 = iree_linalg_ext.scatter
+ ins(%updates, %indices : tensor<1x3xi32>, tensor<1x1xi32>)
+ outs(%original : tensor<4x3xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ iree_linalg_ext.yield %arg0 : i32
+ } -> tensor<4x3xi32>
+ return %0 : tensor<4x3xi32>
+}
+// CHECK-LABEL: func @scatter_update_slice_2D(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : i32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @fft_tensor(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>)
+ -> (tensor<1024xf32>, tensor<1024xf32>) {
+ %cst1 = arith.constant 1 : index
+ %0:2 = iree_linalg_ext.fft
+ ins(%cst1: index)
+ outs(%arg0, %arg1: tensor<1024xf32>, tensor<1024xf32>)
+ : tensor<1024xf32>, tensor<1024xf32>
+ return %0#0, %0#1 : tensor<1024xf32>, tensor<1024xf32>
+}
+// CHECK-LABEL: func @fft_tensor(
+// CHECK-SAME: %[[REAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[IMAG:[a-zA-Z0-9_]+]]
+// CHECK: %[[CST:.+]] = arith.constant 1 : index
+// CHECK: %[[RES:.+]]:2 = iree_linalg_ext.fft
+// CHECK-SAME: ins(%[[CST]] : index)
+// CHECK-SAME: outs(%[[REAL]], %[[IMAG]] : tensor<1024xf32>, tensor<1024xf32>)
+// CHECK-SAME: : tensor<1024xf32>, tensor<1024xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+
+// -----
+
+func @fft_memref(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
+ %cst1 = arith.constant 1 : index
+ iree_linalg_ext.fft
+ ins(%cst1: index)
+ outs(%arg0, %arg1: memref<1024xf32>, memref<1024xf32>)
+ return
+}
+// CHECK-LABEL: func @fft_memref(
+// CHECK-SAME: %[[REAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[IMAG:[a-zA-Z0-9_]+]]
+// CHECK: %[[CST:.+]] = arith.constant 1 : index
+// CHECK: iree_linalg_ext.fft
+// CHECK-SAME: ins(%[[CST]] : index)
+// CHECK-SAME: outs(%[[REAL]], %[[IMAG]] : memref<1024xf32>, memref<1024xf32>)
+// CHECK: return
+
+// -----
+
+func @fft_tensor_coef(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>,
+ %arg2: tensor<1xf32>, %arg3: tensor<1xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) {
+ %cst1 = arith.constant 1 : index
+ %0:2 = iree_linalg_ext.fft
+ ins(%cst1, %arg2, %arg3: index, tensor<1xf32>, tensor<1xf32>)
+ outs(%arg0, %arg1: tensor<1024xf32>, tensor<1024xf32>)
+ : tensor<1024xf32>, tensor<1024xf32>
+ return %0#0, %0#1 : tensor<1024xf32>, tensor<1024xf32>
+}
+// CHECK-LABEL: func @fft_tensor_coef(
+// CHECK-SAME: %[[REAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[IMAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]]
+// CHECK: %[[CST:.+]] = arith.constant 1 : index
+// CHECK: %[[RES:.+]]:2 = iree_linalg_ext.fft
+// CHECK-SAME: ins(%[[CST]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, tensor<1xf32>, tensor<1xf32>)
+// CHECK-SAME: outs(%[[REAL]], %[[IMAG]] : tensor<1024xf32>, tensor<1024xf32>)
+// CHECK-SAME: : tensor<1024xf32>, tensor<1024xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+
+// -----
+
+func @fft_memref_coef(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>,
+ %arg2: memref<1xf32>, %arg3: memref<1xf32>) {
+ %cst1 = arith.constant 1 : index
+ iree_linalg_ext.fft
+ ins(%cst1, %arg2, %arg3: index, memref<1xf32>, memref<1xf32>)
+ outs(%arg0, %arg1: memref<1024xf32>, memref<1024xf32>)
+ return
+}
+// CHECK-LABEL: func @fft_memref_coef(
+// CHECK-SAME: %[[REAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[IMAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]]
+// CHECK: %[[CST:.+]] = arith.constant 1 : index
+// CHECK: iree_linalg_ext.fft
+// CHECK-SAME: ins(%[[CST]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, memref<1xf32>, memref<1xf32>)
+// CHECK-SAME: outs(%[[REAL]], %[[IMAG]] : memref<1024xf32>, memref<1024xf32>)
+// CHECK: return
+
+// -----
+
+// The size of coefficient tensor is 2^(stage-1).
+func @fft_tensor_coef_stage_5(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>,
+ %arg2: tensor<16xf32>, %arg3: tensor<16xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) {
+ %cst1 = arith.constant 5 : index
+ %0:2 = iree_linalg_ext.fft
+ ins(%cst1, %arg2, %arg3: index, tensor<16xf32>, tensor<16xf32>)
+ outs(%arg0, %arg1: tensor<1024xf32>, tensor<1024xf32>)
+ : tensor<1024xf32>, tensor<1024xf32>
+ return %0#0, %0#1 : tensor<1024xf32>, tensor<1024xf32>
+}
+// CHECK-LABEL: func @fft_tensor_coef_stage_5(
+// CHECK-SAME: %[[REAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[IMAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]]
+// CHECK: %[[CST:.+]] = arith.constant 5 : index
+// CHECK: %[[RES:.+]]:2 = iree_linalg_ext.fft
+// CHECK-SAME: ins(%[[CST]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, tensor<16xf32>, tensor<16xf32>)
+// CHECK-SAME: outs(%[[REAL]], %[[IMAG]] : tensor<1024xf32>, tensor<1024xf32>)
+// CHECK-SAME: : tensor<1024xf32>, tensor<1024xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+
+// -----
+
+func @reverse_tensor(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
+ %init = linalg.init_tensor [3, 5] : tensor<3x5xi32>
+ %0 = iree_linalg_ext.reverse
+ dimensions(dense<0> : tensor<1xi64>)
+ ins(%arg0 : tensor<3x5xi32>)
+ outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
+ return %0 : tensor<3x5xi32>
+}
+// CHECK-LABEL: func @reverse_tensor
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [3, 5]
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[INIT]]
+
+// -----
+
+func @reverse_memref(%arg0: memref<3x5xi32>, %arg1: memref<3x5xi32>) {
+ iree_linalg_ext.reverse
+ dimensions(dense<0> : tensor<1xi64>)
+ ins(%arg0 : memref<3x5xi32>)
+ outs(%arg1 : memref<3x5xi32>)
+ return
+}
+// CHECK-LABEL: func @reverse_memref
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x5xi32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x5xi32>
+// CHECK: iree_linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[ARG1]]
+
+// -----
+
+func @reverse_dynamic_tensor(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+ %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xi32>
+ %0 = iree_linalg_ext.reverse
+ dimensions(dense<1> : tensor<1xi64>)
+ ins(%arg0 : tensor<?x?xi32>)
+ outs(%init : tensor<?x?xi32>) : tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: func @reverse_dynamic_tensor
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xi32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[INIT]]
+
+// -----
+
+func @reverse_static_dynamic_tensor(%arg0: tensor<3x5xi32>) -> tensor<?x?xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<3x5xi32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<3x5xi32>
+ %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xi32>
+ %0 = iree_linalg_ext.reverse
+ dimensions(dense<1> : tensor<1xi64>)
+ ins(%arg0 : tensor<3x5xi32>)
+ outs(%init : tensor<?x?xi32>) : tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: func @reverse_static_dynamic_tensor
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[INIT]]
+
+// -----
+
+func @reverse_multi_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
+ %init = linalg.init_tensor [3, 5] : tensor<3x5xi32>
+ %0 = iree_linalg_ext.reverse
+ dimensions(dense<[0, 1]> : tensor<2xi64>)
+ ins(%arg0 : tensor<3x5xi32>)
+ outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
+ return %0 : tensor<3x5xi32>
+}
+// CHECK-LABEL: func @reverse_multi_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [3, 5]
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[INIT]]
diff --git a/llvm-external-projects/iree-dialects/test/iree_linalgext/tiling.mlir b/llvm-external-projects/iree-dialects/test/iree_linalgext/tiling.mlir
new file mode 100644
index 0000000..c0352ef
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/iree_linalgext/tiling.mlir
@@ -0,0 +1,1208 @@
+// RUN: iree-dialects-opt -iree-linalg-ext-tile -split-input-file %s | FileCheck --enable-var-scope --dump-input=fail %s
+
+func @scatter_tiling(
+ %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
+ %update : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = iree_linalg_ext.scatter
+ {__internal_linalg_transform__ = "tiling_input"}
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK: func @scatter_tiling(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?x1xi32>
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[TILESIZEY:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[TILESIZEX:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZEY]]
+// CHECK-SAME: iter_args(%[[INITY:.+]] = %[[ORIGINAL]])
+// CHECK-DAG: %[[USED_TILESIZEY:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[TILESIZEY]], %[[D0]]]
+// CHECK: %[[RESULT_INNER:.+]] = scf.for %[[IV1:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZEX]]
+// CHECK-SAME: iter_args(%[[INITX:.+]] = %[[INITY]])
+// CHECK: %[[USED_TILESIZEX:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[TILESIZEX]], %[[D1]]]
+// CHECK: %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV0]], %[[IV1]]]
+// CHECK-SAME: [%[[USED_TILESIZEY]], %[[USED_TILESIZEX]]]
+// CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV0]], 0]
+// CHECK-SAME: [%[[USED_TILESIZEY]], 1]
+// CHECK: %[[SCATTER_DIM:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]]
+// CHECK: %[[ORIGINAL_SLICE:.+]] = tensor.extract_slice %[[INITX]][0, %[[IV1]]]
+// CHECK-SAME: [%[[SCATTER_DIM]], %[[USED_TILESIZEX]]]
+// CHECK: %[[SCATTER_TILE:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: __internal_linalg_transform__ = "tiling_output"
+// CHECK-SAME: ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]]
+// CHECK-SAME: outs(%[[ORIGINAL_SLICE]]
+// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INITX]][0, %[[IV1]]]
+// CHECK-SAME: [%[[SCATTER_DIM]], %[[USED_TILESIZEX]]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: scf.yield %[[RESULT_INNER]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @scatter_tiling_memref(
+ %original: memref<?x?xf32>, %indices: memref<?x1xi32>,
+ %update : memref<?x?xf32>) {
+ iree_linalg_ext.scatter
+ {__internal_linalg_transform__ = "tiling_input"}
+ ins(%update, %indices : memref<?x?xf32>, memref<?x1xi32>)
+ outs(%original : memref<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ }
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK: func @scatter_tiling_memref(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref<?x1xi32>
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-DAG: %[[TILESIZEY:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[TILESIZEX:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[UPDATES]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[UPDATES]], %[[C1]]
+// CHECK: scf.for %[[IV0:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZEY]]
+// CHECK-DAG: %[[USED_TILESIZEY:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[TILESIZEY]], %[[D0]]]
+// CHECK: scf.for %[[IV1:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZEX]]
+// CHECK-DAG: %[[USED_TILESIZEX:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[TILESIZEX]], %[[D1]]]
+// CHECK: %[[UPDATE_SLICE:.+]] = memref.subview %[[UPDATES]][%[[IV0]], %[[IV1]]]
+// CHECK-SAME: [%[[USED_TILESIZEY]], %[[USED_TILESIZEX]]]
+// CHECK: %[[INDEX_SLICE:.+]] = memref.subview %[[INDICES]][%[[IV0]], 0]
+// CHECK-SAME: [%[[USED_TILESIZEY]], 1]
+// CHECK: %[[SCATTER_DIM:.+]] = memref.dim %[[ORIGINAL]], %[[C0]]
+// CHECK: %[[ORIGINAL_SLICE:.+]] = memref.subview %[[ORIGINAL]][0, %[[IV1]]
+// CHECK-SAME: [%[[SCATTER_DIM]], %[[USED_TILESIZEX]]]
+// CHECK: iree_linalg_ext.scatter
+// CHECK-SAME: __internal_linalg_transform__ = "tiling_output"
+// CHECK-SAME: ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]]
+// CHECK-SAME: outs(%[[ORIGINAL_SLICE]]
+
+// -----
+
+func @scatter_tiling_distribution(
+ %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
+ %update : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = iree_linalg_ext.scatter
+ {__internal_linalg_transform__ = "distribute_input"}
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK: func @scatter_tiling_distribution(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?x1xi32>
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[TILESIZE:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]]
+// CHECK-DAG: %[[ID:.+]] = iree_input.dispatch.workgroup.id[0]
+// CHECK-DAG: %[[COUNT:.+]] = iree_input.dispatch.workgroup.count[0]
+// CHECK-DAG: %[[OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[ID]]]
+// CHECK-DAG: %[[STEP:.+]] = affine.apply #[[MAP0]]()[%[[COUNT]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[OFFSET]] to %[[D0]] step %[[STEP]]
+// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[ORIGINAL]])
+// CHECK: %[[USED_TILESIZE:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+// CHECK: %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], 1]
+// CHECK: %[[D2:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]]
+// CHECK: %[[ORIGINAL_SLICE:.+]] = tensor.extract_slice %[[INIT]][0, 0]
+// CHECK-SAME: [%[[D2]], %[[D1]]]
+// CHECK: %[[SCATTER_TILE:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: __internal_linalg_transform__ = "distribute_output"
+// CHECK-SAME: ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]]
+// CHECK-SAME: outs(%[[ORIGINAL_SLICE]]
+// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0]
+// CHECK-SAME: [%[[D2]], %[[D1]]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @scatter_no_tiling(
+ %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
+ %update : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = iree_linalg_ext.scatter
+ {__internal_linalg_transform__ = "no_tiling_input"}
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK: func @scatter_no_tiling
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?x1xi32>
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: __internal_linalg_transform__ = "no_tiling_output"
+// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @sort_1d(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+ %0 = iree_linalg_ext.sort
+ {__internal_linalg_transform__ = "outer_reduce_input"}
+ outs(%arg0 : tensor<?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %0 = arith.cmpi sgt, %arg2, %arg3 : i32
+ iree_linalg_ext.yield %0 : i1
+ } -> tensor<?xi32>
+ return %0 : tensor<?xi32>
+}
+// CHECK: func @sort_1d(
+// CHECK-SAME: %[[OPERAND:.+]]: tensor<?xi32>
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.sort
+// CHECK-SAME: {__internal_linalg_transform__ = "outer_reduce_output"}
+// CHECK-SAME: outs(%[[OPERAND]] :
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @sort_2d(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = iree_linalg_ext.sort dimension(1)
+ {__internal_linalg_transform__ = "inner_reduce_input"}
+ outs(%arg0 : tensor<?x?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %0 = arith.cmpi sgt, %arg2, %arg3 : i32
+ iree_linalg_ext.yield %0 : i1
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK: func @sort_2d(
+// CHECK-SAME: %[[OPERAND:.+]]: tensor<?x?xi32>
+// CHECK-DAG: %[[TILESIZE:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]]
+// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[OPERAND]])
+// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+// CHECK: %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: %[[SORT_TILE:.+]] = iree_linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "inner_reduce_output"
+// CHECK-SAME: outs(%[[OPERAND_SLICE]]
+// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @sort_2d_inner_parallel(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = iree_linalg_ext.sort dimension(0)
+ {__internal_linalg_transform__ = "outer_reduce_input"}
+ outs(%arg0 : tensor<?x?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %0 = arith.cmpi sgt, %arg2, %arg3 : i32
+ iree_linalg_ext.yield %0 : i1
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK: func @sort_2d_inner_parallel(
+// CHECK-SAME: %[[OPERAND:.+]]: tensor<?x?xi32>
+// CHECK-DAG: %[[TILESIZE:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]]
+// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[OPERAND]])
+// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]]
+// CHECK: %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[INIT]][0, %[[IV]]]
+// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]]
+// CHECK: %[[SORT_TILE:.+]] = iree_linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output"
+// CHECK-SAME: outs(%[[OPERAND_SLICE]]
+// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][0, %[[IV]]]
+// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @sort_2d_multi_result(
+ %arg0: tensor<?x?xi32>, %arg1: tensor<?x?xf32>)
+ -> (tensor<?x?xi32>, tensor<?x?xf32>) {
+ %0:2 = iree_linalg_ext.sort dimension(1)
+ {__internal_linalg_transform__ = "inner_reduce_input"}
+ outs(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %1 = arith.cmpf ogt, %arg4, %arg5 : f32
+ iree_linalg_ext.yield %1 : i1
+ } -> tensor<?x?xi32>, tensor<?x?xf32>
+ return %0#0, %0#1 : tensor<?x?xi32>, tensor<?x?xf32>
+}
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK: func @sort_2d_multi_result(
+// CHECK-SAME: %[[OPERAND1:.+]]: tensor<?x?xi32>
+// CHECK-SAME: %[[OPERAND2:.+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[TILESIZE:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]]
+// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]]
+// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]])
+// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+// CHECK: %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[INIT2]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: %[[SORT_TILE:.+]]:2 = iree_linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "inner_reduce_output"
+// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]
+// CHECK: %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_TILE]]#0 into %[[INIT1]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_TILE]]#1 into %[[INIT2]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: scf.yield %[[YIELD1]], %[[YIELD2]]
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @sort_2d_multi_result_memref(
+ %arg0: memref<?x?xi32>, %arg1: memref<?x?xf32>) {
+ iree_linalg_ext.sort dimension(0)
+ {__internal_linalg_transform__ = "outer_reduce_input"}
+ outs(%arg0, %arg1 : memref<?x?xi32>, memref<?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %0 = arith.cmpf ogt, %arg4, %arg5 : f32
+ iree_linalg_ext.yield %0 : i1
+ }
+ return
+}
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK: func @sort_2d_multi_result_memref(
+// CHECK-SAME: %[[OPERAND1:.+]]: memref<?x?xi32>
+// CHECK-SAME: %[[OPERAND2:.+]]: memref<?x?xf32>
+// CHECK-DAG: %[[TILESIZE:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]]
+// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]]
+// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]]
+// CHECK: %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][0, %[[IV]]]
+// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]]
+// CHECK: %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][0, %[[IV]]]
+// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]]
+// CHECK: iree_linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output"
+// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]
+
+// -----
+
+func @sort_3d_multi_result_distribute(
+ %arg0: tensor<?x?x?xi32>, %arg1 : tensor<?x?x?xf32>)
+ -> (tensor<?x?x?xi32>, tensor<?x?x?xf32>) {
+ %0, %1 = iree_linalg_ext.sort dimension(1)
+ {__internal_linalg_transform__ = "distribute_input"}
+ outs(%arg0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %2 = arith.cmpf ogt, %arg4, %arg5 : f32
+ iree_linalg_ext.yield %2 : i1
+ } -> tensor<?x?x?xi32>, tensor<?x?x?xf32>
+ return %0, %1 : tensor<?x?x?xi32>, tensor<?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+// CHECK: func @sort_3d_multi_result_distribute(
+// CHECK-SAME: %[[OPERAND1:[a-zA-Z0-9_]+]]: tensor<?x?x?xi32>
+// CHECK-SAME: %[[OPERAND2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-DAG: %[[TILESIZE1:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[TILESIZE2:.+]] = arith.constant 30 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[OPERAND1]], %[[C2]]
+// CHECK-DAG: %[[IDX:.+]] = iree_input.dispatch.workgroup.id[0]
+// CHECK-DAG: %[[COUNTX:.+]] = iree_input.dispatch.workgroup.count[0]
+// CHECK-DAG: %[[IDY:.+]] = iree_input.dispatch.workgroup.id[1]
+// CHECK-DAG: %[[COUNTY:.+]] = iree_input.dispatch.workgroup.count[1]
+// CHECK-DAG: %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]]
+// CHECK-DAG: %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]]
+// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]]
+// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]])
+// CHECK-DAG: %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]]
+// CHECK-DAG: %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]]
+// CHECK-DAG: %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]]
+// CHECK: %[[RESULT_INNER:.+]]:2 = scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]]
+// CHECK-SAME: iter_args(%[[INIT3:.+]] = %[[INIT1]], %[[INIT4:.+]] = %[[INIT2]])
+// CHECK-DAG: %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]]
+// CHECK: %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[INIT3]][%[[IV0]], 0, %[[IV1]]]
+// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+// CHECK: %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[INIT4]][%[[IV0]], 0, %[[IV1]]]
+// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+// CHECK: %[[SORT_SLICE:.+]]:2 = iree_linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "distribute_output"
+// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]
+// CHECK: %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_SLICE]]#0
+// CHECK-SAME: into %[[INIT3]][%[[IV0]], 0, %[[IV1]]]
+// CHECK: %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_SLICE]]#1
+// CHECK-SAME: into %[[INIT4]][%[[IV0]], 0, %[[IV1]]]
+// CHECK: scf.yield %[[YIELD1]], %[[YIELD2]]
+// CHECK: scf.yield %[[RESULT_INNER]]#0, %[[RESULT_INNER]]#1
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @sort_3d_multi_result_distribute_memref(
+ %arg0: memref<?x?x?xi32>, %arg1 : memref<?x?x?xf32>) {
+ iree_linalg_ext.sort dimension(1)
+ {__internal_linalg_transform__ = "distribute_input"}
+ outs(%arg0, %arg1 : memref<?x?x?xi32>, memref<?x?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %0 = arith.cmpf ogt, %arg4, %arg5 : f32
+ iree_linalg_ext.yield %0 : i1
+ }
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+// CHECK: func @sort_3d_multi_result_distribute_memref(
+// CHECK-SAME: %[[OPERAND1:[a-zA-Z0-9_]+]]: memref<?x?x?xi32>
+// CHECK-SAME: %[[OPERAND2:[a-zA-Z0-9_]+]]: memref<?x?x?xf32>
+// CHECK-DAG: %[[TILESIZE1:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[TILESIZE2:.+]] = arith.constant 30 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = memref.dim %[[OPERAND1]], %[[C2]]
+// CHECK-DAG: %[[IDX:.+]] = iree_input.dispatch.workgroup.id[0]
+// CHECK-DAG: %[[COUNTX:.+]] = iree_input.dispatch.workgroup.count[0]
+// CHECK-DAG: %[[IDY:.+]] = iree_input.dispatch.workgroup.id[1]
+// CHECK-DAG: %[[COUNTY:.+]] = iree_input.dispatch.workgroup.count[1]
+// CHECK-DAG: %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]]
+// CHECK-DAG: %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]]
+// CHECK: scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]]
+// CHECK-DAG: %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]]
+// CHECK-DAG: %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]]
+// CHECK-DAG: %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]]
+// CHECK: scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]]
+// CHECK-DAG: %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]]
+// CHECK: %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][%[[IV0]], 0, %[[IV1]]]
+// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+// CHECK: %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][%[[IV0]], 0, %[[IV1]]]
+// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+// CHECK: iree_linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "distribute_output"
+// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]
+
+// -----
+
+func @slice_insert(%source :tensor<?x?xf32>, %dest: tensor<?x?xf32>,
+ %idx0 : index, %idx1 : index) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = tensor.dim %source, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %source, %c1 : tensor<?x?xf32>
+ %2 = tensor.insert_slice %source into %dest[%idx0, %idx1] [%0, %1] [1, 1]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?xf32> into tensor<?x?xf32>
+ return %2 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK: func @slice_insert(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] =
+// CHECK: %[[YIELD1:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] =
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[OFFSET0:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG2]]]
+// CHECK: %[[OFFSET1:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG3]]]
+// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[SLICE]]
+// CHECK-SAME: into %{{.+}}[%[[OFFSET0]], %[[OFFSET1]]]
+// CHECK: scf.yield %[[UPDATE]]
+// CHECK: scf.yield %[[YIELD1]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @slice_insert_rank_reduce(%source :tensor<?x?xf32>, %dest: tensor<?x?x?xf32>,
+ %idx0 : index, %idx1 : index) -> tensor<?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = tensor.dim %source, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %source, %c1 : tensor<?x?xf32>
+ %2 = tensor.insert_slice %source into %dest[%idx0, 0, %idx1] [%0, 1, %1] [1, 1, 1]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?xf32> into tensor<?x?x?xf32>
+ return %2 : tensor<?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK: func @slice_insert_rank_reduce(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] =
+// CHECK: %[[YIELD1:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] =
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[OFFSET0:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG2]]]
+// CHECK: %[[OFFSET1:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG3]]]
+// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[SLICE]]
+// CHECK-SAME: into %{{.+}}[%[[OFFSET0]], 0, %[[OFFSET1]]]
+// CHECK: scf.yield %[[UPDATE]]
+// CHECK: scf.yield %[[YIELD1]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @fft_1d_stage_5(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>,
+ %arg2: tensor<16xf32>, %arg3: tensor<16xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) {
+ %cst1 = arith.constant 5 : index
+ %0:2 = iree_linalg_ext.fft
+ {__internal_linalg_transform__ = "tiling_1d_stage5_fft_input"}
+ ins(%cst1, %arg2, %arg3: index, tensor<16xf32>, tensor<16xf32>)
+ outs(%arg0, %arg1: tensor<1024xf32>, tensor<1024xf32>)
+ : tensor<1024xf32>, tensor<1024xf32>
+ return %0#0, %0#1 : tensor<1024xf32>, tensor<1024xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK: func @fft_1d_stage_5(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
+// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index
+// CHECK: %[[RES:.+]]:2 = scf.for %[[I:.+]] = %[[C0]] to %[[C1024]] step %[[C32]]
+// CHECK-SAME: iter_args(%[[ARG5:.+]] = %[[ARG0]], %[[ARG6:.+]] = %[[ARG1]])
+// CHECK-SAME: -> (tensor<1024xf32>, tensor<1024xf32>) {
+// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C32]], %[[C1024]]]
+// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG5]][%[[I]]] [%[[SIZE]]] [1] : tensor<1024xf32> to tensor<?xf32>
+// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG6]][%[[I]]] [%[[SIZE]]] [1] : tensor<1024xf32> to tensor<?xf32>
+// CHECK: %[[FFT:.+]]:2 = iree_linalg_ext.fft
+// CHECK-SAME: {__internal_linalg_transform__ = "tiling_1d_stage5_fft_output"}
+// CHECK-SAME: ins(%[[C5]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, tensor<16xf32>, tensor<16xf32>)
+// CHECK-SAME: outs(%[[SLICE1]], %[[SLICE2]] : tensor<?xf32>, tensor<?xf32>)
+// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[FFT]]#0 into %[[ARG5]][%[[I]]] [%[[SIZE]]] [1] : tensor<?xf32> into tensor<1024xf32>
+// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[FFT]]#1 into %[[ARG6]][%[[I]]] [%[[SIZE]]] [1] : tensor<?xf32> into tensor<1024xf32>
+// CHECK: scf.yield %[[INSERT1]], %[[INSERT2]]
+// CHECK: return %[[RES]]#0, %[[RES]]#1 : tensor<1024xf32>, tensor<1024xf32>
+
+// -----
+
+func @fft_2d_stage_5(%arg0: tensor<3x1024xf32>, %arg1: tensor<3x1024xf32>,
+ %arg2: tensor<16xf32>, %arg3: tensor<16xf32>) -> (tensor<3x1024xf32>, tensor<3x1024xf32>) {
+ %cst1 = arith.constant 5 : index
+ %0:2 = iree_linalg_ext.fft
+ {__internal_linalg_transform__ = "tiling_2d_stage5_fft_input"}
+ ins(%cst1, %arg2, %arg3: index, tensor<16xf32>, tensor<16xf32>)
+ outs(%arg0, %arg1: tensor<3x1024xf32>, tensor<3x1024xf32>)
+ : tensor<3x1024xf32>, tensor<3x1024xf32>
+ return %0#0, %0#1 : tensor<3x1024xf32>, tensor<3x1024xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK: func @fft_2d_stage_5(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
+// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index
+// CHECK: %[[RES:.+]]:2 = scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ARG5:.+]] = %[[ARG0]], %[[ARG6:.+]] = %[[ARG1]])
+// CHECK-SAME: -> (tensor<3x1024xf32>, tensor<3x1024xf32>) {
+// CHECK: %[[SZ1:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C10]], %[[C3]]]
+// CHECK: %{{.+}} = scf.for %[[J:.+]] = %[[C0]] to %[[C1024]] step %[[C32]]
+// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG5]], %[[ARG9:.+]] = %[[ARG6]]) -> (tensor<3x1024xf32>, tensor<3x1024xf32>) {
+// CHECK: %[[SZ2:.+]] = affine.min #[[MAP1]](%[[J]])[%[[C32]], %[[C1024]]]
+// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG8]][%[[I]], %[[J]]] [%[[SZ1]], %[[SZ2]]] [1, 1]
+// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG9]][%[[I]], %[[J]]] [%[[SZ1]], %[[SZ2]]] [1, 1]
+// CHECK: %[[FFT:.+]]:2 = iree_linalg_ext.fft
+// CHECK-SAME: {__internal_linalg_transform__ = "tiling_2d_stage5_fft_output"}
+// CHECK-SAME: ins(%[[C5]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, tensor<16xf32>, tensor<16xf32>)
+// CHECK-SAME: outs(%[[SLICE1]], %[[SLICE2]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[FFT]]#0 into %[[ARG8]][%[[I]], %[[J]]] [%[[SZ1]], %[[SZ2]]] [1, 1]
+// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[FFT]]#1 into %[[ARG9]][%[[I]], %[[J]]] [%[[SZ1]], %[[SZ2]]] [1, 1]
+// CHECK: scf.yield %[[INSERT1]], %[[INSERT2]] : tensor<3x1024xf32>, tensor<3x1024xf32>
+
+// -----
+
+func @fft_1d_stage_5_memref(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>,
+ %arg2: memref<16xf32>, %arg3: memref<16xf32>) {
+ %cst1 = arith.constant 5 : index
+ iree_linalg_ext.fft
+ {__internal_linalg_transform__ = "tiling_1d_stage5_fft_input"}
+ ins(%cst1, %arg2, %arg3: index, memref<16xf32>, memref<16xf32>)
+ outs(%arg0, %arg1: memref<1024xf32>, memref<1024xf32>)
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (32, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK: func @fft_1d_stage_5_memref(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_REAL:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[COEF_IMAG:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
+// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C1024]] step %[[C32]] {
+// CHECK: %[[SZ:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C32]], %[[C1024]]]
+// CHECK: %[[SUB1:.+]] = memref.subview %[[ARG0]][%[[I]]] [%[[SZ]]] [1] : memref<1024xf32> to memref<?xf32, #[[MAP1]]>
+// CHECK: %[[SUB2:.+]] = memref.subview %[[ARG1]][%[[I]]] [%[[SZ]]] [1] : memref<1024xf32> to memref<?xf32, #[[MAP1]]>
+// CHECK: iree_linalg_ext.fft
+// CHECK-SAME: {__internal_linalg_transform__ = "tiling_1d_stage5_fft_output"}
+// CHECK-SAME: ins(%[[C5]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, memref<16xf32>, memref<16xf32>)
+// CHECK-SAME: outs(%[[SUB1]], %[[SUB2]] : memref<?xf32, #[[MAP1]]>, memref<?xf32, #[[MAP1]]>)
+
+// -----
+
+func @reverse_memref(%arg0: memref<?xi32>, %arg1: memref<?xi32>) {
+ iree_linalg_ext.reverse
+ dimensions(dense<0> : tensor<1xi64>)
+ {__internal_linalg_transform__ = "tiling_input"}
+ ins(%arg0: memref<?xi32>)
+ outs(%arg1: memref<?xi32>)
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)>
+// CHECK: func @reverse_memref(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xi32>
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]] {
+// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C10]], %[[D0]]]
+// CHECK: %[[SUB_IN:.+]] = memref.subview %[[ARG0]][%[[I]]] [%[[SIZE]]] [1]
+// CHECK: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xi32>
+// CHECK: %[[IDX:.+]] = affine.apply #[[MAP2]]()[%[[T0]], %[[I]], %[[SIZE]]]
+// CHECK: %[[SUB_OUT:.+]] = memref.subview %[[ARG1]][%[[IDX]]] [%[[SIZE]]] [1]
+// CHECK: iree_linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>)
+// CHECK-SAME: {__internal_linalg_transform__ = "tiling_output"}
+// CHECK-SAME: ins(%[[SUB_IN]]
+// CHECK-SAME: outs(%[[SUB_OUT]]
+
+// -----
+
+func @reverse_tensor_multi_dim(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+ %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xi32>
+ %0 = iree_linalg_ext.reverse
+ dimensions(dense<[0, 1]> : tensor<2xi64>)
+ {__internal_linalg_transform__ = "tiling_input"}
+ ins(%arg0: tensor<?x?xi32>)
+ outs(%init: tensor<?x?xi32>) : tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)>
+// CHECK: func @reverse_tensor_multi_dim(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xi32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] : tensor<?x?xi32>
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xi32>
+// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT]]) -> (tensor<?x?xi32>) {
+// CHECK: %[[SIZE_I:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C10]], %[[D0]]]
+// CHECK: %[[RES2:.+]] = scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[INIT3:.+]] = %[[INIT2]]) -> (tensor<?x?xi32>) {
+// CHECK: %[[SIZE_J:.+]] = affine.min #[[MAP1]](%[[J]])[%[[C20]], %[[D1]]]
+// CHECK: %[[SUB_IN:.+]] = tensor.extract_slice
+// CHECK-SAME: %[[ARG0]][%[[I]], %[[J]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
+// CHECK: %[[T0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
+// CHECK: %[[IDX0:.+]] = affine.apply #[[MAP2]]()[%[[T0]], %[[I]], %[[SIZE_I]]]
+// CHECK: %[[T1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xi32>
+// CHECK: %[[IDX1:.+]] = affine.apply #[[MAP2]]()[%[[T1]], %[[J]], %[[SIZE_J]]]
+// CHECK: %[[SUB_INIT:.+]] = tensor.extract_slice
+// CHECK-SAME: %[[INIT]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
+// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>)
+// CHECK-SAME: {__internal_linalg_transform__ = "tiling_output"}
+// CHECK-SAME: ins(%[[SUB_IN]]
+// CHECK-SAME: outs(%[[SUB_INIT]]
+// CHECK: %[[RES3:.+]] = tensor.insert_slice %[[REV]] into
+// CHECK-SAME: %[[INIT3]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
+// CHECK: scf.yield %[[RES3]]
+// CHECK: scf.yield %[[RES2]]
+// CHECK: return %[[RES]]
+
+// -----
+
+func @dynamic_insert_slice(%arg0 : tensor<?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %0 = tensor.insert_slice %arg0 into %arg1[%arg2, %arg3] [1, %d0] [1, 1]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?xf32> into tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK: func @dynamic_insert_slice(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[RESULT:.+]] = scf.for %[[ARG4:.+]] = %[[C0]] to %[[D0]]
+// CHECK-SAME: step %[[C10]] iter_args(%[[ARG5:.+]] = %[[ARG1]])
+// CHECK: %[[TILESIZE:.+]] = affine.min #[[MAP0]](%[[ARG4]])[%[[C10]], %[[D0]]]
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG4]]] [%[[TILESIZE]]]
+// CHECK: %[[OFFSET:.+]] = affine.apply #[[MAP1]](%[[ARG4]])[%[[ARG3]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT]] into %[[ARG5]]
+// CHECK-SAME: [%[[ARG2]], %[[OFFSET]]] [1, %[[TILESIZE]]]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: return %[[RESULT]]
+
+
+// -----
+
+func @insert_slice_rank_reduced_inner(%arg0 : tensor<?xf32>,
+ %arg1 : tensor<?x?x?xf32>, %arg2: index, %arg3 : index, %arg4 : index) -> tensor<?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %0 = tensor.insert_slice %arg0 into %arg1[%arg2, %arg3, %arg4] [1, %d0, 1] [1, 1, 1]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?xf32> into tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK: func @insert_slice_rank_reduced_inner(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[LB:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[STEP:.+]] = arith.constant 10 : index
+// CHECK: %[[UB:.+]] = tensor.dim %[[ARG0]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] = %[[LB]]
+// CHECK-SAME: to %[[D0]] step %[[STEP]] iter_args(%[[ARG6:.+]] = %[[ARG1]])
+// CHECK: %[[TILESIZE:.+]] = affine.min #[[MAP0]](%[[ARG5]])[%[[STEP]], %[[UB]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]]] [%[[TILESIZE]]]
+// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP1]](%[[IV0]])[%[[ARG3]]]
+// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SLICE]] into %[[ARG6]]
+// CHECK-SAME: [%[[ARG2]], %[[APPLY]], %[[ARG4]]] [1, %[[TILESIZE]], 1]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice(%arg0 : tensor<?x?xf32>, %arg1: index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index,
+ %arg6 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2] [%arg3, %arg4] [%arg5, %arg6]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG3]], %[[ARG4]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG3]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG3]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG4]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG4]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG5]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG6]], %[[ARG2]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[OFFSET_X]]] [%[[TILE_Y]], %[[TILE_X]]] [%[[ARG5]], %[[ARG6]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_static(%arg0 : tensor<50x60xf32>) -> tensor<20x30xf32> {
+ %0 = tensor.extract_slice %arg0[2, 3] [20, 30] [5, 6]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<50x60xf32> to tensor<20x30xf32>
+ return %0 : tensor<20x30xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_static
+// CHECK-SAME: %[[ARG0:.+]]: tensor<50x60xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [20, 30]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[C20]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<20x30xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C20]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[C30]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<20x30xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[C30]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[C5]], %[[C2]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[C6]], %[[C3]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[OFFSET_X]]] [%[[TILE_Y]], %[[TILE_X]]] [5, 6]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_outer(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3] [1, %arg4, %arg5] [%arg6, %arg7, %arg8]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_outer
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG4]], %[[ARG5]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG4]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG4]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG5]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG7]], %[[ARG2]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG8]], %[[ARG3]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[ARG1]], %[[OFFSET_Y]], %[[OFFSET_X]]]
+// CHECK-SAME: [1, %[[TILE_Y]], %[[TILE_X]]] [%[[ARG6]], %[[ARG7]], %[[ARG8]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_middle(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3] [%arg4, 1, %arg5] [%arg6, %arg7, %arg8]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_middle
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG4]], %[[ARG5]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG4]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG4]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG5]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG6]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG8]], %[[ARG3]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[ARG2]], %[[OFFSET_X]]]
+// CHECK-SAME: [%[[TILE_Y]], 1, %[[TILE_X]]] [%[[ARG6]], %[[ARG7]], %[[ARG8]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_inner(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3] [%arg4, %arg5, 1] [%arg6, %arg7, %arg8]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_inner
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG4]], %[[ARG5]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG4]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG4]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG5]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG6]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG7]], %[[ARG2]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[OFFSET_X]], %[[ARG3]]]
+// CHECK-SAME: [%[[TILE_Y]], %[[TILE_X]], 1] [%[[ARG6]], %[[ARG7]], %[[ARG8]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_two_dims_1(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index, %arg9 : index, %arg10 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3, %arg4] [%arg5, 1, %arg6, 1] [%arg7, %arg8, %arg9, %arg10]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_two_dims_1
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG9:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG5]], %[[ARG6]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG5]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG6]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG6]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG7]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG9]], %[[ARG3]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[ARG2]], %[[OFFSET_X]], %[[ARG4]]]
+// CHECK-SAME: [%[[TILE_Y]], 1, %[[TILE_X]], 1]
+// CHECK-SAME: [%[[ARG7]], %[[ARG8]], %[[ARG9]], %[[ARG10]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_two_dims_2(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index, %arg9 : index, %arg10 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3, %arg4] [%arg5, 1, 1, %arg6] [%arg7, %arg8, %arg9, %arg10]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_two_dims_2
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG5]], %[[ARG6]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG5]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG6]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG6]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG7]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG10]], %[[ARG4]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[ARG2]], %[[ARG3]], %[[OFFSET_X]]]
+// CHECK-SAME: [%[[TILE_Y]], 1, 1, %[[TILE_X]]]
+// CHECK-SAME: [%[[ARG7]], %[[ARG8]], %[[ARG9]], %[[ARG10]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_two_dims_3(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index, %arg9 : index, %arg10 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3, %arg4] [1, %arg5, 1, %arg6] [%arg7, %arg8, %arg9, %arg10]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_two_dims_3
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG9:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG5]], %[[ARG6]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG5]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG6]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG6]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG8]], %[[ARG2]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG10]], %[[ARG4]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[ARG1]], %[[OFFSET_Y]], %[[ARG3]], %[[OFFSET_X]]]
+// CHECK-SAME: [1, %[[TILE_Y]], 1, %[[TILE_X]]]
+// CHECK-SAME: [%[[ARG7]], %[[ARG8]], %[[ARG9]], %[[ARG10]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_two_dims_4(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index, %arg9 : index, %arg10 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3, %arg4] [%arg5, %arg6, 1, 1] [%arg7, %arg8, %arg9, %arg10]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_two_dims_4
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG9:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG5]], %[[ARG6]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG5]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG6]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG6]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG7]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG8]], %[[ARG2]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[OFFSET_X]], %[[ARG3]], %[[ARG4]]]
+// CHECK-SAME: [%[[TILE_Y]], %[[TILE_X]], 1, 1]
+// CHECK-SAME: [%[[ARG7]], %[[ARG8]], %[[ARG9]], %[[ARG10]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @pad_tensor(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index, %arg5 : f32) -> tensor<?x?xf32> {
+ %0 = linalg.pad_tensor %arg0 low[%arg1, %arg2] high[%arg3, %arg4] {
+ ^bb0(%arg6 : index, %arg7 : index):
+ linalg.yield %arg5 : f32
+ } {__internal_linalg_transform__ = "tiling_input"}
+ : tensor<?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)>
+// CHECK: func @pad_tensor
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: f32
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor
+// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[UBY:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG3]], %[[D0]]]
+// CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[UBX:.+]] = affine.apply #[[MAP0]]()[%[[ARG2]], %[[ARG4]], %[[D1]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[UBY]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ARG7:.+]] = %[[INIT]])
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[UBX]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ARG9:.+]] = %[[ARG7]])
+// CHECK: %[[PAD_TILE:.+]] = scf.if
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[PAD_TILE]] into %[[ARG9]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
diff --git a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
index a71a199..7b0d4b0 100644
--- a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
+++ b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
@@ -28,16 +28,16 @@
// CHECK-LABEL: @box
// NOTE: "78" is the type code for signed i32
iree_pydm.func @box(%arg0 : !iree_pydm.integer<32>) -> (!iree_pydm.exception_result, !iree_pydm.object<!iree_pydm.integer<32>>) {
- // CHECK: %[[LIST:.*]] = iree.list.create : !iree.list<!iree.variant>
+ // CHECK: %[[LIST:.*]] = iree_input.list.create : !iree_input.list<!iree_input.variant>
// CHECK: %[[c2:.*]] = arith.constant 2 : index
- // CHECK: iree.list.resize %[[LIST]], %c2 : !iree.list<!iree.variant>
+ // CHECK: iree_input.list.resize %[[LIST]], %c2 : !iree_input.list<!iree_input.variant>
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[c9:.*]] = arith.constant 78 : i32
- // CHECK: iree.list.set %[[LIST]][%[[c0]]], %[[c9]] : !iree.list<!iree.variant>, i32
+ // CHECK: iree_input.list.set %[[LIST]][%[[c0]]], %[[c9]] : !iree_input.list<!iree_input.variant>, i32
// CHECK: %[[c1:.*]] = arith.constant 1 : index
- // CHECK: iree.list.set %[[LIST]][%[[c1]]], %arg0 : !iree.list<!iree.variant>, i32
+ // CHECK: iree_input.list.set %[[LIST]][%[[c1]]], %arg0 : !iree_input.list<!iree_input.variant>, i32
// CHECK: %[[c0_i32:.*]] = arith.constant 0 : i32
- // return %[[c0_i32]], %[[LIST]] : i32, !iree.list<!iree.variant>
+ // return %[[c0_i32]], %[[LIST]] : i32, !iree_input.list<!iree_input.variant>
%0 = box %arg0 : !iree_pydm.integer<32> -> !iree_pydm.object<!iree_pydm.integer<32>>
return %0 : !iree_pydm.object<!iree_pydm.integer<32>>
}
@@ -46,18 +46,18 @@
// CHECK-LABEL: @alloc_store_load_var
// NOTE: 256 is the type code for a plain object
iree_pydm.func @alloc_store_load_var(%arg0 : !iree_pydm.object) -> (!iree_pydm.exception_result, !iree_pydm.object) {
- // CHECK: %[[A:.*]] = iree.list.create : !iree.list<!iree.variant>
+ // CHECK: %[[A:.*]] = iree_input.list.create : !iree_input.list<!iree_input.variant>
%a = alloc_free_var "a" -> !iree_pydm.free_var_ref
// CHECK: %[[c2:.*]] = arith.constant 2 : index
- // CHECK: iree.list.resize %[[A]], %[[c2]] : !iree.list<!iree.variant>
+ // CHECK: iree_input.list.resize %[[A]], %[[c2]] : !iree_input.list<!iree_input.variant>
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[object_code:.*]] = arith.constant 256 : i32
- // CHECK: iree.list.set %[[A]][%[[c0]]], %[[object_code]]
+ // CHECK: iree_input.list.set %[[A]][%[[c0]]], %[[object_code]]
// CHECK: %[[c1:.*]] = arith.constant 1 : index
- // CHECK: iree.list.set %[[A]][%[[c1]]], %arg0 : !iree.list<!iree.variant>, !iree.list<!iree.variant>
+ // CHECK: iree_input.list.set %[[A]][%[[c1]]], %arg0 : !iree_input.list<!iree_input.variant>, !iree_input.list<!iree_input.variant>
store_var %a = %arg0 : !iree_pydm.free_var_ref, !iree_pydm.object
// CHECK: %[[c1_0:.*]] = arith.constant 1 : index
- // CHECK: %[[LOADED:.*]] = iree.list.get %[[A]][%[[c1_0]]] : !iree.list<!iree.variant> -> !iree.list<!iree.variant>
+ // CHECK: %[[LOADED:.*]] = iree_input.list.get %[[A]][%[[c1_0]]] : !iree_input.list<!iree_input.variant> -> !iree_input.list<!iree_input.variant>
%0 = load_var %a : !iree_pydm.free_var_ref -> !iree_pydm.object
// CHECK: return {{.*}}, %[[LOADED]]
return %0 : !iree_pydm.object
@@ -69,7 +69,7 @@
iree_pydm.func @unbox(%arg0 : !iree_pydm.object) -> (!iree_pydm.exception_result, !iree_pydm.integer<32>) {
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[NEEDED_TYPE_CODE:.*]] = arith.constant 78 : i32
- // CHECK: %[[TYPE_CODE:.*]] = iree.list.get %arg0[%[[c0]]] : !iree.list<!iree.variant> -> i32
+ // CHECK: %[[TYPE_CODE:.*]] = iree_input.list.get %arg0[%[[c0]]] : !iree_input.list<!iree_input.variant> -> i32
// CHECK: %[[TYPE_EQ:.*]] = arith.cmpi eq, %[[NEEDED_TYPE_CODE]], %[[TYPE_CODE]] : i32
// CHECK: cond_br %[[TYPE_EQ]], ^bb1, ^bb4
@@ -77,7 +77,7 @@
// CHECK: ^bb1:
// CHECK: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[c0_i32:.*]] = arith.constant 0 : i32
- // CHECK: %[[CONTENTS:.*]] = iree.list.get %arg0[%[[c1]]] : !iree.list<!iree.variant> -> i32
+ // CHECK: %[[CONTENTS:.*]] = iree_input.list.get %arg0[%[[c1]]] : !iree_input.list<!iree_input.variant> -> i32
// CHECK: br ^bb2(%[[c0_i32]], %[[CONTENTS]] : i32, i32)
// bb2: Check status code (from raise_on_failure)
@@ -105,11 +105,11 @@
// bb1: success
// CHECK: ^bb1:
// CHECK: %[[c0_i32_0:.*]] = arith.constant 0 : i32
- // CHECK: return %[[c0_i32_0]], %arg1 : i32, !iree.list<!iree.variant>
+ // CHECK: return %[[c0_i32_0]], %arg1 : i32, !iree_input.list<!iree_input.variant>
// bb2: failure
// CHECK: ^bb2:
- // CHECK: %[[NULL:.*]] = iree.list.create : !iree.list<!iree.variant>
- // CHECK: return %arg0, %[[NULL]] : i32, !iree.list<!iree.variant>
+ // CHECK: %[[NULL:.*]] = iree_input.list.create : !iree_input.list<!iree_input.variant>
+ // CHECK: return %arg0, %[[NULL]] : i32, !iree_input.list<!iree_input.variant>
raise_on_failure %arg0 : !iree_pydm.exception_result
return %arg1 : !iree_pydm.object
}
@@ -142,7 +142,7 @@
// CHECK-LABEL: @get_type_code
iree_pydm.func @get_type_code(%arg0 : !iree_pydm.object) -> (!iree_pydm.exception_result, !iree_pydm.integer) {
// CHECK: %[[c0:.*]] = arith.constant 0 : index
- // CHECK: %[[R:.*]] = iree.list.get %arg0[%[[c0]]] : !iree.list<!iree.variant> -> i32
+ // CHECK: %[[R:.*]] = iree_input.list.get %arg0[%[[c0]]] : !iree_input.list<!iree_input.variant> -> i32
%0 = get_type_code %arg0 : !iree_pydm.object -> !iree_pydm.integer
return %0 : !iree_pydm.integer
}
@@ -157,17 +157,17 @@
// -----
// CHECK-LABEL: func @make_tuple(
-// CHECK-SAME: %[[VAL_0:.*]]: !iree.list<!iree.variant>,
-// CHECK-SAME: %[[VAL_1:.*]]: !iree.list<!iree.variant>) -> (i32, !iree.list<!iree.variant>) {
+// CHECK-SAME: %[[VAL_0:.*]]: !iree_input.list<!iree_input.variant>,
+// CHECK-SAME: %[[VAL_1:.*]]: !iree_input.list<!iree_input.variant>) -> (i32, !iree_input.list<!iree_input.variant>) {
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
-// CHECK: %[[VAL_3:.*]] = iree.list.create %[[VAL_2]] : !iree.list<!iree.variant>
-// CHECK: iree.list.resize %[[VAL_3]], %[[VAL_2]] : !iree.list<!iree.variant>
+// CHECK: %[[VAL_3:.*]] = iree_input.list.create %[[VAL_2]] : !iree_input.list<!iree_input.variant>
+// CHECK: iree_input.list.resize %[[VAL_3]], %[[VAL_2]] : !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK: iree.list.set %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_0]] : !iree.list<!iree.variant>, !iree.list<!iree.variant>
+// CHECK: iree_input.list.set %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_0]] : !iree_input.list<!iree_input.variant>, !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK: iree.list.set %[[VAL_3]]{{\[}}%[[VAL_5]]], %[[VAL_1]] : !iree.list<!iree.variant>, !iree.list<!iree.variant>
+// CHECK: iree_input.list.set %[[VAL_3]]{{\[}}%[[VAL_5]]], %[[VAL_1]] : !iree_input.list<!iree_input.variant>, !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
-// CHECK: return %[[VAL_6]], %[[VAL_3]] : i32, !iree.list<!iree.variant>
+// CHECK: return %[[VAL_6]], %[[VAL_3]] : i32, !iree_input.list<!iree_input.variant>
// CHECK: }
iree_pydm.func @make_tuple(%arg0 : !iree_pydm.object<!iree_pydm.integer>, %arg1 : !iree_pydm.object<!iree_pydm.integer>) -> (!iree_pydm.exception_result, !iree_pydm.tuple) {
%0 = make_tuple %arg0, %arg1 : !iree_pydm.object<!iree_pydm.integer>, !iree_pydm.object<!iree_pydm.integer> -> !iree_pydm.tuple
@@ -176,17 +176,17 @@
// -----
// CHECK-LABEL: func @make_list(
-// CHECK-SAME: %[[VAL_0:.*]]: !iree.list<!iree.variant>,
-// CHECK-SAME: %[[VAL_1:.*]]: !iree.list<!iree.variant>) -> (i32, !iree.list<!iree.variant>) {
+// CHECK-SAME: %[[VAL_0:.*]]: !iree_input.list<!iree_input.variant>,
+// CHECK-SAME: %[[VAL_1:.*]]: !iree_input.list<!iree_input.variant>) -> (i32, !iree_input.list<!iree_input.variant>) {
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
-// CHECK: %[[VAL_3:.*]] = iree.list.create %[[VAL_2]] : !iree.list<!iree.variant>
-// CHECK: iree.list.resize %[[VAL_3]], %[[VAL_2]] : !iree.list<!iree.variant>
+// CHECK: %[[VAL_3:.*]] = iree_input.list.create %[[VAL_2]] : !iree_input.list<!iree_input.variant>
+// CHECK: iree_input.list.resize %[[VAL_3]], %[[VAL_2]] : !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK: iree.list.set %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_0]] : !iree.list<!iree.variant>, !iree.list<!iree.variant>
+// CHECK: iree_input.list.set %[[VAL_3]]{{\[}}%[[VAL_4]]], %[[VAL_0]] : !iree_input.list<!iree_input.variant>, !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK: iree.list.set %[[VAL_3]]{{\[}}%[[VAL_5]]], %[[VAL_1]] : !iree.list<!iree.variant>, !iree.list<!iree.variant>
+// CHECK: iree_input.list.set %[[VAL_3]]{{\[}}%[[VAL_5]]], %[[VAL_1]] : !iree_input.list<!iree_input.variant>, !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
-// CHECK: return %[[VAL_6]], %[[VAL_3]] : i32, !iree.list<!iree.variant>
+// CHECK: return %[[VAL_6]], %[[VAL_3]] : i32, !iree_input.list<!iree_input.variant>
// CHECK: }
iree_pydm.func @make_list(%arg0 : !iree_pydm.object<!iree_pydm.integer>, %arg1 : !iree_pydm.object<!iree_pydm.integer>) -> (!iree_pydm.exception_result, !iree_pydm.list) {
%0 = make_list %arg0, %arg1 : !iree_pydm.object<!iree_pydm.integer>, !iree_pydm.object<!iree_pydm.integer> -> !iree_pydm.list
@@ -195,17 +195,17 @@
// -----
// CHECK-LABEL: func @dynamic_unpack(
-// CHECK-SAME: %[[VAL_0:.*]]: !iree.list<!iree.variant>) -> (i32, !iree.list<!iree.variant>) {
+// CHECK-SAME: %[[VAL_0:.*]]: !iree_input.list<!iree_input.variant>) -> (i32, !iree_input.list<!iree_input.variant>) {
// CHECK: %[[VAL_1:.*]] = arith.constant 2 : index
-// CHECK: %[[VAL_2:.*]] = iree.list.size %[[VAL_0]] : !iree.list<!iree.variant>
+// CHECK: %[[VAL_2:.*]] = iree_input.list.size %[[VAL_0]] : !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_3:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_2]] : index
// CHECK: cond_br %[[VAL_3]], ^bb1, ^bb4
// CHECK: ^bb1:
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_6:.*]] = iree.list.get %[[VAL_0]]{{\[}}%[[VAL_5]]] : !iree.list<!iree.variant> -> i32
+// CHECK: %[[VAL_6:.*]] = iree_input.list.get %[[VAL_0]]{{\[}}%[[VAL_5]]] : !iree_input.list<!iree_input.variant> -> i32
// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_8:.*]] = iree.list.get %[[VAL_0]]{{\[}}%[[VAL_7]]] : !iree.list<!iree.variant> -> i1
+// CHECK: %[[VAL_8:.*]] = iree_input.list.get %[[VAL_0]]{{\[}}%[[VAL_7]]] : !iree_input.list<!iree_input.variant> -> i1
// CHECK: br ^bb2(%[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : i32, i32, i1)
// CHECK: ^bb2(%[[VAL_9:.*]]: i32, %[[VAL_10:.*]]: i32, %[[VAL_11:.*]]: i1):
// CHECK: %[[VAL_12:.*]] = arith.constant 0 : i32
@@ -213,22 +213,22 @@
// CHECK: cond_br %[[VAL_13]], ^bb3, ^bb5
// CHECK: ^bb3:
// CHECK: %[[VAL_14:.*]] = arith.constant 2 : index
-// CHECK: %[[VAL_15:.*]] = iree.list.create %[[VAL_14]] : !iree.list<!iree.variant>
-// CHECK: iree.list.resize %[[VAL_15]], %[[VAL_14]] : !iree.list<!iree.variant>
+// CHECK: %[[VAL_15:.*]] = iree_input.list.create %[[VAL_14]] : !iree_input.list<!iree_input.variant>
+// CHECK: iree_input.list.resize %[[VAL_15]], %[[VAL_14]] : !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_16:.*]] = arith.constant 0 : index
-// CHECK: iree.list.set %[[VAL_15]]{{\[}}%[[VAL_16]]], %[[VAL_10]] : !iree.list<!iree.variant>, i32
+// CHECK: iree_input.list.set %[[VAL_15]]{{\[}}%[[VAL_16]]], %[[VAL_10]] : !iree_input.list<!iree_input.variant>, i32
// CHECK: %[[VAL_17:.*]] = arith.constant 1 : index
-// CHECK: iree.list.set %[[VAL_15]]{{\[}}%[[VAL_17]]], %[[VAL_11]] : !iree.list<!iree.variant>, i1
+// CHECK: iree_input.list.set %[[VAL_15]]{{\[}}%[[VAL_17]]], %[[VAL_11]] : !iree_input.list<!iree_input.variant>, i1
// CHECK: %[[VAL_18:.*]] = arith.constant 0 : i32
-// CHECK: return %[[VAL_18]], %[[VAL_15]] : i32, !iree.list<!iree.variant>
+// CHECK: return %[[VAL_18]], %[[VAL_15]] : i32, !iree_input.list<!iree_input.variant>
// CHECK: ^bb4:
// CHECK: %[[VAL_19:.*]] = arith.constant -4 : i32
// CHECK: %[[VAL_20:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_21:.*]] = arith.constant false
// CHECK: br ^bb2(%[[VAL_19]], %[[VAL_20]], %[[VAL_21]] : i32, i32, i1)
// CHECK: ^bb5:
-// CHECK: %[[VAL_22:.*]] = iree.list.create : !iree.list<!iree.variant>
-// CHECK: return %[[VAL_9]], %[[VAL_22]] : i32, !iree.list<!iree.variant>
+// CHECK: %[[VAL_22:.*]] = iree_input.list.create : !iree_input.list<!iree_input.variant>
+// CHECK: return %[[VAL_9]], %[[VAL_22]] : i32, !iree_input.list<!iree_input.variant>
// CHECK: }
iree_pydm.func @dynamic_unpack(%arg0 : !iree_pydm.tuple) -> (!iree_pydm.exception_result, !iree_pydm.tuple) {
%exc_result, %0, %1 = dynamic_unpack %arg0 : !iree_pydm.tuple -> !iree_pydm.exception_result, [!iree_pydm.integer, !iree_pydm.bool]
@@ -239,9 +239,9 @@
// -----
// CHECK-LABEL: func @list_duplicate(
-// CHECK-SAME: %[[VAL_0:.*]]: !iree.list<!iree.variant>,
-// CHECK-SAME: %[[VAL_1:.*]]: i32) -> (i32, !iree.list<!iree.variant>) {
-// CHECK: %[[VAL_2:.*]] = iree.list.size %[[VAL_0]] : !iree.list<!iree.variant>
+// CHECK-SAME: %[[VAL_0:.*]]: !iree_input.list<!iree_input.variant>,
+// CHECK-SAME: %[[VAL_1:.*]]: i32) -> (i32, !iree_input.list<!iree_input.variant>) {
+// CHECK: %[[VAL_2:.*]] = iree_input.list.size %[[VAL_0]] : !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_1]] : i32 to index
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
@@ -249,8 +249,8 @@
// CHECK: %[[VAL_7:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_6]] : i32
// CHECK: %[[VAL_8:.*]] = select %[[VAL_7]], %[[VAL_4]], %[[VAL_3]] : index
// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_2]], %[[VAL_8]] : index
-// CHECK: %[[VAL_10:.*]] = iree.list.create %[[VAL_8]] : !iree.list<!iree.variant>
-// CHECK: iree.list.resize %[[VAL_10]], %[[VAL_9]] : !iree.list<!iree.variant>
+// CHECK: %[[VAL_10:.*]] = iree_input.list.create %[[VAL_8]] : !iree_input.list<!iree_input.variant>
+// CHECK: iree_input.list.resize %[[VAL_10]], %[[VAL_9]] : !iree_input.list<!iree_input.variant>
// CHECK: br ^bb1(%[[VAL_4]] : index)
// CHECK: ^bb1(%[[VAL_11:.*]]: index):
// CHECK: %[[VAL_12:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_9]] : index
@@ -259,14 +259,14 @@
// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_14]], %[[VAL_2]] : index
// CHECK: cond_br %[[VAL_15]], ^bb3(%[[VAL_13]], %[[VAL_14]] : index, index), ^bb1(%[[VAL_13]] : index)
// CHECK: ^bb3(%[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index):
-// CHECK: %[[VAL_18:.*]] = iree.list.get %[[VAL_0]]{{\[}}%[[VAL_17]]] : !iree.list<!iree.variant> -> !iree.list<!iree.variant>
-// CHECK: iree.list.set %[[VAL_10]]{{\[}}%[[VAL_16]]], %[[VAL_18]] : !iree.list<!iree.variant>, !iree.list<!iree.variant>
+// CHECK: %[[VAL_18:.*]] = iree_input.list.get %[[VAL_0]]{{\[}}%[[VAL_17]]] : !iree_input.list<!iree_input.variant> -> !iree_input.list<!iree_input.variant>
+// CHECK: iree_input.list.set %[[VAL_10]]{{\[}}%[[VAL_16]]], %[[VAL_18]] : !iree_input.list<!iree_input.variant>, !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_16]], %[[VAL_5]] : index
// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_17]], %[[VAL_5]] : index
// CHECK: br ^bb2(%[[VAL_19]], %[[VAL_20]] : index, index)
// CHECK: ^bb4:
// CHECK: %[[VAL_21:.*]] = arith.constant 0 : i32
-// CHECK: return %[[VAL_21]], %[[VAL_10]] : i32, !iree.list<!iree.variant>
+// CHECK: return %[[VAL_21]], %[[VAL_10]] : i32, !iree_input.list<!iree_input.variant>
// CHECK: }
iree_pydm.func @list_duplicate(%arg0 : !iree_pydm.list, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.list) {
%result = sequence_clone %arg0 * %arg1 : !iree_pydm.list, !iree_pydm.integer -> !iree_pydm.list
@@ -275,10 +275,10 @@
// -----
// CHECK-LABEL: func @subscript_list(
-// CHECK-SAME: %[[VAL_0:.*]]: !iree.list<!iree.variant>,
-// CHECK-SAME: %[[VAL_1:.*]]: i32) -> (i32, !iree.list<!iree.variant>) {
+// CHECK-SAME: %[[VAL_0:.*]]: !iree_input.list<!iree_input.variant>,
+// CHECK-SAME: %[[VAL_1:.*]]: i32) -> (i32, !iree_input.list<!iree_input.variant>) {
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_3:.*]] = iree.list.size %[[VAL_0]] : !iree.list<!iree.variant>
+// CHECK: %[[VAL_3:.*]] = iree_input.list.size %[[VAL_0]] : !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_3]] : index to i32
// CHECK: %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_2]] : i32
// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_1]] : i32 to index
@@ -292,22 +292,22 @@
// CHECK: cond_br %[[VAL_10]], ^bb3(%[[VAL_9]] : index), ^bb6
// CHECK: ^bb3(%[[VAL_11:.*]]: index):
// CHECK: %[[VAL_12:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_13:.*]] = iree.list.get %[[VAL_0]]{{\[}}%[[VAL_11]]] : !iree.list<!iree.variant> -> !iree.list<!iree.variant>
-// CHECK: br ^bb4(%[[VAL_12]], %[[VAL_13]] : i32, !iree.list<!iree.variant>)
-// CHECK: ^bb4(%[[VAL_14:.*]]: i32, %[[VAL_15:.*]]: !iree.list<!iree.variant>):
+// CHECK: %[[VAL_13:.*]] = iree_input.list.get %[[VAL_0]]{{\[}}%[[VAL_11]]] : !iree_input.list<!iree_input.variant> -> !iree_input.list<!iree_input.variant>
+// CHECK: br ^bb4(%[[VAL_12]], %[[VAL_13]] : i32, !iree_input.list<!iree_input.variant>)
+// CHECK: ^bb4(%[[VAL_14:.*]]: i32, %[[VAL_15:.*]]: !iree_input.list<!iree_input.variant>):
// CHECK: %[[VAL_16:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_17:.*]] = arith.cmpi eq, %[[VAL_16]], %[[VAL_14]] : i32
// CHECK: cond_br %[[VAL_17]], ^bb5, ^bb7
// CHECK: ^bb5:
// CHECK: %[[VAL_18:.*]] = arith.constant 0 : i32
-// CHECK: return %[[VAL_18]], %[[VAL_15]] : i32, !iree.list<!iree.variant>
+// CHECK: return %[[VAL_18]], %[[VAL_15]] : i32, !iree_input.list<!iree_input.variant>
// CHECK: ^bb6:
// CHECK: %[[VAL_19:.*]] = arith.constant -7 : i32
-// CHECK: %[[VAL_20:.*]] = iree.list.create : !iree.list<!iree.variant>
-// CHECK: br ^bb4(%[[VAL_19]], %[[VAL_20]] : i32, !iree.list<!iree.variant>)
+// CHECK: %[[VAL_20:.*]] = iree_input.list.create : !iree_input.list<!iree_input.variant>
+// CHECK: br ^bb4(%[[VAL_19]], %[[VAL_20]] : i32, !iree_input.list<!iree_input.variant>)
// CHECK: ^bb7:
-// CHECK: %[[VAL_21:.*]] = iree.list.create : !iree.list<!iree.variant>
-// CHECK: return %[[VAL_14]], %[[VAL_21]] : i32, !iree.list<!iree.variant>
+// CHECK: %[[VAL_21:.*]] = iree_input.list.create : !iree_input.list<!iree_input.variant>
+// CHECK: return %[[VAL_14]], %[[VAL_21]] : i32, !iree_input.list<!iree_input.variant>
// CHECK: }
iree_pydm.func @subscript_list(%arg0 : !iree_pydm.list, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.object) {
%exc_result, %result = subscript %arg0[%arg1] : !iree_pydm.list, !iree_pydm.integer -> !iree_pydm.object
@@ -317,11 +317,11 @@
// -----
// CHECK-LABEL: func @assign_subscript_list(
-// CHECK-SAME: %[[VAL_0:.*]]: !iree.list<!iree.variant>,
+// CHECK-SAME: %[[VAL_0:.*]]: !iree_input.list<!iree_input.variant>,
// CHECK-SAME: %[[VAL_1:.*]]: i32,
-// CHECK-SAME: %[[VAL_2:.*]]: !iree.list<!iree.variant>) -> (i32, !iree.list<!iree.variant>) {
+// CHECK-SAME: %[[VAL_2:.*]]: !iree_input.list<!iree_input.variant>) -> (i32, !iree_input.list<!iree_input.variant>) {
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_4:.*]] = iree.list.size %[[VAL_0]] : !iree.list<!iree.variant>
+// CHECK: %[[VAL_4:.*]] = iree_input.list.size %[[VAL_0]] : !iree_input.list<!iree_input.variant>
// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i32
// CHECK: %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : i32
// CHECK: %[[VAL_7:.*]] = arith.index_cast %[[VAL_1]] : i32 to index
@@ -335,11 +335,11 @@
// CHECK: cond_br %[[VAL_11]], ^bb3(%[[VAL_10]] : index), ^bb5
// CHECK: ^bb3(%[[VAL_12:.*]]: index):
// CHECK: %[[VAL_13:.*]] = arith.constant 0 : i32
-// CHECK: iree.list.set %[[VAL_0]]{{\[}}%[[VAL_12]]], %[[VAL_2]] : !iree.list<!iree.variant>, !iree.list<!iree.variant>
+// CHECK: iree_input.list.set %[[VAL_0]]{{\[}}%[[VAL_12]]], %[[VAL_2]] : !iree_input.list<!iree_input.variant>, !iree_input.list<!iree_input.variant>
// CHECK: br ^bb4(%[[VAL_13]] : i32)
// CHECK: ^bb4(%[[VAL_14:.*]]: i32):
// CHECK: %[[VAL_15:.*]] = arith.constant 0 : i32
-// CHECK: return %[[VAL_15]], %[[VAL_0]] : i32, !iree.list<!iree.variant>
+// CHECK: return %[[VAL_15]], %[[VAL_0]] : i32, !iree_input.list<!iree_input.variant>
// CHECK: ^bb5:
// CHECK: %[[VAL_16:.*]] = arith.constant -7 : i32
// CHECK: br ^bb4(%[[VAL_16]] : i32)
diff --git a/llvm-external-projects/iree-dialects/test/python/smoketest.py b/llvm-external-projects/iree-dialects/test/python/smoketest.py
index 651d9e0..cef2e6c 100644
--- a/llvm-external-projects/iree-dialects/test/python/smoketest.py
+++ b/llvm-external-projects/iree-dialects/test/python/smoketest.py
@@ -1,7 +1,7 @@
# RUN: %PYTHON %s
import iree.compiler.ir
-from iree.compiler.dialects import iree as iree_d
+from iree.compiler.dialects import iree_input as iree_d
from iree.compiler.dialects import iree_pydm as pydm_d
with iree.compiler.ir.Context() as ctx:
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
index c721847..6c347b4 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
@@ -1,13 +1,19 @@
set(LIBS
+ MLIRArithmetic
MLIRDialect
+ MLIRLinalg
+ MLIRMemRef
MLIROptLib
MLIRSCF
MLIRSCFTransforms
MLIRStandard
+ MLIRTensor
MLIRTransforms
- IREEDialectsIREEDialect
- IREEDialectsIREEPyDMDialect
- IREEDialectsIREEPyDMPasses
+ IREEInputDialect
+ IREELinalgExtDialect
+ IREELinalgExtPasses
+ IREEPyDMDialect
+ IREEPyDMPasses
)
add_llvm_tool(iree-dialects-opt
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
index e663cd1..827635d 100644
--- a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
@@ -4,19 +4,26 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
-using namespace mlir::iree;
+namespace IREE = mlir::iree_compiler::IREE;
int main(int argc, char **argv) {
registerAsmPrinterCLOptions();
@@ -27,13 +34,20 @@
// Local dialects.
mlir::iree_compiler::IREE::PYDM::registerPasses();
+ mlir::iree_compiler::IREE::LinalgExt::registerPasses();
DialectRegistry registry;
registry.insert<
// Local dialects
- mlir::iree::IREEDialect, mlir::iree_compiler::IREE::PYDM::IREEPyDMDialect,
+ mlir::iree_compiler::IREE::Input::IREEInputDialect,
+ mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect,
+ mlir::iree_compiler::IREE::PYDM::IREEPyDMDialect,
// Upstream dialects
- mlir::StandardOpsDialect, mlir::scf::SCFDialect>();
+ mlir::arith::ArithmeticDialect, mlir::linalg::LinalgDialect,
+ mlir::memref::MemRefDialect, mlir::StandardOpsDialect,
+ mlir::scf::SCFDialect, mlir::tensor::TensorDialect>();
+
+ IREE::LinalgExt::registerTiledOpInterfaceExternalModels(registry);
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,