Integrate LLVM at llvm/llvm-project@55f0b3370871

Updates LLVM usage to match
[55f0b3370871](https://github.com/llvm/llvm-project/commit/55f0b3370871)

PiperOrigin-RevId: 398424211
diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index 8b2bb0f..8342f29 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -4,7 +4,7 @@
 aa533abfd4232b01f9e57041d70114d5a77e6de0 third_party/googletest
 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
 acd6f6f014c25e46363e718381e0b35205df2d83 third_party/libyaml
-f5b8f1247cd9d1b18b7b95f6f197d4d654597529 third_party/llvm-project
+55f0b337087136554122f942fea951a357bc4a49 third_party/llvm-project
 fc63b03f73f5781f145df3e261a4dc26638df192 third_party/mlir-hlo
 3f701faace7addc75d16dea8a6cd769fa5b3f260 third_party/musl
 4c7697dbe973ed01ae6fbec37d186ebd05982e1f third_party/pybind11
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 847d2cd..8e612ad 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -80,6 +80,7 @@
 # keep sorted
 TFLITE_FAILING = [
     "concat_test.py",
+    "control_flow_test.py",
     "einsum_dynamic_test.py",
     "einsum_static_test.py",
     "einsum_vector_test.py",
diff --git a/integrations/tensorflow/e2e/keras/layers/BUILD b/integrations/tensorflow/e2e/keras/layers/BUILD
index 57b41eb..e41cec5 100644
--- a/integrations/tensorflow/e2e/keras/layers/BUILD
+++ b/integrations/tensorflow/e2e/keras/layers/BUILD
@@ -127,6 +127,8 @@
             "AveragePooling3D",
             "Conv3DTranspose",
             "ConvLSTM2D",
+            "GRU",
+            "LSTM",
             "Softmax",
             "MaxPool3D",
             "ZeroPadding3D",
diff --git a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
index bee9daa..9ca64f2 100644
--- a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
+++ b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
@@ -378,6 +378,9 @@
     if (castOp->getNumOperands() != 1) return failure();
 
     Value input = operands.front();
+    // We only want to handle cases where the cast op handles memref types.
+    if (!input.getType().isa<ShapedType>()) return failure();
+
     if (!isRankZeroOrOneMemRef(input.getType())) {
       return rewriter.notifyMatchFailure(
           castOp, "expected converted memref of rank <= 1");
@@ -514,9 +517,15 @@
           return isRankZeroOrOneMemRef(subspanOp.getType());
         });
     target.addDynamicallyLegalOp<memref::LoadOp>([](memref::LoadOp loadOp) {
+      // TODO: Explicitly allow allocation ops for now. Need to properly
+      // flatten.
+      if (isa<memref::AllocOp>(loadOp.memref().getDefiningOp())) return true;
       return isRankZeroOrOneMemRef(loadOp.getMemRefType());
     });
     target.addDynamicallyLegalOp<memref::StoreOp>([](memref::StoreOp storeOp) {
+      // TODO: Explicitly allow allocation ops for now. Need to properly
+      // flatten.
+      if (isa<memref::AllocOp>(storeOp.memref().getDefiningOp())) return true;
       return isRankZeroOrOneMemRef(storeOp.getMemRefType());
     });
     target.addDynamicallyLegalOp<vector::TransferReadOp>(
@@ -531,8 +540,11 @@
         });
     target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
         [](UnrealizedConversionCastOp castOp) {
-          return castOp->getNumOperands() == 1 &&
-                 isRankZeroOrOneMemRef(castOp->getOperandTypes().front());
+          if (castOp->getNumOperands() != 1) return false;
+
+          Type inputType = castOp->getOperandTypes().front();
+          return !inputType.isa<ShapedType>() ||
+                 isRankZeroOrOneMemRef(inputType);
         });
 
     // Use partial conversion here so that we can ignore allocations created by
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
index 3247002..90e1164 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
@@ -17,6 +17,7 @@
 #include "iree/compiler/Codegen/Utils/Utils.h"
 #include "iree/compiler/Dialect/Util/IR/UtilOps.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -25,6 +26,8 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
+#define DEBUG_TYPE "spirv-vectorize-load-store"
+
 constexpr int kMaxVectorNumBits = 128;
 constexpr int kMaxVectorNumElements = 4;
 
@@ -61,24 +64,25 @@
   return {};
 }
 
-// Calculates the vector size we want to use based on the memref uses.
-static unsigned calculateMemrefVecSize(SmallVectorImpl<Operation *> &uses) {
-  unsigned minSize = kMaxVectorNumBits;
+// Calculates the vector bit count we want to use based on the memref uses.
+static unsigned calculateMemRefVectorNumBits(
+    SmallVectorImpl<Operation *> &uses) {
+  unsigned minBits = kMaxVectorNumBits;
   for (Operation *op : uses) {
     auto transferOp = dyn_cast<VectorTransferOpInterface>(op);
     if (!transferOp) return 0;
     Optional<unsigned> transferSize = getBitWidth(transferOp.getVectorType());
     if (!transferSize) return 0;
-    minSize = std::min(minSize, *transferSize);
+    minBits = std::min(minBits, *transferSize);
   }
-  return minSize;
+  return minBits;
 }
 
-/// If the memref is vectorizable return the vector size we want to use,
+/// If the memref is vectorizable return the vector bit count we want to use,
 /// otherwise return 0. If it returns a value greater than 0 it also returns the
 /// memref uses.
-static unsigned isMemRefAndVectorizable(Value value,
-                                        SmallVectorImpl<Operation *> &uses) {
+static unsigned isMemRefVectorizable(Value value,
+                                     SmallVectorImpl<Operation *> &uses) {
   auto memrefType = value.getType().dyn_cast<MemRefType>();
 
   // Require scalar element type
@@ -93,9 +97,17 @@
   // buffer.
   if (memrefType.getShape().back() % 2 != 0) return 0;
 
-  if (kMaxVectorNumBits % memrefType.getElementTypeBitWidth() != 0) return 0;
+  unsigned elementNumBits = memrefType.getElementTypeBitWidth();
+  if (kMaxVectorNumBits % elementNumBits != 0) return 0;
 
-  if (getUsesIfAllTransferOp(value, uses)) return calculateMemrefVecSize(uses);
+  if (getUsesIfAllTransferOp(value, uses)) {
+    unsigned vectorBits = calculateMemRefVectorNumBits(uses);
+    unsigned vectorSize = vectorBits / elementNumBits;
+    // Again make sure we don't have vectors of odd numbers.
+    if (vectorSize % 2 != 0) return 0;
+    return vectorBits;
+  }
+
   return 0;
 }
 
@@ -142,7 +154,7 @@
 
 void MemRefUsageAnalysis::analyzeMemRefValue(Value value) {
   SmallVector<Operation *, 4> vectorUses;
-  if (unsigned vectorSize = isMemRefAndVectorizable(value, vectorUses)) {
+  if (unsigned vectorSize = isMemRefVectorizable(value, vectorUses)) {
     valueToVectorBitsMap.insert(std::make_pair(value, vectorSize));
     transferOps.insert(vectorUses.begin(), vectorUses.end());
   }
@@ -346,7 +358,10 @@
            !ShapedType::isDynamic(memrefType.getShape().back()));
 
     auto vecMemRef = getVectorizedMemRefType(rewriter, bindingOp.getResult());
-    if (!vecMemRef) return failure();
+    if (!vecMemRef) {
+      return rewriter.notifyMatchFailure(bindingOp,
+                                         "cannot get vectorized memref type");
+    }
     rewriter.replaceOpWithNewOp<IREE::HAL::InterfaceBindingSubspanOp>(
         bindingOp, *vecMemRef, bindingOp.binding(), bindingOp.byte_offset(),
         bindingOp.byte_length(), bindingOp.dynamic_dims());
diff --git a/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir b/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir
index 3bf7e59..4bfc6e7 100644
--- a/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/vectorize_load_store.mlir
@@ -216,7 +216,6 @@
   %c3 = constant 3: index
   %f0 = constant 0.0 : f32
   %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<20xf32>
-  %1 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<f32>
   %2 = hal.interface.binding.subspan @io::@ret1[%c0] : memref<20xf32>
   // CHECK-DAG: %[[INDEX0:.+]] = constant 3 : index
   // CHECK-DAG: %[[INDEX1:.+]] = constant 4 : index
diff --git a/third_party/llvm-project b/third_party/llvm-project
index f5b8f12..55f0b33 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit f5b8f1247cd9d1b18b7b95f6f197d4d654597529
+Subproject commit 55f0b337087136554122f942fea951a357bc4a49