[ModelBuilder] new example that test MLIR building and execution
PiperOrigin-RevId: 303201274
diff --git a/experimental/ModelBuilder/test/BUILD b/experimental/ModelBuilder/test/BUILD
index f34fb38..acd39e2 100644
--- a/experimental/ModelBuilder/test/BUILD
+++ b/experimental/ModelBuilder/test/BUILD
@@ -30,6 +30,7 @@
data = [
":runtime-support.so",
# Tests.
+ ":test-dot-prod",
":test-mnist-jit",
":test-simple-jit",
":test-simple-mlir",
@@ -42,6 +43,24 @@
)
cc_binary(
+ name = "test-dot-prod",
+ srcs = ["TestDotProdJIT.cpp"],
+ tags = [
+ "noga",
+ ],
+ deps = [
+ ":runtime-support.so",
+ "//experimental/ModelBuilder",
+ "//experimental/ModelBuilder:ModelRunner",
+ "@llvm-project//llvm:support",
+ "@llvm-project//mlir:AllPassesAndDialects",
+ "@llvm-project//mlir:EDSC",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LoopOpsTransforms",
+ ],
+)
+
+cc_binary(
name = "test-mnist-jit",
srcs = ["TestMNISTJIT.cpp"],
tags = [
diff --git a/experimental/ModelBuilder/test/TestDotProdJIT.cpp b/experimental/ModelBuilder/test/TestDotProdJIT.cpp
new file mode 100644
index 0000000..2751a6c
--- /dev/null
+++ b/experimental/ModelBuilder/test/TestDotProdJIT.cpp
@@ -0,0 +1,111 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// clang-format off
+
+// NOLINTNEXTLINE
+// RUN: test-dot-prod -runtime-support=$(dirname %s)/runtime-support.so 2>&1 | IreeFileCheck %s
+
+// clang-format on
+
+#include "experimental/ModelBuilder/MemRefUtils.h"
+#include "experimental/ModelBuilder/ModelBuilder.h"
+#include "experimental/ModelBuilder/ModelRunner.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+
+using namespace mlir; // NOLINT
+using namespace mlir::edsc; // NOLINT
+using namespace mlir::edsc::intrinsics; // NOLINT
+
+static llvm::cl::opt<std::string> runtimeSupport(
+ "runtime-support", llvm::cl::desc("Runtime support library filename"),
+ llvm::cl::value_desc("filename"), llvm::cl::init("-"));
+
+void DotProdOnVectors() {
+ constexpr unsigned N = 4;
+
+ ModelBuilder modelBuilder;
+ // Build a func "dot_prod".
+ constexpr StringLiteral funcName = "dot-prod";
+ auto f32 = modelBuilder.f32;
+ auto vectorType = modelBuilder.getVectorType(N, f32);
+ auto refType = modelBuilder.getMemRefType(1, vectorType);
+
+ auto func = modelBuilder.makeFunction(funcName, {}, {refType, refType});
+
+ SmallVector<AffineMap, 3> accesses;
+ accesses.push_back(modelBuilder.getDimIdentityMap());
+ accesses.push_back(accesses[0]);
+ accesses.push_back(AffineMap::get(1, 0, modelBuilder.getContext()));
+
+ SmallVector<Attribute, 1> iterator_types;
+ iterator_types.push_back(modelBuilder.getStringAttr("reduction"));
+
+ OpBuilder b(&func.getBody());
+ ScopedContext scope(b, func.getLoc());
+ ValueHandle A(func.getArgument(0)), B(func.getArgument(1));
+ Value zero = ValueBuilder<ConstantIndexOp>(0);
+ Value A_val = std_load(A, ValueRange(zero));
+ Value B_val = std_load(B, ValueRange(zero));
+ Value zeroF = std_constant_float(APFloat(0.0f), f32);
+ Value res_val = (vector_contract(A_val, B_val, zeroF,
+ modelBuilder.getAffineMapArrayAttr(accesses),
+ modelBuilder.getArrayAttr(iterator_types)));
+
+ (vector_print(A_val));
+ (vector_print(B_val));
+ (vector_print(res_val));
+
+ std_ret();
+
+ // Compile the function, pass in runtime support library
+ // to the execution engine for vector.print.
+ ModelRunner runner(modelBuilder.getModuleRef());
+ runner.compile(CompilationOptions(), runtimeSupport);
+
+ // initialize data by interoperating with the MLIR ABI by codegen.
+ auto inputInit1 = [](unsigned idx, Vector1D<N, float> *ptr) {
+ for (unsigned i = 0; i < N; ++i) ptr[idx][i] = 3.0 * i;
+ };
+ auto inputInit2 = [](unsigned idx, Vector1D<N, float> *ptr) {
+ for (unsigned i = 0; i < N; ++i) ptr[idx][i] = 2.0 * i;
+ };
+
+ auto _A = makeInitializedStridedMemRefDescriptor<Vector1D<N, float>, 1>(
+ {N}, inputInit1);
+ auto _B = makeInitializedStridedMemRefDescriptor<Vector1D<N, float>, 1>(
+ {N}, inputInit2);
+
+ // Call the funcOp
+ const std::string funAdapterName =
+ (llvm::Twine("_mlir_ciface_") + funcName).str();
+ auto *bufferA = _A.get();
+ auto *bufferB = _B.get();
+ void *args[3] = {&bufferA, &bufferB};
+
+ // CHECK: ( 0, 3, 6, 9 )
+ // CHECK: ( 0, 2, 4, 6 )
+ // CHECK: 84
+ auto err =
+ runner.engine->invoke(funAdapterName, MutableArrayRef<void *>{args});
+
+ if (err) llvm_unreachable("Error running function.");
+}
+
+int main(int argc, char **argv) {
+ llvm::InitLLVM y(argc, argv);
+ llvm::cl::ParseCommandLineOptions(argc, argv, "TestDotProd\n");
+ DotProdOnVectors();
+}