blob: bab753a7d2a966c5a82e53b026c390e28dc2069f [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Translation/SPIRV/EmbeddedKernels.h"
#include "iree/compiler/Translation/SPIRV/Kernels/Kernels.h"
#include "iree/schemas/spirv_executable_def_generated.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
namespace {
// Reads the SPIR-V code for the embedded kernel with the given file name.
// If the kernel under Kernels/ is 'matmul.comp' then |kernelName| would be
// 'matmul.spv' (because it's been compiled).
std::vector<uint32_t> readEmbeddedKernelCode(std::string kernelName) {
auto *fileToc = spirv_kernels::Kernels_create();
for (int i = 0; i < spirv_kernels::Kernels_size(); ++i) {
if (std::strcmp(fileToc[i].name, kernelName.c_str()) == 0) {
std::vector<uint32_t> code;
code.resize(fileToc[i].size / 4);
std::memcpy(code.data(), fileToc[i].data, fileToc[i].size);
return code;
}
}
return {};
}
// Adds a storage buffer binding to the descriptor set layout.
void addDescriptorSetLayoutBinding(uint32_t binding,
iree::VkDescriptorSetLayoutDefT *dsl) {
auto bindingDef = std::make_unique<iree::VkDescriptorSetLayoutBindingDefT>();
bindingDef->binding = binding;
bindingDef->descriptor_count = 1;
bindingDef->descriptor_type = 7; // VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
bindingDef->stage_flags = 0x00000020; // VK_SHADER_STAGE_COMPUTE_BIT
dsl->bindings.push_back(std::move(bindingDef));
}
// Adds a specialization map entry for |constant_id| set to a 4-byte int value.
void addSpecializationMapEntry(
uint32_t constant_id, uint32_t value,
iree::VkSpecializationInfoDefT *specializationInfoDef) {
auto specValue = std::make_unique<iree::VkSpecializationMapEntryDefT>();
specValue->constant_id = constant_id;
specValue->uint32_value = value;
specializationInfoDef->map_entries.push_back(std::move(specValue));
}
void addSpecializationMapEntryVector(
uint32_t constant_start, const std::vector<int> &values,
iree::VkSpecializationInfoDefT *specializationInfoDef) {
for (int i = 0; i < values.size(); ++i) {
addSpecializationMapEntry(constant_start + i,
*reinterpret_cast<const uint32_t *>(&values[i]),
specializationInfoDef);
}
}
LogicalResult buildReductionExecutable(ModuleOp moduleOp, FuncOp entryFuncOp,
iree::SpirVExecutableDefT *outDef) {
auto funcType = entryFuncOp.getType();
auto arg0 = funcType.getInput(0).cast<ShapedType>();
if (!arg0.getElementType().isF32()) {
// When we do other types we'll need other shaders.
return entryFuncOp.emitOpError()
<< "only floating point reduction is implemented";
}
auto applyFuncAttr = entryFuncOp.getAttrOfType<FlatSymbolRefAttr>(
"iree.executable.reduction.apply");
auto applyFuncOp = moduleOp.lookupSymbol(applyFuncAttr.getValue());
// TODO(benvanik): specialize (template on shapes/types/etc).
std::string kernelName = "reduce_untiled.spv";
llvm::Optional<uint32_t> operationId;
applyFuncOp->walk([&](Operation *op) {
if (isa<xla_hlo::AddOp>(op)) {
operationId = 0;
} else if (isa<xla_hlo::MaxOp>(op)) {
operationId = 1;
} else if (isa<xla_hlo::MinOp>(op)) {
operationId = 2;
}
});
if (!operationId.hasValue()) {
applyFuncOp->dump();
return applyFuncOp->emitOpError() << "unsupported reduction operator";
}
outDef->tag = "__reduce__";
outDef->entry_points = {"main"};
outDef->code = readEmbeddedKernelCode(kernelName);
// arg0, arg1, ret0
auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>();
pipelineLayoutDef->buffer_binding_set = 0;
auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>();
addDescriptorSetLayoutBinding(0, dsl.get());
addDescriptorSetLayoutBinding(1, dsl.get());
addDescriptorSetLayoutBinding(2, dsl.get());
pipelineLayoutDef->descriptor_set_layouts.push_back(std::move(dsl));
outDef->pipeline_layout = std::move(pipelineLayoutDef);
// See the shader source for documentation on the values of A/B/C/R.
int64_t reductionDimension =
entryFuncOp
.getAttrOfType<IntegerAttr>("iree.executable.reduction.dimension")
.getInt();
uint32_t r = arg0.getDimSize(reductionDimension);
uint32_t a = 1;
for (int i = 0; i < reductionDimension; ++i) {
a *= arg0.getDimSize(i);
}
uint32_t b = 1;
for (int i = reductionDimension + 1; i < arg0.getRank(); ++i) {
b *= arg0.getDimSize(i);
}
uint32_t c = b;
auto specializationInfoDef =
std::make_unique<iree::VkSpecializationInfoDefT>();
addSpecializationMapEntry(/*kOperationId*/ 100, operationId.getValue(),
specializationInfoDef.get());
addSpecializationMapEntry(/*kA*/ 101, a, specializationInfoDef.get());
addSpecializationMapEntry(/*kB*/ 102, b, specializationInfoDef.get());
addSpecializationMapEntry(/*kC*/ 103, c, specializationInfoDef.get());
addSpecializationMapEntry(/*kR*/ 104, r, specializationInfoDef.get());
outDef->specialization_info = std::move(specializationInfoDef);
return success();
}
LogicalResult buildConvExecutable(ModuleOp moduleOp, FuncOp entryFuncOp,
xla_hlo::ConvOp convOp,
iree::SpirVExecutableDefT *outDef) {
auto lhs = convOp.lhs()->getType().cast<ShapedType>();
auto rhs = convOp.rhs()->getType().cast<ShapedType>();
if (convOp.feature_group_count() != 1) {
return entryFuncOp.emitOpError()
<< "only feature group counts of 1 supported";
}
if (lhs.getRank() != 4 || rhs.getRank() != 4) {
return entryFuncOp.emitOpError() << "only Conv2d supported";
}
auto specializationInfoDef =
std::make_unique<iree::VkSpecializationInfoDefT>();
// Get the padding specializations.
{
std::vector<int> paddings;
if (convOp.padding().hasValue()) {
for (const auto &elm : convOp.padding().getValue().getIntValues()) {
paddings.push_back(elm.getSExtValue());
}
}
addSpecializationMapEntryVector(100, paddings, specializationInfoDef.get());
}
// LHS (image) dimensions in NCHW order - should map to NHWC ie [0,3,1,2].
{
std::vector<int> lhsOrdering{
static_cast<int>(
convOp.dimension_numbers().input_batch_dimension().getInt()),
static_cast<int>(
convOp.dimension_numbers().input_feature_dimension().getInt())};
for (const auto &dim :
convOp.dimension_numbers().input_spatial_dimensions()) {
lhsOrdering.push_back(dim.getSExtValue());
}
if (lhsOrdering.size() != 4 || lhsOrdering[0] != 0 || lhsOrdering[1] != 3 ||
lhsOrdering[2] != 1 || lhsOrdering[3] != 2) {
return entryFuncOp.emitOpError() << "only NHWC tensor ordering supported";
}
// Extents in buffer order.
std::vector<int> lhsExtents(lhsOrdering.size());
for (int i = 0; i < lhs.getRank(); ++i) {
lhsExtents[i] = lhs.getDimSize(i);
}
addSpecializationMapEntryVector(110, lhsExtents,
specializationInfoDef.get());
}
// RHS (kernel) dimension OIHW - should map to HWIO ie [3,2,0,1].
{
std::vector<int> rhsOrdering{
static_cast<int>(convOp.dimension_numbers()
.kernel_output_feature_dimension()
.getInt()),
static_cast<int>(convOp.dimension_numbers()
.kernel_input_feature_dimension()
.getInt())};
for (const auto &dim :
convOp.dimension_numbers().kernel_spatial_dimensions()) {
rhsOrdering.push_back(dim.getSExtValue());
}
if (rhsOrdering.size() != 4 || rhsOrdering[0] != 3 || rhsOrdering[1] != 2 ||
rhsOrdering[2] != 0 || rhsOrdering[3] != 1) {
return entryFuncOp.emitOpError() << "only HWIO kernel ordering supported";
}
// Extents in buffer order.
std::vector<int> rhsExtents(rhsOrdering.size());
for (int i = 0; i < rhs.getRank(); ++i) {
rhsExtents[i] = rhs.getDimSize(i);
}
addSpecializationMapEntryVector(120, rhsExtents,
specializationInfoDef.get());
}
// Result dimension order NCHW - should map to NHWC ie [0,3,1,2].
{
std::vector<int> retOrdering{
static_cast<int>(
convOp.dimension_numbers().output_batch_dimension().getInt()),
static_cast<int>(
convOp.dimension_numbers().output_feature_dimension().getInt())};
for (const auto &dim :
convOp.dimension_numbers().output_spatial_dimensions()) {
retOrdering.push_back(dim.getSExtValue());
}
if (retOrdering.size() != 4 || retOrdering[0] != 0 || retOrdering[1] != 3 ||
retOrdering[2] != 1 || retOrdering[3] != 2) {
return entryFuncOp.emitOpError() << "only HWIO kernel ordering supported";
}
}
outDef->tag = "__conv2d_nhwc__";
outDef->entry_points = {"main"};
outDef->code = readEmbeddedKernelCode("conv2d_nhwc.spv");
auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>();
pipelineLayoutDef->buffer_binding_set = 0;
auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>();
addDescriptorSetLayoutBinding(0, dsl.get());
addDescriptorSetLayoutBinding(1, dsl.get());
addDescriptorSetLayoutBinding(2, dsl.get());
pipelineLayoutDef->descriptor_set_layouts.push_back(std::move(dsl));
outDef->pipeline_layout = std::move(pipelineLayoutDef);
outDef->specialization_info = std::move(specializationInfoDef);
return success();
}
// Builds a SPIR-V executable from a well-known matmul executable.
// |outDef| will be populated with all required information for serialization.
LogicalResult buildMatMulExecutable(ModuleOp moduleOp, FuncOp entryFuncOp,
xla_hlo::DotOp dotOp,
iree::SpirVExecutableDefT *outDef) {
auto arg0 = dotOp.getOperand(0)->getType().cast<ShapedType>();
auto arg1 = dotOp.getOperand(1)->getType().cast<ShapedType>();
outDef->tag = "__matmul__";
outDef->entry_points = {"main"};
// TODO(benvanik): specialize (template on shapes/types/etc).
outDef->code = readEmbeddedKernelCode("matmul.spv");
// arg0, arg1, ret0
auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>();
pipelineLayoutDef->buffer_binding_set = 0;
auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>();
addDescriptorSetLayoutBinding(0, dsl.get());
addDescriptorSetLayoutBinding(1, dsl.get());
addDescriptorSetLayoutBinding(2, dsl.get());
pipelineLayoutDef->descriptor_set_layouts.push_back(std::move(dsl));
outDef->pipeline_layout = std::move(pipelineLayoutDef);
// Shapes of [arg0, arg1, ret0].
// arg0 = [b0, m, k]
// arg1 = [b0, k, n]
// ret0 = [b0, m, n]
// Note that we handle both batched (rank 3) and unbatched (rank 2).
uint32_t m = arg0.getRank() == 3 ? arg0.getDimSize(1) : arg0.getDimSize(0);
uint32_t k = arg0.getRank() == 3 ? arg0.getDimSize(2) : arg0.getDimSize(1);
uint32_t n = arg1.getRank() == 3 ? arg1.getDimSize(2) : arg1.getDimSize(1);
auto specializationInfoDef =
std::make_unique<iree::VkSpecializationInfoDefT>();
addSpecializationMapEntry(/*kMatrixM*/ 100, m, specializationInfoDef.get());
addSpecializationMapEntry(/*kMatrixK*/ 101, k, specializationInfoDef.get());
addSpecializationMapEntry(/*kMatrixN*/ 102, n, specializationInfoDef.get());
outDef->specialization_info = std::move(specializationInfoDef);
return success();
}
} // namespace
bool tryEmbeddedKernelRewrite(ModuleOp moduleOp,
iree::SpirVExecutableDefT *outDef) {
for (auto funcOp : moduleOp.getOps<FuncOp>()) {
if (funcOp.getAttr("iree.executable.reduction")) {
if (failed(buildReductionExecutable(moduleOp, funcOp, outDef))) {
moduleOp.emitOpError() << "failed to splat in the reduction kernel";
return false;
}
return true;
}
for (auto &block : funcOp) {
for (auto &op : block) {
if (auto convOp = dyn_cast_or_null<xla_hlo::ConvOp>(&op)) {
if (failed(buildConvExecutable(moduleOp, funcOp, convOp, outDef))) {
moduleOp.emitOpError() << "failed to splat in the conv kernel";
return false;
}
return true;
} else if (auto dotOp = dyn_cast_or_null<xla_hlo::DotOp>(&op)) {
if (failed(buildMatMulExecutable(moduleOp, funcOp, dotOp, outDef))) {
moduleOp.emitOpError() << "failed to splat in the matmul kernel";
return false;
}
return true;
}
}
}
}
return false;
}
} // namespace iree_compiler
} // namespace mlir