| // Copyright 2022 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" |
| #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" |
| #include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" |
| #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" |
| #include "iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
| #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "llvm/ADT/SetVector.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace LinalgExt { |
| |
| static inline int index(int y, int x, int dimy, int dimx) { |
| return (x + dimx * y); |
| } |
| |
| static inline int index(int z, int y, int x, int w, int dimz, int dimy, |
| int dimx, int dimw) { |
| return (w + dimw * (x + dimx * (y + dimy * z))); |
| } |
| |
| static bool hasAllOneValues(DenseIntElementsAttr attr) { |
| return llvm::all_of(attr, [](APInt element) { return element.isOne(); }); |
| } |
| |
| // TODO: Make this a user-settable parameter once we have support |
| // for more tile sizes |
| static constexpr int64_t outputTileSize = 6; |
| |
| /// This function computes the Winograd filter transform when |
| /// the filter is known to be a constant. Specifically, this |
| /// function computes matmul(G, matmul(F, transpose(G))) where |
| /// F is a tile of the convolution filter of size m x m |
| /// (single input channel, single output channel) and G has |
| /// shape m x (m + r - 1) where r is the output tile size and |
| /// (m + r - 1) is the input tile size. |
| /// The time complexity of this function is O(ic * oc) |
| /// where ic is the number of input channels and oc is the |
| /// number of output channels since input tile size and kernel size |
| /// are constants. So for large ic and oc, this function is |
| /// time intensive. |
| /// TODO: Codegen this as a kernel and run once at initialization |
| static DenseElementsAttr foldFilterTransform( |
| ArrayRef<int64_t> shape, int64_t inputTileSize, int64_t kernelSize, |
| Type outputType, const float *G, bool isSplat, float splatValue, |
| DenseElementsAttr::iterator_range<APFloat> &input, FloatType floatType) { |
| const int &kh = shape[0]; |
| const int &kw = shape[1]; |
| const int &ic = shape[2]; |
| const int &oc = shape[3]; |
| const int64_t numElements = inputTileSize * inputTileSize * ic * oc; |
| SmallVector<APFloat> output(numElements, APFloat(0.0f)); |
| for (int d0 = 0; d0 < inputTileSize; d0++) { |
| for (int d1 = 0; d1 < inputTileSize; d1++) { |
| for (int d2 = 0; d2 < ic; d2++) { |
| for (int d3 = 0; d3 < oc; d3++) { |
| APFloat accum(0.0f); |
| for (int d4 = 0; d4 < kernelSize; d4++) { |
| for (int d5 = 0; d5 < kernelSize; d5++) { |
| APFloat ival(splatValue); |
| if (!isSplat) { |
| ival = input[index(d4, d5, d2, d3, kh, kw, ic, oc)]; |
| } |
| int idx0 = index(d0, d4, inputTileSize, kernelSize); |
| int idx1 = index(d1, d5, inputTileSize, kernelSize); |
| accum = accum + APFloat(G[idx0]) * ival * APFloat(G[idx1]); |
| } |
| } |
| int odx = index(d0, d1, d2, d3, inputTileSize, inputTileSize, ic, oc); |
| output[odx] = accum; |
| if (floatType.isF16()) { |
| bool losesInfo; |
| output[odx].convert(APFloat::IEEEhalf(), |
| APFloat::rmNearestTiesToEven, &losesInfo); |
| } |
| } |
| } |
| } |
| } |
| return DenseElementsAttr::get(outputType, output); |
| } |
| |
| namespace { |
| |
| class FoldWinogradFilterTransform final |
| : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, |
| PatternRewriter &rewriter) const override { |
| // Check that kernel size = 3x3 |
| Value kernel = convOp.getInputs()[1]; |
| auto kernelType = kernel.getType().cast<ShapedType>(); |
| if (!kernelType) |
| return failure(); |
| ArrayRef<int64_t> kernelShape = kernelType.getShape(); |
| const int64_t kh = kernelShape[0]; |
| const int64_t kw = kernelShape[1]; |
| if ((kh != 3) || (kw != 3)) |
| return failure(); |
| const int64_t kernelSize = kh; |
| const int64_t inputTileSize = outputTileSize + kernelSize - 1; |
| |
| DenseIntOrFPElementsAttr kernelAttr; |
| if (!matchPattern(kernel, m_Constant(&kernelAttr))) { |
| return failure(); |
| } |
| |
| Operation *constOp = kernel.getDefiningOp(); |
| ShapedType type = constOp->getResult(0).getType().cast<ShapedType>(); |
| auto elemType = type.getElementType().cast<FloatType>(); |
| ArrayRef<int64_t> shape = type.getShape(); |
| DenseElementsAttr::iterator_range<APFloat> nonSplatValues = |
| kernelAttr.getValues<APFloat>(); |
| bool isSplat = kernelAttr.isSplat(); |
| float splatValue{0.0}; |
| if (isSplat) { |
| splatValue = kernelAttr.getSplatValue<APFloat>().convertToFloat(); |
| } |
| SmallVector<int64_t> resultShape{inputTileSize * inputTileSize, shape[2], |
| shape[3]}; |
| auto resultType = RankedTensorType::get(resultShape, elemType); |
| auto foldedKernelAttr = |
| foldFilterTransform(shape, inputTileSize, kernelSize, resultType, |
| IREE::LinalgExt::Winograd::G_6x6_3x3, isSplat, |
| splatValue, nonSplatValues, elemType); |
| rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, foldedKernelAttr); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| static Value |
| createCollapse(Value tensor, Location loc, PatternRewriter &rewriter, |
| SmallVectorImpl<int64_t> &outputShape, |
| SmallVectorImpl<ReassociationIndices> &reassociations) { |
| auto tensorType = tensor.getType().cast<ShapedType>(); |
| auto elementTy = tensorType.getElementType(); |
| auto resultType = RankedTensorType::get(outputShape, elementTy); |
| return rewriter.create<tensor::CollapseShapeOp>(loc, resultType, tensor, |
| reassociations); |
| } |
| |
| static Value |
| createExpand(Value tensor, Location loc, PatternRewriter &rewriter, |
| SmallVectorImpl<int64_t> &outputShape, |
| SmallVectorImpl<ReassociationIndices> &reassociations) { |
| auto tensorType = tensor.getType().cast<ShapedType>(); |
| auto elementTy = tensorType.getElementType(); |
| auto resultType = RankedTensorType::get(outputShape, elementTy); |
| return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor, |
| reassociations); |
| } |
| |
| namespace { |
| |
| /// Convert conv2d to a sequence of ops that implement the |
| /// Winograd transformation. The Winograd transformation |
| /// is parameterized by the output tile size(r). The larger |
| /// the tile size, the greater the computational savings, |
| /// but this comes at the cost of accuracy. |
| /// |
| /// For now, we restrict this transform to convolutions |
| /// where the filter size = 3x3, though extensions to larger |
| /// filter sizes are possible. We refer to the |
| /// filter size as (m). The input tile size (i) is defined as |
| /// m + r - 1. For a given output tile size, the Winograd |
| /// transformation defines 3 constant matrices: |
| /// |
| /// B: i x i [used in input transform] |
| /// G: m x i [used in the filter transform] |
| /// A: i x r [used in output transform] |
| /// |
| /// The choice of these matrices is not unique and affects |
| /// the accuracy of the approach. |
| /// |
| /// Given a convolution of the form |
| /// |
| /// y = conv2d(x, f) |
| /// |
| /// where x: (N, H, W, C) |
| /// f: (H, W, C, F) |
| /// |
| /// this pattern converts the convolution to the following |
| /// sequence: |
| /// |
| /// f_winograd = winograd.filter_transform(f) [folded] |
| /// x_winograd = winograd.input_transform(x) |
| /// x_winograd_c = collapse(x_winograd) |
| /// y_winograd = batch_matmul(x_winograd_c, f_winograd) |
| /// y_winograd_e = expand(y_winograd) |
| /// y_padded = winograd.output_transform(y_winograd_e) |
| /// y = extract_slice(y_padded) |
| /// |
| /// where the dimensions of the tensors above are: |
| /// |
| /// f_winograd: (i * i, C, F) |
| /// x_winograd: (i, i, N, H', W', C) |
| /// x_winograd_c: (i * i, N * H' * W', C) |
| /// y_winograd: (i * i, N * H' * W', F) |
| /// y_winograd_e: (i, i, N, H', W', F) |
| /// y_padded: (N, r * H', r * W', F) |
| /// |
| /// H': ceil((H - m + 1) / r) |
| /// W': ceil((W - m + 1) / r) |
| /// |
| /// The winograd input transform extracts a tile of the input |
| /// of size i x i and computes matmul(transpose(B), matmul(tile(x), B)). |
| /// The winograd filter transform extracts a tile of the filter |
| /// of size m x m and computes matmul(G, matmul(tile(f), transpose(G)). |
| /// These two are then combined using elementwise multiplication |
| /// (which becomes a batch matmul when combining over multiple channels). |
| /// The winograd output filter extracts a tile of the result of size |
| /// i x i and computes matmul(transpose(A), matmul(tile(y_winograd_e), A)). |
| /// |
| /// For more information and additional references, |
| /// see here: |
| /// |
| /// https://github.com/nod-ai/MLIRWinogradTalk/blob/main/MLIRSummit2022.Nodai.Menon.pdf |
| /// |
| class ConvertConv2DNhwcHwcf final |
| : public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, |
| PatternRewriter &rewriter) const override { |
| // Check that strides = 1 |
| if (!hasAllOneValues(convOp.getStrides())) |
| return failure(); |
| |
| // Check that dilations = 1 |
| if (!hasAllOneValues(convOp.getDilations())) |
| return failure(); |
| |
| // Check that kernel has been constant folded (by validating rank = 3) |
| Value kernel = convOp.getInputs()[1]; |
| auto kernelType = kernel.getType().cast<ShapedType>(); |
| if (!kernelType) |
| return failure(); |
| Type elementType = kernelType.getElementType(); |
| ArrayRef<int64_t> kernelShape = kernelType.getShape(); |
| if (kernelShape.size() != 3) |
| return failure(); |
| |
| const int64_t kernelSize = 3; |
| const int64_t inputTileSize = outputTileSize + kernelSize - 1; |
| |
| // Create winograd input transform op |
| Location loc = convOp.getLoc(); |
| Value zero = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getZeroAttr(elementType)); |
| Value input = convOp.getInputs()[0]; |
| auto inputType = input.getType().cast<ShapedType>(); |
| if (!inputType) |
| return failure(); |
| ArrayRef<int64_t> inputShape = inputType.getShape(); |
| if (llvm::any_of(inputShape, ShapedType::isDynamic)) |
| return failure(); |
| assert(inputShape.size() == 4); |
| |
| SmallVector<int64_t, 2> imageDimensions = {1, 2}; |
| const size_t numImageDims = imageDimensions.size(); |
| SmallVector<int64_t> resultShape(6, inputTileSize); |
| llvm::SmallSetVector<int64_t, 2> imageDimensionsSet(imageDimensions.begin(), |
| imageDimensions.end()); |
| int outputIndex; |
| for (int i = 0; i < inputShape.size(); i++) { |
| outputIndex = i + numImageDims; |
| if (!imageDimensionsSet.contains(i)) { |
| resultShape[outputIndex] = inputShape[i]; |
| } else { |
| resultShape[outputIndex] = |
| std::ceil((float)(inputShape[i] - kernelSize + 1) / outputTileSize); |
| } |
| } |
| Value emptyTensor = |
| rewriter.create<tensor::EmptyOp>(loc, resultShape, elementType); |
| auto winogradInputOp = |
| rewriter.create<IREE::LinalgExt::WinogradInputTransformOp>( |
| loc, emptyTensor.getType(), ValueRange{input}, |
| ValueRange{emptyTensor}, outputTileSize, kernelSize, |
| imageDimensions); |
| Value winogradInput = winogradInputOp.getResult()[0]; |
| |
| // Add collapse shape |
| SmallVector<int64_t> collapsedShape = { |
| resultShape[0] * resultShape[1], |
| resultShape[2] * resultShape[3] * resultShape[4], resultShape[5]}; |
| SmallVector<ReassociationIndices> reassociations = {{0, 1}, {2, 3, 4}, {5}}; |
| Value collapsedWinogradInput = createCollapse( |
| winogradInput, loc, rewriter, collapsedShape, reassociations); |
| |
| // Add BatchMatmulOp |
| SmallVector<int64_t> bmmShape(collapsedShape.begin(), collapsedShape.end()); |
| Value output = convOp.getOutputs()[0]; |
| auto outputType = output.getType().cast<RankedTensorType>(); |
| ArrayRef<int64_t> outputShape = outputType.getShape(); |
| bmmShape[2] = outputShape[3]; |
| auto bmmOutputType = RankedTensorType::get(bmmShape, elementType); |
| emptyTensor = rewriter.create<tensor::EmptyOp>(loc, bmmShape, elementType); |
| auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{zero}, |
| ValueRange{emptyTensor}); |
| auto bmmOp = rewriter.create<linalg::BatchMatmulOp>( |
| loc, bmmOutputType, ValueRange({collapsedWinogradInput, kernel}), |
| ValueRange({fillOp.result()})); |
| Value bmmResult = bmmOp.getResult(0); |
| |
| // Add expand shape |
| SmallVector<int64_t> expandedShape = {resultShape[0], resultShape[1], |
| resultShape[2], resultShape[3], |
| resultShape[4], outputShape[3]}; |
| reassociations = {{0, 1}, {2, 3, 4}, {5}}; |
| Value expandedBmmResult = |
| createExpand(bmmResult, loc, rewriter, expandedShape, reassociations); |
| |
| // Convert back into original domain |
| SmallVector<int64_t> paddedResultShape(outputShape.size(), 0); |
| for (int i = 0; i < outputShape.size(); i++) { |
| if (!imageDimensionsSet.contains(i)) { |
| paddedResultShape[i] = outputShape[i]; |
| } else { |
| paddedResultShape[i] = resultShape[i + numImageDims] * outputTileSize; |
| } |
| } |
| emptyTensor = |
| rewriter.create<tensor::EmptyOp>(loc, paddedResultShape, elementType); |
| auto winogradOutputOp = |
| rewriter.create<IREE::LinalgExt::WinogradOutputTransformOp>( |
| loc, emptyTensor.getType(), ValueRange{expandedBmmResult}, |
| ValueRange{emptyTensor}, outputTileSize, kernelSize, |
| imageDimensions); |
| Value paddedOutput = winogradOutputOp.getResult()[0]; |
| |
| // Extract slice |
| SmallVector<OpFoldResult> offsets(outputShape.size(), |
| rewriter.getIndexAttr(0)); |
| SmallVector<OpFoldResult> strides(outputShape.size(), |
| rewriter.getIndexAttr(1)); |
| SmallVector<OpFoldResult> sizes; |
| for (int i = 0; i < outputShape.size(); i++) |
| sizes.push_back(rewriter.getIndexAttr(outputShape[i])); |
| auto winogradOutput = rewriter.create<tensor::ExtractSliceOp>( |
| loc, outputType, paddedOutput, offsets, sizes, strides); |
| |
| Value result = convOp.getResult(0); |
| result.replaceAllUsesWith(winogradOutput); |
| return success(); |
| } |
| }; |
| |
| struct ConvertConv2DToWinogradPass |
| : ConvertConv2DToWinogradBase<ConvertConv2DToWinogradPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry |
| .insert<linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect>(); |
| } |
| void runOnOperation() override { |
| MLIRContext *context = &getContext(); |
| RewritePatternSet patterns(&getContext()); |
| patterns.insert<FoldWinogradFilterTransform, ConvertConv2DNhwcHwcf>( |
| context); |
| if (failed(applyPatternsAndFoldGreedily(getOperation(), |
| std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<Pass> createConvertConv2DToWinogradPass() { |
| return std::make_unique<ConvertConv2DToWinogradPass>(); |
| } |
| |
| } // namespace LinalgExt |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |