blob: f3c6b0be4dca2f5b6c487582d5a2085b0c148a97 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "experimental/ModelBuilder/ModelBuilder.h"
#include "mlir/Dialect/Affine/EDSC/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::ops;
using namespace mlir::edsc::intrinsics;
thread_local MLIRContext mlir::ModelBuilder::ctx;
mlir::ModelBuilder::ModelBuilder()
: OpBuilder(&ctx),
module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&ctx))),
symbolTable(*module),
loc(module->getLoc()),
i8(IntegerType::get(8, &ctx)),
f32(FloatType::getF32(&ctx)),
f64(FloatType::getF64(&ctx)) {}
Value mlir::ModelBuilder::constant_f32(float v) {
return std_constant_float(llvm::APFloat(v),
FloatType::getF32(ScopedContext::getContext()));
}
Value mlir::ModelBuilder::constant_f64(double v) {
return std_constant_float(llvm::APFloat(v),
FloatType::getF64(ScopedContext::getContext()));
}
FuncOp mlir::ModelBuilder::makeFunction(StringRef name, ArrayRef<Type> results,
ArrayRef<Type> args,
bool emitCInterface, bool declOnly) {
auto function =
FuncOp::create(loc, name, FunctionType::get(args, results, &ctx));
if (!declOnly) function.addEntryBlock();
if (emitCInterface)
function.setAttr("llvm.emit_c_interface",
mlir::UnitAttr::get(getContext()));
module->push_back(function);
return function;
}
VectorType mlir::ModelBuilder::getVectorType(ArrayRef<int64_t> shape,
Type elementalType) {
return VectorType::get(shape, elementalType);
}
MemRefType mlir::ModelBuilder::getMemRefType(ArrayRef<int64_t> shape,
Type elementType,
unsigned addressSpace) {
return MemRefType::get(shape, elementType, {}, addressSpace);
}
RankedTensorType mlir::ModelBuilder::getRankedTensorType(
ArrayRef<int64_t> shape, Type elementType) {
return RankedTensorType::get(shape, elementType);
}
Value mlir::ModelBuilder::fusedBiasTanh(ValueHandle x, ValueHandle bias) {
using edsc::op::operator+;
using edsc::op::operator*;
using edsc::intrinsics::std_call;
assert(x.getType().isF32() && bias.getType().isF32() && "f32 expected");
ValueHandle half(constant_f32(0.5f));
return x + half * call_tanhf((x + bias) * half) + half;
}
ValueHandle mlir::ModelBuilder::FCBiasTanh(std::array<Value, 3> fcArgs,
Value biasValueArg) {
//==========================================================================//
// Layer 1: FC
//==========================================================================//
ValueHandle I(fcArgs[0]), W(fcArgs[1]), O(fcArgs[2]);
// Emit a linalg.generic op that implements matmul:
linalg_generic_matmul(I, W, O);
//==========================================================================//
// Layer 2: BiasAddTanh Block
//==========================================================================//
// Build and capture AffineExpr i and j for building index expressions.
AffineExpr i, j;
bindDims(&ctx, i, j);
// Emit a linalg.generic op that implements pointwise with `opBuilder` for:
// `0.5f * tanh(0.5f * (x + bias)) + 0.5f`
//
// This performs the (inplace) computation:
// `o[i, j] <- pointwise(bias[j], o[i, j])`
//
// in which bias is broadcast along `i`.
StructuredIndexed o(O), bias(biasValueArg);
linalg_generic_pointwise(fusedBiasTanh, o({i, j}), bias({j}), o({i, j}));
return O;
}
Value ModelBuilder::FCBiasTanhTensors(RankedTensorType outputTensorType,
std::array<Value, 2> fcArgs,
Value biasValueArg) {
//==========================================================================//
// Layer 1: FC
//==========================================================================//
ValueHandle I(fcArgs[0]), W(fcArgs[1]);
Value O2 = linalg_generic_matmul(I, W, outputTensorType)->getResult(0);
//==========================================================================//
// Layer 2: BiasAddTanh Block
//==========================================================================//
ValueHandle Bias(biasValueArg);
AffineExpr i, j;
bindDims(&ctx, i, j);
// in-place with explicit bias broacast
StructuredIndexed o2(O2), bias(Bias), o3Type(outputTensorType);
return linalg_generic_pointwise(fusedBiasTanh, o2({i, j}), bias({j}),
o3Type({i, j}))
->getResult(0);
}
ValueHandle ModelBuilder::call_tanhf(Value v) {
assert(v.getType().isF32() && "f32 expected");
return ValueHandle(
emitCallToRegisteredSymbol("tanhf", v.getType(), v)->getResult(0));
}
Operation *ModelBuilder::emitCallToRegisteredSymbol(StringRef functionName,
ArrayRef<Type> returnTypes,
ValueRange values) {
auto &builder = ScopedContext::getBuilder();
auto funcOp =
builder.getInsertionBlock()->getParent()->getParentOfType<FuncOp>();
Operation *func = SymbolTable::lookupNearestSymbolFrom(funcOp, functionName);
if (!func) {
OpBuilder::InsertionGuard insertGuard(builder);
auto module = funcOp.getParentOfType<ModuleOp>();
builder.setInsertionPointToStart(module.getBody());
func = builder.create<FuncOp>(
module.getLoc(), functionName,
FunctionType::get(SmallVector<Type, 4>(values.getTypes()), returnTypes,
builder.getContext()),
ArrayRef<NamedAttribute>{});
}
return std_call(builder.getSymbolRefAttr(func), returnTypes, values)
.getOperation();
}