blob: a0119395cc477bc618b59f2c07ab51e3a5afb4e8 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===- ReductionLowering.cpp ------------------------------------*-C++//-*-===//
//
// Lower reduction dispatch regions to Linalg.
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h"
namespace mlir {
namespace iree_compiler {
namespace {
/// Pass to lower the reduction dispatch functions to Linalg.
struct HLOReductionToLinalgPass : public ModulePass<HLOReductionToLinalgPass> {
void runOnModule() override;
};
/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
/// are "parallel" except the `reductionDim`-th element.
// TODO(hanchung): Use helpers in StructuredOpsUtils.h instead of hardcoded
// strings once the build system is set up.
ArrayAttr getNLoopsAttrs(Builder b, unsigned nLoops, int reductionDim) {
SmallVector<Attribute, 3> attrs(nLoops, b.getStringAttr("parallel"));
attrs[reductionDim] = b.getStringAttr("reduction");
return b.getArrayAttr(attrs);
}
/// Base class for legalization of operations within the reduction apply
/// function (and the function itself).
template <typename OpTy>
class ApplyFunctionConversion : public OpConversionPattern<OpTy> {
public:
ApplyFunctionConversion(MLIRContext *context, TypeConverter &converter,
PatternBenefit benefit = 1)
: OpConversionPattern<OpTy>(context, benefit), converter(converter) {}
protected:
TypeConverter &converter;
};
/// The apply function has a signature (lhs, rhs) -> output, all of the same
/// tensor type t. This is converted to a function with the same signature but
/// with element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
/// be converted to "(f32, f32) -> f32".
struct ReductionApplyFnConversion final : ApplyFunctionConversion<FuncOp> {
using ApplyFunctionConversion<FuncOp>::ApplyFunctionConversion;
PatternMatchResult matchAndRewrite(
FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
//===----------------------------------------------------------------------===//
// Reduction entry function body
//===----------------------------------------------------------------------===//
/// Adds a body with linalg ops for the reduction entry function `fn`. `fn` is
/// assumed to be empty. The dimension to reduce must be set as an attribute on
/// `fn`.
static LogicalResult addReductionEntryFnBody(OpBuilder &builder, FuncOp fn,
FuncOp applyFn) {
if (fn.getNumArguments() != 3) return failure();
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(fn.addEntryBlock());
auto src = fn.getArgument(0);
auto initVal = fn.getArgument(1);
auto dst = fn.getArgument(2);
auto srcArgType = src.getType().template cast<ShapedType>();
unsigned nInputRank = srcArgType.getRank();
if (!nInputRank) return failure();
auto loc = fn.getLoc();
int reductionDim =
fn.getAttrOfType<IntegerAttr>("iree.executable.reduction.dimension")
.getInt();
// Prepare indexing maps for linalg generic op. The elements are for src,
// initial value and dst, respectively.
SmallVector<Attribute, 2> indexingMaps;
indexingMaps.emplace_back(
AffineMapAttr::get(builder.getMultiDimIdentityMap(nInputRank)));
indexingMaps.emplace_back(AffineMapAttr::get(AffineMap::get(
nInputRank, /*symbolCount=*/0, {builder.getAffineConstantExpr(0)})));
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < nInputRank; i++) {
if (i == reductionDim) continue;
exprs.push_back(builder.getAffineDimExpr(i));
}
indexingMaps.emplace_back(
AffineMapAttr::get(AffineMap::get(nInputRank, /*symbolCount=*/0, exprs)));
SmallVector<Type, 2> resultTypes = {};
SmallVector<Value, 2> linalgOpArgs = {src, initVal, dst};
auto linalgOp = builder.create<linalg::IndexedGenericOp>(
loc, resultTypes, linalgOpArgs,
builder.getI64IntegerAttr(2), // args_in
builder.getI64IntegerAttr(1), // args_out
builder.getArrayAttr(indexingMaps),
getNLoopsAttrs(builder, nInputRank, reductionDim),
/*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr);
// Add a block to the region.
auto *region = &linalgOp.region();
auto *block = builder.createBlock(region, region->end());
Type indexType = builder.getIndexType();
Type elemType = srcArgType.getElementType();
for (int i = 0; i < nInputRank; ++i) block->addArguments(indexType);
for (int i = 0; i < linalgOp.getNumInputsAndOutputs(); ++i)
block->addArguments(elemType);
auto numArgs = block->getNumArguments();
auto blockSrcArg = block->getArgument(numArgs - 3);
auto blockInitArg = block->getArgument(numArgs - 2);
auto blockDstArg = block->getArgument(numArgs - 1);
auto zero = builder.create<ConstantOp>(loc, indexType,
builder.getIntegerAttr(indexType, 0));
auto cond = builder.create<CmpIOp>(loc, CmpIPredicate::eq,
block->getArgument(reductionDim),
zero.getResult());
auto lhs = builder.create<SelectOp>(loc, cond, blockInitArg, blockDstArg);
auto rhs = blockSrcArg;
// Inline the function call into the region.
// TODO(hanchung): Use the MLIR inline pass. This requires to implement Linalg
// call interface.
BlockAndValueMapping mapper;
assert(applyFn.getNumArguments() == 2);
mapper.map(applyFn.getArgument(0), lhs);
mapper.map(applyFn.getArgument(1), rhs);
assert(mlir::has_single_element(applyFn.getBlocks()) &&
"apply function with multiple blocks is not support");
applyFn.getBody().cloneInto(block->getParent(), mapper);
Block &newBlock = *block->getParent()->getBlocks().rbegin();
block->getOperations().splice(block->end(), newBlock.getOperations());
newBlock.erase();
// Replace the terminator with linalg::YieldOp is required by linalg.generic
// op.
Operation *terminator = block->getTerminator();
builder.create<linalg::YieldOp>(loc, terminator->getOperands());
terminator->erase();
builder.setInsertionPointAfter(linalgOp);
builder.create<ReturnOp>(loc);
return success();
}
//===----------------------------------------------------------------------===//
// Apply function conversion
//===----------------------------------------------------------------------===//
PatternMatchResult ReductionApplyFnConversion::matchAndRewrite(
FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto fnType = funcOp.getType();
if (fnType.getNumInputs() != 2 || fnType.getNumResults() != 1)
return matchFailure();
if (fnType.getInput(0) != fnType.getInput(1) ||
fnType.getInput(0) != fnType.getResult(0))
return matchFailure();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
auto convertedType = converter.convertType(fnType.getInput(0));
if (!convertedType) return matchFailure();
signatureConverter.addInputs(0, convertedType);
signatureConverter.addInputs(1, convertedType);
auto newFn = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFn.getBody(), newFn.end());
newFn.setType(rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
convertedType));
rewriter.applySignatureConversion(&newFn.getBody(), signatureConverter);
rewriter.eraseOp(funcOp);
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
/// Operations within the apply function need to be converted to standard ops.
template <typename OpTy>
struct ReductionOpConversion final : public ApplyFunctionConversion<OpTy> {
using ApplyFunctionConversion<OpTy>::ApplyFunctionConversion;
PatternMatchResult matchAndRewrite(
OpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (operands.size() != 2) return this->matchFailure();
SmallVector<Type, 1> resultElemTypes = {operands[0].getType()};
Value opResult = xla_lhlo::MapXlaOpToStdScalarOp<OpTy>(op, resultElemTypes,
operands, &rewriter);
rewriter.replaceOp(op, opResult);
return this->matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// Utils
//===----------------------------------------------------------------------===//
static LogicalResult getApplyFunctions(ModuleOp module,
SmallVectorImpl<Operation *> &applyFns) {
auto fns = module.getOps<FuncOp>();
SymbolTable table(module);
for (auto funcOp : fns) {
if (!funcOp.getAttr("iree.executable.reduction")) continue;
auto applyFnSymRef = funcOp.template getAttrOfType<FlatSymbolRefAttr>(
"iree.executable.reduction.apply");
auto applyFn = table.lookup<FuncOp>(applyFnSymRef.getValue());
if (!applyFn) return module.emitError("can't find the apply function");
applyFns.push_back(applyFn.getOperation());
}
return success();
}
//===----------------------------------------------------------------------===//
// Pattern builder
//===----------------------------------------------------------------------===//
static LogicalResult lowerReductionApplyFnToLinalg(MLIRContext *context,
ArrayRef<Operation *> fns) {
TypeConverter converter;
converter.addConversion([](Type type) {
return type.isSignlessIntOrFloat() ? type : Optional<Type>();
});
converter.addConversion([](RankedTensorType type) {
return type.getRank() == 0 ? type.getElementType() : Optional<Type>();
});
OwningRewritePatternList patterns;
patterns.insert<ReductionApplyFnConversion,
ReductionOpConversion<xla_hlo::AddOp> >(context, converter);
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
if (failed(applyFullConversion(fns, target, patterns))) return failure();
return success();
}
static LogicalResult lowerReductionEntryFnToLinalg(MLIRContext *context,
ModuleOp module) {
OwningRewritePatternList patterns;
OpBuilder builder(module.getBodyRegion());
SymbolTable table(module);
for (auto fn : module.getOps<FuncOp>()) {
if (!fn.getAttr("iree.executable.reduction") || !fn.empty()) continue;
auto applyFnSymRef = fn.template getAttrOfType<FlatSymbolRefAttr>(
"iree.executable.reduction.apply");
auto applyFn = table.lookup<FuncOp>(applyFnSymRef.getValue());
if (failed(addReductionEntryFnBody(builder, fn, applyFn))) return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// Pass for invoking the conversion
//===----------------------------------------------------------------------===//
void HLOReductionToLinalgPass::runOnModule() {
MLIRContext *context = &getContext();
auto module = getModule();
SmallVector<Operation *, 1> applyFns;
if (failed(getApplyFunctions(module, applyFns))) return signalPassFailure();
if (failed(lowerReductionApplyFnToLinalg(context, applyFns)))
return signalPassFailure();
if (failed(lowerReductionEntryFnToLinalg(context, module)))
return signalPassFailure();
}
static PassRegistration<HLOReductionToLinalgPass> pass(
"iree-hlo-reduction-to-linalg",
"Convert the reduction dispatch functions to Linalg");
std::unique_ptr<Pass> createHLOReductionToLinalgPass() {
return std::make_unique<HLOReductionToLinalgPass>();
}
} // namespace iree_compiler
} // namespace mlir