Add LinalgExt TypePropagation pattern that handles i1 inputs/outputs (#13936)
This pattern functions under several assumptions:
* Input ("Updates") element type matches output element type.
* There is only one input/output -- although listed as variadic,
LinalgExt doesn't actually support variadic inputs/outputs. The input
list must always be length 2 (updates + indices) and the output list
must be always of length 1.
* We are only handling illegal types for updates and outputs. This is
because an input/output matrix of type i1 (for example) is a reasonable
situation, but an index matrix of an illegal type is an indication that
something else is going wrong (at least, I personally can't think of an
example where that would be the case)
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index 1d82847..2b10f5a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -24,6 +24,8 @@
//
//===---------------------------------------------------------------------===//
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Codegen/Common/CommonPasses.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
@@ -320,6 +322,90 @@
}
};
+/// Pattern to legalize `iree_linalg_ext.scatter` operations.
+struct IREELinalgExtScatterTypePropagation
+ : TypePropagationPattern<IREE::LinalgExt::ScatterOp> {
+ using TypePropagationPattern<
+ IREE::LinalgExt::ScatterOp>::TypePropagationPattern;
+ LogicalResult matchAndRewrite(
+ IREE::LinalgExt::ScatterOp scatterOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ auto opOperands = scatterOp->getOpOperands();
+ Type inputType = opOperands[0].get().getType();
+ Type legalizedInputType = this->getTypeConverter()->convertType(inputType);
+
+ if (inputType == legalizedInputType) {
+ return scatterOp.emitOpError(
+ "unexpected all types legal within conversion pattern");
+ }
+
+ Type resultType = opOperands[2].get().getType();
+ Type legalizedResultType =
+ this->getTypeConverter()->convertType(resultType);
+
+ // Create a clone of the operation without cloning its regions.
+ auto modifiedOp =
+ cast<IREE::LinalgExt::ScatterOp>(mlir::cloneWithoutRegions(
+ rewriter, scatterOp, {legalizedResultType}, adaptor.getOperands()));
+
+ // Inline the region from the original operation into the new operation.
+ rewriter.inlineRegionBefore(scatterOp->getRegions().front(),
+ modifiedOp->getRegions().front(),
+ modifiedOp->getRegions().front().begin());
+ Region &modifiedOpRegion = modifiedOp->getRegions().front();
+
+ // Convert the signature of the region to use the corresponding element
+ // type.
+ TypeConverter::SignatureConversion signatureConverter(
+ modifiedOpRegion.getNumArguments());
+ Type argType = modifiedOpRegion.getArguments()[0].getType();
+ std::optional<Type> legalizedArgType = legalizeStorageElementType(argType);
+ if (!legalizedArgType) {
+ return scatterOp.emitOpError("failed to get legalized type for argument");
+ }
+ signatureConverter.addInputs(0, legalizedArgType.value());
+ signatureConverter.addInputs(1, legalizedArgType.value());
+ rewriter.applySignatureConversion(&modifiedOpRegion, signatureConverter);
+
+ {
+ // Introduce scalar conversion operations to convert back to the original
+ // scalar type.
+ OpBuilder::InsertionGuard g(rewriter);
+ Block *entryBlock = &modifiedOp->getRegion(0).getBlocks().front();
+ BlockArgument inputArg = entryBlock->getArgument(0);
+ BlockArgument outputArg = entryBlock->getArgument(1);
+
+ auto destType = getElementTypeOrSelf(inputType);
+ rewriter.setInsertionPointToStart(entryBlock);
+
+ Value replacementInput =
+ convertElementType(rewriter, inputArg.getLoc(), destType, inputArg);
+ rewriter.replaceUsesOfBlockArgument(entryBlock->getArgument(0),
+ replacementInput);
+ Value replacementOutput =
+ convertElementType(rewriter, outputArg.getLoc(), destType, outputArg);
+ rewriter.replaceUsesOfBlockArgument(entryBlock->getArgument(1),
+ replacementOutput);
+
+ // If the output is of an illegal type, the yield value needs to be
+ // modified
+ auto yieldOp = entryBlock->getTerminator();
+
+ rewriter.setInsertionPoint(yieldOp);
+ OpOperand *modifiedOpOperand = &yieldOp->getOpOperand(0);
+
+ auto yieldOperand = convertElementType(rewriter, yieldOp->getLoc(),
+ legalizedArgType.value(),
+ modifiedOpOperand->get());
+
+ rewriter.replaceOpWithNewOp<IREE::LinalgExt::YieldOp>(yieldOp,
+ yieldOperand);
+ }
+ rewriter.replaceOp(scatterOp, modifiedOp->getResults());
+ return success();
+ }
+};
+
/// Simple rewrite pattern that just forwards the source as the result if the
/// result type is not legal (but source type is)
template <typename OpTy>
@@ -407,8 +493,8 @@
NamedOpTypePropagation<linalg::BatchMatmulOp>,
NamedOpTypePropagation<linalg::MatmulOp>,
NamedOpTypePropagation<linalg::MatvecOp>,
- NamedOpTypePropagation<linalg::DotOp>, TensorExtractTypePropagation>(
- typeConverter, context);
+ NamedOpTypePropagation<linalg::DotOp>, TensorExtractTypePropagation,
+ IREELinalgExtScatterTypePropagation>(typeConverter, context);
ConversionTarget target(*context);
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
index 312a493..46d1c3a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
@@ -359,3 +359,43 @@
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: return %[[GEMM]]
+
+// -----
+
+func.func @scatter() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<8xi8>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<8x1xi32>>
+ %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<3xi8>>
+ %3 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [8], strides = [1] : !flow.dispatch.tensor<readonly:tensor<8xi8>> -> tensor<8xi8>
+ %4 = arith.trunci %3 : tensor<8xi8> to tensor<8xi1>
+ %5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [8, 1], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<8x1xi32>> -> tensor<8x1xi32>
+ %6 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [3], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<3xi8>> -> tensor<3xi8>
+ %7 = arith.trunci %6 : tensor<3xi8> to tensor<3xi1>
+ %8 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%4, %5 : tensor<8xi1>, tensor<8x1xi32>) outs(%7 : tensor<3xi1>) {
+ ^bb0(%arg0: i1, %arg1: i1):
+ %10 = arith.minui %arg1, %arg0 : i1
+ iree_linalg_ext.yield %10 : i1
+ } -> tensor<3xi1>
+ %9 = arith.extui %8 : tensor<3xi1> to tensor<3xi8>
+ flow.dispatch.tensor.store %9, %2, offsets = [0], sizes = [3], strides = [1] : tensor<3xi8> -> !flow.dispatch.tensor<readwrite:tensor<3xi8>>
+ return
+}
+
+// CHECK-LABEL: func.func @scatter()
+// CHECK-DAG: %[[UPDATES:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-DAG: %[[INDICES:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-DAG: %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(2)
+// CHECK-DAG: %[[UPDATES_TENSOR:.+]] = flow.dispatch.tensor.load %[[UPDATES]]
+// CHECK-DAG: %[[INDICES_TENSOR:.+]] = flow.dispatch.tensor.load %[[INDICES]]
+// CHECK-DAG: %[[OUT_TENSOR:.+]] = flow.dispatch.tensor.load %[[OUT]]
+// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false)
+// CHECK-SAME: ins(%[[UPDATES_TENSOR]], %[[INDICES_TENSOR]] : tensor<8xi8>, tensor<8x1xi32>)
+// CHECK-SAME: outs(%[[OUT_TENSOR]] : tensor<3xi8>)
+// CHECK-NEXT: ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: i8, %[[ARG1:[a-zA-Z0-9]+]]: i8)
+// CHECK-DAG: %[[TRUNC0:.+]] = arith.trunci %[[ARG0]] : i8 to i1
+// CHECK-DAG: %[[TRUNC1:.+]] = arith.trunci %[[ARG1]] : i8 to i1
+// CHECK-DAG: %[[MIN:.+]] = arith.minui %[[TRUNC1]], %[[TRUNC0]] : i1
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[MIN]] : i1 to i8
+// CHECK: iree_linalg_ext.yield %[[EXTUI]]
+// CHECK: flow.dispatch.tensor.store %[[SCATTER]], %[[OUT]]