Extend linalg_ext lowering to handle signedness and more ops (#6987)
MHLO ops use signed (although written as signless for historical
reasons) and unsigned ops. The upstream MHLO->Linalg already handles
this and we need to handle it here as well.
In addition, this changes the approach to dealing with detensorizing,
relying more heavily on the dialect conversion framework type converter.
This leaves things as unrealized_conversion_cast ops, to be cleaned up
later after lowering the rest of MHLO. We could continue to use "real"
ops for [de]tensoring, but we need unrealized_conversion_cast for
signedness conversions anyway, so I think it's cleaner to use it
everywhere. It's perhaps unfortunate because really the correct cast
there *is* a bitcast, but like all std ops, std.bitcast does not support
[un]signed integers.
Finally, it adds a bunch of ops for lowering within linalg_ext regions.
We should be handling everything that lowers to standard naturally, not
picking the handful of ops we've run into so far.
Fixes https://github.com/google/iree/issues/6154
diff --git a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
index c976b49..7c8e3e9 100644
--- a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
+++ b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -28,6 +28,42 @@
namespace mlir {
namespace iree_compiler {
+namespace {
+
+static Type convertInteger(IntegerType intType) {
+ return IntegerType::get(intType.getContext(),
+ intType.getIntOrFloatBitWidth());
+}
+
+static Optional<Type> convertTensor(TensorType tensorType) {
+ if (!tensorType.hasRank() || tensorType.getRank() != 0) return llvm::None;
+ Type elementType = tensorType.getElementType();
+ if (auto intType = elementType.dyn_cast<IntegerType>()) {
+ elementType = convertInteger(intType);
+ }
+ return elementType;
+}
+
+static Value materializeUnrealizedConversion(OpBuilder &builder, Type type,
+ ValueRange inputs, Location loc) {
+ return builder.create<UnrealizedConversionCastOp>(loc, type, inputs[0])
+ ->getResult(0);
+}
+
+class MhloToStdTypeConverter : public TypeConverter {
+ public:
+ MhloToStdTypeConverter() {
+ addConversion([](Type type) { return type; });
+
+ addConversion(convertTensor);
+ addConversion(convertInteger);
+
+ addTargetMaterialization(materializeUnrealizedConversion);
+ addSourceMaterialization(materializeUnrealizedConversion);
+ addArgumentMaterialization(materializeUnrealizedConversion);
+ }
+};
+
//===----------------------------------------------------------------------===//
// Utils
//===----------------------------------------------------------------------===//
@@ -47,8 +83,6 @@
return ret;
}
-namespace {
-
//===----------------------------------------------------------------------===//
// Region operations lowering.
//===----------------------------------------------------------------------===//
@@ -60,20 +94,13 @@
OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
if (!isInBodyOfLinalgExtOps(op)) return failure();
- if (!op.getResult().getType().template isa<TensorType>()) return failure();
+ TensorType origRetType = op.getType().template dyn_cast<TensorType>();
+ if (!origRetType) return failure();
SmallVector<Value> scalarArgs;
- for (auto arg : args) {
- if (auto ty = arg.getType().template dyn_cast<TensorType>()) {
- assert(ty.hasRank() && ty.getRank() == 0 &&
- "Have non-0D tensors in the region?");
- scalarArgs.push_back(
- rewriter.create<tensor::ExtractOp>(op.getLoc(), arg));
- } else {
- scalarArgs.push_back(arg);
- }
- }
- Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
- op, getElementTypeOrSelf(op.getType()), scalarArgs, &rewriter);
+ Type newRetType = getElementTypeOrSelf(
+ this->typeConverter->convertType(origRetType.getElementType()));
+ Value result =
+ lmhlo::HloOpToStdScalarOp::map<OpTy>(op, newRetType, args, &rewriter);
rewriter.replaceOp(op, result);
return success();
}
@@ -99,12 +126,12 @@
using OpConversionPattern<mhlo::SortOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::SortOp op, ArrayRef<Value> args,
+ mhlo::SortOp mhloSortOp, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
auto sortOp = rewriter.create<linalg_ext::SortOp>(
- op.getLoc(), op.getResultTypes(),
- /*inputs=*/ValueRange{}, args, op.dimensionAttr());
- rewriter.inlineRegionBefore(op.comparator(), sortOp.region(),
+ mhloSortOp.getLoc(), mhloSortOp.getResultTypes(),
+ /*inputs=*/ValueRange{}, args, mhloSortOp.dimensionAttr());
+ rewriter.inlineRegionBefore(mhloSortOp.comparator(), sortOp.region(),
sortOp.region().begin());
Region ®ion = sortOp.region();
Block &block = region.front();
@@ -116,7 +143,7 @@
}
rewriter.applySignatureConversion(®ion, signature_converter);
- rewriter.replaceOp(op, sortOp->getResults());
+ rewriter.replaceOp(mhloSortOp, sortOp->getResults());
return success();
}
};
@@ -391,28 +418,79 @@
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<linalg_ext::LinalgExtDialect, linalg::LinalgDialect,
IREE::Flow::FlowDialect, StandardOpsDialect,
- tensor::TensorDialect>();
+ complex::ComplexDialect, tensor::TensorDialect>();
}
void runOnOperation() override {
OwningRewritePatternList patterns(&getContext());
MLIRContext *context = &getContext();
+ MhloToStdTypeConverter typeConverter;
patterns.insert<SortOpConversion, ScatterOpConversion, FftOpConversion>(
- context);
- patterns.insert<LinalgExtRegionHLOOpConversion<mhlo::CompareOp>,
- LinalgExtRegionHLOOpConversion<mhlo::AddOp>,
- LinalgExtRegionHLOOpConversion<mhlo::SubOp>,
- LinalgExtRegionHLOOpConversion<mhlo::BitcastConvertOp>,
- LinalgExtRegionReturnOpConversion>(context,
- PatternBenefit(1000));
+ typeConverter, context);
+ // FIXME: It shouldn't be necessary to list every matching MHLO op here,
+ // especially since they're already listed in
+ // populateHLOToLinalgConversionPattern and in HloOpToStdScalarOp. These
+ // lists are all the same. Can we leverage SFINAE here?
+ patterns
+ .insert<LinalgExtRegionHLOOpConversion<mhlo::AbsOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::AddOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::AndOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::Atan2Op>,
+ LinalgExtRegionHLOOpConversion<mhlo::BitcastConvertOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::CeilOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::ClampOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::CompareOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::ComplexOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::ConvertOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::CopyOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::CosOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::DivOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::ExpOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::Expm1Op>,
+ LinalgExtRegionHLOOpConversion<mhlo::FloorOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::ImagOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::IsFiniteOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::LogOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::LogisticOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::Log1pOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::MaxOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::MinOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::MulOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::NegOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::NotOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::OrOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::PowOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::RealOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::RemOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::RsqrtOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::SelectOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::ShiftLeftOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::ShiftRightArithmeticOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::ShiftRightLogicalOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::SignOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::SinOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::SqrtOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::SubOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::TanhOp>,
+ LinalgExtRegionHLOOpConversion<mhlo::XorOp>,
+ LinalgExtRegionReturnOpConversion>(typeConverter, context);
ConversionTarget target(getContext());
target.addLegalDialect<linalg_ext::LinalgExtDialect, linalg::LinalgDialect,
IREE::Flow::FlowDialect, StandardOpsDialect,
- tensor::TensorDialect>();
+ tensor::TensorDialect, complex::ComplexDialect>();
target.addIllegalOp<mhlo::SortOp, mhlo::ScatterOp, mhlo::FftOp>();
- target.addLegalOp<mhlo::ComplexOp>();
+ // FFT conversion creates complex ops which will be converted by the normal
+ // MHLO lowering, but these should still be converted if present inside
+ // other linalg_ext op regions.
+ target.addDynamicallyLegalOp<mhlo::ComplexOp>(
+ [](mhlo::ComplexOp complexOp) {
+ return !isInBodyOfLinalgExtOps(complexOp);
+ });
+ // We deliberately allow unrealized casts to persist. These should fall away
+ // when the rest of MHLO is converted.
+ target.addLegalOp<UnrealizedConversionCastOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
index 38e602e..2938375 100644
--- a/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir
@@ -8,8 +8,9 @@
}) {dimension = 0 : i64, is_stable = false} : (tensor<128xi32>) -> (tensor<128xi32>)
return %0 : tensor<128xi32>
}
-// CHECK-LABEL: func @sort_1d
-// CHECK: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-LABEL: func @sort_1d(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: )
// CHECK: %[[SORT:.+]] = linalg_ext.sort
// CHECK-SAME: dimension(0)
// CHECK-SAME: outs(%[[ARG0]] : tensor<128xi32>)
@@ -30,16 +31,17 @@
return %1 : tensor<1x10xi32>
}
-// CHECK-LABEL: func @sort_with_cst
-// CHECK: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-LABEL: func @sort_with_cst(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: )
// CHECK: %[[CST:.+]] = mhlo.constant dense<0> : tensor<i32>
-// CHECK: %{{.+}} = linalg_ext.sort dimension(1) outs(%[[ARG0]] : tensor<1x10xi32>) {
+// CHECK: %[[SORT:.+]] = linalg_ext.sort dimension(1) outs(%[[ARG0]] : tensor<1x10xi32>) {
// CHECK: ^bb0(%[[ARG1:.+]]: i32, %{{.*}}: i32)
-// CHECK: %[[SCALAR:.+]] = tensor.extract %[[CST]][] : tensor<i32>
+// CHECK: %[[SCALAR:.+]] = builtin.unrealized_conversion_cast %[[CST]] : tensor<i32> to i32
// CHECK: %[[RES:.+]] = cmpi slt, %[[ARG1]], %[[SCALAR]] : i32
// CHECK: linalg_ext.yield %[[RES]] : i1
// CHECK: } -> tensor<1x10xi32>
-// CHECK: }
+// CHECK: return %[[SORT]]
// -----
@@ -51,8 +53,9 @@
}) {dimension = 0 : i64, is_stable = false} : (tensor<16x32xi32>) -> (tensor<16x32xi32>)
return %0 : tensor<16x32xi32>
}
-// CHECK-LABEL: func @sort_2d
-// CHECK: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-LABEL: func @sort_2d(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: )
// CHECK: %[[SORT:.+]] = linalg_ext.sort
// CHECK-SAME: dimension(0)
// CHECK-SAME: outs(%[[ARG0]] : tensor<16x32xi32>)
@@ -63,6 +66,89 @@
// -----
+func @sort_unsigned(%arg0: tensor<1x5xf32>) -> tensor<1x5xf32> {
+ %1 = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %2 = "mhlo.bitcast_convert"(%arg1) : (tensor<f32>) -> tensor<ui32>
+ %3 = "mhlo.bitcast_convert"(%arg2) : (tensor<f32>) -> tensor<ui32>
+ %4 = "mhlo.compare"(%2, %3) {comparison_direction = "LT"} : (tensor<ui32>, tensor<ui32>) -> tensor<i1>
+ "mhlo.return"(%4) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = true} : (tensor<1x5xf32>) -> tensor<1x5xf32>
+ return %1 : tensor<1x5xf32>
+}
+
+// CHECK-LABEL: func @sort_unsigned(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: )
+// CHECK: %[[SORT:.+]] = linalg_ext.sort
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: outs(%[[ARG0]] : tensor<1x5xf32>)
+// CHECK: ^bb0(%[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+// CHECK: %[[CAST1:.+]] = bitcast %[[ARG1]] : f32 to i32
+// CHECK: %[[CAST2:.+]] = bitcast %[[ARG2]] : f32 to i32
+// CHECK: %[[CMP:.+]] = cmpi ult, %[[CAST1]], %[[CAST2]] : i32
+// CHECK: linalg_ext.yield %[[CMP]]
+// CHECK: return %[[SORT]]
+
+// -----
+
+func @sort_unsigned_external_cst(%arg0: tensor<1x5xf32>) -> tensor<1x5xf32> {
+ %ui32 = mhlo.constant dense<2> : tensor<ui32>
+ %1 = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %2 = "mhlo.bitcast_convert"(%arg1) : (tensor<f32>) -> tensor<ui32>
+ %3 = "mhlo.compare"(%2, %ui32) {comparison_direction = "LT"} : (tensor<ui32>, tensor<ui32>) -> tensor<i1>
+ "mhlo.return"(%3) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = true} : (tensor<1x5xf32>) -> tensor<1x5xf32>
+ return %1 : tensor<1x5xf32>
+}
+
+// CHECK-LABEL: func @sort_unsigned_external_cst(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: )
+// CHECK: %[[UI32:.+]] = mhlo.constant dense<2> : tensor<ui32>
+// CHECK: %[[SORT:.+]] = linalg_ext.sort
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: outs(%[[ARG0]] : tensor<1x5xf32>)
+// CHECK: ^bb0(%[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+// CHECK: %[[CAST1:.+]] = bitcast %[[ARG1]] : f32 to i32
+// CHECK: %[[CONVERSION_CAST_CST:.+]] = builtin.unrealized_conversion_cast %[[UI32]] : tensor<ui32> to i32
+// CHECK: %[[CMP:.+]] = cmpi ult, %[[CAST1]], %[[CONVERSION_CAST_CST]] : i32
+// CHECK: linalg_ext.yield %[[CMP]]
+// CHECK: return %[[SORT]]
+
+// -----
+
+// For testing that complex within an linalg_ext op gets lowered
+func @sort_with_complex(%arg0: tensor<1x5xf32>, %arg1 : tensor<complex<f32>>) -> tensor<1x5xf32> {
+ %ui32 = mhlo.constant dense<2> : tensor<ui32>
+ %1 = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
+ %2 = "mhlo.complex"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<complex<f32>>
+ %3 = mhlo.add %2, %arg1 : tensor<complex<f32>>
+ %4 = "mhlo.real"(%3) : (tensor<complex<f32>>) -> tensor<f32>
+ %5 = "mhlo.imag"(%3) : (tensor<complex<f32>>) -> tensor<f32>
+ %6 = "mhlo.compare"(%4, %5) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%6) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = true} : (tensor<1x5xf32>) -> tensor<1x5xf32>
+ return %1 : tensor<1x5xf32>
+}
+
+// CHECK-LABEL: func @sort_with_complex(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: )
+// CHECK: %[[SORT:.+]] = linalg_ext.sort
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: outs(%[[ARG0]] : tensor<1x5xf32>)
+// CHECK: ^bb0(%[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
+// CHECK-NOT: mhlo.complex
+// CHECK: %[[CMP:.+]] = cmpf olt, %{{.+}}, %{{.+}} : f32
+// CHECK: linalg_ext.yield %[[CMP]]
+// CHECK: return %[[SORT]]
+
+// -----
+
func @topk(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) -> (tensor<128xi32>) {
%0:2 = "mhlo.sort"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>): // no predecessors