blob: cd3d27e5c7f95cfd2676f1053bb624dbe597e2b9 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===---------------------------------------------------------------------===//
// Pass to materialize the encoding of tensor based on target information.
//===---------------------------------------------------------------------===//
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
namespace mlir::iree_compiler {
using IREE::Codegen::MaterializeEncodingInfo;
using IREE::Codegen::TileSwizzle;
//===---------------------------------------------------------------------===//
// Utility methods
//===---------------------------------------------------------------------===//
// Utility to apply a tile-swizzling to a packed shape.
static SmallVector<OpFoldResult>
getSwizzledShape(ArrayRef<OpFoldResult> packedShape,
MaterializeEncodingInfo encodingInfo) {
if (packedShape.empty() || !encodingInfo.swizzle) {
return SmallVector<OpFoldResult>(packedShape);
}
int64_t srcRank = packedShape.size() - encodingInfo.innerTileSizes.size();
SmallVector<int64_t> perm = llvm::to_vector(llvm::seq<int64_t>(0, srcRank));
for (auto i : encodingInfo.swizzle->permutation) {
perm.push_back(i + srcRank);
}
SmallVector<OpFoldResult> newShape(packedShape.take_front(srcRank));
SmallVector<int64_t> expandedTileShape =
IREE::Codegen::getExpandedTileShape(encodingInfo.swizzle->expandShape);
MLIRContext *ctx = packedShape[0].getContext();
Builder b(ctx);
for (int64_t d : expandedTileShape) {
newShape.push_back(b.getIndexAttr(d));
}
applyPermutationToVector(newShape, perm);
return newShape;
}
static FailureOr<SmallVector<OpFoldResult>>
getInnerTileSizesOfr(OpBuilder &rewriter, Location loc,
RankedTensorType tensorType,
const MaterializeEncodingInfo &materializeEncodingInfo,
MaterializeEncodingValueFn materializeEncodingValueFn) {
ArrayRef<int64_t> staticTileSizes = materializeEncodingInfo.innerTileSizes;
if (llvm::all_of(staticTileSizes,
[](int64_t i) { return !ShapedType::isDynamic(i); })) {
return getAsOpFoldResult(rewriter.getI64ArrayAttr(staticTileSizes));
}
assert(materializeEncodingValueFn &&
"When dynamic tile sizes are generated, a MaterializeEncodingValueFn "
"should be provided.");
FailureOr<MaterializeEncodingValueInfo> materializeEncodingValueInfo =
materializeEncodingValueFn(tensorType, rewriter, loc);
if (failed(materializeEncodingValueInfo)) {
return failure();
}
ArrayRef<Value> innerTileSizeValues =
materializeEncodingValueInfo->innerTileSizes;
SmallVector<OpFoldResult> result(staticTileSizes.size());
for (size_t i = 0; i < result.size(); ++i) {
if (ShapedType::isDynamic(staticTileSizes[i])) {
result[i] = innerTileSizeValues[i];
} else if (tensorType.isDynamicDim(i)) {
result[i] =
rewriter.create<arith::ConstantIndexOp>(loc, staticTileSizes[i])
.getResult();
} else {
result[i] = rewriter.getI64IntegerAttr(staticTileSizes[i]);
}
}
return result;
}
//===---------------------------------------------------------------------===//
// Methods to convert `set_encoding` and `unset_encoding` operations
// to `pack` and `unpack` operations respectively.
//===---------------------------------------------------------------------===//
FailureOr<Value> lowerSetEncodingOpToPackOp(
RewriterBase &rewriter, IREE::Encoding::SetEncodingOp encodingOp,
Value source, const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
RankedTensorType resultType = encodingOp.getResultType();
MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(resultType);
// Shortcut to avoid creating new operations.
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return source;
}
auto encoding = IREE::Encoding::getEncodingAttr(resultType);
if (!encoding) {
return failure();
}
// Create `tensor.empty` operation for the result of the pack operation.
Location loc = encodingOp.getLoc();
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
rewriter, loc, resultType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
}
Value paddingValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(resultType.getElementType()));
SmallVector<OpFoldResult> sourceDims =
tensor::getMixedSizes(rewriter, loc, source);
SmallVector<OpFoldResult> resultDims = tensor::PackOp::getResultShape(
rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
resultType.getElementType());
return rewriter
.create<tensor::PackOp>(loc, source, emptyOp, encodingInfo.innerDimsPos,
*innerTileSizesOfr, paddingValue,
encodingInfo.outerDimsPerm)
.getResult();
}
FailureOr<Value> lowerUnsetEncodingToUnpackOp(
RewriterBase &rewriter, IREE::Encoding::UnsetEncodingOp encodingOp,
Value packedValue, const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
RankedTensorType sourceType = encodingOp.getSourceType();
MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(sourceType);
// Shortcut to avoid creating new operations.
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return packedValue;
}
// Create an `tensor.empty` for the result of the unpack operation.
Location loc = encodingOp.getLoc();
SmallVector<OpFoldResult> resultDims =
getMixedValues(encodingOp.getResultType().getShape(),
encodingOp.getResultDims(), rewriter);
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
sourceType.getElementType());
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
rewriter, loc, sourceType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
}
return rewriter
.create<tensor::UnPackOp>(loc, packedValue, emptyOp,
encodingInfo.innerDimsPos, *innerTileSizesOfr,
encodingInfo.outerDimsPerm)
.getResult();
}
/// Utility method to convert `tensor.empty` with encoding to a `tensor.empty`
/// of the materialized type.
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp,
ValueRange convertedOperands,
const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
auto emptyType = cast<RankedTensorType>(emptyOp->getResultTypes()[0]);
MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(emptyType);
Location loc = emptyOp.getLoc();
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return rewriter
.create<tensor::EmptyOp>(loc, emptyOp.getMixedSizes(),
emptyType.getElementType())
.getOperation();
}
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
rewriter, loc, emptyType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
emptyOp, "failed to generate runtime tile size query");
}
SmallVector<OpFoldResult> sourceDims = emptyOp.getMixedSizes();
(void)foldDynamicIndexList(sourceDims);
SmallVector<OpFoldResult> newShape = tensor::PackOp::getResultShape(
rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm);
newShape = getSwizzledShape(newShape, encodingInfo);
Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
loc, newShape, emptyType.getElementType());
return newEmptyOp;
}
/// Converts a linalg::GenericOp with encoded inputs into the packed domain.
/// The `genericOp` must have all parallel iterator types and a single output
/// with an identity indexing map.
static FailureOr<Operation *> lowerGenericOpWithEncoding(
RewriterBase &rewriter, linalg::GenericOp genericOp,
ValueRange convertedInputOperands, ValueRange convertedOutputOperands,
const MaterializeEncodingTypeConverter &typeConverter) {
OpOperand *outputOperand = genericOp.getDpsInitOperand(0);
AffineMap outputMap = genericOp.getMatchingIndexingMap(outputOperand);
if (!outputMap.isIdentity()) {
return rewriter.notifyMatchFailure(genericOp,
"Output indexing map is not identity");
}
MaterializeEncodingInfo outMaterializeEncodingInfo =
typeConverter.getEncodingInfo(
cast<RankedTensorType>(outputOperand->get().getType()));
if (IREE::Codegen::isIdentityLayout(outMaterializeEncodingInfo)) {
return rewriter.notifyMatchFailure(
genericOp, "MaterializeEncodingInfo failed for output");
}
if (outMaterializeEncodingInfo.swizzle) {
return rewriter.notifyMatchFailure(
genericOp, "generic op lowering does not support swizzle yet");
}
auto convertedResultType =
cast<RankedTensorType>(convertedOutputOperands[0].getType());
SmallVector<utils::IteratorType> iteratorTypes(convertedResultType.getRank(),
utils::IteratorType::parallel);
// Compute the new indexing maps for the packed layout. This assumes that
// the output map is identity, and that all iterator types are parallel.
SmallVector<int64_t> outInnerDimsPos =
outMaterializeEncodingInfo.innerDimsPos;
SmallVector<int64_t> outInverseOuterDimsPerm =
invertPermutationVector(outMaterializeEncodingInfo.outerDimsPerm);
SmallVector<AffineMap> packedIndexingMaps;
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
MaterializeEncodingInfo materializeEncodingInfo =
typeConverter.getEncodingInfo(
cast<RankedTensorType>(inputOperand->get().getType()));
if (IREE::Codegen::isIdentityLayout(materializeEncodingInfo)) {
return rewriter.notifyMatchFailure(
genericOp, "MaterializeEncodingInfo failed for input");
}
ArrayRef<int64_t> innerDimsPos = materializeEncodingInfo.innerDimsPos;
ArrayRef<int64_t> outerDimsPerm = materializeEncodingInfo.outerDimsPerm;
AffineMap inputMap = genericOp.getMatchingIndexingMap(inputOperand);
// Permute result dims to the input packed domain, and map dims to the
// output packed domain.
SmallVector<int64_t> packedResultDims = llvm::map_to_vector(
applyPermutation(inputMap.getResults(), outerDimsPerm),
[&](AffineExpr expr) {
auto dimExpr = cast<AffineDimExpr>(expr);
return outInverseOuterDimsPerm[dimExpr.getPosition()];
});
// Add new dims for the inner tiles, taking the dim position from the
// corresponding inner tile of the init operand.
for (auto [idx, pos] : llvm::enumerate(innerDimsPos)) {
auto dimPos = cast<AffineDimExpr>(inputMap.getResult(pos)).getPosition();
for (auto [tileIdx, outDim] : llvm::enumerate(outInnerDimsPos)) {
if (dimPos == outDim) {
packedResultDims.push_back(outputMap.getNumDims() + tileIdx);
}
}
}
// Create the packed indexing map.
SmallVector<AffineExpr> packedResultExprs =
llvm::map_to_vector(packedResultDims, [&](int64_t dim) {
return rewriter.getAffineDimExpr(dim);
});
auto packedInputMap = AffineMap::get(
/*dimCount=*/iteratorTypes.size(), /*symbolCount=*/0, packedResultExprs,
rewriter.getContext());
packedIndexingMaps.push_back(packedInputMap);
}
// Create the new packed identity map for the output.
packedIndexingMaps.push_back(
rewriter.getMultiDimIdentityMap(convertedResultType.getRank()));
auto materializedGenericOp = rewriter.create<linalg::GenericOp>(
genericOp.getLoc(), convertedResultType, convertedInputOperands,
convertedOutputOperands, packedIndexingMaps, iteratorTypes,
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
rewriter.inlineRegionBefore(genericOp.getRegion(),
materializedGenericOp.getRegion(),
materializedGenericOp.getRegion().begin());
return materializedGenericOp.getOperation();
}
/// Utility method to convert from a linalg::LinalgOp on `tensor` types with
/// encodings to a linalg::LinalgOp on the materialized type. The current
/// supported op types are:
/// - linalg::FillOp
/// - linalg::GenericOp
// - All the iterators are parallel iterators.
// - The op has a single output.
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
ValueRange convertedInputOperands,
ValueRange convertedOutputOperands,
const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn) {
if (!linalgOp.hasPureTensorSemantics()) {
return rewriter.notifyMatchFailure(linalgOp, "Not pure tensor semantics");
}
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) {
return rewriter.notifyMatchFailure(linalgOp, "Loops are not all parallel");
}
if (linalgOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(linalgOp, "Not only 1 init operand");
}
return TypeSwitch<Operation *, FailureOr<Operation *>>(linalgOp)
.Case<linalg::FillOp>(
[&](linalg::FillOp fillOp) -> FailureOr<Operation *> {
Operation *materializedFillOp = rewriter.create<linalg::FillOp>(
fillOp.getLoc(), convertedOutputOperands[0].getType(),
convertedInputOperands, convertedOutputOperands);
return materializedFillOp;
})
.Case<linalg::GenericOp>(
[&](linalg::GenericOp genericOp) -> FailureOr<Operation *> {
return lowerGenericOpWithEncoding(
rewriter, genericOp, convertedInputOperands,
convertedOutputOperands, typeConverter);
})
.Default([](Operation *op) { return failure(); });
}
/// For `dispatchTensorType` that bind a `RankedTensorType` with encoding,
/// returns the materialized shape of the `dispatchTensorType`. The
/// dynamic dimensions of the `dispatchTensorType` are provided in
/// `dynamicDims`.
static FailureOr<SmallVector<OpFoldResult>> getPackedDimsForDispatchTensor(
OpBuilder &builder, Location loc,
const MaterializeEncodingTypeConverter &typeConverter,
IREE::Flow::DispatchTensorType dispatchTensorType, ValueRange dynamicDims,
MaterializeEncodingValueFn materializeEncodingValueFn) {
auto boundTensorType =
llvm::dyn_cast<RankedTensorType>(dispatchTensorType.getBoundType());
if (!boundTensorType) {
return failure();
}
MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(boundTensorType);
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return failure();
}
SmallVector<OpFoldResult> targetShape =
getMixedValues(boundTensorType.getShape(), dynamicDims, builder);
auto innerTileSizes = getInnerTileSizesOfr(
builder, loc, boundTensorType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizes)) {
return failure();
}
SmallVector<OpFoldResult> convertedTargetShape =
tensor::PackOp::getResultShape(builder, loc, targetShape, *innerTileSizes,
encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm);
return getSwizzledShape(convertedTargetShape, encodingInfo);
}
/// For `dispatchTensorType` that bind a `RankedTensorType` with encoding,
/// returns the dynamic dimensions of the materialized shape of the
/// `dispatchTensorType`. The dynamic dimensions of the `dispatchTensorType` are
/// provided in `dynamicDims`.
static FailureOr<SmallVector<Value>> getPackedDynamicDimsForDispatchTensor(
OpBuilder &builder, Location loc,
const MaterializeEncodingTypeConverter &typeConverter,
IREE::Flow::DispatchTensorType dispatchTensorType, ValueRange dynamicDims,
MaterializeEncodingValueFn materializeEncodingValueFn) {
FailureOr<SmallVector<OpFoldResult>> convertedTargetShape =
getPackedDimsForDispatchTensor(builder, loc, typeConverter,
dispatchTensorType, dynamicDims,
materializeEncodingValueFn);
if (failed(convertedTargetShape)) {
return failure();
}
SmallVector<int64_t> convertedStaticTargetShape;
SmallVector<Value> convertedDynamicTargetShape;
dispatchIndexOpFoldResults(convertedTargetShape.value(),
convertedDynamicTargetShape,
convertedStaticTargetShape);
return convertedDynamicTargetShape;
}
namespace {
/// Pattern to materialize the encoding for `hal.interface.binding.subspan`
/// operations.
struct MaterializeInterfaceBindingEncoding
: public OpMaterializeEncodingPattern<
IREE::HAL::InterfaceBindingSubspanOp> {
using OpMaterializeEncodingPattern<
IREE::HAL::InterfaceBindingSubspanOp>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp subspanOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultType = llvm::dyn_cast<IREE::Flow::DispatchTensorType>(
subspanOp.getResult().getType());
if (!resultType) {
return rewriter.notifyMatchFailure(
subspanOp, "expected result type to be !flow.dispatch.tensor");
}
auto boundTensorType =
llvm::dyn_cast<RankedTensorType>(resultType.getBoundType());
if (!boundTensorType) {
return rewriter.notifyMatchFailure(
subspanOp, "bound type is not a RankedTensorType");
}
auto convertedBoundType = getTypeConverter()->convertType(boundTensorType);
if (convertedBoundType == boundTensorType) {
return rewriter.notifyMatchFailure(subspanOp, "bound type already valid");
}
auto *typeConverter = static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter());
// Get the dynamic dims of the target.
Location loc = subspanOp.getLoc();
SmallVector<Value> newDynamicDims = subspanOp.getDynamicDims();
FailureOr<SmallVector<Value>> convertedDynamicDims =
getPackedDynamicDimsForDispatchTensor(
rewriter, loc, *typeConverter, resultType,
subspanOp.getDynamicDims(), this->materializeEncodingValueFn);
// Drop the encoding if the target does not support it.
if (succeeded(convertedDynamicDims)) {
newDynamicDims = convertedDynamicDims.value();
}
auto newResultType = IREE::Flow::DispatchTensorType::get(
resultType.getAccess(), convertedBoundType);
rewriter.replaceOpWithNewOp<IREE::HAL::InterfaceBindingSubspanOp>(
subspanOp, newResultType, subspanOp.getLayout(), subspanOp.getBinding(),
subspanOp.getByteOffset(), newDynamicDims, subspanOp.getAlignmentAttr(),
subspanOp.getDescriptorFlagsAttr());
return success();
}
};
/// Pattern to convert `flow.dispatch.tensor.store` operation when
/// materializing the encoding.
struct MaterializeFlowDispatchTensorLoadOp
: public OpMaterializeEncodingPattern<IREE::Flow::DispatchTensorLoadOp> {
using OpMaterializeEncodingPattern<
IREE::Flow::DispatchTensorLoadOp>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(IREE::Flow::DispatchTensorLoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle operations where the load covers the entire
// `!flow.dispatch.tensor` type.
// TODO(ravishankarm): Relax this for partial loads.
if (!loadOp.isLoadOfWholeSource()) {
return rewriter.notifyMatchFailure(loadOp, "unhandled partial loads");
}
auto sourceType = loadOp.getSourceType();
auto boundTensorType = cast<RankedTensorType>(sourceType.getBoundType());
auto *typeConverter = static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter());
if (typeConverter->convertType(boundTensorType) == boundTensorType) {
return rewriter.notifyMatchFailure(loadOp, "bound type already valid");
}
Location loc = loadOp.getLoc();
SmallVector<OpFoldResult> newMixedSizes = getMixedValues(
boundTensorType.getShape(), loadOp.getSourceDims(), rewriter);
FailureOr<SmallVector<OpFoldResult>> convertedMixedSizes =
getPackedDimsForDispatchTensor(rewriter, loc, *typeConverter,
sourceType, loadOp.getSourceDims(),
this->materializeEncodingValueFn);
if (succeeded(convertedMixedSizes)) {
newMixedSizes = convertedMixedSizes.value();
}
SmallVector<OpFoldResult> newOffsets(newMixedSizes.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> newStrides(newMixedSizes.size(),
rewriter.getIndexAttr(1));
SmallVector<int64_t> newStaticDims;
SmallVector<Value> newDynamicDims;
dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims);
rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorLoadOp>(
loadOp, adaptor.getSource(), newDynamicDims, newOffsets, newMixedSizes,
newStrides);
return success();
}
};
/// Pattern to convert `flow.dispatch.tensor.store` operation when
/// materializing the encoding.
struct MaterializeFlowDispatchTensorStoreOp
: public OpMaterializeEncodingPattern<IREE::Flow::DispatchTensorStoreOp> {
using OpMaterializeEncodingPattern<
IREE::Flow::DispatchTensorStoreOp>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(IREE::Flow::DispatchTensorStoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only handle operations where the store covers the entire
// `!flow.dispatch.tensor` type.
// TODO(ravishankarm): Relax this for partial stores.
if (!storeOp.isStoreToWholeTarget()) {
return rewriter.notifyMatchFailure(storeOp, "unhandled partial stores");
}
auto targetType = storeOp.getTargetType();
auto boundTensorType = cast<RankedTensorType>(targetType.getBoundType());
auto *typeConverter = static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter());
if (typeConverter->convertType(boundTensorType) == boundTensorType) {
return rewriter.notifyMatchFailure(storeOp, "bound type already valid");
}
Location loc = storeOp.getLoc();
SmallVector<OpFoldResult> newMixedSizes = getMixedValues(
boundTensorType.getShape(), storeOp.getTargetDims(), rewriter);
FailureOr<SmallVector<OpFoldResult>> convertedMixedSizes =
getPackedDimsForDispatchTensor(rewriter, loc, *typeConverter,
targetType, storeOp.getTargetDims(),
this->materializeEncodingValueFn);
if (succeeded(convertedMixedSizes)) {
newMixedSizes = convertedMixedSizes.value();
}
SmallVector<OpFoldResult> newOffsets(newMixedSizes.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> newStrides(newMixedSizes.size(),
rewriter.getIndexAttr(1));
SmallVector<int64_t> newStaticDims;
SmallVector<Value> newDynamicDims;
dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims);
rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorStoreOp>(
storeOp, adaptor.getValue(), adaptor.getTarget(), newDynamicDims,
newOffsets, newMixedSizes, newStrides);
return success();
}
};
//===---------------------------------------------------------------------===//
// Patterns to lower ops with encodings. These are written as
// dialect conversion patterns for now. These are just drivers around
// the core conversion utilities.
//===---------------------------------------------------------------------===//
/// Generic pattern to convert operation that is in Destination Passing Style.
template <typename OpTy>
struct MaterializeDPSOperation : public OpMaterializeEncodingPattern<OpTy> {
using OpMaterializeEncodingPattern<OpTy>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(OpTy dpsOp, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
this->getTypeConverter());
FailureOr<Operation *> convertedOp = lowerOpWithEncoding(
rewriter, dpsOp, adaptor.getInputs(), adaptor.getOutputs(), *converter,
this->materializeEncodingValueFn);
if (failed(convertedOp)) {
return failure();
}
rewriter.replaceOp(dpsOp, convertedOp.value()->getResults());
return success();
}
};
/// Generic pattern to convert an operation.
template <typename OpTy>
struct MaterializeOperation : public OpMaterializeEncodingPattern<OpTy> {
using OpMaterializeEncodingPattern<OpTy>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
this->getTypeConverter());
FailureOr<Operation *> convertedOp =
lowerOpWithEncoding(rewriter, op, adaptor.getOperands(), *converter,
this->materializeEncodingValueFn);
if (failed(convertedOp))
return failure();
SmallVector<Value> replacements;
for (auto [type, res] : llvm::zip_equal(
op->getResultTypes(), convertedOp.value()->getResults())) {
Type targetType = this->getTypeConverter()->convertType(type);
replacements.push_back(
rewriter.createOrFold<tensor::CastOp>(op.getLoc(), targetType, res));
}
rewriter.replaceOp(op, replacements);
return success();
}
};
struct MaterializeOptimizationBarrierOp
: public OpMaterializeEncodingPattern<IREE::Util::OptimizationBarrierOp> {
using OpMaterializeEncodingPattern<
IREE::Util::OptimizationBarrierOp>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(IREE::Util::OptimizationBarrierOp op,
IREE::Util::OptimizationBarrierOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (llvm::none_of(op.getOperandTypes(), [](Type type) -> bool {
auto tensorType = dyn_cast<RankedTensorType>(type);
return tensorType && tensorType.getEncoding();
})) {
return failure();
}
rewriter.replaceOpWithNewOp<IREE::Util::OptimizationBarrierOp>(
op, adaptor.getOperands());
return success();
}
};
static SmallVector<ReassociationIndices>
getReassociationIndices(int outerDims,
const TileSwizzle::ExpandShapeType &expandShape) {
SmallVector<ReassociationIndices> result;
int expandedIdx = 0;
for (int i = 0; i < outerDims; ++i) {
result.push_back({expandedIdx++});
}
for (auto expandShapeDim : expandShape) {
result.push_back({});
for (int i = 0, e = expandShapeDim.size(); i < e; ++i) {
result.back().push_back(expandedIdx++);
}
}
return result;
}
/// Convert iree_linalg_ext.set_encoding op to pack + tile swizzling ops. We use
/// expand_shape + linalg.transpose to represent a tile swizzling op.
struct SetEncodingOpLoweringConversion
: public OpMaterializeEncodingPattern<IREE::Encoding::SetEncodingOp> {
using OpMaterializeEncodingPattern<
IREE::Encoding::SetEncodingOp>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter());
auto packedValue = lowerSetEncodingOpToPackOp(
rewriter, encodingOp, adaptor.getSource(), *converter,
this->materializeEncodingValueFn);
if (failed(packedValue)) {
Type targetType =
getTypeConverter()->convertType(encodingOp.getResultType());
Value result = rewriter.createOrFold<tensor::CastOp>(
encodingOp.getLoc(), targetType, adaptor.getSource());
rewriter.replaceOp(encodingOp, result);
return success();
}
MaterializeEncodingInfo encodingInfo =
converter->getEncodingInfo(encodingOp.getResultType());
if (!encodingInfo.swizzle) {
rewriter.replaceOp(encodingOp, packedValue.value());
return success();
}
Location loc = encodingOp.getLoc();
// Create expand_shape op to tile the innermost two dimensions.
int origRank = encodingOp.getSourceType().getRank();
SmallVector<int64_t> expandShapeShape(
cast<ShapedType>(packedValue->getType())
.getShape()
.take_front(origRank));
expandShapeShape.append(
getExpandedTileShape(encodingInfo.swizzle->expandShape));
RankedTensorType expandShapeType =
encodingOp.getSourceType().clone(expandShapeShape);
SmallVector<ReassociationIndices> reassociation =
getReassociationIndices(origRank, encodingInfo.swizzle->expandShape);
auto expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeType, packedValue.value(), reassociation);
SmallVector<int64_t> transposePerm =
llvm::to_vector(llvm::seq<int64_t>(0, origRank));
for (auto perm : encodingInfo.swizzle->permutation) {
transposePerm.push_back(origRank + perm);
}
SmallVector<OpFoldResult> transposeResultDims =
tensor::getMixedSizes(rewriter, loc, expandShapeOp.getResult());
applyPermutationToVector(transposeResultDims, transposePerm);
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, transposeResultDims, encodingOp.getSourceType().getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, expandShapeOp, emptyTensor, transposePerm);
rewriter.replaceOp(encodingOp, transposeOp->getResult(0));
return success();
}
};
struct UnsetEncodingOpLoweringConversion
: public OpMaterializeEncodingPattern<IREE::Encoding::UnsetEncodingOp> {
using OpMaterializeEncodingPattern<
IREE::Encoding::UnsetEncodingOp>::OpMaterializeEncodingPattern;
LogicalResult
matchAndRewrite(IREE::Encoding::UnsetEncodingOp unsetEncodingOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter());
MaterializeEncodingInfo encodingInfo =
converter->getEncodingInfo(unsetEncodingOp.getSource().getType());
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
Type targetType =
getTypeConverter()->convertType(unsetEncodingOp.getSourceType());
Value result = rewriter.createOrFold<tensor::CastOp>(
unsetEncodingOp.getLoc(), targetType, adaptor.getSource());
rewriter.replaceOp(unsetEncodingOp, result);
return success();
}
Location loc = unsetEncodingOp.getLoc();
Value unpackSrc = adaptor.getSource();
if (encodingInfo.swizzle) {
int targetRank = unsetEncodingOp.getResultType().getRank();
auto srcConvertedType =
cast<RankedTensorType>(adaptor.getSource().getType());
SmallVector<OpFoldResult> emptyShape =
tensor::getMixedSizes(rewriter, loc, adaptor.getSource());
emptyShape.resize(targetRank);
for (auto i : getExpandedTileShape(encodingInfo.swizzle->expandShape)) {
emptyShape.push_back(rewriter.getIndexAttr(i));
}
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, emptyShape, unsetEncodingOp.getSourceType().getElementType());
SmallVector<int64_t> transposePerm =
llvm::to_vector(llvm::seq<int64_t>(0, targetRank));
for (auto perm : encodingInfo.swizzle->permutation) {
transposePerm.push_back(targetRank + perm);
}
auto invertedTransposePerm = invertPermutationVector(transposePerm);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, adaptor.getSource(), emptyTensor, invertedTransposePerm);
SmallVector<ReassociationIndices> reassociation = getReassociationIndices(
targetRank, encodingInfo.swizzle->expandShape);
SmallVector<int64_t> unpackSrcShape(
srcConvertedType.getShape().take_front(targetRank));
unpackSrcShape.append(encodingInfo.innerTileSizes.begin(),
encodingInfo.innerTileSizes.end());
RankedTensorType unpackSrcType =
unsetEncodingOp.getResultType().clone(unpackSrcShape);
unpackSrc = rewriter.create<tensor::CollapseShapeOp>(
loc, unpackSrcType, transposeOp->getResult(0), reassociation);
}
auto unpackedValue = lowerUnsetEncodingToUnpackOp(
rewriter, unsetEncodingOp, unpackSrc, *converter,
this->materializeEncodingValueFn);
if (failed(unpackedValue)) {
Type targetType =
getTypeConverter()->convertType(unsetEncodingOp.getResultType());
Value result = rewriter.createOrFold<tensor::CastOp>(loc, targetType,
adaptor.getSource());
rewriter.replaceOp(unsetEncodingOp, result);
return success();
}
rewriter.replaceOp(unsetEncodingOp, unpackedValue.value());
return success();
}
};
/// Pattern to convert contraction operations.
class MaterializeContractionOp
: public OpInterfaceConversionPattern<linalg::LinalgOp> {
public:
MaterializeContractionOp(
MLIRContext *context,
const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn = {},
PatternBenefit benefit = 1)
: OpInterfaceConversionPattern<linalg::LinalgOp>(typeConverter, context,
benefit),
materializeEncodingValueFn(materializeEncodingValueFn) {}
LogicalResult
matchAndRewrite(linalg::LinalgOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!linalg::isaContractionOpInterface(op)) {
return rewriter.notifyMatchFailure(
op, "does not implement ContractionOpInterface");
}
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
this->getTypeConverter());
IREE::Codegen::LayoutAttrInterface layoutAttr = converter->getLayoutAttr();
SmallVector<Type> convertedResTypes;
for (auto init : op.getDpsInits()) {
convertedResTypes.push_back(converter->convertType(init.getType()));
}
Operation *newOp =
layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
protected:
const MaterializeEncodingValueFn materializeEncodingValueFn;
};
} // namespace
void populateMaterializeEncodingPatterns(
RewritePatternSet &patterns, MaterializeEncodingConversionTarget &target,
MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
MLIRContext *context = patterns.getContext();
typeConverter.addConversion(
[&typeConverter](IREE::Flow::DispatchTensorType dispatchTensorType) {
Type boundType = dispatchTensorType.getBoundType();
Type convertedBoundType = typeConverter.convertType(boundType);
if (convertedBoundType == boundType) {
return dispatchTensorType;
}
return IREE::Flow::DispatchTensorType::get(
dispatchTensorType.getAccess(), convertedBoundType);
});
target.addDynamicallyLegalOp<IREE::HAL::InterfaceBindingSubspanOp>(
[&typeConverter](IREE::HAL::InterfaceBindingSubspanOp subspanOp) {
auto resultType = llvm::dyn_cast<IREE::Flow::DispatchTensorType>(
subspanOp.getResult().getType());
// For types that are not `Flow::DispatchTensorType` mark as legal.
if (!resultType)
return true;
return resultType == typeConverter.convertType(resultType);
});
patterns.insert<
MaterializeContractionOp, SetEncodingOpLoweringConversion,
UnsetEncodingOpLoweringConversion,
MaterializeDPSOperation<linalg::FillOp>,
MaterializeDPSOperation<linalg::GenericOp>,
MaterializeOperation<tensor::EmptyOp>, MaterializeOptimizationBarrierOp,
MaterializeFlowDispatchTensorLoadOp, MaterializeFlowDispatchTensorStoreOp,
MaterializeInterfaceBindingEncoding>(context, typeConverter,
materializeEncodingValueFn);
};
} // namespace mlir::iree_compiler