Add NVVM target to compiler (#4912)
Add NVVM/CUDA target to HAL and add lowering pipeline for linalg to NVVM.
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 243e3e6..274ab15 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -149,6 +149,7 @@
# List of all target backends to be built by default:
set(IREE_ALL_TARGET_BACKENDS
+ CUDA
DYLIB-LLVM-AOT
Metal-SPIRV
Vulkan-SPIRV
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index e85ec49..cccd722 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -51,6 +51,7 @@
"@llvm-project//mlir:MlirOptLib": ["MLIROptLib"],
"@llvm-project//mlir:VectorOps": ["MLIRVector"],
"@llvm-project//mlir:TensorDialect": ["MLIRTensor"],
+ "@llvm-project//mlir:NVVMDialect": ["MLIRNVVMIR"],
# Vulkan
"@iree_vulkan_headers//:vulkan_headers": ["Vulkan::Headers"],
# Cuda
diff --git a/build_tools/cmake/iree_copts.cmake b/build_tools/cmake/iree_copts.cmake
index fa8f56f..b673006 100644
--- a/build_tools/cmake/iree_copts.cmake
+++ b/build_tools/cmake/iree_copts.cmake
@@ -403,7 +403,8 @@
set(LLVM_ENABLE_IDE ON CACHE BOOL "" FORCE)
# TODO(ataei): Use optional build time targets selection for LLVMAOT.
-set(LLVM_TARGETS_TO_BUILD "WebAssembly;X86;ARM;AArch64;RISCV" CACHE STRING "" FORCE)
+set(LLVM_TARGETS_TO_BUILD "WebAssembly;X86;ARM;AArch64;RISCV;NVPTX"
+ CACHE STRING "" FORCE)
set(LLVM_ENABLE_PROJECTS "mlir" CACHE STRING "" FORCE)
set(LLVM_ENABLE_BINDINGS OFF CACHE BOOL "" FORCE)
diff --git a/iree/compiler/Conversion/BUILD b/iree/compiler/Conversion/BUILD
index bfa7317..555dc1a 100644
--- a/iree/compiler/Conversion/BUILD
+++ b/iree/compiler/Conversion/BUILD
@@ -26,6 +26,7 @@
deps = [
"//iree/compiler/Conversion/HLOToLinalg",
"//iree/compiler/Conversion/LinalgToLLVM",
+ "//iree/compiler/Conversion/LinalgToNVVM",
"//iree/compiler/Conversion/LinalgToSPIRV",
"//iree/compiler/Conversion/LinalgToVector",
],
diff --git a/iree/compiler/Conversion/CMakeLists.txt b/iree/compiler/Conversion/CMakeLists.txt
index 0a61e14..9f12d87 100644
--- a/iree/compiler/Conversion/CMakeLists.txt
+++ b/iree/compiler/Conversion/CMakeLists.txt
@@ -10,6 +10,7 @@
DEPS
iree::compiler::Conversion::HLOToLinalg
iree::compiler::Conversion::LinalgToLLVM
+ iree::compiler::Conversion::LinalgToNVVM
iree::compiler::Conversion::LinalgToSPIRV
iree::compiler::Conversion::LinalgToVector
PUBLIC
diff --git a/iree/compiler/Conversion/Common/Passes.cpp b/iree/compiler/Conversion/Common/Passes.cpp
index 3856f7d..6b9d2a7 100644
--- a/iree/compiler/Conversion/Common/Passes.cpp
+++ b/iree/compiler/Conversion/Common/Passes.cpp
@@ -21,7 +21,7 @@
void addLinalgBufferizePasses(OpPassManager &passManager,
WorkgroupMemoryAllocationFn allocationFn) {
- passManager.addPass(createLinalgBufferizePass(allocationFn));
+ passManager.addNestedPass<FuncOp>(createLinalgBufferizePass(allocationFn));
passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
passManager.addNestedPass<FuncOp>(createCSEPass());
passManager.addNestedPass<FuncOp>(createRemoveDeadMemAllocsPass());
diff --git a/iree/compiler/Conversion/LinalgToNVVM/BUILD b/iree/compiler/Conversion/LinalgToNVVM/BUILD
new file mode 100644
index 0000000..db0d69e
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToNVVM/BUILD
@@ -0,0 +1,50 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "LinalgToNVVM",
+ srcs = [
+ "ConvertToNVVM.cpp",
+ "Passes.cpp",
+ ],
+ hdrs = [
+ "Passes.h",
+ ],
+ deps = [
+ "//iree/compiler/Conversion/CodegenUtils",
+ "//iree/compiler/Conversion/Common",
+ "//iree/compiler/Conversion/HLOToHLO",
+ "//iree/compiler/Conversion/HLOToLinalg",
+ "//iree/compiler/Dialect/HAL/IR",
+ "//iree/compiler/Dialect/IREE/IR",
+ "//iree/compiler/Dialect/Shape/Transforms",
+ "@llvm-project//mlir:GPUToNVVMTransforms",
+ "@llvm-project//mlir:GPUTransforms",
+ "@llvm-project//mlir:LLVMTransforms",
+ "@llvm-project//mlir:LinalgTransforms",
+ "@llvm-project//mlir:NVVMDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFToStandard",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Transforms",
+ "@mlir-hlo//:hlo",
+ "@mlir-hlo//:legalize_to_linalg",
+ ],
+)
diff --git a/iree/compiler/Conversion/LinalgToNVVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToNVVM/CMakeLists.txt
new file mode 100644
index 0000000..e34768d
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToNVVM/CMakeLists.txt
@@ -0,0 +1,32 @@
+# Autogenerated from iree/compiler/Conversion/LinalgToNVVM/BUILD by
+# build_tools/bazel_to_cmake/bazel_to_cmake.py
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ LinalgToNVVM
+ HDRS
+ "Passes.h"
+ SRCS
+ "ConvertToNVVM.cpp"
+ "Passes.cpp"
+ DEPS
+ MLIRGPU
+ MLIRGPUToNVVMTransforms
+ MLIRLinalgTransforms
+ MLIRNVVMIR
+ MLIRPass
+ MLIRSCFToStandard
+ MLIRStandard
+ MLIRStandardToLLVM
+ MLIRTransforms
+ iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Conversion::Common
+ iree::compiler::Conversion::HLOToHLO
+ iree::compiler::Conversion::HLOToLinalg
+ iree::compiler::Dialect::HAL::IR
+ iree::compiler::Dialect::IREE::IR
+ iree::compiler::Dialect::Shape::Transforms
+ tensorflow::mlir_hlo
+ PUBLIC
+)
diff --git a/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp b/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp
new file mode 100644
index 0000000..69bb302
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp
@@ -0,0 +1,187 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/LinalgToNVVM/Passes.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+class ConvertFunc : public ConvertToLLVMPattern {
+ public:
+ explicit ConvertFunc(MLIRContext *context, LLVMTypeConverter &converter)
+ : ConvertToLLVMPattern(mlir::FuncOp::getOperationName(), context,
+ converter, 100) {}
+
+ LogicalResult matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto funcOp = cast<FuncOp>(op);
+ FunctionType fnType = funcOp.getType();
+ (void)fnType;
+ if (!funcOp.isPublic()) return failure();
+
+ // illegal FuncOp must have 0 inputs.
+ assert(fnType.getNumInputs() == 0 && fnType.getNumResults() == 0);
+
+ TypeConverter::SignatureConversion signatureConverter(/*numOrigInputs=*/0);
+ SmallVector<Type, 8> llvmInputTypes;
+ funcOp.walk([&](IREE::HAL::InterfaceBindingSubspanOp input) {
+ auto memrefType = input.getType().cast<MemRefType>();
+ Type elType = memrefType.getElementType();
+ auto llvmType =
+ LLVM::LLVMPointerType::get(elType, memrefType.getMemorySpace());
+ llvmInputTypes.push_back(llvmType);
+ });
+ signatureConverter.addInputs(llvmInputTypes);
+
+ // Construct newFunc with all attributes except return type & symbol name.
+ SmallVector<NamedAttribute, 4> funcAttrs;
+ for (auto attr : funcOp.getAttrs()) {
+ if (attr.first == SymbolTable::getSymbolAttrName() ||
+ attr.first == mlir::impl::getTypeAttrName()) {
+ continue;
+ }
+ funcAttrs.push_back(attr);
+ }
+
+ auto llvmFuncType = LLVM::LLVMFunctionType::get(
+ LLVM::LLVMVoidType::get(rewriter.getContext()), llvmInputTypes);
+ auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
+ funcOp.getLoc(), funcOp.getName(), llvmFuncType,
+ LLVM::Linkage::External, funcAttrs);
+
+ // Copy all of funcOp's operations into newFuncOp's body and perform region
+ // type conversion.
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
+ &signatureConverter)))
+ return failure();
+
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+};
+
+class ConvertIREEBindingOp : public ConvertToLLVMPattern {
+ public:
+ explicit ConvertIREEBindingOp(MLIRContext *context,
+ LLVMTypeConverter &converter)
+ : ConvertToLLVMPattern(
+ IREE::HAL::InterfaceBindingSubspanOp::getOperationName(), context,
+ converter) {}
+
+ LogicalResult matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Bail until nested under an LLVMFuncOp.
+ auto llvmFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+ if (!llvmFuncOp) return failure();
+ assert(llvmFuncOp.getNumArguments() > 0);
+
+ Location loc = op->getLoc();
+ auto ireeBindingOp = cast<IREE::HAL::InterfaceBindingSubspanOp>(op);
+ IREE::HAL::InterfaceBindingSubspanOpAdaptor adaptor(operands);
+ MemRefType memrefType =
+ ireeBindingOp.getResult().getType().dyn_cast<MemRefType>();
+
+ // Fetch the interface binding op and extract the buffer index from void**.
+ auto symbol = SymbolTable::lookupNearestSymbolFrom(
+ op, op->getAttrOfType<SymbolRefAttr>("binding"));
+ auto interfaceBindingOp = cast<IREE::HAL::InterfaceBindingOp>(symbol);
+ Value llvmBufferBasePtr =
+ llvmFuncOp.getArgument(interfaceBindingOp.binding());
+ if (memrefType.hasStaticShape()) {
+ auto desc = MemRefDescriptor::fromStaticShape(
+ rewriter, loc, *getTypeConverter(), memrefType, llvmBufferBasePtr);
+ rewriter.replaceOp(op, {desc});
+ } else {
+ // TODO: pull those paramters from HAL constants.
+ assert(0 && "TODO: implement dynamic shape");
+ }
+
+ return success();
+ }
+};
+
+/// A pass that replaces all occurrences of GPU device operations with their
+/// corresponding NVVM equivalent.
+///
+/// This pass only handles device code and is not meant to be run on GPU host
+/// code.
+struct ConvertToNVVMPass
+ : public PassWrapper<ConvertToNVVMPass, OperationPass<ModuleOp>> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
+ }
+ void runOnOperation() override {
+ ModuleOp m = getOperation();
+
+ /// Customize the bitwidth used for the device side index computations.
+ LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
+ /*emitCWrappers =*/false,
+ /*indexBitwidth =*/64,
+ /*useAlignedAlloc =*/false};
+ LLVMTypeConverter converter(m.getContext(), options);
+ // Apply in-dialect lowering first. In-dialect lowering will replace ops
+ // which need to be lowered further, which is not supported by a single
+ // conversion pass.
+ {
+ OwningRewritePatternList patterns;
+ populateGpuRewritePatterns(m.getContext(), patterns);
+ (void)applyPatternsAndFoldGreedily(m, std::move(patterns));
+ }
+ {
+ OwningRewritePatternList llvmPatterns;
+ llvmPatterns.insert<ConvertFunc, ConvertIREEBindingOp>(m.getContext(),
+ converter);
+ populateStdToLLVMConversionPatterns(converter, llvmPatterns);
+ populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
+ LLVMConversionTarget target(getContext());
+ populateStdToLLVMFuncOpConversionPattern(converter, llvmPatterns);
+ configureGpuToNVVMConversionLegality(target);
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
+ if (isEntryPoint(funcOp)) return false;
+ return true;
+ });
+ if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
+ signalPassFailure();
+ }
+ }
+};
+
+} // anonymous namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> createConvertToNVVMPass() {
+ return std::make_unique<ConvertToNVVMPass>();
+}
+
+static PassRegistration<ConvertToNVVMPass> pass(
+ "iree-codegen-convert-to-nvvm",
+ "Perform final conversion from builtin/GPU/HAL/standard dialect to LLVM "
+ "and NVVM dialects");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToNVVM/Passes.cpp b/iree/compiler/Conversion/LinalgToNVVM/Passes.cpp
new file mode 100644
index 0000000..5807951
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToNVVM/Passes.cpp
@@ -0,0 +1,93 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Conversion/LinalgToNVVM/Passes.h"
+
+#include "iree/compiler/Conversion/Common/Passes.h"
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
+#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
+#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+static void addLinalgToNVVMPasses(OpPassManager &pm) {
+ //===--------------------------------------------------------------------===//
+ // Initial clean up.
+ //===--------------------------------------------------------------------===//
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+
+ // TODO: This currently maps to a single thread. We should share Tile and
+ // distribute with other GPU backends.
+ // Linalg -> SCF
+ pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
+ pm.addNestedPass<FuncOp>(createCanonicalizerPass());
+ pm.addNestedPass<FuncOp>(createCSEPass());
+
+ // SCF -> STD
+ pm.addNestedPass<FuncOp>(createLowerToCFGPass());
+ pm.addNestedPass<FuncOp>(createCanonicalizerPass());
+ pm.addNestedPass<FuncOp>(createCSEPass());
+
+ // Strip out the debug info for the kernel as CUDA driver doesn't diggest PTX
+ // debug info well.
+ pm.addPass(createStripDebugInfoPass());
+ // convert to NVVM.
+ pm.addPass(createConvertToNVVMPass());
+}
+
+void buildNVVMTransformPassPipeline(OpPassManager &pm) {
+ OpPassManager &nestedModulePM = pm.nest<ModuleOp>();
+ nestedModulePM.addPass(createInlinerPass());
+
+ WorkgroupMemoryAllocationFn allocationFn =
+ [](OpBuilder &builder, Location loc, ArrayRef<int64_t> staticShape,
+ Type elementType, ArrayRef<Value> dynamicSizes) {
+ MemRefType allocType = MemRefType::get(staticShape, elementType, {}, 3);
+ return builder.create<AllocOp>(loc, allocType, dynamicSizes);
+ };
+ addLinalgBufferizePasses(nestedModulePM, allocationFn);
+
+ //===--------------------------------------------------------------------===//
+ // Convert Linalg ops to LLVM+NVVM ops.
+ //
+ // Post-conditions:
+ // - All Linalg/Loops/GPU/Affine/Standard ops are converted away.
+ // - The module contains the final llvm.module ready to be serialized.
+ //===--------------------------------------------------------------------===//
+ addLinalgToNVVMPasses(nestedModulePM);
+}
+
+static PassPipelineRegistration<> linalgToNVVMPipeline(
+ "iree-codegen-linalg-to-nvvm-pipeline",
+ "Runs the progressive lowering pipeline from Linalg to NVVM",
+ [](OpPassManager &passManager) { addLinalgToNVVMPasses(passManager); });
+
+static PassPipelineRegistration<> hloToLinalgNVVMPipeline(
+ "iree-codegen-hlo-to-nvvm-pipeline",
+ "Runs the progressive lowering pipeline from XLA HLO to Linalg to "
+ "NVVM",
+ [](OpPassManager &passManager) {
+ buildNVVMTransformPassPipeline(passManager);
+ });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToNVVM/Passes.h b/iree/compiler/Conversion/LinalgToNVVM/Passes.h
new file mode 100644
index 0000000..cb731dd
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToNVVM/Passes.h
@@ -0,0 +1,34 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_CONVERSION_LINALGTONVVM_PASSES_H_
+#define IREE_COMPILER_CONVERSION_LINALGTONVVM_PASSES_H_
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Performs the final conversion to NNVM+LLVM dialect.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertToNVVMPass();
+
+/// Populates passes needed to lower a XLA HLO op to NVVM dialect via the
+/// structured ops path. The pass manager `pm` in here should operate on the
+/// module within the IREE::HAL::ExecutableOp.
+void buildNVVMTransformPassPipeline(OpPassManager &pm);
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_CONVERSION_LINALGTONVVM_PASSES_H_
diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/BUILD b/iree/compiler/Conversion/LinalgToNVVM/test/BUILD
new file mode 100644
index 0000000..675ed91
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToNVVM/test/BUILD
@@ -0,0 +1,32 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Tests for common transforms.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt
new file mode 100644
index 0000000..ecaddd9
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToNVVM/test/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Autogenerated from iree/compiler/Conversion/LinalgToNVVM/test/BUILD by
+# build_tools/bazel_to_cmake/bazel_to_cmake.py
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/convert_to_nvvm.mlir b/iree/compiler/Conversion/LinalgToNVVM/test/convert_to_nvvm.mlir
new file mode 100644
index 0000000..fb3e720
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToNVVM/test/convert_to_nvvm.mlir
@@ -0,0 +1,29 @@
+// RUN: iree-opt -iree-codegen-convert-to-nvvm %s | IreeFileCheck %s
+
+// Test that that standard and GPU ops are converted to LLVM and NVVM.
+func @abs_ex_dispatch_0() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : memref<16xf32>
+ %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : memref<16xf32>
+ %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : memref<16xf32>
+ %3 = "gpu.block_id"() {dimension = "x"} : () -> index
+ %4 = "gpu.block_dim"() {dimension = "x"} : () -> index
+ %5 = "gpu.thread_id"() {dimension = "x"} : () -> index
+ %6 = muli %3, %4 : index
+ %7 = addi %6, %5 : index
+ %9 = load %1[%7] : memref<16xf32>
+ %10 = load %2[%7] : memref<16xf32>
+ %11 = addf %9, %10 : f32
+ store %11, %0[%7] : memref<16xf32>
+ return
+}
+hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+}
+
+// CHECK-LABEL: llvm.func @abs_ex_dispatch_0
+// CHECK-SAME: (%{{.*}}: !llvm.ptr<f32>, %{{.*}}: !llvm.ptr<f32>, %{{.*}}: !llvm.ptr<f32>)
+// CHECK: nvvm.read.ptx.sreg.tid.x
+// CHECK: llvm.fadd
diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir
new file mode 100644
index 0000000..170dee6
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir
@@ -0,0 +1,41 @@
+// RUN: iree-opt -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-hlo-to-nvvm-pipeline))" %s | IreeFileCheck %s
+
+// Verify that a simple element wise op gets lowered succefully all the way to
+// nvvm/llvm dialect.
+
+hal.executable @simpleMath_ex_dispatch_0 {
+ hal.interface @legacy_io {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+ }
+ hal.executable.target @cuda, filter="cuda" {
+ hal.executable.entry_point @add_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.input<16xf32>, !flow.dispatch.input<16xf32>, !flow.dispatch.output<16xf32>) -> ()}
+ module {
+ func @add_dispatch_0() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @legacy_io::@arg0[%c0] : !flow.dispatch.input<16xf32>
+ %1 = hal.interface.binding.subspan @legacy_io::@arg1[%c0] : !flow.dispatch.input<16xf32>
+ %2 = hal.interface.binding.subspan @legacy_io::@ret0[%c0] : !flow.dispatch.output<16xf32>
+ %3 = linalg.init_tensor [16] : tensor<16xf32>
+ %4 = flow.dispatch.input.load %0 : !flow.dispatch.input<16xf32> -> tensor<16xf32>
+ %5 = flow.dispatch.input.load %1 : !flow.dispatch.input<16xf32> -> tensor<16xf32>
+ %6 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%4, %5 : tensor<16xf32>, tensor<16xf32>) outs(%3 : tensor<16xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors
+ %7 = addf %arg0, %arg1 : f32
+ linalg.yield %7 : f32
+ } -> tensor<16xf32>
+ flow.dispatch.output.store %6, %2 : tensor<16xf32> -> !flow.dispatch.output<16xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: hal.executable @simpleMath_ex_dispatch_0
+// CHECK: hal.executable.target @cuda, filter="cuda" {
+// CHECK: llvm.fadd
diff --git a/iree/compiler/Dialect/HAL/IR/HALBase.td b/iree/compiler/Dialect/HAL/IR/HALBase.td
index 4dce36e..7d29302 100644
--- a/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -237,6 +237,7 @@
def HAL_EF_Metal : I32EnumAttrCase<"Metal", 1297370181>;
def HAL_EF_LLVM : I32EnumAttrCase<"LLVM", 1280071245>;
def HAL_EF_DyLib : I32EnumAttrCase<"DyLib", 1145850178>;
+def HAL_EF_CUDA : I32EnumAttrCase<"CUDA", 1129661505>;
def HAL_ExecutableFormatAttr :
I32EnumAttr<"ExecutableFormat", "IREE HAL Executable format", [
HAL_EF_Unspecified,
@@ -245,7 +246,8 @@
HAL_EF_VMLA,
HAL_EF_SpirV,
HAL_EF_Metal,
- HAL_EF_DyLib
+ HAL_EF_DyLib,
+ HAL_EF_CUDA
]> {
let returnType = "IREE::HAL::ExecutableFormat";
let convertFromStorage = "static_cast<IREE::HAL::ExecutableFormat>($_self.getInt())";
diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/BUILD b/iree/compiler/Dialect/HAL/Target/CUDA/BUILD
new file mode 100644
index 0000000..718d7f1
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/CUDA/BUILD
@@ -0,0 +1,56 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_cmake_extra_content(
+ content = """
+if(NOT "${IREE_TARGET_BACKEND_CUDA}")
+ return()
+endif()
+""",
+)
+
+cc_library(
+ name = "CUDA",
+ srcs = [
+ "CUDATarget.cpp",
+ ],
+ hdrs = [
+ "CUDATarget.h",
+ ],
+ deps = [
+ "//iree/base:flatcc",
+ "//iree/compiler/Conversion/LinalgToNVVM",
+ "//iree/compiler/Dialect/HAL/Target",
+ "//iree/compiler/Utils",
+ "//iree/schemas:cuda_executable_def_c_fbs",
+ "@llvm-project//llvm:Core",
+ "@llvm-project//llvm:NVPTXCodeGen",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//llvm:Target",
+ "@llvm-project//mlir:LLVMDialect",
+ "@llvm-project//mlir:NVVMDialect",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TargetLLVMIR",
+ "@llvm-project//mlir:TargetLLVMIRModuleTranslation",
+ ],
+)
diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/CUDA/CMakeLists.txt
new file mode 100644
index 0000000..7c7633f
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/CUDA/CMakeLists.txt
@@ -0,0 +1,33 @@
+# Autogenerated from iree/compiler/Dialect/HAL/Target/CUDA/BUILD by
+# build_tools/bazel_to_cmake/bazel_to_cmake.py
+if(NOT "${IREE_TARGET_BACKEND_CUDA}")
+ return()
+endif()
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ CUDA
+ HDRS
+ "CUDATarget.h"
+ SRCS
+ "CUDATarget.cpp"
+ DEPS
+ LLVMCore
+ LLVMNVPTXCodeGen
+ LLVMSupport
+ LLVMTarget
+ MLIRLLVMIR
+ MLIRNVVMIR
+ MLIRPass
+ MLIRSupport
+ MLIRTargetLLVMIR
+ MLIRTargetLLVMIRModuleTranslation
+ iree::base::flatcc
+ iree::compiler::Conversion::LinalgToNVVM
+ iree::compiler::Dialect::HAL::Target
+ iree::compiler::Utils
+ iree::schemas::cuda_executable_def_c_fbs
+ PUBLIC
+)
diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp b/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
new file mode 100644
index 0000000..4e92def
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
@@ -0,0 +1,201 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.h"
+
+#include "iree/compiler/Conversion/LinalgToNVVM/Passes.h"
+#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Utils/FlatbufferUtils.h"
+#include "iree/schemas/cuda_executable_def_builder.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Target/LLVMIR.h"
+#include "mlir/Target/LLVMIR/Export.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+CUDATargetOptions getCUDATargetOptionsFromFlags() {
+ CUDATargetOptions targetOptions;
+ // TODO: flags
+ return targetOptions;
+}
+
+static std::string translateModuleToISA(llvm::Module &module,
+ llvm::TargetMachine &targetMachine) {
+ std::string targetISA;
+ {
+ llvm::raw_string_ostream stream(targetISA);
+ llvm::buffer_ostream pstream(stream);
+ llvm::legacy::PassManager codegenPasses;
+ targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr,
+ llvm::CGFT_AssemblyFile);
+ codegenPasses.run(module);
+ }
+ return targetISA;
+}
+
+class CUDATargetBackend final : public TargetBackend {
+ public:
+ CUDATargetBackend(CUDATargetOptions options) : options_(std::move(options)) {}
+
+ std::string name() const override { return "cuda"; }
+ std::string filter_pattern() const override { return "cuda"; }
+
+ void buildTranslationPassPipeline(OpPassManager &passManager) override {
+ buildNVVMTransformPassPipeline(passManager);
+ }
+
+ LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
+ OpBuilder &executableBuilder) override {
+ // Perform the translation in a separate context to avoid any
+ // multi-threading issues.
+ llvm::LLVMContext context;
+ mlir::registerLLVMDialectTranslation(*targetOp.getContext());
+
+ // We name our files after the executable name so that they are easy to
+ // track both during compilation (logs/artifacts/etc), as outputs (final
+ // intermediate code/binary files), and at runtime (loaded
+ // libraries/symbols/etc).
+ auto libraryName =
+ targetOp->getParentOfType<IREE::HAL::ExecutableOp>().getName().str();
+
+ ModuleOp innerModuleOp = targetOp.getInnerModule();
+
+ // Remove all the functions that are not part of the CUDA kernel.
+ // TODO: Find a better solution to handle this.
+ auto illegalFuncOps = llvm::to_vector<4>(innerModuleOp.getOps<FuncOp>());
+ for (auto funcOp : illegalFuncOps) {
+ funcOp.erase();
+ }
+ auto halInterfaceOps =
+ llvm::to_vector<1>(innerModuleOp.getOps<IREE::HAL::InterfaceOp>());
+ for (auto halOp : halInterfaceOps) {
+ halOp.erase();
+ }
+
+ auto llvmModule =
+ mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName);
+ if (!llvmModule) {
+ return targetOp.emitError() << "failed to translate the MLIR LLVM "
+ "dialect to the native llvm::Module";
+ }
+ for (auto func : innerModuleOp.getOps<LLVM::LLVMFuncOp>()) {
+ auto *llvmFunc = llvmModule->getFunction(func.getName());
+
+ llvm::Metadata *llvmMetadata[] = {
+ llvm::ValueAsMetadata::get(llvmFunc),
+ llvm::MDString::get(llvmModule->getContext(), "kernel"),
+ llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
+ llvm::Type::getInt32Ty(llvmModule->getContext()), 1))};
+ llvm::MDNode *llvmMetadataNode =
+ llvm::MDNode::get(llvmModule->getContext(), llvmMetadata);
+ llvmModule->getOrInsertNamedMetadata("nvvm.annotations")
+ ->addOperand(llvmMetadataNode);
+ }
+
+ std::unique_ptr<llvm::TargetMachine> targetMachine;
+ {
+ llvm::Triple triple("nvptx64-nvidia-cuda");
+ std::string targetChip = "sm_35";
+ std::string features = "+ptx60";
+ std::string error;
+ const llvm::Target *target =
+ llvm::TargetRegistry::lookupTarget("", triple, error);
+ if (target == nullptr) {
+ return targetOp.emitError() << "cannot initialize target triple";
+ }
+ targetMachine.reset(target->createTargetMachine(triple.str(), targetChip,
+ features, {}, {}));
+ if (targetMachine == nullptr) {
+ return targetOp.emitError() << "cannot initialize target machine";
+ }
+ }
+
+ llvmModule->setDataLayout(targetMachine->createDataLayout());
+
+ std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine);
+ // Serialize cuda kernel into the binary that we will embed in the
+ // final flatbuffer.
+ FlatbufferBuilder builder;
+ auto ptxCudeRef = flatbuffers_uint8_vec_create(
+ builder, reinterpret_cast<const uint8_t *>(targetISA.c_str()),
+ targetISA.size());
+
+ auto entryPointNames = llvm::to_vector<8>(
+ llvm::map_range(targetOp.getBlock().getOps<ExecutableEntryPointOp>(),
+ [&](auto op) { return op.getName(); }));
+ auto entryPointsRef = builder.createStringVec(entryPointNames);
+
+ iree_CUDABlockSizeDef_vec_start(builder);
+ for (auto shader : entryPointNames) {
+ // Hard-coded workgroup size.
+ iree_CUDABlockSizeDef_vec_push_create(builder, 1, 1, 1);
+ }
+ auto blockSizesRef = iree_CUDABlockSizeDef_vec_end(builder);
+
+ iree_CUDAExecutableDef_start_as_root(builder);
+ iree_CUDAExecutableDef_entry_points_add(builder, entryPointsRef);
+ iree_CUDAExecutableDef_block_sizes_add(builder, blockSizesRef);
+ iree_CUDAExecutableDef_ptx_image_add(builder, ptxCudeRef);
+ iree_CUDAExecutableDef_end_as_root(builder);
+
+ // Add the binary data to the target executable.
+ executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
+ targetOp.getLoc(), targetOp.sym_name(),
+ static_cast<uint32_t>(IREE::HAL::ExecutableFormat::CUDA),
+ builder.getBufferAttr(executableBuilder.getContext()));
+
+ return success();
+ }
+
+ std::array<Value, 3> calculateDispatchWorkgroupCount(
+ Location loc, IREE::HAL::ExecutableOp executableOp,
+ IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
+ OpBuilder &builder) override {
+ // For now we are not tiling and just dispatch everything as 1,1,1.
+ auto constantOne = builder.createOrFold<mlir::ConstantIndexOp>(loc, 1);
+ return {constantOne, constantOne, constantOne};
+ }
+
+ private:
+ CUDATargetOptions options_;
+};
+
+void registerCUDATargetBackends(
+ std::function<CUDATargetOptions()> queryOptions) {
+ getCUDATargetOptionsFromFlags();
+ static TargetBackendRegistration registration("cuda", [=]() {
+ LLVMInitializeNVPTXTarget();
+ LLVMInitializeNVPTXTargetMC();
+ LLVMInitializeNVPTXTargetInfo();
+ LLVMInitializeNVPTXAsmPrinter();
+ return std::make_unique<CUDATargetBackend>(queryOptions());
+ });
+}
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.h b/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.h
new file mode 100644
index 0000000..d635bf2
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.h
@@ -0,0 +1,41 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_CUDA_CUDATARGET_H_
+#define IREE_COMPILER_DIALECT_HAL_TARGET_CUDA_CUDATARGET_H_
+
+#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Options controlling the CUDA translation.
+struct CUDATargetOptions {};
+
+// Returns a CUDATargetOptions struct initialized with the
+// --iree-hal-cuda-* flags.
+CUDATargetOptions getCUDATargetOptionsFromFlags();
+
+// Registers the CUDA backends.
+void registerCUDATargetBackends(
+ std::function<CUDATargetOptions()> queryOptions);
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_HAL_TARGET_CUDA_CUDATARGET_H_
diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD b/iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD
new file mode 100644
index 0000000..326145d
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD
@@ -0,0 +1,30 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/CUDA/test/CMakeLists.txt
new file mode 100644
index 0000000..b45bfa1
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/CUDA/test/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Autogenerated from iree/compiler/Dialect/HAL/Target/CUDA/test/BUILD by
+# build_tools/bazel_to_cmake/bazel_to_cmake.py
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_MLIR LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir
new file mode 100644
index 0000000..ceb0560
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Target/CUDA/test/smoketest.mlir
@@ -0,0 +1,36 @@
+// RUN: iree-opt -split-input-file -iree-hal-transformation-pipeline -iree-hal-target-backends=cuda %s | IreeFileCheck %s
+
+
+#map = affine_map<(d0) -> (d0)>
+module {
+ flow.executable @add_dispatch_0 attributes {sym_visibility = "private"} {
+ flow.dispatch.entry @add_dispatch_0 attributes {signature = (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>, workgroup_rank = 3 : index}
+ module {
+ func @add_dispatch_0(%arg0: !flow.dispatch.input<16xf32>, %arg1: !flow.dispatch.input<16xf32>, %arg2: !flow.dispatch.output<16xf32>) {
+ %0 = linalg.init_tensor [16] : tensor<16xf32>
+ %1 = flow.dispatch.input.load %arg0 : !flow.dispatch.input<16xf32> -> tensor<16xf32>
+ %2 = flow.dispatch.input.load %arg1 : !flow.dispatch.input<16xf32> -> tensor<16xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %4 = addf %arg3, %arg4 : f32
+ linalg.yield %4 : f32
+ } -> tensor<16xf32>
+ flow.dispatch.output.store %3, %arg2 : tensor<16xf32> -> !flow.dispatch.output<16xf32>
+ return
+ }
+ }
+ }
+ func @add(%arg0: tensor<16xf32>, %arg1: tensor<16xf32>) -> tensor<16xf32> attributes {iree.module.export, iree.reflection = {f = "I13!B4!d16B4!d16R7!B4!d16", fv = "1"}} {
+ %c1 = constant 1 : index
+ %c16 = constant 16 : index
+ %0 = flow.ex.stream.fragment(%arg2 = %c16 : index, %arg3 = %c1 : index, %arg4 = %arg0 : tensor<16xf32>, %arg5 = %arg1 : tensor<16xf32>) -> tensor<16xf32> {
+ %1 = flow.dispatch @add_dispatch_0::@add_dispatch_0[%arg2, %arg3, %arg3] (%arg4, %arg5) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
+ flow.return %1 : tensor<16xf32>
+ }
+ return %0 : tensor<16xf32>
+ }
+}
+
+// CHECK: hal.executable.binary @cuda attributes {
+// CHECK-SAME: data = dense
+// CHECK-SAME: format = 1129661505 : i32}
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 46b45fa..66f8188 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -180,8 +180,10 @@
"IREE_HAVE_VMLA_TARGET",
"IREE_HAVE_VULKANSPIRV_TARGET",
"IREE_HAVE_METALSPIRV_TARGET",
+ "IREE_HAVE_CUDA_TARGET",
],
deps = [
+ "//iree/compiler/Dialect/HAL/Target/CUDA",
"//iree/compiler/Dialect/HAL/Target/LLVM",
"//iree/compiler/Dialect/HAL/Target/MetalSPIRV",
"//iree/compiler/Dialect/HAL/Target/VMLA",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index fc141a1..73ea65f 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -37,6 +37,10 @@
list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::VulkanSPIRV)
list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_VULKANSPIRV_TARGET")
endif()
+if("${IREE_TARGET_BACKEND_CUDA}")
+ list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::CUDA)
+ list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_CUDA_TARGET")
+endif()
if(IREE_ENABLE_EMITC)
set(IREE_OPT_CONDITIONAL_DEPS
diff --git a/iree/tools/init_targets.cc b/iree/tools/init_targets.cc
index 6b17648..7231886 100644
--- a/iree/tools/init_targets.cc
+++ b/iree/tools/init_targets.cc
@@ -26,6 +26,9 @@
#ifdef IREE_HAVE_VULKANSPIRV_TARGET
#include "iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.h"
#endif
+#ifdef IREE_HAVE_CUDA_TARGET
+#include "iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.h"
+#endif
namespace mlir {
namespace iree_compiler {
@@ -53,6 +56,10 @@
IREE::HAL::registerVulkanSPIRVTargetBackends(
[]() { return IREE::HAL::getVulkanSPIRVTargetOptionsFromFlags(); });
#endif
+#ifdef IREE_HAVE_CUDA_TARGET
+ IREE::HAL::registerCUDATargetBackends(
+ []() { return IREE::HAL::getCUDATargetOptionsFromFlags(); });
+#endif
return true;
}();
(void)init_once;