LLVM integrate integrate-llvm-20231107 (#15470)

Co-authored-by: Quinn Dawkins <quinn@nod-labs.com>
Co-authored-by: Kunwar Grover <groverkss@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index 601b98c..1aae604 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -274,6 +274,7 @@
         "@llvm-project//mlir:MemRefTransformOps",
         "@llvm-project//mlir:SCFTransformOps",
         "@llvm-project//mlir:TensorTransformOps",
+        "@llvm-project//mlir:TransformLoopExtension",
         "@llvm-project//mlir:VectorTransformOps",
     ],
 )
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index 9e018f6..6e7488d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -227,6 +227,7 @@
     MLIRTensorTransforms
     MLIRTransformDialect
     MLIRTransformDialectTransforms
+    MLIRTransformLoopExtension
     MLIRTransforms
     MLIRVectorDialect
     MLIRVectorTransformOps
diff --git a/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp b/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp
index ec663c6..dd1faa4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CommonDialectRegistration.cpp
@@ -39,12 +39,15 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -100,7 +103,10 @@
   linalg::registerTransformDialectExtension(registry);
   memref::registerTransformDialectExtension(registry);
   scf::registerTransformDialectExtension(registry);
+  tensor::registerSubsetOpInterfaceExternalModels(registry);
   tensor::registerTransformDialectExtension(registry);
+  transform::registerLoopExtension(registry);
+  vector::registerSubsetOpInterfaceExternalModels(registry);
   vector::registerTransformDialectExtension(registry);
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/HoistRedundantVectorTransfers.cpp b/compiler/src/iree/compiler/Codegen/Common/HoistRedundantVectorTransfers.cpp
index 4a6b37f..085a9ed 100644
--- a/compiler/src/iree/compiler/Codegen/Common/HoistRedundantVectorTransfers.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/HoistRedundantVectorTransfers.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
 
 namespace mlir {
 namespace iree_compiler {
@@ -31,8 +32,12 @@
 void HoistRedundantVectorTransfersPass::runOnOperation() {
   auto funcOp = getOperation();
   linalg::hoistRedundantVectorTransfers(funcOp);
-  linalg::hoistRedundantVectorTransfersOnTensor(funcOp);
   IRRewriter rewriter(funcOp->getContext());
+  // Hoist redundant vector transfers on tensors.
+  // TODO: walking in some reverse / inside-out order would be more efficient
+  // and would capture more cases.
+  funcOp.walk(
+      [&](scf::ForOp forOp) { hoistLoopInvariantSubsets(rewriter, forOp); });
   vector::transferOpflowOpt(rewriter, funcOp);
 }
 } // namespace
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
index 25d0dc2..ae8e0ec 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BUILD.bazel
@@ -68,7 +68,10 @@
         "@llvm-project//mlir:SCFTransformOps",
         "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:TensorTransformOps",
+        "@llvm-project//mlir:TensorTransforms",
+        "@llvm-project//mlir:TransformLoopExtension",
         "@llvm-project//mlir:VectorTransformOps",
+        "@llvm-project//mlir:VectorTransforms",
     ],
 )
 
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
index c166bb4..8acf2f3 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/CMakeLists.txt
@@ -39,7 +39,10 @@
     MLIRSCFTransformOps
     MLIRTensorDialect
     MLIRTensorTransformOps
+    MLIRTensorTransforms
+    MLIRTransformLoopExtension
     MLIRVectorTransformOps
+    MLIRVectorTransforms
     iree::compiler::Codegen::Common::TransformExtensions::CommonExtensions
     iree::compiler::Codegen::LLVMCPU::TransformExtensions::LLVMCPUExtensions
     iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
index 8d22768..8a7dc63 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/Interfaces.cpp
@@ -32,7 +32,10 @@
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
 
 namespace mlir {
 namespace iree_compiler {
@@ -62,8 +65,11 @@
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
   scf::registerTransformDialectExtension(registry);
   scf::registerValueBoundsOpInterfaceExternalModels(registry);
+  tensor::registerSubsetOpInterfaceExternalModels(registry);
   tensor::registerTransformDialectExtension(registry);
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+  transform::registerLoopExtension(registry);
+  vector::registerSubsetOpInterfaceExternalModels(registry);
   vector::registerTransformDialectExtension(registry);
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp
index 324ec17..3bb8bb9 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp
@@ -348,7 +348,6 @@
   if (structType.isInitialized())
     return structType;
 
-  auto uint32Type = IntegerType::get(context, 32);
   auto opaquePtrType = LLVM::LLVMPointerType::get(context);
   SmallVector<Type> fieldTypes;
 
@@ -358,11 +357,9 @@
   // iree_hal_executable_import_thunk_v0_t import_thunk;
   // const iree_hal_executable_import_v0_t* import_funcs;
   // const void** import_contexts;
-  auto importThunkType = LLVM::LLVMFunctionType::get(
-      uint32Type, {opaquePtrType, opaquePtrType, opaquePtrType, opaquePtrType});
-  fieldTypes.push_back(LLVM::LLVMPointerType::get(importThunkType));
-  fieldTypes.push_back(LLVM::LLVMPointerType::get(opaquePtrType));
-  fieldTypes.push_back(LLVM::LLVMPointerType::get(opaquePtrType));
+  fieldTypes.push_back(LLVM::LLVMPointerType::get(context));
+  fieldTypes.push_back(LLVM::LLVMPointerType::get(context));
+  fieldTypes.push_back(LLVM::LLVMPointerType::get(context));
 
   // iree_hal_processor_v0_t processor;
   fieldTypes.push_back(processorType);
@@ -960,12 +957,12 @@
                  di.getPtrOf(di.getConstOf(di.getEnvironmentV0T())), builder);
   Value processorPtrValue = builder.create<LLVM::GEPOp>(
       loc, LLVM::LLVMPointerType::get(context),
-      LLVM::LLVMPointerType::get(environmentType), environmentPtrValue,
+      LLVM::LLVMPointerType::get(context), environmentPtrValue,
       LLVM::GEPArg(int32_t(EnvironmentField::processor)),
       /*inbounds=*/true);
   Value processorDataPtrValue = builder.create<LLVM::GEPOp>(
       loc, LLVM::LLVMPointerType::get(context),
-      LLVM::LLVMPointerType::get(processorType), processorPtrValue,
+      LLVM::LLVMPointerType::get(context), processorPtrValue,
       LLVM::GEPArg(int32_t(ProcessorField::data)),
       /*inbounds=*/true);
   Value updatedProcessorData =
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test.mlir
index 7050437..37b3d3e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test.mlir
@@ -41,7 +41,7 @@
 // CHECK-COUNT-2:        vector.transfer_read
 // CHECK-COUNT-4:        vector.contract
 //         CHECK:      scf.yield %{{.*}} : vector<1x4x4xf32>
-//         CHECK:    scf.yield %{{.*}} : vector<4x4xf32>
+//         CHECK:    scf.yield %{{.*}} : vector<1x4x4xf32>
 //         CHECK:    vector.transfer_write {{.*}} : vector<4x4xf32>, memref<1x112x112x64xf32, #hal.descriptor_type<storage_buffer>>
 
 // -----
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
index 8d4e1d5..8112cd0 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
@@ -57,7 +57,7 @@
 // CHECK-LABEL: func.func @nhwc_conv_static_shape_f32()
 
 // No vector transfer write ops generated for the linalg.fill op: initial values are forwarded to loops.
-// CHECK-NOT: vector.transfer
+// CHECK-NOT: vector.transfer_write
 
 // Check tiling loop along filter height/width and input channel
 //      CHECK: scf.for %{{.*}} = %c0 to %c3 step %c1
@@ -439,14 +439,14 @@
 // CHECK-LABEL: func.func @nchw_conv_static_shape_f32()
 
 // No vector transfer write ops generated for the linalg.fill op: initial values are forwarded to loops.
-// CHECK-NOT: vector.transfer
+// CHECK-NOT: vector.transfer_write
 
 // Check tiling loop along input channel and filter height/width
 // TODO: enable vector hoisting
 //      CHECK: scf.for %{{.*}} = %c0 to %c1280 step %c4
-// CHECK-SAME:     -> (tensor<2x8x1x4xf32>)
+// CHECK-SAME:     -> (vector<4xf32>{{(, vector<4xf32>)+}})
 //      CHECK:   scf.for %{{.*}} = %c0 to %c3 step %c1
-// CHECK-SAME:       -> (tensor<2x8x1x4xf32>)
+// CHECK-SAME:       -> (vector<4xf32>{{(, vector<4xf32>)+}})
 //      CHECK:     scf.for %{{.*}} = %c0 to %c3 step %c1
 // CHECK-SAME:         -> (vector<4xf32>{{(, vector<4xf32>)+}})
 
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
index 0cd0efd..c22bc5f 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir
@@ -185,23 +185,23 @@
 //          CHECK:     %[[ISS1:.+]] = vector.insert_strided_slice %{{.+}}, %[[ISS0]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
 //          CHECK:     %[[ISS2:.+]] = vector.insert_strided_slice %{{.+}}, %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
 //          CHECK:     %[[ISS3:.+]] = vector.insert_strided_slice %{{.+}}, %[[ISS2]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
-//          CHECK:     scf.yield %arg5, %[[ISS3]], %[[ISS1]] : tensor<2x8xf16>, vector<8xf16>, vector<8xf16>
+//          CHECK:     scf.yield %arg5, %[[ISS1]], %[[ISS3]] : tensor<2x8xf16>, vector<8xf16>, vector<8xf16>
 //          CHECK:   }
 // CHECK:   %[[X0:.+]] = vector.transfer_read %[[X]]{{.+}} : tensor<2x8xf16>, vector<8xf16>
 // CHECK:   %[[X1:.+]] = vector.transfer_read %[[X]]{{.+}} : tensor<2x8xf16>, vector<8xf16>
-// CHECK:   %[[LHS0:.+]] = vector.extract_strided_slice %[[FOR]]#2 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:   %[[LHS0:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
 // CHECK:   %[[RHS0:.+]] = vector.extract_strided_slice %[[X0]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
 // CHECK:   %[[DIV0:.+]] = arith.divf %[[LHS0]], %[[RHS0]]
 // CHECK:   %[[ISS0:.+]] = vector.insert_strided_slice %[[DIV0]], %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
-// CHECK:   %[[LHS1:.+]] = vector.extract_strided_slice %[[FOR]]#2 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:   %[[LHS1:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
 // CHECK:   %[[RHS1:.+]] = vector.extract_strided_slice %[[X0]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
 // CHECK:   %[[DIV1:.+]] = arith.divf %[[LHS1]], %[[RHS1]]
 // CHECK:   %[[ISS1:.+]] = vector.insert_strided_slice %[[DIV1]], %[[ISS0]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
-// CHECK:   %[[LHS2:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:   %[[LHS2:.+]] = vector.extract_strided_slice %[[FOR]]#2 {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
 // CHECK:   %[[RHS2:.+]] = vector.extract_strided_slice %[[X1]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
 // CHECK:   %[[DIV2:.+]] = arith.divf %[[LHS2]], %[[RHS2]]
 // CHECK:   %[[ISS2:.+]] = vector.insert_strided_slice %[[DIV2]], %[[ZERO]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
-// CHECK:   %[[LHS3:.+]] = vector.extract_strided_slice %[[FOR]]#1 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK:   %[[LHS3:.+]] = vector.extract_strided_slice %[[FOR]]#2 {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
 // CHECK:   %[[RHS3:.+]] = vector.extract_strided_slice %[[X1]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
 // CHECK:   %[[DIV3:.+]] = arith.divf %[[LHS3]], %[[RHS3]]
 // CHECK:   %[[ISS3:.+]] = vector.insert_strided_slice %[[DIV3]], %[[ISS2]] {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/BUILD.bazel
index a1597e1..991e899 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/BUILD.bazel
@@ -60,6 +60,7 @@
         "@llvm-project//mlir:SCFTransformOps",
         "@llvm-project//mlir:TensorTransforms",
         "@llvm-project//mlir:TensorTransformOps",
+        "@llvm-project//mlir:TransformLoopExtension",
         "@llvm-project//mlir:VectorTransforms",
         "@llvm-project//mlir:VectorTransformOps",
         # Other Stuff
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/CMakeLists.txt
index 2e5e1df..c96141b 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/CMakeLists.txt
@@ -54,6 +54,7 @@
     MLIRTensorTransformOps
     MLIRTensorTransforms
     MLIRTransformDialect
+    MLIRTransformLoopExtension
     MLIRVectorDialect
     MLIRVectorTransformOps
     MLIRVectorTransforms
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp
index 4a122f0..080601b 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/Common/Common.cpp
@@ -14,9 +14,11 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
@@ -34,7 +36,7 @@
 using iree_compiler::IREE::transform_dialect::
     IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp;
 using transform::FuseIntoContainingOp;
-using transform::HoistRedundantTensorSubsetsOp;
+using transform::HoistLoopInvariantSubsetsOp;
 using transform::MatchOp;
 using transform::MemRefEraseDeadAllocAndStoresOp;
 using transform::MergeHandlesOp;
@@ -343,7 +345,9 @@
 
 /// Hoist redundant subet ops.
 void mlir::iree_compiler::buildHoisting(ImplicitLocOpBuilder &b, Value funcH) {
-  b.create<HoistRedundantTensorSubsetsOp>(funcH);
+  Value loops =
+      b.create<transform::MatchOp>(funcH, scf::ForOp::getOperationName());
+  b.create<HoistLoopInvariantSubsetsOp>(loops);
 }
 
 /// Bufferize and drop HAL descriptor from memref ops.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LibraryBuilder.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LibraryBuilder.cpp
index 35863d4..623bead 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LibraryBuilder.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LibraryBuilder.cpp
@@ -44,7 +44,7 @@
     return existingType;
   }
   auto *i32Type = llvm::IntegerType::getInt32Ty(context);
-  auto *i8PtrType = llvm::IntegerType::getInt8PtrTy(context);
+  auto *i8PtrType = llvm::PointerType::getUnqual(context);
   auto *type = llvm::StructType::create(context,
                                         {
                                             i32Type,
@@ -144,7 +144,7 @@
     return existingType;
   }
   auto *i32Type = llvm::IntegerType::getInt32Ty(context);
-  auto *i8PtrType = llvm::IntegerType::getInt8PtrTy(context);
+  auto *i8PtrType = llvm::PointerType::getUnqual(context);
   auto *type = llvm::StructType::create(context,
                                         {
                                             i32Type,
@@ -172,7 +172,7 @@
   auto *i32Type = llvm::IntegerType::getInt32Ty(context);
   auto *dispatchFunctionType = makeDispatchFunctionType(context);
   auto *dispatchAttrsType = makeDispatchAttrsType(context);
-  auto *i8PtrType = llvm::IntegerType::getInt8PtrTy(context);
+  auto *i8PtrType = llvm::PointerType::getUnqual(context);
   auto *srcLocType = makeSrcLocType(context);
   auto *type = llvm::StructType::create(
       context,
@@ -220,7 +220,7 @@
     return existingType;
   }
   auto *i32Type = llvm::IntegerType::getInt32Ty(context);
-  auto *i8PtrType = llvm::IntegerType::getInt8PtrTy(context);
+  auto *i8PtrType = llvm::PointerType::getUnqual(context);
   auto *type = llvm::StructType::create(context,
                                         {
                                             i32Type,
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 2556bee..14e7846 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 2556bee5cb108e85c24d50d141cb405106fc4e4d
+Subproject commit 14e7846d6e2c5c99c52ba3882e59bfb021a5f0fa