blob: fdee32e3b67ef10b4ae1f562e24206aa9dc39dfb [file] [log] [blame]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-util-optimize-int-arithmetic"
using llvm::dbgs;
using namespace mlir::dataflow;
namespace mlir::iree_compiler::IREE::Util {
#define GEN_PASS_DEF_OPTIMIZEINTARITHMETICPASS
#include "iree/compiler/Dialect/Util/Transforms/Passes.h.inc"
namespace {
// An index_cast from i64 to index is a no-op on targets where index is
// 64 bits. But on targets where index is 32bits, it is a truncate. On these
// platforms, demoting to an index is only conservatively correct if all
// operands and all results are within the unsigned 32bit bounds.
// While there is a good chance that such arithmetic that exceeds these
// bounds is simply wrong/overflow-ridden, we opt to do no harm and preseve
// the exact results. This optimization is targeted at "small" sequences
// anyway and this catches everything known to exist. If needed, this rule
// could be dropped if it is ever appropriate to unconditionally assume
// 64bit semantics.
static constexpr uint64_t SAFE_INDEX_UNSIGNED_MAX_VALUE =
std::numeric_limits<uint32_t>::max();
//===----------------------------------------------------------------------===//
// Signed -> Unsigned patterns
// Note that there is an upstream UnsignedWhenEquivalent pass but it uses
// DialectConversion and legality vs simple patterns, so we cannot use it.
// Some support code has been adapted from that pass, though.
//===----------------------------------------------------------------------===//
/// Succeeds when a value is statically non-negative in that it has a lower
/// bound on its value (if it is treated as signed) and that bound is
/// non-negative.
static bool staticallyLegalToConvertToUnsigned(DataFlowSolver &solver,
Value v) {
auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
if (!result || result->getValue().isUninitialized()) {
return false;
}
const ConstantIntRanges &range = result->getValue().getValue();
bool isNonNegative = range.smin().isNonNegative();
Type type = v.getType();
if (isa<IndexType>(type)) {
bool canSafelyTruncate =
range.umin().getZExtValue() <= SAFE_INDEX_UNSIGNED_MAX_VALUE &&
range.umax().getZExtValue() <= SAFE_INDEX_UNSIGNED_MAX_VALUE;
return isNonNegative && canSafelyTruncate;
} else {
return isNonNegative;
}
}
/// Succeeds if an op can be converted to its unsigned equivalent without
/// changing its semantics. This is the case when none of its openands or
/// results can be below 0 when analyzed from a signed perspective.
static LogicalResult
staticallyLegalToConvertToUnsignedOp(DataFlowSolver &solver, Operation *op) {
auto nonNegativePred = [&solver](Value v) -> bool {
bool isNonNegative = staticallyLegalToConvertToUnsigned(solver, v);
return isNonNegative;
};
return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
llvm::all_of(op->getResults(), nonNegativePred));
}
template <typename Signed, typename Unsigned>
struct ConvertOpToUnsigned : public OpRewritePattern<Signed> {
ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &solver)
: OpRewritePattern<Signed>(context), solver(solver) {}
LogicalResult matchAndRewrite(Signed op,
PatternRewriter &rewriter) const override {
if (failed(staticallyLegalToConvertToUnsignedOp(solver, op)))
return failure();
rewriter.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
op->getOperands(), op->getAttrs());
return success();
}
DataFlowSolver &solver;
};
//===----------------------------------------------------------------------===//
// Int64 -> unsigned index demotion
// Torch does a lot of indexy manipulation using scalar i64 ops. We undo these
// here and treat them as index when safe to do so. Since the casts can block
// optimizations, it can be useful to eliminate them when possible.
//===----------------------------------------------------------------------===//
// Matches IR like:
// %5 = arith.addi %0, %1 : int64
// %6 = arith.index_castui %5 : int64 to index
//
// And moves the index_castui to the producer's operands:
// %3 = arith.index_castui %0 : int64 to index
// %4 = arith.index_castui %1 : int64 to index
// %5 = arith.addi %3, %4 : index
//
struct ConvertUnsignedI64IndexCastProducerToIndex
: public OpRewritePattern<arith::IndexCastUIOp> {
ConvertUnsignedI64IndexCastProducerToIndex(MLIRContext *context,
DataFlowSolver &solver)
: OpRewritePattern(context), solver(solver) {}
LogicalResult matchAndRewrite(arith::IndexCastUIOp origIndexOp,
PatternRewriter &rewriter) const override {
Type inType = origIndexOp.getIn().getType();
Type outType = origIndexOp.getOut().getType();
if (!inType.isSignlessInteger(64) && isa<IndexType>(outType))
return failure();
Operation *producer = origIndexOp.getIn().getDefiningOp();
if (!producer)
return failure();
auto producerResult = producer->getResult(0);
if (!producerResult.hasOneUse())
return failure();
auto pred = [&](Value v) -> bool {
auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
if (!result || result->getValue().isUninitialized()) {
return false;
}
const ConstantIntRanges &range = result->getValue().getValue();
bool isInBounds =
range.umin().getZExtValue() <= SAFE_INDEX_UNSIGNED_MAX_VALUE &&
range.umax().getZExtValue() <= SAFE_INDEX_UNSIGNED_MAX_VALUE;
return isInBounds;
};
auto isOpStaticallyLegal = [&](Operation *op) -> bool {
return llvm::all_of(op->getOperands(), pred) &&
llvm::all_of(op->getResults(), pred);
};
if (!isa_and_present<arith::AddIOp, arith::CeilDivUIOp, arith::DivUIOp,
arith::MaxUIOp, arith::MinUIOp, arith::MulIOp,
arith::RemUIOp, arith::SubIOp>(producer))
return failure();
if (!isOpStaticallyLegal(producer))
return failure();
// Make modifications.
rewriter.modifyOpInPlace(producer, [&]() {
rewriter.setInsertionPoint(producer);
for (auto &operand : producer->getOpOperands()) {
if (operand.get().getType() != inType)
continue;
Value newOperand = rewriter.create<arith::IndexCastUIOp>(
producer->getLoc(), outType, operand.get());
operand.set(newOperand);
}
producer->getResult(0).setType(outType);
});
origIndexOp.getOut().replaceAllUsesWith(producer->getResult(0));
rewriter.eraseOp(origIndexOp);
return success();
}
DataFlowSolver &solver;
};
//===----------------------------------------------------------------------===//
// Index -> int32 assumption narrowing
// If we're narrowing `index` values to `i32`, a `util.assume.int` on `index`
// introduces unnecessary zero-extensions and truncations to/from `index`
// when introducing assumptions.
//===----------------------------------------------------------------------===//
struct RemoveIndexCastForAssumeOfI32
: public OpRewritePattern<Util::AssumeIntOp> {
RemoveIndexCastForAssumeOfI32(MLIRContext *context, DataFlowSolver &solver)
: OpRewritePattern(context), solver(solver) {}
LogicalResult matchAndRewrite(Util::AssumeIntOp op,
PatternRewriter &rewriter) const override {
llvm::SmallBitVector needNarrowing(op.getNumOperands(), false);
for (auto [idx, arg] : llvm::enumerate(op.getOperands())) {
if (!arg.getType().isIndex())
continue;
auto castOp = arg.getDefiningOp<arith::IndexCastUIOp>();
if (!castOp)
continue;
Value castIn = castOp.getIn();
Type intType = castIn.getType();
if (intType.getIntOrFloatBitWidth() > 32)
continue;
needNarrowing[idx] = true;
}
if (needNarrowing.none())
return failure();
SmallVector<Value> newArgs;
newArgs.reserve(op.getNumOperands());
for (auto [idx, arg] : llvm::enumerate(op.getOperands())) {
if (!needNarrowing[idx]) {
newArgs.push_back(arg);
continue;
}
newArgs.push_back(arg.getDefiningOp<arith::IndexCastUIOp>().getIn());
}
ArrayAttr assumptions = op.getAssumptionsAttr();
auto newOp = rewriter.create<Util::AssumeIntOp>(
op.getLoc(), ValueTypeRange<ArrayRef<Value>>{newArgs}, newArgs,
assumptions);
SmallVector<Value> replacements(newOp.getResults());
for (auto [newRes, oldRes] :
llvm::zip_equal(replacements, op.getResults())) {
if (newRes.getType() != oldRes.getType()) {
newRes = rewriter.create<arith::IndexCastUIOp>(
op.getLoc(), oldRes.getType(), newRes);
}
// Preserve assumption state.
auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldRes);
if (oldState) {
(void)solver.getOrCreateState<IntegerValueRangeLattice>(newRes)->join(
*oldState);
}
}
rewriter.replaceOp(op, replacements);
return success();
}
DataFlowSolver &solver;
};
//===----------------------------------------------------------------------===//
// scf.for induction variable range narrowing
// If the induction variable of an scf.for can be represented as an I32,
// make that change to save on registers etc.
//===----------------------------------------------------------------------===//
struct NarrowSCFForIvToI32 : public OpRewritePattern<scf::ForOp> {
NarrowSCFForIvToI32(MLIRContext *context, DataFlowSolver &solver)
: OpRewritePattern(context), solver(solver) {}
LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const override {
Location loc = forOp.getLoc();
Value iv = forOp.getInductionVar();
Type srcType = iv.getType();
if (!srcType.isIndex() && !srcType.isInteger(64))
return rewriter.notifyMatchFailure(forOp, "IV isn't an index or i64");
if (!staticallyLegalToConvertToUnsigned(solver, iv))
return rewriter.notifyMatchFailure(forOp, "IV isn't non-negative");
if (!staticallyLegalToConvertToUnsigned(solver, forOp.getStep()))
return rewriter.notifyMatchFailure(forOp, "Step isn't non-negative");
auto *ivState = solver.lookupState<IntegerValueRangeLattice>(iv);
if (ivState->getValue().getValue().smax().getActiveBits() > 31)
return rewriter.notifyMatchFailure(forOp, "IV won't fit in signed int32");
Type i32 = rewriter.getI32Type();
auto doCastDown = [&](Value v) -> Value {
if (srcType.isIndex())
return rewriter.create<arith::IndexCastUIOp>(loc, i32, v);
else
return rewriter.create<arith::TruncIOp>(loc, i32, v);
};
Value newLb = doCastDown(forOp.getLowerBound());
Value newUb = doCastDown(forOp.getUpperBound());
Value newStep = doCastDown(forOp.getStep());
{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(&forOp.getRegion().front());
Value castBackOp;
if (srcType.isIndex()) {
castBackOp =
rewriter.create<arith::IndexCastUIOp>(iv.getLoc(), srcType, iv);
} else {
castBackOp = rewriter.create<arith::ExtUIOp>(iv.getLoc(), srcType, iv);
}
(void)solver.getOrCreateState<IntegerValueRangeLattice>(castBackOp)
->join(*ivState);
rewriter.replaceAllUsesExcept(iv, castBackOp, castBackOp.getDefiningOp());
}
solver.eraseState(iv);
rewriter.modifyOpInPlace(forOp, [&]() {
iv.setType(i32);
forOp.getLowerBoundMutable().assign(newLb);
forOp.getUpperBoundMutable().assign(newUb);
forOp.getStepMutable().assign(newStep);
});
return success();
}
DataFlowSolver &solver;
};
//===----------------------------------------------------------------------===//
// Divisibility
//===----------------------------------------------------------------------===//
static LogicalResult getDivisibility(DataFlowSolver &solver, Operation *op,
Value value, PatternRewriter &rewriter,
ConstantIntDivisibility &out) {
auto *div = solver.lookupState<IntegerDivisibilityLattice>(value);
if (!div || div->getValue().isUninitialized())
return rewriter.notifyMatchFailure(op,
"divisibility could not be determined");
out = div->getValue().getValue();
LLVM_DEBUG(dbgs() << " * Resolved divisibility: " << out << "\n");
return success();
}
struct RemUIDivisibilityByConstant : public OpRewritePattern<arith::RemUIOp> {
RemUIDivisibilityByConstant(MLIRContext *context, DataFlowSolver &solver)
: OpRewritePattern(context), solver(solver) {}
LogicalResult matchAndRewrite(arith::RemUIOp op,
PatternRewriter &rewriter) const override {
APInt rhsConstant;
if (!matchPattern(op.getRhs(), m_ConstantInt(&rhsConstant)))
return rewriter.notifyMatchFailure(op, "rhs is not constant");
ConstantIntDivisibility lhsDiv;
if (failed(getDivisibility(solver, op, op.getLhs(), rewriter, lhsDiv)))
return failure();
uint64_t rhsValue = rhsConstant.getZExtValue();
if (rhsValue > 0 && lhsDiv.udiv() > 0) {
if (lhsDiv.udiv() % rhsValue != 0)
return rewriter.notifyMatchFailure(op, "rhs does not divide lhs");
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, rewriter.getZeroAttr(op.getResult().getType()));
return success();
}
return failure();
}
DataFlowSolver &solver;
};
//===----------------------------------------------------------------------===//
// Affine expansion
// affine.apply expansion can fail after producing a lot of IR. Since this is
// a bad thing to be doing as part of our overall iteration, we do it as a
// preprocessing walk. This also lets it be well behaved with respect to
// error messaging, etc. We will likely replace this with a more integrated
// version at some point which can use the bounds analysis to avoid corners
// of the original.
//===----------------------------------------------------------------------===//
void expandAffineOps(Operation *rootOp) {
IRRewriter rewriter(rootOp->getContext());
rootOp->walk([&](affine::AffineApplyOp op) {
LLVM_DEBUG(dbgs() << "** Expand affine.apply: " << op << "\n");
rewriter.setInsertionPoint(op);
auto maybeExpanded =
mlir::affine::expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
llvm::to_vector<4>(op.getOperands()));
if (!maybeExpanded) {
LLVM_DEBUG(dbgs() << "** ERROR: Failed to expand affine.apply\n");
return;
}
rewriter.replaceOp(op, *maybeExpanded);
});
}
//===----------------------------------------------------------------------===//
// General optimization patterns
//===----------------------------------------------------------------------===//
struct ElideTruncOfIndexCast : public OpRewritePattern<arith::TruncIOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
PatternRewriter &rewriter) const override {
Operation *producer = truncOp.getOperand().getDefiningOp();
if (!producer)
return failure();
if (!isa<arith::IndexCastOp, arith::IndexCastUIOp>(producer))
return failure();
rewriter.replaceOpWithNewOp<arith::IndexCastUIOp>(
truncOp, truncOp.getResult().getType(), producer->getOperand(0));
return success();
}
};
//===----------------------------------------------------------------------===//
// Pass setup
//===----------------------------------------------------------------------===//
class DataFlowListener : public RewriterBase::Listener {
public:
DataFlowListener(DataFlowSolver &s) : s(s) {}
protected:
void notifyOperationErased(Operation *op) override {
s.eraseState(s.getProgramPointAfter(op));
for (Value res : op->getResults())
s.eraseState(res);
}
void notifyOperationModified(Operation *op) override {
s.eraseState(s.getProgramPointAfter(op));
for (Value res : op->getResults()) {
s.eraseState(res);
}
}
DataFlowSolver &s;
};
class OptimizeIntArithmeticPass
: public impl::OptimizeIntArithmeticPassBase<OptimizeIntArithmeticPass> {
using Base::Base;
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
expandAffineOps(op);
DataFlowSolver solver;
// Needed to make the dead code analyis not be too conservative.
solver.load<SparseConstantPropagation>();
solver.load<DeadCodeAnalysis>();
solver.load<IntegerRangeAnalysis>();
solver.load<IntegerDivisibilityAnalysis>();
DataFlowListener listener(solver);
RewritePatternSet patterns(ctx);
// Populate upstream arith patterns.
arith::populateIntRangeOptimizationsPatterns(patterns, solver);
if (narrowToI32) {
arith::populateIntRangeNarrowingPatterns(patterns, solver, {32});
patterns.add<NarrowSCFForIvToI32, RemoveIndexCastForAssumeOfI32>(ctx,
solver);
}
// Populate canonicalization patterns.
auto arithDialect = ctx->getOrLoadDialect<arith::ArithDialect>();
for (const RegisteredOperationName &name : ctx->getRegisteredOperations()) {
if (&name.getDialect() == arithDialect)
name.getCanonicalizationPatterns(patterns, ctx);
}
// General optimization patterns.
patterns.add<ElideTruncOfIndexCast>(ctx);
// Populate unsigned conversion patterns.
patterns.add<ConvertUnsignedI64IndexCastProducerToIndex,
ConvertOpToUnsigned<arith::CeilDivSIOp, arith::CeilDivUIOp>,
ConvertOpToUnsigned<arith::DivSIOp, arith::DivUIOp>,
ConvertOpToUnsigned<arith::FloorDivSIOp, arith::DivUIOp>,
ConvertOpToUnsigned<arith::IndexCastOp, arith::IndexCastUIOp>,
ConvertOpToUnsigned<arith::RemSIOp, arith::RemUIOp>,
ConvertOpToUnsigned<arith::MinSIOp, arith::MinUIOp>,
ConvertOpToUnsigned<arith::MaxSIOp, arith::MaxUIOp>,
ConvertOpToUnsigned<arith::ExtSIOp, arith::ExtUIOp>>(ctx,
solver);
// Populate divisibility patterns.
patterns.add<RemUIDivisibilityByConstant>(ctx, solver);
GreedyRewriteConfig config;
// Results in fewer recursive data flow flushes/cycles on modification.
config.useTopDownTraversal = false;
config.listener = &listener;
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (int i = 0;; ++i) {
LLVM_DEBUG(dbgs() << " * Starting iteration: " << i << "\n");
if (failed(solver.initializeAndRun(op))) {
emitError(op->getLoc()) << "failed to perform int range analysis";
return signalPassFailure();
}
LLVM_DEBUG(
dbgs() << " * Finished Running Solver -- Applying Patterns\n");
bool changed = false;
if (failed(applyPatternsGreedily(op, frozenPatterns, config, &changed))) {
emitError(op->getLoc())
<< "int arithmetic optimization failed to converge on iteration "
<< i;
return signalPassFailure();
}
if (!changed)
break;
}
}
};
} // namespace
} // namespace mlir::iree_compiler::IREE::Util