[ModelBuilder] Reduce benchmark set up boiler-plate code

PiperOrigin-RevId: 303182422
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp
index 140d370..6a745fb 100644
--- a/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp
+++ b/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp
@@ -27,16 +27,15 @@
   return AffineMap::get(3, 0, results);
 }
 
-// Helper method to build matrix-matrix-transposed multiplication.
-template <unsigned M, unsigned N, unsigned K, unsigned I>
+// Helper method to build a NxN matrix-matrix-transpose multiplication function
+// using vector dialect that runs I times to amortize any calling overhead.
+template <unsigned N, unsigned I>
 void buildMatMat(ModelBuilder &mb, StringLiteral fn) {
   auto f32 = mb.f32;
-  auto mkVectorType = mb.getVectorType({M, K}, f32);
-  auto typeA = mb.getMemRefType({}, mkVectorType);
-  auto knVectorType = mb.getVectorType({K, N}, f32);
-  auto typeB = mb.getMemRefType({}, knVectorType);
-  auto mnVectorType = mb.getVectorType({M, N}, f32);
-  auto typeC = mb.getMemRefType({}, mnVectorType);
+  auto nnVectorType = mb.getVectorType({N, N}, f32);
+  auto typeA = mb.getMemRefType({}, nnVectorType);
+  auto typeB = typeA;
+  auto typeC = typeA;
 
   auto f = mb.makeFunction(fn, {}, {typeA, typeB, typeC});
   OpBuilder b(&f.getBody());
@@ -58,7 +57,7 @@
   iterator_types.push_back(mb.getStringAttr("parallel"));
   iterator_types.push_back(mb.getStringAttr("reduction"));
 
-  // Loop I times over the kernel to reduce the JIT's overhead.
+  // Loop I times over the kernel to amortize calling overhead.
   auto loop =
       b.create<loop::ForOp>(f.getLoc(), std_constant_index(0),
                             std_constant_index(I), std_constant_index(1));
@@ -77,42 +76,42 @@
 }
 
 // Benchmark method.
-template <unsigned M, unsigned N, unsigned K>
-void testMatMulUsingVectors(benchmark::State &state, StringLiteral funcName,
-                            bool measureBuild) {
+template <unsigned N, bool MeasureBuild>
+void BM_MxMT_UsingVector(benchmark::State &state) {
   // Prepare arguments beforehand.
-  auto oneInit = [](unsigned idx, Vector2D<M, N, float> *ptr) {
+  auto oneInit = [](unsigned idx, Vector2D<N, N, float> *ptr) {
     float *p = reinterpret_cast<float *>(ptr + idx);
-    for (unsigned i = 0; i < M * N; ++i) p[i] = 1.0f;
+    for (unsigned i = 0; i < N * N; ++i) p[i] = 1.0f;
   };
-  auto incInit = [](unsigned idx, Vector2D<M, N, float> *ptr) {
+  auto incInit = [](unsigned idx, Vector2D<N, N, float> *ptr) {
     float *p = reinterpret_cast<float *>(ptr + idx);
-    for (unsigned i = 0; i < M * N; ++i) p[i] = 1.0f + i;
+    for (unsigned i = 0; i < N * N; ++i) p[i] = 1.0f + i;
   };
-  auto zeroInit = [](unsigned idx, Vector2D<M, N, float> *ptr) {
+  auto zeroInit = [](unsigned idx, Vector2D<N, N, float> *ptr) {
     float *p = reinterpret_cast<float *>(ptr + idx);
-    for (unsigned i = 0; i < M * N; ++i) p[i] = 0.0f;
+    for (unsigned i = 0; i < N * N; ++i) p[i] = 0.0f;
   };
-  auto A = makeInitializedStridedMemRefDescriptor<Vector2D<M, N, float>, 1>(
+  auto A = makeInitializedStridedMemRefDescriptor<Vector2D<N, N, float>, 1>(
       {1}, oneInit);
-  auto B = makeInitializedStridedMemRefDescriptor<Vector2D<M, N, float>, 1>(
+  auto B = makeInitializedStridedMemRefDescriptor<Vector2D<N, N, float>, 1>(
       {1}, incInit);
-  auto C = makeInitializedStridedMemRefDescriptor<Vector2D<M, N, float>, 1>(
+  auto C = makeInitializedStridedMemRefDescriptor<Vector2D<N, N, float>, 1>(
       {1}, zeroInit);
   auto *bufferA = A.get();
   auto *bufferB = B.get();
   auto *bufferC = C.get();
   void *args[3] = {&bufferA, &bufferB, &bufferC};
+  StringLiteral funcName = "mat_mul_trans";
   const std::string kFuncAdapterName =
       (llvm::Twine("_mlir_ciface_") + funcName).str();
 
-  if (measureBuild) {
+  if (MeasureBuild) {
     // If this is a build-time benchmark, build, compile, and execute
     // the function inside the timed loop, building a fresh new function
     // in each iteration to get the full JIT time (keep I == 1 here).
     for (auto _ : state) {
       ModelBuilder builder;
-      buildMatMat<M, N, K, 1>(builder, funcName);
+      buildMatMat<N, 1>(builder, funcName);
       ModelRunner runner(builder.getModuleRef());
       runner.compile(CompilationOptions());
       auto err = runner.engine->invoke(kFuncAdapterName,
@@ -125,7 +124,7 @@
     // the same function inside the loop to focus on actual runtime
     // (set I == 1000 here to amortize calling overhead).
     ModelBuilder builder;
-    buildMatMat<M, N, K, 1000>(builder, funcName);
+    buildMatMat<N, 1000>(builder, funcName);
     ModelRunner runner(builder.getModuleRef());
     runner.compile(CompilationOptions());
     auto err =
@@ -140,59 +139,17 @@
 }
 
 //
-// Benchmark drivers (build).
+// Benchmark drivers (build and run).
 //
 
-static void BM_Build_MatMul_1_1(benchmark::State &state) {
-  testMatMulUsingVectors<1, 1, 1>(state, "test_matmul_1_1_1", true);
-}
-BENCHMARK(BM_Build_MatMul_1_1);
+#define JIT true
+#define RUN false
+#define BENCHMARK_MAT_MUL_TRANS(SZ_N)                 \
+  BENCHMARK_TEMPLATE(BM_MxMT_UsingVector, SZ_N, JIT); \
+  BENCHMARK_TEMPLATE(BM_MxMT_UsingVector, SZ_N, RUN);
 
-static void BM_Build_MatMul_2_2(benchmark::State &state) {
-  testMatMulUsingVectors<2, 2, 2>(state, "test_matmul_2_2_2", true);
-}
-BENCHMARK(BM_Build_MatMul_2_2);
-
-static void BM_Build_MatMul_4_4(benchmark::State &state) {
-  testMatMulUsingVectors<4, 4, 4>(state, "test_matmul_4_4_4", true);
-}
-BENCHMARK(BM_Build_MatMul_4_4);
-
-static void BM_Build_MatMul_8_8(benchmark::State &state) {
-  testMatMulUsingVectors<8, 8, 8>(state, "test_matmul_8_8_8", true);
-}
-BENCHMARK(BM_Build_MatMul_8_8);
-
-static void BM_Build_MatMul_16_16(benchmark::State &state) {
-  testMatMulUsingVectors<16, 16, 16>(state, "test_matmul_16_16_16", true);
-}
-BENCHMARK(BM_Build_MatMul_16_16);
-
-//
-// Benchmark drivers (run).
-//
-
-static void BM_Run1000_MatMul_1_1(benchmark::State &state) {
-  testMatMulUsingVectors<1, 1, 1>(state, "test_matmul_1_1_1", false);
-}
-BENCHMARK(BM_Run1000_MatMul_1_1);
-
-static void BM_Run1000_MatMul_2_2(benchmark::State &state) {
-  testMatMulUsingVectors<2, 2, 2>(state, "test_matmul_2_2_2", false);
-}
-BENCHMARK(BM_Run1000_MatMul_2_2);
-
-static void BM_Run1000_MatMul_4_4(benchmark::State &state) {
-  testMatMulUsingVectors<4, 4, 4>(state, "test_matmul_4_4_4", false);
-}
-BENCHMARK(BM_Run1000_MatMul_4_4);
-
-static void BM_Run1000_MatMul_8_8(benchmark::State &state) {
-  testMatMulUsingVectors<8, 8, 8>(state, "test_matmul_8_8_8", false);
-}
-BENCHMARK(BM_Run1000_MatMul_8_8);
-
-static void BM_Run1000_MatMul_16_16(benchmark::State &state) {
-  testMatMulUsingVectors<16, 16, 16>(state, "test_matmul_16_16_16", false);
-}
-BENCHMARK(BM_Run1000_MatMul_16_16);
+BENCHMARK_MAT_MUL_TRANS(1);
+BENCHMARK_MAT_MUL_TRANS(2);
+BENCHMARK_MAT_MUL_TRANS(4);
+BENCHMARK_MAT_MUL_TRANS(8);
+BENCHMARK_MAT_MUL_TRANS(16);
diff --git a/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp b/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp
index c59f026..f79caf2 100644
--- a/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp
+++ b/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp
@@ -31,8 +31,8 @@
   return AffineMap::get(2, 0, results);
 }
 
-// Helper method to build a NxN matrix-vector multiplication
-// that runs I times to amortize any calling overhead.
+// Helper method to build a NxN matrix-vector multiplication function
+// using vector dialect that runs I times to amortize any calling overhead.
 template <unsigned N, unsigned I>
 void buildMatMat(ModelBuilder &mb, StringLiteral fn) {
   auto f32 = mb.f32;
@@ -80,9 +80,8 @@
 }
 
 // Benchmark method.
-template <unsigned N>
-void testMatVecUsingVectors(benchmark::State &state, StringLiteral funcName,
-                            bool measureBuild) {
+template <unsigned N, bool MeasureBuild>
+void BM_MxV_UsingVector(benchmark::State &state) {
   // Prepare arguments beforehand.
   auto incInit = [](unsigned idx, Vector2D<N, N, float> *ptr) {
     float *p = reinterpret_cast<float *>(ptr + idx);
@@ -106,10 +105,11 @@
   auto *bufferB = B.get();
   auto *bufferC = C.get();
   void *args[3] = {&bufferA, &bufferB, &bufferC};
+  StringLiteral funcName = "matvec_mult";
   const std::string kFuncAdapterName =
       (llvm::Twine("_mlir_ciface_") + funcName).str();
 
-  if (measureBuild) {
+  if (MeasureBuild) {
     // If this is a build-time benchmark, build, compile, and execute
     // the function inside the timed loop, building a fresh new function
     // in each iteration to get the full JIT time (keep I == 1 here).
@@ -143,69 +143,18 @@
 }
 
 //
-// Benchmark drivers (build).
+// Benchmark drivers (build and run).
 //
 
-static void BM_Build_MatVec_1(benchmark::State &state) {
-  testMatVecUsingVectors<1>(state, "test_matvec_1", true);
-}
-BENCHMARK(BM_Build_MatVec_1);
+#define JIT true
+#define RUN false
+#define BENCHMARK_MAT_VEC(SZ_N)                      \
+  BENCHMARK_TEMPLATE(BM_MxV_UsingVector, SZ_N, JIT); \
+  BENCHMARK_TEMPLATE(BM_MxV_UsingVector, SZ_N, RUN);
 
-static void BM_Build_MatVec_2(benchmark::State &state) {
-  testMatVecUsingVectors<2>(state, "test_matvec_2", true);
-}
-BENCHMARK(BM_Build_MatVec_2);
-
-static void BM_Build_MatVec_4(benchmark::State &state) {
-  testMatVecUsingVectors<4>(state, "test_matvec_4", true);
-}
-BENCHMARK(BM_Build_MatVec_4);
-
-static void BM_Build_MatVec_8(benchmark::State &state) {
-  testMatVecUsingVectors<8>(state, "test_matvec_8", true);
-}
-BENCHMARK(BM_Build_MatVec_8);
-
-static void BM_Build_MatVec_16(benchmark::State &state) {
-  testMatVecUsingVectors<16>(state, "test_matvec_16", true);
-}
-BENCHMARK(BM_Build_MatVec_16);
-
-static void BM_Build_MatVec_32(benchmark::State &state) {
-  testMatVecUsingVectors<32>(state, "test_matvec_32", true);
-}
-BENCHMARK(BM_Build_MatVec_32);
-
-//
-// Benchmark drivers (run).
-//
-
-static void BM_Run1000_MatVec_1(benchmark::State &state) {
-  testMatVecUsingVectors<1>(state, "test_matvec_1", false);
-}
-BENCHMARK(BM_Run1000_MatVec_1);
-
-static void BM_Run1000_MatVec_2(benchmark::State &state) {
-  testMatVecUsingVectors<2>(state, "test_matvec_2", false);
-}
-BENCHMARK(BM_Run1000_MatVec_2);
-
-static void BM_Run1000_MatVec_4(benchmark::State &state) {
-  testMatVecUsingVectors<4>(state, "test_matvec_4", false);
-}
-BENCHMARK(BM_Run1000_MatVec_4);
-
-static void BM_Run1000_MatVec_8(benchmark::State &state) {
-  testMatVecUsingVectors<8>(state, "test_matvec_8", false);
-}
-BENCHMARK(BM_Run1000_MatVec_8);
-
-static void BM_Run1000_MatVec_16(benchmark::State &state) {
-  testMatVecUsingVectors<16>(state, "test_matvec_16", false);
-}
-BENCHMARK(BM_Run1000_MatVec_16);
-
-static void BM_Run1000_MatVec_32(benchmark::State &state) {
-  testMatVecUsingVectors<32>(state, "test_matvec_32", false);
-}
-BENCHMARK(BM_Run1000_MatVec_32);
+BENCHMARK_MAT_VEC(1);
+BENCHMARK_MAT_VEC(2);
+BENCHMARK_MAT_VEC(4);
+BENCHMARK_MAT_VEC(8);
+BENCHMARK_MAT_VEC(16);
+BENCHMARK_MAT_VEC(32);