blob: e67543b01b673b1186ffaca9ef41e1cdb3ffeb02 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <jni.h>
#include <vector>
#include "bindings/java/com/google/iree/native/context_wrapper.h"
#include "bindings/java/com/google/iree/native/function_wrapper.h"
#include "bindings/java/com/google/iree/native/instance_wrapper.h"
#include "bindings/java/com/google/iree/native/module_wrapper.h"
#include "iree/base/logging.h"
#define JNI_FUNC extern "C" JNIEXPORT
#define JNI_PREFIX(METHOD) Java_com_google_iree_Context_##METHOD
using iree::java::ContextWrapper;
using iree::java::FunctionWrapper;
using iree::java::InstanceWrapper;
using iree::java::ModuleWrapper;
namespace {
// Returns a pointer to the native IREE context stored by the ContextWrapper
// object.
static ContextWrapper* GetContextWrapper(JNIEnv* env, jobject obj) {
jclass clazz = env->GetObjectClass(obj);
IREE_CHECK(clazz);
jfieldID field = env->GetFieldID(clazz, "nativeAddress", "J");
IREE_CHECK(field);
return reinterpret_cast<ContextWrapper*>(env->GetLongField(obj, field));
}
std::vector<ModuleWrapper*> GetModuleWrappersFromAdresses(
JNIEnv* env, jlongArray moduleAddresses) {
// Get the addresses of the ModuleWrappers.
jsize modules_size = env->GetArrayLength(moduleAddresses);
std::vector<int64_t> module_addresses(modules_size);
env->GetLongArrayRegion(moduleAddresses, 0, modules_size,
reinterpret_cast<jlong*>(module_addresses.data()));
// Convert the addresses to ModuleWrappers.
std::vector<ModuleWrapper*> modules(modules_size);
for (int i = 0; i < modules_size; i++) {
modules[i] = (ModuleWrapper*)module_addresses[i];
}
return modules;
}
} // namespace
JNI_FUNC jlong JNI_PREFIX(nativeNew)(JNIEnv* env, jobject thiz) {
return reinterpret_cast<jlong>(new ContextWrapper());
}
JNI_FUNC void JNI_PREFIX(nativeFree)(JNIEnv* env, jobject thiz, jlong handle) {
ContextWrapper* context = GetContextWrapper(env, thiz);
IREE_CHECK_NE(context, nullptr);
delete context;
}
JNI_FUNC jint JNI_PREFIX(nativeCreate)(JNIEnv* env, jobject thiz,
jlong instanceAddress) {
ContextWrapper* context = GetContextWrapper(env, thiz);
IREE_CHECK_NE(context, nullptr);
auto instance = (InstanceWrapper*)instanceAddress;
auto status = context->Create(*instance);
return (jint)status.code();
}
JNI_FUNC jint JNI_PREFIX(nativeCreateWithModules)(JNIEnv* env, jobject thiz,
jlong instanceAddress,
jlongArray moduleAddresses) {
ContextWrapper* context = GetContextWrapper(env, thiz);
IREE_CHECK_NE(context, nullptr);
auto instance = (InstanceWrapper*)instanceAddress;
auto modules = GetModuleWrappersFromAdresses(env, moduleAddresses);
auto status = context->CreateWithModules(*instance, modules);
return (jint)status.code();
}
JNI_FUNC jint JNI_PREFIX(nativeRegisterModules)(JNIEnv* env, jobject thiz,
jlongArray moduleAddresses) {
ContextWrapper* context = GetContextWrapper(env, thiz);
IREE_CHECK_NE(context, nullptr);
auto modules = GetModuleWrappersFromAdresses(env, moduleAddresses);
auto status = context->RegisterModules(modules);
return (jint)status.code();
}
JNI_FUNC jint JNI_PREFIX(nativeResolveFunction)(JNIEnv* env, jobject thiz,
jlong functionAddress,
jstring name) {
ContextWrapper* context = GetContextWrapper(env, thiz);
IREE_CHECK_NE(context, nullptr);
auto function = (FunctionWrapper*)functionAddress;
const char* native_name = env->GetStringUTFChars(name, /*isCopy=*/nullptr);
auto status = context->ResolveFunction(
iree_string_view_t{native_name, strlen(native_name)}, function);
env->ReleaseStringUTFChars(name, native_name);
return (jint)status.code();
}
JNI_FUNC jint JNI_PREFIX(nativeInvokeFunction)(JNIEnv* env, jobject thiz,
jlong functionAddress,
jobjectArray inputs,
jint inputElementCount,
jobject output) {
ContextWrapper* context = GetContextWrapper(env, thiz);
IREE_CHECK_NE(context, nullptr);
const jsize inputs_size = env->GetArrayLength(inputs);
std::vector<float*> native_inputs(inputs_size);
for (int i = 0; i < inputs_size; i++) {
jobject input = env->GetObjectArrayElement(inputs, i);
float* native_input = (float*)env->GetDirectBufferAddress(input);
native_inputs[i] = native_input;
}
auto function = (FunctionWrapper*)functionAddress;
float* native_output = (float*)env->GetDirectBufferAddress(output);
auto status = context->InvokeFunction(*function, native_inputs,
(int)inputElementCount, native_output);
return (jint)status.code();
}
JNI_FUNC jint JNI_PREFIX(nativeGetId)(JNIEnv* env, jobject thiz) {
ContextWrapper* context = GetContextWrapper(env, thiz);
IREE_CHECK_NE(context, nullptr);
int context_id = context->id();
return (jint)context_id;
}