blob: eb1d3cc6b259bc8b6815386e7e15f6863e1a975e [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Utils/CustomKernelsTargetInfo.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace iree_compiler {
namespace {
// Returns true if `contractionOp` is of the form
// matrix * transposed_matrix.
// That is, if there are 2 parallel iterators, say M and N, 1 additive reduction
// iterator, say K, and the indexing maps are {{M, K}, {N, K}, {M, N}}.
static bool isMatrixTimesMatrixTransposed(vector::ContractionOp contractionOp) {
// Check that the reduction is additive.
if (contractionOp.kind() != vector::CombiningKind::ADD) {
return false;
}
// Check that there are 2 parallel and 1 reduction iterators.
auto iteratorTypes = contractionOp.iterator_types().getValue();
if (iteratorTypes.size() != 3) {
return false;
}
SmallVector<int, 3> parallel_iterators;
SmallVector<int, 3> reduction_iterators;
for (int i = 0; i < 3; i++) {
if (isParallelIterator(iteratorTypes[i])) {
parallel_iterators.push_back(i);
} else if (isReductionIterator(iteratorTypes[i])) {
reduction_iterators.push_back(i);
} else {
return false;
}
}
if (parallel_iterators.size() != 2 || reduction_iterators.size() != 1) {
return false;
}
// Give the found iterators some idiomatic names.
const int MIter = parallel_iterators[0];
const int NIter = parallel_iterators[1];
const int KIter = reduction_iterators[0];
// Check that there are 3 indexing maps.
auto indexingMaps = contractionOp.indexing_maps().getValue();
if (indexingMaps.size() != 3) {
return false;
}
// Check that the indexing maps have the expected form.
const int expectedMapResults[3][2] = {
{MIter, KIter}, {NIter, KIter}, {MIter, NIter}};
for (int m = 0; m < 3; ++m) {
auto map = indexingMaps[m].cast<AffineMapAttr>().getValue();
if (map.getNumDims() != 3 || map.getNumResults() != 2) {
return false;
}
for (int r = 0; r < 2; ++r) {
int actualMapResult =
map.getResults()[r].cast<AffineDimExpr>().getPosition();
if (actualMapResult != expectedMapResults[m][r]) {
return false;
}
}
}
return true;
}
// Returns true if `contractionOp` is of the form
// matrix * transposed_matrix
// where matrix is a vector<{mSize}x{kSize}xType>, and
// transposed_matrix is a vector<{nSize}x{kSize}xType>
static bool isMatrixTimesMatrixTransposedOfGivenShape(
vector::ContractionOp contractionOp, int64_t mSize, int64_t kSize,
int64_t nSize) {
if (!isMatrixTimesMatrixTransposed(contractionOp)) {
return false;
}
VectorType lhsType = contractionOp.lhs().getType().cast<VectorType>();
VectorType rhsType = contractionOp.rhs().getType().cast<VectorType>();
auto lhsShape = lhsType.getShape();
auto rhsShape = rhsType.getShape();
if (lhsShape[0] != mSize || lhsShape[1] != kSize || rhsShape[0] != nSize ||
rhsShape[1] != kSize) {
return false;
}
return true;
}
// Checks that the Value `extResult` is defined by an arith::ExtSIOp promoting
// from `extSrcType` to `extDstType`, and returns the input of the ExtSIOp.
// Note that this only looks at the immediately defining operation, so we likely
// want to have earlier passes that sink widening operations as far down as
// possible, which is probably just good regardless.
static Value getExtSIInput(Type extSrcType, Type extDstType, Value extResult) {
auto extSIOp = extResult.getDefiningOp<arith::ExtSIOp>();
if (!extSIOp) {
return nullptr;
}
Value extInput = extSIOp.getIn();
if (extInput.getType().cast<VectorType>().getElementType() != extSrcType) {
return nullptr;
}
return extInput;
}
// Helper to create a 1D, contiguous slice of a 1D vector.
static Value extract1DSlice(PatternRewriter &rewriter, Location loc,
VectorType dstVecType, Value input, int position) {
assert(input.getType().cast<VectorType>().getRank() == 1);
assert(dstVecType.getRank() == 1);
std::array<int64_t, 1> offsets{position};
std::array<int64_t, 1> strides{1};
return rewriter.create<vector::ExtractStridedSliceOp>(
loc, input, offsets, dstVecType.getShape(), strides);
}
// Helper to flatten a N-dimensional vector to a 1D vector.
static Value flatten(PatternRewriter &rewriter, Location loc, Value vector) {
VectorType inputVecType = vector.getType().cast<VectorType>();
VectorType dstType = VectorType::get(inputVecType.getNumElements(),
inputVecType.getElementType());
return rewriter.create<vector::ShapeCastOp>(loc, dstType, vector);
}
/// Converts matrix-times-matrix-transposed vector.contracts with
/// lhs and rhs inputs defined by arith.extsi promoting from i8 to i32,
///
/// %lhs_i32 = arith.extsi %lhs_i8 : i8 to i32
/// %rhs_i32 = arith.extsi %rhs_i8 : i8 to i32
/// %result = vector.contract [...]
/// %lhs_i32 : vector<8x4xi32>,
/// %rhs_i32 : vector<8x4xi32>,
/// %acc_i32 : vector<8x8xi32>,
/// [...]
///
/// To vector ops reading directly from the %lhs_i8 and %rhs_i8 values
/// (bypassing the existing arith.extsi) and passing that to a llvm.inline_asm
/// block implementing the matrix multiplication arithmetic using Aarch64
/// dot-product instructions (sdot).
struct MMT_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm
: OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractionOp,
PatternRewriter &rewriter) const override {
// Check if `contractionOp` matches, and obtain the un-promoted i8 input
// LHS and RHS vectors, `lhsI8` and `rhsI8`.
if (!isMatrixTimesMatrixTransposedOfGivenShape(contractionOp, 8, 4, 8)) {
return failure();
}
Type I8Type = rewriter.getIntegerType(8);
Type I32Type = rewriter.getIntegerType(32);
VectorType accType = contractionOp.acc().getType().cast<VectorType>();
if (accType.getElementType() != I32Type) {
return failure();
}
Value lhsI8 = getExtSIInput(I8Type, I32Type, contractionOp.lhs());
Value rhsI8 = getExtSIInput(I8Type, I32Type, contractionOp.rhs());
if (!lhsI8 || !rhsI8) {
return failure();
}
// `contractionOp` matches, start rewriting it. We only reference
// the `lhsI8` and `rhsI8` values obtained above as the inputs of the
// arith.extsi, so this rewrite will leave the existing arith.extsi without
// any user (unless something else was using them), so they may be
// removed by another transformation.
Location loc = contractionOp.getLoc();
// Flatten the inputs to 1D vectors.
Value flatLhsI8 = flatten(rewriter, loc, lhsI8);
Value flatRhsI8 = flatten(rewriter, loc, rhsI8);
Value flatAcc = flatten(rewriter, loc, contractionOp.acc());
// Create the 1D input vectors of 16 bytes each that are directly what
// the target SIMD instructions will want.
SmallVector<Value> lhsVec;
SmallVector<Value> rhsVec;
VectorType vector16xi8Type = VectorType::get({16}, I8Type);
for (int position = 0; position < 8 * 4; position += 16) {
lhsVec.push_back(
extract1DSlice(rewriter, loc, vector16xi8Type, flatLhsI8, position));
rhsVec.push_back(
extract1DSlice(rewriter, loc, vector16xi8Type, flatRhsI8, position));
}
SmallVector<Value> accVec;
VectorType int32x4Type = VectorType::get({4}, I32Type);
for (int position = 0; position < 8 * 8; position += 4) {
accVec.push_back(
extract1DSlice(rewriter, loc, int32x4Type, flatAcc, position));
}
// Create the inline asm op's operands list.
SmallVector<Value> asmOperands;
// First the inputs operands.
asmOperands.append(lhsVec);
asmOperands.append(rhsVec);
// Then the input-output operands.
asmOperands.append(accVec);
SmallVector<Type> asmOutputOperandTypes(
llvm::map_range(accVec, [](Value v) { return v.getType(); }));
// Create the inline asm op.
auto returnType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
asmOutputOperandTypes);
auto dialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
// The LLVM inline asm syntax is documented here:
// https://llvm.org/docs/LangRef.html#inline-assembler-expressions
LLVM::InlineAsmOp asmOp = rewriter.create<LLVM::InlineAsmOp>(
loc, returnType, asmOperands,
R"ASM(
sdot $0.4s, $18.16b, $16.4b[0]
sdot $1.4s, $19.16b, $16.4b[0]
sdot $2.4s, $18.16b, $16.4b[1]
sdot $3.4s, $19.16b, $16.4b[1]
sdot $4.4s, $18.16b, $16.4b[2]
sdot $5.4s, $19.16b, $16.4b[2]
sdot $6.4s, $18.16b, $16.4b[3]
sdot $7.4s, $19.16b, $16.4b[3]
sdot $8.4s, $18.16b, $17.4b[0]
sdot $9.4s, $19.16b, $17.4b[0]
sdot $10.4s, $18.16b, $17.4b[1]
sdot $11.4s, $19.16b, $17.4b[1]
sdot $12.4s, $18.16b, $17.4b[2]
sdot $13.4s, $19.16b, $17.4b[2]
sdot $14.4s, $18.16b, $17.4b[3]
sdot $15.4s, $19.16b, $17.4b[3]
)ASM",
"=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,w,w,w,w,0,1,2,3,4,5,6,"
"7,8,9,10,11,12,13,14,15",
/*has_side_effects=*/false, /*is_align_stack=*/false, dialectAttr,
/*operand_attrs=*/ArrayAttr());
// Extract result vectors from the asm op.
SmallVector<Value, 16> resVec;
for (int i = 0; i < 16; ++i) {
resVec.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, int32x4Type, asmOp.getRes(), rewriter.getI64ArrayAttr({i})));
}
// Insert the result vectors of size 4 into the overall result vector of
// size 64, still 1D.
VectorType int32x64xType = VectorType::get({64}, I32Type);
Value result = rewriter.create<arith::ConstantOp>(
loc, int32x64xType, DenseIntElementsAttr::get(int32x64xType, 0));
for (int i = 0; i < 16; ++i) {
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, resVec[i], result, std::array<int64_t, 1>{4 * i},
std::array<int64_t, 1>{1});
}
// Cast the result from 1D to 2D and replace the original vector.contract.
VectorType int32x8x8xType = VectorType::get({8, 8}, I32Type);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(contractionOp,
int32x8x8xType, result);
return success();
}
};
/// Converts matrix-times-matrix-transposed vector.contracts with
/// lhs and rhs inputs defined by arith.extsi promoting from i8 to i32,
///
/// %lhs_i32 = arith.extsi %lhs_i8 : i8 to i32
/// %rhs_i32 = arith.extsi %rhs_i8 : i8 to i32
/// %result = vector.contract [...]
/// %lhs_i32 : vector<8x4xi32>,
/// %rhs_i32 : vector<8x4xi32>,
/// %acc_i32 : vector<8x8xi32>,
/// [...]
///
/// To vector ops reading directly from the %lhs_i8 and %rhs_i8 values
/// (bypassing the existing arith.extsi) and passing that to a llvm.inline_asm
/// block implementing the matrix multiplication arithmetic using Aarch64
/// dot-product instructions (sdot).
/// It matches the same patterns as MMT_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm
struct MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractionOp,
PatternRewriter &rewriter) const override {
if (!isMatrixTimesMatrixTransposedOfGivenShape(contractionOp, 8, 4, 8)) {
return failure();
}
Type I8Type = rewriter.getIntegerType(8);
Type I32Type = rewriter.getIntegerType(32);
auto acc = contractionOp.acc();
auto lhs = contractionOp.lhs();
auto rhs = contractionOp.rhs();
if (acc.getType().cast<VectorType>().getElementType() != I32Type) {
return failure();
}
Value inLhs = getExtSIInput(I8Type, I32Type, lhs);
Value inRhs = getExtSIInput(I8Type, I32Type, rhs);
if (!inLhs || !inRhs) return failure();
auto loc = contractionOp.getLoc();
auto int32x4VType = VectorType::get({4}, I32Type);
std::array<Value, 16> accChunks;
{
int idx = 0;
for (int row = 0; row < 8; ++row) {
auto accRow = rewriter.create<vector::ExtractOp>(
loc, acc, ArrayRef<int64_t>{row});
for (int col = 0; col < 8; col += 4) {
auto accChunk = rewriter.create<vector::ExtractStridedSliceOp>(
loc, accRow, ArrayRef<int64_t>{col}, ArrayRef<int64_t>{4},
ArrayRef<int64_t>{1});
assert(accChunk.getType() == int32x4VType);
accChunks[idx++] = accChunk;
}
}
}
auto int8x4x4VType = VectorType::get({4, 4}, rewriter.getIntegerType(8));
auto extract4x4 = [&](Value in, int rowOffset, int colOffset) {
auto chunk = rewriter.create<vector::ExtractStridedSliceOp>(
loc, in, ArrayRef<int64_t>{rowOffset, colOffset},
ArrayRef<int64_t>{4, 4}, ArrayRef<int64_t>{1, 1});
assert(chunk.getType() == int8x4x4VType);
return chunk;
};
std::array<Value, 2> lhsHalves = {extract4x4(inLhs, 0, 0),
extract4x4(inLhs, 4, 0)};
std::array<Value, 2> rhsHalves = {extract4x4(inRhs, 0, 0),
extract4x4(inRhs, 4, 0)};
auto int8Zero4x4 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(int8x4x4VType));
auto sdot = [&](Value acc, Value a, Value b, int64_t lane) -> Value {
auto bReplicatedLane = rewriter.create<vector::ShuffleOp>(
loc, b, int8Zero4x4, ArrayRef<int64_t>{lane, lane, lane, lane});
return rewriter.create<arm_neon::Sdot2dOp>(loc, int32x4VType, acc, a,
bReplicatedLane);
};
std::array<Value, 16> dstChunks;
{
int idx = 0;
for (Value lhs : lhsHalves) {
for (int lane = 0; lane < 4; ++lane) {
for (Value rhs : rhsHalves) {
dstChunks[idx] = sdot(accChunks[idx], rhs, lhs, lane);
++idx;
}
}
}
}
// Put the results back in the accumulator
{
int idx = 0;
for (int row = 0; row < 8; ++row) {
for (int col = 0; col < 8; col += 4) {
acc = rewriter.create<vector::InsertStridedSliceOp>(
loc, dstChunks[idx++], acc, ArrayRef<int64_t>{row, col},
ArrayRef<int64_t>{1});
}
}
}
rewriter.replaceOp(contractionOp, {acc});
return success();
}
};
class VectorContractCustomKernelsPass
: public VectorContractCustomKernelsBase<VectorContractCustomKernelsPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect, LLVM::LLVMDialect>();
if (target_info.has(CustomKernelTargetFeature::Intrinsics)) {
registry.insert<arm_neon::ArmNeonDialect>();
}
}
LogicalResult initializeOptions(StringRef options) override {
if (failed(Pass::initializeOptions(options))) {
return failure();
}
if (failed(ParseCustomKernelsTargetInfo(arch, features, target_info))) {
return failure();
}
if (intrinsics) {
target_info.add(CustomKernelTargetFeature::Intrinsics);
}
return success();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
populateVectorContractCustomKernelsPatterns(target_info, patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
signalPassFailure();
}
}
private:
CustomKernelsTargetInfo target_info;
};
} // namespace
void populateVectorContractCustomKernelsPatterns(
const CustomKernelsTargetInfo &target_info, RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
if (target_info.has(CustomKernelTargetFeature::Aarch64Dotprod)) {
if (target_info.has(CustomKernelTargetFeature::Intrinsics)) {
patterns.insert<MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics>(context);
} else {
patterns.insert<MMT_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm>(context);
}
}
}
std::unique_ptr<OperationPass<FuncOp>> createVectorContractCustomKernelsPass() {
return std::make_unique<VectorContractCustomKernelsPass>();
}
} // namespace iree_compiler
} // namespace mlir