[GlobalOptimization] Lift `linalg.generic` ops to `linalg.batch_matmul/linalg.batch_vecmat/linalg.batch_matvec` (#15339)
This adds a pass that lifts generalized batch matmul `linalg.generic` ops
into the corresponding linalg name ops. Any element type promotion in
the `linalg.generic` op is rewritten as a producer op, and the element
types of the new linalg named op will be the already-promoted type. This
allows having ukernels with unsigned types, since the unsignedness of
the inputs can be inferred from the producer op.
---------
Co-authored-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
index 7bb6437..5abc9b4 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
@@ -49,6 +49,7 @@
"EraseUnusedLinalgOperands.cpp",
"ExpandVectors.cpp",
"FuseDequantizationMatmul.cpp",
+ "LiftGenericToTransposeBatchMatmul.cpp",
"MaterializeHomogeneousEncodings.cpp",
"Passes.cpp",
"RemoveZeroExtentTensors.cpp",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
index ea706e1..4fe1abe 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
@@ -45,6 +45,7 @@
"EraseUnusedLinalgOperands.cpp"
"ExpandVectors.cpp"
"FuseDequantizationMatmul.cpp"
+ "LiftGenericToTransposeBatchMatmul.cpp"
"MaterializeHomogeneousEncodings.cpp"
"Passes.cpp"
"RemoveZeroExtentTensors.cpp"
diff --git a/compiler/src/iree/compiler/GlobalOptimization/LiftGenericToTransposeBatchMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/LiftGenericToTransposeBatchMatmul.cpp
new file mode 100644
index 0000000..a1b911b
--- /dev/null
+++ b/compiler/src/iree/compiler/GlobalOptimization/LiftGenericToTransposeBatchMatmul.cpp
@@ -0,0 +1,410 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/GlobalOptimization/PassDetail.h"
+#include "iree/compiler/GlobalOptimization/Passes.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <sstream>
+
+#define DEBUG_TYPE "iree-global-opt-lift-generic-to-tranpose-batch-matmul"
+
+namespace mlir {
+namespace iree_compiler {
+namespace GlobalOptimization {
+
+namespace {
+
+bool isCastOfBlockArgument(Operation *op) {
+ return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
+ isa<BlockArgument>(op->getOperand(0));
+}
+
+bool isCastOrInputBlockArgument(Value input, int64_t numInputs) {
+ if (!input.isa<BlockArgument>()) {
+ Operation *castOp0 = input.getDefiningOp();
+ if (!castOp0 || !isCastOfBlockArgument(castOp0)) {
+ return false;
+ }
+ return castOp0->getOperand(0).cast<BlockArgument>().getArgNumber() !=
+ numInputs;
+ } else {
+ return input.cast<BlockArgument>().getArgNumber() != numInputs;
+ }
+}
+
+static bool isBlockArgumentAtIndex(Value input, int64_t index) {
+ return input.isa<BlockArgument>() &&
+ input.cast<BlockArgument>().getArgNumber() == index;
+}
+
+// Returns true if the linalg::GenericOp has a body like a matmul. This
+// does not check the indexing maps
+//
+// This function looks for a body like:
+// ```mlir
+// ^bb0(%in: !lhs, %in_0: !rhs, %out: !out):
+// %3 = arith.extui %in : !lhs to !out
+// %4 = arith.extsi %in_0 : !rhs to !out
+// %5 = arith.muli %3, %4 : !out
+// %6 = arith.addi %5, %out : !out
+// linalg.yield %6 : !out
+// ```
+// Ensuring the following conditions:
+// 1) linalg.yield result comes from an arith.add op, accumulating on %out
+// 2) The other arith.add operand comes from arith.mul
+// 3) Both arith.mul operands are either block input arguments, or produced
+// by a `CastOpInterface` of a block input argument
+static LogicalResult hasMatmulBody(RewriterBase &rewriter,
+ linalg::GenericOp genericOp) {
+ int numInputs = genericOp.getNumDpsInputs();
+ if (numInputs != 2) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "op does not have exactly 2 inputs\n");
+ }
+ int numOutputs = genericOp.getNumDpsInits();
+ if (numOutputs != 1) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "op does not have exactly 1 output\n");
+ }
+
+ auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
+ Value yieldedValue = yieldOp->getOperand(0);
+
+ // Check that yielded value is an arith.add op, and is accumulating
+ Operation *addOp = yieldedValue.getDefiningOp();
+ if (!addOp) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "linalg.yield operand has no defining op\n");
+ }
+ if (!isa<arith::AddFOp, arith::AddIOp>(*addOp)) {
+ return rewriter.notifyMatchFailure(genericOp, "no arith.add body op\n");
+ }
+ Value add0 = addOp->getOperand(0);
+ Value add1 = addOp->getOperand(1);
+ if (!isBlockArgumentAtIndex(add0, numInputs) &&
+ !isBlockArgumentAtIndex(add1, numInputs)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "arith.add body op not accumulating on output\n");
+ }
+
+ // Check that the producer of the add is an arith.mul op
+ Operation *mulOp =
+ add0.isa<BlockArgument>() ? add1.getDefiningOp() : add0.getDefiningOp();
+ if (!mulOp) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "arith.add operand has no defining op\n");
+ }
+ if (!isa<arith::MulFOp, arith::MulIOp>(*mulOp)) {
+ return rewriter.notifyMatchFailure(genericOp, "no arith.mul body op\n");
+ }
+
+ // Check that non block args come from arith.ext ops
+ if (!isCastOrInputBlockArgument(mulOp->getOperand(0), numInputs) ||
+ !isCastOrInputBlockArgument(mulOp->getOperand(1), numInputs)) {
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "arith.mul operands are not CastOpInterface or BlockArgument\n");
+ }
+ return success();
+}
+
+static Value transposeTensor(Location loc, PatternRewriter &rewriter,
+ Value input, SmallVector<int64_t> perm) {
+ if (!perm.size()) {
+ return input;
+ }
+ if (llvm::all_of(llvm::enumerate(perm),
+ [](auto idx) { return idx.index() == idx.value(); })) {
+ return input;
+ }
+ auto inputType = cast<RankedTensorType>(input.getType());
+ SmallVector<OpFoldResult> inputMixedSizes =
+ tensor::getMixedSizes(rewriter, loc, input);
+ SmallVector<OpFoldResult> newInputMixedSizes =
+ applyPermutation(inputMixedSizes, perm);
+ Value init = rewriter.create<tensor::EmptyOp>(loc, newInputMixedSizes,
+ inputType.getElementType());
+ return rewriter.create<linalg::TransposeOp>(loc, input, init, perm)
+ .getResults()[0];
+}
+
+static FailureOr<Value> castTensor(Location loc, PatternRewriter &rewriter,
+ linalg::GenericOp genericOp,
+ int64_t inputIdx, Value input) {
+ Value output = genericOp.getResults()[0];
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto outputType = cast<RankedTensorType>(output.getType());
+ if (inputType.getElementType() == outputType.getElementType()) {
+ return input;
+ }
+ auto castedType =
+ RankedTensorType::get(inputType.getShape(), outputType.getElementType());
+ for (auto bodyOp : genericOp.getBody()->getOps<CastOpInterface>()) {
+ Value castInput = bodyOp->getOperand(0);
+ if (isBlockArgumentAtIndex(castInput, inputIdx)) {
+ return rewriter
+ .create(bodyOp->getLoc(), bodyOp->getName().getIdentifier(), input,
+ castedType, bodyOp->getAttrs())
+ ->getResult(0);
+ }
+ }
+ return failure();
+}
+
+template <typename OpTy>
+static LogicalResult
+liftGenericOp(PatternRewriter &rewriter, linalg::GenericOp genericOp,
+ SmallVector<int64_t> lhsPerm, SmallVector<int64_t> rhsPerm,
+ SmallVector<int64_t> outPerm) {
+ static_assert((std::is_same<OpTy, linalg::BatchVecmatOp>::value ||
+ std::is_same<OpTy, linalg::BatchMatvecOp>::value ||
+ std::is_same<OpTy, linalg::BatchMatmulOp>::value) &&
+ "expected only BatchVecmatOp, BatchMatvecOp, or BatchMatmulOp");
+ Location loc = genericOp.getLoc();
+ Value transposedLhs =
+ transposeTensor(loc, rewriter, genericOp.getInputs()[0], lhsPerm);
+ Value transposedRhs =
+ transposeTensor(loc, rewriter, genericOp.getInputs()[1], rhsPerm);
+ FailureOr<Value> extendedLhs =
+ castTensor(loc, rewriter, genericOp, 0, transposedLhs);
+ FailureOr<Value> extendedRhs =
+ castTensor(loc, rewriter, genericOp, 1, transposedRhs);
+ if (failed(extendedLhs) || failed(extendedRhs)) {
+ return failure();
+ }
+ Value genericInit = genericOp.getDpsInitOperand(0)->get();
+ SmallVector<OpFoldResult> genericMixedSizes =
+ tensor::getMixedSizes(rewriter, loc, genericInit);
+ SmallVector<OpFoldResult> batchMixedSizes =
+ applyPermutation(genericMixedSizes, invertPermutationVector(outPerm));
+ Value out = genericOp.getResults()[0];
+ auto outType = cast<RankedTensorType>(out.getType());
+ Value batchEmpty = rewriter.create<tensor::EmptyOp>(loc, batchMixedSizes,
+ outType.getElementType());
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(outType.getElementType()));
+ Value batchInit =
+ rewriter.create<linalg::FillOp>(loc, zero, batchEmpty).getResult(0);
+ OpTy batchOp = rewriter.create<OpTy>(
+ loc, TypeRange{batchInit.getType()},
+ ValueRange{extendedLhs.value(), extendedRhs.value()},
+ ValueRange{batchInit});
+ Value newOut = transposeTensor(loc, rewriter, batchOp.getResult(0), outPerm);
+ rewriter.replaceOp(genericOp, newOut);
+ return success();
+}
+
+static LogicalResult
+liftToBatchVecmat(PatternRewriter &rewriter, linalg::GenericOp genericOp,
+ linalg::ContractionDimensions contractionDims) {
+ if (contractionDims.batch.size() != 1 || contractionDims.m.size() != 0 ||
+ contractionDims.n.size() != 1 || contractionDims.k.size() != 1) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "expected batch vecmat contraction dims\n\n");
+ }
+ AffineMap vecMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(0));
+ AffineMap matMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(1));
+ AffineMap outMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+ if (vecMap.getNumResults() != 2 || matMap.getNumResults() != 3 ||
+ outMap.getNumResults() != 2) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "wrong numResults for indexing maps\n\n");
+ }
+
+ auto getResultIndex = [&](AffineMap map, int64_t dimIndex) {
+ return *(map.getResultPosition(rewriter.getAffineDimExpr(dimIndex)));
+ };
+ // Permutation from GenericOp lhs shape to BatchVecmatOp lhs shape
+ SmallVector<int64_t> vecPerm;
+ vecPerm.push_back(getResultIndex(vecMap, contractionDims.batch[0]));
+ vecPerm.push_back(getResultIndex(vecMap, contractionDims.k[0]));
+ // Permutation from GenericOp rhs shape to BatchVecmatOp rhs shape
+ SmallVector<int64_t> matPerm;
+ matPerm.push_back(getResultIndex(matMap, contractionDims.batch[0]));
+ matPerm.push_back(getResultIndex(matMap, contractionDims.k[0]));
+ matPerm.push_back(getResultIndex(matMap, contractionDims.n[0]));
+ // Permutation from BatchVecmatOp result shape to GenericOp result shape
+ SmallVector<int64_t> outPerm;
+ outPerm.push_back(getResultIndex(outMap, contractionDims.batch[0]));
+ outPerm.push_back(getResultIndex(outMap, contractionDims.n[0]));
+ outPerm = invertPermutationVector(outPerm);
+ return liftGenericOp<linalg::BatchVecmatOp>(rewriter, genericOp, vecPerm,
+ matPerm, outPerm);
+}
+
+static LogicalResult
+liftToBatchMatvec(PatternRewriter &rewriter, linalg::GenericOp genericOp,
+ linalg::ContractionDimensions contractionDims) {
+ if (contractionDims.batch.size() != 1 || contractionDims.m.size() != 1 ||
+ contractionDims.n.size() != 0 || contractionDims.k.size() != 1) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "expected batch matvec contraction dims\n\n");
+ }
+ AffineMap matMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(0));
+ AffineMap vecMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(1));
+ AffineMap outMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+ if (vecMap.getNumResults() != 2 || matMap.getNumResults() != 3 ||
+ outMap.getNumResults() != 2) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "wrong numResults for indexing maps\n\n");
+ }
+
+ auto getResultIndex = [&](AffineMap map, int64_t dimIndex) {
+ return *(map.getResultPosition(rewriter.getAffineDimExpr(dimIndex)));
+ };
+ // Permutation from GenericOp lhs shape to BatchMatvecOp lhs shape
+ SmallVector<int64_t> matPerm;
+ matPerm.push_back(getResultIndex(matMap, contractionDims.batch[0]));
+ matPerm.push_back(getResultIndex(matMap, contractionDims.m[0]));
+ matPerm.push_back(getResultIndex(matMap, contractionDims.k[0]));
+ // Permutation from GenericOp rhs shape to BatchMatvecOp rhs shape
+ SmallVector<int64_t> vecPerm;
+ vecPerm.push_back(getResultIndex(vecMap, contractionDims.batch[0]));
+ vecPerm.push_back(getResultIndex(vecMap, contractionDims.k[0]));
+ // Permutation from BatchMatvecOp result shape to GenericOp result shape
+ SmallVector<int64_t> outPerm;
+ outPerm.push_back(getResultIndex(outMap, contractionDims.batch[0]));
+ outPerm.push_back(getResultIndex(outMap, contractionDims.m[0]));
+ outPerm = invertPermutationVector(outPerm);
+ return liftGenericOp<linalg::BatchMatvecOp>(rewriter, genericOp, matPerm,
+ vecPerm, outPerm);
+}
+
+static LogicalResult
+liftToBatchMatmul(PatternRewriter &rewriter, linalg::GenericOp genericOp,
+ linalg::ContractionDimensions contractionDims) {
+ if (contractionDims.batch.size() != 1 || contractionDims.m.size() != 1 ||
+ contractionDims.n.size() != 1 || contractionDims.k.size() != 1) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "expected batch matmul contraction dims\n\n");
+ }
+ assert(contractionDims.batch.size() == 1 && contractionDims.m.size() == 1 &&
+ contractionDims.n.size() == 1 && contractionDims.k.size() == 1 &&
+ "expected batch matmul contraction dims");
+ AffineMap lhsMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(0));
+ AffineMap rhsMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(1));
+ AffineMap outMap =
+ genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+ if (lhsMap.getNumResults() != 3 || rhsMap.getNumResults() != 3 ||
+ outMap.getNumResults() != 3) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "wrong numResults for indexing maps\n\n");
+ }
+
+ auto getResultIndex = [&](AffineMap map, int64_t dimIndex) {
+ return *(map.getResultPosition(rewriter.getAffineDimExpr(dimIndex)));
+ };
+ // Permutation from GenericOp lhs shape to BatchMatmulOp lhs shape
+ SmallVector<int64_t> lhsPerm;
+ lhsPerm.push_back(getResultIndex(lhsMap, contractionDims.batch[0]));
+ lhsPerm.push_back(getResultIndex(lhsMap, contractionDims.m[0]));
+ lhsPerm.push_back(getResultIndex(lhsMap, contractionDims.k[0]));
+ // Permutation from GenericOp rhs shape to BatchMatmulOp rhs shape
+ SmallVector<int64_t> rhsPerm;
+ rhsPerm.push_back(getResultIndex(rhsMap, contractionDims.batch[0]));
+ rhsPerm.push_back(getResultIndex(rhsMap, contractionDims.k[0]));
+ rhsPerm.push_back(getResultIndex(rhsMap, contractionDims.n[0]));
+ // Permutation from BatchMatmulOp result shape to GenericOp result shape
+ SmallVector<int64_t> outPerm;
+ outPerm.push_back(getResultIndex(outMap, contractionDims.batch[0]));
+ outPerm.push_back(getResultIndex(outMap, contractionDims.m[0]));
+ outPerm.push_back(getResultIndex(outMap, contractionDims.n[0]));
+ outPerm = invertPermutationVector(outPerm);
+ return liftGenericOp<linalg::BatchMatmulOp>(rewriter, genericOp, lhsPerm,
+ rhsPerm, outPerm);
+}
+
+// Converts linalg.generic op to linalg.batch_matmul, linalg.batch_matvec,
+// or linalg.batch_vecmat, plus linalg.transpose ops on the inputs
+class LiftGenericToTransposeBatchMatmul
+ : public OpRewritePattern<linalg::GenericOp> {
+public:
+ using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ FailureOr<linalg::ContractionDimensions> contractionDims =
+ linalg::inferContractionDims(genericOp);
+ if (failed(contractionDims)) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "failed to infer contraction dims\n\n");
+ }
+
+ auto lhsType =
+ dyn_cast<RankedTensorType>(genericOp.getOperands()[0].getType());
+ auto rhsType =
+ dyn_cast<RankedTensorType>(genericOp.getOperands()[1].getType());
+ auto outType =
+ dyn_cast<RankedTensorType>(genericOp.getResults()[0].getType());
+ if (!lhsType || !rhsType || !outType) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "Operands do not have RankedTensorType\n\n");
+ }
+
+ if (failed(hasMatmulBody(rewriter, genericOp))) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "genericOp does not have a matmul body\n\n");
+ }
+
+ // TODO(#15373) Support non-batch cases
+ if (!failed(liftToBatchVecmat(rewriter, genericOp, *contractionDims))) {
+ return success();
+ };
+ if (!failed(liftToBatchMatvec(rewriter, genericOp, *contractionDims))) {
+ return success();
+ };
+ if (!failed(liftToBatchMatmul(rewriter, genericOp, *contractionDims))) {
+ return success();
+ };
+ return failure();
+ }
+};
+
+struct LiftGenericToTransposeBatchMatmulPass
+ : public LiftGenericToTransposeBatchMatmulBase<
+ LiftGenericToTransposeBatchMatmulPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<LiftGenericToTransposeBatchMatmul>(context);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> createLiftGenericToTransposeBatchMatmulPass() {
+ return std::make_unique<LiftGenericToTransposeBatchMatmulPass>();
+}
+
+} // namespace GlobalOptimization
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index f75f916..f56b0c3 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -54,7 +54,8 @@
.addPass(createRemoveZeroExtentTensorsPass)
.addPass(createDetachElementwiseFromNamedOpsPass)
.addPass(mlir::createLinalgNamedOpConversionPass)
- .addPass(createConvert1X1FilterConv2DToMatmulPass);
+ .addPass(createConvert1X1FilterConv2DToMatmulPass)
+ .addPass(createLiftGenericToTransposeBatchMatmulPass);
mainPassManager.addPass(createEraseUnusedLinalgOperands());
// Expand tensor shapes into SSA values and optimize the whole program.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
index 37f9898..bd6af04 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
@@ -75,6 +75,10 @@
// Sets encoding for tensors to allow tiled execution of operations.
std::unique_ptr<Pass> createSetEncodingPass();
+// Convert linalg.generic ops to linalg.batch_matmul, possibly with transposes
+// on operands/result.
+std::unique_ptr<Pass> createLiftGenericToTransposeBatchMatmulPass();
+
void registerGlobalOptimizationPipeline();
} // namespace GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
index b7a2fe5..babe27b 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
@@ -60,4 +60,10 @@
let constructor = "mlir::iree_compiler::GlobalOptimization::createSetEncodingPass()";
}
+def LiftGenericToTransposeBatchMatmul:
+ Pass<"iree-global-opt-lift-generic-to-tranpose-batch-matmul", ""> {
+ let summary = "Convert linalg.generic ops to linalg.batch_matmul, possibly with transposes on operands/result.";
+ let constructor = "mlir::iree_compiler::GlobalOptimization::createLiftGenericToTransposeBatchMatmulPass()";
+}
+
#endif // IREE_COMPILER_GLOBALOPTIMIZATION_PASSES
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
index fe7bef2..9c7122b 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
@@ -20,6 +20,7 @@
"detach_elementwise_from_named_ops.mlir",
"expand_vectors.mlir",
"fuse_dequantization_matmul.mlir",
+ "lift_generic_to_transpose_batch_matmul.mlir",
"materialize_homogeneous_encodings.mlir",
"remove_zero_extent_tensors.mlir",
"set_encoding.mlir",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
index 561cb72..39c8459 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
@@ -18,6 +18,7 @@
"detach_elementwise_from_named_ops.mlir"
"expand_vectors.mlir"
"fuse_dequantization_matmul.mlir"
+ "lift_generic_to_transpose_batch_matmul.mlir"
"materialize_homogeneous_encodings.mlir"
"remove_zero_extent_tensors.mlir"
"set_encoding.mlir"
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/lift_generic_to_transpose_batch_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/lift_generic_to_transpose_batch_matmul.mlir
new file mode 100644
index 0000000..815f815
--- /dev/null
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/lift_generic_to_transpose_batch_matmul.mlir
@@ -0,0 +1,162 @@
+// RUN: iree-opt --iree-global-opt-lift-generic-to-tranpose-batch-matmul --canonicalize --cse --split-input-file %s | FileCheck %s
+
+module {
+ func.func @raise_batch_vecmat(%arg0: tensor<32x128xi16>, %arg1: tensor<11008x32x128xi4>) -> tensor<11008x32xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %0 = tensor.empty() : tensor<11008x32xi32>
+ %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
+ %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<32x128xi16>, tensor<11008x32x128xi4>)
+ outs(%1 : tensor<11008x32xi32>) {
+ ^bb0(%in: i16, %in_0: i4, %out: i32):
+ %3 = arith.extsi %in : i16 to i32
+ %4 = arith.extui %in_0 : i4 to i32
+ %5 = arith.muli %3, %4 : i32
+ %6 = arith.addi %5, %out : i32
+ linalg.yield %6 : i32
+ } -> tensor<11008x32xi32>
+ return %2 : tensor<11008x32xi32>
+ }
+}
+
+// CHECK: func @raise_batch_vecmat(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<32x128xi16>, %[[ARG1:.+]]: tensor<11008x32x128xi4>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0 : i32
+// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<32x11008xi32>
+// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x11008xi32>)
+// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<32x128x11008xi4>
+// CHECK-DAG: %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<11008x32x128xi4>) outs(%[[INIT1]] : tensor<32x128x11008xi4>) permutation = [1, 2, 0]
+// CHECK-DAG: %[[EXTSI:.+]] = arith.extsi %[[ARG0]] : tensor<32x128xi16> to tensor<32x128xi32>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE0]] : tensor<32x128x11008xi4> to tensor<32x128x11008xi32>
+// CHECK: %[[VECMAT:.+]] = linalg.batch_vecmat ins(%[[EXTSI]], %[[EXTUI]] : tensor<32x128xi32>, tensor<32x128x11008xi32>) outs(%[[FILL]] : tensor<32x11008xi32>)
+// CHECK: %[[INIT2:.+]] = tensor.empty() : tensor<11008x32xi32>
+// CHECK: %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[VECMAT]] : tensor<32x11008xi32>) outs(%[[INIT2]] : tensor<11008x32xi32>) permutation = [1, 0]
+// CHECK: return %[[TRANSPOSE1]]
+
+// -----
+
+module {
+ func.func @raise_batch_matvec(%arg0: tensor<11008x32x128xi4>, %arg1: tensor<128x32xi16>) -> tensor<11008x32xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %0 = tensor.empty() : tensor<11008x32xi32>
+ %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
+ %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<11008x32x128xi4>, tensor<128x32xi16>)
+ outs(%1 : tensor<11008x32xi32>) {
+ ^bb0(%in: i4, %in_0: i16, %out: i32):
+ %3 = arith.extui %in : i4 to i32
+ %4 = arith.extsi %in_0 : i16 to i32
+ %5 = arith.muli %3, %4 : i32
+ %6 = arith.addi %5, %out : i32
+ linalg.yield %6 : i32
+ } -> tensor<11008x32xi32>
+ return %2 : tensor<11008x32xi32>
+ }
+}
+
+// CHECK: func @raise_batch_matvec(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<11008x32x128xi4>, %[[ARG1:.+]]: tensor<128x32xi16>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0 : i32
+// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<32x11008xi32>
+// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x11008xi32>)
+// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<32x11008x128xi4>
+// CHECK-DAG: %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<11008x32x128xi4>) outs(%[[INIT1]] : tensor<32x11008x128xi4>) permutation = [1, 0, 2]
+// CHECK-DAG: %[[INIT2:.+]] = tensor.empty() : tensor<32x128xi16>
+// CHECK-DAG: %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<128x32xi16>) outs(%[[INIT2]] : tensor<32x128xi16>) permutation = [1, 0]
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE0]] : tensor<32x11008x128xi4> to tensor<32x11008x128xi32>
+// CHECK-DAG: %[[EXTSI:.+]] = arith.extsi %[[TRANSPOSE1]] : tensor<32x128xi16> to tensor<32x128xi32>
+// CHECK: %[[MATMUL:.+]] = linalg.batch_matvec ins(%[[EXTUI]], %[[EXTSI]] : tensor<32x11008x128xi32>, tensor<32x128xi32>) outs(%[[FILL]] : tensor<32x11008xi32>)
+// CHECK: %[[INIT3:.+]] = tensor.empty() : tensor<11008x32xi32>
+// CHECK: %[[TRANSPOSE2:.+]] = linalg.transpose ins(%[[MATMUL]] : tensor<32x11008xi32>) outs(%[[INIT3]] : tensor<11008x32xi32>) permutation = [1, 0]
+// CHECK: return %[[TRANSPOSE2]]
+
+// -----
+
+module {
+ func.func @raise_batch_matmul(%arg0: tensor<8x32x128xi16>, %arg1: tensor<11008x32x128xi4>) -> tensor<11008x32x8xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %0 = tensor.empty() : tensor<11008x32x8xi32>
+ %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32x8xi32>) -> tensor<11008x32x8xi32>
+ %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
+ ins(%arg0, %arg1 : tensor<8x32x128xi16>, tensor<11008x32x128xi4>)
+ outs(%1 : tensor<11008x32x8xi32>) {
+ ^bb0(%in: i16, %in_0: i4, %out: i32):
+ %3 = arith.extsi %in : i16 to i32
+ %4 = arith.extui %in_0 : i4 to i32
+ %5 = arith.muli %3, %4 : i32
+ %6 = arith.addi %5, %out : i32
+ linalg.yield %6 : i32
+ } -> tensor<11008x32x8xi32>
+ return %2 : tensor<11008x32x8xi32>
+ }
+}
+
+// CHECK: func @raise_batch_matmul(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<8x32x128xi16>, %[[ARG1:.+]]: tensor<11008x32x128xi4>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0 : i32
+// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<32x8x11008xi32>
+// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x8x11008xi32>)
+// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<32x8x128xi16>
+// CHECK-DAG: %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<8x32x128xi16>) outs(%[[INIT1]] : tensor<32x8x128xi16>) permutation = [1, 0, 2]
+// CHECK-DAG: %[[INIT2:.+]] = tensor.empty() : tensor<32x128x11008xi4>
+// CHECK-DAG: %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<11008x32x128xi4>) outs(%[[INIT2]] : tensor<32x128x11008xi4>) permutation = [1, 2, 0]
+// CHECK-DAG: %[[EXTSI:.+]] = arith.extsi %[[TRANSPOSE0]] : tensor<32x8x128xi16> to tensor<32x8x128xi32>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE1]] : tensor<32x128x11008xi4> to tensor<32x128x11008xi32>
+// CHECK: %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[EXTSI]], %[[EXTUI]] : tensor<32x8x128xi32>, tensor<32x128x11008xi32>) outs(%[[FILL]] : tensor<32x8x11008xi32>)
+// CHECK: %[[INIT3:.+]] = tensor.empty() : tensor<11008x32x8xi32>
+// CHECK: %[[TRANSPOSE2:.+]] = linalg.transpose ins(%[[MATMUL]] : tensor<32x8x11008xi32>) outs(%[[INIT3]] : tensor<11008x32x8xi32>) permutation = [2, 0, 1]
+// CHECK: return %[[TRANSPOSE2]]
+
+// -----
+
+module {
+ func.func @raise_batch_matmul_dyn(%arg0: tensor<8x?x128xi16>, %arg1: tensor<11008x?x128xi4>) -> tensor<11008x?x8xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : index
+ %dim = tensor.dim %arg0, %c1 : tensor<8x?x128xi16>
+ %0 = tensor.empty(%dim) : tensor<11008x?x8xi32>
+ %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x?x8xi32>) -> tensor<11008x?x8xi32>
+ %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
+ ins(%arg0, %arg1 : tensor<8x?x128xi16>, tensor<11008x?x128xi4>)
+ outs(%1 : tensor<11008x?x8xi32>) {
+ ^bb0(%in: i16, %in_0: i4, %out: i32):
+ %3 = arith.extsi %in : i16 to i32
+ %4 = arith.extui %in_0 : i4 to i32
+ %5 = arith.muli %3, %4 : i32
+ %6 = arith.addi %5, %out : i32
+ linalg.yield %6 : i32
+ } -> tensor<11008x?x8xi32>
+ return %2 : tensor<11008x?x8xi32>
+ }
+}
+
+// CHECK: func @raise_batch_matmul_dyn(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<8x?x128xi16>, %[[ARG1:.+]]: tensor<11008x?x128xi4>
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0 : i32
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<8x?x128xi16>
+// CHECK-DAG: %[[INIT0:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x8x11008xi32>
+// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<?x8x11008xi32>)
+// CHECK-DAG: %[[INIT1:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x8x128xi16>
+// CHECK-DAG: %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<8x?x128xi16>) outs(%[[INIT1]] : tensor<?x8x128xi16>) permutation = [1, 0, 2]
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<11008x?x128xi4>
+// CHECK-DAG: %[[INIT2:.+]] = tensor.empty(%[[DIM1]]) : tensor<?x128x11008xi4>
+// CHECK-DAG: %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<11008x?x128xi4>) outs(%[[INIT2]] : tensor<?x128x11008xi4>) permutation = [1, 2, 0]
+// CHECK-DAG: %[[EXTSI:.+]] = arith.extsi %[[TRANSPOSE0]] : tensor<?x8x128xi16> to tensor<?x8x128xi32>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE1]] : tensor<?x128x11008xi4> to tensor<?x128x11008xi32>
+// CHECK: %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[EXTSI]], %[[EXTUI]] : tensor<?x8x128xi32>, tensor<?x128x11008xi32>) outs(%[[FILL]] : tensor<?x8x11008xi32>)
+// CHECK: %[[INIT3:.+]] = tensor.empty(%[[DIM0]]) : tensor<11008x?x8xi32>
+// CHECK: %[[TRANSPOSE2:.+]] = linalg.transpose ins(%[[MATMUL]] : tensor<?x8x11008xi32>) outs(%[[INIT3]] : tensor<11008x?x8xi32>) permutation = [2, 0, 1]
+// CHECK: return %[[TRANSPOSE2]]