Create templated version of Swizzle.
Use in ConvS8K3x1D48.
Change-Id: Ia60680debc4622383800ce2927faa7e060fbbd1b
diff --git a/tflm/opt/conv_s8_3x1_d48.cc b/tflm/opt/conv_s8_3x1_d48.cc
index 70a23b0..b924e24 100644
--- a/tflm/opt/conv_s8_3x1_d48.cc
+++ b/tflm/opt/conv_s8_3x1_d48.cc
@@ -63,9 +63,9 @@
int32_t bias[48 * 4];
int32_t mult[48 * 4];
int32_t shft[48 * 4];
- Swizzle(bias_data, bias, 48);
- Swizzle(output_multiplier, mult, 48);
- Swizzle(output_shift, shft, 48, true);
+ Swizzle<48>(bias_data, bias);
+ Swizzle<48>(output_multiplier, mult);
+ Swizzle<48, true>(output_shift, shft);
int8_t juggled_filter_data[48 / 8][3][1][48 / 4][8][4];
Filter_N_H_W_M<48>(filter_data, juggled_filter_data[0][0][0][0][0], 3, 1, 48);
diff --git a/tflm/opt/conv_util.h b/tflm/opt/conv_util.h
index 34f3857..a5151c3 100644
--- a/tflm/opt/conv_util.h
+++ b/tflm/opt/conv_util.h
@@ -130,6 +130,33 @@
}
}
+template <int N, bool negate=false>
+inline void Swizzle(const int32_t* input, int32_t* output) {
+ const int32_t(&in)[N] = *(int32_t(*)[N])input;
+ int32_t(&out)[N * 4] = *(int32_t(*)[N * 4]) output;
+ // Convert to accumulator swizzle pattern.
+ for (int i = 0; i < N / 8; ++i) {
+ int32_t* out0 = out + i * 32 + 0;
+ int32_t* out1 = out + i * 32 + 16;
+ int32_t* out2 = out + i * 32 + 8;
+ int32_t* out3 = out + i * 32 + 24;
+ for (int j = 0; j < 4; ++j) {
+ const int32_t* p_in = in + i * 8;
+ for (int k = 0; k < 2; ++k) {
+ *out0++ = *p_in++;
+ *out1++ = *p_in++;
+ *out2++ = *p_in++;
+ *out3++ = *p_in++;
+ }
+ }
+ }
+ if (negate) {
+ for (int i = 0; i < N * 4; ++i) {
+ out[i] = -out[i];
+ }
+ }
+}
+
// Runs strip-mined output pipeline (without bias addition) in place on
// registers.
#define INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(result, mult, shft, output_min, \