| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include <cinttypes> |
| #include <climits> |
| #include <cstdint> |
| #include <cstdio> |
| #include <cstdlib> |
| #include <random> |
| |
| //#define PRINT_INPUTS_AND_OUTPUTS (1) |
| |
| #ifndef PRINT_INPUTS_AND_OUTPUTS |
| #define PRINT_INPUTS_AND_OUTPUTS (0) |
| #endif |
| |
| extern "C" void vector_matmul4_asm(int32_t *out, const int8_t *lhs, const int8_t *rhs_t, std::size_t count); |
| |
| extern "C" int main(void) { |
| int8_t lhs[16*37]; |
| int8_t rhs_t[16*37]; |
| int32_t result[sizeof(lhs)+16]; |
| int32_t golden[sizeof(lhs)+16]; |
| std::default_random_engine generator; |
| std::uniform_int_distribution<int8_t> distribution(INT8_MIN, INT8_MAX); |
| |
| for (std::size_t i = 0; i < sizeof(lhs); i++) { |
| lhs[i] = distribution(generator); |
| rhs_t[i] = distribution(generator); |
| } |
| |
| // One extra guard matrix to ensure the assembly doesn't go past the end |
| for (std::size_t i = sizeof(lhs); i < sizeof(lhs)+16; i++) { |
| result[i] = 1337; |
| golden[i] = 1337; |
| } |
| |
| vector_matmul4_asm(result, lhs, rhs_t, sizeof(lhs)/16); |
| |
| for (std::size_t b = 0; b < sizeof(lhs)/16; b++) { |
| for (int j = 0; j < 4; j++) { |
| for (int i = 0; i < 4; i++) { |
| int32_t acc = 0; |
| for (int k = 0; k < 4; k++) { |
| acc += lhs[k+j*4+b*16] * rhs_t[k+i*4+b*16]; |
| } |
| golden[i+j*4+b*16] = acc; |
| } |
| } |
| } |
| |
| std::size_t errors = 0; |
| for (std::size_t b = 0; b < sizeof(result)/sizeof(int32_t)/16; b++) { |
| for (int j = 0; j < 4; j++) { |
| for (int i = 0; i < 4; i++) { |
| errors += result[i+4*j+b*16] == golden[i+4*j+b*16]? 0 : 1; |
| } |
| } |
| } |
| |
| if (PRINT_INPUTS_AND_OUTPUTS) { |
| printf("lhs:\n"); |
| for (std::size_t b = 0; b < sizeof(lhs)/sizeof(int8_t)/16; b++) { |
| printf("b = %d:\n",b); |
| for (int j = 0; j < 4; j++) { |
| printf(" "); |
| for (int i = 0; i < 4; i++) { |
| printf("%5d,", (int)lhs[i+4*j+b*16]); |
| } |
| printf("\n"); |
| } |
| printf("\n"); |
| } |
| |
| printf("rhs_t:\n"); |
| for (std::size_t b = 0; b < sizeof(rhs_t)/sizeof(int8_t)/16; b++) { |
| printf("b = %d:\n",b); |
| for (int j = 0; j < 4; j++) { |
| printf(" "); |
| for (int i = 0; i < 4; i++) { |
| printf("%5d,", (int)rhs_t[i+4*j+b*16]); |
| } |
| printf("\n"); |
| } |
| printf("\n"); |
| } |
| |
| printf("golden:\n"); |
| for (std::size_t b = 0; b < sizeof(golden)/sizeof(int32_t)/16; b++) { |
| printf("b = %d:\n",b); |
| for (int j = 0; j < 4; j++) { |
| printf(" "); |
| for (int i = 0; i < 4; i++) { |
| printf("%7d,", (int)golden[i+4*j+b*16]); |
| } |
| printf("\n"); |
| } |
| printf("\n"); |
| } |
| |
| printf("\nresults:\n"); |
| for (std::size_t b = 0; b < sizeof(result)/sizeof(int32_t)/16; b++) { |
| printf("b = %d:\n",b); |
| for (int j = 0; j < 4; j++) { |
| printf(" "); |
| for (int i = 0; i < 4; i++) { |
| bool same = result[i+4*j+b*16] == golden[i+4*j+b*16]; |
| printf("%7d%c", (int)result[i+4*j+b*16], same? ',' : '/'); |
| } |
| printf("\n"); |
| } |
| printf("\n"); |
| } |
| |
| printf("\n%d errors\n", errors); |
| } |
| |
| return (errors > INT_MAX)? INT_MAX : (int)errors; |
| } |