Add canonicalizer that reorders unary elementwise ops and shape manipulation ops. (#14494)

The pass identifies shape manipulation operations (`reshape`,
`transpose`, `broadcast`) that feed into unary elementwise operations
and swaps their order. Since elementwise ops do not manipulate the input
shape (and there is no interaction between the elements), the reordering
doesn't affect the result.
This enables a future PR where some consecutive shape manipulation
operations should be collapsed together into a single operation.
`Broadcast_in_dim` is not included in this PR as the same thing is done
for it as part of an existing `StablehloToStablehlo` pass.
In the future, we might want to consider merging/more clearly defining
the difference between `Canonicalizations.cpp` and
`StablehloToStablehlo.cpp` files, as at the moment their function is
almost indistinguishable in some cases.
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
index dce7977..744a7ee 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp
@@ -1048,6 +1048,45 @@
   }
 };
 
+struct ReorderElementwiseAndShapeOp final
+    : OpTraitRewritePattern<OpTrait::Elementwise> {
+  using OpTraitRewritePattern::OpTraitRewritePattern;
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getOperands().size() != 1) {
+      return rewriter.notifyMatchFailure(op, "expected to be unary.");
+    }
+
+    auto definingOp = op->getOperand(0).getDefiningOp();
+    if (!definingOp) {
+      return rewriter.notifyMatchFailure(
+          op, "expected to have an op before elementise op.");
+    }
+
+    if (!isa<mlir::stablehlo::ReshapeOp>(definingOp) &&
+        !isa<mlir::stablehlo::TransposeOp>(definingOp) &&
+        !isa<mlir::stablehlo::BroadcastOp>(definingOp)) {
+      return rewriter.notifyMatchFailure(
+          op, "defining operation of unexpected type.");
+    }
+
+    Value input = definingOp->getOperand(0);
+    Value result = op->getResult(0);
+    auto intermediateType = cast<ShapedType>(input.getType())
+                                .clone(getElementTypeOrSelf(result.getType()));
+
+    // Reorder the operation and rewire the inputs/outputs.
+    op->moveBefore(definingOp);
+    definingOp->getResult(0).setType(result.getType());
+    rewriter.replaceAllUsesWith(result, definingOp->getResult(0));
+    result.setType(intermediateType);
+    op->setOperands(input);
+    definingOp->setOperands(result);
+    return success();
+  }
+};
+
 struct StableHLOCanonicalize final
     : impl::StableHLOCanonicalizeBase<StableHLOCanonicalize> {
   void runOnOperation() override {
@@ -1087,5 +1126,6 @@
       ReshapeOpCanon, TransposeOpCanon,
       // Types.
       ZeroExtentTensorCanon>(context, benefit);
+  patterns->add<ReorderElementwiseAndShapeOp>(context);
 }
 } // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir
index f0dc19a..bda6c3c 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir
@@ -703,3 +703,36 @@
   }
   return %3#0, %3#1 : tensor<i32>, tensor<75x0xf32>
 }
+
+// -----
+
+func.func @push_shape_ops_to_end(%arg0 : tensor<12xf32>) -> tensor<3x4x2x1xf32> {
+  %0 = stablehlo.reshape %arg0 : (tensor<12xf32>) -> tensor<3x4xf32>
+  %1 = stablehlo.broadcast %0, sizes = [1, 2] : (tensor<3x4xf32>) -> tensor<1x2x3x4xf32>
+  %2 = stablehlo.cosine %1 : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
+  %3 = stablehlo.transpose %2, dims = [2, 3, 1, 0]  : (tensor<1x2x3x4xf32>) -> tensor<3x4x2x1xf32>
+  %4 = stablehlo.abs %3 : (tensor<3x4x2x1xf32>) -> tensor<3x4x2x1xf32>
+  return %4 : tensor<3x4x2x1xf32>
+}
+
+// CHECK-LABEL: @push_shape_ops_to_end
+// CHECK: %[[COS:.+]] = stablehlo.cosine %arg0 : tensor<12xf32>
+// CHECK: %[[ABS:.+]] = stablehlo.abs %[[COS]] : tensor<12xf32>
+// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %[[ABS]] : (tensor<12xf32>) -> tensor<3x4xf32>
+// CHECK: %[[BROADCAST:.+]] = stablehlo.broadcast %[[RESHAPE]], sizes = [1, 2] : (tensor<3x4xf32>) -> tensor<1x2x3x4xf32>
+// CHECK: %[[TRANSPOSE:.+]] = stablehlo.transpose %[[BROADCAST]], dims = [2, 3, 1, 0] : (tensor<1x2x3x4xf32>) -> tensor<3x4x2x1xf32>
+// CHECK: return %[[TRANSPOSE]]
+
+// -----
+
+func.func @reorder_with_type_change(%arg0 : tensor<3x4xi32>) -> tensor<12xi64> {
+  %0 = stablehlo.reshape %arg0 : (tensor<3x4xi32>) -> tensor<12xi32>
+  %1 = stablehlo.convert %0 : (tensor<12xi32>) -> tensor<12xi64>
+  return %1 : tensor<12xi64>
+}
+
+// CHECK-LABEL: @reorder_with_type_change
+// CHECK: %[[CONVERT:.+]] = stablehlo.convert %arg0 : (tensor<3x4xi32>) -> tensor<3x4xi64>
+// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %[[CONVERT]] : (tensor<3x4xi64>) -> tensor<12xi64>
+// CHECK: return %[[RESHAPE]]
+