Add type propagation for illegal types for linalgExt sort (#14225)
This is very similar to the linalgExt scatter pass, but the differences are significant enough to where generalizing this introduces more painthan it helps alleviate.
Going to monitor and see how many ops wnd up needing a similar change and potentially reevaluate accordingly.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index 10c7a5d..99789c6 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -409,6 +409,71 @@
}
};
+/// Pattern to legalize `iree_linalg_ext.sort` operations.
+struct IREELinalgExtSortTypePropagation
+ : TypePropagationPattern<IREE::LinalgExt::SortOp> {
+ using TypePropagationPattern<IREE::LinalgExt::SortOp>::TypePropagationPattern;
+ LogicalResult
+ matchAndRewrite(IREE::LinalgExt::SortOp sortOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ SmallVector<Type> legalizedResultTypes;
+ for (Type resultType : sortOp->getResultTypes()) {
+ Type legalizedType = this->getTypeConverter()->convertType(resultType);
+ legalizedResultTypes.push_back(legalizedType);
+ }
+
+ // Create a clone of the operation without cloning its regions.
+ auto modifiedOp = cast<IREE::LinalgExt::SortOp>(mlir::cloneWithoutRegions(
+ rewriter, sortOp, {legalizedResultTypes}, adaptor.getOperands()));
+
+ // Inline the region from the original operation into the new operation.
+ rewriter.inlineRegionBefore(sortOp->getRegions().front(),
+ modifiedOp->getRegions().front(),
+ modifiedOp->getRegions().front().begin());
+ Region &modifiedOpRegion = modifiedOp->getRegions().front();
+
+ // Convert the signature of the region to use the corresponding element
+ // type.
+ TypeConverter::SignatureConversion signatureConverter(
+ modifiedOpRegion.getNumArguments());
+ for (auto [index, arg] : llvm::enumerate(modifiedOpRegion.getArguments())) {
+ std::optional<Type> legalizedArgType =
+ legalizeStorageElementType(arg.getType());
+ if (!legalizedArgType) {
+ return sortOp.emitOpError("failed to get legalized type for argument");
+ }
+ signatureConverter.addInputs(index, legalizedArgType.value());
+ }
+ rewriter.applySignatureConversion(&modifiedOpRegion, signatureConverter);
+
+ {
+ // Introduce scalar conversion operations to convert back to the original
+ // scalar type.
+ OpBuilder::InsertionGuard g(rewriter);
+ Block *entryBlock = &modifiedOp->getRegion(0).getBlocks().front();
+ for (auto [index, operand] : llvm::enumerate(sortOp->getOpOperands())) {
+ BlockArgument firstInputArg = entryBlock->getArgument(index * 2);
+ BlockArgument secondInputArg = entryBlock->getArgument(index * 2 + 1);
+
+ auto destType = getElementTypeOrSelf(operand.get().getType());
+ rewriter.setInsertionPointToStart(entryBlock);
+ if (destType != getElementTypeOrSelf(legalizedResultTypes[index])) {
+ Value replacementFirstInput = convertElementType(
+ rewriter, firstInputArg.getLoc(), destType, firstInputArg);
+ rewriter.replaceUsesOfBlockArgument(firstInputArg,
+ replacementFirstInput);
+ Value replacementSecondInput = convertElementType(
+ rewriter, secondInputArg.getLoc(), destType, secondInputArg);
+ rewriter.replaceUsesOfBlockArgument(secondInputArg,
+ replacementSecondInput);
+ }
+ }
+ }
+ rewriter.replaceOp(sortOp, modifiedOp->getResults());
+ return success();
+ }
+};
+
/// Simple rewrite pattern that just forwards the source as the result if the
/// result type is not legal (but source type is)
template <typename OpTy>
@@ -497,7 +562,8 @@
NamedOpTypePropagation<linalg::MatmulOp>,
NamedOpTypePropagation<linalg::MatvecOp>,
NamedOpTypePropagation<linalg::DotOp>, TensorExtractTypePropagation,
- IREELinalgExtScatterTypePropagation>(typeConverter, context);
+ IREELinalgExtScatterTypePropagation, IREELinalgExtSortTypePropagation>(
+ typeConverter, context);
ConversionTarget target(*context);
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
index 46d1c3a..336253d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
@@ -399,3 +399,69 @@
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[MIN]] : i1 to i8
// CHECK: iree_linalg_ext.yield %[[EXTUI]]
// CHECK: flow.dispatch.tensor.store %[[SCATTER]], %[[OUT]]
+
+// -----
+
+func.func @sort() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1xi8>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<1xi32>>
+ %2 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [1], strides = [1] : !flow.dispatch.tensor<readonly:tensor<1xi8>> -> tensor<1xi8>
+ %3 = arith.trunci %2 : tensor<1xi8> to tensor<1xi1>
+ %4 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [1], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<1xi32>> -> tensor<1xi32>
+ %5:2 = iree_linalg_ext.sort dimension(0) outs(%3, %4 : tensor<1xi1>, tensor<1xi32>) {
+ ^bb0(%arg0: i1, %arg1: i1, %arg2: i32, %arg3: i32):
+ %6 = arith.cmpi ult, %arg0, %arg1 : i1
+ iree_linalg_ext.yield %6 : i1
+ } -> tensor<1xi1>, tensor<1xi32>
+ flow.dispatch.tensor.store %5#1, %1, offsets = [0], sizes = [1], strides = [1] : tensor<1xi32> -> !flow.dispatch.tensor<readwrite:tensor<1xi32>>
+ return
+}
+
+// CHECK-LABEL: func.func @sort()
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[A:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-DAG: %[[B:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-DAG: %[[A_TENSOR:.+]] = flow.dispatch.tensor.load %[[A]]
+// CHECK-DAG: %[[B_TENSOR:.+]] = flow.dispatch.tensor.load %[[B]]
+// CHECK: %[[SORT:.+]]:2 = iree_linalg_ext.sort dimension(0)
+// CHECK-SAME: outs(%[[A_TENSOR]], %[[B_TENSOR]] : tensor<1xi8>, tensor<1xi32>)
+// CHECK-NEXT: ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: i8, %[[ARG1:[a-zA-Z0-9]+]]: i8, %[[ARG2:[a-zA-Z0-9]+]]: i32, %[[ARG3:[a-zA-Z0-9]+]]: i32)
+// CHECK-DAG: %[[TRUNC_A_1:.+]] = arith.trunci %[[ARG0]] : i8 to i1
+// CHECK-DAG: %[[TRUNC_A_2:.+]] = arith.trunci %[[ARG1]] : i8 to i1
+// CHECK-DAG: %[[CMPI:.+]] = arith.cmpi ult, %[[TRUNC_A_1]], %[[TRUNC_A_2]] : i1
+// CHECK: iree_linalg_ext.yield %[[CMPI]]
+// CHECK: flow.dispatch.tensor.store %[[SORT]]#1, %[[B]]
+
+
+// -----
+
+func.func @sort_secondary() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1xi32>>
+ %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<1xi8>>
+ %2 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [1], strides = [1] : !flow.dispatch.tensor<readonly:tensor<1xi32>> -> tensor<1xi32>
+ %3 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [1], strides = [1] : !flow.dispatch.tensor<readwrite:tensor<1xi8>> -> tensor<1xi8>
+ %4 = arith.trunci %3 : tensor<1xi8> to tensor<1xi1>
+ %5:2 = iree_linalg_ext.sort dimension(0) outs(%2, %4 : tensor<1xi32>, tensor<1xi1>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i1, %arg3: i1):
+ %6 = arith.cmpi ult, %arg0, %arg1 : i32
+ iree_linalg_ext.yield %6 : i1
+ } -> tensor<1xi32>, tensor<1xi1>
+ %7 = arith.extui %5#1 : tensor<1xi1> to tensor<1xi8>
+ flow.dispatch.tensor.store %7, %1, offsets = [0], sizes = [1], strides = [1] : tensor<1xi8> -> !flow.dispatch.tensor<readwrite:tensor<1xi8>>
+ return
+}
+
+// CHECK-LABEL: func.func @sort_secondary()
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[A:.+]] = hal.interface.binding.subspan set(0) binding(0)
+// CHECK-DAG: %[[B:.+]] = hal.interface.binding.subspan set(0) binding(1)
+// CHECK-DAG: %[[A_TENSOR:.+]] = flow.dispatch.tensor.load %[[A]]
+// CHECK-DAG: %[[B_TENSOR:.+]] = flow.dispatch.tensor.load %[[B]]
+// CHECK: %[[SORT:.+]]:2 = iree_linalg_ext.sort dimension(0)
+// CHECK-SAME: outs(%[[A_TENSOR]], %[[B_TENSOR]] : tensor<1xi32>, tensor<1xi8>)
+// CHECK-NEXT: ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: i32, %[[ARG1:[a-zA-Z0-9]+]]: i32, %[[ARG2:[a-zA-Z0-9]+]]: i8, %[[ARG3:[a-zA-Z0-9]+]]: i8)
+// CHECK-DAG: %[[CMPI:.+]] = arith.cmpi ult, %[[ARG0]], %[[ARG1]] : i32
+// CHECK: iree_linalg_ext.yield %[[CMPI]]
+// CHECK: flow.dispatch.tensor.store %[[SORT]]#1, %[[B]]