blob: 52a08e000cb9c38fe04b56a4c64c5b2d55c393bb [file] [log] [blame]
// Copyright Microsoft and CHERIoT Contributors.
// SPDX-License-Identifier: MIT
#include <cstdint>
#define TEST_NAME "Thread pool"
#include "tests.hh"
#include <cheri.hh>
#include <cheriot-atomic.hh>
#include <switcher.h>
#include <thread.h>
#include <thread_pool.h>
int counter;
using CHERI::with_interrupts_disabled;
using namespace thread_pool;
cheriot::atomic<bool> errorHandled = false;
cheriot::atomic<bool> interruptStarted = false;
cheriot::atomic<int> interruptThreadNumber;
extern "C" ErrorRecoveryBehaviour
compartment_error_handler(ErrorState *frame, size_t mcause, size_t mtval)
{
debug_log("Thread {} error handler invoked with mcause {}. PCC: {}",
thread_id_get(),
mcause,
frame->pcc);
if (mcause != 25)
{
return ErrorRecoveryBehaviour::ForceUnwind;
}
if (thread_id_get() != interruptThreadNumber)
{
debug_log(
"Explicit thread interrupt delivered on the wrong thread (thread {}, "
"expected {})",
thread_id_get(),
interruptThreadNumber.load());
return ErrorRecoveryBehaviour::ForceUnwind;
}
errorHandled = true;
debug_log("Expected software interrupt, installing context");
return ErrorRecoveryBehaviour::InstallContext;
}
int test_thread_pool()
{
// We can't share stack variables, so create a heap allocation that we can
// capture as an explicit pointer.
int *heapInt = new (malloc(sizeof(int))) int(0);
TEST(thread_id_get() == 1,
"Thread id of main thread should be 1, is {}",
thread_id_get());
// Run a simple stateless callback that increments a global in the thread
// pool. This demonstrates that we can correctly capture a stateless
// function and pass it to the worker thread.
async([]() {
with_interrupts_disabled([]() {
counter++;
debug_log("Calling stateless function from thread pool");
});
});
async([=]() {
with_interrupts_disabled([=]() {
debug_log(
"Calling stateful function from thread pool with {} captured",
heapInt);
counter++;
(*heapInt)++;
});
});
debug_log("Counter: {}", counter);
debug_log("heapInt: {}", *heapInt);
int sleeps = 0;
while (counter < 2)
{
Timeout t{1};
thread_sleep(&t);
TEST(sleeps < 100, "Gave up after too many sleeps");
}
debug_log("Yielded {} times for the thread pool to run our jobs", sleeps);
TEST(counter == 2, "Counter is {}, should be 2", counter);
TEST(*heapInt == 1, "Heap-allocated integer is {}, should be 1", *heapInt);
debug_log("Freeing heap int: {}", heapInt);
free(heapInt);
async([]() {
auto fast = thread_id_get();
auto slow = thread_id_get();
TEST(fast == slow,
"Thread ID is different in fast ({}) and slow ({}) accessors",
fast,
slow);
TEST(fast != 1, "Thread ID for thread pool thread should not be 1");
});
CHERI::Capability<void> mainThread{switcher_current_thread()};
TEST(mainThread.is_sealed(), "Thread should be sealed: {}", mainThread);
TEST(mainThread.type() == 10,
"Thread should be sealed with otype 10: {}",
mainThread);
TEST(!switcher_interrupt_thread(mainThread),
"Interrupting the current thread should fail");
TEST(!switcher_interrupt_thread(nullptr),
"Interrupting null thread should fail");
TEST(!switcher_interrupt_thread(&sleeps),
"Interrupting invalid thread should fail");
static void *asyncThread;
static bool interrupted;
async([=]() mutable {
interruptThreadNumber = thread_id_get();
asyncThread = switcher_current_thread();
while (!interruptStarted)
{
yield();
}
TEST(errorHandled,
"Worker thread was not interrupted from higher-priority one");
interrupted = true;
});
for (int i = 0; i < 3; i++)
{
if (!asyncThread)
{
Timeout t{1};
thread_sleep(&t);
}
}
TEST(asyncThread, "Worker thread did not provide thread pointer");
debug_log("Interrupting other thread");
bool ret = switcher_interrupt_thread(asyncThread);
interruptStarted = true;
TEST(ret, "Interrupting worker thread failed: {}", ret);
Timeout t{3};
thread_sleep(&t);
TEST(interrupted, "Worker thread was not interrupted");
return 0;
static cheriot::atomic<uint32_t> barrier{3};
auto barrierWait = []() {
uint32_t value = barrier--;
if (value == 0)
{
barrier.notify_all();
}
while (value != 0)
{
barrier.wait(value);
value = barrier;
}
};
// Make sure that the thread pool threads have both finished.
async(barrierWait);
async(barrierWait);
barrierWait();
debug_log("Thread pool quiesced");
return 0;
}