[NFC] Simplify type checks with isa predicates (#16935)
For more context on isa predicates, see:
https://github.com/llvm/llvm-project/pull/83753.
Also clean up some surrounding casts/isas.
diff --git a/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp b/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp
index 992a8e7..54bf385 100644
--- a/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
@@ -119,12 +120,10 @@
// (any sign-op, or an integral abs-op).
// TODO(peiming, ajcbik): these all can potentially be optimized by applying
// value transform on sparse_tenosr.value memref
- if (isa<mlir::stablehlo::SignOp>(op) || isa<mlir::stablehlo::NegOp>(op) ||
+ if (isa<mlir::stablehlo::SignOp, mlir::stablehlo::NegOp>(op) ||
(isa<mlir::stablehlo::AbsOp>(op) && hasIntegralShapeType(op)) ||
- isa<chlo::AsinOp>(op) || isa<chlo::AsinhOp>(op) ||
- isa<chlo::AtanOp>(op) || isa<chlo::AtanhOp>(op) ||
- isa<chlo::BesselI1eOp>(op) || isa<chlo::SinhOp>(op) ||
- isa<chlo::TanOp>(op)) {
+ isa<chlo::AsinOp, chlo::AsinhOp, chlo::AtanOp, chlo::AtanhOp,
+ chlo::BesselI1eOp, chlo::SinhOp, chlo::TanOp>(op)) {
if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) &&
!sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType()))
return Value();
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
index ed8af82..1977ceb 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -162,8 +162,7 @@
// TODO(benvanik): always pass in a signal fence? could be useful if we
// want to allow for async work using fences that's not device-related.
const bool haveTensorResults =
- llvm::any_of(oldImportType.getResults(),
- [](Type type) { return llvm::isa<TensorType>(type); });
+ llvm::any_of(oldImportType.getResults(), llvm::IsaPred<TensorType>);
if (!haveTensorResults && !hasSideEffects) {
// No tensors returned from import - pass in an immediate signal.
signalFence = entryBuilder.create<IREE::Util::NullOp>(
diff --git a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
index 579d564..54094f8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/BufferizationAnalysis.cpp
@@ -365,12 +365,9 @@
/// - all `tensor.extract_slice` operations dominate the `tensor.insert_slice`
/// op.
static void hasDestructiveUpdatePattern(Value source, BufferizationPlan &plan) {
- auto isUpdateOp = [](Operation *op) {
- return isa<tensor::InsertSliceOp, vector::TransferWriteOp>(op);
- };
- auto isReadOp = [](Operation *op) {
- return isa<tensor::ExtractSliceOp, vector::TransferReadOp>(op);
- };
+ auto isUpdateOp =
+ llvm::IsaPred<tensor::InsertSliceOp, vector::TransferWriteOp>;
+ auto isReadOp = llvm::IsaPred<tensor::ExtractSliceOp, vector::TransferReadOp>;
auto getDest = [](Operation *op) -> Value {
if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
return insertSliceOp.getDest();
@@ -583,13 +580,11 @@
bufferization::AllocTensorOp>(
[&](Operation *op) { return success(); })
.Default([&](Operation *op) -> LogicalResult {
- if (llvm::any_of(op->getOperands(),
- [](Value v) {
- return llvm::isa<RankedTensorType>(v.getType());
- }) ||
- llvm::any_of(op->getResultTypes(), [](Type t) {
- return llvm::isa<RankedTensorType>(t);
- })) {
+ if (llvm::any_of(
+ op->getOperands(),
+ [](Value v) { return isa<RankedTensorType>(v.getType()); }) ||
+ llvm::any_of(op->getResultTypes(),
+ llvm::IsaPred<RankedTensorType>)) {
return op->emitOpError("unhandled tensor operation");
}
return success();
@@ -609,7 +604,7 @@
return;
}
if (auto vectorWriteOp = dyn_cast<vector::TransferWriteOp>(updateOp)) {
- if (llvm::isa<RankedTensorType>(vectorWriteOp.getSource().getType())) {
+ if (isa<RankedTensorType>(vectorWriteOp.getSource().getType())) {
hasDestructiveUpdatePattern(vectorWriteOp.getSource(), plan);
}
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
index 6e793e8..36ebc2b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
@@ -212,14 +212,11 @@
llvm::append_range(values, op->getResults());
// First check if any of them are vector values.
- if (llvm::none_of(values, [](Value value) -> bool {
- return isa<VectorValue>(value);
- })) {
+ if (llvm::none_of(values, llvm::IsaPred<VectorValue>))
return false;
- }
// Check if all operands and results of this operation have a layout.
- return llvm::all_of(values, [&](Value value) -> bool {
+ return llvm::all_of(values, [&analysis](Value value) {
auto vectorValue = dyn_cast<VectorValue>(value);
return !vectorValue || analysis.getLayout<Attribute>(vectorValue);
});
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
index 705910f..b4e89eb 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp
@@ -177,20 +177,16 @@
// instead of just being a zerofill.
ForwardSliceOptions forwardOptions;
forwardOptions.filter = [&](Operation *op) -> bool {
- return llvm::any_of(op->getResultTypes(),
- [](Type t) { return isa<VectorType>(t); });
+ return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
};
BackwardSliceOptions backwardOptions;
backwardOptions.filter = [&](Operation *op) -> bool {
- return llvm::any_of(op->getOperandTypes(),
- [](Type t) { return isa<VectorType>(t); });
+ return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
};
SetVector<Operation *> slice =
getSlice(transfer, backwardOptions, forwardOptions);
- if (llvm::any_of(slice, [](Operation *op) {
- return llvm::isa<vector::ContractionOp>(op);
- })) {
+ if (llvm::any_of(slice, llvm::IsaPred<vector::ContractionOp>)) {
return success();
}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 0722b12..707fa6a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -65,9 +65,9 @@
if (!forall.getMapping().has_value()) {
return false;
}
- return llvm::any_of(forall.getMapping().value(), [](Attribute attr) {
- return isa<gpu::GPUThreadMappingAttr>(attr);
- });
+ return llvm::any_of(*forall.getMapping(),
+ llvm::IsaPred<gpu::GPUThreadMappingAttr>);
+ ;
}
// All pipelines that use this allocation function distribute scf.forall ops
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
index f63180e..4c58624 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp
@@ -17,6 +17,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
@@ -452,16 +453,14 @@
using OpRewritePattern<vector::WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(
- warpOp, [](Operation *op) { return isa<memref::LoadOp>(op); });
+ OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<memref::LoadOp>);
if (!operand)
return failure();
auto load = operand->get().getDefiningOp<memref::LoadOp>();
unsigned operandIndex = operand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex);
- SmallVector<Value> indices(load.getIndices().begin(),
- load.getIndices().end());
+ auto indices = llvm::to_vector_of<Value>(load.getIndices());
if (!indices.empty())
return failure();
diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
index 22741fa..0c8f4c2 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
@@ -271,26 +272,11 @@
/// Returns the first of `exprs` which is of the type `T`.
template <typename T>
static AffineExpr getAffineExprOfType(ArrayRef<AffineExpr> exprs) {
- for (auto expr : exprs) {
- if (isa<T>(expr))
- return expr;
- }
+ if (auto it = llvm::find_if(exprs, llvm::IsaPred<T>); it != exprs.end())
+ return *it;
return nullptr;
}
-/// Returns true if the `expr` is on of the types in {`T1`, `T2`, `T3...`}.
-template <typename T>
-static bool isaAffineExprOfType(AffineExpr expr) {
- return isa<T>(expr);
-}
-template <typename T1, typename T2, typename... T3>
-static bool isaAffineExprOfType(AffineExpr expr) {
- if (isa<T1>(expr)) {
- return true;
- }
- return isaAffineExprOfType<T2, T3...>(expr);
-}
-
/// Returns a Value that represents the value for symbol or dim expr for the map
/// in the `applyOp`.
static Value getValueForDimOrSymbol(affine::AffineApplyOp applyOp,
@@ -387,7 +373,7 @@
// The other expression must be the undistributed `lb`.
AffineExpr lbExpr =
(offsetExpr == expr.getLHS() ? expr.getRHS() : expr.getLHS());
- if (isaAffineExprOfType<AffineDimExpr, AffineSymbolExpr>(lbExpr)) {
+ if (isa<AffineDimExpr, AffineSymbolExpr>(lbExpr)) {
Value v = getValueForDimOrSymbol(applyOp, lbExpr);
if (!v) {
return failure();
@@ -541,7 +527,7 @@
private:
LogicalResult processSentinel(AffineExpr e,
SmallVectorImpl<AffineExpr> &sentinels) {
- if (isaAffineExprOfType<AffineDimExpr, AffineSymbolExpr>(e)) {
+ if (isa<AffineDimExpr, AffineSymbolExpr>(e)) {
sentinels.push_back(e);
return success();
} else if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
@@ -1172,9 +1158,7 @@
backwardSlice.set_union(tmpBackwardSlice);
}
- return llvm::any_of(backwardSlice, [](Operation *op) {
- return llvm::isa<linalg::LinalgOp>(op);
- });
+ return llvm::any_of(backwardSlice, llvm::IsaPred<linalg::LinalgOp>);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
index 4ea0993..bb150a8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.cpp
@@ -51,9 +51,9 @@
Value value) {
// Check if `v` is defined outside of `regionOp`.
auto isOutside = [&](Value v) {
- if (llvm::isa<OpResult>(v))
+ if (isa<OpResult>(v))
return !regionOp->isAncestor(v.getDefiningOp());
- assert(v.isa<BlockArgument>() && "expected bbArg");
+ assert(isa<BlockArgument>(v) && "expected bbArg");
// DispatchRegionOp does not have block arguments.
return true;
};
@@ -167,9 +167,9 @@
rewriter.inlineRegionBefore(regionOp.getWorkgroupCount(),
workgroupsOp.getWorkgroupCount(),
workgroupsOp.getWorkgroupCount().begin());
- mlir::makeRegionIsolatedFromAbove(
- rewriter, workgroupsOp.getWorkgroupCount(),
- [](Operation *op) { return isa<arith::ConstantOp>(op); });
+ mlir::makeRegionIsolatedFromAbove(rewriter,
+ workgroupsOp.getWorkgroupCount(),
+ llvm::IsaPred<arith::ConstantOp>);
}
IRMapping bvm;
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index cc7933c..d6f40fb 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -276,9 +276,8 @@
// TODO(ravishankarm): Maybe make `set_encoding` have pad semantics that can be
// explicitly broken down if needed.
static bool isPadUsedInSetEncoding(tensor::PadOp padOp) {
- return llvm::any_of(padOp->getUsers(), [](Operation *user) {
- return isa<IREE::LinalgExt::SetEncodingOp>(user);
- });
+ return llvm::any_of(padOp->getUsers(),
+ llvm::IsaPred<IREE::LinalgExt::SetEncodingOp>);
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InitializeEmptyTensors.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InitializeEmptyTensors.cpp
index c80bc84..38864c8 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/InitializeEmptyTensors.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/InitializeEmptyTensors.cpp
@@ -32,10 +32,8 @@
/// `flow.tensor.*` op.
static bool shouldBeConvertedToFlowTensorOp(tensor::EmptyOp emptyTensorOp) {
return !(llvm::all_of(emptyTensorOp->getUsers(),
- [](Operation *user) -> bool {
- return isa<linalg::LinalgOp, LinalgExt::LinalgExtOp,
- tensor::PackOp, tensor::UnPackOp>(user);
- }) ||
+ llvm::IsaPred<linalg::LinalgOp, LinalgExt::LinalgExtOp,
+ tensor::PackOp, tensor::UnPackOp>) ||
emptyTensorOp->getParentOfType<Flow::DispatchWorkgroupsOp>());
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index 08f5408..b08cd74 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -23,8 +23,10 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Value.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Transforms/TopologicalSortUtils.h"
@@ -172,8 +174,7 @@
// needs to have a single region with a single block. This seems
// unnecessary for IREEs use case. For now avoid this assert by bailing if
// any operands are block arguments.
- if (llvm::any_of(op->getOperands(),
- [](Value v) { return llvm::isa<BlockArgument>(v); })) {
+ if (llvm::any_of(op->getOperands(), llvm::IsaPred<BlockArgument>)) {
auto parentOp = op->getParentOp();
if (parentOp->getNumRegions() != 1 ||
parentOp->getRegion(0).getBlocks().size() != 1) {
@@ -186,8 +187,7 @@
for (Value initOperand : linalgOp.getDpsInits()) {
mlir::getBackwardSlice(initOperand, &slice, options);
}
- return llvm::any_of(
- slice, [](Operation *op) { return isa<tensor::ExtractOp>(op); });
+ return llvm::any_of(slice, llvm::IsaPred<tensor::ExtractOp>);
}
return false;
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index 23c8399..1b07667 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -3011,16 +3011,12 @@
// static
SmallVector<unsigned>
CmdDispatchOp::makeResourceToArgMap(mlir::FunctionOpInterface funcOp) {
- unsigned operandCount =
- llvm::count_if(funcOp.getArgumentTypes(), [](Type type) {
- return llvm::isa<IREE::Stream::BindingType>(type);
- });
+ unsigned operandCount = llvm::count_if(
+ funcOp.getArgumentTypes(), llvm::IsaPred<IREE::Stream::BindingType>);
SmallVector<unsigned> map(operandCount);
- unsigned operandIdx = 0;
- for (auto it : llvm::enumerate(funcOp.getArgumentTypes())) {
- unsigned argIdx = it.index();
- auto argType = it.value();
- if (llvm::isa<IREE::Stream::BindingType>(argType)) {
+ size_t operandIdx = 0;
+ for (auto [argIdx, argType] : llvm::enumerate(funcOp.getArgumentTypes())) {
+ if (isa<IREE::Stream::BindingType>(argType)) {
map[operandIdx++] = argIdx;
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
index 7350487..46c09fa 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ConvertToStream.cpp
@@ -116,18 +116,17 @@
// the stream ops that capture encodings and shapes.
static bool doesOperationNeedWrapping(Operation *op) {
return llvm::any_of(op->getOperands(),
- [&](Value operand) {
+ [](Value operand) {
if (!llvm::isa<TensorType>(operand.getType()))
return false;
return !isa_and_nonnull<TensorExportOp>(
operand.getDefiningOp());
}) ||
- llvm::any_of(op->getResults(), [&](Value result) {
- if (!llvm::isa<TensorType>(result.getType()))
+ llvm::any_of(op->getResults(), [](Value result) {
+ if (!isa<TensorType>(result.getType()))
return false;
- return !llvm::all_of(result.getUsers(), [&](Operation *user) {
- return isa<TensorImportOp>(user);
- });
+ return !llvm::all_of(result.getUsers(),
+ llvm::IsaPred<TensorImportOp>);
});
}
@@ -148,12 +147,11 @@
rewriter.setInsertionPoint(op);
for (auto [oldOperand, newOperand] :
llvm::zip_equal(op->getOperands(), operands)) {
- if (!llvm::isa<IREE::Stream::ResourceType>(newOperand.getType()) &&
- !llvm::isa<TensorType>(newOperand.getType())) {
+ if (!isa<IREE::Stream::ResourceType, TensorType>(newOperand.getType())) {
newOperands.push_back(newOperand);
continue;
}
- auto tensorType = llvm::dyn_cast<TensorType>(oldOperand.getType());
+ auto tensorType = dyn_cast<TensorType>(oldOperand.getType());
assert(tensorType && "must have a tensor type to map to a resource");
auto dynamicDims = IREE::Util::buildDynamicDimsForValue(
@@ -166,7 +164,7 @@
// Import into resources from tensor results produced by the op.
rewriter.setInsertionPointAfter(op);
for (auto result : op->getResults()) {
- auto tensorType = llvm::dyn_cast<TensorType>(result.getType());
+ auto tensorType = dyn_cast<TensorType>(result.getType());
if (!tensorType)
continue;
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h
index cef8edb..f5ebb30 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h
@@ -83,9 +83,8 @@
int getNumRefArguments() {
assert(originalFunctionType.has_value());
- return llvm::count_if(
- originalFunctionType.value().getInputs(),
- [](Type inputType) { return isa<IREE::VM::RefType>(inputType); });
+ return llvm::count_if(originalFunctionType.value().getInputs(),
+ llvm::IsaPred<IREE::VM::RefType>);
}
int getNumLocalRefs() { return getNumRefRegisters() - getNumRefArguments(); }
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 01d5dac..d6b431f 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -109,11 +109,9 @@
bool isGlobalStoreOp(Operation *op) const {
// TODO(benvanik): trait/interface to make this more generic?
- return isa<IREE::VM::GlobalStoreI32Op>(op) ||
- isa<IREE::VM::GlobalStoreI64Op>(op) ||
- isa<IREE::VM::GlobalStoreF32Op>(op) ||
- isa<IREE::VM::GlobalStoreF64Op>(op) ||
- isa<IREE::VM::GlobalStoreRefOp>(op);
+ return isa<IREE::VM::GlobalStoreI32Op, IREE::VM::GlobalStoreI64Op,
+ IREE::VM::GlobalStoreF32Op, IREE::VM::GlobalStoreF64Op,
+ IREE::VM::GlobalStoreRefOp>(op);
}
};