blob: 6389719bd5220260b96ffbeb414cf2fcd8cc1677 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "experimental/ModelBuilder/ModelBuilder.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::ops;
using namespace mlir::edsc::intrinsics;
thread_local MLIRContext mlir::ModelBuilder::ctx;
mlir::ModelBuilder::ModelBuilder()
: OpBuilder(&ctx),
module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&ctx))),
symbolTable(*module),
loc(module->getLoc()),
i8(IntegerType::get(8, &ctx)),
f32(FloatType::getF32(&ctx)) {}
FuncOp mlir::ModelBuilder::makeFunction(StringRef name, ArrayRef<Type> results,
ArrayRef<Type> args, bool declOnly) {
auto function =
FuncOp::create(loc, name, FunctionType::get(args, results, &ctx));
if (!declOnly) function.addEntryBlock();
module->push_back(function);
return function;
}
MemRefType mlir::ModelBuilder::getMemRefType(ArrayRef<int64_t> shape,
Type elementType) {
return MemRefType::get(shape, elementType, {});
}
ValueHandle mlir::ModelBuilder::FCBiasTanh(std::array<Value, 3> fcArgs,
Value biasValueArg) {
//==========================================================================//
// Layer 1: FC
//==========================================================================//
ValueHandle I(fcArgs[0]), W(fcArgs[1]), O(fcArgs[2]);
// Emit a linalg.generic op that implements matmul:
linalg_matmul(I, W, O);
//==========================================================================//
// Layer 2: BiasAddTanh Block
//==========================================================================//
// Build and capture AffineExpr i and j for building index expressions.
AffineExpr i, j;
bindDims(&ctx, i, j);
// Define the pointwise computation:
// `0.5f * tanh(0.5f * (x + bias)) + 0.5f`
// This assumes ValueHandle captures an MLIR Value with a proper type
// (in this case `f32`)
auto opBuilder = [this](const ValueHandle &x,
const ValueHandle &bias) -> Value {
using edsc::op::operator+;
using edsc::op::operator*;
using edsc::intrinsics::tanh;
// `0.5f * tanh(0.5f * (x + bias)) + 0.5f`
auto half = constant_float(llvm::APFloat(0.5f), f32);
return x + half * tanh((x + bias) * half) + half;
};
// Emit a linalg.generic op that implements pointwise with `opBuilder` for:
// `0.5f * tanh(0.5f * (x + bias)) + 0.5f`
//
// This performs the (inplace) computation:
// `SO[i, j] <- pointwise(SBias[j], SO[i, j])`
//
// in which SBias is broadcast along `i`.
ValueHandle Bias(biasValueArg);
StructuredIndexed SO(O), SBias(Bias);
linalg_pointwise(opBuilder, SO({i, j}), SBias({j}), SO({i, j}));
return O;
}