| // Copyright 2022 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" |
| |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/Builders.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace LinalgExt { |
| |
| Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) { |
| ShapedType type = v.getType().cast<ShapedType>(); |
| if (!type.isDynamicDim(dim)) { |
| return builder.create<arith::ConstantIndexOp>(loc, type.getDimSize(dim)); |
| } |
| return TypeSwitch<Type, Value>(v.getType()) |
| .Case<RankedTensorType>([&](RankedTensorType t) -> Value { |
| return builder.create<tensor::DimOp>(loc, v, dim); |
| }) |
| .Case<MemRefType>([&](MemRefType t) -> Value { |
| return builder.create<memref::DimOp>(loc, v, dim); |
| }); |
| } |
| |
| OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim) { |
| auto t = v.getType().cast<ShapedType>(); |
| if (t.isDynamicDim(dim)) { |
| return getDimValue(builder, loc, v, dim); |
| } |
| return builder.getI64IntegerAttr(t.getDimSize(dim)); |
| } |
| |
| SmallVector<OpFoldResult> getDims(OpBuilder &builder, Location loc, |
| Value shapedTypeValue) { |
| return llvm::map_to_vector( |
| llvm::seq<int64_t>( |
| 0, shapedTypeValue.getType().cast<ShapedType>().getRank()), |
| [&](int64_t dim) { return getDim(builder, loc, shapedTypeValue, dim); }); |
| } |
| |
| SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos, |
| int64_t rank) { |
| SmallVector<int64_t> interchangeVector; |
| interchangeVector.reserve(dimsPos.size()); |
| // First map dims and their position. For example, dims_pos = [2, 0] will map |
| // to: |
| // [ |
| // [ key: 2, value: 0] |
| // [ key: 0, value: 1] |
| // ] |
| // where key is the idx in dims_pos while value its position in dims_pos. |
| DenseMap<int64_t, int64_t> dimsAndPosMapping; |
| for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++) |
| dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx; |
| |
| // Scan the position in order and insert the value in the map |
| // to compute the interchange vector. |
| for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) { |
| if (dimsAndPosMapping.count(dimsIdx)) |
| interchangeVector.push_back(dimsAndPosMapping[dimsIdx]); |
| } |
| return interchangeVector; |
| } |
| |
| Value createValueFrom2DConstant(const float *val, int64_t rows, int64_t cols, |
| Location loc, RewriterBase &rewriter) { |
| ArrayRef<float> vector(val, rows * cols); |
| SmallVector<int64_t> shape{rows, cols}; |
| return rewriter.create<arith::ConstantOp>( |
| loc, DenseFPElementsAttr::get( |
| RankedTensorType::get(shape, rewriter.getF32Type()), vector)); |
| } |
| |
| SmallVector<int64_t> asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) { |
| SmallVector<int64_t> result; |
| for (auto o : ofrs) { |
| // Have to do this first, as getConstantIntValue special-cases constants. |
| if (o.dyn_cast<Value>()) |
| result.push_back(ShapedType::kDynamic); |
| else |
| result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic)); |
| } |
| return result; |
| } |
| |
| FailureOr<SmallVector<OpFoldResult>> |
| getInnerTileSizesOfr(OpBuilder &rewriter, Location loc, |
| RankedTensorType tensorType, |
| const MaterializeEncodingInfo &materializeEncodingInfo, |
| MaterializeEncodingValueFn materializeEncodingValueFn) { |
| ArrayRef<int64_t> staticTileSizes = materializeEncodingInfo.innerTileSizes; |
| if (llvm::all_of(staticTileSizes, |
| [](int64_t i) { return !ShapedType::isDynamic(i); })) { |
| return getAsOpFoldResult(rewriter.getI64ArrayAttr(staticTileSizes)); |
| } |
| assert(materializeEncodingValueFn && |
| "When dynamic tile sizes are generated, a MaterializeEncodingValueFn " |
| "should be provided."); |
| |
| FailureOr<MaterializeEncodingValueInfo> materializeEncodingValueInfo = |
| materializeEncodingValueFn(tensorType, rewriter, loc); |
| if (failed(materializeEncodingValueInfo)) { |
| return failure(); |
| } |
| ArrayRef<Value> innerTileSizeValues = |
| materializeEncodingValueInfo->innerTileSizes; |
| |
| SmallVector<OpFoldResult> result(staticTileSizes.size()); |
| for (size_t i = 0; i < result.size(); ++i) { |
| if (staticTileSizes[i] == ShapedType::kDynamic) { |
| result[i] = innerTileSizeValues[i]; |
| } else if (tensorType.isDynamicDim(i)) { |
| result[i] = |
| rewriter.create<arith::ConstantIndexOp>(loc, staticTileSizes[i]) |
| .getResult(); |
| } else { |
| result[i] = rewriter.getI64IntegerAttr(staticTileSizes[i]); |
| } |
| } |
| return result; |
| } |
| |
| } // namespace LinalgExt |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |