| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "iree/compiler/Utils/CustomKernelsTargetInfo.h" |
| #include "iree/compiler/Utils/StringUtils.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ADT/Triple.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" |
| #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" |
| #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| |
| // Returns true if `contractionOp` is of the form |
| // matrix * transposed_matrix. |
| // That is, if there are 2 parallel iterators, say M and N, 1 additive reduction |
| // iterator, say K, and the indexing maps are {{M, K}, {N, K}, {M, N}}. |
| static bool isMatrixTimesMatrixTransposed(vector::ContractionOp contractionOp) { |
| // Check that the reduction is additive. |
| if (contractionOp.kind() != vector::CombiningKind::ADD) { |
| return false; |
| } |
| // Check that there are 2 parallel and 1 reduction iterators. |
| auto iteratorTypes = contractionOp.iterator_types().getValue(); |
| if (iteratorTypes.size() != 3) { |
| return false; |
| } |
| SmallVector<int, 3> parallelIterators; |
| SmallVector<int, 3> reductionIterators; |
| for (int i = 0; i < 3; i++) { |
| if (isParallelIterator(iteratorTypes[i])) { |
| parallelIterators.push_back(i); |
| } else if (isReductionIterator(iteratorTypes[i])) { |
| reductionIterators.push_back(i); |
| } else { |
| return false; |
| } |
| } |
| if (parallelIterators.size() != 2 || reductionIterators.size() != 1) { |
| return false; |
| } |
| // Give the found iterators some idiomatic names. |
| const int MIter = parallelIterators[0]; |
| const int NIter = parallelIterators[1]; |
| const int KIter = reductionIterators[0]; |
| // Check that there are 3 indexing maps. |
| auto indexingMaps = contractionOp.indexing_maps(); |
| if (indexingMaps.size() != 3) { |
| return false; |
| } |
| // Check that the indexing maps have the expected form. |
| const int expectedMapResults[3][2] = { |
| {MIter, KIter}, {NIter, KIter}, {MIter, NIter}}; |
| for (int m = 0; m < 3; ++m) { |
| auto map = indexingMaps[m]; |
| if (map.getNumDims() != 3 || map.getNumResults() != 2) { |
| return false; |
| } |
| for (int r = 0; r < 2; ++r) { |
| int actualMapResult = |
| map.getResults()[r].cast<AffineDimExpr>().getPosition(); |
| if (actualMapResult != expectedMapResults[m][r]) { |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| |
| // Returns true if `contractionOp` is of the form |
| // matrix * transposed_matrix |
| // where matrix is a vector<{mSize}x{kSize}xType>, and |
| // transposed_matrix is a vector<{nSize}x{kSize}xType>. |
| // |
| // Also returns true if the above condition is met after swapping |
| // mSize<->nSize and one of these two values is 1, and `transpose` is not null. |
| // In that case, the output-param `*transpose` is set to true. Rationale: we |
| // want to use the same kernel for vector*matrix and matrix*vector. The good |
| // thing with MMT, namely |
| // |
| // A * Transpose(B) |
| // |
| // is that swapping A and B merely transposes the result: |
| // |
| // B * Transpose(A) = Transpose( A * Transpose(B) ) |
| // |
| // This opens the possibility of reducing vector*matrix to matrix*vector |
| // by merely swappign LHS<->RHS. Why is this specific to the case where one of |
| // the sides is a vector? Because transposing the result is not OK in general, |
| // we don't want to write out the result accumulators in the wrong storage |
| // order. However, when one of the two sides is a vector, so is the result |
| // accumulator, and for a vector shape (i.e. Mx1 or 1xN), storage orders do not |
| // matter. |
| static bool matchMMT(vector::ContractionOp contractionOp, int64_t mSize, |
| int64_t kSize, int64_t nSize, bool *transpose = nullptr) { |
| if (!isMatrixTimesMatrixTransposed(contractionOp)) { |
| return false; |
| } |
| VectorType lhsType = contractionOp.lhs().getType().cast<VectorType>(); |
| VectorType rhsType = contractionOp.rhs().getType().cast<VectorType>(); |
| auto lhsShape = lhsType.getShape(); |
| auto rhsShape = rhsType.getShape(); |
| if (lhsShape[1] != kSize || rhsShape[1] != kSize) { |
| return false; |
| } |
| if (lhsShape[0] == mSize && rhsShape[0] == nSize) { |
| return true; |
| } |
| if (lhsShape[0] == nSize && rhsShape[0] == mSize && transpose != nullptr) { |
| *transpose = true; |
| return true; |
| } |
| return false; |
| } |
| |
| // `promotedResult` is required to be a Vector. |
| // If its VectorType does not have `promotedType` as its element type, or |
| // the operand to the type-promotion op is not `unpromotedType` returns a null |
| // Value. |
| // If `unpromotedType == promotedType`, return `promotedResult` unchanged. |
| // Otherwise, checks that `promotedResult` is defined by a type-promotion op |
| // (such as arith::ExtSIOp) promoting from `unpromotedType` to `promotedType`, |
| // and returns the input of that promotion op. |
| // Note that this only looks at the immediately defining operation, so we likely |
| // want to have earlier passes that sink widening operations as far down as |
| // possible, which is probably just good regardless. |
| static Value getUnpromotedInput(Type unpromotedType, Type promotedType, |
| Value promotedResult) { |
| VectorType promotedResultVectorType = |
| promotedResult.getType().cast<VectorType>(); |
| if (promotedResultVectorType.getElementType() != promotedType) { |
| return nullptr; |
| } |
| if (unpromotedType == promotedType) { |
| return promotedResult; |
| } |
| // TODO: handle promotion of floating point types. Not doing it for now as |
| // it wouldn't be exercised. |
| auto extSIOp = promotedResult.getDefiningOp<arith::ExtSIOp>(); |
| if (!extSIOp) { |
| return nullptr; |
| } |
| Value extInput = extSIOp.getIn(); |
| if (extInput.getType().cast<VectorType>().getElementType() != |
| unpromotedType) { |
| return nullptr; |
| } |
| return extInput; |
| } |
| |
| // Helper to create a 1D, contiguous slice of a 1D vector. |
| static Value extract1DSlice(PatternRewriter &rewriter, Location loc, |
| VectorType dstVecType, Value input, int position) { |
| assert(input.getType().cast<VectorType>().getRank() == 1); |
| assert(dstVecType.getRank() == 1); |
| std::array<int64_t, 1> offsets{position}; |
| std::array<int64_t, 1> strides{1}; |
| return rewriter.create<vector::ExtractStridedSliceOp>( |
| loc, input, offsets, dstVecType.getShape(), strides); |
| } |
| |
| // Helper to extract an element of a 1D vector. |
| static Value extract(PatternRewriter &rewriter, Location loc, Value input, |
| int position) { |
| VectorType vectorType = input.getType().cast<VectorType>(); |
| assert(vectorType.getRank() == 1); |
| (void)vectorType; |
| std::array<int64_t, 1> offsets{position}; |
| return rewriter.create<vector::ExtractOp>(loc, input, offsets); |
| } |
| |
| // Helper to flatten a N-dimensional vector to a 1D vector. |
| static Value flatten(PatternRewriter &rewriter, Location loc, Value vector) { |
| VectorType inputVecType = vector.getType().cast<VectorType>(); |
| VectorType dstType = VectorType::get(inputVecType.getNumElements(), |
| inputVecType.getElementType()); |
| return rewriter.create<vector::ShapeCastOp>(loc, dstType, vector); |
| } |
| |
| // Describes a kernel. This struct is kept small to separate the kernels |
| // themselves from the MLIR-specific generators consuming them |
| // (see MMTKernelGenerator). |
| // |
| // There is some redundancy among this struct's fields: see the relationships |
| // between fields that are enforced in validate(). This redundancy helps: |
| // (1) Avoid having to perform divisions (performance concern, and readability |
| // concern as would care for these divisions to be exact). |
| // (2) Be explicit about the size of the vectors involved in the kernel's |
| // "calling convention". |
| struct MMTKernel { |
| enum class ScalarType : int8_t { None, I8, I32, F32 }; |
| // Target architecture. Needed to generate inline asm constraints. |
| CustomKernelTargetArch arch = CustomKernelTargetArch::None; |
| // Element type of the LHS vectors. |
| ScalarType lhsType = ScalarType::None; |
| // Element type of the RHS vectors. |
| ScalarType rhsType = ScalarType::None; |
| // Element type of the Accumulator and output vectors. |
| ScalarType accType = ScalarType::None; |
| // Number of rows of the LHS and Accumulator tile. |
| int8_t m0 = 0; |
| // Reduction dimension, i.e. number of columns of the LHS. |
| int8_t k0 = 0; |
| // Number of rows of the RHS (note that the operation being targeted, MMT, |
| // is matrix multiplication with a *transposed* RHS) |
| int8_t n0 = 0; |
| // Number of LHS elements in the type of register to be used for the LHS. |
| // This is > 1 if SIMD registers are to be used. |
| // Note: LHS/RHS/Accumulator may use registers of different sizes. |
| int8_t lhsRegSize = 0; |
| // Number of RHS elements fitting in the type of register to be used for RHS. |
| int8_t rhsRegSize = 0; |
| // Number of Accumulator elements fitting in the type of register to be used |
| // for the accumulator. |
| int8_t accRegSize = 0; |
| // Number of registers needed to hold the LHS. |
| int8_t lhsRegs = 0; |
| // Number of registers needed to hold the RHS. |
| int8_t rhsRegs = 0; |
| // Number of registers needed to hold the Accumulator. |
| int8_t accRegs = 0; |
| // If not null, points to the inline asm code template for this kernel. |
| // Register operands for the LHS, RHS and Accumulator are to be referenced as |
| // $(lhs:<i>), $(rhs:<i>), $(acc:<i>) respectively, where i is a decimal |
| // integer specifying the i-th register for each case (numbered independently, |
| // so each starts at 0). |
| const char *asmImpl = nullptr; |
| // If not null, points to the clobbers list, i.e. the list of registers |
| // that the compiler will reserve for this inline asm block's use, in addition |
| // to the ones implicitly allocated for the declared inputs and outputs. Using |
| // C inline_asm syntax: comma-separated list of raw register names e.g. |
| // "v14,v15" |
| const char *asmClobbers = nullptr; |
| |
| void validate() const { |
| assert(m0 * k0 == lhsRegSize * lhsRegs); // number of elements of LHS |
| assert(n0 * k0 == rhsRegSize * rhsRegs); // number of elements of RHS |
| assert(m0 * n0 == accRegSize * accRegs); // number of elements of Accum |
| assert(lhsType != ScalarType::None); |
| assert(rhsType != ScalarType::None); |
| assert(accType != ScalarType::None); |
| } |
| }; |
| |
| // i8*i8->i32 kernel for Aarch64 NEON. |
| // |
| // Historically certain such kernels [1] required int8 inputs not have the |
| // value -128, which enabled a different kernel design taking advantage |
| // of the narrow range to accumulate once within int16 accumulators without |
| // overflow. These kernels were a 1.5x speedup on some late-2010s out-of-order |
| // cores (ARM Cortex A57/A72/A73, Apple A6--A12, Samsung Exynos M3), but became |
| // obsolete with the +dotprod feature (ARM Cortex-A76, Apple A13), and never |
| // were useful on in-order ARM Cortex-A53/A55. So going forward, they are not |
| // anymore a useful trade-off even in frameworks (such as TensorFlow Lite) that |
| // are designed to avoid -128 values. There is a large ecosystem cost in |
| // maintaining that restriction, and it wouldn't make sense to introduce it now |
| // in new frameworks such as MLIR or IREE, so the present kernel is general, |
| // supports arbitrary int8 values and does not try to use such optimizations. |
| // |
| // This kernel is needed because: at the moment, the codegen has multiple |
| // issues. It uses inefficient scalar memory access instructions, |
| // expands int8 values to int32, and performs slow int32*int32 multiplications: |
| // 118d8: f0 12 c0 39 ldrsb w16, [x23, #4] |
| // ... |
| // 118f4: 1b 0e 04 4e dup v27.4s, w16 |
| // ... |
| // 11900: 32 97 bb 4e mla v18.4s, v25.4s, v27.4s |
| // 11904: 57 97 bb 4e mla v23.4s, v26.4s, v27.4s |
| // |
| // |
| // [1]: |
| // https://github.com/google/ruy/blob/2d950b3bfa7ebfbe7a97ecb44b1cc4da5ac1d6f0/ruy/kernel_arm64.cc#L93 |
| MMTKernel MMTKernel_8x1x8_i8i8i32_Aarch64_Baseline_InlineAsm() { |
| MMTKernel kernel; |
| kernel.arch = CustomKernelTargetArch::Aarch64; |
| kernel.lhsType = MMTKernel::ScalarType::I8; |
| kernel.rhsType = MMTKernel::ScalarType::I8; |
| kernel.accType = MMTKernel::ScalarType::I32; |
| kernel.m0 = 8; // shape: 8x1x8, outer-product. |
| kernel.k0 = 1; // note: we would have enough registers to widen to 12x1x8 |
| kernel.n0 = 8; // if needed. |
| kernel.lhsRegSize = 8; // LHS NEON register type: int8x8 |
| kernel.rhsRegSize = 8; // RHS NEON register type: int8x8 |
| kernel.accRegSize = 4; // Accum NEON register type: int32x4 |
| kernel.lhsRegs = 1; |
| kernel.rhsRegs = 1; |
| kernel.accRegs = 16; // = 8*8/4 for 8x8 accumulators, 4 per register |
| kernel.asmImpl = R"ASM( |
| // NEON does not have instructions to multiply int8 values and accumulate |
| // into int32. This kernel sign-extends int8 to int16, then uses |
| // smlal[2] to multiply-accumulate int16 values into int32 accumulators. |
| sxtl v14.8h, $(lhs:0).8b // v14.8h = sign-extend LHS int8 to int16 |
| sxtl v15.8h, $(rhs:0).8b // v15.8h = sign-extend RHS int8 to int16 |
| smlal $(acc:0).4s, v15.4h, v14.h[0] |
| smlal2 $(acc:1).4s, v15.8h, v14.h[0] |
| smlal $(acc:2).4s, v15.4h, v14.h[1] |
| smlal2 $(acc:3).4s, v15.8h, v14.h[1] |
| smlal $(acc:4).4s, v15.4h, v14.h[2] |
| smlal2 $(acc:5).4s, v15.8h, v14.h[2] |
| smlal $(acc:6).4s, v15.4h, v14.h[3] |
| smlal2 $(acc:7).4s, v15.8h, v14.h[3] |
| smlal $(acc:8).4s, v15.4h, v14.h[4] |
| smlal2 $(acc:9).4s, v15.8h, v14.h[4] |
| smlal $(acc:10).4s, v15.4h, v14.h[5] |
| smlal2 $(acc:11).4s, v15.8h, v14.h[5] |
| smlal $(acc:12).4s, v15.4h, v14.h[6] |
| smlal2 $(acc:13).4s, v15.8h, v14.h[6] |
| smlal $(acc:14).4s, v15.4h, v14.h[7] |
| smlal2 $(acc:15).4s, v15.8h, v14.h[7] |
| )ASM"; |
| kernel.asmClobbers = "v14,v15"; |
| return kernel; |
| } |
| |
| // i8*i8->i32 kernel for Aarch64 NEON, matrix*vector |
| // |
| // This kernel is needed because: at the moment, the codegen is generating |
| // 177 instructions for this kernel (not peeled). |
| MMTKernel MMTKernel_8x8x1_i8i8i32_Aarch64_Baseline_InlineAsm() { |
| MMTKernel kernel; |
| kernel.arch = CustomKernelTargetArch::Aarch64; |
| kernel.lhsType = MMTKernel::ScalarType::I8; |
| kernel.rhsType = MMTKernel::ScalarType::I8; |
| kernel.accType = MMTKernel::ScalarType::I32; |
| kernel.m0 = 8; // shape: 8x8x1, matrix*vector |
| kernel.k0 = 8; |
| kernel.n0 = 1; |
| kernel.lhsRegSize = 16; // LHS NEON register type: int8x16 |
| kernel.rhsRegSize = 8; // RHS NEON register type: int8x8 |
| kernel.accRegSize = 4; // Accum NEON register type: int32x4 |
| kernel.lhsRegs = 4; // = 8x8/16 for 8x8 LHS elems, 16 per register |
| kernel.rhsRegs = 1; |
| kernel.accRegs = 2; // = 8/4 for 8 accumulators, 4 per register |
| kernel.asmImpl = R"ASM( |
| // This kernel multiplies int8 values into temporary int16 values in |
| // registers v8--v15, then performs additions. We can't use |
| // multiply-accumulate instructions here because of the lack of an |
| // instruction multiplying int8 values and accumulating into int32, and |
| // we prefer to avoid the overhead of sign-extending the inputs from int8 |
| // to int16 in this matrix*vector kernel where the largest matrix is the |
| // LHS. |
| ins v15.d[1], $(rhs:0).d[0] // copy 1st half of $(rhs:0) to 2nd half of v15 |
| smull v8.8h, $(lhs:0).8b, $(rhs:0).8b |
| smull2 v9.8h, $(lhs:0).16b, v15.16b |
| smull v10.8h, $(lhs:1).8b, $(rhs:0).8b |
| smull2 v11.8h, $(lhs:1).16b, v15.16b |
| smull v12.8h, $(lhs:2).8b, $(rhs:0).8b |
| smull2 v13.8h, $(lhs:2).16b, v15.16b |
| smull v14.8h, $(lhs:3).8b, $(rhs:0).8b |
| smull2 v15.8h, $(lhs:3).16b, v15.16b |
| // Now if we were able to codegen not just this MMT in isolation but |
| // a whole loop, we would diverge at this point: instead of doing the full |
| // additive reduction that the instructions below do, we would do only |
| // minimal reductions to temporary int32 accumulators |
| // (e.g. sadalp tmp.4s, v8.8h) and we would defer the rest of the work |
| // to the end of the loop. This is an example of how "MMT vector.contract" |
| // is not a perfect abstraction for "basic block of a MMT inner loop". |
| // Anyway... |
| // |
| // pairwise additions of int16 lanes to int32. |
| // So each result int32 is the sum of 2 products. |
| saddlp v8.4s, v8.8h |
| saddlp v9.4s, v9.8h |
| saddlp v10.4s, v10.8h |
| saddlp v11.4s, v11.8h |
| saddlp v12.4s, v12.8h |
| saddlp v13.4s, v13.8h |
| saddlp v14.4s, v14.8h |
| saddlp v15.4s, v15.8h |
| // pairwise additions of int32s, so each result is the sum of 4 products. |
| addp v8.4s, v8.4s, v9.4s |
| addp v10.4s, v10.4s, v11.4s |
| addp v12.4s, v12.4s, v13.4s |
| addp v14.4s, v14.4s, v15.4s |
| // pairwise additions of int32s, so each result is the sum of 8 products. |
| addp v8.4s, v8.4s, v10.4s |
| addp v12.4s, v12.4s, v14.4s |
| // Add to destination accumulators |
| add $(acc:0).4s, $(acc:0).4s, v8.4s |
| add $(acc:1).4s, $(acc:1).4s, v12.4s |
| )ASM"; |
| kernel.asmClobbers = "v8,v9,v10,v11,v12,v13,v14,v15"; |
| return kernel; |
| } |
| |
| // i8*i8->i32 kernel for Aarch64 NEON +dotprod |
| // |
| // This kernel is needed because: at the moment, codegen doesn't know how to |
| // make use of dotprod instructions. |
| MMTKernel MMTKernel_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm() { |
| MMTKernel kernel; |
| kernel.arch = CustomKernelTargetArch::Aarch64; |
| kernel.lhsType = MMTKernel::ScalarType::I8; |
| kernel.rhsType = MMTKernel::ScalarType::I8; |
| kernel.accType = MMTKernel::ScalarType::I32; |
| kernel.m0 = 8; // shape: 8x4x8. We would have enough registers to widen this |
| kernel.k0 = 4; // to 12x4x8 if needed. |
| kernel.n0 = 8; |
| kernel.lhsRegSize = 16; // LHS NEON register type: int8x16 |
| kernel.rhsRegSize = 16; // RHS NEON register type: int8x16 |
| kernel.accRegSize = 4; // Accum NEON register type: int32x4 |
| kernel.lhsRegs = 2; // = 8x4/16 for 8x4 LHS elems, 16 per register |
| kernel.rhsRegs = 2; // = 8x4/16 for 8x4 RHS elems, 16 per register |
| kernel.accRegs = 16; // = 8x8/4 for 8x8 Accum elems, 4 per register |
| kernel.asmImpl = R"ASM( |
| // Note on the operands ordering: RHS before LHS, because we want |
| // to multiply a 4x4 tile from RHS against a row-vector from LHS to |
| // produce a row-vector of Accumulators, because the accumulator |
| // needs to be row-major. |
| sdot $(acc:0).4s, $(rhs:0).16b, $(lhs:0).4b[0] |
| sdot $(acc:1).4s, $(rhs:1).16b, $(lhs:0).4b[0] |
| sdot $(acc:2).4s, $(rhs:0).16b, $(lhs:0).4b[1] |
| sdot $(acc:3).4s, $(rhs:1).16b, $(lhs:0).4b[1] |
| sdot $(acc:4).4s, $(rhs:0).16b, $(lhs:0).4b[2] |
| sdot $(acc:5).4s, $(rhs:1).16b, $(lhs:0).4b[2] |
| sdot $(acc:6).4s, $(rhs:0).16b, $(lhs:0).4b[3] |
| sdot $(acc:7).4s, $(rhs:1).16b, $(lhs:0).4b[3] |
| sdot $(acc:8).4s, $(rhs:0).16b, $(lhs:1).4b[0] |
| sdot $(acc:9).4s, $(rhs:1).16b, $(lhs:1).4b[0] |
| sdot $(acc:10).4s, $(rhs:0).16b, $(lhs:1).4b[1] |
| sdot $(acc:11).4s, $(rhs:1).16b, $(lhs:1).4b[1] |
| sdot $(acc:12).4s, $(rhs:0).16b, $(lhs:1).4b[2] |
| sdot $(acc:13).4s, $(rhs:1).16b, $(lhs:1).4b[2] |
| sdot $(acc:14).4s, $(rhs:0).16b, $(lhs:1).4b[3] |
| sdot $(acc:15).4s, $(rhs:1).16b, $(lhs:1).4b[3] |
| )ASM"; |
| return kernel; |
| } |
| |
| // i8*i8->i32 kernel for Aarch64 NEON +dotprod, matrix*vector |
| // |
| // This kernel is needed because: at the moment, codegen doesn't know how to |
| // make use of dotprod instructions. |
| MMTKernel MMTKernel_8x4x1_i8i8i32_Aarch64Dotprod_InlineAsm() { |
| MMTKernel kernel; |
| kernel.arch = CustomKernelTargetArch::Aarch64; |
| kernel.lhsType = MMTKernel::ScalarType::I8; |
| kernel.rhsType = MMTKernel::ScalarType::I8; |
| kernel.accType = MMTKernel::ScalarType::I32; |
| kernel.m0 = 8; // shape: 8x4x1. |
| kernel.k0 = 4; |
| kernel.n0 = 1; |
| kernel.lhsRegSize = 16; // LHS NEON register type: int8x16 |
| kernel.rhsRegSize = 4; // RHS NEON register type: int8x4. This is very small |
| // and forces sub-optimal codegen. This needs to be |
| // widened by peeling the surrounding loop, not by |
| // increasing the k0 of this MMT, which would change |
| // the data layout in an unwanted way. |
| kernel.accRegSize = 4; // LHS NEON register type: int8x16 |
| kernel.lhsRegs = 2; // = 8x4/16 for 8x4 LHS elems, 16 per register |
| kernel.rhsRegs = 1; // = 4/4 for 4 LHS elems, 4 per register |
| kernel.accRegs = 2; // = 8/4 for 8 Accum elems, 4 per register |
| kernel.asmImpl = R"ASM( |
| sdot $(acc:0).4s, $(lhs:0).16b, $(rhs:0).4b[0] |
| sdot $(acc:1).4s, $(lhs:1).16b, $(rhs:0).4b[0] |
| )ASM"; |
| return kernel; |
| } |
| |
| // i8*i8->i32 kernel for Aarch64 NEON +i8mm |
| // |
| // This kernel is needed because: at the moment, codegen doesn't know how to |
| // make use of i8mm instructions. |
| MMTKernel MMTKernel_8x8x8_i8i8i32_Aarch64I8mm_InlineAsm() { |
| MMTKernel kernel; |
| kernel.arch = CustomKernelTargetArch::Aarch64; |
| kernel.lhsType = MMTKernel::ScalarType::I8; |
| kernel.rhsType = MMTKernel::ScalarType::I8; |
| kernel.accType = MMTKernel::ScalarType::I32; |
| kernel.m0 = 8; // shape: 8x8x8. We would have enough registers to widen this |
| kernel.k0 = 8; // to 12x8x8 if needed. |
| kernel.n0 = 8; |
| kernel.lhsRegSize = 16; // LHS NEON register type: int8x16 |
| kernel.rhsRegSize = 16; // RHS NEON register type: int8x16 |
| kernel.accRegSize = 4; // Accum NEON register type: int32x4 |
| kernel.lhsRegs = 4; // = 8x8/16 for 8x4 LHS elems, 16 per register |
| kernel.rhsRegs = 4; // = 8x8/16 for 8x4 RHS elems, 16 per register |
| kernel.accRegs = 16; // = 8x8/4 for 8x8 Accum elems, 4 per register |
| kernel.asmImpl = R"ASM( |
| // What's with the horrendous shuffles (zip, uzp instructions) ? |
| // The smmla instruction works with a 2x2 accumulator tile. |
| // So at the moment, given the MMT vector.contract representation of |
| // the basic block, we have to perform this re-tiling to 2x2 tiles. |
| // |
| // This is not really optimized -- just provided to help shape the next |
| // stage of the discussion, which will be how we change the abstractions |
| // to resolve this. |
| // |
| // For instance, if we compiled a whole loop at once, we would only need |
| // to do so at the start and at the end of the loop. Or even without |
| // handing the whole loop to asm, we could make the vector.contract |
| // higher-dimensional to allow representing this nested tiled layout. |
| // One thing that we should keep in mind though is that this unusual |
| // 2x2 tiled layout is specific to matrix multiplication instructions. |
| // If the matmul kernel were fused into consumer ops, those would probably |
| // prefer not to deal with a 2x2 tiled layout. |
| // |
| // We also can't easily generalize from this to what will be ARMv9+SME. |
| // There, the matmul instruction will also produce a 2D matrix tile, |
| // but that will be much wider, 16x16, and itself row-major, so that when |
| // store it back to NEON/SVE registers, each of those will be contained |
| // within one row. Even if we put multiple such 16x16 tiles side-by-side |
| // in the overall kernel, that will still be at a scale larger than |
| // individual NEON/SVE registers. |
| // |
| // Rows 0-1 of the 8x8 accumulator tile |
| zip1 v28.2d, $(acc:0).2d, $(acc:2).2d |
| zip2 v29.2d, $(acc:0).2d, $(acc:2).2d |
| zip1 v30.2d, $(acc:1).2d, $(acc:3).2d |
| zip2 v31.2d, $(acc:1).2d, $(acc:3).2d |
| smmla v28.4s, $(lhs:0).16b, $(rhs:0).16b |
| smmla v29.4s, $(lhs:0).16b, $(rhs:1).16b |
| smmla v30.4s, $(lhs:0).16b, $(rhs:2).16b |
| smmla v31.4s, $(lhs:0).16b, $(rhs:3).16b |
| uzp1 $(acc:0).2d, v28.2d, v29.2d |
| uzp1 $(acc:1).2d, v30.2d, v31.2d |
| uzp2 $(acc:2).2d, v28.2d, v29.2d |
| uzp2 $(acc:3).2d, v30.2d, v31.2d |
| // Rows 2-3 of the 8x8 accumulator tile |
| zip1 v28.2d, $(acc:4).2d, $(acc:6).2d |
| zip2 v29.2d, $(acc:4).2d, $(acc:6).2d |
| zip1 v30.2d, $(acc:5).2d, $(acc:7).2d |
| zip2 v31.2d, $(acc:5).2d, $(acc:7).2d |
| smmla v28.4s, $(lhs:1).16b, $(rhs:0).16b |
| smmla v29.4s, $(lhs:1).16b, $(rhs:1).16b |
| smmla v30.4s, $(lhs:1).16b, $(rhs:2).16b |
| smmla v31.4s, $(lhs:1).16b, $(rhs:3).16b |
| uzp1 $(acc:4).2d, v28.2d, v29.2d |
| uzp1 $(acc:5).2d, v30.2d, v31.2d |
| uzp2 $(acc:6).2d, v28.2d, v29.2d |
| uzp2 $(acc:7).2d, v30.2d, v31.2d |
| // Rows 4-5 of the 8x8 accumulator tile |
| zip1 v28.2d, $(acc:8).2d, $(acc:10).2d |
| zip2 v29.2d, $(acc:8).2d, $(acc:10).2d |
| zip1 v30.2d, $(acc:9).2d, $(acc:11).2d |
| zip2 v31.2d, $(acc:9).2d, $(acc:11).2d |
| smmla v28.4s, $(lhs:2).16b, $(rhs:0).16b |
| smmla v29.4s, $(lhs:2).16b, $(rhs:1).16b |
| smmla v30.4s, $(lhs:2).16b, $(rhs:2).16b |
| smmla v31.4s, $(lhs:2).16b, $(rhs:3).16b |
| uzp1 $(acc:8).2d, v28.2d, v29.2d |
| uzp1 $(acc:9).2d, v30.2d, v31.2d |
| uzp2 $(acc:10).2d, v28.2d, v29.2d |
| uzp2 $(acc:11).2d, v30.2d, v31.2d |
| // Rows 6-7 of the 8x8 accumulator tile |
| zip1 v28.2d, $(acc:12).2d, $(acc:14).2d |
| zip2 v29.2d, $(acc:12).2d, $(acc:14).2d |
| zip1 v30.2d, $(acc:13).2d, $(acc:15).2d |
| zip2 v31.2d, $(acc:13).2d, $(acc:15).2d |
| smmla v28.4s, $(lhs:3).16b, $(rhs:0).16b |
| smmla v29.4s, $(lhs:3).16b, $(rhs:1).16b |
| smmla v30.4s, $(lhs:3).16b, $(rhs:2).16b |
| smmla v31.4s, $(lhs:3).16b, $(rhs:3).16b |
| uzp1 $(acc:12).2d, v28.2d, v29.2d |
| uzp1 $(acc:13).2d, v30.2d, v31.2d |
| uzp2 $(acc:14).2d, v28.2d, v29.2d |
| uzp2 $(acc:15).2d, v30.2d, v31.2d |
| )ASM"; |
| kernel.asmClobbers = "v28,v29,v30,v31"; |
| return kernel; |
| } |
| |
| // TODO: |
| // i8*i8->i32 kernel for Aarch64 NEON +i8mm, matrix*vector: |
| // Not implemented yet. Due to the shape of the smmla instruction, such a kernel |
| // would utilize only 50%. It would still be somewhat faster than the dotprod |
| // matrix*vector kernel at the moment because reading 64bits into a NEON |
| // register is faster than reading 32bits twice. That's a shallow advantage that |
| // might vanish once the vector.contract abstraction layer above kernels is |
| // improved. Another reason why it's not implemented yet is it would have shape |
| // 8x8x1, same as the aarch64 baseline matrix*vector i8 kernel, so we would need |
| // a "kernel benefit" system to cleanly express the preference for the i8mm |
| // kernel. |
| |
| // f32*f32->f32 kernel for Aarch64 NEON |
| // |
| // Note: this asm kernel isn't needed. The default vector.contract |
| // lowerings already result in essentially the same code. This is included for |
| // now for completeness, as we need the f32 matrix*vector kernel below anyway. |
| MMTKernel MMTKernel_8x1x8_f32f32f32_Aarch64_Baseline_InlineAsm() { |
| MMTKernel kernel; |
| kernel.arch = CustomKernelTargetArch::Aarch64; |
| kernel.lhsType = MMTKernel::ScalarType::F32; |
| kernel.rhsType = MMTKernel::ScalarType::F32; |
| kernel.accType = MMTKernel::ScalarType::F32; |
| kernel.m0 = 8; |
| kernel.k0 = 1; |
| kernel.n0 = 8; |
| kernel.lhsRegSize = 4; |
| kernel.rhsRegSize = 4; |
| kernel.accRegSize = 4; |
| kernel.lhsRegs = 2; |
| kernel.rhsRegs = 2; |
| kernel.accRegs = 16; |
| kernel.asmImpl = R"ASM( |
| fmla $(acc:0).4s, $(rhs:0).4s, $(lhs:0).s[0] |
| fmla $(acc:1).4s, $(rhs:1).4s, $(lhs:0).s[0] |
| fmla $(acc:2).4s, $(rhs:0).4s, $(lhs:0).s[1] |
| fmla $(acc:3).4s, $(rhs:1).4s, $(lhs:0).s[1] |
| fmla $(acc:4).4s, $(rhs:0).4s, $(lhs:0).s[2] |
| fmla $(acc:5).4s, $(rhs:1).4s, $(lhs:0).s[2] |
| fmla $(acc:6).4s, $(rhs:0).4s, $(lhs:0).s[3] |
| fmla $(acc:7).4s, $(rhs:1).4s, $(lhs:0).s[3] |
| fmla $(acc:8).4s, $(rhs:0).4s, $(lhs:1).s[0] |
| fmla $(acc:9).4s, $(rhs:1).4s, $(lhs:1).s[0] |
| fmla $(acc:10).4s, $(rhs:0).4s, $(lhs:1).s[1] |
| fmla $(acc:11).4s, $(rhs:1).4s, $(lhs:1).s[1] |
| fmla $(acc:12).4s, $(rhs:0).4s, $(lhs:1).s[2] |
| fmla $(acc:13).4s, $(rhs:1).4s, $(lhs:1).s[2] |
| fmla $(acc:14).4s, $(rhs:0).4s, $(lhs:1).s[3] |
| fmla $(acc:15).4s, $(rhs:1).4s, $(lhs:1).s[3] |
| )ASM"; |
| return kernel; |
| } |
| |
| // f32*f32->f32 kernel for Aarch64 NEON, matrix*vector |
| // |
| // Note: this is about the most naive possible SIMD kernel here, and it should |
| // not be needed as this should be an easy case for codegen. Here we are very |
| // limited in what we can do as a MMT vector.contract lowering - to make a |
| // better kernel, we would need to peel the surrounding loop, and implement |
| // a larger vector.contract with an additional parallel iterator, accumulating |
| // into more separate registers, deferring reduction to the end of the loop. |
| // |
| // And yet, this kernel is currently needed, because at the moment this is what |
| // the codegen generates: |
| // 10d08: 90 44 c1 ac ldp q16, q17, [x4], #32 |
| // 10d0c: c6 04 00 f1 subs x6, x6, #1 |
| // 10d10: 13 42 10 6e ext v19.16b, v16.16b, v16.16b, #8 |
| // 10d14: 34 42 11 6e ext v20.16b, v17.16b, v17.16b, #8 |
| // 10d18: b2 44 40 bc ldr s18, [x5], #4 |
| // 10d1c: 40 ce 30 0e fmla v0.2s, v18.2s, v16.2s |
| // 10d20: 47 12 b0 0f fmla v7.2s, v18.2s, v16.s[1] |
| // 10d24: 46 1a b0 0f fmla v6.2s, v18.2s, v16.s[3] |
| // 10d28: 42 ce 31 0e fmla v2.2s, v18.2s, v17.2s |
| // 10d2c: 41 12 b1 0f fmla v1.2s, v18.2s, v17.s[1] |
| // 10d30: 45 ce 33 0e fmla v5.2s, v18.2s, v19.2s |
| // 10d34: 44 ce 34 0e fmla v4.2s, v18.2s, v20.2s |
| // 10d38: 43 1a b1 0f fmla v3.2s, v18.2s, v17.s[3] |
| // 10d3c: 61 fe ff 54 b.ne 0x10d08 <.text+0x770> |
| // |
| // This is effectively non-SIMD, since each of the 8 fmla here does one useful |
| // scalar multiplication (note: the ldr s18 instruction loaded one float into |
| // the first lane of v18.4s and zeroed the other 3 lanes). |
| MMTKernel MMTKernel_8x1x1_f32f32f32_Aarch64_Baseline_InlineAsm() { |
| MMTKernel kernel; |
| kernel.arch = CustomKernelTargetArch::Aarch64; |
| kernel.lhsType = MMTKernel::ScalarType::F32; |
| kernel.rhsType = MMTKernel::ScalarType::F32; |
| kernel.accType = MMTKernel::ScalarType::F32; |
| kernel.m0 = 8; |
| kernel.k0 = 1; |
| kernel.n0 = 1; |
| kernel.lhsRegSize = 4; |
| kernel.rhsRegSize = 1; |
| kernel.accRegSize = 4; |
| kernel.lhsRegs = 2; |
| kernel.rhsRegs = 1; |
| kernel.accRegs = 2; |
| kernel.asmImpl = R"ASM( |
| fmla $(acc:0).4s, $(lhs:0).4s, $(rhs:0).s[0] |
| fmla $(acc:1).4s, $(lhs:1).4s, $(rhs:0).s[0] |
| )ASM"; |
| return kernel; |
| } |
| |
| // Constructs the mlir::Type corresponding to a scalar type. |
| Type mlirType(MLIRContext *context, MMTKernel::ScalarType t) { |
| switch (t) { |
| case MMTKernel::ScalarType::None: |
| break; |
| case MMTKernel::ScalarType::I8: |
| return IntegerType::get(context, 8, IntegerType::Signless); |
| case MMTKernel::ScalarType::I32: |
| return IntegerType::get(context, 32, IntegerType::Signless); |
| case MMTKernel::ScalarType::F32: |
| return FloatType::getF32(context); |
| } |
| assert(false); |
| return Type(); |
| } |
| |
| // This class is a helper for patterns generating custom kernels based on |
| // MMTKernel structs. |
| class MMTKernelGenerator { |
| public: |
| MMTKernelGenerator(MLIRContext *context, MMTKernel kernel) |
| : context(context), kernel(kernel) { |
| kernel.validate(); |
| } |
| // Generates the kernel. Returns the output accumulator values. |
| SmallVector<Value> generate(PatternRewriter &rewriter, Location loc, |
| ArrayRef<Value> lhs, ArrayRef<Value> rhs, |
| ArrayRef<Value> acc) { |
| validateOperands(lhs, rhs, acc); |
| if (kernel.asmImpl) { |
| return generateAsm(rewriter, loc, lhs, rhs, acc); |
| } |
| // In the future we may have alternate generator paths, e.g. 1D intrinsics |
| // or other asm paths with a different interface, e.g. handling also |
| // the memory load accesses. |
| assert(false && "no implementation provided for kernel"); |
| return {}; |
| } |
| // Returns the MLIR element type (not vector type) of the LHS |
| Type getLhsType() const { return mlirType(context, kernel.lhsType); } |
| // Returns the MLIR element type (not vector type) of the RHS |
| Type getRhsType() const { return mlirType(context, kernel.rhsType); } |
| // Returns the MLIR element type (not vector type) of the Accumulator |
| Type getAccType() const { return mlirType(context, kernel.accType); } |
| // Returns the VectorType of LHS SIMD register vectors |
| VectorType getLhsRegVectorType() const { |
| return VectorType::get({kernel.lhsRegSize}, getLhsType()); |
| } |
| // Returns the VectorType of RHS SIMD register vectors |
| VectorType getRhsRegVectorType() const { |
| return VectorType::get({kernel.rhsRegSize}, getRhsType()); |
| } |
| // Returns the VectorType of Accumulator SIMD register vectors |
| VectorType getAccRegVectorType() const { |
| return VectorType::get({kernel.accRegSize}, getAccType()); |
| } |
| |
| private: |
| MLIRContext *context; |
| MMTKernel kernel; |
| |
| // Helper for generate(). Asserts sanity of the vector-of-register-vectors. |
| void validateOperands(ArrayRef<Value> lhs, ArrayRef<Value> rhs, |
| ArrayRef<Value> acc) { |
| auto validate = [](ArrayRef<Value> vals, int expectedSize, |
| VectorType expectedType) { |
| assert(vals.size() == expectedSize); |
| for (const auto &val : vals) { |
| assert(val.getType().dyn_cast<VectorType>() == expectedType); |
| (void)val; |
| } |
| (void)expectedSize; |
| (void)expectedType; |
| }; |
| validate(lhs, kernel.lhsRegs, getLhsRegVectorType()); |
| validate(rhs, kernel.rhsRegs, getRhsRegVectorType()); |
| validate(acc, kernel.accRegs, getAccRegVectorType()); |
| } |
| // Helper for generateAsmCodeAndConstraints |
| std::string getConstraintCode() const { |
| switch (kernel.arch) { |
| case CustomKernelTargetArch::Aarch64: |
| return "w"; |
| case CustomKernelTargetArch::None: |
| break; |
| } |
| assert(false && "Unhandled CustomKernelTargetArch value"); |
| return {}; |
| } |
| // Helper class to build the constraints string of an inline_asm op. |
| class Constraints { |
| private: |
| // The LLVM inline asm syntax is documented here: |
| // https://llvm.org/docs/LangRef.html#inline-assembler-expressions |
| SmallVector<std::string> inputs; |
| SmallVector<std::string> outputs; |
| SmallVector<std::string> tiedInputs; |
| SmallVector<std::string> clobbers; |
| |
| public: |
| enum class Kind { Input, InputOutput }; |
| // Add a new constraint. |
| void add(Kind kind, const std::string &constraintCode) { |
| switch (kind) { |
| case Kind::Input: |
| inputs.push_back(constraintCode); |
| return; |
| case Kind::InputOutput: |
| // An input represented by a number `i` is a tied input, tied to the |
| // i-th output. |
| tiedInputs.push_back(llvm::itostr(outputs.size())); |
| outputs.push_back(std::string("=") + constraintCode); |
| return; |
| } |
| assert(false); |
| } |
| void setClobbers(const char *rawClobbersStr) { |
| assert(clobbers.empty()); |
| if (!rawClobbersStr) { |
| return; |
| } |
| for (StringRef c : llvm::split(rawClobbersStr, ',')) { |
| clobbers.push_back(("~{" + c + "}").str()); |
| } |
| } |
| // Returns the constraints string to be passed to the inline_asm op. |
| // llvm::concat does not currently support concatenating const-qualified |
| // objects, so we can't currently const-qualify this method. |
| std::string toString() { |
| return llvm::join( |
| llvm::concat<std::string>(outputs, inputs, tiedInputs, clobbers), |
| ","); |
| } |
| }; |
| // Helper for generateAsm. Performs some pre-processing of the kernel's |
| // asmImpl. Refer to the comment on kernel::asmImpl. |
| void generateAsmCodeAndConstraints(std::string &code, |
| std::string &constraintsString) { |
| assert(code.empty()); |
| assert(constraintsString.empty()); |
| // The LLVM inline asm syntax is documented here: |
| // https://llvm.org/docs/LangRef.html#inline-assembler-expressions |
| Constraints constraints; |
| code = kernel.asmImpl; |
| // processedIdx is the index of a register in the processed asm. |
| // Example: $5 => processedIdx == 5 |
| int processedIdx = 0; |
| auto processOperands = [&](Constraints::Kind constraintKind, |
| const char *name, int count) { |
| const std::string &constraintCode = getConstraintCode(); |
| // unprocessedIdx is the index of a register in the unprocessed asm. |
| // Example: $(lhs:1) => unprocessedIdx == 1 |
| for (int unprocessedIdx = 0; unprocessedIdx < count; |
| ++unprocessedIdx, ++processedIdx) { |
| constraints.add(constraintKind, constraintCode); |
| // Perform the code replacement for the operand. |
| // Example: $(lhs:1) => $5 |
| replaceAllSubstrsInPlace( |
| code, llvm::formatv("$({0}:{1})", name, unprocessedIdx), |
| llvm::formatv("${0}", processedIdx)); |
| } |
| }; |
| processOperands(Constraints::Kind::InputOutput, "acc", kernel.accRegs); |
| processOperands(Constraints::Kind::Input, "lhs", kernel.lhsRegs); |
| processOperands(Constraints::Kind::Input, "rhs", kernel.rhsRegs); |
| constraints.setClobbers(kernel.asmClobbers); |
| constraintsString = constraints.toString(); |
| } |
| // Helper for generate(). Implements the asm path. |
| SmallVector<Value> generateAsm(PatternRewriter &rewriter, Location loc, |
| ArrayRef<Value> lhs, ArrayRef<Value> rhs, |
| ArrayRef<Value> acc) { |
| SmallVector<Value> inputs; |
| // First the input operands. Then the input-output operands, which, as far |
| // as input constraints are concerned, are *tied* inputs, i.e. refer to |
| // the outputs that we list earlier in the constraints string. This is why |
| // us passing the inputs BEFORE the input-outputs here actually matches |
| // us listing the inputs AFTER the outputs (but BEFORE the tied-inputs) in |
| // the constraints string. Not confusing at all! |
| inputs.append(lhs.begin(), lhs.end()); |
| for (const auto &v : rhs) { |
| if (v.getType().cast<VectorType>().getNumElements() == 1) |
| inputs.push_back(extract(rewriter, loc, v, 0)); |
| else |
| inputs.push_back(v); |
| } |
| inputs.append(acc.begin(), acc.end()); |
| // Create the inline asm op. |
| SmallVector<Type> outputOperandTypes( |
| llvm::map_range(acc, [](Value v) { return v.getType(); })); |
| auto returnType = |
| LLVM::LLVMStructType::getLiteral(context, outputOperandTypes); |
| auto dialectAttr = |
| LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_ATT); |
| std::string code; |
| std::string constraints; |
| generateAsmCodeAndConstraints(code, constraints); |
| LLVM::InlineAsmOp asmOp = rewriter.create<LLVM::InlineAsmOp>( |
| loc, returnType, inputs, code, constraints, |
| /*has_side_effects=*/false, /*is_align_stack=*/false, dialectAttr, |
| /*operand_attrs=*/ArrayAttr()); |
| // Extract result vectors from the asm op. |
| SmallVector<Value> resVec; |
| for (int i = 0; i < kernel.accRegs; ++i) { |
| resVec.push_back(rewriter.create<LLVM::ExtractValueOp>( |
| loc, getAccRegVectorType(), asmOp.getRes(), |
| rewriter.getI64ArrayAttr({i}))); |
| } |
| return resVec; |
| } |
| }; |
| |
| /// Converts matrix-times-matrix-transposed vector.contracts, and possibly also |
| /// any type-promotion op (such as arith.extsi) on the input operands, to |
| /// a custom kernel (at the moment a llvm.inline_asm op) provided by the |
| /// MMTKernel struct. |
| /// |
| /// For example, in the case of a i8*i8->i32 kernel, the IR being replaced |
| /// by the llvm.inline_asm op might look like: |
| /// |
| /// %lhs_i32 = arith.extsi %lhs_i8 : i8 to i32 |
| /// %rhs_i32 = arith.extsi %rhs_i8 : i8 to i32 |
| /// %result = vector.contract [...] |
| /// %lhs_i32 : vector<8x4xi32>, |
| /// %rhs_i32 : vector<8x4xi32>, |
| /// %acc_i32 : vector<8x8xi32>, |
| /// [...] |
| /// |
| class MMTCustomKernelPattern : public OpRewritePattern<vector::ContractionOp> { |
| private: |
| MMTKernel kernel; |
| |
| public: |
| MMTCustomKernelPattern(MLIRContext *context, MMTKernel kernel) |
| : OpRewritePattern<vector::ContractionOp>(context), kernel(kernel) {} |
| |
| LogicalResult matchAndRewrite(vector::ContractionOp contractionOp, |
| PatternRewriter &rewriter) const override { |
| // Check if `contractionOp` matches, and obtain the (un-promoted) input |
| // LHS and RHS vectors. |
| bool transposeKernel = false; |
| if (!matchMMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, |
| &transposeKernel)) { |
| return failure(); |
| } |
| MMTKernelGenerator generator(rewriter.getContext(), kernel); |
| Type lhsElemType = generator.getLhsType(); |
| Type rhsElemType = generator.getRhsType(); |
| Type accElemType = generator.getAccType(); |
| VectorType accType = contractionOp.acc().getType().cast<VectorType>(); |
| if (accType.getElementType() != accElemType) { |
| return failure(); |
| } |
| Value unpromotedLhs = |
| getUnpromotedInput(lhsElemType, accElemType, contractionOp.lhs()); |
| Value unpromotedRhs = |
| getUnpromotedInput(rhsElemType, accElemType, contractionOp.rhs()); |
| if (!unpromotedLhs || !unpromotedRhs) { |
| return failure(); |
| } |
| // Prepare the dense array attribute that will be used as the initializer |
| // for the destination accumulator vector, before actual values are inserted |
| // into it. We do this early because here we need to validate that the |
| // destination scalar type is one that we know how to handle. |
| VectorType flatAccVectorType = |
| VectorType::get({accType.getNumElements()}, accType.getElementType()); |
| ; |
| Attribute resultInitializer; |
| if (accElemType.isSignlessInteger()) { |
| resultInitializer = DenseIntElementsAttr::get(flatAccVectorType, 0); |
| } else if (accElemType.isF32()) { |
| resultInitializer = DenseFPElementsAttr::get(flatAccVectorType, 0.f); |
| } else { |
| return failure(); |
| } |
| // `contractionOp` matches, start rewriting it. |
| Location loc = contractionOp.getLoc(); |
| // Flatten the inputs to 1D vectors. |
| Value flatLhs = flatten(rewriter, loc, unpromotedLhs); |
| Value flatRhs = flatten(rewriter, loc, unpromotedRhs); |
| Value flatAcc = flatten(rewriter, loc, contractionOp.acc()); |
| // Slice into SIMD-register-sized 1D input vectors ready to feed to the |
| // target SIMD instructions. |
| auto sliceIntoRegVectors = [&](int regsCount, VectorType regVectorType, |
| Value src) { |
| SmallVector<Value> regVectors; |
| int regSize = regVectorType.getNumElements(); |
| for (int i = 0; i < regsCount; ++i) { |
| regVectors.push_back( |
| extract1DSlice(rewriter, loc, regVectorType, src, i * regSize)); |
| } |
| return regVectors; |
| }; |
| VectorType lhsRegVectorType = generator.getLhsRegVectorType(); |
| VectorType rhsRegVectorType = generator.getRhsRegVectorType(); |
| VectorType accRegVectorType = generator.getAccRegVectorType(); |
| Value flatLhsForKernel = transposeKernel ? flatRhs : flatLhs; |
| Value flatRhsForKernel = transposeKernel ? flatLhs : flatRhs; |
| SmallVector<Value> lhsRegVectors = |
| sliceIntoRegVectors(kernel.lhsRegs, lhsRegVectorType, flatLhsForKernel); |
| SmallVector<Value> rhsRegVectors = |
| sliceIntoRegVectors(kernel.rhsRegs, rhsRegVectorType, flatRhsForKernel); |
| SmallVector<Value> accRegVectors = |
| sliceIntoRegVectors(kernel.accRegs, accRegVectorType, flatAcc); |
| // Generate the kernel! |
| SmallVector<Value> resRegVectors = generator.generate( |
| rewriter, loc, lhsRegVectors, rhsRegVectors, accRegVectors); |
| // Insert the result vectors of size 4 into the overall result vector of |
| // size 64, still 1D. |
| Value result = rewriter.create<arith::ConstantOp>(loc, flatAccVectorType, |
| resultInitializer); |
| int accRegNumElements = accRegVectorType.getNumElements(); |
| for (int i = 0; i < kernel.accRegs; ++i) { |
| result = rewriter.create<vector::InsertStridedSliceOp>( |
| loc, resRegVectors[i], result, |
| std::array<int64_t, 1>{accRegNumElements * i}, |
| std::array<int64_t, 1>{1}); |
| } |
| // Cast the result from 1D to 2D and replace the original vector.contract. |
| rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(contractionOp, accType, |
| result); |
| return success(); |
| } |
| }; |
| |
| /// Converts matrix-times-matrix-transposed vector.contracts with |
| /// lhs and rhs inputs defined by arith.extsi promoting from i8 to i32, |
| /// |
| /// %lhs_i32 = arith.extsi %lhs_i8 : i8 to i32 |
| /// %rhs_i32 = arith.extsi %rhs_i8 : i8 to i32 |
| /// %result = vector.contract [...] |
| /// %lhs_i32 : vector<8x4xi32>, |
| /// %rhs_i32 : vector<8x4xi32>, |
| /// %acc_i32 : vector<8x8xi32>, |
| /// [...] |
| /// |
| /// To vector ops reading directly from the %lhs_i8 and %rhs_i8 values |
| /// (bypassing the existing arith.extsi) and passing that to a llvm.inline_asm |
| /// block implementing the matrix multiplication arithmetic using Aarch64 |
| /// dot-product instructions (sdot). |
| /// It matches the same patterns as MMT_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm |
| struct MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics |
| : public OpRewritePattern<vector::ContractionOp> { |
| public: |
| using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::ContractionOp contractionOp, |
| PatternRewriter &rewriter) const override { |
| if (!matchMMT(contractionOp, 8, 4, 8)) { |
| return failure(); |
| } |
| |
| Type I8Type = rewriter.getIntegerType(8); |
| Type I32Type = rewriter.getIntegerType(32); |
| |
| auto acc = contractionOp.acc(); |
| auto lhs = contractionOp.lhs(); |
| auto rhs = contractionOp.rhs(); |
| if (acc.getType().cast<VectorType>().getElementType() != I32Type) { |
| return failure(); |
| } |
| |
| Value inLhs = getUnpromotedInput(I8Type, I32Type, lhs); |
| Value inRhs = getUnpromotedInput(I8Type, I32Type, rhs); |
| |
| if (!inLhs || !inRhs) return failure(); |
| |
| auto loc = contractionOp.getLoc(); |
| |
| auto int32x4VType = VectorType::get({4}, I32Type); |
| |
| std::array<Value, 16> accChunks; |
| { |
| int idx = 0; |
| for (int row = 0; row < 8; ++row) { |
| auto accRow = rewriter.create<vector::ExtractOp>( |
| loc, acc, ArrayRef<int64_t>{row}); |
| for (int col = 0; col < 8; col += 4) { |
| auto accChunk = rewriter.create<vector::ExtractStridedSliceOp>( |
| loc, accRow, ArrayRef<int64_t>{col}, ArrayRef<int64_t>{4}, |
| ArrayRef<int64_t>{1}); |
| assert(accChunk.getType() == int32x4VType); |
| accChunks[idx++] = accChunk; |
| } |
| } |
| } |
| |
| auto int8x4x4VType = VectorType::get({4, 4}, rewriter.getIntegerType(8)); |
| auto extract4x4 = [&](Value in, int rowOffset, int colOffset) { |
| auto chunk = rewriter.create<vector::ExtractStridedSliceOp>( |
| loc, in, ArrayRef<int64_t>{rowOffset, colOffset}, |
| ArrayRef<int64_t>{4, 4}, ArrayRef<int64_t>{1, 1}); |
| assert(chunk.getType() == int8x4x4VType); |
| return chunk; |
| }; |
| |
| std::array<Value, 2> lhsHalves = {extract4x4(inLhs, 0, 0), |
| extract4x4(inLhs, 4, 0)}; |
| std::array<Value, 2> rhsHalves = {extract4x4(inRhs, 0, 0), |
| extract4x4(inRhs, 4, 0)}; |
| |
| auto int8Zero4x4 = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getZeroAttr(int8x4x4VType)); |
| auto sdot = [&](Value acc, Value a, Value b, int64_t lane) -> Value { |
| auto bReplicatedLane = rewriter.create<vector::ShuffleOp>( |
| loc, b, int8Zero4x4, ArrayRef<int64_t>{lane, lane, lane, lane}); |
| |
| return rewriter.create<arm_neon::Sdot2dOp>(loc, int32x4VType, acc, a, |
| bReplicatedLane); |
| }; |
| |
| std::array<Value, 16> dstChunks; |
| { |
| int idx = 0; |
| for (Value lhs : lhsHalves) { |
| for (int lane = 0; lane < 4; ++lane) { |
| for (Value rhs : rhsHalves) { |
| dstChunks[idx] = sdot(accChunks[idx], rhs, lhs, lane); |
| ++idx; |
| } |
| } |
| } |
| } |
| |
| // Put the results back in the accumulator |
| { |
| int idx = 0; |
| for (int row = 0; row < 8; ++row) { |
| for (int col = 0; col < 8; col += 4) { |
| acc = rewriter.create<vector::InsertStridedSliceOp>( |
| loc, dstChunks[idx++], acc, ArrayRef<int64_t>{row, col}, |
| ArrayRef<int64_t>{1}); |
| } |
| } |
| } |
| rewriter.replaceOp(contractionOp, {acc}); |
| return success(); |
| } |
| }; |
| |
| class VectorContractCustomKernelsPass |
| : public VectorContractCustomKernelsBase<VectorContractCustomKernelsPass> { |
| public: |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<vector::VectorDialect, LLVM::LLVMDialect>(); |
| if (targetInfo.has(CustomKernelTargetFeature::Intrinsics)) { |
| registry.insert<arm_neon::ArmNeonDialect>(); |
| } |
| } |
| LogicalResult initializeOptions(StringRef options) override { |
| if (failed(Pass::initializeOptions(options))) { |
| return failure(); |
| } |
| if (failed(ParseCustomKernelsTargetInfo(arch, features, targetInfo))) { |
| llvm::errs() << "Bad options `" << options << "` for pass `" |
| << getArgument() << "`\n"; |
| return failure(); |
| } |
| if (intrinsics) { |
| targetInfo.add(CustomKernelTargetFeature::Intrinsics); |
| } |
| return success(); |
| } |
| void runOnOperation() override { |
| MLIRContext *context = &getContext(); |
| RewritePatternSet patterns(context); |
| populateVectorContractCustomKernelsPatterns(targetInfo, patterns); |
| if (failed(applyPatternsAndFoldGreedily(getOperation(), |
| std::move(patterns)))) { |
| signalPassFailure(); |
| } |
| } |
| |
| private: |
| CustomKernelsTargetInfo targetInfo; |
| }; |
| |
| } // namespace |
| |
| void populateVectorContractCustomKernelsPatterns( |
| const CustomKernelsTargetInfo &targetInfo, RewritePatternSet &patterns) { |
| MLIRContext *context = patterns.getContext(); |
| if (targetInfo.is(CustomKernelTargetArch::Aarch64)) { |
| // TODO: add a "kernel benefit" system whereby if two kernels are available |
| // for the same shape and same data types, the fastest one (ie the one |
| // using the most powerful available SIMD instructions) is selected. |
| // This is incidentally not needed at the moment because currently no two |
| // kernels share the exact same shape and data types. |
| patterns.add<MMTCustomKernelPattern>( |
| context, MMTKernel_8x1x8_f32f32f32_Aarch64_Baseline_InlineAsm()); |
| patterns.add<MMTCustomKernelPattern>( |
| context, MMTKernel_8x1x1_f32f32f32_Aarch64_Baseline_InlineAsm()); |
| patterns.add<MMTCustomKernelPattern>( |
| context, MMTKernel_8x1x8_i8i8i32_Aarch64_Baseline_InlineAsm()); |
| patterns.add<MMTCustomKernelPattern>( |
| context, MMTKernel_8x8x1_i8i8i32_Aarch64_Baseline_InlineAsm()); |
| if (targetInfo.has(CustomKernelTargetFeature::Aarch64Dotprod)) { |
| if (targetInfo.has(CustomKernelTargetFeature::Intrinsics)) { |
| patterns.add<MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics>(context); |
| } else { |
| patterns.add<MMTCustomKernelPattern>( |
| context, MMTKernel_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm()); |
| patterns.add<MMTCustomKernelPattern>( |
| context, MMTKernel_8x4x1_i8i8i32_Aarch64Dotprod_InlineAsm()); |
| } |
| } |
| if (targetInfo.has(CustomKernelTargetFeature::Aarch64I8mm)) { |
| patterns.add<MMTCustomKernelPattern>( |
| context, MMTKernel_8x8x8_i8i8i32_Aarch64I8mm_InlineAsm()); |
| } |
| } |
| } |
| |
| std::unique_ptr<OperationPass<FuncOp>> createVectorContractCustomKernelsPass() { |
| return std::make_unique<VectorContractCustomKernelsPass>(); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |