// Copyright 2023 Google, LLC.
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// Heavily based on the `runtime/src/iree/hal/local/elf/elf_module_test_main.c`
// sample from IREE.

#include "iree/base/api.h"
#include "iree/base/internal/cpu.h"
#include "iree/hal/local/elf/elf_module.h"
#include "iree/hal/local/executable_environment.h"
#include "iree/hal/local/executable_library.h"
#include "samples/branch_mul/branch_mul_arg0.h"
#include "samples/branch_mul/branch_mul_arg1.h"
#include "samples/branch_mul/branch_mul_arg2.h"
#include "samples/branch_mul/branch_mul_c.h"
#include "samples/branch_mul/branch_mul_emitc.h"
#include "samples/branch_mul/branch_mul_expected.h"

static int8_t ret0[3][2][256] = { 0 };
static iree_status_t run_test() {
  iree_hal_executable_environment_v0_t environment;
  iree_hal_executable_environment_initialize(iree_allocator_system(),
                                             &environment);

  void* query_fn_ptr = &llvm_module_linked_llvm_cpu_library_query;

  union {
    const iree_hal_executable_library_header_t** header;
    const iree_hal_executable_library_v0_t* v0;
  } library;
  library.header =
      (const iree_hal_executable_library_header_t**)iree_elf_call_p_ip(
          query_fn_ptr, IREE_HAL_EXECUTABLE_LIBRARY_VERSION_LATEST,
          &environment);
  if (library.header == NULL) {
    return iree_make_status(IREE_STATUS_NOT_FOUND,
                            "library header is empty (version mismatch?)");
  }

  const iree_hal_executable_library_header_t* header = *library.header;
  if (header->version != IREE_HAL_EXECUTABLE_LIBRARY_VERSION_LATEST) {
    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                            "library version error");
  }

  if (strncmp(header->name, "llvm_module_linked_llvm_cpu",
              strlen(header->name)) != 0) {
    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                            "library name mismatches");
  }

  if (library.v0->exports.count != 6) {
    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                            "entry point count mismatches");
  }

  size_t binding_lengths[4] = {
      sizeof(arg0),
      sizeof(arg1),
      sizeof(arg2),
      sizeof(ret0),
  };
  void* binding_ptrs[4] = {
      arg0,
      arg1,
      arg2,
      ret0,
  };
  const iree_hal_executable_dispatch_state_v0_t dispatch_state = {
      .workgroup_size_x = 1,
      .workgroup_size_y = 1,
      .workgroup_size_z = 1,
      .workgroup_count_x = 1,
      .workgroup_count_y = 1,
      .workgroup_count_z = 1,
      .max_concurrency = 1,
      .binding_count = 1,
      .binding_lengths = binding_lengths,
      .binding_ptrs = binding_ptrs,
  };
  const iree_hal_executable_workgroup_state_v0_t workgroup_state = {
      .workgroup_id_x = 0,
      .workgroup_id_y = 0,
      .workgroup_id_z = 0,
      .processor_id = iree_cpu_query_processor_id(),
  };
  int ret = iree_elf_call_i_ppp((const void*)library.v0->exports.ptrs[5],
                                (void*)&environment, (void*)&dispatch_state,
                                (void*)&workgroup_state);
  if (ret != 0) {
    return iree_make_status(IREE_STATUS_INTERNAL,
                            "dispatch function returned failure: %d", ret);
  }

  int mismatch_count = 0;
  iree_status_t status = iree_ok_status();
  for (int i = 0; i < 3; ++i) {
    for (int j = 0; j < 2; ++j) {
      for (int k = 0; k < 256; ++k) {
        int8_t out = ret0[i][j][k];
        int8_t expect = expected[i][j][k];
        if (out != expect) {
          mismatch_count++;
          fprintf(stdout, "mismatch at [%d][%d][%d] %x != %x\n", i, j, k, out,
                  expect);
          fprintf(stdout, "ins: %x %x %x\n", arg0[i][j][k], arg1[k],
                  arg2[i][j][k]);
        }
      }
    }
  }
  if (mismatch_count) {
    status = iree_make_status(
        IREE_STATUS_INTERNAL,
        "%d mismatches between actual and expected output", mismatch_count);
  }

  return status;
}

int main() {
  const iree_status_t result = run_test();
  int ret = (int)iree_status_code(result);
  if (!iree_status_is_ok(result)) {
    iree_status_fprint(stderr, result);
    iree_status_free(result);
  }
  return ret;
}
