[ModelBuilder] Build a func for MNIST on tensors This revision demonstrates how to build a full MNIST on tensors. For now this can only emit IR but not run because we are missing a buffer allocation pass on Linalg tensors. PiperOrigin-RevId: 293660377
diff --git a/experimental/ModelBuilder/ModelBuilder.cpp b/experimental/ModelBuilder/ModelBuilder.cpp index 6389719..b1d4278 100644 --- a/experimental/ModelBuilder/ModelBuilder.cpp +++ b/experimental/ModelBuilder/ModelBuilder.cpp
@@ -32,6 +32,11 @@ i8(IntegerType::get(8, &ctx)), f32(FloatType::getF32(&ctx)) {} +Value mlir::ModelBuilder::constant_f32(float v) { + return constant_float(llvm::APFloat(v), + FloatType::getF32(ScopedContext::getContext())); +} + FuncOp mlir::ModelBuilder::makeFunction(StringRef name, ArrayRef<Type> results, ArrayRef<Type> args, bool declOnly) { auto function = @@ -46,6 +51,20 @@ return MemRefType::get(shape, elementType, {}); } +RankedTensorType mlir::ModelBuilder::getRankedTensorType( + ArrayRef<int64_t> shape, Type elementType) { + return RankedTensorType::get(shape, elementType); +} + +Value mlir::ModelBuilder::fusedBiasTanh(ValueHandle x, ValueHandle bias) { + using edsc::op::operator+; + using edsc::op::operator*; + using edsc::intrinsics::tanh; + assert(x.getType().isF32() && bias.getType().isF32() && "f32 expected"); + ValueHandle half(constant_f32(0.5f)); + return x + half * tanh((x + bias) * half) + half; +} + ValueHandle mlir::ModelBuilder::FCBiasTanh(std::array<Value, 3> fcArgs, Value biasValueArg) { //==========================================================================// @@ -62,31 +81,36 @@ AffineExpr i, j; bindDims(&ctx, i, j); - // Define the pointwise computation: - // `0.5f * tanh(0.5f * (x + bias)) + 0.5f` - // This assumes ValueHandle captures an MLIR Value with a proper type - // (in this case `f32`) - auto opBuilder = [this](const ValueHandle &x, - const ValueHandle &bias) -> Value { - using edsc::op::operator+; - using edsc::op::operator*; - using edsc::intrinsics::tanh; - - // `0.5f * tanh(0.5f * (x + bias)) + 0.5f` - auto half = constant_float(llvm::APFloat(0.5f), f32); - return x + half * tanh((x + bias) * half) + half; - }; - // Emit a linalg.generic op that implements pointwise with `opBuilder` for: // `0.5f * tanh(0.5f * (x + bias)) + 0.5f` // // This performs the (inplace) computation: - // `SO[i, j] <- pointwise(SBias[j], SO[i, j])` + // `o[i, j] <- pointwise(bias[j], o[i, j])` // - // in which SBias is broadcast along `i`. - ValueHandle Bias(biasValueArg); - StructuredIndexed SO(O), SBias(Bias); - linalg_pointwise(opBuilder, SO({i, j}), SBias({j}), SO({i, j})); + // in which bias is broadcast along `i`. + StructuredIndexed o(O), bias(biasValueArg); + linalg_pointwise(fusedBiasTanh, o({i, j}), bias({j}), o({i, j})); return O; } + +Value ModelBuilder::FCBiasTanhTensors(RankedTensorType outputTensorType, + std::array<Value, 2> fcArgs, + Value biasValueArg) { + //==========================================================================// + // Layer 1: FC + //==========================================================================// + ValueHandle I(fcArgs[0]), W(fcArgs[1]); + Value O2 = linalg_matmul(I, W, outputTensorType)->getResult(0); + + //==========================================================================// + // Layer 2: BiasAddTanh Block + //==========================================================================// + ValueHandle Bias(biasValueArg); + AffineExpr i, j; + bindDims(&ctx, i, j); + // in-place with explicit bias broacast + StructuredIndexed o2(O2), bias(Bias), o3Type(outputTensorType); + return linalg_pointwise(fusedBiasTanh, o2({i, j}), bias({j}), o3Type({i, j})) + ->getResult(0); +}
diff --git a/experimental/ModelBuilder/ModelBuilder.h b/experimental/ModelBuilder/ModelBuilder.h index 2b7b718..e9c0a7b7 100644 --- a/experimental/ModelBuilder/ModelBuilder.h +++ b/experimental/ModelBuilder/ModelBuilder.h
@@ -80,6 +80,9 @@ OwningModuleRef &getModuleRef() { return module; } + // Build the MLIR representation for an f32 constant. + static Value constant_f32(float v); + // Build an MLIR FuncOp that will be callable after JIT compilation occured. FuncOp makeFunction(StringRef name, ArrayRef<Type> results = {}, ArrayRef<Type> args = {}, bool declOnly = false); @@ -91,12 +94,30 @@ // per-need basis. MemRefType getMemRefType(ArrayRef<int64_t> shape, Type elementType); - // FCBiasTanh implements: + // Build an MLIR RankedTensorType with a base `elementType` and a `shape` that + // can be any mix of static and dynamic values. For now this only supports a + // dense and contiguous layout. + // In the future, this can be extended support more advanced layouts, on a + // per-need basis. + RankedTensorType getRankedTensorType(ArrayRef<int64_t> shape, + Type elementType); + + // Build the MLIR representation for: // 1. fc(I, W, O) // 2. pointwise(O, bias) in-place with explicit bias broadcast to compute: // `0.5f * tanh(0.5f * (x + bias)) + 0.5f` // Returns O. + // Version with a MemRef output argument. ValueHandle FCBiasTanh(std::array<Value, 3> fcArgs, Value biasValueArg); + // Version with a RankedTensor result. + Value FCBiasTanhTensors(RankedTensorType outputTensorType, + std::array<Value, 2> fcArgs, Value biasValueArg); + + // Build the MLIR representation for: + // `0.5f * tanh(0.5f * (x + bias)) + 0.5f` + // This assumes `x` and `bias` capture scalar MLIR values of type f32. + // This is used as a region builder when constructing e.g. a pointwise op. + static Value fusedBiasTanh(ValueHandle x, ValueHandle bias); protected: static thread_local MLIRContext ctx;
diff --git a/experimental/ModelBuilder/test/TestMNISTJIT.cpp b/experimental/ModelBuilder/test/TestMNISTJIT.cpp index 8abde92..0b582ba 100644 --- a/experimental/ModelBuilder/test/TestMNISTJIT.cpp +++ b/experimental/ModelBuilder/test/TestMNISTJIT.cpp
@@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -// RUN: test-mnist-jit | IreeFileCheck %s +// RUN: test-mnist-jit 2>&1 | IreeFileCheck %s #include "experimental/ModelBuilder/MemRefUtils.h" #include "experimental/ModelBuilder/ModelBuilder.h" #include "experimental/ModelBuilder/ModelRunner.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/StandardTypes.h" using namespace mlir; @@ -65,8 +67,8 @@ alloc(modelBuilder.getMemRefType({-1, W2}, f32), batchSize); Value outputBlock3 = func.getArgument(1); - auto zero = constant_float(llvm::APFloat(0.0f), f32); - auto someVal = constant_float(llvm::APFloat(0.1123f), f32); + ValueHandle zero(modelBuilder.constant_f32(0.0f)); + ValueHandle someVal(modelBuilder.constant_f32(0.1123f)); linalg_fill(h1Weights, someVal); linalg_fill(h2Weights, someVal); linalg_fill(h3Weights, someVal); @@ -94,20 +96,81 @@ (ret()); } +// Helper function to build a func `funcName` that takes a tensors for the input +// in the form of a `tensor<?x784xf32>` as well as static tensors for all the +// weights and biases. +// +// This is the counterpart of `buildMNIST` which builds a similar model on +// buffers. +void buildMNISTOnTensors(ModelBuilder &modelBuilder, StringLiteral funcName, + int64_t B, int64_t W0, int64_t W1, int64_t W2, + int64_t W3) { + auto f32 = modelBuilder.f32; + auto inputType = modelBuilder.getRankedTensorType({B, W0}, f32); + auto h1WeightsType = modelBuilder.getRankedTensorType({W0, W1}, f32); + auto h2WeightsType = modelBuilder.getRankedTensorType({W1, W2}, f32); + auto h3WeightsType = modelBuilder.getRankedTensorType({W2, W3}, f32); + auto bias1Type = modelBuilder.getRankedTensorType({W1}, f32); + auto bias2Type = modelBuilder.getRankedTensorType({W2}, f32); + auto bias3Type = modelBuilder.getRankedTensorType({W3}, f32); + auto outputType = modelBuilder.getRankedTensorType({B, W3}, f32); + auto func = modelBuilder.makeFunction( + funcName, {outputType}, + {inputType, h1WeightsType, h2WeightsType, h3WeightsType, bias1Type, + bias2Type, bias3Type}); + Value input = func.getArgument(0); + Value h1Weights = func.getArgument(1); + Value h2Weights = func.getArgument(2); + Value h3Weights = func.getArgument(3); + Value bias1 = func.getArgument(4); + Value bias2 = func.getArgument(5); + Value bias3 = func.getArgument(6); + + // 2. Fill the body (3 blocks of FCBiasTanh), alloc everything manually atm. + OpBuilder b(&func.getBody()); + ScopedContext scope(b, func.getLoc()); + + auto outputBlock1Type = modelBuilder.getRankedTensorType({B, W1}, f32); + auto outputBlock1 = modelBuilder.FCBiasTanhTensors(outputBlock1Type, + {input, h1Weights}, bias1); + auto outputBlock2Type = modelBuilder.getRankedTensorType({B, W2}, f32); + auto outputBlock2 = modelBuilder.FCBiasTanhTensors( + outputBlock2Type, {outputBlock1, h2Weights}, bias2); + auto outputBlock3Type = outputType; + auto outputBlock3 = modelBuilder.FCBiasTanhTensors( + outputBlock3Type, {outputBlock2, h3Weights}, bias3); + // Vexing parses. + (ret(outputBlock3)); +} + int main() { - constexpr StringLiteral kFuncName = "test_mnist_jit"; constexpr unsigned B = 3, W0 = 784, W1 = 256, W2 = 256, W3 = 10; - // 1. Build a func "mnist" that takes a memref<?x784xf32> buffer - // (use batch size M=3 in this example) ModelBuilder modelBuilder; - buildMNIST(modelBuilder, kFuncName, B, W0, W1, W2, W3); + // 1. Build a func "test_mnist_jit_tensors". + constexpr StringLiteral kFuncTensorsName = "test_mnist_jit_tensors"; + buildMNISTOnTensors(modelBuilder, kFuncTensorsName, ShapedType::kDynamicSize, + W0, W1, W2, W3); + // 1.b. Dump the function for testing and erase it: we can't compile it to + // buffers for now. + modelBuilder.getModuleRef()->dump(); + SymbolTable::lookupNearestSymbolFrom( + modelBuilder.getModuleRef()->getOperation(), kFuncTensorsName) + ->erase(); - // 2. Compile the function. + // 2. Build a separate func "test_mnist_jit_buffers" that takes a + // memref<?x784xf32> buffer + // (use batch size M=3 in this example) + // In the future, when we can lower the function built in 1. to buffers we + // will. + constexpr StringLiteral kFuncBuffersName = "test_mnist_jit_buffers"; + buildMNIST(modelBuilder, kFuncBuffersName, B, W0, W1, W2, W3); + + // 3. Compile the function. ModelRunner runner(modelBuilder.getModuleRef()); runner.compile(/*llvmOptLevel=*/3, /*llcOptLevel=*/3); - // 3. Allocate data within data structures that interoperate with the MLIR ABI + // 4. Allocate data within data structures that interoperate with the MLIR ABI // conventions used by codegen. auto inputLinearInit = [](unsigned idx, float *ptr) { *ptr = 0.032460f; }; ManagedUnrankedMemRefDescriptor inputBuffer = @@ -116,17 +179,55 @@ ManagedUnrankedMemRefDescriptor outputBuffer = makeInitializedUnrankedDescriptor<float>({B, W3}, outputLinearInit); - // 4. Call the funcOp name `kFuncName` with arguments. + // 5. Call the funcOp name `kFuncBuffersName` with arguments. void *args[2] = {&inputBuffer->descriptor, &outputBuffer->descriptor}; - auto error = - runner.engine->invoke(kFuncName, llvm::MutableArrayRef<void *>{args}); + auto error = runner.engine->invoke(kFuncBuffersName, + llvm::MutableArrayRef<void *>{args}); - // 5. Dump content of output buffer for testing with FileCheck. + // 6. Dump content of output buffer for testing with FileCheck. if (!error) ::impl::printMemRef( *static_cast<StridedMemRefType<float, 2> *>(outputBuffer->descriptor)); } +// For now, we can only dump the IR for `test_mnist_jit_tensors`. +// Once buffer allocation is implemented we will only have an execution test. +// +// CHECK: func @test_mnist_jit_tensors +// +// Matmul +// CHECK: linalg.generic +// CHECK: tensor<?x784xf32>, tensor<784x256xf32> -> tensor<?x256xf32> +// +// Pointwise +// CHECK: linalg.generic +// CHECK: addf +// CHECK: mulf +// CHECK: tanh +// CHECK: mulf +// CHECK: addf +// CHECK: addf +// CHECK: tensor<?x256xf32>, tensor<256xf32> -> tensor<?x256xf32> +// +// Matmul +// CHECK: linalg.generic +// CHECK: tensor<?x256xf32>, tensor<256x256xf32> -> tensor<?x256xf32> +// +// Pointwise +// CHECK: linalg.generic +// CHECK: tensor<?x256xf32>, tensor<256xf32> -> tensor<?x256xf32> +// +// Matmul +// CHECK: linalg.generic +// CHECK: tensor<?x256xf32>, tensor<256x10xf32> -> tensor<?x10xf32> +// +// Pointwise +// CHECK: linalg.generic +// CHECK: tensor<?x10xf32>, tensor<10xf32> -> tensor<?x10xf32> +// CHECK: return {{.*}} : tensor<?x10xf32> + +// Execution test for `test_mnist_jit_buffers`. +// // CHECK: Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [3, 10] // CHECK-SAME: strides = [10, 1] data = // clang-format off