Integrate LLVM at llvm/llvm-project@bcad20bc6591
Updates LLVM usage to match
[bcad20bc6591](https://github.com/llvm/llvm-project/commit/bcad20bc6591)
PiperOrigin-RevId: 407237921
diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index 3f13478..67ee5f4 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -4,7 +4,7 @@
aa533abfd4232b01f9e57041d70114d5a77e6de0 third_party/googletest
88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
acd6f6f014c25e46363e718381e0b35205df2d83 third_party/libyaml
-97a1570d8c31dc3bff12dd77b1ee824e1872bb69 third_party/llvm-project
+bcad20bc6591c8b503923402038c735a77373f99 third_party/llvm-project
47c417090dc5192e1da5206cc99bd6f145a2bc32 third_party/mlir-hlo
3f701faace7addc75d16dea8a6cd769fa5b3f260 third_party/musl
4c7697dbe973ed01ae6fbec37d186ebd05982e1f third_party/pybind11
diff --git a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
index ff2375a..7f86260 100644
--- a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
+++ b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
@@ -132,7 +132,7 @@
using OpConversionPattern<AllocOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
- AllocOpTy allocOp, ArrayRef<Value> operands,
+ AllocOpTy allocOp, typename AllocOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldType = allocOp.getType().template dyn_cast<MemRefType>();
if (!oldType || !oldType.getLayout().isIdentity()) return failure();
@@ -163,7 +163,7 @@
}
LogicalResult matchAndRewrite(
- memref::GlobalOp globalOp, ArrayRef<Value> operands,
+ memref::GlobalOp globalOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldType = globalOp.type().dyn_cast<MemRefType>();
if (!oldType || !oldType.getLayout().isIdentity()) return failure();
@@ -189,7 +189,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- memref::GetGlobalOp getOp, ArrayRef<Value> operands,
+ memref::GetGlobalOp getOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldType = getOp.getType().dyn_cast<MemRefType>();
if (!oldType || !oldType.getLayout().isIdentity()) return failure();
@@ -213,7 +213,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- IREE::HAL::InterfaceBindingSubspanOp subspanOp, ArrayRef<Value> operands,
+ IREE::HAL::InterfaceBindingSubspanOp subspanOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldType = subspanOp.getType().dyn_cast<MemRefType>();
// IREE subspan ops only use memref types with the default identity
@@ -419,11 +419,11 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- UnrealizedConversionCastOp castOp, ArrayRef<Value> operands,
+ UnrealizedConversionCastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (castOp->getNumOperands() != 1) return failure();
- Value input = operands.front();
+ Value input = adaptor.getOperands().front();
// We only want to handle cases where the cast op handles memref types.
if (!input.getType().isa<BaseMemRefType>()) return failure();
diff --git a/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
index 0d2f57d..dd54ba9 100644
--- a/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
@@ -343,7 +343,7 @@
using OpConversionPattern<InterfaceOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
- InterfaceOpTy op, ArrayRef<Value> operands,
+ InterfaceOpTy op, typename InterfaceOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type i32Type = rewriter.getI32Type();
diff --git a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
index 4f9253a..e5f5546 100644
--- a/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
@@ -161,7 +161,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- IREE::HAL::InterfaceLoadConstantOp loadOp, ArrayRef<Value> operands,
+ IREE::HAL::InterfaceLoadConstantOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(#1519): hal.interface.load.constant should point to the
// hal.interface op.
@@ -194,7 +194,7 @@
using OpConversionPattern<InterfaceOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
- InterfaceOpTy op, ArrayRef<Value> operands,
+ InterfaceOpTy op, typename InterfaceOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int32_t index = static_cast<int32_t>(op.dimension().getSExtValue());
auto i32Type = rewriter.getIntegerType(32);
@@ -219,8 +219,7 @@
interfaceToResourceVars(interfaceToResourceVars) {}
LogicalResult matchAndRewrite(
- IREE::HAL::InterfaceBindingSubspanOp interfaceOp,
- ArrayRef<Value> operands,
+ IREE::HAL::InterfaceBindingSubspanOp interfaceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (interfaceOp.use_empty()) {
rewriter.eraseOp(interfaceOp);
@@ -252,9 +251,9 @@
struct FoldAsNoOp final : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
- OpTy op, ArrayRef<Value> operands,
+ OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOp(op, operands);
+ rewriter.replaceOp(op, adaptor.getOperands());
return success();
}
};
@@ -265,11 +264,12 @@
: public OpConversionPattern<UnrealizedConversionCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- UnrealizedConversionCastOp op, ArrayRef<Value> operands,
+ UnrealizedConversionCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op->getNumOperands() == 1 && op->getNumResults() == 1 &&
- operands.front().getType() == op->getResultTypes().front()) {
- rewriter.replaceOp(op, operands);
+ adaptor.getOperands().front().getType() ==
+ op->getResultTypes().front()) {
+ rewriter.replaceOp(op, adaptor.getOperands());
return success();
}
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp
index 470170e..170f1bc 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorToCooperativeOps.cpp
@@ -173,10 +173,10 @@
using OpConversionPattern<SrcOpType>::OpConversionPattern;
LogicalResult matchAndRewrite(
- SrcOpType op, ArrayRef<Value> operands,
+ SrcOpType op, typename SrcOpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// All operands should be of cooperative matrix types.
- for (Value operand : operands) {
+ for (Value operand : adaptor.getOperands()) {
if (!operand.getType().isa<spirv::CooperativeMatrixNVType>())
return failure();
}
@@ -185,7 +185,7 @@
if (op->getNumResults() != 1) return failure();
auto matType = this->typeConverter->convertType(op.getType());
- rewriter.replaceOpWithNewOp<DstOpType>(op, matType, operands);
+ rewriter.replaceOpWithNewOp<DstOpType>(op, matType, adaptor.getOperands());
return success();
}
};
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
index 73bbe6a..4367451 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorizeLoadStore.cpp
@@ -178,7 +178,7 @@
public:
using MemRefConversionPattern<FuncOp>::MemRefConversionPattern;
LogicalResult matchAndRewrite(
- FuncOp funcOp, ArrayRef<Value> operands,
+ FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
@@ -188,7 +188,7 @@
using MemRefConversionPattern<
vector::TransferReadOp>::MemRefConversionPattern;
LogicalResult matchAndRewrite(
- vector::TransferReadOp read, ArrayRef<Value> operands,
+ vector::TransferReadOp read, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!memrefUsageAnalysis.transferConvert(read)) {
return rewriter.notifyMatchFailure(
@@ -196,8 +196,6 @@
}
Location loc = read.getLoc();
- vector::TransferReadOp::Adaptor adaptor(operands,
- read->getAttrDictionary());
auto scalarMemrefType = read.source().getType().dyn_cast<MemRefType>();
auto vectorMemrefType = adaptor.source().getType().dyn_cast<MemRefType>();
@@ -246,7 +244,7 @@
using MemRefConversionPattern<
vector::TransferWriteOp>::MemRefConversionPattern;
LogicalResult matchAndRewrite(
- vector::TransferWriteOp write, ArrayRef<Value> operands,
+ vector::TransferWriteOp write, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!memrefUsageAnalysis.transferConvert(write)) {
return rewriter.notifyMatchFailure(
@@ -254,8 +252,6 @@
}
Location loc = write.getLoc();
- vector::TransferWriteOp::Adaptor adaptor(operands,
- write->getAttrDictionary());
auto scalarMemrefType = write.source().getType().dyn_cast<MemRefType>();
auto vectorMemrefType = adaptor.source().getType().dyn_cast<MemRefType>();
@@ -333,7 +329,7 @@
public:
using MemRefConversionPattern<memref::AllocOp>::MemRefConversionPattern;
LogicalResult matchAndRewrite(
- memref::AllocOp alloc, ArrayRef<Value> operands,
+ memref::AllocOp alloc, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto memrefType = getVectorizedMemRefType(rewriter, alloc.getResult());
if (!memrefType) return failure();
@@ -350,7 +346,7 @@
IREE::HAL::InterfaceBindingSubspanOp>::MemRefConversionPattern;
LogicalResult matchAndRewrite(
- IREE::HAL::InterfaceBindingSubspanOp bindingOp, ArrayRef<Value> operands,
+ IREE::HAL::InterfaceBindingSubspanOp bindingOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto memrefType = bindingOp.getType().dyn_cast<MemRefType>();
if (!memrefType) return failure();
@@ -460,7 +456,7 @@
} // namespace
LogicalResult ProcessFuncArg::matchAndRewrite(
- FuncOp funcOp, ArrayRef<Value> operands,
+ FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
TypeConverter::SignatureConversion signatureConverter(
funcOp.getType().getNumInputs());
diff --git a/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
index 68c60b6..3aa6d34 100644
--- a/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp
@@ -58,7 +58,7 @@
// static
LogicalResult HALConversionTarget::applyDefaultBufferRewrite(
- Operation *srcOp, ArrayRef<Value> operands, StringRef dstOpName,
+ Operation *srcOp, ValueRange operands, StringRef dstOpName,
TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
OperationState state{srcOp->getLoc(), dstOpName};
state.addAttributes(srcOp->getAttrs());
diff --git a/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h b/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
index 8f5c576..88869ad 100644
--- a/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
+++ b/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h
@@ -26,7 +26,7 @@
// Attempts to rewrite an op that may use tensor values into an op using HAL
// buffers. See HALOpConversion for more information.
static LogicalResult applyDefaultBufferRewrite(
- Operation *srcOp, ArrayRef<Value> operands, StringRef dstOpName,
+ Operation *srcOp, ValueRange operands, StringRef dstOpName,
TypeConverter &typeConverter, ConversionPatternRewriter &rewriter);
};
@@ -53,10 +53,11 @@
: OpConversionPattern<SRC>(context), typeConverter(typeConverter) {}
LogicalResult matchAndRewrite(
- SRC srcOp, ArrayRef<Value> operands,
+ SRC srcOp, typename SRC::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return HALConversionTarget::applyDefaultBufferRewrite(
- srcOp, operands, DST::getOperationName(), typeConverter, rewriter);
+ srcOp, adaptor.getOperands(), DST::getOperationName(), typeConverter,
+ rewriter);
}
protected:
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertGlobalOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertGlobalOps.cpp
index 748821a..26eec19 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertGlobalOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertGlobalOps.cpp
@@ -45,7 +45,7 @@
: OpConversionPattern(ctx), converter(converter) {}
LogicalResult matchAndRewrite(
- IREE::Util::GlobalOp globalOp, llvm::ArrayRef<Value> newOperands,
+ IREE::Util::GlobalOp globalOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): multiple converted type results to multiple globals.
Optional<Attribute> initialValue = globalOp.initial_value();
@@ -83,7 +83,7 @@
: OpConversionPattern(ctx), converter(converter) {}
LogicalResult matchAndRewrite(
- IREE::Util::GlobalAddressOp addressOp, llvm::ArrayRef<Value> newOperands,
+ IREE::Util::GlobalAddressOp addressOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): multiple converted type results to multiple globals.
rewriter.replaceOpWithNewOp<IREE::Util::GlobalAddressOp>(
@@ -103,7 +103,7 @@
: OpConversionPattern(ctx), converter(converter) {}
LogicalResult matchAndRewrite(
- IREE::Util::GlobalLoadOp loadOp, llvm::ArrayRef<Value> newOperands,
+ IREE::Util::GlobalLoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): multiple converted type results to multiple globals.
rewriter.replaceOpWithNewOp<IREE::Util::GlobalLoadOp>(
@@ -123,14 +123,12 @@
: OpConversionPattern(ctx), converter(converter) {}
LogicalResult matchAndRewrite(
- IREE::Util::GlobalLoadIndirectOp loadOp,
- llvm::ArrayRef<Value> newOperands,
+ IREE::Util::GlobalLoadIndirectOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- IREE::Util::GlobalLoadIndirectOp::Adaptor operands(newOperands);
// TODO(benvanik): multiple converted type results to multiple globals.
rewriter.replaceOpWithNewOp<IREE::Util::GlobalLoadIndirectOp>(
loadOp, converter.convertType(loadOp.result().getType()),
- operands.global());
+ adaptor.global());
return success();
}
@@ -166,9 +164,8 @@
: OpConversionPattern(ctx), converter(converter) {}
LogicalResult matchAndRewrite(
- IREE::Util::GlobalStoreOp storeOp, llvm::ArrayRef<Value> newOperands,
+ IREE::Util::GlobalStoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- IREE::Util::GlobalStoreOp::Adaptor operands(newOperands);
auto globalOp = storeOp.getGlobalOp();
if (!globalOp) return failure();
@@ -177,7 +174,7 @@
return rewriter.notifyMatchFailure(storeOp, "illegal global op type");
}
Value storeValue = implicitCastGlobalStore(
- storeOp.getLoc(), operands.value(), globalType, rewriter);
+ storeOp.getLoc(), adaptor.value(), globalType, rewriter);
if (!storeValue) {
return rewriter.notifyMatchFailure(storeOp,
"mismatched store and global type");
@@ -199,15 +196,12 @@
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(
- IREE::Util::GlobalStoreIndirectOp storeOp,
- llvm::ArrayRef<Value> newOperands,
+ IREE::Util::GlobalStoreIndirectOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- IREE::Util::GlobalStoreIndirectOp::Adaptor operands(newOperands);
-
Type globalType =
- operands.global().getType().cast<IREE::Util::PtrType>().getTargetType();
+ adaptor.global().getType().cast<IREE::Util::PtrType>().getTargetType();
Value storeValue = implicitCastGlobalStore(
- storeOp.getLoc(), operands.value(), globalType, rewriter);
+ storeOp.getLoc(), adaptor.value(), globalType, rewriter);
if (!storeValue) {
return rewriter.notifyMatchFailure(storeOp,
"mismatched store and global type");
@@ -215,7 +209,7 @@
// TODO(benvanik): multiple converted type results to multiple globals.
rewriter.replaceOpWithNewOp<IREE::Util::GlobalStoreIndirectOp>(
- storeOp, storeValue, operands.global());
+ storeOp, storeValue, adaptor.global());
return success();
}
};
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
index 19a0398..f0da141 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
@@ -1171,11 +1171,8 @@
using OpConversionPattern<
IREE::Flow::ExStreamFragmentOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- IREE::Flow::ExStreamFragmentOp streamOp, ArrayRef<Value> newOperands,
+ IREE::Flow::ExStreamFragmentOp streamOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- IREE::Flow::ExStreamFragmentOp::Adaptor adaptor(
- newOperands, streamOp->getAttrDictionary());
-
auto valueAliases = computeValueAliases(streamOp);
auto livenessIntervals = computeLivenessIntervals(streamOp, valueAliases);
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp
index 3a75f1f..40490b9 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp
@@ -31,17 +31,15 @@
: OpConversionPattern(ctx), converter(converter) {}
LogicalResult matchAndRewrite(
- IREE::Flow::TensorLoadOp loadOp, llvm::ArrayRef<Value> newOperands,
+ IREE::Flow::TensorLoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- IREE::Flow::TensorLoadOp::Adaptor operands(newOperands,
- loadOp->getAttrDictionary());
auto source = IREE::HAL::TensorRewriteAdaptor::getChecked(
- loadOp.getLoc(), loadOp.source(), operands.source(), rewriter);
+ loadOp.getLoc(), loadOp.source(), adaptor.source(), rewriter);
if (!source.hasValue()) {
return loadOp.emitOpError() << "cannot create adaptor for source";
}
- auto sourceOffset = source->computeOffset(operands.indices());
+ auto sourceOffset = source->computeOffset(adaptor.indices());
rewriter.replaceOpWithNewOp<IREE::HAL::BufferLoadOp>(
loadOp, converter.convertType(loadOp.result().getType()),
source->getBuffer(), sourceOffset);
@@ -59,21 +57,19 @@
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(
- IREE::Flow::TensorStoreOp storeOp, llvm::ArrayRef<Value> newOperands,
+ IREE::Flow::TensorStoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- IREE::Flow::TensorStoreOp::Adaptor operands(newOperands,
- storeOp->getAttrDictionary());
auto target = IREE::HAL::TensorRewriteAdaptor::getChecked(
- storeOp.getLoc(), storeOp.target(), operands.target(), rewriter);
+ storeOp.getLoc(), storeOp.target(), adaptor.target(), rewriter);
if (!target.hasValue()) {
return storeOp.emitOpError() << "cannot create adaptor for target";
}
- auto targetOffset = target->computeOffset(operands.indices());
+ auto targetOffset = target->computeOffset(adaptor.indices());
rewriter.create<IREE::HAL::BufferStoreOp>(
- storeOp.getLoc(), operands.value(), target->getBuffer(), targetOffset);
- rewriter.replaceOp(storeOp, {operands.value()});
+ storeOp.getLoc(), adaptor.value(), target->getBuffer(), targetOffset);
+ rewriter.replaceOp(storeOp, {adaptor.value()});
return success();
}
};
@@ -85,11 +81,11 @@
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(
- IREE::Flow::TensorTraceOp traceOp, llvm::ArrayRef<Value> rawOperands,
+ IREE::Flow::TensorTraceOp traceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = traceOp.getLoc();
SmallVector<Value, 4> bufferViews;
- for (auto operand : llvm::enumerate(rawOperands)) {
+ for (auto operand : llvm::enumerate(adaptor.getOperands())) {
auto adaptor = IREE::HAL::TensorRewriteAdaptor::get(
loc, traceOp.getOperand(operand.index()), operand.value(), rewriter);
bufferViews.emplace_back(adaptor.getBufferView());
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp
index 55a4a52..fdc5479 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertAllocatorOps.cpp
@@ -57,22 +57,21 @@
}
LogicalResult matchAndRewrite(
- IREE::HAL::AllocatorTryMapOp op, llvm::ArrayRef<Value> rawOperands,
+ IREE::HAL::AllocatorTryMapOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- IREE::HAL::AllocatorTryMapOp::Adaptor operands(rawOperands);
auto callOp = rewriter.create<IREE::VM::CallOp>(
op.getLoc(), importOp.getName(),
ArrayRef<Type>{getTypeConverter()->convertType(op.result().getType())},
ArrayRef<Value>{
- operands.allocator(),
+ adaptor.allocator(),
rewriter.createOrFold<IREE::VM::ConstI32Op>(op.getLoc(), /*try=*/1),
rewriter.createOrFold<IREE::VM::ConstI32Op>(op.getLoc(),
op.memory_typesAttr()),
rewriter.createOrFold<IREE::VM::ConstI32Op>(op.getLoc(),
op.buffer_usageAttr()),
- operands.source(),
- operands.offset(),
- operands.length(),
+ adaptor.source(),
+ adaptor.offset(),
+ adaptor.length(),
});
copyImportAttrs(importOp, callOp);
auto result = callOp.results().front();
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertConstantOps.cpp b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertConstantOps.cpp
index bb91412..47af8af 100644
--- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertConstantOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertConstantOps.cpp
@@ -29,7 +29,7 @@
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(
- mlir::arith::ConstantOp constantOp, llvm::ArrayRef<Value> newOperands,
+ mlir::arith::ConstantOp constantOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!constantOp.getType().isa<TensorType>()) return failure();
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertShapeOps.cpp b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertShapeOps.cpp
index 61f1e7e..ff7cffc 100644
--- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertShapeOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertShapeOps.cpp
@@ -26,10 +26,10 @@
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- Shape::TieShapeOp op, llvm::ArrayRef<Value> operands,
+ Shape::TieShapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<Shape::TieShapeOp>(op, operands[0],
- operands[1]);
+ rewriter.replaceOpWithNewOp<Shape::TieShapeOp>(op, adaptor.getOperands()[0],
+ adaptor.getOperands()[1]);
return success();
}
};
@@ -42,22 +42,22 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- tensor::DimOp dimOp, llvm::ArrayRef<Value> rawOperands,
+ tensor::DimOp dimOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- tensor::DimOp::Adaptor operands(rawOperands);
if (!IREE::HAL::TensorRewriteAdaptor::isValidNewType(
- operands.source().getType())) {
+ adaptor.source().getType())) {
return failure();
}
- auto adaptor = IREE::HAL::TensorRewriteAdaptor::get(
- dimOp.getLoc(), dimOp.source(), operands.source(), rewriter);
+ auto rewriteAdaptor = IREE::HAL::TensorRewriteAdaptor::get(
+ dimOp.getLoc(), dimOp.source(), adaptor.source(), rewriter);
Optional<int64_t> index = dimOp.getConstantIndex();
assert(index.hasValue() && "expect constant index in `std.dim` operation");
auto dimIndex = rewriter.getIndexAttr(index.getValue());
rewriter.replaceOpWithNewOp<IREE::HAL::BufferViewDimOp>(
- dimOp, dimOp.getResult().getType(), adaptor.getBufferView(), dimIndex);
+ dimOp, dimOp.getResult().getType(), rewriteAdaptor.getBufferView(),
+ dimIndex);
return success();
}
};
@@ -69,17 +69,18 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- RankOp rankOp, llvm::ArrayRef<Value> rawOperands,
+ RankOp rankOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!IREE::HAL::TensorRewriteAdaptor::isValidNewType(
- rawOperands[0].getType())) {
+ adaptor.getOperands()[0].getType())) {
return failure();
}
- auto adaptor = IREE::HAL::TensorRewriteAdaptor::get(
- rankOp.getLoc(), rankOp.getOperand(), rawOperands[0], rewriter);
+ auto rewriteAdaptor = IREE::HAL::TensorRewriteAdaptor::get(
+ rankOp.getLoc(), rankOp.getOperand(), adaptor.getOperands()[0],
+ rewriter);
rewriter.replaceOpWithNewOp<IREE::HAL::BufferViewRankOp>(
- rankOp, rankOp.getResult().getType(), adaptor.getBufferView());
+ rankOp, rankOp.getResult().getType(), rewriteAdaptor.getBufferView());
return success();
}
};
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp
index 634d2d4..b965ee3 100644
--- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp
@@ -25,24 +25,22 @@
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- IREE::HAL::TensorCastOp op, llvm::ArrayRef<Value> rawOperands,
+ IREE::HAL::TensorCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- IREE::HAL::TensorCastOpAdaptor newOperands(
- rawOperands, op.getOperation()->getAttrDictionary());
Value newValue = {};
auto targetType = op.target().getType();
if (targetType.isa<TensorType>()) {
// HAL type -> tensor<...>
- newValue = newOperands.source();
+ newValue = adaptor.source();
} else if (targetType.isa<IREE::HAL::BufferType>()) {
// tensor<...> -> !hal.buffer
- auto adaptor = IREE::HAL::TensorRewriteAdaptor::get(
- op.getLoc(), op.source(), newOperands.source(), rewriter);
- newValue = adaptor.getBuffer();
+ auto rewriteAdaptor = IREE::HAL::TensorRewriteAdaptor::get(
+ op.getLoc(), op.source(), adaptor.source(), rewriter);
+ newValue = rewriteAdaptor.getBuffer();
} else if (targetType.isa<IREE::HAL::BufferViewType>()) {
// tensor<...> -> !hal.buffer_view
- auto adaptor = IREE::HAL::TensorRewriteAdaptor::get(
- op.getLoc(), op.source(), newOperands.source(), rewriter);
+ auto rewriteAdaptor = IREE::HAL::TensorRewriteAdaptor::get(
+ op.getLoc(), op.source(), adaptor.source(), rewriter);
// Note that the buffer view cannot just be returned here: it's backing
// buffer will be correct, but the cast may be doing a metadata change,
@@ -54,12 +52,13 @@
if (auto sourceType =
originalValue.getType().dyn_cast<RankedTensorType>()) {
auto shapeDims = getShapeDims(rewriter, op.getLoc(), sourceType,
- newOperands.source_dims());
+ adaptor.source_dims());
newValue = rewriter.create<IREE::HAL::BufferViewCreateOp>(
- op.getLoc(), adaptor.getBuffer(), adaptor.getElementType(),
- adaptor.getEncodingType(), shapeDims);
+ op.getLoc(), rewriteAdaptor.getBuffer(),
+ rewriteAdaptor.getElementType(), rewriteAdaptor.getEncodingType(),
+ shapeDims);
} else {
- newValue = adaptor.getBufferView();
+ newValue = rewriteAdaptor.getBufferView();
}
}
if (!newValue) {
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStructuralOps.cpp b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStructuralOps.cpp
index 0084f74..e65015d 100644
--- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStructuralOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStructuralOps.cpp
@@ -25,7 +25,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- mlir::FuncOp funcOp, llvm::ArrayRef<Value> operands,
+ mlir::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto &typeConverter = *getTypeConverter();
@@ -68,9 +68,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- mlir::CallOp op, llvm::ArrayRef<Value> operands,
+ mlir::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- mlir::CallOpAdaptor adaptor(operands);
SmallVector<Type, 4> resultTypes;
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
resultTypes))) {
@@ -87,9 +86,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- mlir::BranchOp op, llvm::ArrayRef<Value> operands,
+ mlir::BranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- mlir::BranchOpAdaptor adaptor(operands);
rewriter.replaceOpWithNewOp<mlir::BranchOp>(op, op.dest(),
adaptor.destOperands());
return success();
@@ -101,10 +99,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- mlir::CondBranchOp op, llvm::ArrayRef<Value> operands,
+ mlir::CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- mlir::CondBranchOpAdaptor adaptor(operands,
- op.getOperation()->getAttrDictionary());
rewriter.replaceOpWithNewOp<mlir::CondBranchOp>(
op, adaptor.condition(), op.trueDest(), adaptor.trueDestOperands(),
op.falseDest(), adaptor.falseDestOperands());
@@ -117,9 +113,10 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- mlir::ReturnOp returnOp, llvm::ArrayRef<Value> operands,
+ mlir::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<mlir::ReturnOp>(returnOp, operands);
+ rewriter.replaceOpWithNewOp<mlir::ReturnOp>(returnOp,
+ adaptor.getOperands());
return success();
}
};
@@ -129,9 +126,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- mlir::SelectOp selectOp, llvm::ArrayRef<Value> operands,
+ mlir::SelectOp selectOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- mlir::SelectOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<mlir::SelectOp>(selectOp, adaptor.condition(),
adaptor.true_value(),
adaptor.false_value());
diff --git a/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h b/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
index 3fcf841..db83bb9 100644
--- a/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
+++ b/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h
@@ -17,7 +17,7 @@
struct GenericConvertTypesPattern : public OpConversionPattern<T> {
using OpConversionPattern<T>::OpConversionPattern;
LogicalResult matchAndRewrite(
- T op, llvm::ArrayRef<Value> newOperands,
+ T op, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
for (auto oldType : op.getOperation()->getResultTypes()) {
@@ -29,8 +29,8 @@
// resultTypes.append(newTypes);
resultTypes.push_back(newTypes.front());
}
- auto newOp = rewriter.create<T>(op.getLoc(), resultTypes, newOperands,
- op->getAttrs());
+ auto newOp = rewriter.create<T>(op.getLoc(), resultTypes,
+ adaptor.getOperands(), op->getAttrs());
rewriter.replaceOp(op, newOp->getResults());
return success();
}
diff --git a/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
index 219cb3a..21992c3 100644
--- a/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
+++ b/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
@@ -144,9 +144,9 @@
}
LogicalResult matchAndRewrite(
- T op, llvm::ArrayRef<Value> operands,
+ T op, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto results = rewriteToCall(op, Adaptor{operands}, importOp,
+ auto results = rewriteToCall(op, adaptor, importOp,
*this->getTypeConverter(), rewriter);
if (!results.hasValue()) return failure();
rewriter.replaceOp(op, results.getValue());
diff --git a/iree/compiler/Dialect/VM/Conversion/MathToVM/ConvertMathToVM.cpp b/iree/compiler/Dialect/VM/Conversion/MathToVM/ConvertMathToVM.cpp
index ad5371d..bfaabd4 100644
--- a/iree/compiler/Dialect/VM/Conversion/MathToVM/ConvertMathToVM.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/MathToVM/ConvertMathToVM.cpp
@@ -29,19 +29,19 @@
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
- SrcOpTy srcOp, typename SrcOpTy::Adaptor srcAdaptor,
+ SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): support vectors.
if (srcOp.result().getType().template isa<VectorType>()) return failure();
- switch (srcAdaptor.operand().getType().getIntOrFloatBitWidth()) {
+ switch (adaptor.operand().getType().getIntOrFloatBitWidth()) {
case 32:
rewriter.replaceOpWithNewOp<Dst32OpTy>(
- srcOp, srcAdaptor.operand().getType(), srcAdaptor.operand());
+ srcOp, adaptor.operand().getType(), adaptor.operand());
break;
case 64:
rewriter.replaceOpWithNewOp<Dst64OpTy>(
- srcOp, srcAdaptor.operand().getType(), srcAdaptor.operand());
+ srcOp, adaptor.operand().getType(), adaptor.operand());
break;
default:
llvm_unreachable("invalid target type");
@@ -55,21 +55,19 @@
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
- SrcOpTy srcOp, typename SrcOpTy::Adaptor srcAdaptor,
+ SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(benvanik): support vectors.
if (srcOp.result().getType().template isa<VectorType>()) return failure();
- switch (srcAdaptor.lhs().getType().getIntOrFloatBitWidth()) {
+ switch (adaptor.lhs().getType().getIntOrFloatBitWidth()) {
case 32:
- rewriter.replaceOpWithNewOp<Dst32OpTy>(
- srcOp, srcAdaptor.lhs().getType(), srcAdaptor.lhs(),
- srcAdaptor.rhs());
+ rewriter.replaceOpWithNewOp<Dst32OpTy>(srcOp, adaptor.lhs().getType(),
+ adaptor.lhs(), adaptor.rhs());
break;
case 64:
- rewriter.replaceOpWithNewOp<Dst64OpTy>(
- srcOp, srcAdaptor.lhs().getType(), srcAdaptor.lhs(),
- srcAdaptor.rhs());
+ rewriter.replaceOpWithNewOp<Dst64OpTy>(srcOp, adaptor.lhs().getType(),
+ adaptor.lhs(), adaptor.rhs());
break;
default:
llvm_unreachable("invalid target type");
diff --git a/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.cpp b/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.cpp
index 2a27ad6..5eca5e1 100644
--- a/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/MemRefToVM/ConvertMemRefToVM.cpp
@@ -31,9 +31,9 @@
struct FoldAsNoOp final : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
- OpTy op, ArrayRef<Value> operands,
+ OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOp(op, operands);
+ rewriter.replaceOp(op, adaptor.getOperands());
return success();
}
};
@@ -82,9 +82,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- memref::GlobalOp globalOp, ArrayRef<Value> rawOperands,
+ memref::GlobalOp globalOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- memref::GlobalOpAdaptor operands(rawOperands);
if (!isRankZeroOrOneMemRef(globalOp.type())) {
return rewriter.notifyMatchFailure(
globalOp,
@@ -112,9 +111,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- memref::GetGlobalOp getOp, ArrayRef<Value> rawOperands,
+ memref::GetGlobalOp getOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- memref::GetGlobalOpAdaptor operands(rawOperands);
if (!isRankZeroOrOneMemRef(getOp.result().getType())) {
return rewriter.notifyMatchFailure(
getOp, "only rank-0 and rank-1 memrefs are supported; flatten first");
@@ -130,9 +128,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- memref::LoadOp loadOp, ArrayRef<Value> rawOperands,
+ memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- memref::LoadOpAdaptor operands(rawOperands);
if (!isRankZeroOrOneMemRef(loadOp.memref().getType())) {
return rewriter.notifyMatchFailure(
loadOp,
@@ -147,35 +144,35 @@
if (integerType.isInteger(1) || integerType.isInteger(8)) {
if (integerType.isSigned() || integerType.isSignless()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI8SOp>(
- loadOp, newType, operands.memref(), byteOffset);
+ loadOp, newType, adaptor.memref(), byteOffset);
} else {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI8UOp>(
- loadOp, newType, operands.memref(), byteOffset);
+ loadOp, newType, adaptor.memref(), byteOffset);
}
} else if (integerType.isInteger(16)) {
if (integerType.isSigned() || integerType.isSignless()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI16SOp>(
- loadOp, newType, operands.memref(), byteOffset);
+ loadOp, newType, adaptor.memref(), byteOffset);
} else {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI16UOp>(
- loadOp, newType, operands.memref(), byteOffset);
+ loadOp, newType, adaptor.memref(), byteOffset);
}
} else if (integerType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI32Op>(
- loadOp, newType, operands.memref(), byteOffset);
+ loadOp, newType, adaptor.memref(), byteOffset);
} else if (integerType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadI64Op>(
- loadOp, newType, operands.memref(), byteOffset);
+ loadOp, newType, adaptor.memref(), byteOffset);
} else {
return rewriter.notifyMatchFailure(
loadOp, "invalid integer buffer element type");
}
} else if (oldType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadF32Op>(
- loadOp, newType, operands.memref(), byteOffset);
+ loadOp, newType, adaptor.memref(), byteOffset);
} else if (oldType.isF64()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferLoadF64Op>(
- loadOp, newType, operands.memref(), byteOffset);
+ loadOp, newType, adaptor.memref(), byteOffset);
} else {
return rewriter.notifyMatchFailure(loadOp,
"invalid float buffer element type");
@@ -189,9 +186,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- memref::StoreOp storeOp, ArrayRef<Value> rawOperands,
+ memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- memref::StoreOpAdaptor operands(rawOperands);
if (!isRankZeroOrOneMemRef(storeOp.memref().getType())) {
return rewriter.notifyMatchFailure(
storeOp,
@@ -203,22 +199,22 @@
getTypeConverter()->convertType(rewriter.getIndexType()), rewriter);
if (oldType.isInteger(1) || oldType.isInteger(8)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI8Op>(
- storeOp, operands.memref(), byteOffset, operands.value());
+ storeOp, adaptor.memref(), byteOffset, adaptor.value());
} else if (oldType.isInteger(16)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI16Op>(
- storeOp, operands.memref(), byteOffset, operands.value());
+ storeOp, adaptor.memref(), byteOffset, adaptor.value());
} else if (oldType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI32Op>(
- storeOp, operands.memref(), byteOffset, operands.value());
+ storeOp, adaptor.memref(), byteOffset, adaptor.value());
} else if (oldType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreI64Op>(
- storeOp, operands.memref(), byteOffset, operands.value());
+ storeOp, adaptor.memref(), byteOffset, adaptor.value());
} else if (oldType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreF32Op>(
- storeOp, operands.memref(), byteOffset, operands.value());
+ storeOp, adaptor.memref(), byteOffset, adaptor.value());
} else if (oldType.isF64()) {
rewriter.replaceOpWithNewOp<IREE::VM::BufferStoreF64Op>(
- storeOp, operands.memref(), byteOffset, operands.value());
+ storeOp, adaptor.memref(), byteOffset, adaptor.value());
} else {
return rewriter.notifyMatchFailure(storeOp,
"invalid buffer element type");
diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
index a4f25b0..79f26cb 100644
--- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
@@ -28,7 +28,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- ModuleOp srcOp, ArrayRef<Value> operands,
+ ModuleOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Do not attempt to convert the top level module.
// This mechanism can only support rewriting non top-level modules.
@@ -66,7 +66,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- FuncOp srcOp, ArrayRef<Value> operands,
+ FuncOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FunctionType srcFuncType = srcOp.getType();
TypeConverter::SignatureConversion signatureConversion(
@@ -137,9 +137,10 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- mlir::ReturnOp srcOp, ArrayRef<Value> operands,
+ mlir::ReturnOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<IREE::VM::ReturnOp>(srcOp, operands);
+ rewriter.replaceOpWithNewOp<IREE::VM::ReturnOp>(srcOp,
+ adaptor.getOperands());
return success();
}
};
@@ -151,7 +152,7 @@
TypeConverter &typeConverter;
LogicalResult matchAndRewrite(
- arith::ConstantOp srcOp, ArrayRef<Value> operands,
+ arith::ConstantOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto targetType = typeConverter.convertType(srcOp.getType());
if (!targetType) {
@@ -221,50 +222,49 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- arith::CmpIOp srcOp, ArrayRef<Value> operands,
+ arith::CmpIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- arith::CmpIOp::Adaptor srcAdaptor(operands);
auto returnType = rewriter.getIntegerType(32);
switch (srcOp.getPredicate()) {
case arith::CmpIPredicate::eq:
rewriter.replaceOpWithNewOp<IREE::VM::CmpEQI32Op>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
return success();
case arith::CmpIPredicate::ne:
rewriter.replaceOpWithNewOp<IREE::VM::CmpNEI32Op>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
return success();
case arith::CmpIPredicate::slt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32SOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
return success();
case arith::CmpIPredicate::sle:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32SOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
return success();
case arith::CmpIPredicate::sgt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32SOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
return success();
case arith::CmpIPredicate::sge:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32SOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
return success();
case arith::CmpIPredicate::ult:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTI32UOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
return success();
case arith::CmpIPredicate::ule:
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEI32UOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
return success();
case arith::CmpIPredicate::ugt:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTI32UOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
return success();
case arith::CmpIPredicate::uge:
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEI32UOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
return success();
default:
return failure();
@@ -276,9 +276,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- arith::CmpFOp srcOp, ArrayRef<Value> operands,
+ arith::CmpFOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- arith::CmpFOp::Adaptor srcAdaptor(operands);
auto returnType = rewriter.getIntegerType(32);
switch (srcOp.getPredicate()) {
case arith::CmpFPredicate::AlwaysFalse: // 0
@@ -291,9 +290,9 @@
rewriter.replaceOpWithNewOp<IREE::VM::OrI32Op>(
srcOp, returnType,
rewriter.createOrFold<IREE::VM::CmpNaNF32Op>(
- srcOp.getLoc(), returnType, srcAdaptor.lhs()),
+ srcOp.getLoc(), returnType, adaptor.lhs()),
rewriter.createOrFold<IREE::VM::CmpNaNF32Op>(
- srcOp.getLoc(), returnType, srcAdaptor.rhs()));
+ srcOp.getLoc(), returnType, adaptor.rhs()));
break;
case arith::CmpFPredicate::ORD: // !(isnan(lhs) || isnan(rhs))
rewriter.replaceOpWithNewOp<IREE::VM::XorI32Op>(
@@ -302,57 +301,57 @@
rewriter.createOrFold<IREE::VM::AndI32Op>(
srcOp.getLoc(), returnType,
rewriter.createOrFold<IREE::VM::CmpNaNF32Op>(
- srcOp.getLoc(), returnType, srcAdaptor.lhs()),
+ srcOp.getLoc(), returnType, adaptor.lhs()),
rewriter.createOrFold<IREE::VM::CmpNaNF32Op>(
- srcOp.getLoc(), returnType, srcAdaptor.rhs())));
+ srcOp.getLoc(), returnType, adaptor.rhs())));
break;
case arith::CmpFPredicate::OEQ: // ordered and equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpEQF32OOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
case arith::CmpFPredicate::OGT: // ordered and greater than
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTF32OOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
case arith::CmpFPredicate::OGE: // ordered and greater than or equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEF32OOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
case arith::CmpFPredicate::OLT: // ordered and less than
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTF32OOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
case arith::CmpFPredicate::OLE: // ordered and less than or equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEF32OOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
case arith::CmpFPredicate::ONE: // ordered and not equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpNEF32OOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
case arith::CmpFPredicate::UEQ: // unordered or equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpEQF32UOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
case arith::CmpFPredicate::UGT: // unordered or greater than
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTF32UOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
case arith::CmpFPredicate::UGE: // unordered or greater than or equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpGTEF32UOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
case arith::CmpFPredicate::ULT: // unordered or less than
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTF32UOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
case arith::CmpFPredicate::ULE: // unordered or less than or equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpLTEF32UOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
case arith::CmpFPredicate::UNE: // unordered or not equal
rewriter.replaceOpWithNewOp<IREE::VM::CmpNEF32UOp>(
- srcOp, returnType, srcAdaptor.lhs(), srcAdaptor.rhs());
+ srcOp, returnType, adaptor.lhs(), adaptor.rhs());
break;
default:
return rewriter.notifyMatchFailure(srcOp,
@@ -367,17 +366,16 @@
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
- SrcOpTy srcOp, ArrayRef<Value> operands,
+ SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- typename SrcOpTy::Adaptor srcAdaptor(operands);
- switch (srcAdaptor.operand().getType().getIntOrFloatBitWidth()) {
+ switch (adaptor.operand().getType().getIntOrFloatBitWidth()) {
case 32:
rewriter.replaceOpWithNewOp<Dst32OpTy>(
- srcOp, srcAdaptor.operand().getType(), srcAdaptor.operand());
+ srcOp, adaptor.operand().getType(), adaptor.operand());
break;
case 64:
rewriter.replaceOpWithNewOp<Dst64OpTy>(
- srcOp, srcAdaptor.operand().getType(), srcAdaptor.operand());
+ srcOp, adaptor.operand().getType(), adaptor.operand());
break;
default:
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
@@ -391,19 +389,16 @@
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
- SrcOpTy srcOp, ArrayRef<Value> operands,
+ SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- typename SrcOpTy::Adaptor srcAdaptor(operands);
- switch (srcAdaptor.lhs().getType().getIntOrFloatBitWidth()) {
+ switch (adaptor.lhs().getType().getIntOrFloatBitWidth()) {
case 32:
- rewriter.replaceOpWithNewOp<Dst32OpTy>(
- srcOp, srcAdaptor.lhs().getType(), srcAdaptor.lhs(),
- srcAdaptor.rhs());
+ rewriter.replaceOpWithNewOp<Dst32OpTy>(srcOp, adaptor.lhs().getType(),
+ adaptor.lhs(), adaptor.rhs());
break;
case 64:
- rewriter.replaceOpWithNewOp<Dst64OpTy>(
- srcOp, srcAdaptor.lhs().getType(), srcAdaptor.lhs(),
- srcAdaptor.rhs());
+ rewriter.replaceOpWithNewOp<Dst64OpTy>(srcOp, adaptor.lhs().getType(),
+ adaptor.lhs(), adaptor.rhs());
break;
default:
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
@@ -417,23 +412,22 @@
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
- SrcOpTy srcOp, ArrayRef<Value> operands,
+ SrcOpTy srcOp, typename SrcOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- typename SrcOpTy::Adaptor srcAdaptor(operands);
- Value amount = srcAdaptor.rhs();
+ Value amount = adaptor.rhs();
if (amount.getType().getIntOrFloatBitWidth() > 32) {
// Shift amounts are always 32-bit in the VM.
amount = rewriter.createOrFold<arith::TruncIOp>(
srcOp.getLoc(), rewriter.getI32Type(), amount);
}
- switch (srcAdaptor.lhs().getType().getIntOrFloatBitWidth()) {
+ switch (adaptor.lhs().getType().getIntOrFloatBitWidth()) {
case 32:
rewriter.replaceOpWithNewOp<Dst32OpTy>(srcOp, srcOp.getType(),
- srcAdaptor.lhs(), amount);
+ adaptor.lhs(), amount);
break;
case 64:
rewriter.replaceOpWithNewOp<Dst64OpTy>(srcOp, srcOp.getType(),
- srcAdaptor.lhs(), amount);
+ adaptor.lhs(), amount);
break;
default:
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
@@ -447,9 +441,9 @@
using OpConversionPattern<StdOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- StdOp srcOp, ArrayRef<Value> operands,
+ StdOp srcOp, typename StdOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOp(srcOp, operands);
+ rewriter.replaceOp(srcOp, adaptor.getOperands());
return success();
}
};
@@ -458,20 +452,18 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- arith::IndexCastOp srcOp, ArrayRef<Value> rawOperands,
+ arith::IndexCastOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- arith::IndexCastOpAdaptor operands(rawOperands);
- auto srcType = operands.in().getType();
+ auto srcType = adaptor.in().getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType == dstType) {
- rewriter.replaceOp(srcOp, rawOperands);
+ rewriter.replaceOp(srcOp, adaptor.getOperands());
} else if (srcType.getIntOrFloatBitWidth() <
dstType.getIntOrFloatBitWidth()) {
- rewriter.replaceOpWithNewOp<arith::ExtUIOp>(srcOp, dstType,
- operands.in());
+ rewriter.replaceOpWithNewOp<arith::ExtUIOp>(srcOp, dstType, adaptor.in());
} else {
rewriter.replaceOpWithNewOp<arith::TruncIOp>(srcOp, dstType,
- operands.in());
+ adaptor.in());
}
return success();
}
@@ -481,9 +473,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- arith::ExtUIOp srcOp, ArrayRef<Value> rawOperands,
+ arith::ExtUIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- arith::ExtUIOpAdaptor operands(rawOperands);
auto srcType = srcOp.in().getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isInteger(1) && dstType.isInteger(32)) {
@@ -492,17 +483,17 @@
// NOTE: this may not be required - if we know that the i1 is never able
// to have more than bit 0 manipulated then this is wasted work.
rewriter.replaceOpWithNewOp<IREE::VM::AndI32Op>(
- srcOp, dstType, operands.in(),
+ srcOp, dstType, adaptor.in(),
rewriter.createOrFold<IREE::VM::ConstI32Op>(srcOp.getLoc(), 1));
} else if (srcType.isInteger(8) && dstType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI8I32UOp>(srcOp, dstType,
- operands.in());
+ adaptor.in());
} else if (srcType.isInteger(16) && dstType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI16I32UOp>(srcOp, dstType,
- operands.in());
+ adaptor.in());
} else if (srcType.isInteger(32) && dstType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI32I64UOp>(srcOp, dstType,
- operands.in());
+ adaptor.in());
} else {
// TODO(benvanik): we should be building a sequence of extensions for
// things like i8 -> i64.
@@ -516,20 +507,19 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- arith::ExtSIOp srcOp, ArrayRef<Value> rawOperands,
+ arith::ExtSIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- arith::ExtSIOpAdaptor operands(rawOperands);
auto srcType = srcOp.in().getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isInteger(8) && dstType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI8I32SOp>(srcOp, dstType,
- operands.in());
+ adaptor.in());
} else if (srcType.isInteger(16) && dstType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI16I32SOp>(srcOp, dstType,
- operands.in());
+ adaptor.in());
} else if (srcType.isInteger(32) && dstType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::ExtI32I64SOp>(srcOp, dstType,
- operands.in());
+ adaptor.in());
} else {
// TODO(benvanik): we should be building a sequence of extensions for
// things like i8 -> i64.
@@ -543,9 +533,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- arith::TruncIOp srcOp, ArrayRef<Value> rawOperands,
+ arith::TruncIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- arith::TruncIOpAdaptor operands(rawOperands);
auto srcType = srcOp.in().getType();
auto resultType = srcOp.getResult().getType();
auto dstType = getTypeConverter()->convertType(resultType);
@@ -553,7 +542,7 @@
// i1 is represented as i32, so just mask off the bit and truncate as
// normal. Note that if we started as i64 we need to first get that into
// an i32 that we can work with.
- auto value = operands.in();
+ auto value = adaptor.in();
if (srcType.isInteger(64)) {
value = rewriter.createOrFold<IREE::VM::TruncI64I32Op>(srcOp.getLoc(),
dstType, value);
@@ -563,19 +552,19 @@
rewriter.createOrFold<IREE::VM::ConstI32Op>(srcOp.getLoc(), 1));
} else if (srcType.isInteger(32) && resultType.isInteger(8)) {
rewriter.replaceOpWithNewOp<IREE::VM::TruncI32I8Op>(srcOp, dstType,
- operands.in());
+ adaptor.in());
} else if (srcType.isInteger(32) && resultType.isInteger(16)) {
rewriter.replaceOpWithNewOp<IREE::VM::TruncI32I16Op>(srcOp, dstType,
- operands.in());
+ adaptor.in());
} else if (srcType.isInteger(64) && resultType.isInteger(8)) {
rewriter.replaceOpWithNewOp<IREE::VM::TruncI64I8Op>(srcOp, dstType,
- operands.in());
+ adaptor.in());
} else if (srcType.isInteger(64) && resultType.isInteger(16)) {
rewriter.replaceOpWithNewOp<IREE::VM::TruncI64I16Op>(srcOp, dstType,
- operands.in());
+ adaptor.in());
} else if (srcType.isInteger(64) && resultType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::TruncI64I32Op>(srcOp, dstType,
- operands.in());
+ adaptor.in());
} else {
return rewriter.notifyMatchFailure(srcOp, "unsupported truncation");
}
@@ -587,15 +576,14 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- arith::SIToFPOp srcOp, ArrayRef<Value> operands,
+ arith::SIToFPOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- arith::SIToFPOpAdaptor srcAdaptor(operands);
- auto srcType = operands[0].getType();
+ auto srcType = adaptor.getOperands()[0].getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isSignlessInteger(32) || srcType.isSignedInteger(32)) {
if (dstType.isF32()) {
- rewriter.replaceOpWithNewOp<IREE::VM::CastSI32F32Op>(srcOp, dstType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<IREE::VM::CastSI32F32Op>(
+ srcOp, dstType, adaptor.getOperands()[0]);
return success();
}
}
@@ -607,15 +595,14 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- arith::UIToFPOp srcOp, ArrayRef<Value> operands,
+ arith::UIToFPOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- arith::UIToFPOpAdaptor srcAdaptor(operands);
- auto srcType = operands[0].getType();
+ auto srcType = adaptor.getOperands()[0].getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isUnsignedInteger(32)) {
if (dstType.isF32()) {
- rewriter.replaceOpWithNewOp<IREE::VM::CastUI32F32Op>(srcOp, dstType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<IREE::VM::CastUI32F32Op>(
+ srcOp, dstType, adaptor.getOperands()[0]);
return success();
}
}
@@ -627,15 +614,14 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- arith::FPToSIOp srcOp, ArrayRef<Value> operands,
+ arith::FPToSIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- arith::FPToSIOpAdaptor srcAdaptor(operands);
- auto srcType = operands[0].getType();
+ auto srcType = adaptor.getOperands()[0].getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isF32()) {
if (dstType.isSignlessInteger(32) || dstType.isSignedInteger(32)) {
- rewriter.replaceOpWithNewOp<IREE::VM::CastF32SI32Op>(srcOp, dstType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<IREE::VM::CastF32SI32Op>(
+ srcOp, dstType, adaptor.getOperands()[0]);
return success();
}
}
@@ -647,15 +633,14 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- arith::FPToUIOp srcOp, ArrayRef<Value> operands,
+ arith::FPToUIOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- arith::FPToUIOpAdaptor srcAdaptor(operands);
- auto srcType = operands[0].getType();
+ auto srcType = adaptor.getOperands()[0].getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isF32()) {
if (srcType.isUnsignedInteger(32)) {
- rewriter.replaceOpWithNewOp<IREE::VM::CastF32UI32Op>(srcOp, dstType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<IREE::VM::CastF32UI32Op>(
+ srcOp, dstType, adaptor.getOperands()[0]);
return success();
}
}
@@ -667,22 +652,22 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- arith::BitcastOp srcOp, ArrayRef<Value> operands,
+ arith::BitcastOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto srcType = operands[0].getType();
+ auto srcType = adaptor.getOperands()[0].getType();
auto dstType = getTypeConverter()->convertType(srcOp.getResult().getType());
if (srcType.isF32() && dstType.isInteger(32)) {
- rewriter.replaceOpWithNewOp<IREE::VM::BitcastF32I32Op>(srcOp, dstType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<IREE::VM::BitcastF32I32Op>(
+ srcOp, dstType, adaptor.getOperands()[0]);
} else if (srcType.isInteger(32) && dstType.isF32()) {
- rewriter.replaceOpWithNewOp<IREE::VM::BitcastI32F32Op>(srcOp, dstType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<IREE::VM::BitcastI32F32Op>(
+ srcOp, dstType, adaptor.getOperands()[0]);
} else if (srcType.isF64() && dstType.isInteger(64)) {
- rewriter.replaceOpWithNewOp<IREE::VM::BitcastF64I64Op>(srcOp, dstType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<IREE::VM::BitcastF64I64Op>(
+ srcOp, dstType, adaptor.getOperands()[0]);
} else if (srcType.isInteger(64) && dstType.isF64()) {
- rewriter.replaceOpWithNewOp<IREE::VM::BitcastI64F64Op>(srcOp, dstType,
- operands[0]);
+ rewriter.replaceOpWithNewOp<IREE::VM::BitcastI64F64Op>(
+ srcOp, dstType, adaptor.getOperands()[0]);
} else {
return rewriter.notifyMatchFailure(srcOp, "unsupported bitcast");
}
@@ -693,34 +678,33 @@
class SelectOpConversion : public OpConversionPattern<SelectOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- SelectOp srcOp, ArrayRef<Value> operands,
+ SelectOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SelectOp::Adaptor srcAdaptor(operands);
- auto valueType = srcAdaptor.true_value().getType();
+ auto valueType = adaptor.true_value().getType();
if (valueType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::SelectI32Op>(
- srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(),
- srcAdaptor.false_value());
+ srcOp, valueType, adaptor.condition(), adaptor.true_value(),
+ adaptor.false_value());
return success();
} else if (valueType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::SelectI64Op>(
- srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(),
- srcAdaptor.false_value());
+ srcOp, valueType, adaptor.condition(), adaptor.true_value(),
+ adaptor.false_value());
return success();
} else if (valueType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::SelectF32Op>(
- srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(),
- srcAdaptor.false_value());
+ srcOp, valueType, adaptor.condition(), adaptor.true_value(),
+ adaptor.false_value());
return success();
} else if (valueType.isF64()) {
rewriter.replaceOpWithNewOp<IREE::VM::SelectF64Op>(
- srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(),
- srcAdaptor.false_value());
+ srcOp, valueType, adaptor.condition(), adaptor.true_value(),
+ adaptor.false_value());
return success();
} else if (valueType.isa<IREE::VM::RefType>()) {
rewriter.replaceOpWithNewOp<IREE::VM::SelectRefOp>(
- srcOp, valueType, srcAdaptor.condition(), srcAdaptor.true_value(),
- srcAdaptor.false_value());
+ srcOp, valueType, adaptor.condition(), adaptor.true_value(),
+ adaptor.false_value());
return success();
} else {
return rewriter.notifyMatchFailure(srcOp,
@@ -733,9 +717,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- AssertOp srcOp, ArrayRef<Value> newOperands,
+ AssertOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- AssertOpAdaptor operands(newOperands, srcOp->getAttrDictionary());
auto status = rewriter.create<IREE::VM::ConstI32Op>(
srcOp.getLoc(),
rewriter.getIntegerAttr(
@@ -743,10 +726,10 @@
static_cast<int32_t>(IREE::Util::StatusCode::FailedPrecondition)));
// TODO(benvanik): invert cond_fail instead.
auto invertedCondition = rewriter.createOrFold<IREE::VM::XorI32Op>(
- srcOp.getLoc(), operands.arg().getType(), operands.arg(),
+ srcOp.getLoc(), adaptor.arg().getType(), adaptor.arg(),
rewriter.createOrFold<IREE::VM::ConstI32Op>(srcOp.getLoc(), 1));
rewriter.replaceOpWithNewOp<IREE::VM::CondFailOp>(
- srcOp, invertedCondition, status, operands.msg().getValue());
+ srcOp, invertedCondition, status, adaptor.msg().getValue());
return success();
}
};
@@ -755,10 +738,10 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- BranchOp srcOp, ArrayRef<Value> operands,
+ BranchOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::VM::BranchOp>(srcOp, srcOp.getDest(),
- operands);
+ adaptor.getOperands());
return success();
}
};
@@ -767,13 +750,12 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- CondBranchOp srcOp, ArrayRef<Value> operands,
+ CondBranchOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Block *trueDest = srcOp.getTrueDest();
rewriter.replaceOpWithNewOp<IREE::VM::CondBranchOp>(
- srcOp, operands[0], trueDest,
- operands.slice(1, trueDest->getNumArguments()), srcOp.getFalseDest(),
- operands.slice(1 + trueDest->getNumArguments()));
+ srcOp, adaptor.getCondition(), trueDest, adaptor.getTrueDestOperands(),
+ srcOp.getFalseDest(), adaptor.getFalseDestOperands());
return success();
}
};
@@ -782,9 +764,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- CallOp srcOp, ArrayRef<Value> operands,
+ CallOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- CallOp::Adaptor srcAdaptor(operands);
// Convert function result types. The conversion framework will ensure
// that the callee has been equivalently converted.
SmallVector<Type, 4> resultTypes;
@@ -796,7 +777,7 @@
resultTypes.push_back(resultType);
}
rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
- srcOp, srcOp.getCallee(), resultTypes, srcAdaptor.operands());
+ srcOp, srcOp.getCallee(), resultTypes, adaptor.operands());
return success();
}
diff --git a/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStatusOps.cpp b/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStatusOps.cpp
index c57db63..26f1b9c 100644
--- a/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStatusOps.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStatusOps.cpp
@@ -18,12 +18,11 @@
: OpConversionPattern(context) {}
LogicalResult matchAndRewrite(
- IREE::Util::StatusCheckOkOp op, llvm::ArrayRef<Value> newOperands,
+ IREE::Util::StatusCheckOkOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- IREE::Util::StatusCheckOkOp::Adaptor operands(newOperands);
// If status value is non-zero, fail.
rewriter.replaceOpWithNewOp<IREE::VM::CondFailOp>(
- op, operands.status(), op.message().getValueOr(""));
+ op, adaptor.status(), op.message().getValueOr(""));
return success();
}
};
diff --git a/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
index c815be2..b0d75da 100644
--- a/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
+++ b/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
@@ -90,9 +90,8 @@
using OpConversionPattern<
IREEPublic::BufferViewToTensorOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- IREEPublic::BufferViewToTensorOp srcOp, ArrayRef<Value> operands,
+ IREEPublic::BufferViewToTensorOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- IREEPublic::BufferViewToTensorOpAdaptor adaptor(operands);
Type resultType = typeConverter->convertType(srcOp.target().getType());
if (!resultType) return failure();
rewriter.replaceOpWithNewOp<IREE::HAL::TensorCastOp>(
@@ -106,9 +105,8 @@
using OpConversionPattern<
IREEPublic::TensorToBufferViewOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- IREEPublic::TensorToBufferViewOp srcOp, ArrayRef<Value> operands,
+ IREEPublic::TensorToBufferViewOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- IREEPublic::TensorToBufferViewOpAdaptor adaptor(operands);
Type resultType = typeConverter->convertType(srcOp.target().getType());
if (!resultType) return failure();
rewriter.replaceOpWithNewOp<IREE::HAL::TensorCastOp>(
@@ -120,7 +118,7 @@
class BuiltinFuncOpPattern : public OpConversionPattern<FuncOp> {
using OpConversionPattern<FuncOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- FuncOp srcOp, ArrayRef<Value> operands,
+ FuncOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FunctionType srcFuncType = srcOp.getType();
TypeConverter::SignatureConversion signatureConversion(
@@ -178,7 +176,7 @@
class GlobalOpPattern : public OpConversionPattern<IREEPublic::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- IREEPublic::GlobalOp srcOp, ArrayRef<Value> operands,
+ IREEPublic::GlobalOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type newType = typeConverter->convertType(srcOp.type());
if (!newType) return failure();
diff --git a/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp b/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp
index 1e50720..5e98d4d 100644
--- a/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp
+++ b/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp
@@ -310,7 +310,7 @@
: public OpConversionPattern<chlo::ConstantLikeOp> {
using OpConversionPattern<chlo::ConstantLikeOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- chlo::ConstantLikeOp op, ArrayRef<Value> operands,
+ chlo::ConstantLikeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType().cast<RankedTensorType>();
if (!resultTy.hasRank())
@@ -322,14 +322,12 @@
return success();
}
- chlo::ConstantLikeOpAdaptor transformed(operands);
Location loc = op.getLoc();
int resultRank = resultTy.getRank();
SmallVector<Extent> resultExtents;
resultExtents.reserve(resultRank);
- appendExtents(rewriter, loc, resultExtents, transformed.operand(),
- resultTy);
+ appendExtents(rewriter, loc, resultExtents, adaptor.operand(), resultTy);
auto resultTy0D = RankedTensorType::get({}, resultTy.getElementType());
Value scalarConst = rewriter.create<mhlo::ConstOp>(
@@ -596,15 +594,14 @@
using OpConversionPattern<chlo::BroadcastSelectOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- chlo::BroadcastSelectOp op, ArrayRef<Value> operands,
+ chlo::BroadcastSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- chlo::BroadcastSelectOp::Adaptor transformed(operands);
Location loc = op.getLoc();
// Only support ranked operands.
- Value pred = transformed.pred();
- Value thenValue = transformed.on_true();
- Value elseValue = transformed.on_false();
+ Value pred = adaptor.pred();
+ Value thenValue = adaptor.on_true();
+ Value elseValue = adaptor.on_false();
auto predType = pred.getType().dyn_cast<RankedTensorType>();
auto thenType = thenValue.getType().dyn_cast<RankedTensorType>();
auto elseType = elseValue.getType().dyn_cast<RankedTensorType>();
@@ -713,12 +710,11 @@
using OpConversionPattern<mhlo::DynamicReshapeOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::DynamicReshapeOp op, ArrayRef<Value> rawOperands,
+ mhlo::DynamicReshapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- mhlo::DynamicReshapeOpAdaptor operands(rawOperands);
- Value input = operands.operand();
- Value outputShape = operands.output_shape();
+ Value input = adaptor.operand();
+ Value outputShape = adaptor.output_shape();
auto outputShapeType = outputShape.getType().dyn_cast<RankedTensorType>();
auto resultType = typeConverter->convertType(op.getType())
.dyn_cast_or_null<RankedTensorType>();
diff --git a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
index 15ed9af..756b9bb 100644
--- a/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
+++ b/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp
@@ -124,7 +124,7 @@
struct LinalgExtRegionHLOOpConversion : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
- OpTy op, ArrayRef<Value> args,
+ OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (!isInBodyOfLinalgExtOps(op)) return failure();
TensorType origRetType = op.getType().template dyn_cast<TensorType>();
@@ -132,8 +132,8 @@
SmallVector<Value> scalarArgs;
Type newRetType = getElementTypeOrSelf(
this->typeConverter->convertType(origRetType.getElementType()));
- Value result =
- lmhlo::HloOpToStdScalarOp::map<OpTy>(op, newRetType, args, &rewriter);
+ Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
+ op, newRetType, adaptor.getOperands(), &rewriter);
rewriter.replaceOp(op, result);
return success();
}
@@ -143,10 +143,10 @@
: public OpConversionPattern<mhlo::ReturnOp> {
using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::ReturnOp op, ArrayRef<Value> args,
+ mhlo::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (!isInBodyOfLinalgExtOps(op)) return failure();
- rewriter.replaceOpWithNewOp<linalg_ext::YieldOp>(op, args);
+ rewriter.replaceOpWithNewOp<linalg_ext::YieldOp>(op, adaptor.getOperands());
return success();
}
};
@@ -159,11 +159,12 @@
using OpConversionPattern<mhlo::SortOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::SortOp mhloSortOp, ArrayRef<Value> args,
+ mhlo::SortOp mhloSortOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto sortOp = rewriter.create<linalg_ext::SortOp>(
mhloSortOp.getLoc(), mhloSortOp.getResultTypes(),
- /*inputs=*/ValueRange{}, args, mhloSortOp.dimensionAttr());
+ /*inputs=*/ValueRange{}, adaptor.getOperands(),
+ mhloSortOp.dimensionAttr());
rewriter.inlineRegionBefore(mhloSortOp.comparator(), sortOp.region(),
sortOp.region().begin());
Region ®ion = sortOp.region();
@@ -226,8 +227,7 @@
return true;
}
- static SmallVector<int64_t> getTiedResultOperandIndices(
- ArrayRef<Value> args) {
+ static SmallVector<int64_t> getTiedResultOperandIndices(ValueRange operands) {
// Mark linalg_ext.scatter::orinigal as readwrite tensor.
return {0};
}
@@ -273,12 +273,11 @@
}
LogicalResult matchAndRewrite(
- mhlo::ScatterOp op, ArrayRef<Value> args,
+ mhlo::ScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (!hasCanonicalDimensionNumbers(op)) return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- mhlo::ScatterOpAdaptor adaptor(args);
Value original = adaptor.operand();
Value indices = adaptor.scatter_indices();
@@ -389,10 +388,9 @@
}
LogicalResult matchAndRewrite(
- mhlo::FftOp op, ArrayRef<Value> args,
+ mhlo::FftOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// Only handle 2^n fft length.
- mhlo::FftOpAdaptor adaptor(args);
auto operandType = adaptor.operand().getType().dyn_cast<RankedTensorType>();
if (!operandType || !operandType.hasStaticShape()) {
return failure();
@@ -451,23 +449,24 @@
using OpConversionPattern<mhlo::ReverseOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::ReverseOp op, ArrayRef<Value> args,
+ mhlo::ReverseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- auto ty = args[0].getType().dyn_cast<RankedTensorType>();
+ auto ty = adaptor.getOperands()[0].getType().dyn_cast<RankedTensorType>();
if (!ty) return failure();
Location loc = op.getLoc();
SmallVector<Value> dynSizes;
for (auto en : llvm::enumerate(ty.getShape())) {
if (en.value() == ShapedType::kDynamicSize) {
- dynSizes.push_back(
- rewriter.create<tensor::DimOp>(loc, args[0], en.index()));
+ dynSizes.push_back(rewriter.create<tensor::DimOp>(
+ loc, adaptor.getOperands()[0], en.index()));
}
}
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, dynSizes, ty.getShape(), ty.getElementType());
rewriter.replaceOpWithNewOp<linalg_ext::ReverseOp>(
- op, op->getResultTypes(), args, initTensor, op.dimensions());
+ op, op->getResultTypes(), adaptor.getOperands(), initTensor,
+ op.dimensions());
return success();
}
};
diff --git a/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp b/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
index 969f0a5..b1289fa 100644
--- a/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
+++ b/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
@@ -58,7 +58,7 @@
using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::ConcatenateOp op, ArrayRef<Value> args,
+ mhlo::ConcatenateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultType = this->typeConverter->convertType(op.getResult().getType())
.dyn_cast<RankedTensorType>();
@@ -73,11 +73,12 @@
SmallVector<Value, 3> offsets, sizes, strides;
for (int i = 0; i < rank; ++i) {
offsets.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
- sizes.push_back(rewriter.create<tensor::DimOp>(loc, args[0], i));
+ sizes.push_back(
+ rewriter.create<tensor::DimOp>(loc, adaptor.getOperands()[0], i));
strides.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 1));
}
Value resultDimSize = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- for (auto arg : args) {
+ for (auto arg : adaptor.getOperands()) {
auto size = rewriter.create<tensor::DimOp>(loc, arg, dim);
resultDimSize = rewriter.create<arith::AddIOp>(loc, resultDimSize, size);
}
@@ -90,7 +91,7 @@
rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
Value accBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- for (auto arg : args) {
+ for (auto arg : adaptor.getOperands()) {
offsets[dim] = accBound;
sizes[dim] = rewriter.create<tensor::DimOp>(loc, arg, dim);
result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
@@ -167,14 +168,13 @@
using OpConversionPattern<mhlo::FftOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::FftOp op, ArrayRef<Value> args,
+ mhlo::FftOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.fft_type() != "RFFT") {
return rewriter.notifyMatchFailure(op,
"non RFFT types are supported yet");
}
- mhlo::FftOpAdaptor adaptor(args);
auto inputType = adaptor.operand().getType().dyn_cast<RankedTensorType>();
if (!inputType || !inputType.hasStaticShape() || inputType.getRank() > 2) {
return rewriter.notifyMatchFailure(op, "only static 1D or 2D dft ops");
@@ -253,7 +253,7 @@
struct ConstOpConversion : public OpConversionPattern<mhlo::ConstOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- mhlo::ConstOp op, ArrayRef<Value> /*operands*/,
+ mhlo::ConstOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto valueAttr = op.value();
Type oldElType = valueAttr.getType().getElementType();
diff --git a/iree/compiler/Utils/PatternUtils.h b/iree/compiler/Utils/PatternUtils.h
index 476f0c3..50b0e9f 100644
--- a/iree/compiler/Utils/PatternUtils.h
+++ b/iree/compiler/Utils/PatternUtils.h
@@ -66,9 +66,9 @@
PatternBenefit benefit)
: OpConversionPattern<OpTy>(context, benefit), f(f) {}
LogicalResult matchAndRewrite(
- OpTy op, ArrayRef<Value> operands,
+ OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- return f(op, typename OpTy::Adaptor(operands), rewriter);
+ return f(op, adaptor, rewriter);
}
GenericOpRewritePattern<OpTy> f;
};
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
index 308702d..5e37fce 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
@@ -165,7 +165,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- pydm_d::AllocFreeVarOp srcOp, ArrayRef<Value> operands,
+ pydm_d::AllocFreeVarOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: We may want to initialize the list structurally in some way.
// This will fail either way on read from unassigned variable, but we need
@@ -275,9 +275,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- pydm_d::BoolToPredOp srcOp, ArrayRef<Value> operands,
+ pydm_d::BoolToPredOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- pydm_d::BoolToPredOp::Adaptor adaptor(operands);
rewriter.replaceOp(srcOp, adaptor.value());
return success();
}
@@ -287,7 +286,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- pydm_d::BoxOp srcOp, ArrayRef<Value> operands,
+ pydm_d::BoxOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
auto origType =
@@ -297,7 +296,7 @@
"not a PythonTypeInterface type");
auto typeCode = origType.getTypeCode();
auto list = createObjectList(loc, rewriter, static_cast<int>(typeCode),
- operands[0]);
+ adaptor.getOperands()[0]);
rewriter.replaceOp(srcOp, list);
return success();
}
@@ -307,9 +306,8 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- pydm_d::CallOp srcOp, ArrayRef<Value> operands,
+ pydm_d::CallOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- pydm_d::CallOp::Adaptor adaptor(operands);
SmallVector<Type> resultTypes;
if (failed(getTypeConverter()->convertTypes(srcOp.getResultTypes(),
resultTypes))) {
@@ -389,7 +387,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- pydm_d::FuncOp srcOp, ArrayRef<Value> operands,
+ pydm_d::FuncOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FunctionType srcFuncType = srcOp.getType();
TypeConverter::SignatureConversion signatureConversion(
@@ -464,7 +462,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- pydm_d::LoadVarOp srcOp, ArrayRef<Value> operands,
+ pydm_d::LoadVarOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
auto resultType =
@@ -472,7 +470,7 @@
if (!resultType)
return rewriter.notifyMatchFailure(
srcOp, "could not convert load_var result type");
- auto list = operands[0];
+ auto list = adaptor.getOperands()[0];
auto index1 =
rewriter.create<arith_d::ConstantOp>(loc, rewriter.getIndexAttr(1));
rewriter.replaceOpWithNewOp<iree_d::ListGetOp>(srcOp, resultType, list,
@@ -488,7 +486,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- pydm_d::NoneOp srcOp, ArrayRef<Value> operands,
+ pydm_d::NoneOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type i32 = rewriter.getI32Type();
rewriter.replaceOpWithNewOp<arith_d::ConstantOp>(
@@ -505,11 +503,11 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- pydm_d::RaiseOnFailureOp srcOp, ArrayRef<Value> operands,
+ pydm_d::RaiseOnFailureOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
- Value status = operands[0];
+ Value status = adaptor.getOperands()[0];
// Get the containing function return type so that we can create a
// suitable null return value.
auto parentFunc = srcOp->getParentOfType<builtin_d::FuncOp>();
@@ -545,13 +543,13 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- pydm_d::ReturnOp srcOp, ArrayRef<Value> operands,
+ pydm_d::ReturnOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
auto zeroResult = rewriter.create<arith_d::ConstantOp>(
loc, rewriter.getI32IntegerAttr(0));
rewriter.replaceOpWithNewOp<std_d::ReturnOp>(
- srcOp, ValueRange{zeroResult, operands[0]});
+ srcOp, ValueRange{zeroResult, adaptor.getOperands()[0]});
return success();
}
};
@@ -560,7 +558,7 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- pydm_d::StoreVarOp srcOp, ArrayRef<Value> operands,
+ pydm_d::StoreVarOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
@@ -571,8 +569,8 @@
"not a python type for value()");
int typeCode = static_cast<int>(origStoreType.getTypeCode());
- auto list = operands[0];
- auto newValue = operands[1];
+ auto list = adaptor.getOperands()[0];
+ auto newValue = adaptor.getOperands()[1];
resetObjectList(loc, rewriter, list, typeCode, newValue);
rewriter.eraseOp(srcOp);
return success();
@@ -583,10 +581,10 @@
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- pydm_d::UnboxOp srcOp, ArrayRef<Value> operands,
+ pydm_d::UnboxOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = srcOp.getLoc();
- auto list = operands[0];
+ auto list = adaptor.getOperands()[0];
// Target exception result type.
Type statusType = getTypeConverter()->convertType(srcOp.status().getType());
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 97a1570..bcad20b 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 97a1570d8c31dc3bff12dd77b1ee824e1872bb69
+Subproject commit bcad20bc6591c8b503923402038c735a77373f99