blob: d433f1f0d448590cd0488a2dbfcf66aa77dc9b56 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <emscripten.h>
#include <emscripten/bind.h>
#include <vector>
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "base/flatbuffer_util.h"
#include "base/init.h"
#include "base/status.h"
#include "hal/buffer_view.h"
#include "hal/buffer_view_string_util.h"
#include "hal/driver_registry.h"
#include "rt/context.h"
#include "rt/instance.h"
#include "schemas/module_def_generated.h"
#include "vm/sequencer_module.h"
namespace iree {
// Parses a list of input shapes and values from a string of newline-separated
// inputs. Expects the contents to have one value per line with each value
// listed as
// [shape]xtype=[value]
// Example:
// 4x4xi8=0,1,2,3
StatusOr<std::vector<hal::BufferView>> ParseInputs(
absl::string_view inputs_string, hal::Allocator* allocator) {
std::string input_lines = absl::StrReplaceAll(inputs_string, {{"\\n", "\n"}});
std::vector<hal::BufferView> input_buffer_views;
for (const auto& input_line :
absl::StrSplit(input_lines, '\n', absl::SkipWhitespace())) {
ASSIGN_OR_RETURN(auto input_buffer_view,
hal::ParseBufferViewFromString(input_line, allocator));
input_buffer_views.push_back(input_buffer_view);
}
return input_buffer_views;
}
// Runs an IREE module with the provided inputs and returns its outputs.
StatusOr<std::string> RunIreeModule(std::string module_file_data,
absl::string_view inputs_string) {
auto instance = make_ref<rt::Instance>();
// Create driver and device.
ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create(
"interpreter"));
ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice());
RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device));
auto policy = make_ref<rt::Policy>();
auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy));
// Load main module FlatBuffer.
ASSIGN_OR_RETURN(auto main_module_file,
FlatBufferFile<ModuleDef>::FromString(ModuleDefIdentifier(),
module_file_data));
ASSIGN_OR_RETURN(auto main_module,
vm::SequencerModule::FromFile(std::move(main_module_file)));
// Register the main module with the context.
RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module)));
// Setup arguments and storage for results.
// TODO(scotttodd): Receive main function name from JS.
ASSIGN_OR_RETURN(auto main_function,
main_module->LookupFunctionByName(
rt::Function::Linkage::kExport, "main"));
ASSIGN_OR_RETURN(auto arguments,
ParseInputs(inputs_string, device->allocator()));
// Call into the main function.
ASSIGN_OR_RETURN(auto invocation,
rt::Invocation::Create(add_ref(context), main_function,
make_ref<rt::Policy>(), {},
absl::MakeConstSpan(arguments)));
// Wait until invocation completes.
// TODO(scotttodd): make this an async callback.
RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture()));
ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults());
// Dump all results to stdout.
// TODO(scotttodd): Receive output types / print mode from JS.
// TODO(scotttodd): Return list of outputs instead of just the first (proto?)
for (int i = 0; i < results.size(); ++i) {
const auto& result = results[i];
auto print_mode = hal::BufferViewPrintMode::kFloatingPoint;
ASSIGN_OR_RETURN(auto result_str,
PrintBufferViewToString(result, print_mode, 1024));
const auto& buffer = result.buffer;
if (!buffer) {
return InternalErrorBuilder(IREE_LOC)
<< "result[" << i << "] unexpectedly has no buffer";
}
return result_str;
}
return InternalErrorBuilder(IREE_LOC) << "Received no results";
}
std::string RunIreeModuleEntry(std::string module_file_data,
std::string inputs_string) {
// TODO(scotttodd): optimize, minimize copies
// https://groups.google.com/d/msg/emscripten-discuss/CMfYljLWMvY/Di52WB2QAgAJ
auto result_or = RunIreeModule(std::move(module_file_data), inputs_string);
if (!result_or.ok()) {
return "Error: " + result_or.status().ToString();
} else {
return result_or.ValueOrDie();
}
}
EMSCRIPTEN_BINDINGS(iree) {
emscripten::function("runIreeModule", &RunIreeModuleEntry);
}
extern "C" int main(int argc, char** argv) {
InitializeEnvironment(&argc, &argv);
return 0;
}
} // namespace iree