blob: 1d558e049369d38a405d3405adb27d43bccc3b4b [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace LinalgExt {
Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) {
ShapedType type = v.getType().cast<ShapedType>();
if (!type.isDynamicDim(dim)) {
return builder.create<arith::ConstantIndexOp>(loc, type.getDimSize(dim));
}
return TypeSwitch<Type, Value>(v.getType())
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
return builder.create<tensor::DimOp>(loc, v, dim);
})
.Case<MemRefType>([&](MemRefType t) -> Value {
return builder.create<memref::DimOp>(loc, v, dim);
});
}
OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim) {
auto t = v.getType().cast<ShapedType>();
if (t.isDynamicDim(dim)) {
return getDimValue(builder, loc, v, dim);
}
return builder.getI64IntegerAttr(t.getDimSize(dim));
}
SmallVector<OpFoldResult> getDims(OpBuilder &builder, Location loc,
Value shapedTypeValue) {
return llvm::map_to_vector(
llvm::seq<int64_t>(
0, shapedTypeValue.getType().cast<ShapedType>().getRank()),
[&](int64_t dim) { return getDim(builder, loc, shapedTypeValue, dim); });
}
SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos,
int64_t rank) {
SmallVector<int64_t> interchangeVector;
interchangeVector.reserve(dimsPos.size());
// First map dims and their position. For example, dims_pos = [2, 0] will map
// to:
// [
// [ key: 2, value: 0]
// [ key: 0, value: 1]
// ]
// where key is the idx in dims_pos while value its position in dims_pos.
DenseMap<int64_t, int64_t> dimsAndPosMapping;
for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++)
dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
// Scan the position in order and insert the value in the map
// to compute the interchange vector.
for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) {
if (dimsAndPosMapping.count(dimsIdx))
interchangeVector.push_back(dimsAndPosMapping[dimsIdx]);
}
return interchangeVector;
}
Value createValueFrom2DConstant(const float *val, int64_t rows, int64_t cols,
Location loc, RewriterBase &rewriter) {
ArrayRef<float> vector(val, rows * cols);
SmallVector<int64_t> shape{rows, cols};
return rewriter.create<arith::ConstantOp>(
loc, DenseFPElementsAttr::get(
RankedTensorType::get(shape, rewriter.getF32Type()), vector));
}
SmallVector<int64_t> asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
SmallVector<int64_t> result;
for (auto o : ofrs) {
// Have to do this first, as getConstantIntValue special-cases constants.
if (o.dyn_cast<Value>())
result.push_back(ShapedType::kDynamic);
else
result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
}
return result;
}
FailureOr<SmallVector<OpFoldResult>>
getInnerTileSizesOfr(OpBuilder &rewriter, Location loc,
RankedTensorType tensorType,
const MaterializeEncodingInfo &materializeEncodingInfo,
MaterializeEncodingValueFn materializeEncodingValueFn) {
ArrayRef<int64_t> staticTileSizes = materializeEncodingInfo.innerTileSizes;
if (llvm::all_of(staticTileSizes,
[](int64_t i) { return !ShapedType::isDynamic(i); })) {
return getAsOpFoldResult(rewriter.getI64ArrayAttr(staticTileSizes));
}
assert(materializeEncodingValueFn &&
"When dynamic tile sizes are generated, a MaterializeEncodingValueFn "
"should be provided.");
FailureOr<MaterializeEncodingValueInfo> materializeEncodingValueInfo =
materializeEncodingValueFn(tensorType, rewriter, loc);
if (failed(materializeEncodingValueInfo)) {
return failure();
}
ArrayRef<Value> innerTileSizeValues =
materializeEncodingValueInfo->innerTileSizes;
SmallVector<OpFoldResult> result(staticTileSizes.size());
for (size_t i = 0; i < result.size(); ++i) {
if (staticTileSizes[i] == ShapedType::kDynamic) {
result[i] = innerTileSizeValues[i];
} else if (tensorType.isDynamicDim(i)) {
result[i] =
rewriter.create<arith::ConstantIndexOp>(loc, staticTileSizes[i])
.getResult();
} else {
result[i] = rewriter.getI64IntegerAttr(staticTileSizes[i]);
}
}
return result;
}
} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir