Legalize unsigned edge case for zero-extent `tensor.empty` (#14447)
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToIREEInputDialects.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToIREEInputDialects.cpp index c85e303..5ac2a39 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToIREEInputDialects.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToIREEInputDialects.cpp
@@ -336,6 +336,28 @@ } }; +struct TensorEmptyPattern final : OpConversionPattern<tensor::EmptyOp> { + using OpConversionPattern<tensor::EmptyOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto oldType = cast<ShapedType>(op.getType()); + auto newType = getTypeConverter()->convertType(oldType); + if (newType == oldType) + return failure(); + + if (!newType) + return rewriter.notifyMatchFailure(op, "result type conversion failed"); + + rewriter.replaceOpWithNewOp<tensor::EmptyOp>( + op, oldType.getShape(), + getTypeConverter()->convertType(oldType.getElementType()), + op.getDynamicSizes()); + return success(); + } +}; + struct GlobalOpPattern final : OpConversionPattern<ml_program::GlobalOp> { using OpConversionPattern<ml_program::GlobalOp>::OpConversionPattern; @@ -494,7 +516,7 @@ // Structural patterns (functions, cfg, terminators). patterns.add<BuiltinFuncOpPattern>(*typeConverter, context); - patterns.add<GlobalOpPattern>(*typeConverter, context); + patterns.add<GlobalOpPattern, TensorEmptyPattern>(*typeConverter, context); for (StringRef opName : {func::ReturnOp::getOperationName(), func::CallOp::getOperationName(), @@ -560,6 +582,10 @@ return typeConverter->isLegal(op.getType()); }); + target.addDynamicallyLegalOp<tensor::EmptyOp>([&](tensor::EmptyOp op) { + return typeConverter->isLegal(op.getType()); + }); + // Let the rest fall through. target.addLegalDialect<BuiltinDialect>(); target.addLegalDialect<IREE::LinalgExt::IREELinalgExtDialect>();
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_iree_input_dialects.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_iree_input_dialects.mlir index 93e599b..3a51c37 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_iree_input_dialects.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/test/stablehlo_to_iree_input_dialects.mlir
@@ -88,3 +88,13 @@ return %arg0 : tensor<5x6xcomplex<f32>> } } + +// ----- + +// CHECK-LABEL: @empty_zero_extent +func.func public @empty_zero_extent(%arg0: tensor<ui8>, %arg1: tensor<0x4xui32>) -> (tensor<0x4xui32>) { + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<0x4xi32> + %0 = tensor.empty() : tensor<0x4xui32> + // CHECK: return %[[EMPTY]] + return %0 : tensor<0x4xui32> +}