Moves the linalg_ext dialect to iree_linalg_ext under the iree-dialects project. (#7657)
Mostly NFC:
* Pre-factors iree-dialects into a consistent state wrt namespaces and directory layouts.
* Moves linalg_ext to a new dialect under iree-dialects. Does some adaptation to upstream style along the way.
* Redirects everything that was using it to use the new one.
* Enables tests for iree-dialects in the cmake CI (they were not enabled) and fixes some type prefixing that had drifted. (Will follow up with enabling them in the internal builds to guard against further regression).
Non-NFC:
*When tiling, the old pass was directly generating flow.dispatch.workgroup... ops to query the current workgroup id. We had been planning to add those to the input dialect, so I pulled part of that patch forward.
* Since these now newly lack lowerings, I was expecting to hit some test failures that would guide me to where to adapt, but this appears to be dead code (outside of integration tests?). We'll see what the CI says.
* Can keep coding on this patch to adapt whatever is needed.
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
index 75be3ec..fde1221 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
@@ -3,9 +3,9 @@
Utils.cpp
LINK_LIBS PUBLIC
MLIRIR
- IREEDialectsIREEDialect
- IREEDialectsIREEPyDMDialect
- IREEDialectsIREEPyDMPasses
+ IREEInputDialect
+ IREEPyDMDialect
+ IREEPyDMPasses
)
iree_dialects_target_includes(IREEDialectsCAPI)
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
index 003c765..ac169f1 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
@@ -6,9 +6,9 @@
#include "iree-dialects-c/Dialects.h"
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Pass.h"
#include "mlir/CAPI/Registration.h"
@@ -24,7 +24,8 @@
// IREEDialect
//===----------------------------------------------------------------------===//
-MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(IREE, iree, mlir::iree::IREEDialect)
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(
+ IREEInput, iree_input, mlir::iree_compiler::IREE::Input::IREEInputDialect)
//===----------------------------------------------------------------------===//
// IREEPyDMDialect
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
index 61df04e..620c526 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
@@ -1,2 +1,3 @@
-add_subdirectory(IREE)
-add_subdirectory(IREEPyDM)
+add_subdirectory(Input)
+add_subdirectory(LinalgExt)
+add_subdirectory(PyDM)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREE/CMakeLists.txt
deleted file mode 100644
index 0630336..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/CMakeLists.txt
+++ /dev/null
@@ -1,16 +0,0 @@
-add_mlir_library(IREEDialectsIREEDialect
- IREEDialect.cpp
- IREEOps.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${IREE_DIALECTS_SOURCE_DIR}/include
-
- DEPENDS
- MLIRIREEOpsIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRSideEffectInterfaces
-)
-
-iree_dialects_target_includes(IREEDialectsIREEDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEDialect.cpp
deleted file mode 100644
index 6b2ab7f..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEDialect.cpp
+++ /dev/null
@@ -1,43 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-
-#include "iree-dialects/Dialect/IREE/IREEOps.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "mlir/IR/DialectImplementation.h"
-#include "mlir/Support/LLVM.h"
-
-using namespace mlir;
-using namespace mlir::iree;
-
-#include "iree-dialects/Dialect/IREE/IREEOpsDialect.cpp.inc"
-
-#define GET_TYPEDEF_CLASSES
-#include "iree-dialects/Dialect/IREE/IREEOpsTypes.cpp.inc"
-
-void IREEDialect::initialize() {
- addTypes<
-#define GET_TYPEDEF_LIST
-#include "iree-dialects/Dialect/IREE/IREEOpsTypes.cpp.inc"
- >();
- addOperations<
-#define GET_OP_LIST
-#include "iree-dialects/Dialect/IREE/IREEOps.cpp.inc"
- >();
-}
-
-Type IREEDialect::parseType(DialectAsmParser &parser) const {
- StringRef typeTag;
- Type genType;
- if (succeeded(parser.parseKeyword(&typeTag)))
- generatedTypeParser(parser, typeTag, genType);
- return genType;
-}
-
-void IREEDialect::printType(Type type, DialectAsmPrinter &printer) const {
- (void)generatedTypePrinter(type, printer);
-}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt
deleted file mode 100644
index b6293fe..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt
+++ /dev/null
@@ -1,16 +0,0 @@
-add_mlir_library(IREEDialectsIREEPyDMDialect
- Dialect.cpp
- Ops.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${IREE_DIALECTS_SOURCE_DIR}/include
-
- DEPENDS
- MLIRIREEPyDMOpsIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRSideEffectInterfaces
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
deleted file mode 100644
index db6f84f..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
+++ /dev/null
@@ -1,18 +0,0 @@
-add_subdirectory(Optimize)
-add_subdirectory(RTL)
-add_subdirectory(ToIREE)
-
-add_mlir_library(IREEDialectsIREEPyDMPasses
- Passes.cpp
-
- DEPENDS
- MLIRIREEPyDMTransformsPassesIncGen
-
- LINK_LIBS PUBLIC
- IREEDialectsIREEPyDMOptimizePasses
- IREEDialectsIREEPyDMRTLPasses
- IREEDialectsIREEPyDMToIREEPasses
- MLIRTransforms
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt
deleted file mode 100644
index 093f273..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/CMakeLists.txt
+++ /dev/null
@@ -1,16 +0,0 @@
-add_mlir_library(IREEDialectsIREEPyDMOptimizePasses
- FixateWeakNumeric.cpp
- LocalPropagateTypes.cpp
- VariablesToSSA.cpp
-
- DEPENDS
- MLIRIREEPyDMTransformsPassesIncGen
-
- LINK_LIBS PUBLIC
- IREEDialectsIREEPyDMDialect
- IREEDialectsIREEPyDMUtils
- MLIRIR
- MLIRTransformUtils
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMOptimizePasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/PassDetail.h b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/PassDetail.h
deleted file mode 100644
index a95ac11..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/PassDetail.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
-#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
-
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-
-namespace iree {
-class IREEDialect;
-}
-
-namespace iree_compiler {
-namespace IREE {
-namespace PYDM {
-
-class FuncOp;
-
-#define GEN_PASS_CLASSES
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h.inc"
-
-} // namespace PYDM
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/CMakeLists.txt
deleted file mode 100644
index a9139a4..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/CMakeLists.txt
+++ /dev/null
@@ -1,17 +0,0 @@
-add_mlir_library(IREEDialectsIREEPyDMRTLPasses
- LinkageAnalysis.cpp
- LinkRTLPass.cpp
- LowerToRTLPass.cpp
-
- DEPENDS
- MLIRIREEPyDMTransformsPassesIncGen
-
- LINK_LIBS PUBLIC
- IREEDialectsIREEPyDMDialect
- IREEDialectsIREEDialect
- MLIRIR
- MLIRParser
- MLIRTransformUtils
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMRTLPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/CMakeLists.txt
deleted file mode 100644
index 8c332af..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/CMakeLists.txt
+++ /dev/null
@@ -1,18 +0,0 @@
-add_mlir_library(IREEDialectsIREEPyDMToIREEPasses
- ConversionPass.cpp
- LoweringPatterns.cpp
- TypeConverter.cpp
-
- DEPENDS
- MLIRIREEPyDMTransformsPassesIncGen
-
- LINK_LIBS PUBLIC
- IREEDialectsIREEPyDMDialect
- IREEDialectsIREEDialect
- MLIRArithmetic
- MLIRIR
- MLIRStandard
- MLIRTransformUtils
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMToIREEPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/CMakeLists.txt
deleted file mode 100644
index efca191..0000000
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/CMakeLists.txt
+++ /dev/null
@@ -1,11 +0,0 @@
-add_mlir_library(IREEDialectsIREEPyDMUtils
- TypeInference.cpp
-
- LINK_LIBS PUBLIC
- IREEDialectsIREEPyDMDialect
- MLIRIR
- MLIRStandard
- MLIRTransformUtils
-)
-
-iree_dialects_target_includes(IREEDialectsIREEPyDMUtils)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/Input/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/Input/CMakeLists.txt
new file mode 100644
index 0000000..b755303
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/Input/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_library(IREEInputDialect
+ InputDialect.cpp
+ InputOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${IREE_DIALECTS_SOURCE_DIR}/include
+
+ DEPENDS
+ IREEInputIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSideEffectInterfaces
+)
+
+iree_dialects_target_includes(IREEInputDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputDialect.cpp
new file mode 100644
index 0000000..ef6a95a
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputDialect.cpp
@@ -0,0 +1,43 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::Input;
+
+#include "iree-dialects/Dialect/Input/InputDialect.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "iree-dialects/Dialect/Input/InputTypes.cpp.inc"
+
+void IREEInputDialect::initialize() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "iree-dialects/Dialect/Input/InputTypes.cpp.inc"
+ >();
+ addOperations<
+#define GET_OP_LIST
+#include "iree-dialects/Dialect/Input/InputOps.cpp.inc"
+ >();
+}
+
+Type IREEInputDialect::parseType(DialectAsmParser &parser) const {
+ StringRef typeTag;
+ Type genType;
+ if (succeeded(parser.parseKeyword(&typeTag)))
+ generatedTypeParser(parser, typeTag, genType);
+ return genType;
+}
+
+void IREEInputDialect::printType(Type type, DialectAsmPrinter &printer) const {
+ (void)generatedTypePrinter(type, printer);
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputOps.cpp
similarity index 92%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/Input/InputOps.cpp
index a723b58..894a74f 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREE/IREEOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputOps.cpp
@@ -4,16 +4,16 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREE/IREEOps.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
-using namespace mlir::iree;
+using namespace mlir::iree_compiler::IREE::Input;
//===----------------------------------------------------------------------===//
// custom<SymbolVisibility>($sym_visibility)
@@ -91,4 +91,4 @@
}
#define GET_OP_CLASSES
-#include "iree-dialects/Dialect/IREE/IREEOps.cpp.inc"
+#include "iree-dialects/Dialect/Input/InputOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt
new file mode 100644
index 0000000..9f57627
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/CMakeLists.txt
new file mode 100644
index 0000000..6022a57
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/CMakeLists.txt
@@ -0,0 +1,29 @@
+add_mlir_library(IREELinalgExtDialect
+ LinalgExtDialect.cpp
+ LinalgExtInterfaces.cpp
+ LinalgExtOps.cpp
+ TiledOpInterface.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${IREE_DIALECTS_SOURCE_DIR}/include
+
+ DEPENDS
+ IREELinalgExtIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAffine
+ MLIRDialectUtils
+ MLIRIR
+ MLIRLinalg
+ MLIRMath
+ MLIRMemRef
+ MLIRPass
+ MLIRSideEffectInterfaces
+ MLIRSupport
+ MLIRSCF
+ MLIRStandard
+ MLIRTensor
+ MLIRViewLikeInterface
+)
+
+iree_dialects_target_includes(IREELinalgExtDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
new file mode 100644
index 0000000..4657c12
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtDialect.cpp
@@ -0,0 +1,29 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/SourceMgr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+
+void IREELinalgExtDialect::initialize() {
+ // TODO(hanchung): Add interface to the dialect.
+ // addInterfaces<IREEInlinerInterface>();
+#define GET_OP_LIST
+ addOperations<
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc"
+ >();
+}
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
new file mode 100644
index 0000000..5fdaf8a
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
@@ -0,0 +1,51 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+using namespace IREE::LinalgExt;
+
+OpOperandVector::operator SmallVector<Value>() {
+ SmallVector<Value> result;
+ result.reserve(this->size());
+ llvm::transform(*this, std::back_inserter(result),
+ [](OpOperand *opOperand) { return opOperand->get(); });
+ return result;
+}
+
+LogicalResult IREE::LinalgExt::detail::verifyLinalgExtOpInterface(
+ Operation *op) {
+ LinalgExtOp linalgExtOp = cast<LinalgExtOp>(op);
+ if (op->getNumResults()) {
+ if (!linalgExtOp.hasTensorSemantics()) {
+ return linalgExtOp.emitOpError(
+ "expected inputs and outputs to be RankedTensorType or scalar");
+ }
+
+ if (op->getNumResults() != linalgExtOp.outputs().size()) {
+ return linalgExtOp.emitOpError(
+ "expected number of outputs to be same as the number of results");
+ }
+ for (auto en : llvm::enumerate(op->getResultTypes())) {
+ Type outputType = linalgExtOp.outputs()[en.index()].getType();
+ if (en.value() != outputType) {
+ return linalgExtOp.emitOpError("expected type of `outs` operand #")
+ << en.index() << " " << outputType
+ << " to be same as result type " << en.value();
+ }
+ }
+ } else {
+ if (!linalgExtOp.hasBufferSemantics()) {
+ return linalgExtOp.emitOpError(
+ "expected inputs and outputs to be MemRefType or scalar");
+ }
+ }
+ return success();
+}
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.cpp.inc" // IWYU pragma: export
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
new file mode 100644
index 0000000..ebcb0cd
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -0,0 +1,985 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/SMLoc.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+namespace IREE = mlir::iree_compiler::IREE;
+
+//===----------------------------------------------------------------------===//
+// Utils.
+//===----------------------------------------------------------------------===//
+
+static void getEffectsImpl(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects,
+ ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) {
+ for (Value value : results) {
+ effects.emplace_back(MemoryEffects::Allocate::get(), value,
+ SideEffects::DefaultResource::get());
+ }
+ for (Value value : inputBuffers) {
+ effects.emplace_back(MemoryEffects::Read::get(), value,
+ SideEffects::DefaultResource::get());
+ }
+ for (Value value : outputBuffers) {
+ effects.emplace_back(MemoryEffects::Read::get(), value,
+ SideEffects::DefaultResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), value,
+ SideEffects::DefaultResource::get());
+ }
+}
+
+/// Returns a memref.subview or a tensor.extract_slice based on the type of the
+/// `source`.
+static Value getSlice(OpBuilder &b, Location loc, Value source,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
+ return TypeSwitch<Type, Value>(source.getType())
+ .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+ return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
+ strides);
+ })
+ .Case<MemRefType>([&](MemRefType type) -> Value {
+ return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
+ strides);
+ })
+ .Default([&](Type t) { return nullptr; });
+}
+
+Value IREE::LinalgExt::getDimValue(OpBuilder &builder, Location loc, Value v,
+ int64_t dim) {
+ return TypeSwitch<Type, Value>(v.getType())
+ .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+ return builder.create<tensor::DimOp>(loc, v, dim);
+ })
+ .Case<MemRefType>([&](MemRefType t) -> Value {
+ return builder.create<memref::DimOp>(loc, v, dim);
+ })
+ .Default([&](Type t) { return Value(); });
+}
+
+OpFoldResult IREE::LinalgExt::getDim(OpBuilder &builder, Location loc, Value v,
+ int64_t dim) {
+ auto t = v.getType().cast<ShapedType>();
+ if (t.isDynamicDim(dim)) {
+ return getDimValue(builder, loc, v, dim);
+ }
+ return builder.getI64IntegerAttr(t.getDimSize(dim));
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+static LogicalResult verifyScatterOp(ScatterOp op) {
+ if (op.inputs().size() != 2) {
+ return op.emitOpError("expected two input operands");
+ }
+ if (op.outputs().size() != 1) {
+ return op.emitOpError("expected one output operand");
+ }
+ auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
+ return t1.getShape()[dim] == t2.getShape()[dim];
+ };
+
+ auto indicesType = op.getIndicesType();
+ if (indicesType.getRank() != 2 ||
+ !indicesType.getElementType().isInteger(32)) {
+ return op.emitOpError(
+ "expected indices to be of rank 2 of i32 element type");
+ }
+ auto indexDepth = op.getIndexDepth();
+ if (indexDepth == ShapedType::kDynamicSize) {
+ return op.emitOpError("expected index depth is static");
+ }
+
+ // The first dimension of the indices should match the first dimension of the
+ // output. They indicate to the number of updates.
+ auto updateType = op.getUpdateType();
+ if (updateType.getRank() < 1) {
+ return op.emitOpError("expected update value to be at least rank 1");
+ }
+ if (!checkDimensionsMatch(indicesType, updateType, 0)) {
+ return op.emitOpError(
+ "mismatch in shape of indices and update value at dim#0");
+ }
+ auto originalType = op.getOriginalType();
+ // indexDepth + update dims should match to original dims. The first dim of
+ // update is the number of updates.
+ if (originalType.getRank() != indexDepth + updateType.getRank() - 1) {
+ return op.emitOpError(
+ "mismatch in rank of update value, index depth and original value");
+ }
+ for (auto dim : llvm::seq<unsigned>(indexDepth, originalType.getRank())) {
+ // Offset one because the first dim is the number of updates.
+ if (updateType.getDimSize(1 + dim - indexDepth) !=
+ originalType.getDimSize(dim)) {
+ return op.emitOpError("mismatch in shape of update value dim#")
+ << (1 + dim - indexDepth) << " and original value at dim#" << dim;
+ }
+ }
+ Region ®ion = op.region();
+ Block *body = ®ion.front();
+ if (body->getNumArguments() != 2) {
+ return op.emitOpError("expected region to have two arguments");
+ }
+ Type arg0Type = body->getArgument(0).getType();
+ Type arg1Type = body->getArgument(1).getType();
+ if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
+ return op.emitOpError(
+ "expected region to have scalar argument of integer or float types");
+ }
+ if (arg0Type != updateType.getElementType()) {
+ return op.emitOpError("mismatch in argument 0 of region ")
+ << arg0Type << " and element type of update value "
+ << updateType.getElementType();
+ }
+ if (arg1Type != originalType.getElementType()) {
+ return op.emitOpError("mismatch in argument 1 of region ")
+ << arg1Type << " and element type of original value "
+ << originalType.getElementType();
+ }
+ if (arg0Type != arg1Type) {
+ return op.emitOpError("mismatch in region argument types ")
+ << arg0Type << " and " << arg1Type;
+ }
+ auto yieldOp = cast<IREE::LinalgExt::YieldOp>(body->getTerminator());
+ if (yieldOp->getNumOperands() != 1) {
+ return yieldOp.emitOpError("expected region to yield a single value");
+ }
+ auto yieldedType = yieldOp->getOperand(0).getType();
+ if (yieldedType != arg0Type) {
+ return yieldOp.emitOpError("mismatch in type of yielded value ")
+ << yieldedType << " and argument of the region " << arg0Type;
+ }
+ return success();
+}
+
+SmallVector<StringRef> ScatterOp::getLoopIteratorTypes() {
+ SmallVector<StringRef> iteratorTypes(getUpdateType().getRank(),
+ getParallelIteratorTypeName());
+ return iteratorTypes;
+}
+
+SmallVector<Range> ScatterOp::getLoopBounds(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<Range> ranges;
+ for (auto dim : llvm::seq<int64_t>(0, getUpdateType().getRank())) {
+ Value ub = getDimValue(builder, loc, updates(), dim);
+ ranges.emplace_back(Range{zero, ub, one});
+ }
+ return ranges;
+}
+
+Operation *ScatterOp::getTiledImplementation(OpBuilder &builder,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) {
+ assert(outputs.size() >= 1 && offsets.size() >= 1 && sizes.size() >= 1);
+ Location loc = getLoc();
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ auto oneAttr = builder.getI64IntegerAttr(1);
+
+ // Slice of the updates.
+ auto updateRank = getUpdateType().getRank();
+ SmallVector<OpFoldResult> updateStrides(updateRank, oneAttr);
+ Value tiledUpdate =
+ getSlice(builder, loc, updates(), offsets, sizes, updateStrides);
+ assert(tiledUpdate && "failed to get slice of update");
+
+ // Slice of indices.
+ auto indicesRank = getIndicesType().getRank();
+ SmallVector<OpFoldResult> indicesOffsets(indicesRank, zeroAttr);
+ SmallVector<OpFoldResult> indicesSizes(indicesRank);
+ indicesOffsets[0] = offsets[0];
+ indicesSizes[0] = sizes[0];
+ for (auto dim : llvm::seq<int64_t>(1, indicesRank)) {
+ indicesSizes[dim] = getDim(builder, loc, indices(), dim);
+ }
+ SmallVector<OpFoldResult> indicesStrides(indicesRank, oneAttr);
+ Value tiledIndices = getSlice(builder, loc, indices(), indicesOffsets,
+ indicesSizes, indicesStrides);
+ assert(tiledIndices && "failed to get slice of indices");
+
+ // Slice of the original.
+ auto originalRank = getOriginalType().getRank();
+ SmallVector<OpFoldResult> originalOffsets(originalRank, zeroAttr);
+ SmallVector<OpFoldResult> originalSizes(originalRank);
+ for (auto dim : llvm::seq<int64_t>(0, originalRank - updateRank + 1)) {
+ originalSizes[dim] = getDim(builder, loc, original(), dim);
+ }
+ for (auto dim :
+ llvm::seq<int64_t>(originalRank - updateRank + 1, originalRank)) {
+ originalOffsets[dim] = offsets[dim - (originalRank - updateRank)];
+ originalSizes[dim] = sizes[dim - (originalRank - updateRank)];
+ }
+ SmallVector<OpFoldResult> originalStrides(originalRank, oneAttr);
+ Value tiledOriginal = getSlice(builder, loc, outputs[0], originalOffsets,
+ originalSizes, originalStrides);
+ assert(tiledOriginal && "failed to get slice of original tensor");
+
+ SmallVector<Type> resultTypes;
+ if (getNumResults()) {
+ resultTypes.push_back(tiledOriginal.getType());
+ }
+ Operation *tiledScatterOp =
+ cast<LinalgExtOp>(getOperation())
+ .clone(builder, loc, resultTypes,
+ ValueRange{tiledUpdate, tiledIndices, tiledOriginal});
+ for (auto result : llvm::enumerate(tiledScatterOp->getResults())) {
+ auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+ loc, result.value(), outputs[0], originalOffsets, originalSizes,
+ originalStrides);
+ results.push_back(insertSliceOp.getResult());
+ }
+ return tiledScatterOp;
+}
+
+LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
+ Location loc,
+ ValueRange ivs) {
+ auto indexDepth = getIndexDepth();
+ Value update = b.create<memref::LoadOp>(loc, updates(), ivs);
+ SmallVector<Value> starts;
+ SmallVector<Value> loadIndices;
+ loadIndices.push_back(ivs.front());
+ loadIndices.push_back(Value());
+ for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
+ loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
+ Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
+ starts.push_back(b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx));
+ }
+ starts.append(std::next(ivs.begin()), ivs.end());
+ Value init = b.create<memref::LoadOp>(loc, original(), starts);
+
+ BlockAndValueMapping bvm;
+ Block &block = region().front();
+ bvm.map(block.getArgument(0), update);
+ bvm.map(block.getArgument(1), init);
+ for (auto &blockOp : block.without_terminator()) {
+ b.clone(blockOp, bvm);
+ }
+ // The last op is linalg_ext.yield op. Store the operand to
+ // destination.
+ b.create<memref::StoreOp>(
+ loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)),
+ original(), starts);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SortOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifySortOp(SortOp op) {
+ if (op.getNumInputs()) {
+ return op.emitOpError("does not expect to take any inputs");
+ }
+ if (op.getNumOutputs() == 0) {
+ return op.emitOpError("expected at least one `outs` operand");
+ }
+
+ Block &block = op.region().front();
+ size_t numOutputs = op.getNumOutputs();
+ if (block.getNumArguments() != 2 * numOutputs) {
+ return op.emitOpError("region block should have ")
+ << 2 * numOutputs << " arguments";
+ }
+
+ int64_t rank = op.getOperandRank();
+ ArrayRef<int64_t> shape = op.getOperandShape();
+ if (rank > 1 && !op.dimensionAttr()) {
+ return op.emitOpError("dimension must be specified if rank > 1");
+ }
+ int dimension = 0;
+ if (op.dimensionAttr()) {
+ dimension = op.dimension().getValue();
+ }
+ if (dimension < 0 || dimension >= rank) {
+ return op.emitOpError("dimension must be within (0, ") << rank << "]";
+ }
+
+ for (auto indexedOperand : llvm::enumerate(op.outputs())) {
+ int index = indexedOperand.index();
+ auto operandType = op.getOperandType(index);
+ if (operandType.getRank() != rank) {
+ return op.emitOpError("expected operand ")
+ << index << " to be rank " << rank << ", same as other operands";
+ }
+ if (operandType.getShape() != shape) {
+ return op.emitOpError("expected operand ")
+ << index << " to have same shape as other operands";
+ }
+ Type elemType = operandType.getElementType();
+ for (int i : {2 * index, 2 * index + 1}) {
+ Type argType = block.getArgument(i).getType();
+ if (argType != elemType) {
+ return op.emitOpError("region block argument #")
+ << i << " should be of type " << elemType << " but got "
+ << argType;
+ }
+ }
+ }
+
+ auto yieldOp = cast<YieldOp>(block.getTerminator());
+ if (yieldOp.getNumOperands() != 1) {
+ return op.emitOpError("should yield exactly one operand");
+ }
+ auto ty = yieldOp.getOperand(0).getType().dyn_cast<IntegerType>();
+ if (!ty || ty.getWidth() != 1) {
+ return op.emitOpError("should yield i1 type");
+ }
+
+ return success();
+}
+
+SmallVector<StringRef> SortOp::getLoopIteratorTypes() {
+ // All loops except the dimension to sort along are parallel.
+ SmallVector<StringRef> iteratorTypes(getOperandRank(),
+ getParallelIteratorTypeName());
+ iteratorTypes[getSortedDimension()] = getReductionIteratorTypeName();
+ return iteratorTypes;
+}
+
+SmallVector<Range> SortOp::getLoopBounds(OpBuilder &builder) {
+ int64_t operandRank = getOperandRank();
+ SmallVector<Range> loopBounds(operandRank);
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ Value source = operand(0);
+ for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
+ loopBounds[dim].offset = zero;
+ loopBounds[dim].size = getDimValue(builder, loc, source, dim);
+ loopBounds[dim].stride = one;
+ }
+ return loopBounds;
+}
+
+SmallVector<unsigned> SortOp::getPartitionableLoops(
+ unsigned maxNumParallelDims) {
+ auto range = llvm::seq<unsigned>(0, getOperandRank());
+ SmallVector<unsigned> partitionableLoops(range.begin(), range.end());
+ partitionableLoops.erase(
+ std::next(partitionableLoops.begin(), getSortedDimension()));
+ if (partitionableLoops.size() > maxNumParallelDims) {
+ partitionableLoops.erase(
+ partitionableLoops.begin(),
+ std::next(partitionableLoops.begin(),
+ partitionableLoops.size() - maxNumParallelDims));
+ }
+ return partitionableLoops;
+}
+
+Operation *SortOp::getTiledImplementation(OpBuilder &builder,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) {
+ assert(outputs.size() == this->outputs().size());
+ int64_t rank = getOperandRank();
+ assert(offsets.size() == static_cast<size_t>(rank) &&
+ sizes.size() == static_cast<size_t>(rank));
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ SmallVector<OpFoldResult> strides(rank, oneAttr);
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands(outputs.size());
+ for (auto en : llvm::enumerate(outputs)) {
+ tiledOperands[en.index()] =
+ getSlice(builder, getLoc(), en.value(), offsets, sizes, strides);
+ assert(tiledOperands[en.index()] && "failed to get slice of operand");
+ }
+ SmallVector<Type, 4> resultTypes;
+ if (getNumResults()) {
+ resultTypes = llvm::to_vector<4>(
+ llvm::map_range(tiledOperands, [&](Value v) { return v.getType(); }));
+ }
+ Operation *tiledSortOp = cast<LinalgExtOp>(getOperation())
+ .clone(builder, loc, resultTypes, tiledOperands);
+ for (auto result : llvm::enumerate(tiledSortOp->getResults())) {
+ auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+ loc, result.value(), outputs[result.index()], offsets, sizes, strides);
+ results.push_back(insertSliceOp.getResult());
+ }
+ return tiledSortOp;
+}
+
+LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc,
+ ValueRange ivs) {
+ auto sortDim = getSortedDimension();
+ SmallVector<Value> indices, sortBlkArgs;
+ indices.append(ivs.begin(), ivs.end());
+ // Bubble sort innermost loop.
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ Value ub;
+ if (getOperandType(0).isDynamicDim(sortDim)) {
+ ub = b.create<memref::DimOp>(loc, operand(0), sortDim);
+ } else {
+ ub = b.create<arith::ConstantIndexOp>(
+ loc, getOperandType(0).getDimSize(sortDim));
+ }
+ ub = b.create<arith::SubIOp>(loc, ub, one);
+ auto scfFor = b.create<scf::ForOp>(
+ loc, zero, ub, one, ValueRange{},
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iters) {
+ SmallVector<Value> indices(ivs);
+ Value ivPlusOne = b.create<arith::AddIOp>(loc, iv, one);
+ for (auto output : getOutputOperands()) {
+ indices[sortDim] = iv;
+ sortBlkArgs.push_back(
+ b.create<memref::LoadOp>(loc, output->get(), indices));
+ indices[sortDim] = ivPlusOne;
+ sortBlkArgs.push_back(
+ b.create<memref::LoadOp>(loc, output->get(), indices));
+ }
+ });
+
+ auto &srcBlock = region().front();
+ Region ®ion = scfFor.region();
+ BlockAndValueMapping bvm;
+ {
+ OpBuilder::InsertionGuard guard(b);
+ auto &block = region.front();
+ b.setInsertionPointToEnd(&block);
+ for (auto it : llvm::zip(srcBlock.getArguments(), sortBlkArgs)) {
+ bvm.map(std::get<0>(it), std::get<1>(it));
+ }
+ for (auto &blockOp : srcBlock.without_terminator()) {
+ b.clone(blockOp, bvm);
+ }
+ }
+ Value cond = bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0));
+
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToEnd(®ion.front());
+ b.create<scf::IfOp>(
+ loc, TypeRange{}, cond,
+ [&](OpBuilder &b, Location loc) {
+ // Do not swap the pairs if true.
+ b.create<scf::YieldOp>(loc);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // Swap the pairs if false.
+ SmallVector<Value> indices(ivs.begin(), ivs.end());
+ Value ivPlusOne =
+ b.create<arith::AddIOp>(loc, scfFor.getInductionVar(), one);
+ for (int i = 0, e = getNumOutputs(); i < e; ++i) {
+ Value v1 = sortBlkArgs[i * 2];
+ Value v2 = sortBlkArgs[i * 2 + 1];
+ indices[sortDim] = scfFor.getInductionVar();
+ b.create<memref::StoreOp>(loc, v2, getOutputOperand(i)->get(),
+ indices);
+ indices[sortDim] = ivPlusOne;
+ b.create<memref::StoreOp>(loc, v1, getOutputOperand(i)->get(),
+ indices);
+ }
+ b.create<scf::YieldOp>(loc);
+ });
+ b.create<scf::YieldOp>(loc);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FftOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyFftOp(FftOp op) {
+ auto length = op.getFftLength();
+ // After tiling, it could be dynamic shape. (Because
+ // subview/subtensor does not inference the type correctly
+ // on (1 << x)) cases).
+ if (length == ShapedType::kDynamicSize) return success();
+ if (length & (length - 1)) {
+ return op.emitOpError("only powers of 2 are handled currently");
+ }
+ if (!op.getNumInputs() || !op.isScalar(op.getInputOperand(0))) {
+ return op.emitOpError("expected to carry `stage` input");
+ }
+ if (op.getNumInputs() != 1) {
+ if (op.getNumInputs() != 3 || op.isScalar(op.getInputOperand(1)) ||
+ op.isScalar(op.getInputOperand(2))) {
+ return op.emitOpError("expected to carry real and imag coeff inputs");
+ }
+ }
+ if (op.getNumOutputs() != 2) {
+ return op.emitOpError("expected outputs to be real and imag tensor/memref");
+ }
+ return success();
+}
+
+SmallVector<StringRef> FftOp::getLoopIteratorTypes() {
+ // There are `rank-1` outer loops. The fft itselfs has one loop for each
+ // stage, which handles the merge step -- taking two half size tensors and
+ // merge them into one tensor.
+ SmallVector<StringRef> iteratorTypes(getOperandRank(),
+ getParallelIteratorTypeName());
+ return iteratorTypes;
+}
+
+SmallVector<Range> FftOp::getLoopBounds(OpBuilder &builder) {
+ SmallVector<Range> res;
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ for (auto en : llvm::enumerate(getOperandShape().drop_back())) {
+ Value size;
+ if (en.value() == ShapedType::kDynamicSize) {
+ size = getDimValue(builder, loc, getReal(), en.index());
+ } else {
+ size = builder.create<arith::ConstantIndexOp>(loc, en.value());
+ }
+ res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/one});
+ }
+
+ Value size = getDimValue(builder, loc, getReal(), getOperandRank() - 1);
+ Value stride = builder.create<arith::ShLIOp>(loc, one, getStage());
+ res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/stride});
+ return res;
+}
+
+void FftOp::generateScalarImplWithoutCoeffBuf(OpBuilder &b, Location loc,
+ ArrayRef<Value> operands,
+ Value wholeSize) {
+ auto rank = getOperandRank();
+ SmallVector<AffineMap> maps(operands.size(), b.getMultiDimIdentityMap(rank));
+
+ auto f32Type = b.getF32Type();
+ auto indexToF32 = [](OpBuilder &builder, Location loc, Value v) -> Value {
+ v = builder.create<arith::IndexCastOp>(loc, builder.getI32Type(), v);
+ return builder.create<arith::SIToFPOp>(loc, builder.getF32Type(), v);
+ };
+
+ // We will need exp(-2 * PI * j / m * I), compute "-2 * PI / m" for imag part
+ // first.
+ Value coeff = b.create<arith::ConstantFloatOp>(
+ loc, llvm::APFloat(static_cast<float>(-2 * acos(-1))), f32Type);
+ coeff = b.create<arith::DivFOp>(loc, coeff, indexToF32(b, loc, wholeSize));
+
+ b.create<linalg::GenericOp>(
+ loc, TypeRange{}, ValueRange{}, operands, maps, getLoopIteratorTypes(),
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value lhsReal = args[0];
+ Value lhsImag = args[1];
+ Value rhsReal = args[2];
+ Value rhsImag = args[3];
+
+ // Compute "-2 * PI / m * j"
+ Value w = b.create<arith::MulFOp>(
+ loc, coeff,
+ indexToF32(b, loc, b.create<linalg::IndexOp>(loc, rank - 1)));
+ Value wReal = b.create<math::CosOp>(loc, w);
+ Value wImag = b.create<math::SinOp>(loc, w);
+
+ // t = w * a[k + j + mh];
+ // -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
+ Value xu = b.create<arith::MulFOp>(loc, wReal, rhsReal);
+ Value yv = b.create<arith::MulFOp>(loc, wImag, rhsImag);
+ Value xv = b.create<arith::MulFOp>(loc, wReal, rhsImag);
+ Value yu = b.create<arith::MulFOp>(loc, wImag, rhsReal);
+ Value tReal = b.create<arith::SubFOp>(loc, xu, yv);
+ Value tImag = b.create<arith::AddFOp>(loc, xv, yu);
+
+ // cplx u = a[k + j];
+ // a[k + j] = u + t;
+ // a[k + j + mh] = u - t;
+ Value r1 = b.create<arith::AddFOp>(loc, lhsReal, tReal);
+ Value r2 = b.create<arith::AddFOp>(loc, lhsImag, tImag);
+ Value r3 = b.create<arith::SubFOp>(loc, lhsReal, tReal);
+ Value r4 = b.create<arith::SubFOp>(loc, lhsImag, tImag);
+ b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4});
+ });
+}
+
+void FftOp::generateScalarImplWithCoeffBuf(OpBuilder &b, Location loc,
+ ArrayRef<Value> operands) {
+ auto rank = getOperandRank();
+ SmallVector<AffineMap> maps;
+ // The size of coefficent buffer is epxected to match `2^(stage-1)`, which
+ // equals to the last dim of operands.
+ maps.append(
+ 2, AffineMap::get(rank, 0, b.getAffineDimExpr(rank - 1), b.getContext()));
+ maps.append(operands.size(), b.getMultiDimIdentityMap(rank));
+
+ b.create<linalg::GenericOp>(
+ loc, TypeRange{}, ValueRange{getRealCoeff(), getImagCoeff()}, operands,
+ maps, getLoopIteratorTypes(),
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value wReal = args[0];
+ Value wImag = args[1];
+ Value lhsReal = args[2];
+ Value lhsImag = args[3];
+ Value rhsReal = args[4];
+ Value rhsImag = args[5];
+
+ // t = w * a[k + j + mh];
+ // -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
+ Value xu = b.create<arith::MulFOp>(loc, wReal, rhsReal);
+ Value yv = b.create<arith::MulFOp>(loc, wImag, rhsImag);
+ Value xv = b.create<arith::MulFOp>(loc, wReal, rhsImag);
+ Value yu = b.create<arith::MulFOp>(loc, wImag, rhsReal);
+ Value tReal = b.create<arith::SubFOp>(loc, xu, yv);
+ Value tImag = b.create<arith::AddFOp>(loc, xv, yu);
+
+ // cplx u = a[k + j];
+ // a[k + j] = u + t;
+ // a[k + j + mh] = u - t;
+ Value r1 = b.create<arith::AddFOp>(loc, lhsReal, tReal);
+ Value r2 = b.create<arith::AddFOp>(loc, lhsImag, tImag);
+ Value r3 = b.create<arith::SubFOp>(loc, lhsReal, tReal);
+ Value r4 = b.create<arith::SubFOp>(loc, lhsImag, tImag);
+ b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4});
+ });
+}
+
+// Generates FFT stage scalar implementation. This follows Cooley–Tukey FFT
+// algorithm. The pseudo reference code is:
+// let s <- stage of linalg_ext.fft
+// int m = 1 << s;
+// int mh = m >> 1;
+// for (int k = 0; k < n; k += m) {
+// for (int j = 0; j < mh; ++j) {
+// cplx w = exp(-2 * PI * j / m * I);
+// cplx t = w * a[k + j + mh];
+// cplx u = a[k + j];
+// a[k + j] = u + t;
+// a[k + j + mh] = u - t;
+// }
+// }
+LogicalResult FftOp::generateScalarImplementation(OpBuilder &b, Location loc,
+ ValueRange ivs) {
+ Value real = getReal();
+ Value imag = getImag();
+ Value stage = getStage();
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ Value wholeSize = b.create<arith::ShLIOp>(loc, one, stage);
+ Value halfSize = b.create<arith::ShRSIOp>(loc, wholeSize, one);
+
+ auto rank = getOperandRank();
+ SmallVector<Value> operands;
+ SmallVector<OpFoldResult> lhsIvs(ivs.begin(), ivs.end());
+ SmallVector<OpFoldResult> ones(rank, b.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes(rank, b.getIndexAttr(1));
+ sizes.back() = halfSize;
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, real, lhsIvs, sizes, ones));
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, imag, lhsIvs, sizes, ones));
+
+ SmallVector<OpFoldResult> rhsIvs(ivs.begin(), ivs.end());
+ rhsIvs.back() =
+ b.create<arith::AddIOp>(loc, ivs.back(), halfSize).getResult();
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, real, rhsIvs, sizes, ones));
+ operands.push_back(
+ b.create<memref::SubViewOp>(loc, imag, rhsIvs, sizes, ones));
+
+ if (hasCoeff()) {
+ generateScalarImplWithCoeffBuf(b, loc, operands);
+ } else {
+ generateScalarImplWithoutCoeffBuf(b, loc, operands, wholeSize);
+ }
+
+ return success();
+}
+
+bool FftOp::payloadUsesValueFromOperand(OpOperand *) { return false; }
+
+SmallVector<unsigned> FftOp::getPartitionableLoops(
+ unsigned maxNumParallelDims) {
+ auto range = llvm::seq<unsigned>(0, getOperandRank());
+ SmallVector<unsigned> partitionableLoops(range.begin(), range.end());
+ // Indices matter for coeff computation.
+ if (!hasCoeff()) {
+ partitionableLoops.pop_back();
+ }
+ if (partitionableLoops.size() > maxNumParallelDims) {
+ partitionableLoops.erase(
+ partitionableLoops.begin(),
+ std::next(partitionableLoops.begin(),
+ partitionableLoops.size() - maxNumParallelDims));
+ }
+ return partitionableLoops;
+}
+
+Operation *FftOp::getTiledImplementation(OpBuilder &builder, ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) {
+ int64_t rank = getOperandRank();
+ SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands(3);
+ tiledOperands[0] = getStage();
+ tiledOperands[1] = getRealCoeff();
+ tiledOperands[2] = getImagCoeff();
+ SmallVector<Type, 4> resultTypes;
+
+ for (auto out : outputs) {
+ tiledOperands.push_back(
+ getSlice(builder, getLoc(), out, offsets, sizes, strides));
+ if (hasTensorSemantics()) {
+ resultTypes.push_back(tiledOperands.back().getType());
+ }
+ }
+ Operation *tiledFftOp = cast<LinalgExtOp>(getOperation())
+ .clone(builder, loc, resultTypes, tiledOperands);
+ for (auto result : llvm::enumerate(tiledFftOp->getResults())) {
+ auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+ loc, result.value(), outputs[result.index()], offsets, sizes, strides);
+ results.push_back(insertSliceOp.getResult());
+ }
+ return tiledFftOp;
+}
+
+//===----------------------------------------------------------------------===//
+// ReverseOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyReverseOp(ReverseOp op) {
+ if (op.getNumInputs() != 1) {
+ return op.emitOpError("expected exactly one input");
+ }
+ if (op.getNumOutputs() != 1) {
+ return op.emitOpError("expected exactly one output");
+ }
+ auto inputType = op.input().getType().cast<ShapedType>();
+ auto outputType = op.output().getType().cast<ShapedType>();
+ if (inputType.getElementType() != outputType.getElementType()) {
+ return op.emitOpError(
+ "expected input/output element types to be identical");
+ }
+ ArrayRef<int64_t> inputShapes = inputType.getShape();
+ ArrayRef<int64_t> outputShapes = outputType.getShape();
+ if (inputShapes.size() != outputShapes.size()) {
+ return op.emitOpError("expexted input/output to have identical ranks");
+ }
+ if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
+ [](std::tuple<int64_t, int64_t> s) {
+ return std::get<0>(s) != ShapedType::kDynamicSize &&
+ std::get<1>(s) != ShapedType::kDynamicSize &&
+ std::get<0>(s) != std::get<1>(s);
+ })) {
+ return op.emitOpError("incompatible input/output shapes");
+ }
+
+ int64_t rank = op.getOperandRank();
+ llvm::SmallSetVector<int64_t, 4> s;
+ for (auto dim : op.dims()) {
+ if (dim < 0 || dim >= rank) {
+ return op.emitOpError("all the dimensions must be within [0, ")
+ << rank << ")";
+ }
+ if (s.contains(dim)) {
+ return op.emitOpError("expected dimensions numbers are all unique");
+ }
+ s.insert(dim);
+ }
+
+ return success();
+}
+
+bool ReverseOp::payloadUsesValueFromOperand(OpOperand *) { return false; }
+
+SmallVector<StringRef> ReverseOp::getLoopIteratorTypes() {
+ SmallVector<StringRef> iteratorTypes(getOperandRank(),
+ getParallelIteratorTypeName());
+ return iteratorTypes;
+}
+
+SmallVector<Range> ReverseOp::getLoopBounds(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<Range> ranges;
+ for (auto dim : llvm::seq<int64_t>(0, getOperandRank())) {
+ Value ub = getDimValue(builder, loc, input(), dim);
+ ranges.emplace_back(Range{zero, ub, one});
+ }
+ return ranges;
+}
+
+LogicalResult ReverseOp::generateScalarImplementation(OpBuilder &b,
+ Location loc,
+ ValueRange ivs) {
+ SmallVector<Value> mirrorIndices(ivs.begin(), ivs.end());
+ for (auto dim : dims()) {
+ auto size = getDimValue(b, loc, input(), dim);
+ size = b.create<arith::SubIOp>(loc, size,
+ b.create<arith::ConstantIndexOp>(loc, 1));
+ mirrorIndices[dim] = b.create<arith::SubIOp>(loc, size, mirrorIndices[dim]);
+ }
+ Value val = b.create<memref::LoadOp>(loc, input(), ivs);
+ b.create<memref::StoreOp>(loc, val, output(), mirrorIndices);
+ return success();
+}
+
+Operation *ReverseOp::getTiledImplementation(OpBuilder &builder,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) {
+ int64_t rank = getOperandRank();
+ SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands;
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, input(), offsets, sizes, strides));
+
+ AffineExpr sym0, sym1, sym2;
+ bindSymbols(builder.getContext(), sym0, sym1, sym2);
+ AffineMap map =
+ AffineMap::get(/*dimCount=*/0, /*symbolCount=*/3, {sym0 - sym1 - sym2});
+ SmallVector<OpFoldResult> mirrorOffsets(offsets.begin(), offsets.end());
+ for (auto dim : dims()) {
+ Value size = getDimValue(builder, loc, input(), dim);
+ Value offset =
+ getValueOrCreateConstantIndexOp(builder, loc, mirrorOffsets[dim]);
+ Value tileSize = getValueOrCreateConstantIndexOp(builder, loc, sizes[dim]);
+ mirrorOffsets[dim] =
+ builder
+ .create<AffineApplyOp>(loc, map, ValueRange{size, offset, tileSize})
+ .getResult();
+ }
+
+ SmallVector<Type, 4> resultTypes;
+ if (hasTensorSemantics()) {
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, output(), mirrorOffsets, sizes, strides));
+ resultTypes.push_back(tiledOperands[1].getType());
+ } else {
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, output(), mirrorOffsets, sizes, strides));
+ }
+
+ Operation *tiledRevOp = cast<LinalgExtOp>(getOperation())
+ .clone(builder, loc, resultTypes, tiledOperands);
+
+ for (auto result : llvm::enumerate(tiledRevOp->getResults())) {
+ auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+ loc, result.value(), outputs[result.index()], mirrorOffsets, sizes,
+ strides);
+ results.push_back(insertSliceOp.getResult());
+ }
+ return tiledRevOp;
+}
+
+#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
+ void OP_NAME::getEffects( \
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
+ &effects) { \
+ SmallVector<Value> inputBuffers = getInputBufferOperands(); \
+ SmallVector<Value> outputBuffers = getOutputBufferOperands(); \
+ getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \
+ outputBuffers); \
+ }
+
+DEFINE_OP_GET_EFFECTS(ScatterOp)
+DEFINE_OP_GET_EFFECTS(SortOp)
+DEFINE_OP_GET_EFFECTS(FftOp)
+DEFINE_OP_GET_EFFECTS(ReverseOp)
+
+namespace {
+/// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any
+/// changes.
+struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgExtOp> {
+ using OpInterfaceRewritePattern<LinalgExtOp>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(LinalgExtOp op,
+ PatternRewriter &rewriter) const override {
+ // If no operand comes from a tensor::CastOp and can be folded then fail.
+ bool hasTensorCastOperand =
+ llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
+ if (opOperand->get().isa<BlockArgument>()) return false;
+ auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+ return castOp && canFoldIntoConsumerOp(castOp);
+ });
+ if (!hasTensorCastOperand) return failure();
+
+ SmallVector<Type, 4> newResultTypes;
+ newResultTypes.reserve(op->getNumResults());
+ SmallVector<Value, 4> newOperands;
+ newOperands.reserve(op->getNumOperands());
+ // Inputs may fold.
+ for (OpOperand *opOperand : op.getInputOperands()) {
+ auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+ newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
+ ? tensorCastOp.source()
+ : opOperand->get());
+ }
+ // Init tensors may fold, in which case the resultType must also change.
+ for (OpOperand *opOperand : op.getOutputOperands()) {
+ auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+ bool fold = canFoldIntoConsumerOp(tensorCastOp);
+ newOperands.push_back(fold ? tensorCastOp.getOperand()
+ : opOperand->get());
+ newResultTypes.push_back(newOperands.back().getType());
+ }
+ // Clone op.
+ Operation *newOp =
+ op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
+ SmallVector<Value, 4> replacements;
+ replacements.reserve(newOp->getNumResults());
+ for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
+ Value oldResult = std::get<0>(result);
+ Value newResult = std::get<1>(result);
+ if (newResult.getType() != oldResult.getType()) {
+ replacements.push_back(rewriter.create<tensor::CastOp>(
+ op->getLoc(), oldResult.getType(), newResult));
+ } else {
+ replacements.push_back(newResult);
+ }
+ }
+ rewriter.replaceOp(op, replacements);
+
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// LinalgExtDialect
+//===----------------------------------------------------------------------===//
+
+void IREELinalgExtDialect::getCanonicalizationPatterns(
+ RewritePatternSet &results) const {
+ results.add<FoldTensorCastOp>(getContext());
+}
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/TiledOpInterface.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/TiledOpInterface.cpp
new file mode 100644
index 0000000..105c394
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/TiledOpInterface.cpp
@@ -0,0 +1,310 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.h"
+
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+
+#define DEBUG_TYPE "iree-tiled-op-interface"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+using namespace IREE::LinalgExt;
+
+#include "iree-dialects/Dialect/LinalgExt/IR/TiledOpInterface.cpp.inc"
+
+/// Converts an `OpFoldResult` to a `Value` by building a constant op if
+/// if the `OpFoldResult` is an `IntegerAttr`.
+static Value getValue(OpBuilder &builder, Location loc,
+ OpFoldResult valueOrAttr) {
+ if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+ return builder.create<arith::ConstantIndexOp>(
+ loc, attr.cast<IntegerAttr>().getInt());
+ }
+ return valueOrAttr.get<Value>();
+}
+
+//===----------------------------------------------------------------------===//
+// Interface implementations for external operations.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// External model for `tensor.extract_slice`.
+struct ExtractSliceTiledOpInterface
+ : public TiledOpInterface::ExternalModel<ExtractSliceTiledOpInterface,
+ tensor::ExtractSliceOp> {
+ SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+ // No operand of `tensor.extract_slice` serves as a destination operand. So
+ // create an `init_tensor` op of the same size as the result.
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ SmallVector<Value> dest;
+ ReifiedRankedShapedTypeDims returnShape;
+ (void)extractSliceOp.reifyResultShapes(b, returnShape);
+ auto ofrShape = llvm::to_vector<4>(llvm::map_range(
+ returnShape[0], [](Value v) { return getAsOpFoldResult(v); }));
+ Value initTensor = b.create<linalg::InitTensorOp>(
+ op->getLoc(), ofrShape, extractSliceOp.getType().getElementType());
+ return {initTensor};
+ }
+
+ SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ return SmallVector<StringRef>(extractSliceOp.getType().getRank(),
+ getParallelIteratorTypeName());
+ }
+
+ SmallVector<Range> getLoopBounds(Operation *op, OpBuilder &b) const {
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ SmallVector<Value> dest;
+ ReifiedRankedShapedTypeDims returnShape;
+ (void)extractSliceOp.reifyResultShapes(b, returnShape);
+ Location loc = op->getLoc();
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<Range> loopRanges(returnShape[0].size(),
+ Range{zero, nullptr, one});
+ for (auto ub : enumerate(returnShape[0])) {
+ loopRanges[ub.index()].size = ub.value();
+ }
+ return loopRanges;
+ }
+
+ Operation *getTiledImplementation(Operation *op, OpBuilder &b,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) const {
+ auto extractOp = cast<tensor::ExtractSliceOp>(op);
+ // Check that strides are 1. For now abort if they arent
+ Location loc = extractOp.getLoc();
+ auto oneAttr = b.getI64IntegerAttr(1);
+
+ // Compute the offset and sizes for the tiled `tensor.extract_slice`
+ // operation.
+ llvm::SmallDenseSet<unsigned> droppedDims = extractOp.getDroppedDims();
+ unsigned resultDimPos = 0;
+ auto opOffsets = extractOp.getMixedOffsets();
+ auto opSizes = extractOp.getMixedSizes();
+ auto opStrides = extractOp.getMixedStrides();
+ MLIRContext *context = b.getContext();
+ SmallVector<OpFoldResult> newOffset, newSizes, newStrides;
+ for (auto opOffset : enumerate(opOffsets)) {
+ // If the dimension is dropped, use the same offset.
+ if (droppedDims.count(opOffset.index())) {
+ newOffset.push_back(opOffset.value());
+ newSizes.push_back(opSizes[opOffset.index()]);
+ } else {
+ AffineExpr d0, s0, s1;
+ bindDims(context, d0);
+ bindSymbols(context, s0, s1);
+ AffineMap map = AffineMap::get(1, 2, d0 * s0 + s1);
+ SmallVector<Value> operands = {
+ getValue(b, loc, offsets[resultDimPos]),
+ getValue(b, loc, opStrides[opOffset.index()]),
+ getValue(b, loc, opOffset.value())};
+ Value offset = b.create<AffineApplyOp>(loc, map, operands);
+ newOffset.push_back(offset);
+ newSizes.push_back(sizes[resultDimPos]);
+ resultDimPos++;
+ }
+ newStrides.push_back(opStrides[opOffset.index()]);
+ }
+
+ // Generate the tiled `tensor.extract_slice` operation.
+ Type resultType = tensor::ExtractSliceOp::inferRankReducedResultType(
+ extractOp.getType().getRank(), extractOp.getSourceType(), newOffset,
+ newSizes, newStrides);
+ auto tiledExtractOp = b.create<tensor::ExtractSliceOp>(
+ loc, resultType.cast<RankedTensorType>(), extractOp.source(), newOffset,
+ newSizes, newStrides);
+
+ // Insert the tiled extract into the result tensor.
+ SmallVector<OpFoldResult> resultStrides(offsets.size(), oneAttr);
+ auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
+ loc, tiledExtractOp.result(), outputs[0], offsets, sizes,
+ resultStrides);
+ results.push_back(tiledInsertOp.result());
+ return tiledExtractOp;
+ }
+};
+
+struct InsertSliceTiledOpInterface
+ : public TiledOpInterface::ExternalModel<InsertSliceTiledOpInterface,
+ tensor::InsertSliceOp> {
+ SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+ SmallVector<Value> dest;
+ dest.push_back(cast<tensor::InsertSliceOp>(op).dest());
+ return dest;
+ }
+
+ SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+ auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ return SmallVector<StringRef>(insertSliceOp.getSourceType().getRank(),
+ getParallelIteratorTypeName());
+ }
+
+ SmallVector<Range> getLoopBounds(Operation *op, OpBuilder &b) const {
+ auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ Value source = insertSliceOp.source();
+ RankedTensorType sourceType = insertSliceOp.getSourceType();
+ Location loc = op->getLoc();
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<Range> loopBounds(sourceType.getRank(),
+ Range{zero, nullptr, one});
+ for (auto dim :
+ llvm::seq<int64_t>(0, insertSliceOp.getSourceType().getRank())) {
+ loopBounds[dim].size = b.create<tensor::DimOp>(loc, source, dim);
+ }
+ return loopBounds;
+ }
+
+ Operation *getTiledImplementation(Operation *op, OpBuilder &b,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) const {
+ auto insertOp = cast<tensor::InsertSliceOp>(op);
+ // Compute a subtensor of the source based on the offsets.
+ auto opStrides = insertOp.getMixedStrides();
+ if (!llvm::all_of(opStrides, [&](OpFoldResult valueOrAttr) {
+ Optional<int64_t> intVal = getConstantIntValue(valueOrAttr);
+ return intVal && *intVal == 1;
+ })) {
+ op->emitOpError("unable to tile operation with non-unit stride");
+ return nullptr;
+ }
+ MLIRContext *context = b.getContext();
+ Location loc = insertOp.getLoc();
+ auto oneAttr = b.getI64IntegerAttr(1);
+ SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
+ auto extractSliceOp = b.create<tensor::ExtractSliceOp>(
+ loc, insertOp.source(), offsets, sizes, strides);
+
+ // The offsets for the insert is based on the op offsets plus the offsets of
+ // the loops passed in.
+ auto opOffsets = insertOp.getMixedOffsets();
+ auto opSizes = insertOp.getMixedSizes();
+ unsigned offsetIndex = 0;
+ ArrayRef<int64_t> sourceShape = insertOp.getSourceType().getShape();
+ int64_t destRank = insertOp.getType().getRank();
+ SmallVector<OpFoldResult> resultOffsets(destRank);
+ SmallVector<OpFoldResult> resultSizes(destRank);
+ for (auto opOffset : llvm::enumerate(opOffsets)) {
+ // Check for rank-reducing by checking that
+ // 1) The corresponding opSize value is 1
+ // 2) The current rank of the source is not 1.
+ // Then the opOffset is for the rank-reduced dimension. Skip.
+ unsigned opOffsetIndex = opOffset.index();
+ OpFoldResult opOffsetVal = opOffset.value();
+ Optional<int64_t> opSizeVal = getConstantIntValue(opSizes[opOffsetIndex]);
+ if (offsetIndex >= sourceShape.size() ||
+ (opSizeVal && *opSizeVal == 1 && sourceShape[offsetIndex] != 1)) {
+ resultOffsets[opOffsetIndex] = opOffsetVal;
+ resultSizes[opOffsetIndex] = oneAttr;
+ continue;
+ }
+ OpFoldResult offset = offsets[offsetIndex];
+ if (opOffsetVal.is<Attribute>() && offset.is<Attribute>()) {
+ resultOffsets[opOffsetIndex] = b.getI64IntegerAttr(
+ *getConstantIntValue(opOffsetVal) + *getConstantIntValue(offset));
+ } else {
+ AffineExpr d0, s0;
+ bindDims(context, d0);
+ bindSymbols(context, s0);
+ AffineMap map = AffineMap::get(1, 1, d0 + s0);
+ SmallVector<Value> operands = {getValue(b, loc, offset),
+ getValue(b, loc, opOffsetVal)};
+ resultOffsets[opOffsetIndex] =
+ b.create<AffineApplyOp>(loc, map, operands).getResult();
+ }
+ resultSizes[opOffsetIndex] = sizes[offsetIndex];
+ offsetIndex++;
+ }
+ SmallVector<OpFoldResult> resultStrides(destRank, oneAttr);
+ auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
+ loc, extractSliceOp.result(), outputs[0], resultOffsets, resultSizes,
+ resultStrides);
+ results.push_back(tiledInsertOp.result());
+ return extractSliceOp;
+ }
+};
+
+/// Forwards the implementation of `TiledOpInterface` to upstream
+/// `TilingInterface`. Note that this forwarding is only valid when the
+/// iteration space is same as the data space of the result(s). This is due to
+/// the difference in the tiling algorithm being developed around
+/// `TilingInterface` and that used with `TiledOpInterface`. The difference
+/// comes down to the former only needing the tiled operation, and not the value
+/// of the whole tensor.
+template <typename OpTy>
+struct ForwardToTilingInterface
+ : public TiledOpInterface::ExternalModel<ForwardToTilingInterface<OpTy>,
+ OpTy> {
+ SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+ return cast<OpTy>(op).getDestinationOperands(b);
+ }
+
+ SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+ return cast<OpTy>(op).getLoopIteratorTypes();
+ }
+ SmallVector<Range> getLoopBounds(Operation *op, OpBuilder &b) const {
+ return cast<OpTy>(op).getLoopBounds(b);
+ }
+ Operation *getTiledImplementation(Operation *op, OpBuilder &b,
+ ValueRange dest,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) const {
+ Operation *tiledOp =
+ cast<OpTy>(op).getTiledImplementation(b, dest, offsets, sizes);
+ if (!tiledOp) {
+ op->emitOpError("failed to tile operation");
+ return nullptr;
+ }
+ if (tiledOp->getNumResults() != dest.size()) {
+ op->emitOpError(
+ "mismatch in the number of results of the tiled operation and the "
+ "number of results expected");
+ return nullptr;
+ }
+ Location loc = op->getLoc();
+ auto oneAttr = b.getI64IntegerAttr(1);
+ SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
+ for (auto result : llvm::enumerate(tiledOp->getResults())) {
+ // Assume that the shape of the result is same as the loop bounds of the
+ // op. This implies the result can be inserted into the `dest` at
+ // `offsets` and `sizes`. This would be illegal if that is not the
+ // case. This is a point of difference between the `TiledOpInterface` in
+ // IREE and `TilingInterface` in MLIR, since the latter sees fusion and
+ // tiling as the same things. So it returns just the tiled op, and not the
+ // result of the full tensor as the current tiling algorithm expects.
+ auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
+ loc, result.value(), dest[result.index()], offsets, sizes, strides);
+ results.push_back(tiledInsertOp);
+ }
+ return tiledOp;
+ }
+};
+
+} // namespace
+
+void IREE::LinalgExt::registerTiledOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ LLVM_DEBUG(
+ { llvm::dbgs() << "Adding external models of tiled op interface\n"; });
+ registry
+ .addOpInterface<tensor::ExtractSliceOp, ExtractSliceTiledOpInterface>();
+ registry.addOpInterface<tensor::InsertSliceOp, InsertSliceTiledOpInterface>();
+ registry.addOpInterface<linalg::PadTensorOp,
+ ForwardToTilingInterface<linalg::PadTensorOp>>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..66078b9
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -0,0 +1,24 @@
+add_mlir_library(IREELinalgExtPasses
+ ConvertToLoops.cpp
+ Passes.cpp
+ Tiling.cpp
+
+ DEPENDS
+ IREELinalgExtTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEInputDialect
+ IREELinalgExtDialect
+ MLIRAffine
+ MLIRIR
+ MLIRLinalg
+ MLIRLinalgTransforms
+ MLIRMath
+ MLIRMemRef
+ MLIRPass
+ MLIRSCF
+ MLIRStandard
+ MLIRSupport
+ MLIRTensor
+ MLIRTransforms
+)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
new file mode 100644
index 0000000..0c2fcd0
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
@@ -0,0 +1,115 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/PassDetail.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+using namespace IREE::LinalgExt;
+
+/// Recursive method that lowers one dimension of the `TiledOpInterface` to
+/// scalar loops at a time.
+static LogicalResult lowerToLoopsImpl(OpBuilder &builder,
+ TiledOpInterface tilableOp,
+ ArrayRef<Range> loopRanges,
+ unsigned loopDepth,
+ SmallVectorImpl<Value> &ivs) {
+ Location loc = tilableOp.getLoc();
+ if (loopDepth == loopRanges.size()) {
+ return tilableOp.generateScalarImplementation(builder, loc, ivs);
+ }
+ LogicalResult status = success();
+ builder.create<scf::ForOp>(
+ loc, loopRanges[loopDepth].offset, loopRanges[loopDepth].size,
+ loopRanges[loopDepth].stride, ValueRange{},
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ ivs.push_back(iv);
+ status = lowerToLoopsImpl(b, tilableOp, loopRanges, loopDepth + 1, ivs);
+ b.create<scf::YieldOp>(loc);
+ });
+ return status;
+}
+
+/// Main entry point for lowering `TiledOpInterface` op to loops.
+static LogicalResult lowerToLoops(OpBuilder &builder,
+ TiledOpInterface tilableOp) {
+ SmallVector<Range> loopBounds = tilableOp.getLoopBounds(builder);
+ SmallVector<Value> ivs;
+ return lowerToLoopsImpl(builder, tilableOp, loopBounds, 0, ivs);
+}
+
+/// Pattern rewriter hook to lower a `TiledOpInterface` to loops.
+namespace {
+struct TiledOpInterfaceLowerToLoopsPattern : public RewritePattern {
+ TiledOpInterfaceLowerToLoopsPattern(MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto tilableOp = dyn_cast<TiledOpInterface>(op);
+ if (!tilableOp) {
+ return failure();
+ }
+ if (llvm::any_of(tilableOp->getResults(),
+ [&](Value v) { return v.getType().isa<ShapedType>(); })) {
+ return rewriter.notifyMatchFailure(
+ tilableOp, "lower to loops needs to have tensor semantics");
+ }
+ if (failed(lowerToLoops(rewriter, tilableOp))) {
+ return failure();
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct LinalgExtToLoopsPass
+ : public LinalgExtToLoopsBase<LinalgExtToLoopsPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, StandardOpsDialect,
+ mlir::arith::ArithmeticDialect, math::MathDialect,
+ memref::MemRefDialect, scf::SCFDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+
+ OwningRewritePatternList patterns(context);
+ patterns.insert<TiledOpInterfaceLowerToLoopsPattern>(context);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+IREE::LinalgExt::createLinalgExtToLoopsPass() {
+ return std::make_unique<LinalgExtToLoopsPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Passes.cpp
new file mode 100644
index 0000000..c41b9ed
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Passes.cpp
@@ -0,0 +1,33 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+namespace detail {
+#define GEN_PASS_REGISTRATION
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h.inc" // IWYU pragma: export
+} // namespace detail
+
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+void IREE::LinalgExt::registerPasses() {
+ IREE::LinalgExt::detail::registerPasses();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
new file mode 100644
index 0000000..9df31de
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -0,0 +1,354 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/PassDetail.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+using namespace IREE::LinalgExt;
+
+//===----------------------------------------------------------------------===//
+// Utility methods for tiling a linalg_ext operation that implements a
+// TiledOpInterface
+//===----------------------------------------------------------------------===//
+
+/// Returns failure if the options are unsupported.
+static LogicalResult verifySupportedTilingOptions(
+ PatternRewriter &rewriter, Operation *op,
+ const linalg::LinalgTilingOptions &options) {
+ if (!options.interchangeVector.empty()) {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported interchange during tiling");
+ }
+ if (options.loopType != linalg::LinalgTilingLoopType::Loops) {
+ return rewriter.notifyMatchFailure(op,
+ "only tiling with scf.for is supported");
+ }
+ if (options.distribution) {
+ if (llvm::any_of(options.distribution->distributionMethod,
+ [](linalg::DistributionMethod method) {
+ return method != linalg::DistributionMethod::Cyclic;
+ })) {
+ return rewriter.notifyMatchFailure(op,
+ "only cyclic distibution is allowed");
+ }
+ }
+ return success();
+}
+
+/// Converts an `OpFoldResult` to a `Value` by building a constant op if
+/// if the `OpFoldResult` is an `IntegerAttr`.
+static Value getValue(OpBuilder &builder, Location loc,
+ OpFoldResult valueOrAttr) {
+ if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+ return builder.create<arith::ConstantIndexOp>(
+ loc, attr.cast<IntegerAttr>().getInt());
+ }
+ return valueOrAttr.get<Value>();
+}
+
+/// Returns true if loop is untiled. Only checks if the value is statically
+/// zero. It is assumed that a `Value` defined by a constant op is already
+/// converted to an `IntegerAttr` of that value. So here just return true if
+/// this is an attribute with a zero value.
+static bool isUntiledLoop(OpFoldResult valueOrAttr) {
+ Optional<int64_t> intVal = getConstantIntValue(valueOrAttr);
+ return intVal && *intVal == 0;
+}
+
+/// Generates the tiled loops and the body by invoking the interface methods of
+/// TiledOpInterface.
+/// - `outputs` are the operands to use for outputs of the tiled operation.
+/// - `tileSizes` are tile sizes specified for all loops of the operation. If a
+/// loop is to be untiled it is set to 0.
+/// - `iteratorType` is the type of the loop iterator returned by the
+/// TiledOpInterface.
+/// - `loopBounds` are the bounds of all the loops of the op returned by the
+/// TiledOpInterface.
+/// - `loopDepth` is the current loop depth being processed.
+/// - `offsets` are the `Value`s that represent the position of the tile being
+/// operated on. The offsets are computed as the tiled loops are being
+/// generated.
+/// - `distributionInfo` is the proc_id and nprocs `Value`s to be used for
+/// distributed loops. It is a stack, and once an entry at the top of the
+/// stack is used for distribution it is popped before processing the inner
+/// loops.
+static FailureOr<TiledOp> tileInterfaceOpImpl(
+ OpBuilder &builder, TiledOpInterface tilableOp, ValueRange outputs,
+ MutableArrayRef<OpFoldResult> tileSizes, ArrayRef<StringRef> iteratorTypes,
+ ArrayRef<Range> loopBounds, unsigned loopDepth,
+ SmallVectorImpl<OpFoldResult> &offsets,
+ ArrayRef<linalg::ProcInfo> distributionInfo) {
+ Location loc = tilableOp.getLoc();
+ // If this is the innermost loop, then generated the tiled implementation of
+ // the op by invoking the TiledOpInterface methods.
+ if (loopDepth == tileSizes.size()) {
+ TiledOp ret;
+ ret.op = tilableOp.getTiledImplementation(builder, outputs, offsets,
+ tileSizes, ret.results);
+ if (!ret.op) {
+ return static_cast<LogicalResult>(
+ tilableOp.emitOpError("failed to get tiled implementation"));
+ }
+ return ret;
+ }
+
+ // If tile size at this depth is empty, do nothing.
+ if (isUntiledLoop(tileSizes[loopDepth])) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ offsets.push_back(zeroAttr);
+ assert(matchPattern(loopBounds[loopDepth].offset, m_Zero()) &&
+ "expected loop bounds to have lower bound of zero");
+ tileSizes[loopDepth] = getAsOpFoldResult(loopBounds[loopDepth].size);
+ return tileInterfaceOpImpl(builder, tilableOp, outputs, tileSizes,
+ iteratorTypes, loopBounds, loopDepth + 1,
+ offsets, distributionInfo);
+ }
+
+ // Generate an scf.for for the current loop depth.
+ Value lb = loopBounds[loopDepth].offset;
+ Value ub = loopBounds[loopDepth].size;
+ // TODO(#7073): Put the check back. This is required by tiling linalg_ext.fft
+ // op. We can put the check back after updating linalg_ext.fft semantics.
+ // if (!matchPattern(loopBounds[loopDepth].stride, m_One())) {
+ // return static_cast<LogicalResult>(
+ // tilableOp.emitOpError("expected stride to be 1"));
+ //}
+ Value step = getValue(builder, loc, tileSizes[loopDepth]);
+
+ // Update lb, ub and step for cyclic distribution.
+ if (!distributionInfo.empty() &&
+ iteratorTypes[loopDepth] == getParallelIteratorTypeName()) {
+ linalg::updateBoundsForCyclicDistribution(
+ builder, loc, distributionInfo.front().procId,
+ distributionInfo.front().nprocs, lb, ub, step);
+ distributionInfo = distributionInfo.drop_front();
+ }
+ FailureOr<TiledOp> innerReturnValue;
+ bool isBufferTiling = tilableOp->getNumResults() == 0;
+ ValueRange initValues(isBufferTiling ? ValueRange{} : outputs);
+ auto forOp = builder.create<scf::ForOp>(
+ loc, lb, ub, step, initValues,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ offsets.push_back(iv);
+ auto affineMaps = AffineMap::inferFromExprList({ArrayRef<AffineExpr>{
+ b.getAffineSymbolExpr(0),
+ b.getAffineSymbolExpr(1) - b.getAffineDimExpr(0)}})[0];
+ // Similar to linalg tiling, the tile size is the min(tileSizes, ub -
+ // iv) to account for cases where tile size does not divide (ub - lb)
+ // exactly.
+ Value inBoundsTileSize = b.create<AffineMinOp>(
+ loc, affineMaps,
+ ValueRange{iv, getValue(builder, loc, tileSizes[loopDepth]), ub});
+ tileSizes[loopDepth] = getAsOpFoldResult(inBoundsTileSize);
+ // Recursively proceed to generate the tiled loop for the next level.
+ innerReturnValue =
+ tileInterfaceOpImpl(b, tilableOp, (isBufferTiling ? outputs : args),
+ tileSizes, iteratorTypes, loopBounds,
+ loopDepth + 1, offsets, distributionInfo);
+ if (failed(innerReturnValue)) return;
+ b.create<scf::YieldOp>(loc, innerReturnValue->results);
+ });
+ if (failed(innerReturnValue)) {
+ return innerReturnValue;
+ }
+ innerReturnValue->loops.insert(innerReturnValue->loops.begin(),
+ forOp.getOperation());
+ innerReturnValue->results = forOp.getResults();
+ return innerReturnValue;
+}
+
+FailureOr<TiledOp> tileInterfaceOp(OpBuilder &b, TiledOpInterface tilableOp,
+ const linalg::LinalgTilingOptions &options) {
+ SmallVector<Value> dest = tilableOp.getDestinationOperands(b);
+ if (dest.empty()) {
+ return static_cast<LogicalResult>(tilableOp.emitOpError(
+ "cannot tile operation without destination operands"));
+ }
+
+ SmallVector<StringRef> iteratorTypes = tilableOp.getLoopIteratorTypes();
+ SmallVector<Value, 4> tileSizesVals =
+ options.tileSizeComputationFunction(b, tilableOp);
+ auto zeroAttr = b.getI64IntegerAttr(0);
+
+ // The actual tile sizes used converts `Value` defined as constant 0, to a
+ // zero integer attributes. Currently if the iterator type is not "parallel",
+ // the tile size is forced to zero as well.
+ auto tileSizes = getAsOpFoldResult(tileSizesVals);
+ tileSizes.resize(iteratorTypes.size(), zeroAttr);
+ for (auto en : llvm::enumerate(iteratorTypes)) {
+ if (en.value() == getParallelIteratorTypeName()) continue;
+ if (!isUntiledLoop(tileSizes[en.index()])) {
+ return static_cast<LogicalResult>(tilableOp.emitOpError(
+ "unimplemented tiling of non-parallel loop iterator type"));
+ }
+ }
+
+ // Trivial early exit case of tile sizes being zero for all parallel loops.
+ if (llvm::all_of(tileSizes, isUntiledLoop)) {
+ return TiledOp{tilableOp, {}, {}};
+ }
+
+ SmallVector<Range> loopBounds = tilableOp.getLoopBounds(b);
+ SmallVector<linalg::ProcInfo> distributionInfo;
+ // If the tiled loops are distributed, get the proc_id and nprocs for the
+ // distributed loops. First collect the parallel loops by iterating over the
+ // tileSizes and getting the loops that are distribute, i.e.,
+ // - parallel, i.e. iteratorTypes is "parallel"
+ // - tiled, i.e. tileSize != 0
+ if (options.distribution) {
+ SmallVector<Range> distributedLoopRange;
+ for (auto i : llvm::seq<unsigned>(0, tileSizes.size())) {
+ if (isUntiledLoop(tileSizes[i])) continue;
+ if (iteratorTypes[i] != getParallelIteratorTypeName()) continue;
+ distributedLoopRange.push_back(loopBounds[i]);
+ }
+ distributionInfo = options.distribution->procInfo(b, tilableOp.getLoc(),
+ distributedLoopRange);
+ }
+
+ SmallVector<OpFoldResult> offsets;
+ return tileInterfaceOpImpl(b, tilableOp, dest, tileSizes, iteratorTypes,
+ loopBounds, 0, offsets, distributionInfo);
+}
+
+LogicalResult TiledOpInterfaceBaseTilingPattern::matchAndRewriteBase(
+ TiledOpInterface tilableOp, PatternRewriter &rewriter,
+ TiledOp &result) const {
+ if (failed(filter.checkAndNotify(rewriter, tilableOp))) {
+ return failure();
+ }
+ if (failed(verifySupportedTilingOptions(rewriter, tilableOp, options))) {
+ return failure();
+ }
+
+ FailureOr<TiledOp> res = tileInterfaceOp(rewriter, tilableOp, options);
+ if (failed(res)) return res;
+ result = *res;
+ if (result.op) {
+ filter.replaceLinalgTransformationFilter(rewriter, result.op);
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Test pass for tiling Linalg Ext ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TiledOpInterfaceTilingPass
+ : public TiledOpInterfaceTilingBase<TiledOpInterfaceTilingPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<
+ AffineDialect, IREE::Input::IREEInputDialect, linalg::LinalgDialect,
+ IREE::LinalgExt::IREELinalgExtDialect, memref::MemRefDialect,
+ StandardOpsDialect, mlir::arith::ArithmeticDialect, math::MathDialect,
+ tensor::TensorDialect, scf::SCFDialect>();
+ }
+ void runOnOperation() override;
+};
+} // namespace
+
+template <typename OpTy>
+static Value buildFlowWorkgroupInfoOp(OpBuilder &b, unsigned dim) {
+ return b.template create<OpTy>(b.getInsertionPoint()->getLoc(), dim);
+}
+
+void TiledOpInterfaceTilingPass::runOnOperation() {
+ FuncOp funcOp = getOperation();
+ MLIRContext *context = funcOp.getContext();
+
+ RewritePatternSet patterns(context);
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes({10, 20}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("tiling_input", context),
+ Identifier::get("tiling_output", context)));
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{0}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("no_tiling_input", context),
+ Identifier::get("no_tiling_output", context)));
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes({0, 20}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("outer_reduce_input", context),
+ Identifier::get("outer_reduce_output", context)));
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes({10, 0, 0}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("inner_reduce_input", context),
+ Identifier::get("inner_reduce_output", context)));
+
+ static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
+ [](OpBuilder &builder, Location loc, ArrayRef<Range> parallelLoopRanges) {
+ auto numParallelDims = parallelLoopRanges.size();
+
+ SmallVector<linalg::ProcInfo, 3> procInfo(numParallelDims);
+ for (size_t dim = 0; dim < numParallelDims; ++dim) {
+ procInfo[numParallelDims - dim - 1] = {
+ buildFlowWorkgroupInfoOp<IREE::Input::DispatchWorkgroupIDOp>(
+ builder, dim),
+ buildFlowWorkgroupInfoOp<IREE::Input::DispatchWorkgroupCountOp>(
+ builder, dim)};
+ }
+ return procInfo;
+ },
+ {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
+ linalg::DistributionMethod::Cyclic},
+ DenseMap<StringRef,
+ std::function<linalg::ProcInfo(OpBuilder &, Location)>>()};
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context,
+ linalg::LinalgTilingOptions()
+ .setTileSizes(ArrayRef<int64_t>{10, 0, 30})
+ .setDistributionOptions(workgroupDistributionOptions),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("distribute_input", context),
+ Identifier::get("distribute_output", context)));
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context,
+ linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{32}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("tiling_1d_stage5_fft_input", context),
+ Identifier::get("tiling_1d_stage5_fft_output", context)));
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context,
+ linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{10, 32}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("tiling_2d_stage5_fft_input", context),
+ Identifier::get("tiling_2d_stage5_fft_output", context)));
+
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OperationPass<FuncOp>>
+IREE::LinalgExt::createTiledOpInterfaceTilingPass() {
+ return std::make_unique<TiledOpInterfaceTilingPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/CMakeLists.txt
similarity index 100%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/CMakeLists.txt
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/CMakeLists.txt
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/CMakeLists.txt
new file mode 100644
index 0000000..7ca64d8
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_library(IREEPyDMDialect
+ PyDMDialect.cpp
+ PyDMOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${IREE_DIALECTS_SOURCE_DIR}/include
+
+ DEPENDS
+ IREEPyDMIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSideEffectInterfaces
+)
+
+iree_dialects_target_includes(IREEPyDMDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp
similarity index 94%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp
index af9e341..5db10f7 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Dialect.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp
@@ -4,10 +4,10 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Interfaces.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -18,11 +18,11 @@
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
using namespace PYDM;
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.cpp.inc"
-#include "iree-dialects/Dialect/IREEPyDM/IR/OpInterfaces.cpp.inc"
-#include "iree-dialects/Dialect/IREEPyDM/IR/TypeInterfaces.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOpInterfaces.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMTypeInterfaces.cpp.inc"
#define GET_TYPEDEF_CLASSES
-#include "iree-dialects/Dialect/IREEPyDM/IR/Types.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMTypes.cpp.inc"
//------------------------------------------------------------------------------
// Dialect implementation
@@ -38,11 +38,11 @@
void IREEPyDMDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
-#include "iree-dialects/Dialect/IREEPyDM/IR/Types.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc"
>();
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp
similarity index 98%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp
index d79cbd8..7ef7bfc 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/Ops.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp
@@ -4,9 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
@@ -807,4 +807,4 @@
}
#define GET_OP_CLASSES
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..fd9704a
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_subdirectory(Optimize)
+add_subdirectory(RTL)
+add_subdirectory(ToIREE)
+
+add_mlir_library(IREEPyDMPasses
+ Passes.cpp
+
+ DEPENDS
+ IREEPyDMTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEPyDMOptimizePasses
+ IREEPyDMRTLPasses
+ IREEPyDMToIREEPasses
+ MLIRTransforms
+)
+
+iree_dialects_target_includes(IREEPyDMPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/CMakeLists.txt
new file mode 100644
index 0000000..1f7abb1
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_library(IREEPyDMOptimizePasses
+ FixateWeakNumeric.cpp
+ LocalPropagateTypes.cpp
+ VariablesToSSA.cpp
+
+ DEPENDS
+ IREEPyDMTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEPyDMDialect
+ IREEPyDMUtils
+ MLIRIR
+ MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMOptimizePasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp
similarity index 96%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp
index 992c53f..fc17764 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/FixateWeakNumeric.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp
@@ -5,8 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
using namespace mlir;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/LocalPropagateTypes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp
similarity index 98%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/LocalPropagateTypes.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp
index 3966f41..4f4a73f 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/LocalPropagateTypes.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp
@@ -5,9 +5,9 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
-#include "iree-dialects/Dialect/IREEPyDM/Utils/TypeInference.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/Utils/TypeInference.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/VariablesToSSA.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp
similarity index 98%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/VariablesToSSA.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp
index 482bf99..0a0a141 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Optimize/VariablesToSSA.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp
@@ -5,8 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h
new file mode 100644
index 0000000..9fbfc52
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h
@@ -0,0 +1,32 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
+#define IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace iree {
+class IREEDialect;
+}
+
+namespace iree_compiler {
+namespace IREE {
+namespace PYDM {
+
+class FuncOp;
+
+#define GEN_PASS_CLASSES
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h.inc"
+
+} // namespace PYDM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp
similarity index 92%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp
index cbaf558..820338c 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp
@@ -4,9 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
@@ -44,7 +44,7 @@
namespace PYDM_generated {
namespace {
#define GEN_PASS_REGISTRATION
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h.inc"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h.inc"
} // namespace
} // namespace PYDM_generated
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/CMakeLists.txt
new file mode 100644
index 0000000..08392f8
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_library(IREEPyDMRTLPasses
+ LinkageAnalysis.cpp
+ LinkRTLPass.cpp
+ LowerToRTLPass.cpp
+
+ DEPENDS
+ IREEPyDMTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEInputDialect
+ IREEPyDMDialect
+ MLIRIR
+ MLIRParser
+ MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMRTLPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp
similarity index 97%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp
index 315dfa4..f61214a 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp
@@ -7,9 +7,9 @@
#include <memory>
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp
similarity index 83%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp
index dc1a112..aee2bf2 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp
@@ -4,9 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/RTL/LinkageAnalysis.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
#include "mlir/IR/SymbolTable.h"
using namespace mlir;
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LowerToRTLPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp
similarity index 96%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LowerToRTLPass.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp
index ce3b970..67957a6 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LowerToRTLPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp
@@ -5,10 +5,10 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-#include "iree-dialects/Dialect/IREE/IREEOps.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -221,8 +221,8 @@
struct LowerIREEPyDMToRTLPass
: public LowerIREEPyDMToRTLBase<LowerIREEPyDMToRTLPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<mlir::iree::IREEDialect, BuiltinDialect, StandardOpsDialect>();
+ registry.insert<mlir::iree_compiler::IREE::Input::IREEInputDialect,
+ BuiltinDialect, StandardOpsDialect>();
}
void runOnOperation() override {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/CMakeLists.txt
new file mode 100644
index 0000000..4c8a7b4
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_library(IREEPyDMToIREEPasses
+ ConversionPass.cpp
+ LoweringPatterns.cpp
+ TypeConverter.cpp
+
+ DEPENDS
+ IREEPyDMTransformsPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEInputDialect
+ IREEPyDMDialect
+ MLIRArithmetic
+ MLIRIR
+ MLIRStandard
+ MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMToIREEPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp
similarity index 79%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp
index 63e343c..f49d0b1 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp
@@ -5,12 +5,12 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "../PassDetail.h"
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-#include "iree-dialects/Dialect/IREE/IREEOps.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/Patterns.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -25,8 +25,8 @@
struct ConvertIREEPyDMToIREEPass
: public ConvertIREEPyDMToIREEBase<ConvertIREEPyDMToIREEPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<mlir::iree::IREEDialect, BuiltinDialect, StandardOpsDialect,
- math::MathDialect>();
+ registry.insert<mlir::iree_compiler::IREE::Input::IREEInputDialect,
+ BuiltinDialect, StandardOpsDialect, math::MathDialect>();
}
void runOnOperation() override {
@@ -39,7 +39,8 @@
ConversionTarget target(*context);
target.addIllegalDialect<IREEPyDMDialect>();
target.addLegalDialect<BuiltinDialect>();
- target.addLegalDialect<mlir::iree::IREEDialect>();
+ target
+ .addLegalDialect<mlir::iree_compiler::IREE::Input::IREEInputDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addLegalDialect<mlir::math::MathDialect>();
target.addLegalDialect<mlir::StandardOpsDialect>();
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp
similarity index 95%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp
index 3f33b13..0d5317d 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp
@@ -4,9 +4,9 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREE/IREEOps.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Ops.h"
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/Patterns.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -18,6 +18,7 @@
using llvm::enumerate;
using namespace mlir;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+namespace Input = mlir::iree_compiler::IREE::Input;
using namespace PYDM;
namespace {
@@ -39,15 +40,16 @@
} // namespace
static Type getVariantListType(Builder &builder) {
- return builder.getType<iree::ListType>(builder.getType<iree::VariantType>());
+ return builder.getType<Input::ListType>(
+ builder.getType<Input::VariantType>());
}
static Value getNullValue(Location loc, OpBuilder &builder, Type t) {
return TypeSwitch<Type, Value>(t)
- .Case<iree::ListType>([&](auto t) -> Value {
+ .Case<Input::ListType>([&](auto t) -> Value {
// TODO: If it becomes important to optimize this, come up with a way
// to return an empty list without creating one.
- return builder.create<iree::ListCreateOp>(
+ return builder.create<Input::ListCreateOp>(
loc, getVariantListType(builder), /*capacity=*/nullptr);
})
.Default([&](Type t) -> Value {
@@ -75,8 +77,8 @@
}
static Value createUndefObjectList(Location loc, OpBuilder &builder) {
- return builder.create<iree::ListCreateOp>(loc, getVariantListType(builder),
- /*capacity=*/nullptr);
+ return builder.create<Input::ListCreateOp>(loc, getVariantListType(builder),
+ /*capacity=*/nullptr);
}
void resetObjectList(Location loc, OpBuilder &builder, Value list, int typeCode,
@@ -85,13 +87,13 @@
// to truly reset, we have to resize. Low level optimizations should be able
// to elide this if it turns out to be unnecessary.
auto size = builder.create<arith::ConstantIndexOp>(loc, 2);
- builder.create<iree::ListResizeOp>(loc, list, size);
+ builder.create<Input::ListResizeOp>(loc, list, size);
auto index0 = builder.create<arith::ConstantIndexOp>(loc, 0);
Value typeCodeValue = builder.create<arith::ConstantOp>(
loc, builder.getI32IntegerAttr(typeCode));
- builder.create<iree::ListSetOp>(loc, list, index0, typeCodeValue);
+ builder.create<Input::ListSetOp>(loc, list, index0, typeCodeValue);
auto index1 = builder.create<arith::ConstantIndexOp>(loc, 1);
- builder.create<iree::ListSetOp>(loc, list, index1, data);
+ builder.create<Input::ListSetOp>(loc, list, index1, data);
}
static Value createObjectList(Location loc, OpBuilder &builder, int typeCode,
@@ -334,7 +336,7 @@
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, slice.getType());
Value listSizeIndex =
- rewriter.create<iree::ListSizeOp>(loc, indexType, sequence);
+ rewriter.create<Input::ListSizeOp>(loc, indexType, sequence);
Value listSizeInteger = rewriter.create<arith::IndexCastOp>(
loc, slice.getType(), listSizeIndex);
Block *entryBlock = rewriter.getInsertionBlock();
@@ -387,7 +389,7 @@
{
rewriter.setInsertionPointToEnd(setElementBlock);
Value successResult = getSuccessStatusValue(loc, rewriter);
- rewriter.create<iree::ListSetOp>(
+ rewriter.create<Input::ListSetOp>(
loc, sequence, setElementBlock->getArgument(0), valueToSet);
rewriter.create<mlir::BranchOp>(loc, continuationBlock,
ValueRange{successResult});
@@ -555,7 +557,7 @@
rewriter.setInsertionPointToEnd(entryBlock);
auto arityValue =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(arity));
- Value listSize = rewriter.create<iree::ListSizeOp>(
+ Value listSize = rewriter.create<Input::ListSizeOp>(
loc, rewriter.getIndexType(), adaptor.sequence());
Value arityMatch = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, arityValue, listSize);
@@ -571,7 +573,7 @@
for (auto it : enumerate(slotTypes)) {
Value index = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(it.index()));
- Value slotValue = rewriter.create<iree::ListGetOp>(
+ Value slotValue = rewriter.create<Input::ListGetOp>(
loc, it.value(), adaptor.sequence(), index);
branchArgs.push_back(slotValue);
}
@@ -687,8 +689,8 @@
Type i32Type = rewriter.getIntegerType(32);
Value index0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
- Value typeCode =
- rewriter.create<iree::ListGetOp>(loc, i32Type, adaptor.value(), index0);
+ Value typeCode = rewriter.create<Input::ListGetOp>(loc, i32Type,
+ adaptor.value(), index0);
rewriter.replaceOp(
srcOp,
castIntegerValue(loc, typeCode, resultType.cast<mlir::IntegerType>(),
@@ -712,8 +714,8 @@
auto list = adaptor.getOperands()[0];
auto index1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
- rewriter.replaceOpWithNewOp<iree::ListGetOp>(srcOp, resultType, list,
- index1);
+ rewriter.replaceOpWithNewOp<Input::ListGetOp>(srcOp, resultType, list,
+ index1);
return success();
}
};
@@ -737,13 +739,13 @@
auto size = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(adaptor.elements().size()));
auto list =
- rewriter.create<iree::ListCreateOp>(loc, getVariantListType(rewriter),
- /*capacity=*/size);
- rewriter.create<iree::ListResizeOp>(loc, list, size);
+ rewriter.create<Input::ListCreateOp>(loc, getVariantListType(rewriter),
+ /*capacity=*/size);
+ rewriter.create<Input::ListResizeOp>(loc, list, size);
for (auto it : enumerate(adaptor.elements())) {
auto index = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(it.index()));
- rewriter.create<iree::ListSetOp>(loc, list, index, it.value());
+ rewriter.create<Input::ListSetOp>(loc, list, index, it.value());
}
rewriter.replaceOp(srcOp, ValueRange{list});
@@ -766,13 +768,13 @@
auto size = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(adaptor.slots().size()));
auto list =
- rewriter.create<iree::ListCreateOp>(loc, getVariantListType(rewriter),
- /*capacity=*/size);
- rewriter.create<iree::ListResizeOp>(loc, list, size);
+ rewriter.create<Input::ListCreateOp>(loc, getVariantListType(rewriter),
+ /*capacity=*/size);
+ rewriter.create<Input::ListResizeOp>(loc, list, size);
for (auto it : enumerate(adaptor.slots())) {
auto index = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(it.index()));
- rewriter.create<iree::ListSetOp>(loc, list, index, it.value());
+ rewriter.create<Input::ListSetOp>(loc, list, index, it.value());
}
rewriter.replaceOp(srcOp, ValueRange{list});
@@ -901,7 +903,7 @@
Type indexType = rewriter.getType<IndexType>();
Type listType = listOperand.getType();
Value subListSize =
- rewriter.create<iree::ListSizeOp>(loc, indexType, listOperand);
+ rewriter.create<Input::ListSizeOp>(loc, indexType, listOperand);
Value countInteger = countOperand;
Value countIndex =
rewriter.create<arith::IndexCastOp>(loc, indexType, countOperand);
@@ -916,8 +918,8 @@
Value newListSize =
rewriter.create<arith::MulIOp>(loc, subListSize, clampedCountIndex);
Value newList =
- rewriter.create<iree::ListCreateOp>(loc, listType, clampedCountIndex);
- rewriter.create<iree::ListResizeOp>(loc, newList, newListSize);
+ rewriter.create<Input::ListCreateOp>(loc, listType, clampedCountIndex);
+ rewriter.create<Input::ListResizeOp>(loc, newList, newListSize);
// Split blocks to loop.
// TODO: Use a new list.copy op instead of an inner loop.
@@ -969,9 +971,9 @@
Value newListIt = innerBody->getArgument(0);
Value subListIt = innerBody->getArgument(1);
- Value elementValue = rewriter.create<iree::ListGetOp>(
+ Value elementValue = rewriter.create<Input::ListGetOp>(
loc, listElementType, listOperand, subListIt);
- rewriter.create<iree::ListSetOp>(loc, newList, newListIt, elementValue);
+ rewriter.create<Input::ListSetOp>(loc, newList, newListIt, elementValue);
newListIt = rewriter.create<arith::AddIOp>(loc, newListIt, oneIndex);
subListIt = rewriter.create<arith::AddIOp>(loc, subListIt, oneIndex);
@@ -1057,7 +1059,7 @@
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, slice.getType());
Value listSizeIndex =
- rewriter.create<iree::ListSizeOp>(loc, indexType, adaptor.value());
+ rewriter.create<Input::ListSizeOp>(loc, indexType, adaptor.value());
Value listSizeInteger = rewriter.create<arith::IndexCastOp>(
loc, slice.getType(), listSizeIndex);
@@ -1117,7 +1119,7 @@
{
rewriter.setInsertionPointToEnd(getElementBlock);
Value successResult = getSuccessStatusValue(loc, rewriter);
- Value resultValue = rewriter.create<iree::ListGetOp>(
+ Value resultValue = rewriter.create<Input::ListGetOp>(
loc, resultType, adaptor.value(), getElementBlock->getArgument(0));
rewriter.create<mlir::BranchOp>(loc, continuationBlock,
ValueRange{successResult, resultValue});
@@ -1180,7 +1182,7 @@
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value requiredTypeCodeValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(typeCode));
- Value actualTypeCodeValue = rewriter.create<iree::ListGetOp>(
+ Value actualTypeCodeValue = rewriter.create<Input::ListGetOp>(
loc, rewriter.getI32Type(), list, index0);
Value typeCodeEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, requiredTypeCodeValue,
@@ -1195,7 +1197,7 @@
auto index1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value successResult = getSuccessStatusValue(loc, rewriter);
- Value unboxedValue = rewriter.create<iree::ListGetOp>(
+ Value unboxedValue = rewriter.create<Input::ListGetOp>(
loc, targetUnboxedType, list, index1);
rewriter.create<mlir::BranchOp>(loc, continuationBlock,
ValueRange{successResult, unboxedValue});
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp
similarity index 87%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp
index a7b664f..a16fe16 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp
@@ -4,19 +4,21 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h"
-#include "iree-dialects/Dialect/IREE/IREEDialect.h"
-#include "iree-dialects/Dialect/IREEPyDM/IR/Dialect.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
namespace PYDM = mlir::iree_compiler::IREE::PYDM;
using namespace PYDM;
static Type getVariantListType(Builder &builder) {
- return builder.getType<iree::ListType>(builder.getType<iree::VariantType>());
+ return builder.getType<IREE::Input::ListType>(
+ builder.getType<IREE::Input::VariantType>());
}
LoweringTypeConverter::LoweringTypeConverter() {
@@ -81,7 +83,7 @@
addConversion([](mlir::IntegerType t) -> Optional<Type> { return t; });
addConversion([](mlir::FloatType t) -> Optional<Type> { return t; });
addConversion([](mlir::IndexType t) -> Optional<Type> { return t; });
- addConversion([](iree::ListType t) -> Optional<Type> { return t; });
+ addConversion([](IREE::Input::ListType t) -> Optional<Type> { return t; });
}
Type LoweringTypeConverter::getBoolType(Builder b) const {
@@ -103,7 +105,7 @@
bool LoweringTypeConverter::isTypeLegal(Type t) const {
return t.isa<mlir::IntegerType, mlir::FloatType, mlir::IndexType,
- iree::ListType>();
+ IREE::Input::ListType>();
}
bool LoweringTypeConverter::areTypesLegal(TypeRange types) const {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/CMakeLists.txt
new file mode 100644
index 0000000..2417731
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_library(IREEPyDMUtils
+ TypeInference.cpp
+
+ LINK_LIBS PUBLIC
+ IREEPyDMDialect
+ MLIRIR
+ MLIRStandard
+ MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMUtils)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/TypeInference.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp
similarity index 98%
rename from llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/TypeInference.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp
index d0b8d40..ed09f96 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Utils/TypeInference.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp
@@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/IREEPyDM/Utils/TypeInference.h"
+#include "iree-dialects/Dialect/PyDM/Utils/TypeInference.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeRange.h"