Integrate llvm/llvm-project@27ac46e6bea2 (#17662)

Updated to llvm/llvm-project@27ac46e6bea2
* Used LLVM `MathExtras.h` to replace MLIR one
* Updated `applySignatureConversion` usage

Updated to openxla/stablehlo@dd48ec5
* `chlo.minimum_broadcast_shapes` op was removed
  https://github.com/openxla/stablehlo/pull/2287
* `chlo.dynamic_reshape` op was removed
  https://github.com/openxla/stablehlo/pull/2286
* Added batching dims to scatter dims
  https://github.com/openxla/stablehlo/pull/2259

Updated to llvm/torch-mlir@77d7f64

---------

Co-authored-by: hanhanW <hanhan0912@gmail.com>
Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
Co-authored-by: Quinn Dawkins <quinn@nod-labs.com>
Co-authored-by: Nirvedh Meshram <nirvedh@gmail.com>
diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
index 432fcf2..8dc2d8d 100644
--- a/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
@@ -18,7 +18,6 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "stablehlo/dialect/BroadcastUtils.h"
@@ -443,38 +442,6 @@
   }
 };
 
-struct ConvertDynamicReshapeOp final
-    : OpRewritePattern<mlir::chlo::DynamicReshapeOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(mlir::chlo::DynamicReshapeOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    TypedValue<TensorType> tensor = op.getOperand();
-    TypedValue<RankedTensorType> shape = op.getOutputShape();
-
-    auto shapeTy = cast<ShapedType>(shape.getType());
-    auto resultTy = cast<ShapedType>(op.getType());
-
-    Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, tensor);
-    Value numEls = rewriter.create<shape::NumElementsOp>(loc, inputShape);
-    Value cstr =
-        rewriter.create<mlir::stablehlo::CstrReshapableOp>(loc, numEls, shape);
-    rewriter.replaceOpWithNewOp<shape::AssumingOp>(
-        op, cstr, [&](OpBuilder &b, Location l) {
-          Value computedShape =
-              b.create<mlir::stablehlo::ComputeReshapeShapeOp>(l, shapeTy,
-                                                               numEls, shape);
-          SmallVector<Value> result;
-          result.push_back(b.create<mlir::stablehlo::DynamicReshapeOp>(
-              l, resultTy, tensor, computedShape));
-          return result;
-        });
-
-    return success();
-  }
-};
-
 //===----------------------------------------------------------------------===//
 // Decomposition Patterns.
 //===----------------------------------------------------------------------===//
@@ -2192,7 +2159,6 @@
       ConversionTarget conversionTarget(getContext());
       RewritePatternSet conversionPatterns(ctx);
       conversionTarget.addIllegalDialect<chlo::ChloDialect>();
-      conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
       conversionTarget.addLegalDialect<
           mlir::stablehlo::StablehloDialect, mlir::arith::ArithDialect,
           mlir::shape::ShapeDialect, mlir::scf::SCFDialect,
@@ -2239,9 +2205,7 @@
       context, patterns, 10);
   populateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
       context, patterns, 5);
-  patterns
-      ->add<ConvertConstantLikeOp, ConvertDynamicReshapeOp, ConvertSelectOp>(
-          context);
+  patterns->add<ConvertConstantLikeOp, ConvertSelectOp>(context);
 }
 
 static void populateDecompositionPatterns(MLIRContext *context,
diff --git a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp
index da22c9e..4970994 100644
--- a/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/Preprocessing/StableHLOToStableHLO.cpp
@@ -601,7 +601,8 @@
 
     auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
         op.getContext(), newUpdateWindowDims,
-        dimNumbers.getInsertedWindowDims(),
+        dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
+        dimNumbers.getScatterIndicesBatchingDims(),
         dimNumbers.getScatterDimsToOperandDims(),
         dimNumbers.getIndexVectorDim() + 1);
 
@@ -700,7 +701,8 @@
 
     auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
         op.getContext(), newUpdatedWindowDims,
-        dimNumbers.getInsertedWindowDims(),
+        dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
+        dimNumbers.getScatterIndicesBatchingDims(),
         dimNumbers.getScatterDimsToOperandDims(),
         /*indexVectorDim=*/1);
 
@@ -801,7 +803,8 @@
 
     auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
         op.getContext(), newUpdatedWindowDims,
-        dimNumbers.getInsertedWindowDims(),
+        dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
+        dimNumbers.getScatterIndicesBatchingDims(),
         dimNumbers.getScatterDimsToOperandDims(),
         /*indexVectorDim=*/indexVectorDim);
 
@@ -939,6 +942,8 @@
 
     auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
         op.getContext(), newUpdatedWindowDims, newInsertedWindowDims,
+        dimNumbers.getInputBatchingDims(),
+        dimNumbers.getScatterIndicesBatchingDims(),
         dimNumbers.getScatterDimsToOperandDims(),
         /*indexVectorDim=*/1);
 
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
index 1103337..0180abb 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
@@ -416,7 +416,7 @@
         return rewriter.notifyMatchFailure(op,
                                            "argument type conversion failed");
       }
-      rewriter.applySignatureConversion(newRegion, result);
+      rewriter.applySignatureConversion(&newRegion->front(), result);
     }
     Operation *newOp = rewriter.create(state);
     rewriter.replaceOp(op, newOp->getResults());
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp
index 335805d..c886df8 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalg.cpp
@@ -1653,7 +1653,7 @@
     }
     signatureConverter.addInputs(resultType.getElementType());
 
-    rewriter.applySignatureConversion(&region, signatureConverter,
+    rewriter.applySignatureConversion(&region.front(), signatureConverter,
                                       getTypeConverter());
     rewriter.replaceOp(op, linalgOp.getResults());
     return success();
@@ -1706,7 +1706,7 @@
       signatureConverter.addInputs(idx, convertedTy);
     }
 
-    rewriter.applySignatureConversion(&region, signatureConverter,
+    rewriter.applySignatureConversion(&region.front(), signatureConverter,
                                       getTypeConverter());
     auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType,
                                                         linalgOp.getResults());
@@ -2073,8 +2073,8 @@
     reduceSignConverter.addInputs(srcETy);
     reduceSignConverter.addInputs(1, destETy);
     reduceSignConverter.addInputs(indexETy);
-    rewriter.applySignatureConversion(&reduceRegion, reduceSignConverter,
-                                      getTypeConverter());
+    rewriter.applySignatureConversion(&reduceRegion.front(),
+                                      reduceSignConverter, getTypeConverter());
 
     // Grab the terminator and use the turned value to now select the
     // correct index and value.
@@ -2179,8 +2179,8 @@
     scatterSignConverter.addInputs(indexETy);
     scatterSignConverter.addInputs(0, sourceTy.getElementType());
     scatterSignConverter.addInputs(1, sourceTy.getElementType());
-    rewriter.applySignatureConversion(&scatterRegion, scatterSignConverter,
-                                      getTypeConverter());
+    rewriter.applySignatureConversion(&scatterRegion.front(),
+                                      scatterSignConverter, getTypeConverter());
 
     auto &scatterBlock = scatterRegion.front();
     auto scatterTerminator = scatterBlock.getTerminator();
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
index d7b2a4e..a20f872 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
@@ -189,7 +189,7 @@
           idx, getTypeConverter()->convertType(
                    getElementTypeOrSelf(argument.getType())));
     }
-    rewriter.applySignatureConversion(&region, signature_converter);
+    rewriter.applySignatureConversion(&region.front(), signature_converter);
 
     rewriter.replaceOp(op, sortOp->getResults());
     return success();
@@ -281,7 +281,7 @@
     // where output[O] maps to block args #1 in linalg_ext.scatter ops.
     signatureConverter.addInputs(1, argType);
     signatureConverter.addInputs(0, argType);
-    rewriter.applySignatureConversion(&region, signatureConverter);
+    rewriter.applySignatureConversion(&region.front(), signatureConverter);
 
     rewriter.replaceOp(op, scatterOp->getResults());
     return success();
@@ -598,7 +598,8 @@
     TypeConverter::SignatureConversion signatureConverter(2);
     signatureConverter.addInputs(0, input0Ty.getElementType());
     signatureConverter.addInputs(1, init0Ty.getElementType());
-    rewriter.applySignatureConversion(&scanOp.getRegion(), signatureConverter);
+    rewriter.applySignatureConversion(&scanOp.getRegion().front(),
+                                      signatureConverter);
 
     rewriter.replaceOp(op, scanOp.getResult(0));
     return success();
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp
index 8f43f12..ce8ba53 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgReduce.cpp
@@ -201,7 +201,7 @@
               cast<ShapedType>(val.getType()).getElementType()));
     }
 
-    rewriter.applySignatureConversion(&region, signatureConverter,
+    rewriter.applySignatureConversion(&region.front(), signatureConverter,
                                       getTypeConverter());
     rewriter.replaceOp(op, linalgOp.getResults());
     return success();
@@ -301,7 +301,7 @@
           // type for new operand number 'idx' + linalgOp.getNumInputs()
           typeConverter->convertType(val.getElementType()));
     }
-    rewriter.applySignatureConversion(&region, signatureConverter,
+    rewriter.applySignatureConversion(&region.front(), signatureConverter,
                                       getTypeConverter());
 
     // Cast the result to the correct type.
@@ -470,7 +470,7 @@
           i, cast<ShapedType>(input.getType()).getElementType());
     }
 
-    rewriter.applySignatureConversion(&region, signatureConverter,
+    rewriter.applySignatureConversion(&region.front(), signatureConverter,
                                       getTypeConverter());
     rewriter.replaceOp(op, linalgOp.getResults());
     return success();
diff --git a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
index eea5386..06725dc 100644
--- a/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
+++ b/compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
@@ -112,7 +112,7 @@
       TypeConverter::SignatureConversion result(newRegion->getNumArguments());
       (void)getTypeConverter()->convertSignatureArgs(
           newRegion->getArgumentTypes(), result);
-      rewriter.applySignatureConversion(newRegion, result);
+      rewriter.applySignatureConversion(&newRegion->front(), result);
     }
     Operation *newOp = rewriter.create(state);
     rewriter.replaceOp(op, newOp->getResults());
diff --git a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
index dc3ba5b..d9097e3 100644
--- a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
+++ b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
@@ -66,7 +66,7 @@
       TypeConverter::SignatureConversion result(newRegion->getNumArguments());
       (void)getTypeConverter()->convertSignatureArgs(
           newRegion->getArgumentTypes(), result);
-      rewriter.applySignatureConversion(newRegion, result);
+      rewriter.applySignatureConversion(&newRegion->front(), result);
     }
     Operation *newOp = rewriter.create(state);
     rewriter.replaceOp(op, newOp->getResults());
diff --git a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
index b7b78bc..d9541a7 100644
--- a/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
@@ -11,17 +11,14 @@
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
 #include "iree/compiler/Dialect/Util/IR/UtilOps.h"
-#include "llvm/ADT/SmallPtrSet.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
 #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
 #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
 #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
 #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
-#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
 
 namespace Torch = mlir::torch::Torch;
 namespace TorchConversion = mlir::torch::TorchConversion;
@@ -103,8 +100,18 @@
   if (isa<TensorType>(ty))
     return possibleTorchTensor;
 
+  if (auto defining = dyn_cast_or_null<TorchConversion::FromBuiltinTensorOp>(
+          possibleTorchTensor.getDefiningOp())) {
+    return defining.getOperand();
+  }
+
   Torch::ValueTensorType vtensorType = cast<Torch::ValueTensorType>(ty);
   TensorType builtinTy = vtensorType.toBuiltinTensor();
+  if (auto intTy = dyn_cast<IntegerType>(builtinTy.getElementType())) {
+    builtinTy =
+        builtinTy.clone(builder.getIntegerType(intTy.getIntOrFloatBitWidth()));
+  }
+
   return builder.create<TorchConversion::ToBuiltinTensorOp>(
       possibleTorchTensor.getLoc(), builtinTy, possibleTorchTensor);
 }
@@ -357,6 +364,11 @@
     builtinTensorType = tType;
   } else if (auto vtType = dyn_cast<Torch::ValueTensorType>(torchType)) {
     builtinTensorType = vtType.toBuiltinTensor();
+    if (auto intTy =
+            dyn_cast<IntegerType>(builtinTensorType.getElementType())) {
+      builtinTensorType = builtinTensorType.clone(
+          builder.getIntegerType(intTy.getIntOrFloatBitWidth()));
+    }
   } else {
     return emitError(loc) << "unsupported immutable tensor argument: "
                           << torchType;
diff --git a/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir b/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir
index 9256519..801b49a 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/auto_input_conversion.mlir
@@ -2,15 +2,6 @@
 
 // Check that the auto input conversion pipeline uses this plugin.
 
-// CHECK-LABEL: util.func public @simple_add_torch
-// CHECK:  arith.addf
-func.func @simple_add_torch(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
-  %int1 = torch.constant.int 1
-  %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32>
-  return %0 : !torch.vtensor<[2],f32>
-}
-
-// -----
 
 // CHECK-LABEL: util.func public @simple_add_onnx
 // CHECK:  arith.addi
diff --git a/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir b/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
index 1ad40bc..fad4e7c 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/bitcast_quant_tensor.mlir
@@ -8,8 +8,8 @@
   %zps = torch.vtensor.literal(dense<0.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16>
   %bit_width = torch.constant.int 4
   %group_size = torch.constant.int 2
-  // CHECK: %[[TOBUILTIN:.*]] = torch_c.to_builtin_tensor %[[C0]] : !torch.vtensor<[8,4],ui8> -> tensor<8x4xi8>
-  // CHECK: %[[BITCAST:.*]] = flow.tensor.bitcast %[[TOBUILTIN]] : tensor<8x4xi8> -> tensor<8x8xi4>
+  // CHECK: %[[TOBUILTIN:.*]] = torch_c.to_builtin_tensor %[[C0]] : !torch.vtensor<[8,4],ui8> -> tensor<8x4xui8>
+  // CHECK: %[[BITCAST:.*]] = flow.tensor.bitcast %[[TOBUILTIN]] : tensor<8x4xui8> -> tensor<8x8xi4>
   // CHECK: %[[TOTORCH:.*]] = torch_c.from_builtin_tensor %[[BITCAST]] : tensor<8x8xi4> -> !torch.vtensor<[8,8],ui4>
   %output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16>
   return %output : !torch.vtensor<[1,1,8],f16>
diff --git a/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt b/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt
index 7a99114..2b96c50 100644
--- a/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt
+++ b/compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt
@@ -85,8 +85,8 @@
   TD_FILE
     "${TORCH_MLIR_ROOT_DIR}/include/torch-mlir/Dialect/Torch/IR/TorchOps.td"
   OUTS
-    -gen-dialect-decls Dialect/Torch/IR/TorchDialect.h.inc
-    -gen-dialect-defs Dialect/Torch/IR/TorchDialect.cpp.inc
+    -gen-dialect-decls -dialect=torch Dialect/Torch/IR/TorchDialect.h.inc
+    -gen-dialect-defs -dialect=torch Dialect/Torch/IR/TorchDialect.cpp.inc
     -gen-op-decls Dialect/Torch/IR/TorchOps.h.inc
     -gen-op-defs Dialect/Torch/IR/TorchOps.cpp.inc
 )
diff --git a/compiler/pyproject.toml b/compiler/pyproject.toml
index 527e76b..a7ed98f 100644
--- a/compiler/pyproject.toml
+++ b/compiler/pyproject.toml
@@ -13,5 +13,6 @@
     "packaging",
     "pybind11>=2.10.1",
     "PyYAML",
+    "sympy",
 ]
 build-backend = "setuptools.build_meta"
diff --git a/compiler/setup.py b/compiler/setup.py
index 6c4ac22..393786e 100644
--- a/compiler/setup.py
+++ b/compiler/setup.py
@@ -464,6 +464,7 @@
     install_requires=[
         "numpy",
         "PyYAML",
+        "sympy",
     ],
     extras_require={
         "onnx": [
diff --git a/compiler/src/iree/compiler/API/BUILD.bazel b/compiler/src/iree/compiler/API/BUILD.bazel
index 374982b..3076404 100644
--- a/compiler/src/iree/compiler/API/BUILD.bazel
+++ b/compiler/src/iree/compiler/API/BUILD.bazel
@@ -43,6 +43,7 @@
         "@llvm-project//mlir:CAPIPDL",
         "@llvm-project//mlir:CAPITransformDialect",
         "@llvm-project//mlir:CAPITransformDialectTransforms",
+        "@llvm-project//mlir:CAPITransforms",
     ],
 )
 
diff --git a/compiler/src/iree/compiler/API/CMakeLists.txt b/compiler/src/iree/compiler/API/CMakeLists.txt
index 1ff2e9e..f404793 100644
--- a/compiler/src/iree/compiler/API/CMakeLists.txt
+++ b/compiler/src/iree/compiler/API/CMakeLists.txt
@@ -23,6 +23,7 @@
     MLIRCAPIPDL
     MLIRCAPITransformDialect
     MLIRCAPITransformDialectTransforms
+    MLIRCAPITransforms
     iree::compiler::API::Internal::CompilerDriver
     iree::compiler::API::Internal::IREECompileToolEntryPoint
     iree::compiler::API::Internal::IREEMLIRLSPServerToolEntryPoint
@@ -70,6 +71,7 @@
   obj.MLIRCAPIGPU
   obj.MLIRCAPILinalg
   obj.MLIRCAPIPDL
+  obj.MLIRCAPITransforms
   obj.MLIRCAPITransformDialect
   obj.MLIRCAPITransformDialectTransforms
   iree_compiler_API_Internal_CompilerDriver.objects
diff --git a/compiler/src/iree/compiler/API/api_exports.c b/compiler/src/iree/compiler/API/api_exports.c
index 09d8fa3..4904c65 100644
--- a/compiler/src/iree/compiler/API/api_exports.c
+++ b/compiler/src/iree/compiler/API/api_exports.c
@@ -126,6 +126,7 @@
 extern void mlirAffineMulExprGet();
 extern void mlirAffineSymbolExprGet();
 extern void mlirAffineSymbolExprGetPosition();
+extern void mlirApplyPatternsAndFoldGreedily();
 extern void mlirArrayAttrGet();
 extern void mlirArrayAttrGetElement();
 extern void mlirArrayAttrGetNumElements();
@@ -158,6 +159,7 @@
 extern void mlirAttributeIsAElements();
 extern void mlirAttributeIsAFlatSymbolRef();
 extern void mlirAttributeIsAFloat();
+extern void mlirAttributeIsAGPUObjectAttr();
 extern void mlirAttributeIsAInteger();
 extern void mlirAttributeIsAIntegerSet();
 extern void mlirAttributeIsALocation();
@@ -181,6 +183,7 @@
 extern void mlirBlockDestroy();
 extern void mlirBlockDetach();
 extern void mlirBlockEqual();
+extern void mlirBlockEraseArgument();
 extern void mlirBlockGetArgument();
 extern void mlirBlockGetFirstOperation();
 extern void mlirBlockGetNextInRegion();
@@ -349,12 +352,20 @@
 extern void mlirFloatAttrGetValueDouble();
 extern void mlirFloatTF32TypeGetTypeID();
 extern void mlirFloatTypeGetWidth();
+extern void mlirFreezeRewritePattern();
+extern void mlirFrozenRewritePatternSetDestroy();
 extern void mlirFunctionTypeGet();
 extern void mlirFunctionTypeGetInput();
 extern void mlirFunctionTypeGetNumInputs();
 extern void mlirFunctionTypeGetNumResults();
 extern void mlirFunctionTypeGetResult();
 extern void mlirFunctionTypeGetTypeID();
+extern void mlirGPUObjectAttrGet();
+extern void mlirGPUObjectAttrGetFormat();
+extern void mlirGPUObjectAttrGetObject();
+extern void mlirGPUObjectAttrGetProperties();
+extern void mlirGPUObjectAttrGetTarget();
+extern void mlirGPUObjectAttrHasProperties();
 extern void mlirGetDialectHandle__iree_input__();
 extern void mlirGetDialectHandle__transform__();
 extern void mlirIREELinalgTransformRegisterPasses();
@@ -399,6 +410,7 @@
 extern void mlirIntegerTypeIsUnsigned();
 extern void mlirIntegerTypeSignedGet();
 extern void mlirIntegerTypeUnsignedGet();
+extern void mlirIsCurrentDebugType();
 extern void mlirIsGlobalDebugEnabled();
 extern void mlirLinalgFillBuiltinNamedOpRegion();
 extern void mlirLlvmThreadPoolCreate();
@@ -422,6 +434,7 @@
 extern void mlirMemRefTypeGetMemorySpace();
 extern void mlirMemRefTypeGetStridesAndOffset();
 extern void mlirMemRefTypeGetTypeID();
+extern void mlirMergeSymbolsIntoFromClone();
 extern void mlirModuleCreateEmpty();
 extern void mlirModuleCreateParse();
 extern void mlirModuleDestroy();
@@ -516,6 +529,8 @@
 extern void mlirOperationWriteBytecodeWithConfig();
 extern void mlirPDLAttributeTypeGet();
 extern void mlirPDLOperationTypeGet();
+extern void mlirPDLPatternModuleDestroy();
+extern void mlirPDLPatternModuleFromModule();
 extern void mlirPDLRangeTypeGet();
 extern void mlirPDLRangeTypeGetElementType();
 extern void mlirPDLTypeTypeGet();
@@ -547,6 +562,9 @@
 extern void mlirRegionTakeBody();
 extern void mlirRegisterGPUPasses();
 extern void mlirRegisterLinalgPasses();
+extern void mlirRewritePatternSetFromPDLPatternModule();
+extern void mlirSetGlobalDebugType();
+extern void mlirSetGlobalDebugTypes();
 extern void mlirShapedTypeGetDimSize();
 extern void mlirShapedTypeGetDynamicSize();
 extern void mlirShapedTypeGetDynamicStrideOrOffset();
@@ -822,6 +840,7 @@
   x += (uintptr_t)&mlirAffineMulExprGet;
   x += (uintptr_t)&mlirAffineSymbolExprGet;
   x += (uintptr_t)&mlirAffineSymbolExprGetPosition;
+  x += (uintptr_t)&mlirApplyPatternsAndFoldGreedily;
   x += (uintptr_t)&mlirArrayAttrGet;
   x += (uintptr_t)&mlirArrayAttrGetElement;
   x += (uintptr_t)&mlirArrayAttrGetNumElements;
@@ -854,6 +873,7 @@
   x += (uintptr_t)&mlirAttributeIsAElements;
   x += (uintptr_t)&mlirAttributeIsAFlatSymbolRef;
   x += (uintptr_t)&mlirAttributeIsAFloat;
+  x += (uintptr_t)&mlirAttributeIsAGPUObjectAttr;
   x += (uintptr_t)&mlirAttributeIsAInteger;
   x += (uintptr_t)&mlirAttributeIsAIntegerSet;
   x += (uintptr_t)&mlirAttributeIsALocation;
@@ -877,6 +897,7 @@
   x += (uintptr_t)&mlirBlockDestroy;
   x += (uintptr_t)&mlirBlockDetach;
   x += (uintptr_t)&mlirBlockEqual;
+  x += (uintptr_t)&mlirBlockEraseArgument;
   x += (uintptr_t)&mlirBlockGetArgument;
   x += (uintptr_t)&mlirBlockGetFirstOperation;
   x += (uintptr_t)&mlirBlockGetNextInRegion;
@@ -1045,12 +1066,20 @@
   x += (uintptr_t)&mlirFloatAttrGetValueDouble;
   x += (uintptr_t)&mlirFloatTF32TypeGetTypeID;
   x += (uintptr_t)&mlirFloatTypeGetWidth;
+  x += (uintptr_t)&mlirFreezeRewritePattern;
+  x += (uintptr_t)&mlirFrozenRewritePatternSetDestroy;
   x += (uintptr_t)&mlirFunctionTypeGet;
   x += (uintptr_t)&mlirFunctionTypeGetInput;
   x += (uintptr_t)&mlirFunctionTypeGetNumInputs;
   x += (uintptr_t)&mlirFunctionTypeGetNumResults;
   x += (uintptr_t)&mlirFunctionTypeGetResult;
   x += (uintptr_t)&mlirFunctionTypeGetTypeID;
+  x += (uintptr_t)&mlirGPUObjectAttrGet;
+  x += (uintptr_t)&mlirGPUObjectAttrGetFormat;
+  x += (uintptr_t)&mlirGPUObjectAttrGetObject;
+  x += (uintptr_t)&mlirGPUObjectAttrGetProperties;
+  x += (uintptr_t)&mlirGPUObjectAttrGetTarget;
+  x += (uintptr_t)&mlirGPUObjectAttrHasProperties;
   x += (uintptr_t)&mlirGetDialectHandle__iree_input__;
   x += (uintptr_t)&mlirGetDialectHandle__transform__;
   x += (uintptr_t)&mlirIREELinalgTransformRegisterPasses;
@@ -1095,6 +1124,7 @@
   x += (uintptr_t)&mlirIntegerTypeIsUnsigned;
   x += (uintptr_t)&mlirIntegerTypeSignedGet;
   x += (uintptr_t)&mlirIntegerTypeUnsignedGet;
+  x += (uintptr_t)&mlirIsCurrentDebugType;
   x += (uintptr_t)&mlirIsGlobalDebugEnabled;
   x += (uintptr_t)&mlirLinalgFillBuiltinNamedOpRegion;
   x += (uintptr_t)&mlirLlvmThreadPoolCreate;
@@ -1118,6 +1148,7 @@
   x += (uintptr_t)&mlirMemRefTypeGetMemorySpace;
   x += (uintptr_t)&mlirMemRefTypeGetStridesAndOffset;
   x += (uintptr_t)&mlirMemRefTypeGetTypeID;
+  x += (uintptr_t)&mlirMergeSymbolsIntoFromClone;
   x += (uintptr_t)&mlirModuleCreateEmpty;
   x += (uintptr_t)&mlirModuleCreateParse;
   x += (uintptr_t)&mlirModuleDestroy;
@@ -1212,6 +1243,8 @@
   x += (uintptr_t)&mlirOperationWriteBytecodeWithConfig;
   x += (uintptr_t)&mlirPDLAttributeTypeGet;
   x += (uintptr_t)&mlirPDLOperationTypeGet;
+  x += (uintptr_t)&mlirPDLPatternModuleDestroy;
+  x += (uintptr_t)&mlirPDLPatternModuleFromModule;
   x += (uintptr_t)&mlirPDLRangeTypeGet;
   x += (uintptr_t)&mlirPDLRangeTypeGetElementType;
   x += (uintptr_t)&mlirPDLTypeTypeGet;
@@ -1243,6 +1276,9 @@
   x += (uintptr_t)&mlirRegionTakeBody;
   x += (uintptr_t)&mlirRegisterGPUPasses;
   x += (uintptr_t)&mlirRegisterLinalgPasses;
+  x += (uintptr_t)&mlirRewritePatternSetFromPDLPatternModule;
+  x += (uintptr_t)&mlirSetGlobalDebugType;
+  x += (uintptr_t)&mlirSetGlobalDebugTypes;
   x += (uintptr_t)&mlirShapedTypeGetDimSize;
   x += (uintptr_t)&mlirShapedTypeGetDynamicSize;
   x += (uintptr_t)&mlirShapedTypeGetDynamicStrideOrOffset;
diff --git a/compiler/src/iree/compiler/API/api_exports.def b/compiler/src/iree/compiler/API/api_exports.def
index 9858b0a..0c104a7 100644
--- a/compiler/src/iree/compiler/API/api_exports.def
+++ b/compiler/src/iree/compiler/API/api_exports.def
@@ -118,6 +118,7 @@
   mlirAffineMulExprGet
   mlirAffineSymbolExprGet
   mlirAffineSymbolExprGetPosition
+  mlirApplyPatternsAndFoldGreedily
   mlirArrayAttrGet
   mlirArrayAttrGetElement
   mlirArrayAttrGetNumElements
@@ -150,6 +151,7 @@
   mlirAttributeIsAElements
   mlirAttributeIsAFlatSymbolRef
   mlirAttributeIsAFloat
+  mlirAttributeIsAGPUObjectAttr
   mlirAttributeIsAInteger
   mlirAttributeIsAIntegerSet
   mlirAttributeIsALocation
@@ -173,6 +175,7 @@
   mlirBlockDestroy
   mlirBlockDetach
   mlirBlockEqual
+  mlirBlockEraseArgument
   mlirBlockGetArgument
   mlirBlockGetFirstOperation
   mlirBlockGetNextInRegion
@@ -341,12 +344,20 @@
   mlirFloatAttrGetValueDouble
   mlirFloatTF32TypeGetTypeID
   mlirFloatTypeGetWidth
+  mlirFreezeRewritePattern
+  mlirFrozenRewritePatternSetDestroy
   mlirFunctionTypeGet
   mlirFunctionTypeGetInput
   mlirFunctionTypeGetNumInputs
   mlirFunctionTypeGetNumResults
   mlirFunctionTypeGetResult
   mlirFunctionTypeGetTypeID
+  mlirGPUObjectAttrGet
+  mlirGPUObjectAttrGetFormat
+  mlirGPUObjectAttrGetObject
+  mlirGPUObjectAttrGetProperties
+  mlirGPUObjectAttrGetTarget
+  mlirGPUObjectAttrHasProperties
   mlirGetDialectHandle__iree_input__
   mlirGetDialectHandle__transform__
   mlirIREELinalgTransformRegisterPasses
@@ -391,6 +402,7 @@
   mlirIntegerTypeIsUnsigned
   mlirIntegerTypeSignedGet
   mlirIntegerTypeUnsignedGet
+  mlirIsCurrentDebugType
   mlirIsGlobalDebugEnabled
   mlirLinalgFillBuiltinNamedOpRegion
   mlirLlvmThreadPoolCreate
@@ -414,6 +426,7 @@
   mlirMemRefTypeGetMemorySpace
   mlirMemRefTypeGetStridesAndOffset
   mlirMemRefTypeGetTypeID
+  mlirMergeSymbolsIntoFromClone
   mlirModuleCreateEmpty
   mlirModuleCreateParse
   mlirModuleDestroy
@@ -508,6 +521,8 @@
   mlirOperationWriteBytecodeWithConfig
   mlirPDLAttributeTypeGet
   mlirPDLOperationTypeGet
+  mlirPDLPatternModuleDestroy
+  mlirPDLPatternModuleFromModule
   mlirPDLRangeTypeGet
   mlirPDLRangeTypeGetElementType
   mlirPDLTypeTypeGet
@@ -539,6 +554,9 @@
   mlirRegionTakeBody
   mlirRegisterGPUPasses
   mlirRegisterLinalgPasses
+  mlirRewritePatternSetFromPDLPatternModule
+  mlirSetGlobalDebugType
+  mlirSetGlobalDebugTypes
   mlirShapedTypeGetDimSize
   mlirShapedTypeGetDynamicSize
   mlirShapedTypeGetDynamicStrideOrOffset
diff --git a/compiler/src/iree/compiler/API/api_exports.ld b/compiler/src/iree/compiler/API/api_exports.ld
index 48d6b32..f810a0b 100644
--- a/compiler/src/iree/compiler/API/api_exports.ld
+++ b/compiler/src/iree/compiler/API/api_exports.ld
@@ -119,6 +119,7 @@
     mlirAffineMulExprGet;
     mlirAffineSymbolExprGet;
     mlirAffineSymbolExprGetPosition;
+    mlirApplyPatternsAndFoldGreedily;
     mlirArrayAttrGet;
     mlirArrayAttrGetElement;
     mlirArrayAttrGetNumElements;
@@ -151,6 +152,7 @@
     mlirAttributeIsAElements;
     mlirAttributeIsAFlatSymbolRef;
     mlirAttributeIsAFloat;
+    mlirAttributeIsAGPUObjectAttr;
     mlirAttributeIsAInteger;
     mlirAttributeIsAIntegerSet;
     mlirAttributeIsALocation;
@@ -174,6 +176,7 @@
     mlirBlockDestroy;
     mlirBlockDetach;
     mlirBlockEqual;
+    mlirBlockEraseArgument;
     mlirBlockGetArgument;
     mlirBlockGetFirstOperation;
     mlirBlockGetNextInRegion;
@@ -342,12 +345,20 @@
     mlirFloatAttrGetValueDouble;
     mlirFloatTF32TypeGetTypeID;
     mlirFloatTypeGetWidth;
+    mlirFreezeRewritePattern;
+    mlirFrozenRewritePatternSetDestroy;
     mlirFunctionTypeGet;
     mlirFunctionTypeGetInput;
     mlirFunctionTypeGetNumInputs;
     mlirFunctionTypeGetNumResults;
     mlirFunctionTypeGetResult;
     mlirFunctionTypeGetTypeID;
+    mlirGPUObjectAttrGet;
+    mlirGPUObjectAttrGetFormat;
+    mlirGPUObjectAttrGetObject;
+    mlirGPUObjectAttrGetProperties;
+    mlirGPUObjectAttrGetTarget;
+    mlirGPUObjectAttrHasProperties;
     mlirGetDialectHandle__iree_input__;
     mlirGetDialectHandle__transform__;
     mlirIREELinalgTransformRegisterPasses;
@@ -392,6 +403,7 @@
     mlirIntegerTypeIsUnsigned;
     mlirIntegerTypeSignedGet;
     mlirIntegerTypeUnsignedGet;
+    mlirIsCurrentDebugType;
     mlirIsGlobalDebugEnabled;
     mlirLinalgFillBuiltinNamedOpRegion;
     mlirLlvmThreadPoolCreate;
@@ -415,6 +427,7 @@
     mlirMemRefTypeGetMemorySpace;
     mlirMemRefTypeGetStridesAndOffset;
     mlirMemRefTypeGetTypeID;
+    mlirMergeSymbolsIntoFromClone;
     mlirModuleCreateEmpty;
     mlirModuleCreateParse;
     mlirModuleDestroy;
@@ -509,6 +522,8 @@
     mlirOperationWriteBytecodeWithConfig;
     mlirPDLAttributeTypeGet;
     mlirPDLOperationTypeGet;
+    mlirPDLPatternModuleDestroy;
+    mlirPDLPatternModuleFromModule;
     mlirPDLRangeTypeGet;
     mlirPDLRangeTypeGetElementType;
     mlirPDLTypeTypeGet;
@@ -540,6 +555,9 @@
     mlirRegionTakeBody;
     mlirRegisterGPUPasses;
     mlirRegisterLinalgPasses;
+    mlirRewritePatternSetFromPDLPatternModule;
+    mlirSetGlobalDebugType;
+    mlirSetGlobalDebugTypes;
     mlirShapedTypeGetDimSize;
     mlirShapedTypeGetDynamicSize;
     mlirShapedTypeGetDynamicStrideOrOffset;
diff --git a/compiler/src/iree/compiler/API/api_exports.macos.lst b/compiler/src/iree/compiler/API/api_exports.macos.lst
index a6487e0..8a47e30 100644
--- a/compiler/src/iree/compiler/API/api_exports.macos.lst
+++ b/compiler/src/iree/compiler/API/api_exports.macos.lst
@@ -117,6 +117,7 @@
 _mlirAffineMulExprGet
 _mlirAffineSymbolExprGet
 _mlirAffineSymbolExprGetPosition
+_mlirApplyPatternsAndFoldGreedily
 _mlirArrayAttrGet
 _mlirArrayAttrGetElement
 _mlirArrayAttrGetNumElements
@@ -149,6 +150,7 @@
 _mlirAttributeIsAElements
 _mlirAttributeIsAFlatSymbolRef
 _mlirAttributeIsAFloat
+_mlirAttributeIsAGPUObjectAttr
 _mlirAttributeIsAInteger
 _mlirAttributeIsAIntegerSet
 _mlirAttributeIsALocation
@@ -172,6 +174,7 @@
 _mlirBlockDestroy
 _mlirBlockDetach
 _mlirBlockEqual
+_mlirBlockEraseArgument
 _mlirBlockGetArgument
 _mlirBlockGetFirstOperation
 _mlirBlockGetNextInRegion
@@ -340,12 +343,20 @@
 _mlirFloatAttrGetValueDouble
 _mlirFloatTF32TypeGetTypeID
 _mlirFloatTypeGetWidth
+_mlirFreezeRewritePattern
+_mlirFrozenRewritePatternSetDestroy
 _mlirFunctionTypeGet
 _mlirFunctionTypeGetInput
 _mlirFunctionTypeGetNumInputs
 _mlirFunctionTypeGetNumResults
 _mlirFunctionTypeGetResult
 _mlirFunctionTypeGetTypeID
+_mlirGPUObjectAttrGet
+_mlirGPUObjectAttrGetFormat
+_mlirGPUObjectAttrGetObject
+_mlirGPUObjectAttrGetProperties
+_mlirGPUObjectAttrGetTarget
+_mlirGPUObjectAttrHasProperties
 _mlirGetDialectHandle__iree_input__
 _mlirGetDialectHandle__transform__
 _mlirIREELinalgTransformRegisterPasses
@@ -390,6 +401,7 @@
 _mlirIntegerTypeIsUnsigned
 _mlirIntegerTypeSignedGet
 _mlirIntegerTypeUnsignedGet
+_mlirIsCurrentDebugType
 _mlirIsGlobalDebugEnabled
 _mlirLinalgFillBuiltinNamedOpRegion
 _mlirLlvmThreadPoolCreate
@@ -413,6 +425,7 @@
 _mlirMemRefTypeGetMemorySpace
 _mlirMemRefTypeGetStridesAndOffset
 _mlirMemRefTypeGetTypeID
+_mlirMergeSymbolsIntoFromClone
 _mlirModuleCreateEmpty
 _mlirModuleCreateParse
 _mlirModuleDestroy
@@ -507,6 +520,8 @@
 _mlirOperationWriteBytecodeWithConfig
 _mlirPDLAttributeTypeGet
 _mlirPDLOperationTypeGet
+_mlirPDLPatternModuleDestroy
+_mlirPDLPatternModuleFromModule
 _mlirPDLRangeTypeGet
 _mlirPDLRangeTypeGetElementType
 _mlirPDLTypeTypeGet
@@ -538,6 +553,9 @@
 _mlirRegionTakeBody
 _mlirRegisterGPUPasses
 _mlirRegisterLinalgPasses
+_mlirRewritePatternSetFromPDLPatternModule
+_mlirSetGlobalDebugType
+_mlirSetGlobalDebugTypes
 _mlirShapedTypeGetDimSize
 _mlirShapedTypeGetDynamicSize
 _mlirShapedTypeGetDynamicStrideOrOffset
diff --git a/compiler/src/iree/compiler/API/generate_exports.py b/compiler/src/iree/compiler/API/generate_exports.py
index 5200325..68f1c43 100755
--- a/compiler/src/iree/compiler/API/generate_exports.py
+++ b/compiler/src/iree/compiler/API/generate_exports.py
@@ -56,6 +56,7 @@
     "Interfaces.h",
     "IR.h",
     "Pass.h",
+    "Rewrite.h",
     "Support.h",
     "Transforms.h",
     "Dialect/GPU.h",
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp
index b4d7bd7..3c52c16 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp
@@ -147,7 +147,7 @@
       TypeConverter::SignatureConversion result(newRegion->getNumArguments());
       (void)getTypeConverter()->convertSignatureArgs(
           newRegion->getArgumentTypes(), result);
-      rewriter.applySignatureConversion(newRegion, result);
+      rewriter.applySignatureConversion(&newRegion->front(), result);
     }
 
     Operation *newOp = rewriter.create(state);
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp
index 4272150..13958d5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp
@@ -172,7 +172,8 @@
                                            "argument type conversion failed");
       }
 
-      rewriter.applySignatureConversion(newRegion, result, typeConverter);
+      rewriter.applySignatureConversion(&newRegion->front(), result,
+                                        typeConverter);
     }
 
     Operation *newOp = rewriter.create(state);
diff --git a/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
index 1388ea6..3bd452d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
@@ -4,18 +4,15 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include <numeric>
-
 #include "iree/compiler/Codegen/Common/PassDetail.h"
 #include "iree/compiler/Codegen/Common/Passes.h"
+#include "llvm/Support/MathExtras.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir::iree_compiler {
@@ -208,7 +205,7 @@
             VectorType::get({numElements}, iterType.getElementType());
         castTypes.push_back(shapeCastType);
         auto targetType =
-            VectorType::get({mlir::ceilDiv(totalBits, 32)},
+            VectorType::get({llvm::divideCeilSigned(totalBits, 32)},
                             rewriter.getIntegerType(
                                 std::min(static_cast<int64_t>(32), totalBits)));
         targetTypes.push_back(targetType);
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
index d748a29..4610c54 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
@@ -5,7 +5,6 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include <algorithm>
-#include <numeric>
 
 #include "iree/compiler/Codegen/Common/GPU/Passes.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
@@ -13,6 +12,7 @@
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/MarkerUtils.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -20,7 +20,6 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/MLIRContext.h"
-#include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "iree-codegen-gpu-distribute-shared-memory-copy"
@@ -349,7 +348,7 @@
       ubCstOp.value() < 0 || stepCstOp.value() < 0)
     return 0;
   int64_t tripCount =
-      mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value());
+      llvm::divideCeil(ubCstOp.value() - lbCstOp.value(), stepCstOp.value());
   return tripCount;
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index 2917e66..9edb2d9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -226,7 +226,8 @@
       }
       signatureConverter.addInputs(index, legalizedArgType.value());
     }
-    rewriter.applySignatureConversion(&modifiedOpRegion, signatureConverter);
+    rewriter.applySignatureConversion(&modifiedOpRegion.front(),
+                                      signatureConverter);
 
     // 6. Introduce scalar conversion operations to convert back to the
     // original scalar type.
@@ -368,7 +369,8 @@
     }
     signatureConverter.addInputs(0, legalizedArgType.value());
     signatureConverter.addInputs(1, legalizedArgType.value());
-    rewriter.applySignatureConversion(&modifiedOpRegion, signatureConverter);
+    rewriter.applySignatureConversion(&modifiedOpRegion.front(),
+                                      signatureConverter);
 
     {
       // Introduce scalar conversion operations to convert back to the original
@@ -444,7 +446,8 @@
       }
       signatureConverter.addInputs(index, legalizedArgType.value());
     }
-    rewriter.applySignatureConversion(&modifiedOpRegion, signatureConverter);
+    rewriter.applySignatureConversion(&modifiedOpRegion.front(),
+                                      signatureConverter);
 
     {
       // Introduce scalar conversion operations to convert back to the original
@@ -536,7 +539,8 @@
         doSignatureConversion |= argType != legalizedType;
       }
       if (doSignatureConversion) {
-        rewriter.applySignatureConversion(&newOpRegion, signatureConverter);
+        rewriter.applySignatureConversion(&newOpRegion.front(),
+                                          signatureConverter);
       }
     }
     rewriter.replaceOp(op, newOp->getResults());
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir
index 6914f1f..e11573a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir
@@ -180,30 +180,13 @@
   %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x4x32x8xf32> -> tensor<128x256xf32>
   return %0 : tensor<128x256xf32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 8)>
 // CHECK:      func.func @CKck_to_KC
 // CHECK-SAME:   %[[IN:[A-Za-z0-9]+]]:
 // CHECK-SAME:   %[[OUT:[A-Za-z0-9]+]]:
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C8:.+]] = arith.constant 8 : index
-// CHECK-DAG:    %[[C32:.+]] = arith.constant 32 : index
-// CHECK-DAG:    %[[C128:.+]] = arith.constant 128 : index
-// CHECK-DAG:    %[[C256:.+]] = arith.constant 256 : index
-// CHECK:        %[[RES0:.+]] = scf.for %[[K:.+]] = %[[C0]] to %[[C128]] step %[[C32]]
-// CHECK-SAME:     iter_args(%[[ITER0:.+]] = %[[OUT]])
-// CHECK:          %[[RES1:.+]] = scf.for %[[C:.+]] = %[[C0]] to %[[C256]] step %[[C8]]
-// CHECK-SAME:       iter_args(%[[ITER1:.+]] = %[[ITER0]])
-// CHECK-DAG:        %[[IN_K:.+]] = affine.apply #[[MAP0]](%[[K]])
-// CHECK-DAG:        %[[IN_C:.+]] = affine.apply #[[MAP1]](%[[C]])
-// CHECK:            %[[IN_SLICE:.+]] = tensor.extract_slice %[[IN]][%[[IN_C]], %[[IN_K]], 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
-// CHECK:            %[[TILE:.+]] = tensor.extract_slice %[[IN_SLICE]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32>
-// CHECK:            %[[INSERT:.+]] = tensor.insert_slice %[[TILE]] into %[[ITER1]][%[[K]], %[[C]]] [32, 8] [1, 1]
-// CHECK:            scf.yield %[[INSERT]]
-// CHECK:          }
-// CHECK:          scf.yield %[[RES1]]
-// CHECK:        }
-// CHECK:        return %[[RES0]]
+// CHECK:        %[[TRANSP:.+]] = linalg.transpose ins(%[[IN]]
+// CHECK:        %[[COLLAPSED:.+]] = tensor.collapse_shape %[[TRANSP]] {{.+}} : tensor<4x32x32x8xf32> into tensor<128x256xf32>
+// CHECK:        %[[RES:.+]] = linalg.copy ins(%[[COLLAPSED]]
+// CHECK:        return %[[RES]]
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
index ceb11bb..fb529d6 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp
@@ -10,6 +10,7 @@
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/MathExtras.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -26,7 +27,6 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/MathExtras.h"
 
 #define DEBUG_TYPE "iree-codegen-gpu-transforms"
 
@@ -48,7 +48,7 @@
 
   int64_t tripCount = 1;
   for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
-    tripCount *= mlir::ceilDiv((ub - lb), step);
+    tripCount *= llvm::divideCeil((ub - lb), step);
   }
   return tripCount;
 }
@@ -464,6 +464,7 @@
     }
     return llvm::to_vector(llvm::seq(static_cast<int64_t>(0), rank));
   };
+  Value laneId = newForallOp.getInductionVar(0);
 
   // LHS slice offsets.
   int64_t lhsOuterRank = mmaOp.getLhsOuterRank();
@@ -476,9 +477,8 @@
   SmallVector<int64_t> lhsPermutation = getOrInferPermutationOfRank(
       mmaOp.getLhsPermutation(), mmaOp.getLhsInnerShape().size());
   if (failed(mmaOp.getKind().populateOperandOffsetsSizesStrides(
-          rewriter, loc, IREE::GPU::MMAFragment::Lhs,
-          *newForallOp.getSingleInductionVar(), lhsPermutation, lhsOffsets,
-          lhsSizes, lhsStrides))) {
+          rewriter, loc, IREE::GPU::MMAFragment::Lhs, laneId, lhsPermutation,
+          lhsOffsets, lhsSizes, lhsStrides))) {
     return failure();
   }
   // Extract the rank-reduced slice of the lhs based on the expected inner
@@ -497,9 +497,8 @@
   SmallVector<int64_t> rhsPermutation = getOrInferPermutationOfRank(
       mmaOp.getRhsPermutation(), mmaOp.getRhsInnerShape().size());
   if (failed(mmaOp.getKind().populateOperandOffsetsSizesStrides(
-          rewriter, loc, IREE::GPU::MMAFragment::Rhs,
-          *newForallOp.getSingleInductionVar(), rhsPermutation, rhsOffsets,
-          rhsSizes, rhsStrides))) {
+          rewriter, loc, IREE::GPU::MMAFragment::Rhs, laneId, rhsPermutation,
+          rhsOffsets, rhsSizes, rhsStrides))) {
     return failure();
   }
   // Extract the rank-reduced slice of the rhs based on the expected inner
@@ -518,9 +517,8 @@
   SmallVector<int64_t> accPermutation = getOrInferPermutationOfRank(
       mmaOp.getAccPermutation(), mmaOp.getAccInnerShape().size());
   if (failed(mmaOp.getKind().populateOperandOffsetsSizesStrides(
-          rewriter, loc, IREE::GPU::MMAFragment::Acc,
-          *newForallOp.getSingleInductionVar(), accPermutation, accOffsets,
-          accSizes, accStrides))) {
+          rewriter, loc, IREE::GPU::MMAFragment::Acc, laneId, accPermutation,
+          accOffsets, accSizes, accStrides))) {
     return failure();
   }
   // Extract the rank-reduced slice of the accumulator based on the expected
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp
index d13358d..624361f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastAddressSpaceFunction.cpp
@@ -4,23 +4,14 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include "iree/compiler/Codegen/LLVMGPU/KernelConfig.h"
 #include "iree/compiler/Codegen/LLVMGPU/PassDetail.h"
 #include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
-#include "iree/compiler/Codegen/Transforms/Transforms.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
-#include "llvm/Support/raw_ostream.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
-#include "mlir/Dialect/GPU/Transforms/Passes.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/MathExtras.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/Passes.h"
 
 #define DEBUG_TYPE "iree-llvmgpu-cast-address-space-function"
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/PrefetchSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/PrefetchSharedMemoryCopy.cpp
index c4b98e5..544c73f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/PrefetchSharedMemoryCopy.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/PrefetchSharedMemoryCopy.cpp
@@ -12,6 +12,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -24,7 +25,6 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/MathExtras.h"
 
 #define DEBUG_TYPE "iree-codegen-llvmgpu-prefetch-shared-memory-copy"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -279,7 +279,7 @@
     ub = *ubCst;
     step = *stepCst;
 
-    int64_t numIters = mlir::ceilDiv(ub - lb, step);
+    int64_t numIters = llvm::divideCeil(ub - lb, step);
     if (numIters <= 2)
       return failure();
 
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
index 709d635..c168da6 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
@@ -319,7 +319,7 @@
       signatureConverter.addInputs(index, arg.getType());
     }
     // Creates a new function with the update signature.
-    rewriter.applySignatureConversion(&funcOp.getFunctionBody(),
+    rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(),
                                       signatureConverter);
 
     // Creates a new function with the update signature.
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.cpp
index bc992dd..fcba371 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/AbstractGemmLikeStrategy.cpp
@@ -11,8 +11,8 @@
 #include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
-#include "mlir/Support/MathExtras.h"
 
 using namespace mlir;
 
@@ -134,7 +134,7 @@
     numWarps = SmallVector<int64_t>(clNumWarps.begin(), clNumWarps.end());
   } else {
     numWarps = numThreads;
-    numWarps[0] = mlir::ceilDiv(numWarps[0], getSubgroupSize());
+    numWarps[0] = llvm::divideCeil(numWarps[0], getSubgroupSize());
   }
   if (clUseAsyncCopies.getNumOccurrences())
     useAsyncCopies = clUseAsyncCopies;
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp
index 259c86f..7d64630 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp
@@ -10,8 +10,8 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
-#include "mlir/Support/MathExtras.h"
 
 using namespace mlir;
 
@@ -126,7 +126,7 @@
       llvm::zip(copySizes, maybeCopyMapping->numThreads), [](auto &&pair) {
         int64_t size, numThreads;
         std::tie(size, numThreads) = pair;
-        return mlir::ceilDiv(size, numThreads);
+        return llvm::divideCeilSigned(size, numThreads);
       }));
   SmallVector<Attribute> allThreadMappings{linearId2(ctx), linearId1(ctx),
                                            linearId0(ctx)};
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MappingInfo.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MappingInfo.cpp
index 56194bf..8d1c2c3 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MappingInfo.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MappingInfo.cpp
@@ -4,12 +4,10 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-#include "iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.h"
+#include "iree/compiler/Codegen/TransformStrategies/GPU/MappingInfo.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
-#include "mlir/Support/MathExtras.h"
 
 using namespace mlir;
 
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h
index 14385c5..99c41d0 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h
@@ -13,9 +13,7 @@
 #include "iree/compiler/Codegen/TransformStrategies/GPU/Common.h"
 #include "iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.h"
 #include "llvm/Support/raw_ostream.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/MathExtras.h"
 
 namespace mlir::iree_compiler::gpu {
 
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp
index 40385df..770091b 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp
@@ -23,6 +23,7 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MathExtras.h"
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
@@ -38,7 +39,6 @@
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Support/MathExtras.h"
 
 using namespace mlir;
 
@@ -321,7 +321,7 @@
   // the future.
   if (strategy.pipelineDepth * strategy.reductionTileSize > strategy.k()) {
     strategy.pipelineDepth =
-        mlir::floorDiv(strategy.k(), strategy.reductionTileSize);
+        llvm::divideFloorSigned(strategy.k(), strategy.reductionTileSize);
   }
 }
 
@@ -572,7 +572,7 @@
   // the future.
   if (strategy.pipelineDepth * strategy.reductionTileSize > strategy.k()) {
     strategy.pipelineDepth =
-        mlir::floorDiv(strategy.k(), strategy.reductionTileSize);
+        llvm::divideFloorSigned(strategy.k(), strategy.reductionTileSize);
   }
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp
index 8acf32c..77700c0 100644
--- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp
@@ -6,34 +6,22 @@
 
 #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
 
-#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Utils.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/MathExtras.h"
 
 namespace mlir::iree_compiler::IREE::Encoding {
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index bb4f817..3ee627a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -6,7 +6,6 @@
 
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
 #include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
@@ -14,14 +13,12 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/MathExtras.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -33,14 +30,12 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/MathExtras.h"
 
 #include <cstdint>
 #include <optional>
@@ -840,7 +835,8 @@
       resultShape[tiledDim] = ShapedType::kDynamic;
       continue;
     }
-    resultShape[tiledDim] = ceilDiv(resultShape[tiledDim], innerTileSizes[idx]);
+    resultShape[tiledDim] =
+        llvm::divideCeil(resultShape[tiledDim], innerTileSizes[idx]);
   }
 
   // Swap tile loops if outer_dims_perm is available.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index c74b8ca..5eef577 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -1608,7 +1608,7 @@
       signatureConverter.addInputs(arg.index(), convertedType);
     }
 
-    rewriter.applySignatureConversion(&funcOp.getFunctionBody(),
+    rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(),
                                       signatureConverter);
 
     // Creates a new function with the updated signature.
diff --git a/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp b/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp
index 92dd2fb..d59a0ba 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp
@@ -172,7 +172,7 @@
       TypeConverter::SignatureConversion result(newRegion->getNumArguments());
       (void)getTypeConverter()->convertSignatureArgs(
           newRegion->getArgumentTypes(), result);
-      rewriter.applySignatureConversion(newRegion, result);
+      rewriter.applySignatureConversion(&newRegion->front(), result);
     }
 
     Operation *newOp = rewriter.create(state);
diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
index 4fd3d76..f6162eb 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
@@ -468,7 +468,7 @@
       TypeConverter::SignatureConversion result(newRegion->getNumArguments());
       (void)getTypeConverter()->convertSignatureArgs(
           newRegion->getArgumentTypes(), result);
-      rewriter.applySignatureConversion(newRegion, result);
+      rewriter.applySignatureConversion(&newRegion->front(), result);
     }
     Operation *newOp = rewriter.create(state);
     rewriter.replaceOp(op, newOp->getResults());