blob: 567273891ccc0ad1e858d8c9df6fcd8c0125018d [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===- IREEToSPIRVPass.cpp -------------------------------------*- C++//-*-===//
//
// Pass to translate iree executables for vulkan-spirv.
//
//===----------------------------------------------------------------------===//
#include "compiler/Translation/SPIRV/IREEToSPIRVPass.h"
#include "compiler/Translation/SPIRV/IREEIndexComputation.h"
#include "compiler/Translation/SPIRV/IREEToSPIRV.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/Dialect/StandardOps/Ops.h"
namespace mlir {
namespace iree_compiler {
namespace {
class IREEToSPIRVPass : public ModulePass<IREEToSPIRVPass> {
void runOnModule() override;
};
} // namespace
void IREEToSPIRVPass::runOnModule() {
auto module = getModule();
OpBuilder builder(module.getBodyRegion());
// Initialize the index computation.
IndexPropagationList<IndexPropagationOp<ConstantOp>,
// IREE-specific ops:
IndexPropagationOp<IREE::ReturnOp>,
IREELoadIndexPropagation, IREEStoreIndexPropagation,
// Standard dialect unary elementwise ops:
NoBroadcastPwOpIndexPropagation<SIToFPOp>,
NoBroadcastPwOpIndexPropagation<SignExtendIOp>,
// Standard dialect binary elementwise ops:
NoBroadcastPwOpIndexPropagation<AddFOp>,
NoBroadcastPwOpIndexPropagation<AddIOp>,
NoBroadcastPwOpIndexPropagation<AndOp>,
NoBroadcastPwOpIndexPropagation<CmpFOp>,
NoBroadcastPwOpIndexPropagation<CmpIOp>,
NoBroadcastPwOpIndexPropagation<DivFOp>,
NoBroadcastPwOpIndexPropagation<DivISOp>,
NoBroadcastPwOpIndexPropagation<DivIUOp>,
NoBroadcastPwOpIndexPropagation<MulFOp>,
NoBroadcastPwOpIndexPropagation<MulIOp>,
NoBroadcastPwOpIndexPropagation<OrOp>,
NoBroadcastPwOpIndexPropagation<RemFOp>,
NoBroadcastPwOpIndexPropagation<RemISOp>,
NoBroadcastPwOpIndexPropagation<RemIUOp>,
NoBroadcastPwOpIndexPropagation<SubFOp>,
NoBroadcastPwOpIndexPropagation<SubFOp>,
NoBroadcastPwOpIndexPropagation<SubIOp>,
NoBroadcastPwOpIndexPropagation<TruncateIOp>,
NoBroadcastPwOpIndexPropagation<XOrOp>,
NoBroadcastPwOpIndexPropagation<ZeroExtendIOp>,
// XLA unary elementwise ops:
NoBroadcastPwOpIndexPropagation<xla_hlo::AbsOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::CeilOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::ConvertOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::CosOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::ExpOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::FloorOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::LogOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::NegOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::RsqrtOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::SignOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::TanhOp>,
// XLA binary elementwise ops:
NoBroadcastPwOpIndexPropagation<xla_hlo::AddOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::AndOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::DivOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::MaxOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::MinOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::MulOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::SubOp>,
// XLA other ops:
// TODO(ravishankarm): conv, dot.
// TODO(ravishankarm): gather.
// TODO(ravishankarm): pad.
// TODO(ravishankarm): slice.
NoBroadcastPwOpIndexPropagation<xla_hlo::CopyOp>,
ReshapeOpIndexPropagation<xla_hlo::ReshapeOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::SelectOp>,
XLABroadcastOpIndexPropagation,
XLABroadcastInDimOpIndexPropagation,
XLAReverseOpIndexPropagation,
XLATransposeOpIndexPropagation>
indexPropagation;
// Initialize the spir-v codegenerator.
SPIRVCodegen<
ConstantOpSPIRVLowering,
// IREE-specific ops:
IREELoadOpSPIRVLowering, IREEReturnOpSPIRVLowering,
IREEStoreOpSPIRVLowering,
// Standard dialect unary elementwise ops:
// Standard dialect binary elementwise ops:
SPIRVPwOpLowering<AddFOp, spirv::FAddOp>,
SPIRVPwOpLowering<DivFOp, spirv::FDivOp>,
SPIRVPwOpLowering<MulFOp, spirv::FMulOp>,
SPIRVPwOpLowering<SubFOp, spirv::FSubOp>,
SPIRVPwOpLowering<AddIOp, spirv::IAddOp>,
SPIRVPwOpLowering<DivISOp, spirv::SDivOp>,
SPIRVPwOpLowering<MulIOp, spirv::IMulOp>,
SPIRVPwOpLowering<SubIOp, spirv::ISubOp>,
// XLA unary elementwise ops:
SPIRVPwOpLowering<xla_hlo::AbsOp, spirv::GLSLSAbsOp, spirv::GLSLFAbsOp>,
SPIRVPwOpLowering<xla_hlo::CeilOp, spirv::GLSLCeilOp>,
// TODO(ravishankarm): xla_hlo::ConvertOp
SPIRVPwOpLowering<xla_hlo::CosOp, spirv::GLSLCosOp>,
SPIRVPwOpLowering<xla_hlo::ExpOp, spirv::GLSLExpOp>,
SPIRVPwOpLowering<xla_hlo::FloorOp, spirv::GLSLFloorOp>,
SPIRVPwOpLowering<xla_hlo::LogOp, spirv::GLSLLogOp>,
SPIRVPwOpLowering<xla_hlo::NegOp, spirv::FNegateOp>,
SPIRVPwOpLowering<xla_hlo::RsqrtOp, spirv::GLSLInverseSqrtOp>,
SPIRVPwOpLowering<xla_hlo::SignOp, spirv::GLSLSSignOp,
spirv::GLSLFSignOp>,
SPIRVPwOpLowering<xla_hlo::TanhOp, spirv::GLSLTanhOp>,
// XLA binary elementwise ops:
SPIRVPwOpLowering<xla_hlo::AddOp, spirv::IAddOp, spirv::FAddOp>,
SPIRVPwOpLowering<xla_hlo::AndOp, spirv::LogicalAndOp>,
SPIRVPwOpLowering<xla_hlo::DivOp, spirv::FDivOp>,
SPIRVPwOpLowering<xla_hlo::MaxOp, spirv::GLSLSMaxOp, spirv::GLSLFMaxOp>,
SPIRVPwOpLowering<xla_hlo::MinOp, spirv::GLSLSMinOp, spirv::GLSLFMinOp>,
SPIRVPwOpLowering<xla_hlo::MulOp, spirv::IMulOp, spirv::FMulOp>,
SPIRVPwOpLowering<xla_hlo::SubOp, spirv::ISubOp, spirv::FSubOp>,
// XLA other ops:
CmpFOpSPIRVLowering,
SPIRVPwOpLowering<xla_hlo::SelectOp, spirv::SelectOp>,
SPIRVIndexOpLowering<xla_hlo::BroadcastOp>,
SPIRVIndexOpLowering<xla_hlo::BroadcastInDimOp>,
SPIRVIndexOpLowering<xla_hlo::CopyOp>,
SPIRVIndexOpLowering<xla_hlo::ReshapeOp>,
SPIRVIndexOpLowering<xla_hlo::ReverseOp>,
SPIRVIndexOpLowering<xla_hlo::TransposeOp>>
spirvCodegen;
// Create a spirv.module Op.
auto spvModule = builder.create<spirv::ModuleOp>(
module.getLoc(),
builder.getI32IntegerAttr(
static_cast<int32_t>(spirv::AddressingModel::Logical)),
builder.getI32IntegerAttr(
static_cast<int32_t>(spirv::MemoryModel::GLSL450)));
SmallVector<StringRef, 2> caps;
caps.push_back(spirv::stringifyCapability(spirv::Capability::Shader));
spvModule.setAttr("capabilities", builder.getStrArrayAttr(caps));
SmallVector<StringRef, 2> exts;
exts.push_back("SPV_KHR_storage_buffer_storage_class");
spvModule.setAttr("extensions", builder.getStrArrayAttr(exts));
for (auto funcOp : module.getOps<FuncOp>()) {
// TODO(ravishankarm): FuncOps in executable that are not dispatch functions
// are not lowered to SPIR-V. Fix this limitation.
if (!funcOp.getAttr("iree.executable.export")) continue;
IndexComputationCache indexMap;
if (failed(indexPropagation.propagate(funcOp.getBody(), indexMap))) {
return signalPassFailure();
}
// dumpIndexCache(indexMap);
ValueCache valueCache;
AffineExprCodegen affineExprCodegen(spvModule, indexMap);
if (failed(spirvCodegen.codegen(spvModule, funcOp, affineExprCodegen,
valueCache))) {
return signalPassFailure();
}
}
}
std::unique_ptr<OpPassBase<ModuleOp>> createIREEToSPIRVPass() {
return std::make_unique<IREEToSPIRVPass>();
}
static PassRegistration<IREEToSPIRVPass> pass(
"convert-iree-to-spirv",
"Convert IREE dispatch functions to SPIR-V dialect");
} // namespace iree_compiler
} // namespace mlir