Merge google -> main (#3050)
* 2730f436 Merge pull request #3048 from GMNGeoffrey:main-to-google
* 7f8f2c51 Opt-in to the global dialect registry
* 051a9e2f Synchronize submodules
* 11d565ee Integrate LLVM at llvm/llvm-project@ffd0b31c7cba
* 833983bc Synchronize submodules
* 5ae0d624 Integrate LLVM at llvm/llvm-project@1d3d9b9cd808
* 80c32bdf Synchronize submodules
* 731139cf Integrate LLVM at llvm/llvm-project@646f19bb9dc8
* de995a98 Synchronize submodules
* 703e3782 Integrate LLVM at llvm/llvm-project@bc3d4d9ed783
* 92689b8d Adds support for invoking a function through the java api
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 0bb7bdb..ceb9894 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -4,7 +4,7 @@
a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-1d01fc100bb5bef5f5eaf92520b2e52f64ee1d6e third_party/llvm-project
+ffd0b31c7cbaa8322d2963afe6ace4e3a0889ddb third_party/llvm-project
17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
560cd8c94495409cbeacb31c0b907b5068adc45b third_party/mlir-emitc
d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
diff --git a/bindings/java/com/google/iree/Context.java b/bindings/java/com/google/iree/Context.java
index 2b42c79..8201c1b 100644
--- a/bindings/java/com/google/iree/Context.java
+++ b/bindings/java/com/google/iree/Context.java
@@ -16,6 +16,7 @@
package com.google.iree;
+import java.nio.FloatBuffer;
import java.util.List;
/** An isolated execution context. */
@@ -63,6 +64,17 @@
return function;
}
+ public void invokeFunction(
+ Function function, FloatBuffer[] inputs, int inputElementCount, FloatBuffer output)
+ throws Exception {
+ Status status =
+ Status.fromCode(
+ nativeInvokeFunction(function.getNativeAddress(), inputs, inputElementCount, output));
+ if (!status.isOk()) {
+ throw status.toException("Could not invoke function");
+ }
+ }
+
public int getId() {
return nativeGetId();
}
@@ -93,6 +105,10 @@
private native int nativeResolveFunction(long functionAddress, String name);
+ // TODO(jennik): 'output' should be a Floatbuffer[].
+ private native int nativeInvokeFunction(
+ long functionAddress, FloatBuffer[] inputs, int inputElementCount, FloatBuffer output);
+
private native void nativeFree();
private native int nativeGetId();
diff --git a/bindings/java/com/google/iree/native/context_jni.cc b/bindings/java/com/google/iree/native/context_jni.cc
index 1234947..0cdcce2 100644
--- a/bindings/java/com/google/iree/native/context_jni.cc
+++ b/bindings/java/com/google/iree/native/context_jni.cc
@@ -118,6 +118,30 @@
return (jint)status.code();
}
+JNI_FUNC jint JNI_PREFIX(nativeInvokeFunction)(JNIEnv* env, jobject thiz,
+ jlong functionAddress,
+ jobjectArray inputs,
+ jint inputElementCount,
+ jobject output) {
+ ContextWrapper* context = GetContextWrapper(env, thiz);
+ CHECK_NE(context, nullptr);
+
+ const jsize inputs_size = env->GetArrayLength(inputs);
+ std::vector<float*> native_inputs(inputs_size);
+
+ for (int i = 0; i < inputs_size; i++) {
+ jobject input = env->GetObjectArrayElement(inputs, i);
+ float* native_input = (float*)env->GetDirectBufferAddress(input);
+ native_inputs[i] = native_input;
+ }
+
+ auto function = (FunctionWrapper*)functionAddress;
+ float* native_output = (float*)env->GetDirectBufferAddress(output);
+ auto status = context->InvokeFunction(*function, native_inputs,
+ (int)inputElementCount, native_output);
+ return (jint)status.code();
+}
+
JNI_FUNC jint JNI_PREFIX(nativeGetId)(JNIEnv* env, jobject thiz) {
ContextWrapper* context = GetContextWrapper(env, thiz);
CHECK_NE(context, nullptr);
diff --git a/bindings/java/com/google/iree/native/context_wrapper.cc b/bindings/java/com/google/iree/native/context_wrapper.cc
index 10f2726..7eb5753 100644
--- a/bindings/java/com/google/iree/native/context_wrapper.cc
+++ b/bindings/java/com/google/iree/native/context_wrapper.cc
@@ -14,7 +14,9 @@
#include "bindings/java/com/google/iree/native/context_wrapper.h"
+#include "iree/base/api.h"
#include "iree/base/logging.h"
+#include "iree/vm/ref_cc.h"
namespace iree {
namespace java {
@@ -72,6 +74,72 @@
function_wrapper->function());
}
+Status ContextWrapper::InvokeFunction(const FunctionWrapper& function_wrapper,
+ const std::vector<float*>& inputs,
+ int input_element_count, float* output) {
+ vm::ref<iree_vm_list_t> input_list;
+ IREE_RETURN_IF_ERROR(iree_vm_list_create(
+ /*element_type=*/nullptr, input_element_count, iree_allocator_system(),
+ &input_list));
+
+ iree_hal_allocator_t* allocator = iree_hal_device_allocator(device_);
+ iree_hal_memory_type_t input_memory_type =
+ static_cast<iree_hal_memory_type_t>(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+ IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
+ iree_hal_buffer_usage_t input_buffer_usage =
+ static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_ALL |
+ IREE_HAL_BUFFER_USAGE_CONSTANT);
+
+ for (auto input : inputs) {
+ // Write the input into a mappable buffer.
+ iree_hal_buffer_t* input_buffer = nullptr;
+ IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
+ allocator, input_memory_type, input_buffer_usage,
+ sizeof(float) * input_element_count, &input_buffer));
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_write_data(
+ input_buffer, 0, input, input_element_count * sizeof(float)));
+
+ // Wrap the input buffers in buffer views.
+ iree_hal_buffer_view_t* input_buffer_view = nullptr;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
+ input_buffer, /*shape=*/&input_element_count,
+ /*shape_rank=*/1, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+ iree_allocator_system(), &input_buffer_view));
+ iree_hal_buffer_release(input_buffer);
+
+ // Marshal the input buffer views through the input VM variant list.
+ auto input_buffer_view_ref =
+ iree_hal_buffer_view_move_ref(input_buffer_view);
+ IREE_RETURN_IF_ERROR(
+ iree_vm_list_push_ref_move(input_list.get(), &input_buffer_view_ref));
+ }
+
+ // Prepare outputs list to accept results from the invocation.
+ vm::ref<iree_vm_list_t> outputs;
+ IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr,
+ 4 * sizeof(float),
+ iree_allocator_system(), &outputs));
+
+ // Synchronously invoke the function.
+ IREE_RETURN_IF_ERROR(iree_vm_invoke(context_, *function_wrapper.function(),
+ /*policy=*/nullptr, input_list.get(),
+ outputs.get(), iree_allocator_system()));
+
+ // Read back the results into the given output buffer.
+ auto* output_buffer_view =
+ reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
+ outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
+ auto* output_buffer = iree_hal_buffer_view_buffer(output_buffer_view);
+ iree_hal_mapped_memory_t mapped_memory;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_map(output_buffer,
+ IREE_HAL_MEMORY_ACCESS_READ, 0,
+ IREE_WHOLE_BUFFER, &mapped_memory));
+ memcpy(output, mapped_memory.contents.data,
+ mapped_memory.contents.data_length);
+ iree_hal_buffer_unmap(output_buffer, &mapped_memory);
+ return OkStatus();
+}
+
int ContextWrapper::id() const { return iree_vm_context_id(context_); }
ContextWrapper::~ContextWrapper() {
diff --git a/bindings/java/com/google/iree/native/context_wrapper.h b/bindings/java/com/google/iree/native/context_wrapper.h
index 7b0cb88..dfbdfe7 100644
--- a/bindings/java/com/google/iree/native/context_wrapper.h
+++ b/bindings/java/com/google/iree/native/context_wrapper.h
@@ -38,6 +38,11 @@
Status ResolveFunction(iree_string_view_t name,
FunctionWrapper* function_wrapper);
+ // TODO(jennik): Support other input types aside from floats.
+ Status InvokeFunction(const FunctionWrapper& function_wrapper,
+ const std::vector<float*>& inputs,
+ int input_element_count, float* output);
+
int id() const;
~ContextWrapper();
diff --git a/bindings/javatests/com/google/iree/IntegrationTest.java b/bindings/javatests/com/google/iree/IntegrationTest.java
index ea653ae..85fca5f 100644
--- a/bindings/javatests/com/google/iree/IntegrationTest.java
+++ b/bindings/javatests/com/google/iree/IntegrationTest.java
@@ -16,17 +16,22 @@
package com.google.iree;
+import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.fail;
import android.content.Context;
import android.content.res.Resources;
+import android.util.Log;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.junit.Test;
@@ -34,6 +39,8 @@
@RunWith(AndroidJUnit4.class)
public final class IntegrationTest {
+ private static final String TAG = IntegrationTest.class.getCanonicalName();
+
@Test
public void throwsExceptionWithoutNativeLib() throws Exception {
try {
@@ -65,7 +72,29 @@
Function function = ireeContext.resolveFunction(functionName);
function.printDebugString();
- // TODO(jennik): Invoke the function.
+ int elementCount = 4;
+ FloatBuffer x = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+ .order(ByteOrder.nativeOrder())
+ .asFloatBuffer()
+ .put(new float[] {4.0f, 4.0f, 4.0f, 4.0f});
+ FloatBuffer y = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+ .order(ByteOrder.nativeOrder())
+ .asFloatBuffer()
+ .put(new float[] {2.0f, 2.0f, 2.0f, 2.0f});
+ FloatBuffer[] inputs = {x, y};
+
+ // TODO(jennik): Allocate outputs in C++ rather than here.
+ FloatBuffer outputBuffer = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+ .order(ByteOrder.nativeOrder())
+ .asFloatBuffer()
+ .put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});
+ ireeContext.invokeFunction(function, inputs, elementCount, outputBuffer);
+
+ float[] output = new float[elementCount];
+ outputBuffer.position(0);
+ outputBuffer.get(output);
+ Log.d(TAG, "Output: " + Arrays.toString(output));
+ assertArrayEquals(new float[] {8.0f, 8.0f, 8.0f, 8.0f}, output, 0.1f);
function.free();
module.free();
@@ -96,6 +125,29 @@
Function function = ireeContext.resolveFunction(functionName);
function.printDebugString();
+ int elementCount = 4;
+ FloatBuffer x = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+ .order(ByteOrder.nativeOrder())
+ .asFloatBuffer()
+ .put(new float[] {4.0f, 4.0f, 4.0f, 4.0f});
+ FloatBuffer y = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+ .order(ByteOrder.nativeOrder())
+ .asFloatBuffer()
+ .put(new float[] {2.0f, 2.0f, 2.0f, 2.0f});
+ FloatBuffer[] inputs = {x, y};
+
+ FloatBuffer outputBuffer = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+ .order(ByteOrder.nativeOrder())
+ .asFloatBuffer()
+ .put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});
+ ireeContext.invokeFunction(function, inputs, elementCount, outputBuffer);
+
+ float[] output = new float[elementCount];
+ outputBuffer.position(0);
+ outputBuffer.get(output);
+ Log.d(TAG, "Output: " + Arrays.toString(output));
+ assertArrayEquals(new float[] {8.0f, 8.0f, 8.0f, 8.0f}, output, 0.1f);
+
function.free();
module.free();
ireeContext.free();
diff --git a/bindings/javatests/com/google/iree/integration_test.cc b/bindings/javatests/com/google/iree/integration_test.cc
index eac0905..d10d3c7 100644
--- a/bindings/javatests/com/google/iree/integration_test.cc
+++ b/bindings/javatests/com/google/iree/integration_test.cc
@@ -65,6 +65,24 @@
LOG(INFO) << "Function name: "
<< std::string(function.name().data, function.name().size);
+ float input_x[] = {2.0f, 2.0f, 2.0f, 2.0f};
+ float input_y[] = {4.0f, 4.0f, 4.0f, 4.0f};
+ std::vector<float*> input{input_x, input_y};
+ float output[4] = {0.0f, 1.0f, 2.0f, 3.0f};
+ int element_count = 4;
+
+ auto invoke_status =
+ context->InvokeFunction(function, input, element_count, output);
+ if (!context_status.ok()) {
+ LOG(ERROR) << "Invoke function error: " << function_status.code();
+ return 1;
+ }
+
+ LOG(INFO) << "Function output:";
+ for (int i = 0; i < element_count; i++) {
+ LOG(INFO) << output[i];
+ }
+
return 0;
}
diff --git a/bindings/python/pyiree/compiler/compiler.cc b/bindings/python/pyiree/compiler/compiler.cc
index 6c189f8..2bae670 100644
--- a/bindings/python/pyiree/compiler/compiler.cc
+++ b/bindings/python/pyiree/compiler/compiler.cc
@@ -73,6 +73,7 @@
llvm::sys::DefaultOneShotPipeSignalHandler);
llvm::sys::PrintStackTraceOnErrorSignal("pyiree");
+ mlir::enableGlobalDialectRegistry(true);
// Register built-in MLIR dialects.
mlir::registerMlirDialects();
diff --git a/iree/samples/custom_modules/dialect/custom_opt.cc b/iree/samples/custom_modules/dialect/custom_opt.cc
index be1891a..77969ad 100644
--- a/iree/samples/custom_modules/dialect/custom_opt.cc
+++ b/iree/samples/custom_modules/dialect/custom_opt.cc
@@ -74,6 +74,7 @@
llvm::cl::init(false));
int main(int argc, char **argv) {
+ mlir::enableGlobalDialectRegistry(true);
mlir::DialectRegistry registry;
mlir::registerMlirDialects(registry);
mlir::registerMlirPasses();
diff --git a/iree/samples/custom_modules/dialect/custom_translate.cc b/iree/samples/custom_modules/dialect/custom_translate.cc
index 1cff7e5..d913857 100644
--- a/iree/samples/custom_modules/dialect/custom_translate.cc
+++ b/iree/samples/custom_modules/dialect/custom_translate.cc
@@ -56,6 +56,7 @@
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
+ mlir::enableGlobalDialectRegistry(true);
mlir::registerMlirDialects();
mlir::registerXLADialects();
diff --git a/iree/tools/opt_main.cc b/iree/tools/opt_main.cc
index 281eae8..5e6af4b 100644
--- a/iree/tools/opt_main.cc
+++ b/iree/tools/opt_main.cc
@@ -76,6 +76,7 @@
llvm::cl::init(false));
int main(int argc, char **argv) {
+ mlir::enableGlobalDialectRegistry(true);
mlir::DialectRegistry registry;
mlir::registerMlirDialects(registry);
mlir::registerMlirPasses();
diff --git a/iree/tools/run_mlir_main.cc b/iree/tools/run_mlir_main.cc
index 289ac5f..bf2b5c5 100644
--- a/iree/tools/run_mlir_main.cc
+++ b/iree/tools/run_mlir_main.cc
@@ -490,6 +490,7 @@
}
}
+ mlir::enableGlobalDialectRegistry(true);
mlir::registerMlirDialects();
mlir::iree_compiler::registerIreeDialects();
mlir::iree_compiler::registerIreeCompilerModuleDialects();
diff --git a/iree/tools/translate_main.cc b/iree/tools/translate_main.cc
index 3ecd201..6f251ed 100644
--- a/iree/tools/translate_main.cc
+++ b/iree/tools/translate_main.cc
@@ -59,6 +59,7 @@
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
+ mlir::enableGlobalDialectRegistry(true);
mlir::registerMlirDialects();
#ifdef IREE_HAVE_EMITC_DIALECT
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 1d01fc1..ffd0b31 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 1d01fc100bb5bef5f5eaf92520b2e52f64ee1d6e
+Subproject commit ffd0b31c7cbaa8322d2963afe6ace4e3a0889ddb