blob: 3b99924726810df141498493252d90fadd5d50c1 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===- AdjustIntegerWidthPass.cpp ------------------------------*- C++//-*-===//
//
// Pass to adjust integer widths of operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
/// Pass to
/// 1) Legalize 64-bit integer values to 32-bit integers values.
/// 2) Legalize !spv.array containing i1 type to !spv.array of i32 types.
/// 1) Legalize 8-bit integer values to 32-bit integers values.
/// TODO(b/144743561): Use Int8 capability after it is well-supported.
struct AdjustIntegerWidthPass
: public PassWrapper<AdjustIntegerWidthPass, OperationPass<>> {
void runOnOperation() override;
};
// Returns true if the type contains any IntegerType of the width specified by
// `widths`
bool hasIntTypeOfWidth(Type type, ArrayRef<int64_t> widths) {
if (auto intType = type.dyn_cast<IntegerType>()) {
return llvm::is_contained(widths, intType.getWidth());
} else if (auto structType = type.dyn_cast<spirv::StructType>()) {
for (int64_t i = 0, e = structType.getNumElements(); i != e; ++i) {
if (hasIntTypeOfWidth(structType.getElementType(i), widths)) return true;
}
return false;
} else if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
return hasIntTypeOfWidth(arrayType.getElementType(), widths);
} else if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
return hasIntTypeOfWidth(ptrType.getPointeeType(), widths);
}
return false;
}
// Returns true if the type contains i1, i8, i16, or i64.
bool hasSupportedIntegerType(Type type) {
return hasIntTypeOfWidth(type, {1, 8, 16, 64});
}
// Legalizes all the integer types in struct to i32.
Type legalizeIntegerType(Type type) {
if (auto intType = type.dyn_cast<IntegerType>()) {
return IntegerType::get(32, intType.getContext());
} else if (auto structType = type.dyn_cast<spirv::StructType>()) {
SmallVector<Type, 1> elementTypes;
for (auto i : llvm::seq<unsigned>(0, structType.getNumElements())) {
elementTypes.push_back(legalizeIntegerType(structType.getElementType(i)));
}
// TODO(ravishankarm): Use ABI attributes to legalize the struct type.
return VulkanLayoutUtils::decorateType(
spirv::StructType::get(elementTypes));
} else if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
return spirv::ArrayType::get(
legalizeIntegerType(arrayType.getElementType()),
arrayType.getNumElements());
} else if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
return spirv::PointerType::get(
legalizeIntegerType(ptrType.getPointeeType()),
ptrType.getStorageClass());
}
return type;
}
/// Rewrite access chain operations where the pointee type contains i1 or i64
/// types.
struct AdjustAccessChainOp : public OpRewritePattern<spirv::AccessChainOp> {
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::AccessChainOp op,
PatternRewriter &rewriter) const override {
if (!hasSupportedIntegerType(op.component_ptr().getType())) {
return failure();
}
ValueRange indices(op.indices());
Type newType = legalizeIntegerType(op.component_ptr().getType());
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(op, newType,
op.base_ptr(), indices);
return success();
}
};
/// Rewrite address of operations which refers to global variables that contain
/// i1 or i64 types.
struct AdjustAddressOfOp : public OpRewritePattern<spirv::AddressOfOp> {
using OpRewritePattern<spirv::AddressOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::AddressOfOp op,
PatternRewriter &rewriter) const override {
if (!hasSupportedIntegerType(op.pointer().getType())) {
return failure();
}
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
op, legalizeIntegerType(op.pointer().getType()),
SymbolRefAttr::get(op.variable(), rewriter.getContext()));
return success();
}
};
/// Rewrite global variable ops that contains i1, i8, i16 and i64 types to i32
/// type.
struct AdjustGlobalVariableWidth
: public OpRewritePattern<spirv::GlobalVariableOp> {
using OpRewritePattern<spirv::GlobalVariableOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::GlobalVariableOp op,
PatternRewriter &rewriter) const override {
if (!hasSupportedIntegerType(op.type())) {
return failure();
}
rewriter.replaceOpWithNewOp<spirv::GlobalVariableOp>(
op, legalizeIntegerType(op.type()), op.sym_name(),
op.getAttr("descriptor_set").cast<IntegerAttr>().getInt(),
op.getAttr("binding").cast<IntegerAttr>().getInt());
return success();
}
};
// Returns an adjusted spirv::AccessChainOp to access corresponding i32
// elements. One element was a `bits`-bit integer. The method adjust the last
// index to make it access the corresponding i32 element. Note that this only
// works for a scalar or 1-D tensor.
Value convertToI32AccessChain(spirv::AccessChainOp op, int bits,
PatternRewriter &rewriter) {
assert(32 % bits == 0);
const auto loc = op.getLoc();
auto i32Type = rewriter.getIntegerType(32);
auto idx = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(32 / bits));
auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1);
SmallVector<Value, 4> indices;
for (auto it : op.indices()) {
indices.push_back(it);
}
if (indices.size() > 1) {
indices.back() = rewriter.create<spirv::SDivOp>(loc, lastDim, idx);
}
Type t = legalizeIntegerType(op.component_ptr().getType());
return rewriter.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
}
// Returns the offset of input value in i32 representation. For example, if
// `bits` equals to 8, the x-th element is located at (x % 4) * 8. Because there
// are four elements in one i32, and one element has 8 bits.
Value getOffsetOfInt(spirv::AccessChainOp op, int bits,
PatternRewriter &rewriter) {
assert(32 % bits == 0);
const auto loc = op.getLoc();
Type i32Type = rewriter.getIntegerType(32);
auto idx = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(32 / bits));
auto num = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(bits));
auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1);
auto m = rewriter.create<spirv::SModOp>(loc, lastDim, idx);
return rewriter.create<spirv::IMulOp>(loc, i32Type, m, num);
}
Value rewriteIntForLoadOp(spirv::LoadOp op, PatternRewriter &rewriter) {
const auto loc = op.getLoc();
Type valueType = op.value().getType();
Type i32Type = rewriter.getIntegerType(32);
Value result;
auto accessChainOp = cast<spirv::AccessChainOp>(op.ptr().getDefiningOp());
// Only support for scalar and 1-D tensor. The first element in indices is
// index, the remaining elements map to other dimensions.
if (accessChainOp.indices().size() > 2) {
return nullptr;
}
int bits = hasIntTypeOfWidth(valueType, {1, 8}) ? 8 : 16;
Value i32AccessChainOp =
convertToI32AccessChain(accessChainOp, bits, rewriter);
Value loadOp = rewriter.create<spirv::LoadOp>(
loc, i32Type, i32AccessChainOp,
op.getAttrOfType<IntegerAttr>(
spirv::attributeName<spirv::MemoryAccess>()),
op.getAttrOfType<IntegerAttr>("alignment"));
// If it is a scalar, use the loading value directly. Otherwise, extract
// corresponding bits out. If it is a scalar, the indices only contains one
// element (which is index).
if (accessChainOp.indices().size() == 1) {
result = loadOp;
} else {
Value offset = getOffsetOfInt(accessChainOp, bits, rewriter);
result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, i32Type,
loadOp, offset);
}
auto intMax = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr((1 << bits) - 1));
result = rewriter.create<spirv::BitwiseAndOp>(loc, i32Type, result, intMax);
// If this is a load of a i1, replace it with a load of i8, and truncate the
// result. Use INotEqualOp because SConvert doesn't work for i1.
if (hasIntTypeOfWidth(valueType, {1})) {
Type newType = legalizeIntegerType(valueType);
auto zero = spirv::ConstantOp::getZero(newType, loc, &rewriter);
result = rewriter.create<spirv::INotEqualOp>(loc, valueType, result, zero)
.getResult();
}
return result;
}
/// Rewrite loads from !spv.ptr<i64,..> to load from !spv.ptr<i32,...>
/// Rewrite loads from !spv.ptr<i1,...> to load from !spv.ptr<i32,...> followed
/// by a truncate to i1 type.
/// Rewrite loads from !spv.ptr<i8,...> to load from !spv.ptr<i32,...> followed
/// by an extraction.
/// Rewrite loads from !spv.ptr<i16,...> to load from !spv.ptr<i32,...> followed
/// by an extraction.
struct AdjustLoadOp : public OpRewritePattern<spirv::LoadOp> {
using OpRewritePattern<spirv::LoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::LoadOp op,
PatternRewriter &rewriter) const override {
Type valueType = op.value().getType();
if (!hasSupportedIntegerType(valueType)) {
return failure();
}
Type newType = legalizeIntegerType(valueType);
const auto loc = op.getLoc();
Value result;
if (hasIntTypeOfWidth(valueType, {1, 8, 16})) {
result = rewriteIntForLoadOp(op, rewriter);
} else {
auto loadOp = rewriter.create<spirv::LoadOp>(
loc, newType, op.ptr(),
op.getAttrOfType<IntegerAttr>(
spirv::attributeName<spirv::MemoryAccess>()),
op.getAttrOfType<IntegerAttr>("alignment"));
result = loadOp.getResult();
}
rewriter.replaceOp(op, result);
return success();
}
};
// Returns the shifted 32-bit value with the given offset.
Value shiftStoreValue(spirv::StoreOp op, const Value &offset, const Value &mask,
PatternRewriter &rewriter) {
Type valueType = op.value().getType();
Type i32Type = rewriter.getIntegerType(32);
const auto loc = op.getLoc();
Value storeVal = op.value();
if (hasIntTypeOfWidth(valueType, {1})) {
Value zero =
spirv::ConstantOp::getZero(i32Type, loc, &rewriter).getResult();
Value one = spirv::ConstantOp::getOne(i32Type, loc, &rewriter).getResult();
storeVal =
rewriter.create<spirv::SelectOp>(loc, storeVal, one, zero).getResult();
} else {
storeVal = rewriter.create<spirv::SConvertOp>(loc, i32Type, storeVal);
storeVal = rewriter.create<spirv::BitwiseAndOp>(loc, storeVal, mask);
}
return rewriter.create<spirv::ShiftLeftLogicalOp>(loc, i32Type, storeVal,
offset);
}
// Rewrites store operation that contains i1, i8 and i16 types to i32 type.
// Since there are multi threads in the processing, atomic operations are
// required. E.g., if the loading value is i8, rewrite the StoreOp to
// 1) load a 32-bit integer
// 2) clear 8 bits in the loading value
// 3) store 32-bit value back
// 4) load a 32-bit integer
// 5) modify 8 bits in the loading value
// 6) store 32-bit value back
// The step 1 to step 3 are done by AtomicAnd, and the step 4 to
// step 6 are done by AtomicOr.
LogicalResult rewriteIntForStoreOp(spirv::StoreOp op,
PatternRewriter &rewriter) {
Type valueType = op.value().getType();
Type i32Type = rewriter.getIntegerType(32);
const auto loc = op.getLoc();
auto accessChainOp = cast<spirv::AccessChainOp>(op.ptr().getDefiningOp());
// Only support for scalar and 1-D tensor. The first element in indices is
// index, the remaining elements map to other dimensions.
if (accessChainOp.indices().size() > 2) {
return failure();
}
int bits = hasIntTypeOfWidth(valueType, {1, 8}) ? 8 : 16;
auto offset = getOffsetOfInt(accessChainOp, bits, rewriter);
// Create a mask to clear the destination. E.g., if it is the second i8 in
// i32, 0xFFFF00FF is created.
auto mask = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr((1 << bits) - 1));
Value clear8BitMask =
rewriter.create<spirv::ShiftLeftLogicalOp>(loc, i32Type, mask, offset);
clear8BitMask = rewriter.create<spirv::NotOp>(loc, i32Type, clear8BitMask);
Value storeVal = shiftStoreValue(op, offset, mask, rewriter);
Value i32AccessChainOp =
convertToI32AccessChain(accessChainOp, bits, rewriter);
Value result = rewriter.create<spirv::AtomicAndOp>(
loc, i32Type, i32AccessChainOp, spirv::Scope::Device,
spirv::MemorySemantics::AcquireRelease, clear8BitMask);
result = rewriter.create<spirv::AtomicOrOp>(
loc, i32Type, i32AccessChainOp, spirv::Scope::Device,
spirv::MemorySemantics::AcquireRelease, storeVal);
// The AtomicOrOp has no side effect. Since it is already inserted, we can
// just remove the original StoreOp. Note that rewriter.replaceOp()
// doesn't work because it only accepts that the numbers of result are the
// same.
rewriter.eraseOp(op);
return success();
}
/// Rewrite store operation that contain i1, i8 and i64 types to i32 type.
struct AdjustStoreOp : public OpRewritePattern<spirv::StoreOp> {
using OpRewritePattern<spirv::StoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::StoreOp op,
PatternRewriter &rewriter) const override {
Type valueType = op.value().getType();
if (!hasSupportedIntegerType(valueType)) {
return failure();
}
if (hasIntTypeOfWidth(valueType, {1, 8, 16})) {
if (failed(rewriteIntForStoreOp(op, rewriter))) return failure();
} else {
const auto loc = op.getLoc();
auto i32Type = rewriter.getIntegerType(32);
auto value = rewriter.create<spirv::SConvertOp>(loc, i32Type, op.value());
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
op, op.ptr(), value,
op.getAttrOfType<IntegerAttr>(
spirv::attributeName<spirv::MemoryAccess>()),
op.getAttrOfType<IntegerAttr>("alignment"));
}
return success();
}
};
/// Some Adjust* OpRewritePattern will generate useless SConvert operations,
/// which are invalid operations. Remove the SConvert operation if this is an
/// nop, i.e., if the source type and destination type are the same, remove the
/// op. It relies on the furthur finialization to remove the op, and propagate
/// right operands to other operations.
struct RemoveNopSConvertOp : public OpRewritePattern<spirv::SConvertOp> {
using OpRewritePattern<spirv::SConvertOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::SConvertOp op,
PatternRewriter &rewriter) const override {
Type t1 = op.operand().getType();
Type t2 = op.result().getType();
if (t1 != t2) return failure();
auto zero = spirv::ConstantOp::getZero(t1, op.getLoc(), &rewriter);
rewriter.replaceOpWithNewOp<spirv::IAddOp>(op, op.operand(), zero);
return success();
}
};
/// Rewrite SConvert operation that the target type is i8, i16 or i64.
struct AdjustSConvertOp : public OpRewritePattern<spirv::SConvertOp> {
using OpRewritePattern<spirv::SConvertOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::SConvertOp op,
PatternRewriter &rewriter) const override {
Type t = op.result().getType();
if (!hasIntTypeOfWidth(t, {8, 16, 64})) {
return failure();
}
Type i32Type = rewriter.getIntegerType(32);
rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, i32Type, op.operand());
return success();
}
};
/// Rewrite i64 constants to i32 constants.
struct AdjustConstantOp : public OpRewritePattern<spirv::ConstantOp> {
using OpRewritePattern<spirv::ConstantOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::ConstantOp op,
PatternRewriter &rewriter) const {
Type constantType = op.getType();
if (!hasIntTypeOfWidth(constantType, {8, 16, 64})) {
return failure();
}
Value i32cst;
if (auto attr = op.value().dyn_cast<IntegerAttr>()) {
Type i32Type = rewriter.getIntegerType(32);
auto i32Attr = IntegerAttr::get(i32Type, attr.getInt());
i32cst =
rewriter.create<spirv::ConstantOp>(op.getLoc(), i32Type, i32Attr);
} else {
llvm_unreachable("only support splat constant");
}
rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, constantType, i32cst);
return success();
}
};
/// Rewrite integer arithmetic operations that operate on 64-bit integers to
/// operate on 32-bit integers.
template <typename OpTy>
struct AdjustIntegerArithmeticOperations : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const {
Type resultType = op.result().getType();
if (!hasIntTypeOfWidth(resultType, {8, 16, 64})) {
return failure();
}
Type newType = legalizeIntegerType(op.getResult().getType());
ValueRange operands(op.getOperation()->getOperands());
rewriter.replaceOpWithNewOp<OpTy>(op, newType, operands, op.getAttrs());
return success();
}
};
/// Rewrite SelectOp so that result is a 32-bit integer.
struct AdjustSelectOp : public OpRewritePattern<spirv::SelectOp> {
using OpRewritePattern<spirv::SelectOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::SelectOp op,
PatternRewriter &rewriter) const {
Type resultType = op.result().getType();
if (!hasSupportedIntegerType(resultType)) {
return failure();
}
Type newType = legalizeIntegerType(op.getResult().getType());
ValueRange operands(op.getOperation()->getOperands());
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, newType, operands,
op.getAttrs());
return success();
}
};
void AdjustIntegerWidthPass::runOnOperation() {
OwningRewritePatternList patterns;
patterns.insert<
// Arithmetic ops:
AdjustIntegerArithmeticOperations<spirv::GLSLSAbsOp>,
AdjustIntegerArithmeticOperations<spirv::GLSLSMaxOp>,
AdjustIntegerArithmeticOperations<spirv::GLSLSMinOp>,
AdjustIntegerArithmeticOperations<spirv::GLSLSSignOp>,
AdjustIntegerArithmeticOperations<spirv::IAddOp>,
AdjustIntegerArithmeticOperations<spirv::ISubOp>,
AdjustIntegerArithmeticOperations<spirv::IMulOp>,
AdjustIntegerArithmeticOperations<spirv::SDivOp>,
AdjustIntegerArithmeticOperations<spirv::SModOp>,
AdjustIntegerArithmeticOperations<spirv::SRemOp>,
AdjustIntegerArithmeticOperations<spirv::UDivOp>,
AdjustIntegerArithmeticOperations<spirv::UModOp>,
// Control flow ops:
AdjustSelectOp,
// Structure ops:
AdjustConstantOp,
// Others:
AdjustAccessChainOp, AdjustAddressOfOp, AdjustGlobalVariableWidth,
AdjustLoadOp, AdjustStoreOp, RemoveNopSConvertOp, AdjustSConvertOp>(
&getContext());
Operation *op = getOperation();
applyPatternsAndFoldGreedily(op->getRegions(), patterns);
}
static PassRegistration<AdjustIntegerWidthPass> pass(
"iree-spirv-adjust-integer-width",
"Adjust integer width from all integer types to i32 type");
} // namespace
std::unique_ptr<Pass> createAdjustIntegerWidthPass() {
return std::make_unique<AdjustIntegerWidthPass>();
}
} // namespace iree_compiler
} // namespace mlir