[LLVMGPU] Support aggressive fusion (#11747)

Enable tests for aggressive fusion on llvmgpu side.
Also make the softmax case with aggressive fusion go through fast path.
Add a transformation to avoid creation of large temp array.
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD b/compiler/src/iree/compiler/Codegen/Common/BUILD
index 5ca7f5c..3d74a04 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD
@@ -122,6 +122,7 @@
         "MaterializeEncodingPass.cpp",
         "OptimizeVectorTransferPass.cpp",
         "PolynomialApproximationPass.cpp",
+        "RematerializeParallelOps.cpp",
         "SplitFullPartialTransferPass.cpp",
         "TestPartitionableLoopsInterface.cpp",
         "TileAndDistributeToWorkgroupsPass.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index b59bf8b..fc3090f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -96,6 +96,7 @@
     "MaterializeEncodingPass.cpp"
     "OptimizeVectorTransferPass.cpp"
     "PolynomialApproximationPass.cpp"
+    "RematerializeParallelOps.cpp"
     "SplitFullPartialTransferPass.cpp"
     "TestPartitionableLoopsInterface.cpp"
     "TileAndDistributeToWorkgroupsPass.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp b/compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp
new file mode 100644
index 0000000..22c683b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/RematerializeParallelOps.cpp
@@ -0,0 +1,69 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-codegen-rematerialize-parallel-ops"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Merge elementwise operations into their consumers.
+struct MergeElementwiseOps : public OpRewritePattern<linalg::GenericOp> {
+  using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
+                                PatternRewriter& rewriter) const override {
+    // Find the first operand that is defined by another generic op on tensors.
+    for (OpOperand& opOperand : genericOp->getOpOperands()) {
+      if (!linalg::areElementwiseOpsFusable(&opOperand)) continue;
+
+      FailureOr<Operation*> fusedOp =
+          linalg::fuseElementwiseOps(rewriter, &opOperand);
+      if (succeeded(fusedOp)) {
+        // Forward lowering config.
+        if (auto loweringAttr = getLoweringConfig(genericOp)) {
+          setLoweringConfig(fusedOp.value(), loweringAttr);
+        }
+        auto replacements =
+            fusedOp.value()->getResults().take_back(genericOp.getNumResults());
+        rewriter.replaceOp(genericOp, replacements);
+        return success();
+      }
+    }
+    return failure();
+  }
+};
+
+struct RematerializeParallelOpsPass
+    : public RematerializeParallelOpsBase<RematerializeParallelOpsPass> {
+  void runOnOperation() override {
+    func::FuncOp funcOp = getOperation();
+    RewritePatternSet fusionPatterns(funcOp.getContext());
+    fusionPatterns.insert<MergeElementwiseOps>(funcOp.getContext());
+    if (failed(
+            applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+createRematerializeParallelOpsPass() {
+  return std::make_unique<RematerializeParallelOpsPass>();
+}
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD b/compiler/src/iree/compiler/Codegen/Common/test/BUILD
index 3ccaa84..a6c5356 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD
@@ -36,6 +36,7 @@
             "iree_comprehensive_bufferize.mlir",
             "pad_dynamic_alloc.mlir",
             "materialize_encoding.mlir",
+            "rematerialize_parallel_ops.mlir",
             "reduce_bank_conflicts.mlir",
             "reductions.mlir",
             "remove_dead_allocs.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index d83e099..0152b18 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -34,6 +34,7 @@
     "pad_dynamic_alloc.mlir"
     "reduce_bank_conflicts.mlir"
     "reductions.mlir"
+    "rematerialize_parallel_ops.mlir"
     "remove_dead_allocs.mlir"
     "remove_trivial_loops.mlir"
     "swizzle_workgroup.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir
new file mode 100644
index 0000000..6ba0ce4
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/rematerialize_parallel_ops.mlir
@@ -0,0 +1,44 @@
+// RUN: iree-opt -iree-codegen-rematerialize-parallel-ops %s | FileCheck %s
+
+func.func @merged_reduction_parallel(%0: tensor<1x40960xf32>, %1: tensor<1xf32>, %7: tensor<1xf32>)
+  -> tensor<1x40960xf32> {
+    %2 = tensor.empty() : tensor<1x40960xf32>
+    %cst = arith.constant -3.40282347E+38 : f32
+    %8 = linalg.generic 
+    {indexing_maps = [
+        affine_map<(d0, d1) -> (d0, d1)>,
+        affine_map<(d0, d1) -> (d0)>,
+        affine_map<(d0, d1) -> (d0, d1)>],
+        iterator_types = ["parallel", "parallel"]}
+        ins(%0, %1 : tensor<1x40960xf32>, tensor<1xf32>)
+        outs(%2 : tensor<1x40960xf32>) {
+      ^bb0(%in: f32, %in_2: f32, %out: f32):
+        %10 = arith.subf %in, %in_2 : f32
+        %11 = math.exp %10 : f32
+        linalg.yield %11 : f32
+      } -> (tensor<1x40960xf32>)
+    %9 = linalg.generic {
+        indexing_maps = [
+            affine_map<(d0, d1) -> (d0, d1)>,
+            affine_map<(d0, d1) -> (d0)>,
+            affine_map<(d0, d1) -> (d0, d1)>],
+            iterator_types = ["parallel", "parallel"]}
+            ins(%8, %7 : tensor<1x40960xf32>, tensor<1xf32>)
+            outs(%2 : tensor<1x40960xf32>) {
+      ^bb0(%in: f32, %in_2: f32, %out: f32):
+        %10 = arith.divf %cst, %in_2 : f32
+        %11 = arith.mulf %in, %10 : f32
+        linalg.yield %11 : f32
+      } -> tensor<1x40960xf32>
+   return %9 : tensor<1x40960xf32>
+}
+
+
+//   CHECK-LABEL: func.func @merged_reduction_parallel
+//         CHECK:   %{{.+}} = linalg.generic
+//         CHECK:     arith.subf
+//    CHECK-NEXT:     math.exp
+//    CHECK-NEXT:     arith.divf
+//    CHECK-NEXT:     arith.mulf
+//    CHECK-NEXT:     linalg.yield %{{.+}} : f32
+//         CHECK:   } -> tensor<1x40960xf32>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 4c4a9d5..55034d5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -555,7 +555,11 @@
   if (!targetInfo.hasWarpShuffle) return failure();
   if (!isa<linalg::GenericOp>(op)) return failure();
   // TODO(thomasraoux): Enable dynamic shape.
-  if (op.hasDynamicShape()) return failure();
+  bool hasDynamicShape = false;
+  entryPoint.walk([&hasDynamicShape](linalg::LinalgOp op) {
+    if (op.hasDynamicShape()) hasDynamicShape = true;
+  });
+  if (hasDynamicShape) return failure();
   SmallVector<unsigned> reductionDims;
   op.getReductionDims(reductionDims);
   if (reductionDims.size() != 1 || reductionDims[0] != op.getNumLoops() - 1)
@@ -569,11 +573,21 @@
       }))
     return failure();
 
-  // Only single combiner operations are supported for now.
-  SmallVector<Operation *, 4> combinerOps;
-  if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
-      combinerOps.size() != 1)
-    return failure();
+  bool foundSingleReductionOutput = false;
+  for (int64_t i = 0, e = op.getDpsInitOperands().size(); i < e; i++) {
+    // Only single combiner operations are supported for now.
+    SmallVector<Operation *, 4> combinerOps;
+    if (matchReduction(op.getRegionOutputArgs(), i, combinerOps) &&
+        combinerOps.size() == 1) {
+      if (foundSingleReductionOutput) return failure();
+      foundSingleReductionOutput = true;
+      continue;
+    }
+    if (!op.getMatchingIndexingMap(op.getDpsInitOperand(i)).isIdentity())
+      return failure();
+  }
+  if (!foundSingleReductionOutput) return failure();
+
   Optional<int64_t> dimSize = getLinalgDimSize(op, reductionDims[0]);
   if (!dimSize || *dimSize % cudaWarpSize != 0) return failure();
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 057a240..bf93104 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -308,7 +308,9 @@
 void addGPUWarpReductionPassPipeline(OpPassManager &pm) {
   tileAndDistributeToWorkgroup(pm);
   auto &nestedModulePM = pm.nest<ModuleOp>();
-
+  nestedModulePM.addNestedPass<func::FuncOp>(
+      createRematerializeParallelOpsPass());
+  nestedModulePM.addNestedPass<func::FuncOp>(createCanonicalizerPass());
   nestedModulePM.addNestedPass<func::FuncOp>(
       createRemoveSingleIterationLoopPass());
   nestedModulePM.addNestedPass<func::FuncOp>(createGPUTileReductionPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline.mlir
index 86397ba..d90f2b2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline.mlir
@@ -187,3 +187,140 @@
 //         CHECK:      vector.transfer_write {{.*}} : vector<4xf32>, memref<512x10240xf32>
 //         CHECK:    }
 //         CHECK:    return
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+hal.executable @softmax {
+hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> {
+  hal.executable.export @softmax layout(#pipeline_layout) {
+    ^bb0(%arg0: !hal.device, %arg1: index, %arg2 : index):
+      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg1, %arg2
+      hal.return %x, %y, %z : index, index, index
+    }
+  builtin.module {
+    func.func @softmax() {
+      %c0 = arith.constant 0 : index
+      %cst = arith.constant -3.40282347E+38 : f32
+      %cst_0 = arith.constant 0.000000e+00 : f32
+      %cst_1 = arith.constant 1.000000e+00 : f32
+      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<12x128x40960xf32>>
+      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<12x128x40960xf32>>
+      %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<12x128x40960xf32>> -> tensor<12x128x40960xf32>
+      %3 = tensor.empty() : tensor<12x128xf32>
+      %4 = tensor.empty() : tensor<12x128x40960xf32>
+      %5 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1], [0, 0, 4096]]>} ins(%cst : f32) outs(%3 : tensor<12x128xf32>) -> tensor<12x128xf32>
+      %6 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1], [0, 0, 4096]]>} ins(%cst_0 : f32) outs(%3 : tensor<12x128xf32>) -> tensor<12x128xf32>
+      %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%2 : tensor<12x128x40960xf32>) outs(%5 : tensor<12x128xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1], [0, 0, 4096]]>} {
+      ^bb0(%in: f32, %out: f32):
+        %11 = arith.maxf %in, %out : f32
+        linalg.yield %11 : f32
+      } -> tensor<12x128xf32>
+      %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2, %7 : tensor<12x128x40960xf32>, tensor<12x128xf32>) outs(%4 : tensor<12x128x40960xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1], [0, 0, 4096]]>} {
+      ^bb0(%in: f32, %in_2: f32, %out: f32):
+        %11 = arith.subf %in, %in_2 : f32
+        %12 = math.exp %11 : f32
+        linalg.yield %12 : f32
+      } -> tensor<12x128x40960xf32>
+      %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%8 : tensor<12x128x40960xf32>) outs(%6 : tensor<12x128xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1], [0, 0, 4096]]>} {
+      ^bb0(%in: f32, %out: f32):
+        %11 = arith.addf %in, %out : f32
+        linalg.yield %11 : f32
+      } -> tensor<12x128xf32>
+      %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8, %9 : tensor<12x128x40960xf32>, tensor<12x128xf32>) outs(%4 : tensor<12x128x40960xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1], [0, 0, 4096]]>} {
+      ^bb0(%in: f32, %in_2: f32, %out: f32):
+        %11 = arith.divf %cst_1, %in_2 : f32
+        %12 = arith.mulf %in, %11 : f32
+        linalg.yield %12 : f32
+      } -> tensor<12x128x40960xf32>
+      flow.dispatch.tensor.store %10, %1, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : tensor<12x128x40960xf32> -> !flow.dispatch.tensor<writeonly:tensor<12x128x40960xf32>>
+      return
+    }
+  }
+}
+}
+
+//   CHECK-LABEL:  func.func @softmax
+//         CHECK:    scf.for {{.*}} -> (vector<4xf32>) {
+//         CHECK:      vector.transfer_read {{.*}} : memref<12x128x40960xf32>, vector<4xf32>
+//         CHECK:      arith.maxf {{.*}} : vector<4xf32>
+//         CHECK:      scf.yield
+//         CHECK:    vector.reduction <maxf>, %{{.*}} : vector<4xf32> into f32
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.maxf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.maxf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.maxf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.maxf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.maxf
+//         CHECK:    arith.remui
+//         CHECK:    scf.if
+//         CHECK:      memref.store {{.*}} : memref<32xf32, 3>
+//         CHECK:    }
+//         CHECK:    gpu.barrier
+//         CHECK:    arith.minui
+//         CHECK:    memref.load
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.maxf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.maxf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.maxf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.maxf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.maxf
+//         CHECK:    arith.maxf
+//         CHECK:    vector.broadcast %{{.*}} : f32 to vector<4xf32>
+//         CHECK:    scf.for {{.*}} -> (vector<4xf32>) {
+//         CHECK:      vector.transfer_read
+//         CHECK:      arith.subf
+//         CHECK:      math.exp
+//         CHECK:      arith.addf
+//         CHECK:      scf.yield
+//         CHECK:    vector.reduction <add>, %{{.*}} : vector<4xf32> into f32
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.addf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.addf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.addf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.addf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.addf
+//         CHECK:    scf.if
+//         CHECK:      memref.store {{.*}} : memref<32xf32, 3>
+//         CHECK:    }
+//         CHECK:    gpu.barrier
+//         CHECK:    memref.load
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.addf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.addf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.addf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.addf
+//         CHECK:    gpu.shuffle  xor
+//         CHECK:    arith.addf
+//         CHECK:    arith.addf
+//         CHECK:    vector.broadcast
+//         CHECK:    vector.broadcast
+//         CHECK:    arith.divf
+//         CHECK:    scf.for
+//         CHECK:      vector.transfer_read
+//         CHECK:      arith.subf
+//         CHECK:      math.exp
+//         CHECK:      arith.mulf
+//         CHECK:      vector.transfer_write
+//         CHECK:    }
+//         CHECK:    return
diff --git a/compiler/src/iree/compiler/Codegen/Passes.h b/compiler/src/iree/compiler/Codegen/Passes.h
index 3d606d5..8b1f6f2 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/Passes.h
@@ -212,6 +212,10 @@
 std::unique_ptr<OperationPass<func::FuncOp>>
 createEraseHALDescriptorTypeFromMemRefPass();
 
+/// Pass to merge parallel linalg operations.
+std::unique_ptr<OperationPass<func::FuncOp>>
+createRematerializeParallelOpsPass();
+
 //----------------------------------------------------------------------------//
 // Common codegen patterns.
 //----------------------------------------------------------------------------//
diff --git a/compiler/src/iree/compiler/Codegen/Passes.td b/compiler/src/iree/compiler/Codegen/Passes.td
index 1502871..25adeee 100644
--- a/compiler/src/iree/compiler/Codegen/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Passes.td
@@ -310,6 +310,13 @@
       "mlir::iree_compiler::createEraseHALDescriptorTypeFromMemRefPass()";
 }
 
+def RematerializeParallelOps :
+    Pass<"iree-codegen-rematerialize-parallel-ops", "func::FuncOp"> {
+  let summary = "Pass to rematerialize and merge parallel ops to avoid"
+                "creating temporary allocs.";
+  let constructor = "mlir::iree_compiler::createRematerializeParallelOpsPass()";
+}
+
 //------------------------------------------------------------------------------
 // LLVMCPU
 //------------------------------------------------------------------------------
diff --git a/tests/e2e/regression/BUILD b/tests/e2e/regression/BUILD
index 6c1ecf6..ce11d07 100644
--- a/tests/e2e/regression/BUILD
+++ b/tests/e2e/regression/BUILD
@@ -56,6 +56,7 @@
             "layernorm.mlir",
             "linalg_quantized_matmul_vs_linalg_matmul.mlir",
             "lowering_config.mlir",
+            "softmax_large.mlir",
         ] + BACKEND_TESTS,
     ),
     cfg = "//tests:lit.cfg.py",
@@ -136,6 +137,27 @@
 )
 
 iree_check_single_backend_test_suite(
+    name = "aggressive_fusion_test_cuda",
+    srcs = [
+        "softmax.mlir",
+        "softmax_large.mlir",
+    ],
+    compiler_flags = [
+        "--iree-flow-enable-aggressive-fusion",
+    ],
+    driver = "cuda",
+    tags = [
+        # CUDA cuInit fails with sanitizer on.
+        "noasan",
+        "nomsan",
+        "notsan",
+        "noubsan",
+        "requires-gpu-nvidia",
+    ],
+    target_backend = "cuda",
+)
+
+iree_check_single_backend_test_suite(
     name = "disable_demote_f64_to_f32",
     srcs = [
         "disable_demote_f64_to_f32.mlir",
diff --git a/tests/e2e/regression/CMakeLists.txt b/tests/e2e/regression/CMakeLists.txt
index df1a0d4..d721c22 100644
--- a/tests/e2e/regression/CMakeLists.txt
+++ b/tests/e2e/regression/CMakeLists.txt
@@ -162,6 +162,26 @@
 
 iree_check_single_backend_test_suite(
   NAME
+    aggressive_fusion_test_cuda
+  SRCS
+    "softmax.mlir"
+    "softmax_large.mlir"
+  TARGET_BACKEND
+    "cuda"
+  DRIVER
+    "cuda"
+  COMPILER_FLAGS
+    "--iree-flow-enable-aggressive-fusion"
+  LABELS
+    "noasan"
+    "nomsan"
+    "notsan"
+    "noubsan"
+    "requires-gpu-nvidia"
+)
+
+iree_check_single_backend_test_suite(
+  NAME
     disable_demote_f64_to_f32
   SRCS
     "disable_demote_f64_to_f32.mlir"
diff --git a/tests/e2e/regression/softmax_large.mlir b/tests/e2e/regression/softmax_large.mlir
new file mode 100644
index 0000000..4a9e758
--- /dev/null
+++ b/tests/e2e/regression/softmax_large.mlir
@@ -0,0 +1,58 @@
+// Generated from this TOSA input:
+//
+// func.func @softmax() {
+//   %0 = util.unfoldable_constant dense<5.0> : tensor<12x128x40960xf32>
+//   %red = "tosa.reduce_max"(%0) {axis = 2 : i64} : (tensor<12x128x40960xf32>) -> tensor<12x128x1xf32>
+//   %sub = "tosa.sub"(%0, %red) : (tensor<12x128x40960xf32>, tensor<12x128x1xf32>) -> tensor<12x128x40960xf32>
+//   %exp = "tosa.exp"(%sub) : (tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32>
+//   %sum = "tosa.reduce_sum"(%exp) {axis = 2 : i64} : (tensor<12x128x40960xf32>) -> tensor<12x128x1xf32>
+//   %rec = "tosa.reciprocal"(%sum) : (tensor<12x128x1xf32>) -> tensor<12x128x1xf32>
+//   %mul = "tosa.mul"(%exp, %rec) {shift = 0 : i32} : (tensor<12x128x40960xf32>, tensor<12x128x1xf32>) -> tensor<12x128x40960xf32>
+//   check.expect_almost_eq_const(%mul, dense<0.0078125> : tensor<12x128x40960xf32>) : tensor<12x128x40960xf32>
+//   return
+// }
+
+func.func @softmax() {
+  %cst = arith.constant 1.000000e+00 : f32
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %cst_1 = arith.constant -3.40282347E+38 : f32
+  %cst_2 = arith.constant dense<2.44140625e-06> : tensor<12x128x40960xf32>
+  %cst_3 = arith.constant dense<5.000000e+00> : tensor<12x128x40960xf32>
+  %0 = util.optimization_barrier %cst_3 : tensor<12x128x40960xf32>
+  %1 = tensor.empty() : tensor<12x128xf32>
+  %2 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<12x128xf32>) -> tensor<12x128xf32>
+  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%0 : tensor<12x128x40960xf32>) outs(%2 : tensor<12x128xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32):
+    %11 = arith.maxf %arg0, %arg1 : f32
+    linalg.yield %11 : f32
+  } -> tensor<12x128xf32>
+  %4 = tensor.empty() : tensor<12x128x40960xf32>
+  %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0, %3 : tensor<12x128x40960xf32>, tensor<12x128xf32>) outs(%4 : tensor<12x128x40960xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+    %11 = arith.subf %arg0, %arg1 : f32
+    linalg.yield %11 : f32
+  } -> tensor<12x128x40960xf32>
+  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%5 : tensor<12x128x40960xf32>) outs(%4 : tensor<12x128x40960xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32):
+    %11 = math.exp %arg0 : f32
+    linalg.yield %11 : f32
+  } -> tensor<12x128x40960xf32>
+  %7 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<12x128xf32>) -> tensor<12x128xf32>
+  %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%6 : tensor<12x128x40960xf32>) outs(%7 : tensor<12x128xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32):
+    %11 = arith.addf %arg0, %arg1 : f32
+    linalg.yield %11 : f32
+  } -> tensor<12x128xf32>
+  %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<12x128xf32>) outs(%1 : tensor<12x128xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32):
+    %11 = arith.divf %cst, %arg0 : f32
+    linalg.yield %11 : f32
+  } -> tensor<12x128xf32>
+  %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%6, %9 : tensor<12x128x40960xf32>, tensor<12x128xf32>) outs(%4 : tensor<12x128x40960xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+    %11 = arith.mulf %arg0, %arg1 : f32
+    linalg.yield %11 : f32
+  } -> tensor<12x128x40960xf32>
+  check.expect_almost_eq(%10, %cst_2) : tensor<12x128x40960xf32>
+  return
+}
\ No newline at end of file