Integrate llvm/llvm-project@27ac46e6bea2 (#17662)

Updated to llvm/llvm-project@27ac46e6bea2
* Used LLVM `MathExtras.h` to replace MLIR one
* Updated `applySignatureConversion` usage

Updated to openxla/stablehlo@dd48ec5
* `chlo.minimum_broadcast_shapes` op was removed
  https://github.com/openxla/stablehlo/pull/2287
* `chlo.dynamic_reshape` op was removed
  https://github.com/openxla/stablehlo/pull/2286
* Added batching dims to scatter dims
  https://github.com/openxla/stablehlo/pull/2259

Updated to llvm/torch-mlir@77d7f64

---------

Co-authored-by: hanhanW <hanhan0912@gmail.com>
Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
Co-authored-by: Quinn Dawkins <quinn@nod-labs.com>
Co-authored-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
index 432fcf2..8dc2d8d 100644
--- a/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
@@ -18,7 +18,6 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "stablehlo/dialect/BroadcastUtils.h"
@@ -443,38 +442,6 @@
   }
 };
 
-struct ConvertDynamicReshapeOp final
-    : OpRewritePattern<mlir::chlo::DynamicReshapeOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(mlir::chlo::DynamicReshapeOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    TypedValue<TensorType> tensor = op.getOperand();
-    TypedValue<RankedTensorType> shape = op.getOutputShape();
-
-    auto shapeTy = cast<ShapedType>(shape.getType());
-    auto resultTy = cast<ShapedType>(op.getType());
-
-    Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, tensor);
-    Value numEls = rewriter.create<shape::NumElementsOp>(loc, inputShape);
-    Value cstr =
-        rewriter.create<mlir::stablehlo::CstrReshapableOp>(loc, numEls, shape);
-    rewriter.replaceOpWithNewOp<shape::AssumingOp>(
-        op, cstr, [&](OpBuilder &b, Location l) {
-          Value computedShape =
-              b.create<mlir::stablehlo::ComputeReshapeShapeOp>(l, shapeTy,
-                                                               numEls, shape);
-          SmallVector<Value> result;
-          result.push_back(b.create<mlir::stablehlo::DynamicReshapeOp>(
-              l, resultTy, tensor, computedShape));
-          return result;
-        });
-
-    return success();
-  }
-};
-
 //===----------------------------------------------------------------------===//
 // Decomposition Patterns.
 //===----------------------------------------------------------------------===//
@@ -2192,7 +2159,6 @@
       ConversionTarget conversionTarget(getContext());
       RewritePatternSet conversionPatterns(ctx);
       conversionTarget.addIllegalDialect<chlo::ChloDialect>();
-      conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
       conversionTarget.addLegalDialect<
           mlir::stablehlo::StablehloDialect, mlir::arith::ArithDialect,
           mlir::shape::ShapeDialect, mlir::scf::SCFDialect,
@@ -2239,9 +2205,7 @@
       context, patterns, 10);
   populateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
       context, patterns, 5);
-  patterns
-      ->add<ConvertConstantLikeOp, ConvertDynamicReshapeOp, ConvertSelectOp>(
-          context);
+  patterns->add<ConvertConstantLikeOp, ConvertSelectOp>(context);
 }
 
 static void populateDecompositionPatterns(MLIRContext *context,
diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp
index da22c9e..4970994 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp
@@ -601,7 +601,8 @@
 
     auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
         op.getContext(), newUpdateWindowDims,
-        dimNumbers.getInsertedWindowDims(),
+        dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
+        dimNumbers.getScatterIndicesBatchingDims(),
         dimNumbers.getScatterDimsToOperandDims(),
         dimNumbers.getIndexVectorDim() + 1);
 
@@ -700,7 +701,8 @@
 
     auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
         op.getContext(), newUpdatedWindowDims,
-        dimNumbers.getInsertedWindowDims(),
+        dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
+        dimNumbers.getScatterIndicesBatchingDims(),
         dimNumbers.getScatterDimsToOperandDims(),
         /*indexVectorDim=*/1);
 
@@ -801,7 +803,8 @@
 
     auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
         op.getContext(), newUpdatedWindowDims,
-        dimNumbers.getInsertedWindowDims(),
+        dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
+        dimNumbers.getScatterIndicesBatchingDims(),
         dimNumbers.getScatterDimsToOperandDims(),
         /*indexVectorDim=*/indexVectorDim);
 
@@ -939,6 +942,8 @@
 
     auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
         op.getContext(), newUpdatedWindowDims, newInsertedWindowDims,
+        dimNumbers.getInputBatchingDims(),
+        dimNumbers.getScatterIndicesBatchingDims(),
         dimNumbers.getScatterDimsToOperandDims(),
         /*indexVectorDim=*/1);
 
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
index 1103337..0180abb 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
@@ -416,7 +416,7 @@
         return rewriter.notifyMatchFailure(op,
                                            "argument type conversion failed");
       }
-      rewriter.applySignatureConversion(newRegion, result);
+      rewriter.applySignatureConversion(&newRegion->front(), result);
     }
     Operation *newOp = rewriter.create(state);
     rewriter.replaceOp(op, newOp->getResults());
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp
index 335805d..c886df8 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp
@@ -1653,7 +1653,7 @@
     }
     signatureConverter.addInputs(resultType.getElementType());
 
-    rewriter.applySignatureConversion(&region, signatureConverter,
+    rewriter.applySignatureConversion(&region.front(), signatureConverter,
                                       getTypeConverter());
     rewriter.replaceOp(op, linalgOp.getResults());
     return success();
@@ -1706,7 +1706,7 @@
       signatureConverter.addInputs(idx, convertedTy);
     }
 
-    rewriter.applySignatureConversion(&region, signatureConverter,
+    rewriter.applySignatureConversion(&region.front(), signatureConverter,
                                       getTypeConverter());
     auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType,
                                                         linalgOp.getResults());
@@ -2073,8 +2073,8 @@
     reduceSignConverter.addInputs(srcETy);
     reduceSignConverter.addInputs(1, destETy);
     reduceSignConverter.addInputs(indexETy);
-    rewriter.applySignatureConversion(&reduceRegion, reduceSignConverter,
-                                      getTypeConverter());
+    rewriter.applySignatureConversion(&reduceRegion.front(),
+                                      reduceSignConverter, getTypeConverter());
 
     // Grab the terminator and use the turned value to now select the
     // correct index and value.
@@ -2179,8 +2179,8 @@
     scatterSignConverter.addInputs(indexETy);
     scatterSignConverter.addInputs(0, sourceTy.getElementType());
     scatterSignConverter.addInputs(1, sourceTy.getElementType());
-    rewriter.applySignatureConversion(&scatterRegion, scatterSignConverter,
-                                      getTypeConverter());
+    rewriter.applySignatureConversion(&scatterRegion.front(),
+                                      scatterSignConverter, getTypeConverter());
 
     auto &scatterBlock = scatterRegion.front();
     auto scatterTerminator = scatterBlock.getTerminator();
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
index d7b2a4e..a20f872 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
@@ -189,7 +189,7 @@
           idx, getTypeConverter()->convertType(
                    getElementTypeOrSelf(argument.getType())));
     }
-    rewriter.applySignatureConversion(&region, signature_converter);
+    rewriter.applySignatureConversion(&region.front(), signature_converter);
 
     rewriter.replaceOp(op, sortOp->getResults());
     return success();
@@ -281,7 +281,7 @@
     // where output[O] maps to block args #1 in linalg_ext.scatter ops.
     signatureConverter.addInputs(1, argType);
     signatureConverter.addInputs(0, argType);
-    rewriter.applySignatureConversion(&region, signatureConverter);
+    rewriter.applySignatureConversion(&region.front(), signatureConverter);
 
     rewriter.replaceOp(op, scatterOp->getResults());
     return success();
@@ -598,7 +598,8 @@
     TypeConverter::SignatureConversion signatureConverter(2);
     signatureConverter.addInputs(0, input0Ty.getElementType());
     signatureConverter.addInputs(1, init0Ty.getElementType());
-    rewriter.applySignatureConversion(&scanOp.getRegion(), signatureConverter);
+    rewriter.applySignatureConversion(&scanOp.getRegion().front(),
+                                      signatureConverter);
 
     rewriter.replaceOp(op, scanOp.getResult(0));
     return success();
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp
index 8f43f12..ce8ba53 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp
@@ -201,7 +201,7 @@
               cast<ShapedType>(val.getType()).getElementType()));
     }
 
-    rewriter.applySignatureConversion(&region, signatureConverter,
+    rewriter.applySignatureConversion(&region.front(), signatureConverter,
                                       getTypeConverter());
     rewriter.replaceOp(op, linalgOp.getResults());
     return success();
@@ -301,7 +301,7 @@
           // type for new operand number 'idx' + linalgOp.getNumInputs()
           typeConverter->convertType(val.getElementType()));
     }
-    rewriter.applySignatureConversion(&region, signatureConverter,
+    rewriter.applySignatureConversion(&region.front(), signatureConverter,
                                       getTypeConverter());
 
     // Cast the result to the correct type.
@@ -470,7 +470,7 @@
           i, cast<ShapedType>(input.getType()).getElementType());
     }
 
-    rewriter.applySignatureConversion(&region, signatureConverter,
+    rewriter.applySignatureConversion(&region.front(), signatureConverter,
                                       getTypeConverter());
     rewriter.replaceOp(op, linalgOp.getResults());
     return success();
diff --git a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
index eea5386..06725dc 100644
--- a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
+++ b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
@@ -112,7 +112,7 @@
       TypeConverter::SignatureConversion result(newRegion->getNumArguments());
       (void)getTypeConverter()->convertSignatureArgs(
           newRegion->getArgumentTypes(), result);
-      rewriter.applySignatureConversion(newRegion, result);
+      rewriter.applySignatureConversion(&newRegion->front(), result);
     }
     Operation *newOp = rewriter.create(state);
     rewriter.replaceOp(op, newOp->getResults());
diff --git a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
index dc3ba5b..d9097e3 100644
--- a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
+++ b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
@@ -66,7 +66,7 @@
       TypeConverter::SignatureConversion result(newRegion->getNumArguments());
       (void)getTypeConverter()->convertSignatureArgs(
           newRegion->getArgumentTypes(), result);
-      rewriter.applySignatureConversion(newRegion, result);
+      rewriter.applySignatureConversion(&newRegion->front(), result);
     }
     Operation *newOp = rewriter.create(state);
     rewriter.replaceOp(op, newOp->getResults());
diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
index b7b78bc..d9541a7 100644
--- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
@@ -11,17 +11,14 @@
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
 #include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "llvm/ADT/SmallPtrSet.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
 #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
 #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
 #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
 #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
-#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
 
 namespace Torch = mlir::torch::Torch;
 namespace TorchConversion = mlir::torch::TorchConversion;
@@ -103,8 +100,18 @@
   if (isa<TensorType>(ty))
     return possibleTorchTensor;
 
+  if (auto defining = dyn_cast_or_null<TorchConversion::FromBuiltinTensorOp>(
+          possibleTorchTensor.getDefiningOp())) {
+    return defining.getOperand();
+  }
+
   Torch::ValueTensorType vtensorType = cast<Torch::ValueTensorType>(ty);
   TensorType builtinTy = vtensorType.toBuiltinTensor();
+  if (auto intTy = dyn_cast<IntegerType>(builtinTy.getElementType())) {
+    builtinTy =
+        builtinTy.clone(builder.getIntegerType(intTy.getIntOrFloatBitWidth()));
+  }
+
   return builder.create<TorchConversion::ToBuiltinTensorOp>(
       possibleTorchTensor.getLoc(), builtinTy, possibleTorchTensor);
 }
@@ -357,6 +364,11 @@
     builtinTensorType = tType;
   } else if (auto vtType = dyn_cast<Torch::ValueTensorType>(torchType)) {
     builtinTensorType = vtType.toBuiltinTensor();
+    if (auto intTy =
+            dyn_cast<IntegerType>(builtinTensorType.getElementType())) {
+      builtinTensorType = builtinTensorType.clone(
+          builder.getIntegerType(intTy.getIntOrFloatBitWidth()));
+    }
   } else {
     return emitError(loc) << "unsupported immutable tensor argument: "
                           << torchType;
diff --git a/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir b/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir
index 9256519..801b49a 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir
@@ -2,15 +2,6 @@
 
 // Check that the auto input conversion pipeline uses this plugin.
 
-// CHECK-LABEL: util.func public @simple_add_torch
-// CHECK:  arith.addf
-func.func @simple_add_torch(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
-  %int1 = torch.constant.int 1
-  %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32>
-  return %0 : !torch.vtensor<[2],f32>
-}
-
-// -----
 
 // CHECK-LABEL: util.func public @simple_add_onnx
 // CHECK:  arith.addi
diff --git a/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir b/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
index 1ad40bc..fad4e7c 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
@@ -8,8 +8,8 @@
   %zps = torch.vtensor.literal(dense<0.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16>
   %bit_width = torch.constant.int 4
   %group_size = torch.constant.int 2
-  // CHECK: %[[TOBUILTIN:.*]] = torch_c.to_builtin_tensor %[[C0]] : !torch.vtensor<[8,4],ui8> -> tensor<8x4xi8>
-  // CHECK: %[[BITCAST:.*]] = flow.tensor.bitcast %[[TOBUILTIN]] : tensor<8x4xi8> -> tensor<8x8xi4>
+  // CHECK: %[[TOBUILTIN:.*]] = torch_c.to_builtin_tensor %[[C0]] : !torch.vtensor<[8,4],ui8> -> tensor<8x4xui8>
+  // CHECK: %[[BITCAST:.*]] = flow.tensor.bitcast %[[TOBUILTIN]] : tensor<8x4xui8> -> tensor<8x8xi4>
   // CHECK: %[[TOTORCH:.*]] = torch_c.from_builtin_tensor %[[BITCAST]] : tensor<8x8xi4> -> !torch.vtensor<[8,8],ui4>
   %output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16>
   return %output : !torch.vtensor<[1,1,8],f16>
diff --git a/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt b/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt
index 7a99114..2b96c50 100644
--- a/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt
+++ b/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt
@@ -85,8 +85,8 @@
   TD_FILE
     "${TORCH_MLIR_ROOT_DIR}/include/torch-mlir/Dialect/Torch/IR/TorchOps.td"
   OUTS
-    -gen-dialect-decls Dialect/Torch/IR/TorchDialect.h.inc
-    -gen-dialect-defs Dialect/Torch/IR/TorchDialect.cpp.inc
+    -gen-dialect-decls -dialect=torch Dialect/Torch/IR/TorchDialect.h.inc
+    -gen-dialect-defs -dialect=torch Dialect/Torch/IR/TorchDialect.cpp.inc
     -gen-op-decls Dialect/Torch/IR/TorchOps.h.inc
     -gen-op-defs Dialect/Torch/IR/TorchOps.cpp.inc
 )