Merge google -> main (#3050)

* 2730f436 Merge pull request #3048 from GMNGeoffrey:main-to-google
* 7f8f2c51 Opt-in to the global dialect registry
* 051a9e2f Synchronize submodules
* 11d565ee Integrate LLVM at llvm/llvm-project@ffd0b31c7cba
* 833983bc Synchronize submodules
* 5ae0d624 Integrate LLVM at llvm/llvm-project@1d3d9b9cd808
* 80c32bdf Synchronize submodules
* 731139cf Integrate LLVM at llvm/llvm-project@646f19bb9dc8
* de995a98 Synchronize submodules
* 703e3782 Integrate LLVM at llvm/llvm-project@bc3d4d9ed783
* 92689b8d Adds support for invoking a function through the java api
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 0bb7bdb..ceb9894 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -4,7 +4,7 @@
 a5d9d0f7d368054fd1691aedf1db4116efcc233e third_party/flatbuffers
 4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
 f2fb48c3b3d79a75a88a99fba6576b25d42ec528 third_party/googletest
-1d01fc100bb5bef5f5eaf92520b2e52f64ee1d6e third_party/llvm-project
+ffd0b31c7cbaa8322d2963afe6ace4e3a0889ddb third_party/llvm-project
 17b12a4481daa150e2d1ea3ada086b551b856707 third_party/marl
 560cd8c94495409cbeacb31c0b907b5068adc45b third_party/mlir-emitc
 d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
diff --git a/bindings/java/com/google/iree/Context.java b/bindings/java/com/google/iree/Context.java
index 2b42c79..8201c1b 100644
--- a/bindings/java/com/google/iree/Context.java
+++ b/bindings/java/com/google/iree/Context.java
@@ -16,6 +16,7 @@
 
 package com.google.iree;
 
+import java.nio.FloatBuffer;
 import java.util.List;
 
 /** An isolated execution context. */
@@ -63,6 +64,17 @@
     return function;
   }
 
+  public void invokeFunction(
+      Function function, FloatBuffer[] inputs, int inputElementCount, FloatBuffer output)
+      throws Exception {
+    Status status =
+        Status.fromCode(
+            nativeInvokeFunction(function.getNativeAddress(), inputs, inputElementCount, output));
+    if (!status.isOk()) {
+      throw status.toException("Could not invoke function");
+    }
+  }
+
   public int getId() {
     return nativeGetId();
   }
@@ -93,6 +105,10 @@
 
   private native int nativeResolveFunction(long functionAddress, String name);
 
+  // TODO(jennik): 'output' should be a Floatbuffer[].
+  private native int nativeInvokeFunction(
+      long functionAddress, FloatBuffer[] inputs, int inputElementCount, FloatBuffer output);
+
   private native void nativeFree();
 
   private native int nativeGetId();
diff --git a/bindings/java/com/google/iree/native/context_jni.cc b/bindings/java/com/google/iree/native/context_jni.cc
index 1234947..0cdcce2 100644
--- a/bindings/java/com/google/iree/native/context_jni.cc
+++ b/bindings/java/com/google/iree/native/context_jni.cc
@@ -118,6 +118,30 @@
   return (jint)status.code();
 }
 
+JNI_FUNC jint JNI_PREFIX(nativeInvokeFunction)(JNIEnv* env, jobject thiz,
+                                               jlong functionAddress,
+                                               jobjectArray inputs,
+                                               jint inputElementCount,
+                                               jobject output) {
+  ContextWrapper* context = GetContextWrapper(env, thiz);
+  CHECK_NE(context, nullptr);
+
+  const jsize inputs_size = env->GetArrayLength(inputs);
+  std::vector<float*> native_inputs(inputs_size);
+
+  for (int i = 0; i < inputs_size; i++) {
+    jobject input = env->GetObjectArrayElement(inputs, i);
+    float* native_input = (float*)env->GetDirectBufferAddress(input);
+    native_inputs[i] = native_input;
+  }
+
+  auto function = (FunctionWrapper*)functionAddress;
+  float* native_output = (float*)env->GetDirectBufferAddress(output);
+  auto status = context->InvokeFunction(*function, native_inputs,
+                                        (int)inputElementCount, native_output);
+  return (jint)status.code();
+}
+
 JNI_FUNC jint JNI_PREFIX(nativeGetId)(JNIEnv* env, jobject thiz) {
   ContextWrapper* context = GetContextWrapper(env, thiz);
   CHECK_NE(context, nullptr);
diff --git a/bindings/java/com/google/iree/native/context_wrapper.cc b/bindings/java/com/google/iree/native/context_wrapper.cc
index 10f2726..7eb5753 100644
--- a/bindings/java/com/google/iree/native/context_wrapper.cc
+++ b/bindings/java/com/google/iree/native/context_wrapper.cc
@@ -14,7 +14,9 @@
 
 #include "bindings/java/com/google/iree/native/context_wrapper.h"
 
+#include "iree/base/api.h"
 #include "iree/base/logging.h"
+#include "iree/vm/ref_cc.h"
 
 namespace iree {
 namespace java {
@@ -72,6 +74,72 @@
                                           function_wrapper->function());
 }
 
+Status ContextWrapper::InvokeFunction(const FunctionWrapper& function_wrapper,
+                                      const std::vector<float*>& inputs,
+                                      int input_element_count, float* output) {
+  vm::ref<iree_vm_list_t> input_list;
+  IREE_RETURN_IF_ERROR(iree_vm_list_create(
+      /*element_type=*/nullptr, input_element_count, iree_allocator_system(),
+      &input_list));
+
+  iree_hal_allocator_t* allocator = iree_hal_device_allocator(device_);
+  iree_hal_memory_type_t input_memory_type =
+      static_cast<iree_hal_memory_type_t>(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
+                                          IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
+  iree_hal_buffer_usage_t input_buffer_usage =
+      static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_ALL |
+                                           IREE_HAL_BUFFER_USAGE_CONSTANT);
+
+  for (auto input : inputs) {
+    // Write the input into a mappable buffer.
+    iree_hal_buffer_t* input_buffer = nullptr;
+    IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
+        allocator, input_memory_type, input_buffer_usage,
+        sizeof(float) * input_element_count, &input_buffer));
+    IREE_RETURN_IF_ERROR(iree_hal_buffer_write_data(
+        input_buffer, 0, input, input_element_count * sizeof(float)));
+
+    // Wrap the input buffers in buffer views.
+    iree_hal_buffer_view_t* input_buffer_view = nullptr;
+    IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
+        input_buffer, /*shape=*/&input_element_count,
+        /*shape_rank=*/1, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
+        iree_allocator_system(), &input_buffer_view));
+    iree_hal_buffer_release(input_buffer);
+
+    // Marshal the input buffer views through the input VM variant list.
+    auto input_buffer_view_ref =
+        iree_hal_buffer_view_move_ref(input_buffer_view);
+    IREE_RETURN_IF_ERROR(
+        iree_vm_list_push_ref_move(input_list.get(), &input_buffer_view_ref));
+  }
+
+  // Prepare outputs list to accept results from the invocation.
+  vm::ref<iree_vm_list_t> outputs;
+  IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr,
+                                           4 * sizeof(float),
+                                           iree_allocator_system(), &outputs));
+
+  // Synchronously invoke the function.
+  IREE_RETURN_IF_ERROR(iree_vm_invoke(context_, *function_wrapper.function(),
+                                      /*policy=*/nullptr, input_list.get(),
+                                      outputs.get(), iree_allocator_system()));
+
+  // Read back the results into the given output buffer.
+  auto* output_buffer_view =
+      reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
+          outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
+  auto* output_buffer = iree_hal_buffer_view_buffer(output_buffer_view);
+  iree_hal_mapped_memory_t mapped_memory;
+  IREE_RETURN_IF_ERROR(iree_hal_buffer_map(output_buffer,
+                                           IREE_HAL_MEMORY_ACCESS_READ, 0,
+                                           IREE_WHOLE_BUFFER, &mapped_memory));
+  memcpy(output, mapped_memory.contents.data,
+         mapped_memory.contents.data_length);
+  iree_hal_buffer_unmap(output_buffer, &mapped_memory);
+  return OkStatus();
+}
+
 int ContextWrapper::id() const { return iree_vm_context_id(context_); }
 
 ContextWrapper::~ContextWrapper() {
diff --git a/bindings/java/com/google/iree/native/context_wrapper.h b/bindings/java/com/google/iree/native/context_wrapper.h
index 7b0cb88..dfbdfe7 100644
--- a/bindings/java/com/google/iree/native/context_wrapper.h
+++ b/bindings/java/com/google/iree/native/context_wrapper.h
@@ -38,6 +38,11 @@
   Status ResolveFunction(iree_string_view_t name,
                          FunctionWrapper* function_wrapper);
 
+  // TODO(jennik): Support other input types aside from floats.
+  Status InvokeFunction(const FunctionWrapper& function_wrapper,
+                        const std::vector<float*>& inputs,
+                        int input_element_count, float* output);
+
   int id() const;
 
   ~ContextWrapper();
diff --git a/bindings/javatests/com/google/iree/IntegrationTest.java b/bindings/javatests/com/google/iree/IntegrationTest.java
index ea653ae..85fca5f 100644
--- a/bindings/javatests/com/google/iree/IntegrationTest.java
+++ b/bindings/javatests/com/google/iree/IntegrationTest.java
@@ -16,17 +16,22 @@
 
 package com.google.iree;
 
+import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.fail;
 
 import android.content.Context;
 import android.content.res.Resources;
+import android.util.Log;
 import androidx.test.core.app.ApplicationProvider;
 import androidx.test.ext.junit.runners.AndroidJUnit4;
 import java.io.IOException;
 import java.io.InputStream;
 import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import org.apache.commons.io.IOUtils;
 import org.junit.Test;
@@ -34,6 +39,8 @@
 
 @RunWith(AndroidJUnit4.class)
 public final class IntegrationTest {
+  private static final String TAG = IntegrationTest.class.getCanonicalName();
+
   @Test
   public void throwsExceptionWithoutNativeLib() throws Exception {
     try {
@@ -65,7 +72,29 @@
     Function function = ireeContext.resolveFunction(functionName);
     function.printDebugString();
 
-    // TODO(jennik): Invoke the function.
+    int elementCount = 4;
+    FloatBuffer x = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+                        .order(ByteOrder.nativeOrder())
+                        .asFloatBuffer()
+                        .put(new float[] {4.0f, 4.0f, 4.0f, 4.0f});
+    FloatBuffer y = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+                        .order(ByteOrder.nativeOrder())
+                        .asFloatBuffer()
+                        .put(new float[] {2.0f, 2.0f, 2.0f, 2.0f});
+    FloatBuffer[] inputs = {x, y};
+
+    // TODO(jennik): Allocate outputs in C++ rather than here.
+    FloatBuffer outputBuffer = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+                                   .order(ByteOrder.nativeOrder())
+                                   .asFloatBuffer()
+                                   .put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});
+    ireeContext.invokeFunction(function, inputs, elementCount, outputBuffer);
+
+    float[] output = new float[elementCount];
+    outputBuffer.position(0);
+    outputBuffer.get(output);
+    Log.d(TAG, "Output: " + Arrays.toString(output));
+    assertArrayEquals(new float[] {8.0f, 8.0f, 8.0f, 8.0f}, output, 0.1f);
 
     function.free();
     module.free();
@@ -96,6 +125,29 @@
     Function function = ireeContext.resolveFunction(functionName);
     function.printDebugString();
 
+    int elementCount = 4;
+    FloatBuffer x = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+                        .order(ByteOrder.nativeOrder())
+                        .asFloatBuffer()
+                        .put(new float[] {4.0f, 4.0f, 4.0f, 4.0f});
+    FloatBuffer y = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+                        .order(ByteOrder.nativeOrder())
+                        .asFloatBuffer()
+                        .put(new float[] {2.0f, 2.0f, 2.0f, 2.0f});
+    FloatBuffer[] inputs = {x, y};
+
+    FloatBuffer outputBuffer = ByteBuffer.allocateDirect(elementCount * /*sizeof(float)=*/4)
+                                   .order(ByteOrder.nativeOrder())
+                                   .asFloatBuffer()
+                                   .put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});
+    ireeContext.invokeFunction(function, inputs, elementCount, outputBuffer);
+
+    float[] output = new float[elementCount];
+    outputBuffer.position(0);
+    outputBuffer.get(output);
+    Log.d(TAG, "Output: " + Arrays.toString(output));
+    assertArrayEquals(new float[] {8.0f, 8.0f, 8.0f, 8.0f}, output, 0.1f);
+
     function.free();
     module.free();
     ireeContext.free();
diff --git a/bindings/javatests/com/google/iree/integration_test.cc b/bindings/javatests/com/google/iree/integration_test.cc
index eac0905..d10d3c7 100644
--- a/bindings/javatests/com/google/iree/integration_test.cc
+++ b/bindings/javatests/com/google/iree/integration_test.cc
@@ -65,6 +65,24 @@
   LOG(INFO) << "Function name: "
             << std::string(function.name().data, function.name().size);
 
+  float input_x[] = {2.0f, 2.0f, 2.0f, 2.0f};
+  float input_y[] = {4.0f, 4.0f, 4.0f, 4.0f};
+  std::vector<float*> input{input_x, input_y};
+  float output[4] = {0.0f, 1.0f, 2.0f, 3.0f};
+  int element_count = 4;
+
+  auto invoke_status =
+      context->InvokeFunction(function, input, element_count, output);
+  if (!context_status.ok()) {
+    LOG(ERROR) << "Invoke function error: " << function_status.code();
+    return 1;
+  }
+
+  LOG(INFO) << "Function output:";
+  for (int i = 0; i < element_count; i++) {
+    LOG(INFO) << output[i];
+  }
+
   return 0;
 }
 
diff --git a/bindings/python/pyiree/compiler/compiler.cc b/bindings/python/pyiree/compiler/compiler.cc
index 6c189f8..2bae670 100644
--- a/bindings/python/pyiree/compiler/compiler.cc
+++ b/bindings/python/pyiree/compiler/compiler.cc
@@ -73,6 +73,7 @@
       llvm::sys::DefaultOneShotPipeSignalHandler);
   llvm::sys::PrintStackTraceOnErrorSignal("pyiree");
 
+  mlir::enableGlobalDialectRegistry(true);
   // Register built-in MLIR dialects.
   mlir::registerMlirDialects();
 
diff --git a/iree/samples/custom_modules/dialect/custom_opt.cc b/iree/samples/custom_modules/dialect/custom_opt.cc
index be1891a..77969ad 100644
--- a/iree/samples/custom_modules/dialect/custom_opt.cc
+++ b/iree/samples/custom_modules/dialect/custom_opt.cc
@@ -74,6 +74,7 @@
     llvm::cl::init(false));
 
 int main(int argc, char **argv) {
+  mlir::enableGlobalDialectRegistry(true);
   mlir::DialectRegistry registry;
   mlir::registerMlirDialects(registry);
   mlir::registerMlirPasses();
diff --git a/iree/samples/custom_modules/dialect/custom_translate.cc b/iree/samples/custom_modules/dialect/custom_translate.cc
index 1cff7e5..d913857 100644
--- a/iree/samples/custom_modules/dialect/custom_translate.cc
+++ b/iree/samples/custom_modules/dialect/custom_translate.cc
@@ -56,6 +56,7 @@
 
 int main(int argc, char **argv) {
   llvm::InitLLVM y(argc, argv);
+  mlir::enableGlobalDialectRegistry(true);
 
   mlir::registerMlirDialects();
   mlir::registerXLADialects();
diff --git a/iree/tools/opt_main.cc b/iree/tools/opt_main.cc
index 281eae8..5e6af4b 100644
--- a/iree/tools/opt_main.cc
+++ b/iree/tools/opt_main.cc
@@ -76,6 +76,7 @@
     llvm::cl::init(false));
 
 int main(int argc, char **argv) {
+  mlir::enableGlobalDialectRegistry(true);
   mlir::DialectRegistry registry;
   mlir::registerMlirDialects(registry);
   mlir::registerMlirPasses();
diff --git a/iree/tools/run_mlir_main.cc b/iree/tools/run_mlir_main.cc
index 289ac5f..bf2b5c5 100644
--- a/iree/tools/run_mlir_main.cc
+++ b/iree/tools/run_mlir_main.cc
@@ -490,6 +490,7 @@
     }
   }
 
+  mlir::enableGlobalDialectRegistry(true);
   mlir::registerMlirDialects();
   mlir::iree_compiler::registerIreeDialects();
   mlir::iree_compiler::registerIreeCompilerModuleDialects();
diff --git a/iree/tools/translate_main.cc b/iree/tools/translate_main.cc
index 3ecd201..6f251ed 100644
--- a/iree/tools/translate_main.cc
+++ b/iree/tools/translate_main.cc
@@ -59,6 +59,7 @@
 
 int main(int argc, char **argv) {
   llvm::InitLLVM y(argc, argv);
+  mlir::enableGlobalDialectRegistry(true);
 
   mlir::registerMlirDialects();
 #ifdef IREE_HAVE_EMITC_DIALECT
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 1d01fc1..ffd0b31 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 1d01fc100bb5bef5f5eaf92520b2e52f64ee1d6e
+Subproject commit ffd0b31c7cbaa8322d2963afe6ace4e3a0889ddb