Integrate llvm/llvm-project@be9c91843bab (#13296)

* Updated submodule to llvm/llvm-project@be9c91843bab
* Updated submodule to tensorflow/mlir-hlo@4d28523
* Updated codebase to use explicit cast after llvm/llvm-project@6089d61
* Updated codebase constant op builders after llvm/llvm-project@00e3566
* Removed a test that represented an expected missed vectorization path

---------

Co-authored-by: Mahesh Ravishankar <ravishankarm@google.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
index 49f0c6a..faa275f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
@@ -505,7 +505,7 @@
       if (!attr.isSplat()) continue;
       auto type = attr.getType().dyn_cast<RankedTensorType>();
       if (!type) continue;
-      Attribute scalarAttr = attr.getValues<Attribute>()[0];
+      TypedAttr scalarAttr = attr.getValues<TypedAttr>()[0];
 
       modifiedOutput = true;
       Value emptyTensor = rewriter.create<tensor::EmptyOp>(
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPUPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPUPatterns.cpp
index fb27df4..fc9542e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPUPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPUPatterns.cpp
@@ -187,9 +187,10 @@
   // Limit promotion to matmul and batch matmul, there may be generic
   // ops with more batch dimensions we didn't distribute and therefore
   // cannot find a higher bound.
-  return success(linalg::isaContractionOpInterface(op) &&
-                 linalgOp.getNumParallelLoops() >= 2 &&
-                 linalgOp.getNumParallelLoops() <= 3);
+  return success(
+      linalg::isaContractionOpInterface(cast<linalg::LinalgOp>(op)) &&
+      linalgOp.getNumParallelLoops() >= 2 &&
+      linalgOp.getNumParallelLoops() <= 3);
 }
 
 }  // namespace
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index ac7b010..fd12e4b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -650,7 +650,7 @@
     return forallOp->emitError("mapping must be present");
   SmallVector<Attribute> blockMapping =
       llvm::to_vector(forallOp.getMapping()->getValue());
-  if (llvm::any_of(blockMapping, [](DeviceMappingAttrInterface map) {
+  if (llvm::any_of(blockMapping, [](Attribute map) {
         return !map.isa<gpu::GPUBlockMappingAttr>();
       })) {
     return forallOp->emitError("mapping must be #gpu.block<x/y/z/>");
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index 8266932..0d942a1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -137,8 +137,8 @@
     auto newAttrType = RankedTensorType::get(attrType.getShape(),
                                              legalizedElementType.value());
     auto newAttr = DenseElementsAttr::get(newAttrType, legalizedValues);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constantOp, newAttr,
-                                                   newAttrType);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constantOp, newAttrType,
+                                                   newAttr);
     return success();
   }
 };
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
index bc9e88a..524f13f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
@@ -472,51 +472,6 @@
 
 // -----
 
-func.func @fill_matmul_exp() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %c0 = arith.constant 0 : index
-  %c33 = arith.constant 33 : index
-  %c49 = arith.constant 49 : index
-  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(32) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<33x16xf32>>
-  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(32) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<16x49xf32>>
-  %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(32) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<33x49xf32>>
-  %workgroup_id_x = hal.interface.workgroup.id[0] : index
-  %workgroup_count_x = hal.interface.workgroup.count[0] : index
-  %workgroup_id_y = hal.interface.workgroup.id[1] : index
-  %workgroup_count_y = hal.interface.workgroup.count[1] : index
-  %3 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_y]
-  %4 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_y]
-  scf.for %arg0 = %3 to %c33 step %4 {
-    %5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
-    %6 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_count_x]
-    scf.for %arg1 = %5 to %c49 step %6 {
-      %7 = affine.min affine_map<(d0) -> (16, -d0 + 33)>(%arg0)
-      %8 = affine.min affine_map<(d0) -> (16, -d0 + 49)>(%arg1)
-      %9 = tensor.empty(%7, %8) : tensor<?x?xf32>
-      %10 = affine.min affine_map<(d0) -> (-d0 + 33, 16)>(%arg0)
-      %11 = flow.dispatch.tensor.load %0, offsets = [%arg0, 0], sizes = [%10, 16], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<33x16xf32>> -> tensor<?x16xf32>
-      %12 = affine.min affine_map<(d0) -> (-d0 + 49, 16)>(%arg1)
-      %13 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [16, %12], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<16x49xf32>> -> tensor<16x?xf32>
-      %14 = tensor.empty(%10, %12) : tensor<?x?xf32>
-      %15 = linalg.fill ins(%cst : f32) outs(%14 : tensor<?x?xf32>) -> tensor<?x?xf32>
-      %16 = linalg.matmul ins(%11, %13 : tensor<?x16xf32>, tensor<16x?xf32>) outs(%15 : tensor<?x?xf32>) -> tensor<?x?xf32>
-      %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%16 : tensor<?x?xf32>) outs(%9 : tensor<?x?xf32>) {
-      ^bb0(%arg2: f32, %arg3: f32):
-        %18 = math.exp %arg2 : f32
-        linalg.yield %18 : f32
-      } -> tensor<?x?xf32>
-      flow.dispatch.tensor.store %17, %2, offsets = [%arg0, %arg1], sizes = [%7, %8], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<33x49xf32>>
-    }
-  }
-  return
-}
-// CHECK-LABEL: func.func @fill_matmul_exp()
-//       CHECK:   %[[MATMUL:.+]] = linalg.matmul
-//       CHECK:   linalg.generic
-//  CHECK-SAME:       outs(%[[MATMUL]]
-
-// -----
-
 func.func @cumsum__2x2x2x2x2x2x2() {
   %cst = arith.constant dense<0.000000e+00> : tensor<2x2x2x2x2x2x2xf32>
   %c0 = arith.constant 0 : index
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index 70f949d..099ba22 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -760,10 +760,11 @@
 
     Type wideType = rewriter.getIntegerType(64);
     // Shift amount necessary to extract the high bits from widened result.
-    Attribute shiftValAttr = rewriter.getI64IntegerAttr(32);
+    TypedAttr shiftValAttr = rewriter.getI64IntegerAttr(32);
     if (auto vecTy = resultType.dyn_cast<VectorType>()) {
       wideType = VectorType::get(vecTy.getShape(), wideType);
-      shiftValAttr = SplatElementsAttr::get(wideType, shiftValAttr);
+      shiftValAttr =
+          SplatElementsAttr::get(cast<ShapedType>(wideType), shiftValAttr);
     }
     Value shiftVal = rewriter.create<arith::ConstantOp>(loc, shiftValAttr);
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index 9caf23a..2282595 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -839,12 +839,13 @@
     parallelTileSizes.push_back(sz);
   }
   SmallVector<int64_t> reductionTileSizes;
-  splitParallelAndReductionTiles(op.getOperation(), parallelTileSizes,
-                                 reductionTileSizes);
-
-  setVectorSizesForDynamicShapes(op.getOperation(), vecPreProcStrategy,
+  splitParallelAndReductionTiles(cast<linalg::LinalgOp>(op.getOperation()),
                                  parallelTileSizes, reductionTileSizes);
 
+  setVectorSizesForDynamicShapes(cast<linalg::LinalgOp>(op.getOperation()),
+                                 vecPreProcStrategy, parallelTileSizes,
+                                 reductionTileSizes);
+
   TileSizesListType newTileSizes;
   // Copy all the tile size levels except the workgroup one which will be split
   // into parallel and reduction.
@@ -880,8 +881,8 @@
       getMaxVectorTileSize(0, K, workgroupTileSizes.back(), vectorSize));
 
   SmallVector<int64_t> reductionTileSizes;
-  splitParallelAndReductionTiles(op.getOperation(), parallelTileSizes,
-                                 reductionTileSizes);
+  splitParallelAndReductionTiles(cast<linalg::LinalgOp>(op.getOperation()),
+                                 parallelTileSizes, reductionTileSizes);
 
   TileSizesListType tileSizes;
   tileSizes.emplace_back(flowTileSizes.begin(), flowTileSizes.end());
@@ -1092,8 +1093,8 @@
 
   SmallVector<int64_t> parallelTileSizes = getL1TileSizes();
   SmallVector<int64_t> reductionTileSizes;
-  splitParallelAndReductionTiles(mmt4dOp.getOperation(), parallelTileSizes,
-                                 reductionTileSizes);
+  splitParallelAndReductionTiles(cast<linalg::LinalgOp>(mmt4dOp.getOperation()),
+                                 parallelTileSizes, reductionTileSizes);
 
   TileSizesListType tileSizes = {getWorkgroupTileSizes(), parallelTileSizes,
                                  reductionTileSizes};
@@ -1665,8 +1666,9 @@
 
 static LogicalResult setConvNhwcRootConfigImpl(func::FuncOp entryPointFn,
                                                linalg::LinalgOp convOp) {
-  int64_t vectorSize =
-      getVectorSize(entryPointFn, convOp.getDpsInitOperand(0)->get().getType());
+  int64_t vectorSize = getVectorSize(
+      entryPointFn,
+      cast<ShapedType>(convOp.getDpsInitOperand(0)->get().getType()));
   SmallVector<int64_t> targetTileSizes =
       getConvWorkgroupSizes(entryPointFn, convOp, vectorSize);
   return setConvRootConfig(entryPointFn, convOp, targetTileSizes, vectorSize);
@@ -1706,8 +1708,9 @@
 /// operations.
 static LogicalResult setConvNchwRootConfigImpl(func::FuncOp entryPointFn,
                                                linalg::LinalgOp convOp) {
-  int64_t vectorSize =
-      getVectorSize(entryPointFn, convOp.getDpsInitOperand(0)->get().getType());
+  int64_t vectorSize = getVectorSize(
+      entryPointFn,
+      cast<ShapedType>(convOp.getDpsInitOperand(0)->get().getType()));
   SmallVector<int64_t> targetTileSizes = {1, vectorSize * 2, 1, 8, 8, 1, 1};
   return setConvRootConfig(entryPointFn, convOp, targetTileSizes, vectorSize);
 }
@@ -1731,8 +1734,8 @@
 /// operations.
 static LogicalResult setRootConfig(func::FuncOp entryPointFn,
                                    linalg::DepthwiseConv2DNhwcHwcOp convOp) {
-  int64_t vectorSize =
-      getVectorSize(entryPointFn, convOp.getResult(0).getType());
+  int64_t vectorSize = getVectorSize(
+      entryPointFn, cast<ShapedType>(convOp.getResult(0).getType()));
   SmallVector<int64_t> targetTileSizes =
       getConvWorkgroupSizes(entryPointFn, convOp, vectorSize);
   return setConvRootConfig(entryPointFn, convOp, targetTileSizes, vectorSize);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
index 30e5a63..7e98448 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
@@ -160,7 +160,7 @@
   // If the pack op didn't have a padding_value attribute, default to 0.
   if (!paddingVal) {
     paddingVal =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(i64), i64);
+        rewriter.create<arith::ConstantOp>(loc, i64, rewriter.getZeroAttr(i64));
   }
   int paddingValBitWidth = paddingVal.getType().getIntOrFloatBitWidth();
   // Non-integer element types get bitcast to integer of same bit width.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp
index 1f9e161..76151fb 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp
@@ -124,8 +124,8 @@
 
   // 2) Apply splitReduction on the single vector-length array.
   // splitReduction already replaces the op.
-  FailureOr<linalg::SplitReductionResult> splitRes =
-      splitReduction(rewriter, tileResFirst->tiledOps.back(), fn);
+  FailureOr<linalg::SplitReductionResult> splitRes = splitReduction(
+      rewriter, cast<linalg::LinalgOp>(tileResFirst->tiledOps.back()), fn);
   if (failed(splitRes)) {
     LLVM_DEBUG(llvm::dbgs() << "failed on step 2 (SplitReduction)\n");
     return success();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
index e98be0c..e4e40c0 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuse.cpp
@@ -77,7 +77,7 @@
   SmallVector<OpResult> yieldedValuesToOrigValues;
   SmallVector<Operation *> tiledOps;
   FailureOr<scf::SCFTilingResult> tilingResult =
-      scf::tileUsingSCFForOp(rewriter, rootOp, options);
+      scf::tileUsingSCFForOp(rewriter, cast<TilingInterface>(rootOp), options);
   if (failed(tilingResult)) {
     return failure();
   }
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp
index 634559d..6d11daa 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp
@@ -39,7 +39,7 @@
   auto funcOp = getOperation();
 
   // Per-function lowering pipeline.
-  auto vectorTransposeLowering = vector::VectorTransposeLowering::Shuffle;
+  auto vectorTransposeLowering = vector::VectorTransposeLowering::Shuffle1D;
   auto vectorMultiReductionLowering =
       vector::VectorMultiReductionLowering::InnerReduction;
   auto vectorContractLowering = vector::VectorContractLowering::OuterProduct;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
index c14f97c..4c80a35 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
@@ -941,7 +941,7 @@
     VectorType flatAccVectorType =
         VectorType::get({accType.getNumElements()}, accType.getElementType());
     ;
-    Attribute resultInitializer;
+    TypedAttr resultInitializer;
     if (accElemType.isSignlessInteger()) {
       resultInitializer = DenseIntElementsAttr::get(flatAccVectorType, 0);
     } else if (accElemType.isF32()) {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
index cfa6c97..0306b85 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/AMDConfig.cpp
@@ -89,7 +89,8 @@
     const int multipler = 32 / bitwidth;
     bool hasPaddedInput = convOp.image().getDefiningOp<tensor::PadOp>();
     const int bestTilingFactor = (hasPaddedInput ? 16 : 32) * multipler;
-    return setConvOpConfig(rootOp, subgroupSize, bestTilingFactor);
+    return setConvOpConfig(cast<linalg::LinalgOp>(rootOp), subgroupSize,
+                           bestTilingFactor);
   }
 
   return failure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
index 6c1294c..0ef67bc 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/AdrenoConfig.cpp
@@ -58,7 +58,8 @@
     linalg::detail::ConvolutionDimensions convDims;
     linalg::detail::isConvolutionInterfaceImpl(rootOp, &convDims);
     const int bestTilingFactor = (convDims.depth.empty() ? 32 : 16) * multipler;
-    return setConvOpConfig(rootOp, subgroupSize, bestTilingFactor);
+    return setConvOpConfig(cast<linalg::LinalgOp>(rootOp), subgroupSize,
+                           bestTilingFactor);
   }
 
   return failure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp
index 39fe1d3..a63ab56 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/AppleConfig.cpp
@@ -56,7 +56,8 @@
     if (bitwidth > 32) return failure();
     const int multipler = 32 / bitwidth;
     const int bestTilingFactor = 16 * multipler;
-    return setConvOpConfig(rootOp, subgroupSize, bestTilingFactor);
+    return setConvOpConfig(cast<linalg::LinalgOp>(rootOp), subgroupSize,
+                           bestTilingFactor);
   }
 
   return failure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
index 4cdbbe0..dfef1a2 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/MaliConfig.cpp
@@ -60,7 +60,8 @@
     const int multipler = 32 / bitwidth;
     bool hasPaddedInput = convOp.image().getDefiningOp<tensor::PadOp>();
     const int bestTilingFactor = (hasPaddedInput ? 8 : 16) * multipler;
-    return setConvOpConfig(rootOp, subgroupSize, bestTilingFactor);
+    return setConvOpConfig(cast<linalg::LinalgOp>(rootOp), subgroupSize,
+                           bestTilingFactor);
   }
 
   return failure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
index e80ac7e..c0dceac 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndPromote.cpp
@@ -319,7 +319,7 @@
   // If there are no fused elementwise ops, we can avoid promoting C matrix.
   if (linalgOps.size() <= 1) return success();
 
-  linalg::LinalgOp matmulOp = linalgOps.front();
+  auto matmulOp = cast<linalg::LinalgOp>(linalgOps.front());
   auto genericOp = cast<linalg::GenericOp>(*linalgOps.back());
 
   auto matmulType =
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
index b13ce0b..b2b96df 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
@@ -385,7 +385,7 @@
 
 // List of identity elements by operation.
 // https://en.wikipedia.org/wiki/Identity_element
-static Attribute getCombiningKindIdentity(OpBuilder &builder,
+static TypedAttr getCombiningKindIdentity(OpBuilder &builder,
                                           vector::CombiningKind combiningKind,
                                           Type type) {
   switch (combiningKind) {
@@ -419,7 +419,7 @@
       return builder.getFloatAttr(type, negInfApFloat);
     }
   }
-  return Attribute();
+  return TypedAttr();
 }
 
 /// Compute the value on a single thread to get per lane reduction value.
@@ -463,11 +463,11 @@
   } else {
     // In cases where vecSize < unrollCount, we would pad the vector
     // with identity elements until it's total bit size is 32.
-    Attribute identityAttr =
+    TypedAttr identityAttr =
         getCombiningKindIdentity(builder, kind, elementType);
     identityAttr = DenseElementsAttr::get(unrolledLaneValType, identityAttr);
-    Value identity = builder.create<arith::ConstantOp>(loc, identityAttr,
-                                                       unrolledLaneValType);
+    Value identity = builder.create<arith::ConstantOp>(loc, unrolledLaneValType,
+                                                       identityAttr);
     perLaneReduction = builder.create<vector::InsertStridedSliceOp>(
         loc, input, identity, /*offsets=*/ArrayRef<int64_t>{0},
         /*strides=*/ArrayRef<int64_t>{1});
@@ -484,13 +484,13 @@
   if (vectorType) {
     elementType = vectorType.getElementType();
   }
-  Attribute identityAttr = getCombiningKindIdentity(builder, kind, elementType);
+  TypedAttr identityAttr = getCombiningKindIdentity(builder, kind, elementType);
   if (vectorType) {
     identityAttr = DenseElementsAttr::get(vectorType, identityAttr);
   }
   assert(identityAttr && "Unknown identity value for the reduction");
   Value identity =
-      builder.create<arith::ConstantOp>(loc, identityAttr, identityType);
+      builder.create<arith::ConstantOp>(loc, identityType, identityAttr);
   return identity;
 }
 
diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
index eb01e9f..cae4aa6 100644
--- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
+++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
@@ -235,7 +235,7 @@
       }
 
       modified = true;
-      targetGlobal.setInitialValueAttr(value);
+      targetGlobal.setInitialValueAttr(cast<TypedAttr>(value));
     }
 
     // Delete any ops noted for pruning.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp
index 3165824..03bb15a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowDialect.cpp
@@ -79,7 +79,7 @@
 Operation *FlowDialect::materializeConstant(OpBuilder &builder, Attribute value,
                                             Type type, Location loc) {
   if (arith::ConstantOp::isBuildableWith(value, type))
-    return builder.create<arith::ConstantOp>(loc, type, value);
+    return builder.create<arith::ConstantOp>(loc, type, cast<TypedAttr>(value));
   return nullptr;
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 0e24d89..0571d12 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -749,19 +749,20 @@
   if (!value) return {};
   if (auto target = operands.getTarget().dyn_cast_or_null<ElementsAttr>()) {
     // Store into the constant target tensor.
-    if (target.getType().getRank() == 0) {
-      return DenseElementsAttr::get(target.getType(), {value});
+    auto targetType = cast<ShapedType>(target.getType());
+    if (targetType.getRank() == 0) {
+      return DenseElementsAttr::get(targetType, {value});
     }
     if (llvm::count(operands.getIndices(), nullptr) == 0) {
       uint64_t offset = getFlattenedIndex(
-          target.getType(),
+          targetType,
           llvm::to_vector<4>(
               llvm::map_range(operands.getIndices(), [](Attribute value) {
                 return value.cast<IntegerAttr>().getValue().getZExtValue();
               })));
       SmallVector<Attribute, 16> newContents(target.getValues<Attribute>());
       newContents[offset] = value;
-      return DenseElementsAttr::get(target.getType(), newContents);
+      return DenseElementsAttr::get(targetType, newContents);
     }
   }
   return {};
@@ -834,7 +835,8 @@
 // Slices tensor from start to (start + length) exclusively at dim.
 static ElementsAttr tensorSlice(ElementsAttr tensor, uint64_t dim,
                                 uint64_t start, uint64_t length) {
-  auto shape = llvm::to_vector<4>(tensor.getType().getShape());
+  auto tensorType = cast<ShapedType>(tensor.getType());
+  auto shape = llvm::to_vector<4>(tensorType.getShape());
   if (length == shape[dim]) {
     // No need to slice.
     return tensor;
@@ -851,7 +853,7 @@
                       /*init=*/1, /*op=*/std::multiplies<int64_t>());
   int64_t num = length * step / shape[dim];
   for (int64_t offset = step / shape[dim] * start,
-               numElements = tensor.getType().getNumElements();
+               numElements = tensorType.getNumElements();
        offset < numElements; offset += step) {
     newContents.append(valuesBegin + offset, valuesBegin + offset + num);
   }
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp
index 27d3f53..969807d 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DetachElementwiseFromNamedOps.cpp
@@ -154,7 +154,7 @@
       Type elementType = resultType.getElementType();
       Value emptyTensorOp = rewriter.create<tensor::EmptyOp>(
           loc, resultType.getShape(), elementType);
-      Attribute constValue;
+      TypedAttr constValue;
       if (elementType.isa<IntegerType>()) {
         constValue = rewriter.getIntegerAttr(
             elementType, attr.template getSplatValue<APInt>());
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp
index b55747d..d5d9200 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DumpDispatchGraph.cpp
@@ -209,9 +209,9 @@
     // Elide "big" elements attributes.
     auto elements = attr.dyn_cast<ElementsAttr>();
     if (elements && elements.getNumElements() > largeAttrLimit) {
-      os << std::string(elements.getType().getRank(), '[') << "..."
-         << std::string(elements.getType().getRank(), ']') << " : "
-         << elements.getType();
+      auto type = cast<ShapedType>(elements.getType());
+      os << std::string(type.getRank(), '[') << "..."
+         << std::string(type.getRank(), ']') << " : " << type;
       return;
     }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InitializeEmptyTensors.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InitializeEmptyTensors.cpp
index d49fd9a..7cc17bb 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InitializeEmptyTensors.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InitializeEmptyTensors.cpp
@@ -20,13 +20,13 @@
 
 /// Returns a zero value attribute based on the `elementType`.
 /// Returns failure, when the type is not handled.
-static FailureOr<Attribute> getZero(OpBuilder &builder, Location loc,
+static FailureOr<TypedAttr> getZero(OpBuilder &builder, Location loc,
                                     Type elementType) {
   if (auto intType = elementType.dyn_cast<IntegerType>()) {
-    return builder.getIntegerAttr(intType, 0);
+    return cast<TypedAttr>(builder.getIntegerAttr(intType, 0));
   }
   if (auto floatType = elementType.dyn_cast<FloatType>()) {
-    return builder.getFloatAttr(floatType, 0.0);
+    return cast<TypedAttr>(builder.getFloatAttr(floatType, 0.0));
   }
   return failure();
 }
@@ -48,7 +48,7 @@
     RankedTensorType resultType = emptyTensorOp.getType();
     Type elementType = resultType.getElementType();
     Location loc = emptyTensorOp.getLoc();
-    FailureOr<Attribute> zero = getZero(rewriter, loc, elementType);
+    FailureOr<TypedAttr> zero = getZero(rewriter, loc, elementType);
     if (failed(zero)) {
       return rewriter.notifyMatchFailure(
           emptyTensorOp, "unable to get zero value for element type");
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
index a8426e1..463d87a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SetEncoding.cpp
@@ -47,17 +47,17 @@
 /// Returns a constant 0 of type `elementType`.
 static FailureOr<Value> getZero(OpBuilder &builder, Location loc,
                                 Type elementType) {
-  Attribute zeroVal =
-      TypeSwitch<Type, Attribute>(elementType)
+  TypedAttr zeroVal =
+      TypeSwitch<Type, TypedAttr>(elementType)
           .Case<FloatType>([&](FloatType floatType) -> Attribute {
-            return builder.getFloatAttr(floatType, 0);
+            return cast<TypedAttr>(builder.getFloatAttr(floatType, 0));
           })
           .Case<IntegerType>([&](IntegerType intType) -> Attribute {
-            return builder.getIntegerAttr(intType, 0);
+            return cast<TypedAttr>(builder.getIntegerAttr(intType, 0));
           })
           .Default([](Type type) { return nullptr; });
   if (!zeroVal) return failure();
-  return builder.create<arith::ConstantOp>(loc, zeroVal, elementType)
+  return builder.create<arith::ConstantOp>(loc, elementType, zeroVal)
       .getResult();
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp
index 63c742a..2cdcec2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp
@@ -36,9 +36,9 @@
         ArrayRef<Value>{
             adaptor.getAllocator(),
             rewriter.createOrFold<IREE::VM::ConstI32Op>(
-                op.getLoc(), op.getMemoryTypesAttr()),
+                op.getLoc(), op.getMemoryTypesAttr().getInt()),
             rewriter.createOrFold<IREE::VM::ConstI32Op>(
-                op.getLoc(), op.getBufferUsageAttr()),
+                op.getLoc(), op.getBufferUsageAttr().getInt()),
             adaptor.getSource(),
             castToImportType(adaptor.getOffset(), rewriter.getI64Type(),
                              rewriter),
@@ -76,9 +76,9 @@
             adaptor.getAllocator(),
             rewriter.createOrFold<IREE::VM::ConstI32Op>(op.getLoc(), /*try=*/1),
             rewriter.createOrFold<IREE::VM::ConstI32Op>(
-                op.getLoc(), op.getMemoryTypesAttr()),
+                op.getLoc(), op.getMemoryTypesAttr().getInt()),
             rewriter.createOrFold<IREE::VM::ConstI32Op>(
-                op.getLoc(), op.getBufferUsageAttr()),
+                op.getLoc(), op.getBufferUsageAttr().getInt()),
             adaptor.getSource(),
             castToImportType(adaptor.getOffset(), rewriter.getI64Type(),
                              rewriter),
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp
index 5977ad0..8aafdfd 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp
@@ -37,7 +37,7 @@
     auto queryOp = rewriter.create<IREE::HAL::DeviceQueryOp>(
         op.getLoc(), rewriter.getI1Type(), rewriter.getI64Type(),
         adaptor.getDevice(), op.getCategoryAttr(), op.getKeyAttr(),
-        Attribute{});
+        TypedAttr{});
     auto ok = queryOp.getOk().cast<Value>();
     auto value = queryOp.getValue();
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index fd295b4..f36173a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -996,7 +996,8 @@
   ArrayRef<Type> argTypes = getArgumentTypes();
   ArrayRef<Type> resultTypes = getResultTypes();
   mlir::function_interface_impl::printFunctionSignature(
-      p, op, argTypes, /*isVariadic=*/false, resultTypes);
+      p, cast<FunctionOpInterface>(op), argTypes, /*isVariadic=*/false,
+      resultTypes);
   p << " as ";
   if (resultTypes.size() != 1) p << '(';
   llvm::interleaveComma(getKeys().getValue(), p,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
index 2f8acd7..51bf863 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -769,10 +769,10 @@
     // Empty returns false (no conditions match).
     return builder.create<arith::ConstantIntOp>(loc, /*value=*/0, /*width=*/1);
   }
-  auto conditionValues =
-      llvm::map_range(getConditions(), [&](MatchAttrInterface attr) {
-        return attr.buildConditionExpression(loc, value, builder);
-      });
+  auto conditionValues = llvm::map_range(getConditions(), [&](Attribute attr) {
+    return attr.cast<MatchAttrInterface>().buildConditionExpression(loc, value,
+                                                                    builder);
+  });
   Value resultValue;
   for (auto conditionValue : conditionValues) {
     resultValue = resultValue ? builder.createOrFold<arith::OrIOp>(
@@ -798,10 +798,10 @@
     // Empty returns true (all 0 conditions match).
     return builder.create<arith::ConstantIntOp>(loc, /*value=*/1, /*width=*/1);
   }
-  auto conditionValues =
-      llvm::map_range(getConditions(), [&](MatchAttrInterface attr) {
-        return attr.buildConditionExpression(loc, value, builder);
-      });
+  auto conditionValues = llvm::map_range(getConditions(), [&](Attribute attr) {
+    return attr.cast<MatchAttrInterface>().buildConditionExpression(loc, value,
+                                                                    builder);
+  });
   Value resultValue;
   for (auto conditionValue : conditionValues) {
     resultValue = resultValue ? builder.createOrFold<arith::AndIOp>(
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index 2b26df2..1e274ea 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -49,7 +49,7 @@
   // Analyzed minimum binding sizes.
   SmallVector<Binding> bindings;
   // Push constant operands that are known constant. May be null if dynamic.
-  SmallVector<Attribute> uniformOperands;
+  SmallVector<TypedAttr> uniformOperands;
 };
 
 using DispatchParamsMap =
@@ -97,9 +97,9 @@
                             resourceLengthInt.getSExtValue()});
       }
 
-      SmallVector<Attribute> uniformOperands;
+      SmallVector<TypedAttr> uniformOperands;
       for (auto operand : dispatchOp.getUniformOperands()) {
-        Attribute uniformOperand;
+        TypedAttr uniformOperand;
         if (!matchPattern(operand, m_Constant(&uniformOperand))) {
           // Non-constant uniform operand; skip the dispatch.
           // TODO(benvanik): extract information from the executable annotations
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index e85612d..928eb7e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -434,8 +434,8 @@
 
     uint64_t dimIdx = sizeOp.getDimension().getZExtValue();
     auto dimAttr = workgroupSizeAttr[dimIdx];
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(sizeOp, dimAttr,
-                                                   rewriter.getIndexType());
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+        sizeOp, rewriter.getIndexType(), cast<TypedAttr>(dimAttr));
     return success();
   }
 };
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp
index da53892..559a8a2 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamDialect.cpp
@@ -115,7 +115,7 @@
     return builder.create<mlir::func::ConstantOp>(
         loc, type, value.cast<FlatSymbolRefAttr>());
   } else if (arith::ConstantOp::isBuildableWith(value, type)) {
-    return builder.create<arith::ConstantOp>(loc, type, value);
+    return builder.create<arith::ConstantOp>(loc, type, cast<TypedAttr>(value));
   } else if (value.isa<IREE::Stream::TimepointAttr>()) {
     return builder.create<IREE::Stream::TimepointImmediateOp>(loc);
   }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index ef44afc..20f9370 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -1037,7 +1037,7 @@
 // to be emulated - if we can avoid that here that's a big win. Some HAL
 // implementations (such as Metal) only support 8-bit fills and anything larger
 // needs to be implemented as well.
-static Attribute tryNarrowPatternBits(Attribute patternAttr) {
+static TypedAttr tryNarrowPatternBits(TypedAttr patternAttr) {
   // Get the old pattern bitcast to an APInt. Splats are bitwise operations
   // and we don't care what the value originally was.
   APInt oldPattern;
@@ -1066,7 +1066,7 @@
   LogicalResult matchAndRewrite(TensorSplatOp splatOp,
                                 PatternRewriter &rewriter) const override {
     // Try narrowing the pattern.
-    Attribute oldPatternAttr;
+    TypedAttr oldPatternAttr;
     if (!matchPattern(splatOp.getValue(), m_Constant(&oldPatternAttr))) {
       return failure();
     }
@@ -1159,7 +1159,7 @@
   LogicalResult matchAndRewrite(TensorFillOp fillOp,
                                 PatternRewriter &rewriter) const override {
     // Try narrowing the pattern.
-    Attribute oldPatternAttr;
+    TypedAttr oldPatternAttr;
     if (!matchPattern(fillOp.getValue(), m_Constant(&oldPatternAttr))) {
       return failure();
     }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp
index 0b72cfa..e3c8b0f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp
@@ -95,7 +95,7 @@
     SmallVector<TypedAttr> values;
     for (auto dispatchOp : dispatchOps) {
       auto operand = dispatchOp.getUniformOperands()[idx];
-      Attribute constantValue;
+      TypedAttr constantValue;
       matchPattern(operand, m_Constant(&constantValue));
       values.push_back(constantValue);
       set.locs.insert(operand.getLoc());
@@ -114,7 +114,7 @@
 // Builds a tensor<SITExOPERANDxTYPE> constant attribute.
 // This should probably be vector<> but that dialect has some issues with
 // expressing basic multi-dimension loads :/
-static Attribute buildConstantSetAttr(ConstantSet &set, OpBuilder &builder) {
+static TypedAttr buildConstantSetAttr(ConstantSet &set, OpBuilder &builder) {
   // TODO(benvanik): better definition of variable-width integers across HAL.
   // HACK: we can't handle index types in a few of the codegen backends (vulkan
   // at least); we convert index -> i32 here but we should probably have a
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
index 79327cf..ca73e02 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.cpp
@@ -472,8 +472,9 @@
           SerializableDenseElementsAttrModel, DenseIntOrFPElementsAttr> {
   int64_t getStorageSize(Attribute baseAttr) const {
     auto attr = baseAttr.cast<ElementsAttr>();
-    return attr.getNumElements() * IREE::Util::getRoundedElementByteWidth(
-                                       attr.getType().getElementType());
+    return attr.getNumElements() *
+           IREE::Util::getRoundedElementByteWidth(
+               cast<ShapedType>(attr.getType()).getElementType());
   }
 
   LogicalResult serializeToVector(Attribute baseAttr,
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp
index 7c38095..e890cce 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilDialect.cpp
@@ -96,7 +96,7 @@
 Operation *UtilDialect::materializeConstant(OpBuilder &builder, Attribute value,
                                             Type type, Location loc) {
   if (arith::ConstantOp::isBuildableWith(value, type)) {
-    return builder.create<arith::ConstantOp>(loc, value, type);
+    return builder.create<arith::ConstantOp>(loc, type, cast<TypedAttr>(value));
   }
   return nullptr;
 }
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
index e230d6b..3b9d0fd 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
@@ -756,7 +756,7 @@
   }];
 
   let verify = [{
-    return IREE::Util::detail::verifyTiedOp($_op);
+    return IREE::Util::detail::verifyTiedOp(cast<TiedOpInterface>($_op));
   }];
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index 2be2f10..27a823a 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -138,9 +138,8 @@
     }
     if (constantValue != initialValue) {
       operands.insert(rewriter.create<arith::ConstantOp>(
-          op.getLoc(),
-          rewriter.getIntegerAttr(op.getResult().getType(), constantValue),
-          op.getResult().getType()));
+          op.getLoc(), op.getResult().getType(),
+          rewriter.getIntegerAttr(op.getResult().getType(), constantValue)));
     }
     rewriter.replaceOpWithNewOp<OpT>(op, op.getResult().getType(),
                                      operands.takeVector());
@@ -176,7 +175,7 @@
   return makeRangeEnd(
       loc, offset, length,
       builder.create<arith::ConstantOp>(
-          loc, builder.getIntegerAttr(offset.getType(), 1), offset.getType()),
+          loc, offset.getType(), builder.getIntegerAttr(offset.getType(), 1)),
       builder);
 }
 
@@ -224,14 +223,12 @@
     // Min/max with constant ranges. This allows for normal folding to happen
     // downstream of the op.
     auto constantMinOp = rewriter.create<arith::ConstantOp>(
-        op.getLoc(),
-        rewriter.getIntegerAttr(op.getMin().getType(), constantMin),
-        op.getMin().getType());
+        op.getLoc(), op.getMin().getType(),
+        rewriter.getIntegerAttr(op.getMin().getType(), constantMin));
     auto constantMaxOp = rewriter.create<arith::ConstantOp>(
-        op.getLoc(),
+        op.getLoc(), op.getMax().getType(),
         rewriter.getIntegerAttr(op.getMax().getType(),
-                                constantMax - constantMin + 1),
-        op.getMax().getType());
+                                constantMax - constantMin + 1));
     min = min ? rewriter.create<arith::MinUIOp>(op.getLoc(), min, constantMinOp)
                     .getResult()
               : constantMinOp.getResult();
@@ -260,8 +257,8 @@
       minValue = rewriter.create<arith::MinUIOp>(loc, op.getOffsets().front(),
                                                  op.getOffsets().back());
       auto one = rewriter.create<arith::ConstantOp>(
-          loc, rewriter.getIntegerAttr(op.getMin().getType(), 1),
-          op.getMin().getType());
+          loc, op.getMin().getType(),
+          rewriter.getIntegerAttr(op.getMin().getType(), 1));
       auto endLhs = makeRangeEnd(loc, op.getOffsets().front(),
                                  op.getLengths().front(), one, rewriter);
       auto endRhs = makeRangeEnd(loc, op.getOffsets().back(),
@@ -416,8 +413,8 @@
   using OpRewritePattern<IREE::Util::UnfoldableConstantOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(UnfoldableConstantOp op,
                                 PatternRewriter &rewriter) const override {
-    auto stdConst =
-        rewriter.create<arith::ConstantOp>(op.getLoc(), op.getValue());
+    auto stdConst = rewriter.create<arith::ConstantOp>(
+        op.getLoc(), cast<TypedAttr>(op.getValue()));
     rewriter.replaceOpWithNewOp<OptimizationBarrierOp>(op,
                                                        stdConst.getResult());
     return success();
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
index 851a750..7dafc4b 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
@@ -259,7 +259,7 @@
                                     OpBuilder &builder) {
   if (arith::ConstantOp::isBuildableWith(attr, type)) {
     // Common case fast-path.
-    return builder.create<arith::ConstantOp>(loc, type, attr);
+    return builder.create<arith::ConstantOp>(loc, type, cast<TypedAttr>(attr));
   } else if (mlir::func::ConstantOp::isBuildableWith(attr, type)) {
     return builder.create<mlir::func::ConstantOp>(
         loc, type, attr.cast<FlatSymbolRefAttr>());
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp
index d65feb4..759fbc5 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp
@@ -363,8 +363,9 @@
   // themselves.
   if (arith::ConstantOp::isBuildableWith(constantValue.attr,
                                          constantValue.type)) {
-    op = builder.create<arith::ConstantOp>(
-        constantValue.loc.value(), constantValue.attr, constantValue.type);
+    op = builder.create<arith::ConstantOp>(constantValue.loc.value(),
+                                           constantValue.type,
+                                           cast<TypedAttr>(constantValue.attr));
   }
 
   // Try the attr and type dialects to see if they can materialize.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
index b69aa55..f462bfa 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
@@ -137,7 +137,8 @@
     elementValues.reserve(elementsAttr.getNumElements());
     for (auto intAttr : elementsAttr.getValues<Attribute>()) {
       elementValues.push_back(rewriter.createOrFold<mlir::arith::ConstantOp>(
-          loc, elementsAttr.getType().getElementType(), intAttr));
+          loc, elementsAttr.getType().getElementType(),
+          cast<TypedAttr>(intAttr)));
     }
     return elementValues;
   }
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td
index d773b5d..0567a57 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td
@@ -115,7 +115,7 @@
 class VM_EncTypeOf<string name> : VM_EncEncodeExpr<
     "e.encodeType({0}())", [name]>;
 class VM_EncPrimitiveAttr<string name, int thisBitwidth> : VM_EncEncodeExpr<
-    "e.encodePrimitiveAttr(getOperation()->getAttrOfType<Attribute>(\"" # name # "\"))"> {
+    "e.encodePrimitiveAttr(getOperation()->getAttrOfType<TypedAttr>(\"" # name # "\"))"> {
   int bitwidth = thisBitwidth;
 }
 class VM_EncPrimitiveArrayAttr<string name, int thisBitwidth> : VM_EncEncodeExpr<
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp
index d919419..fa38364 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMDialect.cpp
@@ -261,26 +261,29 @@
 
 Operation *VMDialect::materializeConstant(OpBuilder &builder, Attribute value,
                                           Type type, Location loc) {
-  if (ConstI32Op::isBuildableWith(value, type)) {
-    auto convertedValue = ConstI32Op::convertConstValue(value);
+  auto typedValue = dyn_cast<TypedAttr>(value);
+  if (!typedValue) return nullptr;
+
+  if (ConstI32Op::isBuildableWith(typedValue, type)) {
+    auto convertedValue = ConstI32Op::convertConstValue(typedValue);
     if (convertedValue.cast<IntegerAttr>().getValue() == 0) {
       return builder.create<VM::ConstI32ZeroOp>(loc);
     }
     return builder.create<VM::ConstI32Op>(loc, convertedValue);
-  } else if (ConstI64Op::isBuildableWith(value, type)) {
-    auto convertedValue = ConstI64Op::convertConstValue(value);
+  } else if (ConstI64Op::isBuildableWith(typedValue, type)) {
+    auto convertedValue = ConstI64Op::convertConstValue(typedValue);
     if (convertedValue.cast<IntegerAttr>().getValue() == 0) {
       return builder.create<VM::ConstI64ZeroOp>(loc);
     }
     return builder.create<VM::ConstI64Op>(loc, convertedValue);
-  } else if (ConstF32Op::isBuildableWith(value, type)) {
-    auto convertedValue = ConstF32Op::convertConstValue(value);
+  } else if (ConstF32Op::isBuildableWith(typedValue, type)) {
+    auto convertedValue = ConstF32Op::convertConstValue(typedValue);
     if (convertedValue.cast<FloatAttr>().getValue().isZero()) {
       return builder.create<VM::ConstF32ZeroOp>(loc);
     }
     return builder.create<VM::ConstF32Op>(loc, convertedValue);
-  } else if (ConstF64Op::isBuildableWith(value, type)) {
-    auto convertedValue = ConstF64Op::convertConstValue(value);
+  } else if (ConstF64Op::isBuildableWith(typedValue, type)) {
+    auto convertedValue = ConstF64Op::convertConstValue(typedValue);
     if (convertedValue.cast<FloatAttr>().getValue().isZero()) {
       return builder.create<VM::ConstF64ZeroOp>(loc);
     }
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 7cce235..d288c5a 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -486,7 +486,7 @@
     return DenseElementsAttr::get(operand.getType(), elementResult);
   } else if (auto operand = rawOperand.dyn_cast_or_null<ElementsAttr>()) {
     return operand.cast<DenseIntOrFPElementsAttr>().mapValues(
-        operand.getType().getElementType(),
+        cast<ShapedType>(operand.getType()).getElementType(),
         llvm::function_ref<ElementValueT(const ElementValueT &)>(
             [&](const ElementValueT &value) { return calculate(value); }));
   }
@@ -506,7 +506,7 @@
     return DenseElementsAttr::get(operand.getType(), elementResult);
   } else if (auto operand = rawOperand.dyn_cast_or_null<ElementsAttr>()) {
     return operand.cast<DenseIntOrFPElementsAttr>().mapValues(
-        operand.getType().getElementType(),
+        cast<ShapedType>(operand.getType()).getElementType(),
         llvm::function_ref<APInt(const APFloat &)>([&](const APFloat &value) {
           return calculate(value).bitcastToAPInt();
         }));
@@ -521,7 +521,7 @@
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT =
               std::function<ElementValueT(ElementValueT, ElementValueT)>>
-static Attribute constFoldBinaryOp(Attribute rawLhs, Attribute rawRhs,
+static TypedAttr constFoldBinaryOp(Attribute rawLhs, Attribute rawRhs,
                                    const CalculationT &calculate) {
   if (auto lhs = rawLhs.dyn_cast_or_null<AttrElementT>()) {
     auto rhs = rawRhs.dyn_cast_or_null<AttrElementT>();
@@ -550,7 +550,7 @@
       ++lhsIt;
       ++rhsIt;
     }
-    return DenseElementsAttr::get(lhs.getType(), resultAttrs);
+    return DenseElementsAttr::get(cast<ShapedType>(lhs.getType()), resultAttrs);
   }
   return {};
 }
@@ -602,7 +602,7 @@
       ++bIt;
       ++cIt;
     }
-    return DenseElementsAttr::get(a.getType(), resultAttrs);
+    return DenseElementsAttr::get(cast<ShapedType>(a.getType()), resultAttrs);
   }
   return {};
 }
@@ -2128,7 +2128,7 @@
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT =
               std::function<ElementValueT(ElementValueT, ElementValueT)>>
-static Attribute constFoldBinaryCmpFOp(Attribute rawLhs, Attribute rawRhs,
+static TypedAttr constFoldBinaryCmpFOp(Attribute rawLhs, Attribute rawRhs,
                                        const CalculationT &calculate) {
   if (auto lhs = rawLhs.dyn_cast_or_null<AttrElementT>()) {
     auto rhs = rawRhs.dyn_cast_or_null<AttrElementT>();
@@ -2143,8 +2143,9 @@
         lhs.getSplatValue<Attribute>(), rhs.getSplatValue<Attribute>(),
         calculate);
     if (!elementResult) return {};
-    return DenseElementsAttr::get(IntegerType::get(lhs.getContext(), 32),
-                                  elementResult);
+    auto resultType = lhs.getType().clone(
+        std::nullopt, IntegerType::get(lhs.getContext(), 32));
+    return DenseElementsAttr::get(resultType, elementResult);
   } else if (auto lhs = rawLhs.dyn_cast_or_null<ElementsAttr>()) {
     auto rhs = rawRhs.dyn_cast_or_null<ElementsAttr>();
     if (!rhs || lhs.getType() != rhs.getType()) return {};
@@ -2158,8 +2159,9 @@
       ++lhsIt;
       ++rhsIt;
     }
-    return DenseElementsAttr::get(IntegerType::get(lhs.getContext(), 32),
-                                  resultAttrs);
+    auto resultType = lhs.getShapedType().clone(
+        std::nullopt, IntegerType::get(lhs.getContext(), 32));
+    return DenseElementsAttr::get(resultType, resultAttrs);
   }
   return {};
 }
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 5647d62..34d4a4d 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -501,7 +501,7 @@
   } else if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
     return intAttr.getType().isInteger(SZ);
   } else if (auto elementsAttr = value.dyn_cast<ElementsAttr>()) {
-    return elementsAttr.getType().getElementType().isInteger(SZ);
+    return elementsAttr.getShapedType().getElementType().isInteger(SZ);
   }
   return false;
 }
@@ -520,14 +520,14 @@
   if (auto floatAttr = value.dyn_cast<FloatAttr>()) {
     elementType = floatAttr.getType();
   } else if (auto elementsAttr = value.dyn_cast<ElementsAttr>()) {
-    elementType = elementsAttr.getType().getElementType();
+    elementType = elementsAttr.getShapedType().getElementType();
   }
   if (!elementType) return false;
   return elementType.getIntOrFloatBitWidth() == SZ;
 }
 
 template <int SZ>
-static Attribute convertConstIntegerValue(TypedAttr value) {
+static TypedAttr convertConstIntegerValue(TypedAttr value) {
   assert(isConstIntegerBuildableWith<SZ>(value, value.getType()));
   Builder builder(value.getContext());
   auto integerType = builder.getIntegerType(SZ);
@@ -552,7 +552,7 @@
     }
   }
   assert(false && "unexpected attribute type");
-  return Attribute();
+  return TypedAttr();
 }
 
 static FloatType getFloatType(int bitwidth, MLIRContext *context) {
@@ -570,7 +570,7 @@
 }
 
 template <int SZ>
-static Attribute convertConstFloatValue(TypedAttr value) {
+static TypedAttr convertConstFloatValue(TypedAttr value) {
   assert(isConstFloatBuildableWith<SZ>(value, value.getType()));
   Builder builder(value.getContext());
   auto floatType = getFloatType(SZ, value.getContext());
@@ -589,7 +589,7 @@
     }
   }
   assert(false && "unexpected attribute type");
-  return Attribute();
+  return TypedAttr();
 }
 
 // static
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
index 8256031..3d6b954 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
@@ -153,7 +153,7 @@
         failed(writeUint16(value.getNumElements()))) {
       return currentOp_->emitOpError() << "integer array size out of bounds";
     }
-    for (auto el : value.getValues<Attribute>()) {
+    for (auto el : value.getValues<TypedAttr>()) {
       if (failed(encodePrimitiveAttr(el))) {
         return currentOp_->emitOpError() << "failed to encode element " << el;
       }
diff --git a/compiler/src/iree/compiler/InputConversion/Common/QuantizedConvToConv.cpp b/compiler/src/iree/compiler/InputConversion/Common/QuantizedConvToConv.cpp
index def3ede..ad2836e 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/QuantizedConvToConv.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/QuantizedConvToConv.cpp
@@ -41,7 +41,7 @@
   Value empty =
       builder.create<tensor::EmptyOp>(ty.getShape(), ty.getElementType(), dyn);
 
-  Attribute attr = builder.getZeroAttr(ty.getElementType());
+  TypedAttr attr = builder.getZeroAttr(ty.getElementType());
   Value cnst = builder.create<arith::ConstantOp>(attr);
   return builder.create<linalg::FillOp>(ValueRange{cnst}, ValueRange{empty})
       .result();
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
index 07a5948..974e04e 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -455,7 +455,7 @@
     Value emptyTensorOutputIndices = rewriter.create<mlir::tensor::EmptyOp>(
         loc, mixedSizes, indicesElementType);
     // Initialize indices to 0 and values to negative infinity
-    Attribute negInfAttr;
+    TypedAttr negInfAttr;
     if (auto intType = valueElementType.dyn_cast<IntegerType>()) {
       negInfAttr = rewriter.getIntegerAttr(
           intType, APInt::getSignedMinValue(intType.getWidth()));
@@ -466,7 +466,7 @@
       negInfAttr = rewriter.getFloatAttr(valueElementType, negApFloat);
     }
     Value negInf = rewriter.create<arith::ConstantOp>(loc, negInfAttr);
-    Attribute posInfAttr = rewriter.getIntegerAttr(
+    TypedAttr posInfAttr = rewriter.getIntegerAttr(
         indicesElementType, APInt::getSignedMaxValue(32));
     Value posInf = rewriter.create<arith::ConstantOp>(loc, posInfAttr);
     Value negInfTensor =
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h b/compiler/src/iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h
index 43e3592..2a78aa1 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h
@@ -311,7 +311,7 @@
   if (VectorType vecType = t.dyn_cast<VectorType>()) {
     v = SplatElementsAttr::get(vecType, v);
   }
-  return b->create<arith::ConstantOp>(loc, t, v);
+  return b->create<arith::ConstantOp>(loc, t, cast<TypedAttr>(v));
 }
 
 template <typename PredicateType>
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
index 07b470e..b797f98 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp
@@ -1558,8 +1558,8 @@
     }
 
     SmallVector<OpFoldResult, 3> startIndices, sizes;
-    Type originalStartIndexType =
-        dynamicSliceOp.getStartIndices().front().getType();
+    auto originalStartIndexType =
+        dynamicSliceOp.getStartIndices().front().getType().cast<ShapedType>();
     for (const auto& en : llvm::enumerate(
              llvm::zip(adaptor.getStartIndices(),
                        dynamicSliceOp.getSliceSizes().getValues<int64_t>()))) {
@@ -1636,9 +1636,9 @@
       // By mhlo.DynamicUpdateSlice definition:
       //   `start_indices[i] = clamp(start_indices[i],
       //       0, operand.dimension_size[i] - update.dimension_size[i])`
-      Value startIndex =
-          extractIndexFromTensor(rewriter, loc, en.value(),
-                                 op.getStartIndices()[en.index()].getType());
+      Value startIndex = extractIndexFromTensor(
+          rewriter, loc, en.value(),
+          cast<ShapedType>(op.getStartIndices()[en.index()].getType()));
       Value ub = rewriter.create<arith::ConstantIndexOp>(
           loc, operandType.getDimSize(en.index()) -
                    updateType.getDimSize(en.index()));
@@ -1724,7 +1724,7 @@
     for (Value operand : llvm::drop_begin(adaptor.getOperands(), 1)) {
       coercedOperands.push_back(coerceTensorShape(
           rewriter, loc, cast<TypedValue<ShapedType>>(operand),
-          operand0.getType()));
+          cast<ShapedType>(operand0.getType())));
     }
     Value output = rewriter.create<tensor::EmptyOp>(
         loc, tensor::getMixedSizes(rewriter, loc, operand0),
@@ -2208,8 +2208,8 @@
 
     // We have interior padding, which can be lowered to tensor.insert_slice.
     // Start by filling a result-sized tensor with the pad value.
-    auto emptyTensor =
-        getEmptyTensorFor(rewriter, loc, resultType, op, adaptor.getOperands());
+    auto emptyTensor = getEmptyTensorFor(
+        rewriter, loc, cast<ShapedType>(resultType), op, adaptor.getOperands());
     auto fill =
         rewriter.create<linalg::FillOp>(loc, paddingVal, emptyTensor).result();
 
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp
index 589415e..e0d100b 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp
@@ -120,7 +120,7 @@
       if (getRank(input) == maxRank) {
         mappedInputs.push_back(coerceTensorShape(
             rewriter, loc, cast<TypedValue<ShapedType>>(input),
-            emptyTensor.getType()));
+            cast<ShapedType>(emptyTensor.getType())));
         scalarInputs.push_back(nullptr);
       } else {
         scalarInputs.push_back(rewriter.create<tensor::ExtractOp>(loc, input));
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
index a909b21..fb1b8fc 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp
@@ -111,8 +111,9 @@
       initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue);
 
       SmallVector<Value, 8> dynShape = getReduceOpEmptyTensorDynSizes(
-          rewriter, loc, operand, resultType, reductionDims);
-      auto emptyTensor = getEmptyTensor(rewriter, loc, resultType, dynShape);
+          rewriter, loc, operand, cast<ShapedType>(resultType), reductionDims);
+      auto emptyTensor =
+          getEmptyTensor(rewriter, loc, cast<ShapedType>(resultType), dynShape);
       Value filledTensor =
           rewriter.create<linalg::FillOp>(loc, initValue, emptyTensor).result();
       outputs.push_back(filledTensor);
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
index 82b304b..fa1ef18 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -617,12 +617,12 @@
 
     // Return the output type.
     ShapedType getOutputType() {
-      return getOutput().getType();
+      return getOutput().getType().cast<ShapedType>();
     }
 
     // Return the input type.
     ShapedType getInputType() {
-      return getInput().getType();
+      return getInput().getType().cast<ShapedType>();
     }
 
     // Return the output shape.
@@ -764,12 +764,12 @@
 
     // Return the output type.
     ShapedType getOutputType() {
-      return getOutput().getType();
+      return getOutput().getType().cast<ShapedType>();
     }
 
     // Return the input type.
     ShapedType getInputType() {
-      return getInput().getType();
+      return getInput().getType().cast<ShapedType>();
     }
 
     // Return the output shape.
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
index 812f0df..70aba5c 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h
@@ -263,7 +263,7 @@
     // needs more investigation.
     rewriter.startRootUpdate(op);
     std::optional<linalg::LinalgOp> promotedOp =
-        promoteSubViews(rewriter, op, options);
+        promoteSubViews(rewriter, cast<linalg::LinalgOp>(op), options);
     if (!promotedOp) {
       rewriter.cancelRootUpdate(op);
       return op->emitError("subview promotion failed");
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp
index b1163b8..0680be4 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertConv2DToWinograd.cpp
@@ -62,7 +62,7 @@
 /// TODO: Codegen this as a kernel and run once at initialization
 static DenseElementsAttr
 foldFilterTransform(ArrayRef<int64_t> shape, int64_t inputTileSize,
-                    int64_t kernelSize, Type outputType, const float *G,
+                    int64_t kernelSize, ShapedType outputType, const float *G,
                     bool isSplat, float splatValue,
                     DenseElementsAttr::iterator_range<APFloat> &input,
                     FloatType floatType, bool isNchw) {
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/SplitReduction.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/SplitReduction.cpp
index c99d4e8..e8a3807 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/SplitReduction.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/SplitReduction.cpp
@@ -149,7 +149,7 @@
 
   // Initialize indices to positive infinity and values to negative infinity
   // for a top (maxk) comparison.
-  Attribute negInfAttr;
+  TypedAttr negInfAttr;
   if (auto intType = valueElementType.dyn_cast<IntegerType>()) {
     negInfAttr = rewriter.getIntegerAttr(
         intType, APInt::getSignedMinValue(intType.getWidth()));
@@ -160,7 +160,7 @@
     negInfAttr = rewriter.getFloatAttr(valueElementType, negApFloat);
   }
   Value negInf = rewriter.create<arith::ConstantOp>(loc, negInfAttr);
-  Attribute posInfAttr =
+  TypedAttr posInfAttr =
       rewriter.getIntegerAttr(indicesElementType, APInt::getSignedMaxValue(32));
   Value posInf = rewriter.create<arith::ConstantOp>(loc, posInfAttr);
   Value negInfTensor =
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
index 3f38e01..6723b97 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.cpp
@@ -53,7 +53,8 @@
     // Apply the pattern.
     SimplePatternRewriter rewriter(target);
     FailureOr<LinalgExt::FusionResult> result =
-        pattern.returningMatchAndRewrite(target, rewriter);
+        pattern.returningMatchAndRewrite(cast<TilingInterface>(target),
+                                         rewriter);
     if (failed(result))
       return emitDefaultDefiniteFailure(target);
 
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
index 25ee7bb..177b43e 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/linalg_transform/single-tiling-full-script.mlir
@@ -33,7 +33,7 @@
     lowering_strategy = "outerproduct"
       : (!pdl.operation) -> !pdl.operation
   %func_e_3 = transform.vector.lower_transpose %func_e_2
-    lowering_strategy = "shuffle"
+    lowering_strategy = "shuffle_1d"
       : (!pdl.operation) -> !pdl.operation
 
   lower_to_llvm %module_op1 : (!pdl.operation) -> !pdl.operation
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 5e70d1a..be9c918 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 5e70d1adf85fcdd695f5d796cb0617a8fdf90f82
+Subproject commit be9c91843bab5bb46574c27836bfcd9ad6fc9ef5
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index 34c53ad..4d28523 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit 34c53ad70a3cd1a64f04aff8627badc5310a7970
+Subproject commit 4d28523ba6ad600f50900f9048b06056276bf9db