Do not fuse with elementwise operations that cant bufferize in-place. (#8526)

Currently the backend cannot bufferize in-place dispatch regions that
contain operations where the root operation like conv, etc. is fused
with an elementwise operation, where in the latter the buffer for an
output cannot be reused for the result of the root. Disable fusing
such cases as a WAR.

For now only do this for the convolution cases, more might be needed
while the proper fix is worked out down stream. (Proper fix is to
"vectorize always" even if the vector size is 1).

Issue #8411
diff --git a/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index 9decf95..40eaef0 100644
--- a/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -25,6 +25,7 @@
 #include "iree/compiler/Dialect/Flow/IR/PartitionableLoopsInterface.h"
 #include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -308,17 +309,27 @@
     return signalPassFailure();
   }
 
+  LLVM_DEBUG({
+    llvm::dbgs() << "--- After Tile + Distribute ---\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+
   // Apply linalg tiling optimization patterns.
   RewritePatternSet canonicalizationPatterns(context);
   linalg::populateLinalgTilingCanonicalizationPatterns(
       canonicalizationPatterns);
-  memref::populateResolveRankedShapeTypeResultDimsPatterns(
-      canonicalizationPatterns);
   if (failed(applyPatternsAndFoldGreedily(
           funcOp, std::move(canonicalizationPatterns)))) {
     return signalPassFailure();
   }
 
+  LLVM_DEBUG({
+    llvm::dbgs() << "--- After Canonicalize ---\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+
   // Rewrite destructive updates and ensure no remaining store remains to the
   // full output.
 
@@ -331,6 +342,20 @@
         << *funcOp.getOperation();
     return signalPassFailure();
   }
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "--- After Rewriting destructive updates ---\n";
+    funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+    llvm::dbgs() << "\n\n";
+  });
+
+  // After rewriting destructive updates, there might be uses of compute
+  // operations only in `tensor.dim` ops. Resolve these.
+  RewritePatternSet resolveDimOps(context);
+  memref::populateResolveRankedShapeTypeResultDimsPatterns(resolveDimOps);
+  if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(resolveDimOps)))) {
+    return signalPassFailure();
+  }
 }
 
 std::unique_ptr<OperationPass<FuncOp>>
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index af14cc3..333b32c 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -812,6 +812,36 @@
 // Heuristics for fusing dispatchble ops with root ops using tile + fuse.
 //===----------------------------------------------------------------------===//
 
+/// For the fusion of root op -> elementwise operation to be bufferized
+/// in-place without use of extra memory, the result of the root operation
+/// must be able to reuse the buffer for the result of the elementwise
+/// operation. This is possible if input and output are accessed using the same
+/// indexing map.
+// TODO: This restriction can go away if we can vectorize always, but that has
+// a long tail of tasks.
+static bool canInsOperandTieWithOutsOperand(OpOperand *insOperand) {
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(insOperand->getOwner());
+  if (!linalgOp) return false;
+  AffineMap insOperandIndexingMap = linalgOp.getTiedIndexingMap(insOperand);
+  auto canTieWithOutsOperand = [&](OpOperand *outsOperand) {
+    if (linalgOp.getTiedIndexingMap(outsOperand) != insOperandIndexingMap) {
+      return false;
+    }
+    // TODO(#8411): Until ops are vectorized (always), we need
+    // to check that the elementtype matches for the operands to be tied.
+    // For now just doing this check for convolution ops since we expect
+    // contraction ops to be vectorized.
+    auto producerOp = insOperand->get().getDefiningOp();
+    if (isa<linalg::GenericOp, linalg::ConvolutionOpInterface>(producerOp) &&
+        insOperand->get().getType().cast<ShapedType>().getElementType() !=
+            outsOperand->get().getType().cast<ShapedType>().getElementType()) {
+      return false;
+    }
+    return true;
+  };
+  return llvm::any_of(linalgOp.getOutputOperands(), canTieWithOutsOperand);
+}
+
 /// Some heuristic is needed to fuse a dispatchble op with root operations using
 /// tile + fuse. Using some heuristic, each root operation is tagged with an ID
 /// (using an IntegerAttr with name `kRootOpAttr`) and all dispatchable ops to
@@ -877,11 +907,9 @@
               consumerIndexingMap.getResults()) {
         continue;
       }
-      if (llvm::any_of(
-              consumer.getOutputOperands(), [&consumer](OpOperand *operand) {
-                return !consumer.getTiedIndexingMap(operand).isIdentity();
-              }))
+      if (!canInsOperandTieWithOutsOperand(&use)) {
         continue;
+      }
       int64_t rootNumber = getRootNumber(op);
       setRootAttribute(context, user, rootNumber);
       removeRootOpAttribute(op);
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 0755c49..d74946a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -1147,3 +1147,33 @@
 //  CHECK-DAG:   %[[D7:.+]] = tensor.dim %[[ARG0]], %[[C7]]
 //      CHECK:   flow.dispatch.workgroups[%[[D7]], %[[D5]], %[[D4]]]
 // CHECK-SAME:       (%[[ARG0]], %[[D1]], %[[D4]], %[[D5]], %[[D7]]
+
+// -----
+
+func @no_fuse_quantized(%arg0 : tensor<?x113x113x64xi8>, %arg1 : tensor<3x3x64xi8>,
+    %arg2 : i32, %arg3 : i32) -> tensor<?x56x56x64xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x113x113x64xi8>
+  %0 = linalg.init_tensor [%d0, 56, 56, 64] : tensor<?x56x56x64xi32>
+  %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<?x56x56x64xi32>) -> tensor<?x56x56x64xi32>
+  %2 =  linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+      ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x113x113x64xi8>, tensor<3x3x64xi8>, i32, i32)
+      outs(%1 : tensor<?x56x56x64xi32>) -> tensor<?x56x56x64xi32>
+  %3 = linalg.init_tensor [%d0, 56, 56, 64] : tensor<?x56x56x64xi8>
+  %4 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%2 : tensor<?x56x56x64xi32>) outs(%3 : tensor<?x56x56x64xi8>) {
+    ^bb0(%b0: i32, %b1 : i8):
+      %5 = arith.trunci %b0 : i32 to i8
+      linalg.yield %5 : i8
+    } -> tensor<?x56x56x64xi8>
+  return %4 : tensor<?x56x56x64xi8>
+}
+//     CHECK: func @no_fuse_quantized
+//     CHECK:   flow.dispatch.workgroups
+//     CHECK:   linalg.depthwise_conv_2d_nhwc_hwc_q
+// CHECK-NOT:   linalg.generic
+//     CHECK:   flow.dispatch.workgroups
+//     CHECK:   linalg.generic