| // Copyright 2022 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" |
| #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" |
| #include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h" |
| #include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" |
| #include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h" |
| #include "iree-dialects/Dialect/LinalgExt/Utils/WinogradConstants.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
| #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "llvm/ADT/SetVector.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace LinalgExt { |
| |
| static inline int index(int y, int x, int dimy, int dimx) { |
| return (x + dimx * y); |
| } |
| |
| static inline int index(int z, int y, int x, int w, int dimz, int dimy, |
| int dimx, int dimw) { |
| return (w + dimw * (x + dimx * (y + dimy * z))); |
| } |
| |
| static bool hasAllOneValues(DenseIntElementsAttr attr) { |
| return llvm::all_of(attr, [](APInt element) { return element.isOne(); }); |
| } |
| |
| // TODO: Make this a user-settable parameter once we have support |
| // for more tile sizes |
| static constexpr int64_t outputTileSize = 6; |
| |
| /// This function computes the Winograd filter transform when |
| /// the filter is known to be a constant. Specifically, this |
| /// function computes matmul(G, matmul(F, transpose(G))) where |
| /// F is a tile of the convolution filter of size m x m |
| /// (single input channel, single output channel) and G has |
| /// shape m x (m + r - 1) where r is the output tile size and |
| /// (m + r - 1) is the input tile size. |
| /// The time complexity of this function is O(ic * oc) |
| /// where ic is the number of input channels and oc is the |
| /// number of output channels since input tile size and kernel size |
| /// are constants. So for large ic and oc, this function is |
| /// time intensive. |
| /// TODO: Codegen this as a kernel and run once at initialization |
| static DenseElementsAttr |
| foldFilterTransform(ArrayRef<int64_t> shape, int64_t inputTileSize, |
| int64_t kernelSize, Type outputType, const float *G, |
| bool isSplat, float splatValue, |
| DenseElementsAttr::iterator_range<APFloat> &input, |
| FloatType floatType, bool isNchw) { |
| const int &kh = isNchw ? shape[2] : shape[0]; |
| const int &kw = isNchw ? shape[3] : shape[1]; |
| const int &ic = isNchw ? shape[1] : shape[2]; |
| const int &oc = isNchw ? shape[0] : shape[3]; |
| const int64_t numElements = inputTileSize * inputTileSize * ic * oc; |
| SmallVector<APFloat> output(numElements, APFloat(0.0f)); |
| for (int d0 = 0; d0 < inputTileSize; d0++) { |
| for (int d1 = 0; d1 < inputTileSize; d1++) { |
| for (int d2 = 0; d2 < ic; d2++) { |
| for (int d3 = 0; d3 < oc; d3++) { |
| APFloat accum(0.0f); |
| for (int d4 = 0; d4 < kernelSize; d4++) { |
| for (int d5 = 0; d5 < kernelSize; d5++) { |
| APFloat ival(splatValue); |
| if (!isSplat) { |
| if (!isNchw) { |
| ival = input[index(d4, d5, d2, d3, kh, kw, ic, oc)]; |
| } else { |
| ival = input[index(d3, d2, d4, d5, oc, ic, kh, kw)]; |
| } |
| } |
| int idx0 = index(d0, d4, inputTileSize, kernelSize); |
| int idx1 = index(d1, d5, inputTileSize, kernelSize); |
| accum = accum + APFloat(G[idx0]) * ival * APFloat(G[idx1]); |
| } |
| } |
| int odx = index(d0, d1, d2, d3, inputTileSize, inputTileSize, ic, oc); |
| output[odx] = accum; |
| if (floatType.isF16()) { |
| bool losesInfo; |
| output[odx].convert(APFloat::IEEEhalf(), |
| APFloat::rmNearestTiesToEven, &losesInfo); |
| } |
| } |
| } |
| } |
| } |
| return DenseElementsAttr::get(outputType, output); |
| } |
| |
| template <typename T> |
| static bool hasValidStridesAndDilations(Operation *op) { |
| auto convOp = dyn_cast<T>(op); |
| // Check that strides = 1 |
| if (!hasAllOneValues(convOp.getStrides())) |
| return false; |
| |
| // Check that dilations = 1 |
| if (!hasAllOneValues(convOp.getDilations())) |
| return false; |
| return true; |
| } |
| |
| static bool isValidConv2d(Operation *op, bool &isNchw) { |
| isNchw = isa<linalg::Conv2DNchwFchwOp>(op); |
| const bool isNhwc = isa<linalg::Conv2DNhwcHwcfOp>(op); |
| if (!(isNchw || isNhwc)) |
| return false; |
| return (isNchw ? hasValidStridesAndDilations<linalg::Conv2DNchwFchwOp>(op) |
| : hasValidStridesAndDilations<linalg::Conv2DNhwcHwcfOp>(op)); |
| } |
| |
| namespace { |
| |
| template <typename ConvOp> |
| class FoldWinogradFilterTransform final : public OpRewritePattern<ConvOp> { |
| public: |
| using OpRewritePattern<ConvOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ConvOp convOp, |
| PatternRewriter &rewriter) const override { |
| |
| bool isNchw; |
| if (!isValidConv2d(convOp, isNchw)) |
| return failure(); |
| |
| // Check that kernel size = 3x3 |
| Value kernel = convOp.getInputs()[1]; |
| auto kernelType = kernel.getType().cast<ShapedType>(); |
| if (!kernelType) |
| return failure(); |
| ArrayRef<int64_t> kernelShape = kernelType.getShape(); |
| if (kernelShape.size() != 4) |
| return failure(); |
| const int64_t kh = isNchw ? kernelShape[2] : kernelShape[0]; |
| const int64_t kw = isNchw ? kernelShape[3] : kernelShape[1]; |
| if ((kh != 3) || (kw != 3)) |
| return failure(); |
| const int64_t kernelSize = kh; |
| const int64_t inputTileSize = outputTileSize + kernelSize - 1; |
| |
| DenseIntOrFPElementsAttr kernelAttr; |
| if (!matchPattern(kernel, m_Constant(&kernelAttr))) { |
| return failure(); |
| } |
| |
| Operation *constOp = kernel.getDefiningOp(); |
| ShapedType type = constOp->getResult(0).getType().cast<ShapedType>(); |
| auto elemType = type.getElementType().cast<FloatType>(); |
| ArrayRef<int64_t> shape = type.getShape(); |
| DenseElementsAttr::iterator_range<APFloat> nonSplatValues = |
| kernelAttr.getValues<APFloat>(); |
| bool isSplat = kernelAttr.isSplat(); |
| float splatValue{0.0}; |
| if (isSplat) { |
| splatValue = kernelAttr.getSplatValue<APFloat>().convertToFloat(); |
| } |
| SmallVector<int64_t> resultShape{inputTileSize * inputTileSize, shape[2], |
| shape[3]}; |
| if (isNchw) { |
| resultShape[1] = shape[1]; |
| resultShape[2] = shape[0]; |
| } |
| auto resultType = RankedTensorType::get(resultShape, elemType); |
| auto foldedKernelAttr = |
| foldFilterTransform(shape, inputTileSize, kernelSize, resultType, |
| IREE::LinalgExt::Winograd::G_6x6_3x3, isSplat, |
| splatValue, nonSplatValues, elemType, isNchw); |
| rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, foldedKernelAttr); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| static Value |
| createCollapse(Value tensor, Location loc, PatternRewriter &rewriter, |
| SmallVectorImpl<int64_t> &outputShape, |
| SmallVectorImpl<ReassociationIndices> &reassociations) { |
| auto tensorType = tensor.getType().cast<ShapedType>(); |
| auto elementTy = tensorType.getElementType(); |
| auto resultType = RankedTensorType::get(outputShape, elementTy); |
| return rewriter.create<tensor::CollapseShapeOp>(loc, resultType, tensor, |
| reassociations); |
| } |
| |
| static Value |
| createExpand(Value tensor, Location loc, PatternRewriter &rewriter, |
| SmallVectorImpl<int64_t> &outputShape, |
| SmallVectorImpl<ReassociationIndices> &reassociations) { |
| auto tensorType = tensor.getType().cast<ShapedType>(); |
| auto elementTy = tensorType.getElementType(); |
| auto resultType = RankedTensorType::get(outputShape, elementTy); |
| return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor, |
| reassociations); |
| } |
| |
| namespace { |
| |
| /// Convert conv2d to a sequence of ops that implement the |
| /// Winograd transformation. The Winograd transformation |
| /// is parameterized by the output tile size(r). The larger |
| /// the tile size, the greater the computational savings, |
| /// but this comes at the cost of accuracy. |
| /// |
| /// For now, we restrict this transform to convolutions |
| /// where the filter size = 3x3, though extensions to larger |
| /// filter sizes are possible. We refer to the |
| /// filter size as (m). The input tile size (i) is defined as |
| /// m + r - 1. For a given output tile size, the Winograd |
| /// transformation defines 3 constant matrices: |
| /// |
| /// B: i x i [used in input transform] |
| /// G: m x i [used in the filter transform] |
| /// A: i x r [used in output transform] |
| /// |
| /// The choice of these matrices is not unique and affects |
| /// the accuracy of the approach. |
| /// |
| /// Given a convolution of the form |
| /// |
| /// y = conv2d(x, f) |
| /// |
| /// where x: (N, H, W, C) | (N, C, H, W) |
| /// f: (H, W, C, F) | (F, C, H, W) |
| /// |
| /// this pattern converts the convolution to the following |
| /// sequence: |
| /// |
| /// f_winograd = winograd.filter_transform(f) [folded] |
| /// x_winograd = winograd.input_transform(x) |
| /// x_winograd_c = collapse(x_winograd) |
| /// y_winograd = batch_matmul(x_winograd_c, f_winograd) |
| /// y_winograd_e = expand(y_winograd) |
| /// y_padded = winograd.output_transform(y_winograd_e) |
| /// y = extract_slice(y_padded) |
| /// |
| /// where the dimensions of the tensors above are: |
| /// |
| /// f_winograd: (i * i, C, F) |
| /// x_winograd: (i, i, N, H', W', C) |
| /// x_winograd_c: (i * i, N * H' * W', C) |
| /// y_winograd: (i * i, N * H' * W', F) |
| /// y_winograd_e: (i, i, N, H', W', F) |
| /// y_padded: (N, r * H', r * W', F) | (N, F, r * H', r * W') |
| /// |
| /// H': ceil((H - m + 1) / r) |
| /// W': ceil((W - m + 1) / r) |
| /// |
| /// The winograd input transform extracts a tile of the input |
| /// of size i x i and computes matmul(transpose(B), matmul(tile(x), B)). |
| /// The winograd filter transform extracts a tile of the filter |
| /// of size m x m and computes matmul(G, matmul(tile(f), transpose(G)). |
| /// These two are then combined using elementwise multiplication |
| /// (which becomes a batch matmul when combining over multiple channels). |
| /// The winograd output filter extracts a tile of the result of size |
| /// i x i and computes matmul(transpose(A), matmul(tile(y_winograd_e), A)). |
| /// |
| /// For more information and additional references, |
| /// see here: |
| /// |
| /// https://github.com/nod-ai/MLIRWinogradTalk/blob/main/MLIRSummit2022.Nodai.Menon.pdf |
| /// |
| template <typename ConvOp> |
| class ConvertConvToWinograd final : public OpRewritePattern<ConvOp> { |
| public: |
| using OpRewritePattern<ConvOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ConvOp convOp, |
| PatternRewriter &rewriter) const override { |
| |
| bool isNchw; |
| if (!isValidConv2d(convOp, isNchw)) |
| return failure(); |
| |
| // Check that kernel has been constant folded (by validating rank = 3) |
| Value kernel = convOp.getInputs()[1]; |
| auto kernelType = kernel.getType().cast<ShapedType>(); |
| if (!kernelType) |
| return failure(); |
| Type elementType = kernelType.getElementType(); |
| ArrayRef<int64_t> kernelShape = kernelType.getShape(); |
| if (kernelShape.size() != 3) |
| return failure(); |
| |
| const int64_t kernelSize = 3; |
| const int64_t inputTileSize = outputTileSize + kernelSize - 1; |
| |
| // Create winograd input transform op |
| Location loc = convOp.getLoc(); |
| Value zero = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getZeroAttr(elementType)); |
| Value input = convOp.getInputs()[0]; |
| auto inputType = input.getType().cast<ShapedType>(); |
| if (!inputType) |
| return failure(); |
| SmallVector<int64_t> inputShape(inputType.getShape()); |
| if (llvm::any_of(inputShape, ShapedType::isDynamic)) |
| return failure(); |
| assert(inputShape.size() == 4); |
| if (isNchw) { |
| permute<IREE::LinalgExt::Permutation::NCHW_TO_NHWC>(inputShape); |
| } |
| |
| const std::array<int64_t, 2> nhwcImageDimensions{1, 2}; |
| const std::array<int64_t, 2> nchwImageDimensions{2, 3}; |
| const size_t numImageDims = nhwcImageDimensions.size(); |
| SmallVector<int64_t> resultShape(6, inputTileSize); |
| llvm::SmallSetVector<int64_t, 2> imageDimensionsSet( |
| nhwcImageDimensions.begin(), nhwcImageDimensions.end()); |
| int outputIndex; |
| for (int i = 0; i < inputShape.size(); i++) { |
| outputIndex = i + numImageDims; |
| if (!imageDimensionsSet.contains(i)) { |
| resultShape[outputIndex] = inputShape[i]; |
| } else { |
| resultShape[outputIndex] = |
| std::ceil((float)(inputShape[i] - kernelSize + 1) / outputTileSize); |
| } |
| } |
| Value emptyTensor = |
| rewriter.create<tensor::EmptyOp>(loc, resultShape, elementType); |
| auto &imageDimensions = isNchw ? nchwImageDimensions : nhwcImageDimensions; |
| auto winogradInputOp = |
| rewriter.create<IREE::LinalgExt::WinogradInputTransformOp>( |
| loc, emptyTensor.getType(), ValueRange{input}, |
| ValueRange{emptyTensor}, outputTileSize, kernelSize, |
| imageDimensions); |
| Value winogradInput = winogradInputOp.getResult()[0]; |
| |
| // Add collapse shape |
| SmallVector<int64_t> collapsedShape = { |
| resultShape[0] * resultShape[1], |
| resultShape[2] * resultShape[3] * resultShape[4], resultShape[5]}; |
| SmallVector<ReassociationIndices> reassociations = {{0, 1}, {2, 3, 4}, {5}}; |
| Value collapsedWinogradInput = createCollapse( |
| winogradInput, loc, rewriter, collapsedShape, reassociations); |
| |
| // Add BatchMatmulOp |
| SmallVector<int64_t> bmmShape(collapsedShape.begin(), collapsedShape.end()); |
| Value output = convOp.getOutputs()[0]; |
| auto outputType = output.getType().cast<RankedTensorType>(); |
| SmallVector<int64_t> outputShape(outputType.getShape()); |
| if (isNchw) { |
| permute<IREE::LinalgExt::Permutation::NCHW_TO_NHWC>(outputShape); |
| } |
| bmmShape[2] = outputShape[3]; |
| auto bmmOutputType = RankedTensorType::get(bmmShape, elementType); |
| emptyTensor = rewriter.create<tensor::EmptyOp>(loc, bmmShape, elementType); |
| auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{zero}, |
| ValueRange{emptyTensor}); |
| auto bmmOp = rewriter.create<linalg::BatchMatmulOp>( |
| loc, bmmOutputType, ValueRange({collapsedWinogradInput, kernel}), |
| ValueRange({fillOp.result()})); |
| Value bmmResult = bmmOp.getResult(0); |
| |
| // Add expand shape |
| SmallVector<int64_t> expandedShape = {resultShape[0], resultShape[1], |
| resultShape[2], resultShape[3], |
| resultShape[4], outputShape[3]}; |
| reassociations = {{0, 1}, {2, 3, 4}, {5}}; |
| Value expandedBmmResult = |
| createExpand(bmmResult, loc, rewriter, expandedShape, reassociations); |
| |
| // Convert back into original domain |
| SmallVector<int64_t> paddedResultShape(outputShape.size(), 0); |
| for (int i = 0; i < outputShape.size(); i++) { |
| if (!imageDimensionsSet.contains(i)) { |
| paddedResultShape[i] = outputShape[i]; |
| } else { |
| paddedResultShape[i] = resultShape[i + numImageDims] * outputTileSize; |
| } |
| } |
| if (isNchw) { |
| permute<IREE::LinalgExt::Permutation::NHWC_TO_NCHW>(paddedResultShape); |
| } |
| emptyTensor = |
| rewriter.create<tensor::EmptyOp>(loc, paddedResultShape, elementType); |
| auto winogradOutputOp = |
| rewriter.create<IREE::LinalgExt::WinogradOutputTransformOp>( |
| loc, emptyTensor.getType(), ValueRange{expandedBmmResult}, |
| ValueRange{emptyTensor}, outputTileSize, kernelSize, |
| imageDimensions); |
| Value paddedOutput = winogradOutputOp.getResult()[0]; |
| |
| // Extract slice |
| SmallVector<OpFoldResult> offsets(outputShape.size(), |
| rewriter.getIndexAttr(0)); |
| SmallVector<OpFoldResult> strides(outputShape.size(), |
| rewriter.getIndexAttr(1)); |
| SmallVector<OpFoldResult> sizes; |
| for (const int64_t shape : outputType.getShape()) |
| sizes.push_back(rewriter.getIndexAttr(shape)); |
| auto winogradOutput = rewriter.create<tensor::ExtractSliceOp>( |
| loc, outputType, paddedOutput, offsets, sizes, strides); |
| |
| Value result = convOp.getResult(0); |
| result.replaceAllUsesWith(winogradOutput); |
| return success(); |
| } |
| }; |
| |
| struct ConvertConv2DToWinogradPass |
| : ConvertConv2DToWinogradBase<ConvertConv2DToWinogradPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry |
| .insert<linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect>(); |
| } |
| void runOnOperation() override { |
| MLIRContext *context = &getContext(); |
| RewritePatternSet patterns(&getContext()); |
| patterns.insert<FoldWinogradFilterTransform<linalg::Conv2DNchwFchwOp>, |
| FoldWinogradFilterTransform<linalg::Conv2DNhwcHwcfOp>, |
| ConvertConvToWinograd<linalg::Conv2DNhwcHwcfOp>, |
| ConvertConvToWinograd<linalg::Conv2DNchwFchwOp>>(context); |
| if (failed(applyPatternsAndFoldGreedily(getOperation(), |
| std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<Pass> createConvertConv2DToWinogradPass() { |
| return std::make_unique<ConvertConv2DToWinogradPass>(); |
| } |
| |
| } // namespace LinalgExt |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |