Integrate llvm-project and bump dependencies. (#12562)

* llvm-project: e510d0bda0876c4baa3a270dca39b95da7ec6d9e
* mlir-hlo: e86610442f58b889a57bf814d75c4b50c769c2a3
* tensorflow: 67ba341c869e30ee4a89e040cd875d12b9bc666e

Cherry-picked from LLVM:
```
commit 80074d5fc0ab3f165865b15f5bf55ffac0917bcd (HEAD -> integrate-3-8-2023, fork/integrate-3-8-2023)
Author: Matthias Springer <me@m-sp.org>
Date:   Fri Mar 10 11:25:15 2023 +0100

    [mlir][NFC] reifyResultShapes: Add extra error checking
    
    This change adds a new helper function `mlir::reifyResultShapes` that calls the corresponding interface method and also checks the result produced by the implementation when running in debug mode. Bugs due to incorrect interface implementations can be difficult to debug.
    
    This helper function also reduces the amount of code needed at call sites: the cast to `ReifyRankedShapedTypeOpInterface` is done in the helper function.
    
    Differential Revision: https://reviews.llvm.org/D145777

commit 32b15f601de173e9511f470f7423108d3154e582
Author: Matthias Springer <me@m-sp.org>
Date:   Fri Mar 10 11:24:43 2023 +0100

    [mlir][tensor/linalg] Fix bug in reifyResultShapes
    
    `reifyResultShapes` should return an IntegerAttr if and only if the corresponding dimension is static.
    
    Differential Revision: https://reviews.llvm.org/D145702

commit 894555cd6adf2e0faffe713373a266650b40bb4e
Author: David Green <david.green@arm.com>
Date:   Wed Mar 8 12:48:21 2023 +0000

    [AArch64] Fix load-insert-zero patterns with i8 and negative offsets.
    
    These should have been using the LDURBi instructions where the offset is
    negative, as reported from the reproducer in D144086.
```

Created a new commit on iree-mlir-hlo fork:


https://github.com/iree-org/iree-mhlo-fork/commit/b14e9d9b06255e4476f5698e3bfc531dec793ded
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp
index a2847a1..4c985a9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorPad.cpp
@@ -85,7 +85,7 @@
   // Slice out the original shape from the padded result to pass on to
   // consumers. The original linalg op is used to provide the dims for the reify
   // result shapes.
-  SmallVector<SmallVector<Value>> reifiedResultShapes;
+  SmallVector<SmallVector<OpFoldResult>> reifiedResultShapes;
   if (failed(cast<ReifyRankedShapedTypeOpInterface>(linalgOp.getOperation())
                  .reifyResultShapes(rewriter, reifiedResultShapes))) {
     return failure();
@@ -98,8 +98,7 @@
     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> sizes;
-    for (Value v : reifiedResultShapes[resultNumber])
-      sizes.push_back(getAsOpFoldResult(v));
+    for (OpFoldResult v : reifiedResultShapes[resultNumber]) sizes.push_back(v);
     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
     paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
         loc, paddedResult, offsets, sizes, strides));
@@ -148,7 +147,7 @@
 
   // Slice out the original shape from the padded result to pass on to
   // consumers.
-  SmallVector<SmallVector<Value>> reifiedResultShapes;
+  SmallVector<SmallVector<OpFoldResult>> reifiedResultShapes;
   if (failed(op.reifyResultShapes(rewriter, reifiedResultShapes))) {
     return failure();
   }
@@ -156,8 +155,7 @@
   Value paddedSubviewResults;
   int64_t rank = paddedOp.getDestRank();
   SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
-  SmallVector<OpFoldResult> sizes =
-      getAsOpFoldResult(ValueRange(reifiedResultShapes[0]));
+  SmallVector<OpFoldResult> sizes = reifiedResultShapes[0];
   SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
   paddedSubviewResults = rewriter.create<tensor::ExtractSliceOp>(
       loc, paddedOp.getResult(), offsets, sizes, strides);
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp b/compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp
index d77440b..bfa6bb0 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp
@@ -106,7 +106,7 @@
   Block *block = &region.front();
   Operation *terminator = block->getTerminator();
   ValueRange results = terminator->getOperands();
-  rewriter.mergeBlockBefore(block, op, blockArgs);
+  rewriter.inlineBlockBefore(block, op, blockArgs);
   rewriter.replaceOp(op, results);
   rewriter.eraseOp(terminator);
 }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp
index 73d1a90..dde4e05 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp
@@ -10,6 +10,7 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
@@ -177,7 +178,7 @@
     if (reshapeOp->template getParentOfType<Flow::DispatchWorkgroupsOp>()) {
       return failure();
     }
-    SmallVector<SmallVector<Value>> outputShape;
+    SmallVector<SmallVector<OpFoldResult>> outputShape;
     ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
         cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
     if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
@@ -185,10 +186,11 @@
       return failure();
     }
     SmallVector<Value> outputDynamicShapes;
-    for (auto [resultShape, outputShape] : llvm::zip_equal(
+    for (auto [resultShape, outputShp] : llvm::zip_equal(
              reshapeOp.getResultType().getShape(), outputShape[0])) {
       if (resultShape != ShapedType::kDynamic) continue;
-      outputDynamicShapes.push_back(outputShape);
+      outputDynamicShapes.push_back(getValueOrCreateConstantIndexOp(
+          rewriter, reshapeOp.getLoc(), outputShp));
     }
     rewriter.replaceOpWithNewOp<IREE::Flow::TensorReshapeOp>(
         reshapeOp, reshapeOp.getResultType(), reshapeOp.getSrc(),
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index f88b40b..cc288ce 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -495,7 +495,7 @@
 
 LogicalResult DispatchTieShapeOp::reifyResultShapes(
     OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  SmallVector<Value> shape;
+  SmallVector<OpFoldResult> shape;
   unsigned dynamicIdx = 0;
   auto tensorType =
       getResult().getType().cast<IREE::Flow::DispatchTensorType>();
@@ -503,7 +503,7 @@
     if (dim == ShapedType::kDynamic) {
       shape.push_back(getDynamicDims()[dynamicIdx++]);
     } else {
-      shape.push_back(b.create<arith::ConstantIndexOp>(getLoc(), dim));
+      shape.push_back(b.getIndexAttr(dim));
     }
   }
   reifiedReturnShapes.push_back(shape);
@@ -635,7 +635,7 @@
 LogicalResult DispatchTensorLoadOp::reifyResultShapes(
     OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   auto mixedSizes = getMixedSizes();
-  SmallVector<Value> shape;
+  SmallVector<OpFoldResult> shape;
   if (!mixedSizes.empty()) {
     // Slicing out a tile; return the size sliced.
     shape.reserve(mixedSizes.size());
@@ -644,8 +644,7 @@
       if (droppedDims.test(mixedSize.index())) {
         continue;
       }
-      shape.push_back(
-          getValueOrCreateConstantIndexOp(b, getLoc(), mixedSize.value()));
+      shape.push_back(mixedSize.value());
     }
   } else {
     // Result size matches the source size (no slicing).
@@ -654,7 +653,7 @@
       if (dim == ShapedType::kDynamic) {
         shape.push_back(getSourceDims()[dynamicIdx++]);
       } else {
-        shape.push_back(b.create<arith::ConstantIndexOp>(getLoc(), dim));
+        shape.push_back(b.getIndexAttr(dim));
       }
     }
   }
@@ -1380,14 +1379,14 @@
 
 LogicalResult TensorTieShapeOp::reifyResultShapes(
     OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  SmallVector<Value> shape;
+  SmallVector<OpFoldResult> shape;
   unsigned dynamicIdx = 0;
   auto tensorType = getResult().getType().cast<RankedTensorType>();
   for (int64_t dim : tensorType.getShape()) {
     if (dim == ShapedType::kDynamic) {
       shape.push_back(getDynamicDims()[dynamicIdx++]);
     } else {
-      shape.push_back(b.create<arith::ConstantIndexOp>(getLoc(), dim));
+      shape.push_back(b.getIndexAttr(dim));
     }
   }
   reifiedReturnShapes.push_back(shape);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index 2b8ba21..c56348c 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -69,7 +69,7 @@
   LogicalResult status = sliceOp.reifyResultShapes(builder, resultDims);
   (void)status;
   assert(succeeded(status) && "reifyResultShapes failed");
-  return llvm::to_vector(llvm::map_range(resultDims[0], [&](Value v) {
+  return llvm::to_vector(llvm::map_range(resultDims[0], [&](OpFoldResult v) {
     return Range{zero, v, one};
   }));
 }
@@ -158,7 +158,7 @@
     if (failed(reifyShapeOp.reifyResultShapes(b, dims))) return failure();
     for (int64_t i = 0; i < shapedType.getRank(); ++i)
       if (shapedType.isDynamicDim(i))
-        dynamicDims.push_back(dims[opResult.getResultNumber()][i]);
+        dynamicDims.push_back(dims[opResult.getResultNumber()][i].get<Value>());
     return success();
   }
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index b785220..89a17f9 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -998,8 +998,8 @@
     // Begin/end recording and inline the execution region between them.
     auto endOp =
         rewriter.create<IREE::HAL::CommandBufferFinalizeOp>(loc, commandBuffer);
-    rewriter.mergeBlockBefore(&executeOp.getBody().front(), endOp,
-                              adaptor.getResourceOperands());
+    rewriter.inlineBlockBefore(&executeOp.getBody().front(), endOp,
+                               adaptor.getResourceOperands());
 
     // Gather wait/signal fence, which are optional.
     Value waitFence =
@@ -1032,7 +1032,7 @@
                                 OpBuilder::atBlockBegin(&bodyBlock));
 
     // Inline the serial execution region.
-    rewriter.mergeBlockBefore(&serialOp.getBody().front(), serialOp);
+    rewriter.inlineBlockBefore(&serialOp.getBody().front(), serialOp);
     rewriter.eraseOp(serialOp);
     return success();
   }
@@ -1046,7 +1046,7 @@
       ConversionPatternRewriter &rewriter) const override {
     // Inline the concurrent execution region.
     // TODO(benvanik): split barriers (event set/wait) when nesting.
-    rewriter.mergeBlockBefore(&concurrentOp.getBody().front(), concurrentOp);
+    rewriter.inlineBlockBefore(&concurrentOp.getBody().front(), concurrentOp);
     rewriter.eraseOp(concurrentOp);
     return success();
   }
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 6fae72b..7cce235 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -2829,8 +2829,9 @@
     }
 
     // Merge the successor into the current block and erase the branch.
-    rewriter.mergeBlocks(succ, opParent, op.getOperands());
+    SmallVector<Value> operands(op.getOperands());
     rewriter.eraseOp(op);
+    rewriter.mergeBlocks(succ, opParent, operands);
     return success();
   }
 };
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
index f8268d6..e4a8e1f 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
@@ -251,7 +251,7 @@
   std::pair<LogicalResult, Value> createConst(Location loc, Attribute value,
                                               OpBuilder &builder) {
     if (auto integerAttr = value.dyn_cast<IntegerAttr>()) {
-      if (integerAttr.getValue().isNullValue()) {
+      if (integerAttr.getValue().isZero()) {
         // Globals are zero-initialized by default.
         return {success(), {}};
       }
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp
index a457146..d7b3899 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Conversion/StreamToHALInline/Patterns.cpp
@@ -456,8 +456,8 @@
       IREE::Stream::CmdExecuteOp executeOp, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
     // Inline the serial execution region.
-    rewriter.mergeBlockBefore(&executeOp.getBody().front(), executeOp,
-                              adaptor.getResourceOperands());
+    rewriter.inlineBlockBefore(&executeOp.getBody().front(), executeOp,
+                               adaptor.getResourceOperands());
     // Immediately resolve the timepoint.
     auto resolvedTimepoint =
         rewriter.create<arith::ConstantIntOp>(executeOp.getLoc(), 0, 64)
@@ -474,7 +474,7 @@
       IREE::Stream::CmdSerialOp serialOp, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
     // Inline the serial execution region.
-    rewriter.mergeBlockBefore(&serialOp.getBody().front(), serialOp);
+    rewriter.inlineBlockBefore(&serialOp.getBody().front(), serialOp);
     rewriter.eraseOp(serialOp);
     return success();
   }
@@ -487,7 +487,7 @@
       IREE::Stream::CmdConcurrentOp concurrentOp, OpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
     // Inline the concurrent execution region.
-    rewriter.mergeBlockBefore(&concurrentOp.getBody().front(), concurrentOp);
+    rewriter.inlineBlockBefore(&concurrentOp.getBody().front(), concurrentOp);
     rewriter.eraseOp(concurrentOp);
     return success();
   }
diff --git a/integrations/tensorflow/WORKSPACE b/integrations/tensorflow/WORKSPACE
index 8ff2d5b..70ec789 100644
--- a/integrations/tensorflow/WORKSPACE
+++ b/integrations/tensorflow/WORKSPACE
@@ -7,7 +7,7 @@
 
 load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
 
-TENSORFLOW_COMMIT = "eece4dba1fc65f7977085581852b0d6e6d42f04e"
+TENSORFLOW_COMMIT = "67ba341c869e30ee4a89e040cd875d12b9bc666e"
 
 git_repository(
     name = "org_tensorflow",
diff --git a/integrations/tensorflow/iree_tf_compiler/BUILD b/integrations/tensorflow/iree_tf_compiler/BUILD
index ec972be..8f5605d 100644
--- a/integrations/tensorflow/iree_tf_compiler/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/BUILD
@@ -41,10 +41,10 @@
         "@llvm-project//mlir:Transforms",
         "@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
         "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
+        "@org_tensorflow//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf",
         "@org_tensorflow//tensorflow/compiler/mlir/tosa:tf_passes",
         "@org_tensorflow//tensorflow/compiler/mlir/tosa:tf_tfl_passes",
         "@org_tensorflow//tensorflow/compiler/mlir/tosa:tfl_passes",
-        "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_tf",
         "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo",
         "@stablehlo//:chlo_ops",
     ],
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
index fed0926..c226280 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
@@ -56,8 +56,8 @@
         "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
         "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
         "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
+        "@org_tensorflow//tensorflow/compiler/mlir/tf2xla:xla_legalize_tf",
         "@org_tensorflow//tensorflow/compiler/mlir/tosa:tf_passes",
-        "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_tf",
         "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo",
         "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:all_passes",
         "@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:chlo_legalize_to_hlo",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp b/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp
index 3d1133a..a3cafca 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp
@@ -22,7 +22,7 @@
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
-#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h"
 
 namespace mlir {
 namespace iree_integrations {
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/LowerGlobalTensors.cpp b/integrations/tensorflow/iree_tf_compiler/TF/LowerGlobalTensors.cpp
index 54a90ac..6aa2e76 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/LowerGlobalTensors.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/LowerGlobalTensors.cpp
@@ -154,8 +154,8 @@
       return;
     }
     auto global = globalBuilder.create<mlir::ml_program::GlobalOp>(
-        globalTensor.getLoc(), name, globalTensor.getValue().getType(),
-        globalTensor.getIsMutable(), globalTensor.getValue(), nullptr);
+        globalTensor.getLoc(), name, globalTensor.getValue()->getType(),
+        globalTensor.getIsMutable(), *globalTensor.getValue(), nullptr);
     global.setPrivate();
     symbolRefMap[globalTensor] = global;
   }
diff --git a/integrations/tensorflow/test/iree_tfl_tests/README.md b/integrations/tensorflow/test/iree_tfl_tests/README.md
index 9d74f6d..87873fc 100644
--- a/integrations/tensorflow/test/iree_tfl_tests/README.md
+++ b/integrations/tensorflow/test/iree_tfl_tests/README.md
@@ -6,13 +6,19 @@
 
 |       Model        |      Status        |
 | ------------------ | ------------------ |
-person_detect        | PASS ✓
-east_text_detector   | PASS ✓
-vulkan_posenet_i8    | FAIL ✗
-cartoon_gan          | PASS ✓
-mnasnet              | PASS ✓
-gpt2                 | PASS ✓
-llvmcpu_posenet_i8   | PASS ✓
 mobilenet_v3         | PASS ✓
+llvmcpu_resnet_50_int8 | PASS ✓
+vulkan_mobilebert_tf2_quant | FAIL ✗
+cartoon_gan          | PASS ✓
+llvmcpu_mobilebert_tf2_quant | PASS ✓
+mnasnet              | PASS ✓
+person_detect        | PASS ✓
+vulkan_posenet_i8    | FAIL ✗
+east_text_detector   | PASS ✓
+gpt2                 | PASS ✓
 llvmcpu_mobilenet_v1 | PASS ✓
-vulkan_mobilenet_v1  | FAIL ✗
+llvmcpu_mobilenet_v3-large_uint8 | PASS ✓
+vulkan_mobilenet_v1  | PASS ✓
+vulkan_mobilenet_v3-large_uint8 | FAIL ✗
+llvmcpu_posenet_i8   | FAIL ✗
+vulkan_resnet_50_int8 | FAIL ✗
\ No newline at end of file
diff --git a/integrations/tensorflow/test/iree_tfl_tests/llvmcpu_posenet_i8.run b/integrations/tensorflow/test/iree_tfl_tests/llvmcpu_posenet_i8.run
index fb78e1c..ca4d4d7 100644
--- a/integrations/tensorflow/test/iree_tfl_tests/llvmcpu_posenet_i8.run
+++ b/integrations/tensorflow/test/iree_tfl_tests/llvmcpu_posenet_i8.run
@@ -1,2 +1,3 @@
 # REQUIRES: llvmcpu
 # RUN: %PYTHON -m iree_tfl_tests.posenet_i8_test --target_backend=llvmcpu --artifacts_dir=%t
+# XFAIL: *
diff --git a/integrations/tensorflow/test/iree_tfl_tests/update_tflite_model_documentation.py b/integrations/tensorflow/test/iree_tfl_tests/update_tflite_model_documentation.py
index 82c44f3..3da8969 100755
--- a/integrations/tensorflow/test/iree_tfl_tests/update_tflite_model_documentation.py
+++ b/integrations/tensorflow/test/iree_tfl_tests/update_tflite_model_documentation.py
@@ -1,3 +1,4 @@
+#!/bin/python3
 # Copyright 2022 The IREE Authors
 #
 # Licensed under the Apache License v2.0 with LLVM Exceptions.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
index b78ba0a..493963a 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
@@ -37,12 +37,12 @@
 
 template <typename Ty, typename DimOpTy>
 static void getDimValues(OpBuilder &b, Location loc, Value v, Ty t,
-                         SmallVector<Value> &dimVals) {
+                         SmallVector<OpFoldResult> &dimVals) {
   for (auto dim : llvm::enumerate(t.getShape())) {
     if (ShapedType::isDynamic(dim.value())) {
-      dimVals.push_back(b.create<DimOpTy>(loc, v, dim.index()));
+      dimVals.push_back(b.create<DimOpTy>(loc, v, dim.index()).getResult());
     } else {
-      dimVals.push_back(b.create<arith::ConstantIndexOp>(loc, dim.value()));
+      dimVals.push_back(b.getIndexAttr(dim.value()));
     }
   }
 }
@@ -51,7 +51,7 @@
     OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   Operation *op = getOperation();
   for (auto output : getOutputs()) {
-    SmallVector<Value> dims;
+    SmallVector<OpFoldResult> dims;
     Type outputType = output.getType();
     if (auto rankedTensorType = outputType.dyn_cast<RankedTensorType>()) {
       getDimValues<RankedTensorType, tensor::DimOp>(b, op->getLoc(), output,
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 915085d..1417860 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1937,14 +1937,17 @@
   // over the tile dimensions.
   for (auto dataTileDim :
        llvm::seq<unsigned>(getInputRank(), getOutputRank() - 1)) {
-    Value ub = outputShape[0][dataTileDim];
+    Value ub = getValueOrCreateConstantIndexOp(builder, loc,
+                                               outputShape[0][dataTileDim]);
     scf::ForOp loop = builder.create<scf::ForOp>(loc, zero, ub, one);
     builder.setInsertionPointToStart(loop.getBody());
     ivVec.push_back(loop.getInductionVar());
   }
   // The body of the innermost loops does the actual data movement.
-  builder.create<scf::ForOp>(loc, zero, outputShape[0].back(), one,
-                             ValueRange{},
+  builder.create<scf::ForOp>(loc, zero,
+                             getValueOrCreateConstantIndexOp(
+                                 builder, loc, outputShape[0].back()),
+                             one, ValueRange{},
                              [&](OpBuilder &bodyBuilder, Location bodyLoc,
                                  Value iv, ValueRange regionIterArgs) {
                                ivVec.push_back(iv);
@@ -2681,8 +2684,7 @@
   OpBuilder::InsertionGuard g(builder);
   builder.setInsertionPoint(getOperation());
   reifiedReturnShapes.resize(1);
-  reifiedReturnShapes[0] = getValueOrCreateConstantIndexOp(
-      builder, getLoc(), getDims(builder, getLoc(), getSource()));
+  reifiedReturnShapes[0] = getDims(builder, getLoc(), getSource());
   return success();
 }
 
@@ -2720,8 +2722,7 @@
   OpBuilder::InsertionGuard g(builder);
   builder.setInsertionPoint(getOperation());
   reifiedReturnShapes.resize(1);
-  reifiedReturnShapes[0] = getValueOrCreateConstantIndexOp(
-      builder, getLoc(), getDims(builder, getLoc(), getSource()));
+  reifiedReturnShapes[0] = getDims(builder, getLoc(), getSource());
   return success();
 }
 
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp
index a46512f..f55266e 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ForeachThreadToSequentialFor.cpp
@@ -64,8 +64,8 @@
   bool hasTerminator =
       !body->empty() && body->back().hasTrait<OpTrait::IsTerminator>();
   if (hasTerminator) {
-    rewriter.mergeBlockBefore(&forallOp.getRegion().front(),
-                              body->getTerminator(), bbArgsTranslated);
+    rewriter.inlineBlockBefore(&forallOp.getRegion().front(),
+                               body->getTerminator(), bbArgsTranslated);
   } else {
     rewriter.mergeBlocks(&forallOp.getRegion().front(), body, bbArgsTranslated);
   }
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 9512b6b..80074d5 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 9512b6b7c40e3249ec4db347956d63f4d84c8fc8
+Subproject commit 80074d5fc0ab3f165865b15f5bf55ffac0917bcd
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index 4ceea1d..b14e9d9 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit 4ceea1d9ae3c3f87071e2096c91f38da75d22242
+Subproject commit b14e9d9b06255e4476f5698e3bfc531dec793ded