[ModelBuilder] Make loop usage more idiomatic. PiperOrigin-RevId: 305772984
diff --git a/experimental/ModelBuilder/ModelBuilder.h b/experimental/ModelBuilder/ModelBuilder.h index a138654..14d4218 100644 --- a/experimental/ModelBuilder/ModelBuilder.h +++ b/experimental/ModelBuilder/ModelBuilder.h
@@ -81,6 +81,10 @@ using edsc::intrinsics::StdIndexedValue; // From the Affine Dialect. using edsc::intrinsics::AffineIndexedValue; +// From the Loop Dialect. +using edsc::AffineLoopNestBuilder; +using edsc::LoopNestBuilder; +using edsc::ParallelLoopNestBuilder; // ----------------------------------------------------------------------------- // Entry point class to build a whole model declaratively with C++ EDSCs.
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorColumnMajorLLVMIntrinsicsJIT.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorColumnMajorLLVMIntrinsicsJIT.cpp index 56fc941..76d1804 100644 --- a/experimental/ModelBuilder/test/BenchMatMulVectorColumnMajorLLVMIntrinsicsJIT.cpp +++ b/experimental/ModelBuilder/test/BenchMatMulVectorColumnMajorLLVMIntrinsicsJIT.cpp
@@ -30,8 +30,10 @@ return results; } -// Helper method to build matrix-matrix-transposed multiplication. -template <unsigned M, unsigned N, unsigned K, unsigned I> +// Helper method to build a matrix-matrix column-major multiplication function +// using the vector dialect and that runs ITERS times to amortize any calling +// overhead. +template <unsigned M, unsigned N, unsigned K, unsigned ITERS> void buildMatMat(ModelBuilder &mb, StringLiteral fn) { auto f32 = mb.f32; auto mkVectorType = mb.getVectorType({M, K}, f32); @@ -64,21 +66,15 @@ iterator_types.push_back(mb.getStringAttr("parallel")); iterator_types.push_back(mb.getStringAttr("reduction")); - // Loop I times over the kernel to reduce the JIT's overhead. - auto loop = - b.create<loop::ForOp>(f.getLoc(), std_constant_index(0), - std_constant_index(I), std_constant_index(1)); - - OpBuilder bodyBuilder = loop.getBodyBuilder(); - { - edsc::ScopedContext bodyScope(bodyBuilder, f.getLoc()); + // Loop ITERS times over the kernel to reduce the JIT's overhead. + StdIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); + ValueHandle i(mb.getIndexType()); + LoopNestBuilder(&i, std_constant_index(0), std_constant_index(ITERS), + std_constant_index(1))([&] { // Compute C += A x B, in column-major form, with LLVM matrix intrinsics. - StdIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), - C(f.getArgument(2)); C() = (vector_contract(*A(), *B(), *C(), mb.getAffineMapArrayAttr(accesses), mb.getArrayAttr(iterator_types))); - } - + }); std_ret(); }
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp index 35eb9ec..d19fc4f 100644 --- a/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp +++ b/experimental/ModelBuilder/test/BenchMatMulVectorJIT.cpp
@@ -28,8 +28,9 @@ } // Helper method to build a NxN matrix-matrix-transpose multiplication function -// using vector dialect that runs I times to amortize any calling overhead. -template <unsigned N, unsigned I> +// using the vector dialect and that runs ITERS times to amortize any calling +// overhead. +template <unsigned N, unsigned ITERS> void buildMatMat(ModelBuilder &mb, StringLiteral fn) { auto f32 = mb.f32; auto nnVectorType = mb.getVectorType({N, N}, f32); @@ -57,21 +58,15 @@ iterator_types.push_back(mb.getStringAttr("parallel")); iterator_types.push_back(mb.getStringAttr("reduction")); - // Loop I times over the kernel to amortize calling overhead. - auto loop = - b.create<loop::ForOp>(f.getLoc(), std_constant_index(0), - std_constant_index(I), std_constant_index(1)); - - OpBuilder bodyBuilder = loop.getBodyBuilder(); - { - edsc::ScopedContext bodyScope(bodyBuilder, f.getLoc()); + // Loop ITERS times over the kernel to reduce the JIT's overhead. + StdIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); + ValueHandle i(mb.getIndexType()); + LoopNestBuilder(&i, std_constant_index(0), std_constant_index(ITERS), + std_constant_index(1))([&] { // Compute C += A x B^T with row-wise dot-products. - StdIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), - C(f.getArgument(2)); C() = (vector_contract(*A(), *B(), *C(), mb.getAffineMapArrayAttr(accesses), mb.getArrayAttr(iterator_types))); - } - + }); std_ret(); }
diff --git a/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp b/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp index ed4f17f..bdbf2a1 100644 --- a/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp +++ b/experimental/ModelBuilder/test/BenchMatVecVectorJIT.cpp
@@ -32,8 +32,9 @@ } // Helper method to build a NxN matrix-vector multiplication function -// using vector dialect that runs I times to amortize any calling overhead. -template <unsigned N, unsigned I> +// using the vector dialect and that runs I times to amortize any calling +// overhead. +template <unsigned N, unsigned ITERS> void buildMatMat(ModelBuilder &mb, StringLiteral fn) { auto f32 = mb.f32; auto nnVectorType = mb.getVectorType({N, N}, f32); @@ -61,21 +62,15 @@ iterator_types.push_back(mb.getStringAttr("parallel")); iterator_types.push_back(mb.getStringAttr("reduction")); - // Loop I times over the kernel to amortize calling overhead. - auto loop = - b.create<loop::ForOp>(f.getLoc(), std_constant_index(0), - std_constant_index(I), std_constant_index(1)); - - OpBuilder bodyBuilder = loop.getBodyBuilder(); - { - edsc::ScopedContext bodyScope(bodyBuilder, f.getLoc()); + // Loop ITERS times over the kernel to reduce the JIT's overhead. + StdIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); + ValueHandle i(mb.getIndexType()); + LoopNestBuilder(&i, std_constant_index(0), std_constant_index(ITERS), + std_constant_index(1))([&] { // Compute c += A x b. - StdIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), - C(f.getArgument(2)); C() = (vector_contract(*A(), *B(), *C(), mb.getAffineMapArrayAttr(accesses), mb.getArrayAttr(iterator_types))); - } - + }); std_ret(); }