blob: 67e1d794b8026babdce5280b3c62da5e32d11f01 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Conversion/PassDetail.h"
#include "iree/compiler/Conversion/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace mlir {
namespace iree_compiler {
namespace {
bool isLegalVectorContract(vector::ContractionOp contract) {
if (llvm::size(contract.masks()) != 0) return false;
VectorType lhsType = contract.lhs().getType().cast<VectorType>();
VectorType rhsType = contract.rhs().getType().cast<VectorType>();
VectorType accType = contract.acc().getType().cast<VectorType>();
std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1),
lhsType.getDimSize(1));
// Check if the matrix type can be supported as a cooperative matrix.
// Currently we have hardcoded checks for what Turing hardware supports.
// TODO(thomasraoux): Add device information to be able to query what the
// device supports.
if (lhsType.getElementType().isInteger(8) &&
rhsType.getElementType().isInteger(8) &&
accType.getElementType().isInteger(32) &&
(dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) ||
dim == std::make_tuple(16, 8, 32)))
return true;
if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() &&
(accType.getElementType().isF16() || accType.getElementType().isF32()) &&
(dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) ||
dim == std::make_tuple(16, 8, 16)))
return true;
return false;
}
bool supportsCooperativeMatrix(Operation *op) {
if (isa<vector::TransferReadOp, vector::TransferWriteOp, scf::ForOp,
scf::YieldOp>(op))
return true;
if (isa<vector::ContractionOp>(op) &&
isLegalVectorContract(cast<vector::ContractionOp>(op)))
return true;
// We only support minimal set of operations right now. We can trivially
// extend to ALU instructions supporting Cooperative Matrix in SPIR-V spec.
// We also need to extend to control flow operations, Alloca, etc...
// TODO(thomasraoux): extend support to more complex chain of instructions.
return false;
}
class CooperativeMatrixAnalysis {
public:
explicit CooperativeMatrixAnalysis(mlir::Operation *op) {
auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(op));
if (!targetEnv.allows(spirv::Capability::CooperativeMatrixNV) ||
!targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix))
return;
op->walk([&](Operation *op) {
auto contract = dyn_cast<vector::ContractionOp>(op);
if (contract == nullptr) return;
auto hasVectorDest = [](Operation *op) {
if (isa<ConstantOp, memref::AllocOp>(op)) return false;
for (auto resultType : op->getResultTypes()) {
if (resultType.isa<VectorType>()) return true;
}
if (op->getNumResults() == 0) return true;
return false;
};
auto dependentOps = getSlice(op, hasVectorDest, hasVectorDest);
for (auto *dependeOp : dependentOps) {
// If any instruction cannot use cooperative matrix drop the whole
// chaine. In the future we can introduce "bitcast" type of conversion
// to allow the same value to be used as both cooperative matrix as well
// as an array.
if (!supportsCooperativeMatrix(dependeOp)) return;
}
// All the dependent instruction can use cooperative matrix type. We can
// mark the whole chain of operations as using cooperative matrix.
usesCooperativeMatrix.insert(op);
usesCooperativeMatrix.insert(dependentOps.begin(), dependentOps.end());
});
}
// Returns true if the operation should be lowered using operations on
// cooperative matrix type.
bool usesCooperativeMatrixType(mlir::Operation *op) const {
return usesCooperativeMatrix.count(op);
}
private:
llvm::DenseSet<mlir::Operation *> usesCooperativeMatrix;
};
/// Convert subgroup level vector transfer to SPIR-V cooperative
/// matrix load/store if those are supported.
/// TODO(thomasraoux): Move to MLIR core once this is stable.
template <typename OpTy>
class TransferToCoopMatLoadStore final : public OpConversionPattern<OpTy> {
public:
TransferToCoopMatLoadStore(
MLIRContext *context, SPIRVTypeConverter &converter,
const CooperativeMatrixAnalysis &cooperativeMatrixAnalysis)
: OpConversionPattern<OpTy>(converter, context),
cooperativeMatrixAnalysis(cooperativeMatrixAnalysis) {}
LogicalResult matchAndRewrite(
OpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!cooperativeMatrixAnalysis.usesCooperativeMatrixType(op))
return failure();
if (op.mask()) return failure();
auto memrefType = op.getShapedType().template dyn_cast<MemRefType>();
if (!memrefType) return failure();
auto vecType = op.getVectorType();
if (vecType.getRank() != 2) return failure();
// TODO(thomasraoux): use coloumn major operand when TransfertRead +
// TransposeOp.
if (!op.permutation_map().isMinorIdentity()) return failure();
if (op.in_bounds() &&
llvm::any_of(op.in_bounds()->template cast<ArrayAttr>(),
[](mlir::Attribute dimInBounds) {
return !dimInBounds.cast<BoolAttr>().getValue();
}))
return failure();
int64_t offset = 0;
SmallVector<int64_t, 2> strides;
if (failed(getStridesAndOffset(memrefType, strides, offset)))
return failure();
auto stride = strides[0];
if (BaseMemRefType::isDynamicStrideOrOffset(stride)) return failure();
auto loc = op.getLoc();
typename OpTy::Adaptor adaptor(operands, op->getAttrDictionary());
auto matType = spirv::CooperativeMatrixNVType::get(
vecType.getElementType(), spirv::Scope::Subgroup, vecType.getDimSize(0),
vecType.getDimSize(1));
Value ptr = spirv::getElementPtr(
*this->template getTypeConverter<SPIRVTypeConverter>(), memrefType,
adaptor.source(), adaptor.indices(), loc, rewriter);
auto int32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
loc, int32Type, IntegerAttr::get(int32Type, stride));
auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
replaceTransferOp(op, adaptor, matType, ptr, strideValue, coloumnMajor,
rewriter);
return success();
}
private:
/// Generates the right load/store instruction and replaces the transfer op.
void replaceTransferOp(OpTy originalOp, typename OpTy::Adaptor newInputs,
Type matType, Value bufferPtr, Value strideValue,
Value coloumnMajor,
ConversionPatternRewriter &rewriter) const;
const CooperativeMatrixAnalysis &cooperativeMatrixAnalysis;
};
template <>
void TransferToCoopMatLoadStore<vector::TransferReadOp>::replaceTransferOp(
vector::TransferReadOp originalOp,
vector::TransferReadOp::Adaptor newInputs, Type matType, Value bufferPtr,
Value strideValue, Value coloumnMajor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<spirv::CooperativeMatrixLoadNVOp>(
originalOp, matType, bufferPtr, strideValue, coloumnMajor,
spirv::MemoryAccessAttr());
}
template <>
void TransferToCoopMatLoadStore<vector::TransferWriteOp>::replaceTransferOp(
vector::TransferWriteOp originalOp,
vector::TransferWriteOp::Adaptor newInputs, Type matType, Value bufferPtr,
Value strideValue, Value coloumnMajor,
ConversionPatternRewriter &rewriter) const {
rewriter.create<spirv::CooperativeMatrixStoreNVOp>(
originalOp.getLoc(), bufferPtr, newInputs.vector(), strideValue,
coloumnMajor, spirv::MemoryAccessAttr());
rewriter.eraseOp(originalOp);
}
/// Converts subgroup level vector contract to SPIR-V cooperative
/// matrix matmuladd.
class VectorContractToCoopMatmul final
: public OpConversionPattern<vector::ContractionOp> {
public:
VectorContractToCoopMatmul(
MLIRContext *context, SPIRVTypeConverter &converter,
const CooperativeMatrixAnalysis &cooperativeMatrixAnalysis)
: OpConversionPattern(converter, context),
cooperativeMatrixAnalysis(cooperativeMatrixAnalysis) {}
LogicalResult matchAndRewrite(
vector::ContractionOp contractOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!cooperativeMatrixAnalysis.usesCooperativeMatrixType(contractOp))
return failure();
if (llvm::size(contractOp.masks()) != 0) return failure();
// Check that this is a matmul operation.
auto iteratorTypes = contractOp.iterator_types().getValue();
if (!isParallelIterator(iteratorTypes[0]) ||
!isParallelIterator(iteratorTypes[1]) ||
!isReductionIterator(iteratorTypes[2]))
return failure();
// Coloumn major matmul should have been lowered to Transpose+contract
// by this point. Transpose can be handled by load/stoore operations.
if (!isRowMajorMatmul(contractOp.indexing_maps())) return failure();
vector::ContractionOp::Adaptor adaptor(operands);
auto loadA = adaptor.lhs();
auto loadB = adaptor.rhs();
auto loadC = adaptor.acc();
rewriter.replaceOpWithNewOp<spirv::CooperativeMatrixMulAddNVOp>(
contractOp, loadC.getType(), loadA, loadB, loadC);
return success();
}
private:
const CooperativeMatrixAnalysis &cooperativeMatrixAnalysis;
};
struct LinalgToSPIRVVectorToCooperativeMatrixPass final
: public LinalgToSPIRVVectorToCooperativeMatrixBase<
LinalgToSPIRVVectorToCooperativeMatrixPass> {
void runOnOperation() override {
FuncOp funcOp = getOperation();
auto targetAttr = spirv::lookupTargetEnv(funcOp);
SPIRVTypeConverter typeConverter(targetAttr);
typeConverter.addConversion(
[&typeConverter](MemRefType type) -> Optional<Type> {
if (!type.hasStaticShape()) return llvm::None;
auto flattenType =
MemRefType::get(ShapedType::kDynamicSize, type.getElementType(),
type.getAffineMaps(), type.getMemorySpace());
return typeConverter.convertType(flattenType);
});
typeConverter.addConversion([](VectorType type) -> Optional<Type> {
if (type.getRank() != 2) return llvm::None;
return spirv::CooperativeMatrixNVType::get(
type.getElementType(), spirv::Scope::Subgroup, type.getDimSize(0),
type.getDimSize(1));
});
auto addUnrealizedCast = [](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return Optional<Value>(cast.getResult(0));
};
typeConverter.addSourceMaterialization(addUnrealizedCast);
typeConverter.addTargetMaterialization(addUnrealizedCast);
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
auto &analysis = getAnalysis<CooperativeMatrixAnalysis>();
patterns.add<TransferToCoopMatLoadStore<vector::TransferReadOp>,
TransferToCoopMatLoadStore<vector::TransferWriteOp>,
VectorContractToCoopMatmul>(context, typeConverter, analysis);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
target->addLegalOp<UnrealizedConversionCastOp>();
if (failed(applyPartialConversion(funcOp, *target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
createLinalgToSPIRVVectorToCooperativeMatrixPass() {
return std::make_unique<LinalgToSPIRVVectorToCooperativeMatrixPass>();
}
} // namespace iree_compiler
} // namespace mlir