|  | // Copyright 2020 Google LLC | 
|  | // | 
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | // you may not use this file except in compliance with the License. | 
|  | // You may obtain a copy of the License at | 
|  | // | 
|  | //      https://www.apache.org/licenses/LICENSE-2.0 | 
|  | // | 
|  | // Unless required by applicable law or agreed to in writing, software | 
|  | // distributed under the License is distributed on an "AS IS" BASIS, | 
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | // See the License for the specific language governing permissions and | 
|  | // limitations under the License. | 
|  |  | 
|  | #include "experimental/ModelBuilder/ModelBuilder.h" | 
|  |  | 
|  | #include "mlir/Dialect/Affine/EDSC/Builders.h" | 
|  | #include "mlir/Dialect/Affine/IR/AffineOps.h" | 
|  | #include "mlir/Dialect/GPU/GPUDialect.h" | 
|  | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" | 
|  | #include "mlir/Dialect/Linalg/IR/LinalgOps.h" | 
|  | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" | 
|  | #include "mlir/Dialect/SCF/SCF.h" | 
|  | #include "mlir/Dialect/SPIRV/SPIRVDialect.h" | 
|  | #include "mlir/Dialect/SPIRV/TargetAndABI.h" | 
|  | #include "mlir/Dialect/Shape/IR/Shape.h" | 
|  | #include "mlir/Dialect/StandardOps/IR/Ops.h" | 
|  | #include "mlir/Dialect/Vector/VectorOps.h" | 
|  | #include "mlir/IR/AffineExpr.h" | 
|  | #include "mlir/IR/Dialect.h" | 
|  | #include "mlir/IR/StandardTypes.h" | 
|  | #include "mlir/IR/TypeUtilities.h" | 
|  |  | 
|  | using namespace mlir; | 
|  | using namespace mlir::edsc; | 
|  | using namespace mlir::edsc::ops; | 
|  | using namespace mlir::edsc::intrinsics; | 
|  |  | 
|  | thread_local MLIRContext mlir::ModelBuilder::ctx; | 
|  |  | 
|  | void ModelBuilder::registerAllDialects() { | 
|  | // TODO: remove. | 
|  | } | 
|  |  | 
|  | mlir::ModelBuilder::ModelBuilder() | 
|  | : OpBuilder(&ctx), | 
|  | module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&ctx))), | 
|  | symbolTable(*module), | 
|  | loc(module->getLoc()), | 
|  | i8(IntegerType::get(8, &ctx)), | 
|  | f32(FloatType::getF32(&ctx)), | 
|  | f64(FloatType::getF64(&ctx)) { | 
|  | ctx.getOrLoadDialect<AffineDialect>(); | 
|  | ctx.getOrLoadDialect<gpu::GPUDialect>(); | 
|  | ctx.getOrLoadDialect<LLVM::LLVMDialect>(); | 
|  | ctx.getOrLoadDialect<linalg::LinalgDialect>(); | 
|  | ctx.getOrLoadDialect<scf::SCFDialect>(); | 
|  | ctx.getOrLoadDialect<omp::OpenMPDialect>(); | 
|  | ctx.getOrLoadDialect<spirv::SPIRVDialect>(); | 
|  | ctx.getOrLoadDialect<StandardOpsDialect>(); | 
|  | ctx.getOrLoadDialect<vector::VectorDialect>(); | 
|  | } | 
|  |  | 
|  | Value mlir::ModelBuilder::constant_f32(float v) { | 
|  | return std_constant_float(llvm::APFloat(v), | 
|  | FloatType::getF32(ScopedContext::getContext())); | 
|  | } | 
|  |  | 
|  | Value mlir::ModelBuilder::constant_f64(double v) { | 
|  | return std_constant_float(llvm::APFloat(v), | 
|  | FloatType::getF64(ScopedContext::getContext())); | 
|  | } | 
|  |  | 
|  | Value mlir::ModelBuilder::constant_index(int64_t v) { | 
|  | return std_constant_index(v); | 
|  | } | 
|  |  | 
|  | FuncOp mlir::ModelBuilder::makeFunction(StringRef name, ArrayRef<Type> results, | 
|  | ArrayRef<Type> args, | 
|  | MLIRFuncOpConfig config) { | 
|  | FunctionType ft = FunctionType::get(args, results, &ctx); | 
|  | auto function = FuncOp::create(loc, name, ft); | 
|  | config.apply(function); | 
|  | module->push_back(function); | 
|  | return function; | 
|  | } | 
|  | FuncOp mlir::ModelBuilder::makeFunction( | 
|  | std::function<std::string(FunctionType)> nameBuilder, | 
|  | ArrayRef<Type> results, ArrayRef<Type> args, MLIRFuncOpConfig config) { | 
|  | FunctionType ft = FunctionType::get(args, results, &ctx); | 
|  | return makeFunction(nameBuilder(ft), results, args, config); | 
|  | } | 
|  |  | 
|  | static spirv::TargetEnvAttr getTargetEnv(MLIRContext *context) { | 
|  | auto triple = spirv::VerCapExtAttr::get( | 
|  | spirv::Version::V_1_0, | 
|  | {spirv::Capability::Shader, spirv::Capability::CooperativeMatrixNV, | 
|  | spirv::Capability::Int8, spirv::Capability::Float16, | 
|  | spirv::Capability::StorageUniform16, | 
|  | spirv::Capability::StorageBuffer8BitAccess, | 
|  | spirv::Capability::Float16Buffer}, | 
|  | {spirv::Extension::SPV_KHR_storage_buffer_storage_class, | 
|  | spirv::Extension::SPV_NV_cooperative_matrix, | 
|  | spirv::Extension::SPV_KHR_8bit_storage, | 
|  | spirv::Extension::SPV_KHR_16bit_storage}, | 
|  | context); | 
|  | return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown, | 
|  | spirv::DeviceType::Unknown, | 
|  | spirv::TargetEnvAttr::kUnknownDeviceID, | 
|  | spirv::getDefaultResourceLimits(context)); | 
|  | } | 
|  |  | 
|  | gpu::GPUModuleOp mlir::ModelBuilder::makeGPUModule(StringRef name) { | 
|  | // Add module attributes required first. | 
|  | addGPUAttr(); | 
|  | OpBuilder b(&module->getBodyRegion()); | 
|  | auto kernelModule = b.create<gpu::GPUModuleOp>(loc, name); | 
|  | return kernelModule; | 
|  | } | 
|  |  | 
|  | void mlir::ModelBuilder::addGPUAttr() { | 
|  | // Add module attributes required first. | 
|  | module->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), | 
|  | UnitAttr::get(module->getContext())); | 
|  | spirv::TargetEnvAttr targetEnv = getTargetEnv(module->getContext()); | 
|  | module->setAttr(spirv::getTargetEnvAttrName(), targetEnv); | 
|  | } | 
|  |  | 
|  | gpu::GPUFuncOp mlir::ModelBuilder::makeGPUKernel( | 
|  | StringRef name, gpu::GPUModuleOp GPUModule, ArrayRef<int32_t> workgroupSize, | 
|  | ArrayRef<Type> args, ArrayRef<Type> results) { | 
|  | auto fnType = FunctionType::get(args, results, module->getContext()); | 
|  | OpBuilder b(&GPUModule.body()); | 
|  | auto kernelFunc = b.create<gpu::GPUFuncOp>(loc, name, fnType); | 
|  | kernelFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(), b.getUnitAttr()); | 
|  | kernelFunc.setAttr( | 
|  | spirv::getEntryPointABIAttrName(), | 
|  | spirv::getEntryPointABIAttr(workgroupSize, module->getContext())); | 
|  | return kernelFunc; | 
|  | } | 
|  |  | 
|  | VectorType mlir::ModelBuilder::getVectorType(ArrayRef<int64_t> shape, | 
|  | Type elementalType) { | 
|  | return VectorType::get(shape, elementalType); | 
|  | } | 
|  |  | 
|  | MemRefType mlir::ModelBuilder::getMemRefType(ArrayRef<int64_t> shape, | 
|  | Type elementType, | 
|  | unsigned addressSpace) { | 
|  | return MemRefType::get(shape, elementType, {}, addressSpace); | 
|  | } | 
|  |  | 
|  | RankedTensorType mlir::ModelBuilder::getRankedTensorType( | 
|  | ArrayRef<int64_t> shape, Type elementType) { | 
|  | return RankedTensorType::get(shape, elementType); | 
|  | } | 
|  |  | 
|  | Value mlir::ModelBuilder::fusedBiasTanh(Value x, Value bias) { | 
|  | using edsc::op::operator+; | 
|  | using edsc::op::operator*; | 
|  | using edsc::intrinsics::std_call; | 
|  | assert(x.getType().isF32() && bias.getType().isF32() && "f32 expected"); | 
|  | Value half = constant_f32(0.5f); | 
|  | return x + half * call_tanhf((x + bias) * half) + half; | 
|  | } | 
|  |  | 
|  | Value mlir::ModelBuilder::FCBiasTanh(std::array<Value, 3> fcArgs, | 
|  | Value biasValueArg) { | 
|  | //==========================================================================// | 
|  | // Layer 1: FC | 
|  | //==========================================================================// | 
|  | Value I = fcArgs[0], W = fcArgs[1], O = fcArgs[2]; | 
|  | // Emit a linalg.generic op that implements matmul: | 
|  | linalg_generic_matmul(I, W, O); | 
|  |  | 
|  | //==========================================================================// | 
|  | // Layer 2: BiasAddTanh Block | 
|  | //==========================================================================// | 
|  | // Build and capture AffineExpr i and j for building index expressions. | 
|  | AffineExpr i, j; | 
|  | bindDims(&ctx, i, j); | 
|  |  | 
|  | // Emit a linalg.generic op that implements pointwise with `opBuilder` for: | 
|  | //   `0.5f * tanh(0.5f * (x + bias)) + 0.5f` | 
|  | // | 
|  | // This performs the (inplace) computation: | 
|  | //   `o[i, j] <- pointwise(bias[j], o[i, j])` | 
|  | // | 
|  | // in which bias is broadcast along `i`. | 
|  | StructuredIndexed o(O), bias(biasValueArg); | 
|  | linalg_generic_pointwise(fusedBiasTanh, o({i, j}), bias({j}), o({i, j})); | 
|  |  | 
|  | return O; | 
|  | } | 
|  |  | 
|  | Value ModelBuilder::FCBiasTanhTensors(RankedTensorType outputTensorType, | 
|  | std::array<Value, 2> fcArgs, | 
|  | Value fcInitTensor, Value biasValueArg) { | 
|  | //==========================================================================// | 
|  | // Layer 1: FC | 
|  | //==========================================================================// | 
|  | Value I = fcArgs[0], W = fcArgs[1]; | 
|  | Value O2 = | 
|  | linalg_generic_matmul(I, W, fcInitTensor, outputTensorType)->getResult(0); | 
|  |  | 
|  | //==========================================================================// | 
|  | // Layer 2: BiasAddTanh Block | 
|  | //==========================================================================// | 
|  | AffineExpr i, j; | 
|  | bindDims(&ctx, i, j); | 
|  | // in-place with explicit bias broacast | 
|  | StructuredIndexed o2(O2), bias(biasValueArg), o3Type(outputTensorType); | 
|  | return linalg_generic_pointwise(fusedBiasTanh, o2({i, j}), bias({j}), | 
|  | o3Type({i, j})) | 
|  | ->getResult(0); | 
|  | } | 
|  |  | 
|  | Value ModelBuilder::call_tanhf(Value v) { | 
|  | assert(v.getType().isF32() && "f32 expected"); | 
|  | return emitCallToRegisteredSymbol("tanhf", v.getType(), v)->getResult(0); | 
|  | } | 
|  |  | 
|  | void ModelBuilder::call_print_memref_f32(Value v) { | 
|  | auto &builder = ScopedContext::getBuilderRef(); | 
|  | auto loc = builder.getInsertionBlock() | 
|  | ->getParent() | 
|  | ->getParentOfType<FuncOp>() | 
|  | .getLoc(); | 
|  | auto elementType = v.getType().cast<MemRefType>().getElementType(); | 
|  | auto unrankedType = UnrankedMemRefType::get(elementType, 0); | 
|  | auto castMemRef = builder.create<MemRefCastOp>(loc, v, unrankedType); | 
|  | if (elementType.isF32()) | 
|  | emitCallToRegisteredSymbol("print_memref_f32", {}, {castMemRef}); | 
|  | else | 
|  | llvm_unreachable("Incorrect argument type for print_memref_f32"); | 
|  | } | 
|  |  | 
|  | Operation *ModelBuilder::emitCallToRegisteredSymbol(StringRef functionName, | 
|  | ArrayRef<Type> returnTypes, | 
|  | ValueRange values) { | 
|  | auto &builder = ScopedContext::getBuilderRef(); | 
|  | auto callerFunc = | 
|  | builder.getInsertionBlock()->getParent()->getParentOfType<FuncOp>(); | 
|  | FuncOp calleeFunc = | 
|  | SymbolTable::lookupNearestSymbolFrom<FuncOp>(callerFunc, functionName); | 
|  | if (!calleeFunc) { | 
|  | OpBuilder::InsertionGuard insertGuard(builder); | 
|  | auto module = callerFunc.getParentOfType<ModuleOp>(); | 
|  | builder.setInsertionPointToStart(module.getBody()); | 
|  | calleeFunc = builder.create<FuncOp>( | 
|  | module.getLoc(), functionName, | 
|  | FunctionType::get(SmallVector<Type, 4>(values.getTypes()), returnTypes, | 
|  | builder.getContext())); | 
|  | calleeFunc.setPrivate(); | 
|  | } | 
|  | return std_call(calleeFunc, values); | 
|  | } | 
|  |  | 
|  | MLIRFuncOpConfig &MLIRFuncOpConfig::setNoInline(bool v) { | 
|  | noInline = v; | 
|  | return *this; | 
|  | } | 
|  | MLIRFuncOpConfig &MLIRFuncOpConfig::setPreferAvx512(bool v) { | 
|  | preferAvx512 = v; | 
|  | return *this; | 
|  | } | 
|  | MLIRFuncOpConfig &MLIRFuncOpConfig::setTargetCpu(StringRef s) { | 
|  | targetCpu = std::string(s); | 
|  | return *this; | 
|  | } | 
|  | MLIRFuncOpConfig &MLIRFuncOpConfig::setDeclOnly(bool v) { | 
|  | declOnly = v; | 
|  | return *this; | 
|  | } | 
|  | MLIRFuncOpConfig &MLIRFuncOpConfig::setEmitCInterface(bool v) { | 
|  | emitCInterface = v; | 
|  | return *this; | 
|  | } | 
|  |  | 
|  | void MLIRFuncOpConfig::apply(FuncOp &f) { | 
|  | MLIRContext *ctx = f.getContext(); | 
|  | SmallVector<Attribute, 8> attrs; | 
|  | if (noInline) attrs.push_back(StringAttr::get("noinline", ctx)); | 
|  | if (preferAvx512) | 
|  | attrs.push_back(ArrayAttr::get({StringAttr::get("prefer-vector-width", ctx), | 
|  | StringAttr::get("512", ctx)}, | 
|  | ctx)); | 
|  | if (!targetCpu.empty()) | 
|  | attrs.push_back(ArrayAttr::get( | 
|  | {StringAttr::get("target-cpu", ctx), StringAttr::get(targetCpu, ctx)}, | 
|  | ctx)); | 
|  | if (!attrs.empty()) f.setAttr("passthrough", ArrayAttr::get(attrs, ctx)); | 
|  |  | 
|  | if (emitCInterface) | 
|  | f.setAttr("llvm.emit_c_interface", mlir::UnitAttr::get(ctx)); | 
|  |  | 
|  | if (!declOnly) | 
|  | f.addEntryBlock(); | 
|  | else | 
|  | f.setPrivate(); | 
|  | } | 
|  |  | 
|  | // ----------------------------------------------------------------------------- | 
|  | // EDSC extensions. | 
|  | // ----------------------------------------------------------------------------- | 
|  | template <typename Lambda> | 
|  | static SmallVector<Value, 4> valueRangeOperatorImpl(Lambda fun, ValueRange a, | 
|  | ValueRange b) { | 
|  | SmallVector<Value, 4> res; | 
|  | res.reserve(std::min(a.size(), b.size())); | 
|  | for (auto it : llvm::zip(a, b)) | 
|  | res.push_back(fun(std::get<0>(it), std::get<1>(it))); | 
|  | return res; | 
|  | } | 
|  | SmallVector<Value, 4> mlir::edsc::extensions::operator-(ValueRange a, | 
|  | ValueRange b) { | 
|  | return valueRangeOperatorImpl(edsc::op::operator-, a, b); | 
|  | } | 
|  | SmallVector<Value, 4> mlir::edsc::extensions::operator+(ValueRange a, | 
|  | ValueRange b) { | 
|  | return valueRangeOperatorImpl(edsc::op::operator+, a, b); | 
|  | } | 
|  | SmallVector<Value, 4> mlir::edsc::extensions::std_max(ValueRange a, | 
|  | ValueRange b) { | 
|  | using edsc::op::slt; | 
|  | auto fun = [](Value va, Value vb) { return slt(va, vb) ? vb : va; }; | 
|  | return valueRangeOperatorImpl(fun, a, b); | 
|  | } | 
|  | SmallVector<Value, 4> mlir::edsc::extensions::std_min(ValueRange a, | 
|  | ValueRange b) { | 
|  | using edsc::op::slt; | 
|  | auto fun = [](Value va, Value vb) { return slt(va, vb) ? va : vb; }; | 
|  | return valueRangeOperatorImpl(fun, a, b); | 
|  | } | 
|  | SmallVector<Value, 4> mlir::edsc::extensions::affine_max(ValueRange a, | 
|  | ValueRange b) { | 
|  | // TODO(ntv): cleanup when affine_max accepts has more idiomatic builders. | 
|  | MLIRContext *ctx = ScopedContext::getContext(); | 
|  | auto map = AffineMap::get( | 
|  | 2, 0, {getAffineDimExpr(0, ctx), getAffineDimExpr(1, ctx)}, ctx); | 
|  | auto fun = [&](Value va, Value vb) { | 
|  | return intrinsics::affine_max(map, ValueRange{va, vb}); | 
|  | }; | 
|  | return valueRangeOperatorImpl(fun, a, b); | 
|  | } | 
|  | SmallVector<Value, 4> mlir::edsc::extensions::affine_min(ValueRange a, | 
|  | ValueRange b) { | 
|  | // TODO(ntv): cleanup when affine_min accepts has more idiomatic builders. | 
|  | MLIRContext *ctx = ScopedContext::getContext(); | 
|  | auto map = AffineMap::get( | 
|  | 2, 0, {getAffineDimExpr(0, ctx), getAffineDimExpr(1, ctx)}, ctx); | 
|  | auto fun = [&](Value va, Value vb) { | 
|  | return intrinsics::affine_min(map, ValueRange{va, vb}); | 
|  | }; | 
|  | return valueRangeOperatorImpl(fun, a, b); | 
|  | } |