Add support for bufferizing LinalgExt ops. (#6377)
Since LinalgExtInterface is a subset of LinalgInterface, we can use
template in convertAnyLinalgOp. analyseLinalg*Ops function has different
implementation because we don't define indexing maps in LinalgExtOp.
Also adds a interface method -- clone.
This is a step towards https://github.com/google/iree/issues/6154
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index afd1b47..6e6b5c7 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -50,6 +50,7 @@
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/LinalgExt/IR",
"//iree/compiler/Dialect/Shape/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index ae91a28..8466201 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -58,6 +58,7 @@
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::Shape::IR
PUBLIC
)
diff --git a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
index 7a2d0e0..ad130d0 100644
--- a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
@@ -46,6 +46,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -380,6 +381,26 @@
return tiedOperands;
}
+static LogicalResult analyseLinalgExtOps(linalg_ext::LinalgExtOp op,
+ BufferizationPlan &plan) {
+ if (!op.hasTensorSemantics()) return success();
+ // TODO(hanchung): Revisit if we can tie together op.getOutputOperands() with
+ // the corresponding op.getInputOperands(). For now we have limit LinalgExt
+ // ops, and there is no use case. So we ignore it.
+ // Note: this is what should be done for LinalgOps, except for a what is done
+ // for operand fusion today.
+ for (auto input : op.getInputOperands()) {
+ plan.insert(input->get());
+ }
+ for (auto output : op.getOutputOperands()) {
+ plan.insert(output->get());
+ }
+ for (auto result : op->getResults()) {
+ plan.insert(result);
+ }
+ return success();
+}
+
/// Adds the corresponding `outs` and result tensors of the linalg op into the
/// same equivalence class.
static LogicalResult analyseLinalgOps(linalg::LinalgOp linalgOp,
@@ -580,6 +601,10 @@
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
return analyseLinalgOps(linalgOp, plan);
})
+ .Case<linalg_ext::LinalgExtOp>(
+ [&](linalg_ext::LinalgExtOp linalgExtOp) {
+ return analyseLinalgExtOps(linalgExtOp, plan);
+ })
.Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
[&](auto reshapeOp) {
return analyseSingleOperandResultOp(reshapeOp.src(),
@@ -910,7 +935,8 @@
resultBuffer =
TypeSwitch<Operation *, Value>(op)
.Case<scf::IfOp, scf::ForOp, linalg::LinalgOp,
- tensor::InsertSliceOp, vector::TransferWriteOp>(
+ linalg_ext::LinalgExtOp, tensor::InsertSliceOp,
+ vector::TransferWriteOp>(
[&](auto op) { return resultBuffer; })
.Case<linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp>(
[&](auto reshapeOp) {
@@ -1123,9 +1149,10 @@
/// Generic conversion pattern that matches any linalg::LinalgOp. This avoids
/// template instantiating one pattern for each linalg::LinalgOp. The method
/// expects all operands and results have already been mapped to memrefs.
+template <typename OpTy>
static LogicalResult convertAnyLinalgOp(
- OpBuilder &b, linalg::LinalgOp op, BlockAndValueMapping &bvm,
- BufferizationPlan &plan, WorkgroupMemoryAllocationFn allocationFn) {
+ OpBuilder &b, OpTy op, BlockAndValueMapping &bvm, BufferizationPlan &plan,
+ WorkgroupMemoryAllocationFn allocationFn) {
// Skip linalg ops inserted by this pass.
if (op.hasBufferSemantics()) return success();
@@ -1539,12 +1566,12 @@
}
return convertPadTensorOp(b, padTensorOp, bvm);
})
- .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
- if (failed(getOrAllocateResultBuffers(b, linalgOp.getOperation(), bvm,
- plan, allocationFn))) {
+ .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp>([&](auto op) {
+ if (failed(
+ getOrAllocateResultBuffers(b, op, bvm, plan, allocationFn))) {
return failure();
}
- return convertAnyLinalgOp(b, linalgOp, bvm, plan, allocationFn);
+ return convertAnyLinalgOp(b, op, bvm, plan, allocationFn);
})
.Case<tensor::InsertSliceOp>(
[&](tensor::InsertSliceOp subTensorInsertOp) {
diff --git a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
index e21945f..07257b6 100644
--- a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
@@ -2384,3 +2384,23 @@
// CHECK: scf.if
// CHECK-DAG: memref.store %[[V1]], %[[INOUT]][%[[P1]]]
// CHECK-DAG: memref.store %[[V2]], %[[INOUT]][%[[ARG1]]]
+
+// -----
+
+func @linalg_ext_sort_1d() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@rw[%c0] : !flow.dispatch.tensor<readwrite:128xi32>
+ %1 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readwrite:128xi32> -> tensor<128xi32>
+ %2 = linalg_ext.sort {dimension = 0 : i64} outs(%1 : tensor<128xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32): // no predecessors
+ %3 = cmpi sgt, %arg0, %arg1 : i32
+ linalg_ext.yield %3 : i1
+ } -> tensor<128xi32>
+ flow.dispatch.tensor.store %2, %0, offsets = [], sizes = [], strides = [] : tensor<128xi32> -> !flow.dispatch.tensor<readwrite:128xi32>
+ return
+}
+// CHECK-LABEL: func @linalg_ext_sort_1d()
+// CHECK-DAG: %[[INOUT:.+]] = hal.interface.binding.subspan @io::@rw
+// CHECK: linalg_ext.sort
+// CHECK-SAME: dimension = 0 : i64
+// CHECK-SAME: outs(%[[INOUT]] : memref<128xi32>)
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
index f67a982..35584d8 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h
@@ -7,6 +7,8 @@
#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
#define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
index 19424d2..77d0ddf 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td
@@ -438,6 +438,30 @@
return opOperand->get().getType().template isa<RankedTensorType>();
});
}]
+ >,
+ //===------------------------------------------------------------------===//
+ // Other static interface methods.
+ //===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Clone the current operation with the given location and operands. This
+ is used to abstract away the optional underlying region creation. This
+ does not change the balance between input, output_buffer and
+ init_tensors operands.
+ }],
+ /*retTy=*/"Operation *",
+ /*methodName=*/"clone",
+ (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
+ "ValueRange":$operands),
+ [{
+ BlockAndValueMapping bvm;
+ OperationState state(
+ loc, ConcreteOp::getOperationName(), operands, resultTypes,
+ $_op->getAttrs());
+ for (Region &r : $_op->getRegions())
+ r.cloneInto(state.addRegion(), bvm);
+ return b.createOperation(state);
+ }]
>
];