[Codegen] Upgrade Common, SPIRV, VMVX to free create functions. NFC. (#21879)

The builder create methods are deprecated:
https://mlir.llvm.org/deprecation/. See
https://discourse.llvm.org/t/psa-opty-create-now-with-100-more-tab-complete/87339.

The main benefit of free functions is better tab completion with
LSP/IDE.

I'm splitting the upgrade in chunks going by project directories.
diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
index 1abfea0..217ff46 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
@@ -167,14 +167,14 @@
   auto outputType = RankedTensorType::get(
       staticOutputShape, tensorType.getElementType(), tensorType.getEncoding());
 
-  auto expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
-      loc, outputType, v, reassociation, outputShape);
+  auto expandShapeOp = tensor::ExpandShapeOp::create(
+      rewriter, loc, outputType, v, reassociation, outputShape);
   Value barrier = rewriter
                       .create<IREE::Util::OptimizationBarrierOp>(
                           loc, expandShapeOp.getResult())
                       .getResult(0);
-  auto collapseShapeOp = rewriter.create<tensor::CollapseShapeOp>(
-      loc, tensorType, barrier, reassociation);
+  auto collapseShapeOp = tensor::CollapseShapeOp::create(
+      rewriter, loc, tensorType, barrier, reassociation);
   return ReshapeOps{expandShapeOp, collapseShapeOp};
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/BubbleUpOrdinalOps.cpp b/compiler/src/iree/compiler/Codegen/Common/BubbleUpOrdinalOps.cpp
index 42aa481..cea3e25 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BubbleUpOrdinalOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/BubbleUpOrdinalOps.cpp
@@ -56,11 +56,10 @@
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(sourceCastOp);
     Location loc = ordinalOp.getLoc();
-    Value reverseCastOp = rewriter.create<CastOpTy>(
-        loc, rewriter.getIndexType(), sourceCastOp.getIn());
-    Value newOrdinalOp =
-        rewriter.create<IREE::TensorExt::DispatchWorkloadOrdinalOp>(
-            loc, reverseCastOp, ordinalOp.getOrdinal());
+    Value reverseCastOp = CastOpTy::create(
+        rewriter, loc, rewriter.getIndexType(), sourceCastOp.getIn());
+    Value newOrdinalOp = IREE::TensorExt::DispatchWorkloadOrdinalOp::create(
+        rewriter, loc, reverseCastOp, ordinalOp.getOrdinal());
     rewriter.replaceOp(sourceCastOp, newOrdinalOp);
     rewriter.replaceOp(ordinalOp, newOrdinalOp);
     return success();
diff --git a/compiler/src/iree/compiler/Codegen/Common/BufferizeDispatchTensorLoadStore.cpp b/compiler/src/iree/compiler/Codegen/Common/BufferizeDispatchTensorLoadStore.cpp
index e4695d9..24cbce4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BufferizeDispatchTensorLoadStore.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/BufferizeDispatchTensorLoadStore.cpp
@@ -38,8 +38,8 @@
   MemRefType subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
       tensorType.getShape(), cast<MemRefType>(subspanMemref.getType()), offsets,
       sizes, strides);
-  return rewriter.create<memref::SubViewOp>(
-      loc, subviewMemRefType, subspanMemref, offsets, sizes, strides);
+  return memref::SubViewOp::create(rewriter, loc, subviewMemRefType,
+                                   subspanMemref, offsets, sizes, strides);
 }
 
 static void
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp
index 1e01c42..e6e9ce4 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPULowerToUKernels.cpp
@@ -239,26 +239,26 @@
   flags |= IREE_UK_FLAG_MMT4D_ALLOW_GENERIC_FALLBACK_TILE_FUNCTION;
 
   Location loc = op.getLoc();
-  Value m = rewriter.create<tensor::DimOp>(loc, lhs, 0);
-  Value n = rewriter.create<tensor::DimOp>(loc, rhs, 0);
-  Value k = rewriter.create<tensor::DimOp>(loc, rhs, 1);
+  Value m = tensor::DimOp::create(rewriter, loc, lhs, 0);
+  Value n = tensor::DimOp::create(rewriter, loc, rhs, 0);
+  Value k = tensor::DimOp::create(rewriter, loc, rhs, 1);
 
   auto getDimAsI32 = [](RewriterBase &rewriter, Location loc, Value value,
                         int dim) -> Value {
-    return rewriter.create<arith::IndexCastOp>(
-        loc, rewriter.getI32Type(),
-        rewriter.create<tensor::DimOp>(loc, value, dim));
+    return arith::IndexCastOp::create(
+        rewriter, loc, rewriter.getI32Type(),
+        tensor::DimOp::create(rewriter, loc, value, dim));
   };
   Value m0 = getDimAsI32(rewriter, loc, lhs, 2);
   Value n0 = getDimAsI32(rewriter, loc, rhs, 2);
   Value k0 = getDimAsI32(rewriter, loc, rhs, 3);
-  Value flagsVal = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32IntegerAttr(flags));
+  Value flagsVal = arith::ConstantOp::create(rewriter, loc,
+                                             rewriter.getI32IntegerAttr(flags));
   auto fn = getFnNameAndDefAttrs(ukernelName, rewriter, targetAttr);
   SmallVector<Type> returnTypes =
       getUKernelGenericReturnTypes(targetAttr, outType);
-  auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
-      loc, returnTypes, fn.name, ValueRange{lhs, rhs}, out,
+  auto genericMicroKernelOp = IREE::Codegen::UKernelGenericOp::create(
+      rewriter, loc, returnTypes, fn.name, ValueRange{lhs, rhs}, out,
       ValueRange{m, n, k, m0, n0, k0, flagsVal},
       /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs),
       /*num_strided_outer_dims=*/1);
@@ -344,8 +344,8 @@
   Value paddingVal = op.getPaddingValue();
   // If the pack op didn't have a padding_value attribute, default to 0.
   if (!paddingVal) {
-    paddingVal =
-        rewriter.create<arith::ConstantOp>(loc, i64, rewriter.getZeroAttr(i64));
+    paddingVal = arith::ConstantOp::create(rewriter, loc, i64,
+                                           rewriter.getZeroAttr(i64));
   }
   int paddingValBitWidth = paddingVal.getType().getIntOrFloatBitWidth();
   // Non-integer element types get bitcast to integer of same bit width.
@@ -355,7 +355,7 @@
       return rewriter.notifyMatchFailure(op, "no integer type with this width");
     }
     paddingVal =
-        rewriter.create<arith::BitcastOp>(loc, sameWidthIntType, paddingVal);
+        arith::BitcastOp::create(rewriter, loc, sameWidthIntType, paddingVal);
   }
   // Element types > 64bits could be supported, when the padding value is a
   // repeating 64-bit pattern. For now, we leave this as not-yet-implemented.
@@ -366,21 +366,21 @@
   // Integers narrower than 64 bit get extended to 64 bits, it doesn't matter
   // how, as the high bits are unused.
   if (paddingValBitWidth < 64) {
-    paddingVal = rewriter.create<arith::ExtUIOp>(loc, i64, paddingVal);
+    paddingVal = arith::ExtUIOp::create(rewriter, loc, i64, paddingVal);
   }
-  Value in_size0 = rewriter.create<tensor::DimOp>(loc, in, 0);
-  Value in_size1 = rewriter.create<tensor::DimOp>(loc, in, 1);
-  Value out_size0 = rewriter.create<tensor::DimOp>(loc, out, 0);
-  Value out_size1 = rewriter.create<tensor::DimOp>(loc, out, 1);
-  Value out_size2 = rewriter.create<tensor::DimOp>(loc, out, 2);
-  Value out_size3 = rewriter.create<tensor::DimOp>(loc, out, 3);
-  Value flagsVal = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32IntegerAttr(flags));
+  Value in_size0 = tensor::DimOp::create(rewriter, loc, in, 0);
+  Value in_size1 = tensor::DimOp::create(rewriter, loc, in, 1);
+  Value out_size0 = tensor::DimOp::create(rewriter, loc, out, 0);
+  Value out_size1 = tensor::DimOp::create(rewriter, loc, out, 1);
+  Value out_size2 = tensor::DimOp::create(rewriter, loc, out, 2);
+  Value out_size3 = tensor::DimOp::create(rewriter, loc, out, 3);
+  Value flagsVal = arith::ConstantOp::create(rewriter, loc,
+                                             rewriter.getI32IntegerAttr(flags));
   auto fn = getFnNameAndDefAttrs(ukernelName, rewriter, targetAttr);
   SmallVector<Type> returnTypes =
       getUKernelGenericReturnTypes(targetAttr, outType);
-  auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
-      loc, returnTypes, fn.name, in, out,
+  auto genericMicroKernelOp = IREE::Codegen::UKernelGenericOp::create(
+      rewriter, loc, returnTypes, fn.name, in, out,
       ValueRange{in_size0, in_size1, out_size0, out_size1, out_size2, out_size3,
                  paddingVal, flagsVal},
       /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs),
@@ -456,19 +456,19 @@
   }
 
   Location loc = op.getLoc();
-  Value in_size0 = rewriter.create<tensor::DimOp>(loc, in, 0);
-  Value in_size1 = rewriter.create<tensor::DimOp>(loc, in, 1);
-  Value in_size2 = rewriter.create<tensor::DimOp>(loc, in, 2);
-  Value in_size3 = rewriter.create<tensor::DimOp>(loc, in, 3);
-  Value out_size0 = rewriter.create<tensor::DimOp>(loc, out, 0);
-  Value out_size1 = rewriter.create<tensor::DimOp>(loc, out, 1);
-  Value flagsVal = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32IntegerAttr(flags));
+  Value in_size0 = tensor::DimOp::create(rewriter, loc, in, 0);
+  Value in_size1 = tensor::DimOp::create(rewriter, loc, in, 1);
+  Value in_size2 = tensor::DimOp::create(rewriter, loc, in, 2);
+  Value in_size3 = tensor::DimOp::create(rewriter, loc, in, 3);
+  Value out_size0 = tensor::DimOp::create(rewriter, loc, out, 0);
+  Value out_size1 = tensor::DimOp::create(rewriter, loc, out, 1);
+  Value flagsVal = arith::ConstantOp::create(rewriter, loc,
+                                             rewriter.getI32IntegerAttr(flags));
   auto fn = getFnNameAndDefAttrs(ukernelName, rewriter, targetAttr);
   SmallVector<Type> returnTypes =
       getUKernelGenericReturnTypes(targetAttr, outType);
-  auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
-      loc, returnTypes, fn.name, in, out,
+  auto genericMicroKernelOp = IREE::Codegen::UKernelGenericOp::create(
+      rewriter, loc, returnTypes, fn.name, in, out,
       ValueRange{in_size0, in_size1, in_size2, in_size3, out_size0, out_size1,
                  flagsVal},
       /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs),
@@ -577,7 +577,7 @@
   SmallVector<Value> inputValues;
   Location loc = op.getLoc();
   for (int64_t i : tensorType.getShape()) {
-    inputValues.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
+    inputValues.push_back(arith::ConstantIndexOp::create(rewriter, loc, i));
   }
   uint32_t flagForUserAndOperandTypes =
       getFlagForUserAndOperandTypes(encoding, encoding.getElementTypesArray());
@@ -586,11 +586,11 @@
   if (!flagForUserAndOperandTypes || !flagForIndex) {
     return rewriter.notifyMatchFailure(op, "unhandled encoding");
   }
-  inputValues.push_back(rewriter.create<arith::ConstantIntOp>(
-      loc, flagForUserAndOperandTypes | flagForIndex, 32));
+  inputValues.push_back(arith::ConstantIntOp::create(
+      rewriter, loc, flagForUserAndOperandTypes | flagForIndex, 32));
   auto fn = getFnNameAndDefAttrs(ukernelName, rewriter, targetAttr);
-  auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
-      loc, resultTypes, fn.name, inputValues, /*outs=*/ValueRange{},
+  auto genericMicroKernelOp = IREE::Codegen::UKernelGenericOp::create(
+      rewriter, loc, resultTypes, fn.name, inputValues, /*outs=*/ValueRange{},
       /*other_operands=*/ValueRange{},
       /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs),
       /*strided_dims=*/nullptr);
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp
index dce38ed..9f2a569 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPrepareUkernels.cpp
@@ -233,9 +233,9 @@
           rhs.getLoc(), "rhs producer should be reduced, but reduction failed");
     }
 
-    auto mmt4DOp = rewriter.create<linalg::Mmt4DOp>(
-        loc, reducedOut.getType(), ValueRange{reducedLhs, reducedRhs},
-        ValueRange{reducedOut});
+    auto mmt4DOp = linalg::Mmt4DOp::create(rewriter, loc, reducedOut.getType(),
+                                           ValueRange{reducedLhs, reducedRhs},
+                                           ValueRange{reducedOut});
 
     auto loweringConfig = getLoweringConfig<IREE::CPU::LoweringConfigAttr>(op);
     if (loweringConfig) {
@@ -308,9 +308,9 @@
     auto reducedDest = tensor::createCanonicalRankReducingExtractSliceOp(
         rewriter, loc, packOp.getDest(), reducedDestType);
 
-    auto newPackOp = rewriter.create<linalg::PackOp>(
-        loc, reducedSrc, reducedDest, newInnerDimsPos, packOp.getMixedTiles(),
-        packOp.getPaddingValue(), newOuterDimsPerm);
+    auto newPackOp = linalg::PackOp::create(
+        rewriter, loc, reducedSrc, reducedDest, newInnerDimsPos,
+        packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
 
     auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp(
         rewriter, loc, newPackOp.getResult(), packOp.getDest());
@@ -386,9 +386,9 @@
     auto reducedDest = tensor::createCanonicalRankReducingExtractSliceOp(
         rewriter, loc, unpackOp.getDest(), reducedDestType);
 
-    auto newUnpackOp = rewriter.create<linalg::UnPackOp>(
-        loc, reducedSrc, reducedDest, newInnerDimsPos, unpackOp.getMixedTiles(),
-        newOuterDimsPerm);
+    auto newUnpackOp = linalg::UnPackOp::create(
+        rewriter, loc, reducedSrc, reducedDest, newInnerDimsPos,
+        unpackOp.getMixedTiles(), newOuterDimsPerm);
 
     auto insertSliceOp = tensor::createCanonicalRankReducingInsertSliceOp(
         rewriter, loc, newUnpackOp.getResult(), unpackOp.getDest());
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPropagateDataLayout.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPropagateDataLayout.cpp
index 99e5eb7..6f4466e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPropagateDataLayout.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUPropagateDataLayout.cpp
@@ -126,10 +126,10 @@
     }
 
     Location loc = op.getLoc();
-    auto newDestOp = rewriter.create<tensor::EmptyOp>(
-        loc, destShape, emptyOp.getType().getElementType());
-    auto newUnpackOp = rewriter.create<linalg::UnPackOp>(
-        loc, collapseOp.getSrc(), newDestOp, innerDimPos, innerTiles);
+    auto newDestOp = tensor::EmptyOp::create(
+        rewriter, loc, destShape, emptyOp.getType().getElementType());
+    auto newUnpackOp = linalg::UnPackOp::create(
+        rewriter, loc, collapseOp.getSrc(), newDestOp, innerDimPos, innerTiles);
     SmallVector<ReassociationIndices> newRi;
     for (int64_t i = 0, e = op.getDestRank(); i < e; ++i) {
       if (i == outerRi[0]) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp b/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp
index 82f7f02..725e877 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CombineLayoutTransformation.cpp
@@ -134,10 +134,12 @@
 
   auto indexTransformBuilder =
       [&](ArrayRef<BlockArgument> srcIndices) -> SmallVector<Value> {
-    auto linearizeIndexOp = rewriter.create<affine::AffineLinearizeIndexOp>(
-        mapScatterOp->getLoc(), srcIndices, srcDims, /*disjoint=*/true);
-    auto delinearizeIndexOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
-        mapScatterOp->getLoc(), linearizeIndexOp.getResult(), resultDims,
+    auto linearizeIndexOp = affine::AffineLinearizeIndexOp::create(
+        rewriter, mapScatterOp->getLoc(), srcIndices, srcDims,
+        /*disjoint=*/true);
+    auto delinearizeIndexOp = affine::AffineDelinearizeIndexOp::create(
+        rewriter, mapScatterOp->getLoc(), linearizeIndexOp.getResult(),
+        resultDims,
         /*hasOuterBound=*/true);
     return delinearizeIndexOp->getResults();
   };
@@ -221,7 +223,7 @@
             .create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, srcIdx,
                                    boundValue)
             ->getResult(0);
-    mask = rewriter.create<arith::AndIOp>(loc, mask, isOutOfBounds);
+    mask = arith::AndIOp::create(rewriter, loc, mask, isOutOfBounds);
   }
   rewriter.modifyOpInPlace(yieldOp, [&]() {
     yieldOp->setOperand(yieldOp->getNumOperands() - 1, mask);
@@ -240,8 +242,8 @@
   DistributionConfig distConfig = distConfigs[distributionLevel];
   SmallVector<OpFoldResult> steps =
       getAsIndexOpFoldResult(rewriter.getContext(), distConfig.tileSizes);
-  rewriter.create<scf::ForallOp>(
-      loc, lbs, ubs, steps, /*outputs=*/ValueRange(),
+  scf::ForallOp::create(
+      rewriter, loc, lbs, ubs, steps, /*outputs=*/ValueRange(),
       rewriter.getArrayAttr(distConfig.mapping),
       /*bodyBuilder=*/[&](OpBuilder &b, Location nestedLoc, ValueRange ivs) {
         SmallVector<OpFoldResult> nestedLbs(ivs);
@@ -260,7 +262,7 @@
           buildNestedDistributionLoops(
               rewriter, nestedLoc, distributionLevel + 1, nestedLbs, nestedUbs,
               distConfigs, innerLoopBuilder);
-          b.create<scf::InParallelOp>(nestedLoc);
+          scf::InParallelOp::create(b, nestedLoc);
           return;
         }
         // Otherwise, tile to one, and generate the inner loop body.
@@ -268,11 +270,11 @@
             getValueOrCreateConstantIndexOp(b, nestedLoc, nestedLbs);
         SmallVector<Value> nestedUbVals =
             getValueOrCreateConstantIndexOp(b, nestedLoc, nestedUbs);
-        Value one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
+        Value one = arith::ConstantIndexOp::create(rewriter, nestedLoc, 1);
         SmallVector<Value> unitSteps(nestedLbs.size(), one);
         scf::buildLoopNest(rewriter, nestedLoc, nestedLbVals, nestedUbVals,
                            unitSteps, innerLoopBuilder);
-        b.create<scf::InParallelOp>(nestedLoc);
+        scf::InParallelOp::create(b, nestedLoc);
       });
 }
 
@@ -324,11 +326,12 @@
     Value mask = storeIndices.pop_back_val();
     // Create the store to the outputBuffer.
     auto thenBuilder = [&](OpBuilder &nestedBuilder, Location ifLoc) {
-      nestedBuilder.create<memref::StoreOp>(
-          ifLoc, padOp.getConstantPaddingValue(), outputBuffer, storeIndices);
-      nestedBuilder.create<scf::YieldOp>(ifLoc);
+      memref::StoreOp::create(nestedBuilder, ifLoc,
+                              padOp.getConstantPaddingValue(), outputBuffer,
+                              storeIndices);
+      scf::YieldOp::create(nestedBuilder, ifLoc);
     };
-    b.create<scf::IfOp>(loopLoc, mask, thenBuilder);
+    scf::IfOp::create(b, loopLoc, mask, thenBuilder);
   };
 
   // Distribute the padding of each dimension separately. This causes some
@@ -439,7 +442,7 @@
   Type elementType = cast<RankedTensorType>(root.getType()).getElementType();
   SmallVector<OpFoldResult> sizes = tensor::getMixedSizes(rewriter, loc, root);
   Value mapScatterDest =
-      rewriter.create<tensor::EmptyOp>(loc, sizes, elementType);
+      tensor::EmptyOp::create(rewriter, loc, sizes, elementType);
   auto mapScatterOp = MapScatterOp::createIdentityMapScatter(
       rewriter, loc, root, mapScatterDest);
   rewriter.replaceUsesWithIf(
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp
index 2679d4a..75bfea3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp
@@ -80,27 +80,28 @@
   // contraction op.
   SmallVector<OpFoldResult> mixedSizes =
       tensor::getMixedSizes(rewriter, loc, outputOperand);
-  Value initOp = rewriter.create<tensor::EmptyOp>(loc, mixedSizes, elementType);
-  Value zero = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getZeroAttr(elementType));
-  Value fill = rewriter.create<linalg::FillOp>(loc, zero, initOp).result();
+  Value initOp =
+      tensor::EmptyOp::create(rewriter, loc, mixedSizes, elementType);
+  Value zero = arith::ConstantOp::create(rewriter, loc,
+                                         rewriter.getZeroAttr(elementType));
+  Value fill = linalg::FillOp::create(rewriter, loc, zero, initOp).result();
 
   // Update the contraction op to use the new zero tensor as output operand.
   rewriter.modifyOpInPlace(dpsOp, [&]() { dpsOp.setDpsInitOperand(0, fill); });
 
   // Create a generic op to add back the original output tensor operand.
   rewriter.setInsertionPointAfter(dpsOp);
-  auto genericOp = rewriter.create<linalg::GenericOp>(
-      loc, outputType, ValueRange{dpsOp->getResult(0), outputOperand},
+  auto genericOp = linalg::GenericOp::create(
+      rewriter, loc, outputType, ValueRange{dpsOp->getResult(0), outputOperand},
       ValueRange{initOp}, maps, iterators,
       [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
         Value result;
         if (llvm::isa<FloatType>(elementType)) {
-          result = b.create<arith::AddFOp>(nestedLoc, args[0], args[1]);
+          result = arith::AddFOp::create(b, nestedLoc, args[0], args[1]);
         } else {
-          result = b.create<arith::AddIOp>(nestedLoc, args[0], args[1]);
+          result = arith::AddIOp::create(b, nestedLoc, args[0], args[1]);
         }
-        b.create<linalg::YieldOp>(nestedLoc, result);
+        linalg::YieldOp::create(b, nestedLoc, result);
       });
   dpsOp->getResult(0).replaceAllUsesExcept(genericOp->getResult(0), genericOp);
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp
index 241d9bb..7264a25 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ArithToF32.cpp
@@ -51,11 +51,11 @@
     return nullptr;
 
   if (inputETy.getIntOrFloatBitWidth() > eTy.getIntOrFloatBitWidth()) {
-    return builder.create<arith::TruncFOp>(loc, type, inputs[0]);
+    return arith::TruncFOp::create(builder, loc, type, inputs[0]);
   }
 
   if (inputETy.getIntOrFloatBitWidth() < eTy.getIntOrFloatBitWidth()) {
-    return builder.create<arith::ExtFOp>(loc, type, inputs[0]);
+    return arith::ExtFOp::create(builder, loc, type, inputs[0]);
   }
 
   return nullptr;
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp
index 55fcb12..57e80e6 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertBf16ToUInt16Buffers.cpp
@@ -256,7 +256,7 @@
 
 Value materializeArithBitcast(OpBuilder &builder, Type resultTy,
                               mlir::ValueRange inputs, mlir::Location loc) {
-  return builder.create<arith::BitcastOp>(loc, resultTy, inputs);
+  return arith::BitcastOp::create(builder, loc, resultTy, inputs);
 }
 
 static void populateIreeBf16EmulationPatterns(RewritePatternSet &patterns,
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
index e5d118e..1d5800c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
@@ -78,8 +78,8 @@
     OpBuilder &b, IREE::TensorExt::DispatchTensorStoreOp storeOp) {
   // Clone the offset, size and stride values. They will be CSE-ed later.
   SliceAndDynamicDims clonedVals = cloneOffsetsSizesAndStrides(b, storeOp);
-  Value tensorLoadOp = b.create<IREE::TensorExt::DispatchTensorLoadOp>(
-      storeOp.getLoc(),
+  Value tensorLoadOp = IREE::TensorExt::DispatchTensorLoadOp::create(
+      b, storeOp.getLoc(),
       llvm::cast<RankedTensorType>(storeOp.getValue().getType()),
       storeOp.getTarget(), clonedVals.dynamicDims, clonedVals.offsets,
       clonedVals.sizes, clonedVals.strides);
@@ -95,17 +95,17 @@
   using ReverseReshapeOpTy = typename std::conditional<
       std::is_same<TensorReshapeOpTy, tensor::CollapseShapeOp>::value,
       tensor::ExpandShapeOp, tensor::CollapseShapeOp>::type;
-  return b.create<ReverseReshapeOpTy>(reshapeOp.getLoc(),
-                                      reshapeOp.getSrcType(), resultBuffer,
-                                      reshapeOp.getReassociationIndices());
+  return ReverseReshapeOpTy::create(b, reshapeOp.getLoc(),
+                                    reshapeOp.getSrcType(), resultBuffer,
+                                    reshapeOp.getReassociationIndices());
 }
 
 /// Gets the reverse of a `tensor.cast` op to get a memref type that
 /// can be used for in-place computation of the result of a disaptch region.
 static Value getReverseOfCastOp(OpBuilder &b, tensor::CastOp castOp,
                                 Value resultBuffer) {
-  return b.create<tensor::CastOp>(castOp.getLoc(), castOp.getSource().getType(),
-                                  resultBuffer);
+  return tensor::CastOp::create(b, castOp.getLoc(),
+                                castOp.getSource().getType(), resultBuffer);
 }
 
 /// Returns a tied result value give the operand. If no such result exists,
@@ -400,8 +400,8 @@
   Location loc = genericOp.getLoc();
   SmallVector<utils::IteratorType> iterTypes(genericOp.getNumLoops(),
                                              utils::IteratorType::parallel);
-  auto newOp = rewriter.create<linalg::GenericOp>(
-      loc, newResultTypes, newInputs, newOutputs, maps, iterTypes,
+  auto newOp = linalg::GenericOp::create(
+      rewriter, loc, newResultTypes, newInputs, newOutputs, maps, iterTypes,
       /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
   rewriter.inlineRegionBefore(genericOp.getRegion(), newOp.getRegion(),
                               newOp.getRegion().begin());
@@ -492,8 +492,8 @@
 
     OpBuilder::InsertionGuard g(b);
     b.setInsertionPointAfter(emptyOp);
-    auto allocTensor = b.create<bufferization::AllocTensorOp>(
-        emptyOp.getLoc(), emptyOp.getType(), emptyOp.getDynamicSizes());
+    auto allocTensor = bufferization::AllocTensorOp::create(
+        b, emptyOp.getLoc(), emptyOp.getType(), emptyOp.getDynamicSizes());
     emptyOp.replaceAllUsesWith(allocTensor.getResult());
   });
 
@@ -522,11 +522,11 @@
       TypedAttr scalarAttr = attr.getValues<TypedAttr>()[0];
 
       modifiedOutput = true;
-      Value emptyTensor = rewriter.create<tensor::EmptyOp>(
-          loc, type.getShape(), type.getElementType());
-      Value cstOp = rewriter.create<arith::ConstantOp>(loc, scalarAttr);
+      Value emptyTensor = tensor::EmptyOp::create(
+          rewriter, loc, type.getShape(), type.getElementType());
+      Value cstOp = arith::ConstantOp::create(rewriter, loc, scalarAttr);
       Value fillOp =
-          rewriter.create<linalg::FillOp>(loc, cstOp, emptyTensor).result();
+          linalg::FillOp::create(rewriter, loc, cstOp, emptyTensor).result();
       op->setOperand(opOperand.getOperandNumber(), fillOp);
     }
     if (!modifiedOutput) {
@@ -583,8 +583,8 @@
       auto yieldedVal = yieldOp.getOperand(resultNumber);
       SliceAndDynamicDims sliceAndDynamicDims =
           cloneOffsetsSizesAndStrides(rewriter, storeOp);
-      rewriter.create<IREE::TensorExt::DispatchTensorStoreOp>(
-          storeOp.getLoc(), yieldedVal, storeOp.getTarget(),
+      IREE::TensorExt::DispatchTensorStoreOp::create(
+          rewriter, storeOp.getLoc(), yieldedVal, storeOp.getTarget(),
           sliceAndDynamicDims.dynamicDims, sliceAndDynamicDims.offsets,
           sliceAndDynamicDims.sizes, sliceAndDynamicDims.strides);
     };
diff --git a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
index fa18d0f..97d3a1a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
@@ -167,8 +167,8 @@
 
   auto offsets = rewriter.getI64ArrayAttr({offset});
   auto strides = rewriter.getI64ArrayAttr({1});
-  return rewriter.create<vector::InsertStridedSliceOp>(loc, destVecTy, src,
-                                                       dest, offsets, strides);
+  return vector::InsertStridedSliceOp::create(rewriter, loc, destVecTy, src,
+                                              dest, offsets, strides);
 }
 
 /// Extract `sliceNumElements` from source `vector` at `extractOffset`,
@@ -195,8 +195,9 @@
   assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
          "vector element must be a valid sub-byte type");
   auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
-  auto emptyByteVector = rewriter.create<arith::ConstantOp>(
-      loc, VectorType::get({emulatedPerContainerElem}, vectorElementType),
+  auto emptyByteVector = arith::ConstantOp::create(
+      rewriter, loc,
+      VectorType::get({emulatedPerContainerElem}, vectorElementType),
       rewriter.getZeroAttr(
           VectorType::get({emulatedPerContainerElem}, vectorElementType)));
   auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
@@ -216,16 +217,17 @@
           upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
       "expected input and output number of bits to match");
   if (trueValue.getType() != downcastType) {
-    trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
+    trueValue =
+        vector::BitCastOp::create(builder, loc, downcastType, trueValue);
   }
   if (falseValue.getType() != downcastType) {
     falseValue =
-        builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
+        vector::BitCastOp::create(builder, loc, downcastType, falseValue);
   }
   Value selectedType =
-      builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
+      arith::SelectOp::create(builder, loc, mask, trueValue, falseValue);
   // Upcast the selected value to the new type.
-  return builder.create<vector::BitCastOp>(loc, upcastType, selectedType);
+  return vector::BitCastOp::create(builder, loc, upcastType, selectedType);
 }
 
 /// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
@@ -248,8 +250,8 @@
 
   // Create an atomic load-modify-write region using
   // `memref.generic_atomic_rmw`.
-  auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
-      loc, linearizedMemref, ValueRange{storeIdx});
+  auto atomicOp = memref::GenericAtomicRMWOp::create(
+      builder, loc, linearizedMemref, ValueRange{storeIdx});
   Value origValue = atomicOp.getCurrentValue();
 
   OpBuilder::InsertionGuard guard(builder);
@@ -258,16 +260,16 @@
   // Load the original value from memory, and cast it to the original element
   // type.
   auto oneElemVecType = VectorType::get({1}, origValue.getType());
-  Value origVecValue = builder.create<vector::FromElementsOp>(
-      loc, oneElemVecType, ValueRange{origValue});
+  Value origVecValue = vector::FromElementsOp::create(
+      builder, loc, oneElemVecType, ValueRange{origValue});
 
   // Construct the final masked value and yield it.
   Value maskedValue =
       downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
                               oneElemVecType, mask, valueToStore, origVecValue);
   auto scalarMaskedValue =
-      builder.create<vector::ExtractOp>(loc, maskedValue, 0);
-  builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
+      vector::ExtractOp::create(builder, loc, maskedValue, 0);
+  memref::AtomicYieldOp::create(builder, loc, scalarMaskedValue);
 }
 
 /// Generate a non-atomic read-modify-write sequence for storing to the emulated
@@ -279,16 +281,17 @@
 
   auto oneElemVecType =
       VectorType::get({1}, linearizedMemref.getType().getElementType());
-  Value origVecValue = builder.create<vector::LoadOp>(
-      loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
-  origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
-                                                   origVecValue);
+  Value origVecValue =
+      vector::LoadOp::create(builder, loc, oneElemVecType, linearizedMemref,
+                             ValueRange{linearizedIndex});
+  origVecValue = vector::BitCastOp::create(builder, loc, valueToStore.getType(),
+                                           origVecValue);
 
   Value maskedValue =
       downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
                               oneElemVecType, mask, valueToStore, origVecValue);
-  builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
-                                  linearizedIndex);
+  vector::StoreOp::create(builder, loc, maskedValue, linearizedMemref,
+                          linearizedIndex);
 }
 
 // Emulate `vector.store` using a multi-byte container type.
@@ -382,7 +385,7 @@
     bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+        memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
 
     OpFoldResult linearizedIndices;
     memref::LinearizedMemRefInfo linearizedInfo;
@@ -418,8 +421,8 @@
     if (!emulationRequiresPartialStores) {
       // Basic case: storing full bytes.
       auto numElements = origElements / emulatedPerContainerElem;
-      auto bitCast = rewriter.create<vector::BitCastOp>(
-          loc, VectorType::get(numElements, containerElemTy),
+      auto bitCast = vector::BitCastOp::create(
+          rewriter, loc, VectorType::get(numElements, containerElemTy),
           op.getValueToStore());
       rewriter.replaceOpWithNewOp<vector::StoreOp>(
           op, bitCast.getResult(), memrefBase,
@@ -486,8 +489,9 @@
         std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
                     *foldedNumFrontPadElems, true);
       }
-      auto frontMask = rewriter.create<arith::ConstantOp>(
-          loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
+      auto frontMask = arith::ConstantOp::create(
+          rewriter, loc,
+          DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
 
       currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
       auto value =
@@ -505,9 +509,9 @@
 
     // Increment the destination index by 1 to align to the emulated width
     // boundary.
-    auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    currentDestIndex = rewriter.create<arith::AddIOp>(
-        loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+    auto constantOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
+    currentDestIndex = arith::AddIOp::create(
+        rewriter, loc, rewriter.getIndexType(), currentDestIndex, constantOne);
 
     // 2. Full width store for the inner output bytes.
     // After the previous step, the store address is aligned to the emulated
@@ -526,15 +530,15 @@
       auto storeType = VectorType::get(
           {originType.getNumElements() / emulatedPerContainerElem},
           memrefElemType);
-      auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
-                                                        fullWidthStorePart);
-      rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
-                                       currentDestIndex);
+      auto bitCast = vector::BitCastOp::create(rewriter, loc, storeType,
+                                               fullWidthStorePart);
+      vector::StoreOp::create(rewriter, loc, bitCast.getResult(), memrefBase,
+                              currentDestIndex);
 
       currentSourceIndex += numNonFullWidthElements;
-      currentDestIndex = rewriter.create<arith::AddIOp>(
-          loc, rewriter.getIndexType(), currentDestIndex,
-          rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
+      currentDestIndex = arith::AddIOp::create(
+          rewriter, loc, rewriter.getIndexType(), currentDestIndex,
+          arith::ConstantIndexOp::create(rewriter, loc, fullWidthStoreSize));
     }
 
     // 3. Partial width store for the trailing output byte.
@@ -549,8 +553,9 @@
       // Generate back mask.
       auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
       std::fill_n(maskValues.begin(), remainingElements, 1);
-      auto backMask = rewriter.create<arith::ConstantOp>(
-          loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
+      auto backMask = arith::ConstantOp::create(
+          rewriter, loc,
+          DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
 
       storeFunc(rewriter, loc, memrefBase, currentDestIndex,
                 cast<VectorValue>(subWidthStorePart), backMask.getResult());
diff --git a/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.cpp b/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.cpp
index 40fd181..e98b974 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ExtractAddressComputation.cpp
@@ -37,8 +37,8 @@
                                     memref::LoadOp loadOp, Value srcMemRef,
                                     ArrayRef<Value> indices) {
   Location loc = loadOp.getLoc();
-  return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices,
-                                         loadOp.getNontemporal());
+  return memref::LoadOp::create(rewriter, loc, srcMemRef, indices,
+                                loadOp.getNontemporal());
 }
 
 SmallVector<OpFoldResult> getLoadOpViewSizeForEachDim(RewriterBase &rewriter,
@@ -65,9 +65,8 @@
                                       memref::StoreOp storeOp, Value srcMemRef,
                                       ArrayRef<Value> indices) {
   Location loc = storeOp.getLoc();
-  return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(),
-                                          srcMemRef, indices,
-                                          storeOp.getNontemporal());
+  return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
+                                 srcMemRef, indices, storeOp.getNontemporal());
 }
 
 SmallVector<OpFoldResult>
diff --git a/compiler/src/iree/compiler/Codegen/Common/FastMathPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/FastMathPatterns.cpp
index 7cca3ca..6d6f80c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/FastMathPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/FastMathPatterns.cpp
@@ -32,45 +32,45 @@
 
     // Create constants.
     Type f32Type = rewriter.getF32Type();
-    auto oneF = rewriter.create<arith::ConstantOp>(
-        loc, f32Type, rewriter.getF32FloatAttr(1.0f));
+    auto oneF = arith::ConstantOp::create(rewriter, loc, f32Type,
+                                          rewriter.getF32FloatAttr(1.0f));
 
     // Get abs value.
-    Value ax = rewriter.create<math::AbsFOp>(loc, input);
+    Value ax = math::AbsFOp::create(rewriter, loc, input);
 
     // Create comparison for |x| < 1.0.
-    Value cmp = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
-                                               ax, oneF);
+    Value cmp = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OLT,
+                                      ax, oneF);
 
     // Create if statement.
-    auto ifOp = rewriter.create<scf::IfOp>(loc, resultType, cmp, true);
+    auto ifOp = scf::IfOp::create(rewriter, loc, resultType, cmp, true);
 
     // --- Then region (|x| < 1.0) ---
     {
       OpBuilder::InsertionGuard guard(rewriter);
       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
       // Define polynomial coefficients for |x| < 1.0.
-      auto c1_0 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(-0x1.268bc2p-11f));
-      auto c1_1 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(0x1.420828p-8f));
-      auto c1_2 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(-0x1.b5937p-6f));
-      auto c1_3 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(0x1.ce077cp-4f));
-      auto c1_4 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(-0x1.81266p-2f));
-      auto c1_5 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(0x1.06eba0p-3f));
+      auto c1_0 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(-0x1.268bc2p-11f));
+      auto c1_1 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(0x1.420828p-8f));
+      auto c1_2 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(-0x1.b5937p-6f));
+      auto c1_3 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(0x1.ce077cp-4f));
+      auto c1_4 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(-0x1.81266p-2f));
+      auto c1_5 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(0x1.06eba0p-3f));
 
-      Value t = rewriter.create<arith::MulFOp>(loc, ax, ax);
-      Value mad1 = rewriter.create<math::FmaOp>(loc, t, c1_0, c1_1);
-      Value mad2 = rewriter.create<math::FmaOp>(loc, t, mad1, c1_2);
-      Value mad3 = rewriter.create<math::FmaOp>(loc, t, mad2, c1_3);
-      Value mad4 = rewriter.create<math::FmaOp>(loc, t, mad3, c1_4);
-      Value p = rewriter.create<math::FmaOp>(loc, t, mad4, c1_5);
-      Value result = rewriter.create<math::FmaOp>(loc, ax, p, ax);
-      rewriter.create<scf::YieldOp>(loc, result);
+      Value t = arith::MulFOp::create(rewriter, loc, ax, ax);
+      Value mad1 = math::FmaOp::create(rewriter, loc, t, c1_0, c1_1);
+      Value mad2 = math::FmaOp::create(rewriter, loc, t, mad1, c1_2);
+      Value mad3 = math::FmaOp::create(rewriter, loc, t, mad2, c1_3);
+      Value mad4 = math::FmaOp::create(rewriter, loc, t, mad3, c1_4);
+      Value p = math::FmaOp::create(rewriter, loc, t, mad4, c1_5);
+      Value result = math::FmaOp::create(rewriter, loc, ax, p, ax);
+      scf::YieldOp::create(rewriter, loc, result);
     } // End then region.
 
     // --- Else region (|x| >= 1.0) ---
@@ -79,38 +79,38 @@
       rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
 
       // Define polynomial coefficients for |x| >= 1.0
-      auto c2_0 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(0x1.1d3156p-16f));
-      auto c2_1 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(-0x1.8d129p-12f));
-      auto c2_2 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(0x1.f9a6d2p-9f));
-      auto c2_3 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(-0x1.8c3164p-6f));
-      auto c2_4 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(0x1.b4e9c8p-4f));
-      auto c2_5 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(0x1.4515fap-1f));
-      auto c2_6 = rewriter.create<arith::ConstantOp>(
-          loc, f32Type, rewriter.getF32FloatAttr(0x1.078e50p-3f));
+      auto c2_0 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(0x1.1d3156p-16f));
+      auto c2_1 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(-0x1.8d129p-12f));
+      auto c2_2 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(0x1.f9a6d2p-9f));
+      auto c2_3 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(-0x1.8c3164p-6f));
+      auto c2_4 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(0x1.b4e9c8p-4f));
+      auto c2_5 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(0x1.4515fap-1f));
+      auto c2_6 = arith::ConstantOp::create(
+          rewriter, loc, f32Type, rewriter.getF32FloatAttr(0x1.078e50p-3f));
 
-      Value mad5 = rewriter.create<math::FmaOp>(loc, ax, c2_0, c2_1);
-      Value mad6 = rewriter.create<math::FmaOp>(loc, ax, mad5, c2_2);
-      Value mad7 = rewriter.create<math::FmaOp>(loc, ax, mad6, c2_3);
-      Value mad8 = rewriter.create<math::FmaOp>(loc, ax, mad7, c2_4);
-      Value mad9 = rewriter.create<math::FmaOp>(loc, ax, mad8, c2_5);
-      Value mad10 = rewriter.create<math::FmaOp>(loc, ax, mad9, c2_6);
+      Value mad5 = math::FmaOp::create(rewriter, loc, ax, c2_0, c2_1);
+      Value mad6 = math::FmaOp::create(rewriter, loc, ax, mad5, c2_2);
+      Value mad7 = math::FmaOp::create(rewriter, loc, ax, mad6, c2_3);
+      Value mad8 = math::FmaOp::create(rewriter, loc, ax, mad7, c2_4);
+      Value mad9 = math::FmaOp::create(rewriter, loc, ax, mad8, c2_5);
+      Value mad10 = math::FmaOp::create(rewriter, loc, ax, mad9, c2_6);
       // In the C code, there's an extra fma(ax, p, ax) here, which seems
       // incorrect based on the standard erf approximation formula and leads to
       // values > 1. The typical approximation leads directly to the exponent
-      // term. Value p2 = rewriter.create<math::FmaOp>(loc, ax, mad10, ax); //
+      // term. Value p2 = math::FmaOp::create(rewriter, loc, ax, mad10, ax); //
       // Original line based on C code.
       Value p2 = mad10; // Corrected based on typical erf formula structure for
                         // |x| >= 1
-      Value negP2 = rewriter.create<arith::NegFOp>(loc, p2);
-      Value expNegP2 = rewriter.create<math::ExpOp>(loc, negP2);
-      Value result2 = rewriter.create<arith::SubFOp>(loc, oneF, expNegP2);
-      rewriter.create<scf::YieldOp>(loc, result2);
+      Value negP2 = arith::NegFOp::create(rewriter, loc, p2);
+      Value expNegP2 = math::ExpOp::create(rewriter, loc, negP2);
+      Value result2 = arith::SubFOp::create(rewriter, loc, oneF, expNegP2);
+      scf::YieldOp::create(rewriter, loc, result2);
     } // End else region
 
     // Set insertion point after the if.
@@ -118,7 +118,7 @@
 
     // Restore the sign: BUILTIN_COPYSIGN_F32(ret, x)
     Value finalResult =
-        rewriter.create<math::CopySignOp>(loc, ifOp.getResult(0), input);
+        math::CopySignOp::create(rewriter, loc, ifOp.getResult(0), input);
     // Replace the original op with our implementation.
     rewriter.replaceOp(op, finalResult);
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/FissionTransferOpsInControlFlow.cpp b/compiler/src/iree/compiler/Codegen/Common/FissionTransferOpsInControlFlow.cpp
index 6c72bd6..381d807 100644
--- a/compiler/src/iree/compiler/Codegen/Common/FissionTransferOpsInControlFlow.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/FissionTransferOpsInControlFlow.cpp
@@ -53,8 +53,8 @@
   auto memrefType = MemRefType::get(memrefShape, vectorType.getElementType(),
                                     AffineMap{}, privateAddrSpaceAttr);
 
-  return rewriter.create<memref::AllocaOp>(
-      loc, memrefType,
+  return memref::AllocaOp::create(
+      rewriter, loc, memrefType,
       ValueRange{getValueOrCreateConstantIndexOp(rewriter, loc, allocaSize)});
 }
 
@@ -62,10 +62,10 @@
 /// normalized into step of one in order to access the correct element from the
 /// alloca. %index = (%loop_index - %loop_lower_bound) / %loop_step
 static Value createMemrefAccessIndex(IRRewriter &rewriter, scf::ForOp forOp) {
-  auto subIOp = rewriter.create<arith::SubIOp>(
-      forOp.getLoc(), forOp.getInductionVar(), forOp.getLowerBound());
+  auto subIOp = arith::SubIOp::create(
+      rewriter, forOp.getLoc(), forOp.getInductionVar(), forOp.getLowerBound());
   auto divUIOp =
-      rewriter.create<arith::DivUIOp>(forOp.getLoc(), subIOp, forOp.getStep());
+      arith::DivUIOp::create(rewriter, forOp.getLoc(), subIOp, forOp.getStep());
   return divUIOp.getResult();
 }
 
@@ -90,7 +90,7 @@
   rewriter.setInsertionPoint(readLoop.getBody()->getTerminator());
   auto allocaIndex = createMemrefAccessIndex(rewriter, readLoop);
   auto constantZero =
-      rewriter.create<arith::ConstantIndexOp>(readLoop.getLoc(), 0);
+      arith::ConstantIndexOp::create(rewriter, readLoop.getLoc(), 0);
 
   // Store 'transfer_read' results into the corresponding 'alloca'.
   for (size_t i = 0; i < allocaOps.size(); i++) {
@@ -100,8 +100,8 @@
 
     SmallVector<Value> indices = {allocaIndex};
     indices.append(allocaOp.getType().getShape().size() - 1, constantZero);
-    rewriter.create<vector::TransferWriteOp>(readOp.getLoc(), readOp, allocaOp,
-                                             indices);
+    vector::TransferWriteOp::create(rewriter, readOp.getLoc(), readOp, allocaOp,
+                                    indices);
   }
 
   LDBG() << "Read loop: \n" << readLoop << "\n";
@@ -119,7 +119,7 @@
   rewriter.setInsertionPointToStart(writeLoop.getBody());
   auto allocaIndex = createMemrefAccessIndex(rewriter, writeLoop);
   auto constantZero =
-      rewriter.create<arith::ConstantIndexOp>(writeLoop.getLoc(), 0);
+      arith::ConstantIndexOp::create(rewriter, writeLoop.getLoc(), 0);
   for (size_t i = 0; i < allocaOps.size(); i++) {
     memref::AllocaOp allocaOp = allocaOps[i];
     auto readOp = cast<vector::TransferReadOp>(
diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
index 085f9d4..0523c40 100644
--- a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
@@ -126,7 +126,7 @@
                                           OpBuilder &builder) {
   if (type.hasStaticShape()) {
     assert(dynamicDims.empty());
-    return builder.create<arith::ConstantIndexOp>(loc, type.getNumElements());
+    return arith::ConstantIndexOp::create(builder, loc, type.getNumElements());
   }
 
   int64_t numSymbols = 0;
@@ -289,9 +289,9 @@
         MemRefType::get(staticShape, oldType.getElementType(),
                         MemRefLayoutAttrInterface(), oldType.getMemorySpace());
 
-    auto newOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    auto newOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
-        subspanOp.getLoc(), newType, subspanOp.getLayout(),
+    auto newOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+    auto newOp = IREE::HAL::InterfaceBindingSubspanOp::create(
+        rewriter, subspanOp.getLoc(), newType, subspanOp.getLayout(),
         subspanOp.getBinding(), newOffset, dynamicShape,
         subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr());
 
@@ -305,9 +305,9 @@
                         {}, newType, elementOffset, linearShapeWithoutOffset,
                         stride))
               : nullptr;
-      replacement = rewriter.create<memref::SubViewOp>(
-          loc, returnType, newOp, elementOffset, linearShapeWithoutOffset,
-          OpFoldResult(rewriter.getIndexAttr(1)));
+      replacement = memref::SubViewOp::create(
+          rewriter, loc, returnType, newOp, elementOffset,
+          linearShapeWithoutOffset, OpFoldResult(rewriter.getIndexAttr(1)));
     }
 
     rewriter.replaceOp(subspanOp, replacement);
@@ -395,7 +395,8 @@
         if (ShapedType::isDynamic(shape[i])) {
           dims.push_back(dynamicDims[dynamicDimIndex++]);
         } else {
-          dims.push_back(builder.create<arith::ConstantIndexOp>(loc, shape[i]));
+          dims.push_back(
+              arith::ConstantIndexOp::create(builder, loc, shape[i]));
         }
       }
     };
@@ -409,7 +410,7 @@
     } else {
       if (sourceType.hasStaticShape()) {
         for (int64_t dim : sourceType.getShape()) {
-          dims.push_back(builder.create<arith::ConstantIndexOp>(loc, dim));
+          dims.push_back(arith::ConstantIndexOp::create(builder, loc, dim));
         }
       } else {
         return nullptr;
@@ -424,8 +425,8 @@
 
   Value linearIndex = indices.front();
   for (int i = 1; i < indices.size(); ++i) {
-    linearIndex = builder.create<affine::AffineApplyOp>(
-        loc, mulAddMap, ValueRange{linearIndex, dims[i], indices[i]});
+    linearIndex = affine::AffineApplyOp::create(
+        builder, loc, mulAddMap, ValueRange{linearIndex, dims[i], indices[i]});
   }
   return linearIndex;
 }
@@ -451,9 +452,9 @@
         rewriter, op.getLoc(), op.getMixedOffsets());
     Value linearOffset =
         linearizeIndices(op.getSource(), offsets, op.getLoc(), rewriter);
-    Value stride = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
-    Value newSubView = rewriter.create<memref::SubViewOp>(
-        op.getLoc(), adaptor.getSource(), ValueRange({linearOffset}),
+    Value stride = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 1);
+    Value newSubView = memref::SubViewOp::create(
+        rewriter, op.getLoc(), adaptor.getSource(), ValueRange({linearOffset}),
         ValueRange({size}), ValueRange({stride}));
     rewriter.replaceOpWithNewOp<memref::CastOp>(op, neededResultType,
                                                 newSubView);
diff --git a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp
index 3e99579..cf8c908 100644
--- a/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefs.cpp
@@ -57,7 +57,7 @@
   }
   if (auto constant = dyn_cast<AffineConstantExpr>(result))
     return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
-  return builder.create<affine::AffineApplyOp>(loc, result, dynamicPart)
+  return affine::AffineApplyOp::create(builder, loc, result, dynamicPart)
       .getResult();
 }
 
@@ -74,7 +74,7 @@
     OpBuilder::InsertionGuard g(rewriter);
     setInsertionPointToStart(rewriter, source);
     newExtractStridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+        memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
   }
 
   auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
@@ -127,8 +127,8 @@
 static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
                                       OpFoldResult in) {
   if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
-    return rewriter.create<arith::ConstantIndexOp>(
-        loc, cast<IntegerAttr>(offsetAttr).getInt());
+    return arith::ConstantIndexOp::create(
+        rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt());
   }
   return cast<Value>(in);
 }
@@ -144,8 +144,8 @@
                               getAsOpFoldResult(indices));
 
   return std::make_pair(
-      rewriter.create<memref::ReinterpretCastOp>(
-          loc, source,
+      memref::ReinterpretCastOp::create(
+          rewriter, loc, source,
           /* offset = */ offset,
           /* shapes = */ ArrayRef<OpFoldResult>{collapsedShape},
           /* strides = */ ArrayRef<OpFoldResult>{strides.back()}),
@@ -190,45 +190,47 @@
 static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
                       Value offset) {
   if constexpr (std::is_same_v<T, memref::LoadOp>) {
-    auto newLoad = rewriter.create<memref::LoadOp>(
-        op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
+    auto newLoad =
+        memref::LoadOp::create(rewriter, op->getLoc(), op->getResultTypes(),
+                               flatMemref, ValueRange{offset});
     newLoad->setAttrs(op->getAttrs());
     rewriter.replaceOp(op, newLoad.getResult());
   } else if constexpr (std::is_same_v<T, vector::LoadOp>) {
-    auto newLoad = rewriter.create<vector::LoadOp>(
-        op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
+    auto newLoad =
+        vector::LoadOp::create(rewriter, op->getLoc(), op->getResultTypes(),
+                               flatMemref, ValueRange{offset});
     newLoad->setAttrs(op->getAttrs());
     rewriter.replaceOp(op, newLoad.getResult());
   } else if constexpr (std::is_same_v<T, memref::StoreOp>) {
-    auto newStore = rewriter.create<memref::StoreOp>(
-        op->getLoc(), op->getOperands().front(), flatMemref,
-        ValueRange{offset});
+    auto newStore = memref::StoreOp::create(rewriter, op->getLoc(),
+                                            op->getOperands().front(),
+                                            flatMemref, ValueRange{offset});
     newStore->setAttrs(op->getAttrs());
     rewriter.replaceOp(op, newStore);
   } else if constexpr (std::is_same_v<T, vector::StoreOp>) {
-    auto newStore = rewriter.create<vector::StoreOp>(
-        op->getLoc(), op->getOperands().front(), flatMemref,
-        ValueRange{offset});
+    auto newStore = vector::StoreOp::create(rewriter, op->getLoc(),
+                                            op->getOperands().front(),
+                                            flatMemref, ValueRange{offset});
     newStore->setAttrs(op->getAttrs());
     rewriter.replaceOp(op, newStore);
   } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
-    auto newTransferRead = rewriter.create<vector::TransferReadOp>(
-        op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
+    auto newTransferRead = vector::TransferReadOp::create(
+        rewriter, op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
         op.getPadding());
     rewriter.replaceOp(op, newTransferRead.getResult());
   } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
-    auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
-        op->getLoc(), op.getVector(), flatMemref, ValueRange{offset});
+    auto newTransferWrite = vector::TransferWriteOp::create(
+        rewriter, op->getLoc(), op.getVector(), flatMemref, ValueRange{offset});
     rewriter.replaceOp(op, newTransferWrite);
   } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
-    auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
-        op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
+    auto newMaskedLoad = vector::MaskedLoadOp::create(
+        rewriter, op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
         op.getMask(), op.getPassThru());
     newMaskedLoad->setAttrs(op->getAttrs());
     rewriter.replaceOp(op, newMaskedLoad.getResult());
   } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
-    auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
-        op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(),
+    auto newMaskedStore = vector::MaskedStoreOp::create(
+        rewriter, op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(),
         op.getValueToStore());
     newMaskedStore->setAttrs(op->getAttrs());
     rewriter.replaceOp(op, newMaskedStore);
diff --git a/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp b/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp
index c9987fe..ae22df7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/FoldTensorSubsetIntoVectorTransferOps.cpp
@@ -105,10 +105,10 @@
     for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
       OpFoldResult offset =
           extractOp.getMixedOffsets()[it.index() + rankReduced];
-      newIndices.push_back(rewriter.create<arith::AddIOp>(
-          xferOp->getLoc(), it.value(),
-          getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
-                                          offset)));
+      newIndices.push_back(
+          arith::AddIOp::create(rewriter, xferOp->getLoc(), it.value(),
+                                getValueOrCreateConstantIndexOp(
+                                    rewriter, extractOp.getLoc(), offset)));
     }
     SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
@@ -316,8 +316,8 @@
 
     Location loc = extractSliceOp.getLoc();
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
-    auto init = rewriter.create<tensor::EmptyOp>(
-        loc, mixedSizes, extractSliceOp.getType().getElementType());
+    auto init = tensor::EmptyOp::create(
+        rewriter, loc, mixedSizes, extractSliceOp.getType().getElementType());
 
     auto indices = xferOp.getIndices();
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
index 8bff2e1..3daeee0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ForOpCanonicalizationPass.cpp
@@ -161,9 +161,9 @@
     if (iteratorFolded.empty())
       return failure();
 
-    auto newLoop = rewriter.create<scf::ForOp>(
-        forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), initArgs);
+    auto newLoop =
+        scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+                           forOp.getUpperBound(), forOp.getStep(), initArgs);
 
     SmallVector<Value> newReturnVals = transferBody(
         forOp.getBody(), newLoop.getBody(), returnValues, rewriter);
@@ -238,16 +238,16 @@
     for (auto [index, castType, targetType] :
          llvm::zip_equal(ivIndices, castTypes, targetTypes)) {
       Value oldValue = ivInitValues[index];
-      Value shapeCast = rewriter.create<vector::ShapeCastOp>(
-          oldValue.getLoc(), castType, oldValue);
-      ivInitValues[index] = rewriter.create<vector::BitCastOp>(
-          oldValue.getLoc(), targetType, shapeCast);
+      Value shapeCast = vector::ShapeCastOp::create(rewriter, oldValue.getLoc(),
+                                                    castType, oldValue);
+      ivInitValues[index] = vector::BitCastOp::create(
+          rewriter, oldValue.getLoc(), targetType, shapeCast);
     }
 
     // Create a new loop with the casted init values. This also creates
     // induction variables with proper type.
-    auto newLoop = rewriter.create<scf::ForOp>(
-        forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+    auto newLoop = scf::ForOp::create(
+        rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), ivInitValues);
 
     // Move all operations to the new for op. This also replaces block
@@ -261,9 +261,9 @@
          llvm::zip_equal(ivIndices, castTypes, ivTypes)) {
       Value newIv = newLoop.getRegionIterArgs()[index];
       auto bitcastOp =
-          rewriter.create<vector::BitCastOp>(newIv.getLoc(), castType, newIv);
-      auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
-          newIv.getLoc(), origType, bitcastOp);
+          vector::BitCastOp::create(rewriter, newIv.getLoc(), castType, newIv);
+      auto shapeCastOp = vector::ShapeCastOp::create(rewriter, newIv.getLoc(),
+                                                     origType, bitcastOp);
       // Replace all uses of the new induction variable with a bitcast. We need
       // to exclude the bitcast op itself given it also uses the induction
       // variable.
@@ -279,10 +279,10 @@
     for (auto [index, castType, targetType] :
          llvm::zip_equal(ivIndices, castTypes, targetTypes)) {
       Value oldRet = ivRetValues[index];
-      Value shapeCast = rewriter.create<vector::ShapeCastOp>(oldRet.getLoc(),
-                                                             castType, oldRet);
-      ivRetValues[index] = rewriter.create<vector::BitCastOp>(
-          oldRet.getLoc(), targetType, shapeCast);
+      Value shapeCast = vector::ShapeCastOp::create(rewriter, oldRet.getLoc(),
+                                                    castType, oldRet);
+      ivRetValues[index] = vector::BitCastOp::create(rewriter, oldRet.getLoc(),
+                                                     targetType, shapeCast);
     }
     yieldOp->setOperands(ivRetValues);
 
@@ -295,10 +295,10 @@
     for (auto [index, castType, origType] :
          llvm::zip_equal(ivIndices, castTypes, ivTypes)) {
       Value oldRet = forRetValues[index];
-      Value bitCast =
-          rewriter.create<vector::BitCastOp>(oldRet.getLoc(), castType, oldRet);
-      forRetValues[index] = rewriter.create<vector::ShapeCastOp>(
-          oldRet.getLoc(), origType, bitCast);
+      Value bitCast = vector::BitCastOp::create(rewriter, oldRet.getLoc(),
+                                                castType, oldRet);
+      forRetValues[index] = vector::ShapeCastOp::create(
+          rewriter, oldRet.getLoc(), origType, bitCast);
     }
 
     rewriter.replaceOp(forOp, forRetValues);
diff --git a/compiler/src/iree/compiler/Codegen/Common/ForallToFor.cpp b/compiler/src/iree/compiler/Codegen/Common/ForallToFor.cpp
index 722f9e9..3863176 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ForallToFor.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ForallToFor.cpp
@@ -70,10 +70,10 @@
       auto strides =
           llvm::map_to_vector(parallelInsert.getStrides(),
                               [&](Value v) { return map.lookupOrDefault(v); });
-      auto insertSlice = builder.create<tensor::InsertSliceOp>(
-          parallelInsert.getLoc(), source, dest, offsets, sizes, strides,
-          parallelInsert.getStaticOffsets(), parallelInsert.getStaticSizes(),
-          parallelInsert.getStaticStrides());
+      auto insertSlice = tensor::InsertSliceOp::create(
+          builder, parallelInsert.getLoc(), source, dest, offsets, sizes,
+          strides, parallelInsert.getStaticOffsets(),
+          parallelInsert.getStaticSizes(), parallelInsert.getStaticStrides());
       yieldedValues.push_back(insertSlice.getResult());
     }
     return yieldedValues;
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
index 4f90260..f8f2610 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
@@ -148,8 +148,9 @@
     // Create a zero vector with the full distributed vector shape for
     // accumulating unrolled contraction results.
     auto tileType = VectorType::get(distShape, resultType.getElementType());
-    Value zero = rewriter.create<arith::ConstantOp>(
-        contractOp.getLoc(), tileType, rewriter.getZeroAttr(tileType));
+    Value zero =
+        arith::ConstantOp::create(rewriter, contractOp.getLoc(), tileType,
+                                  rewriter.getZeroAttr(tileType));
     VectorValue finalTile = cast<VectorValue>(zero);
     LLVM_DEBUG(llvm::dbgs() << "init tile: " << finalTile << "\n");
 
@@ -195,7 +196,7 @@
 
       // Get the slice of the accumulator in this batch.
       Value accSlice =
-          rewriter.create<vector::ExtractOp>(loc, acc, resultBatchOffsets);
+          vector::ExtractOp::create(rewriter, loc, acc, resultBatchOffsets);
 
       // Get the k batch size for LHS and RHS vector.
       std::optional<int64_t> kBatch =
@@ -221,14 +222,14 @@
                << llvm::interleaved_array(rhsBatchOffsets);
 
         Value lhsSlice =
-            rewriter.create<vector::ExtractOp>(loc, lhs, lhsBatchOffsets);
+            vector::ExtractOp::create(rewriter, loc, lhs, lhsBatchOffsets);
         Value rhsSlice =
-            rewriter.create<vector::ExtractOp>(loc, rhs, rhsBatchOffsets);
+            vector::ExtractOp::create(rewriter, loc, rhs, rhsBatchOffsets);
         accSlice =
             computeMMA(rewriter, loc, mmaKind, lhsSlice, rhsSlice, accSlice);
       }
-      finalTile = rewriter.create<vector::InsertOp>(loc, accSlice, finalTile,
-                                                    resultBatchOffsets);
+      finalTile = vector::InsertOp::create(rewriter, loc, accSlice, finalTile,
+                                           resultBatchOffsets);
     }
 
     replaceOpWithDistributedValues(rewriter, contractOp, finalTile);
@@ -286,18 +287,18 @@
     // Get the storage vector types that each thread is in charge of.
     auto [aVectorType, bVectorType, cVectorType] = mmaKind.getABCVectorTypes();
     Value aCast =
-        builder.create<vector::ShapeCastOp>(a.getLoc(), aVectorType, a);
+        vector::ShapeCastOp::create(builder, a.getLoc(), aVectorType, a);
     Value bCast =
-        builder.create<vector::ShapeCastOp>(b.getLoc(), bVectorType, b);
+        vector::ShapeCastOp::create(builder, b.getLoc(), bVectorType, b);
     Value cCast =
-        builder.create<vector::ShapeCastOp>(c.getLoc(), cVectorType, c);
+        vector::ShapeCastOp::create(builder, c.getLoc(), cVectorType, c);
     SmallVector<Value> results;
     [[maybe_unused]] LogicalResult createdMmaOp =
         mmaKind.buildUnderlyingOperations(builder, loc, {aCast, bCast}, {cCast},
                                           results);
     assert(succeeded(createdMmaOp) && "Should never fail to construct mma op");
-    return builder.create<vector::ShapeCastOp>(c.getLoc(), c.getType(),
-                                               results[0]);
+    return vector::ShapeCastOp::create(builder, c.getLoc(), c.getType(),
+                                       results[0]);
   }
 };
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/DecomposeHorizontallyFusedGemms.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/DecomposeHorizontallyFusedGemms.cpp
index 485629d..a2a388b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/DecomposeHorizontallyFusedGemms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/DecomposeHorizontallyFusedGemms.cpp
@@ -147,9 +147,9 @@
         inputs, [](OpOperand *operand) { return operand->get(); });
     SmallVector<Value> initVals = llvm::map_to_vector(
         inits, [](OpOperand *operand) { return operand->get(); });
-    auto newOp = rewriter.create<linalg::GenericOp>(
-        linalgOp.getLoc(), TypeRange{inits[0]->get().getType()}, inputVals,
-        initVals, indexingMaps, iteratorTypes,
+    auto newOp = linalg::GenericOp::create(
+        rewriter, linalgOp.getLoc(), TypeRange{inits[0]->get().getType()},
+        inputVals, initVals, indexingMaps, iteratorTypes,
         [&](OpBuilder &b, Location loc, ValueRange blockArgs) {
           Block *oldBody = linalgOp.getBlock();
           usedInputs.insert(resultNumber + linalgOp.getNumDpsInputs());
@@ -166,7 +166,7 @@
             b.clone(*usedOperation, regionMapping);
           }
 
-          b.create<linalg::YieldOp>(loc, regionMapping.lookup(result));
+          linalg::YieldOp::create(b, loc, regionMapping.lookup(result));
         });
 
     // If on decomposition any dims are unused propagating lowering config isnt
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyPaddingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyPaddingLevel.cpp
index c7547e7..c65c7ae 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyPaddingLevel.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyPaddingLevel.cpp
@@ -247,8 +247,8 @@
       assert(succeeded(reductionDimInfo) &&
              "obtained with confirmation earlier");
       for (auto &&dimInfo : reductionDimInfo.value()) {
-        Value redDimSize = rewriter.create<tensor::DimOp>(
-            paddedOp.getLoc(), dimInfo.operand, dimInfo.operandDim);
+        Value redDimSize = tensor::DimOp::create(
+            rewriter, paddedOp.getLoc(), dimInfo.operand, dimInfo.operandDim);
         reductionDimSizes.push_back({dimInfo.loopIndex, redDimSize});
       }
 
@@ -264,12 +264,13 @@
                                            redDimIndex, redDimSize);
         conds.push_back(cond);
       }
-      Value reductionIdentityValue = rewriter.create<arith::ConstantOp>(
-          paddedOp.getLoc(), reductionIdentity.value());
+      Value reductionIdentityValue = arith::ConstantOp::create(
+          rewriter, paddedOp.getLoc(), reductionIdentity.value());
       assert(conds.size() > 0);
       Value cond = conds[0];
       for (Value nxtCond : llvm::drop_begin(conds, 1)) {
-        cond = rewriter.create<arith::AndIOp>(paddedOp.getLoc(), cond, nxtCond);
+        cond =
+            arith::AndIOp::create(rewriter, paddedOp.getLoc(), cond, nxtCond);
       }
 
       // Find the reduction op operand that is reduced with the carried output.
@@ -310,10 +311,10 @@
         sizes[i] = getAsOpFoldResult(v);
     }
 
-    Value out = rewriter.create<tensor::EmptyOp>(
-        paddedOp.getLoc(), sizes, getElementTypeOrSelf(tensorTy));
-    auto copied = rewriter.create<linalg::CopyOp>(paddedOp.getLoc(),
-                                                  padOp.getResult(), out);
+    Value out = tensor::EmptyOp::create(rewriter, paddedOp.getLoc(), sizes,
+                                        getElementTypeOrSelf(tensorTy));
+    auto copied = linalg::CopyOp::create(rewriter, paddedOp.getLoc(),
+                                         padOp.getResult(), out);
     rewriter.replaceUsesWithIf(padOp.getResult(), copied.getResult(0),
                                [&](OpOperand &opOperand) {
                                  return users.contains(opOperand.getOwner());
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUBubbleResourceCasts.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUBubbleResourceCasts.cpp
index b27b486..315fdc8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUBubbleResourceCasts.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUBubbleResourceCasts.cpp
@@ -78,8 +78,9 @@
               }
 
               rewriter.setInsertionPoint(extract);
-              auto newCast = rewriter.create<IREE::GPU::BufferResourceCastOp>(
-                  loc, extract.getSource().getType(), extract.getSource());
+              auto newCast = IREE::GPU::BufferResourceCastOp::create(
+                  rewriter, loc, extract.getSource().getType(),
+                  extract.getSource());
               extract.getSourceMutable().assign(newCast);
               return true;
             })
@@ -89,8 +90,8 @@
               }
 
               rewriter.setInsertionPoint(expand);
-              auto newCast = rewriter.create<IREE::GPU::BufferResourceCastOp>(
-                  loc, expand.getSrcType(), expand.getSrc());
+              auto newCast = IREE::GPU::BufferResourceCastOp::create(
+                  rewriter, loc, expand.getSrcType(), expand.getSrc());
               expand.getSrcMutable().assign(newCast);
               return true;
             })
@@ -101,9 +102,8 @@
                   }
 
                   rewriter.setInsertionPoint(collapse);
-                  auto newCast =
-                      rewriter.create<IREE::GPU::BufferResourceCastOp>(
-                          loc, collapse.getSrcType(), collapse.getSrc());
+                  auto newCast = IREE::GPU::BufferResourceCastOp::create(
+                      rewriter, loc, collapse.getSrcType(), collapse.getSrc());
                   collapse.getSrcMutable().assign(newCast);
                   return true;
                 })
@@ -113,8 +113,8 @@
               }
 
               rewriter.setInsertionPoint(pad);
-              auto newCast = rewriter.create<IREE::GPU::BufferResourceCastOp>(
-                  loc, pad.getSourceType(), pad.getSource());
+              auto newCast = IREE::GPU::BufferResourceCastOp::create(
+                  rewriter, loc, pad.getSourceType(), pad.getSource());
               pad.getSourceMutable().assign(newCast);
               return true;
             })
@@ -126,8 +126,9 @@
               rewriter.setInsertionPoint(linalgOp);
               // Only propagate to input operands.
               for (auto inputOperand : linalgOp.getDpsInputOperands()) {
-                auto newCast = rewriter.create<IREE::GPU::BufferResourceCastOp>(
-                    loc, inputOperand->get().getType(), inputOperand->get());
+                auto newCast = IREE::GPU::BufferResourceCastOp::create(
+                    rewriter, loc, inputOperand->get().getType(),
+                    inputOperand->get());
                 inputOperand->assign(newCast);
               }
               return true;
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineValueBarriers.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineValueBarriers.cpp
index 9e17a65..5094f8f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineValueBarriers.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCombineValueBarriers.cpp
@@ -81,7 +81,7 @@
                            barrierOp.getInputs().end());
   }
   auto combinedBarrierOp =
-      rewriter.create<IREE::GPU::ValueBarrierOp>(loc, barrierOperands);
+      IREE::GPU::ValueBarrierOp::create(rewriter, loc, barrierOperands);
 
   // Replace all uses of the previous barrier with new barrier.
   int resultNumber = 0;
@@ -194,8 +194,8 @@
   barrierOperands.append(barrierB.getOperands().begin(),
                          barrierB.getOperands().end());
 
-  auto combinedBarrierOp = rewriter.create<IREE::GPU::ValueBarrierOp>(
-      barrierB.getLoc(), barrierOperands);
+  auto combinedBarrierOp = IREE::GPU::ValueBarrierOp::create(
+      rewriter, barrierB.getLoc(), barrierOperands);
 
   int numOperandsA = barrierA.getNumOperands();
   int numOperandsB = barrierB.getNumOperands();
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp
index ee66d68..efefceb 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUCreateFastSlowPath.cpp
@@ -80,7 +80,7 @@
 
   // Build the condition for the scf.if op: all pad sizes are zero.
   Location loc = padOp.getLoc();
-  Value cstZero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  Value cstZero = arith::ConstantIndexOp::create(rewriter, loc, 0);
   SmallVector<Value> eqZeroCmpVals;
   for (OpFoldResult pad : llvm::concat<OpFoldResult>(lowPads, highPads)) {
     if (auto padValue = dyn_cast<Value>(pad)) {
@@ -90,14 +90,14 @@
       padSizeOps.insert(padValue.getDefiningOp());
     }
     if (!isZero(pad)) {
-      eqZeroCmpVals.push_back(rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::eq,
+      eqZeroCmpVals.push_back(arith::CmpIOp::create(
+          rewriter, loc, arith::CmpIPredicate::eq,
           getValueOrCreateConstantIndexOp(rewriter, loc, pad), cstZero));
     }
   }
   Value ifCond = eqZeroCmpVals.front();
   for (Value cmp : llvm::ArrayRef(eqZeroCmpVals).drop_front())
-    ifCond = rewriter.create<arith::AndIOp>(loc, ifCond, cmp);
+    ifCond = arith::AndIOp::create(rewriter, loc, ifCond, cmp);
 
   SmallVector<Operation *> cloneOps;
   for (Operation *op : allOps) {
@@ -118,15 +118,15 @@
         builder.clone(*op, bvm);
       }
     }
-    builder.create<scf::YieldOp>(loc);
+    scf::YieldOp::create(builder, loc);
   };
   auto elseBuilder = [&](OpBuilder &builder, Location loc) {
     IRMapping bvm;
     for (Operation *op : cloneOps)
       builder.clone(*op, bvm);
-    builder.create<scf::YieldOp>(loc);
+    scf::YieldOp::create(builder, loc);
   };
-  rewriter.create<scf::IfOp>(padOp.getLoc(), ifCond, thenBuilder, elseBuilder);
+  scf::IfOp::create(rewriter, padOp.getLoc(), ifCond, thenBuilder, elseBuilder);
 
   // All of these ops have been cloned to both regions. Erease them now.
   for (Operation *op : llvm::reverse(cloneOps))
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp
index 6a8b503..2abb7e1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp
@@ -45,7 +45,7 @@
 
   // Create an early zero index value for replacements.
   Location loc = target->getLoc();
-  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
   DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
   WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
     diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeCopyUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeCopyUsingForall.cpp
index 85099df..956b9f2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeCopyUsingForall.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeCopyUsingForall.cpp
@@ -36,8 +36,8 @@
                                          memref::CopyOp copy) {
   SmallVector<Attribute> mapping = {gpu::GPUThreadMappingAttr::get(
       rewriter.getContext(), gpu::MappingId::LinearDim0)};
-  scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
-      copy.getLoc(), ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
+  scf::ForallOp newForallOp = scf::ForallOp::create(
+      rewriter, copy.getLoc(), ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
       ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)},
       ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)},
       /*outputs=*/ValueRange(), /*mapping=*/rewriter.getArrayAttr(mapping));
@@ -73,8 +73,8 @@
   }
   mapping = llvm::to_vector(llvm::reverse(mapping));
 
-  scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
-      copy.getLoc(), lowerBounds, upperBounds, tileSizes,
+  scf::ForallOp newForallOp = scf::ForallOp::create(
+      rewriter, copy.getLoc(), lowerBounds, upperBounds, tileSizes,
       /*outputs=*/ValueRange(), /*mapping=*/rewriter.getArrayAttr(mapping));
 
   rewriter.setInsertionPointToStart(newForallOp.getBody());
@@ -108,10 +108,10 @@
   SmallVector<OpFoldResult> offsets =
       getAsOpFoldResult(newForallOp.getInductionVars());
   SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
-  Value sourceTile = rewriter.create<memref::SubViewOp>(
-      loc, copy.getSource(), offsets, sizes, strides);
-  Value targetTile = rewriter.create<memref::SubViewOp>(
-      loc, copy.getTarget(), offsets, sizes, strides);
+  Value sourceTile = memref::SubViewOp::create(rewriter, loc, copy.getSource(),
+                                               offsets, sizes, strides);
+  Value targetTile = memref::SubViewOp::create(rewriter, loc, copy.getTarget(),
+                                               offsets, sizes, strides);
   rewriter.replaceOpWithNewOp<memref::CopyOp>(copy, sourceTile, targetTile);
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp
index 63bff76..28b769d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp
@@ -122,17 +122,17 @@
   // can use a lower bound of 0 and keep the loop bounds static. This helps
   // simplify later loop folding patterns without an `affine.linearize_index` op
   // to help with inferring int ranges.
-  Value lb = perfectlyDivides ? rewriter.create<arith::ConstantIndexOp>(loc, 0)
+  Value lb = perfectlyDivides ? arith::ConstantIndexOp::create(rewriter, loc, 0)
                               : flatId;
   Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, totalLoopTripCount);
   Value step =
-      rewriter.create<arith::ConstantIndexOp>(loc, flatTotalNumWorkers);
+      arith::ConstantIndexOp::create(rewriter, loc, flatTotalNumWorkers);
   // We need to add barriers before and after the distributed loop because the
   // loop might have reads/writes to shared memory that can have a different
   // layout compared to rest of the program.
-  rewriter.create<gpu::BarrierOp>(loc);
-  auto forLoop = rewriter.create<scf::ForOp>(loc, lb, ub, step, ValueRange{});
-  rewriter.create<gpu::BarrierOp>(loc);
+  gpu::BarrierOp::create(rewriter, loc);
+  auto forLoop = scf::ForOp::create(rewriter, loc, lb, ub, step, ValueRange{});
+  gpu::BarrierOp::create(rewriter, loc);
   Block *loopBody = forLoop.getBody();
 
   // Get the replacement IDs for the forall iterator ids.
@@ -145,8 +145,8 @@
 
   // We require a descending relative mapping, so we can reuse the upper bound
   // sizes directly.
-  auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
-      loc, newFlatProducerId, delinSizes);
+  auto delinearize = affine::AffineDelinearizeIndexOp::create(
+      rewriter, loc, newFlatProducerId, delinSizes);
 
   SmallVector<Value> newBlockArgs = delinearize.getResults();
 
@@ -213,8 +213,9 @@
   SmallVector<int64_t> threadGridBasis = {workgroupSize[2], workgroupSize[1],
                                           workgroupSize[0]};
 
-  Value linearThreadIdVal = rewriter.create<affine::AffineLinearizeIndexOp>(
-      funcOp.getLoc(), threadGrid, threadGridBasis, /*disjoint=*/true);
+  Value linearThreadIdVal = affine::AffineLinearizeIndexOp::create(
+      rewriter, funcOp.getLoc(), threadGrid, threadGridBasis,
+      /*disjoint=*/true);
   for (auto forall : forallOps) {
     rewriter.setInsertionPoint(forall);
     if (failed(resolveGPUMappedForallOp(rewriter, forall, linearThreadIdVal,
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp
index 5c733c5..8edd8ec 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeScfFor.cpp
@@ -63,14 +63,14 @@
     const std::array<gpu::Dimension, 3> symDims = {
         gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z};
     gpu::Dimension symDim = symDims[numDimAttr.getInt()];
-    auto idOp = rewriter.create<gpu::ThreadIdOp>(loc, indexType, symDim);
-    Value count = useBlockDims
-                      ? rewriter.create<gpu::BlockDimOp>(loc, indexType, symDim)
-                            .getResult()
-                      : rewriter
-                            .create<arith::ConstantIndexOp>(
-                                loc, workgroupSize[numDimAttr.getInt()])
-                            .getResult();
+    auto idOp = gpu::ThreadIdOp::create(rewriter, loc, indexType, symDim);
+    Value count =
+        useBlockDims ? gpu::BlockDimOp::create(rewriter, loc, indexType, symDim)
+                           .getResult()
+                     : rewriter
+                           .create<arith::ConstantIndexOp>(
+                               loc, workgroupSize[numDimAttr.getInt()])
+                           .getResult();
 
     MLIRContext *context = getContext();
     AffineExpr sym0, sym1, sym2;
@@ -78,11 +78,11 @@
     auto mulAddMap = AffineMap::get(0, 3, {sym0 * sym1 + sym2}, context);
     auto mulMap = AffineMap::get(0, 2, {sym0 * sym1}, context);
 
-    auto newLb = rewriter.create<affine::AffineApplyOp>(
-        loc, mulAddMap,
+    auto newLb = affine::AffineApplyOp::create(
+        rewriter, loc, mulAddMap,
         ValueRange{idOp, forOp.getStep(), forOp.getLowerBound()});
-    auto newStep = rewriter.create<affine::AffineApplyOp>(
-        loc, mulMap, ValueRange{count, forOp.getStep()});
+    auto newStep = affine::AffineApplyOp::create(
+        rewriter, loc, mulMap, ValueRange{count, forOp.getStep()});
 
     forOp.getLowerBoundMutable().assign(newLb);
     forOp.getStepMutable().assign(newStep);
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
index 8ae7b33..dfaea2f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp
@@ -77,10 +77,10 @@
         for (unsigned i = 0; i < rank - 1; i++) {
           int64_t t = (rank - i) <= kNumGPUDims ? 1 : 0;
           tileSizesVal.push_back(
-              builder.create<arith::ConstantIndexOp>(operation->getLoc(), t));
+              arith::ConstantIndexOp::create(builder, operation->getLoc(), t));
         }
-        tileSizesVal.push_back(builder.create<arith::ConstantIndexOp>(
-            operation->getLoc(), copyTileSize));
+        tileSizesVal.push_back(arith::ConstantIndexOp::create(
+            builder, operation->getLoc(), copyTileSize));
         return tileSizesVal;
       };
   auto getCopyThreadProcInfoFn =
@@ -167,8 +167,8 @@
         std::optional<SmallVector<int64_t>> staticSize =
             getTileToDistributableSize(copyOp, flatWorkgroupSize);
         for (int64_t dim : *staticSize) {
-          tileSizesVal.push_back(
-              builder.create<arith::ConstantIndexOp>(operation->getLoc(), dim));
+          tileSizesVal.push_back(arith::ConstantIndexOp::create(
+              builder, operation->getLoc(), dim));
         }
         return tileSizesVal;
       };
@@ -201,13 +201,13 @@
     delinSizes.push_back(numThreadsDim);
   }
   ValueRange dims =
-      b.create<affine::AffineDelinearizeIndexOp>(loc, flatThreadId, delinSizes)
+      affine::AffineDelinearizeIndexOp::create(b, loc, flatThreadId, delinSizes)
           .getResults();
 
   for (auto [dimId, numThreadsDim] : llvm::zip_equal(dims, delinSizes)) {
     linalg::ProcInfo info;
     info.procId = dimId;
-    info.nprocs = b.create<arith::ConstantIndexOp>(loc, numThreadsDim);
+    info.nprocs = arith::ConstantIndexOp::create(b, loc, numThreadsDim);
     info.distributionMethod =
         linalg::DistributionMethod::CyclicNumProcsEqNumIters;
     infos.push_back(info);
@@ -239,8 +239,8 @@
           return tileSizesVal;
         SmallVector<int64_t> staticSize = getNativeDstShape(copyOp);
         for (int64_t dim : staticSize) {
-          tileSizesVal.push_back(
-              builder.create<arith::ConstantIndexOp>(operation->getLoc(), dim));
+          tileSizesVal.push_back(arith::ConstantIndexOp::create(
+              builder, operation->getLoc(), dim));
         }
         return tileSizesVal;
       };
@@ -292,13 +292,13 @@
   OpBuilder b(funcOp.getFunctionBody());
   Type indexType = b.getIndexType();
   Value threadX =
-      b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::x);
+      gpu::ThreadIdOp::create(b, funcOp.getLoc(), indexType, gpu::Dimension::x);
   Value threadY =
-      b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::y);
+      gpu::ThreadIdOp::create(b, funcOp.getLoc(), indexType, gpu::Dimension::y);
   Value threadZ =
-      b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::z);
-  Value flatThreadId = b.create<affine::AffineLinearizeIndexOp>(
-      funcOp.getLoc(), ValueRange{threadZ, threadY, threadX},
+      gpu::ThreadIdOp::create(b, funcOp.getLoc(), indexType, gpu::Dimension::z);
+  Value flatThreadId = affine::AffineLinearizeIndexOp::create(
+      b, funcOp.getLoc(), ValueRange{threadZ, threadY, threadX},
       ArrayRef<int64_t>{workgroupSize[2], workgroupSize[1], workgroupSize[0]},
       /*disjoint=*/true);
   return flatThreadId;
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
index 5365a5a..3cdb3aa 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributionPatterns.cpp
@@ -45,8 +45,8 @@
     Type elementType = constant.getType().getElementType();
     auto vectorType =
         VectorType::get(layout.getDistributedShape(), elementType);
-    auto distributedOp = rewriter.create<arith::ConstantOp>(
-        constantOp.getLoc(), vectorType,
+    auto distributedOp = arith::ConstantOp::create(
+        rewriter, constantOp.getLoc(), vectorType,
         SplatElementsAttr::get(vectorType, attr.getSplatValue<Attribute>()));
     replaceOpWithDistributedValues(rewriter, constantOp,
                                    distributedOp->getResult(0));
@@ -176,9 +176,9 @@
       newInitArgs.push_back(initArg);
     }
 
-    auto newForOp = rewriter.create<scf::ForOp>(
-        forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), newInitArgs);
+    auto newForOp =
+        scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+                           forOp.getUpperBound(), forOp.getStep(), newInitArgs);
     newForOp->setAttrs(forOp->getAttrs());
     Block *loopBody = newForOp.getBody();
 
@@ -225,7 +225,7 @@
     // Since this operation has no results, we can directly replace it using
     // the standard API.
     auto distributedYieldOp =
-        rewriter.create<scf::YieldOp>(yieldOp.getLoc(), operands);
+        scf::YieldOp::create(rewriter, yieldOp.getLoc(), operands);
     rewriter.replaceOp(yieldOp, distributedYieldOp);
     return success();
   }
@@ -240,8 +240,8 @@
     for (auto [bbArg, oldInit] : llvm::zip_equal(bbArgs, oldInits)) {
       Value val = bbArg;
       if (auto oldVectorInit = dyn_cast<VectorValue>(oldInit)) {
-        val = rewriter.create<IREE::VectorExt::ToSIMDOp>(
-            oldVectorInit.getLoc(), oldVectorInit.getType(), val);
+        val = IREE::VectorExt::ToSIMDOp::create(
+            rewriter, oldVectorInit.getLoc(), oldVectorInit.getType(), val);
       }
       replacements.push_back(val);
     }
@@ -316,8 +316,8 @@
     VectorType distributedType = VectorType::get(distributedShape, elementType);
 
     // Simply distribute all operands and results.
-    VectorValue distributed = rewriter.create<vector::GatherOp>(
-        gatherOp.getLoc(), distributedType, gatherOp.getBase(),
+    VectorValue distributed = vector::GatherOp::create(
+        rewriter, gatherOp.getLoc(), distributedType, gatherOp.getBase(),
         gatherOp.getOffsets(),
         getDistributed(rewriter, indexVec, indicesLayout),
         getDistributed(rewriter, mask, maskLayout),
@@ -344,9 +344,9 @@
     VectorValue source = extractOp.getVector();
     VectorLayoutInterface sourceLayout = signature[source];
 
-    Value distributed = rewriter.create<vector::ExtractOp>(
-        extractOp.getLoc(), getDistributed(rewriter, source, sourceLayout),
-        ArrayRef<int64_t>{});
+    Value distributed = vector::ExtractOp::create(
+        rewriter, extractOp.getLoc(),
+        getDistributed(rewriter, source, sourceLayout), ArrayRef<int64_t>{});
 
     replaceOpWithDistributedValues(rewriter, extractOp, distributed);
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
index 4dee661..ba2199f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToGlobalLoads.cpp
@@ -95,37 +95,37 @@
         copy, "Cannot proceed: cannot handle copying residual elements.");
   }
 
-  Value subgroupId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
-  Value laneId = rewriter.create<gpu::LaneIdOp>(loc, nullptr);
+  Value subgroupId = gpu::SubgroupIdOp::create(rewriter, loc, nullptr);
+  Value laneId = gpu::LaneIdOp::create(rewriter, loc, nullptr);
 
   auto sourceType = cast<MemRefType>(copy.getOperand(0).getType());
   auto localType = cast<MemRefType>(copy.getOutputs().front().getType());
 
   auto getGlobalGatherIndex = [&](Value sgIdVal, Value lIdVal,
                                   Value indVar) -> Value {
-    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    return rewriter.create<affine::AffineLinearizeIndexOp>(
-        loc, ValueRange{sgIdVal, indVar, lIdVal, zero},
+    auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+    return affine::AffineLinearizeIndexOp::create(
+        rewriter, loc, ValueRange{sgIdVal, indVar, lIdVal, zero},
         ArrayRef<int64_t>{numSubgroups, numCopiesPerThread, subgroupSize,
                           elementsPerCopy},
         /*disjoint=*/true);
   };
 
   auto getSubgroupStoreBaseIndex = [&](Value sgIdVal, Value indVar) -> Value {
-    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
     return getGlobalGatherIndex(sgIdVal, zero, indVar);
   };
 
   // Build a for loop skeleton:
-  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
   auto upperBound =
-      rewriter.create<arith::ConstantIndexOp>(loc, numCopiesPerThread);
-  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+      arith::ConstantIndexOp::create(rewriter, loc, numCopiesPerThread);
+  auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
   scf::ForOp forOp =
-      rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+      scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
 
   auto delinearizeIndex = [&](Value index, ArrayRef<int64_t> shape) {
-    return rewriter.create<affine::AffineDelinearizeIndexOp>(loc, index, shape)
+    return affine::AffineDelinearizeIndexOp::create(rewriter, loc, index, shape)
         .getMultiIndex();
   };
 
@@ -142,8 +142,8 @@
         getSubgroupStoreBaseIndex(subgroupId, inductionVar);
     ValueRange delinearizedLocalIndices =
         delinearizeIndex(linearizedBaseIndices, localType.getShape());
-    rewriter.create<IREE::GPU::GlobalLoadDMAOp>(
-        loc, copy.getOperand(0), delinearizedGlobalIndices,
+    IREE::GPU::GlobalLoadDMAOp::create(
+        rewriter, loc, copy.getOperand(0), delinearizedGlobalIndices,
         copy.getOutputs()[0], delinearizedLocalIndices);
   }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp
index 9118017..3339216 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPULowerToUKernels.cpp
@@ -56,7 +56,7 @@
   // Tiling argmax ukernel is also set to enforce this structure.
   const int kReductionDim = op.getNumLoops() - 1;
   Value reductionDimSize =
-      rewriter.create<tensor::DimOp>(loc, input, kReductionDim);
+      tensor::DimOp::create(rewriter, loc, input, kReductionDim);
   bool isPureArgmax = op.getResults()[0].use_empty();
   StringRef kernelName = ukernelAttr.getName();
   SmallVector<Type> resultTypes;
@@ -65,11 +65,11 @@
   Type valType = val.getType();
   outputs = {val, index};
   resultTypes = {valType, indexType};
-  Value writeMaxValueFlag = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI1Type(), rewriter.getBoolAttr(!isPureArgmax));
+  Value writeMaxValueFlag = arith::ConstantOp::create(
+      rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(!isPureArgmax));
 
-  auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
-      loc, resultTypes, kernelName, ValueRange{input}, outputs,
+  auto genericMicroKernelOp = IREE::Codegen::UKernelGenericOp::create(
+      rewriter, loc, resultTypes, kernelName, ValueRange{input}, outputs,
       ValueRange{reductionDimSize, writeMaxValueFlag},
       ukernelAttr.getDefAttrs(), /*num_strided_outer_dims=*/0);
   return cast<IREE::Codegen::UKernelOpInterface>(
@@ -122,10 +122,10 @@
   if (!sharedMemoryBytes) {
     IREE::Codegen::NullPointerType nullPointerType =
         IREE::Codegen::NullPointerType::get(rewriter.getContext());
-    return rewriter.create<IREE::Codegen::NullPointerOp>(loc, nullPointerType);
+    return IREE::Codegen::NullPointerOp::create(rewriter, loc, nullPointerType);
   }
   auto allocOp =
-      rewriter.create<bufferization::AllocTensorOp>(loc, tensorType, dynSizes);
+      bufferization::AllocTensorOp::create(rewriter, loc, tensorType, dynSizes);
   Attribute sharedAddrSpace = gpu::AddressSpaceAttr::get(
       rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
   allocOp.setMemorySpaceAttr(sharedAddrSpace);
@@ -195,15 +195,15 @@
     Location loc = op->getLoc();
     Type I32Type = rewriter.getI32Type();
     auto castIndexToI32 = [&](Value val) {
-      return rewriter.create<arith::IndexCastOp>(loc, I32Type, val);
+      return arith::IndexCastOp::create(rewriter, loc, I32Type, val);
     };
     auto constI32 = [&](int val) {
-      return rewriter.create<arith::ConstantIntOp>(loc, I32Type, val);
+      return arith::ConstantIntOp::create(rewriter, loc, I32Type, val);
     };
     int64_t sharedMemoryBytes = ukernelAttr.getSharedMemoryBytes();
     auto sharedMemory = createSharedMemory(rewriter, loc, sharedMemoryBytes);
     Value k = castIndexToI32(
-        rewriter.create<tensor::DimOp>(op.getLoc(), op.getInputs()[0], 1));
+        tensor::DimOp::create(rewriter, op.getLoc(), op.getInputs()[0], 1));
     Value intrinsicsM = constI32(mma.getIntrinsicsM());
     Value subgroupsM = constI32(mma.getSubgroupsM());
     Value intrinsicsN = constI32(mma.getIntrinsicsN());
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index e866f23..e07904a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -66,8 +66,8 @@
     int64_t elementCount = vectorLayout.getElementTile()[i];
     Location loc = offset.getLoc();
     SmallVector<Value> ids = {
-        warpIndices[i], b.create<arith::ConstantIndexOp>(loc, batchOffsets[i]),
-        b.create<arith::ConstantIndexOp>(loc, outerVectorOffsets[i]),
+        warpIndices[i], arith::ConstantIndexOp::create(b, loc, batchOffsets[i]),
+        arith::ConstantIndexOp::create(b, loc, outerVectorOffsets[i]),
         threadIndices[i], offset};
     // The order in which a vector dimension is "tiled" is
     // subgroups -> batches -> outer vectors -> threads -> elements
@@ -83,7 +83,7 @@
     if (std::optional<int64_t> offsetConst = getConstantIntValue(offset))
       disjoint = *offsetConst < elementCount;
     slicedIndices[pos] =
-        b.create<affine::AffineLinearizeIndexOp>(loc, ids, sizes, disjoint);
+        affine::AffineLinearizeIndexOp::create(b, loc, ids, sizes, disjoint);
   }
   return slicedIndices;
 }
@@ -141,7 +141,7 @@
   VectorType interleavedPackedType =
       VectorType::get(interleavedPackedShape, val.getType().getElementType());
   VectorValue interleavedPackedShaped =
-      rewriter.create<vector::ShapeCastOp>(loc, interleavedPackedType, val);
+      vector::ShapeCastOp::create(rewriter, loc, interleavedPackedType, val);
 
   // 0 1 2 3 4 5 ---> 0 2 4 1 3 5
   SmallVector<int64_t> perm;
@@ -151,8 +151,8 @@
       perm.push_back(tileGroupIdx * layout.getRank() + undistributedDim);
     }
   }
-  return rewriter.create<vector::TransposeOp>(loc, interleavedPackedShaped,
-                                              perm);
+  return vector::TransposeOp::create(rewriter, loc, interleavedPackedShaped,
+                                     perm);
 }
 
 /// Given a distributed vector that has B1XB2xO1XO2xE1XE2,
@@ -175,8 +175,8 @@
   }
   VectorType unpackedType = VectorType::get(
       unpackedShape, deinterleavedPacked.getType().getElementType());
-  return rewriter.create<vector::ShapeCastOp>(loc, unpackedType,
-                                              deinterleavedPacked);
+  return vector::ShapeCastOp::create(rewriter, loc, unpackedType,
+                                     deinterleavedPacked);
 }
 
 /// Given a distributed vector that has [B1xO1xE1]x[B2xO2xE2],
@@ -197,7 +197,7 @@
   VectorType nonInterleavedPackedType = VectorType::get(
       nonInterleavedPackedShape, val.getType().getElementType());
   VectorValue nonInterleavedPackedShaped =
-      rewriter.create<vector::ShapeCastOp>(loc, nonInterleavedPackedType, val);
+      vector::ShapeCastOp::create(rewriter, loc, nonInterleavedPackedType, val);
   // 0 1 2 3 4 5 ---> 0 3 1 4 2 5
   SmallVector<int64_t> perm;
   perm.reserve(layout.getRank() * 3);
@@ -206,8 +206,8 @@
       perm.push_back(tileGroupIdx + 3 * undistributedDim);
     }
   }
-  return rewriter.create<vector::TransposeOp>(loc, nonInterleavedPackedShaped,
-                                              perm);
+  return vector::TransposeOp::create(rewriter, loc, nonInterleavedPackedShaped,
+                                     perm);
 }
 
 /// Computes the warp and thread indices for the given vector layout from a
@@ -237,8 +237,9 @@
   SmallVector<int64_t> sliceMaskOffsets =
       getDistributedTransferOffsetsFromNestedLayout(offsets, vectorLayout);
   SmallVector<int64_t> strides(vectorLayout.getElementTile().size(), 1);
-  VectorValue slicedMask = rewriter.create<vector::ExtractStridedSliceOp>(
-      loc, mask, sliceMaskOffsets, vectorLayout.getElementTile(), strides);
+  VectorValue slicedMask = vector::ExtractStridedSliceOp::create(
+      rewriter, loc, mask, sliceMaskOffsets, vectorLayout.getElementTile(),
+      strides);
   return slicedMask;
 }
 
@@ -249,8 +250,9 @@
   SmallVector<int64_t> sliceMaskOffsets =
       getDistributedTransferOffsetsFromNestedLayout(offsets, vectorLayout);
   SmallVector<int64_t> strides(vectorLayout.getElementTile().size(), 1);
-  VectorValue slicedIndexVec = rewriter.create<vector::ExtractStridedSliceOp>(
-      loc, indexVec, sliceMaskOffsets, vectorLayout.getElementTile(), strides);
+  VectorValue slicedIndexVec = vector::ExtractStridedSliceOp::create(
+      rewriter, loc, indexVec, sliceMaskOffsets, vectorLayout.getElementTile(),
+      strides);
   return slicedIndexVec;
 }
 
@@ -275,21 +277,21 @@
   auto transposePerm = llvm::to_vector_of<int64_t>(slicedDims);
   transposePerm.append(remaningDims);
   auto transposed =
-      rewriter.create<vector::TransposeOp>(loc, val, transposePerm);
+      vector::TransposeOp::create(rewriter, loc, val, transposePerm);
 
   SmallVector<int64_t> extractedPos(slicedDims.size(), 0);
   auto sliced =
-      rewriter.create<vector::ExtractOp>(loc, transposed, extractedPos);
+      vector::ExtractOp::create(rewriter, loc, transposed, extractedPos);
   return cast<VectorValue>(sliced.getResult());
 }
 
 static VectorValue extractSliceAsVector(RewriterBase &rewriter, Location loc,
                                         Value src, ArrayRef<int64_t> offsets) {
-  Value slice = rewriter.create<vector::ExtractOp>(loc, src, offsets);
+  Value slice = vector::ExtractOp::create(rewriter, loc, src, offsets);
   // Promote the slicedVector to 0-d vector if it is a scalar.
   if (!isa<VectorType>(slice.getType())) {
     auto promotedType = VectorType::get({}, getElementTypeOrSelf(slice));
-    slice = rewriter.create<vector::BroadcastOp>(loc, promotedType, slice);
+    slice = vector::BroadcastOp::create(rewriter, loc, promotedType, slice);
   }
   return cast<VectorValue>(slice);
 }
@@ -349,8 +351,9 @@
 
     // Initialize the full distributed vector for unrolling the batch/outer
     // vector dimensions.
-    Value zero = rewriter.create<arith::ConstantOp>(
-        readOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
+    Value zero =
+        arith::ConstantOp::create(rewriter, readOp.getLoc(), vectorType,
+                                  rewriter.getZeroAttr(vectorType));
     VectorValue acc = cast<VectorValue>(zero);
 
     SmallVector<Value> warpIndices, threadIndices;
@@ -390,10 +393,10 @@
                                            maskOffsets, maskLayout, mask);
       }
 
-      VectorValue slicedRead = rewriter.create<vector::TransferReadOp>(
-          readOp.getLoc(), innerVectorType, readOp.getBase(), slicedIndices,
-          readOp.getPermutationMapAttr(), readOp.getPadding(), slicedMask,
-          readOp.getInBoundsAttr());
+      VectorValue slicedRead = vector::TransferReadOp::create(
+          rewriter, readOp.getLoc(), innerVectorType, readOp.getBase(),
+          slicedIndices, readOp.getPermutationMapAttr(), readOp.getPadding(),
+          slicedMask, readOp.getInBoundsAttr());
 
       if (acc.getType().getRank() == 0) {
         // TODO: This should really be a folding pattern in
@@ -401,8 +404,8 @@
         // support 0-d vectors...
         acc = slicedRead;
       } else {
-        acc = rewriter.create<vector::InsertStridedSliceOp>(
-            readOp.getLoc(), slicedRead, acc, offsets, strides);
+        acc = vector::InsertStridedSliceOp::create(
+            rewriter, readOp.getLoc(), slicedRead, acc, offsets, strides);
       }
     }
 
@@ -502,10 +505,10 @@
                                            maskOffsets, maskLayout, mask);
       }
 
-      rewriter.create<vector::TransferWriteOp>(
-          writeOp.getLoc(), slicedVector, writeOp.getBase(), slicedIndices,
-          writeOp.getPermutationMapAttr(), slicedMask,
-          writeOp.getInBoundsAttr());
+      vector::TransferWriteOp::create(rewriter, writeOp.getLoc(), slicedVector,
+                                      writeOp.getBase(), slicedIndices,
+                                      writeOp.getPermutationMapAttr(),
+                                      slicedMask, writeOp.getInBoundsAttr());
     }
 
     rewriter.eraseOp(writeOp);
@@ -584,8 +587,9 @@
 
     // Initialize the full distributed vector for unrolling the batch/outer
     // vector dimensions.
-    Value zero = rewriter.create<arith::ConstantOp>(
-        gatherOp.getLoc(), vectorType, rewriter.getZeroAttr(vectorType));
+    Value zero =
+        arith::ConstantOp::create(rewriter, gatherOp.getLoc(), vectorType,
+                                  rewriter.getZeroAttr(vectorType));
     VectorValue acc = cast<VectorValue>(zero);
 
     SmallVector<Value> warpIndices, threadIndices;
@@ -655,12 +659,11 @@
                                            maskOffsets, maskLayout, mask);
       }
 
-      VectorValue slicedGather =
-          rewriter.create<IREE::VectorExt::TransferGatherOp>(
-              gatherOp.getLoc(), innerVectorType, gatherOp.getBase(),
-              slicedIndices, slicedIndexVecs, gatherOp.getIndexed(),
-              gatherOp.getIndexedMaps(), gatherOp.getPermutationMapAttr(),
-              gatherOp.getPadding(), slicedMask, gatherOp.getInBoundsAttr());
+      VectorValue slicedGather = IREE::VectorExt::TransferGatherOp::create(
+          rewriter, gatherOp.getLoc(), innerVectorType, gatherOp.getBase(),
+          slicedIndices, slicedIndexVecs, gatherOp.getIndexed(),
+          gatherOp.getIndexedMaps(), gatherOp.getPermutationMapAttr(),
+          gatherOp.getPadding(), slicedMask, gatherOp.getInBoundsAttr());
 
       if (acc.getType().getRank() == 0) {
         // TODO: This should really be a folding pattern in
@@ -668,8 +671,8 @@
         // support 0-d vectors...
         acc = slicedGather;
       } else {
-        acc = rewriter.create<vector::InsertStridedSliceOp>(
-            gatherOp.getLoc(), slicedGather, acc, offsets, strides);
+        acc = vector::InsertStridedSliceOp::create(
+            rewriter, gatherOp.getLoc(), slicedGather, acc, offsets, strides);
       }
     }
 
@@ -721,7 +724,7 @@
     Value distributedVector = getDistributed(rewriter, input, vectorLayout);
 
     Location loc = mapScatterOp.getLoc();
-    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
     SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
     SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
     for (auto [idx, offsets] :
@@ -762,8 +765,8 @@
           // the mapping is achieved by adding the corresponding block argument
           // to the sliced index.
           BlockArgument newTransformationIdx = newIndices[i - rankDiff];
-          replacementIdx = rewriter.create<arith::AddIOp>(
-              loc, newTransformationIdx, replacementIdx);
+          replacementIdx = arith::AddIOp::create(
+              rewriter, loc, newTransformationIdx, replacementIdx);
         }
         return replacementIndices;
       };
@@ -802,16 +805,16 @@
 
   VectorType broadcastedVecType =
       VectorType::get(leadingBroadcastShape, getElementTypeOrSelf(source));
-  VectorValue broadcasted = rewriter.create<vector::BroadcastOp>(
-      source.getLoc(), broadcastedVecType, source);
+  VectorValue broadcasted = vector::BroadcastOp::create(
+      rewriter, source.getLoc(), broadcastedVecType, source);
 
   // Transpose the broadcasted dims to the right place.
   SmallVector<int64_t> inversePerm = invertPermutationVector(perm);
   if (isIdentityPermutation(inversePerm)) {
     return broadcasted;
   }
-  return rewriter.create<vector::TransposeOp>(source.getLoc(), broadcasted,
-                                              inversePerm);
+  return vector::TransposeOp::create(rewriter, source.getLoc(), broadcasted,
+                                     inversePerm);
 }
 
 struct DistributeBroadcast final : OpDistributionPattern<vector::BroadcastOp> {
@@ -947,7 +950,7 @@
           loc, rewriter, multiReduceOp.getKind(), disSrc.getType());
 
       disSrc = cast<VectorValue>(
-          rewriter.create<arith::SelectOp>(loc, mask, disSrc, passThruSrc)
+          arith::SelectOp::create(rewriter, loc, mask, disSrc, passThruSrc)
               .getResult());
     }
 
@@ -965,8 +968,8 @@
     }
     Value localInit = getCombiningIdentityValue(
         loc, rewriter, multiReduceOp.getKind(), disAcc.getType());
-    Value localReduction = rewriter.create<vector::MultiDimReductionOp>(
-        loc, disSrc, localInit, distributedReductionMask,
+    Value localReduction = vector::MultiDimReductionOp::create(
+        rewriter, loc, disSrc, localInit, distributedReductionMask,
         multiReduceOp.getKind());
 
     // TODO: As per current upstream lowering implementations, there is no point
@@ -982,7 +985,7 @@
       // Broadcast scalar accumulator to vector.
       VectorType vecType = VectorType::get(ArrayRef{int64_t(1)}, elemTy);
       locallyReduced =
-          rewriter.create<vector::BroadcastOp>(loc, vecType, localReduction);
+          vector::BroadcastOp::create(rewriter, loc, vecType, localReduction);
     }
 
     assert(locallyReduced && "result should have been a vector");
@@ -998,8 +1001,8 @@
       int64_t numElements = shaped.getNumElements();
       SmallVector<int64_t> flatShape(1, numElements);
       VectorType flatVecType = VectorType::get(flatShape, elemTy);
-      VectorValue flat = rewriter.create<vector::ShapeCastOp>(loc, flatVecType,
-                                                              locallyReduced);
+      VectorValue flat = vector::ShapeCastOp::create(rewriter, loc, flatVecType,
+                                                     locallyReduced);
 
       // Do inter-thread/warp reduce.
       FailureOr<VectorValue> threadReducedFlat = doThreadReduction(
@@ -1010,15 +1013,15 @@
 
       // Do reduction against accumulator, which needs to be done after thread
       // reduction.
-      threadReduced = rewriter.create<vector::ShapeCastOp>(
-          loc, shaped, threadReducedFlat.value());
+      threadReduced = vector::ShapeCastOp::create(rewriter, loc, shaped,
+                                                  threadReducedFlat.value());
     }
 
     if (!accVector) {
       // Broadcast the scalar (e.g., f32) to a vector type (e.g., vector<f32>)
       // because the following implementation requires the operand to be a
       // vector.
-      disAcc = rewriter.create<vector::BroadcastOp>(loc, shaped, disAcc);
+      disAcc = vector::BroadcastOp::create(rewriter, loc, shaped, disAcc);
     }
 
     bool hasSubgroupReductions =
@@ -1036,8 +1039,8 @@
       if (resVector) {
         replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);
       } else {
-        Value accReducedVal = rewriter.create<vector::ExtractOp>(
-            loc, accReduction, ArrayRef{int64_t(0)});
+        Value accReducedVal = vector::ExtractOp::create(
+            rewriter, loc, accReduction, ArrayRef{int64_t(0)});
         replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal);
       }
       return success();
@@ -1059,12 +1062,12 @@
     int64_t numElements = flatVecType.getNumElements();
     Location loc = flat.getLoc();
 
-    auto constOp = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getZeroAttr(flatVecType));
+    auto constOp = arith::ConstantOp::create(rewriter, loc,
+                                             rewriter.getZeroAttr(flatVecType));
     auto res = llvm::cast<VectorValue>(constOp.getResult());
 
     for (unsigned i = 0; i < numElements; ++i) {
-      Value extracted = rewriter.create<vector::ExtractOp>(loc, flat, i);
+      Value extracted = vector::ExtractOp::create(rewriter, loc, flat, i);
       // Reduce across all reduction dimensions 1-by-1.
       for (unsigned i = 0, e = reductionMask.size(); i != e; ++i) {
         if (reductionMask[i]) {
@@ -1073,26 +1076,26 @@
           assert(offset <= std::numeric_limits<uint32_t>::max() &&
                  width <= std::numeric_limits<uint32_t>::max());
 
-          extracted = rewriter.create<gpu::SubgroupReduceOp>(
-              loc, extracted, combiningKindToAllReduce(kind),
+          extracted = gpu::SubgroupReduceOp::create(
+              rewriter, loc, extracted, combiningKindToAllReduce(kind),
               /*uniform=*/false, /*cluster_size=*/width,
               /*cluster_stride=*/offset);
         }
       }
 
-      res = rewriter.create<vector::InsertOp>(loc, extracted, res, i);
+      res = vector::InsertOp::create(rewriter, loc, extracted, res, i);
     }
     return res;
   }
 
   Value getBufferForSubgroupReduction(RewriterBase &rewriter, MemRefType memTy,
                                       Value val) const {
-    auto alloc = rewriter.create<memref::AllocOp>(val.getLoc(), memTy);
+    auto alloc = memref::AllocOp::create(rewriter, val.getLoc(), memTy);
     // Insert gpu.barrier to make sure previous iteration of batch loop has
     // fully read the subgroup partial reductions.
     // TODO: We should be only creating a barrier if this buffer is going to be
     // reused.
-    rewriter.create<gpu::BarrierOp>(val.getLoc());
+    gpu::BarrierOp::create(rewriter, val.getLoc());
     return alloc;
   }
 
@@ -1164,12 +1167,12 @@
                                   VectorValue valueToWrite, Value buffer,
                                   NestedLayoutAttr srcLayout,
                                   ArrayRef<int64_t> reductionDims) const {
-    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
     VectorType unDistributedType = valueToWrite.getType();
     SmallVector<Value> indices(unDistributedType.getRank(), c0);
     SmallVector<bool> inBounds(unDistributedType.getRank(), true);
-    auto write = rewriter.create<vector::TransferWriteOp>(
-        loc, valueToWrite, buffer, indices, inBounds);
+    auto write = vector::TransferWriteOp::create(rewriter, loc, valueToWrite,
+                                                 buffer, indices, inBounds);
     // Set layouts signature for write.
     // We need to set the layout on the srcVector/first operand.
     auto subgroupTileLens =
@@ -1210,14 +1213,14 @@
                                                getElementTypeOrSelf(buffer));
     auto readTy = VectorType::get(readLayout.getUndistributedShape(),
                                   getElementTypeOrSelf(buffer));
-    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
     auto inBounds = rewriter.getBoolArrayAttr(
         SmallVector<bool>(readLayout.getRank(), true));
-    auto mask = rewriter.create<vector::CreateMaskOp>(
-        loc, readTy.clone(rewriter.getI1Type()),
+    auto mask = vector::CreateMaskOp::create(
+        rewriter, loc, readTy.clone(rewriter.getI1Type()),
         memref::getMixedSizes(rewriter, loc, buffer));
-    auto read = rewriter.create<vector::TransferReadOp>(
-        loc,
+    auto read = vector::TransferReadOp::create(
+        rewriter, loc,
         /*vectorType=*/readTy,
         /*source=*/buffer,
         /*indices=*/SmallVector<Value>(readLayout.getRank(), zero),
@@ -1232,8 +1235,8 @@
     // different subgroups.
     // Since the data was distributed to every thread, it will
     // form a gpu.subgroup_reduce operation later.
-    auto secondReduction = rewriter.create<vector::MultiDimReductionOp>(
-        loc, kind, read, acc, reductionDims);
+    auto secondReduction = vector::MultiDimReductionOp::create(
+        rewriter, loc, kind, read, acc, reductionDims);
     if (resLayout) {
       setSignatureForRedistribution(rewriter, secondReduction,
                                     {readLayout, resLayout}, {resLayout});
@@ -1279,8 +1282,8 @@
     }
     VectorType partialReducedDistributedType = VectorType::get(
         partialReducedDistributedShape, srcVector.getType().getElementType());
-    Value isoRankThreadReduced = rewriter.create<vector::ShapeCastOp>(
-        loc, partialReducedDistributedType, threadReduced);
+    Value isoRankThreadReduced = vector::ShapeCastOp::create(
+        rewriter, loc, partialReducedDistributedType, threadReduced);
 
     SmallVector<int64_t> preDistrShape =
         srcLayout.getUndistributedPackedShape();
@@ -1295,8 +1298,8 @@
     }
     auto unDistributedType = VectorType::get(
         partialReductionShape, srcVector.getType().getElementType());
-    VectorValue valueToWrite = rewriter.create<IREE::VectorExt::ToSIMDOp>(
-        loc, unDistributedType, isoRankThreadReduced);
+    VectorValue valueToWrite = IREE::VectorExt::ToSIMDOp::create(
+        rewriter, loc, unDistributedType, isoRankThreadReduced);
 
     auto workgroupMemoryAddressSpace = Attribute(gpu::AddressSpaceAttr::get(
         rewriter.getContext(), gpu::AddressSpace::Workgroup));
@@ -1308,7 +1311,7 @@
     writePartialResultToBuffer(rewriter, loc, valueToWrite, alloc, srcLayout,
                                reductionDims);
     // Wait for writes to buffer to finish.
-    rewriter.create<gpu::BarrierOp>(loc);
+    gpu::BarrierOp::create(rewriter, loc);
     return doSubgroupReductionFromBuffer(rewriter, loc, alloc, srcLayout,
                                          resLayout, reductionDims, kind, acc);
   }
@@ -1447,7 +1450,7 @@
     } else {
       VectorType vecType = VectorType::get(ArrayRef{int64_t(1)}, accElemTy);
       localContractValue =
-          rewriter.create<vector::BroadcastOp>(loc, vecType, localContract);
+          vector::BroadcastOp::create(rewriter, loc, vecType, localContract);
     }
 
     assert(localContractValue && "result should have been a vector");
@@ -1495,17 +1498,17 @@
     VectorType partialReducedDistributedType =
         VectorType::get(reductionLayout.getDistributedShape(),
                         localContractValue.getType().getElementType());
-    Value shapeCasted = rewriter.create<vector::ShapeCastOp>(
-        loc, partialReducedDistributedType, localContractValue);
+    Value shapeCasted = vector::ShapeCastOp::create(
+        rewriter, loc, partialReducedDistributedType, localContractValue);
     VectorType unDistributedType =
         VectorType::get(reductionLayout.getUndistributedShape(),
                         localContractValue.getType().getElementType());
-    Value undistrLocalReduced = rewriter.create<IREE::VectorExt::ToSIMDOp>(
-        loc, unDistributedType, shapeCasted);
+    Value undistrLocalReduced = IREE::VectorExt::ToSIMDOp::create(
+        rewriter, loc, unDistributedType, shapeCasted);
 
     // Create the partial reduction
-    auto partialReduction = rewriter.create<vector::MultiDimReductionOp>(
-        loc, contractOp.getKind(), undistrLocalReduced, acc,
+    auto partialReduction = vector::MultiDimReductionOp::create(
+        rewriter, loc, contractOp.getKind(), undistrLocalReduced, acc,
         partialReductionDims);
     if (resVector) {
       setSignatureForRedistribution(rewriter, partialReduction,
@@ -1557,8 +1560,9 @@
     Value localInit = getCombiningIdentityValue(
         loc, rewriter, contractOp.getKind(), acc.getType());
 
-    auto localContractOp = rewriter.create<vector::ContractionOp>(
-        loc, lhs, rhs, localInit, rewriter.getAffineMapArrayAttr(newMaps),
+    auto localContractOp = vector::ContractionOp::create(
+        rewriter, loc, lhs, rhs, localInit,
+        rewriter.getAffineMapArrayAttr(newMaps),
         rewriter.getArrayAttr(newIterators), contractOp.getKind());
     localContractOp->setDiscardableAttrs(
         contractOp->getDiscardableAttrDictionary());
@@ -1614,8 +1618,8 @@
         permutation.push_back(it + (i * rank));
       }
     }
-    VectorValue transposed = rewriter.create<vector::TransposeOp>(
-        transposeOp.getLoc(), input, permutation);
+    VectorValue transposed = vector::TransposeOp::create(
+        rewriter, transposeOp.getLoc(), input, permutation);
     replaceOpWithDistributedValues(rewriter, transposeOp, transposed);
     return success();
   }
@@ -1683,8 +1687,9 @@
       interleavePermutation[2 * i + 1] = i + rank;
     }
 
-    auto interleaved = rewriter.create<vector::TransposeOp>(
-        loc, getDistributed(rewriter, input, layoutA), interleavePermutation);
+    auto interleaved = vector::TransposeOp::create(
+        rewriter, loc, getDistributed(rewriter, input, layoutA),
+        interleavePermutation);
 
     // Shape cast to match the new layout.
 
@@ -1694,14 +1699,14 @@
         transposedShapeB, interleaved.getResultVectorType().getElementType());
 
     auto reshaped =
-        rewriter.create<vector::ShapeCastOp>(loc, reshapedType, interleaved);
+        vector::ShapeCastOp::create(rewriter, loc, reshapedType, interleaved);
 
     // Inverse transpose to preserve original order.
     SmallVector<int64_t> invertedPermutation =
         invertPermutationVector(interleavePermutation);
 
-    auto layouted = rewriter.create<vector::TransposeOp>(loc, reshaped,
-                                                         invertedPermutation);
+    auto layouted = vector::TransposeOp::create(rewriter, loc, reshaped,
+                                                invertedPermutation);
 
     replaceOpWithDistributedValues(rewriter, toLayoutOp, layouted.getResult());
     return success();
@@ -1832,20 +1837,20 @@
     }
     VectorType offsetType =
         VectorType::get({distributedLen}, builder.getIndexType());
-    auto constOffset = builder.create<arith::ConstantOp>(
-        loc, DenseElementsAttr::get(offsetType, offsets));
+    auto constOffset = arith::ConstantOp::create(
+        builder, loc, DenseElementsAttr::get(offsetType, offsets));
     Value finalOffset = constOffset;
     for (const DimInfo &dimInfo : distributedDims) {
       assert(dimInfo.dimIdx.has_value());
       if (dimInfo.dimStride != 0) {
         auto strideVal =
-            builder.create<arith::ConstantIndexOp>(loc, dimInfo.dimStride);
-        auto dimIdxOffsetPerElem = builder.create<arith::MulIOp>(
-            loc, strideVal, dimInfo.dimIdx.value());
-        auto dimIdxOffset = builder.create<vector::BroadcastOp>(
-            loc, offsetType, dimIdxOffsetPerElem);
+            arith::ConstantIndexOp::create(builder, loc, dimInfo.dimStride);
+        auto dimIdxOffsetPerElem = arith::MulIOp::create(
+            builder, loc, strideVal, dimInfo.dimIdx.value());
+        auto dimIdxOffset = vector::BroadcastOp::create(
+            builder, loc, offsetType, dimIdxOffsetPerElem);
         finalOffset =
-            builder.create<arith::AddIOp>(loc, finalOffset, dimIdxOffset);
+            arith::AddIOp::create(builder, loc, finalOffset, dimIdxOffset);
       }
     }
     return cast<VectorValue>(finalOffset);
@@ -1898,8 +1903,8 @@
         rewriter, loc, distributedDims, distributedElements, originalElements);
     VectorType finalSlicedStepOpType =
         VectorType::get({distributedShape}, result.getType().getElementType());
-    auto finalSlicedStepOp = rewriter.create<vector::ShapeCastOp>(
-        loc, finalSlicedStepOpType, slicedStepOp);
+    auto finalSlicedStepOp = vector::ShapeCastOp::create(
+        rewriter, loc, finalSlicedStepOpType, slicedStepOp);
     replaceOpWithDistributedValues(rewriter, stepOp, {finalSlicedStepOp});
     return success();
   }
@@ -1920,8 +1925,8 @@
   constexpr int64_t threadIdx = 3;
   constexpr int64_t elementIdx = 4;
   SmallVector<Value> bounds;
-  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+  auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
 
   for (auto [unDistributedDim, upperBound] : llvm::enumerate(upperBounds)) {
     SmallVector<int64_t> undistributedShape =
@@ -1931,10 +1936,10 @@
                                       undistributedShape[elementIdx]};
     int64_t elementPerThread = ShapedType::getNumElements(distrShape);
     auto allValid =
-        rewriter.create<arith::ConstantIndexOp>(loc, elementPerThread);
+        arith::ConstantIndexOp::create(rewriter, loc, elementPerThread);
     int64_t elementTileSize = distrShape.back();
     auto elementTileLastIdx =
-        rewriter.create<arith::ConstantIndexOp>(loc, elementTileSize - 1);
+        arith::ConstantIndexOp::create(rewriter, loc, elementTileSize - 1);
 
     // A special condition if the pre-distribution bounds match
     // the mask dimension length, then the distributed bounds
@@ -1950,10 +1955,9 @@
         continue;
       }
     }
-    auto lastValidIdx = rewriter.create<arith::SubIOp>(loc, upperBound, one);
-    auto delineraizedLastValidIdx =
-        rewriter.create<affine::AffineDelinearizeIndexOp>(loc, lastValidIdx,
-                                                          undistributedShape);
+    auto lastValidIdx = arith::SubIOp::create(rewriter, loc, upperBound, one);
+    auto delineraizedLastValidIdx = affine::AffineDelinearizeIndexOp::create(
+        rewriter, loc, lastValidIdx, undistributedShape);
     SmallVector<Value> packedLastValidIdx =
         delineraizedLastValidIdx.getResults();
 
@@ -1961,24 +1965,22 @@
     // Every [vtid] less than [vtid that encounters last valid element] should
     // have a all valid element tile
     auto linearizedLastValidIdxPreThreads =
-        rewriter.create<affine::AffineLinearizeIndexOp>(
-            loc,
+        affine::AffineLinearizeIndexOp::create(
+            rewriter, loc,
             ValueRange{packedLastValidIdx[batchIdx],
                        packedLastValidIdx[outerIdx], elementTileLastIdx},
             distrShape);
     // Bound is defined as lastIdx + 1;
-    auto distrUpperBoundPreThreads = rewriter.create<arith::AddIOp>(
-        loc, linearizedLastValidIdxPreThreads, one);
+    auto distrUpperBoundPreThreads = arith::AddIOp::create(
+        rewriter, loc, linearizedLastValidIdxPreThreads, one);
 
-    auto linearizedLastValidIdx =
-        rewriter.create<affine::AffineLinearizeIndexOp>(
-            loc,
-            ValueRange{packedLastValidIdx[batchIdx],
-                       packedLastValidIdx[outerIdx],
-                       packedLastValidIdx[elementIdx]},
-            distrShape);
+    auto linearizedLastValidIdx = affine::AffineLinearizeIndexOp::create(
+        rewriter, loc,
+        ValueRange{packedLastValidIdx[batchIdx], packedLastValidIdx[outerIdx],
+                   packedLastValidIdx[elementIdx]},
+        distrShape);
     auto distrUpperBound =
-        rewriter.create<arith::AddIOp>(loc, linearizedLastValidIdx, one);
+        arith::AddIOp::create(rewriter, loc, linearizedLastValidIdx, one);
 
     // The following code constructs a selection tree
     // that in effect follows the code:
@@ -1997,34 +1999,34 @@
     //     [u1][u2][u4]
 
     // tid == u3
-    auto cmpBoundTidEq = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::eq, threadIndices[unDistributedDim],
-        packedLastValidIdx[threadIdx]);
+    auto cmpBoundTidEq = arith::CmpIOp::create(
+        rewriter, loc, arith::CmpIPredicate::eq,
+        threadIndices[unDistributedDim], packedLastValidIdx[threadIdx]);
     // tid < u3
-    auto cmpBoundTidSlt = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::slt, threadIndices[unDistributedDim],
-        packedLastValidIdx[threadIdx]);
+    auto cmpBoundTidSlt = arith::CmpIOp::create(
+        rewriter, loc, arith::CmpIPredicate::slt,
+        threadIndices[unDistributedDim], packedLastValidIdx[threadIdx]);
     // sg == u0
-    auto cmpBoundSgEq = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::eq, subgroupIndices[unDistributedDim],
-        packedLastValidIdx[subgroupIdx]);
+    auto cmpBoundSgEq = arith::CmpIOp::create(
+        rewriter, loc, arith::CmpIPredicate::eq,
+        subgroupIndices[unDistributedDim], packedLastValidIdx[subgroupIdx]);
     // sg < u0
-    auto cmpBoundSgSlt = rewriter.create<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::slt, subgroupIndices[unDistributedDim],
-        packedLastValidIdx[subgroupIdx]);
+    auto cmpBoundSgSlt = arith::CmpIOp::create(
+        rewriter, loc, arith::CmpIPredicate::slt,
+        subgroupIndices[unDistributedDim], packedLastValidIdx[subgroupIdx]);
 
     // selectTid0 = tid < u3 ? [u1][u2][max] : all invalid
-    auto selectTid0 = rewriter.create<arith::SelectOp>(
-        loc, cmpBoundTidSlt, distrUpperBoundPreThreads, zero);
+    auto selectTid0 = arith::SelectOp::create(rewriter, loc, cmpBoundTidSlt,
+                                              distrUpperBoundPreThreads, zero);
     // selectTid1 = tid == u3 : [u1][u2][u4] : selectTid0
-    auto selectTid1 = rewriter.create<arith::SelectOp>(
-        loc, cmpBoundTidEq, distrUpperBound, selectTid0);
+    auto selectTid1 = arith::SelectOp::create(rewriter, loc, cmpBoundTidEq,
+                                              distrUpperBound, selectTid0);
     // selectSg0 = sg < u0 ? all valid : all invalid
     auto selectSg0 =
-        rewriter.create<arith::SelectOp>(loc, cmpBoundSgSlt, allValid, zero);
+        arith::SelectOp::create(rewriter, loc, cmpBoundSgSlt, allValid, zero);
     // selectSg1 = sg == u0 ? selectTid1 : selectSg0
-    auto selectSg1 = rewriter.create<arith::SelectOp>(loc, cmpBoundSgEq,
-                                                      selectTid1, selectSg0);
+    auto selectSg1 = arith::SelectOp::create(rewriter, loc, cmpBoundSgEq,
+                                             selectTid1, selectSg0);
     bounds.push_back(selectSg1);
   }
   return bounds;
@@ -2064,8 +2066,8 @@
     Type elemType = maskOp.getType().getElementType();
     auto distrUnpackedType =
         VectorType::get(resultLayout.getDistributedUnpackedShape(), elemType);
-    auto distrMask = rewriter.create<vector::CreateMaskOp>(
-        loc, distrUnpackedType, distributedBounds);
+    auto distrMask = vector::CreateMaskOp::create(
+        rewriter, loc, distrUnpackedType, distributedBounds);
     VectorValue interleavedDistrMask =
         getInterleavedPackedForm(rewriter, distrMask, resultLayout);
     replaceOpWithDistributedValues(rewriter, maskOp, {interleavedDistrMask});
@@ -2104,7 +2106,7 @@
 
     SmallVector<Value> constOperands;
     for (int64_t size : maskOp.getMaskDimSizes()) {
-      Value index = rewriter.create<arith::ConstantIndexOp>(loc, size);
+      Value index = arith::ConstantIndexOp::create(rewriter, loc, size);
       constOperands.push_back(index);
     }
 
@@ -2115,8 +2117,8 @@
     Type elemType = maskOp.getType().getElementType();
     auto distrUnpackedType =
         VectorType::get(resultLayout.getDistributedUnpackedShape(), elemType);
-    auto distrMask = rewriter.create<vector::CreateMaskOp>(
-        loc, distrUnpackedType, distributedBounds);
+    auto distrMask = vector::CreateMaskOp::create(
+        rewriter, loc, distrUnpackedType, distributedBounds);
     VectorValue interleavedDistrMask =
         getInterleavedPackedForm(rewriter, distrMask, resultLayout);
     replaceOpWithDistributedValues(rewriter, maskOp, {interleavedDistrMask});
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp
index 3e639e8..1a346f2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPackToIntrinsics.cpp
@@ -191,26 +191,26 @@
         packOp.getMixedTiles(), packOp.getInnerDimsPos(),
         packOp.getOuterDimsPerm());
 
-    auto packedDest = rewriter.create<linalg::PackOp>(
-        loc, forOp.getInitArgs()[tiedResultIdx], input,
+    auto packedDest = linalg::PackOp::create(
+        rewriter, loc, forOp.getInitArgs()[tiedResultIdx], input,
         packOp.getInnerDimsPos(), packOp.getMixedTiles(),
         packOp.getPaddingValue(), packOp.getOuterDimsPerm());
 
     auto packOpValues = llvm::to_vector_of<Value>(forOp.getInitArgs());
     packOpValues[tiedResultIdx] = packedDest.getResult();
-    scf::ForOp newForOp = rewriter.create<scf::ForOp>(
-        loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
-        packOpValues);
+    scf::ForOp newForOp = scf::ForOp::create(
+        rewriter, loc, forOp.getLowerBound(), forOp.getUpperBound(),
+        forOp.getStep(), packOpValues);
 
     // Destination tensor for the new unpackOp, based on the shape of the
     // original tensor that got packed, to help unpack into unaligned shapes and
     // drop padding added by the packOp.
-    Value empty = rewriter.create<tensor::EmptyOp>(
-        loc, packOp.getSourceType().getShape(),
+    Value empty = tensor::EmptyOp::create(
+        rewriter, loc, packOp.getSourceType().getShape(),
         packOp.getSourceType().getElementType());
 
-    auto unpackedOutput = rewriter.create<linalg::UnPackOp>(
-        loc, newForOp.getResults()[tiedResultIdx], empty,
+    auto unpackedOutput = linalg::UnPackOp::create(
+        rewriter, loc, newForOp.getResults()[tiedResultIdx], empty,
         unpackOp.getInnerDimsPos(), unpackOp.getMixedTiles(),
         unpackOp.getOuterDimsPerm());
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp
index 1abfe51..68c7886 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.cpp
@@ -123,15 +123,17 @@
         llvm::cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
             vectorShapeCollapse, sourceType, subViewOffsets, subViewSizes,
             subViewStrides));
-    Value subView = rewriter.create<memref::SubViewOp>(
-        loc, resultType, source, subViewOffsets, subViewSizes, subViewStrides);
-    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    Value readCollapse = rewriter.create<vector::TransferReadOp>(
-        loc, vectorTypeCollapse, subView, ValueRange{c0, c0}, newidentityMap,
-        transferReadOp.getPadding(), transferReadOp.getMask(), newInBoundsAttr);
+    Value subView =
+        memref::SubViewOp::create(rewriter, loc, resultType, source,
+                                  subViewOffsets, subViewSizes, subViewStrides);
+    Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+    Value readCollapse = vector::TransferReadOp::create(
+        rewriter, loc, vectorTypeCollapse, subView, ValueRange{c0, c0},
+        newidentityMap, transferReadOp.getPadding(), transferReadOp.getMask(),
+        newInBoundsAttr);
 
-    Value readBroadcast = rewriter.create<vector::BroadcastOp>(
-        loc, vectorTypeBroadcast, readCollapse);
+    Value readBroadcast = vector::BroadcastOp::create(
+        rewriter, loc, vectorTypeBroadcast, readCollapse);
     SmallVector<int64_t> transposePermutation;
     for (int i = 0; i < vectorType.getRank(); i++) {
       if (i == vectorType.getRank() - 2)
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp
index 65e0c04..457cf3f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPipelining.cpp
@@ -80,20 +80,20 @@
   // Create srcElement Value based on the pred.
   // The next few lins generate the below code:
   // srcElement = (pred) ?  prevSrcElements : 0;
-  Value dstElements =
-      rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
+  Value dstElements = arith::ConstantOp::create(
+      rewriter, loc, asyncCopyOp.getDstElementsAttr());
   Value originalSrcElement =
       asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
-  Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0);
   auto srcElements =
-      rewriter.create<arith::SelectOp>(loc, pred, originalSrcElement, c0Index);
+      arith::SelectOp::create(rewriter, loc, pred, originalSrcElement, c0Index);
   int64_t sizeInBytes =
       (asyncCopyOp.getDst().getType().getElementTypeBitWidth() *
        asyncCopyOp.getDstElements().getZExtValue()) /
       8;
   UnitAttr bypassL1 = sizeInBytes == 16 ? rewriter.getUnitAttr() : UnitAttr();
-  auto asyncCopyZfillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
-      loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
+  auto asyncCopyZfillOp = nvgpu::DeviceAsyncCopyOp::create(
+      rewriter, loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
       asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
       asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
       bypassL1);
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp
index 94e8ba0..0809e83 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp
@@ -35,9 +35,9 @@
   auto tensorType = cast<RankedTensorType>(v.getType());
   SmallVector<OpFoldResult> mixedSizes = tensor::getMixedSizes(builder, loc, v);
 
-  Value empty = builder.create<tensor::EmptyOp>(loc, mixedSizes,
-                                                tensorType.getElementType());
-  auto copy = builder.create<linalg::CopyOp>(loc, v, empty);
+  Value empty = tensor::EmptyOp::create(builder, loc, mixedSizes,
+                                        tensorType.getElementType());
+  auto copy = linalg::CopyOp::create(builder, loc, v, empty);
 
   if (useDirectLoad) {
     setLoweringConfig(
@@ -85,16 +85,16 @@
   for (auto [idx, size] : llvm::enumerate(tensorType.getShape())) {
     if (ShapedType::isDynamic(size)) {
       dynamicSizes.push_back(
-          rewriter.create<tensor::DimOp>(loc, valToMakeShared, idx));
+          tensor::DimOp::create(rewriter, loc, valToMakeShared, idx));
     }
   }
   Attribute addressSpace = gpu::AddressSpaceAttr::get(
       rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
-  auto alloc = rewriter.create<bufferization::AllocTensorOp>(loc, tensorType,
-                                                             dynamicSizes);
+  auto alloc = bufferization::AllocTensorOp::create(rewriter, loc, tensorType,
+                                                    dynamicSizes);
   alloc.setMemorySpaceAttr(addressSpace);
   auto copy =
-      rewriter.create<linalg::CopyOp>(loc, valToMakeShared, alloc.getResult());
+      linalg::CopyOp::create(rewriter, loc, valToMakeShared, alloc.getResult());
 
   Value replacement = copy.getResult(0);
   // If in extract slice is present we make it consume the new copy.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp
index ac2cbad..d902f27 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReduceBankConflicts.cpp
@@ -65,11 +65,12 @@
   IRRewriter rewriter(context);
   rewriter.setInsertionPoint(allocOp);
   Location loc = allocOp.getLoc();
-  Value paddedAlloc = rewriter.create<memref::AllocOp>(loc, allocType);
+  Value paddedAlloc = memref::AllocOp::create(rewriter, loc, allocType);
   SmallVector<int64_t> offsets(shape.size(), 0);
   SmallVector<int64_t> strides(shape.size(), 1);
-  Value subview = rewriter.create<memref::SubViewOp>(
-      loc, paddedAlloc, offsets, allocOp.getType().getShape(), strides);
+  Value subview =
+      memref::SubViewOp::create(rewriter, loc, paddedAlloc, offsets,
+                                allocOp.getType().getShape(), strides);
   replaceMemrefUsesAndPropagateType(rewriter, loc, allocOp, subview);
   rewriter.eraseOp(allocOp);
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReuseSharedMemoryAllocs.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReuseSharedMemoryAllocs.cpp
index bde8a7a..1f684a2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReuseSharedMemoryAllocs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUReuseSharedMemoryAllocs.cpp
@@ -254,7 +254,7 @@
         // Add a barrier if the `otherLiveness` comes before `liveness`.
         if (dominanceInfo.dominates(otherLiveness.first, liveness.first)) {
           builder.setInsertionPoint(liveness.first);
-          builder.create<gpu::BarrierOp>(liveness.first->getLoc());
+          gpu::BarrierOp::create(builder, liveness.first->getLoc());
           break;
         }
       }
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp
index 28bf3ab..df30006 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTensorAlloc.cpp
@@ -116,8 +116,9 @@
 
     rewriter.setInsertionPoint(linalgOp);
     std::optional<Attribute> memorySpace = allocOp.getMemorySpace();
-    auto newAllocOp = rewriter.create<bufferization::AllocTensorOp>(
-        allocOp.getLoc(), allocOp.getType(), allocOp.getDynamicSizes(),
+    auto newAllocOp = bufferization::AllocTensorOp::create(
+        rewriter, allocOp.getLoc(), allocOp.getType(),
+        allocOp.getDynamicSizes(),
         /*copy=*/Value(),
         memorySpace ? cast<IntegerAttr>(*memorySpace) : IntegerAttr());
     rewriter.modifyOpInPlace(linalgOp, [&]() {
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp
index 97eb95b..29e0855 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUTileAndConvertConvToMatmul.cpp
@@ -82,8 +82,8 @@
   newIndexingMaps.push_back(filterMap);
   newIndexingMaps.push_back(indexingMaps[2]);
   // Create the new contraction op and replace the old convolution op.
-  auto newOp = rewriter.create<linalg::GenericOp>(
-      linalgOp.getLoc(), linalgOp.getDpsInits().getType(),
+  auto newOp = linalg::GenericOp::create(
+      rewriter, linalgOp.getLoc(), linalgOp.getDpsInits().getType(),
       linalgOp.getDpsInputs(), linalgOp.getDpsInits(), newIndexingMaps,
       linalgOp.getIteratorTypesArray(), /*bodyBuild=*/nullptr,
       getPrunedAttributeList(linalgOp));
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp
index fdaf942..ce704c8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorAlloc.cpp
@@ -72,22 +72,22 @@
       RankedTensorType::get(vectorType.getShape(), vectorType.getElementType(),
                             sharedMemoryAddrSpace);
   // Vectors are always statically shaped.
-  auto allocTensorOp = b.create<bufferization::AllocTensorOp>(
-      loc, tensorType, ValueRange{}, Value());
+  auto allocTensorOp = bufferization::AllocTensorOp::create(
+      b, loc, tensorType, ValueRange{}, Value());
   allocTensorOp.setMemorySpaceAttr(sharedMemoryAddrSpace);
 
-  Value c0 = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value c0 = arith::ConstantIndexOp::create(b, loc, 0);
   SmallVector<Value> indices(vectorType.getRank(), c0);
   SmallVector<bool> inBounds(vectorType.getRank(), true);
-  Value copied = b.create<vector::TransferWriteOp>(loc, vector, allocTensorOp,
-                                                   indices, inBounds)
+  Value copied = vector::TransferWriteOp::create(b, loc, vector, allocTensorOp,
+                                                 indices, inBounds)
                      .getResult();
   return copied;
 }
 
 static Value readVectorFromTensor(OpBuilder &b, VectorType vectorType,
                                   Value tensor) {
-  Value c0 = b.create<arith::ConstantIndexOp>(tensor.getLoc(), 0);
+  Value c0 = arith::ConstantIndexOp::create(b, tensor.getLoc(), 0);
   SmallVector<Value> indices(vectorType.getRank(), c0);
   SmallVector<bool> inBounds(vectorType.getRank(), true);
   return b
@@ -119,7 +119,7 @@
       // reads in the previous iteration of a loop. We set this barrier
       // at the start of this block.
       builder.setInsertionPointToStart(op->getBlock());
-      builder.create<gpu::BarrierOp>(op->getLoc());
+      gpu::BarrierOp::create(builder, op->getLoc());
 
       // Promote both of the input operands, excluding the accumulator.
       builder.setInsertionPoint(op);
@@ -132,7 +132,7 @@
 
       // Synchronize after the write to shared memory before we read from it.
       auto synced =
-          builder.create<IREE::GPU::ValueBarrierOp>(op->getLoc(), *ret);
+          IREE::GPU::ValueBarrierOp::create(builder, op->getLoc(), *ret);
 
       VectorType inputTy = cast<VectorType>(op.getType());
       Value read = readVectorFromTensor(builder, inputTy, synced.getResult(0));
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
index de4fe0e..5dba175 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
@@ -139,8 +139,8 @@
   SmallVector<int64_t> distributedShape = layout.getDistributedShape();
   VectorType distributedType =
       VectorType::get(distributedShape, value.getType().getElementType());
-  auto toSIMT = rewriter.create<IREE::VectorExt::ToSIMTOp>(
-      value.getLoc(), distributedType, value);
+  auto toSIMT = IREE::VectorExt::ToSIMTOp::create(rewriter, value.getLoc(),
+                                                  distributedType, value);
   return toSIMT.getResult();
 }
 
@@ -154,8 +154,8 @@
       auto oldResult = cast<VectorValue>(opResult);
       // Create a toSIMD op to convert the value back to the simd.
       rewriter.setInsertionPointAfterValue(oldResult);
-      Value toSIMD = rewriter.create<IREE::VectorExt::ToSIMDOp>(
-          oldResult.getLoc(), oldResult.getType(), replacement);
+      Value toSIMD = IREE::VectorExt::ToSIMDOp::create(
+          rewriter, oldResult.getLoc(), oldResult.getType(), replacement);
       // Add to replacements.
       replacement = toSIMD;
     }
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp
index 12d537c..8fec674 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/VectorReductionToGPU.cpp
@@ -53,7 +53,7 @@
     memrefType = MemRefType::get({1}, type, MemRefLayoutAttrInterface{},
                                  addressSpaceAttr);
   }
-  return builder.create<memref::AllocOp>(loc, memrefType);
+  return memref::AllocOp::create(builder, loc, memrefType);
 }
 
 /// Returns true if the given op is a memref.load from a uniform buffer or
@@ -177,7 +177,7 @@
       return failure();
 
     rewriter.setInsertionPointAfter(warpOp);
-    (void)rewriter.create<gpu::BarrierOp>(barrierOp.getLoc());
+    (void)gpu::BarrierOp::create(rewriter, barrierOp.getLoc());
     rewriter.eraseOp(barrierOp);
     return success();
   }
@@ -189,9 +189,9 @@
   assert((val.getType().isF32() || val.getType().isInteger(32)) &&
          "unsupported shuffle type");
   Type i32Type = builder.getIntegerType(32);
-  Value srcIdxI32 = builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx);
-  Value warpSzI32 = builder.create<arith::ConstantOp>(
-      loc, builder.getIntegerAttr(i32Type, warpSz));
+  Value srcIdxI32 = arith::IndexCastOp::create(builder, loc, i32Type, srcIdx);
+  Value warpSzI32 = arith::ConstantOp::create(
+      builder, loc, builder.getIntegerAttr(i32Type, warpSz));
   Value result = builder
                      .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
                                              gpu::ShuffleMode::IDX)
@@ -237,11 +237,11 @@
     const int groupSize = workgroupSize[0];
     Location loc = funcOp.getLoc();
     OpBuilder builder(funcOp);
-    auto threadX = builder.create<gpu::ThreadIdOp>(loc, builder.getIndexType(),
-                                                   gpu::Dimension::x);
-    auto cstGroupSize = builder.create<arith::ConstantIndexOp>(loc, groupSize);
-    auto warpOp = builder.create<gpu::WarpExecuteOnLane0Op>(
-        loc, TypeRange(), threadX.getResult(), groupSize);
+    auto threadX = gpu::ThreadIdOp::create(builder, loc, builder.getIndexType(),
+                                           gpu::Dimension::x);
+    auto cstGroupSize = arith::ConstantIndexOp::create(builder, loc, groupSize);
+    auto warpOp = gpu::WarpExecuteOnLane0Op::create(
+        builder, loc, TypeRange(), threadX.getResult(), groupSize);
     warpOp.getWarpRegion().takeBody(funcOp.getFunctionBody());
     Block &newBlock = funcOp.getFunctionBody().emplaceBlock();
     threadX->moveBefore(&newBlock, newBlock.end());
@@ -250,7 +250,7 @@
     warpOp.getWarpRegion().getBlocks().back().back().moveBefore(&newBlock,
                                                                 newBlock.end());
     builder.setInsertionPointToEnd(&warpOp.getWarpRegion().getBlocks().back());
-    builder.create<gpu::YieldOp>(loc);
+    gpu::YieldOp::create(builder, loc);
 
     debugPrint(funcOp, "after step #2: wrapping code with the warp execute op");
 
@@ -308,7 +308,7 @@
       options.warpAllocationFn = allocateGlobalSharedMemory;
       options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
                                         gpu::WarpExecuteOnLane0Op warpOp) {
-        builder.create<gpu::BarrierOp>(loc);
+        gpu::BarrierOp::create(builder, loc);
       };
       vector::populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
       (void)applyPatternsGreedily(getOperation(), std::move(patterns));
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp
index 3b3dfe9..c567745 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp
@@ -34,10 +34,10 @@
                                                  Value workgroupCountX,
                                                  Value workgroupCountY) {
   Value linearized =
-      b.create<arith::MulIOp>(loc, workgroupIdY, workgroupCountX);
-  linearized = b.create<arith::AddIOp>(loc, linearized, workgroupIdX);
-  Value newX = b.create<arith::DivUIOp>(loc, linearized, workgroupCountY);
-  Value newY = b.create<arith::RemUIOp>(loc, linearized, workgroupCountY);
+      arith::MulIOp::create(b, loc, workgroupIdY, workgroupCountX);
+  linearized = arith::AddIOp::create(b, loc, linearized, workgroupIdX);
+  Value newX = arith::DivUIOp::create(b, loc, linearized, workgroupCountY);
+  Value newY = arith::RemUIOp::create(b, loc, linearized, workgroupCountY);
   return {newX, newY};
 }
 
@@ -55,17 +55,17 @@
                << "Using static workgroup counts: X = " << workgroupCounts[0]
                << ", Y = " << workgroupCounts[1] << "\n");
     Value workgroupCountX =
-        builder.create<arith::ConstantIndexOp>(loc, workgroupCounts[0]);
+        arith::ConstantIndexOp::create(builder, loc, workgroupCounts[0]);
     Value workgroupCountY =
-        builder.create<arith::ConstantIndexOp>(loc, workgroupCounts[1]);
+        arith::ConstantIndexOp::create(builder, loc, workgroupCounts[1]);
     return {workgroupCountX, workgroupCountY};
   }
 
   LLVM_DEBUG(llvm::dbgs() << "Using dynamic workgroup counts\n");
   Value dynamicCountX =
-      builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 0, xBound);
+      IREE::HAL::InterfaceWorkgroupCountOp::create(builder, loc, 0, xBound);
   Value dynamicCountY =
-      builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 1, yBound);
+      IREE::HAL::InterfaceWorkgroupCountOp::create(builder, loc, 1, yBound);
   return {dynamicCountX, dynamicCountY};
 }
 
@@ -101,10 +101,10 @@
   // that to RAUW the old ones. This way we don't have to worry about the
   // picking the exact insertion points that do not violate dominance between
   // their defs and users.
-  Value workgroupIdX = builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(
-      funcOp.getLoc(), 0, oldXId.getUpperBound());
-  Value workgroupIdY = builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(
-      funcOp.getLoc(), 1, oldYId.getUpperBound());
+  Value workgroupIdX = IREE::HAL::InterfaceWorkgroupIDOp::create(
+      builder, funcOp.getLoc(), 0, oldXId.getUpperBound());
+  Value workgroupIdY = IREE::HAL::InterfaceWorkgroupIDOp::create(
+      builder, funcOp.getLoc(), 1, oldYId.getUpperBound());
   auto [workgroupCntX, workgroupCntY] = getWorkgroupCountsXY(
       builder, funcOp, oldXId.getUpperBound(), oldYId.getUpperBound());
   Value newWorkgroupIdX;
diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
index bad4774..f8cfd88 100644
--- a/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/IREEComprehensiveBufferizePass.cpp
@@ -65,7 +65,7 @@
       type = MemRefType::get(type.getShape(), type.getElementType(),
                              type.getLayout());
   }
-  return builder.create<memref::AllocOp>(loc, type, dynamicSizes).getResult();
+  return memref::AllocOp::create(builder, loc, type, dynamicSizes).getResult();
 }
 static LogicalResult defaultMemCpyFn(OpBuilder &builder, Location loc,
                                      Value from, Value to) {
diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp
index ce95b61..d5618c9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/IREEExpandStridedMetadata.cpp
@@ -202,11 +202,11 @@
       newBufferType = MemRefType::get(
           staticLinearShape, memRefType.getElementType(),
           MemRefLayoutAttrInterface(), memRefType.getMemorySpace());
-      Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-      newBinding = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
-          loc, newBufferType, binding.getLayoutAttr(), binding.getBindingAttr(),
-          zero, dynamicLinearShape, binding.getAlignmentAttr(),
-          binding.getDescriptorFlagsAttr());
+      Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+      newBinding = IREE::HAL::InterfaceBindingSubspanOp::create(
+          rewriter, loc, newBufferType, binding.getLayoutAttr(),
+          binding.getBindingAttr(), zero, dynamicLinearShape,
+          binding.getAlignmentAttr(), binding.getDescriptorFlagsAttr());
     }
     SmallVector<Value> results;
     results.reserve(memRefType.getRank() * 2 + 2);
@@ -215,8 +215,8 @@
       if (newBufferType == baseBufferType) {
         results.push_back(newBinding);
       } else {
-        Value reinterpretCast = rewriter.create<memref::ReinterpretCastOp>(
-            loc, baseBufferType, newBinding, /*offset=*/0,
+        Value reinterpretCast = memref::ReinterpretCastOp::create(
+            rewriter, loc, baseBufferType, newBinding, /*offset=*/0,
             /*sizes=*/ArrayRef<int64_t>(),
             /*strides=*/ArrayRef<int64_t>());
         results.push_back(reinterpretCast);
diff --git a/compiler/src/iree/compiler/Codegen/Common/IREEInjectAssumeAlignment.cpp b/compiler/src/iree/compiler/Codegen/Common/IREEInjectAssumeAlignment.cpp
index 0dd833f..85ac629 100644
--- a/compiler/src/iree/compiler/Codegen/Common/IREEInjectAssumeAlignment.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/IREEInjectAssumeAlignment.cpp
@@ -41,8 +41,8 @@
   }
   Location loc = op.getLoc();
   rewriter.setInsertionPointAfter(op);
-  auto alignOp = rewriter.create<memref::AssumeAlignmentOp>(
-      loc, op.getResult(), op.calculateAlignment().value());
+  auto alignOp = memref::AssumeAlignmentOp::create(
+      rewriter, loc, op.getResult(), op.calculateAlignment().value());
   rewriter.replaceAllUsesExcept(op.getResult(), alignOp.getResult(), alignOp);
   return success();
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/InstrumentMemoryAccesses.cpp b/compiler/src/iree/compiler/Codegen/Common/InstrumentMemoryAccesses.cpp
index d399182..5e0962e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/InstrumentMemoryAccesses.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/InstrumentMemoryAccesses.cpp
@@ -42,41 +42,37 @@
           .Case<memref::LoadOp>([&](auto loadOp) {
             OpBuilder builder(loadOp);
             builder.setInsertionPointAfter(loadOp);
-            auto instrumentOp =
-                builder.create<IREE::HAL::InstrumentMemoryLoadOp>(
-                    loadOp.getLoc(), loadOp.getResult().getType(), buffer,
-                    workgroupKey, loadOp.getResult(), loadOp.getMemRef(),
-                    loadOp.getIndices());
+            auto instrumentOp = IREE::HAL::InstrumentMemoryLoadOp::create(
+                builder, loadOp.getLoc(), loadOp.getResult().getType(), buffer,
+                workgroupKey, loadOp.getResult(), loadOp.getMemRef(),
+                loadOp.getIndices());
             loadOp.getResult().replaceAllUsesExcept(instrumentOp.getResult(),
                                                     instrumentOp);
           })
           .Case<memref::StoreOp>([&](auto storeOp) {
             OpBuilder builder(storeOp);
-            auto instrumentOp =
-                builder.create<IREE::HAL::InstrumentMemoryStoreOp>(
-                    storeOp.getLoc(), storeOp.getValueToStore().getType(),
-                    buffer, workgroupKey, storeOp.getValueToStore(),
-                    storeOp.getMemRef(), storeOp.getIndices());
+            auto instrumentOp = IREE::HAL::InstrumentMemoryStoreOp::create(
+                builder, storeOp.getLoc(), storeOp.getValueToStore().getType(),
+                buffer, workgroupKey, storeOp.getValueToStore(),
+                storeOp.getMemRef(), storeOp.getIndices());
             storeOp.getValueMutable().assign(instrumentOp.getResult());
           })
           .Case<vector::LoadOp>([&](auto loadOp) {
             OpBuilder builder(loadOp);
             builder.setInsertionPointAfter(loadOp);
-            auto instrumentOp =
-                builder.create<IREE::HAL::InstrumentMemoryLoadOp>(
-                    loadOp.getLoc(), loadOp.getVectorType(), buffer,
-                    workgroupKey, loadOp.getResult(), loadOp.getBase(),
-                    loadOp.getIndices());
+            auto instrumentOp = IREE::HAL::InstrumentMemoryLoadOp::create(
+                builder, loadOp.getLoc(), loadOp.getVectorType(), buffer,
+                workgroupKey, loadOp.getResult(), loadOp.getBase(),
+                loadOp.getIndices());
             loadOp.getResult().replaceAllUsesExcept(instrumentOp.getResult(),
                                                     instrumentOp);
           })
           .Case<vector::StoreOp>([&](auto storeOp) {
             OpBuilder builder(storeOp);
-            auto instrumentOp =
-                builder.create<IREE::HAL::InstrumentMemoryStoreOp>(
-                    storeOp.getLoc(), storeOp.getVectorType(), buffer,
-                    workgroupKey, storeOp.getValueToStore(), storeOp.getBase(),
-                    storeOp.getIndices());
+            auto instrumentOp = IREE::HAL::InstrumentMemoryStoreOp::create(
+                builder, storeOp.getLoc(), storeOp.getVectorType(), buffer,
+                workgroupKey, storeOp.getValueToStore(), storeOp.getBase(),
+                storeOp.getIndices());
             storeOp.getValueToStoreMutable().assign(instrumentOp.getResult());
           })
           .Default([&](Operation *) {});
diff --git a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
index ab55fa2..baf7514 100644
--- a/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/LinkTuningSpecsPass.cpp
@@ -282,10 +282,10 @@
   //   transform.yield %arg0 : !transform.any_op
   // }
 
-  return builder.create<NamedSequenceOp>(loc, name, TypeAttr::get(specType),
-                                         /*sym_visibility=*/StringAttr{},
-                                         /*arg_attrs=*/ArrayAttr{},
-                                         /*res_attrs=*/ArrayAttr{});
+  return NamedSequenceOp::create(builder, loc, name, TypeAttr::get(specType),
+                                 /*sym_visibility=*/StringAttr{},
+                                 /*arg_attrs=*/ArrayAttr{},
+                                 /*res_attrs=*/ArrayAttr{});
 }
 
 static FailureOr<NamedSequenceOp>
@@ -363,7 +363,7 @@
                   .front();
   }
 
-  builder.create<transform::YieldOp>(loc, operand);
+  transform::YieldOp::create(builder, loc, operand);
 
   if (failed(mlir::verify(module))) {
     return module.emitError("Linked tuning spec failed to verify");
@@ -442,13 +442,13 @@
   Block *body = builder.createBlock(&region, region.begin(),
                                     newEntryPoint.getArgumentTypes(), loc);
   builder.setInsertionPointToStart(body);
-  auto mergedForeachMatch = builder.create<ForeachMatchOp>(
-      loc, resultTypes, newEntryPoint.getArgument(0),
+  auto mergedForeachMatch = ForeachMatchOp::create(
+      builder, loc, resultTypes, newEntryPoint.getArgument(0),
       /* forwarded_inputs = */ ValueRange(),
       /* restrictRoot = */ nullptr, /* flattenResults = */ nullptr,
       builder.getArrayAttr(mergedMatchers),
       builder.getArrayAttr(mergedActions));
-  builder.create<transform::YieldOp>(loc, mergedForeachMatch->getResult(0));
+  transform::YieldOp::create(builder, loc, mergedForeachMatch->getResult(0));
 
   // Step 3: Remove the original inner modules after merging.
   for (auto innerModule :
diff --git a/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp b/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp
index 2f412fe..0d94832 100644
--- a/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp
@@ -56,7 +56,7 @@
     if (!inputTy || !CastOpTy::areCastCompatible(inputTy, resultType)) {
       return Value();
     }
-    return builder.create<CastOpTy>(loc, resultType, input).getResult();
+    return CastOpTy::create(builder, loc, resultType, input).getResult();
   });
   converter.addTargetMaterialization([](OpBuilder &builder, TargetTy resultType,
                                         ValueRange inputs,
@@ -69,7 +69,7 @@
     if (!inputTy || !CastOpTy::areCastCompatible(inputTy, resultType)) {
       return Value();
     }
-    return builder.create<CastOpTy>(loc, resultType, input).getResult();
+    return CastOpTy::create(builder, loc, resultType, input).getResult();
   });
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp
index a59fecf..911f083 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp
@@ -61,15 +61,15 @@
     return rewriter.notifyMatchFailure(
         encodingOp, "failed to generate runtime tile size query");
   }
-  Value paddingValue = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getZeroAttr(resultType.getElementType()));
+  Value paddingValue = arith::ConstantOp::create(
+      rewriter, loc, rewriter.getZeroAttr(resultType.getElementType()));
   SmallVector<OpFoldResult> sourceDims =
       tensor::getMixedSizes(rewriter, loc, source);
   SmallVector<OpFoldResult> resultDims = linalg::PackOp::getResultShape(
       rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
       encodingInfo.outerDimsPerm);
-  auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
-                                                  resultType.getElementType());
+  auto emptyOp = tensor::EmptyOp::create(rewriter, loc, resultDims,
+                                         resultType.getElementType());
   return rewriter
       .create<linalg::PackOp>(loc, source, emptyOp, encodingInfo.innerDimsPos,
                               *innerTileSizesOfr, paddingValue,
@@ -94,8 +94,8 @@
   SmallVector<OpFoldResult> resultDims =
       getMixedValues(encodingOp.getResultType().getShape(),
                      encodingOp.getResultDims(), rewriter);
-  auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
-                                                  sourceType.getElementType());
+  auto emptyOp = tensor::EmptyOp::create(rewriter, loc, resultDims,
+                                         sourceType.getElementType());
   FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
       typeConverter.getInnerTileSizesOfr(rewriter, loc, sourceType,
                                          encodingInfo);
@@ -141,8 +141,8 @@
       rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
       encodingInfo.outerDimsPerm);
   newShape = getSwizzledShape(newShape, encodingInfo);
-  Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
-      loc, newShape, emptyType.getElementType());
+  Operation *newEmptyOp = tensor::EmptyOp::create(rewriter, loc, newShape,
+                                                  emptyType.getElementType());
   return newEmptyOp;
 }
 
@@ -415,9 +415,10 @@
             convertedResultType.getShape(),
             cast<RankedTensorType>(t).getElementType());
       });
-  auto materializedGenericOp = rewriter.create<linalg::GenericOp>(
-      genericOp.getLoc(), convertedResultTypes, convertedInputOperands,
-      convertedOutputOperands, packedIndexingMaps, iteratorTypes,
+  auto materializedGenericOp = linalg::GenericOp::create(
+      rewriter, genericOp.getLoc(), convertedResultTypes,
+      convertedInputOperands, convertedOutputOperands, packedIndexingMaps,
+      iteratorTypes,
       /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
   rewriter.inlineRegionBefore(genericOp.getRegion(),
                               materializedGenericOp.getRegion(),
@@ -447,8 +448,8 @@
   return TypeSwitch<Operation *, FailureOr<Operation *>>(linalgOp)
       .Case<linalg::FillOp>(
           [&](linalg::FillOp fillOp) -> FailureOr<Operation *> {
-            Operation *materializedFillOp = rewriter.create<linalg::FillOp>(
-                fillOp.getLoc(), convertedOutputOperands[0].getType(),
+            Operation *materializedFillOp = linalg::FillOp::create(
+                rewriter, fillOp.getLoc(), convertedOutputOperands[0].getType(),
                 convertedInputOperands, convertedOutputOperands);
             return materializedFillOp;
           })
@@ -755,8 +756,8 @@
 
     SmallVector<ReassociationIndices> reassociation =
         getReassociationIndices(origRank, encodingInfo.swizzle->expandShape);
-    auto expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
-        loc, expandShapeType, packedValue.value(), reassociation);
+    auto expandShapeOp = tensor::ExpandShapeOp::create(
+        rewriter, loc, expandShapeType, packedValue.value(), reassociation);
 
     SmallVector<int64_t> transposePerm =
         llvm::to_vector(llvm::seq<int64_t>(0, origRank));
@@ -767,10 +768,11 @@
         tensor::getMixedSizes(rewriter, loc, expandShapeOp.getResult());
     applyPermutationToVector(transposeResultDims, transposePerm);
 
-    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, transposeResultDims, encodingOp.getSourceType().getElementType());
-    auto transposeOp = rewriter.create<linalg::TransposeOp>(
-        loc, expandShapeOp, emptyTensor, transposePerm);
+    auto emptyTensor =
+        tensor::EmptyOp::create(rewriter, loc, transposeResultDims,
+                                encodingOp.getSourceType().getElementType());
+    auto transposeOp = linalg::TransposeOp::create(rewriter, loc, expandShapeOp,
+                                                   emptyTensor, transposePerm);
     rewriter.replaceOp(encodingOp, transposeOp->getResult(0));
 
     return success();
@@ -808,8 +810,9 @@
       for (auto i : getExpandedTileShape(encodingInfo.swizzle->expandShape)) {
         emptyShape.push_back(rewriter.getIndexAttr(i));
       }
-      auto emptyTensor = rewriter.create<tensor::EmptyOp>(
-          loc, emptyShape, unsetEncodingOp.getSourceType().getElementType());
+      auto emptyTensor = tensor::EmptyOp::create(
+          rewriter, loc, emptyShape,
+          unsetEncodingOp.getSourceType().getElementType());
 
       SmallVector<int64_t> transposePerm =
           llvm::to_vector(llvm::seq<int64_t>(0, targetRank));
@@ -817,8 +820,9 @@
         transposePerm.push_back(targetRank + perm);
       }
       auto invertedTransposePerm = invertPermutationVector(transposePerm);
-      auto transposeOp = rewriter.create<linalg::TransposeOp>(
-          loc, adaptor.getSource(), emptyTensor, invertedTransposePerm);
+      auto transposeOp =
+          linalg::TransposeOp::create(rewriter, loc, adaptor.getSource(),
+                                      emptyTensor, invertedTransposePerm);
 
       SmallVector<ReassociationIndices> reassociation = getReassociationIndices(
           targetRank, encodingInfo.swizzle->expandShape);
@@ -828,8 +832,9 @@
                             encodingInfo.innerTileSizes.end());
       RankedTensorType unpackSrcType =
           unsetEncodingOp.getResultType().clone(unpackSrcShape);
-      unpackSrc = rewriter.create<tensor::CollapseShapeOp>(
-          loc, unpackSrcType, transposeOp->getResult(0), reassociation);
+      unpackSrc = tensor::CollapseShapeOp::create(rewriter, loc, unpackSrcType,
+                                                  transposeOp->getResult(0),
+                                                  reassociation);
     }
 
     auto unpackedValue = lowerUnsetEncodingToUnpackOp(rewriter, unsetEncodingOp,
diff --git a/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp b/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp
index 18f38d2..86d7897 100644
--- a/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/NormalizeLoopBounds.cpp
@@ -154,8 +154,8 @@
   }
 
   rewriter.setInsertionPointAfter(forallOp);
-  auto newLoop = rewriter.create<scf::ForallOp>(
-      rewriter.getUnknownLoc(), newLoopParams->lowerBounds,
+  auto newLoop = scf::ForallOp::create(
+      rewriter, rewriter.getUnknownLoc(), newLoopParams->lowerBounds,
       newLoopParams->upperBounds, newLoopParams->steps, forallOp.getOutputs(),
       forallOp.getMapping());
   rewriter.eraseOp(newLoop.getTerminator());
diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
index a007c4f..3bc565a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
@@ -337,8 +337,8 @@
            "GenericVectorization.cpp::FoldMaskedTransferRAW for information");
 
     // Materialize the padding with a constant.
-    auto padVal = rewriter.create<vector::BroadcastOp>(
-        rPad.getLoc(), valToStore.getType(), rPad);
+    auto padVal = vector::BroadcastOp::create(rewriter, rPad.getLoc(),
+                                              valToStore.getType(), rPad);
     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, wMask, valToStore, padVal);
     return success();
   }
diff --git a/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp b/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp
index ad49b26..84f4d08 100644
--- a/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/PadDynamicAlloc.cpp
@@ -95,11 +95,11 @@
   MemRefType allocType = MemRefType::get(shape, elType, AffineMap(),
                                          allocOp.getType().getMemorySpace());
   Location loc = allocOp.getLoc();
-  Value paddedAlloc = rewriter.create<AllocLikeOp>(loc, allocType);
+  Value paddedAlloc = AllocLikeOp::create(rewriter, loc, allocType);
   SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
   SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
-  Value subview = rewriter.create<memref::SubViewOp>(loc, paddedAlloc, offsets,
-                                                     sizes, strides);
+  Value subview = memref::SubViewOp::create(rewriter, loc, paddedAlloc, offsets,
+                                            sizes, strides);
   replaceMemrefUsesAndPropagateType(rewriter, loc, allocOp, subview);
   rewriter.eraseOp(allocOp);
   return success();
diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp
index fde88eb..1ec3343 100644
--- a/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/PropagateConstantOffsets.cpp
@@ -81,10 +81,10 @@
 
     AffineMap newMap =
         AffineMap::get(map.getNumDims(), map.getNumSymbols(), addExpr.getLHS());
-    Value newApply = rewriter.create<affine::AffineApplyOp>(
-        apply.getLoc(), newMap, apply.getOperands());
-    Value offset =
-        rewriter.create<arith::ConstantIndexOp>(apply.getLoc(), constantOffset);
+    Value newApply = affine::AffineApplyOp::create(rewriter, apply.getLoc(),
+                                                   newMap, apply.getOperands());
+    Value offset = arith::ConstantIndexOp::create(rewriter, apply.getLoc(),
+                                                  constantOffset);
     rewriter.replaceOpWithNewOp<arith::AddIOp>(
         apply, newApply, offset, arith::IntegerOverflowFlags::nsw);
     return success();
@@ -159,7 +159,7 @@
     auto getZero = [&]() {
       if (zero)
         return zero;
-      zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+      zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
       return zero;
     };
     bool didReplace = false;
@@ -220,9 +220,10 @@
 
     rewriter.setInsertionPointAfter(op);
     Value offset =
-        rewriter.create<arith::ConstantIndexOp>(op.getLoc(), runningOffset);
-    auto addOp = rewriter.create<arith::AddIOp>(
-        op.getLoc(), op.getResult(), offset, arith::IntegerOverflowFlags::nsw);
+        arith::ConstantIndexOp::create(rewriter, op.getLoc(), runningOffset);
+    auto addOp =
+        arith::AddIOp::create(rewriter, op.getLoc(), op.getResult(), offset,
+                              arith::IntegerOverflowFlags::nsw);
     rewriter.replaceAllUsesExcept(op, addOp, addOp);
     return success();
   }
@@ -253,7 +254,7 @@
     auto getZero = [&]() {
       if (zero)
         return zero;
-      zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+      zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
       return zero;
     };
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp
index 5331c57..93fd7f6 100644
--- a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp
@@ -260,17 +260,20 @@
     auto expandedDestType =
         cast<RankedTensorType>(forallOutputs[tiedResultIdx].getType())
             .clone(expandedDestShape);
-    auto expandedDest = rewriter.create<tensor::ExpandShapeOp>(
-        loc, expandedDestType, forallOutputs[tiedResultIdx], reIndices);
+    auto expandedDest =
+        tensor::ExpandShapeOp::create(rewriter, loc, expandedDestType,
+                                      forallOutputs[tiedResultIdx], reIndices);
 
     forallOutputs[tiedResultIdx] = expandedDest;
 
-    scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
-        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
-        forallOp.getMixedStep(), forallOutputs, forallOp.getMappingAttr());
+    scf::ForallOp newForallOp = scf::ForallOp::create(
+        rewriter, loc, forallOp.getMixedLowerBound(),
+        forallOp.getMixedUpperBound(), forallOp.getMixedStep(), forallOutputs,
+        forallOp.getMappingAttr());
 
-    auto collapsedResultOp = rewriter.create<tensor::CollapseShapeOp>(
-        loc, cast<ShapedType>(forallOp->getResult(tiedResultIdx).getType()),
+    auto collapsedResultOp = tensor::CollapseShapeOp::create(
+        rewriter, loc,
+        cast<ShapedType>(forallOp->getResult(tiedResultIdx).getType()),
         newForallOp->getResult(tiedResultIdx), reIndices);
 
     // Merge the old scf.forall block which has the expanded users into the new
@@ -353,8 +356,8 @@
     dispatchIndexOpFoldResults(newMixedSizes, sliceSourceDynamicSizes,
                                sliceSourceStaticSizes);
 
-    Value newBitcast = rewriter.create<IREE::TensorExt::BitCastOp>(
-        bitcastOp.getLoc(), newBitcastType, sliceOp.getSource(),
+    Value newBitcast = IREE::TensorExt::BitCastOp::create(
+        rewriter, bitcastOp.getLoc(), newBitcastType, sliceOp.getSource(),
         sliceSourceDynamicSizes, sliceSourceDynamicSizes);
     SmallVector<int64_t> newSizes(sliceOp.getStaticSizes());
     newSizes.back() = newInnerSize;
diff --git a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
index e96b73f..049638d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp
@@ -605,17 +605,17 @@
     return forallOp->emitOpError("failed to lower split reduction modifier op");
   }
 
-  auto procIdOp = rewriter.create<IREE::HAL::InterfaceWorkgroupIDOp>(
-      loc, static_cast<unsigned>(delinearizeFrom));
-  auto nTotalProcsOp = rewriter.create<IREE::HAL::InterfaceWorkgroupCountOp>(
-      loc, static_cast<unsigned>(delinearizeFrom));
+  auto procIdOp = IREE::HAL::InterfaceWorkgroupIDOp::create(
+      rewriter, loc, static_cast<unsigned>(delinearizeFrom));
+  auto nTotalProcsOp = IREE::HAL::InterfaceWorkgroupCountOp::create(
+      rewriter, loc, static_cast<unsigned>(delinearizeFrom));
   OpFoldResult nTotalProcs = nTotalProcsOp.getResult();
   OpFoldResult origNProcs = affine::makeComposedFoldedAffineApply(
       rewriter, loc, s0.floorDiv(s1), {nTotalProcs, nSplitProcs});
   SmallVector<OpFoldResult> basis = numIters;
   basis.push_back(origNProcs);
-  auto delinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
-      loc, procIdOp.getResult(), basis);
+  auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
+      rewriter, loc, procIdOp.getResult(), basis);
 
   Value workgroupIdReplacement = delinearizeOp.getResults().back();
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp
index 8105452..2f4777d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ReshapePatterns.cpp
@@ -128,8 +128,8 @@
     auto newSubspanType = IREE::TensorExt::DispatchTensorType::get(
         tensorAccess, reshapeOp.getResultType());
 
-    Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
-        subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+    Value newSubspanOp = IREE::HAL::InterfaceBindingSubspanOp::create(
+        rewriter, subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
         subspanOp.getBinding(), subspanOp.getByteOffset(),
         collapsedDynamicShape, subspanOp.getAlignmentAttr(),
         subspanOp.getDescriptorFlagsAttr());
@@ -229,8 +229,8 @@
                                expandedStaticDims);
 
     Value newSubspanOp;
-    newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
-        subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+    newSubspanOp = IREE::HAL::InterfaceBindingSubspanOp::create(
+        rewriter, subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
         subspanOp.getBinding(), subspanOp.getByteOffset(), expandedDynamicDims,
         subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr());
 
@@ -313,8 +313,8 @@
     auto newSubspanType = IREE::TensorExt::DispatchTensorType::get(
         tensorAccess, reshapeSrc.getType());
 
-    Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
-        subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+    Value newSubspanOp = IREE::HAL::InterfaceBindingSubspanOp::create(
+        rewriter, subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
         subspanOp.getBinding(), subspanOp.getByteOffset(),
         collapsedDynamicShape, subspanOp.getAlignmentAttr(),
         subspanOp.getDescriptorFlagsAttr());
@@ -497,8 +497,8 @@
     auto newSubspanType = IREE::TensorExt::DispatchTensorType::get(
         tensorAccess, reshapeSrcType.cloneWith(
                           newSubspanShape, reshapeSrcType.getElementType()));
-    auto newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
-        subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+    auto newSubspanOp = IREE::HAL::InterfaceBindingSubspanOp::create(
+        rewriter, subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
         subspanOp.getBinding(), subspanOp.getByteOffset(), expandedDynamicShape,
         subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr());
 
@@ -795,8 +795,8 @@
     {
       OpBuilder::InsertionGuard guard(rewriter);
       rewriter.setInsertionPointAfter(subspanOp);
-      newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
-          subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+      newSubspanOp = IREE::HAL::InterfaceBindingSubspanOp::create(
+          rewriter, subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
           subspanOp.getBinding(), subspanOp.getByteOffset(),
           newSubspanDynamicDims, subspanOp.getAlignmentAttr(),
           subspanOp.getDescriptorFlagsAttr());
@@ -872,8 +872,8 @@
     // Byte offset and byte alignment remain the same after folding the cast.
     // Simply create a new binding with the new type.
     rewriter.setInsertionPoint(subspanOp);
-    Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
-        subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+    Value newSubspanOp = IREE::HAL::InterfaceBindingSubspanOp::create(
+        rewriter, subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
         subspanOp.getBinding(), subspanOp.getByteOffset(),
         subspanOp.getDynamicDims(), subspanOp.getAlignmentAttr(),
         subspanOp.getDescriptorFlagsAttr());
@@ -942,8 +942,8 @@
         subspanType.getAccess(), newSubspanTensorType);
 
     rewriter.setInsertionPointAfter(subspanOp);
-    Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
-        subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+    Value newSubspanOp = IREE::HAL::InterfaceBindingSubspanOp::create(
+        rewriter, subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
         subspanOp.getBinding(), subspanOp.getByteOffset(),
         subspanOp.getDynamicDims(), subspanOp.getAlignmentAttr(),
         subspanOp.getDescriptorFlagsAttr());
@@ -983,7 +983,7 @@
   rewriter.setInsertionPointAfterValue(memref);
   Location loc = tensorToMemrefOp.getLoc();
   Value collapsedMemref =
-      rewriter.create<memref::CollapseShapeOp>(loc, memref, reassociations);
+      memref::CollapseShapeOp::create(rewriter, loc, memref, reassociations);
   rewriter.modifyOpInPlace(tensorToMemrefOp, [&]() {
     tensorToMemrefOp.getBufferMutable().assign(collapsedMemref);
   });
@@ -1018,8 +1018,9 @@
   OpBuilder::InsertionGuard g(rewriter);
   dynamicValues.push_back(memref);
   setInsertionPointAfterLastValue(rewriter, dynamicValues);
-  Value expandedMemref = rewriter.create<memref::ExpandShapeOp>(
-      loc, *expandedMemrefType, memref, reassociations, mixedOutputShape);
+  Value expandedMemref =
+      memref::ExpandShapeOp::create(rewriter, loc, *expandedMemrefType, memref,
+                                    reassociations, mixedOutputShape);
   rewriter.modifyOpInPlace(tensorToMemrefOp, [&]() {
     tensorToMemrefOp.getBufferMutable().assign(expandedMemref);
   });
diff --git a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp
index 3ace270..bc2f6ae 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp
@@ -37,14 +37,15 @@
   if (auto add = v.getDefiningOp<arith::AddIOp>()) {
     llvm::APInt constant;
     if (matchPattern(add.getRhs(), m_ConstantInt(&constant))) {
-      Value combined = rewriter.create<arith::ConstantIndexOp>(
-          add.getLoc(), offset + constant.getSExtValue());
-      return rewriter.create<arith::AddIOp>(add.getLoc(), add.getLhs(),
-                                            combined, add.getOverflowFlags());
+      Value combined = arith::ConstantIndexOp::create(
+          rewriter, add.getLoc(), offset + constant.getSExtValue());
+      return arith::AddIOp::create(rewriter, add.getLoc(), add.getLhs(),
+                                   combined, add.getOverflowFlags());
     }
   }
-  Value offsetVal = rewriter.create<arith::ConstantIndexOp>(v.getLoc(), offset);
-  return rewriter.create<arith::AddIOp>(v.getLoc(), v, offsetVal);
+  Value offsetVal =
+      arith::ConstantIndexOp::create(rewriter, v.getLoc(), offset);
+  return arith::AddIOp::create(rewriter, v.getLoc(), v, offsetVal);
 }
 
 /// Swizzles vector.load(iree_codegen.swizzle_hint, offset). The
@@ -75,8 +76,8 @@
       VectorType::get({accessWidth}, type.getElementType());
 
   // ~ vector.undef, overwritten by unrolling.
-  Value replacement = rewriter.create<arith::ConstantOp>(
-      hintLoc, type, rewriter.getZeroAttr(type));
+  Value replacement = arith::ConstantOp::create(rewriter, hintLoc, type,
+                                                rewriter.getZeroAttr(type));
 
   // Load type = vector<C>, k = accessWidth
   // i = 0 -> C += k is the offset into the vector of a contiguous group of
@@ -87,11 +88,11 @@
         rewriter, hintLoc,
         hintOp.getSwizzle().swizzleOffset(rewriter, hintOp.getLoc(),
                                           newBaseOffset, hintOp.getOperand()));
-    auto subLoad = rewriter.create<vector::LoadOp>(
-        load.getLoc(), swizzledLoadType, load.getBase(), newOffset);
+    auto subLoad = vector::LoadOp::create(
+        rewriter, load.getLoc(), swizzledLoadType, load.getBase(), newOffset);
 
-    replacement = rewriter.create<vector::InsertStridedSliceOp>(
-        load.getLoc(), subLoad, replacement, ArrayRef<int64_t>{i},
+    replacement = vector::InsertStridedSliceOp::create(
+        rewriter, load.getLoc(), subLoad, replacement, ArrayRef<int64_t>{i},
         ArrayRef<int64_t>{1});
   }
   rewriter.replaceOp(load, replacement);
@@ -121,8 +122,8 @@
   // i = 0 -> C += k is the offset into the vector of a contiguous group of
   // swizzled elements.
   for (int64_t i = 0; i < storeWidth; i += accessWidth) {
-    Value subVec = rewriter.create<vector::ExtractStridedSliceOp>(
-        store.getLoc(), store.getValueToStore(), ArrayRef<int64_t>{i},
+    Value subVec = vector::ExtractStridedSliceOp::create(
+        rewriter, store.getLoc(), store.getValueToStore(), ArrayRef<int64_t>{i},
         ArrayRef<int64_t>{accessWidth}, ArrayRef<int64_t>{1});
     Value newBaseOffset = createOrFoldNewStaticAdd(rewriter, memrefOffset, i);
 
@@ -130,8 +131,8 @@
         rewriter, hintLoc,
         hintOp.getSwizzle().swizzleOffset(rewriter, hintOp.getLoc(),
                                           newBaseOffset, hintOp.getOperand()));
-    rewriter.create<vector::StoreOp>(store.getLoc(), subVec, store.getBase(),
-                                     newOffset);
+    vector::StoreOp::create(rewriter, store.getLoc(), subVec, store.getBase(),
+                            newOffset);
   }
   rewriter.eraseOp(store);
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/SpecializeExports.cpp b/compiler/src/iree/compiler/Codegen/Common/SpecializeExports.cpp
index a2d9848..05ada9b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/SpecializeExports.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/SpecializeExports.cpp
@@ -289,7 +289,7 @@
       builder.setInsertionPointToStart(newCondition);
 
       Value exportCondition =
-          builder.create<arith::ConstantIntOp>(loc, builder.getI1Type(), 1);
+          arith::ConstantIntOp::create(builder, loc, builder.getI1Type(), 1);
 
       for (auto [range, assumedSize] :
            llvm::zip(specializationRange, workloadMapping)) {
@@ -300,32 +300,32 @@
         // +1 for the device.
         Value workload =
             newCondition->getArgument(assumedSize.workloadOrdinal + 1);
-        Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+        Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
 
         if (range.getUmin().has_value()) {
-          Value uminVal = builder.create<arith::ConstantIndexOp>(
-              loc, range.getUmin().value());
-          Value cmp = builder.create<arith::CmpIOp>(
-              loc, arith::CmpIPredicate::ule, uminVal, workload);
+          Value uminVal = arith::ConstantIndexOp::create(
+              builder, loc, range.getUmin().value());
+          Value cmp = arith::CmpIOp::create(
+              builder, loc, arith::CmpIPredicate::ule, uminVal, workload);
           exportCondition =
-              builder.create<arith::AndIOp>(loc, cmp, exportCondition);
+              arith::AndIOp::create(builder, loc, cmp, exportCondition);
         }
         if (range.getUmax().has_value()) {
-          Value umaxVal = builder.create<arith::ConstantIndexOp>(
-              loc, range.getUmax().value());
-          Value cmp = builder.create<arith::CmpIOp>(
-              loc, arith::CmpIPredicate::uge, umaxVal, workload);
+          Value umaxVal = arith::ConstantIndexOp::create(
+              builder, loc, range.getUmax().value());
+          Value cmp = arith::CmpIOp::create(
+              builder, loc, arith::CmpIPredicate::uge, umaxVal, workload);
           exportCondition =
-              builder.create<arith::AndIOp>(loc, cmp, exportCondition);
+              arith::AndIOp::create(builder, loc, cmp, exportCondition);
         }
         if (range.getUdiv().has_value()) {
-          Value udivVal = builder.create<arith::ConstantIndexOp>(
-              loc, range.getUdiv().value());
-          Value rem = builder.create<arith::RemUIOp>(loc, workload, udivVal);
-          Value cmp = builder.create<arith::CmpIOp>(
-              loc, arith::CmpIPredicate::eq, rem, zero);
+          Value udivVal = arith::ConstantIndexOp::create(
+              builder, loc, range.getUdiv().value());
+          Value rem = arith::RemUIOp::create(builder, loc, workload, udivVal);
+          Value cmp = arith::CmpIOp::create(
+              builder, loc, arith::CmpIPredicate::eq, rem, zero);
           exportCondition =
-              builder.create<arith::AndIOp>(loc, cmp, exportCondition);
+              arith::AndIOp::create(builder, loc, cmp, exportCondition);
         }
 
         if (auto originalAssumeOp = llvm::dyn_cast<IREE::Util::AssumeIntOp>(
@@ -364,7 +364,7 @@
         }
       }
 
-      builder.create<IREE::HAL::ReturnOp>(loc, exportCondition);
+      IREE::HAL::ReturnOp::create(builder, loc, exportCondition);
     }
     // Current function is still the original function, just with a new symbol
     // name.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp b/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp
index ed748f7..ce78f2a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TensorToVectorVectorizePad.cpp
@@ -134,8 +134,8 @@
       auto srcDimSize =
           rewriter.createOrFold<tensor::DimOp>(loc, padOp.getSource(), i);
       auto lb = getAsIndexValue(lowPads[i], rewriter, loc);
-      auto ub = rewriter.create<affine::AffineApplyOp>(
-          loc, addMap, ValueRange{lb, srcDimSize});
+      auto ub = affine::AffineApplyOp::create(rewriter, loc, addMap,
+                                              ValueRange{lb, srcDimSize});
       paddedDimLBs[i] = lb;
       paddedDimUBs[i] = ub;
     }
@@ -198,31 +198,32 @@
 
       // Need to subtract the low padding to get the index into the source.
       for (int dim : paddedDimIndices) {
-        readIndices[dim] = rewriter.create<affine::AffineApplyOp>(
-            loc, subMap, ValueRange{valueIndices[dim], paddedDimLBs[dim]});
+        readIndices[dim] = affine::AffineApplyOp::create(
+            rewriter, loc, subMap,
+            ValueRange{valueIndices[dim], paddedDimLBs[dim]});
       }
 
-      auto ifOp = rewriter.create<scf::IfOp>(
-          loc, condition,
+      auto ifOp = scf::IfOp::create(
+          rewriter, loc, condition,
           [&](OpBuilder builder, Location Loc) {
-            Value read = builder.create<vector::TransferReadOp>(
-                loc, sliceVectorType, padOp.getSource(), readIndices,
+            Value read = vector::TransferReadOp::create(
+                builder, loc, sliceVectorType, padOp.getSource(), readIndices,
                 paddingValue, llvm::ArrayRef(inBounds));
-            builder.create<scf::YieldOp>(loc, read);
+            scf::YieldOp::create(builder, loc, read);
           },
           [&](OpBuilder builder, Location Loc) {
-            builder.create<scf::YieldOp>(loc, cstSliceVector);
+            scf::YieldOp::create(builder, loc, cstSliceVector);
           });
 
       // Insert this slice back to the full vector.
-      fullVector = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, ifOp.getResult(0), fullVector,
+      fullVector = vector::InsertStridedSliceOp::create(
+          rewriter, loc, ifOp.getResult(0), fullVector,
           llvm::ArrayRef(staticIndices).take_back(fullVectorType.getRank()),
           staticStrides);
     }
 
-    Value fullTensor = rewriter.create<tensor::EmptyOp>(
-        loc, paddedTensorShape, elementType, ValueRange());
+    Value fullTensor = tensor::EmptyOp::create(rewriter, loc, paddedTensorShape,
+                                               elementType, ValueRange());
     valueIndices.assign(tensorRank, zeroIndex);
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         padOp, fullVector, fullTensor, valueIndices);
diff --git a/compiler/src/iree/compiler/Codegen/Common/TestPartitionableLoopsInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TestPartitionableLoopsInterface.cpp
index 329547f..cfee586 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TestPartitionableLoopsInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TestPartitionableLoopsInterface.cpp
@@ -39,8 +39,8 @@
     auto type =
         RankedTensorType::get(partitionableLoops.size(), rewriter.getI32Type());
     auto constantAttr = DenseIntElementsAttr::get(type, partitionableLoops);
-    rewriter.create<IREE::Util::UnfoldableConstantOp>(interfaceOp.getLoc(),
-                                                      constantAttr);
+    IREE::Util::UnfoldableConstantOp::create(rewriter, interfaceOp.getLoc(),
+                                             constantAttr);
     rewriter.modifyOpInPlace(interfaceOp,
                              [&] { interfaceOp->removeAttr(kAttributeName); });
     return success();
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index 0c84a4f..9fa160c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -199,7 +199,7 @@
     }
     numWorkgroups.push_back(numTileAlongDim);
   }
-  Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
   numWorkgroups.resize(workgroupCountOp.getNumResults(), one);
   rewriter.replaceOp(workgroupCountOp, numWorkgroups);
   return success();
@@ -366,7 +366,7 @@
     // Check if tile sizes are deduced from the configuration. If so use
     // those.
     return llvm::map_to_vector(tileSizes, [&](int64_t ts) -> Value {
-      return builder.create<arith::ConstantIndexOp>(op->getLoc(), ts);
+      return arith::ConstantIndexOp::create(builder, op->getLoc(), ts);
     });
   };
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
index 3118560..8ea7ea3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp
@@ -169,12 +169,12 @@
     Value lbVal = getValueOrCreateConstantIndexOp(builder, loc, lb);
     Value ubVal = getValueOrCreateConstantIndexOp(builder, loc, ub);
     Value stepVal = getValueOrCreateConstantIndexOp(builder, loc, step);
-    auto loop = builder.create<scf::ForOp>(
-        loc, lbVal, ubVal, stepVal, ValueRange{},
+    auto loop = scf::ForOp::create(
+        builder, loc, lbVal, ubVal, stepVal, ValueRange{},
         [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
             ValueRange /*iterArgs*/) {
           sizes[index] = createBoundedTileSize(iv, tileSizeVals[index], ub);
-          builder.create<scf::YieldOp>(loc);
+          scf::YieldOp::create(builder, loc);
         });
     offsets[index] = loop.getInductionVar();
     loops.push_back(loop);
@@ -232,8 +232,8 @@
         "failed to create tiled iree_tensor_ext.dispatch.tensor.store op");
   }
 
-  rewriter.create<IREE::TensorExt::DispatchTensorStoreOp>(
-      storeOp.getLoc(), tiledValue, storeOp.getTarget(),
+  IREE::TensorExt::DispatchTensorStoreOp::create(
+      rewriter, storeOp.getLoc(), tiledValue, storeOp.getTarget(),
       clonedSliceAndVals.dynamicDims, combinedOffsets, combinedSizes,
       combinedStrides);
   rewriter.eraseOp(storeOp);
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
index 9839bbc..aad2280 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp
@@ -240,12 +240,12 @@
                                      "Non tensor type operand to copy");
   }
   rewriter.setInsertionPoint(target);
-  Value empty = rewriter.create<tensor::EmptyOp>(
-      target->getLoc(),
+  Value empty = tensor::EmptyOp::create(
+      rewriter, target->getLoc(),
       tensor::getMixedSizes(rewriter, target->getLoc(), operand),
       tensorType.getElementType());
   Operation *copy =
-      rewriter.create<linalg::CopyOp>(target->getLoc(), operand, empty);
+      linalg::CopyOp::create(rewriter, target->getLoc(), operand, empty);
   target->setOperand(operandIndex, copy->getResult(0));
   results.push_back(copy);
   return DiagnosedSilenceableFailure::success();
@@ -345,17 +345,17 @@
   OpFoldResult one = rewriter.getIndexAttr(1);
 
   // Step 3. Create a new parallel loop with a single mapping id.
-  auto newForallOp = rewriter.create<scf::ForallOp>(
-      loc, ArrayRef<OpFoldResult>{zero}, ArrayRef<OpFoldResult>{newUpperBound},
-      ArrayRef<OpFoldResult>{one}, forallOp.getOutputs(),
-      rewriter.getArrayAttr({flatMapping}));
+  auto newForallOp = scf::ForallOp::create(
+      rewriter, loc, ArrayRef<OpFoldResult>{zero},
+      ArrayRef<OpFoldResult>{newUpperBound}, ArrayRef<OpFoldResult>{one},
+      forallOp.getOutputs(), rewriter.getArrayAttr({flatMapping}));
 
   rewriter.setInsertionPointToStart(newForallOp.getBody());
   Value linearId = newForallOp.getInductionVar(0);
 
   // Step 4. Delinearize the flat ID to the original basis.
-  auto ids = rewriter.create<affine::AffineDelinearizeIndexOp>(
-      loc, linearId, forallOp.getMixedUpperBound());
+  auto ids = affine::AffineDelinearizeIndexOp::create(
+      rewriter, loc, linearId, forallOp.getMixedUpperBound());
 
   // Step 5. Inline the region of the original forall op.
   SmallVector<Value> newArgs(ids.getResults());
@@ -431,7 +431,7 @@
   for (auto attr : {bX, bY, bZ}) {
     if (!llvm::is_contained(blockMapping, attr)) {
       blockMapping.push_back(attr);
-      one = one ? one : rewriter.create<arith::ConstantIndexOp>(loc, 1);
+      one = one ? one : arith::ConstantIndexOp::create(rewriter, loc, 1);
       numBlocks.push_back(one);
     }
   }
@@ -452,9 +452,9 @@
     auto idx = static_cast<int64_t>(
         llvm::cast<gpu::GPUBlockMappingAttr>(attr).getBlock());
     workgroupIdOps.push_back(
-        rewriter.create<HAL::InterfaceWorkgroupIDOp>(loc, idx));
+        HAL::InterfaceWorkgroupIDOp::create(rewriter, loc, idx));
     workgroupCountOps.push_back(
-        rewriter.create<HAL::InterfaceWorkgroupCountOp>(loc, idx));
+        HAL::InterfaceWorkgroupCountOp::create(rewriter, loc, idx));
   }
   bvm.map(forallOp.getInductionVars(), workgroupIdOps);
   bvm.map(forallOp.getUpperBound(rewriter), workgroupCountOps);
@@ -781,7 +781,7 @@
   // as an OpDSL named op. However, IREE-specific patterns to cleanup spurious
   // post-bufferization copies do not trigger properly.
   // So we keep using `createLinalgCopyOp` which builds a GenericOp.
-  // builder.create<linalg::CopyOp>(loc, from, to);
+  // linalg::CopyOp::create(builder, loc, from, to);
   mlir::iree_compiler::createLinalgCopyOp(builder, loc, from, to);
   return success();
 }
@@ -811,15 +811,15 @@
     needsBarrier = true;
   }
   if (needsBarrier)
-    builder.create<gpu::BarrierOp>(loc);
+    gpu::BarrierOp::create(builder, loc);
   // TODO: ideally we should use linalg.copy which was recently reintroduced
   // as an OpDSL named op. However, IREE-specific patterns to cleanup spurious
   // post-bufferization copies do not trigger properly.
   // So we keep using `createLinalgCopyOp` which builds a GenericOp.
-  // builder.create<linalg::CopyOp>(loc, from, to);
+  // linalg::CopyOp::create(builder, loc, from, to);
   mlir::iree_compiler::createLinalgCopyOp(builder, loc, from, to);
   if (needsBarrier)
-    builder.create<gpu::BarrierOp>(loc);
+    gpu::BarrierOp::create(builder, loc);
   return success();
 }
 
@@ -1081,7 +1081,7 @@
   //
   // lane_id = (tid_x + tid_y * dim_x + tid_z * dim_y * dim_x) % subgroup_size;
   Value laneId =
-      rewriter.create<gpu::ThreadIdOp>(target.getLoc(), gpu::Dimension::x);
+      gpu::ThreadIdOp::create(rewriter, target.getLoc(), gpu::Dimension::x);
   int64_t subgroupSize = getSubgroupSize();
 
   populateGPUDistributionPatterns(patterns);
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
index b3adfa9..d9023f1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
@@ -325,11 +325,12 @@
       shape, expandShapeOp.getResultType().getElementType());
 
   // Create a new ExtractSliceOp and ExpandShapeOp.
-  Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
-      loc, expandShapeOp.getSrc(), newOffsets, newLengths, newStrides);
-  auto newExpandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
-      loc, resultType, newSliceOp, expandShapeOp.getReassociationIndices(),
-      sizes);
+  Value newSliceOp =
+      tensor::ExtractSliceOp::create(rewriter, loc, expandShapeOp.getSrc(),
+                                     newOffsets, newLengths, newStrides);
+  auto newExpandShapeOp = tensor::ExpandShapeOp::create(
+      rewriter, loc, resultType, newSliceOp,
+      expandShapeOp.getReassociationIndices(), sizes);
   rewriter.replaceOp(sliceOp, newExpandShapeOp);
   return success();
 }
@@ -542,8 +543,9 @@
         for (auto dimIdx : reassocIndices) {
           expandedBasis.push_back(rewriter.getIndexAttr(srcShape[dimIdx]));
         }
-        auto delinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
-            sliceOp.getLoc(), cast<Value>(collapsedOffset), expandedBasis);
+        auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
+            rewriter, sliceOp.getLoc(), cast<Value>(collapsedOffset),
+            expandedBasis);
         createdOps.push_back(delinearizeOp);
         ValueRange offsets = delinearizeOp.getResults();
         expandedOffsets.append(offsets.begin(), offsets.end());
@@ -666,9 +668,9 @@
                            groupExpandedOffsets.rend());
   }
 
-  Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
-      collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
-      expandedSizes, expandedStrides);
+  Value newSliceOp = tensor::ExtractSliceOp::create(
+      rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(),
+      expandedOffsets, expandedSizes, expandedStrides);
   rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
       sliceOp, sliceOp.getResultType(), newSliceOp,
       collapseShapeOp.getReassociationIndices());
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
index 1ad1f2f..bcef32f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
@@ -53,9 +53,9 @@
     unsigned sourceBitWidth = sourceType.getIntOrFloatBitWidth();
     unsigned destBitWidth = targetType.getIntOrFloatBitWidth();
     if (sourceBitWidth > destBitWidth) {
-      return b.create<arith::TruncIOp>(loc, targetType, source);
+      return arith::TruncIOp::create(b, loc, targetType, source);
     } else {
-      return b.create<arith::ExtUIOp>(loc, targetType, source);
+      return arith::ExtUIOp::create(b, loc, targetType, source);
     }
   }
   return nullptr;
@@ -310,8 +310,8 @@
   matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     Location loc = extractOp.getLoc();
-    Value newExtract = rewriter.create<tensor::ExtractOp>(
-        loc, adaptor.getTensor(), adaptor.getIndices());
+    Value newExtract = tensor::ExtractOp::create(
+        rewriter, loc, adaptor.getTensor(), adaptor.getIndices());
     Value replacement = convertElementType(
         rewriter, loc, extractOp.getResult().getType(), newExtract);
     rewriter.replaceOp(extractOp, replacement);
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
index 0e2e74b..4b3f253 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
@@ -241,7 +241,7 @@
   // Create a resolution operation. This conflict should be handled later by
   // someone else, not this analysis.
   Operation *resolveOp =
-      builder.create<IREE::VectorExt::ToLayoutOp>(input.getLoc(), input, rhs);
+      IREE::VectorExt::ToLayoutOp::create(builder, input.getLoc(), input, rhs);
   Value resolvedValue = resolveOp->getResult(0);
   opOperand.set(resolvedValue);
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/VectorizeMemrefCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/VectorizeMemrefCopy.cpp
index 8be2a3e..81a2a01 100644
--- a/compiler/src/iree/compiler/Codegen/Common/VectorizeMemrefCopy.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/VectorizeMemrefCopy.cpp
@@ -24,9 +24,9 @@
     if (copyOp.hasPureTensorSemantics()) {
       return failure();
     }
-    rewriter.create<memref::CopyOp>(copyOp.getLoc(),
-                                    copyOp.getDpsInputOperand(0)->get(),
-                                    copyOp.getDpsInitOperand(0)->get());
+    memref::CopyOp::create(rewriter, copyOp.getLoc(),
+                           copyOp.getDpsInputOperand(0)->get(),
+                           copyOp.getDpsInitOperand(0)->get());
     rewriter.eraseOp(copyOp);
     return success();
   }
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index 0b79b2c..4d7756b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -99,15 +99,16 @@
   if (!isIndirect) {
     std::string name =
         llvm::formatv("__resource_var_{}_{}_", resource.set, resource.binding);
-    variable = builder.create<spirv::GlobalVariableOp>(
-        loc, globalVariableType, name, resource.set, resource.binding);
+    variable = spirv::GlobalVariableOp::create(
+        builder, loc, globalVariableType, name, resource.set, resource.binding);
     if (resource.aliased)
       variable->setAttr("aliased", builder.getUnitAttr());
   } else {
     std::string name =
         llvm::formatv("__resource_var_indirect_{}_", resource.set);
-    variable = builder.create<spirv::GlobalVariableOp>(
-        loc, globalVariableType, name, kIndirectBindingsSetIndex, resource.set);
+    variable = spirv::GlobalVariableOp::create(builder, loc, globalVariableType,
+                                               name, kIndirectBindingsSetIndex,
+                                               resource.set);
   }
   assert(variable);
 
@@ -314,29 +315,30 @@
 
         auto [min, max] = assumeOp.getUnionedUnsignedRange(opIdx);
         if (min.has_value() && max.has_value()) {
-          Value minConst = rewriter.create<spirv::ConstantOp>(
-              loc, i32Type, rewriter.getI32IntegerAttr(*min));
-          Value maxConst = rewriter.create<spirv::ConstantOp>(
-              loc, i32Type, rewriter.getI32IntegerAttr(*max));
-          Value minBound =
-              rewriter.create<spirv::UGreaterThanEqualOp>(loc, value, minConst);
-          rewriter.create<spirv::KHRAssumeTrueOp>(loc, minBound);
+          Value minConst = spirv::ConstantOp::create(
+              rewriter, loc, i32Type, rewriter.getI32IntegerAttr(*min));
+          Value maxConst = spirv::ConstantOp::create(
+              rewriter, loc, i32Type, rewriter.getI32IntegerAttr(*max));
+          Value minBound = spirv::UGreaterThanEqualOp::create(rewriter, loc,
+                                                              value, minConst);
+          spirv::KHRAssumeTrueOp::create(rewriter, loc, minBound);
           Value maxBound =
-              rewriter.create<spirv::ULessThanEqualOp>(loc, value, maxConst);
-          rewriter.create<spirv::KHRAssumeTrueOp>(loc, maxBound);
+              spirv::ULessThanEqualOp::create(rewriter, loc, value, maxConst);
+          spirv::KHRAssumeTrueOp::create(rewriter, loc, maxBound);
         }
 
         std::optional<uint64_t> divisibility =
             assumeOp.getUnionedUnsignedDivisor(opIdx);
         if (divisibility.has_value() && *divisibility > 1) {
-          Value divisor = rewriter.create<spirv::ConstantOp>(
-              loc, i32Type, rewriter.getI32IntegerAttr(*divisibility));
-          Value zero = rewriter.create<spirv::ConstantOp>(
-              loc, i32Type, rewriter.getI32IntegerAttr(0));
-          Value lowPart = rewriter.create<spirv::UModOp>(loc, value, divisor);
+          Value divisor = spirv::ConstantOp::create(
+              rewriter, loc, i32Type,
+              rewriter.getI32IntegerAttr(*divisibility));
+          Value zero = spirv::ConstantOp::create(rewriter, loc, i32Type,
+                                                 rewriter.getI32IntegerAttr(0));
+          Value lowPart = spirv::UModOp::create(rewriter, loc, value, divisor);
           Value dividesExactly =
-              rewriter.create<spirv::IEqualOp>(loc, lowPart, zero);
-          rewriter.create<spirv::KHRAssumeTrueOp>(loc, dividesExactly);
+              spirv::IEqualOp::create(rewriter, loc, lowPart, zero);
+          spirv::KHRAssumeTrueOp::create(rewriter, loc, dividesExactly);
         }
       }
     }
@@ -374,8 +376,8 @@
     auto i32Type = rewriter.getIntegerType(32);
     Value spirvBuiltin =
         spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
-    Value spirvId = rewriter.create<spirv::CompositeExtractOp>(
-        spirvBuiltin.getLoc(), i32Type, spirvBuiltin,
+    Value spirvId = spirv::CompositeExtractOp::create(
+        rewriter, spirvBuiltin.getLoc(), i32Type, spirvBuiltin,
         rewriter.getI32ArrayAttr({index}));
 
     // Casting if Indexing type not 32-bit.
@@ -383,8 +385,8 @@
         *this->template getTypeConverter<SPIRVTypeConverter>();
     Type indexType = typeConverter.getIndexType();
     if (indexType != i32Type) {
-      spirvId = rewriter.create<spirv::UConvertOp>(spirvId.getLoc(), indexType,
-                                                   spirvId);
+      spirvId = spirv::UConvertOp::create(rewriter, spirvId.getLoc(), indexType,
+                                          spirvId);
     }
     rewriter.replaceOp(op, spirvId);
     return success();
@@ -447,20 +449,20 @@
     }
 
     Location loc = subspanOp.getLoc();
-    Value globalAddr = rewriter.create<spirv::AddressOfOp>(loc, varOp);
+    Value globalAddr = spirv::AddressOfOp::create(rewriter, loc, varOp);
     auto i32Ty = rewriter.getI32Type();
-    Value idx = rewriter.create<spirv::ConstantOp>(
-        loc, i32Ty, rewriter.getI32IntegerAttr(info.binding));
-    auto ptr = rewriter.create<spirv::AccessChainOp>(loc, globalAddr, idx);
-    auto addr = rewriter.create<spirv::LoadOp>(loc, ptr);
+    Value idx = spirv::ConstantOp::create(
+        rewriter, loc, i32Ty, rewriter.getI32IntegerAttr(info.binding));
+    auto ptr = spirv::AccessChainOp::create(rewriter, loc, globalAddr, idx);
+    auto addr = spirv::LoadOp::create(rewriter, loc, ptr);
     assert(cast<spirv::PointerType>(addr.getType()).getStorageClass() ==
                spirv::StorageClass::PhysicalStorageBuffer &&
            "Expected a physical storage buffer pointer");
 
     // Bitcast the pointer to the correct pointer type. This is allowed for
     // physical storage buffer addresses.
-    Value ptrInt = rewriter.create<spirv::ConvertPtrToUOp>(
-        loc, rewriter.getI64Type(), addr);
+    Value ptrInt = spirv::ConvertPtrToUOp::create(rewriter, loc,
+                                                  rewriter.getI64Type(), addr);
     rewriter.replaceOpWithNewOp<spirv::ConvertUToPtrOp>(subspanOp,
                                                         convertedType, ptrInt);
     return success();
@@ -774,8 +776,8 @@
 
   // Collect all SPIR-V ops into a spirv.module.
   OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
-  auto spvModule = builder.create<spirv::ModuleOp>(
-      moduleOp.getLoc(), addressingModel, spirv::MemoryModel::GLSL450);
+  auto spvModule = spirv::ModuleOp::create(
+      builder, moduleOp.getLoc(), addressingModel, spirv::MemoryModel::GLSL450);
   Block *body = spvModule.getBody();
   Dialect *spvDialect = spvModule->getDialect();
   for (Operation &op : llvm::make_early_inc_range(*moduleOp.getBody())) {
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index b594601..731d71c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -67,8 +67,8 @@
   MemRefType allocType =
       MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
                       AffineMap(), workgroupSpace);
-  auto allocOp = builder.create<memref::AllocOp>(
-      loc, allocType, dynamicSizes, builder.getI64IntegerAttr(alignment));
+  auto allocOp = memref::AllocOp::create(builder, loc, allocType, dynamicSizes,
+                                         builder.getI64IntegerAttr(alignment));
   return allocOp.getResult();
 }
 
@@ -81,8 +81,9 @@
       spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass::Function);
   MemRefType allocType = MemRefType::get(
       memRefType.getShape(), memRefType.getElementType(), {}, *space);
-  auto allocaOp = builder.create<memref::AllocaOp>(
-      loc, allocType, dynamicSizes, builder.getI64IntegerAttr(alignment));
+  auto allocaOp =
+      memref::AllocaOp::create(builder, loc, allocType, dynamicSizes,
+                               builder.getI64IntegerAttr(alignment));
   return allocaOp.getResult();
 }
 
@@ -94,11 +95,11 @@
   bool needsBarrier = hasSharedMemoryAddressSpace(fromType) ||
                       hasSharedMemoryAddressSpace(toType);
   if (needsBarrier)
-    builder.create<gpu::BarrierOp>(loc);
-  Operation *copy = builder.create<memref::CopyOp>(loc, from, to);
+    gpu::BarrierOp::create(builder, loc);
+  Operation *copy = memref::CopyOp::create(builder, loc, from, to);
   if (needsBarrier) {
     setMarker(copy, getCopyToWorkgroupMemoryMarker());
-    builder.create<gpu::BarrierOp>(loc);
+    gpu::BarrierOp::create(builder, loc);
   }
   return success();
 }
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp
index c303872..3e139fa 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVBreakDownLargeVector.cpp
@@ -101,33 +101,33 @@
     int64_t srcElemIndex = extractBitOffset / srcElemBitwidth;
     int64_t srcElemOffset = extractBitOffset % srcElemBitwidth;
 
-    Value srcElement = rewriter.create<vector::ExtractOp>(
-        extractOp.getLoc(), bitCastOp.getSource(),
+    Value srcElement = vector::ExtractOp::create(
+        rewriter, extractOp.getLoc(), bitCastOp.getSource(),
         ArrayRef<int64_t>{srcElemIndex});
 
-    Value result = rewriter.create<arith::ConstantOp>(
-        extractOp.getLoc(), extOp.getType(),
-        rewriter.getZeroAttr(extOp.getType()));
+    Value result =
+        arith::ConstantOp::create(rewriter, extractOp.getLoc(), extOp.getType(),
+                                  rewriter.getZeroAttr(extOp.getType()));
 
     // Extract each elements assuming little-endian style encoding--lower bits
     // corresponds to earlier elements.
     auto dstElemType = cast<VectorType>(extOp.getType()).getElementType();
-    auto mask = rewriter.create<arith::ConstantOp>(
-        extOp.getLoc(), dstElemType,
+    auto mask = arith::ConstantOp::create(
+        rewriter, extOp.getLoc(), dstElemType,
         rewriter.getIntegerAttr(dstElemType, (1u << midElemBitwidth) - 1));
     int64_t shrSize = srcElemOffset;
     for (int i = 0; i < extractDstType.getNumElements(); ++i) {
       // Each time we extract midElemBitwidth bits from srcElement. We do that
       // by shift right first and then and a mask.
-      Value shrVal = rewriter.create<arith::ConstantOp>(
-          extractOp.getLoc(), dstElemType,
+      Value shrVal = arith::ConstantOp::create(
+          rewriter, extractOp.getLoc(), dstElemType,
           rewriter.getIntegerAttr(dstElemType, shrSize));
-      Value shr = rewriter.create<arith::ShRUIOp>(extractOp.getLoc(),
-                                                  srcElement, shrVal);
+      Value shr = arith::ShRUIOp::create(rewriter, extractOp.getLoc(),
+                                         srcElement, shrVal);
       Value elem =
-          rewriter.create<arith::AndIOp>(extractOp.getLoc(), shr, mask);
-      result = rewriter.create<vector::InsertOp>(extractOp.getLoc(), elem,
-                                                 result, i);
+          arith::AndIOp::create(rewriter, extractOp.getLoc(), shr, mask);
+      result = vector::InsertOp::create(rewriter, extractOp.getLoc(), elem,
+                                        result, i);
       shrSize += midElemBitwidth;
     }
 
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp
index e45f717..bd6be3d 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp
@@ -104,8 +104,8 @@
     }
 
     if (!newArgs.empty()) {
-      auto newOp = rewriter.create<IREE::Util::AssumeIntOp>(
-          op.getLoc(), newArgs, newAssumptions);
+      auto newOp = IREE::Util::AssumeIntOp::create(rewriter, op.getLoc(),
+                                                   newArgs, newAssumptions);
       LLVM_DEBUG(llvm::dbgs()
                  << "WideIntegerEmulation: new op: " << newOp << "\n");
 
@@ -117,10 +117,10 @@
         Type newType = getTypeConverter()->convertType(
             op.getResult(replacementLoc).getType());
         if (auto vecType = dyn_cast_if_present<VectorType>(newType)) {
-          Value zeros = rewriter.create<arith::ConstantOp>(
-              op.getLoc(), newType, rewriter.getZeroAttr(newType));
-          replacement = rewriter.create<vector::InsertOp>(
-              op.getLoc(), result, zeros, ArrayRef<int64_t>{0});
+          Value zeros = arith::ConstantOp::create(
+              rewriter, op.getLoc(), newType, rewriter.getZeroAttr(newType));
+          replacement = vector::InsertOp::create(rewriter, op.getLoc(), result,
+                                                 zeros, ArrayRef<int64_t>{0});
         }
         replacements[replacementLoc] = replacement;
       }
@@ -195,8 +195,8 @@
     // Shape cast results.
     for (auto [oldResult, newResult] :
          llvm::zip_equal(op->getResults(), newOp->getResults())) {
-      Value cast = rewriter.create<vector::ShapeCastOp>(
-          loc, oldResult.getType(), newResult);
+      Value cast = vector::ShapeCastOp::create(rewriter, loc,
+                                               oldResult.getType(), newResult);
       rewriter.replaceAllUsesWith(oldResult, cast);
     }
     return success();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp
index 3f20bbb..fdfb2cd 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEraseStorageBufferStaticShape.cpp
@@ -82,11 +82,11 @@
 
   SmallVector<Value, 1> dynamicDims;
   assert(subspanOp.getDynamicDims().empty());
-  dynamicDims.push_back(rewriter.create<arith::ConstantIndexOp>(
-      subspanOp.getLoc(), oldType.getNumElements()));
+  dynamicDims.push_back(arith::ConstantIndexOp::create(
+      rewriter, subspanOp.getLoc(), oldType.getNumElements()));
 
-  auto newOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
-      subspanOp.getLoc(), newType, subspanOp.getLayoutAttr(),
+  auto newOp = IREE::HAL::InterfaceBindingSubspanOp::create(
+      rewriter, subspanOp.getLoc(), newType, subspanOp.getLayoutAttr(),
       subspanOp.getBindingAttr(), subspanOp.getByteOffset(), dynamicDims,
       subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr());
 
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp
index aa5c73a..4c3f84e 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp
@@ -167,8 +167,8 @@
     OpBuilder moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
 
     // Create our new "linked" hal.executable.
-    auto linkedExecutableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>(
-        moduleOp.getLoc(), linkedExecutableName);
+    auto linkedExecutableOp = IREE::HAL::ExecutableOp::create(
+        moduleBuilder, moduleOp.getLoc(), linkedExecutableName);
     linkedExecutableOp.setVisibility(
         sourceExecutableOps.front().getVisibility());
     OpBuilder executableBuilder =
@@ -180,11 +180,10 @@
           executableTargetAttrs.size() == 1
               ? attr.getSymbolNameFragment()
               : llvm::formatv("{}_{}", attr.getSymbolNameFragment(), index);
-      auto linkedTargetOp =
-          executableBuilder.create<IREE::HAL::ExecutableVariantOp>(
-              moduleOp.getLoc(), linkedVariantName, attr);
+      auto linkedTargetOp = IREE::HAL::ExecutableVariantOp::create(
+          executableBuilder, moduleOp.getLoc(), linkedVariantName, attr);
       auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock());
-      targetBuilder.create<mlir::ModuleOp>(moduleOp.getLoc());
+      mlir::ModuleOp::create(targetBuilder, moduleOp.getLoc());
 
       auto mergeModuleFn = [](mlir::ModuleOp sourceInnerModule,
                               mlir::ModuleOp linkedInnerModule,
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp
index 57508cb..d01eab9 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVMaterializeExecutableConditions.cpp
@@ -203,20 +203,21 @@
   TypedAttr zeroAttr = builder.getZeroAttr(i32Type);
 
   auto buildQueryOp = [&](const char *key, uint32_t value, Value result) {
-    auto queryOp = builder.create<IREE::HAL::DeviceQueryOp>(
-        loc, boolType, i32Type, device, builder.getStringAttr("hal.dispatch"),
-        builder.getStringAttr(key), zeroAttr);
-    auto zero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
-    auto val = builder.create<arith::ConstantIntOp>(loc, value, 32);
-    auto andOp = builder.create<arith::AndIOp>(loc, queryOp.getValue(), val);
-    auto cmpOp = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
-                                               andOp, zero);
+    auto queryOp = IREE::HAL::DeviceQueryOp::create(
+        builder, loc, boolType, i32Type, device,
+        builder.getStringAttr("hal.dispatch"), builder.getStringAttr(key),
+        zeroAttr);
+    auto zero = arith::ConstantIntOp::create(builder, loc, 0, 32);
+    auto val = arith::ConstantIntOp::create(builder, loc, value, 32);
+    auto andOp = arith::AndIOp::create(builder, loc, queryOp.getValue(), val);
+    auto cmpOp = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne,
+                                       andOp, zero);
     // Verify that 1) the query succeeds and 2) the capability is supported.
-    auto ok = builder.create<arith::AndIOp>(loc, queryOp.getOk(), cmpOp);
-    return builder.create<arith::AndIOp>(loc, result, ok).getResult();
+    auto ok = arith::AndIOp::create(builder, loc, queryOp.getOk(), cmpOp);
+    return arith::AndIOp::create(builder, loc, result, ok).getResult();
   };
 
-  Value result = builder.create<arith::ConstantIntOp>(loc, true, 1);
+  Value result = arith::ConstantIntOp::create(builder, loc, true, 1);
   if (features.computeFloat) {
     result =
         buildQueryOp("compute.bitwidths.fp", features.computeFloat, result);
@@ -239,7 +240,7 @@
   if (features.address) {
     result = buildQueryOp("address.mode", features.address, result);
   }
-  builder.create<IREE::HAL::ReturnOp>(loc, result);
+  IREE::HAL::ReturnOp::create(builder, loc, result);
 }
 
 // Returns the device queries as a list of unique keys.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
index 4f4a20c..d947d7c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTileAndVectorizeToCooperativeOps.cpp
@@ -150,7 +150,7 @@
     tileSizes.resize(
         std::min(cast<linalg::LinalgOp>(op).getNumParallelLoops(), 3u));
     return llvm::map_to_vector(tileSizes, [&](int64_t v) -> Value {
-      return builder.create<arith::ConstantIndexOp>(op->getLoc(), v);
+      return arith::ConstantIndexOp::create(builder, op->getLoc(), v);
     });
   };
   auto tilingOptions = linalg::LinalgTilingOptions()
@@ -312,8 +312,8 @@
     if (!foundTranspose)
       return failure();
 
-    Value res = rewriter.create<vector::ContractionOp>(
-        loc, newSources[0], newSources[1], newSources[2],
+    Value res = vector::ContractionOp::create(
+        rewriter, loc, newSources[0], newSources[1], newSources[2],
         rewriter.getAffineMapArrayAttr(newMaps), op.getIteratorTypes());
     rewriter.replaceOp(op, res);
     return success();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
index 83d2e70..518f359 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
@@ -383,8 +383,8 @@
     // If the transfer_read can be replaced by a load after vectorization use
     // LoadOp and cast back to the original type.
     if (*vectorMemrefElemSize == *readVecSize) {
-      Value newLoad = rewriter.create<memref::LoadOp>(
-          loc, memrefVectorType, adaptor.getBase(), indices.value());
+      Value newLoad = memref::LoadOp::create(
+          rewriter, loc, memrefVectorType, adaptor.getBase(), indices.value());
       rewriter.replaceOpWithNewOp<vector::BitCastOp>(read, readVectorType,
                                                      newLoad);
       return success();
@@ -412,11 +412,11 @@
                         memrefVectorType.getElementType());
 
     for (int i = 0; i < vectorCount; ++i) {
-      Value iVal = rewriter.create<arith::ConstantIndexOp>(loc, i);
-      indices->back() = rewriter.create<affine::AffineApplyOp>(
-          loc, addMap, ValueRange{oldIndex, iVal});
+      Value iVal = arith::ConstantIndexOp::create(rewriter, loc, i);
+      indices->back() = affine::AffineApplyOp::create(
+          rewriter, loc, addMap, ValueRange{oldIndex, iVal});
       vectors.push_back(
-          rewriter.create<memref::LoadOp>(loc, adaptor.getBase(), *indices));
+          memref::LoadOp::create(rewriter, loc, adaptor.getBase(), *indices));
     }
 
     // If there is only two component vectors, we can use ShuffleOp, which is a
@@ -424,8 +424,8 @@
     if (vectorCount == 2) {
       SmallVector<int64_t> seqIndices =
           llvm::to_vector(llvm::seq<int64_t>(readVectorType.getNumElements()));
-      auto ShuffleOp = rewriter.create<vector::ShuffleOp>(
-          loc, vectors[0], vectors[1], seqIndices);
+      auto ShuffleOp = vector::ShuffleOp::create(rewriter, loc, vectors[0],
+                                                 vectors[1], seqIndices);
       rewriter.replaceOpWithNewOp<vector::BitCastOp>(read, readVectorType,
                                                      ShuffleOp);
       return success();
@@ -434,12 +434,12 @@
     SmallVector<int64_t> offsets(combinedType.getRank(), 0);
     SmallVector<int64_t> strides(combinedType.getRank(), 1);
 
-    Value newVector = rewriter.create<arith::ConstantOp>(
-        loc, combinedType, rewriter.getZeroAttr(combinedType));
+    Value newVector = arith::ConstantOp::create(
+        rewriter, loc, combinedType, rewriter.getZeroAttr(combinedType));
     for (int i = 0; i < vectorCount; ++i) {
       offsets.back() = i * memrefVectorType.getNumElements();
-      newVector = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, vectors[i], newVector, offsets, strides);
+      newVector = vector::InsertStridedSliceOp::create(
+          rewriter, loc, vectors[i], newVector, offsets, strides);
     }
 
     rewriter.replaceOp(read, newVector);
@@ -486,8 +486,8 @@
     // If the transfer_write can be replaced by a store after vectorization cast
     // the original value and use StoreOp.
     if (*vectorMemrefElemSize == *writeVecSize) {
-      Value data = rewriter.create<vector::BitCastOp>(
-          loc, memrefVectorType, adaptor.getValueToStore());
+      Value data = vector::BitCastOp::create(rewriter, loc, memrefVectorType,
+                                             adaptor.getValueToStore());
       rewriter.replaceOpWithNewOp<memref::StoreOp>(
           write, data, adaptor.getBase(), indices.value());
       return success();
@@ -516,15 +516,15 @@
 
     for (int i = 0; i < vectorCount; ++i) {
       offsets.back() = i * memrefVectorType.getNumElements();
-      auto slice = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, adaptor.getValueToStore(), offsets, sizes, strides);
+      auto slice = vector::ExtractStridedSliceOp::create(
+          rewriter, loc, adaptor.getValueToStore(), offsets, sizes, strides);
       auto component =
-          rewriter.create<vector::BitCastOp>(loc, memrefVectorType, slice);
-      Value iVal = rewriter.create<arith::ConstantIndexOp>(loc, i);
-      indices->back() = rewriter.create<affine::AffineApplyOp>(
-          loc, addMap, ValueRange{oldIndex, iVal});
-      rewriter.create<memref::StoreOp>(loc, component, adaptor.getBase(),
-                                       *indices);
+          vector::BitCastOp::create(rewriter, loc, memrefVectorType, slice);
+      Value iVal = arith::ConstantIndexOp::create(rewriter, loc, i);
+      indices->back() = affine::AffineApplyOp::create(
+          rewriter, loc, addMap, ValueRange{oldIndex, iVal});
+      memref::StoreOp::create(rewriter, loc, component, adaptor.getBase(),
+                              *indices);
     }
 
     rewriter.eraseOp(write);
@@ -614,10 +614,10 @@
   auto divMap = AffineMap::get(0, 2, {sym0.floorDiv(sym1)}, context);
 
   unsigned ratio = *vectorMemrefElemSize / *scalarMemrefElemSize;
-  Value valueRatio = rewriter.create<arith::ConstantIndexOp>(loc, ratio);
+  Value valueRatio = arith::ConstantIndexOp::create(rewriter, loc, ratio);
   auto newIndices = llvm::to_vector(indices);
-  newIndices.back() = rewriter.create<affine::AffineApplyOp>(
-      loc, divMap, ValueRange{indices.back(), valueRatio});
+  newIndices.back() = affine::AffineApplyOp::create(
+      rewriter, loc, divMap, ValueRange{indices.back(), valueRatio});
   return newIndices;
 }
 
@@ -777,14 +777,14 @@
     auto thenBuilder = [&](OpBuilder &b, Location loc) {
       Value thenRes = thenConditionBuilder(b, loc);
       if (thenRes) {
-        b.create<scf::YieldOp>(loc, thenRes);
+        scf::YieldOp::create(b, loc, thenRes);
       } else {
-        b.create<scf::YieldOp>(loc);
+        scf::YieldOp::create(b, loc);
       }
     };
-    auto ifOp = b.create<scf::IfOp>(loc, maybeMaskBit,
-                                    /*thenBuilder=*/thenBuilder,
-                                    /*elseBuilder=*/elseConditionBuilder);
+    auto ifOp = scf::IfOp::create(b, loc, maybeMaskBit,
+                                  /*thenBuilder=*/thenBuilder,
+                                  /*elseBuilder=*/elseConditionBuilder);
 
     return !ifOp.getNumResults() ? Value() : ifOp->getResult(0);
   }
@@ -812,8 +812,8 @@
     if (vectorType.getRank() == 0) {
       Value maybeMaskBit;
       if (maybeMask) {
-        maybeMaskBit = rewriter.create<vector::ExtractOp>(loc, maybeMask,
-                                                          ArrayRef<int64_t>{0});
+        maybeMaskBit = vector::ExtractOp::create(rewriter, loc, maybeMask,
+                                                 ArrayRef<int64_t>{0});
       }
 
       auto thenCond = [&](OpBuilder &b, Location loc) {
@@ -822,7 +822,7 @@
             .getResult();
       };
       auto elseCond = [&](OpBuilder &b, Location loc) {
-        b.create<scf::YieldOp>(loc, readOp.getPadding());
+        scf::YieldOp::create(b, loc, readOp.getPadding());
       };
 
       Value scalar = predicateMaybeMaskedScalarTransfer(
@@ -843,8 +843,8 @@
     auto indices = llvm::to_vector(readOp.getIndices());
     Value oldIndex = indices[dimPos];
 
-    Value newVector = rewriter.create<arith::ConstantOp>(
-        loc, vectorType, rewriter.getZeroAttr(vectorType));
+    Value newVector = arith::ConstantOp::create(
+        rewriter, loc, vectorType, rewriter.getZeroAttr(vectorType));
     for (int i = 0; i < vectorType.getDimSize(0); ++i) {
       // Extract the mask bit for this value if present.
       Value maybeMaskBit;
@@ -852,24 +852,25 @@
         // The result vector is 1-D and we have a projected permutation, meaning
         // we can just extract the mask bit using the same index as the loaded
         // vector.
-        maybeMaskBit = rewriter.create<vector::ExtractOp>(loc, maybeMask,
-                                                          ArrayRef<int64_t>{i});
+        maybeMaskBit = vector::ExtractOp::create(rewriter, loc, maybeMask,
+                                                 ArrayRef<int64_t>{i});
       }
 
-      Value iVal = rewriter.create<arith::ConstantIndexOp>(loc, i);
+      Value iVal = arith::ConstantIndexOp::create(rewriter, loc, i);
       auto thenCond = [&](OpBuilder &b, Location loc) {
-        indices[dimPos] = b.create<affine::AffineApplyOp>(
-            loc, addMap, ValueRange{oldIndex, iVal});
-        Value scalar = b.create<memref::LoadOp>(loc, readOp.getBase(), indices);
+        indices[dimPos] = affine::AffineApplyOp::create(
+            b, loc, addMap, ValueRange{oldIndex, iVal});
+        Value scalar =
+            memref::LoadOp::create(b, loc, readOp.getBase(), indices);
         return scalar;
       };
       auto elseCond = [&](OpBuilder &b, Location loc) {
-        b.create<scf::YieldOp>(loc, readOp.getPadding());
+        scf::YieldOp::create(b, loc, readOp.getPadding());
       };
 
       Value scalar = predicateMaybeMaskedScalarTransfer(
           rewriter, loc, maybeMaskBit, thenCond, elseCond);
-      newVector = rewriter.create<vector::InsertOp>(loc, scalar, newVector, i);
+      newVector = vector::InsertOp::create(rewriter, loc, scalar, newVector, i);
     }
     rewriter.replaceOp(readOp, newVector);
     return success();
@@ -887,8 +888,8 @@
 
     Location loc = loadOp.getLoc();
     if (vectorType.getRank() == 0) {
-      Value scalar = rewriter.create<memref::LoadOp>(loc, loadOp.getBase(),
-                                                     loadOp.getIndices());
+      Value scalar = memref::LoadOp::create(rewriter, loc, loadOp.getBase(),
+                                            loadOp.getIndices());
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vectorType,
                                                        scalar);
       return success();
@@ -905,15 +906,15 @@
     auto indices = llvm::to_vector(loadOp.getIndices());
     Value oldIndex = indices[dimPos];
 
-    Value newVector = rewriter.create<arith::ConstantOp>(
-        loc, vectorType, rewriter.getZeroAttr(vectorType));
+    Value newVector = arith::ConstantOp::create(
+        rewriter, loc, vectorType, rewriter.getZeroAttr(vectorType));
     for (int i = 0; i < vectorType.getDimSize(0); ++i) {
-      Value iVal = rewriter.create<arith::ConstantIndexOp>(loc, i);
-      indices[dimPos] = rewriter.create<affine::AffineApplyOp>(
-          loc, addMap, ValueRange{oldIndex, iVal});
+      Value iVal = arith::ConstantIndexOp::create(rewriter, loc, i);
+      indices[dimPos] = affine::AffineApplyOp::create(
+          rewriter, loc, addMap, ValueRange{oldIndex, iVal});
       Value scalar =
-          rewriter.create<memref::LoadOp>(loc, loadOp.getBase(), indices);
-      newVector = rewriter.create<vector::InsertOp>(loc, scalar, newVector, i);
+          memref::LoadOp::create(rewriter, loc, loadOp.getBase(), indices);
+      newVector = vector::InsertOp::create(rewriter, loc, scalar, newVector, i);
     }
     rewriter.replaceOp(loadOp, newVector);
     return success();
@@ -937,14 +938,14 @@
 
       Value maybeMaskBit;
       if (maybeMask) {
-        maybeMaskBit = rewriter.create<vector::ExtractOp>(loc, maybeMask,
-                                                          ArrayRef<int64_t>{0});
+        maybeMaskBit = vector::ExtractOp::create(rewriter, loc, maybeMask,
+                                                 ArrayRef<int64_t>{0});
       }
 
       auto thenCond = [&](OpBuilder &b, Location loc) {
-        Value scalar = b.create<vector::ExtractOp>(loc, writeOp.getVector());
-        b.create<memref::StoreOp>(loc, scalar, writeOp.getBase(),
-                                  writeOp.getIndices());
+        Value scalar = vector::ExtractOp::create(b, loc, writeOp.getVector());
+        memref::StoreOp::create(b, loc, scalar, writeOp.getBase(),
+                                writeOp.getIndices());
         return Value();
       };
 
@@ -970,16 +971,17 @@
         // The result vector is 1-D and we have a projected permutation, meaning
         // we can just extract the mask bit using the same index as the written
         // vector.
-        maybeMaskBit = rewriter.create<vector::ExtractOp>(loc, maybeMask,
-                                                          ArrayRef<int64_t>{i});
+        maybeMaskBit = vector::ExtractOp::create(rewriter, loc, maybeMask,
+                                                 ArrayRef<int64_t>{i});
       }
 
-      Value iVal = rewriter.create<arith::ConstantIndexOp>(loc, i);
+      Value iVal = arith::ConstantIndexOp::create(rewriter, loc, i);
       auto thenCond = [&](OpBuilder &b, Location loc) {
-        indices[dimPos] = b.create<affine::AffineApplyOp>(
-            loc, addMap, ValueRange{oldIndex, iVal});
-        Value scalar = b.create<vector::ExtractOp>(loc, writeOp.getVector(), i);
-        b.create<memref::StoreOp>(loc, scalar, writeOp.getBase(), indices);
+        indices[dimPos] = affine::AffineApplyOp::create(
+            b, loc, addMap, ValueRange{oldIndex, iVal});
+        Value scalar =
+            vector::ExtractOp::create(b, loc, writeOp.getVector(), i);
+        memref::StoreOp::create(b, loc, scalar, writeOp.getBase(), indices);
         return Value();
       };
       (void)predicateMaybeMaskedScalarTransfer(rewriter, loc, maybeMaskBit,
@@ -1020,19 +1022,19 @@
 
     Location loc = maskOp.getLoc();
     Value maskBit =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(true));
+        arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true));
     for (auto [idx, size] :
          llvm::zip_equal(extractOp.getMixedPosition(), maskOp.getOperands())) {
       Value idxVal;
       if (auto attr = dyn_cast<Attribute>(idx)) {
-        idxVal = rewriter.create<arith::ConstantIndexOp>(
-            loc, dyn_cast<IntegerAttr>(attr).getInt());
+        idxVal = arith::ConstantIndexOp::create(
+            rewriter, loc, dyn_cast<IntegerAttr>(attr).getInt());
       } else {
         idxVal = dyn_cast<Value>(idx);
       }
-      Value cmpIdx = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::slt, idxVal, size);
-      maskBit = rewriter.create<arith::AndIOp>(loc, cmpIdx, maskBit);
+      Value cmpIdx = arith::CmpIOp::create(
+          rewriter, loc, arith::CmpIPredicate::slt, idxVal, size);
+      maskBit = arith::AndIOp::create(rewriter, loc, cmpIdx, maskBit);
     }
     rewriter.replaceOp(extractOp, maskBit);
     return success();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
index fb5e694..2685c41 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Utils.cpp
@@ -67,7 +67,7 @@
   linalg::TileSizeComputationFunction computeFn =
       [tileSizes](OpBuilder &builder, Operation *op) {
         auto range = llvm::map_range(*tileSizes, [&](int64_t size) -> Value {
-          return builder.create<arith::ConstantIndexOp>(op->getLoc(), size);
+          return arith::ConstantIndexOp::create(builder, op->getLoc(), size);
         });
         return llvm::to_vector(range);
       };
@@ -101,8 +101,8 @@
   std::array<gpu::Dimension, kNumGPUDims> dimAttr{
       gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z};
   Type indexType = builder.getIndexType();
-  return {builder.create<GPUIdOp>(loc, indexType, dimAttr[dim]),
-          builder.create<GPUCountOp>(loc, indexType, dimAttr[dim]),
+  return {GPUIdOp::create(builder, loc, indexType, dimAttr[dim]),
+          GPUCountOp::create(builder, loc, indexType, dimAttr[dim]),
           linalg::DistributionMethod::Cyclic};
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp
index 704c817..f09e916 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp
@@ -34,8 +34,8 @@
     // Create our new "linked" hal.executable.
     std::string linkedExecutableName =
         llvm::formatv("{}_linked_{}", moduleName, "vmvx");
-    auto linkedExecutableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>(
-        moduleOp.getLoc(), linkedExecutableName);
+    auto linkedExecutableOp = IREE::HAL::ExecutableOp::create(
+        moduleBuilder, moduleOp.getLoc(), linkedExecutableName);
     linkedExecutableOp.setVisibility(
         sourceExecutableOps.front().getVisibility());
     auto executableBuilder =
@@ -50,16 +50,15 @@
               ? targetAttr.getSymbolNameFragment()
               : llvm::formatv("{}_{}", targetAttr.getSymbolNameFragment(),
                               index);
-      auto linkedTargetOp =
-          executableBuilder.create<IREE::HAL::ExecutableVariantOp>(
-              moduleOp.getLoc(), linkedVariantName, targetAttr);
+      auto linkedTargetOp = IREE::HAL::ExecutableVariantOp::create(
+          executableBuilder, moduleOp.getLoc(), linkedVariantName, targetAttr);
       auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock());
-      auto linkedModuleOp = targetBuilder.create<ModuleOp>(moduleOp.getLoc());
+      auto linkedModuleOp = ModuleOp::create(targetBuilder, moduleOp.getLoc());
 
       // Add an empty vm.module to that module as our vm.funcs must live in it.
       auto nestedBuilder = OpBuilder::atBlockBegin(linkedModuleOp.getBody());
-      nestedBuilder.create<IREE::VM::ModuleOp>(moduleOp.getLoc(),
-                                               "linked_module");
+      IREE::VM::ModuleOp::create(nestedBuilder, moduleOp.getLoc(),
+                                 "linked_module");
 
       auto mergeModuleFn = [](mlir::ModuleOp sourceInnerModule,
                               mlir::ModuleOp linkedInnerModule,
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp
index bfb42ea..7c98661 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXLowerLinalgMicrokernels.cpp
@@ -50,7 +50,7 @@
   for (Value &stride : strides) {
     if (!stride) {
       if (!zero) {
-        zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+        zero = arith::ConstantIndexOp::create(builder, loc, 0);
       }
       stride = zero;
     }
@@ -65,7 +65,7 @@
   Value padValue;
   while (indices.size() < minRank) {
     if (!padValue) {
-      padValue = builder.create<arith::ConstantIndexOp>(loc, padIndex);
+      padValue = arith::ConstantIndexOp::create(builder, loc, padIndex);
     }
     indices.insert(indices.begin(), padValue);
   }
@@ -195,9 +195,9 @@
       sizeStrideTypes.push_back(indexType);
     }
 
-    auto op = builder.create<IREE::VMVX::GetBufferDescriptorOp>(
-        loc, builder.getType<IREE::Util::BufferType>(), builder.getIndexType(),
-        sizeStrideTypes, sizeStrideTypes, buffer);
+    auto op = IREE::VMVX::GetBufferDescriptorOp::create(
+        builder, loc, builder.getType<IREE::Util::BufferType>(),
+        builder.getIndexType(), sizeStrideTypes, sizeStrideTypes, buffer);
 
     desc->baseBuffer = op.getBaseBuffer();
     desc->offset = op.getOffset();
@@ -308,8 +308,8 @@
 
     switch (selection.opType) {
     case OpType::GenericBinary: {
-      rewriter.create<IREE::VMVX::BinaryOp>(
-          loc, rewriter.getStringAttr(selection.opcode),
+      IREE::VMVX::BinaryOp::create(
+          rewriter, loc, rewriter.getStringAttr(selection.opcode),
           // LHS
           params.in0Buffer, operands.first.bufferDesc->offset,
           params.in0Strides,
@@ -411,8 +411,8 @@
 
     switch (selection.opType) {
     case OpType::GenericUnary: {
-      rewriter.create<IREE::VMVX::UnaryOp>(
-          loc, rewriter.getStringAttr(selection.opcode),
+      IREE::VMVX::UnaryOp::create(
+          rewriter, loc, rewriter.getStringAttr(selection.opcode),
           // IN
           params.inBuffer, operand.bufferDesc->offset, params.inStrides,
           // OUT
@@ -508,16 +508,15 @@
     leftPadToRank(loc, outStrides, 2, 0, rewriter);
     leftPadToRank(loc, sizes, 2, 1, rewriter);
 
-    rewriter.create<IREE::VMVX::CopyOp>(
-        loc,
-        // IN
-        inBuffer, in.bufferDesc->offset, inStrides,
-        // OUT
-        outBuffer, out.bufferDesc->offset, outStrides,
-        // Sizes
-        sizes,
-        // Element type.
-        in.bufferDesc->getElementTypeAttr());
+    IREE::VMVX::CopyOp::create(rewriter, loc,
+                               // IN
+                               inBuffer, in.bufferDesc->offset, inStrides,
+                               // OUT
+                               outBuffer, out.bufferDesc->offset, outStrides,
+                               // Sizes
+                               sizes,
+                               // Element type.
+                               in.bufferDesc->getElementTypeAttr());
   }
 };