blob: e53703186e2762a80b11d8141423530c32cdf31a [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace LinalgExt {
Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) {
return TypeSwitch<Type, Value>(v.getType())
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
return builder.create<tensor::DimOp>(loc, v, dim);
})
.Case<MemRefType>([&](MemRefType t) -> Value {
return builder.create<memref::DimOp>(loc, v, dim);
});
}
OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim) {
auto t = v.getType().cast<ShapedType>();
if (t.isDynamicDim(dim)) {
return getDimValue(builder, loc, v, dim);
}
return builder.getI64IntegerAttr(t.getDimSize(dim));
}
SmallVector<OpFoldResult> getDims(OpBuilder &builder, Location loc,
Value shapedTypeValue) {
return llvm::to_vector(llvm::map_range(
llvm::seq<int64_t>(
0, shapedTypeValue.getType().cast<ShapedType>().getRank()),
[&](int64_t dim) { return getDim(builder, loc, shapedTypeValue, dim); }));
}
} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir