blob: 451168805145865f90a40f6fe9d32dd45afb6e8a [file] [log] [blame] [edit]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/VMVX/Conversion/VMVXToVM/ConvertVMVXToVM.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "iree/compiler/Dialect/VMVX/IR/VMVXOps.h"
#include "iree/compiler/Dialect/VMVX/IR/VMVXTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir::iree_compiler {
namespace {
// Erases an op. This should only be used for ops that are legalized away
// as part of lowering (i.e. tagging or metadata ops that are unrepresentable
// in the VM dialect).
class EraseNonVMOp : public ConversionPattern {
public:
EraseNonVMOp(StringRef rootName, MLIRContext *ctx)
: ConversionPattern(rootName, 0, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
// VMVX -> VM import conversion base for generic ops.
// Handles signatures with integers, VM types, or simple buffers.
// TODO: This is a big mess and doesn't support the generality we need. Redo
// it.
template <typename T>
class VMVXImportOpConversion : public OpConversionPattern<T> {
public:
VMVXImportOpConversion(MLIRContext *context, SymbolTable &importSymbols,
TypeConverter &typeConverter)
: OpConversionPattern<T>(context), importSymbols(importSymbols),
typeConverter(typeConverter) {}
LogicalResult
matchAndRewrite(T op, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::string importFqName = getImportFqName(op);
auto importOp =
importSymbols.template lookup<IREE::VM::ImportOp>(importFqName);
if (!importOp) {
op.emitError() << "failed to resolve VM function import for "
<< importFqName;
return failure();
}
auto results = emitCall(op, adaptor, importOp, rewriter);
if (!results.has_value())
return failure();
rewriter.replaceOp(op, results.value());
return success();
}
protected:
virtual std::string getImportFqName(T op) const = 0;
virtual std::optional<SmallVector<Value>>
emitCall(T op, typename T::Adaptor adaptor, IREE::VM::ImportOp importOp,
ConversionPatternRewriter &rewriter) const {
return rewriteToCall(op, adaptor, importOp, typeConverter, rewriter);
}
std::string getSizedTypeStr(Type elementType) const {
int bitWidth = elementType.getIntOrFloatBitWidth();
// Widen i1 -> i8 to match the VM type conversion.
if (bitWidth == 1) {
bitWidth = 8;
}
return "x" + std::to_string(bitWidth);
}
std::string getTypedTypeStr(Type type, bool forceUnsigned = false) const {
Type elementType = type;
auto shapedType = dyn_cast<ShapedType>(type);
if (shapedType) {
elementType = shapedType.getElementType();
}
std::string typePrefix = "x";
if (isa<FloatType>(elementType)) {
typePrefix = "f";
} else if (elementType.isSignlessInteger()) {
typePrefix = forceUnsigned ? "u" : "i";
}
int bitWidth = elementType.getIntOrFloatBitWidth();
// Widen i1 -> i8 to match the VM type conversion.
if (bitWidth == 1) {
bitWidth = 8;
}
return typePrefix + std::to_string(bitWidth);
}
private:
SymbolTable &importSymbols;
TypeConverter &typeConverter;
};
class BinaryOpConversion : public VMVXImportOpConversion<IREE::VMVX::BinaryOp> {
public:
using VMVXImportOpConversion::VMVXImportOpConversion;
std::string getImportFqName(IREE::VMVX::BinaryOp op) const override {
int rank = op.getLhsStrides().size();
std::string name("vmvx.");
name.append(op.getOpcode().begin(), op.getOpcode().end());
name.append(".");
name.append(std::to_string(rank));
name.append("d.");
name.append(getTypedTypeStr(op.getElementType()));
return name;
}
};
// Converts the vmvx.copy op to an appropriate typed import.
class CopyOpConversion : public VMVXImportOpConversion<IREE::VMVX::CopyOp> {
public:
using VMVXImportOpConversion::VMVXImportOpConversion;
std::string getImportFqName(IREE::VMVX::CopyOp op) const override {
int rank = op.getInStrides().size();
std::string name("vmvx.copy.");
name.append(std::to_string(rank));
name.append("d.");
name.append(getSizedTypeStr(op.getElementType()));
return name;
}
};
// Converts the vmvx.fill2d op to an appropriate typed import.
class Fill2DOpConversion : public VMVXImportOpConversion<IREE::VMVX::Fill2DOp> {
public:
using VMVXImportOpConversion::VMVXImportOpConversion;
std::string getImportFqName(IREE::VMVX::Fill2DOp op) const override {
std::string name("vmvx.fill.2d.");
name.append(getSizedTypeStr(op.getScalar().getType()));
return name;
}
};
class UnaryOpConversion : public VMVXImportOpConversion<IREE::VMVX::UnaryOp> {
public:
using VMVXImportOpConversion::VMVXImportOpConversion;
std::string getImportFqName(IREE::VMVX::UnaryOp op) const override {
int rank = op.getInStrides().size();
std::string name("vmvx.");
name.append(op.getOpcode().begin(), op.getOpcode().end());
name.append(".");
name.append(std::to_string(rank));
name.append("d.");
name.append(getTypedTypeStr(op.getElementType()));
return name;
}
};
} // namespace
void populateVMVXToVMPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
SymbolTable &importSymbols,
RewritePatternSet &patterns) {
patterns.insert<BinaryOpConversion, CopyOpConversion, Fill2DOpConversion,
UnaryOpConversion>(context, importSymbols, typeConverter);
}
} // namespace mlir::iree_compiler