Remove redundant reshape checks in dot general preprocessing (#15319)

The final reshape is not required to output the correct size as
canonicalizers should clean up any unneeded transforms.
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp
index 6e7ecde..fca7077 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/DotGeneralToDot.cpp
@@ -148,6 +148,68 @@
   return transposeReshape(arg, loc, contractDims, outerDims, shape, rewriter);
 }
 
+struct GeneralDotRemoveBatch final
+    : OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::stablehlo::DotGeneralOp op,
+                                PatternRewriter &rewriter) const override {
+    auto lhsTy = cast<ShapedType>(op.getLhs().getType());
+    auto rhsTy = cast<ShapedType>(op.getRhs().getType());
+    auto ty = cast<ShapedType>(op.getType());
+
+    if (!ty.hasStaticShape()) {
+      return rewriter.notifyMatchFailure(op, "does not have static shape");
+    }
+
+    auto dimNumbers = op.getDotDimensionNumbers();
+    if (dimNumbers.getLhsBatchingDimensions().size() != 1 ||
+        dimNumbers.getLhsBatchingDimensions().size() != 1) {
+      return rewriter.notifyMatchFailure(op, "non-unary batch dimension");
+    }
+
+    if (dimNumbers.getLhsBatchingDimensions().front() != 0 ||
+        dimNumbers.getRhsBatchingDimensions().front() != 0) {
+      return rewriter.notifyMatchFailure(op, "not first dim on lhs/rhs");
+    }
+
+    if (lhsTy.getDimSize(0) != 1 || rhsTy.getDimSize(0) != 1) {
+      return rewriter.notifyMatchFailure(op, "not unary batch size");
+    }
+
+    // We no longer include the batch dimension of 1.
+    llvm::SmallVector<int64_t> newLhsContractingDims;
+    for (auto dim : dimNumbers.getLhsContractingDimensions())
+      newLhsContractingDims.push_back(dim - 1);
+
+    llvm::SmallVector<int64_t> newRhsContractingDims;
+    for (auto dim : dimNumbers.getRhsContractingDimensions())
+      newRhsContractingDims.push_back(dim - 1);
+
+    auto lhs = rewriter.create<mlir::stablehlo::ReshapeOp>(
+        op.getLoc(), lhsTy.clone(lhsTy.getShape().drop_front()), op.getLhs());
+
+    auto rhs = rewriter.create<mlir::stablehlo::ReshapeOp>(
+        op.getLoc(), rhsTy.clone(rhsTy.getShape().drop_front()), op.getRhs());
+
+    auto newDimNumbers = mlir::stablehlo::DotDimensionNumbersAttr::get(
+        rewriter.getContext(),
+        /*lhsBatchingDimensions=*/{},
+        /*rhsBatchingDimensions=*/{},
+        /*lhsContractingDimensions=*/
+        newLhsContractingDims,
+        /*rhsContractingDimensions=*/
+        newRhsContractingDims);
+
+    auto dot = rewriter.create<mlir::stablehlo::DotGeneralOp>(
+        op.getLoc(), ty.clone(ty.getShape().drop_front()), lhs, rhs,
+        newDimNumbers, op.getPrecisionConfigAttr());
+    rewriter.replaceOpWithNewOp<mlir::stablehlo::ReshapeOp>(op, ty,
+                                                            dot.getResult());
+    return success();
+  }
+};
+
 struct GeneralDotConvert final
     : OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -374,7 +436,9 @@
 void populatePreprocessingDotGeneralToDotPatterns(mlir::MLIRContext *context,
                                                   RewritePatternSet *patterns,
                                                   PatternBenefit benefit) {
-  patterns->add<GeneralDotConvert, DotVectorOptimization>(context, benefit);
+  patterns
+      ->add<GeneralDotConvert, GeneralDotRemoveBatch, DotVectorOptimization>(
+          context, benefit);
 }
 
 } // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
index 30d7416..601c4e3 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
@@ -418,19 +418,17 @@
     auto lhsNewType = cast<RankedTensorType>(lhs.getType());
     auto rhsNewType = cast<RankedTensorType>(rhs.getType());
 
-    // if lhs's shape or rhs's shape has collapsed, we need reshape the result
-    bool needReshapeResult = lhsNewType.getRank() < lhsShapeType.getRank() ||
-                             rhsNewType.getRank() < rhsShapeType.getRank();
     // batching、lhs parallel、rhs parallel this order is a conversion
-    SmallVector<int64_t> newShape = {lhsNewType.getShape()[0],
-                                     lhsNewType.getShape()[1]};
+    SmallVector<int64_t, 3> newShape = {lhsNewType.getShape()[0]};
+
+    if (lhsNewType.getRank() > 2)
+      newShape.push_back(lhsNewType.getDimSize(1));
+
     if (rhsNewType.getRank() > 2)
       newShape.push_back(rhsNewType.getDimSize(2));
 
     TensorType newResultType =
-        needReshapeResult
-            ? RankedTensorType::get(newShape, resultType.getElementType())
-            : op.getType();
+        RankedTensorType::get(newShape, resultType.getElementType());
 
     auto newOp = rewriter.create<mlir::stablehlo::DotGeneralOp>(
         op.getLoc(), newResultType, lhs, rhs, dimensionNumbers,
@@ -446,10 +444,11 @@
     }
 
     Value result = newOp.getResult();
-    if (needReshapeResult) {
-      result = rewriter.create<mlir::stablehlo::ReshapeOp>(op.getLoc(),
-                                                           resultType, result);
+    if (op.getType() != newResultType) {
+      result = rewriter.create<mlir::stablehlo::ReshapeOp>(
+          op.getLoc(), op.getType(), newOp.getResult());
     }
+
     rewriter.replaceOp(op, result);
     return success();
   }
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalize_dot_general.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalize_dot_general.mlir
index bcb2ac1..f08a02c 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalize_dot_general.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalize_dot_general.mlir
@@ -17,6 +17,8 @@
   return %0 : tensor<3xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @dot_general_4d
 func.func public @dot_general_4d(%arg0: tensor<1x2x3xf32> {stablehlo.sharding = ""}, %arg1: tensor<1x4x2x3xf32> {stablehlo.sharding = ""}) -> tensor<1x2x4xf32> {
   %0 = stablehlo.dot_general %arg0, %arg1,
@@ -35,3 +37,18 @@
   // CHECK-NEXT: return %[[RES]]
   return %0 : tensor<1x2x4xf32>
 }
+
+
+// -----
+
+// CHECK-LABEL: @unary_out_channel_dot
+func.func public @unary_out_channel_dot(%arg0: tensor<1x3x4xui16>, %arg1: tensor<1x4x3xui16>) -> tensor<1xui16> {
+
+  // CHECK: %[[TRANS:.+]] = stablehlo.transpose %arg0, dims = [0, 2, 1]
+  // CHECK: %[[LHS:.+]] = stablehlo.reshape %[[TRANS]] : (tensor<1x4x3xui16>) -> tensor<12xui16>
+  // CHECK: %[[RHS:.+]] = stablehlo.reshape %arg1 : (tensor<1x4x3xui16>) -> tensor<12xui16>
+  // CHECK: %[[DOT:.+]] = stablehlo.dot %[[LHS]], %[[RHS]]
+  // CHECK: %[[OUT:.+]] = stablehlo.reshape %[[DOT]] : (tensor<ui16>) -> tensor<1xui16>
+  %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2, 1] x [1, 2], precision = [HIGH, HIGH] : (tensor<1x3x4xui16>, tensor<1x4x3xui16>) -> tensor<1xui16>
+  return %0 : tensor<1xui16>
+}