NFC: Remove usages of Value::operator* and Value::operator-> now that Value is properly value-typed.

These were temporary methods used to simplify the transition.

PiperOrigin-RevId: 287913613
diff --git a/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp b/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp
index 98ea2b8..ebf873b 100644
--- a/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp
+++ b/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp
@@ -99,14 +99,14 @@
         // XLA resource functionalization should have canonicalized everything
         // to uses of those two ops in the body of the tf_saved_model exported
         // function.
-        for (OpOperand &operand : llvm::make_early_inc_range(arg->getUses())) {
+        for (OpOperand &operand : llvm::make_early_inc_range(arg.getUses())) {
           if (auto read_variable =
                   dyn_cast<TF::ReadVariableOp>(operand.getOwner())) {
             auto load = OpBuilder(read_variable)
                             .create<IREE::Flow::VariableLoadOp>(
                                 read_variable.getLoc(),
-                                read_variable.value()->getType(), flow_sym_ref);
-            read_variable.value()->replaceAllUsesWith(load.result());
+                                read_variable.value().getType(), flow_sym_ref);
+            read_variable.value().replaceAllUsesWith(load.result());
             read_variable.erase();
             continue;
           }
@@ -128,8 +128,8 @@
         auto load =
             OpBuilder(func.getBody())
                 .create<IREE::Flow::VariableLoadOp>(
-                    global_tensor.getLoc(), arg->getType(), flow_sym_ref);
-        arg->replaceAllUsesWith(load.result());
+                    global_tensor.getLoc(), arg.getType(), flow_sym_ref);
+        arg.replaceAllUsesWith(load.result());
       }
     }
     func.eraseArguments(args_to_erase);
diff --git a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp
index d457c3d..7dc574f 100644
--- a/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Conversion/HLOToFlow/ConvertHLOToFlow.cpp
@@ -51,7 +51,7 @@
                                                          tensorValue);
         }));
     rewriter.replaceOpWithNewOp<IREE::Flow::TensorUpdateOp>(
-        op, op.getResult()->getType(), op.update(), op.operand(), startIndices);
+        op, op.getResult().getType(), op.update(), op.operand(), startIndices);
     return matchSuccess();
   }
 };
diff --git a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.cpp b/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.cpp
index 42b3cc0..56652e0 100644
--- a/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.cpp
+++ b/iree/compiler/Dialect/Flow/Conversion/StandardToFlow/ConvertStandardToFlow.cpp
@@ -31,7 +31,7 @@
   using OpRewritePattern::OpRewritePattern;
   PatternMatchResult matchAndRewrite(ExtractElementOp op,
                                      PatternRewriter &rewriter) const override {
-    auto aggregateType = op.getAggregate()->getType().dyn_cast<TensorType>();
+    auto aggregateType = op.getAggregate().getType().dyn_cast<TensorType>();
     if (!aggregateType) {
       // We currently are only looking for tensor types.
       return matchFailure();
@@ -49,7 +49,7 @@
                                        ConversionTarget &conversionTarget) {
   conversionTarget.addDynamicallyLegalOp<ExtractElementOp>(
       [](ExtractElementOp op) {
-        return !op.getAggregate()->getType().isa<TensorType>();
+        return !op.getAggregate().getType().isa<TensorType>();
       });
 }
 
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 80b1b2d..583866d 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -83,7 +83,7 @@
 
   PatternMatchResult matchAndRewrite(VariableLoadOp op,
                                      PatternRewriter &rewriter) const override {
-    if (op.result()->use_empty()) {
+    if (op.result().use_empty()) {
       rewriter.eraseOp(op);
       return matchSuccess();
     }
@@ -110,7 +110,7 @@
   PatternMatchResult matchAndRewrite(VariableStoreOp op,
                                      PatternRewriter &rewriter) const override {
     if (auto loadOp =
-            dyn_cast_or_null<VariableLoadOp>(op.value()->getDefiningOp())) {
+            dyn_cast_or_null<VariableLoadOp>(op.value().getDefiningOp())) {
       if (loadOp.variable() == op.variable()) {
         rewriter.eraseOp(op);
         return matchSuccess();
@@ -148,8 +148,8 @@
 }
 
 OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
-  auto sourceType = source()->getType().cast<ShapedType>();
-  auto resultType = result()->getType().cast<ShapedType>();
+  auto sourceType = source().getType().cast<ShapedType>();
+  auto resultType = result().getType().cast<ShapedType>();
   if (sourceType.hasStaticShape() && sourceType == resultType) {
     // No-op.
     return source();
@@ -157,7 +157,7 @@
 
   // Skip intermediate reshapes.
   if (auto definingOp =
-          dyn_cast_or_null<TensorReshapeOp>(source()->getDefiningOp())) {
+          dyn_cast_or_null<TensorReshapeOp>(source().getDefiningOp())) {
     setOperand(definingOp.getOperand());
     return result();
   }
@@ -206,7 +206,7 @@
   // TODO(benvanik): only fold when shape is constant.
   if (operands[0]) {
     // Splat value is constant and we can fold the operation.
-    return SplatElementsAttr::get(result()->getType().cast<ShapedType>(),
+    return SplatElementsAttr::get(result().getType().cast<ShapedType>(),
                                   operands[0]);
   }
   return {};
@@ -244,8 +244,8 @@
                         operands[1].cast<ElementsAttr>(), indices);
   } else {
     // Replace the entire tensor when the sizes match.
-    auto updateType = update()->getType().cast<ShapedType>();
-    auto targetType = target()->getType().cast<ShapedType>();
+    auto updateType = update().getType().cast<ShapedType>();
+    auto targetType = target().getType().cast<ShapedType>();
     if (updateType.hasStaticShape() && targetType.hasStaticShape() &&
         updateType == targetType) {
       return update();
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index 19a5a7f..7e2b604 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -198,7 +198,7 @@
   p.printSymbolName(op.variable());
   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"variable"});
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 static LogicalResult verifyVariableLoadOp(VariableLoadOp &op) {
@@ -207,7 +207,7 @@
     return op.emitOpError() << "undefined variable: " << op.variable();
   }
   auto variableOp = dyn_cast<VariableOp>(symbolOp);
-  auto loadType = op.result()->getType();
+  auto loadType = op.result().getType();
   if (!isVariableTypeCompatible(variableOp.type(), loadType)) {
     return op.emitOpError()
            << "variable type mismatch; variable " << op.variable() << " is "
@@ -243,7 +243,7 @@
   p.printSymbolName(op.variable());
   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"variable"});
   p << " : ";
-  p.printType(op.value()->getType());
+  p.printType(op.value().getType());
 }
 
 static LogicalResult verifyVariableStoreOp(VariableStoreOp &op) {
@@ -252,7 +252,7 @@
     return op.emitOpError() << "undefined variable: " << op.variable();
   }
   auto variableOp = dyn_cast<VariableOp>(symbolOp);
-  auto storeType = op.value()->getType();
+  auto storeType = op.value().getType();
   if (!isVariableTypeCompatible(variableOp.type(), storeType)) {
     return op.emitOpError()
            << "variable type mismatch; variable " << op.variable() << " is "
@@ -344,16 +344,16 @@
   p << "[";
   p.printOperand(op.workload());
   p << " : ";
-  p.printType(op.workload()->getType());
+  p.printType(op.workload().getType());
   p << "]";
 
   // Print the data argument remapping.
   p << "(";
   interleaveComma(llvm::zip(op.body().front().getArguments(), op.args()), p,
                   [&](std::tuple<BlockArgument, Value> it) {
-                    p << *std::get<0>(it) << " = " << *std::get<1>(it);
+                    p << std::get<0>(it) << " = " << std::get<1>(it);
                     p << " : ";
-                    p << std::get<1>(it)->getType();
+                    p << std::get<1>(it).getType();
                   });
   p << ")";
 
@@ -469,7 +469,7 @@
   p << "[";
   p.printOperand(op.workload());
   p << " : ";
-  p.printType(op.workload()->getType());
+  p.printType(op.workload().getType());
   p << "]";
 
   p << "(";
@@ -478,7 +478,7 @@
   if (op.getNumResults() > 0) {
     p << " : (";
     interleaveComma(op.operands(), p,
-                    [&](Value operand) { p.printType(operand->getType()); });
+                    [&](Value operand) { p.printType(operand.getType()); });
     p << ")";
     p << " -> ";
     if (op.getNumResults() > 1) p << "(";
@@ -498,7 +498,7 @@
     p << ") = ";
     p.printOperand(operand);
     p << " : ";
-    p.printType(operand->getType());
+    p.printType(operand.getType());
   });
   p << ") ";
 
@@ -565,7 +565,7 @@
   p << "[";
   p.printOperand(op.workload());
   p << " : ";
-  p.printType(op.workload()->getType());
+  p.printType(op.workload().getType());
   p << "]";
 
   p << "(";
@@ -574,7 +574,7 @@
   if (op.getNumResults() > 0) {
     p << " : (";
     interleaveComma(op.operands(), p,
-                    [&](Value operand) { p.printType(operand->getType()); });
+                    [&](Value operand) { p.printType(operand.getType()); });
     p << ")";
     p << " -> (";
     interleaveComma(op.getResultTypes(), p);
@@ -593,7 +593,7 @@
     p << ") = ";
     p.printOperand(operand);
     p << " : ";
-    p.printType(operand->getType());
+    p.printType(operand.getType());
   });
   p << ") ";
 
@@ -851,7 +851,7 @@
   p << "[";
   p.printOperand(op.workload());
   p << " : ";
-  p.printType(op.workload()->getType());
+  p.printType(op.workload().getType());
   p << "](";
   p.printOperands(op.operands());
   p << ')';
@@ -896,9 +896,9 @@
   p << op.getOperationName() << ' ';
   p.printOperand(op.source());
   p << " : ";
-  p.printType(op.source()->getType());
+  p.printType(op.source().getType());
   p << " -> ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
 }
 
@@ -936,7 +936,7 @@
     p << ']';
   }
   p << " : ";
-  p.printType(op.source()->getType());
+  p.printType(op.source().getType());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
 }
 
@@ -981,7 +981,7 @@
     p << ']';
   }
   p << " : ";
-  p.printType(op.target()->getType());
+  p.printType(op.target().getType());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
 }
 
@@ -1008,7 +1008,7 @@
   p << op.getOperationName() << ' ';
   p.printOperand(op.value());
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
 }
 
@@ -1034,7 +1034,7 @@
   p << op.getOperationName() << ' ';
   p.printOperand(op.operand());
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
 }
 
@@ -1082,9 +1082,9 @@
   p << " for ";
   p.printOperands(op.lengths());
   p << "] : ";
-  p.printType(op.source()->getType());
+  p.printType(op.source().getType());
   p << " -> ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
 }
 
@@ -1128,9 +1128,9 @@
   p << '[';
   p.printOperands(op.start_indices());
   p << "] : ";
-  p.printType(op.update()->getType());
+  p.printType(op.update().getType());
   p << " -> ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
 }
 
@@ -1199,9 +1199,9 @@
   p << "(";
   interleaveComma(llvm::zip(op.body().front().getArguments(), op.args()), p,
                   [&](std::tuple<BlockArgument, Value> it) {
-                    p << *std::get<0>(it) << " = " << *std::get<1>(it);
+                    p << std::get<0>(it) << " = " << std::get<1>(it);
                     p << " : ";
-                    p << std::get<1>(it)->getType();
+                    p << std::get<1>(it).getType();
                   });
   p << ")";
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/AssignExecutableWorkloads.cpp b/iree/compiler/Dialect/Flow/Transforms/AssignExecutableWorkloads.cpp
index 3d38268..0b65592 100644
--- a/iree/compiler/Dialect/Flow/Transforms/AssignExecutableWorkloads.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/AssignExecutableWorkloads.cpp
@@ -44,7 +44,7 @@
     funcOp.walk([&](DispatchOp op) {
       auto &workloadInfo = workloadInfos[op.executable()][op.entry_point()];
       if (auto constantOp =
-              dyn_cast<ConstantOp>(op.workload()->getDefiningOp())) {
+              dyn_cast<ConstantOp>(op.workload().getDefiningOp())) {
         for (auto existingWorkloadAttr : workloadInfo.staticWorkloads) {
           if (existingWorkloadAttr == constantOp.value()) {
             return;  // Already present, ignore.
diff --git a/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
index 7c6c9f6..91d7959 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
@@ -67,7 +67,7 @@
 bool recursiveUntuple(Value value, Location loc, OpBuilder &builder,
                       BlockAndValueMapping *mapping,
                       llvm::SmallVectorImpl<Value> *newValues) {
-  Type type = value->getType();
+  Type type = value.getType();
   // We can return the value as is.
   if (!type.isa<TupleType>()) {
     newValues->push_back(value);
@@ -150,8 +150,8 @@
   auto newResults = newOp.getResults();
   for (auto oldResult : oldOp->getResults()) {
     llvm::SmallVector<Value, 10> subValues;
-    auto newResult = recursiveRetuple(oldResult->getType(), &newResults,
-                                      builder, oldOp->getLoc());
+    auto newResult = recursiveRetuple(oldResult.getType(), &newResults, builder,
+                                      oldOp->getLoc());
     mapping->map(oldResult, newResult);
   }
 
@@ -252,9 +252,9 @@
     auto *newBlock = builder.createBlock(&newFunction.getBody());
     for (auto oldArg : oldBlock.getArguments()) {
       llvm::SmallVector<Type, 4> newTypes;
-      untupleTypes(oldArg->getType(), &newTypes);
+      untupleTypes(oldArg.getType(), &newTypes);
 
-      Value newTuple = processTuple(oldArg->getType(), oldFunction.getLoc(),
+      Value newTuple = processTuple(oldArg.getType(), oldFunction.getLoc(),
                                     newBlock, builder);
       if (!newTuple) {
         return true;
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
index 2f78ccd..14a397c 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -76,7 +76,7 @@
   SmallVector<Type, 8> resultTypes;
   resultTypes.append(regionOp.result_type_begin(), regionOp.result_type_end());
   for (auto newResult : newResults) {
-    resultTypes.push_back(newResult->getType());
+    resultTypes.push_back(newResult.getType());
   }
   auto newRegionOp = builder.create<DispatchRegionOp>(
       fusedLoc, resultTypes, regionOp.workload(), operands,
@@ -85,7 +85,7 @@
 
   // Replace uses of original values with the new values.
   for (int i = 0; i < regionOp.getNumResults(); ++i) {
-    regionOp.getResult(i)->replaceAllUsesWith(newRegionOp.getResult(i));
+    regionOp.getResult(i).replaceAllUsesWith(newRegionOp.getResult(i));
   }
 
   // Erase the original region.
@@ -112,9 +112,9 @@
   SmallVector<Value, 8> newRegionResults;
   for (int i = 0; i < returnOp.getNumOperands(); ++i) {
     auto resultValue = regionOp.getResult(i);
-    if (!resultValue->use_empty()) {
+    if (!resultValue.use_empty()) {
       // Still has uses so we will preserve it.
-      newReturnTypes.push_back(resultValue->getType());
+      newReturnTypes.push_back(resultValue.getType());
       newReturnValues.push_back(returnOp.getOperand(i));
       newRegionResults.push_back(resultValue);
     }
@@ -135,7 +135,7 @@
 
   // Replace uses of original values with the new values.
   for (int i = 0; i < newRegionResults.size(); ++i) {
-    newRegionResults[i]->replaceAllUsesWith(newRegionOp.getResult(i));
+    newRegionResults[i].replaceAllUsesWith(newRegionOp.getResult(i));
   }
 
   // Erase the original region.
@@ -154,16 +154,16 @@
 
 // Returns true if |value| depends in any way on |op| through any path.
 bool doesValueDependOnOperation(Value value, Operation *op) {
-  if (!value->getDefiningOp()) {
+  if (!value.getDefiningOp()) {
     return false;
-  } else if (value->getDefiningOp() == op) {
+  } else if (value.getDefiningOp() == op) {
     return true;
-  } else if (value->getDefiningOp()->getBlock() == op->getBlock() &&
-             value->getDefiningOp()->isBeforeInBlock(op)) {
+  } else if (value.getDefiningOp()->getBlock() == op->getBlock() &&
+             value.getDefiningOp()->isBeforeInBlock(op)) {
     // Can't depend on |op| as it is defined prior to it.
     return false;
   }
-  for (auto operand : value->getDefiningOp()->getOperands()) {
+  for (auto operand : value.getDefiningOp()->getOperands()) {
     if (doesValueDependOnOperation(operand, op)) {
       return true;
     }
@@ -177,7 +177,7 @@
 bool areDispatchRegionsTransitivelyDependent(DispatchRegionOp &lhs,
                                              DispatchRegionOp &rhs) {
   for (auto arg : rhs.args()) {
-    if (arg->getDefiningOp() != lhs && doesValueDependOnOperation(arg, lhs)) {
+    if (arg.getDefiningOp() != lhs && doesValueDependOnOperation(arg, lhs)) {
       // Transitively dependent - boo - can't merge yet.
       return true;
     }
@@ -257,7 +257,7 @@
     if (!didElide) {
       // Add to the lhs block.
       auto oldArg = rhs.getOperand(rhsOpIdx + 1);
-      auto newArg = lhsBlock.addArgument(oldArg->getType());
+      auto newArg = lhsBlock.addArgument(oldArg.getType());
       mapping.map(rhsBlock.getArgument(rhsOpIdx), newArg);
       newArgs.push_back(oldArg);
     }
@@ -291,7 +291,7 @@
 
   // Replace uses of original values with the new values.
   for (int i = 0; i < rhs.getNumResults(); ++i) {
-    rhs.getResult(i)->replaceAllUsesWith(
+    rhs.getResult(i).replaceAllUsesWith(
         newRegionOp.getResult(lhsReturnValues.size() + i));
   }
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp
index 51bfbc4..f7d75b4 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp
@@ -95,8 +95,8 @@
         auto *nextOp = markList.pop_back_val();
         if (!currentOutsideOps.insert(nextOp)) continue;
         for (auto operand : nextOp->getOperands()) {
-          if (operand->getDefiningOp()) {
-            markList.insert(operand->getDefiningOp());
+          if (operand.getDefiningOp()) {
+            markList.insert(operand.getDefiningOp());
           }
         }
       }
@@ -130,7 +130,7 @@
       // Recursively work through the inputs of the op to pull in any
       // dependencies that we are able to (are flow ops, have no side-effects).
       for (auto operand : op->getOperands()) {
-        auto *depOp = operand->getDefiningOp();
+        auto *depOp = operand.getDefiningOp();
         if (!depOp) {
           // Op is a block arg.
           continue;
@@ -213,15 +213,15 @@
       for (auto operand : op->getOperands()) {
         if (std::find(fragmentOperands.begin(), fragmentOperands.end(),
                       operand) == fragmentOperands.end()) {
-          if (!operand->getDefiningOp() ||
-              !streamOpSet.count(operand->getDefiningOp())) {
+          if (!operand.getDefiningOp() ||
+              !streamOpSet.count(operand.getDefiningOp())) {
             fragmentOperands.push_back(operand);
           }
         }
       }
       for (auto result : op->getResults()) {
         bool onlyStreamUses = true;
-        for (auto &use : result->getUses()) {
+        for (auto &use : result.getUses()) {
           if (!streamOpSet.count(use.getOwner())) {
             onlyStreamUses = false;
             break;
@@ -229,7 +229,7 @@
         }
         if (!onlyStreamUses) {
           fragmentResults.push_back(result);
-          fragmentResultTypes.push_back(result->getType());
+          fragmentResultTypes.push_back(result.getType());
         }
       }
     }
@@ -242,7 +242,7 @@
     entryBlock->addArguments(llvm::to_vector<8>(fragmentOp.getOperandTypes()));
     BlockAndValueMapping mapping;
     for (auto arg : entryBlock->getArguments()) {
-      mapping.map(fragmentOperands[arg->getArgNumber()], arg);
+      mapping.map(fragmentOperands[arg.getArgNumber()], arg);
     }
     OpBuilder fragmentBuilder(entryBlock);
     for (auto *op : streamOps) {
@@ -257,7 +257,7 @@
          llvm::zip(fragmentResults, fragmentOp.getResults())) {
       auto oldValue = std::get<0>(resultOldNew);
       auto newValue = std::get<1>(resultOldNew);
-      oldValue->replaceAllUsesWith(newValue);
+      oldValue.replaceAllUsesWith(newValue);
     }
 
     // Erase the ops from the block now that we've cloned them.
diff --git a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp
index daa4e06..94d6729 100644
--- a/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/IdentifyDispatchRegions.cpp
@@ -72,7 +72,7 @@
     // as we do here.
     return false;
   } else if (op->getNumResults() &&
-             !op->getResult(0)->getType().isa<ShapedType>()) {
+             !op->getResult(0).getType().isa<ShapedType>()) {
     // We don't put scalar manipulation into dispatch regions.
     return false;
   } else if (!isOpOfKnownDialect(op)) {
@@ -128,8 +128,8 @@
                      llvm::SetVector<Operation *> *subgraph) {
   // Skip ops that are used outside of the subgraph we are building.
   for (auto result : op->getResults()) {
-    if (result->use_empty() || result->hasOneUse()) continue;
-    for (auto *user : result->getUsers()) {
+    if (result.use_empty() || result.hasOneUse()) continue;
+    for (auto *user : result.getUsers()) {
       if (subgraph->count(user) == 0) {
         // Op that consumes the result is not (yet) in the subgraph.
         // For now we'll ignore these as it may represent a fork that we don't
@@ -141,7 +141,7 @@
 
   // Walk backward up to ops providing our input operands.
   for (auto operand : op->getOperands()) {
-    auto *sourceOp = operand->getDefiningOp();
+    auto *sourceOp = operand.getDefiningOp();
     if (!sourceOp) continue;
     if (subgraph->count(sourceOp) == 0) {
       if (isDispatchableOp(sourceOp, dispatchability) &&
@@ -191,7 +191,7 @@
       // Compute the workload based on the output shape.
       // When variadic all output shapes match so we can just take the first.
       auto workload = calculateWorkload(
-          &rootOp, rootOp.getResult(0)->getType().cast<ShapedType>());
+          &rootOp, rootOp.getResult(0).getType().cast<ShapedType>());
 
       // Try to build a dispatch region from this root.
       if (failed(buildDispatchRegion(func, block, workload, fusedSubgraph))) {
diff --git a/iree/compiler/Dialect/Flow/Transforms/IdentifyReductionRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/IdentifyReductionRegions.cpp
index 09ddedf..c998741 100644
--- a/iree/compiler/Dialect/Flow/Transforms/IdentifyReductionRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/IdentifyReductionRegions.cpp
@@ -60,7 +60,7 @@
   // Compute the workload based on the output shape.
   // When variadic all output shapes match so we can just take the first.
   auto workload = calculateWorkload(
-      originalOp, originalOp->getResult(0)->getType().cast<ShapedType>());
+      originalOp, originalOp->getResult(0).getType().cast<ShapedType>());
 
   // Build the region op and add it to the parent block.
   SmallVector<Type, 4> resultTypes{originalOp->getResultTypes()};
@@ -82,8 +82,7 @@
 
   // Replace usage of values with the results of the region.
   for (int i = 0; i < originalOp->getNumResults(); ++i) {
-    originalOp->getResult(i)->replaceAllUsesWith(
-        reductionRegionOp.getResult(i));
+    originalOp->getResult(i).replaceAllUsesWith(reductionRegionOp.getResult(i));
   }
 
   return success();
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
index 3a61c4c..7d5f501 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp
@@ -43,7 +43,7 @@
 
   // Replace uses of the existing results with the new results.
   for (int i = 0; i < regionOp.getNumResults(); ++i) {
-    regionOp.getResult(i)->replaceAllUsesWith(dispatchOp.getResult(i));
+    regionOp.getResult(i).replaceAllUsesWith(dispatchOp.getResult(i));
   }
 
   // Erase original region.
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp
index f12f4c3..b9d2f2c 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineReductionRegions.cpp
@@ -32,7 +32,7 @@
 SmallVector<int64_t, 4> calculateResultShape(Value input, int windowDimension) {
   SmallVector<int64_t, 4> resultShape;
   for (auto it :
-       llvm::enumerate(input->getType().cast<ShapedType>().getShape())) {
+       llvm::enumerate(input.getType().cast<ShapedType>().getShape())) {
     if (it.index() != windowDimension) {
       resultShape.push_back(it.value());
     }
@@ -86,9 +86,9 @@
   SmallVector<Type, 8> elementalResultTypes;
   for (auto arg : regionOp.initial_values()) {
     // (in0, in1) -> out0
-    elementalOperandTypes.push_back(arg->getType());
-    elementalOperandTypes.push_back(arg->getType());
-    elementalResultTypes.push_back(arg->getType());
+    elementalOperandTypes.push_back(arg.getType());
+    elementalOperandTypes.push_back(arg.getType());
+    elementalResultTypes.push_back(arg.getType());
   }
   auto elementalFunctionType = FunctionType::get(
       elementalOperandTypes, elementalResultTypes, regionOp.getContext());
@@ -106,10 +106,10 @@
   // dimension.
   SmallVector<Type, 8> allOperandTypes;
   auto inputTypes =
-      llvm::map_range(inputs, [](Value value) { return value->getType(); });
+      llvm::map_range(inputs, [](Value value) { return value.getType(); });
   allOperandTypes.append(inputTypes.begin(), inputTypes.end());
   auto initialValueTypes = llvm::map_range(
-      initialValues, [](Value value) { return value->getType(); });
+      initialValues, [](Value value) { return value.getType(); });
   allOperandTypes.append(initialValueTypes.begin(), initialValueTypes.end());
   SmallVector<Type, 4> resultTypes;
   for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) {
@@ -182,7 +182,7 @@
 
   // Replace uses of the existing results with the new results.
   for (int i = 0; i < regionOp.getNumResults(); ++i) {
-    regionOp.getResult(i)->replaceAllUsesWith(temps[i]);
+    regionOp.getResult(i).replaceAllUsesWith(temps[i]);
   }
 
   // Erase original region.
@@ -206,9 +206,9 @@
   SmallVector<Type, 8> elementalResultTypes;
   for (auto arg : regionOp.initial_values()) {
     // (in0, in1) -> out0
-    elementalOperandTypes.push_back(arg->getType());
-    elementalOperandTypes.push_back(arg->getType());
-    elementalResultTypes.push_back(arg->getType());
+    elementalOperandTypes.push_back(arg.getType());
+    elementalOperandTypes.push_back(arg.getType());
+    elementalResultTypes.push_back(arg.getType());
   }
   auto elementalFunctionType = FunctionType::get(
       elementalOperandTypes, elementalResultTypes, regionOp.getContext());
@@ -226,10 +226,10 @@
   // dimension.
   SmallVector<Type, 8> allOperandTypes;
   auto inputTypes =
-      llvm::map_range(inputs, [](Value value) { return value->getType(); });
+      llvm::map_range(inputs, [](Value value) { return value.getType(); });
   allOperandTypes.append(inputTypes.begin(), inputTypes.end());
   auto initialValueTypes = llvm::map_range(
-      initialValues, [](Value value) { return value->getType(); });
+      initialValues, [](Value value) { return value.getType(); });
   allOperandTypes.append(initialValueTypes.begin(), initialValueTypes.end());
   SmallVector<Type, 4> resultTypes;
   for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) {
@@ -324,7 +324,7 @@
 
   // Replace uses of the existing results with the new results.
   for (int i = 0; i < regionOp.getNumResults(); ++i) {
-    regionOp.getResult(i)->replaceAllUsesWith(temps[i]);
+    regionOp.getResult(i).replaceAllUsesWith(temps[i]);
   }
 
   // Erase original region.
diff --git a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
index cff2005..f5fbf62 100644
--- a/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/RematerializeDispatchConstants.cpp
@@ -77,7 +77,7 @@
   // coming from the same source operation.
   SmallPtrSet<Operation *, 4> operandOps;
   for (auto operand : sourceOp->getOperands()) {
-    operandOps.insert(operand->getDefiningOp());
+    operandOps.insert(operand.getDefiningOp());
   }
   for (auto *operandOp : operandOps) {
     recursivelyCloneOp(operandOp, builder, mapping);
@@ -99,7 +99,7 @@
 
   OpBuilder builder(targetBlock);
   builder.setInsertionPointToStart(targetBlock);
-  auto *sourceOp = sourceValue->getDefiningOp();
+  auto *sourceOp = sourceValue.getDefiningOp();
   auto *clonedOp = recursivelyCloneOp(sourceOp, builder, mapping);
 
   // Return only the result matching our source value (in the case of multiple
@@ -134,7 +134,7 @@
 
   // Replace all uses of the inner operand with the new value.
   for (unsigned argIndex : argIndices) {
-    entryBlock.getArgument(argIndex)->replaceAllUsesWith(clonedValue);
+    entryBlock.getArgument(argIndex).replaceAllUsesWith(clonedValue);
   }
 
   // Remove the dispatch region args and the block args that have been
@@ -154,7 +154,7 @@
 LogicalResult rematerializeConstantInDispatchRegions(ConstantOp constantOp) {
   Value constantValue = constantOp.getResult();
   SmallVector<DispatchRegionOp, 4> usingRegionOps;
-  for (auto *user : constantValue->getUsers()) {
+  for (auto *user : constantValue.getUsers()) {
     if (auto dispatchRegionOp = dyn_cast<DispatchRegionOp>(user)) {
       // Ensure this isn't just the workload and is used as an arg.
       if (std::find(dispatchRegionOp.args().begin(),
diff --git a/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
index 797c18d..911d58b 100644
--- a/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Utils/DispatchUtils.cpp
@@ -46,13 +46,13 @@
     llvm::SetVector<Value> *escapingValues) {
   for (auto *op : opSet) {
     for (auto value : op->getOperands()) {
-      if (!llvm::is_contained(opSet, value->getDefiningOp())) {
+      if (!llvm::is_contained(opSet, value.getDefiningOp())) {
         // Op is using a value not in the ops set, ensure we capture it.
         capturedValues->insert(value);
       }
     }
     for (auto value : op->getResults()) {
-      for (auto &use : value->getUses()) {
+      for (auto &use : value.getUses()) {
         if (!llvm::is_contained(opSet, use.getOwner())) {
           // An op outside of the ops set is using the value, needs to escape.
           escapingValues->insert(value);
@@ -85,7 +85,7 @@
     return failure();
   }
   SmallVector<Type, 8> escapingTypes;
-  for (auto value : escapingValues) escapingTypes.push_back(value->getType());
+  for (auto value : escapingValues) escapingTypes.push_back(value.getType());
 
   // Build the region op and add it to the parent block.
   OpBuilder parentBuilder(parentBlock);
@@ -99,7 +99,7 @@
   OpBuilder regionBuilder(regionBlock);
   BlockAndValueMapping mapping;
   for (auto capturedValue : capturedValues) {
-    auto blockArg = regionBlock->addArgument(capturedValue->getType());
+    auto blockArg = regionBlock->addArgument(capturedValue.getType());
     mapping.map(capturedValue, blockArg);
   }
 
@@ -120,7 +120,7 @@
 
   // Replace usage of values with the results of the region.
   for (int i = 0; i < escapingValues.size(); ++i) {
-    escapingValues[i]->replaceAllUsesWith(dispatchRegionOp.getResult(i));
+    escapingValues[i].replaceAllUsesWith(dispatchRegionOp.getResult(i));
   }
 
   // Remove original ops from the parent region.
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
index 13b00f3..f85bb29 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
@@ -68,22 +68,22 @@
 
   // Compute the allocation size for the value.
   int elementSize = IREE::HAL::getRoundedElementByteWidth(
-      streamValue->getType().cast<ShapedType>().getElementType());
+      streamValue.getType().cast<ShapedType>().getElementType());
   auto shape = IREE::HAL::getShapeDims(streamValue, rewriter);
   auto allocationSize = rewriter
                             .create<IREE::HAL::AllocatorComputeSizeOp>(
-                                externalValue->getLoc(), allocator, memoryTypes,
+                                externalValue.getLoc(), allocator, memoryTypes,
                                 bufferUsage, shape, elementSize)
                             .getResult();
 
   auto buffer = rewriter
                     .create<IREE::HAL::AllocatorAllocateOp>(
-                        externalValue->getLoc(), allocator, memoryTypes,
+                        externalValue.getLoc(), allocator, memoryTypes,
                         bufferUsage, allocationSize)
                     .getResult();
 
   // TODO(benvanik): implement resource sets.
-  rewriter.create<IREE::HAL::ExDeferReleaseOp>(externalValue->getLoc(), buffer);
+  rewriter.create<IREE::HAL::ExDeferReleaseOp>(externalValue.getLoc(), buffer);
 
   return buffer;
 }
@@ -119,22 +119,22 @@
 
   // Compute the allocation size for the value.
   int elementSize = IREE::HAL::getRoundedElementByteWidth(
-      streamValue->getType().cast<ShapedType>().getElementType());
+      streamValue.getType().cast<ShapedType>().getElementType());
   auto shape = IREE::HAL::getShapeDims(streamValue, rewriter);
   auto allocationSize = rewriter
                             .create<IREE::HAL::AllocatorComputeSizeOp>(
-                                streamValue->getLoc(), allocator, memoryTypes,
+                                streamValue.getLoc(), allocator, memoryTypes,
                                 bufferUsage, shape, elementSize)
                             .getResult();
 
   auto buffer = rewriter
                     .create<IREE::HAL::AllocatorAllocateOp>(
-                        streamValue->getLoc(), allocator, memoryTypes,
+                        streamValue.getLoc(), allocator, memoryTypes,
                         bufferUsage, allocationSize)
                     .getResult();
 
   // TODO(benvanik): implement resource sets.
-  rewriter.create<IREE::HAL::ExDeferReleaseOp>(streamValue->getLoc(), buffer);
+  rewriter.create<IREE::HAL::ExDeferReleaseOp>(streamValue.getLoc(), buffer);
 
   return buffer;
 }
@@ -210,7 +210,7 @@
                                ConversionPatternRewriter &rewriter) {
   int bindingOrdinal = 0;
   auto pushBinding = [&](Value tensorValue) {
-    auto tensorType = tensorValue->getType().cast<ShapedType>();
+    auto tensorType = tensorValue.getType().cast<ShapedType>();
     auto shape = IREE::HAL::getShapeDims(tensorValue, rewriter);
     int elementSize =
         IREE::HAL::getRoundedElementByteWidth(tensorType.getElementType());
@@ -278,7 +278,7 @@
       updateOp.getLoc(), rewriter.getI32IntegerAttr(0));
 
   // Compute the size of the update range.
-  auto updateType = updateOp.update()->getType().cast<ShapedType>();
+  auto updateType = updateOp.update().getType().cast<ShapedType>();
   int elementSize =
       IREE::HAL::getRoundedElementByteWidth(updateType.getElementType());
   auto targetShape = IREE::HAL::getShapeDims(updateOp.target(), rewriter);
@@ -366,7 +366,7 @@
     // Remap non-tensor operands (such as workloads).
     auto &entryBlock = streamOp.body().front();
     for (int i = 0; i < operands.size(); ++i) {
-      if (operands[i]->getType().isa<IREE::RefPtrType>()) {
+      if (operands[i].getType().isa<IREE::RefPtrType>()) {
         bufferSet.rangeMap[entryBlock.getArgument(i)] =
             BufferRange{operands[i]};
       } else {
@@ -406,7 +406,7 @@
     // otherwise we lose access to the original values (which we need for
     // shape information).
     for (int i = 0; i < operands.size(); ++i) {
-      if (operands[i]->getType().isa<IREE::RefPtrType>()) {
+      if (operands[i].getType().isa<IREE::RefPtrType>()) {
         rewriter.replaceUsesOfBlockArgument(entryBlock.getArgument(i),
                                             operands[i]);
       }
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp
index 7d658ba..257654b 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp
@@ -78,14 +78,14 @@
       IREE::Flow::TensorLoadOp loadOp, llvm::ArrayRef<Value> newOperands,
       ConversionPatternRewriter &rewriter) const override {
     IREE::Flow::TensorLoadOpOperandAdaptor operands(newOperands);
-    auto sourceType = loadOp.source()->getType().cast<ShapedType>();
+    auto sourceType = loadOp.source().getType().cast<ShapedType>();
     auto sourceShape = IREE::HAL::getShapeDims(loadOp.source(), rewriter);
     auto sourceOffset =
         rewriter.createOrFold<IREE::HAL::BufferViewComputeOffsetOp>(
             loadOp.getLoc(), operands.source(), sourceShape, operands.indices(),
             IREE::HAL::getRoundedElementByteWidth(sourceType.getElementType()));
     rewriter.replaceOpWithNewOp<IREE::HAL::BufferLoadOp>(
-        loadOp, converter.convertType(loadOp.result()->getType()),
+        loadOp, converter.convertType(loadOp.result().getType()),
         operands.source(), sourceOffset);
     return matchSuccess();
   }
@@ -104,7 +104,7 @@
       IREE::Flow::TensorStoreOp storeOp, llvm::ArrayRef<Value> newOperands,
       ConversionPatternRewriter &rewriter) const override {
     IREE::Flow::TensorStoreOpOperandAdaptor operands(newOperands);
-    auto targetType = storeOp.target()->getType().cast<ShapedType>();
+    auto targetType = storeOp.target().getType().cast<ShapedType>();
     auto targetShape = IREE::HAL::getShapeDims(storeOp.target(), rewriter);
     auto targetOffset =
         rewriter.createOrFold<IREE::HAL::BufferViewComputeOffsetOp>(
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
index b0a7ed3..c11bb3f 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
@@ -100,7 +100,7 @@
       ConversionPatternRewriter &rewriter) const override {
     // TODO(benvanik): multiple converted type results to multiple variables.
     rewriter.replaceOpWithNewOp<IREE::HAL::VariableLoadOp>(
-        loadOp, converter.convertType(loadOp.result()->getType()),
+        loadOp, converter.convertType(loadOp.result().getType()),
         rewriter.getSymbolRefAttr(loadOp.variable()));
     return matchSuccess();
   }
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp
index 3f2966d..d3db40a 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp
@@ -38,7 +38,7 @@
     auto sizeConst = rewriter.createOrFold<mlir::ConstantOp>(
         op.getLoc(),
         rewriter.getI32IntegerAttr(
-            IREE::HAL::getRoundedElementByteWidth(op.getResult()->getType())));
+            IREE::HAL::getRoundedElementByteWidth(op.getResult().getType())));
     rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
         op, rewriter.getSymbolRefAttr(importOp), importType.getResults(),
         ArrayRef<Value>{adaptor.source_buffer(), adaptor.source_offset(),
@@ -68,7 +68,7 @@
     auto sizeConst = rewriter.createOrFold<mlir::ConstantOp>(
         op.getLoc(),
         rewriter.getI32IntegerAttr(
-            IREE::HAL::getRoundedElementByteWidth(op.value()->getType())));
+            IREE::HAL::getRoundedElementByteWidth(op.value().getType())));
     rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
         op, rewriter.getSymbolRefAttr(importOp), importType.getResults(),
         ArrayRef<Value>{adaptor.value(), adaptor.target_buffer(),
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
index 4fb6700..c9454bc 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
@@ -78,9 +78,9 @@
       return matchFailure();
     }
     for (auto memoryBarrier : op.memory_barriers()) {
-      assert(memoryBarrier->getDefiningOp());
+      assert(memoryBarrier.getDefiningOp());
       auto makeMemoryBarrierOp =
-          cast<IREE::HAL::MakeMemoryBarrierOp>(memoryBarrier->getDefiningOp());
+          cast<IREE::HAL::MakeMemoryBarrierOp>(memoryBarrier.getDefiningOp());
       callOperands.push_back(rewriter.create<mlir::ConstantOp>(
           op.getLoc(), rewriter.getI32IntegerAttr(static_cast<int32_t>(
                            makeMemoryBarrierOp.source_scope()))));
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
index fd776d5..c4d3df8 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
@@ -226,7 +226,7 @@
     rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
         cacheExecutableOp,
         rewriter.getSymbolRefAttr(cacheExecutableOp.executable()),
-        ArrayRef<Type>{cacheExecutableOp.getResult()->getType()},
+        ArrayRef<Type>{cacheExecutableOp.getResult().getType()},
         ArrayRef<Value>{operands[0]});
     return matchSuccess();
   }
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index 2396b68..5297f24 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -79,7 +79,7 @@
 
   PatternMatchResult matchAndRewrite(VariableLoadOp op,
                                      PatternRewriter &rewriter) const override {
-    if (op.result()->use_empty()) {
+    if (op.result().use_empty()) {
       rewriter.eraseOp(op);
       return matchSuccess();
     }
@@ -106,7 +106,7 @@
   PatternMatchResult matchAndRewrite(VariableStoreOp op,
                                      PatternRewriter &rewriter) const override {
     if (auto loadOp =
-            dyn_cast_or_null<VariableLoadOp>(op.value()->getDefiningOp())) {
+            dyn_cast_or_null<VariableLoadOp>(op.value().getDefiningOp())) {
       if (loadOp.variable() == op.variable()) {
         rewriter.eraseOp(op);
         return matchSuccess();
@@ -139,7 +139,7 @@
   PatternMatchResult matchAndRewrite(AllocatorAllocateOp op,
                                      PatternRewriter &rewriter) const override {
     if (auto computeSizeOp = dyn_cast_or_null<AllocatorComputeSizeOp>(
-            op.allocation_size()->getDefiningOp())) {
+            op.allocation_size().getDefiningOp())) {
       if (op.memory_types() == computeSizeOp.memory_types() &&
           op.buffer_usage() == computeSizeOp.buffer_usage()) {
         rewriter.replaceOpWithNewOp<AllocatorAllocateShapedOp>(
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 14731d7..10b0ba4 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -90,7 +90,7 @@
   p << op.getOperationName();
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -129,7 +129,7 @@
   p.printOptionalAttrDictWithKeyword(op.getAttrs(),
                                      /*elidedAttrs=*/{"executable"});
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -224,7 +224,7 @@
   p.printOptionalAttrDictWithKeyword(op.getAttrs(),
                                      /*elidedAttrs=*/{"executable"});
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -248,7 +248,7 @@
   p << op.getOperationName() << ' ';
   p.printOperand(op.operand());
   p << " : ";
-  p.printType(op.operand()->getType());
+  p.printType(op.operand().getType());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
 }
 
@@ -325,7 +325,7 @@
       op.getAttrs(),
       /*elidedAttrs=*/{"source_scope", "target_scope"});
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -389,7 +389,7 @@
       op.getAttrs(),
       /*elidedAttrs=*/{"source_scope", "target_scope"});
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -560,7 +560,7 @@
   p.printSymbolName(op.variable());
   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"variable"});
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 static LogicalResult verifyVariableLoadOp(VariableLoadOp &op) {
@@ -569,7 +569,7 @@
     return op.emitOpError() << "undefined variable: " << op.variable();
   }
   auto variableOp = dyn_cast<VariableOp>(symbolOp);
-  auto loadType = op.result()->getType();
+  auto loadType = op.result().getType();
   if (!isVariableTypeCompatible(variableOp.type(), loadType)) {
     return op.emitOpError()
            << "variable type mismatch; variable " << op.variable() << " is "
@@ -605,7 +605,7 @@
   p.printSymbolName(op.variable());
   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"variable"});
   p << " : ";
-  p.printType(op.value()->getType());
+  p.printType(op.value().getType());
 }
 
 static LogicalResult verifyVariableStoreOp(VariableStoreOp &op) {
@@ -614,7 +614,7 @@
     return op.emitOpError() << "undefined variable: " << op.variable();
   }
   auto variableOp = dyn_cast<VariableOp>(symbolOp);
-  auto storeType = op.value()->getType();
+  auto storeType = op.value().getType();
   if (!isVariableTypeCompatible(variableOp.type(), storeType)) {
     return op.emitOpError()
            << "variable type mismatch; variable " << op.variable() << " is "
@@ -761,7 +761,7 @@
       op.getAttrs(),
       /*elidedAttrs=*/{"memory_types", "buffer_usage"});
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -824,7 +824,7 @@
       op.getAttrs(),
       /*elidedAttrs=*/{"memory_types", "buffer_usage", "value"});
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
   p << " = ";
   p.printAttribute(op.value());
 }
@@ -902,7 +902,7 @@
       op.getAttrs(),
       /*elidedAttrs=*/{"memory_types", "buffer_usage", "element_size"});
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -945,7 +945,7 @@
   p.printOperand(op.length());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1024,7 +1024,7 @@
   p.printOperand(op.length());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
   p << " : ";
-  p.printType(op.target_buffer()->getType());
+  p.printType(op.target_buffer().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1067,7 +1067,7 @@
   p.printOperand(op.length());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
   p << " : ";
-  p.printType(op.source_buffer()->getType());
+  p.printType(op.source_buffer().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1141,7 +1141,7 @@
   p << "[";
   p.printOperand(op.source_offset());
   p << "] : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
 }
 
@@ -1181,7 +1181,7 @@
   p << "[";
   p.printOperand(op.target_offset());
   p << "] : ";
-  p.printType(op.value()->getType());
+  p.printType(op.value().getType());
   p.printOptionalAttrDictWithKeyword(op.getAttrs(),
                                      /*elidedAttrs=*/{"element_size"});
 }
@@ -1510,7 +1510,7 @@
       op.getAttrs(),
       /*elidedAttrs=*/{"modes", "command_categories"});
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1996,7 +1996,7 @@
   p.printOperand(op.set_layout());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -2067,7 +2067,7 @@
   p.printOptionalAttrDictWithKeyword(op.getAttrs(),
                                      /*elidedAttrs=*/{"binding", "access"});
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -2142,7 +2142,7 @@
   p.printOperand(op.device());
   p.printOptionalAttrDictWithKeyword(op.getAttrs());
   p << " : ";
-  p.printType(op.result()->getType());
+  p.printType(op.result().getType());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/Target/LegacyUtil.cpp b/iree/compiler/Dialect/HAL/Target/LegacyUtil.cpp
index fb09011..9ce7276 100644
--- a/iree/compiler/Dialect/HAL/Target/LegacyUtil.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LegacyUtil.cpp
@@ -58,11 +58,11 @@
   OpBuilder entryBuilder(&entryBlock);
   entryBuilder.setInsertionPointToStart(&entryBlock);
   for (auto arg : entryBlock.getArguments()) {
-    Type oldType = arg->getType();
-    arg->setType(convertTypeToMemRef(legalizeType(oldType)));
+    Type oldType = arg.getType();
+    arg.setType(convertTypeToMemRef(legalizeType(oldType)));
     auto loadInputOp = entryBuilder.create<IREE::LoadInputOp>(
         dispatchEntryOp.getLoc(), oldType, arg);
-    arg->replaceAllUsesWith(loadInputOp.getResult());
+    arg.replaceAllUsesWith(loadInputOp.getResult());
     loadInputOp.setOperand(arg);
   }
 
diff --git a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
index 741d114..6c72a2b 100644
--- a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
+++ b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp
@@ -44,9 +44,8 @@
 SmallVector<Value, 4> getShapeDims(Value shapedValue,
                                    ConversionPatternRewriter &rewriter) {
   // TODO(benvanik): dynamic shape support.
-  return getStaticShapeDims(shapedValue->getLoc(),
-                            shapedValue->getType().cast<ShapedType>(),
-                            rewriter);
+  return getStaticShapeDims(shapedValue.getLoc(),
+                            shapedValue.getType().cast<ShapedType>(), rewriter);
 }
 
 }  // namespace HAL
diff --git a/iree/compiler/Dialect/IREE/IR/IREEOps.cpp b/iree/compiler/Dialect/IREE/IR/IREEOps.cpp
index 1baa276..5c48393 100644
--- a/iree/compiler/Dialect/IREE/IR/IREEOps.cpp
+++ b/iree/compiler/Dialect/IREE/IR/IREEOps.cpp
@@ -80,11 +80,11 @@
   printer << op->getName() << '(';
   printer.printOperand(inputValue);
   printer << " : ";
-  printer.printType(inputValue->getType());
+  printer.printType(inputValue.getType());
   printer << ") ";
   printer.printOptionalAttrDict(op->getAttrs());
   printer << " : ";
-  printer.printType(outputValue->getType());
+  printer.printType(outputValue.getType());
 }
 
 //===----------------------------------------------------------------------===//
@@ -112,11 +112,11 @@
   printer << op->getName() << '(';
   printer.printOperand(inputValue);
   printer << " : ";
-  printer.printType(inputValue->getType());
+  printer.printType(inputValue.getType());
   printer << ", ";
   printer.printOperand(outputValue);
   printer << " : ";
-  printer.printType(outputValue->getType());
+  printer.printType(outputValue.getType());
   printer << ") ";
   printer.printOptionalAttrDict(op->getAttrs());
 }
@@ -148,11 +148,11 @@
   printer << op->getName() << '(';
   printer.printOperand(storeReduceOp.src());
   printer << " : ";
-  printer.printType(storeReduceOp.src()->getType());
+  printer.printType(storeReduceOp.src().getType());
   printer << ", ";
   printer.printOperand(storeReduceOp.dst());
   printer << " : ";
-  printer.printType(storeReduceOp.dst()->getType());
+  printer.printType(storeReduceOp.dst().getType());
   printer << ", ";
   printer.printAttribute(storeReduceOp.reduction_fnAttr());
   printer << ") ";
diff --git a/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp b/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
index 757af75..853c1af 100644
--- a/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
+++ b/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
@@ -217,10 +217,10 @@
 
     // Allocate arguments first from left-to-right.
     for (auto blockArg : block->getArguments()) {
-      auto reg = registerUsage.allocateRegister(blockArg->getType());
+      auto reg = registerUsage.allocateRegister(blockArg.getType());
       if (!reg.hasValue()) {
         return funcOp.emitError() << "register allocation failed for block arg "
-                                  << blockArg->getArgNumber();
+                                  << blockArg.getArgNumber();
       }
       map_[blockArg] = reg.getValue();
     }
@@ -230,7 +230,7 @@
     // makes things really hard to read. Ideally an optimization pass that
     // removes unused block arguments would prevent this from happening.
     for (auto blockArg : block->getArguments()) {
-      if (blockArg->use_empty()) {
+      if (blockArg.use_empty()) {
         registerUsage.releaseRegister(map_[blockArg]);
       }
     }
@@ -242,13 +242,13 @@
         }
       }
       for (auto result : op.getResults()) {
-        auto reg = registerUsage.allocateRegister(result->getType());
+        auto reg = registerUsage.allocateRegister(result.getType());
         if (!reg.hasValue()) {
           return op.emitError() << "register allocation failed for result "
-                                << result->cast<OpResult>()->getResultNumber();
+                                << result.cast<OpResult>().getResultNumber();
         }
         map_[result] = reg.getValue();
-        if (result->use_empty()) {
+        if (result.use_empty()) {
           registerUsage.releaseRegister(reg.getValue());
         }
       }
diff --git a/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp b/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp
index f8f366a..782efb5 100644
--- a/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp
+++ b/iree/compiler/Dialect/VM/Analysis/ValueLiveness.cpp
@@ -67,17 +67,17 @@
     SmallVector<StringAttr, 8> valueNames;
     for (auto value : values) {
       std::string str;
-      if (auto blockArg = value->dyn_cast<BlockArgument>()) {
-        if (blockArg->getOwner()->isEntryBlock()) {
-          str = llvm::formatv("%arg{0}", blockArg->getArgNumber());
+      if (auto blockArg = value.dyn_cast<BlockArgument>()) {
+        if (blockArg.getOwner()->isEntryBlock()) {
+          str = llvm::formatv("%arg{0}", blockArg.getArgNumber());
         } else {
-          str = llvm::formatv("%bb{0}_arg{1}",
-                              blockOrdinals[blockArg->getOwner()],
-                              blockArg->getArgNumber());
+          str =
+              llvm::formatv("%bb{0}_arg{1}", blockOrdinals[blockArg.getOwner()],
+                            blockArg.getArgNumber());
         }
       } else {
         llvm::raw_string_ostream os(str);
-        value->print(os);
+        value.print(os);
         str = os.str();
       }
 
@@ -173,7 +173,7 @@
       }
       for (auto result : op.getResults()) {
         blockSets.defined.insert(result);
-        for (auto &use : result->getUses()) {
+        for (auto &use : result.getUses()) {
           if (use.getOwner()->getBlock() != &block) {
             // Value escapes this block.
             blockSets.liveOut.insert(result);
@@ -256,7 +256,7 @@
       } else {
         // Live out but not live in implies defined in the block.
         Operation *firstUse =
-            value->getDefiningOp() ? value->getDefiningOp() : &block.front();
+            value.getDefiningOp() ? value.getDefiningOp() : &block.front();
         addLiveRange(value, firstUse, &block.back());
       }
     }
@@ -265,7 +265,7 @@
     for (auto value : blockSets.liveIn) {
       if (blockSets.liveOut.count(value)) continue;
       Operation *lastUse = &block.front();
-      for (auto &use : value->getUses()) {
+      for (auto &use : value.getUses()) {
         if (use.getOwner()->getBlock() != &block) continue;
         if (lastUse->isBeforeInBlock(use.getOwner())) {
           lastUse = use.getOwner();
@@ -278,9 +278,9 @@
     for (auto value : blockSets.defined) {
       if (blockSets.liveOut.count(value)) continue;
       Operation *firstUse =
-          value->getDefiningOp() ? value->getDefiningOp() : &block.front();
+          value.getDefiningOp() ? value.getDefiningOp() : &block.front();
       Operation *lastUse = firstUse;
-      for (auto &use : value->getUses()) {
+      for (auto &use : value.getUses()) {
         if (use.getOwner()->getBlock() != &block) continue;
         if (lastUse->isBeforeInBlock(use.getOwner())) {
           lastUse = use.getOwner();
diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
index c3fd761..51fcc8e 100644
--- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
@@ -272,7 +272,7 @@
       ConversionPatternRewriter &rewriter) const override {
     SelectOpOperandAdaptor srcAdaptor(operands);
     IntegerType requiredType = IntegerType::get(kBits, srcOp.getContext());
-    if (srcAdaptor.true_value()->getType() != requiredType)
+    if (srcAdaptor.true_value().getType() != requiredType)
       return matchFailure();
 
     rewriter.replaceOpWithNewOp<IREE::VM::SelectI32Op>(
diff --git a/iree/compiler/Dialect/VM/IR/VMDialect.cpp b/iree/compiler/Dialect/VM/IR/VMDialect.cpp
index 140c66c..1e12d5b 100644
--- a/iree/compiler/Dialect/VM/IR/VMDialect.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMDialect.cpp
@@ -43,7 +43,7 @@
     llvm::raw_svector_ostream os(osBuffer);
 
     // TODO(b/143187291): tablegen this by adding a value name prefix field.
-    if (op->getResult(0)->getType().isa<VectorType>()) {
+    if (op->getResult(0).getType().isa<VectorType>()) {
       os << "v";
     }
     if (auto globalLoadOp = dyn_cast<GlobalLoadI32Op>(op)) {
@@ -66,7 +66,7 @@
       }
     } else if (auto rodataOp = dyn_cast<ConstRefRodataOp>(op)) {
       os << rodataOp.rodata();
-    } else if (op->getResult(0)->getType().isa<RefPtrType>()) {
+    } else if (op->getResult(0).getType().isa<RefPtrType>()) {
       os << "ref";
     } else if (isa<CmpEQI32Op>(op)) {
       os << "eq";
@@ -149,7 +149,7 @@
     // Replace the values directly with the return operands.
     assert(returnOp.getNumOperands() == valuesToReplace.size());
     for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
-      valuesToReplace[it.index()]->replaceAllUsesWith(it.value());
+      valuesToReplace[it.index()].replaceAllUsesWith(it.value());
     }
   }
 
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 9a34de5..ce8385c 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -203,7 +203,7 @@
 OpFoldResult ConstI32Op::fold(ArrayRef<Attribute> operands) { return value(); }
 
 OpFoldResult ConstI32ZeroOp::fold(ArrayRef<Attribute> operands) {
-  return IntegerAttr::get(getResult()->getType(), 0);
+  return IntegerAttr::get(getResult().getType(), 0);
 }
 
 OpFoldResult ConstRefZeroOp::fold(ArrayRef<Attribute> operands) {
@@ -744,10 +744,10 @@
   using OpRewritePattern<CondBranchOp>::OpRewritePattern;
   PatternMatchResult matchAndRewrite(CondBranchOp op,
                                      PatternRewriter &rewriter) const override {
-    if (!op.getCondition()->getDefiningOp()) {
+    if (!op.getCondition().getDefiningOp()) {
       return matchFailure();
     }
-    if (auto notOp = dyn_cast<NotI32Op>(op.getCondition()->getDefiningOp())) {
+    if (auto notOp = dyn_cast<NotI32Op>(op.getCondition().getDefiningOp())) {
       rewriter.replaceOpWithNewOp<CondBranchOp>(
           op, notOp.getOperand(), op.getFalseDest(), op.getFalseOperands(),
           op.getTrueDest(), op.getTrueOperands());
@@ -779,7 +779,7 @@
     // First check if the call is unused - this ensures we only do the symbol
     // lookup if we are actually going to use it.
     for (auto result : op.getResults()) {
-      if (!result->use_empty()) {
+      if (!result.use_empty()) {
         return matchFailure();
       }
     }
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.cpp b/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 2cfd0b5..ad1f3ed 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -524,7 +524,7 @@
   p.printSymbolName(op->getAttrOfType<FlatSymbolRefAttr>("global").getValue());
   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"global"});
   p << " : ";
-  p.printType(op->getResult(0)->getType());
+  p.printType(op->getResult(0).getType());
 }
 
 static LogicalResult verifyGlobalLoadOp(Operation *op) {
@@ -535,7 +535,7 @@
     return op->emitOpError() << "Undefined global: " << globalAttr;
   }
   auto globalType = globalOp->getAttrOfType<TypeAttr>("type");
-  auto loadType = op->getResult(0)->getType();
+  auto loadType = op->getResult(0).getType();
   if (globalType.getValue() != loadType) {
     return op->emitOpError()
            << "Global type mismatch; global " << globalAttr << " is "
@@ -566,7 +566,7 @@
   p.printOperand(op->getOperand(0));
   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"global"});
   p << " : ";
-  p.printType(op->getOperand(0)->getType());
+  p.printType(op->getOperand(0).getType());
 }
 
 static LogicalResult verifyGlobalStoreOp(Operation *op) {
@@ -577,7 +577,7 @@
     return op->emitOpError() << "Undefined global: " << globalAttr;
   }
   auto globalType = globalOp->getAttrOfType<TypeAttr>("type");
-  auto storeType = op->getOperand(0)->getType();
+  auto storeType = op->getOperand(0).getType();
   if (globalType.getValue() != storeType) {
     return op->emitOpError()
            << "Global type mismatch; global " << globalAttr << " is "
@@ -730,7 +730,7 @@
 static void printConstI32ZeroOp(OpAsmPrinter &p, ConstI32ZeroOp &op) {
   p << op.getOperationName();
   p << " : ";
-  p.printType(op.getResult()->getType());
+  p.printType(op.getResult().getType());
   p.printOptionalAttrDict(op.getAttrs());
 }
 
@@ -755,7 +755,7 @@
 static void printConstRefZeroOp(OpAsmPrinter &p, ConstRefZeroOp &op) {
   p << op.getOperationName();
   p << " : ";
-  p.printType(op.getResult()->getType());
+  p.printType(op.getResult().getType());
   p.printOptionalAttrDict(op.getAttrs());
 }
 
@@ -808,7 +808,7 @@
   p.printSymbolName(op.rodata());
   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"rodata"});
   p << " : ";
-  p.printType(op.value()->getType());
+  p.printType(op.value().getType());
 }
 
 static LogicalResult verifyConstRefRodataOp(ConstRefRodataOp &op) {
@@ -860,10 +860,10 @@
 }
 
 static void printSelectOp(OpAsmPrinter &p, Operation *op) {
-  p << op->getName() << ' ' << *op->getOperand(0) << ", " << *op->getOperand(1)
-    << ", " << *op->getOperand(2);
+  p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1)
+    << ", " << op->getOperand(2);
   p.printOptionalAttrDict(op->getAttrs());
-  p << " : " << op->getResult(0)->getType();
+  p << " : " << op->getResult(0).getType();
 }
 
 //===----------------------------------------------------------------------===//
@@ -883,9 +883,9 @@
 }
 
 static void printUnaryArithmeticOp(OpAsmPrinter &p, Operation *op) {
-  p << op->getName() << ' ' << *op->getOperand(0);
+  p << op->getName() << ' ' << op->getOperand(0);
   p.printOptionalAttrDict(op->getAttrs());
-  p << " : " << op->getOperand(0)->getType();
+  p << " : " << op->getOperand(0).getType();
 }
 
 static ParseResult parseBinaryArithmeticOp(OpAsmParser &parser,
@@ -901,9 +901,9 @@
 }
 
 static void printBinaryArithmeticOp(OpAsmPrinter &p, Operation *op) {
-  p << op->getName() << ' ' << *op->getOperand(0) << ", " << *op->getOperand(1);
+  p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1);
   p.printOptionalAttrDict(op->getAttrs());
-  p << " : " << op->getResult(0)->getType();
+  p << " : " << op->getResult(0).getType();
 }
 
 //===----------------------------------------------------------------------===//
@@ -928,10 +928,10 @@
 }
 
 static void printShiftArithmeticOp(OpAsmPrinter &p, Operation *op) {
-  p << op->getName() << ' ' << *op->getOperand(0) << ", "
+  p << op->getName() << ' ' << op->getOperand(0) << ", "
     << op->getAttrOfType<IntegerAttr>("amount").getInt();
   p.printOptionalAttrDict(op->getAttrs(), {"amount"});
-  p << " : " << op->getResult(0)->getType();
+  p << " : " << op->getResult(0).getType();
 }
 
 //===----------------------------------------------------------------------===//
@@ -959,9 +959,9 @@
 }
 
 static void printUnaryComparisonOp(OpAsmPrinter &p, Operation *op) {
-  p << op->getName() << ' ' << *op->getOperand(0);
+  p << op->getName() << ' ' << op->getOperand(0);
   p.printOptionalAttrDict(op->getAttrs());
-  p << " : " << op->getOperand(0)->getType();
+  p << " : " << op->getOperand(0).getType();
 }
 
 static ParseResult parseBinaryComparisonOp(OpAsmParser &parser,
@@ -978,9 +978,9 @@
 }
 
 static void printBinaryComparisonOp(OpAsmPrinter &p, Operation *op) {
-  p << op->getName() << ' ' << *op->getOperand(0) << ", " << *op->getOperand(1);
+  p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1);
   p.printOptionalAttrDict(op->getAttrs());
-  p << " : " << op->getOperand(0)->getType();
+  p << " : " << op->getOperand(0).getType();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1240,7 +1240,7 @@
   p << ")";
   if (op.getNumResults() == 1) {
     p << " -> ";
-    p.printType(op.getResult(0)->getType());
+    p.printType(op.getResult(0).getType());
   } else if (op.getNumResults() > 1) {
     p << " -> (";
     interleaveComma(op.getResultTypes(), p);
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
index e5b5580d..94cbf7c 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
@@ -82,10 +82,10 @@
   }
 
   LogicalResult encodeType(Value value) override {
-    auto refPtrType = value->getType().dyn_cast<IREE::RefPtrType>();
+    auto refPtrType = value.getType().dyn_cast<IREE::RefPtrType>();
     if (!refPtrType) {
       return currentOp_->emitOpError()
-             << "type " << value->getType()
+             << "type " << value.getType()
              << " is not supported as a serialized type kind";
     }
     int typeOrdinal = typeTable_->lookup(refPtrType.getObjectType());
diff --git a/iree/compiler/Translation/Interpreter/IR/CommonOps.cpp b/iree/compiler/Translation/Interpreter/IR/CommonOps.cpp
index b0b0602..9c0308b 100644
--- a/iree/compiler/Translation/Interpreter/IR/CommonOps.cpp
+++ b/iree/compiler/Translation/Interpreter/IR/CommonOps.cpp
@@ -111,14 +111,14 @@
   p << "iree_interp.tensor_to_memref(";
   p.printOperand(op.getOperand());
   p << " : ";
-  p.printType(op.getOperand()->getType());
+  p.printType(op.getOperand().getType());
   p << ") : ";
   p.printType(op.getType());
 }
 
 OpFoldResult TensorToMemRefOp::fold(ArrayRef<Attribute> operands) {
   if (auto memrefToTensorOp = dyn_cast_or_null<IREEInterp::MemRefToTensorOp>(
-          getOperand()->getDefiningOp())) {
+          getOperand().getDefiningOp())) {
     return memrefToTensorOp.getOperand();
   }
 
@@ -127,7 +127,7 @@
 
 void TensorToMemRefOp::build(Builder *builder, OperationState &state,
                              Value arg) {
-  build(builder, state, convertTypeToMemRef(arg->getType()), arg);
+  build(builder, state, convertTypeToMemRef(arg.getType()), arg);
 }
 
 //===----------------------------------------------------------------------===//
@@ -154,14 +154,14 @@
   p << "iree_interp.memref_to_tensor(";
   p.printOperand(op.getOperand());
   p << " : ";
-  p.printType(op.getOperand()->getType());
+  p.printType(op.getOperand().getType());
   p << ") : ";
   p.printType(op.getType());
 }
 
 OpFoldResult MemRefToTensorOp::fold(ArrayRef<Attribute> operands) {
   if (auto tensorToMemRefOp = dyn_cast_or_null<IREEInterp::TensorToMemRefOp>(
-          getOperand()->getDefiningOp())) {
+          getOperand().getDefiningOp())) {
     return tensorToMemRefOp.getOperand();
   }
 
@@ -172,7 +172,7 @@
                              Value arg) {
   // TODO(gcmn) Use getTensorType from MemRefUtils when circular dependency can
   // be avoided.
-  auto memRefType = arg->getType().cast<MemRefType>();
+  auto memRefType = arg.getType().cast<MemRefType>();
   auto tensorType =
       RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
   build(builder, state, tensorType, arg);
@@ -202,14 +202,14 @@
   p << "iree_interp.scalar_to_memref(";
   p.printOperand(op.getOperand());
   p << " : ";
-  p.printType(op.getOperand()->getType());
+  p.printType(op.getOperand().getType());
   p << ") : ";
   p.printType(op.getType());
 }
 
 OpFoldResult ScalarToMemRefOp::fold(ArrayRef<Attribute> operands) {
   if (auto memrefToScalarOp = dyn_cast_or_null<IREEInterp::MemRefToScalarOp>(
-          getOperand()->getDefiningOp())) {
+          getOperand().getDefiningOp())) {
     return memrefToScalarOp.getOperand();
   }
 
@@ -218,7 +218,7 @@
 
 void ScalarToMemRefOp::build(Builder *builder, OperationState &state,
                              Value arg) {
-  build(builder, state, convertTypeToMemRef(arg->getType()), arg);
+  build(builder, state, convertTypeToMemRef(arg.getType()), arg);
 }
 
 //===----------------------------------------------------------------------===//
@@ -245,14 +245,14 @@
   p << "iree_interp.memref_to_scalar(";
   p.printOperand(op.getOperand());
   p << " : ";
-  p.printType(op.getOperand()->getType());
+  p.printType(op.getOperand().getType());
   p << ") : ";
   p.printType(op.getType());
 }
 
 OpFoldResult MemRefToScalarOp::fold(ArrayRef<Attribute> operands) {
   if (auto scalarToMemRefOp = dyn_cast_or_null<IREEInterp::ScalarToMemRefOp>(
-          getOperand()->getDefiningOp())) {
+          getOperand().getDefiningOp())) {
     return scalarToMemRefOp.getOperand();
   }
 
diff --git a/iree/compiler/Translation/Interpreter/IR/HLOps.cpp b/iree/compiler/Translation/Interpreter/IR/HLOps.cpp
index 2c4e563..bdc57eb 100644
--- a/iree/compiler/Translation/Interpreter/IR/HLOps.cpp
+++ b/iree/compiler/Translation/Interpreter/IR/HLOps.cpp
@@ -91,7 +91,7 @@
   p.printOperands(++operandRange.begin(), operandRange.end());
   p << ')';
   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
-  p << " : " << op.getCallee()->getType();
+  p << " : " << op.getCallee().getType();
 }
 
 //===----------------------------------------------------------------------===//
@@ -192,7 +192,7 @@
 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
   // If this is the only usage, we know the clone is unnecessary.
   // TODO(b/135053584) More sophisticated analysis.
-  if (src()->hasOneUse()) return src();
+  if (src().hasOneUse()) return src();
   return {};
 }
 
@@ -205,7 +205,7 @@
   using OpRewritePattern::OpRewritePattern;
   PatternMatchResult matchAndRewrite(ConcatOp concatOp,
                                      PatternRewriter &rewriter) const override {
-    auto finalType = concatOp.getResult()->getType().cast<ShapedType>();
+    auto finalType = concatOp.getResult().getType().cast<ShapedType>();
     auto loc = concatOp.getLoc();
     std::vector<Value> dimPieces;
     auto dst =
@@ -217,7 +217,7 @@
     auto concatDimension = concatOp.dimension().getZExtValue();
     llvm::SmallVector<int64_t, 4> dstIndices(finalType.getRank(), 0);
     for (auto src : concatOp.srcs()) {
-      auto srcShape = src->getType().cast<ShapedType>().getShape();
+      auto srcShape = src.getType().cast<ShapedType>().getShape();
       auto lengths = createArrayConstant(rewriter, loc, srcShape);
       auto dstIndicesOp = createArrayConstant(rewriter, loc, dstIndices);
       rewriter.create<IREEInterp::HL::CopyOp>(loc, src, srcIndices, dst,
diff --git a/iree/compiler/Translation/Interpreter/IR/OpWriters.cpp b/iree/compiler/Translation/Interpreter/IR/OpWriters.cpp
index 6a971f7..38f16b4 100644
--- a/iree/compiler/Translation/Interpreter/IR/OpWriters.cpp
+++ b/iree/compiler/Translation/Interpreter/IR/OpWriters.cpp
@@ -56,11 +56,11 @@
 LogicalResult WriteConvertOperands(Operation *op, BytecodeWriter *writer) {
   auto src = op->getOperand(0);
   RETURN_IF_FAILURE(
-      writer->WriteTypeIndex(getElementTypeOrSelf(src->getType())));
+      writer->WriteTypeIndex(getElementTypeOrSelf(src.getType())));
   RETURN_IF_FAILURE(writer->WriteLocal(src));
   auto dst = op->getOperand(1);
   RETURN_IF_FAILURE(
-      writer->WriteTypeIndex(getElementTypeOrSelf(dst->getType())));
+      writer->WriteTypeIndex(getElementTypeOrSelf(dst.getType())));
   RETURN_IF_FAILURE(writer->WriteLocal(dst));
   return success();
 }
diff --git a/iree/compiler/Translation/Interpreter/Serialization/BytecodeWriter.cpp b/iree/compiler/Translation/Interpreter/Serialization/BytecodeWriter.cpp
index 103acfb..47d4c12 100644
--- a/iree/compiler/Translation/Interpreter/Serialization/BytecodeWriter.cpp
+++ b/iree/compiler/Translation/Interpreter/Serialization/BytecodeWriter.cpp
@@ -200,7 +200,7 @@
   }
   if (ordinal > UINT16_MAX) {
     // TODO(benvanik): varints?
-    emitError(UnknownLoc::get(value->getContext()))
+    emitError(UnknownLoc::get(value.getContext()))
         << "Too many ordinals: " << ordinal
         << "; only 0-UINT16_MAX are supported";
     return llvm::None;
@@ -220,7 +220,7 @@
   }
   if (ordinal.getValue() > UINT16_MAX) {
     // TODO(benvanik): varints?
-    return emitError(UnknownLoc::get(value->getContext()))
+    return emitError(UnknownLoc::get(value.getContext()))
            << "Too many locals: " << ordinal.getValue()
            << "; only 0-UINT16_MAX are supported";
   }
diff --git a/iree/compiler/Translation/Interpreter/Transforms/ConvertFromTupleCallingConvention.cpp b/iree/compiler/Translation/Interpreter/Transforms/ConvertFromTupleCallingConvention.cpp
index c0df0c1..c51ca04 100644
--- a/iree/compiler/Translation/Interpreter/Transforms/ConvertFromTupleCallingConvention.cpp
+++ b/iree/compiler/Translation/Interpreter/Transforms/ConvertFromTupleCallingConvention.cpp
@@ -65,7 +65,7 @@
 bool recursiveUntuple(Value value, Location loc, OpBuilder &builder,
                       BlockAndValueMapping *mapping,
                       llvm::SmallVectorImpl<Value> *newValues) {
-  Type type = value->getType();
+  Type type = value.getType();
   // We can return the value as is.
   if (!type.isa<TupleType>()) {
     newValues->push_back(value);
@@ -148,8 +148,8 @@
   auto newResults = newOp.getResults();
   for (auto oldResult : oldOp->getResults()) {
     llvm::SmallVector<Value, 10> subValues;
-    auto newResult = recursiveRetuple(oldResult->getType(), &newResults,
-                                      builder, oldOp->getLoc());
+    auto newResult = recursiveRetuple(oldResult.getType(), &newResults, builder,
+                                      oldOp->getLoc());
     mapping->map(oldResult, newResult);
   }
 
@@ -250,9 +250,9 @@
     auto *newBlock = builder.createBlock(&newFunction.getBody());
     for (auto oldArg : oldBlock.getArguments()) {
       llvm::SmallVector<Type, 4> newTypes;
-      untupleTypes(oldArg->getType(), &newTypes);
+      untupleTypes(oldArg.getType(), &newTypes);
 
-      Value newTuple = processTuple(oldArg->getType(), oldFunction.getLoc(),
+      Value newTuple = processTuple(oldArg.getType(), oldFunction.getLoc(),
                                     newBlock, builder);
       if (!newTuple) {
         return true;
diff --git a/iree/compiler/Translation/Interpreter/Transforms/ConvertToMemRefCallingConvention.cpp b/iree/compiler/Translation/Interpreter/Transforms/ConvertToMemRefCallingConvention.cpp
index 57c8823..31200f6 100644
--- a/iree/compiler/Translation/Interpreter/Transforms/ConvertToMemRefCallingConvention.cpp
+++ b/iree/compiler/Translation/Interpreter/Transforms/ConvertToMemRefCallingConvention.cpp
@@ -42,7 +42,7 @@
 Value resolveValueToSourceMemRef(Value value, Operation *useOp) {
   // TODO(benvanik): implement this for real; this is naive but enough for our
   // simple load patterns.
-  auto *defInstr = value->getDefiningOp();
+  auto *defInstr = value.getDefiningOp();
   if (auto loadOp = dyn_cast_or_null<LoadOp>(defInstr)) {
     // TODO(benvanik): support views.
     return loadOp.getMemRef();
@@ -79,14 +79,14 @@
 
 bool insertLoad(BlockArgument oldArg, BlockArgument newArg, OpBuilder &builder,
                 BlockAndValueMapping *mapping) {
-  auto loc = oldArg->getOwner()->getParent()->getLoc();
+  auto loc = oldArg.getOwner()->getParent()->getLoc();
 
   // If old arg was a memref we don't need to change anything. We still need
   // to remap so that the use lists match through conversion, though.
-  if (oldArg->getType().isa<MemRefType>()) {
+  if (oldArg.getType().isa<MemRefType>()) {
     mapping->map(oldArg, newArg);
     return false;
-  } else if (oldArg->getType().isa<TensorType>()) {
+  } else if (oldArg.getType().isa<TensorType>()) {
     auto castOp = builder.create<IREEInterp::MemRefToTensorOp>(loc, newArg);
     mapping->map(oldArg, castOp.getResult());
     return false;
@@ -102,17 +102,17 @@
 bool insertLoad(Operation *oldOp, Value oldValue, Value newValue,
                 OpBuilder &builder, BlockAndValueMapping *mapping) {
   // If old value was a memref we don't need to change anything.
-  if (oldValue->getType().isa<MemRefType>()) {
+  if (oldValue.getType().isa<MemRefType>()) {
     mapping->map(oldValue, newValue);
     return false;
-  } else if (oldValue->getType().isa<TensorType>()) {
+  } else if (oldValue.getType().isa<TensorType>()) {
     auto castOp =
         builder.create<IREEInterp::MemRefToTensorOp>(oldOp->getLoc(), newValue);
     mapping->map(oldValue, castOp.getResult());
     return false;
   }
 
-  assert(newValue->getType().isa<MemRefType>());
+  assert(newValue.getType().isa<MemRefType>());
 
   // Insert the load we'll use to unbox the value.
   auto loadedValue =
@@ -132,9 +132,9 @@
   // If the previous value was already a memref we don't need to change
   // anything.
   // TODO(benvanik): ensure indices make sense.
-  if (oldValue->getType().isa<MemRefType>()) {
+  if (oldValue.getType().isa<MemRefType>()) {
     return newValue;
-  } else if (oldValue->getType().isa<TensorType>()) {
+  } else if (oldValue.getType().isa<TensorType>()) {
     auto castOp =
         builder.create<IREEInterp::TensorToMemRefOp>(oldOp->getLoc(), newValue);
     return castOp.getResult();
@@ -147,7 +147,7 @@
 
   // Allocate the memref to store the value.
   auto newStorage = builder.create<AllocOp>(
-      oldOp->getLoc(), convertTypeToMemRef(oldValue->getType()));
+      oldOp->getLoc(), convertTypeToMemRef(oldValue.getType()));
 
   // Insert the store we'll use to box the value.
   builder.create<StoreOp>(oldOp->getLoc(), newValue, newStorage,
@@ -334,7 +334,7 @@
     auto *newBlock = builder.createBlock(&newFunc.getBody());
     for (auto oldArg : oldBlock.getArguments()) {
       // Replace the block args with memrefs.
-      auto memRefType = convertTypeToMemRef(oldArg->getType());
+      auto memRefType = convertTypeToMemRef(oldArg.getType());
       if (!memRefType) return true;
       auto newArg = newBlock->addArgument(memRefType);
 
diff --git a/iree/compiler/Translation/Interpreter/Transforms/ExpandReductionsToOps.cpp b/iree/compiler/Translation/Interpreter/Transforms/ExpandReductionsToOps.cpp
index f81fc73..d43f69a 100644
--- a/iree/compiler/Translation/Interpreter/Transforms/ExpandReductionsToOps.cpp
+++ b/iree/compiler/Translation/Interpreter/Transforms/ExpandReductionsToOps.cpp
@@ -45,14 +45,14 @@
   // ops within the function.
   // TODO(b/139313439): support fused reductions.
   for (auto operand : elementOp->getOperands()) {
-    if (operand->getDefiningOp() != nullptr) {
+    if (operand.getDefiningOp() != nullptr) {
       return elementOp->emitOpError()
              << "Fused reductions are not supported (operand not sourced from "
                 "block args)";
     }
   }
   for (auto result : elementOp->getResults()) {
-    for (auto *user : result->getUsers()) {
+    for (auto *user : result.getUsers()) {
       if (!user->isKnownTerminator()) {
         return elementOp->emitOpError() << "Fused reductions are not supported "
                                            "(result used by non-terminator)";
@@ -77,7 +77,7 @@
       applyFunc.getNumArguments() / 2 + setIndex);
   Value dstArg =
       entryPointEntryBlock.getArgument(applyFunc.getNumArguments() + setIndex);
-  auto dstType = dstArg->getType().cast<ShapedType>();
+  auto dstType = dstArg.getType().cast<ShapedType>();
   Type elementType = dstType.getElementType();
   auto loc = elementOp->getLoc();
   auto dimensionAttr = entryPoint.getAttrOfType<IntegerAttr>(
diff --git a/iree/compiler/Translation/Interpreter/Transforms/LegalizeTypeStorage.cpp b/iree/compiler/Translation/Interpreter/Transforms/LegalizeTypeStorage.cpp
index 475dd7c..8cb5634 100644
--- a/iree/compiler/Translation/Interpreter/Transforms/LegalizeTypeStorage.cpp
+++ b/iree/compiler/Translation/Interpreter/Transforms/LegalizeTypeStorage.cpp
@@ -81,7 +81,7 @@
     auto *newBlock = builder.createBlock(&newFunction.getBody());
     mapping.map(&oldBlock, newBlock);
     for (auto oldArg : oldBlock.getArguments()) {
-      auto newArg = newBlock->addArgument(legalizeType(oldArg->getType()));
+      auto newArg = newBlock->addArgument(legalizeType(oldArg.getType()));
       mapping.map(oldArg, newArg);
     }
   }
diff --git a/iree/compiler/Translation/Interpreter/Transforms/LowerInterpreterDialect.cpp b/iree/compiler/Translation/Interpreter/Transforms/LowerInterpreterDialect.cpp
index 06a13c7..5aba12e 100644
--- a/iree/compiler/Translation/Interpreter/Transforms/LowerInterpreterDialect.cpp
+++ b/iree/compiler/Translation/Interpreter/Transforms/LowerInterpreterDialect.cpp
@@ -108,7 +108,7 @@
 
     SmallVector<Value, 4> replacementValues;
     for (Value result : op.getOperation()->getResults()) {
-      auto memRefType = result->getType().cast<MemRefType>();
+      auto memRefType = result.getType().cast<MemRefType>();
       if (!memRefType.hasStaticShape()) {
         // TODO(benvanik): real thing here - dynamic shaping required.
         // This should emit a shape calculation based on the operation. Most
diff --git a/iree/compiler/Translation/Interpreter/Transforms/LowerXLAToInterpreterDialect.cpp b/iree/compiler/Translation/Interpreter/Transforms/LowerXLAToInterpreterDialect.cpp
index 958a6ab..3dadff3 100644
--- a/iree/compiler/Translation/Interpreter/Transforms/LowerXLAToInterpreterDialect.cpp
+++ b/iree/compiler/Translation/Interpreter/Transforms/LowerXLAToInterpreterDialect.cpp
@@ -118,7 +118,7 @@
       xla_hlo::BroadcastInDimOp *op, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
     auto inputValue = operands[0];
-    auto inputType = inputValue->getType().cast<MemRefType>();
+    auto inputType = inputValue.getType().cast<MemRefType>();
     auto finalType = convertTypeToMemRef(*op);
 
     // Reshape to scalar and broadcast.
@@ -194,7 +194,7 @@
     auto operand = operands[0];
     auto update = operands[1];
 
-    auto updateType = update->getType().cast<ShapedType>();
+    auto updateType = update.getType().cast<ShapedType>();
     Value lengthConstant =
         createArrayConstant(rewriter, op->getLoc(), updateType.getShape());
 
@@ -224,7 +224,7 @@
     auto srcOffset = createArrayConstant(rewriter, op->getLoc(), zero_offset);
 
     auto copiedOperand = rewriter.create<IREEInterp::HL::CloneOp>(
-        op->getLoc(), operand->getType(), operand);
+        op->getLoc(), operand.getType(), operand);
 
     rewriter
         .create<IREEInterp::HL::CopyOp>(op->getLoc(), update, srcOffset,
@@ -244,7 +244,7 @@
       XlaOpType *op, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
     auto val = operands[0];
-    auto inputType = val->getType().cast<MemRefType>();
+    auto inputType = val.getType().cast<MemRefType>();
     auto elementType = inputType.getElementType();
 
     if (elementType.isa<FloatType>()) {
@@ -270,7 +270,7 @@
       ConversionPatternRewriter &rewriter) const override {
     auto lhs = operands[0];
     auto rhs = operands[1];
-    auto inputType = lhs->getType().cast<MemRefType>();
+    auto inputType = lhs.getType().cast<MemRefType>();
     auto elementType = inputType.getElementType();
 
     if (elementType.isa<FloatType>()) {
@@ -303,8 +303,8 @@
     auto operand = operands[0];
     auto result = op->getResult();
 
-    auto operandType = operand->getType().cast<MemRefType>().getElementType();
-    auto resultType = result->getType().cast<ShapedType>().getElementType();
+    auto operandType = operand.getType().cast<MemRefType>().getElementType();
+    auto resultType = result.getType().cast<ShapedType>().getElementType();
 
     auto newResultType = convertTypeToMemRef(result);
 
@@ -359,7 +359,7 @@
       return matchFailure();
     }
 
-    auto resultType = gatherOp.getResult()->getType().cast<RankedTensorType>();
+    auto resultType = gatherOp.getResult().getType().cast<RankedTensorType>();
     if (dimension_numbers.offset_dims().getType().getNumElements() !=
         resultType.getRank()) {
       gatherOp.emitRemark() << "Couldn't lower gather with offset_dims != "
@@ -385,11 +385,11 @@
       }
     }
 
-    auto inputType = gatherOp.operand()->getType().cast<RankedTensorType>();
+    auto inputType = gatherOp.operand().getType().cast<RankedTensorType>();
 
     auto startIndices =
         inputAsMemref(rewriter, gatherOp, gatherOp.start_indices());
-    auto startIndicesType = startIndices->getType().cast<MemRefType>();
+    auto startIndicesType = startIndices.getType().cast<MemRefType>();
     if (startIndicesType.getNumElements() != inputType.getRank()) {
       auto extraDims = inputType.getRank() - startIndicesType.getNumElements();
       auto elementType = startIndicesType.getElementType();
diff --git a/iree/compiler/Translation/Interpreter/Transforms/MakeExecutableABI.cpp b/iree/compiler/Translation/Interpreter/Transforms/MakeExecutableABI.cpp
index c0eb64b..fcb90b7 100644
--- a/iree/compiler/Translation/Interpreter/Transforms/MakeExecutableABI.cpp
+++ b/iree/compiler/Translation/Interpreter/Transforms/MakeExecutableABI.cpp
@@ -36,7 +36,7 @@
   OpBuilder builder(bindOp);
 
   Value newValue = nullptr;
-  auto dstType = bindOp.getResult()->getType();
+  auto dstType = bindOp.getResult().getType();
   if (dstType.isa<TensorType>()) {
     auto castOp = builder.create<IREEInterp::MemRefToTensorOp>(bindOp.getLoc(),
                                                                bindOp.src());
@@ -60,7 +60,7 @@
 LogicalResult replaceStoreOutputOp(IREE::StoreOutputOp bindOp) {
   OpBuilder builder(bindOp);
 
-  auto srcType = bindOp.src()->getType();
+  auto srcType = bindOp.src().getType();
   if (srcType.isa<MemRefType>()) {
     // Already stored into the output.
   } else if (srcType.isa<TensorType>()) {
@@ -68,7 +68,7 @@
                                                                bindOp.src());
 
     // Insert a copy to our output parameter.
-    auto dst = bindOp.dst()->getType().cast<ShapedType>();
+    auto dst = bindOp.dst().getType().cast<ShapedType>();
     if (!dst.hasStaticShape()) {
       return bindOp.emitError()
              << "Dynamic output args are not yet implemented";
diff --git a/iree/compiler/Translation/Interpreter/Utils/MemRefUtils.cpp b/iree/compiler/Translation/Interpreter/Utils/MemRefUtils.cpp
index 14fea0f..6d988e6 100644
--- a/iree/compiler/Translation/Interpreter/Utils/MemRefUtils.cpp
+++ b/iree/compiler/Translation/Interpreter/Utils/MemRefUtils.cpp
@@ -26,9 +26,9 @@
 namespace iree_compiler {
 
 Value wrapAsTensor(Value value, Operation *srcOp, OpBuilder &builder) {
-  if (srcOp->getResult(0)->getType().isa<TensorType>()) {
-    if (isa_and_nonnull<IREEInterp::TensorToMemRefOp>(value->getDefiningOp())) {
-      return value->getDefiningOp()->getOperand(0);
+  if (srcOp->getResult(0).getType().isa<TensorType>()) {
+    if (isa_and_nonnull<IREEInterp::TensorToMemRefOp>(value.getDefiningOp())) {
+      return value.getDefiningOp()->getOperand(0);
     }
     auto newOp =
         builder.create<IREEInterp::MemRefToTensorOp>(srcOp->getLoc(), value);
@@ -38,9 +38,9 @@
 }
 
 Value wrapAsMemRef(Value value, Operation *srcOp, OpBuilder &builder) {
-  if (value->getType().isa<TensorType>()) {
-    if (isa_and_nonnull<IREEInterp::MemRefToTensorOp>(value->getDefiningOp())) {
-      return value->getDefiningOp()->getOperand(0);
+  if (value.getType().isa<TensorType>()) {
+    if (isa_and_nonnull<IREEInterp::MemRefToTensorOp>(value.getDefiningOp())) {
+      return value.getDefiningOp()->getOperand(0);
     }
     auto newOp =
         builder.create<IREEInterp::TensorToMemRefOp>(srcOp->getLoc(), value);
@@ -50,13 +50,13 @@
 }
 
 Value loadAccessValue(Location location, Value operand, OpBuilder &builder) {
-  if (operand->getType().isa<MemRefType>() ||
-      operand->getType().isa<TensorType>()) {
+  if (operand.getType().isa<MemRefType>() ||
+      operand.getType().isa<TensorType>()) {
     return operand;
   }
 
-  auto memRefType = MemRefType::get({}, operand->getType());
-  if (auto loadOp = dyn_cast_or_null<LoadOp>(operand->getDefiningOp())) {
+  auto memRefType = MemRefType::get({}, operand.getType());
+  if (auto loadOp = dyn_cast_or_null<LoadOp>(operand.getDefiningOp())) {
     // TODO(benvanik): handle creating views.
     if (loadOp.getMemRefType() == memRefType) {
       return loadOp.getMemRef();
diff --git a/iree/compiler/Translation/SPIRV/AdjustIntegerWidthPass.cpp b/iree/compiler/Translation/SPIRV/AdjustIntegerWidthPass.cpp
index 0857544..91ade9c 100644
--- a/iree/compiler/Translation/SPIRV/AdjustIntegerWidthPass.cpp
+++ b/iree/compiler/Translation/SPIRV/AdjustIntegerWidthPass.cpp
@@ -96,11 +96,11 @@
   using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
   PatternMatchResult matchAndRewrite(spirv::AccessChainOp op,
                                      PatternRewriter &rewriter) const override {
-    if (!hasIntTypeOfWidth(op.component_ptr()->getType(), {1, 8, 64})) {
+    if (!hasIntTypeOfWidth(op.component_ptr().getType(), {1, 8, 64})) {
       return matchFailure();
     }
     ValueRange indices(op.indices());
-    Type newType = legalizeIntegerType(op.component_ptr()->getType());
+    Type newType = legalizeIntegerType(op.component_ptr().getType());
     rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(op, newType,
                                                       op.base_ptr(), indices);
     return matchSuccess();
@@ -113,11 +113,11 @@
   using OpRewritePattern<spirv::AddressOfOp>::OpRewritePattern;
   PatternMatchResult matchAndRewrite(spirv::AddressOfOp op,
                                      PatternRewriter &rewriter) const override {
-    if (!hasIntTypeOfWidth(op.pointer()->getType(), {1, 8, 64})) {
+    if (!hasIntTypeOfWidth(op.pointer().getType(), {1, 8, 64})) {
       return matchFailure();
     }
     rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
-        op, legalizeIntegerType(op.pointer()->getType()),
+        op, legalizeIntegerType(op.pointer().getType()),
         SymbolRefAttr::get(op.variable(), rewriter.getContext()));
     return matchSuccess();
   }
@@ -154,7 +154,7 @@
   if (indices.size() > 1) {
     indices.back() = rewriter.create<spirv::SDivOp>(loc, lastDim, four);
   }
-  Type t = legalizeIntegerType(op.component_ptr()->getType());
+  Type t = legalizeIntegerType(op.component_ptr().getType());
   return rewriter.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
 }
 
@@ -179,7 +179,7 @@
   using OpRewritePattern<spirv::LoadOp>::OpRewritePattern;
   PatternMatchResult matchAndRewrite(spirv::LoadOp op,
                                      PatternRewriter &rewriter) const override {
-    Type valueType = op.value()->getType();
+    Type valueType = op.value().getType();
     if (!hasIntTypeOfWidth(valueType, {1, 8, 64})) {
       return matchFailure();
     }
@@ -189,8 +189,7 @@
     const auto loc = op.getLoc();
     Value result;
     if (hasIntTypeOfWidth(valueType, {1, 8})) {
-      auto accessChainOp =
-          cast<spirv::AccessChainOp>(op.ptr()->getDefiningOp());
+      auto accessChainOp = cast<spirv::AccessChainOp>(op.ptr().getDefiningOp());
       // Only support for scalar and 1-D tensor. The first element in indices is
       // index, the remaining elements map to other dimensions.
       if (accessChainOp.indices().size() > 2) {
@@ -244,7 +243,7 @@
 // Returns the shifted 32-bit value with the given offset.
 Value shiftStoreValue(spirv::StoreOp op, Value offset,
                       PatternRewriter &rewriter) {
-  Type valueType = op.value()->getType();
+  Type valueType = op.value().getType();
   Type i32Type = rewriter.getIntegerType(32);
   const auto loc = op.getLoc();
 
@@ -279,7 +278,7 @@
 LogicalResult rewriteInt1AndInt8(spirv::StoreOp op, PatternRewriter &rewriter) {
   Type i32Type = rewriter.getIntegerType(32);
   const auto loc = op.getLoc();
-  auto accessChainOp = cast<spirv::AccessChainOp>(op.ptr()->getDefiningOp());
+  auto accessChainOp = cast<spirv::AccessChainOp>(op.ptr().getDefiningOp());
 
   // Only support for scalar and 1-D tensor. The first element in indices is
   // index, the remaining elements map to other dimensions.
@@ -321,7 +320,7 @@
 
   PatternMatchResult matchAndRewrite(spirv::StoreOp op,
                                      PatternRewriter &rewriter) const override {
-    Type valueType = op.value()->getType();
+    Type valueType = op.value().getType();
     if (!hasIntTypeOfWidth(valueType, {1, 8, 64})) {
       return matchFailure();
     }
@@ -352,8 +351,8 @@
   using OpRewritePattern<spirv::SConvertOp>::OpRewritePattern;
   PatternMatchResult matchAndRewrite(spirv::SConvertOp op,
                                      PatternRewriter &rewriter) const override {
-    Type t1 = op.operand()->getType();
-    Type t2 = op.result()->getType();
+    Type t1 = op.operand().getType();
+    Type t2 = op.result().getType();
     if (t1 != t2) return matchFailure();
     auto zero = spirv::ConstantOp::getZero(t1, op.getLoc(), &rewriter);
     rewriter.replaceOpWithNewOp<spirv::IAddOp>(op, op.operand(), zero);
@@ -366,7 +365,7 @@
   using OpRewritePattern<spirv::SConvertOp>::OpRewritePattern;
   PatternMatchResult matchAndRewrite(spirv::SConvertOp op,
                                      PatternRewriter &rewriter) const override {
-    Type t = op.result()->getType();
+    Type t = op.result().getType();
     if (!hasIntTypeOfWidth(t, {8, 64})) {
       return matchFailure();
     }
@@ -407,11 +406,11 @@
 struct AdjustIntegerArithmeticOperations : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
   PatternMatchResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const {
-    Type resultType = op.result()->getType();
+    Type resultType = op.result().getType();
     if (!hasIntTypeOfWidth(resultType, {64})) {
       return Pattern::matchFailure();
     }
-    Type newType = legalizeIntegerType(op.getResult()->getType());
+    Type newType = legalizeIntegerType(op.getResult().getType());
     ValueRange operands(op.getOperation()->getOperands());
     rewriter.replaceOpWithNewOp<OpTy>(op, newType, operands, op.getAttrs());
     return Pattern::matchSuccess();
diff --git a/iree/compiler/Translation/SPIRV/EmbeddedKernels.cpp b/iree/compiler/Translation/SPIRV/EmbeddedKernels.cpp
index bab753a..d8d2177 100644
--- a/iree/compiler/Translation/SPIRV/EmbeddedKernels.cpp
+++ b/iree/compiler/Translation/SPIRV/EmbeddedKernels.cpp
@@ -150,8 +150,8 @@
 LogicalResult buildConvExecutable(ModuleOp moduleOp, FuncOp entryFuncOp,
                                   xla_hlo::ConvOp convOp,
                                   iree::SpirVExecutableDefT *outDef) {
-  auto lhs = convOp.lhs()->getType().cast<ShapedType>();
-  auto rhs = convOp.rhs()->getType().cast<ShapedType>();
+  auto lhs = convOp.lhs().getType().cast<ShapedType>();
+  auto rhs = convOp.rhs().getType().cast<ShapedType>();
   if (convOp.feature_group_count() != 1) {
     return entryFuncOp.emitOpError()
            << "only feature group counts of 1 supported";
@@ -261,8 +261,8 @@
 LogicalResult buildMatMulExecutable(ModuleOp moduleOp, FuncOp entryFuncOp,
                                     xla_hlo::DotOp dotOp,
                                     iree::SpirVExecutableDefT *outDef) {
-  auto arg0 = dotOp.getOperand(0)->getType().cast<ShapedType>();
-  auto arg1 = dotOp.getOperand(1)->getType().cast<ShapedType>();
+  auto arg0 = dotOp.getOperand(0).getType().cast<ShapedType>();
+  auto arg1 = dotOp.getOperand(1).getType().cast<ShapedType>();
 
   outDef->tag = "__matmul__";
   outDef->entry_points = {"main"};
diff --git a/iree/compiler/Translation/SPIRV/IREEIndexComputation.cpp b/iree/compiler/Translation/SPIRV/IREEIndexComputation.cpp
index d57de5a..910f687 100644
--- a/iree/compiler/Translation/SPIRV/IREEIndexComputation.cpp
+++ b/iree/compiler/Translation/SPIRV/IREEIndexComputation.cpp
@@ -33,7 +33,7 @@
                                           Value value) {
   SmallVector<int64_t, 4> valueShape;
   int64_t valueNumElements = 0;
-  Type valueType = value->getType();
+  Type valueType = value.getType();
   if (auto valueShapedType = valueType.dyn_cast<ShapedType>()) {
     if (!valueShapedType.hasStaticShape()) {
       return emitError(loc, "can only handle tensor of static shape");
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation.cpp b/iree/compiler/Translation/SPIRV/IndexComputation.cpp
index de7a957..4f777df 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation.cpp
+++ b/iree/compiler/Translation/SPIRV/IndexComputation.cpp
@@ -260,8 +260,7 @@
   // clearer picture as to what index types become at the time of SPIR-V
   // lowering since they do not have an equivalent XLA-HLO representation.
   auto extractElementOp = cast<ExtractElementOp>(operation);
-  if (extractElementOp.aggregate()->getType().cast<ShapedType>().getRank() >
-      0) {
+  if (extractElementOp.aggregate().getType().cast<ShapedType>().getRank() > 0) {
     return extractElementOp.emitError(
         "unhandled index propagation for non-zero ranked tensor types");
   }
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation.h b/iree/compiler/Translation/SPIRV/IndexComputation.h
index 5a125a6..44fbd7d 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation.h
+++ b/iree/compiler/Translation/SPIRV/IndexComputation.h
@@ -106,12 +106,12 @@
       SmallVectorImpl<AffineMap> &operandIndices) const override {
     // All operands must have the same type.
     auto argRefType =
-        operation->getOperand(0)->getType().dyn_cast<RankedTensorType>();
+        operation->getOperand(0).getType().dyn_cast<RankedTensorType>();
     if (!argRefType) {
       return operation->emitError("expected operands to be of tensortype");
     }
     for (auto arg : operation->getOperands()) {
-      auto argType = arg->getType().dyn_cast<RankedTensorType>();
+      auto argType = arg.getType().dyn_cast<RankedTensorType>();
       if (!argType || argType.getShape() != argRefType.getShape()) {
         return operation->emitError("expected operands to have same shape");
       }
@@ -144,9 +144,9 @@
       SmallVectorImpl<AffineMap> &operandIndices) const override {
     Builder builder(op->getContext());
     auto resultType =
-        op->getResult(0)->getType().template dyn_cast<ShapedType>();
+        op->getResult(0).getType().template dyn_cast<ShapedType>();
     auto operandType =
-        op->getOperand(0)->getType().template dyn_cast<ShapedType>();
+        op->getOperand(0).getType().template dyn_cast<ShapedType>();
     if (!resultType || !operandType) {
       return op->emitError("expected result and operand to be shaped types");
     }
@@ -189,7 +189,7 @@
   LogicalResult propagateIndexMapImpl(
       Operation *op, DenseSet<unsigned> dimensions, AffineMap resultIndex,
       SmallVectorImpl<AffineMap> &operandIndices) const {
-    auto shaped_type = op->getOperand(0)->getType().cast<ShapedType>();
+    auto shaped_type = op->getOperand(0).getType().cast<ShapedType>();
     Builder builder(op->getContext());
     SmallVector<AffineExpr, 4> dimensionsExprs;
     for (unsigned index = 0; index < shaped_type.getRank(); ++index) {
@@ -232,7 +232,7 @@
       SmallVectorImpl<AffineMap> &operandIndices) const {
     Builder builder(op->getContext());
     SmallVector<AffineExpr, 4> exprs;
-    auto shaped_type = op->getOperand(0)->getType().cast<ShapedType>();
+    auto shaped_type = op->getOperand(0).getType().cast<ShapedType>();
     int rank = shaped_type.getRank();
     for (int i = 0; i < rank; ++i) {
       exprs.push_back(builder.getAffineDimExpr(i) * strides[i] +
diff --git a/iree/compiler/Translation/SPIRV/IndexComputationAttribute.cpp b/iree/compiler/Translation/SPIRV/IndexComputationAttribute.cpp
index a59b9f7..6812791 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputationAttribute.cpp
+++ b/iree/compiler/Translation/SPIRV/IndexComputationAttribute.cpp
@@ -167,32 +167,32 @@
 /// Gets an attribute associated with a block argument.
 template <typename T>
 T getBlockArgumentAttr(BlockArgument blockArg, StringRef attrName) {
-  auto block = blockArg->getOwner();
+  auto block = blockArg.getOwner();
   auto funcOp = dyn_cast<FuncOp>(block->getParentOp());
   if (!funcOp) {
-    emitError(blockArg->getLoc(),
+    emitError(blockArg.getLoc(),
               "unimplemented index computation for block argument when "
               "block is not in a function");
     return nullptr;
   }
-  return funcOp.getArgAttrOfType<T>(blockArg->getArgNumber(), attrName);
+  return funcOp.getArgAttrOfType<T>(blockArg.getArgNumber(), attrName);
 }
 
 /// Updates an attribute associated with a block argument
 template <typename T>
 LogicalResult setBlockArgumentAttr(BlockArgument blockArg, T updatedAttr,
                                    StringRef attrName) {
-  auto block = blockArg->getOwner();
+  auto block = blockArg.getOwner();
   auto funcOp = dyn_cast<FuncOp>(block->getParentOp());
   if (!funcOp) {
-    return emitError(blockArg->getLoc(),
+    return emitError(blockArg.getLoc(),
                      "unimplemented index computation for block argument when "
                      "block is not in a function");
   }
   auto currAttr =
-      funcOp.getArgAttrOfType<ArrayAttr>(blockArg->getArgNumber(), attrName);
+      funcOp.getArgAttrOfType<ArrayAttr>(blockArg.getArgNumber(), attrName);
   if (currAttr != updatedAttr) {
-    funcOp.setArgAttr(blockArg->getArgNumber(), attrName, updatedAttr);
+    funcOp.setArgAttr(blockArg.getArgNumber(), attrName, updatedAttr);
   }
   return success();
 }
@@ -214,7 +214,7 @@
   auto attrName = getIndexComputationAttrName();
   auto currAttr = getBlockArgumentAttr<ArrayAttr>(blockArg, attrName);
   auto updatedAttr = updateIndexComputationAttrWithResultIndex(
-      blockArg->getContext(), currAttr, resultIndexMap, 0);
+      blockArg.getContext(), currAttr, resultIndexMap, 0);
   return setBlockArgumentAttr(blockArg, updatedAttr, attrName);
 }
 
@@ -243,24 +243,23 @@
 /// Records an index map for a tensor value.
 LogicalResult addNewIndexMapForValue(Value value, AffineMap resultIndexMap) {
   // Check if the Value is a block argument or has a defining operation.
-  auto valueKind = value->getKind();
-  if (valueKind == Value::Kind::BlockArgument) {
-    return addBlockArgIndexMap(value->cast<BlockArgument>(), resultIndexMap);
+  if (value.isa<BlockArgument>()) {
+    return addBlockArgIndexMap(value.cast<BlockArgument>(), resultIndexMap);
   }
-  return addOpResultIndexMap(value->getDefiningOp(), resultIndexMap);
+  return addOpResultIndexMap(value.getDefiningOp(), resultIndexMap);
 }
 
 Optional<int64_t> addNewSymbolNumberForTensorIndex(Value value,
                                                    AffineMap index) {
-  if (value->getKind() == Value::Kind::BlockArgument ||
-      !isa<IREE::LoadInputOp>(value->getDefiningOp())) {
-    emitError(value->getLoc(),
+  if (value.isa<BlockArgument>() ||
+      !isa<IREE::LoadInputOp>(value.getDefiningOp())) {
+    emitError(value.getLoc(),
               "only result of a iree.load_input can be associated with "
               "an symbol number");
     return {};
   }
-  auto loadInputOp = cast<IREE::LoadInputOp>(value->getDefiningOp());
-  auto context = value->getContext();
+  auto loadInputOp = cast<IREE::LoadInputOp>(value.getDefiningOp());
+  auto context = value.getContext();
   auto funcOp = loadInputOp.getOperation()->getParentOfType<FuncOp>();
 
   // Find the symbol number to use. It is recorded as an attribute on the
@@ -274,11 +273,11 @@
   unsigned symbolNumber = static_cast<unsigned>(updatedNumSymbolsAttr.getInt());
 
   // Record the mapping from element at tensor index to the symbol.
-  auto srcArg = loadInputOp.src()->cast<BlockArgument>();
+  auto srcArg = loadInputOp.src().cast<BlockArgument>();
   auto attrName = getSymbolNumberAttrName();
   auto currAttr = getBlockArgumentAttr<ArrayAttr>(srcArg, attrName);
   auto updatedAttr = updateTensorIndexToSymbolNumberAttr(
-      value->getContext(), currAttr, index, symbolNumber);
+      value.getContext(), currAttr, index, symbolNumber);
   setBlockArgumentAttr(srcArg, updatedAttr, attrName);
   return symbolNumber;
 }
@@ -301,13 +300,12 @@
 }
 
 void getIndexMapsForValue(Value value, SmallVectorImpl<AffineMap> &indices) {
-  auto valueKind = value->getKind();
   auto attrName = getIndexComputationAttrName();
   ArrayAttr allIndices =
-      (valueKind == Value::Kind::BlockArgument
-           ? getBlockArgumentAttr<ArrayAttr>(value->cast<BlockArgument>(),
+      (value.isa<BlockArgument>()
+           ? getBlockArgumentAttr<ArrayAttr>(value.cast<BlockArgument>(),
                                              attrName)
-           : value->getDefiningOp()->getAttrOfType<ArrayAttr>(attrName));
+           : value.getDefiningOp()->getAttrOfType<ArrayAttr>(attrName));
   if (!allIndices) {
     return;
   }
diff --git a/iree/compiler/Translation/SPIRV/PrepareReductionDispatch.cpp b/iree/compiler/Translation/SPIRV/PrepareReductionDispatch.cpp
index 3ae8685..b403c43 100644
--- a/iree/compiler/Translation/SPIRV/PrepareReductionDispatch.cpp
+++ b/iree/compiler/Translation/SPIRV/PrepareReductionDispatch.cpp
@@ -63,7 +63,7 @@
                 UnitAttr::get(fn.getContext()));
   auto applyFn =
       fn.getAttrOfType<FlatSymbolRefAttr>("iree.executable.reduction.apply");
-  auto srcType = src->getType().cast<MemRefType>();
+  auto srcType = src.getType().cast<MemRefType>();
   auto loc = fn.getLoc();
   auto loadInputOp = rewriter.create<IREE::LoadInputOp>(
       loc, RankedTensorType::get(srcType.getShape(), srcType.getElementType()),
@@ -81,7 +81,7 @@
     }
   };
 
-  auto shape = src->getType().cast<ShapedType>().getShape();
+  auto shape = src.getType().cast<ShapedType>().getShape();
   std::array<int32_t, 3> workload = {1, 1, 1};
   calculateWorkload(shape, workload);
   SmallVector<APInt, 3> workloadAPInt;
diff --git a/iree/compiler/Translation/SPIRV/ReductionFnLowering.cpp b/iree/compiler/Translation/SPIRV/ReductionFnLowering.cpp
index 0b7c54a..49ee423 100644
--- a/iree/compiler/Translation/SPIRV/ReductionFnLowering.cpp
+++ b/iree/compiler/Translation/SPIRV/ReductionFnLowering.cpp
@@ -164,10 +164,10 @@
   // type.
   Value ptr = operands[0];
   Value value = operands[1];
-  if (!ptr->getType().isa<spirv::PointerType>()) std::swap(ptr, value);
-  if (!ptr->getType().isa<spirv::PointerType>()) return this->matchFailure();
+  if (!ptr.getType().isa<spirv::PointerType>()) std::swap(ptr, value);
+  if (!ptr.getType().isa<spirv::PointerType>()) return this->matchFailure();
   rewriter.replaceOpWithNewOp<ReplacementOpTy>(
-      op, ptr->getType().cast<spirv::PointerType>().getPointeeType(), ptr,
+      op, ptr.getType().cast<spirv::PointerType>().getPointeeType(), ptr,
       spirv::Scope::Device, spirv::MemorySemantics::AcquireRelease, value);
   return this->matchSuccess();
 }
diff --git a/iree/compiler/Translation/SPIRV/SPIRVLowering.cpp b/iree/compiler/Translation/SPIRV/SPIRVLowering.cpp
index 4601be1..613d5c6 100644
--- a/iree/compiler/Translation/SPIRV/SPIRVLowering.cpp
+++ b/iree/compiler/Translation/SPIRV/SPIRVLowering.cpp
@@ -32,7 +32,7 @@
                        TensorIndexToScalarValueMap &valueCache,
                        AffineMap indexMap, Value buffer) {
   auto varPtrType =
-      buffer->getType().cast<spirv::PointerType>().getPointeeType();
+      buffer.getType().cast<spirv::PointerType>().getPointeeType();
   // The variable has to be a struct type with a single element.
   auto varStructType = varPtrType.cast<spirv::StructType>();
   assert(varStructType.getNumElements() == 1 &&
@@ -254,7 +254,7 @@
   // Add values corresponding to the symbol numbers.
   SmallVector<std::pair<AffineMap, unsigned>, 2> symbolInfo;
   index_computation_attribute::getSymbolNumberForTensorIndex(
-      origArg->cast<BlockArgument>(), symbolInfo);
+      origArg.cast<BlockArgument>(), symbolInfo);
   for (auto element : symbolInfo) {
     // Load the value at the index.
     auto val =
@@ -303,7 +303,7 @@
   }
 
   for (auto arg : fn.getArguments()) {
-    if (fn.getArgAttrOfType<UnitAttr>(arg->getArgNumber(),
+    if (fn.getArgAttrOfType<UnitAttr>(arg.getArgNumber(),
                                       "iree.executable.reduction.output")) {
       continue;
     }
@@ -347,7 +347,7 @@
     TensorIndexToScalarValueMap &valueCache) const {
   auto constOp = cast<ConstantOp>(op);
   auto attr = constOp.value().dyn_cast<DenseElementsAttr>();
-  auto resultType = constOp.getResult()->getType();
+  auto resultType = constOp.getResult().getType();
   Type resultElemType;
   if (resultType.isIntOrFloat()) {
     resultElemType = resultType;
diff --git a/iree/compiler/Translation/SPIRV/SPIRVLowering.h b/iree/compiler/Translation/SPIRV/SPIRVLowering.h
index 5933fea..d93b458 100644
--- a/iree/compiler/Translation/SPIRV/SPIRVLowering.h
+++ b/iree/compiler/Translation/SPIRV/SPIRVLowering.h
@@ -127,7 +127,7 @@
            "instruction");
     auto cmpSelectOp = cast<OpTy>(op);
     auto result = cmpSelectOp.getResult();
-    auto resultTy = result->getType().template dyn_cast<ShapedType>();
+    auto resultTy = result.getType().template dyn_cast<ShapedType>();
     if (!resultTy) {
       return op->emitError(
           "unhandled lowering of operations that don't return a "
@@ -144,7 +144,7 @@
                                       ArrayRef<NamedAttribute>());
     }
     auto selectOp = builder.create<spirv::SelectOp>(
-        op->getLoc(), operands[0]->getType(), cmpOp->getResult(0), operands[0],
+        op->getLoc(), operands[0].getType(), cmpOp->getResult(0), operands[0],
         operands[1]);
     valueCache.setValueAtIndex(op->getResult(0), index, selectOp.getResult());
     return success();
@@ -175,7 +175,7 @@
     }
     auto pwOp = cast<OpTy>(op);
     auto result = pwOp.getResult();
-    auto resultType = result->getType().template dyn_cast<ShapedType>();
+    auto resultType = result.getType().template dyn_cast<ShapedType>();
     if (!resultType) {
       return op->emitError(
           "unhandled lowering of operations that don't return a "
diff --git a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.cpp b/iree/compiler/Translation/SPIRV/XLAIndexPropagation.cpp
index 0a3f603..92a9fe7 100644
--- a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.cpp
+++ b/iree/compiler/Translation/SPIRV/XLAIndexPropagation.cpp
@@ -75,7 +75,7 @@
   SmallVector<AffineExpr, 4> exprs;
   for (auto i : llvm::seq<size_t>(
            broadcastDim.getType().getShape()[0],
-           operation->getResult(0)->getType().cast<ShapedType>().getRank())) {
+           operation->getResult(0).getType().cast<ShapedType>().getRank())) {
     exprs.push_back(resultIndex.getResult(i));
   }
 
@@ -105,7 +105,7 @@
   // dimension.
   int offset = 0;
   for (Value operand : op->getOperands()) {
-    auto operandType = operand->getType().cast<RankedTensorType>();
+    auto operandType = operand.getType().cast<RankedTensorType>();
     int rank = operandType.getRank();
     SmallVector<AffineExpr, 4> exprs;
     for (int i = 0; i < rank; ++i) {
@@ -151,11 +151,11 @@
   auto gatherOp = cast<xla_hlo::GatherOp>(op);
   Value startIndices = gatherOp.start_indices();
   int64_t startIndicesRank =
-      startIndices->getType().cast<ShapedType>().getRank();
+      startIndices.getType().cast<ShapedType>().getRank();
   Value operand = gatherOp.operand();
-  int64_t operandRank = operand->getType().cast<ShapedType>().getRank();
+  int64_t operandRank = operand.getType().cast<ShapedType>().getRank();
   Value result = gatherOp.getResult();
-  int64_t resultRank = result->getType().cast<ShapedType>().getRank();
+  int64_t resultRank = result.getType().cast<ShapedType>().getRank();
   ArrayRef<AffineExpr> resultExprs = resultIndex.getResults();
   int64_t indexVectorDim =
       gatherOp.dimension_numbers().index_vector_dim().getValue().getSExtValue();
@@ -266,7 +266,7 @@
 
   // Index for the tensor operand.
   SmallVector<AffineExpr, 4> exprs(
-      padOp.operand()->getType().cast<RankedTensorType>().getRank());
+      padOp.operand().getType().cast<RankedTensorType>().getRank());
   for (auto resultExpr : enumerate(resultIndex.getResults())) {
     auto i = resultExpr.index();
     int64_t padding_low = edge_padding_low.getValue<IntegerAttr>(i).getInt();
diff --git a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.h b/iree/compiler/Translation/SPIRV/XLAIndexPropagation.h
index c11f910..f5c2b59 100644
--- a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.h
+++ b/iree/compiler/Translation/SPIRV/XLAIndexPropagation.h
@@ -65,7 +65,7 @@
       return operation->emitError("unhandled multiple return values");
     }
     auto returnValue = operation->getOperand(0);
-    auto returnType = returnValue->getType().cast<RankedTensorType>();
+    auto returnType = returnValue.getType().cast<RankedTensorType>();
     auto returnRank = returnType.getRank();
     if (returnRank > 3) {
       return operation->emitError("unhandled return tensor of dimension ")
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp b/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp
index f6c11ed..2593438 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV.cpp
@@ -42,7 +42,7 @@
       builder.saveInsertionPoint(), loc, index.getResult(append_dim));
 
   int offset = op->getOperand(0)
-                   ->getType()
+                   .getType()
                    .cast<RankedTensorType>()
                    .getShape()[append_dim];
   Value resultVal = operands[0];
@@ -62,7 +62,7 @@
     resultVal = builder.create<spirv::SelectOp>(
         loc, cond, operands[operandIt.index()], resultVal);
     auto operandShape =
-        operandIt.value()->getType().cast<RankedTensorType>().getShape();
+        operandIt.value().getType().cast<RankedTensorType>().getShape();
     offset += operandShape[append_dim];
   }
   valueCache.setValueAtIndex(op->getResult(0), index, resultVal);
@@ -79,9 +79,9 @@
   auto convertOp = cast<xla_hlo::ConvertOp>(op);
   auto loc = convertOp.getLoc();
   auto resultElemType =
-      convertOp.getResult()->getType().dyn_cast<ShapedType>().getElementType();
+      convertOp.getResult().getType().dyn_cast<ShapedType>().getElementType();
   auto operandElemType =
-      convertOp.getOperand()->getType().dyn_cast<ShapedType>().getElementType();
+      convertOp.getOperand().getType().dyn_cast<ShapedType>().getElementType();
 
   if (resultElemType == operandElemType) {
     valueCache.setValueAtIndex(op->getResult(0), index, operands[0]);
@@ -157,7 +157,7 @@
   auto i1Type = builder.getI1Type();
   auto zero = spirv::ConstantOp::getZero(i32Type, loc, &builder);
   Value cond = spirv::ConstantOp::getOne(i1Type, loc, &builder);
-  auto operandType = padOp.operand()->getType().cast<RankedTensorType>();
+  auto operandType = padOp.operand().getType().cast<RankedTensorType>();
   if (!operandType.hasStaticShape()) {
     return padOp.emitError("pad op codegen supported only for static shapes");
   }
@@ -201,10 +201,10 @@
 
     if (paddingStride != 1) {
       // ((d_i - edge_padding_low[i]) % (interior_padding[i]+1) == 0)
-      auto t1 = builder.create<spirv::ISubOp>(loc, dimIndex->getType(),
-                                              dimIndex, edgePadding);
-      auto t2 = builder.create<spirv::SModOp>(loc, t1.getResult()->getType(),
-                                              t1, stride);
+      auto t1 = builder.create<spirv::ISubOp>(loc, dimIndex.getType(), dimIndex,
+                                              edgePadding);
+      auto t2 = builder.create<spirv::SModOp>(loc, t1.getResult().getType(), t1,
+                                              stride);
       auto checkStride = builder.create<spirv::IEqualOp>(loc, i1Type, t2, zero);
       cond =
           builder.create<spirv::LogicalAndOp>(loc, i1Type, cond, checkStride);
diff --git a/iree/compiler/Utils/DispatchUtils.cpp b/iree/compiler/Utils/DispatchUtils.cpp
index 04318d1..40daf19 100644
--- a/iree/compiler/Utils/DispatchUtils.cpp
+++ b/iree/compiler/Utils/DispatchUtils.cpp
@@ -66,7 +66,7 @@
   std::array<int32_t, 3> workload = {1, 1, 1};
 
   // TODO(b/139353314): lookup/calculate based on type/etc.
-  auto resultType = baseOperand->getType();
+  auto resultType = baseOperand.getType();
   if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
     if (!shapedType.hasStaticShape()) {
       op->emitOpError() << "Dynamic shapes not yet supported";
diff --git a/iree/compiler/Utils/GraphUtils.cpp b/iree/compiler/Utils/GraphUtils.cpp
index 09a0dbd..2c49912 100644
--- a/iree/compiler/Utils/GraphUtils.cpp
+++ b/iree/compiler/Utils/GraphUtils.cpp
@@ -30,7 +30,7 @@
   VisitFn visit = [&](Operation *op) {
     if (markedOps.count(op) > 0) return;
     for (auto result : op->getResults()) {
-      for (auto *user : result->getUsers()) {
+      for (auto *user : result.getUsers()) {
         // Don't visit ops not in our set.
         if (unsortedOps.count(user) == 0) continue;
         visit(user);
diff --git a/iree/compiler/Utils/TypeConversionUtils.cpp b/iree/compiler/Utils/TypeConversionUtils.cpp
index 7d12545..6be7eb2 100644
--- a/iree/compiler/Utils/TypeConversionUtils.cpp
+++ b/iree/compiler/Utils/TypeConversionUtils.cpp
@@ -61,7 +61,7 @@
 }
 
 MemRefType convertTypeToMemRef(Value value) {
-  return convertTypeToMemRef(value->getType());
+  return convertTypeToMemRef(value.getType());
 }
 
 }  // namespace iree_compiler