blob: 23f87a75392bfb7c941a6b118af1c50ac3f25a7d [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===- SPIRVLowering.cpp ---------------------------------------*- C++//-*-===//
//
// SPIR-V Code-generation for XLA-HLO Ops within IREE Dispatch functions
//
//===----------------------------------------------------------------------===//
#include "compiler/Translation/SPIRV/SPIRVLowering.h"
namespace mlir {
namespace iree_compiler {
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
LogicalResult ConstantOpSPIRVLowering::lowerOperation(
Operation *op, OpBuilder &builder, AffineMap index, ArrayRef<Value *>,
ValueCache &valueCache) const {
auto constOp = cast<ConstantOp>(op);
auto attr = constOp.value().dyn_cast<DenseElementsAttr>();
if (!attr || !attr.isSplat()) {
return op->emitError(
"unhandled constant lowering unless value is a splat dense element "
"attribute");
}
auto resultType = constOp.getResult()->getType();
Type resultElemType;
if (resultType.isIntOrFloat()) {
resultElemType = resultType;
} else if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
resultElemType = shapedType.getElementType();
} else {
return op->emitError("unhandled result type of constant : ") << resultType;
}
Attribute constVal = attr.getSplatValue();
auto spirvConstOp =
builder.create<spirv::ConstantOp>(op->getLoc(), resultElemType, constVal);
valueCache.setOperandDstValue(constOp.getResult(), index,
spirvConstOp.getResult());
return success();
}
//===----------------------------------------------------------------------===//
// CmpFOp
//===----------------------------------------------------------------------===//
LogicalResult CmpFOpSPIRVLowering::lowerOperation(
Operation *op, OpBuilder &builder, AffineMap index,
ArrayRef<Value *> operands, ValueCache &valueCache) const {
if (operands.size() != 2) {
return op->emitError("expected two operands in spir-v lowering of CmpFOp");
}
Operation *spirvOp = nullptr;
auto opInfo = op->getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName());
if (!opInfo) {
return op->emitError("expected CmpFOp to contain ")
<< CmpFOp::getPredicateAttrName() << " attribute";
}
auto boolType = builder.getI1Type();
auto predicateVal = static_cast<CmpFPredicate>(opInfo.getInt());
switch (predicateVal) {
#define DISPATCH(caseLabel, opName) \
case caseLabel: \
spirvOp = builder.create<opName>(op->getLoc(), boolType, operands[0], \
operands[1]); \
break;
DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
#undef DISPATCH
default:
return op->emitError("unhandled predicate attribute for SPIR-V lowering");
}
valueCache.setOperandDstValue(op->getResult(0), index, spirvOp->getResult(0));
return success();
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
LogicalResult ReturnOpSPIRVLowering::lowerOperation(
Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen,
ValueCache &valueCache,
DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers,
ArrayRef<spirv::GlobalVariableOp> outputBuffers) const {
auto returnOp = cast<ReturnOp>(op);
if (returnOp.getNumOperands() != 1) {
return returnOp.emitError(
"unhandled lowering of return statement with multiple returns");
}
auto returnTensor = returnOp.getOperand(0);
auto indices = affineExprCodegen.getIndices(returnTensor);
if (indices.size() != 1) {
return returnOp.emitError(
"expected to compute a single element of the return tensor");
}
assert(outputBuffers.size() == 1 && "Expected a single output buffer");
auto var = outputBuffers[0];
auto ptr = genPointerOffset(builder, returnOp.getLoc(), affineExprCodegen,
indices[0], var);
auto scalarVal = valueCache.getOperandDstValue(returnTensor, indices[0]);
builder.create<spirv::StoreOp>(returnOp.getLoc(), ptr, scalarVal,
/*memory_access = */ nullptr,
/*alignment = */ nullptr);
builder.create<spirv::ReturnOp>(returnOp.getLoc());
return success();
}
} // namespace iree_compiler
} // namespace mlir