blob: a8590ac4ddd1c5cb8f18dcf62694dde62c72172f [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// clang-format off
// NOLINTNEXTLINE
// RUN: test-matmul-vulkan -vulkan-wrapper=$(dirname %s)/../../../../llvm/llvm-project/mlir/tools/libvulkan-runtime-wrappers.so 2>&1 | IreeFileCheck %s
// NOLINTNEXTLINE
// RUN: test-matmul-vulkan -vulkan-wrapper=$(dirname %s)/../../../../llvm/llvm-project/mlir/tools/libvulkan-runtime-wrappers.so -use-workgroup-memory -workgroup-size=2,2 -tile-sizes=2,2 2>&1 | IreeFileCheck %s
// clang-format on
#include <string>
#include "mlir/ExecutionEngine/RunnerUtils.h"
#include "experimental/ModelBuilder/ModelBuilder.h"
#include "experimental/ModelBuilder/ModelRunner.h"
#include "experimental/ModelBuilder/VulkanWrapperPass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser.h"
#include "iree/base/initializer.h"
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
#include "mlir/Pass/PassManager.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
static llvm::cl::opt<std::string> vulkanWrapper(
"vulkan-wrapper", llvm::cl::desc("Vulkan wrapper library"),
llvm::cl::value_desc("filename"), llvm::cl::init("-"));
static llvm::cl::opt<bool> useWorkgroupMemory(
"use-workgroup-memory", llvm::cl::desc("Enable use of workgroup memory"),
llvm::cl::value_desc("boolean"), llvm::cl::init(false));
static llvm::cl::list<int> workgroupSize(
"workgroup-size", llvm::cl::desc("Workgroup size to use"),
llvm::cl::CommaSeparated);
static llvm::cl::list<int> tileSizes("tile-sizes",
llvm::cl::desc("Tile sizes to use"),
llvm::cl::CommaSeparated);
using namespace mlir; // NOLINT
using namespace mlir::edsc; // NOLINT
using namespace mlir::edsc::intrinsics; // NOLINT
void testMatMul() {
const int height = 4;
const int width = 4;
StringLiteral funcName = "kernel_matmul";
MLIRContext context;
ModelBuilder modelBuilder;
auto typeA = modelBuilder.getMemRefType({width, height}, modelBuilder.f32);
auto typeB = modelBuilder.getMemRefType({width, height}, modelBuilder.f32);
auto typeC = modelBuilder.getMemRefType({width, height}, modelBuilder.f32);
// 1. Build the kernel.
{
modelBuilder.addGPUAttr();
// create kernel
FuncOp kernelFunc = modelBuilder.makeFunction(
funcName, {}, {typeA, typeB, typeC}, MLIRFuncOpConfig());
OpBuilder b(&kernelFunc.getBody());
ScopedContext scope(b, kernelFunc.getLoc());
Value A = kernelFunc.getArgument(0);
Value B = kernelFunc.getArgument(1);
Value C = kernelFunc.getArgument(2);
(linalg_matmul(ValueRange{A, B}, ValueRange{C}));
std_ret();
}
// 2. Compile the function, pass in runtime support library
// to the execution engine for vector.print.
ModelRunner runner(modelBuilder.getModuleRef(),
ModelRunner::Target::GPUTarget);
CompilationOptions options;
mlir::iree_compiler::SPIRVCodegenOptions codegenOptions;
SmallVector<Type, 3> args = {typeA, typeB, typeC};
codegenOptions.workgroupSize.assign(workgroupSize.begin(),
workgroupSize.end());
codegenOptions.tileSizes.assign(tileSizes.begin(), tileSizes.end());
codegenOptions.useWorkgroupMemory = useWorkgroupMemory;
auto lowering = [&](mlir::PassManager &pm) {
pm.addPass(
mlir::iree_compiler::createLinalgTileAndFusePass(codegenOptions));
pm.addPass(mlir::iree_compiler::createConvertToGPUPass());
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createLegalizeStdOpsForSPIRVLoweringPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::iree_compiler::createConvertToSPIRVPass());
auto &spirvModulePM = pm.nest<mlir::spirv::ModuleOp>();
spirvModulePM.addPass(mlir::createSetSpirvABIPass());
spirvModulePM.addPass(mlir::spirv::createLowerABIAttributesPass());
spirvModulePM.addPass(mlir::createCanonicalizerPass());
spirvModulePM.addPass(mlir::createCSEPass());
spirvModulePM.addPass(
mlir::spirv::createUpdateVersionCapabilityExtensionPass());
int numWorkgroupX = codegenOptions.tileSizes.empty()
? 1
: (width + codegenOptions.tileSizes[0] - 1) /
codegenOptions.tileSizes[0];
int numWorkgroupY = codegenOptions.tileSizes.size() < 2
? 1
: (height + codegenOptions.tileSizes[1] - 1) /
codegenOptions.tileSizes[1];
pm.addPass(mlir::createAddVulkanLaunchWrapperPass(
{numWorkgroupX, numWorkgroupY, 1}, args));
mlir::LowerToLLVMOptions llvmOptions = {
/*useBarePtrCallConv =*/false,
/*emitCWrappers = */ true,
/*indexBitwidth =*/mlir::kDeriveIndexBitwidthFromDataLayout};
pm.addPass(createLowerToLLVMPass(llvmOptions));
pm.addPass(mlir::createConvertVulkanLaunchFuncToVulkanCallsPass());
};
options.loweringPasses = lowering;
runner.compile(options, {vulkanWrapper});
// 3. Allocate data within data structures that interoperate with the MLIR ABI
// conventions used by codegen.
auto oneInit = [](unsigned idx, float *ptr) { ptr[idx] = 2.0f + 3 * idx; };
auto incInit = [](unsigned idx, float *ptr) { ptr[idx] = 1.0f + idx; };
auto zeroInit = [](unsigned idx, float *ptr) { ptr[idx] = 0.0f; };
auto A = makeInitializedStridedMemRefDescriptor<float, 2>({width, height},
oneInit);
auto B = makeInitializedStridedMemRefDescriptor<float, 2>({width, height},
incInit);
auto C = makeInitializedStridedMemRefDescriptor<float, 2>({width, height},
zeroInit);
// 4. Call the funcOp named `funcName`.
auto err = runner.invoke(std::string(funcName) + "_wrapper", A, B, C);
if (err) llvm_unreachable("Error running function.");
// 5. Dump content of input and output buffer for testing with FileCheck.
::impl::printMemRef(*A);
::impl::printMemRef(*B);
::impl::printMemRef(*C);
}
int main(int argc, char **argv) {
iree::Initializer::RunInitializers();
// Allow LLVM setup through command line and parse the
// test specific option for a runtime support library.
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv, "TestMatMulVulkan\n");
// clang-format off
// CHECK: Memref
// CHECK: [2, 5, 8, 11],
// CHECK: [14, 17, 20, 23],
// CHECK: [26, 29, 32, 35],
// CHECK: [38, 41, 44, 47]
// CHECK: Memref
// CHECK: [1, 2, 3, 4],
// CHECK: [5, 6, 7, 8],
// CHECK: [9, 10, 11, 12],
// CHECK: [13, 14, 15, 16]
// CHECK: Memref
// CHECK: [242, 268, 294, 320],
// CHECK: [578, 652, 726, 800],
// CHECK: [914, 1036, 1158, 1280],
// CHECK: [1250, 1420, 1590, 1760]
testMatMul();
}