[ModelBuilder] Build a func for MNIST on tensors

This revision demonstrates how to build a full MNIST on tensors.
For now this can only emit IR but not run because we are missing a buffer allocation pass on Linalg tensors.

PiperOrigin-RevId: 293660377
diff --git a/experimental/ModelBuilder/ModelBuilder.cpp b/experimental/ModelBuilder/ModelBuilder.cpp
index 6389719..b1d4278 100644
--- a/experimental/ModelBuilder/ModelBuilder.cpp
+++ b/experimental/ModelBuilder/ModelBuilder.cpp
@@ -32,6 +32,11 @@
       i8(IntegerType::get(8, &ctx)),
       f32(FloatType::getF32(&ctx)) {}
 
+Value mlir::ModelBuilder::constant_f32(float v) {
+  return constant_float(llvm::APFloat(v),
+                        FloatType::getF32(ScopedContext::getContext()));
+}
+
 FuncOp mlir::ModelBuilder::makeFunction(StringRef name, ArrayRef<Type> results,
                                         ArrayRef<Type> args, bool declOnly) {
   auto function =
@@ -46,6 +51,20 @@
   return MemRefType::get(shape, elementType, {});
 }
 
+RankedTensorType mlir::ModelBuilder::getRankedTensorType(
+    ArrayRef<int64_t> shape, Type elementType) {
+  return RankedTensorType::get(shape, elementType);
+}
+
+Value mlir::ModelBuilder::fusedBiasTanh(ValueHandle x, ValueHandle bias) {
+  using edsc::op::operator+;
+  using edsc::op::operator*;
+  using edsc::intrinsics::tanh;
+  assert(x.getType().isF32() && bias.getType().isF32() && "f32 expected");
+  ValueHandle half(constant_f32(0.5f));
+  return x + half * tanh((x + bias) * half) + half;
+}
+
 ValueHandle mlir::ModelBuilder::FCBiasTanh(std::array<Value, 3> fcArgs,
                                            Value biasValueArg) {
   //==========================================================================//
@@ -62,31 +81,36 @@
   AffineExpr i, j;
   bindDims(&ctx, i, j);
 
-  // Define the pointwise computation:
-  //   `0.5f * tanh(0.5f * (x + bias)) + 0.5f`
-  // This assumes ValueHandle captures an MLIR Value with a proper type
-  // (in this case `f32`)
-  auto opBuilder = [this](const ValueHandle &x,
-                          const ValueHandle &bias) -> Value {
-    using edsc::op::operator+;
-    using edsc::op::operator*;
-    using edsc::intrinsics::tanh;
-
-    // `0.5f * tanh(0.5f * (x + bias)) + 0.5f`
-    auto half = constant_float(llvm::APFloat(0.5f), f32);
-    return x + half * tanh((x + bias) * half) + half;
-  };
-
   // Emit a linalg.generic op that implements pointwise with `opBuilder` for:
   //   `0.5f * tanh(0.5f * (x + bias)) + 0.5f`
   //
   // This performs the (inplace) computation:
-  //   `SO[i, j] <- pointwise(SBias[j], SO[i, j])`
+  //   `o[i, j] <- pointwise(bias[j], o[i, j])`
   //
-  // in which SBias is broadcast along `i`.
-  ValueHandle Bias(biasValueArg);
-  StructuredIndexed SO(O), SBias(Bias);
-  linalg_pointwise(opBuilder, SO({i, j}), SBias({j}), SO({i, j}));
+  // in which bias is broadcast along `i`.
+  StructuredIndexed o(O), bias(biasValueArg);
+  linalg_pointwise(fusedBiasTanh, o({i, j}), bias({j}), o({i, j}));
 
   return O;
 }
+
+Value ModelBuilder::FCBiasTanhTensors(RankedTensorType outputTensorType,
+                                      std::array<Value, 2> fcArgs,
+                                      Value biasValueArg) {
+  //==========================================================================//
+  // Layer 1: FC
+  //==========================================================================//
+  ValueHandle I(fcArgs[0]), W(fcArgs[1]);
+  Value O2 = linalg_matmul(I, W, outputTensorType)->getResult(0);
+
+  //==========================================================================//
+  // Layer 2: BiasAddTanh Block
+  //==========================================================================//
+  ValueHandle Bias(biasValueArg);
+  AffineExpr i, j;
+  bindDims(&ctx, i, j);
+  // in-place with explicit bias broacast
+  StructuredIndexed o2(O2), bias(Bias), o3Type(outputTensorType);
+  return linalg_pointwise(fusedBiasTanh, o2({i, j}), bias({j}), o3Type({i, j}))
+      ->getResult(0);
+}
diff --git a/experimental/ModelBuilder/ModelBuilder.h b/experimental/ModelBuilder/ModelBuilder.h
index 2b7b718..e9c0a7b7 100644
--- a/experimental/ModelBuilder/ModelBuilder.h
+++ b/experimental/ModelBuilder/ModelBuilder.h
@@ -80,6 +80,9 @@
 
   OwningModuleRef &getModuleRef() { return module; }
 
+  // Build the MLIR representation for an f32 constant.
+  static Value constant_f32(float v);
+
   // Build an MLIR FuncOp that will be callable after JIT compilation occured.
   FuncOp makeFunction(StringRef name, ArrayRef<Type> results = {},
                       ArrayRef<Type> args = {}, bool declOnly = false);
@@ -91,12 +94,30 @@
   // per-need basis.
   MemRefType getMemRefType(ArrayRef<int64_t> shape, Type elementType);
 
-  // FCBiasTanh implements:
+  // Build an MLIR RankedTensorType with a base `elementType` and a `shape` that
+  // can be any mix of static and dynamic values. For now this only supports a
+  // dense and contiguous layout.
+  // In the future, this can be extended support more advanced layouts, on a
+  // per-need basis.
+  RankedTensorType getRankedTensorType(ArrayRef<int64_t> shape,
+                                       Type elementType);
+
+  // Build the MLIR representation for:
   //   1. fc(I, W, O)
   //   2. pointwise(O, bias) in-place with explicit bias broadcast to compute:
   //      `0.5f * tanh(0.5f * (x + bias)) + 0.5f`
   // Returns O.
+  // Version with a MemRef output argument.
   ValueHandle FCBiasTanh(std::array<Value, 3> fcArgs, Value biasValueArg);
+  // Version with a RankedTensor result.
+  Value FCBiasTanhTensors(RankedTensorType outputTensorType,
+                          std::array<Value, 2> fcArgs, Value biasValueArg);
+
+  // Build the MLIR representation for:
+  //   `0.5f * tanh(0.5f * (x + bias)) + 0.5f`
+  // This assumes `x` and `bias` capture scalar MLIR values of type f32.
+  // This is used as a region builder when constructing e.g. a pointwise op.
+  static Value fusedBiasTanh(ValueHandle x, ValueHandle bias);
 
  protected:
   static thread_local MLIRContext ctx;
diff --git a/experimental/ModelBuilder/test/TestMNISTJIT.cpp b/experimental/ModelBuilder/test/TestMNISTJIT.cpp
index 8abde92..0b582ba 100644
--- a/experimental/ModelBuilder/test/TestMNISTJIT.cpp
+++ b/experimental/ModelBuilder/test/TestMNISTJIT.cpp
@@ -12,11 +12,13 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-// RUN: test-mnist-jit | IreeFileCheck %s
+// RUN: test-mnist-jit 2>&1 | IreeFileCheck %s
 
 #include "experimental/ModelBuilder/MemRefUtils.h"
 #include "experimental/ModelBuilder/ModelBuilder.h"
 #include "experimental/ModelBuilder/ModelRunner.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/StandardTypes.h"
 
 using namespace mlir;
 
@@ -65,8 +67,8 @@
       alloc(modelBuilder.getMemRefType({-1, W2}, f32), batchSize);
   Value outputBlock3 = func.getArgument(1);
 
-  auto zero = constant_float(llvm::APFloat(0.0f), f32);
-  auto someVal = constant_float(llvm::APFloat(0.1123f), f32);
+  ValueHandle zero(modelBuilder.constant_f32(0.0f));
+  ValueHandle someVal(modelBuilder.constant_f32(0.1123f));
   linalg_fill(h1Weights, someVal);
   linalg_fill(h2Weights, someVal);
   linalg_fill(h3Weights, someVal);
@@ -94,20 +96,81 @@
   (ret());
 }
 
+// Helper function to build a func `funcName` that takes a tensors for the input
+// in the form of a `tensor<?x784xf32>` as well as static tensors for all the
+// weights and biases.
+//
+// This is the counterpart of `buildMNIST` which builds a similar model on
+// buffers.
+void buildMNISTOnTensors(ModelBuilder &modelBuilder, StringLiteral funcName,
+                         int64_t B, int64_t W0, int64_t W1, int64_t W2,
+                         int64_t W3) {
+  auto f32 = modelBuilder.f32;
+  auto inputType = modelBuilder.getRankedTensorType({B, W0}, f32);
+  auto h1WeightsType = modelBuilder.getRankedTensorType({W0, W1}, f32);
+  auto h2WeightsType = modelBuilder.getRankedTensorType({W1, W2}, f32);
+  auto h3WeightsType = modelBuilder.getRankedTensorType({W2, W3}, f32);
+  auto bias1Type = modelBuilder.getRankedTensorType({W1}, f32);
+  auto bias2Type = modelBuilder.getRankedTensorType({W2}, f32);
+  auto bias3Type = modelBuilder.getRankedTensorType({W3}, f32);
+  auto outputType = modelBuilder.getRankedTensorType({B, W3}, f32);
+  auto func = modelBuilder.makeFunction(
+      funcName, {outputType},
+      {inputType, h1WeightsType, h2WeightsType, h3WeightsType, bias1Type,
+       bias2Type, bias3Type});
+  Value input = func.getArgument(0);
+  Value h1Weights = func.getArgument(1);
+  Value h2Weights = func.getArgument(2);
+  Value h3Weights = func.getArgument(3);
+  Value bias1 = func.getArgument(4);
+  Value bias2 = func.getArgument(5);
+  Value bias3 = func.getArgument(6);
+
+  // 2. Fill the body (3 blocks of FCBiasTanh), alloc everything manually atm.
+  OpBuilder b(&func.getBody());
+  ScopedContext scope(b, func.getLoc());
+
+  auto outputBlock1Type = modelBuilder.getRankedTensorType({B, W1}, f32);
+  auto outputBlock1 = modelBuilder.FCBiasTanhTensors(outputBlock1Type,
+                                                     {input, h1Weights}, bias1);
+  auto outputBlock2Type = modelBuilder.getRankedTensorType({B, W2}, f32);
+  auto outputBlock2 = modelBuilder.FCBiasTanhTensors(
+      outputBlock2Type, {outputBlock1, h2Weights}, bias2);
+  auto outputBlock3Type = outputType;
+  auto outputBlock3 = modelBuilder.FCBiasTanhTensors(
+      outputBlock3Type, {outputBlock2, h3Weights}, bias3);
+  // Vexing parses.
+  (ret(outputBlock3));
+}
+
 int main() {
-  constexpr StringLiteral kFuncName = "test_mnist_jit";
   constexpr unsigned B = 3, W0 = 784, W1 = 256, W2 = 256, W3 = 10;
 
-  // 1. Build a func "mnist" that takes a memref<?x784xf32> buffer
-  //    (use batch size M=3 in this example)
   ModelBuilder modelBuilder;
-  buildMNIST(modelBuilder, kFuncName, B, W0, W1, W2, W3);
+  // 1. Build a func "test_mnist_jit_tensors".
+  constexpr StringLiteral kFuncTensorsName = "test_mnist_jit_tensors";
+  buildMNISTOnTensors(modelBuilder, kFuncTensorsName, ShapedType::kDynamicSize,
+                      W0, W1, W2, W3);
+  // 1.b. Dump the function for testing and erase it: we can't compile it to
+  // buffers for now.
+  modelBuilder.getModuleRef()->dump();
+  SymbolTable::lookupNearestSymbolFrom(
+      modelBuilder.getModuleRef()->getOperation(), kFuncTensorsName)
+      ->erase();
 
-  // 2. Compile the function.
+  // 2. Build a separate func "test_mnist_jit_buffers" that takes a
+  // memref<?x784xf32> buffer
+  //    (use batch size M=3 in this example)
+  // In the future, when we can lower the function built in 1. to buffers we
+  // will.
+  constexpr StringLiteral kFuncBuffersName = "test_mnist_jit_buffers";
+  buildMNIST(modelBuilder, kFuncBuffersName, B, W0, W1, W2, W3);
+
+  // 3. Compile the function.
   ModelRunner runner(modelBuilder.getModuleRef());
   runner.compile(/*llvmOptLevel=*/3, /*llcOptLevel=*/3);
 
-  // 3. Allocate data within data structures that interoperate with the MLIR ABI
+  // 4. Allocate data within data structures that interoperate with the MLIR ABI
   // conventions used by codegen.
   auto inputLinearInit = [](unsigned idx, float *ptr) { *ptr = 0.032460f; };
   ManagedUnrankedMemRefDescriptor inputBuffer =
@@ -116,17 +179,55 @@
   ManagedUnrankedMemRefDescriptor outputBuffer =
       makeInitializedUnrankedDescriptor<float>({B, W3}, outputLinearInit);
 
-  // 4. Call the funcOp name `kFuncName` with arguments.
+  // 5. Call the funcOp name `kFuncBuffersName` with arguments.
   void *args[2] = {&inputBuffer->descriptor, &outputBuffer->descriptor};
-  auto error =
-      runner.engine->invoke(kFuncName, llvm::MutableArrayRef<void *>{args});
+  auto error = runner.engine->invoke(kFuncBuffersName,
+                                     llvm::MutableArrayRef<void *>{args});
 
-  // 5. Dump content of output buffer for testing with FileCheck.
+  // 6. Dump content of output buffer for testing with FileCheck.
   if (!error)
     ::impl::printMemRef(
         *static_cast<StridedMemRefType<float, 2> *>(outputBuffer->descriptor));
 }
 
+// For now, we can only dump the IR for `test_mnist_jit_tensors`.
+// Once buffer allocation is implemented we will only have an execution test.
+//
+// CHECK: func @test_mnist_jit_tensors
+//
+// Matmul
+// CHECK: linalg.generic
+// CHECK:   tensor<?x784xf32>, tensor<784x256xf32> -> tensor<?x256xf32>
+//
+// Pointwise
+// CHECK: linalg.generic
+// CHECK:   addf
+// CHECK:   mulf
+// CHECK:   tanh
+// CHECK:   mulf
+// CHECK:   addf
+// CHECK:   addf
+// CHECK:   tensor<?x256xf32>, tensor<256xf32> -> tensor<?x256xf32>
+//
+// Matmul
+// CHECK: linalg.generic
+// CHECK:   tensor<?x256xf32>, tensor<256x256xf32> -> tensor<?x256xf32>
+//
+// Pointwise
+// CHECK: linalg.generic
+// CHECK:   tensor<?x256xf32>, tensor<256xf32> -> tensor<?x256xf32>
+//
+// Matmul
+// CHECK: linalg.generic
+// CHECK:   tensor<?x256xf32>, tensor<256x10xf32> -> tensor<?x10xf32>
+//
+// Pointwise
+// CHECK: linalg.generic
+// CHECK:   tensor<?x10xf32>, tensor<10xf32> -> tensor<?x10xf32>
+// CHECK:   return {{.*}} : tensor<?x10xf32>
+
+// Execution test for `test_mnist_jit_buffers`.
+//
 // CHECK: Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [3, 10]
 // CHECK-SAME: strides = [10, 1] data =
 // clang-format off