| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include <cinttypes> |
| #include <climits> |
| #include <cstdint> |
| #include <cstdio> |
| #include <cstdlib> |
| #include <random> |
| |
| // #define PRINT_INPUTS_AND_OUTPUTS (1) |
| |
| #ifndef PRINT_INPUTS_AND_OUTPUTS |
| #define PRINT_INPUTS_AND_OUTPUTS (0) |
| #endif |
| |
| extern "C" void vector_matmul4_asm(int32_t *out, const int8_t *lhs, |
| const int8_t *rhs_t, std::size_t count); |
| |
| extern "C" int main(void) { |
| int8_t lhs[16 * 37]; |
| int8_t rhs_t[16 * 37]; |
| int32_t result[sizeof(lhs) + 16]; |
| int32_t golden[sizeof(lhs) + 16]; |
| std::default_random_engine generator; |
| std::uniform_int_distribution<int8_t> distribution(INT8_MIN, INT8_MAX); |
| |
| for (std::size_t i = 0; i < sizeof(lhs); i++) { |
| lhs[i] = distribution(generator); |
| rhs_t[i] = distribution(generator); |
| } |
| |
| // One extra guard matrix to ensure the assembly doesn't go past the end |
| for (std::size_t i = sizeof(lhs); i < sizeof(lhs) + 16; i++) { |
| result[i] = 1337; |
| golden[i] = 1337; |
| } |
| |
| vector_matmul4_asm(result, lhs, rhs_t, sizeof(lhs) / 16); |
| |
| for (std::size_t b = 0; b < sizeof(lhs) / 16; b++) { |
| for (int j = 0; j < 4; j++) { |
| for (int i = 0; i < 4; i++) { |
| int32_t acc = 0; |
| for (int k = 0; k < 4; k++) { |
| acc += lhs[k + j * 4 + b * 16] * rhs_t[k + i * 4 + b * 16]; |
| } |
| golden[i + j * 4 + b * 16] = acc; |
| } |
| } |
| } |
| |
| std::size_t errors = 0; |
| for (std::size_t b = 0; b < sizeof(result) / sizeof(int32_t) / 16; b++) { |
| for (int j = 0; j < 4; j++) { |
| for (int i = 0; i < 4; i++) { |
| errors += |
| result[i + 4 * j + b * 16] == golden[i + 4 * j + b * 16] ? 0 : 1; |
| } |
| } |
| } |
| |
| if (PRINT_INPUTS_AND_OUTPUTS) { |
| printf("lhs:\n"); |
| for (std::size_t b = 0; b < sizeof(lhs) / sizeof(int8_t) / 16; b++) { |
| printf("b = %d:\n", b); |
| for (int j = 0; j < 4; j++) { |
| printf(" "); |
| for (int i = 0; i < 4; i++) { |
| printf("%5d,", static_cast<int>(lhs[i + 4 * j + b * 16])); |
| } |
| printf("\n"); |
| } |
| printf("\n"); |
| } |
| |
| printf("rhs_t:\n"); |
| for (std::size_t b = 0; b < sizeof(rhs_t) / sizeof(int8_t) / 16; b++) { |
| printf("b = %d:\n", b); |
| for (int j = 0; j < 4; j++) { |
| printf(" "); |
| for (int i = 0; i < 4; i++) { |
| printf("%5d,", static_cast<int>(rhs_t[i + 4 * j + b * 16])); |
| } |
| printf("\n"); |
| } |
| printf("\n"); |
| } |
| |
| printf("golden:\n"); |
| for (std::size_t b = 0; b < sizeof(golden) / sizeof(int32_t) / 16; b++) { |
| printf("b = %d:\n", b); |
| for (int j = 0; j < 4; j++) { |
| printf(" "); |
| for (int i = 0; i < 4; i++) { |
| printf("%7d,", static_cast<int>(golden[i + 4 * j + b * 16])); |
| } |
| printf("\n"); |
| } |
| printf("\n"); |
| } |
| |
| printf("\nresults:\n"); |
| for (std::size_t b = 0; b < sizeof(result) / sizeof(int32_t) / 16; b++) { |
| printf("b = %d:\n", b); |
| for (int j = 0; j < 4; j++) { |
| printf(" "); |
| for (int i = 0; i < 4; i++) { |
| bool same = result[i + 4 * j + b * 16] == golden[i + 4 * j + b * 16]; |
| printf("%7d%c", static_cast<int>(result[i + 4 * j + b * 16]), |
| same ? ',' : '/'); |
| } |
| printf("\n"); |
| } |
| printf("\n"); |
| } |
| |
| printf("\n%d errors\n", errors); |
| } |
| |
| return (errors > INT_MAX) ? INT_MAX : static_cast<int>(errors); |
| } |