blob: 899c4fc2ffb66da955b65c41c3cb16ce73061fef [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
using namespace mlir;
namespace IREE = mlir::iree_compiler::IREE;
using namespace IREE::LinalgExt;
OpOperandVector::operator SmallVector<Value>() {
SmallVector<Value> result;
result.reserve(this->size());
llvm::transform(*this, std::back_inserter(result),
[](OpOperand *opOperand) { return opOperand->get(); });
return result;
}
LogicalResult
IREE::LinalgExt::detail::verifyLinalgExtOpInterface(Operation *op) {
LinalgExtOp linalgExtOp = cast<LinalgExtOp>(op);
if (op->getNumResults()) {
if (!linalgExtOp.hasTensorSemantics()) {
return linalgExtOp.emitOpError(
"expected inputs and outputs to be RankedTensorType or scalar");
}
if (op->getNumResults() != linalgExtOp.outputs().size()) {
return linalgExtOp.emitOpError(
"expected number of outputs to be same as the number of results");
}
for (auto en : llvm::enumerate(op->getResultTypes())) {
Type outputType = linalgExtOp.outputs()[en.index()].getType();
if (en.value() != outputType) {
return linalgExtOp.emitOpError("expected type of `outs` operand #")
<< en.index() << " " << outputType
<< " to be same as result type " << en.value();
}
}
} else {
if (!linalgExtOp.hasBufferSemantics()) {
return linalgExtOp.emitOpError(
"expected inputs and outputs to be MemRefType or scalar");
}
}
return success();
}
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOpInterfaces.cpp.inc" // IWYU pragma: export
template <typename Ty, typename DimOpTy>
static void getDimValues(OpBuilder &b, Location loc, Value v, Ty t,
SmallVector<Value> &dimVals) {
for (auto dim : llvm::enumerate(t.getShape())) {
if (ShapedType::isDynamic(dim.value())) {
dimVals.push_back(b.create<DimOpTy>(loc, v, dim.index()));
} else {
dimVals.push_back(b.create<arith::ConstantIndexOp>(loc, dim.value()));
}
}
}
LogicalResult LinalgExtOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
Operation *op = getOperation();
for (auto output : outputs()) {
SmallVector<Value> dims;
Type outputType = output.getType();
if (auto rankedTensorType = outputType.dyn_cast<RankedTensorType>()) {
getDimValues<RankedTensorType, tensor::DimOp>(b, op->getLoc(), output,
rankedTensorType, dims);
} else if (auto memrefType = outputType.dyn_cast<MemRefType>()) {
getDimValues<MemRefType, memref::DimOp>(b, op->getLoc(), output,
memrefType, dims);
} else if (!outputType.isIntOrIndexOrFloat()) {
return op->emitOpError(
"invalid type for output operand, expected tensor, "
"memref or scalar type");
}
reifiedReturnShapes.emplace_back(std::move(dims));
}
return success();
}