[Codegen][GPU] Add pass to annotate memory spaces on allocations (#18251)

Trying to infer the memory space of an allocation from within the
bufferization alloc callback function is too late. This adds a
rudimentary pass to annotate the memory space in obvious situations
and then disallows all cases of a bufferization allocation without an
already pre-determined memory space (for the LLVMGPUTileAndFuse
pipeline). This gives us correctness guarantees that were somewhat hand
wavy before.

This makes all allocations that aren't marked explicitly as shared (or can
be obviously inferred as shared) as thread local. Any previous lowerings
that violate this invariant is a bug (most likely from a failure to tile
an operation).
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index 4a0b879..64d0412 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -58,6 +58,7 @@
         "GPUDistributeSharedMemoryCopy.cpp",
         "GPUDistributionPatterns.cpp",
         "GPUGeneralizeNamedOps.cpp",
+        "GPUInferMemorySpace.cpp",
         "GPULowerToUKernels.cpp",
         "GPUMultiBuffering.cpp",
         "GPUNestedLayoutDistributionPatterns.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index eb51b3e..5a376fd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -56,6 +56,7 @@
     "GPUDistributeSharedMemoryCopy.cpp"
     "GPUDistributionPatterns.cpp"
     "GPUGeneralizeNamedOps.cpp"
+    "GPUInferMemorySpace.cpp"
     "GPULowerToUKernels.cpp"
     "GPUMultiBuffering.cpp"
     "GPUNestedLayoutDistributionPatterns.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp
new file mode 100644
index 0000000..64fed3b
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp
@@ -0,0 +1,106 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Common/GPU/Passes.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_GPUINFERMEMORYSPACEPASS
+#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
+
+namespace {
+
+/// Pass to infer the memory spaces of unmarked `bufferization.alloc_tensor`
+/// ops. Inferring the memory space during bufferization (in the allocation
+/// function) is infeasible due to some limited analysis of surrounding loop
+/// structures needed. After this pass, any unexpected allocations are then
+/// treated as a compiler failure indicating something went wrong during
+/// bufferization.
+struct GPUInferMemorySpacePass final
+    : impl::GPUInferMemorySpacePassBase<GPUInferMemorySpacePass> {
+
+  void runOnOperation() override;
+};
+
+bool isDefinitelyShared(bufferization::AllocTensorOp alloc) {
+  // An allocation can be inferred as shared if it is the destination of a
+  // thread distributed `scf.forall` op. All other shared allocations are
+  // expected to be properly indicated in advance.
+  for (auto user : alloc->getUsers()) {
+    auto forallOp = dyn_cast<scf::ForallOp>(user);
+    if (!forallOp ||
+        !forallOpHasMappingType<gpu::GPUThreadMappingAttr,
+                                gpu::GPUWarpMappingAttr>(forallOp)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+void GPUInferMemorySpacePass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  FunctionOpInterface funcOp = getOperation();
+
+  gpu::AddressSpaceAttr privateAddressSpace = gpu::AddressSpaceAttr::get(
+      context, gpu::GPUDialect::getPrivateAddressSpace());
+  gpu::AddressSpaceAttr sharedAddressSpace = gpu::AddressSpaceAttr::get(
+      context, gpu::GPUDialect::getWorkgroupAddressSpace());
+
+  WalkResult res = funcOp.walk([&](bufferization::AllocTensorOp alloc) {
+    // Continue if the allocation already has a valid memory space.
+    std::optional<Attribute> currentMemSpace = alloc.getMemorySpace();
+    if (currentMemSpace.has_value()) {
+      if (currentMemSpace.value() == privateAddressSpace ||
+          currentMemSpace.value() == sharedAddressSpace) {
+        return WalkResult::advance();
+      }
+      alloc.emitOpError(
+          "unexpected gpu memory space must be private or workgroup.");
+      return WalkResult::interrupt();
+    }
+
+    /// Determining GPU memory spaces must be trivial by the time of this pass.
+    /// Because this pass runs immediately before bufferization, input IR is
+    /// expected to mix (thread) distributed and shared contexts. Because after
+    /// bufferization distributed loops (scf.forall) ops are expected to be
+    /// inlined as-is with no further tiling occurring, all tensors at this
+    /// point in the IR are assumed to be thread-local unless it is explicitly
+    /// marked as shared. This gives the following invariants:
+    ///
+    /// 1. If the alloc_tensor is annotated with `#gpu.address_space<private>`
+    ///    already, or if it is used as the immediate destination of a thread
+    ///    or warp distributed `scf.forall` op, then the allocation must be
+    ///    shared memory.
+    /// 2. All other allocations are thread local.
+    ///
+    /// Any allocation that is not explicitly marked as shared memory that is
+    /// supposed to be indicates a bug in earlier passes/lowerings.
+    if (isDefinitelyShared(alloc)) {
+      alloc.setMemorySpaceAttr(sharedAddressSpace);
+    } else {
+      alloc.setMemorySpaceAttr(privateAddressSpace);
+    }
+    return WalkResult::advance();
+  });
+
+  if (res.wasInterrupted()) {
+    funcOp->emitOpError("failed to set the gpu memory space for all "
+                        "`bufferization.alloc_tensor` ops");
+    return signalPassFailure();
+  }
+}
+
+} // namespace
+
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVerifyDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVerifyDistribution.cpp
index 273cadf..fe47380 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVerifyDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVerifyDistribution.cpp
@@ -7,6 +7,7 @@
 #include "iree/compiler/Codegen/Common/GPU/Passes.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
 #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -19,28 +20,6 @@
 
 namespace {
 
-template <typename... Type>
-bool forallOpHasMappingType(scf::ForallOp forallOp) {
-  std::optional<ArrayAttr> mapping = forallOp.getMapping();
-  if (!mapping || mapping.value().empty()) {
-    return false;
-  }
-
-  return isa<Type...>(*mapping.value().begin());
-}
-
-template <typename... Type>
-bool operationHasParentForallOfMappingType(Operation *op) {
-  auto parentForallOp = op->getParentOfType<scf::ForallOp>();
-  while (parentForallOp) {
-    if (forallOpHasMappingType<Type...>(parentForallOp)) {
-      return true;
-    }
-    parentForallOp = parentForallOp->getParentOfType<scf::ForallOp>();
-  }
-  return false;
-}
-
 /// Pass to verify that writes only happen in distributed contexts. Code in
 /// shared contexts are executed uniformly across all threads after resolution
 /// of distributed contexts (i.e. scf.forall), thus operations with write
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
index cec8ba4..f02205a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
@@ -59,6 +59,14 @@
   let summary = "Convert named Linalg ops to linalg.generic ops";
 }
 
+def GPUInferMemorySpacePass :
+    InterfacePass<"iree-codegen-gpu-infer-memory-space", "mlir::FunctionOpInterface"> {
+  let summary = "Pass to infer and set the memory space for all alloc_tensor ops.";
+  let dependentDialects = [
+    "::mlir::gpu::GPUDialect"
+  ];
+}
+
 def GPULowerToUKernelsPass :
     Pass<"iree-codegen-gpu-lower-to-ukernels", ""> {
   let summary = "Separate out parts of the IR that lower to a micro-kernel";
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index 5854bd5..257cbe8 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -25,6 +25,7 @@
             "gpu_distribute_scf_for.mlir",
             "gpu_distribute_shared_memory.mlir",
             "gpu_generalize_named_ops.mlir",
+            "gpu_infer_memory_space.mlir",
             "gpu_lower_to_ukernels.mlir",
             "gpu_nested_layout_contract_amdgpu.mlir",
             "gpu_nested_layout_vector_distribution.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index a61138b..a67de53 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -21,6 +21,7 @@
     "gpu_distribute_scf_for.mlir"
     "gpu_distribute_shared_memory.mlir"
     "gpu_generalize_named_ops.mlir"
+    "gpu_infer_memory_space.mlir"
     "gpu_lower_to_ukernels.mlir"
     "gpu_nested_layout_contract_amdgpu.mlir"
     "gpu_nested_layout_vector_distribution.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_infer_memory_space.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_infer_memory_space.mlir
new file mode 100644
index 0000000..7d1533a
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_infer_memory_space.mlir
@@ -0,0 +1,54 @@
+// RUN: iree-opt %s --split-input-file --verify-diagnostics \
+// RUN:   --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-infer-memory-space))" | FileCheck %s
+
+func.func @write_in_lane_forall(%dest : tensor<4x3xi32>) -> tensor<4x3xi32> {
+  %alloc = bufferization.alloc_tensor() : tensor<2x3xi32>
+  %cst = arith.constant dense<0> : vector<2x3xi32>
+  %c0 = arith.constant 0 : index
+  %res = scf.forall (%arg0) in (2) shared_outs(%arg1 = %dest) -> tensor<4x3xi32> {
+    %w = vector.transfer_write %cst, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<2x3xi32>, tensor<2x3xi32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %w into %arg1[%arg0, 0] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<4x3xi32>
+    }
+  } {mapping = [#iree_gpu.lane_id<0>]}
+  return %res : tensor<4x3xi32>
+}
+
+// CHECK: func @write_in_lane_forall
+// CHECK:   %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<private>}
+// CHECK:   vector.transfer_write %{{.*}}, %[[ALLOC]]
+
+// -----
+
+func.func @forall_shared_dest(%w : tensor<2x3xi32>) -> tensor<4x3xi32> {
+  %dest = bufferization.alloc_tensor() : tensor<4x3xi32>
+  %res = scf.forall (%arg0) in (2) shared_outs(%arg1 = %dest) -> tensor<4x3xi32> {
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %w into %arg1[%arg0, 0] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<4x3xi32>
+    }
+  } {mapping = [#gpu.warp<x>]}
+  return %res : tensor<4x3xi32>
+}
+
+// CHECK: func @forall_shared_dest
+// CHECK:   %[[ALLOC:.+]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>}
+// CHECK:   scf.forall {{.*}} shared_outs(%{{.*}} = %[[ALLOC]])
+
+// -----
+
+func.func @already_annotated_alloc() -> tensor<2x3xi32> {
+  %alloc = bufferization.alloc_tensor() {memory_space = #gpu.address_space<private>} : tensor<2x3xi32>
+  return %alloc : tensor<2x3xi32>
+}
+
+// CHECK: func @already_annotated_alloc
+// CHECK:   bufferization.alloc_tensor() {memory_space = #gpu.address_space<private>}
+
+// -----
+
+// expected-error@+1 {{failed to set the gpu memory space for all `bufferization.alloc_tensor` ops}}
+func.func @unknown_memory_space() -> tensor<2x3xi32> {
+  // expected-error@+1 {{unexpected gpu memory space must be private or workgroup.}}
+  %alloc = bufferization.alloc_tensor() {memory_space = "bad"} : tensor<2x3xi32>
+  return %alloc : tensor<2x3xi32>
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index 9ef45c7..35bafc7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -166,6 +166,7 @@
         "@llvm-project//mlir:ArithToLLVM",
         "@llvm-project//mlir:ArithTransforms",
         "@llvm-project//mlir:BufferizationDialect",
+        "@llvm-project//mlir:BufferizationTransforms",
         "@llvm-project//mlir:ComplexToLLVM",
         "@llvm-project//mlir:ComplexToStandard",
         "@llvm-project//mlir:ControlFlowToLLVM",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index a5d3b08..0f8b40b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -116,6 +116,7 @@
     MLIRArithToLLVM
     MLIRArithTransforms
     MLIRBufferizationDialect
+    MLIRBufferizationTransforms
     MLIRComplexToLLVM
     MLIRComplexToStandard
     MLIRControlFlowToLLVM
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index f71f616..7c71f86 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
@@ -143,20 +144,6 @@
       .getResult();
 }
 
-static FailureOr<Value> gpuWorkgroupAllocationFn(OpBuilder &builder,
-                                                 Location loc,
-                                                 MemRefType memRefType,
-                                                 ValueRange dynamicSizes,
-                                                 unsigned alignment) {
-  gpu::AddressSpaceAttr addressSpace = gpu::AddressSpaceAttr::get(
-      builder.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
-  MemRefType allocType =
-      MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
-                      AffineMap(), addressSpace);
-  return builder.create<memref::AllocOp>(loc, allocType, dynamicSizes)
-      .getResult();
-}
-
 // Barriers are only needed when copying to/from workgroup memory. The only
 // other kind of memory that can be allocated is function memory, which is local
 // to a thread.
@@ -211,10 +198,8 @@
 // Common Pass Recipes
 //===----------------------------------------------------------------------===//
 
-static void addBufferizePasses(OpPassManager &funcPassManager,
-                               bool allowPrivateAllocations = true) {
-  BufferizationOptions::AllocationFn allocationFn =
-      allowPrivateAllocations ? gpuAllocationFn : gpuWorkgroupAllocationFn;
+static void addBufferizePasses(OpPassManager &funcPassManager) {
+  BufferizationOptions::AllocationFn allocationFn = gpuAllocationFn;
   BufferizationOptions::MemCpyFn memcpyFn = gpuCopyFn;
   addIREEComprehensiveBufferizePasses(funcPassManager, allocationFn, memcpyFn);
   funcPassManager.addPass(createCanonicalizerPass());
@@ -305,6 +290,34 @@
 // Tile and Fuse
 //===---------------------------------------------------------------------===//
 
+static FailureOr<Value> gpuRequireMemSpaceAllocationFn(OpBuilder &builder,
+                                                       Location loc,
+                                                       MemRefType memRefType,
+                                                       ValueRange dynamicSizes,
+                                                       unsigned alignment) {
+  // Bail out if the memref type does not specify a memory space.
+  if (!isa<gpu::AddressSpaceAttr>(memRefType.getMemorySpace())) {
+    return failure();
+  }
+  return builder.create<memref::AllocOp>(loc, memRefType, dynamicSizes)
+      .getResult();
+}
+
+static void addGPUBufferizePasses(OpPassManager &funcPassManager) {
+  funcPassManager.addPass(createEliminateEmptyTensorsPass());
+  funcPassManager.addPass(bufferization::createEmptyTensorToAllocTensorPass());
+  funcPassManager.addPass(createGPUInferMemorySpacePass());
+  BufferizationOptions::AllocationFn allocationFn =
+      gpuRequireMemSpaceAllocationFn;
+  BufferizationOptions::MemCpyFn memcpyFn = gpuCopyFn;
+  funcPassManager.addPass(
+      createIREEComprehensiveBufferizePass(allocationFn, memcpyFn));
+  addIREEPostBufferizationPasses(funcPassManager);
+
+  funcPassManager.addPass(createCanonicalizerPass());
+  funcPassManager.addPass(createCSEPass());
+}
+
 void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager) {
   tileAndDistributeToWorkgroup(funcPassManager,
                                /*useWARForCooperativeMatrixCodegen=*/false,
@@ -371,6 +384,10 @@
                                      /*normalizeForall=*/true}));
   funcPassManager.addPass(createCanonicalizerPass());
   funcPassManager.addPass(createCSEPass());
+
+  // TODO: This LICM instance is load bearing due to brittleness of the
+  // hoisting and fusion pass, as well as a lack of a fallback distribution
+  // pass.
   funcPassManager.addPass(createLoopInvariantCodeMotionPass());
 
   // Step 5. Greedily fuse parallel loops and hoist from serial loops.
@@ -385,7 +402,7 @@
   funcPassManager.addPass(createCleanupBufferAllocViewPass());
 
   // Step 7. Bufferize.
-  addBufferizePasses(funcPassManager, /*allowPrivateAllocations=*/true);
+  addGPUBufferizePasses(funcPassManager);
 
   // Step 8. Resolve remaining parallel loops.
   funcPassManager.addPass(createGPUVerifyDistributionPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
index 4f5425a..9914509 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir
@@ -558,6 +558,7 @@
 // CHECK-LABEL: func @conv_nchw_fused
 //       CHECK:   scf.for %{{.*}} = %c0 to %c64 step %c1
 //       CHECK:     linalg.conv_2d_nchw_fchw
+//  CHECK-SAME:       outs(%{{.*}} : memref<1x1x1x1xf32, #gpu.address_space<private>>)
 //       CHECK:   arith.addf
 //       CHECK:   arith.cmpf
 //       CHECK:   arith.select
diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
index b34209a..7cbddf4 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.h
@@ -39,6 +39,32 @@
 getSubgroupIdsAndCounts(OpBuilder &builder, Location loc, unsigned warpSize,
                         unsigned numDims, llvm::ArrayRef<int64_t> numSubgroups);
 
+// Indicates whether the given `scf.forall` op has a processor ID mapping of
+// the template type(s).
+template <typename... Type>
+bool forallOpHasMappingType(scf::ForallOp forallOp) {
+  std::optional<ArrayAttr> mapping = forallOp.getMapping();
+  if (!mapping || mapping.value().empty()) {
+    return false;
+  }
+
+  return isa<Type...>(*mapping.value().begin());
+}
+
+// Indicates whether an operation is within a distributed context with the
+// specified mapping type(s).
+template <typename... Type>
+bool operationHasParentForallOfMappingType(Operation *op) {
+  auto parentForallOp = op->getParentOfType<scf::ForallOp>();
+  while (parentForallOp) {
+    if (forallOpHasMappingType<Type...>(parentForallOp)) {
+      return true;
+    }
+    parentForallOp = parentForallOp->getParentOfType<scf::ForallOp>();
+  }
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // GPU vectorization
 //===----------------------------------------------------------------------===//