[CUDA-Codegen] Add support for constant tensor (#5966)

Add passes createTensorConstantBufferizePass to support lowering inline constant tensors.
diff --git a/iree/compiler/Conversion/Common/BUILD b/iree/compiler/Conversion/Common/BUILD
index bf7192f..3bd8795 100644
--- a/iree/compiler/Conversion/Common/BUILD
+++ b/iree/compiler/Conversion/Common/BUILD
@@ -12,18 +12,38 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+load("//build_tools/bazel:tblgen.bzl", "gentbl_cc_library")
+
 package(
     default_visibility = ["//visibility:public"],
     features = ["layering_check"],
     licenses = ["notice"],  # Apache 2.0
 )
 
+gentbl_cc_library(
+    name = "FoldTensorExtractOpIncGen",
+    tbl_outs = [
+        (
+            ["-gen-rewriters"],
+            "FoldTensorExtractOp.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "FoldTensorExtractOp.td",
+    td_srcs = [
+        "@llvm-project//mlir:OpBaseTdFiles",
+        "@llvm-project//mlir:MemRefOpsTdFiles",
+        "@llvm-project//mlir:TensorOpsTdFiles",
+    ],
+)
+
 cc_library(
     name = "Common",
     srcs = [
         "BufferAllocViewCleanUpPass.cpp",
         "DemoteF32ToF16.cpp",
         "FlattenMemRefSubspanPass.cpp",
+        "FoldTensorExtractOpPass.cpp",
         "ForOpCanonicalizationPass.cpp",
         "LaunchConfig.cpp",
         "LinalgBufferizePass.cpp",
@@ -39,6 +59,7 @@
     ],
     deps = [
         "//iree/compiler/Conversion/CodegenUtils",
+        "//iree/compiler/Conversion/Common:FoldTensorExtractOpIncGen",
         "//iree/compiler/Dialect/Flow/IR",
         "//iree/compiler/Dialect/HAL/IR",
         "//iree/compiler/Dialect/IREE/IR",
@@ -58,6 +79,7 @@
         "@llvm-project//mlir:SideEffectInterfaces",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
         "@llvm-project//mlir:Transforms",
         "@llvm-project//mlir:VectorOps",
     ],
diff --git a/iree/compiler/Conversion/Common/CMakeLists.txt b/iree/compiler/Conversion/Common/CMakeLists.txt
index 9514518..7816f9c 100644
--- a/iree/compiler/Conversion/Common/CMakeLists.txt
+++ b/iree/compiler/Conversion/Common/CMakeLists.txt
@@ -10,6 +10,15 @@
 
 iree_add_all_subdirs()
 
+iree_tablegen_library(
+  NAME
+    FoldTensorExtractOpIncGen
+  TD_FILE
+    "FoldTensorExtractOp.td"
+  OUTS
+    -gen-rewriters FoldTensorExtractOp.cpp.inc
+)
+
 iree_cc_library(
   NAME
     Common
@@ -21,6 +30,7 @@
     "BufferAllocViewCleanUpPass.cpp"
     "DemoteF32ToF16.cpp"
     "FlattenMemRefSubspanPass.cpp"
+    "FoldTensorExtractOpPass.cpp"
     "ForOpCanonicalizationPass.cpp"
     "LaunchConfig.cpp"
     "LinalgBufferizePass.cpp"
@@ -43,9 +53,11 @@
     MLIRSideEffectInterfaces
     MLIRStandard
     MLIRSupport
+    MLIRTensor
     MLIRTransforms
     MLIRVector
     iree::compiler::Conversion::CodegenUtils
+    iree::compiler::Conversion::Common::FoldTensorExtractOpIncGen
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOp.td b/iree/compiler/Conversion/Common/FoldTensorExtractOp.td
similarity index 81%
rename from iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOp.td
rename to iree/compiler/Conversion/Common/FoldTensorExtractOp.td
index 1f6d0b0..7312267 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOp.td
+++ b/iree/compiler/Conversion/Common/FoldTensorExtractOp.td
@@ -12,8 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#ifndef IREE_COMPILER_CONVERSION_LINALGTOLLVM_FOLDTENSOREXTRACTOP
-#define IREE_COMPILER_CONVERSION_LINALGTOLLVM_FOLDTENSOREXTRACTOP
+#ifndef IREE_COMPILER_CONVERSION_COMMON_FOLDTENSOREXTRACTOP
+#define IREE_COMPILER_CONVERSION_COMMON_FOLDTENSOREXTRACTOP
 
 include "mlir/Dialect/MemRef/IR/MemRefOps.td"
 include "mlir/Dialect/Tensor/IR/TensorOps.td"
@@ -23,4 +23,4 @@
 def : Pat<(Tensor_ExtractOp (TensorLoadOp $value), $indices),
           (LoadOp $value, $indices)>;
 
-#endif // IREE_COMPILER_CONVERSION_LINALGTOLLVM_FOLDTENSOREXTRACTOP
+#endif // IREE_COMPILER_CONVERSION_COMMON_FOLDTENSOREXTRACTOP
diff --git a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp b/iree/compiler/Conversion/Common/FoldTensorExtractOpPass.cpp
similarity index 95%
rename from iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp
rename to iree/compiler/Conversion/Common/FoldTensorExtractOpPass.cpp
index 026dd95..133ef12 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOpPass.cpp
+++ b/iree/compiler/Conversion/Common/FoldTensorExtractOpPass.cpp
@@ -11,7 +11,7 @@
 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 // See the License for the specific language governing permissions and
 // limitations under the License.
-#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
+#include "iree/compiler/Conversion/Common/Passes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -23,7 +23,7 @@
 namespace iree_compiler {
 
 namespace {
-#include "iree/compiler/Conversion/LinalgToLLVM/FoldTensorExtractOp.cpp.inc"
+#include "iree/compiler/Conversion/Common/FoldTensorExtractOp.cpp.inc"
 }
 
 namespace {
diff --git a/iree/compiler/Conversion/Common/Passes.h b/iree/compiler/Conversion/Common/Passes.h
index 9cd0624..bd29a23 100644
--- a/iree/compiler/Conversion/Common/Passes.h
+++ b/iree/compiler/Conversion/Common/Passes.h
@@ -71,5 +71,10 @@
 std::unique_ptr<OperationPass<IREE::HAL::ExecutableTargetOp>>
 createSetNumWorkgroupsPass(ArrayRef<int64_t> workgroupSize = {});
 
+/// After running the upstream TensorConstantBufferize pass, remove tensor_loads
+/// introduced for use only in tensor_extract. These can be folded to use a load
+/// of the created memref object that holds the constant values.
+std::unique_ptr<OperationPass<>> createFoldTensorExtractOpPass();
+
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Conversion/Common/test/BUILD b/iree/compiler/Conversion/Common/test/BUILD
index 0d8f1a2..f1b956a 100644
--- a/iree/compiler/Conversion/Common/test/BUILD
+++ b/iree/compiler/Conversion/Common/test/BUILD
@@ -31,6 +31,7 @@
             "canonicalize_interface_load_store.mlir",
             "f32Tof16.mlir",
             "flatten_memref_subspan.mlir",
+            "fold_tensor_extract_op.mlir",
             "forop_canonicalization.mlir",
             "linalg_bufferize.mlir",
             "remove_dead_allocs.mlir",
diff --git a/iree/compiler/Conversion/Common/test/CMakeLists.txt b/iree/compiler/Conversion/Common/test/CMakeLists.txt
index 3ae07b0..25b8b05 100644
--- a/iree/compiler/Conversion/Common/test/CMakeLists.txt
+++ b/iree/compiler/Conversion/Common/test/CMakeLists.txt
@@ -18,6 +18,7 @@
     "canonicalize_interface_load_store.mlir"
     "f32Tof16.mlir"
     "flatten_memref_subspan.mlir"
+    "fold_tensor_extract_op.mlir"
     "forop_canonicalization.mlir"
     "linalg_bufferize.mlir"
     "remove_dead_allocs.mlir"
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/fold_tensor_extract_op.mlir b/iree/compiler/Conversion/Common/test/fold_tensor_extract_op.mlir
similarity index 100%
rename from iree/compiler/Conversion/LinalgToLLVM/test/fold_tensor_extract_op.mlir
rename to iree/compiler/Conversion/Common/test/fold_tensor_extract_op.mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index b708cc8..567c85e 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -12,36 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-load("//build_tools/bazel:tblgen.bzl", "gentbl_cc_library")
-
 package(
     default_visibility = ["//visibility:public"],
     features = ["layering_check"],
     licenses = ["notice"],  # Apache 2.0
 )
 
-gentbl_cc_library(
-    name = "FoldTensorExtractOpIncGen",
-    tbl_outs = [
-        (
-            ["-gen-rewriters"],
-            "FoldTensorExtractOp.cpp.inc",
-        ),
-    ],
-    tblgen = "@llvm-project//mlir:mlir-tblgen",
-    td_file = "FoldTensorExtractOp.td",
-    td_srcs = [
-        "@llvm-project//mlir:OpBaseTdFiles",
-        "@llvm-project//mlir:MemRefOpsTdFiles",
-        "@llvm-project//mlir:TensorOpsTdFiles",
-    ],
-)
-
 cc_library(
     name = "LinalgToLLVM",
     srcs = [
         "ConvertToLLVM.cpp",
-        "FoldTensorExtractOpPass.cpp",
         "KernelDispatch.cpp",
         "LLVMCodeGenOptions.cpp",
         "LinalgTileAndVectorizePass.cpp",
@@ -60,7 +40,6 @@
     deps = [
         "//iree/compiler/Conversion/CodegenUtils",
         "//iree/compiler/Conversion/Common",
-        "//iree/compiler/Conversion/LinalgToLLVM:FoldTensorExtractOpIncGen",
         "//iree/compiler/Conversion/VectorToLLVM",
         "//iree/compiler/Dialect/Flow/IR",
         "//iree/compiler/Dialect/HAL/IR",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index c42d4fc..2fdfd52 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -10,15 +10,6 @@
 
 iree_add_all_subdirs()
 
-iree_tablegen_library(
-  NAME
-    FoldTensorExtractOpIncGen
-  TD_FILE
-    "FoldTensorExtractOp.td"
-  OUTS
-    -gen-rewriters FoldTensorExtractOp.cpp.inc
-)
-
 iree_cc_library(
   NAME
     LinalgToLLVM
@@ -28,7 +19,6 @@
     "Passes.h"
   SRCS
     "ConvertToLLVM.cpp"
-    "FoldTensorExtractOpPass.cpp"
     "KernelDispatch.cpp"
     "LLVMCodeGenOptions.cpp"
     "LinalgTileAndVectorizePass.cpp"
@@ -66,7 +56,6 @@
     MLIRVectorToSCF
     iree::compiler::Conversion::CodegenUtils
     iree::compiler::Conversion::Common
-    iree::compiler::Conversion::LinalgToLLVM::FoldTensorExtractOpIncGen
     iree::compiler::Conversion::VectorToLLVM
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::HAL::IR
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 147a650..3ce7581 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -50,11 +50,6 @@
 /// Pass to convert Linalg ops into vector operations.
 std::unique_ptr<FunctionPass> createLinalgVectorizePass();
 
-/// After running the upstream TensorConstantBufferize pass, remove tensor_loads
-/// introduced for use only in tensor_extract. These can be folded to use a load
-/// of the created memref object that holds the constant values.
-std::unique_ptr<OperationPass<>> createFoldTensorExtractOpPass();
-
 //===----------------------------------------------------------------------===//
 // Pass Pipelines for CPU Lowering
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/BUILD b/iree/compiler/Conversion/LinalgToLLVM/test/BUILD
index 1f861d7..c0fc200 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/BUILD
@@ -30,7 +30,6 @@
             "hal_interface_bindings.mlir",
             "hal_interface_constants.mlir",
             "hal_interface_workgroup_info.mlir",
-            "fold_tensor_extract_op.mlir",
             "linalg_vectorize.mlir",
             "materialize_launch_configuration.mlir",
             "matmul_vectorization.mlir",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt
index 811ea5a..7b0821f 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/CMakeLists.txt
@@ -14,7 +14,6 @@
   NAME
     lit
   SRCS
-    "fold_tensor_extract_op.mlir"
     "hal_interface_bindings.mlir"
     "hal_interface_constants.mlir"
     "hal_interface_workgroup_info.mlir"
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD b/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD
index 58f9f9d..a101363 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/BUILD
@@ -57,6 +57,7 @@
         "@llvm-project//mlir:ROCDLDialect",
         "@llvm-project//mlir:SCFToStandard",
         "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:StandardOpsTransforms",
         "@llvm-project//mlir:StandardToSPIRV",
         "@llvm-project//mlir:Transforms",
         "@llvm-project//mlir:VectorOps",
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt
index 687eb59..71a1a72 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/CMakeLists.txt
@@ -42,6 +42,7 @@
     MLIRROCDLIR
     MLIRSCFToStandard
     MLIRStandard
+    MLIRStandardOpsTransforms
     MLIRStandardToLLVM
     MLIRStandardToSPIRV
     MLIRTransforms
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp
index ccea7dd..80192a0 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/Passes.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
 #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassOptions.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -59,6 +60,10 @@
   pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCanonicalizerPass());
   pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCSEPass());
 
+  // Handled tensor-type constants.
+  pm.addNestedPass<ModuleOp>(createTensorConstantBufferizePass());
+  pm.addNestedPass<ModuleOp>(createFoldTensorExtractOpPass());
+
   // SCF -> STD
   pm.nest<ModuleOp>().addNestedPass<FuncOp>(createLowerToCFGPass());
   pm.nest<ModuleOp>().addNestedPass<FuncOp>(createCanonicalizerPass());
diff --git a/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir b/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir
index 3a94b8e..50659e0 100644
--- a/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVMGPU/test/nvvm_pipeline_test.mlir
@@ -181,3 +181,41 @@
 //         CHECK:   lvm.fmul %{{.*}}, %{{.*}}  : f32
 //         CHECK:   llvm.fadd %{{.*}}, %{{.*}}  : f32
 //         CHECK:   llvm.store {{.*}} : !llvm.ptr<f32>
+
+// -----
+
+hal.executable @simpleMath_ex_dispatch_0 {
+  hal.interface @io {
+    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+    hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+  }
+  hal.executable.target @cuda, filter="cuda" {
+  hal.executable.entry_point @add_dispatch_0 attributes {interface = @io, ordinal = 0 : index}
+  module  {
+    func @add_dispatch_0() {
+      %c0 = constant 0 : index
+      %0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:16xf32>
+      %2 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:16xf32>
+      %3 = linalg.init_tensor [16] : tensor<16xf32>
+      %4 = flow.dispatch.tensor.load %0, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<readonly:16xf32> -> tensor<16xf32>
+      %5 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32>
+      %6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%4, %5 : tensor<16xf32>, tensor<16xf32>) outs(%3 : tensor<16xf32>) {
+      ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):  // no predecessors
+          %7 = addf %arg0, %arg1 : f32
+          linalg.yield %7 : f32
+        } -> tensor<16xf32>
+        flow.dispatch.tensor.store %6, %2, offsets=[], sizes=[], strides=[] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:16xf32>
+        return
+      }
+      hal.interface @io attributes {sym_visibility = "private"} {
+        hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+        hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+      }
+    }
+  }
+}
+
+// CHECK-LABEL: hal.executable @simpleMath_ex_dispatch_0
+//       CHECK:   hal.executable.target @cuda, filter="cuda" {
+//       CHECK:   llvm.mlir.global private constant @{{.*}}(dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01, 1.100000e+01, 1.200000e+01, 1.300000e+01, 1.400000e+01, 1.500000e+01, 1.600000e+01]> : tensor<16xf32>)
+//       CHECK:   llvm.fadd