blob: b2adbd9fce2b4741cabd740c8e3b79d5d2c8621c [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.h"
#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
#include "iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.h"
#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "iree/compiler/Dialect/VMLA/vmla.imports.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
// Erases an op. This should only be used for ops that are legalized away
// as part of lowering (i.e. tagging or metadata ops that are unrepresentable
// in the VM dialect).
class EraseNonVMOp : public ConversionPattern {
public:
EraseNonVMOp(StringRef rootName, MLIRContext *ctx)
: ConversionPattern(rootName, 0, ctx) {}
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
// When converting to the VM, it is safe to remove any identity tie_shape
// ops that remain.
class ElideTieShapeOp : public OpConversionPattern<Shape::TieShapeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
Shape::TieShapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, operands[0]);
return success();
}
};
// VMLA -> VM import conversion base for generic ops.
// Handles signatures with integers, VM types, or simple buffers.
template <typename T, typename Adaptor = typename T::Adaptor>
class VMLAImportOpConversion : public OpConversionPattern<T> {
public:
VMLAImportOpConversion(MLIRContext *context, SymbolTable &importSymbols,
TypeConverter &typeConverter, StringRef importName)
: OpConversionPattern<T>(context),
importSymbols(importSymbols),
typeConverter(typeConverter),
importName(importName) {}
LogicalResult matchAndRewrite(
T op, llvm::ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
std::string importFqName = importName + getImportSuffix(op);
auto importOp =
importSymbols.template lookup<IREE::VM::ImportOp>(importFqName);
if (!importOp) {
op.emitError() << "failed to resolve VM function import for "
<< importFqName;
return failure();
}
assert(importOp);
return rewriteToCall(op, Adaptor{operands}, importOp, typeConverter,
rewriter);
}
protected:
virtual std::string getImportSuffix(T op) const { return ""; }
std::string getSizedTypeStr(Type elementType) const {
int bitWidth = elementType.getIntOrFloatBitWidth();
// Widen i1 -> i8 to match the VM type conversion.
if (bitWidth == 1) {
bitWidth = 8;
}
return "x" + std::to_string(bitWidth);
}
std::string getTypedTypeStr(Type type, bool forceUnsigned = false) const {
Type elementType = type;
auto shapedType = type.dyn_cast<ShapedType>();
if (shapedType) {
elementType = shapedType.getElementType();
}
std::string typePrefix = "x";
if (elementType.isa<FloatType>()) {
typePrefix = "f";
} else if (elementType.isSignlessInteger()) {
typePrefix = forceUnsigned ? "u" : "i";
}
int bitWidth = elementType.getIntOrFloatBitWidth();
// Widen i1 -> i8 to match the VM type conversion.
if (bitWidth == 1) {
bitWidth = 8;
}
return typePrefix + std::to_string(bitWidth);
}
private:
SymbolTable &importSymbols;
TypeConverter &typeConverter;
std::string importName;
};
#define VMLA_IMPORT_OP(op_type, op_mnemonic) \
patterns.insert<VMLAImportOpConversion<op_type>>( \
context, importSymbols, typeConverter, op_mnemonic);
// VMLA -> VM import conversion for ops using sized operands (foo.xNN).
// This will use only the bit-width of the element type to add a .xNN suffix to
// the op name. Assumes the element type is valid.
template <typename T>
class VMLASizedImportOpConversion : public VMLAImportOpConversion<T> {
public:
using VMLAImportOpConversion<T>::VMLAImportOpConversion;
std::string getImportSuffix(T op) const override {
return std::string(".") + this->getSizedTypeStr(op.element_type());
}
};
#define VMLA_SIZED_IMPORT_OP(op_type, op_mnemonic) \
patterns.insert<VMLASizedImportOpConversion<op_type>>( \
context, importSymbols, typeConverter, op_mnemonic);
// VMLA -> VM import conversion for ops using typed operands (foo.fNN, etc).
// This will use the element type to add a type-specific suffix to the op name.
// Assumes the element type is valid.
template <typename T>
class VMLATypedImportOpConversion : public VMLAImportOpConversion<T> {
public:
using VMLAImportOpConversion<T>::VMLAImportOpConversion;
std::string getImportSuffix(T op) const override {
bool forceUnsigned =
!!static_cast<Operation *>(op)->getAttr("forceUnsigned");
return "." + this->getTypedTypeStr(op.element_type(), forceUnsigned);
}
};
#define VMLA_TYPED_IMPORT_OP(op_type, op_mnemonic) \
patterns.insert<VMLATypedImportOpConversion<op_type>>( \
context, importSymbols, typeConverter, op_mnemonic);
class VMLAConstantOpConversion
: public OpConversionPattern<IREE::VMLA::ConstantOp> {
public:
VMLAConstantOpConversion(MLIRContext *context,
TypeConverter & /*typeConverter*/)
: OpConversionPattern(context) {}
LogicalResult matchAndRewrite(
IREE::VMLA::ConstantOp op, llvm::ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (auto splatAttr = op.value().dyn_cast<SplatElementsAttr>()) {
// Encode just a single splat element and use a buffer fill.
auto rodataValue = rewriter.createOrFold<IREE::VM::RodataInlineOp>(
op.getLoc(),
IREE::VM::RefType::get(IREE::ByteBufferType::get(op.getContext())),
DenseElementsAttr::get(
RankedTensorType::get({1}, splatAttr.getSplatValue().getType()),
splatAttr.getSplatValue()));
auto fillValue = rewriter.createOrFold<IREE::VMLA::BufferConstOp>(
op.getLoc(), IREE::VMLA::BufferType::get(op.getContext()),
rodataValue);
auto bufferLengthValue = rewriter.createOrFold<mlir::ConstantIndexOp>(
op.getLoc(), splatAttr.getType().cast<ShapedType>().getNumElements() *
VMLATypeConverter::getRoundedElementByteWidth(
splatAttr.getSplatValue().getType()));
auto bufferValue = rewriter.createOrFold<IREE::VMLA::BufferAllocOp>(
op.getLoc(), IREE::VMLA::BufferType::get(op.getContext()),
bufferLengthValue);
rewriter.create<IREE::VMLA::BufferFillOp>(op.getLoc(), fillValue,
bufferValue);
rewriter.replaceOp(op, bufferValue);
} else {
// Encode constant data into a rodata segment. These will eventually get
// deduped and combined.
auto rodataValue = rewriter.createOrFold<IREE::VM::RodataInlineOp>(
op.getLoc(),
IREE::VM::RefType::get(IREE::ByteBufferType::get(op.getContext())),
op.value());
rewriter.replaceOpWithNewOp<IREE::VMLA::BufferConstOp>(
op, IREE::VMLA::BufferType::get(op.getContext()), rodataValue);
}
return success();
}
private:
// TODO(b/145839814): find a name that's unique or make the rewriter support
// assigning unique names.
int allocateUniqueId(Operation *context) const {
if (uniqueContext != context) {
uniqueContext = context;
uniqueCounter = 0;
}
return uniqueCounter++;
}
mutable Operation *uniqueContext = nullptr;
mutable int uniqueCounter = 0;
};
class VMLAConvertImportOpConversion
: public VMLAImportOpConversion<IREE::VMLA::ConvertOp> {
public:
using VMLAImportOpConversion<IREE::VMLA::ConvertOp>::VMLAImportOpConversion;
std::string getImportSuffix(IREE::VMLA::ConvertOp op) const override {
return std::string(".") + getTypedTypeStr(op.src_type()) +
std::string(".") + getTypedTypeStr(op.dst_type());
}
};
class VMLABatchMatMulImportOpConversion
: public VMLAImportOpConversion<IREE::VMLA::BatchMatMulOp> {
public:
using VMLAImportOpConversion<
IREE::VMLA::BatchMatMulOp>::VMLAImportOpConversion;
std::string getImportSuffix(IREE::VMLA::BatchMatMulOp op) const override {
return std::string(".") + getTypedTypeStr(op.lhs_type()) +
getTypedTypeStr(op.rhs_type()) + std::string(".") +
getTypedTypeStr(op.dst_type());
}
};
class VMLAConvImportOpConversion
: public VMLAImportOpConversion<IREE::VMLA::ConvOp> {
public:
using VMLAImportOpConversion<IREE::VMLA::ConvOp>::VMLAImportOpConversion;
std::string getImportSuffix(IREE::VMLA::ConvOp op) const override {
return std::string(".") + getTypedTypeStr(op.input_type()) +
getTypedTypeStr(op.filter_type()) + std::string(".") +
getTypedTypeStr(op.dst_type());
}
};
class VMLAFftImportOpConversion
: public VMLAImportOpConversion<IREE::VMLA::FftOp> {
public:
using VMLAImportOpConversion<IREE::VMLA::FftOp>::VMLAImportOpConversion;
std::string getImportSuffix(IREE::VMLA::FftOp op) const override {
return std::string(".") + getTypedTypeStr(op.real_element_type());
}
};
} // namespace
void populateVMLAToVMPatterns(MLIRContext *context,
TypeConverter &typeConverter,
SymbolTable &importSymbols,
OwningRewritePatternList &patterns) {
patterns.insert<VMLAConstantOpConversion>(context, typeConverter);
patterns.insert<EraseNonVMOp>(Shape::ConstRankedShapeOp::getOperationName(),
context);
patterns.insert<EraseNonVMOp>(Shape::MakeRankedShapeOp::getOperationName(),
context);
patterns.insert<ElideTieShapeOp>(context);
VMLA_IMPORT_OP(IREE::VMLA::BufferConstOp, "vmla.buffer.const");
VMLA_IMPORT_OP(IREE::VMLA::BufferAllocOp, "vmla.buffer.alloc");
VMLA_IMPORT_OP(IREE::VMLA::BufferCloneOp, "vmla.buffer.clone");
VMLA_IMPORT_OP(IREE::VMLA::BufferByteLengthOp, "vmla.buffer.byte_length");
VMLA_IMPORT_OP(IREE::VMLA::BufferViewOp, "vmla.buffer.view");
VMLA_IMPORT_OP(IREE::VMLA::BufferCopyOp, "vmla.buffer.copy");
VMLA_IMPORT_OP(IREE::VMLA::BufferFillOp, "vmla.buffer.fill");
VMLA_IMPORT_OP(IREE::VMLA::BufferLoadI32Op, "vmla.buffer.load.i32");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::CmpOp, "vmla.cmp");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::SelectOp, "vmla.select");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::FiniteOp, "vmla.finite");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::CopyOp, "vmla.copy");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::TransposeOp, "vmla.transpose");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::ReverseOp, "vmla.reverse");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::PadOp, "vmla.pad");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::GatherOp, "vmla.gather");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::ScatterOp, "vmla.scatter");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::BroadcastOp, "vmla.broadcast");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::IotaOp, "vmla.iota");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::TileOp, "vmla.tile");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::NotOp, "vmla.not");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::AndOp, "vmla.and");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::AndBroadcastOp, "vmla.and.broadcast");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::OrOp, "vmla.or");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::XorOp, "vmla.xor");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::XorBroadcastOp, "vmla.xor.broadcast");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::ShlOp, "vmla.shl");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::ShrOp, "vmla.shr");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::AddOp, "vmla.add");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::SubOp, "vmla.sub");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::AbsOp, "vmla.abs");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::NegOp, "vmla.neg");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::MulOp, "vmla.mul");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::DivOp, "vmla.div");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::RemOp, "vmla.rem");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::PowOp, "vmla.pow");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::ExpOp, "vmla.exp");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::LogOp, "vmla.log");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::RsqrtOp, "vmla.rsqrt");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::SqrtOp, "vmla.sqrt");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::CosOp, "vmla.cos");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::SinOp, "vmla.sin");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::TanhOp, "vmla.tanh");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::Atan2Op, "vmla.atan2");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::MinOp, "vmla.min");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::MaxOp, "vmla.max");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::ClampOp, "vmla.clamp");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::FloorOp, "vmla.floor");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::CeilOp, "vmla.ceil");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::RoundOp, "vmla.round");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::SortOp, "vmla.sort");
patterns.insert<VMLAConvertImportOpConversion>(context, importSymbols,
typeConverter, "vmla.convert");
patterns.insert<VMLABatchMatMulImportOpConversion>(
context, importSymbols, typeConverter, "vmla.batch.matmul");
patterns.insert<VMLAConvImportOpConversion>(context, importSymbols,
typeConverter, "vmla.conv");
patterns.insert<VMLAFftImportOpConversion>(context, importSymbols,
typeConverter, "vmla.fft");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceSumOp, "vmla.reduce.sum");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMinOp, "vmla.reduce.min");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMaxOp, "vmla.reduce.max");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceAndOp, "vmla.reduce.and");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceOrOp, "vmla.reduce.or");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingSumOp, "vmla.pooling.sum");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingMinOp, "vmla.pooling.min");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingMaxOp, "vmla.pooling.max");
VMLA_IMPORT_OP(IREE::VMLA::InterfaceConstOp, "vmla.interface.const");
VMLA_IMPORT_OP(IREE::VMLA::InterfaceBindingOp, "vmla.interface.binding");
}
namespace {
// A pass converting the IREE flow dialect into the IREE VMLA dialect.
class ConvertVMLAToVMPass
: public PassWrapper<ConvertVMLAToVMPass, OperationPass<ModuleOp>> {
public:
explicit ConvertVMLAToVMPass(IREE::VM::TargetOptions targetOptions)
: targetOptions_(targetOptions) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREEDialect, IREE::VM::VMDialect>();
}
void runOnOperation() override {
auto *context = &getContext();
VMConversionTarget conversionTarget(context);
IREE::VM::TypeConverter typeConverter(targetOptions_);
mlir::ModuleOp outerModuleOp, innerModuleOp;
std::tie(outerModuleOp, innerModuleOp) =
VMConversionTarget::nestModuleForConversion(getOperation());
appendImportModule(
StringRef(vmla_imports_create()->data, vmla_imports_create()->size),
innerModuleOp);
OwningRewritePatternList conversionPatterns;
populateStandardToVMPatterns(context, typeConverter, conversionPatterns);
SymbolTable importSymbols(innerModuleOp);
populateVMLAToVMPatterns(context, typeConverter, importSymbols,
conversionPatterns);
// Usually shape conversion patterns come in at a higher level, but for
// this standalone pass, they must be provided directly.
Shape::populateFoldConversionPatterns(&getContext(), conversionPatterns);
if (failed(applyPartialConversion(outerModuleOp, conversionTarget,
std::move(conversionPatterns)))) {
outerModuleOp.emitError() << "conversion to vm.module failed";
return signalPassFailure();
}
}
private:
IREE::VM::TargetOptions targetOptions_;
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> createConvertVMLAToVMPass(
IREE::VM::TargetOptions targetOptions) {
return std::make_unique<ConvertVMLAToVMPass>(targetOptions);
}
static PassRegistration<ConvertVMLAToVMPass> pass(
"iree-convert-vmla-to-vm",
"Convert the IREE VMLA dialect to the IREE VM dialect", [] {
auto options = IREE::VM::getTargetOptionsFromFlags();
return std::make_unique<ConvertVMLAToVMPass>(options);
});
} // namespace iree_compiler
} // namespace mlir