blob: 0fe497269f50411e08e6e02e94d275d2e8d9c195 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cinttypes>
#include <climits>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <random>
//#define PRINT_INPUTS_AND_OUTPUTS (1)
#ifndef PRINT_INPUTS_AND_OUTPUTS
#define PRINT_INPUTS_AND_OUTPUTS (0)
#endif
extern "C" void vector_matmul4_asm(int32_t *out, const int8_t *lhs, const int8_t *rhs_t, std::size_t count);
extern "C" int main(void) {
int8_t lhs[16*37];
int8_t rhs_t[16*37];
int32_t result[sizeof(lhs)+16];
int32_t golden[sizeof(lhs)+16];
std::default_random_engine generator;
std::uniform_int_distribution<int8_t> distribution(INT8_MIN, INT8_MAX);
for (std::size_t i = 0; i < sizeof(lhs); i++) {
lhs[i] = distribution(generator);
rhs_t[i] = distribution(generator);
}
// One extra guard matrix to ensure the assembly doesn't go past the end
for (std::size_t i = sizeof(lhs); i < sizeof(lhs)+16; i++) {
result[i] = 1337;
golden[i] = 1337;
}
vector_matmul4_asm(result, lhs, rhs_t, sizeof(lhs)/16);
for (std::size_t b = 0; b < sizeof(lhs)/16; b++) {
for (int j = 0; j < 4; j++) {
for (int i = 0; i < 4; i++) {
int32_t acc = 0;
for (int k = 0; k < 4; k++) {
acc += lhs[k+j*4+b*16] * rhs_t[k+i*4+b*16];
}
golden[i+j*4+b*16] = acc;
}
}
}
std::size_t errors = 0;
for (std::size_t b = 0; b < sizeof(result)/sizeof(int32_t)/16; b++) {
for (int j = 0; j < 4; j++) {
for (int i = 0; i < 4; i++) {
errors += result[i+4*j+b*16] == golden[i+4*j+b*16]? 0 : 1;
}
}
}
if (PRINT_INPUTS_AND_OUTPUTS) {
printf("lhs:\n");
for (std::size_t b = 0; b < sizeof(lhs)/sizeof(int8_t)/16; b++) {
printf("b = %d:\n",b);
for (int j = 0; j < 4; j++) {
printf(" ");
for (int i = 0; i < 4; i++) {
printf("%5d,", (int)lhs[i+4*j+b*16]);
}
printf("\n");
}
printf("\n");
}
printf("rhs_t:\n");
for (std::size_t b = 0; b < sizeof(rhs_t)/sizeof(int8_t)/16; b++) {
printf("b = %d:\n",b);
for (int j = 0; j < 4; j++) {
printf(" ");
for (int i = 0; i < 4; i++) {
printf("%5d,", (int)rhs_t[i+4*j+b*16]);
}
printf("\n");
}
printf("\n");
}
printf("golden:\n");
for (std::size_t b = 0; b < sizeof(golden)/sizeof(int32_t)/16; b++) {
printf("b = %d:\n",b);
for (int j = 0; j < 4; j++) {
printf(" ");
for (int i = 0; i < 4; i++) {
printf("%7d,", (int)golden[i+4*j+b*16]);
}
printf("\n");
}
printf("\n");
}
printf("\nresults:\n");
for (std::size_t b = 0; b < sizeof(result)/sizeof(int32_t)/16; b++) {
printf("b = %d:\n",b);
for (int j = 0; j < 4; j++) {
printf(" ");
for (int i = 0; i < 4; i++) {
bool same = result[i+4*j+b*16] == golden[i+4*j+b*16];
printf("%7d%c", (int)result[i+4*j+b*16], same? ',' : '/');
}
printf("\n");
}
printf("\n");
}
printf("\n%d errors\n", errors);
}
return (errors > INT_MAX)? INT_MAX : (int)errors;
}