Add CMake build for ModelBuilder project
Update bazel file to allow bazel_to_cmake to work. Add few workarounds
for dynamic library not working with the script.
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d9d802d..73e6c68 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -269,6 +269,11 @@
set(CMAKE_BUILD_TYPE "Debug")
endif()
+ # experimental model builder uses vulkan runner.
+ if(${IREE_BUILD_EXPERIMENTAL})
+ set(MLIR_VULKAN_RUNNER_ENABLED ON)
+ endif()
+
add_subdirectory("${llvm_monorepo_path}/llvm" "third_party/llvm-project/llvm" EXCLUDE_FROM_ALL)
# Reset CMAKE_BUILD_TYPE to its previous setting
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index 00ee03c..225bc30 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -49,6 +49,9 @@
"@llvm-project//mlir:StandardToSPIRVConversions": [
"MLIRStandardToSPIRVTransforms"
],
+ "@llvm-project//mlir:GPUToVulkanTransforms": [
+ "MLIRGPUtoVulkanTransforms"
+ ],
"@llvm-project//mlir:mlir_c_runner_utils": ["MLIRExecutionEngine"],
"@llvm-project//mlir:mlir-translate": ["mlir-translate"],
"@llvm-project//mlir:MlirTableGenMain": ["MLIRTableGen"],
diff --git a/experimental/CMakeLists.txt b/experimental/CMakeLists.txt
index d7282f2..92557d4 100644
--- a/experimental/CMakeLists.txt
+++ b/experimental/CMakeLists.txt
@@ -15,3 +15,5 @@
# NOTE: not all projects require a cmake build. If you are adding a directory
# that is only expected to build with bazel (such as something depending on
# TensorFlow) you can ignore cmake.
+#
+iree_add_all_subdirs()
diff --git a/experimental/ModelBuilder/BUILD b/experimental/ModelBuilder/BUILD
index f70c44a..d74dfe3 100644
--- a/experimental/ModelBuilder/BUILD
+++ b/experimental/ModelBuilder/BUILD
@@ -23,6 +23,7 @@
hdrs = ["ModelBuilder.h"],
deps = [
"@llvm-project//mlir:Affine",
+ "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:EDSC",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
diff --git a/experimental/ModelBuilder/CMakeLists.txt b/experimental/ModelBuilder/CMakeLists.txt
new file mode 100644
index 0000000..e6a21c5
--- /dev/null
+++ b/experimental/ModelBuilder/CMakeLists.txt
@@ -0,0 +1,88 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ ModelBuilder
+ HDRS
+ "ModelBuilder.h"
+ SRCS
+ "ModelBuilder.cpp"
+ DEPS
+ MLIRAffineOps
+ MLIREDSC
+ MLIRGPU
+ MLIRIR
+ MLIRLinalgOps
+ MLIRLinalgTransforms
+ MLIRSCF
+ MLIRSCFTransforms
+ MLIRSPIRV
+ MLIRStandardOps
+ MLIRVector
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ VulkanLaunchWrapper
+ HDRS
+ "VulkanWrapperPass.h"
+ SRCS
+ "VulkanWrapperPass.cpp"
+ DEPS
+ LLVMSupport
+ MLIRStandardToLLVM
+ MLIRIR
+ MLIRLLVMIR
+ MLIRPass
+ MLIRSPIRV
+ MLIRSPIRVSerialization
+ MLIRStandardOps
+ MLIRSupport
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ ModelRunner
+ HDRS
+ "MemRefUtils.h"
+ "ModelRunner.h"
+ SRCS
+ "ModelRunner.cpp"
+ DEPS
+ ::ModelBuilder
+ LLVMSupport
+ MLIRExecutionEngine
+ MLIRGPU
+ MLIRGPUToSPIRVTransforms
+ MLIRGPUtoVulkanTransforms
+ MLIRSPIRVTransforms
+ MLIRIR
+ MLIRLinalgToLLVM
+ MLIRLinalgTransforms
+ MLIRPass
+ MLIRSPIRV
+ MLIRStandardToLLVM
+ MLIRStandardToSPIRVTransforms
+ MLIRTargetLLVMIR
+ MLIRTransformUtils
+ MLIRVector
+ MLIRVectorToLLVM
+ MLIRVectorToSCF
+ PUBLIC
+)
diff --git a/experimental/ModelBuilder/ModelBuilder.cpp b/experimental/ModelBuilder/ModelBuilder.cpp
index 315d79b..a4a9062 100644
--- a/experimental/ModelBuilder/ModelBuilder.cpp
+++ b/experimental/ModelBuilder/ModelBuilder.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/InitAllDialects.h"
using namespace mlir;
using namespace mlir::edsc;
@@ -29,6 +30,18 @@
thread_local MLIRContext mlir::ModelBuilder::ctx;
+void ModelBuilder::registerAllDialects() {
+ registerDialect<AffineDialect>();
+ registerDialect<gpu::GPUDialect>();
+ registerDialect<LLVM::LLVMDialect>();
+ registerDialect<linalg::LinalgDialect>();
+ registerDialect<scf::SCFDialect>();
+ registerDialect<omp::OpenMPDialect>();
+ registerDialect<spirv::SPIRVDialect>();
+ registerDialect<StandardOpsDialect>();
+ registerDialect<vector::VectorDialect>();
+}
+
mlir::ModelBuilder::ModelBuilder()
: OpBuilder(&ctx),
module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&ctx))),
diff --git a/experimental/ModelBuilder/ModelBuilder.h b/experimental/ModelBuilder/ModelBuilder.h
index 8014903..35ad4bd 100644
--- a/experimental/ModelBuilder/ModelBuilder.h
+++ b/experimental/ModelBuilder/ModelBuilder.h
@@ -134,6 +134,9 @@
// SymbolTable as well as uniqued MLIR types.
ModelBuilder();
+ // Register all the dialects used by ModelBuilder.
+ static void registerAllDialects();
+
// Return a reference to the underlying module.
OwningModuleRef &getModuleRef() { return module; }
diff --git a/experimental/ModelBuilder/test/BUILD b/experimental/ModelBuilder/test/BUILD
index ad27616..6a84782 100644
--- a/experimental/ModelBuilder/test/BUILD
+++ b/experimental/ModelBuilder/test/BUILD
@@ -31,16 +31,16 @@
data = [
# runtime libraries
"@llvm-project//mlir:tools/libvulkan-runtime-wrappers.so",
- ":runtime-support.so",
+ "runtime-support.so",
# Tests.
":test-dot-prod",
- ":test-mnist-jit",
- ":test-simple-jit",
- ":test-simple-jit-vulkan",
- ":test-simple-mlir",
- ":test-vector-transfers-jit",
- ":test-matmul-vulkan",
- ":test-vec-to-gpu",
+ "test-mnist-jit",
+ "test-simple-jit",
+ "test-simple-jit-vulkan",
+ "test-simple-mlir",
+ "test-vector-transfers-jit",
+ "test-matmul-vulkan",
+ "test-vec-to-gpu",
# FileChecker.
"//iree/tools:IreeFileCheck",
],
@@ -53,7 +53,6 @@
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:EDSC",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFTransforms",
@@ -68,7 +67,6 @@
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:EDSC",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFTransforms",
@@ -81,7 +79,6 @@
deps = [
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
- "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:EDSC",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFTransforms",
@@ -97,7 +94,6 @@
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:EDSC",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFTransforms",
@@ -112,7 +108,6 @@
"//experimental/ModelBuilder:ModelRunner",
"//iree/base:initializer",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:SPIRVDialect",
@@ -131,7 +126,6 @@
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/LinalgToSPIRV",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:ExecutionEngine",
"@llvm-project//mlir:ExecutionEngineUtils",
"@llvm-project//mlir:GPUToSPIRVTransforms",
@@ -169,7 +163,6 @@
"//iree/compiler/Conversion/CodegenUtils",
"//iree/compiler/Conversion/LinalgToSPIRV",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:GPUToVulkanTransforms",
"@llvm-project//mlir:GPUTransforms",
"@llvm-project//mlir:IR",
@@ -195,7 +188,6 @@
deps = [
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
- "@llvm-project//mlir:AllPassesAndDialects",
],
)
@@ -208,8 +200,7 @@
deps = [
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
- "@com_google_benchmark//:benchmark:benchmark_main",
- "@llvm-project//mlir:AllPassesAndDialects",
+ "@com_google_benchmark//:benchmark",
"@llvm-project//mlir:EDSC",
"@llvm-project//mlir:IR",
],
@@ -222,8 +213,7 @@
deps = [
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
- "@com_google_benchmark//:benchmark:benchmark_main",
- "@llvm-project//mlir:AllPassesAndDialects",
+ "@com_google_benchmark//:benchmark",
"@llvm-project//mlir:EDSC",
"@llvm-project//mlir:IR",
],
@@ -237,8 +227,7 @@
deps = [
"//experimental/ModelBuilder",
"//experimental/ModelBuilder:ModelRunner",
- "@com_google_benchmark//:benchmark:benchmark_main",
- "@llvm-project//mlir:AllPassesAndDialects",
+ "@com_google_benchmark//:benchmark",
"@llvm-project//mlir:EDSC",
"@llvm-project//mlir:IR",
],
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorColumnMajorLLVMIntrinsicsJIT.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorColumnMajorLLVMIntrinsicsJIT.cpp
index adffcd6..d3e743b 100644
--- a/experimental/ModelBuilder/test/BenchMatMulVectorColumnMajorLLVMIntrinsicsJIT.cpp
+++ b/experimental/ModelBuilder/test/BenchMatMulVectorColumnMajorLLVMIntrinsicsJIT.cpp
@@ -138,6 +138,13 @@
}
}
+int main(int argc, char **argv) {
+ mlir::ModelBuilder::registerAllDialects();
+ ::benchmark::Initialize(&argc, argv);
+ if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1;
+ ::benchmark::RunSpecifiedBenchmarks();
+}
+
//
// Benchmark drivers (build).
//
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp
index 7905f99..37bf443 100644
--- a/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp
+++ b/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp
@@ -125,6 +125,13 @@
}
}
+int main(int argc, char **argv) {
+ mlir::ModelBuilder::registerAllDialects();
+ ::benchmark::Initialize(&argc, argv);
+ if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1;
+ ::benchmark::RunSpecifiedBenchmarks();
+}
+
//
// Benchmark drivers (build and run).
//
diff --git a/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp b/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp
index 30d9b80..b71542e 100644
--- a/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp
+++ b/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp
@@ -129,6 +129,13 @@
}
}
+int main(int argc, char **argv) {
+ mlir::ModelBuilder::registerAllDialects();
+ ::benchmark::Initialize(&argc, argv);
+ if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1;
+ ::benchmark::RunSpecifiedBenchmarks();
+}
+
//
// Benchmark drivers (build and run).
//
diff --git a/experimental/ModelBuilder/test/CMakeLists.txt b/experimental/ModelBuilder/test/CMakeLists.txt
new file mode 100644
index 0000000..94dbcb0
--- /dev/null
+++ b/experimental/ModelBuilder/test/CMakeLists.txt
@@ -0,0 +1,275 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_add_all_subdirs()
+
+file(GLOB _GLOB_X_CPP LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS *.cpp)
+file(GLOB _GLOB_BENCHX LIST_DIRECTORIES false RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} CONFIGURE_DEPENDS Bench*)
+list(REMOVE_ITEM _GLOB_X_CPP ${_GLOB_BENCHX})
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_CPP}"
+ DATA
+ test-dot-prod
+ test-matmul-vulkan
+ test-mnist-jit
+ test-simple-jit
+ test-simple-jit-vulkan
+ test-simple-mlir
+ test-vec-to-gpu
+ test-vector-transfers-jit
+ iree::tools::IreeFileCheck
+)
+
+iree_cc_binary(
+ NAME
+ test-dot-prod
+ OUT
+ test-dot-prod
+ SRCS
+ "TestDotProdJIT.cpp"
+ DEPS
+ LLVMSupport
+ MLIRAllDialects
+ MLIREDSC
+ MLIRIR
+ MLIRSCFTransforms
+ experimental::ModelBuilder
+ experimental::ModelBuilder::ModelRunner
+)
+
+iree_cc_binary(
+ NAME
+ test-vector-transfers-jit
+ OUT
+ test-vector-transfers-jit
+ SRCS
+ "TestVectorTransfersJIT.cpp"
+ DEPS
+ runtime-support.so
+ LLVMSupport
+ MLIRAllDialects
+ MLIREDSC
+ MLIRIR
+ MLIRSCFTransforms
+ experimental::ModelBuilder
+ experimental::ModelBuilder::ModelRunner
+)
+
+iree_cc_binary(
+ NAME
+ test-mnist-jit
+ OUT
+ test-mnist-jit
+ SRCS
+ "TestMNISTJIT.cpp"
+ DEPS
+ MLIRAllDialects
+ MLIREDSC
+ MLIRIR
+ MLIRSCFTransforms
+ MLIRmlir_runner_utils
+ experimental::ModelBuilder
+ experimental::ModelBuilder::ModelRunner
+)
+
+iree_cc_binary(
+ NAME
+ test-simple-jit
+ OUT
+ test-simple-jit
+ SRCS
+ "TestSimpleJIT.cpp"
+ DEPS
+ LLVMSupport
+ MLIRAllDialects
+ MLIREDSC
+ MLIRIR
+ MLIRSCFTransforms
+ experimental::ModelBuilder
+ experimental::ModelBuilder::ModelRunner
+)
+
+iree_cc_binary(
+ NAME
+ test-simple-jit-vulkan
+ OUT
+ test-simple-jit-vulkan
+ SRCS
+ "TestSimpleJITVulkan.cpp"
+ DEPS
+ LLVMSupport
+ MLIRAllDialects
+ MLIRIR
+ MLIRParser
+ MLIRSPIRV
+ MLIRmlir_runner_utils
+ experimental::ModelBuilder
+ experimental::ModelBuilder::ModelRunner
+ iree::base::initializer
+ iree::hal::llvmjit::llvmjit_driver_module
+ iree::hal::vmla::vmla_driver_module
+ iree::hal::vulkan::vulkan_driver_module
+ vulkan-runtime-wrappers
+)
+
+iree_cc_binary(
+ NAME
+ test-matmul-vulkan
+ OUT
+ test-matmul-vulkan
+ SRCS
+ "TestMatMulVulkan.cpp"
+ DEPS
+ LLVMSupport
+ MLIRAllDialects
+ MLIRExecutionEngine
+ MLIRGPU
+ MLIRGPUToSPIRVTransforms
+ MLIRGPUToVulkanTransforms
+ MLIRIR
+ MLIRLinalgOps
+ MLIRLinalgToLLVM
+ MLIRLinalgTransforms
+ MLIRParser
+ MLIRPass
+ MLIRSPIRV
+ MLIRStandardToLLVM
+ MLIRStandardToSPIRVTransforms
+ MLIRSupport
+ MLIRTargetLLVMIR
+ MLIRTransformUtils
+ MLIRVectorToLLVM
+ MLIRmlir_runner_utils
+ experimental::ModelBuilder
+ experimental::ModelBuilder::ModelRunner
+ experimental::ModelBuilder::VulkanLaunchWrapper
+ iree::base::initializer
+ iree::compiler::Conversion::LinalgToSPIRV
+ iree::hal::llvmjit::llvmjit_driver_module
+ iree::hal::vmla::vmla_driver_module
+ iree::hal::vulkan::vulkan_driver_module
+ vulkan-runtime-wrappers
+)
+
+iree_cc_binary(
+ NAME
+ test-vec-to-gpu
+ OUT
+ test-vec-to-gpu
+ SRCS
+ "TestVectorToGPU.cpp"
+ DEPS
+ LLVMSupport
+ MLIRAllDialects
+ MLIRExecutionEngine
+ MLIRGPU
+ MLIRGPUToVulkanTransforms
+ MLIRIR
+ MLIRLinalgOps
+ MLIRLinalgToLLVM
+ MLIRLinalgTransforms
+ MLIRParser
+ MLIRPass
+ MLIRSPIRV
+ MLIRStandardToLLVM
+ MLIRStandardToSPIRVTransforms
+ MLIRTransformUtils
+ MLIRVector
+ MLIRmlir_runner_utils
+ experimental::ModelBuilder
+ experimental::ModelBuilder::ModelRunner
+ experimental::ModelBuilder::VulkanLaunchWrapper
+ iree::base::initializer
+ iree::compiler::Conversion::CodegenUtils
+ iree::compiler::Conversion::LinalgToSPIRV
+ iree::hal::llvmjit::llvmjit_driver_module
+ iree::hal::vmla::vmla_driver_module
+ iree::hal::vulkan::vulkan_driver_module
+ vulkan-runtime-wrappers
+)
+
+iree_cc_binary(
+ NAME
+ test-simple-mlir
+ OUT
+ test-simple-mlir
+ SRCS
+ "TestSimpleMLIR.cpp"
+ DEPS
+ MLIRAllDialects
+ experimental::ModelBuilder
+ experimental::ModelBuilder::ModelRunner
+)
+
+iree_cc_binary(
+ NAME
+ bench-matvec-vector-jit
+ OUT
+ bench-matvec-vector-jit
+ SRCS
+ "BenchMatVecVectorJIT.cpp"
+ DEPS
+ MLIRAllDialects
+ MLIREDSC
+ MLIRIR
+ benchmark
+ experimental::ModelBuilder
+ experimental::ModelBuilder::ModelRunner
+)
+
+iree_cc_binary(
+ NAME
+ bench-matmul-vector-jit
+ OUT
+ bench-matmul-vector-jit
+ SRCS
+ "BenchMatMulVectorJIT.cpp"
+ DEPS
+ MLIRAllDialects
+ MLIREDSC
+ MLIRIR
+ benchmark
+ experimental::ModelBuilder
+ experimental::ModelBuilder::ModelRunner
+)
+
+iree_cc_binary(
+ NAME
+ bench-matmul-vector-column-major-llvm-intrinsics-jit
+ OUT
+ bench-matmul-vector-column-major-llvm-intrinsics-jit
+ SRCS
+ "BenchMatMulVectorColumnMajorLLVMIntrinsicsJIT.cpp"
+ DEPS
+ MLIRAllDialects
+ MLIREDSC
+ MLIRIR
+ benchmark
+ experimental::ModelBuilder
+ experimental::ModelBuilder::ModelRunner
+)
+
+# TODO(thomasraoux): Fix bazel_to_cmake tto support shared library
+iree_cc_library(
+NAME
+ runtime-support.so
+ OUT
+ runtime-support.so
+ DEPS
+ MLIRExecutionEngine
+ SHARED
+)
diff --git a/experimental/ModelBuilder/test/TestDotProdJIT.cpp b/experimental/ModelBuilder/test/TestDotProdJIT.cpp
index e05a856..e64f9c7 100644
--- a/experimental/ModelBuilder/test/TestDotProdJIT.cpp
+++ b/experimental/ModelBuilder/test/TestDotProdJIT.cpp
@@ -100,6 +100,7 @@
}
int main(int argc, char **argv) {
+ ModelBuilder::registerAllDialects();
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv, "TestDotProd\n");
DotProdOnVectors();
diff --git a/experimental/ModelBuilder/test/TestMNISTJIT.cpp b/experimental/ModelBuilder/test/TestMNISTJIT.cpp
index 49676e2..f657ff6 100644
--- a/experimental/ModelBuilder/test/TestMNISTJIT.cpp
+++ b/experimental/ModelBuilder/test/TestMNISTJIT.cpp
@@ -145,6 +145,7 @@
}
int main() {
+ ModelBuilder::registerAllDialects();
constexpr unsigned B = 3, W0 = 784, W1 = 256, W2 = 256, W3 = 10;
ModelBuilder modelBuilder;
diff --git a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
index b53bc81..e16a8a2 100644
--- a/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestMatMulVulkan.cpp
@@ -149,6 +149,7 @@
}
int main(int argc, char **argv) {
+ ModelBuilder::registerAllDialects();
iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
diff --git a/experimental/ModelBuilder/test/TestSimpleJIT.cpp b/experimental/ModelBuilder/test/TestSimpleJIT.cpp
index 9a2b29c..9c6f389 100644
--- a/experimental/ModelBuilder/test/TestSimpleJIT.cpp
+++ b/experimental/ModelBuilder/test/TestSimpleJIT.cpp
@@ -179,6 +179,7 @@
}
int main(int argc, char **argv) {
+ ModelBuilder::registerAllDialects();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
llvm::InitLLVM y(argc, argv);
diff --git a/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp b/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
index 0e3d388..9c43915 100644
--- a/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
+++ b/experimental/ModelBuilder/test/TestSimpleJITVulkan.cpp
@@ -118,6 +118,7 @@
}
int main(int argc, char **argv) {
+ ModelBuilder::registerAllDialects();
iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
diff --git a/experimental/ModelBuilder/test/TestSimpleMLIR.cpp b/experimental/ModelBuilder/test/TestSimpleMLIR.cpp
index 9901972..d046590 100644
--- a/experimental/ModelBuilder/test/TestSimpleMLIR.cpp
+++ b/experimental/ModelBuilder/test/TestSimpleMLIR.cpp
@@ -132,6 +132,7 @@
}
int main(int argc, char **argv) {
+ ModelBuilder::registerAllDialects();
testValueVectorAdd();
testMemRefVectorAdd();
}
diff --git a/experimental/ModelBuilder/test/TestVectorToGPU.cpp b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
index 40e3690..69f8354 100644
--- a/experimental/ModelBuilder/test/TestVectorToGPU.cpp
+++ b/experimental/ModelBuilder/test/TestVectorToGPU.cpp
@@ -249,6 +249,7 @@
}
int main(int argc, char **argv) {
+ ModelBuilder::registerAllDialects();
iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
diff --git a/experimental/ModelBuilder/test/TestVectorTransfersJIT.cpp b/experimental/ModelBuilder/test/TestVectorTransfersJIT.cpp
index 7c0eeca..761f681 100644
--- a/experimental/ModelBuilder/test/TestVectorTransfersJIT.cpp
+++ b/experimental/ModelBuilder/test/TestVectorTransfersJIT.cpp
@@ -126,6 +126,7 @@
}
int main(int argc, char **argv) {
+ ModelBuilder::registerAllDialects();
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv, "TestVectorTransfers\n");