blob: 38f1f157f663135b3cb78805b91b3bc80b2623b9 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cinttypes>
#include <climits>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <random>
// #define PRINT_INPUTS_AND_OUTPUTS (1)
#ifndef PRINT_INPUTS_AND_OUTPUTS
#define PRINT_INPUTS_AND_OUTPUTS (0)
#endif
extern "C" void vector_matmul4_asm(int32_t *out, const int8_t *lhs,
const int8_t *rhs_t, std::size_t count);
extern "C" int main(void) {
int8_t lhs[16 * 37];
int8_t rhs_t[16 * 37];
int32_t result[sizeof(lhs) + 16];
int32_t golden[sizeof(lhs) + 16];
std::default_random_engine generator;
std::uniform_int_distribution<int8_t> distribution(INT8_MIN, INT8_MAX);
for (std::size_t i = 0; i < sizeof(lhs); i++) {
lhs[i] = distribution(generator);
rhs_t[i] = distribution(generator);
}
// One extra guard matrix to ensure the assembly doesn't go past the end
for (std::size_t i = sizeof(lhs); i < sizeof(lhs) + 16; i++) {
result[i] = 1337;
golden[i] = 1337;
}
vector_matmul4_asm(result, lhs, rhs_t, sizeof(lhs) / 16);
for (std::size_t b = 0; b < sizeof(lhs) / 16; b++) {
for (int j = 0; j < 4; j++) {
for (int i = 0; i < 4; i++) {
int32_t acc = 0;
for (int k = 0; k < 4; k++) {
acc += lhs[k + j * 4 + b * 16] * rhs_t[k + i * 4 + b * 16];
}
golden[i + j * 4 + b * 16] = acc;
}
}
}
std::size_t errors = 0;
for (std::size_t b = 0; b < sizeof(result) / sizeof(int32_t) / 16; b++) {
for (int j = 0; j < 4; j++) {
for (int i = 0; i < 4; i++) {
errors +=
result[i + 4 * j + b * 16] == golden[i + 4 * j + b * 16] ? 0 : 1;
}
}
}
if (PRINT_INPUTS_AND_OUTPUTS) {
printf("lhs:\n");
for (std::size_t b = 0; b < sizeof(lhs) / sizeof(int8_t) / 16; b++) {
printf("b = %d:\n", b);
for (int j = 0; j < 4; j++) {
printf(" ");
for (int i = 0; i < 4; i++) {
printf("%5d,", static_cast<int>(lhs[i + 4 * j + b * 16]));
}
printf("\n");
}
printf("\n");
}
printf("rhs_t:\n");
for (std::size_t b = 0; b < sizeof(rhs_t) / sizeof(int8_t) / 16; b++) {
printf("b = %d:\n", b);
for (int j = 0; j < 4; j++) {
printf(" ");
for (int i = 0; i < 4; i++) {
printf("%5d,", static_cast<int>(rhs_t[i + 4 * j + b * 16]));
}
printf("\n");
}
printf("\n");
}
printf("golden:\n");
for (std::size_t b = 0; b < sizeof(golden) / sizeof(int32_t) / 16; b++) {
printf("b = %d:\n", b);
for (int j = 0; j < 4; j++) {
printf(" ");
for (int i = 0; i < 4; i++) {
printf("%7d,", static_cast<int>(golden[i + 4 * j + b * 16]));
}
printf("\n");
}
printf("\n");
}
printf("\nresults:\n");
for (std::size_t b = 0; b < sizeof(result) / sizeof(int32_t) / 16; b++) {
printf("b = %d:\n", b);
for (int j = 0; j < 4; j++) {
printf(" ");
for (int i = 0; i < 4; i++) {
bool same = result[i + 4 * j + b * 16] == golden[i + 4 * j + b * 16];
printf("%7d%c", static_cast<int>(result[i + 4 * j + b * 16]),
same ? ',' : '/');
}
printf("\n");
}
printf("\n");
}
printf("\n%d errors\n", errors);
}
return (errors > INT_MAX) ? INT_MAX : static_cast<int>(errors);
}