blob: ce24817867471a5a888ae867204c95e495eafdc3 [file]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace mlir {
namespace iree_compiler {
namespace {
//===----------------------------------------------------------------------===//
// Op Conversion Patterns
//===----------------------------------------------------------------------===//
/// Converts vector transfer ops to SPIR-V cooperative matrix load/store ops.
struct ConvertVectorTransferOp final
: public OpInterfaceConversionPattern<VectorTransferOpInterface> {
using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
LogicalResult matchAndRewrite(
VectorTransferOpInterface op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Don't support masked load/store.
if (op.getMaskType()) return failure();
// Expect inbound access.
if (op.in_bounds()) {
auto inBounds = op.in_bounds()->getAsValueRange<BoolAttr>();
if (!llvm::all_of(inBounds, [](bool v) { return v; })) return failure();
}
// Expect transfers over memrefs.
auto memrefType = op.getShapedType().dyn_cast<MemRefType>();
if (!memrefType) return failure();
// Expect 2-D vectors.
auto vectorType = op.getVectorType();
if (vectorType.getRank() != 2) return failure();
// TODO: Use coloumn major with transposed transfer ops.
if (!op.permutation_map().isMinorIdentity()) return failure();
int64_t offset = 0;
SmallVector<int64_t, 2> strides;
if (failed(getStridesAndOffset(memrefType, strides, offset)))
return failure();
auto stride = strides[0];
if (ShapedType::isDynamicStrideOrOffset(stride)) return failure();
auto loc = op.getLoc();
auto i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, IntegerAttr::get(i32Type, stride));
auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
Type matType = typeConverter->convertType(vectorType);
if (auto readOp = dyn_cast<vector::TransferReadOp>(*op)) {
vector::TransferReadOp::Adaptor adaptor(operands,
op->getAttrDictionary());
Value bufferPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType,
adaptor.getSource(), adaptor.getIndices(), loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::CooperativeMatrixLoadNVOp>(
op, matType, bufferPtr, strideValue, coloumnMajor,
spirv::MemoryAccessAttr());
return success();
}
if (auto writeOp = dyn_cast<vector::TransferWriteOp>(*op)) {
vector::TransferWriteOp::Adaptor adaptor(operands,
op->getAttrDictionary());
Value bufferPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType,
adaptor.getSource(), adaptor.getIndices(), loc, rewriter);
rewriter.create<spirv::CooperativeMatrixStoreNVOp>(
loc, bufferPtr, adaptor.getVector(), strideValue, coloumnMajor,
spirv::MemoryAccessAttr());
rewriter.eraseOp(op);
return success();
}
return failure();
}
};
/// Converts vector.contract ops to SPIR-V cooperative matrix multiple-add ops.
struct ConvertVectorContractOp final
: public OpConversionPattern<vector::ContractionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
vector::ContractionOp contractOp, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
if (!llvm::empty(contractOp.getMasks())) return failure();
// Check that this is a matmul operation.
auto iterators = contractOp.getIteratorTypes().getValue();
if (iterators.size() != 3 || !isParallelIterator(iterators[0]) ||
!isParallelIterator(iterators[1]) ||
!isReductionIterator(iterators[2])) {
return failure();
}
if (contractOp.getKind() != vector::CombiningKind::ADD) return failure();
// Column major matmuls should have been lowered to transpose + contract
// by this point. Transpose can be handled by load/store operations.
if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr())) return failure();
rewriter.replaceOpWithNewOp<spirv::CooperativeMatrixMulAddNVOp>(
contractOp, operands.getAcc().getType(), operands.getLhs(),
operands.getRhs(), operands.getAcc());
return success();
}
};
/// Converts splat vector constants to constant SPIR-V cooperative matrix ops.
struct ConvertConstantMatrix final
: public OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
arith::ConstantOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
// Only convert 2-D vector constants.
auto vectorType = op.getType().dyn_cast<VectorType>();
if (!vectorType || vectorType.getRank() != 2) return failure();
// Only convert splat integer/float vectors.
auto values = op.getValue().dyn_cast<DenseIntOrFPElementsAttr>();
if (!values || !values.isSplat()) return failure();
Attribute value = values.getSplatValue<Attribute>();
auto elementType = values.getType().getElementType();
Value splatValue = rewriter.create<spirv::ConstantOp>(
op.getLoc(), typeConverter->convertType(elementType), value);
auto matType = typeConverter->convertType(vectorType);
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, matType,
splatValue);
return success();
}
};
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops.
template <typename SrcOpType, typename DstOpType>
struct ConvertElementwiseOp final : public OpConversionPattern<SrcOpType> {
using OpConversionPattern<SrcOpType>::OpConversionPattern;
LogicalResult matchAndRewrite(
SrcOpType op, typename SrcOpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// All operands should be of cooperative matrix types.
for (Value operand : adaptor.getOperands()) {
if (!operand.getType().isa<spirv::CooperativeMatrixNVType>())
return failure();
}
// Only support ops with one result.
if (op->getNumResults() != 1) return failure();
auto matType = this->typeConverter->convertType(op.getType());
rewriter.replaceOpWithNewOp<DstOpType>(op, matType, adaptor.getOperands());
return success();
}
};
//===----------------------------------------------------------------------===//
// Main Pass
//===----------------------------------------------------------------------===//
struct SPIRVVectorToCooperativeOpsPass final
: public SPIRVVectorToCooperativeOpsBase<SPIRVVectorToCooperativeOpsPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
func::FuncOp funcOp = getOperation();
spirv::TargetEnvAttr targetAttr = getSPIRVTargetEnvAttr(funcOp);
SPIRVTypeConverter typeConverter(targetAttr);
// Inject conversion rules for 2-D vector types to cooperative matrix types.
//
// Note that we don't perform legality check here; we just directly convert.
// Legality check is expected to be done when deciding the whole pass
// pipeline is feasible and also in SPIR-V ConversionTarget.
typeConverter.addConversion(
[&typeConverter](VectorType type) -> Optional<Type> {
if (type.getRank() != 2) return llvm::None;
Type elementType = typeConverter.convertType(type.getElementType());
return spirv::CooperativeMatrixNVType::get(
elementType, spirv::Scope::Subgroup, type.getDimSize(0),
type.getDimSize(1));
});
// Inject another conversion rule for MemRef types.
//
// This is for consistency purpose: we will run FlattenMemRefSubspanPass
// later. That pass flattens all MemRefs into 1-D unknown-sized ones before
// invoking upstream SPIR-V type converter. So in the end all MemRefs will
// be converted into SPIR-V runtime arrays. But here if we don't inject the
// following rule, we'll convert MemRefs into constant-sized arrays. That
// would cause consistency issues. It's a bit unfortunate to have this; it's
// a result of performing cooperative matrix conversions earlier (it needs
// to be done before FlattenMemRefSubspanPass because we need 2-D MemRefs)
// and conversions spreading across upstream and IREE repos..
typeConverter.addConversion(
[&typeConverter](MemRefType type) -> Optional<Type> {
if (!type.hasStaticShape()) return llvm::None;
// In IREE all MemRefs are originated from subspan ops, which should
// have identity layout.
if (!type.getLayout().isIdentity()) return llvm::None;
auto flattenedType =
MemRefType::get(ShapedType::kDynamicSize, type.getElementType(),
AffineMap(), type.getMemorySpace());
return typeConverter.convertType(flattenedType);
});
// Add unrealized conversion cast ops to bridge type conversions: we are
// only converting the cooperative matrix subset; the rest needs to be done
// at a later stage.
auto addUnrealizedCast = [](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return Optional<Value>(cast.getResult(0));
};
typeConverter.addSourceMaterialization(addUnrealizedCast);
typeConverter.addTargetMaterialization(addUnrealizedCast);
RewritePatternSet patterns(context);
patterns.add<
ConvertConstantMatrix, ConvertVectorContractOp, ConvertVectorTransferOp,
// See SPV_NV_cooperative_matrix for supported element wise ops.
ConvertElementwiseOp<arith::AddFOp, spirv::FAddOp>,
ConvertElementwiseOp<arith::AddIOp, spirv::IAddOp>,
ConvertElementwiseOp<arith::SubFOp, spirv::FSubOp>,
ConvertElementwiseOp<arith::SubIOp, spirv::ISubOp>,
ConvertElementwiseOp<arith::DivFOp, spirv::FDivOp>,
ConvertElementwiseOp<arith::DivSIOp, spirv::SDivOp>,
ConvertElementwiseOp<arith::DivUIOp, spirv::UDivOp>,
ConvertElementwiseOp<arith::NegFOp, spirv::FNegateOp>>(typeConverter,
context);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
target->addLegalOp<UnrealizedConversionCastOp>();
target->addIllegalDialect<vector::VectorDialect>();
if (failed(applyPartialConversion(funcOp, *target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
createSPIRVVectorToCooperativeOpsPass() {
return std::make_unique<SPIRVVectorToCooperativeOpsPass>();
}
} // namespace iree_compiler
} // namespace mlir