blob: 9914e5e1613e84a2d797234e257dc3df8550ed6e [file]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===---------------------------------------------------------------------===//
// Pass to materialize the encoding of tensor based on target information.
//===---------------------------------------------------------------------===//
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/TensorExt/IR/TensorExtOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#define DEBUG_TYPE "iree-codegen-materialize-encoding"
namespace mlir::iree_compiler {
using IREE::Codegen::MaterializeEncodingInfo;
using IREE::Codegen::TileSwizzle;
//===---------------------------------------------------------------------===//
// Methods to convert `set_encoding` and `unset_encoding` operations
// to `pack` and `unpack` operations respectively.
//===---------------------------------------------------------------------===//
FailureOr<Value> lowerSetEncodingOpToPackOp(
RewriterBase &rewriter, IREE::Encoding::SetEncodingOp encodingOp,
Value source, const MaterializeEncodingTypeConverter &typeConverter) {
RankedTensorType resultType = encodingOp.getResultType();
MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(resultType);
// Shortcut to avoid creating new operations.
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return source;
}
// Create `tensor.empty` operation for the result of the pack operation.
Location loc = encodingOp.getLoc();
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
typeConverter.getInnerTileSizesOfr(rewriter, loc, resultType,
encodingInfo);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
}
Value paddingValue = arith::ConstantOp::create(
rewriter, loc, rewriter.getZeroAttr(resultType.getElementType()));
SmallVector<OpFoldResult> sourceDims =
tensor::getMixedSizes(rewriter, loc, source);
SmallVector<OpFoldResult> resultDims = linalg::PackOp::getResultShape(
rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm);
auto emptyOp = tensor::EmptyOp::create(rewriter, loc, resultDims,
resultType.getElementType());
return linalg::PackOp::create(rewriter, loc, source, emptyOp,
encodingInfo.innerDimsPos, *innerTileSizesOfr,
paddingValue, encodingInfo.outerDimsPerm)
.getResult();
}
FailureOr<Value> lowerUnsetEncodingToUnpackOp(
RewriterBase &rewriter, IREE::Encoding::UnsetEncodingOp encodingOp,
Value packedValue, const MaterializeEncodingTypeConverter &typeConverter) {
RankedTensorType sourceType = encodingOp.getSourceType();
MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(sourceType);
// Shortcut to avoid creating new operations.
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return packedValue;
}
// Create an `tensor.empty` for the result of the unpack operation.
Location loc = encodingOp.getLoc();
SmallVector<OpFoldResult> resultDims =
getMixedValues(encodingOp.getResultType().getShape(),
encodingOp.getResultDims(), rewriter);
auto emptyOp = tensor::EmptyOp::create(rewriter, loc, resultDims,
sourceType.getElementType());
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
typeConverter.getInnerTileSizesOfr(rewriter, loc, sourceType,
encodingInfo);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
}
return linalg::UnPackOp::create(rewriter, loc, packedValue, emptyOp,
encodingInfo.innerDimsPos, *innerTileSizesOfr,
encodingInfo.outerDimsPerm)
.getResult();
}
/// Utility method to convert `tensor.empty` with encoding to a `tensor.empty`
/// of the materialized type.
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp,
ValueRange convertedOperands,
const MaterializeEncodingTypeConverter &typeConverter) {
auto emptyType = cast<RankedTensorType>(emptyOp->getResultTypes()[0]);
MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(emptyType);
Location loc = emptyOp.getLoc();
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return tensor::EmptyOp::create(rewriter, loc, emptyOp.getMixedSizes(),
emptyType.getElementType())
.getOperation();
}
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
typeConverter.getInnerTileSizesOfr(rewriter, loc, emptyType,
encodingInfo);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
emptyOp, "failed to generate runtime tile size query");
}
SmallVector<OpFoldResult> sourceDims = emptyOp.getMixedSizes();
(void)foldDynamicIndexList(sourceDims);
SmallVector<OpFoldResult> newShape = linalg::PackOp::getResultShape(
rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm);
newShape = getSwizzledShape(newShape, encodingInfo);
Operation *newEmptyOp = tensor::EmptyOp::create(rewriter, loc, newShape,
emptyType.getElementType());
return newEmptyOp;
}
namespace {
/// Pattern to materialize the encoding for `hal.interface.binding.subspan`
/// operations.
struct MaterializeInterfaceBindingEncoding
: public OpConversionPattern<IREE::HAL::InterfaceBindingSubspanOp> {
using OpConversionPattern<
IREE::HAL::InterfaceBindingSubspanOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp subspanOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto origResultType = dyn_cast<IREE::TensorExt::DispatchTensorType>(
subspanOp.getResult().getType());
if (!origResultType) {
return rewriter.notifyMatchFailure(
subspanOp,
"expected result type to be !iree_tensor_ext.dispatch.tensor");
}
auto origBoundTensorType =
dyn_cast<RankedTensorType>(origResultType.getBoundType());
if (!origBoundTensorType) {
return rewriter.notifyMatchFailure(
subspanOp, "bound type is not a RankedTensorType");
}
auto typeConverter = getTypeConverter<MaterializeEncodingTypeConverter>();
auto convertedResultType =
typeConverter->convertType<IREE::TensorExt::DispatchTensorType>(
origResultType);
if (!convertedResultType) {
return rewriter.notifyMatchFailure(subspanOp,
"expected converted result type to be "
"!iree_tensor_ext.dispatch.tensor");
}
if (origResultType == convertedResultType) {
return rewriter.notifyMatchFailure(
subspanOp, "DispatchTensorType type already valid");
}
// Get the dynamic dims of the target.
// TODO(hanchung): We only have getOffsetsSizesStrides interface method that
// handles all three together. It would be cleaner to have a separate method
// to get dynamic sizes only.
Location loc = subspanOp.getLoc();
ValueRange origDynamicDims = subspanOp.getDynamicDims();
SmallVector<OpFoldResult> origSizes = getMixedValues(
origBoundTensorType.getShape(), origDynamicDims, rewriter);
SmallVector<OpFoldResult> origOffsets(origDynamicDims.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> origStrides(origDynamicDims.size(),
rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
if (failed(typeConverter->getOffsetsSizesStrides(
rewriter, loc, origResultType, origDynamicDims, origOffsets,
origSizes, origStrides, newOffsets, newSizes, newStrides))) {
return failure();
}
SmallVector<int64_t> newStaticDims;
SmallVector<Value> newDynamicDims;
dispatchIndexOpFoldResults(newSizes, newDynamicDims, newStaticDims);
rewriter.replaceOpWithNewOp<IREE::HAL::InterfaceBindingSubspanOp>(
subspanOp, convertedResultType, subspanOp.getLayout(),
subspanOp.getBinding(), subspanOp.getByteOffset(), newDynamicDims,
subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr());
return success();
}
};
/// Pattern to convert `iree_tensor_ext.dispatch.tensor.load` operation when
/// materializing the encoding.
struct MaterializeTensorExtDispatchTensorLoadOp
: public OpConversionPattern<IREE::TensorExt::DispatchTensorLoadOp> {
using OpConversionPattern<
IREE::TensorExt::DispatchTensorLoadOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(IREE::TensorExt::DispatchTensorLoadOp loadOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto sourceType = loadOp.getSourceType();
auto boundTensorType = cast<RankedTensorType>(sourceType.getBoundType());
auto typeConverter = getTypeConverter<MaterializeEncodingTypeConverter>();
if (typeConverter->convertType(boundTensorType) == boundTensorType) {
return rewriter.notifyMatchFailure(loadOp, "bound type already valid");
}
SmallVector<OpFoldResult> newOffsets, newMixedSizes, newStrides;
if (failed(typeConverter->getOffsetsSizesStrides(
rewriter, loadOp.getLoc(), sourceType, loadOp.getSourceDims(),
loadOp.getMixedOffsets(), loadOp.getMixedSizes(),
loadOp.getMixedStrides(), newOffsets, newMixedSizes, newStrides))) {
return failure();
}
SmallVector<int64_t> newStaticDims;
SmallVector<Value> newDynamicDims;
dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims);
rewriter.replaceOpWithNewOp<IREE::TensorExt::DispatchTensorLoadOp>(
loadOp, adaptor.getSource(), newDynamicDims, newOffsets, newMixedSizes,
newStrides);
return success();
}
};
/// Pattern to convert `iree_tensor_ext.dispatch.tensor.store` operation when
/// materializing the encoding.
struct MaterializeTensorExtDispatchTensorStoreOp
: public OpConversionPattern<IREE::TensorExt::DispatchTensorStoreOp> {
using OpConversionPattern<
IREE::TensorExt::DispatchTensorStoreOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(IREE::TensorExt::DispatchTensorStoreOp storeOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto targetType = storeOp.getTargetType();
auto boundTensorType = cast<RankedTensorType>(targetType.getBoundType());
auto typeConverter = getTypeConverter<MaterializeEncodingTypeConverter>();
if (typeConverter->convertType(boundTensorType) == boundTensorType) {
return rewriter.notifyMatchFailure(storeOp, "bound type already valid");
}
SmallVector<OpFoldResult> newOffsets, newMixedSizes, newStrides;
if (failed(typeConverter->getOffsetsSizesStrides(
rewriter, storeOp.getLoc(), targetType, storeOp.getTargetDims(),
storeOp.getMixedOffsets(), storeOp.getMixedSizes(),
storeOp.getMixedStrides(), newOffsets, newMixedSizes,
newStrides))) {
return failure();
}
SmallVector<int64_t> newStaticDims;
SmallVector<Value> newDynamicDims;
dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims);
rewriter.replaceOpWithNewOp<IREE::TensorExt::DispatchTensorStoreOp>(
storeOp, adaptor.getValue(), adaptor.getTarget(), newDynamicDims,
newOffsets, newMixedSizes, newStrides);
return success();
}
};
//===---------------------------------------------------------------------===//
// Patterns for layout transfers. They decompose load/store ops into
// set_encoding/unset_encoding + load/store, if the converted types mismatch.
//===---------------------------------------------------------------------===//
/// Returns the value that brings `src` to `destType` by inserting the necessary
/// encoding ops.
static Value generateEncodingTransferOps(RewriterBase &rewriter, Value src,
ArrayRef<Value> dynamicDims,
RankedTensorType destType) {
auto srcType = cast<RankedTensorType>(src.getType());
if (srcType == destType) {
return src;
}
Value value = src;
if (srcType.getEncoding()) {
value = IREE::Encoding::UnsetEncodingOp::create(
rewriter, src.getLoc(), srcType.dropEncoding(), value, dynamicDims,
/*encoding_dims=*/ValueRange{});
}
if (destType.getEncoding()) {
value = IREE::Encoding::SetEncodingOp::create(
rewriter, src.getLoc(), destType, value,
/*encoding_dims=*/ValueRange{});
}
return value;
}
/// Pattern to convert `iree_tensor_ext.dispatch.tensor.load` operation when
/// materializing the encoding.
struct DecomposeMismatchEncodingTensorLoadOp
: public OpRewritePattern<IREE::TensorExt::DispatchTensorLoadOp> {
using OpRewritePattern<
IREE::TensorExt::DispatchTensorLoadOp>::OpRewritePattern;
DecomposeMismatchEncodingTensorLoadOp(
MaterializeEncodingTypeConverter &converter, MLIRContext *ctx,
PatternBenefit benefit = 0)
: OpRewritePattern<IREE::TensorExt::DispatchTensorLoadOp>(ctx, benefit),
typeConverter(converter) {}
LogicalResult matchAndRewrite(IREE::TensorExt::DispatchTensorLoadOp loadOp,
PatternRewriter &rewriter) const override {
if (!loadOp.isLoadOfWholeSource()) {
return rewriter.notifyMatchFailure(loadOp, "unhandled partial loads");
}
IREE::TensorExt::DispatchTensorType srcType = loadOp.getSourceType();
auto boundTensorType = dyn_cast<RankedTensorType>(srcType.getBoundType());
if (!boundTensorType) {
return rewriter.notifyMatchFailure(
loadOp, "source bound type is not a RankedTensorType");
}
// Only decompose if there's an encoding involved. If neither the source
// nor the destination has an encoding, this pattern should not match.
// This can happen when isLoadOfWholeSource() returns true but the load
// reshapes the tensor (e.g., loading a 4D tensor from a 5D source).
RankedTensorType destType = loadOp.getResult().getType();
if (!boundTensorType.getEncoding() && !destType.getEncoding()) {
return rewriter.notifyMatchFailure(
loadOp, "no encoding involved in source or destination");
}
// We have to check the bound type from converted DispatchTensorType because
// it is what we'll see in encoding materialization. E.g.,
// GPUPaddingResolver converts RankedTensorType into the same type, but it
// creates different IREE::TensorExt::DispatchTensorType that may have
// larger tensor shape for bound type.
auto convertedSrcType =
typeConverter.convertType<IREE::TensorExt::DispatchTensorType>(srcType);
if (typeConverter.convertType(convertedSrcType.getBoundType()) ==
typeConverter.convertType(destType)) {
return rewriter.notifyMatchFailure(
loadOp, "the source type and the result type match after conversion");
}
LDBG() << "Performance warning: decomposing mismatched encoding load op: "
<< loadOp;
Location loc = loadOp.getLoc();
Value result = IREE::TensorExt::DispatchTensorLoadOp::create(
rewriter, loc, boundTensorType, loadOp.getSource(),
loadOp.getSourceDims(), loadOp.getMixedOffsets(),
loadOp.getMixedSizes(), loadOp.getMixedStrides());
SmallVector<Value> dynamicDims = llvm::to_vector(loadOp.getSizes());
result =
generateEncodingTransferOps(rewriter, result, dynamicDims, destType);
rewriter.replaceOp(loadOp, result);
return success();
}
private:
MaterializeEncodingTypeConverter &typeConverter;
};
/// Pattern to convert `iree_tensor_ext.dispatch.tensor.store` operation when
/// materializing the encoding.
struct DecomposeMismatchEncodingTensorStoreOp
: public OpRewritePattern<IREE::TensorExt::DispatchTensorStoreOp> {
using OpRewritePattern<
IREE::TensorExt::DispatchTensorStoreOp>::OpRewritePattern;
DecomposeMismatchEncodingTensorStoreOp(
MaterializeEncodingTypeConverter &converter, MLIRContext *ctx,
PatternBenefit benefit = 0)
: OpRewritePattern<IREE::TensorExt::DispatchTensorStoreOp>(ctx, benefit),
typeConverter(converter) {}
LogicalResult matchAndRewrite(IREE::TensorExt::DispatchTensorStoreOp storeOp,
PatternRewriter &rewriter) const override {
if (!storeOp.isStoreToWholeTarget()) {
return rewriter.notifyMatchFailure(storeOp, "unhandled partial stores");
}
IREE::TensorExt::DispatchTensorType targetType = storeOp.getTargetType();
auto boundTensorType =
dyn_cast<RankedTensorType>(targetType.getBoundType());
if (!boundTensorType) {
return rewriter.notifyMatchFailure(
storeOp, "target bound type is not a RankedTensorType");
}
// Only decompose if there's an encoding involved. If neither the value
// nor the target has an encoding, this pattern should not match.
// This can happen when isStoreToWholeTarget() returns true but the store
// reshapes the tensor (e.g., storing a 4D tensor to a 5D target).
RankedTensorType valueType = storeOp.getValue().getType();
if (!boundTensorType.getEncoding() && !valueType.getEncoding()) {
return rewriter.notifyMatchFailure(
storeOp, "no encoding involved in value or target");
}
// Similar to DecomposeMismatchEncodingTensorLoadOp, we have to check with
// the bound type from converted DispatchTensorType.
auto convertedTargetType =
typeConverter.convertType<IREE::TensorExt::DispatchTensorType>(
targetType);
if (typeConverter.convertType(convertedTargetType.getBoundType()) ==
typeConverter.convertType(valueType)) {
return rewriter.notifyMatchFailure(
storeOp, "the value type and the target type match");
}
LDBG() << "Performance warning: decomposing mismatched encoding store op: "
<< storeOp;
Location loc = storeOp.getLoc();
Value valueToStore = storeOp.getValue();
SmallVector<Value> dynamicDims = llvm::to_vector(storeOp.getSizes());
valueToStore = generateEncodingTransferOps(rewriter, valueToStore,
dynamicDims, boundTensorType);
IREE::TensorExt::DispatchTensorStoreOp::create(
rewriter, loc, valueToStore, storeOp.getTarget(),
storeOp.getTargetDims(), storeOp.getMixedOffsets(),
storeOp.getMixedSizes(), storeOp.getMixedStrides());
rewriter.eraseOp(storeOp);
return success();
}
private:
MaterializeEncodingTypeConverter &typeConverter;
};
//===---------------------------------------------------------------------===//
// Patterns to lower ops with encodings. These are written as
// dialect conversion patterns for now. These are just drivers around
// the core conversion utilities.
//===---------------------------------------------------------------------===//
/// Generic pattern to convert an operation.
template <typename OpTy>
struct MaterializeOperation : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter =
this->template getTypeConverter<MaterializeEncodingTypeConverter>();
FailureOr<Operation *> convertedOp =
lowerOpWithEncoding(rewriter, op, adaptor.getOperands(), *converter);
if (failed(convertedOp)) {
return failure();
}
rewriter.replaceOp(op, convertedOp.value());
return success();
}
};
struct MaterializeOptimizationBarrierOp
: public OpConversionPattern<IREE::Util::OptimizationBarrierOp> {
using OpConversionPattern<
IREE::Util::OptimizationBarrierOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(IREE::Util::OptimizationBarrierOp op,
IREE::Util::OptimizationBarrierOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (llvm::none_of(op.getOperandTypes(), [](Type type) -> bool {
auto tensorType = dyn_cast<RankedTensorType>(type);
return tensorType && tensorType.getEncoding();
})) {
return failure();
}
rewriter.replaceOpWithNewOp<IREE::Util::OptimizationBarrierOp>(
op, adaptor.getOperands());
return success();
}
};
static SmallVector<ReassociationIndices>
getReassociationIndices(int outerDims,
const TileSwizzle::ExpandShapeType &expandShape) {
SmallVector<ReassociationIndices> result;
int expandedIdx = 0;
for (int i = 0; i < outerDims; ++i) {
result.push_back({expandedIdx++});
}
for (auto expandShapeDim : expandShape) {
result.push_back({});
for (int i = 0, e = expandShapeDim.size(); i < e; ++i) {
result.back().push_back(expandedIdx++);
}
}
return result;
}
/// Convert iree_linalg_ext.set_encoding op to pack + tile swizzling ops. We use
/// expand_shape + linalg.transpose to represent a tile swizzling op.
struct SetEncodingOpLoweringConversion
: public OpConversionPattern<IREE::Encoding::SetEncodingOp> {
using Base::Base;
LogicalResult
matchAndRewrite(IREE::Encoding::SetEncodingOp encodingOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (encodingOp.getSource().getType().getRank() == 0) {
rewriter.replaceOp(encodingOp, adaptor.getSource());
return success();
}
auto converter = getTypeConverter<MaterializeEncodingTypeConverter>();
auto packedValue = lowerSetEncodingOpToPackOp(
rewriter, encodingOp, adaptor.getSource(), *converter);
if (failed(packedValue)) {
rewriter.replaceOp(encodingOp, adaptor.getSource());
return success();
}
MaterializeEncodingInfo encodingInfo =
converter->getEncodingInfo(encodingOp.getResultType());
if (!encodingInfo.swizzle) {
rewriter.replaceOp(encodingOp, packedValue.value());
return success();
}
Location loc = encodingOp.getLoc();
// Create expand_shape op to tile the innermost two dimensions.
int origRank = encodingOp.getSourceType().getRank();
SmallVector<int64_t> expandShapeShape(
cast<ShapedType>(packedValue->getType())
.getShape()
.take_front(origRank));
expandShapeShape.append(
getExpandedTileShape(encodingInfo.swizzle->expandShape));
RankedTensorType expandShapeType =
encodingOp.getSourceType().clone(expandShapeShape);
SmallVector<ReassociationIndices> reassociation =
getReassociationIndices(origRank, encodingInfo.swizzle->expandShape);
auto expandShapeOp = tensor::ExpandShapeOp::create(
rewriter, loc, expandShapeType, packedValue.value(), reassociation);
SmallVector<int64_t> transposePerm =
llvm::to_vector(llvm::seq<int64_t>(0, origRank));
for (auto perm : encodingInfo.swizzle->permutation) {
transposePerm.push_back(origRank + perm);
}
SmallVector<OpFoldResult> transposeResultDims =
tensor::getMixedSizes(rewriter, loc, expandShapeOp.getResult());
applyPermutationToVector(transposeResultDims, transposePerm);
auto emptyTensor =
tensor::EmptyOp::create(rewriter, loc, transposeResultDims,
encodingOp.getSourceType().getElementType());
auto transposeOp = linalg::TransposeOp::create(rewriter, loc, expandShapeOp,
emptyTensor, transposePerm);
rewriter.replaceOp(encodingOp, transposeOp->getResult(0));
return success();
}
};
struct UnsetEncodingOpLoweringConversion
: public OpConversionPattern<IREE::Encoding::UnsetEncodingOp> {
using OpConversionPattern<
IREE::Encoding::UnsetEncodingOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(IREE::Encoding::UnsetEncodingOp unsetEncodingOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = getTypeConverter<MaterializeEncodingTypeConverter>();
MaterializeEncodingInfo encodingInfo =
converter->getEncodingInfo(unsetEncodingOp.getSource().getType());
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
rewriter.replaceOp(unsetEncodingOp, adaptor.getSource());
return success();
}
Location loc = unsetEncodingOp.getLoc();
Value unpackSrc = adaptor.getSource();
if (encodingInfo.swizzle) {
int targetRank = unsetEncodingOp.getResultType().getRank();
auto srcConvertedType =
cast<RankedTensorType>(adaptor.getSource().getType());
SmallVector<OpFoldResult> emptyShape =
tensor::getMixedSizes(rewriter, loc, adaptor.getSource());
emptyShape.resize(targetRank);
for (auto i : getExpandedTileShape(encodingInfo.swizzle->expandShape)) {
emptyShape.push_back(rewriter.getIndexAttr(i));
}
auto emptyTensor = tensor::EmptyOp::create(
rewriter, loc, emptyShape,
unsetEncodingOp.getSourceType().getElementType());
SmallVector<int64_t> transposePerm =
llvm::to_vector(llvm::seq<int64_t>(0, targetRank));
for (auto perm : encodingInfo.swizzle->permutation) {
transposePerm.push_back(targetRank + perm);
}
auto invertedTransposePerm = invertPermutationVector(transposePerm);
auto transposeOp =
linalg::TransposeOp::create(rewriter, loc, adaptor.getSource(),
emptyTensor, invertedTransposePerm);
SmallVector<ReassociationIndices> reassociation = getReassociationIndices(
targetRank, encodingInfo.swizzle->expandShape);
SmallVector<int64_t> unpackSrcShape(
srcConvertedType.getShape().take_front(targetRank));
unpackSrcShape.append(encodingInfo.innerTileSizes.begin(),
encodingInfo.innerTileSizes.end());
RankedTensorType unpackSrcType =
unsetEncodingOp.getResultType().clone(unpackSrcShape);
unpackSrc = tensor::CollapseShapeOp::create(rewriter, loc, unpackSrcType,
transposeOp->getResult(0),
reassociation);
}
auto unpackedValue = lowerUnsetEncodingToUnpackOp(rewriter, unsetEncodingOp,
unpackSrc, *converter);
if (failed(unpackedValue)) {
rewriter.replaceOp(unsetEncodingOp, adaptor.getSource());
return success();
}
rewriter.replaceOp(unsetEncodingOp, unpackedValue.value());
return success();
}
};
/// Pattern to rewrite linalg::LinalgOp by materializing its encoding using the
/// provided LayoutMaterializerAttr.
class MaterializeLinalgOp
: public OpInterfaceConversionPattern<linalg::LinalgOp> {
public:
MaterializeLinalgOp(const MaterializeEncodingTypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpInterfaceConversionPattern<linalg::LinalgOp>(typeConverter, context,
benefit) {}
LogicalResult
matchAndRewrite(linalg::LinalgOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto converter = getTypeConverter<MaterializeEncodingTypeConverter>();
IREE::Encoding::LayoutMaterializerAttr layoutAttr =
converter->getLayoutAttr();
SmallVector<Type> convertedResTypes;
for (auto init : op.getDpsInits()) {
convertedResTypes.push_back(converter->convertType(init.getType()));
}
Operation *newOp =
layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
static bool isRankedTensorTypeWithEncoding(Type type) {
auto rankedTensorType = dyn_cast<RankedTensorType>(type);
if (!rankedTensorType) {
return false;
}
return rankedTensorType.getEncoding() ? true : false;
}
struct MaterializeFuncReturnOp final
: public OpConversionPattern<func::ReturnOp> {
using Base::Base;
LogicalResult
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
return success();
}
};
} // namespace
void populateDecomposeMismatchedLayoutLoadStoreOpsPatterns(
RewritePatternSet &patterns,
MaterializeEncodingTypeConverter &typeConverter) {
patterns.insert<DecomposeMismatchEncodingTensorLoadOp,
DecomposeMismatchEncodingTensorStoreOp>(
typeConverter, patterns.getContext());
}
void populateMaterializeEncodingPatterns(
RewritePatternSet &patterns, MaterializeEncodingConversionTarget &target,
MaterializeEncodingTypeConverter &typeConverter) {
MLIRContext *context = patterns.getContext();
target.addDynamicallyLegalOp<IREE::HAL::InterfaceBindingSubspanOp>(
[&typeConverter](IREE::HAL::InterfaceBindingSubspanOp subspanOp) {
auto resultType = dyn_cast<IREE::TensorExt::DispatchTensorType>(
subspanOp.getResult().getType());
// For types that are not `TensorExt::DispatchTensorType` mark as legal.
if (!resultType) {
return true;
}
return resultType == typeConverter.convertType(resultType);
});
target.addIllegalOp<IREE::Encoding::SetEncodingOp,
IREE::Encoding::UnsetEncodingOp>();
target.addDynamicallyLegalOp<IREE::TensorExt::DispatchTensorStoreOp>(
[&typeConverter](IREE::TensorExt::DispatchTensorStoreOp storeOp) {
auto resultType = dyn_cast<IREE::TensorExt::DispatchTensorType>(
storeOp.getTargetType());
// For types that are not `TensorExt::DispatchTensorType` mark as legal.
if (!resultType) {
return true;
}
return resultType == typeConverter.convertType(resultType);
});
target.addDynamicallyLegalOp<IREE::TensorExt::DispatchTensorLoadOp>(
[&typeConverter](IREE::TensorExt::DispatchTensorLoadOp loadOp) {
auto resultType = dyn_cast<IREE::TensorExt::DispatchTensorType>(
loadOp.getSourceType());
// For types that are not `TensorExt::DispatchTensorType` mark as legal.
if (!resultType) {
return true;
}
return resultType == typeConverter.convertType(resultType);
});
target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp returnOp) {
return llvm::none_of(returnOp.getOperandTypes(),
isRankedTensorTypeWithEncoding);
});
patterns.insert<MaterializeLinalgOp, SetEncodingOpLoweringConversion,
UnsetEncodingOpLoweringConversion,
MaterializeOperation<tensor::EmptyOp>,
MaterializeOptimizationBarrierOp,
MaterializeTensorExtDispatchTensorLoadOp,
MaterializeTensorExtDispatchTensorStoreOp,
MaterializeInterfaceBindingEncoding, MaterializeFuncReturnOp>(
typeConverter, context);
};
} // namespace mlir::iree_compiler