[ModelBuilder] new example that test MLIR building and execution PiperOrigin-RevId: 303201274
diff --git a/experimental/ModelBuilder/test/BUILD b/experimental/ModelBuilder/test/BUILD index f34fb38..acd39e2 100644 --- a/experimental/ModelBuilder/test/BUILD +++ b/experimental/ModelBuilder/test/BUILD
@@ -30,6 +30,7 @@ data = [ ":runtime-support.so", # Tests. + ":test-dot-prod", ":test-mnist-jit", ":test-simple-jit", ":test-simple-mlir", @@ -42,6 +43,24 @@ ) cc_binary( + name = "test-dot-prod", + srcs = ["TestDotProdJIT.cpp"], + tags = [ + "noga", + ], + deps = [ + ":runtime-support.so", + "//experimental/ModelBuilder", + "//experimental/ModelBuilder:ModelRunner", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:EDSC", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LoopOpsTransforms", + ], +) + +cc_binary( name = "test-mnist-jit", srcs = ["TestMNISTJIT.cpp"], tags = [
diff --git a/experimental/ModelBuilder/test/TestDotProdJIT.cpp b/experimental/ModelBuilder/test/TestDotProdJIT.cpp new file mode 100644 index 0000000..2751a6c --- /dev/null +++ b/experimental/ModelBuilder/test/TestDotProdJIT.cpp
@@ -0,0 +1,111 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// clang-format off + +// NOLINTNEXTLINE +// RUN: test-dot-prod -runtime-support=$(dirname %s)/runtime-support.so 2>&1 | IreeFileCheck %s + +// clang-format on + +#include "experimental/ModelBuilder/MemRefUtils.h" +#include "experimental/ModelBuilder/ModelBuilder.h" +#include "experimental/ModelBuilder/ModelRunner.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" + +using namespace mlir; // NOLINT +using namespace mlir::edsc; // NOLINT +using namespace mlir::edsc::intrinsics; // NOLINT + +static llvm::cl::opt<std::string> runtimeSupport( + "runtime-support", llvm::cl::desc("Runtime support library filename"), + llvm::cl::value_desc("filename"), llvm::cl::init("-")); + +void DotProdOnVectors() { + constexpr unsigned N = 4; + + ModelBuilder modelBuilder; + // Build a func "dot_prod". + constexpr StringLiteral funcName = "dot-prod"; + auto f32 = modelBuilder.f32; + auto vectorType = modelBuilder.getVectorType(N, f32); + auto refType = modelBuilder.getMemRefType(1, vectorType); + + auto func = modelBuilder.makeFunction(funcName, {}, {refType, refType}); + + SmallVector<AffineMap, 3> accesses; + accesses.push_back(modelBuilder.getDimIdentityMap()); + accesses.push_back(accesses[0]); + accesses.push_back(AffineMap::get(1, 0, modelBuilder.getContext())); + + SmallVector<Attribute, 1> iterator_types; + iterator_types.push_back(modelBuilder.getStringAttr("reduction")); + + OpBuilder b(&func.getBody()); + ScopedContext scope(b, func.getLoc()); + ValueHandle A(func.getArgument(0)), B(func.getArgument(1)); + Value zero = ValueBuilder<ConstantIndexOp>(0); + Value A_val = std_load(A, ValueRange(zero)); + Value B_val = std_load(B, ValueRange(zero)); + Value zeroF = std_constant_float(APFloat(0.0f), f32); + Value res_val = (vector_contract(A_val, B_val, zeroF, + modelBuilder.getAffineMapArrayAttr(accesses), + modelBuilder.getArrayAttr(iterator_types))); + + (vector_print(A_val)); + (vector_print(B_val)); + (vector_print(res_val)); + + std_ret(); + + // Compile the function, pass in runtime support library + // to the execution engine for vector.print. + ModelRunner runner(modelBuilder.getModuleRef()); + runner.compile(CompilationOptions(), runtimeSupport); + + // initialize data by interoperating with the MLIR ABI by codegen. + auto inputInit1 = [](unsigned idx, Vector1D<N, float> *ptr) { + for (unsigned i = 0; i < N; ++i) ptr[idx][i] = 3.0 * i; + }; + auto inputInit2 = [](unsigned idx, Vector1D<N, float> *ptr) { + for (unsigned i = 0; i < N; ++i) ptr[idx][i] = 2.0 * i; + }; + + auto _A = makeInitializedStridedMemRefDescriptor<Vector1D<N, float>, 1>( + {N}, inputInit1); + auto _B = makeInitializedStridedMemRefDescriptor<Vector1D<N, float>, 1>( + {N}, inputInit2); + + // Call the funcOp + const std::string funAdapterName = + (llvm::Twine("_mlir_ciface_") + funcName).str(); + auto *bufferA = _A.get(); + auto *bufferB = _B.get(); + void *args[3] = {&bufferA, &bufferB}; + + // CHECK: ( 0, 3, 6, 9 ) + // CHECK: ( 0, 2, 4, 6 ) + // CHECK: 84 + auto err = + runner.engine->invoke(funAdapterName, MutableArrayRef<void *>{args}); + + if (err) llvm_unreachable("Error running function."); +} + +int main(int argc, char **argv) { + llvm::InitLLVM y(argc, argv); + llvm::cl::ParseCommandLineOptions(argc, argv, "TestDotProd\n"); + DotProdOnVectors(); +}