Integrate llvm/llvm-project@27ac46e6bea2 (#17662)
Updated to llvm/llvm-project@27ac46e6bea2
* Used LLVM `MathExtras.h` to replace MLIR one
* Updated `applySignatureConversion` usage
Updated to openxla/stablehlo@dd48ec5
* `chlo.minimum_broadcast_shapes` op was removed
https://github.com/openxla/stablehlo/pull/2287
* `chlo.dynamic_reshape` op was removed
https://github.com/openxla/stablehlo/pull/2286
* Added batching dims to scatter dims
https://github.com/openxla/stablehlo/pull/2259
Updated to llvm/torch-mlir@77d7f64
---------
Co-authored-by: hanhanW <hanhan0912@gmail.com>
Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
Co-authored-by: Quinn Dawkins <quinn@nod-labs.com>
Co-authored-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
index 432fcf2..8dc2d8d 100644
--- a/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
@@ -18,7 +18,6 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/BroadcastUtils.h"
@@ -443,38 +442,6 @@
}
};
-struct ConvertDynamicReshapeOp final
- : OpRewritePattern<mlir::chlo::DynamicReshapeOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(mlir::chlo::DynamicReshapeOp op,
- PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- TypedValue<TensorType> tensor = op.getOperand();
- TypedValue<RankedTensorType> shape = op.getOutputShape();
-
- auto shapeTy = cast<ShapedType>(shape.getType());
- auto resultTy = cast<ShapedType>(op.getType());
-
- Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, tensor);
- Value numEls = rewriter.create<shape::NumElementsOp>(loc, inputShape);
- Value cstr =
- rewriter.create<mlir::stablehlo::CstrReshapableOp>(loc, numEls, shape);
- rewriter.replaceOpWithNewOp<shape::AssumingOp>(
- op, cstr, [&](OpBuilder &b, Location l) {
- Value computedShape =
- b.create<mlir::stablehlo::ComputeReshapeShapeOp>(l, shapeTy,
- numEls, shape);
- SmallVector<Value> result;
- result.push_back(b.create<mlir::stablehlo::DynamicReshapeOp>(
- l, resultTy, tensor, computedShape));
- return result;
- });
-
- return success();
- }
-};
-
//===----------------------------------------------------------------------===//
// Decomposition Patterns.
//===----------------------------------------------------------------------===//
@@ -2192,7 +2159,6 @@
ConversionTarget conversionTarget(getContext());
RewritePatternSet conversionPatterns(ctx);
conversionTarget.addIllegalDialect<chlo::ChloDialect>();
- conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
conversionTarget.addLegalDialect<
mlir::stablehlo::StablehloDialect, mlir::arith::ArithDialect,
mlir::shape::ShapeDialect, mlir::scf::SCFDialect,
@@ -2239,9 +2205,7 @@
context, patterns, 10);
populateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
context, patterns, 5);
- patterns
- ->add<ConvertConstantLikeOp, ConvertDynamicReshapeOp, ConvertSelectOp>(
- context);
+ patterns->add<ConvertConstantLikeOp, ConvertSelectOp>(context);
}
static void populateDecompositionPatterns(MLIRContext *context,
diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp
index da22c9e..4970994 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp
@@ -601,7 +601,8 @@
auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
op.getContext(), newUpdateWindowDims,
- dimNumbers.getInsertedWindowDims(),
+ dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
+ dimNumbers.getScatterIndicesBatchingDims(),
dimNumbers.getScatterDimsToOperandDims(),
dimNumbers.getIndexVectorDim() + 1);
@@ -700,7 +701,8 @@
auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
op.getContext(), newUpdatedWindowDims,
- dimNumbers.getInsertedWindowDims(),
+ dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
+ dimNumbers.getScatterIndicesBatchingDims(),
dimNumbers.getScatterDimsToOperandDims(),
/*indexVectorDim=*/1);
@@ -801,7 +803,8 @@
auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
op.getContext(), newUpdatedWindowDims,
- dimNumbers.getInsertedWindowDims(),
+ dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
+ dimNumbers.getScatterIndicesBatchingDims(),
dimNumbers.getScatterDimsToOperandDims(),
/*indexVectorDim=*/indexVectorDim);
@@ -939,6 +942,8 @@
auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
op.getContext(), newUpdatedWindowDims, newInsertedWindowDims,
+ dimNumbers.getInputBatchingDims(),
+ dimNumbers.getScatterIndicesBatchingDims(),
dimNumbers.getScatterDimsToOperandDims(),
/*indexVectorDim=*/1);
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
index 1103337..0180abb 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
@@ -416,7 +416,7 @@
return rewriter.notifyMatchFailure(op,
"argument type conversion failed");
}
- rewriter.applySignatureConversion(newRegion, result);
+ rewriter.applySignatureConversion(&newRegion->front(), result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp
index 335805d..c886df8 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp
@@ -1653,7 +1653,7 @@
}
signatureConverter.addInputs(resultType.getElementType());
- rewriter.applySignatureConversion(®ion, signatureConverter,
+ rewriter.applySignatureConversion(®ion.front(), signatureConverter,
getTypeConverter());
rewriter.replaceOp(op, linalgOp.getResults());
return success();
@@ -1706,7 +1706,7 @@
signatureConverter.addInputs(idx, convertedTy);
}
- rewriter.applySignatureConversion(®ion, signatureConverter,
+ rewriter.applySignatureConversion(®ion.front(), signatureConverter,
getTypeConverter());
auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType,
linalgOp.getResults());
@@ -2073,8 +2073,8 @@
reduceSignConverter.addInputs(srcETy);
reduceSignConverter.addInputs(1, destETy);
reduceSignConverter.addInputs(indexETy);
- rewriter.applySignatureConversion(&reduceRegion, reduceSignConverter,
- getTypeConverter());
+ rewriter.applySignatureConversion(&reduceRegion.front(),
+ reduceSignConverter, getTypeConverter());
// Grab the terminator and use the turned value to now select the
// correct index and value.
@@ -2179,8 +2179,8 @@
scatterSignConverter.addInputs(indexETy);
scatterSignConverter.addInputs(0, sourceTy.getElementType());
scatterSignConverter.addInputs(1, sourceTy.getElementType());
- rewriter.applySignatureConversion(&scatterRegion, scatterSignConverter,
- getTypeConverter());
+ rewriter.applySignatureConversion(&scatterRegion.front(),
+ scatterSignConverter, getTypeConverter());
auto &scatterBlock = scatterRegion.front();
auto scatterTerminator = scatterBlock.getTerminator();
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
index d7b2a4e..a20f872 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
@@ -189,7 +189,7 @@
idx, getTypeConverter()->convertType(
getElementTypeOrSelf(argument.getType())));
}
- rewriter.applySignatureConversion(®ion, signature_converter);
+ rewriter.applySignatureConversion(®ion.front(), signature_converter);
rewriter.replaceOp(op, sortOp->getResults());
return success();
@@ -281,7 +281,7 @@
// where output[O] maps to block args #1 in linalg_ext.scatter ops.
signatureConverter.addInputs(1, argType);
signatureConverter.addInputs(0, argType);
- rewriter.applySignatureConversion(®ion, signatureConverter);
+ rewriter.applySignatureConversion(®ion.front(), signatureConverter);
rewriter.replaceOp(op, scatterOp->getResults());
return success();
@@ -598,7 +598,8 @@
TypeConverter::SignatureConversion signatureConverter(2);
signatureConverter.addInputs(0, input0Ty.getElementType());
signatureConverter.addInputs(1, init0Ty.getElementType());
- rewriter.applySignatureConversion(&scanOp.getRegion(), signatureConverter);
+ rewriter.applySignatureConversion(&scanOp.getRegion().front(),
+ signatureConverter);
rewriter.replaceOp(op, scanOp.getResult(0));
return success();
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp
index 8f43f12..ce8ba53 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp
@@ -201,7 +201,7 @@
cast<ShapedType>(val.getType()).getElementType()));
}
- rewriter.applySignatureConversion(®ion, signatureConverter,
+ rewriter.applySignatureConversion(®ion.front(), signatureConverter,
getTypeConverter());
rewriter.replaceOp(op, linalgOp.getResults());
return success();
@@ -301,7 +301,7 @@
// type for new operand number 'idx' + linalgOp.getNumInputs()
typeConverter->convertType(val.getElementType()));
}
- rewriter.applySignatureConversion(®ion, signatureConverter,
+ rewriter.applySignatureConversion(®ion.front(), signatureConverter,
getTypeConverter());
// Cast the result to the correct type.
@@ -470,7 +470,7 @@
i, cast<ShapedType>(input.getType()).getElementType());
}
- rewriter.applySignatureConversion(®ion, signatureConverter,
+ rewriter.applySignatureConversion(®ion.front(), signatureConverter,
getTypeConverter());
rewriter.replaceOp(op, linalgOp.getResults());
return success();
diff --git a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
index eea5386..06725dc 100644
--- a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
+++ b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
@@ -112,7 +112,7 @@
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result);
- rewriter.applySignatureConversion(newRegion, result);
+ rewriter.applySignatureConversion(&newRegion->front(), result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
diff --git a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
index dc3ba5b..d9097e3 100644
--- a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
+++ b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
@@ -66,7 +66,7 @@
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result);
- rewriter.applySignatureConversion(newRegion, result);
+ rewriter.applySignatureConversion(&newRegion->front(), result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
index b7b78bc..d9541a7 100644
--- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
@@ -11,17 +11,14 @@
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "llvm/ADT/SmallPtrSet.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
-#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
namespace Torch = mlir::torch::Torch;
namespace TorchConversion = mlir::torch::TorchConversion;
@@ -103,8 +100,18 @@
if (isa<TensorType>(ty))
return possibleTorchTensor;
+ if (auto defining = dyn_cast_or_null<TorchConversion::FromBuiltinTensorOp>(
+ possibleTorchTensor.getDefiningOp())) {
+ return defining.getOperand();
+ }
+
Torch::ValueTensorType vtensorType = cast<Torch::ValueTensorType>(ty);
TensorType builtinTy = vtensorType.toBuiltinTensor();
+ if (auto intTy = dyn_cast<IntegerType>(builtinTy.getElementType())) {
+ builtinTy =
+ builtinTy.clone(builder.getIntegerType(intTy.getIntOrFloatBitWidth()));
+ }
+
return builder.create<TorchConversion::ToBuiltinTensorOp>(
possibleTorchTensor.getLoc(), builtinTy, possibleTorchTensor);
}
@@ -357,6 +364,11 @@
builtinTensorType = tType;
} else if (auto vtType = dyn_cast<Torch::ValueTensorType>(torchType)) {
builtinTensorType = vtType.toBuiltinTensor();
+ if (auto intTy =
+ dyn_cast<IntegerType>(builtinTensorType.getElementType())) {
+ builtinTensorType = builtinTensorType.clone(
+ builder.getIntegerType(intTy.getIntOrFloatBitWidth()));
+ }
} else {
return emitError(loc) << "unsupported immutable tensor argument: "
<< torchType;
diff --git a/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir b/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir
index 9256519..801b49a 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir
@@ -2,15 +2,6 @@
// Check that the auto input conversion pipeline uses this plugin.
-// CHECK-LABEL: util.func public @simple_add_torch
-// CHECK: arith.addf
-func.func @simple_add_torch(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
- %int1 = torch.constant.int 1
- %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32>
- return %0 : !torch.vtensor<[2],f32>
-}
-
-// -----
// CHECK-LABEL: util.func public @simple_add_onnx
// CHECK: arith.addi
diff --git a/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir b/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
index 1ad40bc..fad4e7c 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
@@ -8,8 +8,8 @@
%zps = torch.vtensor.literal(dense<0.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16>
%bit_width = torch.constant.int 4
%group_size = torch.constant.int 2
- // CHECK: %[[TOBUILTIN:.*]] = torch_c.to_builtin_tensor %[[C0]] : !torch.vtensor<[8,4],ui8> -> tensor<8x4xi8>
- // CHECK: %[[BITCAST:.*]] = flow.tensor.bitcast %[[TOBUILTIN]] : tensor<8x4xi8> -> tensor<8x8xi4>
+ // CHECK: %[[TOBUILTIN:.*]] = torch_c.to_builtin_tensor %[[C0]] : !torch.vtensor<[8,4],ui8> -> tensor<8x4xui8>
+ // CHECK: %[[BITCAST:.*]] = flow.tensor.bitcast %[[TOBUILTIN]] : tensor<8x4xui8> -> tensor<8x8xi4>
// CHECK: %[[TOTORCH:.*]] = torch_c.from_builtin_tensor %[[BITCAST]] : tensor<8x8xi4> -> !torch.vtensor<[8,8],ui4>
%output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16>
return %output : !torch.vtensor<[1,1,8],f16>
diff --git a/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt b/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt
index 7a99114..2b96c50 100644
--- a/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt
+++ b/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt
@@ -85,8 +85,8 @@
TD_FILE
"${TORCH_MLIR_ROOT_DIR}/include/torch-mlir/Dialect/Torch/IR/TorchOps.td"
OUTS
- -gen-dialect-decls Dialect/Torch/IR/TorchDialect.h.inc
- -gen-dialect-defs Dialect/Torch/IR/TorchDialect.cpp.inc
+ -gen-dialect-decls -dialect=torch Dialect/Torch/IR/TorchDialect.h.inc
+ -gen-dialect-defs -dialect=torch Dialect/Torch/IR/TorchDialect.cpp.inc
-gen-op-decls Dialect/Torch/IR/TorchOps.h.inc
-gen-op-defs Dialect/Torch/IR/TorchOps.cpp.inc
)