[Codegen] Replace LICM with a version that checks trip count (#18679)

The upstream LICM pass does no verification that it is safe to hoist an
op out of a loop (i.e. the loop has >= 1 trip count). This replaces all
uses of LICM in codegen with a version that does this verification.
Other phases of the compiler probably should switch as well.
diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
index 0ec2d8c..616e393 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
@@ -117,6 +117,7 @@
         "HoistUnrolledVectorExtractInsertSlice.cpp",
         "IREEComprehensiveBufferizePass.cpp",
         "IREEExpandStridedMetadata.cpp",
+        "IREELoopInvariantCodeMotion.cpp",
         "InstrumentMemoryAccesses.cpp",
         "LowerExecutableUsingTransformDialect.cpp",
         "LowerUKernelsToCalls.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
index 428bb49..648805b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -108,6 +108,7 @@
     "HoistUnrolledVectorExtractInsertSlice.cpp"
     "IREEComprehensiveBufferizePass.cpp"
     "IREEExpandStridedMetadata.cpp"
+    "IREELoopInvariantCodeMotion.cpp"
     "InstrumentMemoryAccesses.cpp"
     "LowerExecutableUsingTransformDialect.cpp"
     "LowerUKernelsToCalls.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/IREELoopInvariantCodeMotion.cpp b/compiler/src/iree/compiler/Codegen/Common/IREELoopInvariantCodeMotion.cpp
new file mode 100644
index 0000000..10f8d60
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/IREELoopInvariantCodeMotion.cpp
@@ -0,0 +1,28 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Transforms/Transforms.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_IREELOOPINVARIANTCODEMOTIONPASS
+#include "iree/compiler/Codegen/Common/Passes.h.inc"
+
+namespace {
+/// IREE loop invariant code motion (LICM) pass.
+struct IREELoopInvariantCodeMotionPass
+    : public impl::IREELoopInvariantCodeMotionPassBase<
+          IREELoopInvariantCodeMotionPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void IREELoopInvariantCodeMotionPass::runOnOperation() {
+  moveLoopInvariantCodeFromGuaranteedLoops(getOperation());
+}
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index c30dbb1..6bb6c82 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -456,6 +456,21 @@
   let summary = "Remove distributed loop with single iteration.";
 }
 
+// TODO: Replace with upstream: https://github.com/iree-org/iree/issues/18759
+def IREELoopInvariantCodeMotionPass :
+  Pass<"iree-loop-invariant-code-motion", ""> {
+  let summary = "Performs LICM on loops guaranteed to have >= 1 trip";
+  let description = [{
+    This is a mirror of the upstream LICM pass that restricts to loops that are
+    guaranteed to have at least one trip. This currently only supports loops
+    that expose a lower and upper bound as the generic loop-like interface does
+    not expose a way to query for trip count.
+
+    Additionally code motion of `scf.forall` ops with mappings is always unsafe
+    and is explicitly disabled.
+  }];
+}
+
 def SplitFullPartialTransferPass :
     InterfacePass<"iree-codegen-split-full-partial-transfer", "mlir::FunctionOpInterface"> {
   let summary =
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 284e6cf..b00a94a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -49,6 +49,7 @@
             "hoist_unrolled_vector_extract_insert_slice.mlir",
             "iree_comprehensive_bufferize.mlir",
             "iree_expand_strided_metadata.mlir",
+            "iree_loop_invariant_code_motion.mlir",
             "lower_ukernel_to_calls.mlir",
             "materialize_encoding_into_nop.mlir",
             "materialize_user_configs.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index f75729f..fb27a4b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -45,6 +45,7 @@
     "hoist_unrolled_vector_extract_insert_slice.mlir"
     "iree_comprehensive_bufferize.mlir"
     "iree_expand_strided_metadata.mlir"
+    "iree_loop_invariant_code_motion.mlir"
     "lower_ukernel_to_calls.mlir"
     "materialize_encoding_into_nop.mlir"
     "materialize_user_configs.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_loop_invariant_code_motion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_loop_invariant_code_motion.mlir
new file mode 100644
index 0000000..8f60ee6
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_loop_invariant_code_motion.mlir
@@ -0,0 +1,95 @@
+// RUN: iree-opt %s  -split-input-file --iree-loop-invariant-code-motion | FileCheck %s
+
+func.func @nested_loops_code_invariant_to_both() {
+  %m = memref.alloc() : memref<10xf32>
+  %cf7 = arith.constant 7.0 : f32
+  %cf8 = arith.constant 8.0 : f32
+
+  affine.for %arg0 = 0 to 10 {
+    affine.for %arg1 = 0 to 10 {
+      %v0 = arith.addf %cf7, %cf8 : f32
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: @nested_loops_code_invariant_to_both
+//       CHECK:   memref.alloc() : memref<10xf32>
+//  CHECK-NEXT:   arith.constant 7
+//  CHECK-NEXT:   arith.constant 8
+//  CHECK-NEXT:   arith.addf
+
+// -----
+
+func.func @do_not_hoist_with_unknown_trip_count(%lb: index, %ub: index) {
+  affine.for %arg1 = %lb to %ub {
+    affine.for %arg0 = 0 to 10 {
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: @do_not_hoist_with_unknown_trip_count
+//  CHECK-NEXT:   affine.for
+//  CHECK-NEXT:     affine.for
+//  CHECK-NEXT:     }
+//  CHECK-NEXT:   }
+
+// -----
+
+func.func @do_not_hoist_scf_for_with_unknown_trip_count(%lb: index, %ub: index) {
+  %c1 = arith.constant 1 : index
+  %cf7 = arith.constant 7.0 : f32
+  %cf8 = arith.constant 8.0 : f32
+  scf.for %arg0 = %lb to %ub step %c1 {
+    %v0 = arith.addf %cf7, %cf8 : f32
+  }
+  return
+}
+
+// CHECK-LABEL: @do_not_hoist_scf_for_with_unknown_trip_count
+//       CHECK:   scf.for
+//  CHECK-NEXT:     arith.addf
+//  CHECK-NEXT:   }
+
+// -----
+
+func.func @do_hoist_scf_for_with_known_trip_count() {
+  %c4 = arith.constant 4 : index
+  %c6 = arith.constant 6 : index
+  %c1 = arith.constant 1 : index
+  %cf7 = arith.constant 7.0 : f32
+  %cf8 = arith.constant 8.0 : f32
+  scf.for %arg0 = %c4 to %c6 step %c1 {
+    %v0 = arith.addf %cf7, %cf8 : f32
+  }
+  return
+}
+
+// CHECK-LABEL: @do_hoist_scf_for_with_known_trip_count
+//       CHECK:   arith.addf
+//       CHECK:   scf.for
+
+// -----
+
+func.func @do_not_hoist_scf_while() {
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf7 = arith.constant 7.0 : f32
+  %cf8 = arith.constant 8.0 : f32
+  scf.while (%iter = %c0) : (index) -> (index) {
+    %cond = arith.cmpi slt, %iter, %c4 : index
+    scf.condition(%cond) %iter : index
+  } do {
+  ^bb0(%arg1: index):
+    %v0 = arith.addf %cf7, %cf8 : f32
+    scf.yield %arg1 : index
+  }
+  return
+}
+
+// CHECK-LABEL: @do_not_hoist_scf_while
+//       CHECK:   scf.while
+//       CHECK:     scf.condition
+//       CHECK:     arith.addf
+//       CHECK:     scf.yield
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 980007c..3508e52 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -262,7 +262,7 @@
   funcPassManager.addPass(createGPUDistributePass());
 
   // Post bufferization optimizations.
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
@@ -395,7 +395,7 @@
   // TODO: This LICM instance is load bearing due to brittleness of the
   // hoisting and fusion pass, as well as a lack of a fallback distribution
   // pass.
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   {
     OptimizeTensorInsertExtractSlicesPassOptions options;
     options.foldIdentitySlices = true;
@@ -408,7 +408,7 @@
   funcPassManager.addPass(createGPUGreedilyDistributeToThreadsPass());
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   funcPassManager.addPass(IREE::GPU::createCombineBarrierRegionsPass());
 
   // Step 6. Lower special ops and vectorize.
@@ -438,7 +438,7 @@
   // Step 9. Remaining post-bufferization optimizations/lowerings.
   funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass());
   funcPassManager.addPass(createUnrollAnnotatedLoopsPass());
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   if (pipelineOptions.enableReduceSharedMemoryBankConflicts) {
     GPUReduceBankConflictsPassOptions options = {};
     options.paddingBits = 64;
@@ -492,7 +492,7 @@
   funcPassManager.addPass(createGPUDistributeScfForPass(options));
 
   // Post bufferization optimizations.
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
@@ -562,7 +562,7 @@
   funcPassManager.addPass(createOptimizeVectorTransferPass());
 
   // Hoist loop invariant code to avoid pipelining it.
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   // Pipeline memory operations.
   funcPassManager.addPass(createGPUPipeliningPass());
 }
@@ -625,7 +625,7 @@
   funcPassManager.addPass(createCSEPass());
 
   // Hoist loop invariant code to avoid pipelining it.
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   // Pipeline memory operations.
   GPUPipeliningPassOptions pipelieningOptions = {};
   pipelieningOptions.epiloguePeeling = false;
@@ -692,7 +692,7 @@
   funcPassManager.addPass(createCSEPass());
 
   // Hoist loop invariant code to avoid pipelining it.
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   // Pipeline memory operations.
   GPUPipeliningPassOptions pipelieningOptions = {};
   pipelieningOptions.epiloguePeeling = false;
@@ -855,7 +855,7 @@
   // Set anchors at tensor level for vector distribution later and hoist out
   // loop invariant anchors.
   funcPassManager.addPass(createLLVMGPUConfigureTensorLayoutsPass());
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
 
   // Generalize all named ops so that we can fold away unit extent dims. By this
   // point, all tiling is finished so the tiling configurations on those ops can
@@ -936,7 +936,7 @@
     funcPassManager.addPass(createCanonicalizerPass());
     funcPassManager.addPass(createCSEPass());
   }
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
 
@@ -945,7 +945,7 @@
   funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
   funcPassManager.addPass(createOptimizeVectorTransferPass());
   funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
   funcPassManager.addPass(createForOpCanonicalizationPass());
@@ -1040,13 +1040,13 @@
       .addPass(memref::createExpandStridedMetadataPass)
       // Hoist loop invariant variables to give affine decomposition pass the
       // right loop dependencies.
-      .addPass(createLoopInvariantCodeMotionPass)
+      .addPass(createIREELoopInvariantCodeMotionPass)
       // Decompose affine ops.
       .addPass(createDecomposeAffineOpsPass)
       // Get rid of the redundant computations.
       .addPass(createCSEPass)
       // Hoist the resulting decompositions.
-      .addPass(createLoopInvariantCodeMotionPass)
+      .addPass(createIREELoopInvariantCodeMotionPass)
       .addPass(createLowerAffinePass);
 }
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
index bfde530..66ac37b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
@@ -639,10 +639,9 @@
 // the producer's (convolution's) distributed scf.forall loop.
 // CHECK-LABEL: func @conv_nchw_fused
 //       CHECK:   %[[ALLOCA:.+]] = memref.alloca() : memref<1x1x1x1xf32, #gpu.address_space<private>>
-//       CHECK:   %[[ALLOCA2:.+]] = memref.alloca() : memref<1x1x1x1xf32, #gpu.address_space<private>>
 //       CHECK:   scf.for %{{.*}} = %c0 to %c64 step %c1
 //       CHECK:     linalg.conv_2d_nchw_fchw
-//  CHECK-SAME:       outs(%[[ALLOCA2]] : memref<1x1x1x1xf32, #gpu.address_space<private>>)
+//  CHECK-SAME:       outs(%[[ALLOCA]] : memref<1x1x1x1xf32, #gpu.address_space<private>>)
 //       CHECK:   arith.addf
 //       CHECK:   arith.cmpf
 //       CHECK:   arith.select
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index b715689..30868ff 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -530,7 +530,7 @@
   funcPassManager.addPass(createOptimizeVectorTransferPass());
 
   // Hoist loop invariant code to avoid pipelining it.
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   PipeliningSchedulingStrategy schedule =
       storeStage == 0 ? PipeliningSchedulingStrategy::loadStoreStage0
                       : PipeliningSchedulingStrategy::loadGlobalStage0;
@@ -572,7 +572,7 @@
     funcPassManager.addPass(createCSEPass());
   }
 
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
 
@@ -587,7 +587,7 @@
 
   // Simplify the IR for vector distribution.
   funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
-  funcPassManager.addPass(createLoopInvariantCodeMotionPass());
+  funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
   funcPassManager.addPass(createForOpCanonicalizationPass());
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index dbe350f..0b8c49c 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -30,6 +30,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
 
 #define DEBUG_TYPE "iree-codegen-transforms"
 
@@ -493,6 +494,52 @@
 }
 
 //===---------------------------------------------------------------------===//
+// Helper to perform LICM on loops that are guaranteed at least one trip.
+//===---------------------------------------------------------------------===//
+
+void moveLoopInvariantCodeFromGuaranteedLoops(Operation *target) {
+  // Walk through all loops in a function in innermost-loop-first order. This
+  // way, we first LICM from the inner loop, and place the ops in
+  // the outer loop, which in turn can be further LICM'ed.
+  //
+  // Hoisting is only performed on loops with guaranteed non-zero trip counts.
+  // `scf.forall` ops with mapping attributes can never be proven to have a
+  // non-zero trip count until the loop is resolved and is blanket included
+  // here.
+  target->walk([&](LoopLikeOpInterface loopLike) {
+    if (auto forallOp = dyn_cast<scf::ForallOp>(*loopLike)) {
+      if (forallOp.getMapping()) {
+        return;
+      }
+    }
+
+    // Skip loops without lower/upper bounds. There is no generic way to verify
+    // whether a loop has at least one trip so new loop types of interest can be
+    // added as needed. For example, `scf.while` needs non-trivial analysis of
+    // its condition region to know that it has at least one trip.
+    std::optional<SmallVector<OpFoldResult>> maybeLowerBounds =
+        loopLike.getLoopLowerBounds();
+    std::optional<SmallVector<OpFoldResult>> maybeUpperBounds =
+        loopLike.getLoopUpperBounds();
+    if (!maybeLowerBounds || !maybeUpperBounds) {
+      return;
+    }
+
+    // If any lower + upper bound pair cannot be definitely verified as lb < ub
+    // then the loop may have a zero trip count.
+    for (auto [lb, ub] :
+         llvm::zip_equal(*maybeLowerBounds, *maybeUpperBounds)) {
+      if (!ValueBoundsConstraintSet::compare(lb, ValueBoundsConstraintSet::LT,
+                                             ub)) {
+        return;
+      }
+    }
+
+    moveLoopInvariantCode(loopLike);
+  });
+}
+
+//===---------------------------------------------------------------------===//
 // Patterns to fold tensor.expand/collapse_shape into
 // `hal.interface.binding.subspan`
 //===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
index 74ea3ef..b0b64ca 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.h
@@ -157,6 +157,12 @@
     ArrayRef<OpFoldResult> workgroupCount,
     int maxWorkgroupParallelDims = kNumMaxParallelDims);
 
+/// Helper to perform LICM on loops nested within |target| that are guaranteed
+/// to have at least one trip. Additionally LICM on `scf.forall` ops with
+/// mapping attributes are excluded as their trip count is unclear until
+/// resolution.
+void moveLoopInvariantCodeFromGuaranteedLoops(Operation *target);
+
 //===----------------------------------------------------------------------===//
 // Transformations exposed as patterns, moved from upstream MLIR as IREE still
 // heavily relies on patterns that compose through filters.
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
index 9b4db12..00c5c9f 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
@@ -114,7 +114,7 @@
     FunctionLikeNest &funcPassManager) {
   funcPassManager.addPass(createLowerAffinePass)
       .addPass(createForOpCanonicalizationPass)
-      .addPass(createLoopInvariantCodeMotionPass);
+      .addPass(createIREELoopInvariantCodeMotionPass);
 }
 
 void buildVMVXTransformPassPipeline(OpPassManager &variantPassManager) {