[GlobalOptimization] Lift `linalg.generic` ops to `linalg.batch_matmul/linalg.batch_vecmat/linalg.batch_matvec` (#15339)

This adds a pass that lifts generalized batch matmul `linalg.generic` ops 
into the corresponding linalg name ops. Any element type promotion in 
the `linalg.generic` op is rewritten as a producer op, and the element 
types of the new linalg named op will be the already-promoted type. This 
allows having ukernels with unsigned types, since the unsignedness of 
the inputs can be inferred from the producer op.

---------

Co-authored-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
index 7bb6437..5abc9b4 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel
@@ -49,6 +49,7 @@
         "EraseUnusedLinalgOperands.cpp",
         "ExpandVectors.cpp",
         "FuseDequantizationMatmul.cpp",
+        "LiftGenericToTransposeBatchMatmul.cpp",
         "MaterializeHomogeneousEncodings.cpp",
         "Passes.cpp",
         "RemoveZeroExtentTensors.cpp",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
index ea706e1..4fe1abe 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt
@@ -45,6 +45,7 @@
     "EraseUnusedLinalgOperands.cpp"
     "ExpandVectors.cpp"
     "FuseDequantizationMatmul.cpp"
+    "LiftGenericToTransposeBatchMatmul.cpp"
     "MaterializeHomogeneousEncodings.cpp"
     "Passes.cpp"
     "RemoveZeroExtentTensors.cpp"
diff --git a/compiler/src/iree/compiler/GlobalOptimization/LiftGenericToTransposeBatchMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/LiftGenericToTransposeBatchMatmul.cpp
new file mode 100644
index 0000000..a1b911b
--- /dev/null
+++ b/compiler/src/iree/compiler/GlobalOptimization/LiftGenericToTransposeBatchMatmul.cpp
@@ -0,0 +1,410 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/GlobalOptimization/PassDetail.h"
+#include "iree/compiler/GlobalOptimization/Passes.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <sstream>
+
+#define DEBUG_TYPE "iree-global-opt-lift-generic-to-tranpose-batch-matmul"
+
+namespace mlir {
+namespace iree_compiler {
+namespace GlobalOptimization {
+
+namespace {
+
+bool isCastOfBlockArgument(Operation *op) {
+  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
+         isa<BlockArgument>(op->getOperand(0));
+}
+
+bool isCastOrInputBlockArgument(Value input, int64_t numInputs) {
+  if (!input.isa<BlockArgument>()) {
+    Operation *castOp0 = input.getDefiningOp();
+    if (!castOp0 || !isCastOfBlockArgument(castOp0)) {
+      return false;
+    }
+    return castOp0->getOperand(0).cast<BlockArgument>().getArgNumber() !=
+           numInputs;
+  } else {
+    return input.cast<BlockArgument>().getArgNumber() != numInputs;
+  }
+}
+
+static bool isBlockArgumentAtIndex(Value input, int64_t index) {
+  return input.isa<BlockArgument>() &&
+         input.cast<BlockArgument>().getArgNumber() == index;
+}
+
+// Returns true if the linalg::GenericOp has a body like a matmul. This
+// does not check the indexing maps
+//
+// This function looks for a body like:
+// ```mlir
+// ^bb0(%in: !lhs, %in_0: !rhs, %out: !out):
+//   %3 = arith.extui %in : !lhs to !out
+//   %4 = arith.extsi %in_0 : !rhs to !out
+//   %5 = arith.muli %3, %4 : !out
+//   %6 = arith.addi %5, %out : !out
+//   linalg.yield %6 : !out
+// ```
+// Ensuring the following conditions:
+//    1) linalg.yield result comes from an arith.add op, accumulating on %out
+//    2) The other arith.add operand comes from arith.mul
+//    3) Both arith.mul operands are either block input arguments, or produced
+//       by a `CastOpInterface` of a block input argument
+static LogicalResult hasMatmulBody(RewriterBase &rewriter,
+                                   linalg::GenericOp genericOp) {
+  int numInputs = genericOp.getNumDpsInputs();
+  if (numInputs != 2) {
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "op does not have exactly 2 inputs\n");
+  }
+  int numOutputs = genericOp.getNumDpsInits();
+  if (numOutputs != 1) {
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "op does not have exactly 1 output\n");
+  }
+
+  auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
+  Value yieldedValue = yieldOp->getOperand(0);
+
+  // Check that yielded value is an arith.add op, and is accumulating
+  Operation *addOp = yieldedValue.getDefiningOp();
+  if (!addOp) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "linalg.yield operand has no defining op\n");
+  }
+  if (!isa<arith::AddFOp, arith::AddIOp>(*addOp)) {
+    return rewriter.notifyMatchFailure(genericOp, "no arith.add body op\n");
+  }
+  Value add0 = addOp->getOperand(0);
+  Value add1 = addOp->getOperand(1);
+  if (!isBlockArgumentAtIndex(add0, numInputs) &&
+      !isBlockArgumentAtIndex(add1, numInputs)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "arith.add body op not accumulating on output\n");
+  }
+
+  // Check that the producer of the add is an arith.mul op
+  Operation *mulOp =
+      add0.isa<BlockArgument>() ? add1.getDefiningOp() : add0.getDefiningOp();
+  if (!mulOp) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "arith.add operand has no defining op\n");
+  }
+  if (!isa<arith::MulFOp, arith::MulIOp>(*mulOp)) {
+    return rewriter.notifyMatchFailure(genericOp, "no arith.mul body op\n");
+  }
+
+  // Check that non block args come from arith.ext ops
+  if (!isCastOrInputBlockArgument(mulOp->getOperand(0), numInputs) ||
+      !isCastOrInputBlockArgument(mulOp->getOperand(1), numInputs)) {
+    return rewriter.notifyMatchFailure(
+        genericOp,
+        "arith.mul operands are not CastOpInterface or BlockArgument\n");
+  }
+  return success();
+}
+
+static Value transposeTensor(Location loc, PatternRewriter &rewriter,
+                             Value input, SmallVector<int64_t> perm) {
+  if (!perm.size()) {
+    return input;
+  }
+  if (llvm::all_of(llvm::enumerate(perm),
+                   [](auto idx) { return idx.index() == idx.value(); })) {
+    return input;
+  }
+  auto inputType = cast<RankedTensorType>(input.getType());
+  SmallVector<OpFoldResult> inputMixedSizes =
+      tensor::getMixedSizes(rewriter, loc, input);
+  SmallVector<OpFoldResult> newInputMixedSizes =
+      applyPermutation(inputMixedSizes, perm);
+  Value init = rewriter.create<tensor::EmptyOp>(loc, newInputMixedSizes,
+                                                inputType.getElementType());
+  return rewriter.create<linalg::TransposeOp>(loc, input, init, perm)
+      .getResults()[0];
+}
+
+static FailureOr<Value> castTensor(Location loc, PatternRewriter &rewriter,
+                                   linalg::GenericOp genericOp,
+                                   int64_t inputIdx, Value input) {
+  Value output = genericOp.getResults()[0];
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto outputType = cast<RankedTensorType>(output.getType());
+  if (inputType.getElementType() == outputType.getElementType()) {
+    return input;
+  }
+  auto castedType =
+      RankedTensorType::get(inputType.getShape(), outputType.getElementType());
+  for (auto bodyOp : genericOp.getBody()->getOps<CastOpInterface>()) {
+    Value castInput = bodyOp->getOperand(0);
+    if (isBlockArgumentAtIndex(castInput, inputIdx)) {
+      return rewriter
+          .create(bodyOp->getLoc(), bodyOp->getName().getIdentifier(), input,
+                  castedType, bodyOp->getAttrs())
+          ->getResult(0);
+    }
+  }
+  return failure();
+}
+
+template <typename OpTy>
+static LogicalResult
+liftGenericOp(PatternRewriter &rewriter, linalg::GenericOp genericOp,
+              SmallVector<int64_t> lhsPerm, SmallVector<int64_t> rhsPerm,
+              SmallVector<int64_t> outPerm) {
+  static_assert((std::is_same<OpTy, linalg::BatchVecmatOp>::value ||
+                 std::is_same<OpTy, linalg::BatchMatvecOp>::value ||
+                 std::is_same<OpTy, linalg::BatchMatmulOp>::value) &&
+                "expected only BatchVecmatOp, BatchMatvecOp, or BatchMatmulOp");
+  Location loc = genericOp.getLoc();
+  Value transposedLhs =
+      transposeTensor(loc, rewriter, genericOp.getInputs()[0], lhsPerm);
+  Value transposedRhs =
+      transposeTensor(loc, rewriter, genericOp.getInputs()[1], rhsPerm);
+  FailureOr<Value> extendedLhs =
+      castTensor(loc, rewriter, genericOp, 0, transposedLhs);
+  FailureOr<Value> extendedRhs =
+      castTensor(loc, rewriter, genericOp, 1, transposedRhs);
+  if (failed(extendedLhs) || failed(extendedRhs)) {
+    return failure();
+  }
+  Value genericInit = genericOp.getDpsInitOperand(0)->get();
+  SmallVector<OpFoldResult> genericMixedSizes =
+      tensor::getMixedSizes(rewriter, loc, genericInit);
+  SmallVector<OpFoldResult> batchMixedSizes =
+      applyPermutation(genericMixedSizes, invertPermutationVector(outPerm));
+  Value out = genericOp.getResults()[0];
+  auto outType = cast<RankedTensorType>(out.getType());
+  Value batchEmpty = rewriter.create<tensor::EmptyOp>(loc, batchMixedSizes,
+                                                      outType.getElementType());
+  Value zero = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(outType.getElementType()));
+  Value batchInit =
+      rewriter.create<linalg::FillOp>(loc, zero, batchEmpty).getResult(0);
+  OpTy batchOp = rewriter.create<OpTy>(
+      loc, TypeRange{batchInit.getType()},
+      ValueRange{extendedLhs.value(), extendedRhs.value()},
+      ValueRange{batchInit});
+  Value newOut = transposeTensor(loc, rewriter, batchOp.getResult(0), outPerm);
+  rewriter.replaceOp(genericOp, newOut);
+  return success();
+}
+
+static LogicalResult
+liftToBatchVecmat(PatternRewriter &rewriter, linalg::GenericOp genericOp,
+                  linalg::ContractionDimensions contractionDims) {
+  if (contractionDims.batch.size() != 1 || contractionDims.m.size() != 0 ||
+      contractionDims.n.size() != 1 || contractionDims.k.size() != 1) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "expected batch vecmat contraction dims\n\n");
+  }
+  AffineMap vecMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(0));
+  AffineMap matMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(1));
+  AffineMap outMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+  if (vecMap.getNumResults() != 2 || matMap.getNumResults() != 3 ||
+      outMap.getNumResults() != 2) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "wrong numResults for indexing maps\n\n");
+  }
+
+  auto getResultIndex = [&](AffineMap map, int64_t dimIndex) {
+    return *(map.getResultPosition(rewriter.getAffineDimExpr(dimIndex)));
+  };
+  // Permutation from GenericOp lhs shape to BatchVecmatOp lhs shape
+  SmallVector<int64_t> vecPerm;
+  vecPerm.push_back(getResultIndex(vecMap, contractionDims.batch[0]));
+  vecPerm.push_back(getResultIndex(vecMap, contractionDims.k[0]));
+  // Permutation from GenericOp rhs shape to BatchVecmatOp rhs shape
+  SmallVector<int64_t> matPerm;
+  matPerm.push_back(getResultIndex(matMap, contractionDims.batch[0]));
+  matPerm.push_back(getResultIndex(matMap, contractionDims.k[0]));
+  matPerm.push_back(getResultIndex(matMap, contractionDims.n[0]));
+  // Permutation from BatchVecmatOp result shape to GenericOp result shape
+  SmallVector<int64_t> outPerm;
+  outPerm.push_back(getResultIndex(outMap, contractionDims.batch[0]));
+  outPerm.push_back(getResultIndex(outMap, contractionDims.n[0]));
+  outPerm = invertPermutationVector(outPerm);
+  return liftGenericOp<linalg::BatchVecmatOp>(rewriter, genericOp, vecPerm,
+                                              matPerm, outPerm);
+}
+
+static LogicalResult
+liftToBatchMatvec(PatternRewriter &rewriter, linalg::GenericOp genericOp,
+                  linalg::ContractionDimensions contractionDims) {
+  if (contractionDims.batch.size() != 1 || contractionDims.m.size() != 1 ||
+      contractionDims.n.size() != 0 || contractionDims.k.size() != 1) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "expected batch matvec contraction dims\n\n");
+  }
+  AffineMap matMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(0));
+  AffineMap vecMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(1));
+  AffineMap outMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+  if (vecMap.getNumResults() != 2 || matMap.getNumResults() != 3 ||
+      outMap.getNumResults() != 2) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "wrong numResults for indexing maps\n\n");
+  }
+
+  auto getResultIndex = [&](AffineMap map, int64_t dimIndex) {
+    return *(map.getResultPosition(rewriter.getAffineDimExpr(dimIndex)));
+  };
+  // Permutation from GenericOp lhs shape to BatchMatvecOp lhs shape
+  SmallVector<int64_t> matPerm;
+  matPerm.push_back(getResultIndex(matMap, contractionDims.batch[0]));
+  matPerm.push_back(getResultIndex(matMap, contractionDims.m[0]));
+  matPerm.push_back(getResultIndex(matMap, contractionDims.k[0]));
+  // Permutation from GenericOp rhs shape to BatchMatvecOp rhs shape
+  SmallVector<int64_t> vecPerm;
+  vecPerm.push_back(getResultIndex(vecMap, contractionDims.batch[0]));
+  vecPerm.push_back(getResultIndex(vecMap, contractionDims.k[0]));
+  // Permutation from BatchMatvecOp result shape to GenericOp result shape
+  SmallVector<int64_t> outPerm;
+  outPerm.push_back(getResultIndex(outMap, contractionDims.batch[0]));
+  outPerm.push_back(getResultIndex(outMap, contractionDims.m[0]));
+  outPerm = invertPermutationVector(outPerm);
+  return liftGenericOp<linalg::BatchMatvecOp>(rewriter, genericOp, matPerm,
+                                              vecPerm, outPerm);
+}
+
+static LogicalResult
+liftToBatchMatmul(PatternRewriter &rewriter, linalg::GenericOp genericOp,
+                  linalg::ContractionDimensions contractionDims) {
+  if (contractionDims.batch.size() != 1 || contractionDims.m.size() != 1 ||
+      contractionDims.n.size() != 1 || contractionDims.k.size() != 1) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "expected batch matmul contraction dims\n\n");
+  }
+  assert(contractionDims.batch.size() == 1 && contractionDims.m.size() == 1 &&
+         contractionDims.n.size() == 1 && contractionDims.k.size() == 1 &&
+         "expected batch matmul contraction dims");
+  AffineMap lhsMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(0));
+  AffineMap rhsMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(1));
+  AffineMap outMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+  if (lhsMap.getNumResults() != 3 || rhsMap.getNumResults() != 3 ||
+      outMap.getNumResults() != 3) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "wrong numResults for indexing maps\n\n");
+  }
+
+  auto getResultIndex = [&](AffineMap map, int64_t dimIndex) {
+    return *(map.getResultPosition(rewriter.getAffineDimExpr(dimIndex)));
+  };
+  // Permutation from GenericOp lhs shape to BatchMatmulOp lhs shape
+  SmallVector<int64_t> lhsPerm;
+  lhsPerm.push_back(getResultIndex(lhsMap, contractionDims.batch[0]));
+  lhsPerm.push_back(getResultIndex(lhsMap, contractionDims.m[0]));
+  lhsPerm.push_back(getResultIndex(lhsMap, contractionDims.k[0]));
+  // Permutation from GenericOp rhs shape to BatchMatmulOp rhs shape
+  SmallVector<int64_t> rhsPerm;
+  rhsPerm.push_back(getResultIndex(rhsMap, contractionDims.batch[0]));
+  rhsPerm.push_back(getResultIndex(rhsMap, contractionDims.k[0]));
+  rhsPerm.push_back(getResultIndex(rhsMap, contractionDims.n[0]));
+  // Permutation from BatchMatmulOp result shape to GenericOp result shape
+  SmallVector<int64_t> outPerm;
+  outPerm.push_back(getResultIndex(outMap, contractionDims.batch[0]));
+  outPerm.push_back(getResultIndex(outMap, contractionDims.m[0]));
+  outPerm.push_back(getResultIndex(outMap, contractionDims.n[0]));
+  outPerm = invertPermutationVector(outPerm);
+  return liftGenericOp<linalg::BatchMatmulOp>(rewriter, genericOp, lhsPerm,
+                                              rhsPerm, outPerm);
+}
+
+// Converts linalg.generic op to linalg.batch_matmul, linalg.batch_matvec,
+// or linalg.batch_vecmat, plus linalg.transpose ops on the inputs
+class LiftGenericToTransposeBatchMatmul
+    : public OpRewritePattern<linalg::GenericOp> {
+public:
+  using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<linalg::ContractionDimensions> contractionDims =
+        linalg::inferContractionDims(genericOp);
+    if (failed(contractionDims)) {
+      return rewriter.notifyMatchFailure(
+          genericOp, "failed to infer contraction dims\n\n");
+    }
+
+    auto lhsType =
+        dyn_cast<RankedTensorType>(genericOp.getOperands()[0].getType());
+    auto rhsType =
+        dyn_cast<RankedTensorType>(genericOp.getOperands()[1].getType());
+    auto outType =
+        dyn_cast<RankedTensorType>(genericOp.getResults()[0].getType());
+    if (!lhsType || !rhsType || !outType) {
+      return rewriter.notifyMatchFailure(
+          genericOp, "Operands do not have RankedTensorType\n\n");
+    }
+
+    if (failed(hasMatmulBody(rewriter, genericOp))) {
+      return rewriter.notifyMatchFailure(
+          genericOp, "genericOp does not have a matmul body\n\n");
+    }
+
+    // TODO(#15373) Support non-batch cases
+    if (!failed(liftToBatchVecmat(rewriter, genericOp, *contractionDims))) {
+      return success();
+    };
+    if (!failed(liftToBatchMatvec(rewriter, genericOp, *contractionDims))) {
+      return success();
+    };
+    if (!failed(liftToBatchMatmul(rewriter, genericOp, *contractionDims))) {
+      return success();
+    };
+    return failure();
+  }
+};
+
+struct LiftGenericToTransposeBatchMatmulPass
+    : public LiftGenericToTransposeBatchMatmulBase<
+          LiftGenericToTransposeBatchMatmulPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(&getContext());
+    patterns.insert<LiftGenericToTransposeBatchMatmul>(context);
+    if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                            std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> createLiftGenericToTransposeBatchMatmulPass() {
+  return std::make_unique<LiftGenericToTransposeBatchMatmulPass>();
+}
+
+} // namespace GlobalOptimization
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index f75f916..f56b0c3 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -54,7 +54,8 @@
       .addPass(createRemoveZeroExtentTensorsPass)
       .addPass(createDetachElementwiseFromNamedOpsPass)
       .addPass(mlir::createLinalgNamedOpConversionPass)
-      .addPass(createConvert1X1FilterConv2DToMatmulPass);
+      .addPass(createConvert1X1FilterConv2DToMatmulPass)
+      .addPass(createLiftGenericToTransposeBatchMatmulPass);
   mainPassManager.addPass(createEraseUnusedLinalgOperands());
 
   // Expand tensor shapes into SSA values and optimize the whole program.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.h b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
index 37f9898..bd6af04 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.h
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.h
@@ -75,6 +75,10 @@
 // Sets encoding for tensors to allow tiled execution of operations.
 std::unique_ptr<Pass> createSetEncodingPass();
 
+// Convert linalg.generic ops to linalg.batch_matmul, possibly with transposes
+// on operands/result.
+std::unique_ptr<Pass> createLiftGenericToTransposeBatchMatmulPass();
+
 void registerGlobalOptimizationPipeline();
 
 } // namespace GlobalOptimization
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.td b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
index b7a2fe5..babe27b 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.td
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.td
@@ -60,4 +60,10 @@
   let constructor = "mlir::iree_compiler::GlobalOptimization::createSetEncodingPass()";
 }
 
+def LiftGenericToTransposeBatchMatmul:
+    Pass<"iree-global-opt-lift-generic-to-tranpose-batch-matmul", ""> {
+  let summary = "Convert linalg.generic ops to linalg.batch_matmul, possibly with transposes on operands/result.";
+  let constructor = "mlir::iree_compiler::GlobalOptimization::createLiftGenericToTransposeBatchMatmulPass()";
+}
+
 #endif // IREE_COMPILER_GLOBALOPTIMIZATION_PASSES
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
index fe7bef2..9c7122b 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/BUILD.bazel
@@ -20,6 +20,7 @@
             "detach_elementwise_from_named_ops.mlir",
             "expand_vectors.mlir",
             "fuse_dequantization_matmul.mlir",
+            "lift_generic_to_transpose_batch_matmul.mlir",
             "materialize_homogeneous_encodings.mlir",
             "remove_zero_extent_tensors.mlir",
             "set_encoding.mlir",
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
index 561cb72..39c8459 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/CMakeLists.txt
@@ -18,6 +18,7 @@
     "detach_elementwise_from_named_ops.mlir"
     "expand_vectors.mlir"
     "fuse_dequantization_matmul.mlir"
+    "lift_generic_to_transpose_batch_matmul.mlir"
     "materialize_homogeneous_encodings.mlir"
     "remove_zero_extent_tensors.mlir"
     "set_encoding.mlir"
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/lift_generic_to_transpose_batch_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/lift_generic_to_transpose_batch_matmul.mlir
new file mode 100644
index 0000000..815f815
--- /dev/null
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/lift_generic_to_transpose_batch_matmul.mlir
@@ -0,0 +1,162 @@
+// RUN: iree-opt --iree-global-opt-lift-generic-to-tranpose-batch-matmul --canonicalize --cse --split-input-file %s | FileCheck %s
+
+module {
+  func.func @raise_batch_vecmat(%arg0: tensor<32x128xi16>, %arg1: tensor<11008x32x128xi4>) -> tensor<11008x32xi32> {
+    %c0_i32 = arith.constant 0 : i32
+    %0 = tensor.empty() : tensor<11008x32xi32>
+    %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
+    %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, 
+                                          affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 
+                                          affine_map<(d0, d1, d2) -> (d0, d1)>], 
+                         iterator_types = ["parallel", "parallel", "reduction"]} 
+                         ins(%arg0, %arg1 : tensor<32x128xi16>, tensor<11008x32x128xi4>) 
+                         outs(%1 : tensor<11008x32xi32>) {
+    ^bb0(%in: i16, %in_0: i4, %out: i32):
+      %3 = arith.extsi %in : i16 to i32
+      %4 = arith.extui %in_0 : i4 to i32
+      %5 = arith.muli %3, %4 : i32
+      %6 = arith.addi %5, %out : i32
+      linalg.yield %6 : i32
+    } -> tensor<11008x32xi32>
+    return %2 : tensor<11008x32xi32>
+  }
+}
+
+//      CHECK:  func @raise_batch_vecmat(
+// CHECK-SAME:  %[[ARG0:.+]]: tensor<32x128xi16>, %[[ARG1:.+]]: tensor<11008x32x128xi4>
+//  CHECK-DAG:  %[[CST:.+]] = arith.constant 0 : i32
+//  CHECK-DAG:  %[[INIT0:.+]] = tensor.empty() : tensor<32x11008xi32>
+//  CHECK-DAG:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x11008xi32>)
+//  CHECK-DAG:  %[[INIT1:.+]] = tensor.empty() : tensor<32x128x11008xi4>
+//  CHECK-DAG:  %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<11008x32x128xi4>) outs(%[[INIT1]] : tensor<32x128x11008xi4>) permutation = [1, 2, 0]
+//  CHECK-DAG:  %[[EXTSI:.+]] = arith.extsi %[[ARG0]] : tensor<32x128xi16> to tensor<32x128xi32>
+//  CHECK-DAG:  %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE0]] : tensor<32x128x11008xi4> to tensor<32x128x11008xi32>
+//      CHECK:  %[[VECMAT:.+]] = linalg.batch_vecmat ins(%[[EXTSI]], %[[EXTUI]] : tensor<32x128xi32>, tensor<32x128x11008xi32>) outs(%[[FILL]] : tensor<32x11008xi32>)
+//      CHECK:  %[[INIT2:.+]] = tensor.empty() : tensor<11008x32xi32>
+//      CHECK:  %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[VECMAT]] : tensor<32x11008xi32>) outs(%[[INIT2]] : tensor<11008x32xi32>) permutation = [1, 0]
+//      CHECK:  return %[[TRANSPOSE1]]
+
+// -----
+
+module {
+  func.func @raise_batch_matvec(%arg0: tensor<11008x32x128xi4>, %arg1: tensor<128x32xi16>) -> tensor<11008x32xi32> {
+    %c0_i32 = arith.constant 0 : i32
+    %0 = tensor.empty() : tensor<11008x32xi32>
+    %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
+    %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 
+                                          affine_map<(d0, d1, d2) -> (d2, d1)>, 
+                                          affine_map<(d0, d1, d2) -> (d0, d1)>], 
+                         iterator_types = ["parallel", "parallel", "reduction"]} 
+                         ins(%arg0, %arg1 : tensor<11008x32x128xi4>, tensor<128x32xi16>) 
+                         outs(%1 : tensor<11008x32xi32>) {
+    ^bb0(%in: i4, %in_0: i16, %out: i32):
+      %3 = arith.extui %in : i4 to i32
+      %4 = arith.extsi %in_0 : i16 to i32
+      %5 = arith.muli %3, %4 : i32
+      %6 = arith.addi %5, %out : i32
+      linalg.yield %6 : i32
+    } -> tensor<11008x32xi32>
+    return %2 : tensor<11008x32xi32>
+  }
+}
+
+//      CHECK:  func @raise_batch_matvec(
+// CHECK-SAME:  %[[ARG0:.+]]: tensor<11008x32x128xi4>, %[[ARG1:.+]]: tensor<128x32xi16>
+//  CHECK-DAG:  %[[CST:.+]] = arith.constant 0 : i32
+//  CHECK-DAG:  %[[INIT0:.+]] = tensor.empty() : tensor<32x11008xi32>
+//  CHECK-DAG:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x11008xi32>)
+//  CHECK-DAG:  %[[INIT1:.+]] = tensor.empty() : tensor<32x11008x128xi4>
+//  CHECK-DAG:  %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<11008x32x128xi4>) outs(%[[INIT1]] : tensor<32x11008x128xi4>) permutation = [1, 0, 2]
+//  CHECK-DAG:  %[[INIT2:.+]] = tensor.empty() : tensor<32x128xi16>
+//  CHECK-DAG:  %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<128x32xi16>) outs(%[[INIT2]] : tensor<32x128xi16>) permutation = [1, 0]
+//  CHECK-DAG:  %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE0]] : tensor<32x11008x128xi4> to tensor<32x11008x128xi32>
+//  CHECK-DAG:  %[[EXTSI:.+]] = arith.extsi %[[TRANSPOSE1]] : tensor<32x128xi16> to tensor<32x128xi32>
+//      CHECK:  %[[MATMUL:.+]] = linalg.batch_matvec ins(%[[EXTUI]], %[[EXTSI]] : tensor<32x11008x128xi32>, tensor<32x128xi32>) outs(%[[FILL]] : tensor<32x11008xi32>)
+//      CHECK:  %[[INIT3:.+]] = tensor.empty() : tensor<11008x32xi32>
+//      CHECK:  %[[TRANSPOSE2:.+]] = linalg.transpose ins(%[[MATMUL]] : tensor<32x11008xi32>) outs(%[[INIT3]] : tensor<11008x32xi32>) permutation = [1, 0]
+//      CHECK:  return %[[TRANSPOSE2]]
+
+// -----
+
+module {
+  func.func @raise_batch_matmul(%arg0: tensor<8x32x128xi16>, %arg1: tensor<11008x32x128xi4>) -> tensor<11008x32x8xi32> {
+    %c0_i32 = arith.constant 0 : i32
+    %0 = tensor.empty() : tensor<11008x32x8xi32>
+    %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32x8xi32>) -> tensor<11008x32x8xi32>
+    %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>, 
+                                          affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, 
+                                          affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>], 
+                         iterator_types = ["parallel", "parallel", "reduction", "parallel"]} 
+                         ins(%arg0, %arg1 : tensor<8x32x128xi16>, tensor<11008x32x128xi4>) 
+                         outs(%1 : tensor<11008x32x8xi32>) {
+    ^bb0(%in: i16, %in_0: i4, %out: i32):
+      %3 = arith.extsi %in : i16 to i32
+      %4 = arith.extui %in_0 : i4 to i32
+      %5 = arith.muli %3, %4 : i32
+      %6 = arith.addi %5, %out : i32
+      linalg.yield %6 : i32
+    } -> tensor<11008x32x8xi32>
+    return %2 : tensor<11008x32x8xi32>
+  }
+}
+
+//      CHECK:  func @raise_batch_matmul(
+// CHECK-SAME:  %[[ARG0:.+]]: tensor<8x32x128xi16>, %[[ARG1:.+]]: tensor<11008x32x128xi4>
+//  CHECK-DAG:  %[[CST:.+]] = arith.constant 0 : i32
+//  CHECK-DAG:  %[[INIT0:.+]] = tensor.empty() : tensor<32x8x11008xi32>
+//  CHECK-DAG:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<32x8x11008xi32>)
+//  CHECK-DAG:  %[[INIT1:.+]] = tensor.empty() : tensor<32x8x128xi16>
+//  CHECK-DAG:  %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<8x32x128xi16>) outs(%[[INIT1]] : tensor<32x8x128xi16>) permutation = [1, 0, 2]
+//  CHECK-DAG:  %[[INIT2:.+]] = tensor.empty() : tensor<32x128x11008xi4>
+//  CHECK-DAG:  %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<11008x32x128xi4>) outs(%[[INIT2]] : tensor<32x128x11008xi4>) permutation = [1, 2, 0]
+//  CHECK-DAG:  %[[EXTSI:.+]] = arith.extsi %[[TRANSPOSE0]] : tensor<32x8x128xi16> to tensor<32x8x128xi32>
+//  CHECK-DAG:  %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE1]] : tensor<32x128x11008xi4> to tensor<32x128x11008xi32>
+//      CHECK:  %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[EXTSI]], %[[EXTUI]] : tensor<32x8x128xi32>, tensor<32x128x11008xi32>) outs(%[[FILL]] : tensor<32x8x11008xi32>)
+//      CHECK:  %[[INIT3:.+]] = tensor.empty() : tensor<11008x32x8xi32>
+//      CHECK:  %[[TRANSPOSE2:.+]] = linalg.transpose ins(%[[MATMUL]] : tensor<32x8x11008xi32>) outs(%[[INIT3]] : tensor<11008x32x8xi32>) permutation = [2, 0, 1]
+//      CHECK:  return %[[TRANSPOSE2]]
+
+// -----
+
+module {
+  func.func @raise_batch_matmul_dyn(%arg0: tensor<8x?x128xi16>, %arg1: tensor<11008x?x128xi4>) -> tensor<11008x?x8xi32> {
+    %c0_i32 = arith.constant 0 : i32
+    %c1 = arith.constant 1 : index
+    %dim = tensor.dim %arg0, %c1 : tensor<8x?x128xi16>
+    %0 = tensor.empty(%dim) : tensor<11008x?x8xi32>
+    %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x?x8xi32>) -> tensor<11008x?x8xi32>
+    %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>, 
+                                          affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, 
+                                          affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>], 
+                         iterator_types = ["parallel", "parallel", "reduction", "parallel"]} 
+                         ins(%arg0, %arg1 : tensor<8x?x128xi16>, tensor<11008x?x128xi4>) 
+                         outs(%1 : tensor<11008x?x8xi32>) {
+    ^bb0(%in: i16, %in_0: i4, %out: i32):
+      %3 = arith.extsi %in : i16 to i32
+      %4 = arith.extui %in_0 : i4 to i32
+      %5 = arith.muli %3, %4 : i32
+      %6 = arith.addi %5, %out : i32
+      linalg.yield %6 : i32
+    } -> tensor<11008x?x8xi32>
+    return %2 : tensor<11008x?x8xi32>
+  }
+}
+
+//      CHECK:  func @raise_batch_matmul_dyn(
+// CHECK-SAME:  %[[ARG0:.+]]: tensor<8x?x128xi16>, %[[ARG1:.+]]: tensor<11008x?x128xi4>
+//  CHECK-DAG:  %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:  %[[CST:.+]] = arith.constant 0 : i32
+//  CHECK-DAG:  %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<8x?x128xi16>
+//  CHECK-DAG:  %[[INIT0:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x8x11008xi32>
+//  CHECK-DAG:  %[[FILL:.+]] = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT0]] : tensor<?x8x11008xi32>)
+//  CHECK-DAG:  %[[INIT1:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x8x128xi16>
+//  CHECK-DAG:  %[[TRANSPOSE0:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<8x?x128xi16>) outs(%[[INIT1]] : tensor<?x8x128xi16>) permutation = [1, 0, 2]
+//  CHECK-DAG:  %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<11008x?x128xi4>
+//  CHECK-DAG:  %[[INIT2:.+]] = tensor.empty(%[[DIM1]]) : tensor<?x128x11008xi4>
+//  CHECK-DAG:  %[[TRANSPOSE1:.+]] = linalg.transpose ins(%[[ARG1]] : tensor<11008x?x128xi4>) outs(%[[INIT2]] : tensor<?x128x11008xi4>) permutation = [1, 2, 0]
+//  CHECK-DAG:  %[[EXTSI:.+]] = arith.extsi %[[TRANSPOSE0]] : tensor<?x8x128xi16> to tensor<?x8x128xi32>
+//  CHECK-DAG:  %[[EXTUI:.+]] = arith.extui %[[TRANSPOSE1]] : tensor<?x128x11008xi4> to tensor<?x128x11008xi32>
+//      CHECK:  %[[MATMUL:.+]] = linalg.batch_matmul ins(%[[EXTSI]], %[[EXTUI]] : tensor<?x8x128xi32>, tensor<?x128x11008xi32>) outs(%[[FILL]] : tensor<?x8x11008xi32>)
+//      CHECK:  %[[INIT3:.+]] = tensor.empty(%[[DIM0]]) : tensor<11008x?x8xi32>
+//      CHECK:  %[[TRANSPOSE2:.+]] = linalg.transpose ins(%[[MATMUL]] : tensor<?x8x11008xi32>) outs(%[[INIT3]] : tensor<11008x?x8xi32>) permutation = [2, 0, 1]
+//      CHECK:  return %[[TRANSPOSE2]]