| // Copyright 2020 Google LLC | 
 | // | 
 | // Licensed under the Apache License, Version 2.0 (the "License"); | 
 | // you may not use this file except in compliance with the License. | 
 | // You may obtain a copy of the License at | 
 | // | 
 | //      https://www.apache.org/licenses/LICENSE-2.0 | 
 | // | 
 | // Unless required by applicable law or agreed to in writing, software | 
 | // distributed under the License is distributed on an "AS IS" BASIS, | 
 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
 | // See the License for the specific language governing permissions and | 
 | // limitations under the License. | 
 |  | 
 | #include "experimental/ModelBuilder/ModelBuilder.h" | 
 |  | 
 | #include "mlir/Dialect/Affine/EDSC/Builders.h" | 
 | #include "mlir/Dialect/Affine/IR/AffineOps.h" | 
 | #include "mlir/Dialect/GPU/GPUDialect.h" | 
 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" | 
 | #include "mlir/Dialect/Linalg/IR/LinalgOps.h" | 
 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" | 
 | #include "mlir/Dialect/SCF/SCF.h" | 
 | #include "mlir/Dialect/SPIRV/SPIRVDialect.h" | 
 | #include "mlir/Dialect/SPIRV/TargetAndABI.h" | 
 | #include "mlir/Dialect/Shape/IR/Shape.h" | 
 | #include "mlir/Dialect/StandardOps/IR/Ops.h" | 
 | #include "mlir/Dialect/Vector/VectorOps.h" | 
 | #include "mlir/IR/AffineExpr.h" | 
 | #include "mlir/IR/Dialect.h" | 
 | #include "mlir/IR/StandardTypes.h" | 
 | #include "mlir/IR/TypeUtilities.h" | 
 |  | 
 | using namespace mlir; | 
 | using namespace mlir::edsc; | 
 | using namespace mlir::edsc::ops; | 
 | using namespace mlir::edsc::intrinsics; | 
 |  | 
 | thread_local MLIRContext mlir::ModelBuilder::ctx; | 
 |  | 
 | void ModelBuilder::registerAllDialects() { | 
 |   // TODO: remove. | 
 | } | 
 |  | 
 | mlir::ModelBuilder::ModelBuilder() | 
 |     : OpBuilder(&ctx), | 
 |       module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&ctx))), | 
 |       symbolTable(*module), | 
 |       loc(module->getLoc()), | 
 |       i8(IntegerType::get(8, &ctx)), | 
 |       f32(FloatType::getF32(&ctx)), | 
 |       f64(FloatType::getF64(&ctx)) { | 
 |   ctx.getOrLoadDialect<AffineDialect>(); | 
 |   ctx.getOrLoadDialect<gpu::GPUDialect>(); | 
 |   ctx.getOrLoadDialect<LLVM::LLVMDialect>(); | 
 |   ctx.getOrLoadDialect<linalg::LinalgDialect>(); | 
 |   ctx.getOrLoadDialect<scf::SCFDialect>(); | 
 |   ctx.getOrLoadDialect<omp::OpenMPDialect>(); | 
 |   ctx.getOrLoadDialect<spirv::SPIRVDialect>(); | 
 |   ctx.getOrLoadDialect<StandardOpsDialect>(); | 
 |   ctx.getOrLoadDialect<vector::VectorDialect>(); | 
 | } | 
 |  | 
 | Value mlir::ModelBuilder::constant_f32(float v) { | 
 |   return std_constant_float(llvm::APFloat(v), | 
 |                             FloatType::getF32(ScopedContext::getContext())); | 
 | } | 
 |  | 
 | Value mlir::ModelBuilder::constant_f64(double v) { | 
 |   return std_constant_float(llvm::APFloat(v), | 
 |                             FloatType::getF64(ScopedContext::getContext())); | 
 | } | 
 |  | 
 | Value mlir::ModelBuilder::constant_index(int64_t v) { | 
 |   return std_constant_index(v); | 
 | } | 
 |  | 
 | FuncOp mlir::ModelBuilder::makeFunction(StringRef name, ArrayRef<Type> results, | 
 |                                         ArrayRef<Type> args, | 
 |                                         MLIRFuncOpConfig config) { | 
 |   FunctionType ft = FunctionType::get(args, results, &ctx); | 
 |   auto function = FuncOp::create(loc, name, ft); | 
 |   config.apply(function); | 
 |   module->push_back(function); | 
 |   return function; | 
 | } | 
 | FuncOp mlir::ModelBuilder::makeFunction( | 
 |     std::function<std::string(FunctionType)> nameBuilder, | 
 |     ArrayRef<Type> results, ArrayRef<Type> args, MLIRFuncOpConfig config) { | 
 |   FunctionType ft = FunctionType::get(args, results, &ctx); | 
 |   return makeFunction(nameBuilder(ft), results, args, config); | 
 | } | 
 |  | 
 | static spirv::TargetEnvAttr getTargetEnv(MLIRContext *context) { | 
 |   auto triple = spirv::VerCapExtAttr::get( | 
 |       spirv::Version::V_1_0, | 
 |       {spirv::Capability::Shader, spirv::Capability::CooperativeMatrixNV, | 
 |        spirv::Capability::Int8, spirv::Capability::Float16, | 
 |        spirv::Capability::StorageUniform16, | 
 |        spirv::Capability::StorageBuffer8BitAccess, | 
 |        spirv::Capability::Float16Buffer}, | 
 |       {spirv::Extension::SPV_KHR_storage_buffer_storage_class, | 
 |        spirv::Extension::SPV_NV_cooperative_matrix, | 
 |        spirv::Extension::SPV_KHR_8bit_storage, | 
 |        spirv::Extension::SPV_KHR_16bit_storage}, | 
 |       context); | 
 |   return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown, | 
 |                                    spirv::DeviceType::Unknown, | 
 |                                    spirv::TargetEnvAttr::kUnknownDeviceID, | 
 |                                    spirv::getDefaultResourceLimits(context)); | 
 | } | 
 |  | 
 | gpu::GPUModuleOp mlir::ModelBuilder::makeGPUModule(StringRef name) { | 
 |   // Add module attributes required first. | 
 |   addGPUAttr(); | 
 |   OpBuilder b(&module->getBodyRegion()); | 
 |   auto kernelModule = b.create<gpu::GPUModuleOp>(loc, name); | 
 |   return kernelModule; | 
 | } | 
 |  | 
 | void mlir::ModelBuilder::addGPUAttr() { | 
 |   // Add module attributes required first. | 
 |   module->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), | 
 |                   UnitAttr::get(module->getContext())); | 
 |   spirv::TargetEnvAttr targetEnv = getTargetEnv(module->getContext()); | 
 |   module->setAttr(spirv::getTargetEnvAttrName(), targetEnv); | 
 | } | 
 |  | 
 | gpu::GPUFuncOp mlir::ModelBuilder::makeGPUKernel( | 
 |     StringRef name, gpu::GPUModuleOp GPUModule, ArrayRef<int32_t> workgroupSize, | 
 |     ArrayRef<Type> args, ArrayRef<Type> results) { | 
 |   auto fnType = FunctionType::get(args, results, module->getContext()); | 
 |   OpBuilder b(&GPUModule.body()); | 
 |   auto kernelFunc = b.create<gpu::GPUFuncOp>(loc, name, fnType); | 
 |   kernelFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(), b.getUnitAttr()); | 
 |   kernelFunc.setAttr( | 
 |       spirv::getEntryPointABIAttrName(), | 
 |       spirv::getEntryPointABIAttr(workgroupSize, module->getContext())); | 
 |   return kernelFunc; | 
 | } | 
 |  | 
 | VectorType mlir::ModelBuilder::getVectorType(ArrayRef<int64_t> shape, | 
 |                                              Type elementalType) { | 
 |   return VectorType::get(shape, elementalType); | 
 | } | 
 |  | 
 | MemRefType mlir::ModelBuilder::getMemRefType(ArrayRef<int64_t> shape, | 
 |                                              Type elementType, | 
 |                                              unsigned addressSpace) { | 
 |   return MemRefType::get(shape, elementType, {}, addressSpace); | 
 | } | 
 |  | 
 | RankedTensorType mlir::ModelBuilder::getRankedTensorType( | 
 |     ArrayRef<int64_t> shape, Type elementType) { | 
 |   return RankedTensorType::get(shape, elementType); | 
 | } | 
 |  | 
 | Value mlir::ModelBuilder::fusedBiasTanh(Value x, Value bias) { | 
 |   using edsc::op::operator+; | 
 |   using edsc::op::operator*; | 
 |   using edsc::intrinsics::std_call; | 
 |   assert(x.getType().isF32() && bias.getType().isF32() && "f32 expected"); | 
 |   Value half = constant_f32(0.5f); | 
 |   return x + half * call_tanhf((x + bias) * half) + half; | 
 | } | 
 |  | 
 | Value mlir::ModelBuilder::FCBiasTanh(std::array<Value, 3> fcArgs, | 
 |                                      Value biasValueArg) { | 
 |   //==========================================================================// | 
 |   // Layer 1: FC | 
 |   //==========================================================================// | 
 |   Value I = fcArgs[0], W = fcArgs[1], O = fcArgs[2]; | 
 |   // Emit a linalg.generic op that implements matmul: | 
 |   linalg_generic_matmul(I, W, O); | 
 |  | 
 |   //==========================================================================// | 
 |   // Layer 2: BiasAddTanh Block | 
 |   //==========================================================================// | 
 |   // Build and capture AffineExpr i and j for building index expressions. | 
 |   AffineExpr i, j; | 
 |   bindDims(&ctx, i, j); | 
 |  | 
 |   // Emit a linalg.generic op that implements pointwise with `opBuilder` for: | 
 |   //   `0.5f * tanh(0.5f * (x + bias)) + 0.5f` | 
 |   // | 
 |   // This performs the (inplace) computation: | 
 |   //   `o[i, j] <- pointwise(bias[j], o[i, j])` | 
 |   // | 
 |   // in which bias is broadcast along `i`. | 
 |   StructuredIndexed o(O), bias(biasValueArg); | 
 |   linalg_generic_pointwise(fusedBiasTanh, o({i, j}), bias({j}), o({i, j})); | 
 |  | 
 |   return O; | 
 | } | 
 |  | 
 | Value ModelBuilder::FCBiasTanhTensors(RankedTensorType outputTensorType, | 
 |                                       std::array<Value, 2> fcArgs, | 
 |                                       Value fcInitTensor, Value biasValueArg) { | 
 |   //==========================================================================// | 
 |   // Layer 1: FC | 
 |   //==========================================================================// | 
 |   Value I = fcArgs[0], W = fcArgs[1]; | 
 |   Value O2 = | 
 |       linalg_generic_matmul(I, W, fcInitTensor, outputTensorType)->getResult(0); | 
 |  | 
 |   //==========================================================================// | 
 |   // Layer 2: BiasAddTanh Block | 
 |   //==========================================================================// | 
 |   AffineExpr i, j; | 
 |   bindDims(&ctx, i, j); | 
 |   // in-place with explicit bias broacast | 
 |   StructuredIndexed o2(O2), bias(biasValueArg), o3Type(outputTensorType); | 
 |   return linalg_generic_pointwise(fusedBiasTanh, o2({i, j}), bias({j}), | 
 |                                   o3Type({i, j})) | 
 |       ->getResult(0); | 
 | } | 
 |  | 
 | Value ModelBuilder::call_tanhf(Value v) { | 
 |   assert(v.getType().isF32() && "f32 expected"); | 
 |   return emitCallToRegisteredSymbol("tanhf", v.getType(), v)->getResult(0); | 
 | } | 
 |  | 
 | void ModelBuilder::call_print_memref_f32(Value v) { | 
 |   auto &builder = ScopedContext::getBuilderRef(); | 
 |   auto loc = builder.getInsertionBlock() | 
 |                  ->getParent() | 
 |                  ->getParentOfType<FuncOp>() | 
 |                  .getLoc(); | 
 |   auto elementType = v.getType().cast<MemRefType>().getElementType(); | 
 |   auto unrankedType = UnrankedMemRefType::get(elementType, 0); | 
 |   auto castMemRef = builder.create<MemRefCastOp>(loc, v, unrankedType); | 
 |   if (elementType.isF32()) | 
 |     emitCallToRegisteredSymbol("print_memref_f32", {}, {castMemRef}); | 
 |   else | 
 |     llvm_unreachable("Incorrect argument type for print_memref_f32"); | 
 | } | 
 |  | 
 | Operation *ModelBuilder::emitCallToRegisteredSymbol(StringRef functionName, | 
 |                                                     ArrayRef<Type> returnTypes, | 
 |                                                     ValueRange values) { | 
 |   auto &builder = ScopedContext::getBuilderRef(); | 
 |   auto callerFunc = | 
 |       builder.getInsertionBlock()->getParent()->getParentOfType<FuncOp>(); | 
 |   FuncOp calleeFunc = | 
 |       SymbolTable::lookupNearestSymbolFrom<FuncOp>(callerFunc, functionName); | 
 |   if (!calleeFunc) { | 
 |     OpBuilder::InsertionGuard insertGuard(builder); | 
 |     auto module = callerFunc.getParentOfType<ModuleOp>(); | 
 |     builder.setInsertionPointToStart(module.getBody()); | 
 |     calleeFunc = builder.create<FuncOp>( | 
 |         module.getLoc(), functionName, | 
 |         FunctionType::get(SmallVector<Type, 4>(values.getTypes()), returnTypes, | 
 |                           builder.getContext())); | 
 |     calleeFunc.setPrivate(); | 
 |   } | 
 |   return std_call(calleeFunc, values); | 
 | } | 
 |  | 
 | MLIRFuncOpConfig &MLIRFuncOpConfig::setNoInline(bool v) { | 
 |   noInline = v; | 
 |   return *this; | 
 | } | 
 | MLIRFuncOpConfig &MLIRFuncOpConfig::setPreferAvx512(bool v) { | 
 |   preferAvx512 = v; | 
 |   return *this; | 
 | } | 
 | MLIRFuncOpConfig &MLIRFuncOpConfig::setTargetCpu(StringRef s) { | 
 |   targetCpu = std::string(s); | 
 |   return *this; | 
 | } | 
 | MLIRFuncOpConfig &MLIRFuncOpConfig::setDeclOnly(bool v) { | 
 |   declOnly = v; | 
 |   return *this; | 
 | } | 
 | MLIRFuncOpConfig &MLIRFuncOpConfig::setEmitCInterface(bool v) { | 
 |   emitCInterface = v; | 
 |   return *this; | 
 | } | 
 |  | 
 | void MLIRFuncOpConfig::apply(FuncOp &f) { | 
 |   MLIRContext *ctx = f.getContext(); | 
 |   SmallVector<Attribute, 8> attrs; | 
 |   if (noInline) attrs.push_back(StringAttr::get("noinline", ctx)); | 
 |   if (preferAvx512) | 
 |     attrs.push_back(ArrayAttr::get({StringAttr::get("prefer-vector-width", ctx), | 
 |                                     StringAttr::get("512", ctx)}, | 
 |                                    ctx)); | 
 |   if (!targetCpu.empty()) | 
 |     attrs.push_back(ArrayAttr::get( | 
 |         {StringAttr::get("target-cpu", ctx), StringAttr::get(targetCpu, ctx)}, | 
 |         ctx)); | 
 |   if (!attrs.empty()) f.setAttr("passthrough", ArrayAttr::get(attrs, ctx)); | 
 |  | 
 |   if (emitCInterface) | 
 |     f.setAttr("llvm.emit_c_interface", mlir::UnitAttr::get(ctx)); | 
 |  | 
 |   if (!declOnly) | 
 |     f.addEntryBlock(); | 
 |   else | 
 |     f.setPrivate(); | 
 | } | 
 |  | 
 | // ----------------------------------------------------------------------------- | 
 | // EDSC extensions. | 
 | // ----------------------------------------------------------------------------- | 
 | template <typename Lambda> | 
 | static SmallVector<Value, 4> valueRangeOperatorImpl(Lambda fun, ValueRange a, | 
 |                                                     ValueRange b) { | 
 |   SmallVector<Value, 4> res; | 
 |   res.reserve(std::min(a.size(), b.size())); | 
 |   for (auto it : llvm::zip(a, b)) | 
 |     res.push_back(fun(std::get<0>(it), std::get<1>(it))); | 
 |   return res; | 
 | } | 
 | SmallVector<Value, 4> mlir::edsc::extensions::operator-(ValueRange a, | 
 |                                                         ValueRange b) { | 
 |   return valueRangeOperatorImpl(edsc::op::operator-, a, b); | 
 | } | 
 | SmallVector<Value, 4> mlir::edsc::extensions::operator+(ValueRange a, | 
 |                                                         ValueRange b) { | 
 |   return valueRangeOperatorImpl(edsc::op::operator+, a, b); | 
 | } | 
 | SmallVector<Value, 4> mlir::edsc::extensions::std_max(ValueRange a, | 
 |                                                       ValueRange b) { | 
 |   using edsc::op::slt; | 
 |   auto fun = [](Value va, Value vb) { return slt(va, vb) ? vb : va; }; | 
 |   return valueRangeOperatorImpl(fun, a, b); | 
 | } | 
 | SmallVector<Value, 4> mlir::edsc::extensions::std_min(ValueRange a, | 
 |                                                       ValueRange b) { | 
 |   using edsc::op::slt; | 
 |   auto fun = [](Value va, Value vb) { return slt(va, vb) ? va : vb; }; | 
 |   return valueRangeOperatorImpl(fun, a, b); | 
 | } | 
 | SmallVector<Value, 4> mlir::edsc::extensions::affine_max(ValueRange a, | 
 |                                                          ValueRange b) { | 
 |   // TODO(ntv): cleanup when affine_max accepts has more idiomatic builders. | 
 |   MLIRContext *ctx = ScopedContext::getContext(); | 
 |   auto map = AffineMap::get( | 
 |       2, 0, {getAffineDimExpr(0, ctx), getAffineDimExpr(1, ctx)}, ctx); | 
 |   auto fun = [&](Value va, Value vb) { | 
 |     return intrinsics::affine_max(map, ValueRange{va, vb}); | 
 |   }; | 
 |   return valueRangeOperatorImpl(fun, a, b); | 
 | } | 
 | SmallVector<Value, 4> mlir::edsc::extensions::affine_min(ValueRange a, | 
 |                                                          ValueRange b) { | 
 |   // TODO(ntv): cleanup when affine_min accepts has more idiomatic builders. | 
 |   MLIRContext *ctx = ScopedContext::getContext(); | 
 |   auto map = AffineMap::get( | 
 |       2, 0, {getAffineDimExpr(0, ctx), getAffineDimExpr(1, ctx)}, ctx); | 
 |   auto fun = [&](Value va, Value vb) { | 
 |     return intrinsics::affine_min(map, ValueRange{va, vb}); | 
 |   }; | 
 |   return valueRangeOperatorImpl(fun, a, b); | 
 | } |