Adding HLO/Standard op -> VMLA op conversion.
For a majority of the ops this is performed automatically via op traits like 'IncludeShapes' indicating that the runtime wants shapes for all buffers passed as arguments. A few conversions that will need to perform HLO->HLO/HLO->not a 1:1 VMLA op will come in future changes.
PiperOrigin-RevId: 294359962
diff --git a/iree/compiler/Dialect/VMLA/Conversion/BUILD b/iree/compiler/Dialect/VMLA/Conversion/BUILD
index 910024c..51f5aa0 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/BUILD
+++ b/iree/compiler/Dialect/VMLA/Conversion/BUILD
@@ -29,6 +29,7 @@
],
deps = [
"//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/VMLA/IR",
"//iree/compiler/Dialect/VMLA/IR:VMLADialect",
"@llvm-project//mlir:IR",
diff --git a/iree/compiler/Dialect/VMLA/Conversion/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/CMakeLists.txt
index 6d86d85..314da4e 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Conversion/CMakeLists.txt
@@ -27,6 +27,7 @@
"TypeConverter.cpp"
DEPS
iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::VMLA::IR
iree::compiler::Dialect::VMLA::IR::VMLADialect
MLIRIR
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
index 85d0c6d..4a9e315 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
@@ -15,8 +15,14 @@
#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLATraits.h"
+#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/StandardTypes.h"
@@ -31,7 +37,6 @@
// The VMLA dialect expects both standard ops and the VMLA ops (in case some
// conversion has already happened).
addLegalOp<ModuleOp, ModuleTerminatorOp>();
- addLegalDialect<StandardOpsDialect>();
addLegalDialect<IREE::VMLA::VMLADialect>();
// Allow other ops to pass through so long as their type is valid (not a
@@ -54,10 +59,182 @@
// static
LogicalResult VMLAConversionTarget::applyDefaultBufferRewrite(
- Operation *srcOp, ArrayRef<Value> operands, StringRef dstOpName,
+ Operation *srcOp, ArrayRef<Value> operands, VMLAOpSemantics semantics,
+ StringRef dstOpName, TypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
+ OperationState state{srcOp->getLoc(), dstOpName};
+ state.addAttributes(srcOp->getAttrs());
+
+ auto *dstOperation = state.name.getAbstractOperation();
+ auto *opInterface = dstOperation->getInterface<IREE::VMLA::VMLAOp>();
+
+ // Allow the op to get at any of the type information it requires. For
+ // example, if the op may later need to know the type of the elements in a
+ // type-erased buffer it can stash the original tensor type as an attribute.
+ if (opInterface) {
+ opInterface->extractTypeAttributes(
+ state, llvm::to_vector<4>(srcOp->getOperandTypes()),
+ llvm::to_vector<4>(srcOp->getResultTypes()));
+ }
+
+ // Until MLIR supports unsigned types we need to sidechannel this to the
+ // VMLA->VM conversion that really needs to know.
+ switch (semantics) {
+ default:
+ break;
+ case VMLAOpSemantics::kForceUnsigned:
+ state.addAttribute("force_unsigned", UnitAttr::get(srcOp->getContext()));
+ break;
+ }
+
+ // Add all input operands.
+ for (auto srcDstOperand : llvm::zip(srcOp->getOperands(), operands)) {
+ auto srcOperand = std::get<0>(srcDstOperand);
+ auto dstOperand = std::get<1>(srcDstOperand);
+ if (auto tensorType =
+ srcOperand.getType().template dyn_cast<TensorType>()) {
+ // Some ops also require shape information.
+ state.addOperands({dstOperand});
+ if (dstOperation->hasTrait<OpTrait::IREE::VMLA::IncludeShapes>()) {
+ Value operandShape = getTensorShape(srcOp->getLoc(), srcOperand,
+ typeConverter, rewriter);
+ if (!operandShape) {
+ return srcOp->emitError() << "failed to get operand tensor shape";
+ }
+ state.addOperands({operandShape});
+ }
+ } else {
+ // Normal pass-through operand.
+ state.addOperands({dstOperand});
+ }
+ }
+
+ // Allocate output buffers for tensors returned by the op. We'll append these
+ // to the operands in order (as is convention here).
+ SmallVector<Value, 4> allocatedBuffers;
+ for (auto srcResult : srcOp->getResults()) {
+ if (auto tensorType = srcResult.getType().template dyn_cast<TensorType>()) {
+ auto dstBuffer = allocateOutputBuffer(srcOp->getLoc(), srcResult,
+ typeConverter, rewriter);
+ if (!dstBuffer) {
+ return srcOp->emitError()
+ << "failed to allocate output buffer for tensor result";
+ }
+ state.addOperands({dstBuffer});
+ allocatedBuffers.push_back(dstBuffer);
+ } else {
+ // Normal pass-through result.
+ state.addTypes({srcResult.getType()});
+ }
+ }
+
+ // Rebuild the result list and replace the op ensuring that all original op
+ // results are represented in order even if we changed them to out params.
+ auto *dstOp = rewriter.createOperation(state);
+ auto dstResults = llvm::to_vector<4>(dstOp->getResults());
+ SmallVector<Value, 4> resultValues;
+ for (auto resultType : srcOp->getResultTypes()) {
+ if (resultType.template isa<TensorType>()) {
+ resultValues.push_back(allocatedBuffers.front());
+ allocatedBuffers.erase(allocatedBuffers.begin());
+ } else {
+ resultValues.push_back(dstResults.front());
+ dstResults.erase(dstResults.begin());
+ }
+ }
+ rewriter.replaceOp(srcOp, resultValues);
+ return success();
+}
+
+// static
+Value VMLAConversionTarget::getTensorShape(
+ Location loc, Value originalValue, TypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
+ // TODO(benvanik): use tie_shape to find the ranked shape to use for the
+ // originalValue tensor.
+ auto originalType = originalValue.getType().cast<ShapedType>();
+ return rewriter.createOrFold<Shape::ConstRankedShapeOp>(
+ loc, Shape::RankedShapeType::get(originalType.getShape(),
+ rewriter.getIntegerType(32)));
+}
+
+// static
+Value VMLAConversionTarget::getBufferOffset(
+ Location loc, Value tensorValue, Value indicesValue,
TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
- // TODO(benvanik): implement rewriting.
- return failure();
+ auto indicesType = indicesValue.getType().cast<ShapedType>();
+ SmallVector<Value, 4> indices(indicesType.getNumElements());
+ for (int i = 0; i < indicesType.getNumElements(); ++i) {
+ auto extractIndex = rewriter.createOrFold<mlir::ConstantOp>(
+ loc, rewriter.getIntegerType(32), rewriter.getI32IntegerAttr(i));
+ indices[i] = rewriter.createOrFold<mlir::ExtractElementOp>(
+ loc, indicesValue, ValueRange{extractIndex});
+ }
+ return getBufferOffset(loc, tensorValue, indices, typeConverter, rewriter);
+}
+
+// static
+Value VMLAConversionTarget::getBufferOffset(
+ Location loc, Value tensorValue, ValueRange indices,
+ TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
+ // Element type byte length as the base.
+ auto tensorType = tensorValue.getType().cast<ShapedType>();
+ auto elementType = tensorType.getElementType();
+ auto elementSize = rewriter.createOrFold<mlir::ConstantOp>(
+ loc, rewriter.getIntegerType(32),
+ rewriter.getI32IntegerAttr(
+ VMLATypeConverter::getRoundedElementByteWidth(elementType)));
+
+ auto shape = getTensorShape(loc, tensorValue, typeConverter, rewriter);
+ Value offset = rewriter.createOrFold<mlir::ConstantOp>(
+ loc, rewriter.getIntegerType(32), rewriter.getI32IntegerAttr(0));
+ for (int i = 0; i < tensorType.getRank(); ++i) {
+ auto axisOffset = indices[i];
+ for (int j = i + 1; j < tensorType.getRank(); ++j) {
+ auto dim = rewriter.createOrFold<Shape::RankedDimOp>(loc, shape, j);
+ axisOffset = rewriter.createOrFold<mlir::MulIOp>(loc, axisOffset, dim);
+ }
+ offset = rewriter.createOrFold<mlir::AddIOp>(loc, offset, axisOffset);
+ }
+ return rewriter.createOrFold<mlir::MulIOp>(loc, offset, elementSize);
+}
+
+// static
+Value VMLAConversionTarget::getBufferLength(
+ Location loc, Value tensorValue, TypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
+ // Element type byte length as the base.
+ auto tensorType = tensorValue.getType().cast<ShapedType>();
+ auto elementType = tensorType.getElementType();
+ auto elementSize = rewriter.createOrFold<mlir::ConstantOp>(
+ loc, rewriter.getIntegerType(32),
+ rewriter.getI32IntegerAttr(
+ VMLATypeConverter::getRoundedElementByteWidth(elementType)));
+
+ auto shape = getTensorShape(loc, tensorValue, typeConverter, rewriter);
+ auto dims = rewriter.create<Shape::RankedDimsOp>(loc, shape);
+ Value length = elementSize;
+ for (auto dim : dims.getResults()) {
+ length = rewriter.createOrFold<mlir::MulIOp>(loc, length, dim);
+ }
+ return length;
+}
+
+// static
+Value VMLAConversionTarget::allocateOutputBuffer(
+ Location loc, Value originalValue, TypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
+ // Compute the required buffer size. Since we are always dense (right now)
+ // this is just normal x*y*z*...
+ Value byteLength =
+ getBufferLength(loc, originalValue, typeConverter, rewriter);
+
+ // Allocate the buffer of the required size.
+ // The caller can then use the buffer instead of the original SSA value.
+ return rewriter.createOrFold<IREE::VMLA::BufferAllocOp>(
+ loc,
+ IREE::RefPtrType::get(IREE::VMLA::BufferType::get(rewriter.getContext())),
+ byteLength);
}
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h
index 7ab9a22..ab65502 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h
@@ -18,6 +18,7 @@
#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
@@ -26,6 +27,12 @@
namespace mlir {
namespace iree_compiler {
+enum class VMLAOpSemantics {
+ kDefault = 0,
+ // Forces integers to be treated as unsigned integers.
+ kForceUnsigned,
+};
+
// A conversion target for the VMLA dialect that ensures that tensor types are
// fully removed. Conversions targeting the VMLA dialect should always use this.
class VMLAConversionTarget : public ConversionTarget {
@@ -35,8 +42,34 @@
// Attempts to rewrite an op that may use tensor values into an op using VMLA
// buffers. See VMLAOpConversion for more information.
static LogicalResult applyDefaultBufferRewrite(
- Operation *srcOp, ArrayRef<Value> operands, StringRef dstOpName,
- TypeConverter &typeConverter, ConversionPatternRewriter &rewriter);
+ Operation *srcOp, ArrayRef<Value> operands, VMLAOpSemantics semantics,
+ StringRef dstOpName, TypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter);
+
+ // Returns the shape of the |originalValue| tensor as an SSA ranked shape.
+ static Value getTensorShape(Location loc, Value originalValue,
+ TypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter);
+
+ // Returns the offset, in bytes, of an index within a linearized dense buffer.
+ static Value getBufferOffset(Location loc, Value tensorValue,
+ Value indicesValue, TypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter);
+ static Value getBufferOffset(Location loc, Value tensorValue,
+ ValueRange indices, TypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter);
+
+ // Returns the length, in bytes, of a linearized dense buffer.
+ static Value getBufferLength(Location loc, Value tensorValue,
+ TypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter);
+
+ // Allocates a VMLA buffer for an output operand of an op.
+ // Returns a buffer allocated with the appropriate size for storing the value.
+ // Callers must replace uses of |originalValue| with the returned value.
+ static Value allocateOutputBuffer(Location loc, Value originalValue,
+ TypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter);
private:
bool isDynamicallyLegal(Operation *op) const override;
@@ -52,7 +85,8 @@
// will be IREE::VMLA::BufferTypes. Any static information available about the
// tensor (such as static dimensions, element type, layout, etc) are extracted
// here and lowered as expanded values.
-template <typename SRC, typename DST>
+template <typename SRC, typename DST,
+ VMLAOpSemantics semantics = VMLAOpSemantics::kDefault>
class VMLAOpConversion : public OpConversionPattern<SRC> {
public:
VMLAOpConversion(MLIRContext *context, TypeConverter &typeConverter)
@@ -62,7 +96,7 @@
SRC srcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (succeeded(VMLAConversionTarget::applyDefaultBufferRewrite(
- srcOp, operands, DST::getOperationName(), typeConverter,
+ srcOp, operands, semantics, DST::getOperationName(), typeConverter,
rewriter))) {
return OpConversionPattern<SRC>::matchSuccess();
}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
index 9d2d5e2..3d6de77 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
@@ -27,6 +27,8 @@
],
deps = [
"//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/Shape/IR",
+ "//iree/compiler/Dialect/VMLA/Conversion",
"//iree/compiler/Dialect/VMLA/IR",
"//iree/compiler/Dialect/VMLA/IR:VMLADialect",
"@llvm-project//mlir:IR",
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
index 3bb8e13..7ced493 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
@@ -23,6 +23,8 @@
"ConvertHLOToVMLA.cpp"
DEPS
iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::Shape::IR
+ iree::compiler::Dialect::VMLA::Conversion
iree::compiler::Dialect::VMLA::IR
iree::compiler::Dialect::VMLA::IR::VMLADialect
MLIRIR
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 9e9d5c7..1b19138 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -15,6 +15,8 @@
#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
@@ -24,13 +26,47 @@
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
namespace mlir {
namespace iree_compiler {
+namespace {
+
+// Clones operand[0] and returns the result.
+// This models the value semantics of XLA. We expect previous passes to elide
+// identity ops when possible and only check for trivial single use ops here.
+template <typename SRC>
+struct IdentityOpConversion : public OpConversionPattern<SRC> {
+ using OpConversionPattern<SRC>::OpConversionPattern;
+
+ PatternMatchResult matchAndRewrite(
+ SRC srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (srcOp.getOperand().hasOneUse()) {
+ // Can directly pass through the input buffer as we don't need to clone
+ // for other users.
+ rewriter.replaceOp(srcOp, operands[0]);
+ return this->matchSuccess();
+ } else {
+ // More than one user of the operand exist and we need to ensure they
+ // keep a valid snapshot of the buffer.
+ rewriter.replaceOpWithNewOp<IREE::VMLA::BufferCloneOp>(
+ srcOp,
+ IREE::RefPtrType::get(
+ IREE::VMLA::BufferType::get(rewriter.getContext())),
+ operands[0]);
+ return this->matchSuccess();
+ }
+ }
+};
+
+} // namespace
+
void populateHLOToVMLAPatterns(MLIRContext *context,
OwningRewritePatternList &patterns,
TypeConverter &typeConverter) {
@@ -40,7 +76,87 @@
xla_hlo::PopulateXlaToStdPatterns(&patterns, context);
xla_hlo::PopulateUnfuseBatchNormPatterns(context, &patterns);
- // TODO(benvanik): conversion patterns.
+ // Simple 1:1 conversion patterns using the automated trait-based converter.
+ // Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
+ patterns.insert<VMLAOpConversion<xla_hlo::AddOp, IREE::VMLA::AddOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::SubOp, IREE::VMLA::SubOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::DivOp, IREE::VMLA::DivOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::MulOp, IREE::VMLA::MulOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::PowOp, IREE::VMLA::PowOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::RemOp, IREE::VMLA::RemOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::ShiftLeftOp, IREE::VMLA::ShlOp>>(
+ context, typeConverter);
+ patterns.insert<
+ VMLAOpConversion<xla_hlo::ShiftRightArithmeticOp, IREE::VMLA::ShrOp>>(
+ context, typeConverter);
+ patterns
+ .insert<VMLAOpConversion<xla_hlo::ShiftRightLogicalOp, IREE::VMLA::ShrOp,
+ VMLAOpSemantics::kForceUnsigned>>(context,
+ typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::AndOp, IREE::VMLA::AndOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::OrOp, IREE::VMLA::OrOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::XorOp, IREE::VMLA::XorOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::CopyOp, IREE::VMLA::BufferCloneOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::ExpOp, IREE::VMLA::ExpOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::LogOp, IREE::VMLA::LogOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::FloorOp, IREE::VMLA::FloorOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::RsqrtOp, IREE::VMLA::RsqrtOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::SqrtOp, IREE::VMLA::SqrtOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::CosOp, IREE::VMLA::CosOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::SinOp, IREE::VMLA::SinOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::TanhOp, IREE::VMLA::TanhOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::Atan2Op, IREE::VMLA::Atan2Op>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::SelectOp, IREE::VMLA::SelectOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::ConvertOp, IREE::VMLA::ConvertOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::ReverseOp, IREE::VMLA::ReverseOp>>(
+ context, typeConverter);
+ patterns
+ .insert<VMLAOpConversion<xla_hlo::TransposeOp, IREE::VMLA::TransposeOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::PadOp, IREE::VMLA::PadOp>>(
+ context, typeConverter);
+ patterns
+ .insert<VMLAOpConversion<xla_hlo::BroadcastOp, IREE::VMLA::BroadcastOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::AbsOp, IREE::VMLA::AbsOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::MaxOp, IREE::VMLA::MaxOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::MinOp, IREE::VMLA::MinOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::ClampOp, IREE::VMLA::ClampOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<xla_hlo::DotOp, IREE::VMLA::MatMulOp>>(
+ context, typeConverter);
+
+ // Ops that are only used for type information that we erase. We can elide
+ // these entirely by just passing on their input values.
+ patterns.insert<IdentityOpConversion<xla_hlo::BitcastConvertOp>>(context);
+ patterns.insert<IdentityOpConversion<xla_hlo::ReshapeOp>>(context);
+
+ // TODO(benvanik): add missing ops:
+ // - ConvOp
}
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/constant_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/constant_ops.mlir
deleted file mode 100644
index deda099..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/constant_ops.mlir
+++ /dev/null
@@ -1,6 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @todo
-func @todo() {
- return
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir
index deda099..8773ad5 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir
@@ -1,6 +1,35 @@
// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-// CHECK-LABEL: @todo
-func @todo() {
- return
+// CHECK-LABEL: @abs_scalar
+func @abs_scalar(%arg0 : tensor<f32>) -> tensor<f32> {
+ // CHECK-NEXT: [[BUF_SZ:%.+]] = constant 4
+ // CHECK-NEXT: [[BUF:%.+]] = "vmla.buffer.alloc"([[BUF_SZ]])
+ // CHECK-NEXT: "vmla.abs"(%arg0, [[BUF]]) {element_type = f32}
+ %0 = "xla_hlo.abs"(%arg0) : (tensor<f32>) -> tensor<f32>
+ // CHECK-NEXT: return [[BUF]]
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @abs_tensor
+func @abs_tensor(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
+ // CHECK-NEXT: [[BUF_SZ:%.+]] = constant 16
+ // CHECK-NEXT: [[BUF:%.+]] = "vmla.buffer.alloc"([[BUF_SZ]])
+ // CHECK-NEXT: "vmla.abs"(%arg0, [[BUF]]) {element_type = f32}
+ %0 = "xla_hlo.abs"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
+ // CHECK-NEXT: return [[BUF]]
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp
+func @clamp(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>) -> tensor<4xf32> {
+ // CHECK-NEXT: [[BUF_SZ:%.+]] = constant 16
+ // CHECK-NEXT: [[BUF:%.+]] = "vmla.buffer.alloc"([[BUF_SZ]])
+ // CHECK-NEXT: "vmla.clamp"(%arg0, %arg1, %arg2, [[BUF]]) {element_type = f32}
+ %0 = "xla_hlo.clamp"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ // CHECK-NEXT: return [[BUF]]
+ return %0 : tensor<4xf32>
}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reshape.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reshape.mlir
new file mode 100644
index 0000000..792f086
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reshape.mlir
@@ -0,0 +1,18 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @reshape_bypass
+func @reshape_bypass(%arg0 : tensor<3x2xi32>) -> tensor<6xi32> {
+ // CHECK-NEXT: return %arg0
+ %0 = "xla_hlo.reshape"(%arg0) : (tensor<3x2xi32>) -> tensor<6xi32>
+ return %0 : tensor<6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reshape_copy
+func @reshape_copy(%arg0 : tensor<3x2xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) {
+ // CHECK-NEXT: %0 = "vmla.buffer.clone"(%arg0)
+ %0 = "xla_hlo.reshape"(%arg0) : (tensor<3x2xi32>) -> tensor<6xi32>
+ // CHECK-NEXT: return %arg0, %0
+ return %arg0, %0 : tensor<3x2xi32>, tensor<6xi32>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/shaping_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/shaping_ops.mlir
deleted file mode 100644
index deda099..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/shaping_ops.mlir
+++ /dev/null
@@ -1,6 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @todo
-func @todo() {
- return
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/view_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/view_ops.mlir
deleted file mode 100644
index deda099..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/view_ops.mlir
+++ /dev/null
@@ -1,6 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @todo
-func @todo() {
- return
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD
index 2fb18e4..8497968 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD
+++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD
@@ -27,6 +27,7 @@
],
deps = [
"//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/VMLA/Conversion",
"//iree/compiler/Dialect/VMLA/IR",
"//iree/compiler/Dialect/VMLA/IR:VMLADialect",
"@llvm-project//mlir:IR",
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/CMakeLists.txt
index 1e58dc8..672faae 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/CMakeLists.txt
@@ -23,6 +23,7 @@
"ConvertStandardToVMLA.cpp"
DEPS
iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::VMLA::Conversion
iree::compiler::Dialect::VMLA::IR
iree::compiler::Dialect::VMLA::IR::VMLADialect
MLIRIR
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.cpp
index c996f46..277131c 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.cpp
@@ -15,6 +15,7 @@
#include "iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.h"
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
+#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
@@ -30,10 +31,208 @@
namespace mlir {
namespace iree_compiler {
+namespace {
+
+struct ConstantOpConversion
+ : public VMLAOpConversion<mlir::ConstantOp, IREE::VMLA::BufferConstOp> {
+ using VMLAOpConversion::VMLAOpConversion;
+
+ PatternMatchResult matchAndRewrite(
+ mlir::ConstantOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto value = srcOp.value().dyn_cast<ElementsAttr>();
+ if (!value) return matchFailure();
+ rewriter.replaceOpWithNewOp<IREE::VMLA::ConstantOp>(srcOp, value);
+ return matchSuccess();
+ }
+};
+
+struct CmpIOpConversion
+ : public VMLAOpConversion<mlir::CmpIOp, IREE::VMLA::CmpOp> {
+ using VMLAOpConversion::VMLAOpConversion;
+
+ PatternMatchResult matchAndRewrite(
+ mlir::CmpIOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto inputType = srcOp.lhs().getType().dyn_cast<ShapedType>();
+ if (!inputType) return matchFailure();
+
+ IREE::VMLA::CmpPredicate predicate = IREE::VMLA::CmpPredicate::EQ;
+ bool forceUnsigned = false;
+ switch (srcOp.predicate()) {
+ case CmpIPredicate::eq:
+ predicate = IREE::VMLA::CmpPredicate::EQ;
+ break;
+ case CmpIPredicate::ne:
+ predicate = IREE::VMLA::CmpPredicate::NE;
+ break;
+ case CmpIPredicate::slt:
+ predicate = IREE::VMLA::CmpPredicate::LT;
+ break;
+ case CmpIPredicate::sle:
+ predicate = IREE::VMLA::CmpPredicate::LE;
+ break;
+ case CmpIPredicate::sgt:
+ predicate = IREE::VMLA::CmpPredicate::GT;
+ break;
+ case CmpIPredicate::sge:
+ predicate = IREE::VMLA::CmpPredicate::GE;
+ break;
+ case CmpIPredicate::ult:
+ predicate = IREE::VMLA::CmpPredicate::LT;
+ forceUnsigned = true;
+ break;
+ case CmpIPredicate::ule:
+ predicate = IREE::VMLA::CmpPredicate::LE;
+ forceUnsigned = true;
+ break;
+ case CmpIPredicate::ugt:
+ predicate = IREE::VMLA::CmpPredicate::GT;
+ forceUnsigned = true;
+ break;
+ case CmpIPredicate::uge:
+ predicate = IREE::VMLA::CmpPredicate::GE;
+ forceUnsigned = true;
+ break;
+ default:
+ llvm_unreachable("unhandled comparison predicate");
+ return matchFailure();
+ }
+
+ auto dst = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+ auto newOp = rewriter.create<IREE::VMLA::CmpOp>(
+ srcOp.getLoc(), static_cast<uint32_t>(predicate), operands[0],
+ operands[1], dst, TypeAttr::get(inputType.getElementType()));
+ if (forceUnsigned) {
+ newOp.setAttr("force_unsigned", UnitAttr::get(rewriter.getContext()));
+ }
+ rewriter.replaceOp(srcOp, newOp.dst());
+ return matchSuccess();
+ }
+};
+
+class CmpFOpConversion
+ : public VMLAOpConversion<mlir::CmpFOp, IREE::VMLA::CmpOp> {
+ public:
+ using VMLAOpConversion::VMLAOpConversion;
+
+ PatternMatchResult matchAndRewrite(
+ mlir::CmpFOp srcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto inputType = srcOp.lhs().getType().dyn_cast<ShapedType>();
+ if (!inputType) return matchFailure();
+
+ // NOTE: the std.cmpf semantics are practically undefined. We explicitly
+ // match the HLO semantics (that get lowered to the expected case values
+ // here). In the future as new ML-focused intermediate dialects are built we
+ // can reevaluate what we support here.
+ //
+ // Rules:
+ // https://stackoverflow.com/questions/8627331/what-does-ordered-unordered-comparison-mean
+ IREE::VMLA::CmpPredicate predicate = IREE::VMLA::CmpPredicate::EQ;
+ switch (srcOp.getPredicate()) {
+ case CmpFPredicate::OEQ:
+ predicate = IREE::VMLA::CmpPredicate::EQ;
+ break;
+ case CmpFPredicate::UNE:
+ predicate = IREE::VMLA::CmpPredicate::NE;
+ break;
+ case CmpFPredicate::OLT:
+ predicate = IREE::VMLA::CmpPredicate::LT;
+ break;
+ case CmpFPredicate::OLE:
+ predicate = IREE::VMLA::CmpPredicate::LE;
+ break;
+ case CmpFPredicate::OGT:
+ predicate = IREE::VMLA::CmpPredicate::GT;
+ break;
+ case CmpFPredicate::OGE:
+ predicate = IREE::VMLA::CmpPredicate::GE;
+ break;
+ default:
+ llvm_unreachable("unhandled comparison predicate");
+ return matchFailure();
+ }
+
+ auto dst = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+ auto newOp = rewriter.create<IREE::VMLA::CmpOp>(
+ srcOp.getLoc(), static_cast<uint32_t>(predicate), operands[0],
+ operands[1], dst, TypeAttr::get(inputType.getElementType()));
+ rewriter.replaceOp(srcOp, newOp.dst());
+ return matchSuccess();
+ }
+};
+
+} // namespace
+
void populateStandardToVMLAPatterns(MLIRContext *context,
OwningRewritePatternList &patterns,
TypeConverter &typeConverter) {
- // TODO(benvanik): conversion patterns.
+ patterns.insert<ConstantOpConversion>(context, typeConverter);
+ patterns.insert<CmpIOpConversion>(context, typeConverter);
+ patterns.insert<CmpFOpConversion>(context, typeConverter);
+
+ patterns.insert<VMLAOpConversion<mlir::AddIOp, IREE::VMLA::AddOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::AddFOp, IREE::VMLA::AddOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::SubIOp, IREE::VMLA::SubOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::SubFOp, IREE::VMLA::SubOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::MulIOp, IREE::VMLA::MulOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::MulFOp, IREE::VMLA::MulOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::SignedDivIOp, IREE::VMLA::DivOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::UnsignedDivIOp, IREE::VMLA::DivOp,
+ VMLAOpSemantics::kForceUnsigned>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::DivFOp, IREE::VMLA::DivOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::AbsFOp, IREE::VMLA::AbsOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::SignedRemIOp, IREE::VMLA::RemOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::UnsignedRemIOp, IREE::VMLA::RemOp,
+ VMLAOpSemantics::kForceUnsigned>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::RemFOp, IREE::VMLA::RemOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::LogOp, IREE::VMLA::LogOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::ExpOp, IREE::VMLA::ExpOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::SqrtOp, IREE::VMLA::SqrtOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::CosOp, IREE::VMLA::CosOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::TanhOp, IREE::VMLA::TanhOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::NegFOp, IREE::VMLA::NegOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::AndOp, IREE::VMLA::AndOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::OrOp, IREE::VMLA::OrOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::XOrOp, IREE::VMLA::XorOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::ShiftLeftOp, IREE::VMLA::ShlOp>>(
+ context, typeConverter);
+ patterns
+ .insert<VMLAOpConversion<mlir::SignedShiftRightOp, IREE::VMLA::ShrOp>>(
+ context, typeConverter);
+ patterns
+ .insert<VMLAOpConversion<mlir::UnsignedShiftRightOp, IREE::VMLA::ShrOp,
+ VMLAOpSemantics::kForceUnsigned>>(context,
+ typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::CeilFOp, IREE::VMLA::CeilOp>>(
+ context, typeConverter);
+ patterns.insert<VMLAOpConversion<mlir::SelectOp, IREE::VMLA::SelectOp>>(
+ context, typeConverter);
}
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/comparison_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/comparison_ops.mlir
new file mode 100644
index 0000000..e3b94c0
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/comparison_ops.mlir
@@ -0,0 +1,23 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @cmp_i
+func @cmp_i(%arg0 : tensor<4xi32>, %arg1 : tensor<4xi32>) -> tensor<4xi1> {
+ // CHECK: [[BUF_SZ:%.+]] = constant 4
+ // CHECK-NEXT: [[BUF:%.+]] = "vmla.buffer.alloc"([[BUF_SZ]])
+ // CHECK-NEXT: "vmla.cmp"(%arg0, %arg1, [[BUF]]) {element_type = i32, predicate = 5 : i32}
+ %0 = cmpi "sge", %arg0, %arg1 : tensor<4xi32>
+ // CHECK-NEXT: return [[BUF]]
+ return %0 : tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @cmp_f
+func @cmp_f(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor<4xi1> {
+ // CHECK: [[BUF_SZ:%.+]] = constant 4
+ // CHECK-NEXT: [[BUF:%.+]] = "vmla.buffer.alloc"([[BUF_SZ]])
+ // CHECK-NEXT: "vmla.cmp"(%arg0, %arg1, [[BUF]]) {element_type = f32, predicate = 5 : i32}
+ %0 = cmpf "oge", %arg0, %arg1 : tensor<4xf32>
+ // CHECK-NEXT: return [[BUF]]
+ return %0 : tensor<4xi1>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/constant_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/constant_ops.mlir
index deda099..14fb8b7 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/constant_ops.mlir
+++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/constant_ops.mlir
@@ -1,6 +1,17 @@
// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-// CHECK-LABEL: @todo
-func @todo() {
- return
+// CHECK-LABEL: @constant_scalar
+func @constant_scalar() -> tensor<i16> {
+ // CHECK: = "vmla.constant"() {value = dense<12345> : tensor<i16>}
+ %0 = constant dense<12345> : tensor<i16>
+ return %0 : tensor<i16>
+}
+
+// -----
+
+// CHECK-LABEL: @constant_tensor
+func @constant_tensor() -> tensor<4xf32> {
+ // CHECK: = "vmla.constant"() {value = dense<[-1.000000e+00, -2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32>}
+ %0 = constant dense<[-1.0, -2.0, 3.0, 4.0]> : tensor<4xf32>
+ return %0 : tensor<4xf32>
}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/math_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/math_ops.mlir
new file mode 100644
index 0000000..13c3786
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/math_ops.mlir
@@ -0,0 +1,35 @@
+// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: @absf
+func @absf(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
+ // CHECK-NEXT: [[BUF_SZ:%.+]] = constant 16
+ // CHECK-NEXT: [[BUF:%.+]] = "vmla.buffer.alloc"([[BUF_SZ]])
+ // CHECK-NEXT: "vmla.abs"(%arg0, [[BUF]]) {element_type = f32}
+ %0 = absf %arg0 : tensor<4xf32>
+ // CHECK-NEXT: return [[BUF]]
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @shr_signed
+func @shr_signed(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
+ // CHECK-NEXT: [[BUF_SZ:%.+]] = constant 16
+ // CHECK-NEXT: [[BUF:%.+]] = "vmla.buffer.alloc"([[BUF_SZ]])
+ // CHECK-NEXT: "vmla.shr"(%arg0, %arg0, [[BUF]]) {element_type = i32}
+ %0 = shift_right_signed %arg0, %arg0 : tensor<4xi32>
+ // CHECK-NEXT: return [[BUF]]
+ return %0 : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @shr_unsigned
+func @shr_unsigned(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
+ // CHECK-NEXT: [[BUF_SZ:%.+]] = constant 16
+ // CHECK-NEXT: [[BUF:%.+]] = "vmla.buffer.alloc"([[BUF_SZ]])
+ // CHECK-NEXT: "vmla.shr"(%arg0, %arg0, [[BUF]]) {element_type = i32, force_unsigned}
+ %0 = shift_right_unsigned %arg0, %arg0 : tensor<4xi32>
+ // CHECK-NEXT: return [[BUF]]
+ return %0 : tensor<4xi32>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.cpp
index 06b47ce..0f74cf0 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.cpp
@@ -26,6 +26,9 @@
// TODO(benvanik): composite-type conversion (buffer + dynamic dims).
return IREE::RefPtrType::get(
IREE::VMLA::BufferType::get(type.getContext()));
+ } else if (type.isInteger(1)) {
+ // Widen i1 to i8.
+ return IntegerType::get(8, type.getContext());
}
return type;
}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h b/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h
index bbbc8b2..2788748 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h
+++ b/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h
@@ -22,6 +22,12 @@
class VMLATypeConverter : public TypeConverter {
public:
+ // Returns the number of bytes an element of the given type occupies
+ // post-conversion. For example, the size of i1 would be '1 byte'.
+ static int32_t getRoundedElementByteWidth(Type type) {
+ return (type.getIntOrFloatBitWidth() + 8 - 1) / 8;
+ }
+
Type convertType(Type type) override;
// TODO(benvanik): signature conversion for output buffers.
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index a638d7d..2bbeb9c 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -161,6 +161,7 @@
VMLA_SIZED_IMPORT_OP(IREE::VMLA::TransposeOp, "vmla.transpose");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::ReverseOp, "vmla.reverse");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::PadOp, "vmla.pad");
+ VMLA_SIZED_IMPORT_OP(IREE::VMLA::BroadcastOp, "vmla.broadcast");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::TileOp, "vmla.tile");
VMLA_SIZED_IMPORT_OP(IREE::VMLA::NotOp, "vmla.not");
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLABase.td b/iree/compiler/Dialect/VMLA/IR/VMLABase.td
index 1c5868c..bf325ae 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLABase.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLABase.td
@@ -149,7 +149,7 @@
VMLA_Op<mnemonic, !listconcat(traits, [VMLA_OpInterface])> {
let extraClassDeclaration = [{
static void extractTypeAttributes(OperationState &state, ArrayRef<Type> operandTypes, ArrayRef<Type> resultTypes) {
- state.addAttribute("element_type", TypeAttr::get(resultTypes[0]));
+ state.addAttribute("element_type", TypeAttr::get(resultTypes[0].cast<ShapedType>().getElementType()));
}
}];
}
@@ -177,4 +177,17 @@
);
}
+class VMLA_TernaryOp<string mnemonic, Attr typeAttr, list<OpTrait> traits = []> :
+ VMLA_ElementTypeOp<mnemonic, traits> {
+ let arguments = (ins
+ RefPtrOf<VMLA_Buffer>:$a,
+ RefPtrOf<VMLA_Buffer>:$b,
+ RefPtrOf<VMLA_Buffer>:$c,
+ RefPtrOf<VMLA_Buffer>:$dst,
+ typeAttr:$element_type,
+ // TODO(benvanik): remove once unsigned types are in MLIR.
+ UnitAttr:$forceUnsigned
+ );
+}
+
#endif // IREE_DIALECT_VMLA_BASE
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index 857b57e..cfaf98c 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -168,6 +168,16 @@
);
}
+def VMLA_BroadcastOp : VMLA_ElementTypeOp<"broadcast", [VMLA_IncludeShapes]> {
+ let arguments = (ins
+ VMLA_BufferRef:$src,
+ VMLA_Shape:$src_shape,
+ VMLA_BufferRef:$dst,
+ VMLA_Shape:$dst_shape,
+ VMLA_AnyTypeAttr:$element_type
+ );
+}
+
def VMLA_TileOp : VMLA_ElementTypeOp<"tile", [VMLA_IncludeShapes]> {
let arguments = (ins
VMLA_BufferRef:$src,
@@ -212,6 +222,7 @@
def VMLA_MinOp : VMLA_BinaryOp<"min", VMLA_AnyTypeAttr>;
def VMLA_MaxOp : VMLA_BinaryOp<"max", VMLA_AnyTypeAttr>;
+def VMLA_ClampOp : VMLA_TernaryOp<"clamp", VMLA_AnyTypeAttr>;
def VMLA_FloorOp : VMLA_UnaryOp<"floor", VMLA_FloatTypeAttr>;
def VMLA_CeilOp : VMLA_UnaryOp<"ceil", VMLA_FloatTypeAttr>;
diff --git a/iree/compiler/Dialect/VMLA/README.md b/iree/compiler/Dialect/VMLA/README.md
new file mode 100644
index 0000000..679f01f
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/README.md
@@ -0,0 +1,11 @@
+# VMLA (Virtual Machine-based Linear Algebra)
+
+This dialect is designed to closely model XLA HLO ops in a way that is easy to
+map to execution on the IREE VM. The changes involve using byte buffers instead
+of tensors, propagating shape information and converting shape math to simple
+integer arithmetic, and legalizing types to supported values (such as 1bit bools
+to 8bit integers of 0 or 1).
+
+## Adding an Op
+
+TODO(benvanik): document and show an example change.
diff --git a/iree/compiler/Dialect/VMLA/Transforms/BUILD b/iree/compiler/Dialect/VMLA/Transforms/BUILD
index a6049a5..9a90ca4 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/BUILD
+++ b/iree/compiler/Dialect/VMLA/Transforms/BUILD
@@ -27,6 +27,8 @@
"Passes.h",
],
deps = [
+ "//iree/compiler/Dialect/IREE/Transforms",
+ "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/VMLA/Conversion",
"//iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA",
"//iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA",
@@ -34,6 +36,7 @@
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
diff --git a/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
index a5ef099..a9c2c5f 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
@@ -23,6 +23,8 @@
"Conversion.cpp"
"Passes.cpp"
DEPS
+ iree::compiler::Dialect::IREE::Transforms
+ iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::VMLA::Conversion
iree::compiler::Dialect::VMLA::Conversion::HLOToVMLA
iree::compiler::Dialect::VMLA::Conversion::StandardToVMLA
@@ -30,6 +32,7 @@
LLVMSupport
MLIRIR
MLIRPass
+ MLIRStandardOps
MLIRSupport
MLIRTransforms
tensorflow::mlir_xla
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
index 6de0381..0519ba1 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
#include "iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.h"
#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -38,12 +40,18 @@
VMLATypeConverter typeConverter;
VMLAConversionTarget conversionTarget(context, typeConverter);
+ // Ensure all HLO goes away.
conversionTarget.addIllegalDialect<xla_hlo::XlaHloDialect>();
+ conversionTarget.addLegalDialect<ShapeDialect>();
OwningRewritePatternList conversionPatterns;
populateStandardToVMLAPatterns(context, conversionPatterns, typeConverter);
populateHLOToVMLAPatterns(context, conversionPatterns, typeConverter);
+ // Ensure FuncOp signatures are updated.
+ populateFuncOpTypeConversionPattern(conversionPatterns, context,
+ typeConverter);
+
if (failed(applyPartialConversion(getOperation(), conversionTarget,
conversionPatterns, &typeConverter))) {
getOperation().emitError() << "conversion to the VMLA dialect failed";
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
index 3ebb4d6..0ab1c4f 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
@@ -16,6 +16,7 @@
#include <memory>
+#include "iree/compiler/Dialect/IREE/Transforms/Passes.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
@@ -40,6 +41,9 @@
// TODO(benvanik): legalize input.
// passManager.addPass(IREE::VMLA::createLegalizeInputTypesPass());
+ // TODO(benvanik): preserve these hints during conversion.
+ passManager.addNestedPass<FuncOp>(createDropCompilerHintsPass());
+
// Convert from the various input dialects to the VMLA dialect.
passManager.addPass(createConversionPass());
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index 9d4a01e..d699ff5 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -114,6 +114,19 @@
%interior_padding : i32 ...
)
+vm.import @broadcast.x8(
+ %src : !iree.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %dst : !iree.ref<!vmla.buffer>, %dst_shape : i32 ...
+)
+vm.import @broadcast.x16(
+ %src : !iree.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %dst : !iree.ref<!vmla.buffer>, %dst_shape : i32 ...
+)
+vm.import @broadcast.x32(
+ %src : !iree.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %dst : !iree.ref<!vmla.buffer>, %dst_shape : i32 ...
+)
+
vm.import @tile.x8(
%src : !iree.ref<!vmla.buffer>, %src_shape : i32 ...,
%dst : !iree.ref<!vmla.buffer>, %dst_shape : i32 ...
@@ -201,6 +214,10 @@
vm.import @max.i16(%lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
vm.import @max.i32(%lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
vm.import @max.f32(%lhs : !iree.ref<!vmla.buffer>, %rhs : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
+vm.import @clamp.i8(%min : !iree.ref<!vmla.buffer>, %value : !iree.ref<!vmla.buffer>, %max : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
+vm.import @clamp.i16(%min : !iree.ref<!vmla.buffer>, %value : !iree.ref<!vmla.buffer>, %max : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
+vm.import @clamp.i32(%min : !iree.ref<!vmla.buffer>, %value : !iree.ref<!vmla.buffer>, %max : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
+vm.import @clamp.f32(%min : !iree.ref<!vmla.buffer>, %value : !iree.ref<!vmla.buffer>, %max : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
vm.import @floor.f32(%src : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
vm.import @ceil.f32(%src : !iree.ref<!vmla.buffer>, %dst : !iree.ref<!vmla.buffer>)
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index 4eaed92..f311eb9 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -299,8 +299,8 @@
struct Clamp {
template <typename T>
- static Status Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> min_buffer,
+ static Status Execute(absl::Span<const T> min_buffer,
+ absl::Span<const T> src_buffer,
absl::Span<const T> max_buffer,
absl::Span<T> dst_buffer);
};
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index 52fee3d..ea5978c 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -539,8 +539,8 @@
}
template <typename T>
-Status Clamp::Execute(absl::Span<const T> src_buffer,
- absl::Span<const T> min_buffer,
+Status Clamp::Execute(absl::Span<const T> min_buffer,
+ absl::Span<const T> src_buffer,
absl::Span<const T> max_buffer,
absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index c20e77b..b2deeda 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -439,6 +439,17 @@
IREE_VMLA_PAD_OP(PadX16, uint16_t);
IREE_VMLA_PAD_OP(PadX32, uint32_t);
+#define IREE_VMLA_BROADCAST_OP(name, type) \
+ Status name(vm::ref<iree_vmla_buffer_t>& src, iree_vmla_shape_t src_shape, \
+ vm::ref<iree_vmla_buffer_t>& dst, iree_vmla_shape_t dst_shape) { \
+ IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
+ return kernels::Broadcast::Execute<type>(src->As<type>(), \
+ dst->As<type>()); \
+ }
+ IREE_VMLA_BROADCAST_OP(BroadcastX8, uint8_t);
+ IREE_VMLA_BROADCAST_OP(BroadcastX16, uint16_t);
+ IREE_VMLA_BROADCAST_OP(BroadcastX32, uint32_t);
+
#define IREE_VMLA_TILE_OP(name, type) \
Status name(vm::ref<iree_vmla_buffer_t>& src, iree_vmla_shape_t src_shape, \
vm::ref<iree_vmla_buffer_t>& dst, iree_vmla_shape_t dst_shape) { \
@@ -532,6 +543,10 @@
IREE_VMLA_BINARY_OP(MaxI16, kernels::Max, int16_t);
IREE_VMLA_BINARY_OP(MaxI32, kernels::Max, int32_t);
IREE_VMLA_BINARY_OP(MaxF32, kernels::Max, float);
+ IREE_VMLA_TERNARY_OP(ClampI8, kernels::Clamp, int8_t);
+ IREE_VMLA_TERNARY_OP(ClampI16, kernels::Clamp, int16_t);
+ IREE_VMLA_TERNARY_OP(ClampI32, kernels::Clamp, int32_t);
+ IREE_VMLA_TERNARY_OP(ClampF32, kernels::Clamp, float);
IREE_VMLA_UNARY_OP(FloorF32, kernels::Floor, float);
IREE_VMLA_UNARY_OP(CeilF32, kernels::Ceil, float);