| //===- MLIRRunnerUtils.h - Utils for debugging MLIR CPU execution ---------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| // This file is copied from llvm-project/mlir/test/mlir-cpu-runner/include. |
| // It sits at an incorrect location in the MLIR repository because support for |
| // integration with external libraries is not considered core atm. |
| // When this file lives at a more reasonable location in MLIR, we can remove |
| // this unnecessary duplication. |
| |
| #ifndef MLIR_CPU_RUNNER_MLIRUTILS_H_ |
| #define MLIR_CPU_RUNNER_MLIRUTILS_H_ |
| |
| #include <assert.h> |
| |
| #include <cstdint> |
| #include <iostream> |
| |
| #ifdef _WIN32 |
| #ifndef MLIR_RUNNER_UTILS_EXPORT |
| #ifdef mlir_runner_utils_EXPORTS |
| /* We are building this library */ |
| #define MLIR_RUNNER_UTILS_EXPORT __declspec(dllexport) |
| #else |
| /* We are using this library */ |
| #define MLIR_RUNNER_UTILS_EXPORT __declspec(dllimport) |
| #endif // mlir_runner_utils_EXPORTS |
| #endif // MLIR_RUNNER_UTILS_EXPORT |
| #else |
| #define MLIR_RUNNER_UTILS_EXPORT |
| #endif // _WIN32 |
| |
| template <typename T, int N> |
| struct StridedMemRefType; |
| template <typename StreamType, typename T, int N> |
| void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V); |
| |
| template <int N> |
| void dropFront(int64_t arr[N], int64_t *res) { |
| for (unsigned i = 1; i < N; ++i) *(res + i - 1) = arr[i]; |
| } |
| |
| /// StridedMemRef descriptor type with static rank. |
| template <typename T, int N> |
| struct StridedMemRefType { |
| T *basePtr; |
| T *data; |
| int64_t offset; |
| int64_t sizes[N]; |
| int64_t strides[N]; |
| // This operator[] is extremely slow and only for sugaring purposes. |
| StridedMemRefType<T, N - 1> operator[](int64_t idx) { |
| StridedMemRefType<T, N - 1> res; |
| res.basePtr = basePtr; |
| res.data = data; |
| res.offset = offset + idx * strides[0]; |
| dropFront<N>(sizes, res.sizes); |
| dropFront<N>(strides, res.strides); |
| return res; |
| } |
| }; |
| |
| /// StridedMemRef descriptor type specialized for rank 1. |
| template <typename T> |
| struct StridedMemRefType<T, 1> { |
| T *basePtr; |
| T *data; |
| int64_t offset; |
| int64_t sizes[1]; |
| int64_t strides[1]; |
| T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); } |
| }; |
| |
| /// StridedMemRef descriptor type specialized for rank 0. |
| template <typename T> |
| struct StridedMemRefType<T, 0> { |
| T *basePtr; |
| T *data; |
| int64_t offset; |
| }; |
| |
| // Unranked MemRef |
| template <typename T> |
| struct UnrankedMemRefType { |
| int64_t rank; |
| void *descriptor; |
| }; |
| |
| template <typename StreamType, typename T, int N> |
| void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) { |
| static_assert(N > 0, "Expected N > 0"); |
| os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = " << N |
| << " offset = " << V.offset << " sizes = [" << V.sizes[0]; |
| for (unsigned i = 1; i < N; ++i) os << ", " << V.sizes[i]; |
| os << "] strides = [" << V.strides[0]; |
| for (unsigned i = 1; i < N; ++i) os << ", " << V.strides[i]; |
| os << "]"; |
| } |
| |
| template <typename StreamType, typename T> |
| void printMemRefMetaData(StreamType &os, StridedMemRefType<T, 0> &V) { |
| os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = 0" |
| << " offset = " << V.offset; |
| } |
| |
| template <typename T, typename StreamType> |
| void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType<T> &V) { |
| os << "Unranked Memref rank = " << V.rank << " " |
| << "descriptor@ = " << reinterpret_cast<void *>(V.descriptor) << "\n"; |
| } |
| |
| template <typename T, int Dim, int... Dims> |
| struct Vector { |
| Vector<T, Dims...> vector[Dim]; |
| }; |
| template <typename T, int Dim> |
| struct Vector<T, Dim> { |
| T vector[Dim]; |
| }; |
| |
| template <int D1, typename T> |
| using Vector1D = Vector<T, D1>; |
| template <int D1, int D2, typename T> |
| using Vector2D = Vector<T, D1, D2>; |
| template <int D1, int D2, int D3, typename T> |
| using Vector3D = Vector<T, D1, D2, D3>; |
| template <int D1, int D2, int D3, int D4, typename T> |
| using Vector4D = Vector<T, D1, D2, D3, D4>; |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // Templated instantiation follows. |
| //////////////////////////////////////////////////////////////////////////////// |
| namespace impl { |
| template <typename T, int M, int... Dims> |
| std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v); |
| |
| template <int... Dims> |
| struct StaticSizeMult { |
| static constexpr int value = 1; |
| }; |
| |
| template <int N, int... Dims> |
| struct StaticSizeMult<N, Dims...> { |
| static constexpr int value = N * StaticSizeMult<Dims...>::value; |
| }; |
| |
| static inline void printSpace(std::ostream &os, int count) { |
| for (int i = 0; i < count; ++i) { |
| os << ' '; |
| } |
| } |
| |
| template <typename T, int M, int... Dims> |
| struct VectorDataPrinter { |
| static void print(std::ostream &os, const Vector<T, M, Dims...> &val); |
| }; |
| |
| template <typename T, int M, int... Dims> |
| void VectorDataPrinter<T, M, Dims...>::print(std::ostream &os, |
| const Vector<T, M, Dims...> &val) { |
| static_assert(M > 0, "0 dimensioned tensor"); |
| static_assert(sizeof(val) == M * StaticSizeMult<Dims...>::value * sizeof(T), |
| "Incorrect vector size!"); |
| // First |
| os << "(" << val.vector[0]; |
| if (M > 1) os << ", "; |
| if (sizeof...(Dims) > 1) os << "\n"; |
| // Kernel |
| for (unsigned i = 1; i + 1 < M; ++i) { |
| printSpace(os, 2 * sizeof...(Dims)); |
| os << val.vector[i] << ", "; |
| if (sizeof...(Dims) > 1) os << "\n"; |
| } |
| // Last |
| if (M > 1) { |
| printSpace(os, sizeof...(Dims)); |
| os << val.vector[M - 1]; |
| } |
| os << ")"; |
| } |
| |
| template <typename T, int M, int... Dims> |
| std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v) { |
| VectorDataPrinter<T, M, Dims...>::print(os, v); |
| return os; |
| } |
| |
| template <typename T, int N> |
| struct MemRefDataPrinter { |
| static void print(std::ostream &os, T *base, int64_t rank, int64_t offset, |
| int64_t *sizes, int64_t *strides); |
| static void printFirst(std::ostream &os, T *base, int64_t rank, |
| int64_t offset, int64_t *sizes, int64_t *strides); |
| static void printLast(std::ostream &os, T *base, int64_t rank, int64_t offset, |
| int64_t *sizes, int64_t *strides); |
| }; |
| |
| template <typename T> |
| struct MemRefDataPrinter<T, 0> { |
| static void print(std::ostream &os, T *base, int64_t rank, int64_t offset, |
| int64_t *sizes = nullptr, int64_t *strides = nullptr); |
| }; |
| |
| template <typename T, int N> |
| void MemRefDataPrinter<T, N>::printFirst(std::ostream &os, T *base, |
| int64_t rank, int64_t offset, |
| int64_t *sizes, int64_t *strides) { |
| os << "["; |
| MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset, sizes + 1, |
| strides + 1); |
| // If single element, close square bracket and return early. |
| if (sizes[0] <= 1) { |
| os << "]"; |
| return; |
| } |
| os << ", "; |
| if (N > 1) os << "\n"; |
| } |
| |
| template <typename T, int N> |
| void MemRefDataPrinter<T, N>::print(std::ostream &os, T *base, int64_t rank, |
| int64_t offset, int64_t *sizes, |
| int64_t *strides) { |
| printFirst(os, base, rank, offset, sizes, strides); |
| for (unsigned i = 1; i + 1 < sizes[0]; ++i) { |
| printSpace(os, rank - N + 1); |
| MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset + i * strides[0], |
| sizes + 1, strides + 1); |
| os << ", "; |
| if (N > 1) os << "\n"; |
| } |
| if (sizes[0] <= 1) return; |
| printLast(os, base, rank, offset, sizes, strides); |
| } |
| |
| template <typename T, int N> |
| void MemRefDataPrinter<T, N>::printLast(std::ostream &os, T *base, int64_t rank, |
| int64_t offset, int64_t *sizes, |
| int64_t *strides) { |
| printSpace(os, rank - N + 1); |
| MemRefDataPrinter<T, N - 1>::print(os, base, rank, |
| offset + (sizes[0] - 1) * (*strides), |
| sizes + 1, strides + 1); |
| os << "]"; |
| } |
| |
| template <typename T> |
| void MemRefDataPrinter<T, 0>::print(std::ostream &os, T *base, int64_t rank, |
| int64_t offset, int64_t *sizes, |
| int64_t *strides) { |
| os << base[offset]; |
| } |
| |
| template <typename T, int N> |
| void printMemRef(StridedMemRefType<T, N> &M) { |
| static_assert(N > 0, "Expected N > 0"); |
| printMemRefMetaData(std::cout, M); |
| std::cout << " data = " << std::endl; |
| MemRefDataPrinter<T, N>::print(std::cout, M.data, N, M.offset, M.sizes, |
| M.strides); |
| std::cout << std::endl; |
| } |
| |
| template <typename T> |
| void printMemRef(StridedMemRefType<T, 0> &M) { |
| printMemRefMetaData(std::cout, M); |
| std::cout << " data = " << std::endl; |
| std::cout << "["; |
| MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset); |
| std::cout << "]" << std::endl; |
| } |
| } // namespace impl |
| |
| #endif // MLIR_CPU_RUNNER_MLIRUTILS_H_ |