Integrate LLVM at llvm/llvm-project@274f12a44c60

Updates LLVM usage to match
[274f12a44c60](https://github.com/llvm/llvm-project/commit/274f12a44c60)

PiperOrigin-RevId: 410826501
diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index f470115..7cb55a4 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -4,7 +4,7 @@
 aa533abfd4232b01f9e57041d70114d5a77e6de0 third_party/googletest
 88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
 acd6f6f014c25e46363e718381e0b35205df2d83 third_party/libyaml
-8909dc5ebe8ad39f1743131eb70df402d796acab third_party/llvm-project
+274f12a44c606ecd20152f3e63c4f186793d9a8c third_party/llvm-project
 af14e1ded33c3164d4418c5d234b5b346b6d017c third_party/mlir-hlo
 3f701faace7addc75d16dea8a6cd769fa5b3f260 third_party/musl
 4c7697dbe973ed01ae6fbec37d186ebd05982e1f third_party/pybind11
diff --git a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
index 4a79629..4fd88e3 100644
--- a/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -237,8 +237,10 @@
                             << *funcOp << "\n");
 
     // 5. Perform bufferization.
-    linalg::comprehensive_bufferize::BufferizationState state(
-        aliasInfo, *allocationFn, bvm);
+    linalg::comprehensive_bufferize::BufferizationState state(aliasInfo,
+                                                              *allocationFn);
+    // Merge `bvm` into `state`.
+    for (auto it : bvm.getValueMap()) state.mapValue(it.first, it.second);
     for (Operation *op : ops)
       if (failed(linalg::comprehensive_bufferize::bufferizeOp(
               op, state, /*bufferizedFunctionTypes=*/nullptr)))
diff --git a/iree/compiler/Codegen/Common/VectorizeConv.cpp b/iree/compiler/Codegen/Common/VectorizeConv.cpp
index 11b6736..164662a 100644
--- a/iree/compiler/Codegen/Common/VectorizeConv.cpp
+++ b/iree/compiler/Codegen/Common/VectorizeConv.cpp
@@ -216,12 +216,12 @@
   }
 };
 
-/// Vectorizes linalg.depthwise_conv2D_nhw for a single GPU
-/// invocation. Therefore, the linalg.depthwise_conv2D_nhw op
+/// Vectorizes linalg.depthwise_conv_2d_nhwc_hwc for a single GPU
+/// invocation. Therefore, the linalg.depthwise_conv_2d_nhwc_hwc op
 /// should have a very specific form; other patterns are expected to tile and
 /// distribute larger convolutions into this form for a single GPU invocation.
 ///
-/// The linalg.depthwise_conv2D_nhw op should follow:
+/// The linalg.depthwise_conv_2d_nhwc_hwc op should follow:
 /// - Filter: HfWfC format
 /// - Input : NHiWiC format
 /// - Output: NHoWoC format
@@ -237,10 +237,10 @@
 /// Channel is requried to be a multiple of 4 so that we can process them with
 /// load4/store4, which is native to GPUs.
 struct VectorizeLinalgDepthwiseConv
-    : OpRewritePattern<linalg::DepthwiseConv2DNhwOp> {
+    : OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
   using OpRewritePattern::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwOp convOp,
+  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
                                 PatternRewriter &rewriter) const override {
     LLVM_DEBUG(llvm::dbgs() << "inspecting " << convOp << "\n");
 
diff --git a/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir b/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
index fdf8333..ca70a3c 100644
--- a/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
+++ b/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
@@ -114,7 +114,7 @@
 }
 
 //      CHECK: #[[SIZE_MAP:.+]] = affine_map<()[s0, s1, s2, s3] -> (((s0 * s1) * s2) * s3)>
-//      CHECK: #[[OFFSET_MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7] -> (((s4 * s5 + s6) * s2 + s3) * s0 + s1 + s7 floordiv 4)>
+//      CHECK: #[[OFFSET_MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7] -> (s1 + (s3 + (s6 + s4 * s5) * s2) * s0 + s7 floordiv 4)>
 //      CHECK: func @store_subspan_with_all_dynamic_dim
 // CHECK-SAME: (%[[VALUE:.+]]: f32, %[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index)
 //      CHECK:   %[[C0:.+]] = arith.constant 0 : index
@@ -142,7 +142,7 @@
 }
 
 //      CHECK: #[[SIZE_MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) * 32)>
-//      CHECK: #[[OFFSET_MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> (((s3 * 4 + s4) * s1 + s2) * 8 + s0 + s5 floordiv 4)>
+//      CHECK: #[[OFFSET_MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0 + s2 * 8 + ((s3 * 4 + s4) * s1) * 8 + s5 floordiv 4)>
 //      CHECK: func @store_subspan_with_mixed_dynamic_dim
 // CHECK-SAME: (%[[VALUE:.+]]: f32, %[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index)
 //      CHECK:   %[[C0:.+]] = arith.constant 0 : index
@@ -210,7 +210,7 @@
 
 
 //      CHECK: #[[SIZE_MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * s2)>
-//      CHECK: #[[INDEX_MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> ((s2 * s3 + s4) * s0 + s1)>
+//      CHECK: #[[INDEX_MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 + (s4 + s2 * s3) * s0)>
 // CHECK: func @load_store_alloca_dynamic
 // CHECK-SAME: (%[[VAL:.+]]: f32, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index, %[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index)
 //      CHECK:   %[[SIZE:.+]] = affine.apply #[[SIZE_MAP]]()[%[[DIM0]], %[[DIM1]], %[[DIM2]]]
diff --git a/iree/compiler/Codegen/Common/test/vectorize_linalg_conv.mlir b/iree/compiler/Codegen/Common/test/vectorize_linalg_conv.mlir
index d733619..5b27db2 100644
--- a/iree/compiler/Codegen/Common/test/vectorize_linalg_conv.mlir
+++ b/iree/compiler/Codegen/Common/test/vectorize_linalg_conv.mlir
@@ -115,7 +115,7 @@
 // -----
 
 func @vectorize_depthwise_conv(%input: memref<1x3x3x8xf32>, %filter: memref<1x1x8xf32>, %output: memref<1x2x2x8xf32>) {
-  linalg.depthwise_conv2D_nhw {dilations = dense<2> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%input, %filter : memref<1x3x3x8xf32>, memref<1x1x8xf32>) outs(%output : memref<1x2x2x8xf32>)
+  linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<2> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%input, %filter : memref<1x3x3x8xf32>, memref<1x1x8xf32>) outs(%output : memref<1x2x2x8xf32>)
   return
 }
 
@@ -180,8 +180,8 @@
 
 // CHECK-LABEL: func @do_not_vectorize_depthwise_conv_with_non_1_filter_height
 func @do_not_vectorize_depthwise_conv_with_non_1_filter_height(%input: memref<1x2x3x4xf32>, %filter: memref<2x1x4xf32>, %output: memref<1x1x2x4xf32>) {
-  // CHECK: linalg.depthwise_conv2D_nhw
-  linalg.depthwise_conv2D_nhw {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+  // CHECK: linalg.depthwise_conv_2d_nhwc_hwc
+  linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
      ins(%input, %filter : memref<1x2x3x4xf32>, memref<2x1x4xf32>)
     outs(%output : memref<1x1x2x4xf32>)
   return
@@ -191,8 +191,8 @@
 
 // CHECK-LABEL: func @do_not_vectorize_depthwise_conv_with_non_1_filter_width
 func @do_not_vectorize_depthwise_conv_with_non_1_filter_width(%input: memref<1x1x4x4xf32>, %filter: memref<1x2x4xf32>, %output: memref<1x1x2x4xf32>) {
-  // CHECK: linalg.depthwise_conv2D_nhw
-  linalg.depthwise_conv2D_nhw {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+  // CHECK: linalg.depthwise_conv_2d_nhwc_hwc
+  linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
      ins(%input, %filter : memref<1x1x4x4xf32>, memref<1x2x4xf32>)
     outs(%output : memref<1x1x2x4xf32>)
   return
@@ -238,7 +238,7 @@
 // -----
 
 func @vectorize_depthwise_conv(%input: tensor<1x3x3x8xf32>, %filter: tensor<1x1x8xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
-  %0 = linalg.depthwise_conv2D_nhw {dilations = dense<2> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+  %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<2> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
      ins(%input, %filter : tensor<1x3x3x8xf32>, tensor<1x1x8xf32>)
     outs(%init : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
   return %0 : tensor<1x2x2x8xf32>
diff --git a/iree/compiler/Codegen/LLVMCPU/BUILD b/iree/compiler/Codegen/LLVMCPU/BUILD
index 97c33b6..953be7c 100644
--- a/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -44,6 +44,7 @@
         "@llvm-project//mlir:AffineToStandardTransforms",
         "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:ArithmeticToLLVM",
+        "@llvm-project//mlir:ArithmeticTransforms",
         "@llvm-project//mlir:CFGTransforms",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
diff --git a/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 3ba3f6e..8ae0b0d 100644
--- a/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -11,6 +11,7 @@
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
 #include "mlir/Pass/PassManager.h"
@@ -175,6 +176,7 @@
   passManager.addPass(createFoldTensorExtractOpPass());
 
   // (HAL, IREE, Linalg, STD) -> LLVM
+  passManager.addNestedPass<FuncOp>(arith::createArithmeticExpandOpsPass());
   passManager.addNestedPass<FuncOp>(createStdExpandOpsPass());
   passManager.addPass(createConvertToLLVMPass());
 
diff --git a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
index f48dc87..fa20497 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/materialize_launch_configuration.mlir
@@ -837,7 +837,7 @@
               %22 = affine.min affine_map<(d0)[s0] -> (-d0 + 96, s0)>(%arg2)[%workgroup_size_x]
               %23 = linalg.init_tensor [1, %20, %21, %22] : tensor<1x?x?x?xf32>
               %24 = linalg.fill(%cst, %23) {__internal_linalg_transform__ = "workgroup"} : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %25 = linalg.depthwise_conv2D_nhw {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
+              %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
               flow.dispatch.tensor.store %25, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x80x80x96xf32>
             }
           }
@@ -862,7 +862,7 @@
 //  CHECK-DAG:     %[[D1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]]
 //  CHECK-DAG:     %[[D2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]]
 //      CHECK:     hal.return %[[D0]], %[[D1]], %[[D2]]
-//      CHECK:     linalg.depthwise_conv2D_nhw
+//      CHECK:     linalg.depthwise_conv_2d_nhwc_hwc
 //  CHECK-NOT:       lowering.config
 
 // -----
@@ -1030,7 +1030,7 @@
               %16 = affine.min affine_map<(d0) -> (-d0 + 7, 2)>(%arg0)
               %17 = linalg.init_tensor [1, %16, %c7, %c64] : tensor<1x?x?x?xf32>
               %18 = linalg.fill(%cst, %17) {__internal_linalg_transform__ = "workgroup"} : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32> 
-              %19 = linalg.depthwise_conv2D_nhw {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%12, %14 : tensor<1x?x?x?xf32>, tensor<5x5x?xf32>) outs(%18 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
+              %19 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%12, %14 : tensor<1x?x?x?xf32>, tensor<5x5x?xf32>) outs(%18 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
               flow.dispatch.tensor.store %19, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %15, %c7, %c64], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x7x7x576xf32>
             }
           }
diff --git a/iree/compiler/Codegen/LLVMCPU/test/tile_fuse_and_vectorize.mlir b/iree/compiler/Codegen/LLVMCPU/test/tile_fuse_and_vectorize.mlir
index 0bcc089..367fca7 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/tile_fuse_and_vectorize.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/tile_fuse_and_vectorize.mlir
@@ -197,8 +197,8 @@
         %26 = arith.divf %25, %23 : f32
         %27 = arith.addf %26, %arg6 : f32
         %28 = arith.addf %27, %cst_1 : f32
-        %29 = minf %28, %cst_2 : f32
-        %30 = maxf %29, %cst : f32
+        %29 = arith.minf %28, %cst_2 : f32
+        %30 = arith.maxf %29, %cst : f32
         %31 = arith.mulf %30, %cst_3 : f32
         %32 = arith.mulf %31, %27 : f32
         linalg.yield %32 : f32
diff --git a/iree/compiler/Codegen/LLVMGPU/BUILD b/iree/compiler/Codegen/LLVMGPU/BUILD
index bedf42a..46eb194 100644
--- a/iree/compiler/Codegen/LLVMGPU/BUILD
+++ b/iree/compiler/Codegen/LLVMGPU/BUILD
@@ -47,6 +47,7 @@
         "@llvm-project//mlir:Affine",
         "@llvm-project//mlir:AffineToStandard",
         "@llvm-project//mlir:ArithmeticToLLVM",
+        "@llvm-project//mlir:ArithmeticTransforms",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:GPUToNVVMTransforms",
         "@llvm-project//mlir:GPUToROCDLTransforms",
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 8425178..e8acf40 100644
--- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -80,8 +80,8 @@
   bool fusedOpSupported = true;
   entryPoint.walk([&fusedOpSupported](linalg::GenericOp linalgOp) {
     for (Operation &fusedOp : linalgOp.getOps()) {
-      if (!isa<arith::AddFOp, arith::MulFOp, MaxFOp, MinFOp, linalg::YieldOp,
-               arith::DivFOp>(fusedOp)) {
+      if (!isa<arith::AddFOp, arith::MulFOp, arith::MaxFOp, arith::MinFOp,
+          linalg::YieldOp, arith::DivFOp>(fusedOp)) {
         fusedOpSupported = false;
         break;
       }
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
index cd8e857..08bab9c 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
@@ -168,28 +168,29 @@
           .setDistributionOptions(invocationDistributionOptions);
 
   MLIRContext *context = patterns.getContext();
-  patterns.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>,
-                  linalg::LinalgTilingPattern<linalg::FillOp>,
-                  linalg::LinalgTilingPattern<linalg::CopyOp>,
-                  linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
-                  linalg::LinalgTilingPattern<linalg::GenericOp>,
-                  linalg::LinalgTilingPattern<linalg::Conv2DNhwcHwcfOp>,
-                  linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwOp>,
-                  linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcOp>,
-                  linalg::LinalgTilingPattern<linalg::PoolingNhwcMaxOp>,
-                  linalg::LinalgTilingPattern<linalg::PoolingNhwcMinOp>,
-                  linalg::LinalgTilingPattern<linalg::PoolingNhwcSumOp>,
-                  IREE::LinalgExt::TiledOpInterfaceTilingPattern>(
-      context, tilingOptions,
-      linalg::LinalgTransformationFilter(
-          {Identifier::get(getWorkgroupKTiledMarker(), context),
-           Identifier::get(getWorkgroupMemoryMarker(), context)},
-          Identifier::get(getVectorizeMarker(), context))
-          .addFilter([](Operation *op) {
-            // FFT doesn't support second level of tiling yet.
-            return success(!isa<IREE::LinalgExt::FftOp>(op));
-          })
-          .setMatchByDefault());
+  patterns
+      .insert<linalg::LinalgTilingPattern<linalg::MatmulOp>,
+              linalg::LinalgTilingPattern<linalg::FillOp>,
+              linalg::LinalgTilingPattern<linalg::CopyOp>,
+              linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
+              linalg::LinalgTilingPattern<linalg::GenericOp>,
+              linalg::LinalgTilingPattern<linalg::Conv2DNhwcHwcfOp>,
+              linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcHwcOp>,
+              linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcHwcmOp>,
+              linalg::LinalgTilingPattern<linalg::PoolingNhwcMaxOp>,
+              linalg::LinalgTilingPattern<linalg::PoolingNhwcMinOp>,
+              linalg::LinalgTilingPattern<linalg::PoolingNhwcSumOp>,
+              IREE::LinalgExt::TiledOpInterfaceTilingPattern>(
+          context, tilingOptions,
+          linalg::LinalgTransformationFilter(
+              {Identifier::get(getWorkgroupKTiledMarker(), context),
+               Identifier::get(getWorkgroupMemoryMarker(), context)},
+              Identifier::get(getVectorizeMarker(), context))
+              .addFilter([](Operation *op) {
+                // FFT doesn't support second level of tiling yet.
+                return success(!isa<IREE::LinalgExt::FftOp>(op));
+              })
+              .setMatchByDefault());
 }
 
 static LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) {
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index ccadf62..f4f1a9b 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
@@ -149,6 +150,7 @@
   pm.addNestedPass<FuncOp>(createCanonicalizerPass());
   pm.addNestedPass<FuncOp>(createCSEPass());
 
+  pm.addNestedPass<FuncOp>(arith::createArithmeticExpandOpsPass());
   pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
   pm.addPass(createLowerAffinePass());
 
diff --git a/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp b/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
index 877cac7..1635c3f 100644
--- a/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
@@ -38,7 +38,7 @@
         return setConvOpConfig(op, subgroupSize,
                                /*bestTilingFactor=*/32);
       })
-      .Case<linalg::DepthwiseConv2DNhwOp>([subgroupSize](auto op) {
+      .Case<linalg::DepthwiseConv2DNhwcHwcOp>([subgroupSize](auto op) {
         return setConvOpConfig(op, subgroupSize,
                                /*bestTilingFactor=*/16);
       })
diff --git a/iree/compiler/Codegen/SPIRV/BUILD b/iree/compiler/Codegen/SPIRV/BUILD
index ec773b2..4a6e3dc 100644
--- a/iree/compiler/Codegen/SPIRV/BUILD
+++ b/iree/compiler/Codegen/SPIRV/BUILD
@@ -53,6 +53,7 @@
         "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:ArithmeticDialect",
         "@llvm-project//mlir:ArithmeticToSPIRV",
+        "@llvm-project//mlir:ArithmeticTransforms",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:GPUDialect",
         "@llvm-project//mlir:GPUToSPIRV",
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 1413a0d..dff4182 100644
--- a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -137,7 +137,7 @@
   // Tiling along reduction dimensions
   if (isa<linalg::Conv2DNhwcHwcfOp>(linalgOp)) {
     tileSizes.push_back({0, 0, 0, 0, 1, 1, 4});
-  } else if (isa<linalg::DepthwiseConv2DNhwOp>(linalgOp)) {
+  } else if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(linalgOp)) {
     tileSizes.push_back({0, 0, 0, 0, 1, 1});
   } else {
     return success();
@@ -551,7 +551,7 @@
         // If unsuccessful, try to tile and distribute.
         return setDefaultOpConfig(limits, op);
       })
-      .Case<linalg::Conv2DNhwcHwcfOp, linalg::DepthwiseConv2DNhwOp>(
+      .Case<linalg::Conv2DNhwcHwcfOp, linalg::DepthwiseConv2DNhwcHwcOp>(
           [limits](auto op) {
             // Try to tile and vectorize first. It's common to see 32 threads
             // per subgroup for GPUs.
diff --git a/iree/compiler/Codegen/SPIRV/MaliConfig.cpp b/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
index 9577d43..ced4698 100644
--- a/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
@@ -44,7 +44,7 @@
         return setConvOpConfig(op, subgroupSize,
                                /*bestTilingFactor=*/16);
       })
-      .Case<linalg::DepthwiseConv2DNhwOp>([subgroupSize](auto op) {
+      .Case<linalg::DepthwiseConv2DNhwcHwcOp>([subgroupSize](auto op) {
         return setConvOpConfig(op, subgroupSize,
                                /*bestTilingFactor=*/16);
       })
diff --git a/iree/compiler/Codegen/SPIRV/Passes.cpp b/iree/compiler/Codegen/SPIRV/Passes.cpp
index b3b3306..60448c4 100644
--- a/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -20,6 +20,7 @@
 #include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -66,6 +67,7 @@
   // subview ops.
   pm.addPass(memref::createFoldSubViewOpsPass());
   pm.addNestedPass<FuncOp>(Shape::createFoldDimOverShapeCarryingOpPass());
+  pm.addNestedPass<FuncOp>(arith::createArithmeticExpandOpsPass());
   pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp b/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
index f647bcf..613f9a5 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVTileAndDistribute.cpp
@@ -96,23 +96,24 @@
 
   SmallVector<StringRef, 2> matchMarkers = {getWorkgroupMemoryMarker()};
 
-  patterns.insert<linalg::LinalgTilingPattern<linalg::CopyOp>,
-                  linalg::LinalgTilingPattern<linalg::Conv1DNwcWcfOp>,
-                  linalg::LinalgTilingPattern<linalg::Conv3DNdhwcDhwcfOp>,
-                  linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcOp>,
-                  linalg::LinalgTilingPattern<linalg::FillOp>,
-                  linalg::LinalgTilingPattern<linalg::GenericOp>,
-                  linalg::LinalgTilingPattern<linalg::PoolingNhwcMaxOp>,
-                  linalg::LinalgTilingPattern<linalg::PoolingNhwcMinOp>,
-                  linalg::LinalgTilingPattern<linalg::PoolingNhwcSumOp>>(
-      context, tilingOptions,
-      getLinalgMatchAndReplaceMarker(matchMarkers, getVectorizeMarker(),
-                                     context)
-          .setMatchByDefault());
+  patterns
+      .insert<linalg::LinalgTilingPattern<linalg::CopyOp>,
+              linalg::LinalgTilingPattern<linalg::Conv1DNwcWcfOp>,
+              linalg::LinalgTilingPattern<linalg::Conv3DNdhwcDhwcfOp>,
+              linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcHwcmOp>,
+              linalg::LinalgTilingPattern<linalg::FillOp>,
+              linalg::LinalgTilingPattern<linalg::GenericOp>,
+              linalg::LinalgTilingPattern<linalg::PoolingNhwcMaxOp>,
+              linalg::LinalgTilingPattern<linalg::PoolingNhwcMinOp>,
+              linalg::LinalgTilingPattern<linalg::PoolingNhwcSumOp>>(
+          context, tilingOptions,
+          getLinalgMatchAndReplaceMarker(matchMarkers, getVectorizeMarker(),
+                                         context)
+              .setMatchByDefault());
 
   patterns.insert<linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
                   linalg::LinalgTilingPattern<linalg::Conv2DNhwcHwcfOp>,
-                  linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwOp>,
+                  linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcHwcOp>,
                   linalg::LinalgTilingPattern<linalg::MatmulOp>>(
       context, tilingOptions,
       getLinalgMatchAndReplaceMarker(matchMarkers, getTileReductionMarker(),
@@ -143,7 +144,7 @@
 
   patterns.insert<linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
                   linalg::LinalgTilingPattern<linalg::Conv2DNhwcHwcfOp>,
-                  linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwOp>,
+                  linalg::LinalgTilingPattern<linalg::DepthwiseConv2DNhwcHwcOp>,
                   linalg::LinalgTilingPattern<linalg::MatmulOp>>(
       context, tilingOptions, marker);
 }
diff --git a/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir b/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir
index 6649ee4..54b4aee 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_adreno_conv.mlir
@@ -337,7 +337,7 @@
               %22 = affine.min affine_map<(d0)[s0] -> (-d0 + 144, s0)>(%arg2)[%workgroup_size_x]
               %23 = linalg.init_tensor [1, %20, %21, %22] : tensor<1x?x?x?xf32>
               %24 = linalg.fill(%cst, %23) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %25 = linalg.depthwise_conv2D_nhw {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
+              %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
               flow.dispatch.tensor.store %25, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x28x28x144xf32>
             }
           }
@@ -367,7 +367,7 @@
 // CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]
 
 //      CHECK: func @dwconv_28x28x144()
-//      CHECK:   linalg.depthwise_conv2D_nhw
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
@@ -431,7 +431,7 @@
               %22 = affine.min affine_map<(d0)[s0] -> (-d0 + 8, s0)>(%arg2)[%workgroup_size_x]
               %23 = linalg.init_tensor [1, %20, %21, %22] : tensor<1x?x?x?xf32>
               %24 = linalg.fill(%cst, %23) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %25 = linalg.depthwise_conv2D_nhw {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
+              %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
               flow.dispatch.tensor.store %25, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x4x4x8xf32>
             }
           }
@@ -460,5 +460,5 @@
 // CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]
 
 //      CHECK: func @dwconv_4x4x8()
-//      CHECK:   linalg.depthwise_conv2D_nhw
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
diff --git a/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir b/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir
index 102b5ef..960ec5f 100644
--- a/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/config_mali_conv.mlir
@@ -335,7 +335,7 @@
               %22 = affine.min affine_map<(d0)[s0] -> (-d0 + 144, s0)>(%arg2)[%workgroup_size_x]
               %23 = linalg.init_tensor [1, %20, %21, %22] : tensor<1x?x?x?xf32>
               %24 = linalg.fill(%cst, %23) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
-              %25 = linalg.depthwise_conv2D_nhw {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
+              %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
               flow.dispatch.tensor.store %25, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x28x28x144xf32>
             }
           }
@@ -365,7 +365,7 @@
 // CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]
 
 //      CHECK: func @dwconv_28x28x144()
-//      CHECK:   linalg.depthwise_conv2D_nhw
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
 
 // -----
@@ -430,7 +430,7 @@
               %22 = affine.min affine_map<(d0)[s0] -> (-d0 + 8, s0)>(%arg2)[%workgroup_size_x]
               %23 = linalg.init_tensor [1, %20, %21, %22] : tensor<1x?x?x?xf32>
               %24 = linalg.fill(%cst, %23) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32> 
-              %25 = linalg.depthwise_conv2D_nhw {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
+              %25 = linalg.depthwise_conv_2d_nhwc_hwc {__internal_linalg_transform__ = "workgroup", dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%14, %16 : tensor<1x?x?x?xf32>, tensor<3x3x?xf32>) outs(%24 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
               flow.dispatch.tensor.store %25, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %17, %18, %19], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x1x2x8xf32>
             }
           }
@@ -459,6 +459,6 @@
 // CHECK-NEXT:   hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z]]
 
 //      CHECK: func @dwconv_1x2x8()
-//      CHECK:   linalg.depthwise_conv2D_nhw
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
 // CHECK-SAME:     lowering.config = #[[CONFIG]]
 
diff --git a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
index fdbe7b6..a5443a1 100644
--- a/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
+++ b/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_conv.mlir
@@ -160,7 +160,7 @@
               %19 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 56)>(%arg1)[%workgroup_size_y]
               %20 = memref.subview %2[0, %arg0, %arg1, %arg2] [1, %18, %19, %15] [1, 1, 1, 1] : memref<1x56x56x96xf32> to memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>
               linalg.fill(%cst, %20) {lowering.config = #config} : f32, memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>
-              linalg.depthwise_conv2D_nhw {lowering.config = #config, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%16, %17 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1225824 + s0 + d1 * 10848 + d2 * 96 + d3)>>, memref<3x3x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2)>>) outs(%20 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>)
+              linalg.depthwise_conv_2d_nhwc_hwc {lowering.config = #config, dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%16, %17 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 1225824 + s0 + d1 * 10848 + d2 * 96 + d3)>>, memref<3x3x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 288 + s0 + d1 * 96 + d2)>>) outs(%20 : memref<1x?x?x?xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 301056 + s0 + d1 * 5376 + d2 * 96 + d3)>>)
             }
           }
         }
@@ -180,7 +180,7 @@
 // For linalg.fill
 // CHECK: vector.transfer_write
 
-// For linalg.depthwise_conv2D_nhw
+// For linalg.depthwise_conv_2d_nhwc_hwc
 // CHECK: vector.transfer_read
 
 // check tiling loop along filter height/width and input channel
@@ -194,5 +194,5 @@
 
 // CHECK-COUNT-2: scf.yield
 
-// For linalg.depthwise_conv2D_nhw
+// For linalg.depthwise_conv_2d_nhwc_hwc
 // CHECK: vector.transfer_write
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp
index 92b7f63..4fca00e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertConv2DToImg2ColPass.cpp
@@ -170,11 +170,11 @@
 // by transposing both input filter so channles are outer most the computation
 // is a batched matrix-vector product.
 class DepthwiseConv2DNHWCHWCImg2ColMatmulConversion
-    : public OpRewritePattern<linalg::DepthwiseConv2DNhwOp> {
+    : public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
  public:
-  using OpRewritePattern<linalg::DepthwiseConv2DNhwOp>::OpRewritePattern;
+  using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwOp convOp,
+  LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
                                 PatternRewriter &rewriter) const override {
     RankedTensorType inputTensorType =
         convOp.getInputOperand(0)->get().getType().dyn_cast<RankedTensorType>();
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/conv2d_to_img2col.mlir b/iree/compiler/Dialect/Flow/Transforms/test/conv2d_to_img2col.mlir
index 5f65869..268e87a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/conv2d_to_img2col.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/conv2d_to_img2col.mlir
@@ -34,7 +34,7 @@
 // -----
 
 func @depthwise_conv_hwc_114x16x3(%input: tensor<1x114x114x16xf32>, %filter: tensor<3x3x16xf32>, %output: tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32> {
-    %0 = linalg.depthwise_conv2D_nhw {
+    %0 = linalg.depthwise_conv_2d_nhwc_hwc {
       dilations = dense<1> : tensor<2xi64>,
       strides = dense<1> : tensor<2xi64>
     } ins(%input, %filter : tensor<1x114x114x16xf32>, tensor<3x3x16xf32>) outs(%output : tensor<1x112x112x16xf32>) -> tensor<1x112x112x16xf32>
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index c949988..daa32cc 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -423,14 +423,14 @@
   %cst = arith.constant 0.000000e+00 : f32
   %1 = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
   %2 = linalg.fill(%cst, %1) : f32, tensor<1x56x56x96xf32> -> tensor<1x56x56x96xf32>
-  %4 = linalg.depthwise_conv2D_nhw {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%input, %filter : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) outs(%2 : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
+  %4 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%input, %filter : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>) outs(%2 : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
   return %4 : tensor<1x56x56x96xf32>
 }
 
 // CHECK-LABEL: func @depthwise_conv2d
 // CHECK: scf.for
 // CHECK: scf.for
-// CHECK: linalg.depthwise_conv2D_nhw
+// CHECK: linalg.depthwise_conv_2d_nhwc_hwc
 
 // -----
 
diff --git a/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD b/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD
index b5f7480..0656d11 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD
+++ b/iree/compiler/Dialect/Modules/VMVX/Transforms/BUILD
@@ -39,6 +39,7 @@
         "@llvm-project//mlir:Affine",
         "@llvm-project//mlir:AffineToStandardTransforms",
         "@llvm-project//mlir:AffineTransforms",
+        "@llvm-project//mlir:ArithmeticTransforms",
         "@llvm-project//mlir:CFGTransforms",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgInterfaces",
diff --git a/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp b/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp
index 869948d..bc05716 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Modules/VMVX/Transforms/Passes.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
@@ -57,6 +58,7 @@
   nestedModulePM.addNestedPass<FuncOp>(createCSEPass());
   nestedModulePM.addNestedPass<FuncOp>(createConvertVectorToSCFPass());
   nestedModulePM.addNestedPass<FuncOp>(createCanonicalizerPass());
+  nestedModulePM.addNestedPass<FuncOp>(arith::createArithmeticExpandOpsPass());
   nestedModulePM.addNestedPass<FuncOp>(createStdExpandOpsPass());
 
   // Handle tensor-type constants.
diff --git a/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index 0280563..11442ab 100644
--- a/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -85,7 +85,7 @@
 // Example:
 //  %min = util.range.min %0, %1 : index
 // ->
-//  %min = minui %0, %1 : index
+//  %min = arith.minui %0, %1 : index
 template <typename RangeOpT, typename StdOpT>
 struct ExpandSimpleRangeOp : public OpRewritePattern<RangeOpT> {
   using OpRewritePattern<RangeOpT>::OpRewritePattern;
@@ -147,13 +147,13 @@
 
 void RangeMinOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                              MLIRContext *context) {
-  results.insert<ExpandSimpleRangeOp<RangeMinOp, mlir::MinUIOp>>(context);
+  results.insert<ExpandSimpleRangeOp<RangeMinOp, arith::MinUIOp>>(context);
   results.insert<SimplifyUniformRangeOp<RangeMinOp, INT64_MAX, xmin>>(context);
 }
 
 void RangeMaxOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                              MLIRContext *context) {
-  results.insert<ExpandSimpleRangeOp<RangeMaxOp, mlir::MaxUIOp>>(context);
+  results.insert<ExpandSimpleRangeOp<RangeMaxOp, arith::MaxUIOp>>(context);
   results.insert<SimplifyUniformRangeOp<RangeMaxOp, INT64_MIN, xmax>>(context);
 }
 
@@ -227,10 +227,10 @@
         rewriter.getIntegerAttr(op.max().getType(),
                                 constantMax - constantMin + 1),
         op.max().getType());
-    min = min ? rewriter.create<mlir::MinUIOp>(op.getLoc(), min, constantMinOp)
+    min = min ? rewriter.create<arith::MinUIOp>(op.getLoc(), min, constantMinOp)
                     .getResult()
               : constantMinOp.getResult();
-    max = max ? rewriter.create<mlir::MaxUIOp>(op.getLoc(), max, constantMaxOp)
+    max = max ? rewriter.create<arith::MaxUIOp>(op.getLoc(), max, constantMaxOp)
                     .getResult()
               : constantMaxOp.getResult();
 
@@ -252,8 +252,8 @@
                               rewriter);
     } else if (op.offsets().size() == 2) {
       // Two ranges turn into min/max.
-      minValue = rewriter.create<mlir::MinUIOp>(loc, op.offsets().front(),
-                                                op.offsets().back());
+      minValue = rewriter.create<arith::MinUIOp>(loc, op.offsets().front(),
+                                                 op.offsets().back());
       auto one = rewriter.create<arith::ConstantOp>(
           loc, rewriter.getIntegerAttr(op.min().getType(), 1),
           op.min().getType());
@@ -261,7 +261,7 @@
                                  op.lengths().front(), one, rewriter);
       auto endRhs = makeRangeEnd(loc, op.offsets().back(), op.lengths().back(),
                                  one, rewriter);
-      maxValue = rewriter.create<mlir::MaxUIOp>(loc, endLhs, endRhs);
+      maxValue = rewriter.create<arith::MaxUIOp>(loc, endLhs, endRhs);
     }
     if (!minValue || !maxValue) return failure();
     rewriter.replaceOp(op, {minValue, maxValue});
diff --git a/iree/compiler/Dialect/Util/IR/test/range_folding.mlir b/iree/compiler/Dialect/Util/IR/test/range_folding.mlir
index 8d84f09..33b11b9 100644
--- a/iree/compiler/Dialect/Util/IR/test/range_folding.mlir
+++ b/iree/compiler/Dialect/Util/IR/test/range_folding.mlir
@@ -20,7 +20,7 @@
 
 // CHECK-LABEL: @rangeMinExpand
 func @rangeMinExpand(%arg0: index, %arg1: index) -> index {
-  // CHECK: %[[MIN:.+]] = minui %arg0, %arg1 : index
+  // CHECK: %[[MIN:.+]] = arith.minui %arg0, %arg1 : index
   %0 = util.range.min %arg0, %arg1 : index
   // CHECK: return %[[MIN]]
   return %0 : index
@@ -63,8 +63,8 @@
   %c3 = arith.constant 3 : index
   // CHECK: %[[RANGE_MAX_EXC:.+]] = arith.addi %arg0, %arg1
   // CHECK: %[[RANGE_MAX_INC:.+]] = arith.subi %[[RANGE_MAX_EXC]], %c1
-  // CHECK: %[[RANGE_MIN:.+]] = minui %arg0, %c1
-  // CHECK: %[[RANGE_MAX:.+]] = maxui %[[RANGE_MAX_INC]], %c4
+  // CHECK: %[[RANGE_MIN:.+]] = arith.minui %arg0, %c1
+  // CHECK: %[[RANGE_MAX:.+]] = arith.maxui %[[RANGE_MAX_INC]], %c4
   %0:2 = util.range.extents [%c1 for %c2], [%arg0 for %arg1], [%c2 for %c3] : index
   // CHECK: return %[[RANGE_MIN]], %[[RANGE_MAX]]
   return %0#0, %0#1 : index, index
@@ -85,12 +85,12 @@
 
 // CHECK-LABEL: @rangeExtentsExpand2
 func @rangeExtentsExpand2(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> (index, index) {
-  // CHECK: %[[RANGE_MIN:.+]] = minui %arg0, %arg2
+  // CHECK: %[[RANGE_MIN:.+]] = arith.minui %arg0, %arg2
   // CHECK: %[[RANGE0_MAX_EXC:.+]] = arith.addi %arg0, %arg1
   // CHECK: %[[RANGE0_MAX_INC:.+]] = arith.subi %[[RANGE0_MAX_EXC]], %c1
   // CHECK: %[[RANGE1_MAX_EXC:.+]] = arith.addi %arg2, %arg3
   // CHECK: %[[RANGE1_MAX_INC:.+]] = arith.subi %[[RANGE1_MAX_EXC]], %c1
-  // CHECK: %[[RANGE_MAX:.+]] = maxui %[[RANGE0_MAX_INC]], %[[RANGE1_MAX_INC]]
+  // CHECK: %[[RANGE_MAX:.+]] = arith.maxui %[[RANGE0_MAX_INC]], %[[RANGE1_MAX_INC]]
   %0:2 = util.range.extents [%arg0 for %arg1], [%arg2 for %arg3] : index
   // CHECK: return %[[RANGE_MIN]], %[[RANGE_MAX]]
   return %0#0, %0#1 : index, index
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 8909dc5..274f12a 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 8909dc5ebe8ad39f1743131eb70df402d796acab
+Subproject commit 274f12a44c606ecd20152f3e63c4f186793d9a8c