blob: a8b10b9dbfbacef12fa3f525dd89febc17311342 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-vectorize-conv"
namespace mlir {
namespace iree_compiler {
namespace {
/// Vectorizes linalg.conv for a single GPU invocation. Therefore, the
/// linalg.conv op should have a very specific form; other patterns are
/// expected to tile and distribute larger convolutions into this form for
/// a single GPU invocation.
///
/// The linalg.conv op should follow:
/// - Filter: HfWfCiCo format
/// - Input : NHiWiCi format
/// - Output: NHoWoCo format
/// - For output:
/// - N must be 1.
/// - Ho must be 1.
/// - Co must be a multiple of 4.
/// - For filter:
/// - Hf must be 1.
/// - Hf must be 1.
/// - Ci must be 4.
/// - No dilation.
/// - No padding.
///
/// Output channel is requried to be a multiple of 4 so that we can process
/// them with load4/store4, which is native to GPUs. Similarly for the input
/// channel size requirement.
struct VectorizeLinalgConv : OpRewritePattern<linalg::ConvOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::ConvOp convOp,
PatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "inspecting " << convOp << "\n");
// This pattern does not handle convolutions with dilation.
if (auto dilations = convOp.dilations()) {
auto values = dilations->getAsValueRange<IntegerAttr>();
if (llvm::any_of(values, [](const APInt &value) {
return value.getSExtValue() != 1;
})) {
return failure();
}
}
auto filterViewOp = convOp.filter().getDefiningOp<SubViewOp>();
auto inputViewOp = convOp.input().getDefiningOp<SubViewOp>();
auto outputViewOp = convOp.output().getDefiningOp<SubViewOp>();
if (!filterViewOp || !inputViewOp || !outputViewOp) return failure();
// The filter/input/output view should have static sizes to vectorize.
if (!llvm::empty(filterViewOp.getDynamicSizes()) ||
!llvm::empty(inputViewOp.getDynamicSizes()) ||
!llvm::empty(outputViewOp.getDynamicSizes())) {
return failure();
}
// The output batch and height dimensions should be 1. If not, other
// patterns can generate parallel loops can distribute them.
if (outputViewOp.getStaticSize(0) != 1 ||
outputViewOp.getStaticSize(1) != 1) {
return failure();
}
// We addtionally expect the filter height/width dimensions are both 1 to
// simplify vectorization. Other patterns can generate loops to create 1x1
// filter subivews.
if (filterViewOp.getStaticSize(0) != 1 ||
filterViewOp.getStaticSize(1) != 1) {
return failure();
}
int64_t numInputChannels = filterViewOp.getStaticSize(2);
int64_t numOutputChannels = filterViewOp.getStaticSize(3);
if (numInputChannels != 4 || numOutputChannels % 4 != 0) return failure();
int64_t numOutputWidths = outputViewOp.getStaticSize(2);
int64_t widthStride = convOp.getStride(1);
// This invocation handles a batch of (numOutputWidths * numOutputChannels).
LLVM_DEBUG({
llvm::dbgs() << "# output width: " << numOutputWidths << "\n";
llvm::dbgs() << "# output channels: " << numOutputChannels << "\n";
llvm::dbgs() << "width stride: " << widthStride << "\n";
});
MLIRContext *context = convOp.getContext();
Location loc = convOp.getLoc();
Type elementType = filterViewOp.getType().getElementType();
auto filterVectorType =
VectorType::get({numInputChannels, numOutputChannels}, elementType);
auto vector1x4Type = VectorType::get({1, 4}, elementType);
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
// Load the entire filter subview.
SmallVector<Value, 4> filterIndices(4, zero);
Value wholeFilter = rewriter.create<vector::TransferReadOp>(
loc, filterVectorType, filterViewOp, filterIndices);
// Get filter slices so that later we can use them for dot product with the
// input. Both the height and width dimensions are 1; so we just need to
// loop over input and output channel dimensions.
SmallVector<SmallVector<Value, 1>, 4> filterVectors(numInputChannels);
for (int ic = 0; ic < numInputChannels; ++ic) {
auto &thisInputChannel = filterVectors[ic];
thisInputChannel.reserve(numOutputChannels / 4);
for (int oc = 0; oc < numOutputChannels / 4; ++oc) {
Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
loc, wholeFilter, /*offsets=*/ArrayRef<int64_t>({ic, oc * 4}),
/*sizes=*/ArrayRef<int64_t>({1, 4}),
/*strides=*/ArrayRef<int64_t>({1, 1}));
thisInputChannel.push_back(slice);
}
}
// Build indexing maps for a later vector contraction op.
AffineExpr dim0 = getAffineDimExpr(0, context); // M
AffineExpr dim1 = getAffineDimExpr(1, context); // N
AffineExpr dim2 = getAffineDimExpr(2, context); // K
auto map02 = AffineMap::get(3, 0, {dim0, dim2}, context);
auto map21 = AffineMap::get(3, 0, {dim2, dim1}, context);
auto map01 = AffineMap::get(3, 0, {dim0, dim1}, context);
ArrayAttr indexingMaps =
rewriter.getAffineMapArrayAttr({map02, map21, map01});
// Also build iterator types for the vector contraction op.
ArrayAttr iterators = rewriter.getStrArrayAttr(
{getParallelIteratorTypeName(), getParallelIteratorTypeName(),
getReductionIteratorTypeName()});
// Compute the (numOutputWidths * numOutputChannels) batch. We only
// contribute numInputChannels accumulation along the reduction dimension.
// So read in the result from the output, compose a chain of
// numInputChannels vector dot operations, and then write out.
for (int ow = 0; ow < numOutputWidths; ++ow) {
// Read in the input vector for these 4 input channels a a batch. The
// input vector are used for computing all output channels so data can
// be reused.
SmallVector<Value, 4> inputIndices(4, zero);
inputIndices[2] = rewriter.create<ConstantIndexOp>(loc, ow * widthStride);
Value inputVector = rewriter.create<vector::TransferReadOp>(
loc, vector1x4Type, inputViewOp, inputIndices);
for (int oc = 0; oc < numOutputChannels / 4; ++oc) {
// Read in the initial value for this output vector.
SmallVector<Value, 4> outputIndices(4, zero);
outputIndices[2] = rewriter.create<ConstantIndexOp>(loc, ow);
outputIndices[3] = rewriter.create<ConstantIndexOp>(loc, oc * 4);
Value outputVector = rewriter.create<vector::TransferReadOp>(
loc, vector1x4Type, outputViewOp, outputIndices);
// Peform a chain of dot product and accumulation.
for (int i = 0; i < numInputChannels; ++i) {
auto inputSlice = rewriter.create<vector::ExtractStridedSliceOp>(
loc, inputVector, /*offsets=*/ArrayRef<int64_t>({0, i}),
/*sizes=*/ArrayRef<int64_t>({1, 1}),
/*strides=*/ArrayRef<int64_t>({1, 1}));
outputVector = rewriter.create<vector::ContractionOp>(
loc, inputSlice, filterVectors[i][oc], outputVector, indexingMaps,
iterators);
}
// Write out the output vector.
rewriter.create<vector::TransferWriteOp>(loc, outputVector,
outputViewOp, outputIndices);
}
}
rewriter.eraseOp(convOp);
return success();
}
};
struct VectorizeLinalgConvPass
: public PassWrapper<VectorizeLinalgConvPass, OperationPass<FuncOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, vector::VectorDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
OwningRewritePatternList patterns;
patterns.insert<VectorizeLinalgConv>(context);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace
void populateVectorizeLinalgConvPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
patterns.insert<VectorizeLinalgConv>(context);
}
std::unique_ptr<Pass> createVectorizeLinalgConvPass() {
return std::make_unique<VectorizeLinalgConvPass>();
}
static PassRegistration<VectorizeLinalgConvPass> pass(
"iree-codegen-vectorize-linalg-conv",
"Vectorize a very specific form of linalg.conv");
} // namespace iree_compiler
} // namespace mlir