Integrate llvm-project and bump dependencies. (#10103)

* llvm-project: 2c3ca3b684bb2b188d977d47548e79dc559fb8ad
* mlir-hlo: b30f16819dd99be5d00c65a458ab9de12e7b8d13
* tensorflow: 49f97f135a2e1d5d22e60d2a80ec668d53f9708a

Extra changes:

* AbsOp -> AbsFOp
* llvm global access requires a symbol cache.
* Tablegen lib build fix
* callOp and ExtractValue signature changed
* transform.sequence syntax change
* Fix lit tests order restrictions
* Update SPIR-V after memory space changes
* Fix reshape printing
* Fix mhlo type conversion for rank0 tensor

Co-authored-by: Lei Zhang <antiagainst@google.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
index 5289f88..e35dbb1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
@@ -83,13 +83,10 @@
 //  CHECK-DAG:     %[[STEP_X:.+]] = affine.apply #[[MAP1]]()[%[[WG_COUNT_X]]]
 //      CHECK:     scf.for %[[IV1:.+]] = %[[LB_X]] to %[[N]] step %[[STEP_X]]
 //  CHECK-DAG:       %[[TILESIZE_M:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]]]
-//  CHECK-DAG:       %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
-// CHECK-SAME:           offsets = [%[[IV0]], 0], sizes = [%[[TILESIZE_M]], %[[K]]]
+//  CHECK-DAG:       %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]], offsets = [%[[IV0]], 0], sizes = [%[[TILESIZE_M]], %[[K]]]
 //  CHECK-DAG:       %[[TILESIZE_N:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N]]]
-//  CHECK-DAG:       %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
-// CHECK-SAME:           offsets = [0, %[[IV1]]], sizes = [%[[K]], %[[TILESIZE_N]]]
-//  CHECK-DAG:       %[[INIT:.+]] = flow.dispatch.tensor.load %[[INIT_BINDING]]
-// CHECK-SAME:           offsets = [%[[IV0]], %[[IV1]]], sizes = [%[[TILESIZE_M]], %[[TILESIZE_N]]]
+//  CHECK-DAG:       %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]], offsets = [0, %[[IV1]]], sizes = [%[[K]], %[[TILESIZE_N]]]
+//  CHECK-DAG:       %[[INIT:.+]] = flow.dispatch.tensor.load %[[INIT_BINDING]], offsets = [%[[IV0]], %[[IV1]]], sizes = [%[[TILESIZE_M]], %[[TILESIZE_N]]]
 //      CHECK:       %[[GEMM:.+]] = linalg.matmul
 // CHECK-SAME:           ins(%[[LHS]], %[[RHS]] :
 // CHECK-SAME:           outs(%[[INIT]] :
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir b/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
index 0ca1ed5..a29f97b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
@@ -10,7 +10,7 @@
 
 transform.with_pdl_patterns {
 ^bb0(%arg0: !pdl.operation):
-  transform.sequence %arg0 {
+  transform.sequence %arg0 failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["func.func"]} in %arg1
     transform.iree.apply_patterns %0 { canonicalization }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index a4bcfa5..c8dc24a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -495,10 +495,9 @@
   Value loadProcessorData(Location loc, int64_t index, OpBuilder &builder) {
     // Load the value; it should always be in bounds.
     Value dataArrayValue = loadFieldValue(loc, ProcessorField::data, builder);
-    Type elementType =
-        dataArrayValue.getType().cast<LLVM::LLVMArrayType>().getElementType();
-    Value dataValue = builder.create<LLVM::ExtractValueOp>(
-        loc, elementType, dataArrayValue, builder.getI64ArrayAttr(index));
+    SmallVector<int64_t, 1> position = {index};
+    Value dataValue =
+        builder.create<LLVM::ExtractValueOp>(loc, dataArrayValue, position);
     return dataValue;
   }
 
@@ -555,7 +554,7 @@
                                          /*import_func_ptr=*/importPtrValue,
                                          /*import_params=*/params,
                                      });
-    return callOp.getResult(0);
+    return callOp.getResult();
   }
 
  private:
@@ -577,35 +576,35 @@
     auto environmentPtrValue = funcOp.getArgument(0);
     Value environmentValue =
         builder.create<LLVM::LoadOp>(loc, environmentPtrValue);
-    Type fieldType = environmentType.getBody()[(int)field];
-    return builder.createOrFold<LLVM::ExtractValueOp>(
-        loc, fieldType, environmentValue, builder.getI64ArrayAttr((int)field));
+    SmallVector<int64_t, 1> position = {int64_t(field)};
+    return builder.createOrFold<LLVM::ExtractValueOp>(loc, environmentValue,
+                                                      position);
   }
 
   Value loadFieldValue(Location loc, ProcessorField field, OpBuilder &builder) {
     Value processorValue =
         loadFieldValue(loc, EnvironmentField::processor, builder);
-    Type fieldType = processorType.getBody()[(int)field];
-    return builder.createOrFold<LLVM::ExtractValueOp>(
-        loc, fieldType, processorValue, builder.getI64ArrayAttr((int)field));
+    SmallVector<int64_t, 1> position = {int64_t(field)};
+    return builder.createOrFold<LLVM::ExtractValueOp>(loc, processorValue,
+                                                      position);
   }
 
   Value loadFieldValue(Location loc, DispatchStateField field,
                        OpBuilder &builder) {
     Value statePtrValue = funcOp.getArgument(1);
     Value stateValue = builder.createOrFold<LLVM::LoadOp>(loc, statePtrValue);
-    Type fieldType = dispatchStateType.getBody()[(int)field];
-    return builder.createOrFold<LLVM::ExtractValueOp>(
-        loc, fieldType, stateValue, builder.getI64ArrayAttr((int)field));
+    SmallVector<int64_t, 1> position = {int64_t(field)};
+    return builder.createOrFold<LLVM::ExtractValueOp>(loc, stateValue,
+                                                      position);
   }
 
   Value loadFieldValue(Location loc, WorkgroupStateField field,
                        OpBuilder &builder) {
     Value statePtrValue = funcOp.getArgument(2);
     Value stateValue = builder.createOrFold<LLVM::LoadOp>(loc, statePtrValue);
-    Type fieldType = dispatchStateType.getBody()[(int)field];
-    return builder.createOrFold<LLVM::ExtractValueOp>(
-        loc, fieldType, stateValue, builder.getI64ArrayAttr((int)field));
+    SmallVector<int64_t, 1> position = {int64_t(field)};
+    return builder.createOrFold<LLVM::ExtractValueOp>(loc, stateValue,
+                                                      position);
   }
 
   LLVM::LLVMFuncOp funcOp;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
index e87abd9..5301830 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
@@ -887,9 +887,9 @@
     // Extract result vectors from the asm op.
     SmallVector<Value> resVec;
     for (int i = 0; i < kernel.accRegs; ++i) {
-      resVec.push_back(rewriter.create<LLVM::ExtractValueOp>(
-          loc, getAccRegVectorType(), asmOp.getRes(),
-          rewriter.getI64ArrayAttr({i})));
+      SmallVector<int64_t, 1> position = {i};
+      resVec.push_back(
+          rewriter.create<LLVM::ExtractValueOp>(loc, asmOp.getRes(), position));
     }
     return resVec;
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/emit_vectorization_remarks.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/emit_vectorization_remarks.mlir
index 7dd520a..5e913bc 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/emit_vectorization_remarks.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/emit_vectorization_remarks.mlir
@@ -15,7 +15,7 @@
       ins(%arg0 : tensor<?x?xf32>)
       outs(%2 : tensor<?x?xf32>) {
     ^bb0(%arg1: f32, %arg2: f32):
-      %4 = math.abs %arg1 : f32
+      %4 = math.absf %arg1 : f32
       linalg.yield %4 : f32
     } -> tensor<?x?xf32>
     return %3 : tensor<?x?xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
index b0adade..1bd7cee 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
@@ -26,12 +26,13 @@
 namespace iree_compiler {
 
 void ConvertToDynamicSharedMemory(ModuleOp moduleOp) {
+  SymbolTableCollection symbolTableCollection;
   // Collect all the adressOfOps to static shared memory globals.
   SmallVector<LLVM::AddressOfOp> addressOfOps;
   moduleOp.walk([&](LLVM::AddressOfOp addressOfOp) {
     // Check that the global associated with this addressOfOp has shared memory
     // space.
-    if (addressOfOp.getGlobal().getAddrSpace() == 3)
+    if (addressOfOp.getGlobal(symbolTableCollection).getAddrSpace() == 3)
       addressOfOps.push_back(addressOfOp);
   });
   if (addressOfOps.size() == 0) return;
@@ -50,7 +51,7 @@
   for (auto addressOfOpsIt : llvm::enumerate(addressOfOps)) {
     uint32_t offset = 0;
     auto addressOfOp = addressOfOpsIt.value();
-    auto globalOp = addressOfOp.getGlobal();
+    auto globalOp = addressOfOp.getGlobal(symbolTableCollection);
     if (globalMemoryOffsetMap.count(globalOp)) {
       offset = globalMemoryOffsetMap[globalOp];
     } else {
@@ -424,7 +425,7 @@
 }
 
 void populateScalarizeMathOps(RewritePatternSet &patterns) {
-  patterns.add<ScalarizeMathOp<math::SqrtOp>, ScalarizeMathOp<math::AbsOp>,
+  patterns.add<ScalarizeMathOp<math::SqrtOp>, ScalarizeMathOp<math::AbsFOp>,
                ScalarizeMathOp<math::AtanOp>, ScalarizeMathOp<math::Atan2Op>,
                ScalarizeMathOp<math::CeilOp>, ScalarizeMathOp<math::CosOp>,
                ScalarizeMathOp<math::ExpOp>, ScalarizeMathOp<math::ExpM1Op>,
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir
index 756a21a..b6326a4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/linalg_transform.mlir
@@ -60,11 +60,11 @@
           // FOREACH-TO-GPU: scf.if %[[COND2]] {
           // FOREACH-TO-GPU:   affine.min #{{.*}}()[%[[TIDX]]]
           // FOREACH-TO-GPU:   affine.min #{{.*}}()[%[[TIDY]]]
-          // FOREACH-TO-GPU:   affine.apply #{{.*}}()[%[[TIDX]]]
-          // FOREACH-TO-GPU:   %[[svA:.*]] = memref.subview {{.*}} : memref<250x500xf32> to memref<?x500xf32
-          // FOREACH-TO-GPU:   affine.apply #{{.*}}()[%[[TIDY]]]
-          // FOREACH-TO-GPU:   %[[svB:.*]] = memref.subview {{.*}} : memref<500x1020xf32> to memref<500x?xf32
-          // FOREACH-TO-GPU:   %[[svC:.*]] = memref.subview {{.*}} : memref<250x1020xf32> to memref<?x?xf32
+          // FOREACH-TO-GPU-DAG:   affine.apply #{{.*}}()[%[[TIDX]]]
+          // FOREACH-TO-GPU-DAG:   %[[svA:.*]] = memref.subview {{.*}} : memref<250x500xf32> to memref<?x500xf32
+          // FOREACH-TO-GPU-DAG:   affine.apply #{{.*}}()[%[[TIDY]]]
+          // FOREACH-TO-GPU-DAG:   %[[svB:.*]] = memref.subview {{.*}} : memref<500x1020xf32> to memref<500x?xf32
+          // FOREACH-TO-GPU-DAG:   %[[svC:.*]] = memref.subview {{.*}} : memref<250x1020xf32> to memref<?x?xf32
           // FOREACH-TO-GPU:   linalg.matmul ins(%[[svA]], %[[svB]] : memref<?x500xf32{{.*}}>, memref<500x?xf32{{.*}}>) outs(%[[svC]] : memref<?x?xf32{{.*}}>)
           // FOREACH-TO-GPU: }
           // FOREACH-TO-GPU: gpu.barrier
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index eedd582..29ec95e 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -18,8 +18,11 @@
 #include "iree/compiler/Codegen/Passes.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
 #include "mlir/Dialect/Func/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -42,10 +45,10 @@
                                                      MemRefType memRefType,
                                                      ValueRange dynamicSizes,
                                                      unsigned alignment) {
-  auto storageClass = SPIRVTypeConverter::getMemorySpaceForStorageClass(
-      spirv::StorageClass::Workgroup);
+  Optional<unsigned> space =
+      spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass::Workgroup);
   MemRefType allocType = MemRefType::get(
-      memRefType.getShape(), memRefType.getElementType(), {}, storageClass);
+      memRefType.getShape(), memRefType.getElementType(), {}, *space);
   return builder
       .create<memref::AllocOp>(loc, allocType, dynamicSizes,
                                builder.getI64IntegerAttr(alignment))
@@ -57,10 +60,10 @@
                                                     MemRefType memRefType,
                                                     ValueRange dynamicSizes,
                                                     unsigned alignment) {
-  auto storageClass = SPIRVTypeConverter::getMemorySpaceForStorageClass(
-      spirv::StorageClass::Function);
+  Optional<unsigned> space =
+      spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass::Function);
   MemRefType allocType = MemRefType::get(
-      memRefType.getShape(), memRefType.getElementType(), {}, storageClass);
+      memRefType.getShape(), memRefType.getElementType(), {}, *space);
   return builder
       .create<memref::AllocaOp>(loc, allocType, dynamicSizes,
                                 builder.getI64IntegerAttr(alignment))
@@ -182,6 +185,7 @@
   pm.addPass(createCanonicalizerPass());
   pm.addPass(createCSEPass());
 
+  pm.addPass(createMapMemRefStorageClassPass());
   pm.addPass(createConvertToSPIRVPass());
 
   OpPassManager &spirvPM = pm.nest<spirv::ModuleOp>();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp
index 8534bd7..0ebfd29 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp
@@ -9,9 +9,11 @@
 #include "iree/compiler/Codegen/SPIRV/Utils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -237,9 +239,11 @@
           // In IREE all MemRefs are originated from subspan ops, which should
           // have identity layout.
           if (!type.getLayout().isIdentity()) return llvm::None;
-          auto flattenedType =
-              MemRefType::get(ShapedType::kDynamicSize, type.getElementType(),
-                              AffineMap(), type.getMemorySpace());
+          auto storage = spirv::mapMemorySpaceToVulkanStorageClass(
+              type.getMemorySpaceAsInt());
+          auto flattenedType = MemRefType::get(
+              ShapedType::kDynamicSize, type.getElementType(), AffineMap(),
+              spirv::StorageClassAttr::get(type.getContext(), *storage));
           return typeConverter.convertType(flattenedType);
         });
 
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
index c618b78..d433249 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_nvidia_matmul_cooperative_ops.mlir
@@ -25,13 +25,13 @@
         cooperative_matrix_properties_nv = [
           #spv.coop_matrix_props<
             a_type = i8, b_type = i8, c_type = i32, k_size = 32,
-            m_size = 8, n_size = 8, result_type = i32, scope = 3 : i32>,
+            m_size = 8, n_size = 8, result_type = i32, scope  = <Subgroup>>,
           #spv.coop_matrix_props<
             a_type = f16, b_type = f16, c_type = f16, k_size = 16,
-            m_size = 16, n_size = 16, result_type = f16, scope = 3 : i32>,
+            m_size = 16, n_size = 16, result_type = f16, scope  = <Subgroup>>,
           #spv.coop_matrix_props<
             a_type = f16, b_type = f16, c_type = f32, k_size = 16,
-            m_size = 16, n_size = 16, result_type = f32, scope = 3 : i32>
+            m_size = 16, n_size = 16, result_type = f32, scope  = <Subgroup>>
         ],
         max_compute_shared_memory_size = 49152,
         max_compute_workgroup_invocations = 1024,
@@ -115,13 +115,13 @@
         cooperative_matrix_properties_nv = [
           #spv.coop_matrix_props<
             a_type = i8, b_type = i8, c_type = i32, k_size = 32,
-            m_size = 8, n_size = 8, result_type = i32, scope = 3 : i32>,
+            m_size = 8, n_size = 8, result_type = i32, scope  = <Subgroup>>,
           #spv.coop_matrix_props<
             a_type = f16, b_type = f16, c_type = f16, k_size = 16,
-            m_size = 16, n_size = 16, result_type = f16, scope = 3 : i32>,
+            m_size = 16, n_size = 16, result_type = f16, scope  = <Subgroup>>,
           #spv.coop_matrix_props<
             a_type = f16, b_type = f16, c_type = f32, k_size = 16,
-            m_size = 16, n_size = 16, result_type = f32, scope = 3 : i32>
+            m_size = 16, n_size = 16, result_type = f32, scope  = <Subgroup>>
         ],
         max_compute_shared_memory_size = 49152,
         max_compute_workgroup_invocations = 1024,
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
index a942785..98f7293 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir
@@ -59,25 +59,25 @@
         // Same type
         // CHECK: spv.mlir.addressof @[[ARG0]]
         // CHECK: spv.mlir.addressof @[[ARG0]]
-        %0 = hal.interface.binding.subspan set(1) binding(2) type(storage_buffer) : memref<4x4xf32>
-        %1 = hal.interface.binding.subspan set(1) binding(2) type(storage_buffer) : memref<4x4xf32>
+        %0 = hal.interface.binding.subspan set(1) binding(2) type(storage_buffer) : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
+        %1 = hal.interface.binding.subspan set(1) binding(2) type(storage_buffer) : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
 
         // Different type
         // CHECK: spv.mlir.addressof @[[ARG1_0]]
         // CHECK: spv.mlir.addressof @[[ARG1_1]]
-        %2 = hal.interface.binding.subspan set(1) binding(3) type(storage_buffer) : memref<4x4xf32>
-        %3 = hal.interface.binding.subspan set(1) binding(3) type(storage_buffer) : memref<4xvector<4xf32>>
+        %2 = hal.interface.binding.subspan set(1) binding(3) type(storage_buffer) : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
+        %3 = hal.interface.binding.subspan set(1) binding(3) type(storage_buffer) : memref<4xvector<4xf32>, #spv.storage_class<StorageBuffer>>
 
         // CHECK: spv.mlir.addressof @[[RET0]]
-        %4 = hal.interface.binding.subspan set(3) binding(4) type(storage_buffer) : memref<4x4xf32>
+        %4 = hal.interface.binding.subspan set(3) binding(4) type(storage_buffer) : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
 
-        %5 = memref.load %0[%c0, %c0] : memref<4x4xf32>
-        %6 = memref.load %1[%c0, %c0] : memref<4x4xf32>
+        %5 = memref.load %0[%c0, %c0] : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
+        %6 = memref.load %1[%c0, %c0] : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
 
-        %7 = memref.load %2[%c0, %c0] : memref<4x4xf32>
-        %8 = memref.load %3[%c0] : memref<4xvector<4xf32>>
+        %7 = memref.load %2[%c0, %c0] : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
+        %8 = memref.load %3[%c0] : memref<4xvector<4xf32>, #spv.storage_class<StorageBuffer>>
 
-        %9 = memref.load %4[%c0, %c0] : memref<4x4xf32>
+        %9 = memref.load %4[%c0, %c0] : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
 
         return
       }
@@ -116,11 +116,11 @@
         // CHECK: spv.mlir.addressof @[[FUNC1_ARG]]
         // CHECK: spv.mlir.addressof @[[FUNC1_RET]]
         %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan set(1) binding(2) type(storage_buffer) : memref<4x4xf32>
-        %1 = hal.interface.binding.subspan set(3) binding(4) type(storage_buffer) : memref<4xvector<4xf32>>
+        %0 = hal.interface.binding.subspan set(1) binding(2) type(storage_buffer) : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
+        %1 = hal.interface.binding.subspan set(3) binding(4) type(storage_buffer) : memref<4xvector<4xf32>, #spv.storage_class<StorageBuffer>>
 
-        %2 = memref.load %0[%c0, %c0] : memref<4x4xf32>
-        %3 = memref.load %1[%c0] : memref<4xvector<4xf32>>
+        %2 = memref.load %0[%c0, %c0] : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
+        %3 = memref.load %1[%c0] : memref<4xvector<4xf32>, #spv.storage_class<StorageBuffer>>
 
         return
       }
@@ -130,11 +130,11 @@
         // CHECK: spv.mlir.addressof @[[FUNC2_ARG]]
         // CHECK: spv.mlir.addressof @[[FUNC2_RET]]
         %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan set(1) binding(2) type(storage_buffer) : memref<4x4xf32> // Same type as previous function
-        %1 = hal.interface.binding.subspan set(3) binding(4) type(storage_buffer) : memref<4x4xf32> // Different type as previous function
+        %0 = hal.interface.binding.subspan set(1) binding(2) type(storage_buffer) : memref<4x4xf32, #spv.storage_class<StorageBuffer>> // Same type as previous function
+        %1 = hal.interface.binding.subspan set(3) binding(4) type(storage_buffer) : memref<4x4xf32, #spv.storage_class<StorageBuffer>> // Different type as previous function
 
-        %2 = memref.load %0[%c0, %c0] : memref<4x4xf32>
-        %3 = memref.load %1[%c0, %c0] : memref<4x4xf32>
+        %2 = memref.load %0[%c0, %c0] : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
+        %3 = memref.load %1[%c0, %c0] : memref<4x4xf32, #spv.storage_class<StorageBuffer>>
 
         return
       }
@@ -160,13 +160,13 @@
     builtin.module {
       func.func @interface_binding() {
         %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<8x5xf32>
-        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<5xf32>
-        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<8x5xf32>
+        %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<8x5xf32, #spv.storage_class<StorageBuffer>>
+        %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<5xf32, #spv.storage_class<StorageBuffer>>
+        %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) : memref<8x5xf32, #spv.storage_class<StorageBuffer>>
 
-        %3 = memref.load %0[%c0, %c0] : memref<8x5xf32>
-        %4 = memref.load %1[%c0] : memref<5xf32>
-        %5 = memref.load %2[%c0, %c0] : memref<8x5xf32>
+        %3 = memref.load %0[%c0, %c0] : memref<8x5xf32, #spv.storage_class<StorageBuffer>>
+        %4 = memref.load %1[%c0] : memref<5xf32, #spv.storage_class<StorageBuffer>>
+        %5 = memref.load %2[%c0, %c0] : memref<8x5xf32, #spv.storage_class<StorageBuffer>>
 
         return
       }
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir
index 3cdeaf4..f02240b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_cooperative_ops.mlir
@@ -20,13 +20,13 @@
         cooperative_matrix_properties_nv = [
           #spv.coop_matrix_props<
             a_type = i8, b_type = i8, c_type = i32, k_size = 32,
-            m_size = 8, n_size = 8, result_type = i32, scope = 3 : i32>,
+            m_size = 8, n_size = 8, result_type = i32, scope = <Subgroup>>,
           #spv.coop_matrix_props<
             a_type = f16, b_type = f16, c_type = f16, k_size = 16,
-            m_size = 16, n_size = 16, result_type = f16, scope = 3 : i32>,
+            m_size = 16, n_size = 16, result_type = f16, scope = <Subgroup>>,
           #spv.coop_matrix_props<
             a_type = f16, b_type = f16, c_type = f32, k_size = 16,
-            m_size = 16, n_size = 16, result_type = f32, scope = 3 : i32>
+            m_size = 16, n_size = 16, result_type = f32, scope = <Subgroup>>
         ],
         max_compute_shared_memory_size = 49152,
         max_compute_workgroup_invocations = 1024,
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir
index 1bcfcf1..400d6a4 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/pipeline_matmul_promotion.mlir
@@ -57,18 +57,18 @@
 //   CHECK-COUNT-5: spv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
 
 //           CHECK: spv.mlir.loop
-//           CHECK:   spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
+//           CHECK:   spv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
 //   CHECK-COUNT-5:   spv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-//           CHECK:   spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
+//           CHECK:   spv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
 
 //  CHECK-COUNT-64:   spv.Load "Workgroup" %{{.+}} : vector<4xf32>
 // CHECK-COUNT-128:   spv.GL.Fma %{{.+}}, %{{.+}}, %{{.+}} : vector<4xf32>
 //   CHECK-COUNT-5:   spv.Load "StorageBuffer" %{{.+}} : vector<4xf32>
 //           CHECK:   spv.mlir.merge
 
-//           CHECK: spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
+//           CHECK: spv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
 //   CHECK-COUNT-5: spv.Store "Workgroup" %{{.+}}, %{{.+}} : vector<4xf32>
-//           CHECK: spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
+//           CHECK: spv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
 
 //  CHECK-COUNT-64: spv.Load "Workgroup" %{{.+}} : vector<4xf32>
 // CHECK-COUNT-128: spv.GL.Fma %{{.+}}, %{{.+}}, %{{.+}} : vector<4xf32>
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
index 5e98d0a..6900c66 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_vectorize_to_cooperative_ops.mlir
@@ -21,13 +21,13 @@
         cooperative_matrix_properties_nv = [
           #spv.coop_matrix_props<
             a_type = i8, b_type = i8, c_type = i32, k_size = 32,
-            m_size = 8, n_size = 8, result_type = i32, scope = 3 : i32>,
+            m_size = 8, n_size = 8, result_type = i32, scope = <Subgroup>>,
           #spv.coop_matrix_props<
             a_type = f16, b_type = f16, c_type = f16, k_size = 16,
-            m_size = 16, n_size = 16, result_type = f16, scope = 3 : i32>,
+            m_size = 16, n_size = 16, result_type = f16, scope = <Subgroup>>,
           #spv.coop_matrix_props<
             a_type = f16, b_type = f16, c_type = f32, k_size = 16,
-            m_size = 16, n_size = 16, result_type = f32, scope = 3 : i32>
+            m_size = 16, n_size = 16, result_type = f32, scope = <Subgroup>>
         ],
         max_compute_shared_memory_size = 49152,
         max_compute_workgroup_invocations = 1024,
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
index b8e213a..8d54c34 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -782,7 +782,7 @@
 
     ```mlir
     %c = flow.tensor.constant tensor<2x2xf32> -> tensor<?x?xf32>
-    %res = math.abs %c : tensor<?x?xf32>
+    %res = math.absf %c : tensor<?x?xf32>
     ```
   }];
   let arguments = (ins ElementsAttr:$value);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
index 6e2fea5..b73df33 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/transform_dispatch_region_formation.mlir
@@ -15,7 +15,7 @@
 
 transform.with_pdl_patterns {
 ^bb0(%arg0: !pdl.operation):
-  transform.sequence %arg0 {
+  transform.sequence %arg0 failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["tensor.extract_slice"]} in %arg1
     transform.iree.wrap_in_dispatch_region %0
@@ -46,7 +46,7 @@
 
 transform.with_pdl_patterns {
 ^bb0(%arg0: !pdl.operation):
-  transform.sequence %arg0 {
+  transform.sequence %arg0 failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
     %dispatch_op = transform.iree.wrap_in_dispatch_region %0
@@ -79,7 +79,7 @@
 
 transform.with_pdl_patterns {
 ^bb0(%arg0: !pdl.operation):
-  transform.sequence %arg0 {
+  transform.sequence %arg0 failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
     %dispatch_op = transform.iree.wrap_in_dispatch_region %0
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/ConvertMathToVM.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/ConvertMathToVM.cpp
index 38ba920..3028739 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/ConvertMathToVM.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/ConvertMathToVM.cpp
@@ -86,7 +86,7 @@
                               TypeConverter &typeConverter,
                               RewritePatternSet &patterns) {
   patterns.insert<
-      UnaryArithmeticOpConversion<math::AbsOp, IREE::VM::AbsF32Op,
+      UnaryArithmeticOpConversion<math::AbsFOp, IREE::VM::AbsF32Op,
                                   IREE::VM::AbsF64Op>,
       UnaryArithmeticOpConversion<math::CeilOp, IREE::VM::CeilF32Op,
                                   IREE::VM::CeilF64Op>,
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/arithmetic_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/arithmetic_ops.mlir
index 87aac0f..8479425 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/arithmetic_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/MathToVM/test/arithmetic_ops.mlir
@@ -71,7 +71,7 @@
   %15 = math.erf %14 : f32
 
   // CHECK: vm.abs.f32
-  %16 = math.abs %14 : f32
+  %16 = math.absf %14 : f32
 
   // CHECK: vm.ceil.f32
   %17 = math.ceil %14 : f32
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
index 9496906..12e4a08 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
@@ -1114,7 +1114,7 @@
 
   // Floating-point arithmetic ops.
   patterns
-      .insert<UnaryArithmeticOpConversion<math::AbsOp, IREE::VM::AbsF32Op,
+      .insert<UnaryArithmeticOpConversion<math::AbsFOp, IREE::VM::AbsF32Op,
                                           IREE::VM::AbsF64Op>,
               BinaryArithmeticOpConversion<arith::AddFOp, IREE::VM::AddF32Op,
                                            IREE::VM::AddF64Op>,
diff --git a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
index 330ec6c..6e3fd39 100644
--- a/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
+++ b/compiler/src/iree/compiler/Dialect/Vulkan/Utils/test/target_env_conversion.mlir
@@ -12,7 +12,7 @@
 // DEFAULT: #spv.target_env<#spv.vce<v1.3, [Shader, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], [SPV_KHR_storage_buffer_storage_class]>, SwiftShader:CPU, #spv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 4, cooperative_matrix_properties_nv = []>>
 // ADRENO: #spv.target_env<#spv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, Qualcomm:IntegratedGPU, #spv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], subgroup_size = 64, cooperative_matrix_properties_nv = []>>
 // MALI: #spv.target_env<#spv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, ARM:IntegratedGPU, #spv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512, 512, 512], subgroup_size = 16, cooperative_matrix_properties_nv = []>>
-// TURINGT4: #spv.target_env<#spv.vce<v1.5, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], cooperative_matrix_properties_nv = [#spv.coop_matrix_props<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, scope = 3 : i32>, #spv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, scope = 3 : i32>, #spv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, scope = 3 : i32>]>>
+// TURINGT4: #spv.target_env<#spv.vce<v1.5, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU, #spv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], cooperative_matrix_properties_nv = [#spv.coop_matrix_props<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, scope = <Subgroup>>, #spv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, scope = <Subgroup>>, #spv.coop_matrix_props<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, scope = <Subgroup>>]>>
 // AMD5700XT: #spv.target_env<#spv.vce<v1.5, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, AMD:DiscreteGPU, #spv.resource_limits<max_compute_shared_memory_size = 65536, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], subgroup_size = 64, cooperative_matrix_properties_nv = []>>
 // M1: #spv.target_env<#spv.vce<v1.3, [Shader, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, Apple:IntegratedGPU, #spv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 1024], cooperative_matrix_properties_nv = []>>
 
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
index 5e4d685..2732010 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
@@ -293,6 +293,19 @@
   }
 };
 
+llvm::Optional<Value> scalarToTensor(OpBuilder &builder, Type /*type*/,
+                                     ValueRange inputs, Location loc) {
+  assert(inputs.size() == 1);
+  if (inputs.front().getType().isa<ShapedType>()) {
+    return llvm::None;
+  }
+  return builder
+      .create<tensor::FromElementsOp>(
+          loc, RankedTensorType::get({}, inputs.front().getType()),
+          inputs.front())
+      .getResult();
+}
+
 struct ConvertMHLOToLinalgOnTensorsPass
     : public ConvertMHLOToLinalgOnTensorsBase<
           ConvertMHLOToLinalgOnTensorsPass> {
@@ -307,6 +320,7 @@
     MLIRContext *context = &getContext();
 
     auto typeConverter = mhlo::createHloToLinalgSignedIntegerConverter();
+    typeConverter->addArgumentMaterialization(scalarToTensor);
     // NOTE: not using corresponding setupMHLOToFlowPatterns because the entire
     // MHLO dialects are marked illegal by this pass.
     // TODO: Collapse/rework all of these patterns once the consolidation
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
index bc3508f..bed9898 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
@@ -240,7 +240,7 @@
 //                Concate and reshape the output.
 // CHECK:         %[[CON:.+]] = "mhlo.concatenate"(%[[Z0]], %[[Z1]]) {dimension = 0 : i64} : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
 // CHECK:         %[[SLICE:.+]] = tensor.extract_slice %[[CON]][0] [15] [1] : tensor<16xf32> to tensor<15xf32>
-// CHECK:         %[[RES:.+]] = "mhlo.reshape"(%[[SLICE]]) : (tensor<15xf32>) -> tensor<3x5xf32>
+// CHECK:         %[[RES:.+]] = mhlo.reshape %[[SLICE]] : (tensor<15xf32>) -> tensor<3x5xf32>
 // CHECK:         return %[[RES]]
 
 // -----
@@ -254,10 +254,10 @@
 }
 
 // CHECK-LABEL: func.func @scatter_rank0
-// CHECK-DAG: %[[RE_I:.+]] = "mhlo.reshape"(%arg1) : (tensor<2xi32>) -> tensor<1x2xi32>
-// CHECK-DAG: %[[RE_U:.+]] = "mhlo.reshape"(%arg2) : (tensor<i32>) -> tensor<1xi32>
-// CHECK:     %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %[[RE_I]], %[[RE_U]])
-// CHECK:       mhlo.return %arg4
+// CHECK-DAG: %[[RE_I:.+]] = mhlo.reshape %{{.*}} : (tensor<2xi32>) -> tensor<1x2xi32>
+// CHECK-DAG: %[[RE_U:.+]] = mhlo.reshape %{{.*}} : (tensor<i32>) -> tensor<1xi32>
+// CHECK:     %[[SCATTER:.+]] = "mhlo.scatter"(%{{.*}}, %[[RE_I]], %[[RE_U]])
+// CHECK:       mhlo.return %{{.*}}
 
 // -----
 
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir
index 4f69764..84274a9 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir
@@ -14,10 +14,10 @@
 }
 
 // CHECK: dot_general_to_dot(%[[ARG0:.+]]: tensor<1x32x128x4xf32>, %[[ARG1:.+]]: tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32>
-// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x32x128x4xf32>) -> tensor<32x512xf32>
-// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<128x4x8x64xf32>) -> tensor<512x512xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = mhlo.reshape %[[ARG0]] : (tensor<1x32x128x4xf32>) -> tensor<32x512xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = mhlo.reshape %[[ARG1]] : (tensor<128x4x8x64xf32>) -> tensor<512x512xf32>
 // CHECK: %[[DOT:.+]] = "mhlo.dot"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
-// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT]]) : (tensor<32x512xf32>) -> tensor<1x32x8x64xf32>
+// CHECK: %[[RESULT:.+]] = mhlo.reshape %[[DOT]] : (tensor<32x512xf32>) -> tensor<1x32x8x64xf32>
 // CHECK: return %[[RESULT]] : tensor<1x32x8x64xf32>
 
 // -----
@@ -35,10 +35,10 @@
   return %0 : tensor<1x8x32x32xf32>
 }
 // CHECK: dot_general_to_dot_general_rank_reduced(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
-// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = mhlo.reshape %[[ARG0]] : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = mhlo.reshape %[[ARG1]] : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
 // CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
-// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[RESULT:.+]] = mhlo.reshape %[[DOT_RESULT]] : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
 // CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
 
 // -----
@@ -57,10 +57,10 @@
   return %0 : tensor<1x8x32x32xf32>
 }
 // CHECK: dot_general_to_dot_general_rank_reduced_attribute(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
-// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = mhlo.reshape %[[ARG0]] : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = mhlo.reshape %[[ARG1]] : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
 // CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]]) {{{.*}}, unknown_attribute_to_propagate
-// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[RESULT:.+]] = mhlo.reshape %[[DOT_RESULT]] : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
 // CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
 
 // -----
@@ -79,10 +79,10 @@
 }
 // CHECK: dot_general_to_dot_general_rank_reduced_a_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
 // CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x64x32xf32>) -> tensor<1x8x32x64xf32>
-// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0_RESHAPED_TR]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG1_RSSHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = mhlo.reshape %[[ARG0_RESHAPED_TR]] : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_RSSHAPED:.+]] = mhlo.reshape %[[ARG1]] : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
 // CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RSSHAPED]])
-// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[RESULT:.+]] = mhlo.reshape %[[DOT_RESULT]] : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
 
 // -----
 
@@ -100,10 +100,10 @@
 }
 // CHECK: dot_general_to_dot_general_rank_reduced_b_transposed(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
 // CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x32x64xf32>) -> tensor<1x8x64x32xf32>
-// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1_RESHAPED_TR]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = mhlo.reshape %[[ARG0]] : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = mhlo.reshape %[[ARG1_RESHAPED_TR]] : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
 // CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
-// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[RESULT:.+]] = mhlo.reshape %[[DOT_RESULT]] : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
 // CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
 
 
@@ -124,10 +124,10 @@
 // CHECK: dot_general_to_dot_general_rank_reduced_ab_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
 // CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x64x32xf32>) -> tensor<1x8x32x64xf32>
 // CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1]]) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x8x32x64xf32>) -> tensor<1x8x64x32xf32>
-// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0_RESHAPED_TR]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
-// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1_RESHAPED_TR]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = mhlo.reshape %[[ARG0_RESHAPED_TR]] : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = mhlo.reshape %[[ARG1_RESHAPED_TR]] : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
 // CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
-// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[RESULT:.+]] = mhlo.reshape %[[DOT_RESULT]] : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
 // CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
 
 // -----
@@ -152,10 +152,10 @@
 // CHECK-SAME:      permutation = dense<[0, 2, 1, 3]>
 // CHECK:         %[[ARG1_TRANSPOSED:.+]] = "mhlo.transpose"(%[[ARG1]])
 // CHECK-SAME:      permutation = dense<[0, 2, 3, 1]>
-// CHECK:         %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0_TRANSPOSED]]) : (tensor<1x8x1x64xf32>) -> tensor<8x1x64xf32>
-// CHECK:         %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1_TRANSPOSED]]) : (tensor<1x8x64x512xf32>) -> tensor<8x64x512xf32>
+// CHECK:         %[[ARG0_RESHAPED:.+]] = mhlo.reshape %[[ARG0_TRANSPOSED]] : (tensor<1x8x1x64xf32>) -> tensor<8x1x64xf32>
+// CHECK:         %[[ARG1_RESHAPED:.+]] = mhlo.reshape %[[ARG1_TRANSPOSED]] : (tensor<1x8x64x512xf32>) -> tensor<8x64x512xf32>
 // CHECK:         %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
-// CHECK:         %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x1x512xf32>) -> tensor<1x8x1x512xf32>
+// CHECK:         %[[RESULT:.+]] = mhlo.reshape %[[DOT_RESULT]] : (tensor<8x1x512xf32>) -> tensor<1x8x1x512xf32>
 // CHECK:         return %[[RESULT]] : tensor<1x8x1x512xf32>
 
 // -----
@@ -180,11 +180,11 @@
 // CHECK: %[[ARG1_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG1]])
 // CHECK-SAME: {permutation = dense<[1, 2, 0]> : tensor<3xi64>}
 // CHECK-SAME: (tensor<309x4x36xf32>) -> tensor<4x36x309xf32>
-// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0_RESHAPED_TR]])
+// CHECK: %[[ARG0_RESHAPED:.+]] = mhlo.reshape %[[ARG0_RESHAPED_TR]]
 // CHECK-SAME: (tensor<4x64x155x36xf32>) -> tensor<4x9920x36xf32>
 // CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED_TR]])
 // CHECK-SAME: (tensor<4x9920x36xf32>, tensor<4x36x309xf32>) -> tensor<4x9920x309xf32>
-// CHECK: "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<4x9920x309xf32>) -> tensor<4x64x155x309xf32>
+// CHECK: mhlo.reshape %[[DOT_RESULT]] : (tensor<4x9920x309xf32>) -> tensor<4x64x155x309xf32>
 
 // -----
 
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/transformation_pipeline.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/transformation_pipeline.mlir
index 2077472..aa248ea 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/transformation_pipeline.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/transformation_pipeline.mlir
@@ -93,8 +93,8 @@
 // CHECK-NEXT:     %0 = linalg.init_tensor [4] : tensor<4xf32>
 // CHECK-NEXT:     %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4xf32>) -> tensor<4xf32>
 // CHECK-NEXT:     %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<4x8xf32>) outs(%1 : tensor<4xf32>) {
-// CHECK-NEXT:     ^bb0(%arg1: f32, %arg2: f32):
-// CHECK-NEXT:       %3 = arith.addf %arg1, %arg2 : f32
+// CHECK-NEXT:     ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32):
+// CHECK-NEXT:       %3 = arith.addf %[[ARG2]], %[[ARG1]] : f32
 // CHECK-NEXT:       linalg.yield %3 : f32
 // CHECK-NEXT:     } -> tensor<4xf32>
 // CHECK-NEXT:     return %2 : tensor<4xf32>
diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py
index 1e7d307..bde6ae9 100644
--- a/runtime/bindings/python/tests/vm_test.py
+++ b/runtime/bindings/python/tests/vm_test.py
@@ -46,7 +46,7 @@
   binary = iree.compiler.compile_str(
       """
       func.func @dynamic_abs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
-        %0 = math.abs %arg0 : tensor<?x?xf32>
+        %0 = math.absf %arg0 : tensor<?x?xf32>
         return %0 : tensor<?x?xf32>
       }
       """,
diff --git a/runtime/src/iree/hal/cts/testdata/command_buffer_dispatch_test.mlir b/runtime/src/iree/hal/cts/testdata/command_buffer_dispatch_test.mlir
index b19795c..529fe43 100644
--- a/runtime/src/iree/hal/cts/testdata/command_buffer_dispatch_test.mlir
+++ b/runtime/src/iree/hal/cts/testdata/command_buffer_dispatch_test.mlir
@@ -1,7 +1,7 @@
 // Bootstrapped from this source IR:
 //
 // func.func @abs(%input : tensor<f32>) -> (tensor<f32>) {
-//   %result = math.abs %input : tensor<f32>
+//   %result = math.absf %input : tensor<f32>
 //   return %result : tensor<f32>
 // }
 
@@ -29,7 +29,7 @@
       %3 = linalg.init_tensor [] : tensor<f32>
       %4 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%2 : tensor<f32>) outs(%3 : tensor<f32>) {
       ^bb0(%arg0: f32, %arg1: f32):
-        %5 = math.abs %arg0 : f32
+        %5 = math.absf %arg0 : f32
         linalg.yield %5 : f32
       } -> tensor<f32>
       flow.dispatch.tensor.store %4, %1, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:f32>
diff --git a/runtime/src/iree/hal/cts/testdata/executable_cache_test.mlir b/runtime/src/iree/hal/cts/testdata/executable_cache_test.mlir
index b19795c..529fe43 100644
--- a/runtime/src/iree/hal/cts/testdata/executable_cache_test.mlir
+++ b/runtime/src/iree/hal/cts/testdata/executable_cache_test.mlir
@@ -1,7 +1,7 @@
 // Bootstrapped from this source IR:
 //
 // func.func @abs(%input : tensor<f32>) -> (tensor<f32>) {
-//   %result = math.abs %input : tensor<f32>
+//   %result = math.absf %input : tensor<f32>
 //   return %result : tensor<f32>
 // }
 
@@ -29,7 +29,7 @@
       %3 = linalg.init_tensor [] : tensor<f32>
       %4 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%2 : tensor<f32>) outs(%3 : tensor<f32>) {
       ^bb0(%arg0: f32, %arg1: f32):
-        %5 = math.abs %arg0 : f32
+        %5 = math.absf %arg0 : f32
         linalg.yield %5 : f32
       } -> tensor<f32>
       flow.dispatch.tensor.store %4, %1, offsets = [], sizes = [], strides = [] : tensor<f32> -> !flow.dispatch.tensor<writeonly:f32>
diff --git a/samples/models/simple_abs.mlir b/samples/models/simple_abs.mlir
index 3f35d6a..9a6e516 100644
--- a/samples/models/simple_abs.mlir
+++ b/samples/models/simple_abs.mlir
@@ -1,4 +1,4 @@
 func.func @abs(%input : tensor<f32>) -> (tensor<f32>) {
-  %result = math.abs %input : tensor<f32>
+  %result = math.absf %input : tensor<f32>
   return %result : tensor<f32>
 }
diff --git a/tests/compiler_driver/executable_benchmarks.mlir b/tests/compiler_driver/executable_benchmarks.mlir
index 2cf1b9c..b1ee0ec 100644
--- a/tests/compiler_driver/executable_benchmarks.mlir
+++ b/tests/compiler_driver/executable_benchmarks.mlir
@@ -10,7 +10,7 @@
 // at files and that's harder cross-platform).
 
 func.func @abs(%input : tensor<f32>) -> (tensor<f32>) {
-  %result = math.abs %input : tensor<f32>
+  %result = math.absf %input : tensor<f32>
   return %result : tensor<f32>
 }
 
diff --git a/third_party/llvm-project b/third_party/llvm-project
index bd81e8a..8741b1a 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit bd81e8a5e8220a3cda4260e2ddffbd2f88a7f990
+Subproject commit 8741b1a9c42b72e7758c272c94cde0e6f72c7eb3
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index b5b1ad8..c7e9ad2 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit b5b1ad8f1aa11023b4f32d0d744490364040ad68
+Subproject commit c7e9ad27d419231db90e4ef9bbbe989ccf5c5538
diff --git a/tools/BUILD b/tools/BUILD
index e224faf..32b7a77 100644
--- a/tools/BUILD
+++ b/tools/BUILD
@@ -206,6 +206,7 @@
     name = "iree-tblgen",
     srcs = [
         "//compiler/src/iree/compiler/Dialect/VM/Tools:GenSrcs",
+        "@llvm-project//mlir:tools/mlir-tblgen/mlir-tblgen.cpp",
     ],
     tags = ["hostonly"],
     deps = [
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index 11a5f4b..5a8e12d 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -184,6 +184,7 @@
       LLVMTableGen
       MLIRSupport
       MLIRTableGen
+      MLIRTblgenLib
       iree::compiler::Utils
     HOSTONLY
   )
diff --git a/tools/test/benchmark_flags.txt b/tools/test/benchmark_flags.txt
index cd57590..c624c96 100644
--- a/tools/test/benchmark_flags.txt
+++ b/tools/test/benchmark_flags.txt
@@ -13,7 +13,7 @@
   // LIST-BENCHMARKS: BM_foo2
   func.func @foo2() -> tensor<4xf32> {
     %input = util.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
-    %result = math.abs %input : tensor<4xf32>
+    %result = math.absf %input : tensor<4xf32>
     return %result : tensor<4xf32>
   }
 }
diff --git a/tools/test/iree-benchmark-module.mlir b/tools/test/iree-benchmark-module.mlir
index 4b423d3..67988f4 100644
--- a/tools/test/iree-benchmark-module.mlir
+++ b/tools/test/iree-benchmark-module.mlir
@@ -4,6 +4,6 @@
 
 // CHECK-LABEL: BM_abs
 func.func @abs(%input : tensor<f32>) -> (tensor<f32>) {
-  %result = math.abs %input : tensor<f32>
+  %result = math.absf %input : tensor<f32>
   return %result : tensor<f32>
 }
diff --git a/tools/test/iree-run-mlir.mlir b/tools/test/iree-run-mlir.mlir
index a82ab74..692815a 100644
--- a/tools/test/iree-run-mlir.mlir
+++ b/tools/test/iree-run-mlir.mlir
@@ -4,7 +4,7 @@
 
 // CHECK-LABEL: EXEC @abs
 func.func @abs(%input : tensor<f32>) -> (tensor<f32>) {
-  %result = math.abs %input : tensor<f32>
+  %result = math.absf %input : tensor<f32>
   return %result : tensor<f32>
 }
 // CHECK: f32=2
diff --git a/tools/test/iree-run-module.mlir b/tools/test/iree-run-module.mlir
index fb1b2b7..f72e794 100644
--- a/tools/test/iree-run-module.mlir
+++ b/tools/test/iree-run-module.mlir
@@ -4,7 +4,7 @@
 
 // CHECK-LABEL: EXEC @abs
 func.func @abs(%input : tensor<f32>) -> (tensor<f32>) {
-  %result = math.abs %input : tensor<f32>
+  %result = math.absf %input : tensor<f32>
   return %result : tensor<f32>
 }
 // CHECK: f32=2
diff --git a/tools/test/multiple_exported_functions.mlir b/tools/test/multiple_exported_functions.mlir
index 2250867..dec7cec 100644
--- a/tools/test/multiple_exported_functions.mlir
+++ b/tools/test/multiple_exported_functions.mlir
@@ -9,7 +9,7 @@
   }
   func.func @foo2() -> tensor<4xf32> {
     %input = util.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>
-    %result = math.abs %input : tensor<4xf32>
+    %result = math.absf %input : tensor<4xf32>
     return %result : tensor<4xf32>
   }
 }