Merge pull request #5622 from ThomasRaoux:main-to-google
PiperOrigin-RevId: 370538063
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a8ae5a8..de28b15 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -44,6 +44,7 @@
option(IREE_BUILD_TESTS "Builds IREE unit tests." ON)
option(IREE_BUILD_DOCS "Builds IREE docs." OFF)
option(IREE_BUILD_SAMPLES "Builds IREE sample projects." ON)
+option(IREE_BUILD_EMBEDDING_SAMPLES "Builds IREE embedding sample projects. The compiler needs to be available." OFF)
option(IREE_BUILD_TENSORFLOW_COMPILER "Builds TensorFlow compiler frontend." OFF)
option(IREE_BUILD_TFLITE_COMPILER "Builds the TFLite compiler frontend." OFF)
@@ -472,6 +473,10 @@
add_subdirectory(iree/samples)
endif()
+if(${IREE_BUILD_EMBEDDING_SAMPLES} AND NOT ${IREE_BUILD_SAMPLES})
+ add_subdirectory(iree/samples/simple_embedding)
+endif()
+
if(${IREE_BUILD_EXPERIMENTAL_MODEL_BUILDER})
add_subdirectory(experimental/ModelBuilder)
endif()
diff --git a/build_tools/mako/benchmark_modules_on_android.py b/build_tools/mako/benchmark_modules_on_android.py
index 1990135..f93e686 100644
--- a/build_tools/mako/benchmark_modules_on_android.py
+++ b/build_tools/mako/benchmark_modules_on_android.py
@@ -58,12 +58,11 @@
def benchmark(module_name, flagfile_name, target) -> str:
samples = []
- driver = target.get_driver()
cmd = [
- "adb", "shell", "LD_LIBRARY_PATH=/data/local/tmp", "taskset", "80",
- f"{DEVICE_ROOT}/iree-benchmark-module",
+ "adb", "shell", "LD_LIBRARY_PATH=/data/local/tmp", "taskset",
+ target.taskset, f"{DEVICE_ROOT}/iree-benchmark-module",
f"--flagfile={DEVICE_ROOT}/{flagfile_name}",
- f"--module_file={DEVICE_ROOT}/{module_name}", f"--driver={driver}",
+ f"--module_file={DEVICE_ROOT}/{module_name}", f"--driver={target.driver}",
"--benchmark_repetitions=10"
] + target.runtime_flags
print(f"Running cmd: {' '.join(cmd)}")
diff --git a/build_tools/mako/compile_android_modules.py b/build_tools/mako/compile_android_modules.py
index 3cf4fb5..198d9a5 100644
--- a/build_tools/mako/compile_android_modules.py
+++ b/build_tools/mako/compile_android_modules.py
@@ -34,12 +34,14 @@
module_name = configuration.get_module_name(model_benchmark.name,
phone.name, target.mako_tag)
print(f"Generating {module_name} ...")
- subprocess.run(args=[
- IREE_TRANSLATE_PATH, model_benchmark.model_path,
- "--iree-mlir-to-vm-bytecode-module",
- f"--iree-hal-target-backends={target.name}", "-o", module_name
- ] + target.compilation_flags,
- check=True)
+ subprocess.run(
+ args=[
+ IREE_TRANSLATE_PATH, model_benchmark.model_path,
+ "--iree-mlir-to-vm-bytecode-module",
+ f"--iree-hal-target-backends={target.hal_target_backend}", "-o",
+ module_name
+ ] + target.compilation_flags,
+ check=True)
if __name__ == "__main__":
diff --git a/build_tools/mako/config/mobile-bert-pixel4.config b/build_tools/mako/config/mobile-bert-pixel4.config
index 5722c2a..98f32ef 100644
--- a/build_tools/mako/config/mobile-bert-pixel4.config
+++ b/build_tools/mako/config/mobile-bert-pixel4.config
@@ -27,11 +27,14 @@
type: NUMERIC
}
-# Three metrics, define the names for y-axis values of both run and aggregate
-# charts.
+# Define the names for y-axis values of both run and aggregate charts.
metric_info_list: {
value_key: "cpu"
- label: "DYLib_AOT"
+ label: "DYLib_AOT-1-thread"
+}
+metric_info_list: {
+ value_key: "cpu3t"
+ label: "DYLib_AOT-3-thread"
}
metric_info_list: {
value_key: "vmla"
diff --git a/build_tools/mako/config/mobile-bert-s20.config b/build_tools/mako/config/mobile-bert-s20.config
index 71990e7..8bdaa94 100644
--- a/build_tools/mako/config/mobile-bert-s20.config
+++ b/build_tools/mako/config/mobile-bert-s20.config
@@ -27,11 +27,14 @@
type: NUMERIC
}
-# Three metrics, define the names for y-axis values of both run and aggregate
-# charts.
+# Define the names for y-axis values of both run and aggregate charts.
metric_info_list: {
value_key: "cpu"
- label: "DYLib_AOT"
+ label: "DYLib_AOT-1-thread"
+}
+metric_info_list: {
+ value_key: "cpu3t"
+ label: "DYLib_AOT-3-thread"
}
metric_info_list: {
value_key: "vmla"
diff --git a/build_tools/mako/config/mobilenet-v2-pixel4.config b/build_tools/mako/config/mobilenet-v2-pixel4.config
index 3efd677..3d25ce5 100644
--- a/build_tools/mako/config/mobilenet-v2-pixel4.config
+++ b/build_tools/mako/config/mobilenet-v2-pixel4.config
@@ -27,11 +27,14 @@
type: NUMERIC
}
-# Three metrics, define the names for y-axis values of both run and aggregate
-# charts.
+# Define the names for y-axis values of both run and aggregate charts.
metric_info_list: {
value_key: "cpu"
- label: "DYLib_AOT"
+ label: "DYLib_AOT-1-thread"
+}
+metric_info_list: {
+ value_key: "cpu3t"
+ label: "DYLib_AOT-3-thread"
}
metric_info_list: {
value_key: "vmla"
diff --git a/build_tools/mako/config/mobilenet-v2-s20.config b/build_tools/mako/config/mobilenet-v2-s20.config
index 2cd286d..057bd31 100644
--- a/build_tools/mako/config/mobilenet-v2-s20.config
+++ b/build_tools/mako/config/mobilenet-v2-s20.config
@@ -27,11 +27,14 @@
type: NUMERIC
}
-# Three metrics, define the names for y-axis values of both run and aggregate
-# charts.
+# Define the names for y-axis values of both run and aggregate charts.
metric_info_list: {
value_key: "cpu"
- label: "DYLib_AOT"
+ label: "DYLib_AOT-1-thread"
+}
+metric_info_list: {
+ value_key: "cpu3t"
+ label: "DYLib_AOT-3-thread"
}
metric_info_list: {
value_key: "vmla"
diff --git a/build_tools/mako/configuration.py b/build_tools/mako/configuration.py
index 858e5f3..d6ea3ff 100644
--- a/build_tools/mako/configuration.py
+++ b/build_tools/mako/configuration.py
@@ -18,7 +18,9 @@
"""Information of a target backend.
Attributes:
- name: The target name used in iree-translate, e.g., vulkan-spirv.
+ driver: The driver used in iree-benchmark-module, e.g., vulkan.
+ hal_target_backend: The target name used in iree-translate, e.g., vulkan-spirv.
+ taskset: The value used for taskset when benchmarking the IREE module.
mako_tag: The value_key in Mako config. This will be used in Mako metric
info, which should match to the config.
compilation_flags: Addition compilation flags. This is useful to target
@@ -28,25 +30,23 @@
"""
def __init__(self,
- name,
+ driver,
+ hal_target_backend,
+ taskset,
mako_tag,
compilation_flags=None,
runtime_flags=None):
- if "_" in name:
- raise ValueError("The target name contains invalid char '_'")
if compilation_flags is None:
compilation_flags = []
if runtime_flags is None:
runtime_flags = []
- self.name = name
+ self.driver = driver
+ self.hal_target_backend = hal_target_backend
+ self.taskset = taskset
self.mako_tag = mako_tag
self.compilation_flags = compilation_flags
self.runtime_flags = runtime_flags
- def get_driver(self) -> str:
- """ Returns a string indicates the driver of the target."""
- return self.name.split("-")[0]
-
def add_batch_flag(self, size):
self.compilation_flags.append(
f"--iree-hal-benchmark-dispatch-repeat-count={size}")
@@ -100,23 +100,43 @@
if batch_config is None:
batch_config = []
targets = [
- TargetInfo(name="vmla", mako_tag="vmla"),
- TargetInfo(name="dylib-llvm-aot",
- mako_tag="cpu",
- compilation_flags=[
- "--iree-llvm-target-triple=aarch64-none-linux-android29",
- "-iree-flow-inline-constants-max-byte-length=2048",
- "-iree-flow-dispatch-formation-enable-operand-fusion"
- ],
- runtime_flags=["--dylib_worker_count=1"]),
TargetInfo(
- name="vulkan-spirv",
+ driver="vmla",
+ hal_target_backend="vmla",
+ taskset="80",
+ mako_tag="vmla"),
+ TargetInfo(
+ driver="dylib-sync",
+ hal_target_backend="dylib-llvm-aot",
+ taskset="80",
+ mako_tag="cpu",
+ compilation_flags=[
+ "--iree-llvm-target-triple=aarch64-none-linux-android29",
+ "--iree-flow-inline-constants-max-byte-length=2048",
+ "--iree-flow-dispatch-formation-enable-operand-fusion"
+ ]),
+ TargetInfo(
+ driver="dylib",
+ hal_target_backend="dylib-llvm-aot",
+ taskset="f0",
+ mako_tag="cpu3t",
+ compilation_flags=[
+ "--iree-llvm-target-triple=aarch64-none-linux-android29",
+ "--iree-flow-inline-constants-max-byte-length=2048",
+ "--iree-flow-dispatch-formation-enable-operand-fusion"
+ ],
+ runtime_flags=[
+ "--dylib_worker_count=3",
+ ]),
+ TargetInfo(
+ driver="vulkan",
+ hal_target_backend="vulkan-spirv",
+ taskset="80",
mako_tag="vlk",
compilation_flags=[
"--iree-vulkan-target-triple=qualcomm-adreno640-unknown-android10",
- "-iree-flow-inline-constants-max-byte-length=2048",
- "-iree-flow-dispatch-formation-enable-operand-fusion",
- "-iree-flow-tile-and-distribute-elementwise-ops"
+ "--iree-flow-inline-constants-max-byte-length=2048",
+ "--iree-flow-dispatch-formation-enable-operand-fusion"
])
]
targets = [elem for elem in targets if elem.mako_tag not in skipped_target]
@@ -132,24 +152,44 @@
if batch_config is None:
batch_config = []
targets = [
- TargetInfo(name="vmla", mako_tag="vmla"),
- TargetInfo(name="dylib-llvm-aot",
- mako_tag="cpu",
- compilation_flags=[
- "--iree-llvm-target-triple=aarch64-none-linux-android29",
- "-iree-flow-inline-constants-max-byte-length=2048",
- "-iree-flow-dispatch-formation-enable-operand-fusion"
- ],
- runtime_flags=["--dylib_worker_count=1"]),
TargetInfo(
- name="vulkan-spirv",
+ driver="vmla",
+ hal_target_backend="vmla",
+ taskset="80",
+ mako_tag="vmla"),
+ TargetInfo(
+ driver="dylib-sync",
+ hal_target_backend="dylib-llvm-aot",
+ taskset="80",
+ mako_tag="cpu",
+ compilation_flags=[
+ "--iree-llvm-target-triple=aarch64-none-linux-android29",
+ "--iree-flow-inline-constants-max-byte-length=2048",
+ "--iree-flow-dispatch-formation-enable-operand-fusion"
+ ]),
+ TargetInfo(
+ driver="dylib",
+ hal_target_backend="dylib-llvm-aot",
+ taskset="f0",
+ mako_tag="cpu3t",
+ compilation_flags=[
+ "--iree-llvm-target-triple=aarch64-none-linux-android29",
+ "--iree-flow-inline-constants-max-byte-length=2048",
+ "--iree-flow-dispatch-formation-enable-operand-fusion"
+ ],
+ runtime_flags=[
+ "--dylib_worker_count=3",
+ ]),
+ TargetInfo(
+ driver="vulkan",
+ hal_target_backend="vulkan-spirv",
+ taskset="80",
mako_tag="vlk",
compilation_flags=[
"--iree-vulkan-target-triple=valhall-g77-unknown-android10",
# TODO(GH-5330): Revisit the number or delete the flag.
- "-iree-flow-inline-constants-max-byte-length=16",
- "-iree-flow-dispatch-formation-enable-operand-fusion",
- "-iree-flow-tile-and-distribute-elementwise-ops",
+ "--iree-flow-inline-constants-max-byte-length=16",
+ "--iree-flow-dispatch-formation-enable-operand-fusion"
])
]
targets = [elem for elem in targets if elem.mako_tag not in skipped_target]
@@ -166,25 +206,20 @@
MODEL_BENCHMARKS = [
ModelBenchmarkInfo(
name="mobile-bert",
- model_artifacts_name=
- "iree-mobile-bert-artifacts-6fe4616e0ab9958eb18f368960a31276f1362029.tar.gz",
+ model_artifacts_name="iree-mobile-bert-artifacts-6fe4616e0ab9958eb18f368960a31276f1362029.tar.gz",
model_path="tmp/iree/modules/MobileBertSquad/iree_input.mlir",
- flagfile_path=
- "tmp/iree/modules/MobileBertSquad/iree_vmla/traces/serving_default/flagfile",
+ flagfile_path="tmp/iree/modules/MobileBertSquad/iree_vmla/traces/serving_default/flagfile",
phones=[
- PhoneBenchmarkInfo(name="Pixel4",
- benchmark_key="5538704950034432",
- targets=get_pixel4_default_target_list(
- skipped_target=["cpu2", "vlk2"],
- batch_config={"cpu": 8})),
- PhoneBenchmarkInfo(name="S20",
- benchmark_key="4699630718681088",
- targets=get_s20_default_target_list(
- skipped_target=["cpu2", "vlk2"],
- batch_config={
- "cpu": 8,
- "vlk": 16
- })),
+ PhoneBenchmarkInfo(
+ name="Pixel4",
+ benchmark_key="5538704950034432",
+ targets=get_pixel4_default_target_list(
+ skipped_target=["cpu2", "vlk2"],)),
+ PhoneBenchmarkInfo(
+ name="S20",
+ benchmark_key="4699630718681088",
+ targets=get_s20_default_target_list(
+ skipped_target=["cpu2", "vlk2"],)),
]),
ModelBenchmarkInfo(
name="mobilenet-v2",
@@ -192,20 +227,15 @@
model_path="mobilenet-v2/iree_input.mlir",
flagfile_path="mobilenet-v2/flagfile",
phones=[
- PhoneBenchmarkInfo(name="Pixel4",
- benchmark_key="6338759231537152",
- targets=get_pixel4_default_target_list(
- skipped_target=["vlk2"],
- batch_config={
- "cpu": 16,
- })),
+ PhoneBenchmarkInfo(
+ name="Pixel4",
+ benchmark_key="6338759231537152",
+ targets=get_pixel4_default_target_list(
+ skipped_target=["vlk2"])),
PhoneBenchmarkInfo(
name="S20",
benchmark_key="5618403088793600",
- targets=get_s20_default_target_list(batch_config={
- "cpu": 16,
- "vlk": 64,
- })),
+ targets=get_s20_default_target_list()),
]),
ModelBenchmarkInfo(
name="mobilebert-f16",
diff --git a/integrations/tensorflow/build_tools/testdata/generate_errors_module.py b/integrations/tensorflow/build_tools/testdata/generate_errors_module.py
new file mode 100644
index 0000000..cea0547
--- /dev/null
+++ b/integrations/tensorflow/build_tools/testdata/generate_errors_module.py
@@ -0,0 +1,54 @@
+# Lint as: python3
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Generates sample models for excercising various function signatures.
+
+Usage:
+ generate_errors_module.py /tmp/errors.sm
+
+This can then be fed into iree-tf-import to process it:
+
+Fully convert to IREE input (run all import passes):
+ iree-tf-import /tmp/errors.sm
+
+Import only (useful for crafting test cases for the import pipeline):
+ iree-tf-import -o /dev/null -save-temp-tf-input=- /tmp/errors.sm
+
+Can be further lightly pre-processed via:
+ | iree-tf-opt --tf-standard-pipeline
+"""
+
+import sys
+
+import tensorflow as tf
+
+
+class ErrorsModule(tf.Module):
+
+ @tf.function(input_signature=[tf.TensorSpec([16], tf.float32)])
+ def string_op(self, a):
+ tf.print(a)
+ return a
+
+
+try:
+ file_name = sys.argv[1]
+except IndexError:
+ print("Expected output file name")
+ sys.exit(1)
+
+m = ErrorsModule()
+tf.saved_model.save(
+ m, file_name, options=tf.saved_model.SaveOptions(save_debug_info=True))
+print(f"Saved to {file_name}")
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
index 2342967..c7a9002 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
@@ -26,6 +26,7 @@
"LowerExportedFunctions.cpp",
"LowerGlobalTensors.cpp",
"Passes.cpp",
+ "PrettifyDebugInfo.cpp",
"PropagateResourceCasts.cpp",
"SavedModelToIreeABI.cpp",
"StripAsserts.cpp",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
index aa94ac3..a34f563 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
+++ b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
@@ -65,6 +65,10 @@
// functions for any saved model exported functions.
std::unique_ptr<OperationPass<ModuleOp>> createSavedModelToIREEABIPass();
+// Simplifies TensorFlow debug info for the purposes of making it easier to
+// look at.
+std::unique_ptr<OperationPass<ModuleOp>> createPrettifyDebugInfoPass();
+
// Push resource casts forward to better propagate resource related shapes.
std::unique_ptr<OperationPass<ModuleOp>> createPropagateResourceCastsPass();
@@ -92,6 +96,7 @@
createFlattenTuplesInCFGPass();
createLowerGlobalTensorsPass();
createLowerExportedFunctionsPass();
+ createPrettifyDebugInfoPass();
createPropagateResourceCastsPass();
createSavedModelToIREEABIPass();
createStripAssertsPass();
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/PrettifyDebugInfo.cpp b/integrations/tensorflow/iree_tf_compiler/TF/PrettifyDebugInfo.cpp
new file mode 100644
index 0000000..71e7374
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TF/PrettifyDebugInfo.cpp
@@ -0,0 +1,47 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree_tf_compiler/TF/Passes.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+
+namespace mlir {
+namespace iree_integrations {
+namespace TF {
+
+class PrettifyDebugInfoPass
+ : public PassWrapper<PrettifyDebugInfoPass, OperationPass<ModuleOp>> {
+ public:
+ void runOnOperation() override {
+ // TODO: Finish algorithm for simplifying TF debug info.
+ // auto moduleOp = getOperation();
+ // moduleOp.walk([&](Operation *op) {
+ // Location loc = op->getLoc();
+ // if (auto callSite = loc.dyn_cast<CallSiteLoc>()) {
+ // callSite.getCallee().dump();
+ // }
+ // });
+ }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createPrettifyDebugInfoPass() {
+ return std::make_unique<PrettifyDebugInfoPass>();
+}
+
+static PassRegistration<PrettifyDebugInfoPass> modulePass(
+ "iree-tf-prettify-debug-info",
+ "Simplifies TF debug info to make it easier to look at");
+
+} // namespace TF
+} // namespace iree_integrations
+} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
index 6887dd7..426bf3d 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-tflite-main.cpp
@@ -17,6 +17,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/AsmState.h"
@@ -76,6 +77,9 @@
MLIRContext context(registry);
context.loadAllAvailableDialects();
+ llvm::SourceMgr sourceMgr;
+ mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
+
// Load input buffer.
std::string errorMessage;
auto inputFile = openInputFile(inputPath, &errorMessage);
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
index 10b672a..e48ed4d 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
@@ -178,6 +178,9 @@
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
+ llvm::SourceMgr sourceMgr;
+ mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
+
auto status =
ConvertHloToMlirHlo(module.get(), hloProto.mutable_hlo_module());
if (!status.ok()) {
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-tf-import-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-tf-import-main.cpp
index 5ac7224..4e8e90b 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-tf-import-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-tf-import-main.cpp
@@ -24,6 +24,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
@@ -160,6 +161,11 @@
llvm::cl::desc("Save the resultant IR to this file (useful for saving an "
"intermediate in a pipeline)"),
llvm::cl::init(""));
+ static llvm::cl::opt<bool> prettifyTfDebugInfo(
+ "prettify-tf-debug-info",
+ llvm::cl::desc("Prettifies TF debug information to make it easier "
+ "to look at"),
+ llvm::cl::init(true));
// Register any command line options.
registerAsmPrinterCLOptions();
@@ -172,6 +178,10 @@
MLIRContext context(registry);
context.loadAllAvailableDialects();
+
+ llvm::SourceMgr sourceMgr;
+ mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
+
OwningModuleRef module;
auto saveToFile = [&](llvm::StringRef savePath) -> LogicalResult {
@@ -210,6 +220,10 @@
PassManager pm(&context, PassManager::Nesting::Implicit);
applyPassManagerCLOptions(pm);
+ if (prettifyTfDebugInfo) {
+ pm.addPass(iree_integrations::TF::createPrettifyDebugInfoPass());
+ }
+
iree_integrations::TF::buildTFImportPassPipeline(pm);
if (failed(pm.run(*module))) {
llvm::errs()
diff --git a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
index 02a3df6..37d9421 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.cpp
@@ -33,41 +33,38 @@
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated);
static llvm::cl::opt<int> matmulWorkgroupTileSize(
- "iree-codegen-linalg-to-llvm-kernel-dispatch-matmul-workgroup-tile-size",
+ "iree-codegen-llvm-matmul-workgroup-size",
llvm::cl::desc(
"linalg.matmul tile size for workgroups spliting of M, N dimension"),
llvm::cl::init(64));
-
static llvm::cl::opt<int> matmulL1TileSize(
- "iree-codegen-linalg-to-llvm-kernel-dispatch-matmul-l1-tile-size",
+ "iree-codegen-llvm-matmul-l1-size",
llvm::cl::desc(
- "linalg.matmul tile size for workgroups spliting of M, N dimension"),
+ "linalg.matmul tile size for L1 spliting of M, N, K dimension"),
llvm::cl::init(32));
-
static llvm::cl::opt<int> matmulL2TileSize(
- "iree-codegen-linalg-to-llvm-kernel-dispatch-matmul-l2-tile-size",
- llvm::cl::desc(
- "linalg.matmul tile size for workgroups spliting of M, N dimension"),
- llvm::cl::init(4));
+ "iree-codegen-llvm-matmul-vector-size",
+ llvm::cl::desc("linalg.matmul vector tile size"), llvm::cl::init(4));
static llvm::cl::opt<int> batchMatmulWorkgroupTileSize(
- "iree-codegen-linalg-to-llvm-kernel-dispatch-batch-matmul-workgroup-tile-"
- "size",
- llvm::cl::desc(
- "linalg.matmul tile size for workgroups spliting of M, N dimension"),
+ "iree-codegen-llvm-batch-matmul-workgroup-size",
+ llvm::cl::desc("linalg.batch_matmul tile size for workgroups spliting of "
+ "M, N dimension"),
llvm::cl::init(32));
-
static llvm::cl::opt<int> batchMatmulL1TileSize(
- "iree-codegen-linalg-to-llvm-kernel-dispatch-batch-matmul-l1-tile-size",
+ "iree-codegen-llvm-batch-matmul-l1-size",
llvm::cl::desc(
- "linalg.matmul tile size for workgroups spliting of M, N dimension"),
+ "linalg.batch_matmul tile size for L1 spliting of M, N, K dimensions"),
llvm::cl::init(16));
-
static llvm::cl::opt<int> batchMatmulL2TileSize(
- "iree-codegen-linalg-to-llvm-kernel-dispatch-batch-matmul-l2-tile-size",
+ "iree-codegen-llvm-batch-matmul-vector-size",
+ llvm::cl::desc("linalg.batch_matmul vector tile size"), llvm::cl::init(4));
+
+static llvm::cl::opt<int> genericOpsWorkgroupTileSize(
+ "iree-codegen-llvm-generic-ops-workgroup-size",
llvm::cl::desc(
- "linalg.matmul tile size for workgroups spliting of M, N dimension"),
- llvm::cl::init(4));
+ "linalg.generic and linalg.indexed_generic workgroup tile size"),
+ llvm::cl::init(128));
namespace {
template <TilingLevel tilingLevel>
@@ -104,6 +101,31 @@
}
}
+ if (isa<linalg::GenericOp>(op)) {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+ switch (tilingLevel) {
+ case TilingLevel::WorkGroupTiles: {
+ llvm::SmallVector<int64_t, 4> workgroupTileSizes;
+ int iterationRank = linalgOp.iterator_types().size();
+ for (int i = 0; i < std::min(iterationRank, 3); ++i) {
+ auto iteratorType = linalgOp.iterator_types()[i];
+ if (iteratorType.cast<StringAttr>().getValue() ==
+ getParallelIteratorTypeName()) {
+ workgroupTileSizes.push_back(genericOpsWorkgroupTileSize);
+ } else {
+ // Don't tile workgroup across reduction dimensions.
+ workgroupTileSizes.push_back(0);
+ }
+ }
+ return workgroupTileSizes;
+ }
+ // TODO(ataei): Set the parameters when we enable vectorization.
+ case TilingLevel::Level1Tiles:
+ case TilingLevel::Level2Tiles:
+ return {1, 1, 1};
+ }
+ }
+
return {1, 1, 1};
}
} // namespace
@@ -129,6 +151,17 @@
#undef DEFINE_TILE_SIZE_FN
+bool isDispatchOp(Operation *op) {
+ if (auto contractionOp = dyn_cast<linalg::ContractionOpInterface>(op)) {
+ if (contractionOp.isRowMajorMatmul() ||
+ contractionOp.isRowMajorBatchMatmul()) {
+ return true;
+ }
+ }
+ if (isa<linalg::GenericOp>(op)) return true;
+ return false;
+}
+
Optional<LaunchConfig> initCPULaunchConfig(
MLIRContext *context, const linalg::LinalgDependenceGraph &dependenceGraph,
ArrayRef<linalg::LinalgOp> linalgOps) {
@@ -136,28 +169,21 @@
Optional<linalg::LinalgOp> rootOperation = llvm::None;
for (auto linalgOp : linalgOps) {
- if (auto contractionOp =
- dyn_cast<linalg::ContractionOpInterface>(linalgOp.getOperation())) {
- if (!contractionOp.isRowMajorMatmul() &&
- !contractionOp.isRowMajorBatchMatmul()) {
- continue;
- }
- if (rootOperation) {
- contractionOp.emitError(
- "unhandled multiple root operations in dispatch region");
- return llvm::None;
- }
- rootOperation = linalgOp;
- SmallVector<int64_t, 4> opTileSizes;
- if (!clLLVMTileSizes.empty()) {
- opTileSizes.assign(clLLVMTileSizes.begin(), clLLVMTileSizes.end());
- } else {
- opTileSizes = getTileSizes<TilingLevel::WorkGroupTiles>(contractionOp);
- }
- config.setTileSizes(contractionOp, opTileSizes, 0);
- config.setRootOperation(contractionOp);
- continue;
+ if (!isDispatchOp(linalgOp)) continue;
+ if (rootOperation) {
+ linalgOp.emitError(
+ "unhandled multiple root operations in dispatch region");
+ return llvm::None;
}
+ rootOperation = linalgOp;
+ SmallVector<int64_t, 4> opTileSizes;
+ if (!clLLVMTileSizes.empty()) {
+ opTileSizes.assign(clLLVMTileSizes.begin(), clLLVMTileSizes.end());
+ } else {
+ opTileSizes = getTileSizes<TilingLevel::WorkGroupTiles>(linalgOp);
+ }
+ config.setTileSizes(linalgOp, opTileSizes, 0);
+ config.setRootOperation(linalgOp);
}
if (!rootOperation) {
return config;
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index 1d9bb84..6b2ff1b 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Conversion/Common/Passes.h"
+
+#include "iree/compiler/Conversion/Common/Attributes.h"
#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
#include "iree/compiler/Conversion/LinalgToLLVM/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
@@ -60,6 +61,7 @@
LLVMCodegenOptions options) {
passManager.addPass(createMaterializeCPULaunchConfigurationPass());
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
+ nestedModulePM.addPass(createCanonicalizerPass());
// TODO(ataei): We want to enable when tensor -> vector pass is fully
// supported which requires first moving vector-tiling before this step.
if (options.useLinalgOnTensorsToVectors) {
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir
index 93fada9..e7da53c 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir
@@ -125,6 +125,8 @@
}
}
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
// CHECK: hal.executable @add
// CHECK: hal.executable.entry_point @add
// CHECK-NEXT: ^{{[a-zA-Z0-9_]+}}(
@@ -132,4 +134,6 @@
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK: hal.return %[[C1]], %[[C1]], %[[C1]]
+// CHECK-DAG: %[[WGX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK-DAG: %[[WGY:.+]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
+// CHECK: hal.return %[[WGX]], %[[WGY]], %[[C1]]
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 553d3fa..d087f53 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -59,13 +59,6 @@
"Enable fusing operand producers during dispatch region formation"),
llvm::cl::init(false));
-// TODO(#5045): Tile and distribute on CPU causes performance regressions. This
-// needs to be addressed before this can be turned on by default.
-static llvm::cl::opt<bool> clTileAndDistributeElementwiseOps(
- "iree-flow-tile-and-distribute-elementwise-ops",
- llvm::cl::desc("Enable tile and distribute on elementwise operations"),
- llvm::cl::init(false));
-
static const char kRootOpAttr[] = "__root_op__";
static const char kFusionGroupsAttr[] = "__fused_op__";
@@ -1061,9 +1054,7 @@
context->allowUnregisteredDialects(true);
unsigned numRoots = decideFusableLinalgOps(funcOp);
- if (clTileAndDistributeElementwiseOps) {
- makeElementwiseOpsRootOps<linalg::GenericOp>(funcOp, numRoots);
- }
+ makeElementwiseOpsRootOps<linalg::GenericOp>(funcOp, numRoots);
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After annotating linalg op fusion scheme ---\n";
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index b0f869f..b1a9a42 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -53,6 +53,11 @@
"matmul ops pass."),
llvm::cl::init(true));
+static llvm::cl::opt<bool> clEnableConvToImg2Col(
+ "iree-flow-enable-conv-img2col-transform",
+ llvm::cl::desc("Enable converting convolution ops to img2col form."),
+ llvm::cl::init(false));
+
namespace mlir {
namespace iree_compiler {
namespace IREE {
@@ -207,6 +212,10 @@
passManager.addNestedPass<FuncOp>(
mlir::iree_compiler::createConvert1x1ConvToMatmulPass());
}
+ if (clEnableConvToImg2Col) {
+ passManager.addNestedPass<FuncOp>(
+ mlir::iree_compiler::createConvertConv2DToImg2ColPass());
+ }
passManager.addNestedPass<FuncOp>(
mlir::createConvertElementwiseToLinalgPass());
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 860105c..7dceaf6 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -271,8 +271,10 @@
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = memref.dim %[[ARG0]], %[[C2]]
// CHECK-DAG: %[[D3:.+]] = memref.dim %[[ARG0]], %[[C3]]
-// CHECK: %[[D4:.+]] = affine.apply #[[MAP0]]()[%[[D0]], %[[D1]]]
-// CHECK: flow.dispatch.workgroups[%[[D3]], %[[D2]], %[[D4]]]
+// CHECK-DAG: %[[WG_SISE_2:.+]] = flow.dispatch.workgroup.size[2] : index
+// CHECK-DAG: %[[WG_ID_2:.+]] = flow.dispatch.workgroup.id[2] : index
+// CHECK-DAG: flow.dispatch.workgroups[%[[D3]], %[[D2]], %[[D1]]]
+// CHECK-DAG: %[[D4:.+]] = affine.apply #[[MAP0]]()[%[[WG_ID_2]], %[[WG_SISE_2]]]
// -----
@@ -582,16 +584,16 @@
// CHECK-NOT: linalg.
// CHECK-NOT: subtensor
// CHECK: flow.dispatch.workgroups
-// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:i32>
-// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:1x?xf32>
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?xf32>
+// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:1x?xf32>
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?xf32>
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:i32>
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:?xf32>
// CHECK: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], {{.*}}
// CHECK: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
// CHECK: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
-// CHECK: %[[OP1:.+]] = linalg.tensor_reshape %[[LEAF3]]
-// CHECK: %[[OP2:.+]] = linalg.tensor_reshape %[[LEAF2]]
+// CHECK: %[[OP1:.+]] = linalg.tensor_reshape %[[LEAF2]]
+// CHECK: %[[OP2:.+]] = linalg.tensor_reshape %[[LEAF1]]
// CHECK: %[[OP3:.+]] = subtensor %[[OP1]][0, 0]
// CHECK: %[[OP4:.+]] = subtensor %[[OP1]][0, 10]
// CHECK: %[[OP5:.+]] = subtensor %[[OP1]][0, 20]
@@ -640,17 +642,17 @@
// CHECK: subtensor %[[OP1]]
// CHECK: linalg.tensor_reshape %[[ARG1]]
// CHECK: flow.dispatch.workgroups
-// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:i32>
-// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:1x?xf32>
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?xf32>
-// CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:1x?xf32>
+// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:1x?xf32>
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?xf32>
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:1x?xf32>
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:i32>
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG9:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:?xf32>
-// CHECK: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
+// CHECK: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], {{.*}}
// CHECK: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
// CHECK: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG7]], {{.*}}
-// CHECK: %[[OP1:.+]] = subtensor %[[LEAF3]][0, 0]
-// CHECK: %[[OP2:.+]] = subtensor %[[LEAF3]][0, 10]
+// CHECK: %[[OP1:.+]] = subtensor %[[LEAF2]][0, 0]
+// CHECK: %[[OP2:.+]] = subtensor %[[LEAF2]][0, 10]
// CHECK: %[[OP3:.+]] = linalg.tensor_reshape %[[LEAF1]]
// CHECK: %[[OP4:.+]] = linalg.tensor_reshape %[[OP1]]
// CHECK: %[[OP5:.+]] = linalg.tensor_reshape %[[OP2]]
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_elementwise.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_elementwise.mlir
index 4facda6..6410715 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_elementwise.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_elementwise.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -verify-diagnostics -iree-flow-dispatch-linalg-on-tensors-pass -iree-flow-tile-and-distribute-elementwise-ops -canonicalize -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-flow-dispatch-linalg-on-tensors-pass -canonicalize -cse %s | IreeFileCheck %s
func @tile_generic_op_alone(%A: tensor<?x?xf32>, %B: tensor<?xf32>) -> tensor<?x?xf32> {
%c0 = constant 0 : index
diff --git a/iree/hal/local/elf/platform/generic.c b/iree/hal/local/elf/platform/generic.c
index 1f5fc6c..c1df242 100644
--- a/iree/hal/local/elf/platform/generic.c
+++ b/iree/hal/local/elf/platform/generic.c
@@ -94,8 +94,10 @@
// IREE_ELF_CLEAR_CACHE can be defined externally to override this default
// behavior.
#if !defined(IREE_ELF_CLEAR_CACHE)
+// Explicitly enable for GCC, which has had this since 4.x but does not
+// seem to advertise it via __has_builtin.
#if defined __has_builtin
-#if __has_builtin(__builtin___clear_cache)
+#if __has_builtin(__builtin___clear_cache) || defined(__GNUC__)
#define IREE_ELF_CLEAR_CACHE(start, end) __builtin___clear_cache(start, end)
#endif // __builtin___clear_cache
#endif // __has_builtin
diff --git a/iree/hal/local/elf/platform/linux.c b/iree/hal/local/elf/platform/linux.c
index 5e30563..394e7a1 100644
--- a/iree/hal/local/elf/platform/linux.c
+++ b/iree/hal/local/elf/platform/linux.c
@@ -144,8 +144,10 @@
// IREE_ELF_CLEAR_CACHE can be defined externally to override this default
// behavior.
#if !defined(IREE_ELF_CLEAR_CACHE)
+// Explicitly enable for GCC, which has had this since 4.x but does not
+// seem to advertise it via __has_builtin.
#if defined __has_builtin
-#if __has_builtin(__builtin___clear_cache)
+#if __has_builtin(__builtin___clear_cache) || defined(__GNUC__)
#define IREE_ELF_CLEAR_CACHE(start, end) __builtin___clear_cache(start, end)
#endif // __builtin___clear_cache
#endif // __has_builtin
diff --git a/iree/samples/simple_embedding/BUILD b/iree/samples/simple_embedding/BUILD
index b0d9ce8..6c10039 100644
--- a/iree/samples/simple_embedding/BUILD
+++ b/iree/samples/simple_embedding/BUILD
@@ -14,6 +14,8 @@
load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
load("//iree/tools:compilation.bzl", "iree_bytecode_module")
+load("//build_tools/bazel:run_binary_test.bzl", "run_binary_test")
+load("//build_tools/embed_data:build_defs.bzl", "c_embed_data")
package(
default_visibility = ["//visibility:public"],
@@ -21,6 +23,42 @@
licenses = ["notice"], # Apache 2.0
)
+# The prebuilt simple_embedding_test-llvm-aot_rv64 bytecode module is build with
+# RISC-V 64-bit cross-compile toolchain (with RISCV_TOOLCHAIN_ROOT defined):
+# iree-translate \
+# -iree-mlir-to-vm-bytecode-module \
+# -iree-hal-target-backends=dylib-llvm-aot \
+# -iree-llvm-target-triple=riscv64 \
+# -iree-llvm-target-cpu=generic-rv64 \
+# -iree-llvm-target-cpu-features="+m,+a,+f,+d,+c" \
+# -iree-llvm-target-abi=lp64d \
+# iree/samples/simple_embedding/simple_embedding_test.mlir \
+# -o /tmp/simple_embedding_test-llvm-aot_rv64.vmfb
+c_embed_data(
+ name = "simple_embedding_test_llvm_aot_rv64",
+ srcs = [
+ "data/simple_embedding_test-llvm-aot_rv64.vmfb",
+ ],
+ c_file_output = "simple_embedding_test_llvm_aot_rv64.c",
+ flatten = True,
+ h_file_output = "simple_embedding_test_llvm_aot_rv64.h",
+)
+
+cc_binary(
+ name = "simple_embedding_run",
+ srcs = ["simple_embedding_run.c"],
+ deps = [
+ ":simple_embedding_test_bytecode_module_c",
+ ":simple_embedding_test_llvm_aot_rv64",
+ "//iree/base:api",
+ "//iree/hal:api",
+ "//iree/hal/drivers",
+ "//iree/modules/hal",
+ "//iree/vm",
+ "//iree/vm:bytecode_module",
+ ],
+)
+
iree_cmake_extra_content(
content = """
if(NOT ${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT} OR NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV}
@@ -28,11 +66,13 @@
return()
endif()
""",
+ inline = True,
)
iree_bytecode_module(
name = "simple_embedding_test_bytecode_module",
src = "simple_embedding_test.mlir",
+ c_output = True,
cc_namespace = "iree::samples",
flags = [
"-iree-mlir-to-vm-bytecode-module",
@@ -41,23 +81,29 @@
],
)
-cc_test(
- name = "simple_embedding_test",
- srcs = ["simple_embedding_test.cc"],
- deps = [
- ":simple_embedding_test_bytecode_module_cc",
- "//iree/base:api",
- "//iree/base:logging",
- "//iree/hal:api",
- "//iree/hal/dylib/registration",
- "//iree/hal/vulkan/registration",
- "//iree/modules/hal",
- "//iree/testing:gtest",
- "//iree/testing:gtest_main",
- "//iree/vm",
- "//iree/vm:bytecode_module",
- "//iree/vm:cc",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
- ],
+# Bytecode module is built without cross-compilation. Bypass the following test if it
+# is not built for the host machine.
+
+iree_cmake_extra_content(
+ content = """
+# Bytecode is built without cross-compilation. Bypass the following test if it
+# is not built for the host machine.
+
+if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "amd64|x86_64|AMD64")
+ return()
+endif()
+""",
+ inline = True,
+)
+
+run_binary_test(
+ name = "simple_embedding_dylib_test",
+ args = ["dylib"],
+ test_binary = ":simple_embedding_run",
+)
+
+run_binary_test(
+ name = "simple_embedding_vulkan_test",
+ args = ["vulkan"],
+ test_binary = ":simple_embedding_run",
)
diff --git a/iree/samples/simple_embedding/CMakeLists.txt b/iree/samples/simple_embedding/CMakeLists.txt
index 767a81f..f8d0237 100644
--- a/iree/samples/simple_embedding/CMakeLists.txt
+++ b/iree/samples/simple_embedding/CMakeLists.txt
@@ -8,13 +8,42 @@
# To disable autogeneration for this file entirely, delete this header. #
################################################################################
+iree_add_all_subdirs()
+
+iree_c_embed_data(
+ NAME
+ simple_embedding_test_llvm_aot_rv64
+ SRCS
+ "data/simple_embedding_test-llvm-aot_rv64.vmfb"
+ C_FILE_OUTPUT
+ "simple_embedding_test_llvm_aot_rv64.c"
+ H_FILE_OUTPUT
+ "simple_embedding_test_llvm_aot_rv64.h"
+ FLATTEN
+ PUBLIC
+)
+
+iree_cc_binary(
+ NAME
+ simple_embedding_run
+ SRCS
+ "simple_embedding_run.c"
+ DEPS
+ ::simple_embedding_test_bytecode_module_c
+ ::simple_embedding_test_llvm_aot_rv64
+ iree::base::api
+ iree::hal::api
+ iree::hal::drivers
+ iree::modules::hal
+ iree::vm
+ iree::vm::bytecode_module
+)
+
if(NOT ${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT} OR NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV}
OR NOT ${IREE_HAL_DRIVER_DYLIB} OR NOT ${IREE_HAL_DRIVER_VULKAN})
return()
endif()
-iree_add_all_subdirs()
-
iree_bytecode_module(
NAME
simple_embedding_test_bytecode_module
@@ -22,6 +51,7 @@
"simple_embedding_test.mlir"
CC_NAMESPACE
"iree::samples"
+ C_OUTPUT
FLAGS
"-iree-mlir-to-vm-bytecode-module"
"-iree-hal-target-backends=dylib-llvm-aot"
@@ -29,26 +59,29 @@
PUBLIC
)
-iree_cc_test(
+# Bytecode is built without cross-compilation. Bypass the following test if it
+# is not built for the host machine.
+
+if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "amd64|x86_64|AMD64")
+ return()
+endif()
+
+iree_run_binary_test(
NAME
- simple_embedding_test
- SRCS
- "simple_embedding_test.cc"
- DEPS
- ::simple_embedding_test_bytecode_module_cc
- absl::span
- absl::strings
- iree::base::api
- iree::base::logging
- iree::hal::api
- iree::hal::dylib::registration
- iree::hal::vulkan::registration
- iree::modules::hal
- iree::testing::gtest
- iree::testing::gtest_main
- iree::vm
- iree::vm::bytecode_module
- iree::vm::cc
+ "simple_embedding_dylib_test"
+ ARGS
+ "dylib"
+ TEST_BINARY
+ ::simple_embedding_run
+)
+
+iree_run_binary_test(
+ NAME
+ "simple_embedding_vulkan_test"
+ ARGS
+ "vulkan"
+ TEST_BINARY
+ ::simple_embedding_run
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/samples/simple_embedding/data/simple_embedding_test-llvm-aot_rv64.vmfb b/iree/samples/simple_embedding/data/simple_embedding_test-llvm-aot_rv64.vmfb
new file mode 100644
index 0000000..a5df535
--- /dev/null
+++ b/iree/samples/simple_embedding/data/simple_embedding_test-llvm-aot_rv64.vmfb
Binary files differ
diff --git a/iree/samples/simple_embedding/simple_embedding_run.c b/iree/samples/simple_embedding/simple_embedding_run.c
new file mode 100644
index 0000000..14104f1
--- /dev/null
+++ b/iree/samples/simple_embedding/simple_embedding_run.c
@@ -0,0 +1,218 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// A example of setting up the HAL mddule to run simple pointwise array
+// multiplication with the dylib driver.
+#include <stdio.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#if IREE_ARCH_RISCV_64
+#include "iree/hal/dylib/registration/driver_module.h"
+#else
+#include "iree/hal/drivers/init.h"
+#endif
+
+#include "iree/modules/hal/hal_module.h"
+#include "iree/vm/api.h"
+#include "iree/vm/bytecode_module.h"
+
+// Compiled module embedded here to avoid file IO:
+#if IREE_ARCH_RISCV_64
+#include "iree/samples/simple_embedding/simple_embedding_test_llvm_aot_rv64.h"
+#else
+#include "iree/samples/simple_embedding/simple_embedding_test_bytecode_module_c.h"
+#endif
+
+iree_status_t Run(char* hal_driver_name) {
+ // TODO(benvanik): move to instance-based registration.
+ IREE_RETURN_IF_ERROR(iree_hal_module_register_types());
+
+ iree_vm_instance_t* instance = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_vm_instance_create(iree_allocator_system(), &instance));
+
+#if IREE_ARCH_RISCV_64
+ // Only register dylib HAL driver
+ IREE_RETURN_IF_ERROR(iree_hal_dylib_driver_module_register(
+ iree_hal_driver_registry_default()));
+#else
+ // Register all drivers so it can be selected by the driver name.
+ IREE_RETURN_IF_ERROR(iree_hal_register_all_available_drivers(
+ iree_hal_driver_registry_default()));
+#endif
+
+ // Create the hal driver from the name. The driver name can be assigned as a
+ // hard-coded char array such as "dylib" as well.
+ iree_hal_driver_t* driver = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create_by_name(
+ iree_hal_driver_registry_default(),
+ iree_make_cstring_view(hal_driver_name), iree_allocator_system(),
+ &driver));
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_driver_create_default_device(
+ driver, iree_allocator_system(), &device));
+ iree_vm_module_t* hal_module = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_module_create(device, iree_allocator_system(), &hal_module));
+ iree_hal_driver_release(driver);
+
+ // Load bytecode module from the embedded data.
+#if IREE_ARCH_RISCV_64
+ const struct iree_file_toc_t* module_file_toc =
+ simple_embedding_test_llvm_aot_rv64_create();
+#else
+ // Note the setup here only supports native build. The bytecode is not built
+ // for the cross-compile execution. The code can be compiled but it will
+ // hit runtime error in a cross-compile environment.
+ const struct iree_file_toc_t* module_file_toc =
+ simple_embedding_test_bytecode_module_c_create();
+#endif
+
+ iree_vm_module_t* bytecode_module = NULL;
+ iree_const_byte_span_t module_data = {(const uint8_t*)module_file_toc->data,
+ module_file_toc->size};
+ IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
+ module_data, iree_allocator_null(), iree_allocator_system(),
+ &bytecode_module));
+
+ // Allocate a context that will hold the module state across invocations.
+ const int kVmModuleNum = 2;
+ iree_vm_context_t* context = NULL;
+ iree_vm_module_t* modules[] = {hal_module, bytecode_module};
+ IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
+ instance, &modules[0], kVmModuleNum, iree_allocator_system(), &context));
+ iree_vm_module_release(hal_module);
+ iree_vm_module_release(bytecode_module);
+
+ // Lookup the entry point function.
+ // Note that we use the synchronous variant which operates on pure type/shape
+ // erased buffers.
+ const char kMainFunctionName[] = "module.simple_mul";
+ iree_vm_function_t main_function;
+ IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function(
+ context, iree_make_cstring_view(kMainFunctionName), &main_function));
+
+ // Allocate buffers that can be mapped on the CPU and that can also be used
+ // on the device. Not all devices support this, but the ones we have now do.
+ const int kElementCount = 4;
+ iree_hal_buffer_t* arg0_buffer = NULL;
+ iree_hal_buffer_t* arg1_buffer = NULL;
+ iree_hal_memory_type_t input_memory_type =
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
+ IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
+ iree_hal_device_allocator(device), input_memory_type,
+ IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount, &arg0_buffer));
+ IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
+ iree_hal_device_allocator(device), input_memory_type,
+ IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount, &arg1_buffer));
+
+ // Populate initial values for 4 * 2 = 8.
+ const float kFloat4 = 4.0f;
+ const float kFloat2 = 2.0f;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_fill(arg0_buffer, 0, IREE_WHOLE_BUFFER,
+ &kFloat4, sizeof(float)));
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_fill(arg1_buffer, 0, IREE_WHOLE_BUFFER,
+ &kFloat2, sizeof(float)));
+
+ // Wrap buffers in shaped buffer views.
+ iree_hal_dim_t shape[1] = {kElementCount};
+ iree_hal_buffer_view_t* arg0_buffer_view = NULL;
+ iree_hal_buffer_view_t* arg1_buffer_view = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
+ arg0_buffer, IREE_HAL_ELEMENT_TYPE_FLOAT_32, shape, IREE_ARRAYSIZE(shape),
+ &arg0_buffer_view));
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
+ arg1_buffer, IREE_HAL_ELEMENT_TYPE_FLOAT_32, shape, IREE_ARRAYSIZE(shape),
+ &arg1_buffer_view));
+ iree_hal_buffer_release(arg0_buffer);
+ iree_hal_buffer_release(arg1_buffer);
+
+ // Setup call inputs with our buffers.
+ iree_vm_list_t* inputs = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_list_create(
+ /*element_type=*/NULL,
+ /*capacity=*/2, iree_allocator_system(), &inputs),
+ "can't allocate input vm list");
+
+ iree_vm_ref_t arg0_buffer_view_ref =
+ iree_hal_buffer_view_move_ref(arg0_buffer_view);
+ iree_vm_ref_t arg1_buffer_view_ref =
+ iree_hal_buffer_view_move_ref(arg1_buffer_view);
+ IREE_RETURN_IF_ERROR(
+ iree_vm_list_push_ref_move(inputs, &arg0_buffer_view_ref));
+ IREE_RETURN_IF_ERROR(
+ iree_vm_list_push_ref_move(inputs, &arg1_buffer_view_ref));
+
+ // Prepare outputs list to accept the results from the invocation.
+ // The output vm list is allocated statically.
+ iree_vm_list_t* outputs = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_list_create(
+ /*element_type=*/NULL,
+ /*capacity=*/1, iree_allocator_system(), &outputs),
+ "can't allocate output vm list");
+
+ // Synchronously invoke the function.
+ IREE_RETURN_IF_ERROR(iree_vm_invoke(context, main_function,
+ /*policy=*/NULL, inputs, outputs,
+ iree_allocator_system()));
+
+ // Get the result buffers from the invocation.
+ iree_hal_buffer_view_t* ret_buffer_view =
+ (iree_hal_buffer_view_t*)iree_vm_list_get_ref_deref(
+ outputs, 0, iree_hal_buffer_view_get_descriptor());
+ if (ret_buffer_view == NULL) {
+ return iree_make_status(IREE_STATUS_NOT_FOUND,
+ "can't find return buffer view");
+ }
+
+ // Read back the results and ensure we got the right values.
+ iree_hal_buffer_mapping_t mapped_memory;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
+ iree_hal_buffer_view_buffer(ret_buffer_view), IREE_HAL_MEMORY_ACCESS_READ,
+ 0, IREE_WHOLE_BUFFER, &mapped_memory));
+ for (int i = 0; i < mapped_memory.contents.data_length / sizeof(float); ++i) {
+ if (((const float*)mapped_memory.contents.data)[i] != 8.0f) {
+ return iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches");
+ }
+ }
+ iree_hal_buffer_unmap_range(&mapped_memory);
+
+ iree_vm_list_release(inputs);
+ iree_vm_list_release(outputs);
+ iree_hal_device_release(device);
+ iree_vm_context_release(context);
+ iree_vm_instance_release(instance);
+ return iree_ok_status();
+}
+
+int main(int argc, char** argv) {
+ if (argc < 2) {
+ printf("usage: simple_embedding_run <HAL driver name>\n");
+ return -1;
+ }
+ char* hal_driver_name = argv[1];
+ const iree_status_t result = Run(hal_driver_name);
+ if (!iree_status_is_ok(result)) {
+ char* message;
+ size_t message_length;
+ iree_status_to_string(result, &message, &message_length);
+ printf("simple_embedding_run failed: %s\n", message);
+ iree_allocator_free(iree_allocator_system(), message);
+ return -1;
+ }
+ printf("simple_embedding_run passed\n");
+ return 0;
+}
diff --git a/iree/samples/simple_embedding/simple_embedding_test.cc b/iree/samples/simple_embedding/simple_embedding_test.cc
deleted file mode 100644
index 17db37b..0000000
--- a/iree/samples/simple_embedding/simple_embedding_test.cc
+++ /dev/null
@@ -1,213 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "absl/strings/str_replace.h"
-#include "absl/types/span.h"
-#include "iree/base/api.h"
-#include "iree/base/logging.h"
-#include "iree/hal/api.h"
-#include "iree/hal/dylib/registration/driver_module.h"
-#include "iree/hal/vulkan/registration/driver_module.h"
-#include "iree/modules/hal/hal_module.h"
-#include "iree/testing/gtest.h"
-#include "iree/testing/status_matchers.h"
-#include "iree/vm/api.h"
-#include "iree/vm/bytecode_module.h"
-#include "iree/vm/ref_cc.h"
-
-// Compiled module embedded here to avoid file IO:
-#include "iree/samples/simple_embedding/simple_embedding_test_bytecode_module.h"
-
-namespace iree {
-namespace samples {
-namespace {
-
-struct TestParams {
- // HAL driver to use for the test.
- std::string driver_name;
-};
-
-std::ostream& operator<<(std::ostream& os, const TestParams& params) {
- return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}});
-}
-
-std::vector<TestParams> GetDriverTestParams() {
- // The test file was compiled for DyLib+Vulkan, so test on each driver.
- std::vector<TestParams> test_params;
-
- IREE_CHECK_OK(iree_hal_dylib_driver_module_register(
- iree_hal_driver_registry_default()));
- TestParams dylib_params;
- dylib_params.driver_name = "dylib";
- test_params.push_back(std::move(dylib_params));
-
- IREE_CHECK_OK(iree_hal_vulkan_driver_module_register(
- iree_hal_driver_registry_default()));
- TestParams vulkan_params;
- vulkan_params.driver_name = "vulkan";
- test_params.push_back(std::move(vulkan_params));
-
- return test_params;
-}
-
-class SimpleEmbeddingTest : public ::testing::Test,
- public ::testing::WithParamInterface<TestParams> {};
-
-TEST_P(SimpleEmbeddingTest, RunOnce) {
- // TODO(benvanik): move to instance-based registration.
- IREE_ASSERT_OK(iree_hal_module_register_types());
-
- iree_vm_instance_t* instance = nullptr;
- IREE_ASSERT_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
-
- // Create the driver/device as defined by the test and setup the HAL module.
- const auto& driver_name = GetParam().driver_name;
- IREE_LOG(INFO) << "Creating driver '" << driver_name << "'...";
- iree_hal_driver_t* driver = nullptr;
- IREE_ASSERT_OK(iree_hal_driver_registry_try_create_by_name(
- iree_hal_driver_registry_default(),
- iree_string_view_t{driver_name.data(), driver_name.size()},
- iree_allocator_system(), &driver));
- iree_hal_device_t* device = nullptr;
- IREE_ASSERT_OK(iree_hal_driver_create_default_device(
- driver, iree_allocator_system(), &device));
- iree_vm_module_t* hal_module = nullptr;
- IREE_ASSERT_OK(
- iree_hal_module_create(device, iree_allocator_system(), &hal_module));
- iree_hal_driver_release(driver);
-
- // Load bytecode module from the embedded data.
- IREE_LOG(INFO) << "Loading simple_module_test.mlir...";
- const auto* module_file_toc = simple_embedding_test_bytecode_module_create();
- iree_vm_module_t* bytecode_module = nullptr;
- IREE_ASSERT_OK(iree_vm_bytecode_module_create(
- iree_const_byte_span_t{
- reinterpret_cast<const uint8_t*>(module_file_toc->data),
- module_file_toc->size},
- iree_allocator_null(), iree_allocator_system(), &bytecode_module));
-
- // Allocate a context that will hold the module state across invocations.
- iree_vm_context_t* context = nullptr;
- std::vector<iree_vm_module_t*> modules = {hal_module, bytecode_module};
- IREE_ASSERT_OK(iree_vm_context_create_with_modules(
- instance, modules.data(), modules.size(), iree_allocator_system(),
- &context));
- IREE_LOG(INFO) << "Module loaded and context is ready for use";
- iree_vm_module_release(hal_module);
- iree_vm_module_release(bytecode_module);
-
- // Lookup the entry point function.
- // Note that we use the synchronous variant which operates on pure type/shape
- // erased buffers.
- const char kMainFunctionName[] = "module.simple_mul";
- iree_vm_function_t main_function;
- IREE_ASSERT_OK(iree_vm_context_resolve_function(
- context, iree_make_cstring_view(kMainFunctionName), &main_function))
- << "Exported function '" << kMainFunctionName << "' not found";
-
- // Allocate buffers that can be mapped on the CPU and that can also be used
- // on the device. Not all devices support this, but the ones we have now do.
- IREE_LOG(INFO) << "Creating I/O buffers...";
- constexpr int kElementCount = 4;
- iree_hal_buffer_t* arg0_buffer = nullptr;
- iree_hal_buffer_t* arg1_buffer = nullptr;
- IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
- iree_hal_device_allocator(device),
- iree_hal_memory_type_t(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
- IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
- IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount, &arg0_buffer));
- IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
- iree_hal_device_allocator(device),
- iree_hal_memory_type_t(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
- IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
- IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount, &arg1_buffer));
-
- // Populate initial values for 4 * 2 = 8.
- float kFloat4 = 4.0f;
- float kFloat2 = 2.0f;
- IREE_ASSERT_OK(iree_hal_buffer_fill(arg0_buffer, 0, IREE_WHOLE_BUFFER,
- &kFloat4, sizeof(float)));
- IREE_ASSERT_OK(iree_hal_buffer_fill(arg1_buffer, 0, IREE_WHOLE_BUFFER,
- &kFloat2, sizeof(float)));
-
- // Wrap buffers in shaped buffer views.
- iree_hal_dim_t shape[1] = {kElementCount};
- iree_hal_buffer_view_t* arg0_buffer_view = nullptr;
- iree_hal_buffer_view_t* arg1_buffer_view = nullptr;
- IREE_ASSERT_OK(iree_hal_buffer_view_create(
- arg0_buffer, IREE_HAL_ELEMENT_TYPE_FLOAT_32, shape, IREE_ARRAYSIZE(shape),
- &arg0_buffer_view));
- IREE_ASSERT_OK(iree_hal_buffer_view_create(
- arg1_buffer, IREE_HAL_ELEMENT_TYPE_FLOAT_32, shape, IREE_ARRAYSIZE(shape),
- &arg1_buffer_view));
- iree_hal_buffer_release(arg0_buffer);
- iree_hal_buffer_release(arg1_buffer);
-
- // Setup call inputs with our buffers.
- // TODO(benvanik): make a macro/magic.
- vm::ref<iree_vm_list_t> inputs;
- IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 2,
- iree_allocator_system(), &inputs));
- auto arg0_buffer_view_ref = iree_hal_buffer_view_move_ref(arg0_buffer_view);
- auto arg1_buffer_view_ref = iree_hal_buffer_view_move_ref(arg1_buffer_view);
- IREE_ASSERT_OK(
- iree_vm_list_push_ref_move(inputs.get(), &arg0_buffer_view_ref));
- IREE_ASSERT_OK(
- iree_vm_list_push_ref_move(inputs.get(), &arg1_buffer_view_ref));
-
- // Prepare outputs list to accept the results from the invocation.
- vm::ref<iree_vm_list_t> outputs;
- IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 1,
- iree_allocator_system(), &outputs));
-
- // Synchronously invoke the function.
- IREE_LOG(INFO) << "Calling " << kMainFunctionName << "...";
- IREE_ASSERT_OK(iree_vm_invoke(context, main_function,
- /*policy=*/nullptr, inputs.get(), outputs.get(),
- iree_allocator_system()));
-
- // Get the result buffers from the invocation.
- IREE_LOG(INFO) << "Retrieving results...";
- auto* ret_buffer_view =
- reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
- outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
- ASSERT_NE(nullptr, ret_buffer_view);
-
- // Read back the results and ensure we got the right values.
- IREE_LOG(INFO) << "Reading back results...";
- iree_hal_buffer_mapping_t mapped_memory;
- IREE_ASSERT_OK(iree_hal_buffer_map_range(
- iree_hal_buffer_view_buffer(ret_buffer_view), IREE_HAL_MEMORY_ACCESS_READ,
- 0, IREE_WHOLE_BUFFER, &mapped_memory));
- ASSERT_THAT(absl::Span<const float>(
- reinterpret_cast<const float*>(mapped_memory.contents.data),
- mapped_memory.contents.data_length / sizeof(float)),
- ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f}));
- iree_hal_buffer_unmap_range(&mapped_memory);
- IREE_LOG(INFO) << "Results match!";
-
- inputs.reset();
- outputs.reset();
- iree_hal_device_release(device);
- iree_vm_context_release(context);
- iree_vm_instance_release(instance);
-}
-
-INSTANTIATE_TEST_SUITE_P(AllDrivers, SimpleEmbeddingTest,
- ::testing::ValuesIn(GetDriverTestParams()),
- ::testing::PrintToStringParamName());
-
-} // namespace
-} // namespace samples
-} // namespace iree
diff --git a/iree/test/microbenchmarks/mhlo_conv.mlir b/iree/test/microbenchmarks/mhlo_conv.mlir
index 81c1f11..61cf9c5 100644
--- a/iree/test/microbenchmarks/mhlo_conv.mlir
+++ b/iree/test/microbenchmarks/mhlo_conv.mlir
@@ -1,7 +1,12 @@
-// Naming convention @conv_{input-spatial-dim}_{output-spatial-dim}_{filter-size}-{outputsizexinputsize}
-
-// The following ops sampled from MobileVision V1 (mobilenet_v1_100_224)
-// https://github.com/google/iree/blob/main/integrations/tensorflow/e2e/slim_vision_models/slim_vision_model_test.py#L34
+//===----------------------------------------------------------------------===//
+// Pointwise convolution ops.
+// Naming convention: '_'.join(
+// [conv,
+// {input-spatial-dims},
+// {output-spatial-dims},
+// {filter-dims},
+// {output-size x input-size}])
+//===----------------------------------------------------------------------===//
func @conv_244_112_3x3_3x32() -> tensor<1x112x112x32xf32> attributes { iree.module.export } {
%input = iree.unfoldable_constant dense<1.0> : tensor<1x224x224x3xf32>
@@ -18,7 +23,7 @@
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
- },
+ },
feature_group_count = 1 : i64,
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
@@ -31,7 +36,7 @@
%input = iree.unfoldable_constant dense<1.0> : tensor<1x112x112x32xf32>
%filter = iree.unfoldable_constant dense<1.0> : tensor<1x1x32x64xf32>
%0 = "mhlo.convolution"(%input, %filter) {
- batch_group_count = 1 : i64,
+ batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
@@ -74,3 +79,86 @@
} : (tensor<1x7x7x1024xf32>, tensor<1x1x1024x1024xf32>) -> tensor<1x7x7x1024xf32>
return %0 : tensor<1x7x7x1024xf32>
}
+
+//===----------------------------------------------------------------------===//
+// Depthwise convolution ops.
+// Naming convention: '_'.join(
+// [depthwise_conv,
+// {input-spatial-dim},
+// {output-spatial-dim},
+// {filter-size},
+// {output-size x input-size},
+// {feature-group}])
+//===----------------------------------------------------------------------===//
+
+func @depthwise_conv_15x1_1x1_15x1_1x1024_1024() -> tensor<1x1x1x1024xf32> attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<1.0> : tensor<1x15x1x1024xf32>
+ %filter = iree.unfoldable_constant dense<1.0> : tensor<15x1x1x1024xf32>
+ %res = "mhlo.convolution"(%input, %filter) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 1024 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<1x15x1x1024xf32>, tensor<15x1x1x1024xf32>) -> tensor<1x1x1x1024xf32>
+ return %res : tensor<1x1x1x1024xf32>
+}
+
+func @depthwise_conv_15x1_1x1_15x1_1x512_512() -> tensor<1x1x1x512xf32> attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<1.0> : tensor<1x15x1x512xf32>
+ %filter = iree.unfoldable_constant dense<1.0> : tensor<15x1x1x512xf32>
+ %res = "mhlo.convolution"(%input, %filter) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 512 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<1x15x1x512xf32>, tensor<15x1x1x512xf32>) -> tensor<1x1x1x512xf32>
+ return %res : tensor<1x1x1x512xf32>
+}
+
+func @depthwise_conv_16x1_2x1_16x1_1x512_512() -> tensor<1x2x1x512xf32> attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<1.0> : tensor<1x16x1x512xf32>
+ %filter = iree.unfoldable_constant dense<1.0> : tensor<15x1x1x512xf32>
+ %res = "mhlo.convolution"(%input, %filter) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 512 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<1x16x1x512xf32>, tensor<15x1x1x512xf32>) -> tensor<1x2x1x512xf32>
+ return %res : tensor<1x2x1x512xf32>
+}
diff --git a/iree/test/microbenchmarks/mhlo_dot.mlir b/iree/test/microbenchmarks/mhlo_dot.mlir
index 4abb01d..a3614ee 100644
--- a/iree/test/microbenchmarks/mhlo_dot.mlir
+++ b/iree/test/microbenchmarks/mhlo_dot.mlir
@@ -1,5 +1,6 @@
-// The following ops are sampled from mobile_bert
-// https://github.com/google/iree/blob/main/integrations/tensorflow/e2e/mobile_bert_squad_test.py
+//===----------------------------------------------------------------------===//
+// O(N^3) matmul ops.
+//===----------------------------------------------------------------------===//
func @dot_384x384x512() -> tensor<384x512xf32> attributes { iree.module.export } {
%lhs = iree.unfoldable_constant dense<1.0> : tensor<384x384xf32>
@@ -30,7 +31,7 @@
}
func @dot_384x512x2() -> tensor<384x2xf32> attributes { iree.module.export } {
- %lhs = iree.unfoldable_constant dense<1.0> : tensor<384x512xf32>
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<384x512xf32>
%rhs = iree.unfoldable_constant dense<1.0> : tensor<512x2xf32>
%0 = "mhlo.dot"(%lhs, %rhs) : (tensor<384x512xf32>, tensor<512x2xf32>) -> tensor<384x2xf32>
return %0 : tensor<384x2xf32>
@@ -42,3 +43,133 @@
%0 = "mhlo.dot"(%lhs, %rhs) : (tensor<384x384xf32>, tensor<384x32xf32>) -> tensor<384x32xf32>
return %0 : tensor<384x32xf32>
}
+
+//===----------------------------------------------------------------------===//
+// O(N^2) matmul ops.
+//===----------------------------------------------------------------------===//
+
+func @dot_1x1024x1024() -> tensor<1x1024xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x1024xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<1024x1024xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x1024xf32>, tensor<1024x1024xf32>) -> tensor<1x1024xf32>
+ return %0 : tensor<1x1024xf32>
+}
+
+func @dot_1x1024x2048() -> tensor<1x2048xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x1024xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<1024x2048xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x1024xf32>, tensor<1024x2048xf32>) -> tensor<1x2048xf32>
+ return %0 : tensor<1x2048xf32>
+}
+
+func @dot_1x1024x3072() -> tensor<1x3072xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x1024xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<1024x3072xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x1024xf32>, tensor<1024x3072xf32>) -> tensor<1x3072xf32>
+ return %0 : tensor<1x3072xf32>
+}
+
+func @dot_1x1024x512() -> tensor<1x512xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x1024xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<1024x512xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x1024xf32>, tensor<1024x512xf32>) -> tensor<1x512xf32>
+ return %0 : tensor<1x512xf32>
+}
+
+func @dot_1x128x2() -> tensor<1x2xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x128xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<128x2xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x128xf32>, tensor<128x2xf32>) -> tensor<1x2xf32>
+ return %0 : tensor<1x2xf32>
+}
+
+func @dot_1x256x512() -> tensor<1x512xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x256xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<256x512xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x256xf32>, tensor<256x512xf32>) -> tensor<1x512xf32>
+ return %0 : tensor<1x512xf32>
+}
+
+func @dot_1x3072x1024() -> tensor<1x1024xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x3072xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<3072x1024xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x3072xf32>, tensor<3072x1024xf32>) -> tensor<1x1024xf32>
+ return %0 : tensor<1x1024xf32>
+}
+
+func @dot_1x3072x512() -> tensor<1x512xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x3072xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<3072x512xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x3072xf32>, tensor<3072x512xf32>) -> tensor<1x512xf32>
+ return %0 : tensor<1x512xf32>
+}
+
+func @dot_1x512x1024() -> tensor<1x1024xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x512xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<512x1024xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x512xf32>, tensor<512x1024xf32>) -> tensor<1x1024xf32>
+ return %0 : tensor<1x1024xf32>
+}
+
+func @dot_1x512x3072() -> tensor<1x3072xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x512xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<512x3072xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x512xf32>, tensor<512x3072xf32>) -> tensor<1x3072xf32>
+ return %0 : tensor<1x3072xf32>
+}
+
+func @dot_1x512x512() -> tensor<1x512xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x512xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<512x512xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x512xf32>, tensor<512x512xf32>) -> tensor<1x512xf32>
+ return %0 : tensor<1x512xf32>
+}
+
+func @dot_1x528x128() -> tensor<1x128xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<1x528xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<528x128xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<1x528xf32>, tensor<528x128xf32>) -> tensor<1x128xf32>
+ return %0 : tensor<1x128xf32>
+}
+
+func @dot_2x3072x512() -> tensor<2x512xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<2x3072xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<3072x512xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<2x3072xf32>, tensor<3072x512xf32>) -> tensor<2x512xf32>
+ return %0 : tensor<2x512xf32>
+}
+
+func @dot_2x512x1024() -> tensor<2x1024xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<2x512xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<512x1024xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<2x512xf32>, tensor<512x1024xf32>) -> tensor<2x1024xf32>
+ return %0 : tensor<2x1024xf32>
+}
+
+func @dot_2x512x3072() -> tensor<2x3072xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<2x512xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<512x3072xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<2x512xf32>, tensor<512x3072xf32>) -> tensor<2x3072xf32>
+ return %0 : tensor<2x3072xf32>
+}
+
+func @dot_2x512x512() -> tensor<2x512xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<2x512xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<512x512xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<2x512xf32>, tensor<512x512xf32>) -> tensor<2x512xf32>
+ return %0 : tensor<2x512xf32>
+}
+
+func @dot_2x528x512() -> tensor<2x512xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<2x528xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<528x512xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<2x528xf32>, tensor<528x512xf32>) -> tensor<2x512xf32>
+ return %0 : tensor<2x512xf32>
+}
+
+func @dot_6x513x128() -> tensor<6x128xf32> attributes { iree.module.export } {
+ %lhs = iree.unfoldable_constant dense<1.0> : tensor<6x513xf32>
+ %rhs = iree.unfoldable_constant dense<1.0> : tensor<513x128xf32>
+ %0 = "mhlo.dot"(%lhs, %rhs) : (tensor<6x513xf32>, tensor<513x128xf32>) -> tensor<6x128xf32>
+ return %0 : tensor<6x128xf32>
+}
diff --git a/iree/test/microbenchmarks/mhlo_fft_abs.mlir b/iree/test/microbenchmarks/mhlo_fft_abs.mlir
new file mode 100644
index 0000000..1286151
--- /dev/null
+++ b/iree/test/microbenchmarks/mhlo_fft_abs.mlir
@@ -0,0 +1,13 @@
+//===----------------------------------------------------------------------===//
+// rfft + abs ops
+//===----------------------------------------------------------------------===//
+
+func @rfft_abs_6x1024() -> tensor<6x513xf32> attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<1.0> : tensor<6x1024xf32>
+ %0 = "mhlo.fft"(%input) {
+ fft_length = dense<1024> : tensor<1xi64>,
+ fft_type = "RFFT"
+ } : (tensor<6x1024xf32>) -> tensor<6x513xcomplex<f32>>
+ %1 = "mhlo.abs"(%0) : (tensor<6x513xcomplex<f32>>) -> tensor<6x513xf32>
+ return %1: tensor<6x513xf32>
+}
diff --git a/scripts/check_tabs.sh b/scripts/check_tabs.sh
index bdb1eac..683e9c0 100755
--- a/scripts/check_tabs.sh
+++ b/scripts/check_tabs.sh
@@ -29,6 +29,7 @@
"\.fb$"
"\.jar$"
"\.so$"
+ "\.vmfb$"
)
# Join on |