blob: 14337c9ec666c25ceea4d5ae4c2c56cc1aa61d3e [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===---------------------------------------------------------------------===//
// Pass to materialize the encoding of tensor based on target information.
//===---------------------------------------------------------------------===//
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
namespace mlir::iree_compiler {
using namespace IREE::LinalgExt;
using IREE::HAL::ExecutableTargetAttr;
using IREE::LinalgExt::getEncodingAttr;
//===---------------------------------------------------------------------===//
// Utility methods
//===---------------------------------------------------------------------===//
static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op,
ValueRange convertedInputOperands,
ValueRange convertedOutputOperands) {
SmallVector<Value> operands;
operands.append(convertedInputOperands.begin(), convertedInputOperands.end());
operands.append(convertedOutputOperands.begin(),
convertedOutputOperands.end());
return mlir::clone(builder, op,
{dropEncoding(cast<RankedTensorType>(
convertedOutputOperands[0].getType()))},
operands);
}
static FailureOr<SmallVector<OpFoldResult>>
getInnerTileSizesOfr(OpBuilder &rewriter, Location loc,
RankedTensorType tensorType,
const MaterializeEncodingInfo &materializeEncodingInfo,
MaterializeEncodingValueFn materializeEncodingValueFn) {
ArrayRef<int64_t> staticTileSizes = materializeEncodingInfo.innerTileSizes;
if (llvm::all_of(staticTileSizes,
[](int64_t i) { return !ShapedType::isDynamic(i); })) {
return getAsOpFoldResult(rewriter.getI64ArrayAttr(staticTileSizes));
}
assert(materializeEncodingValueFn &&
"When dynamic tile sizes are generated, a MaterializeEncodingValueFn "
"should be provided.");
FailureOr<MaterializeEncodingValueInfo> materializeEncodingValueInfo =
materializeEncodingValueFn(tensorType, rewriter, loc);
if (failed(materializeEncodingValueInfo)) {
return failure();
}
ArrayRef<Value> innerTileSizeValues =
materializeEncodingValueInfo->innerTileSizes;
SmallVector<OpFoldResult> result(staticTileSizes.size());
for (size_t i = 0; i < result.size(); ++i) {
if (ShapedType::isDynamic(staticTileSizes[i])) {
result[i] = innerTileSizeValues[i];
} else if (tensorType.isDynamicDim(i)) {
result[i] =
rewriter.create<arith::ConstantIndexOp>(loc, staticTileSizes[i])
.getResult();
} else {
result[i] = rewriter.getI64IntegerAttr(staticTileSizes[i]);
}
}
return result;
}
RankedTensorType getExpandedType(RankedTensorType type, bool isBatched,
bool isTransposed,
SmallVectorImpl<ReassociationIndices> &ri) {
if (!isBatched) {
ri.assign({{0, 1}, {2, 3}});
if (!isTransposed) {
return RankedTensorType::get(
{1, type.getDimSize(0), 1, type.getDimSize(1)},
type.getElementType());
}
return RankedTensorType::get({type.getDimSize(0), 1, type.getDimSize(1), 1},
type.getElementType());
}
ri.assign({{0}, {1, 2}, {3, 4}});
if (!isTransposed) {
return RankedTensorType::get(
{type.getDimSize(0), 1, type.getDimSize(1), 1, type.getDimSize(2)},
type.getElementType());
}
return RankedTensorType::get(
{type.getDimSize(0), type.getDimSize(1), 1, type.getDimSize(2), 1},
type.getElementType());
}
/// Given an input Value and a desired output element type, create and return
/// an element-wise linalg::GenericOp that extends the input Value to the
/// output element type.
static Value createElementWiseExtUIOp(RewriterBase &rewriter, Value input,
Location loc, Type outElemType) {
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<AffineMap> maps(
2, rewriter.getMultiDimIdentityMap(inputType.getRank()));
SmallVector<utils::IteratorType> iteratorTypes(inputType.getRank(),
utils::IteratorType::parallel);
auto castedType = inputType.clone(outElemType);
SmallVector<OpFoldResult> inputMixedSizes =
tensor::getMixedSizes(rewriter, loc, input);
Value init =
rewriter.create<tensor::EmptyOp>(loc, inputMixedSizes, outElemType);
return rewriter
.create<linalg::GenericOp>(
loc, castedType, input, init, maps, iteratorTypes,
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
Value castRes =
b.create<arith::ExtUIOp>(nestedLoc, outElemType, args[0])
->getResult(0);
b.create<linalg::YieldOp>(nestedLoc, castRes);
})
.getResult(0);
}
/// If needed, expand and the input Value, and return the resulting input with
/// the canonical mmt4d input shape. If the input element type is unsigned,
/// create a producer Linalg::GenericOp on the input that unsigned extends the
/// input to the output element type. This extension is required to keep the
/// unsignedness information on the input for ukernels.
Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp,
RewriterBase &rewriter,
SmallVectorImpl<ReassociationIndices> &ri,
ArrayRef<Type> elemTypes, int operandIdx) {
assert(linalgOp.getNumDpsInputs() == 2);
assert(linalgOp.getNumDpsInits() == 1);
auto cDims = linalg::inferContractionDims(linalgOp);
Location loc = linalgOp->getLoc();
Value expandedValue = value;
// If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the
// operand is a vector and must be extended
if ((cDims->m.empty() && operandIdx != 1) ||
(cDims->n.empty() && operandIdx != 0)) {
auto type = cast<RankedTensorType>(value.getType());
RankedTensorType newType = getExpandedType(
type, /*isBatched=*/!cDims->batch.empty(),
/*isTransposed=*/operandIdx == 2 && cDims->n.empty(), ri);
expandedValue =
rewriter.create<tensor::ExpandShapeOp>(loc, newType, value, ri);
}
if (elemTypes[operandIdx].isUnsignedInteger()) {
return createElementWiseExtUIOp(rewriter, expandedValue, loc,
elemTypes.back());
}
return expandedValue;
}
//===---------------------------------------------------------------------===//
// Methods to convert `set_encoding` and `unset_encoding` operations
// to `pack` and `unpack` operations respectively.
//===---------------------------------------------------------------------===//
/// Utility method to get the optional padding value to use with pack operation
/// if source is defined using a `tensor.pad` operation. Note `source` is
/// passed by reference. It is updated to use the source of the pad operation.
static std::optional<Value> getPaddingValue(Value &source) {
auto padOp = source.getDefiningOp<tensor::PadOp>();
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) {
return std::nullopt;
}
Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue) {
return std::nullopt;
}
source = padOp.getSource();
return constantPaddingValue;
}
/// Utility method to convert from `set_encoding` op to `pack` operation.
/// For now this takes a `paddingValue` as input. The source is also taken
/// as input so that these could be used with `OpConversionPatterns`.
static FailureOr<tensor::PackOp> lowerSetEncodingOpToPackOp(
RewriterBase &rewriter, SetEncodingOp encodingOp, Value source,
MaterializeEncodingFn materializeEncodingFn,
MaterializeEncodingValueFn materializeEncodingValueFn) {
RankedTensorType resultType = encodingOp.getResultType();
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
materializeEncodingFn(resultType);
if (failed(materializeEncodingInfo)) {
return rewriter.notifyMatchFailure(encodingOp, "unhandled result encoding");
}
// Create `tensor.empty` operation for the result of the pack operation.
Location loc = encodingOp.getLoc();
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
getInnerTileSizesOfr(rewriter, loc, resultType, *materializeEncodingInfo,
materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
}
auto encoding = getEncodingAttr(resultType);
if (!encoding) {
return failure();
}
std::optional<Value> paddingValue;
if (encoding.getRoundDimsToArray().empty()) {
paddingValue = getPaddingValue(source);
} else {
paddingValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(resultType.getElementType()));
}
SmallVector<OpFoldResult> sourceDims =
tensor::getMixedSizes(rewriter, loc, source);
SmallVector<OpFoldResult> resultDims = tensor::PackOp::getResultShape(
rewriter, loc, sourceDims, *innerTileSizesOfr,
materializeEncodingInfo->innerDimsPos,
materializeEncodingInfo->outerDimsPerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
resultType.getElementType());
return rewriter.create<tensor::PackOp>(
loc, source, emptyOp, materializeEncodingInfo->innerDimsPos,
*innerTileSizesOfr, paddingValue, materializeEncodingInfo->outerDimsPerm);
}
/// Utility method to convert from `set_encoding` op to `pack` operation.
/// The source is taken as input so that these could be used with
/// `OpConversionPatterns`.
static FailureOr<tensor::UnPackOp> lowerUnsetEncodingToUnpackOp(
RewriterBase &rewriter, UnsetEncodingOp encodingOp, Value packedValue,
MaterializeEncodingFn materializeEncodingFn,
MaterializeEncodingValueFn materializeEncodingValueFn) {
RankedTensorType sourceType = encodingOp.getSourceType();
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
materializeEncodingFn(sourceType);
if (failed(materializeEncodingInfo)) {
return rewriter.notifyMatchFailure(encodingOp, "unhandled source encoding");
}
// Create an `tensor.empty` for the result of the unpack operation.
Location loc = encodingOp.getLoc();
SmallVector<OpFoldResult> resultDims =
tensor::getMixedSizes(rewriter, loc, encodingOp.getSource());
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
sourceType.getElementType());
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
getInnerTileSizesOfr(rewriter, loc, sourceType, *materializeEncodingInfo,
materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
}
return rewriter.create<tensor::UnPackOp>(
loc, packedValue, emptyOp, materializeEncodingInfo->innerDimsPos,
*innerTileSizesOfr, materializeEncodingInfo->outerDimsPerm);
}
static FailureOr<SmallVector<Value>> lowerUpperBoundTileSizeOpToConstants(
RewriterBase &rewriter, UpperBoundTileSizeOp upperBoundTileSizeOp,
MaterializeEncodingFn materializeEncodingFn) {
Location loc = upperBoundTileSizeOp.getLoc();
RankedTensorType tensorType = upperBoundTileSizeOp.getTensorType();
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
materializeEncodingFn(tensorType);
if (failed(materializeEncodingInfo)) {
return rewriter.notifyMatchFailure(upperBoundTileSizeOp,
"unhandled source encoding");
}
ArrayRef<int64_t> innerTileSizes = materializeEncodingInfo->innerTileSizes;
ArrayRef<int64_t> innerDimsPos = materializeEncodingInfo->innerDimsPos;
SmallVector<Value> results(tensorType.getRank());
for (unsigned i = 0; i < innerTileSizes.size(); ++i) {
int64_t tileSize = innerTileSizes[i];
if (ShapedType::isDynamic(tileSize)) {
tileSize = 16;
}
results[innerDimsPos[i]] =
rewriter.create<arith::ConstantIndexOp>(loc, tileSize);
}
// For the dims that have no inner tiles, use 1 as tile size to avoid padding.
for (unsigned i = 0; i < results.size(); ++i) {
if (!results[i]) {
results[i] = rewriter.create<arith::ConstantIndexOp>(loc, 1);
}
}
return results;
}
static FailureOr<Operation *>
lowerContractionOpWithEncoding(RewriterBase &rewriter,
linalg::LinalgOp linalgOp, ValueRange operands,
MaterializeEncodingFn materializeEncodingFn) {
if (!linalgOp.hasPureTensorSemantics())
return failure();
auto inputs = linalgOp.getDpsInputOperands();
auto outputs = linalgOp.getDpsInits();
auto lhsType = cast<RankedTensorType>(inputs[0]->get().getType());
auto rhsType = cast<RankedTensorType>(inputs[1]->get().getType());
auto resultType = cast<RankedTensorType>(outputs[0].getType());
auto lhsEncoding = getEncodingAttr(lhsType);
auto rhsEncoding = getEncodingAttr(rhsType);
auto resultEncoding = getEncodingAttr(resultType);
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
return failure();
}
if (lhsEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::LHS ||
rhsEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RHS ||
resultEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RESULT) {
return failure();
}
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
materializeEncodingFn(getOriginalTypeWithEncoding(
cast<RankedTensorType>(linalgOp->getResultTypes()[0])));
Operation *result;
if (failed(materializeEncodingInfo)) {
result = dropEncodingAndCloneOp(rewriter, linalgOp,
operands.take_front(inputs.size()),
operands.drop_front(inputs.size()));
} else {
auto elemTypes = llvm::map_to_vector(
lhsEncoding.getElementTypes().getValue(),
[](Attribute a) { return cast<TypeAttr>(a).getValue(); });
SmallVector<ReassociationIndices> ri;
Value newLhs =
getMmt4dOperand(operands[0], linalgOp, rewriter, ri, elemTypes,
/*operandIdx=*/0);
Value newRhs =
getMmt4dOperand(operands[1], linalgOp, rewriter, ri, elemTypes,
/*operandIdx=*/1);
Value newResult =
getMmt4dOperand(operands[2], linalgOp, rewriter, ri, elemTypes,
/*operandIdx=*/2);
Type newResultType = newResult.getType();
auto cDims = IREE::LinalgExt::getEncodingContractionDims(lhsEncoding);
if (cDims->batch.empty()) {
result = rewriter.create<linalg::Mmt4DOp>(
linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
ValueRange{newResult});
} else {
result = rewriter.create<linalg::BatchMmt4DOp>(
linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
ValueRange{newResult});
}
if (!ri.empty()) {
result = rewriter.create<tensor::CollapseShapeOp>(
linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri);
}
}
return result;
}
/// Utility method to convert `tensor.empty` with encoding to a `tensor.empty`
/// of the materialized type.
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp,
ValueRange convertedOperands,
MaterializeEncodingFn materializeEncodingFn,
MaterializeEncodingValueFn materializeEncodingValueFn) {
auto emptyType = cast<RankedTensorType>(emptyOp->getResultTypes()[0]);
auto resultType =
getOriginalTypeWithEncoding(emptyType).clone(emptyType.getElementType());
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
materializeEncodingFn(resultType);
Location loc = emptyOp.getLoc();
if (failed(materializeEncodingInfo)) {
Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
loc, emptyOp.getMixedSizes(), resultType.getElementType());
return newEmptyOp;
}
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
getInnerTileSizesOfr(rewriter, loc, resultType, *materializeEncodingInfo,
materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
emptyOp, "failed to generate runtime tile size query");
}
SmallVector<OpFoldResult> sourceDims = emptyOp.getMixedSizes();
(void)foldDynamicIndexList(sourceDims);
SmallVector<OpFoldResult> newShape =
PackOp::getResultShape(rewriter, loc, sourceDims, *innerTileSizesOfr,
materializeEncodingInfo->innerDimsPos,
materializeEncodingInfo->outerDimsPerm);
Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
loc, newShape, resultType.getElementType());
return newEmptyOp;
}
/// Utility method to convert from a linalg::LinalgOp on `tensor` types with
/// encodings to a linalg::LinalgOp on the materialized type. The current
/// supported op types are:
/// - linalg::LinalgOp that `isaContractionOpInterface`
/// - linalg::FillOp
/// - element-wise linalg::GenericOp with single input and output
static FailureOr<Operation *> lowerOpWithEncoding(
RewriterBase &rewriter, linalg::LinalgOp linalgOp,
ValueRange convertedInputOperands, ValueRange convertedOutputOperands,
MaterializeEncodingFn materializeEncodingFn, MaterializeEncodingValueFn) {
if (linalg::isaContractionOpInterface(linalgOp)) {
SmallVector<Value> operands;
operands.append(convertedInputOperands.begin(),
convertedInputOperands.end());
operands.append(convertedOutputOperands.begin(),
convertedOutputOperands.end());
return lowerContractionOpWithEncoding(rewriter, linalgOp, operands,
materializeEncodingFn);
}
return TypeSwitch<Operation *, FailureOr<Operation *>>(linalgOp)
.Case<linalg::FillOp>(
[&](linalg::FillOp fillOp) -> FailureOr<Operation *> {
if (!fillOp.hasPureTensorSemantics())
return failure();
Operation *materializedFillOp = rewriter.create<linalg::FillOp>(
fillOp.getLoc(), convertedOutputOperands[0].getType(),
convertedInputOperands, convertedOutputOperands);
return materializedFillOp;
})
.Case<linalg::GenericOp>([&](linalg::GenericOp genericOp)
-> FailureOr<Operation *> {
if (!genericOp.hasPureTensorSemantics() || !isElementwise(genericOp) ||
genericOp.getNumDpsInputs() != 1 ||
genericOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(
genericOp, "linalg.generic op is not elementwise "
"with single input and single output");
}
if (!llvm::all_of(genericOp.getIndexingMapsArray(),
[](AffineMap m) { return m.isIdentity(); })) {
return rewriter.notifyMatchFailure(
genericOp, "indexing maps are not all identity maps");
}
auto convertedResultType =
cast<RankedTensorType>(convertedOutputOperands[0].getType());
SmallVector<AffineMap> maps(
2, AffineMap::getMultiDimIdentityMap(convertedResultType.getRank(),
rewriter.getContext()));
SmallVector<utils::IteratorType> iteratorTypes(
convertedResultType.getRank(), utils::IteratorType::parallel);
auto materializedGenericOp = rewriter.create<linalg::GenericOp>(
genericOp.getLoc(), convertedResultType, convertedInputOperands,
convertedOutputOperands, maps, iteratorTypes,
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
rewriter.inlineRegionBefore(genericOp.getRegion(),
materializedGenericOp.getRegion(),
materializedGenericOp.getRegion().begin());
return materializedGenericOp.getOperation();
})
.Default([](Operation *op) { return failure(); });
}
/// For `dispatchTensorType` that bind a `RankedTensorType` with encoding,
/// returns the materialized shape of the `dispatchTensorType`. The
/// dynamic dimensions of the `dispatchTensorType` are provided in
/// `dynamicDims`.
static FailureOr<SmallVector<OpFoldResult>> getPackedDimsForDispatchTensor(
OpBuilder &builder, Location loc,
const MaterializeEncodingTypeConverter &typeConverter,
IREE::Flow::DispatchTensorType dispatchTensorType, ValueRange dynamicDims,
MaterializeEncodingValueFn materializeEncodingValueFn) {
auto boundTensorType =
llvm::dyn_cast<RankedTensorType>(dispatchTensorType.getBoundType());
if (!boundTensorType) {
return failure();
}
RankedTensorType originalTensorType =
getOriginalTypeWithEncoding(boundTensorType);
MaterializeEncodingFn materializeEncodingFn =
typeConverter.getMaterializeEncodingFn();
FailureOr<MaterializeEncodingInfo> encodingInfo =
materializeEncodingFn(boundTensorType);
if (failed(encodingInfo)) {
return failure();
}
SmallVector<OpFoldResult> targetShape =
getMixedValues(originalTensorType.getShape(), dynamicDims, builder);
auto innerTileSizes =
getInnerTileSizesOfr(builder, loc, originalTensorType, *encodingInfo,
materializeEncodingValueFn);
if (failed(innerTileSizes)) {
return failure();
}
SmallVector<OpFoldResult> convertedTargetShape =
tensor::PackOp::getResultShape(builder, loc, targetShape, *innerTileSizes,
encodingInfo->innerDimsPos,
encodingInfo->outerDimsPerm);
return convertedTargetShape;
}
/// For `dispatchTensorType` that bind a `RankedTensorType` with encoding,
/// returns the dynamic dimensions of the materialized shape of the
/// `dispatchTensorType`. The dynamic dimensions of the `dispatchTensorType` are
/// provided in `dynamicDims`.
static FailureOr<SmallVector<Value>> getPackedDynamicDimsForDispatchTensor(
OpBuilder &builder, Location loc,
const MaterializeEncodingTypeConverter &typeConverter,
IREE::Flow::DispatchTensorType dispatchTensorType, ValueRange dynamicDims,
MaterializeEncodingValueFn materializeEncodingValueFn) {
FailureOr<SmallVector<OpFoldResult>> convertedTargetShape =
getPackedDimsForDispatchTensor(builder, loc, typeConverter,
dispatchTensorType, dynamicDims,
materializeEncodingValueFn);
if (failed(convertedTargetShape)) {
return failure();
}
SmallVector<int64_t> convertedStaticTargetShape;
SmallVector<Value> convertedDynamicTargetShape;
dispatchIndexOpFoldResults(convertedTargetShape.value(),
convertedDynamicTargetShape,
convertedStaticTargetShape);
return convertedDynamicTargetShape;
}
namespace {
/// Pattern to materialize the encoding for `hal.interface.binding.subspan`
/// operations.
struct MaterializeInterfaceBindingEncoding
: public OpMaterializeEncodingPattern<
IREE::HAL::InterfaceBindingSubspanOp> {
using OpMaterializeEncodingPattern<
IREE::HAL::InterfaceBindingSubspanOp>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp subspanOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultType = llvm::dyn_cast<IREE::Flow::DispatchTensorType>(
subspanOp.getResult().getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
subspanOp, "expected result type to be !flow.dispatch.tensor");
}
auto boundTensorType =
llvm::dyn_cast<RankedTensorType>(resultType.getBoundType());
if (!boundTensorType) {
return rewriter.notifyMatchFailure(
subspanOp, "bound type is not a RankedTensorType");
}
auto convertedBoundType = getTypeConverter()->convertType(boundTensorType);
if (convertedBoundType == boundTensorType) {
return rewriter.notifyMatchFailure(subspanOp, "bound type already valid");
}
auto *typeConverter = static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter());
// Get the dynamic dims of the target.
Location loc = subspanOp.getLoc();
SmallVector<Value> newDynamicDims = subspanOp.getDynamicDims();
FailureOr<SmallVector<Value>> convertedDynamicDims =
getPackedDynamicDimsForDispatchTensor(
rewriter, loc, *typeConverter, resultType,
subspanOp.getDynamicDims(), this->materializeEncodingValueFn);
// Drop the encoding if the target does not support it.
if (succeeded(convertedDynamicDims)) {
newDynamicDims = convertedDynamicDims.value();
}
auto newResultType = IREE::Flow::DispatchTensorType::get(
resultType.getAccess(), convertedBoundType);
rewriter.replaceOpWithNewOp<IREE::HAL::InterfaceBindingSubspanOp>(
subspanOp, newResultType, subspanOp.getSet(), subspanOp.getBinding(),
subspanOp.getDescriptorType(), subspanOp.getByteOffset(),
newDynamicDims, subspanOp.getAlignmentAttr(),
subspanOp.getDescriptorFlagsAttr());
return success();
}
};
/// Pattern to convert `flow.dispatch.tensor.store` operation when
/// materializing the encoding.
struct MaterializeFlowDispatchTensorLoadOp
: public OpMaterializeEncodingPattern<IREE::Flow::DispatchTensorLoadOp> {
using OpMaterializeEncodingPattern<
IREE::Flow::DispatchTensorLoadOp>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(IREE::Flow::DispatchTensorLoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle operations where the load covers the entire
// `!flow.dispatch.tensor` type.
// TODO(ravishankarm): Relax this for partial loads.
if (!loadOp.isLoadOfWholeSource()) {
return rewriter.notifyMatchFailure(loadOp, "unhandled partial loads");
}
auto sourceType = loadOp.getSourceType();
auto boundTensorType = cast<RankedTensorType>(sourceType.getBoundType());
auto *typeConverter = static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter());
if (typeConverter->convertType(boundTensorType) == boundTensorType) {
return rewriter.notifyMatchFailure(loadOp, "bound type already valid");
}
Location loc = loadOp.getLoc();
SmallVector<OpFoldResult> newMixedSizes = getMixedValues(
boundTensorType.getShape(), loadOp.getSourceDims(), rewriter);
FailureOr<SmallVector<OpFoldResult>> convertedMixedSizes =
getPackedDimsForDispatchTensor(rewriter, loc, *typeConverter,
sourceType, loadOp.getSourceDims(),
this->materializeEncodingValueFn);
if (succeeded(convertedMixedSizes)) {
newMixedSizes = convertedMixedSizes.value();
}
SmallVector<OpFoldResult> newOffsets(newMixedSizes.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> newStrides(newMixedSizes.size(),
rewriter.getIndexAttr(1));
SmallVector<int64_t> newStaticDims;
SmallVector<Value> newDynamicDims;
dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims);
rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorLoadOp>(
loadOp, adaptor.getSource(), newDynamicDims, newOffsets, newMixedSizes,
newStrides);
return success();
}
};
/// Pattern to convert `flow.dispatch.tensor.store` operation when
/// materializing the encoding.
struct MaterializeFlowDispatchTensorStoreOp
: public OpMaterializeEncodingPattern<IREE::Flow::DispatchTensorStoreOp> {
using OpMaterializeEncodingPattern<
IREE::Flow::DispatchTensorStoreOp>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(IREE::Flow::DispatchTensorStoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle operations where the store covers the entire
// `!flow.dispatch.tensor` type.
// TODO(ravishankarm): Relax this for partial stores.
if (!storeOp.isStoreToWholeTarget()) {
return rewriter.notifyMatchFailure(storeOp, "unhandled partial stores");
}
auto targetType = storeOp.getTargetType();
auto boundTensorType = cast<RankedTensorType>(targetType.getBoundType());
auto *typeConverter = static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter());
if (typeConverter->convertType(boundTensorType) == boundTensorType) {
return rewriter.notifyMatchFailure(storeOp, "bound type already valid");
}
Location loc = storeOp.getLoc();
SmallVector<OpFoldResult> newMixedSizes = getMixedValues(
boundTensorType.getShape(), storeOp.getTargetDims(), rewriter);
FailureOr<SmallVector<OpFoldResult>> convertedMixedSizes =
getPackedDimsForDispatchTensor(rewriter, loc, *typeConverter,
targetType, storeOp.getTargetDims(),
this->materializeEncodingValueFn);
if (succeeded(convertedMixedSizes)) {
newMixedSizes = convertedMixedSizes.value();
}
SmallVector<OpFoldResult> newOffsets(newMixedSizes.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> newStrides(newMixedSizes.size(),
rewriter.getIndexAttr(1));
SmallVector<int64_t> newStaticDims;
SmallVector<Value> newDynamicDims;
dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims);
rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorStoreOp>(
storeOp, adaptor.getValue(), adaptor.getTarget(), newDynamicDims,
newOffsets, newMixedSizes, newStrides);
return success();
}
};
//===---------------------------------------------------------------------===//
// Patterns to lower ops with encodings. These are written as
// dialect conversion patterns for now. These are just drivers around
// the core conversion utilities.
//===---------------------------------------------------------------------===//
/// Convert `set_encoding` op to `pack` op.
struct SetEncodingOpToPackOpConversion
: public OpMaterializeEncodingPattern<SetEncodingOp> {
using OpMaterializeEncodingPattern<
SetEncodingOp>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(SetEncodingOp encodingOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MaterializeEncodingFn materializeEncodingFn =
static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter())
->getMaterializeEncodingFn();
auto packOp = lowerSetEncodingOpToPackOp(
rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn,
this->materializeEncodingValueFn);
if (failed(packOp)) {
Value result = adaptor.getSource();
Type targetType =
getTypeConverter()->convertType(encodingOp.getResultType());
if (targetType != result.getType()) {
result = rewriter.create<tensor::CastOp>(encodingOp.getLoc(),
targetType, result);
}
rewriter.replaceOp(encodingOp, result);
return success();
}
rewriter.replaceOp(encodingOp, packOp->getResult());
return success();
}
};
/// Convert `unset_encoding` op to `unpack` op.
struct UnsetEncodingOpToUnPackOpConversion
: public OpMaterializeEncodingPattern<UnsetEncodingOp> {
using OpMaterializeEncodingPattern<
UnsetEncodingOp>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(UnsetEncodingOp encodingOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MaterializeEncodingFn materializeEncodingFn =
static_cast<const MaterializeEncodingTypeConverter *>(
this->getTypeConverter())
->getMaterializeEncodingFn();
auto unpackOp = lowerUnsetEncodingToUnpackOp(
rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn,
this->materializeEncodingValueFn);
if (failed(unpackOp)) {
Value result = adaptor.getSource();
Type targetType =
getTypeConverter()->convertType(encodingOp.getResultType());
if (targetType != result.getType()) {
result = rewriter.create<tensor::CastOp>(encodingOp.getLoc(),
targetType, result);
}
rewriter.replaceOp(encodingOp, result);
return success();
}
rewriter.replaceOp(encodingOp, unpackOp->getResult());
return success();
}
};
/// Convert `upper_bound_tile_size` op to `constant` op. If the
/// `materializeEncodingFn` returns a failure, the pattern will materialize it
/// to the same shape.
struct UpperBoundTileSizeToConstantOpConversion
: public OpRewritePattern<UpperBoundTileSizeOp> {
UpperBoundTileSizeToConstantOpConversion(
MLIRContext *context, MaterializeEncodingFn materializeEncodingFn)
: OpRewritePattern<UpperBoundTileSizeOp>(context),
materializeEncodingFn(materializeEncodingFn) {}
LogicalResult matchAndRewrite(UpperBoundTileSizeOp upperBoundTileSizeOp,
PatternRewriter &rewriter) const override {
auto constants = lowerUpperBoundTileSizeOpToConstants(
rewriter, upperBoundTileSizeOp, materializeEncodingFn);
if (failed(constants)) {
SmallVector<Value> results(upperBoundTileSizeOp.getNumResults(),
rewriter.create<arith::ConstantIndexOp>(
upperBoundTileSizeOp.getLoc(), 1));
rewriter.replaceOp(upperBoundTileSizeOp, results);
return success();
}
rewriter.replaceOp(upperBoundTileSizeOp, *constants);
return success();
}
MaterializeEncodingFn materializeEncodingFn;
};
/// Generic pattern to convert operation that is in Destination Passing Style.
template <typename OpTy>
struct MaterializeDPSOperation : public OpMaterializeEncodingPattern<OpTy> {
using OpMaterializeEncodingPattern<OpTy>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(OpTy dpsOp, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MaterializeEncodingFn materializeEncodingFn =
static_cast<const MaterializeEncodingTypeConverter *>(
this->getTypeConverter())
->getMaterializeEncodingFn();
FailureOr<Operation *> convertedOp = lowerOpWithEncoding(
rewriter, dpsOp, adaptor.getInputs(), adaptor.getOutputs(),
materializeEncodingFn, this->materializeEncodingValueFn);
if (failed(convertedOp)) {
return failure();
}
rewriter.replaceOp(dpsOp, convertedOp.value()->getResults());
return success();
}
};
/// Generic pattern to convert an operation.
template <typename OpTy>
struct MaterializeOperation : public OpMaterializeEncodingPattern<OpTy> {
using OpMaterializeEncodingPattern<OpTy>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MaterializeEncodingFn materializeEncodingFn =
static_cast<const MaterializeEncodingTypeConverter *>(
this->getTypeConverter())
->getMaterializeEncodingFn();
FailureOr<Operation *> convertedOp = lowerOpWithEncoding(
rewriter, op, adaptor.getOperands(), materializeEncodingFn,
this->materializeEncodingValueFn);
if (failed(convertedOp))
return failure();
SmallVector<Value> replacements;
for (auto [type, res] : llvm::zip_equal(
op->getResultTypes(), convertedOp.value()->getResults())) {
Type targetType = this->getTypeConverter()->convertType(type);
if (targetType == res.getType()) {
replacements.push_back(res);
} else {
replacements.push_back(
rewriter.create<tensor::CastOp>(op.getLoc(), targetType, res));
}
}
rewriter.replaceOp(op, replacements);
return success();
}
};
/// Pattern to convert contraction operations.
class MaterializeContractionOp : public OpInterfaceConversionPattern<
mlir::linalg::ContractionOpInterface> {
public:
MaterializeContractionOp(
MLIRContext *context,
const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn = {},
PatternBenefit benefit = 1)
: OpInterfaceConversionPattern<mlir::linalg::ContractionOpInterface>(
typeConverter, context, benefit),
materializeEncodingValueFn(materializeEncodingValueFn) {}
LogicalResult
matchAndRewrite(mlir::linalg::ContractionOpInterface op,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MaterializeEncodingFn materializeEncodingFn =
static_cast<const MaterializeEncodingTypeConverter *>(
this->getTypeConverter())
->getMaterializeEncodingFn();
auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
if (!linalgOp || operands.size() != 3) {
return failure();
}
FailureOr<Operation *> convertedOp = lowerOpWithEncoding(
rewriter, linalgOp, operands.take_front(2), operands.take_back(1),
materializeEncodingFn, this->materializeEncodingValueFn);
if (failed(convertedOp)) {
return failure();
}
rewriter.replaceOp(op.getOperation(), convertedOp.value()->getResult(0));
return success();
}
protected:
const MaterializeEncodingValueFn materializeEncodingValueFn;
};
} // namespace
void populateMaterializeEncodingIntoPackUnPackPatterns(
RewritePatternSet &patterns, MaterializeEncodingConversionTarget &target,
MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
MLIRContext *context = patterns.getContext();
typeConverter.addConversion(
[&typeConverter](IREE::Flow::DispatchTensorType dispatchTensorType) {
Type boundType = dispatchTensorType.getBoundType();
Type convertedBoundType = typeConverter.convertType(boundType);
if (convertedBoundType == boundType) {
return dispatchTensorType;
}
return IREE::Flow::DispatchTensorType::get(
dispatchTensorType.getAccess(), convertedBoundType);
});
target.addDynamicallyLegalOp<IREE::HAL::InterfaceBindingSubspanOp>(
[&typeConverter](IREE::HAL::InterfaceBindingSubspanOp subspanOp) {
auto resultType = llvm::dyn_cast<IREE::Flow::DispatchTensorType>(
subspanOp.getResult().getType());
// For types that are not `Flow::DispatchTensorType` mark as legal.
if (!resultType)
return true;
return resultType == typeConverter.convertType(resultType);
});
// Add all patterns for converting from encoded type to the materialized
// type.
patterns.insert<MaterializeDPSOperation<linalg::FillOp>,
MaterializeDPSOperation<linalg::GenericOp>,
MaterializeOperation<tensor::EmptyOp>,
MaterializeContractionOp, SetEncodingOpToPackOpConversion,
UnsetEncodingOpToUnPackOpConversion>(
patterns.getContext(), typeConverter, materializeEncodingValueFn);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
patterns.insert<MaterializeFlowDispatchTensorLoadOp,
MaterializeFlowDispatchTensorStoreOp,
MaterializeInterfaceBindingEncoding>(
context, typeConverter, materializeEncodingValueFn);
}
void populateMaterializeUpperBoundTileSizePatterns(
RewritePatternSet &patterns, MaterializeEncodingFn materializeEncodingFn) {
patterns.insert<UpperBoundTileSizeToConstantOpConversion>(
patterns.getContext(), materializeEncodingFn);
}
} // namespace mlir::iree_compiler