Adapt mhlo.conv to Linalg on buffers patterns to be on tensors. (#4748)
Lowers a conv op to a static rank conv op. The general conv op only
supports on memrefs. This limits the support for conv op to specific
data layout and only work for few ranks, ie 1d, 2d, and 3d. This
separates the lowering to normal conv and depthwise conv. The former one
is lowering to Linalg on tensors, and the latter one is leaving with the
same behavior, ie, lowering to Linalg on buffers.
Add support of TC conv ops to later GPU passes.
Depends on https://reviews.llvm.org/D96038
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 1086851..d972351 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -325,13 +325,6 @@
ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
if (!hasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
- llvm::SmallVector<Attribute, 4> strides;
- if (auto windowStrides = op.window_strides()) {
- auto range = windowStrides->getAttributeValues();
- strides.append(range.begin(), range.end());
- }
- auto stridesArg = ArrayAttr::get(op.getContext(), strides);
-
// TODO(ataei): Only support dilated convolution for now. We need to consider
// LHS dilation for deconvolution cases.
llvm::SmallVector<Attribute, 4> dilation;
@@ -339,7 +332,6 @@
auto range = rhsDilation->getAttributeValues();
dilation.append(range.begin(), range.end());
}
- auto dilationArg = ArrayAttr::get(op.getContext(), dilation);
// Set padding only if it is non-zero.
DenseIntElementsAttr padding = op.paddingAttr();
@@ -361,92 +353,88 @@
shape[op.dimension_numbers().kernel_input_feature_dimension().getInt()];
auto groupSize =
shape[op.dimension_numbers().kernel_output_feature_dimension().getInt()];
- // Depthwise conv path...
- if (op.feature_group_count() > 1u && op.feature_group_count() == numGroups) {
- // Lowering depthwise convolution to linalg.generic op. The idea is to use
- // the group convolution formulation to perform the separable depthwise
- // convolution as the following, given an n-dimensional input x and filter w
- // the direct convolution operation can be written as:
- // y[n, d1, d2, ....dn, ci * groupSize + co] = sum(k1, k2, ....kn,
- // x[n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn]
- // * w[k1, k2, ...kn, ci, co])
-
- // TODO(ataei): Support dilation.
- if (llvm::any_of(dilation, [](Attribute attr) {
- return (attr.dyn_cast<IntegerAttr>().getInt() != 1);
- })) {
- return failure();
- }
-
- SmallVector<AffineExpr, 4> inputExprs;
- SmallVector<AffineExpr, 4> filterExprs;
- SmallVector<AffineExpr, 4> outputExprs;
-
- const auto spatialDims =
- llvm::size(op.dimension_numbers().input_spatial_dimensions());
- const int d1Index = 1;
- const int coIndex = d1Index + spatialDims;
- const int ciIndex = coIndex + 1;
- const int k1Index = ciIndex + 1;
- // n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn
- inputExprs.push_back(rewriter.getAffineDimExpr(0));
- for (int i = 0; i < spatialDims; ++i) {
- if (op.window_stridesAttr()) {
- auto stride = op.window_stridesAttr().getValue<APInt>(i);
- inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) *
- stride.getZExtValue() +
- rewriter.getAffineDimExpr(k1Index + i));
- } else {
- inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) +
- rewriter.getAffineDimExpr(k1Index + i));
- }
- }
- inputExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
-
- // k1, k2, ...kn, ci, co
- for (int i = 0; i < spatialDims; ++i) {
- filterExprs.push_back(rewriter.getAffineDimExpr(k1Index + i));
- }
- filterExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
- filterExprs.push_back(rewriter.getAffineDimExpr(coIndex));
-
- // n, d1, d2, ....dn, ci * groupSize + co
- outputExprs.push_back(rewriter.getAffineDimExpr(0));
- for (int i = 0; i < spatialDims; ++i) {
- outputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i));
- }
- outputExprs.push_back(rewriter.getAffineDimExpr(ciIndex) * groupSize +
- rewriter.getAffineDimExpr(coIndex));
-
- // nloops = |d| + |k| + |{n, ci, co}|
- int nloops = spatialDims * 2 + 3;
- SmallVector<AffineMap, 4> indexingMaps;
- indexingMaps.emplace_back(AffineMap::get(
- nloops, /*symbolCount=*/0, inputExprs, rewriter.getContext()));
- indexingMaps.emplace_back(AffineMap::get(
- nloops, /*symbolCount=*/0, filterExprs, rewriter.getContext()));
- indexingMaps.emplace_back(AffineMap::get(
- nloops, /*symbolCount=*/0, outputExprs, rewriter.getContext()));
-
- Location loc = op.getLoc();
-
- SmallVector<StringRef, 3> loopAttributeTypes(spatialDims + 3, "parallel");
- loopAttributeTypes.append(spatialDims, "reduction");
- rewriter.create<linalg::GenericOp>(
- loc,
- /*resultTensorTypes=*/ArrayRef<Type>{},
- /*inputs=*/inputBuffers,
- /*outputs=*/resultBuffers, indexingMaps, loopAttributeTypes,
- [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- Value mul = nestedBuilder.create<MulFOp>(nestedLoc, args[0], args[1]);
- Value add = nestedBuilder.create<AddFOp>(nestedLoc, mul, args[2]);
- nestedBuilder.create<linalg::YieldOp>(loc, add);
- });
- } else {
- rewriter.create<linalg::ConvOp>(op.getLoc(), inputBuffers[1],
- inputBuffers[0], resultBuffers[0],
- stridesArg, dilationArg, padding);
+ if (op.feature_group_count() <= 1u || op.feature_group_count() != numGroups) {
+ return failure();
}
+ // Lowering depthwise convolution to linalg.generic op. The idea is to use
+ // the group convolution formulation to perform the separable depthwise
+ // convolution as the following, given an n-dimensional input x and filter w
+ // the direct convolution operation can be written as:
+ // y[n, d1, d2, ....dn, ci * groupSize + co] = sum(k1, k2, ....kn,
+ // x[n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn]
+ // * w[k1, k2, ...kn, ci, co])
+
+ // TODO(ataei): Support dilation.
+ if (llvm::any_of(dilation, [](Attribute attr) {
+ return (attr.dyn_cast<IntegerAttr>().getInt() != 1);
+ })) {
+ return failure();
+ }
+
+ SmallVector<AffineExpr, 4> inputExprs;
+ SmallVector<AffineExpr, 4> filterExprs;
+ SmallVector<AffineExpr, 4> outputExprs;
+
+ const auto spatialDims =
+ llvm::size(op.dimension_numbers().input_spatial_dimensions());
+ const int d1Index = 1;
+ const int coIndex = d1Index + spatialDims;
+ const int ciIndex = coIndex + 1;
+ const int k1Index = ciIndex + 1;
+ // n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn
+ inputExprs.push_back(rewriter.getAffineDimExpr(0));
+ for (int i = 0; i < spatialDims; ++i) {
+ if (op.window_stridesAttr()) {
+ auto stride = op.window_stridesAttr().getValue<APInt>(i);
+ inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) *
+ stride.getZExtValue() +
+ rewriter.getAffineDimExpr(k1Index + i));
+ } else {
+ inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) +
+ rewriter.getAffineDimExpr(k1Index + i));
+ }
+ }
+ inputExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
+
+ // k1, k2, ...kn, ci, co
+ for (int i = 0; i < spatialDims; ++i) {
+ filterExprs.push_back(rewriter.getAffineDimExpr(k1Index + i));
+ }
+ filterExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
+ filterExprs.push_back(rewriter.getAffineDimExpr(coIndex));
+
+ // n, d1, d2, ....dn, ci * groupSize + co
+ outputExprs.push_back(rewriter.getAffineDimExpr(0));
+ for (int i = 0; i < spatialDims; ++i) {
+ outputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i));
+ }
+ outputExprs.push_back(rewriter.getAffineDimExpr(ciIndex) * groupSize +
+ rewriter.getAffineDimExpr(coIndex));
+
+ // nloops = |d| + |k| + |{n, ci, co}|
+ int nloops = spatialDims * 2 + 3;
+ SmallVector<AffineMap, 4> indexingMaps;
+ indexingMaps.emplace_back(AffineMap::get(nloops, /*symbolCount=*/0,
+ inputExprs, rewriter.getContext()));
+ indexingMaps.emplace_back(AffineMap::get(nloops, /*symbolCount=*/0,
+ filterExprs, rewriter.getContext()));
+ indexingMaps.emplace_back(AffineMap::get(nloops, /*symbolCount=*/0,
+ outputExprs, rewriter.getContext()));
+
+ Location loc = op.getLoc();
+
+ SmallVector<StringRef, 3> loopAttributeTypes(spatialDims + 3, "parallel");
+ loopAttributeTypes.append(spatialDims, "reduction");
+ rewriter.create<linalg::GenericOp>(
+ loc,
+ /*resultTensorTypes=*/ArrayRef<Type>{},
+ /*inputs=*/inputBuffers,
+ /*outputs=*/resultBuffers, indexingMaps, loopAttributeTypes,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value mul = nestedBuilder.create<MulFOp>(nestedLoc, args[0], args[1]);
+ Value add = nestedBuilder.create<AddFOp>(nestedLoc, mul, args[2]);
+ nestedBuilder.create<linalg::YieldOp>(loc, add);
+ });
return success();
}
@@ -838,20 +826,26 @@
}
};
-/// Converts linalg.matmul-ish on tensors to linalg.matmul-ish on buffers.
+/// Converts a linalg named op on tensors to linalg named op on buffers.
template <typename LinalgOpTy>
-struct MatmulOnTensorConversion
- : public ConvertToLinalgBufferOp<MatmulOnTensorConversion<LinalgOpTy>,
+struct NamedOpConversion
+ : public ConvertToLinalgBufferOp<NamedOpConversion<LinalgOpTy>,
LinalgOpTy> {
- using ConvertToLinalgBufferOp<MatmulOnTensorConversion<LinalgOpTy>,
+ using ConvertToLinalgBufferOp<NamedOpConversion<LinalgOpTy>,
LinalgOpTy>::ConvertToLinalgBufferOp;
LogicalResult apply(LinalgOpTy op, ArrayRef<Value> inputBuffers,
ArrayRef<Value> resultBuffers,
ConversionPatternRewriter &rewriter) const {
if (!op.hasTensorSemantics()) return failure();
- // The last one is a init tensor.
- rewriter.create<LinalgOpTy>(
- op.getLoc(), inputBuffers.drop_back(op.getNumResults()), resultBuffers);
+ auto linalgOp = cast<linalg::LinalgOp>(op.getOperation());
+ SmallVector<Value, 8> newOperands;
+ newOperands.append(inputBuffers.begin(),
+ inputBuffers.end() - op.getNumResults());
+ newOperands.append(resultBuffers.begin(), resultBuffers.end());
+ auto otherOperands = linalgOp.getAssumedNonShapedOperands();
+ newOperands.append(otherOperands.begin(), otherOperands.end());
+ Location loc = op.getLoc();
+ linalgOp.clone(rewriter, loc, /*resultTypes=*/TypeRange{}, newOperands);
return success();
}
};
@@ -1244,12 +1238,16 @@
void populateHLOToLinalgOnBuffersConversionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
TensorToBufferMap const &resultTensorToBufferMap) {
- patterns.insert<ConvOpConversion, ConcatenateOpConversion,
- FillOpOnTensorConversion, InitTensorOpConversion,
+ patterns.insert<ConvOpConversion, DepthwiseConvOpConversion,
+ ConcatenateOpConversion, FillOpOnTensorConversion,
+ InitTensorOpConversion,
LinalgOpOnTensorConversion<linalg::GenericOp>,
LinalgOpOnTensorConversion<linalg::IndexedGenericOp>,
- MatmulOnTensorConversion<linalg::MatmulOp>,
- MatmulOnTensorConversion<linalg::BatchMatmulOp>,
+ NamedOpConversion<linalg::ConvInputNWCFilterWCFOp>,
+ NamedOpConversion<linalg::ConvInputNHWCFilterHWCFOp>,
+ NamedOpConversion<linalg::ConvInputNDHWCFilterDHWCFOp>,
+ NamedOpConversion<linalg::MatmulOp>,
+ NamedOpConversion<linalg::BatchMatmulOp>,
PadTensorOpConversion, ReduceWindowOpConversion,
SubTensorOpConversion, TensorReshapeOpConversion>(
context, resultTensorToBufferMap);
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
index d088a14..c145e90 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
@@ -251,6 +251,156 @@
};
} // namespace
+//===----------------------------------------------------------------------===//
+// mhlo.conv conversion patterns.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+static bool isDepthwiseConv(mhlo::ConvOp op) {
+ auto shape = op.rhs().getType().cast<ShapedType>().getShape();
+ auto numGroups =
+ shape[op.dimension_numbers().kernel_input_feature_dimension().getInt()];
+ return op.feature_group_count() > 1u && op.feature_group_count() == numGroups;
+}
+
+/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
+/// follows a canonical form:
+///
+/// * Input dimensions have order: (batch_count, spatial_dims,
+/// input_channel_count).
+/// * Filter dimensions have order: (spatial_dims, input_channel_count,
+/// output_channel_count).
+/// * Output dimensions have order: (batch_count, spatial_dims,
+/// output_channel_count).
+static bool hasCanonicalDimensionNumbers(
+ const mhlo::ConvDimensionNumbers &dimensionNumbers) {
+ const int inputSpatialRank =
+ llvm::size(dimensionNumbers.input_spatial_dimensions());
+ // The dimensions for input should follow the order of
+ // batch_count, spatial_dims..., input_feature_count.
+ if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
+ dimensionNumbers.input_feature_dimension().getInt() !=
+ (inputSpatialRank + 1)) {
+ return false;
+ }
+
+ const int kernelSpatialRank =
+ llvm::size(dimensionNumbers.kernel_spatial_dimensions());
+ // The dimensions for filter should follow the order of
+ // spatial_dims..., input_feature_count, num_output_feature_count.
+ if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
+ kernelSpatialRank ||
+ dimensionNumbers.kernel_output_feature_dimension().getInt() !=
+ (kernelSpatialRank + 1)) {
+ return false;
+ }
+
+ const int outputSpatialRank =
+ llvm::size(dimensionNumbers.output_spatial_dimensions());
+ // The dimensions for output should follow the order of
+ // batch_count, spatial_dims.., output_feature_count.
+ if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
+ dimensionNumbers.output_feature_dimension().getInt() !=
+ (outputSpatialRank + 1)) {
+ return false;
+ }
+
+ if (inputSpatialRank != outputSpatialRank ||
+ inputSpatialRank != kernelSpatialRank) {
+ return false;
+ }
+
+ auto inputSpatialDim = dimensionNumbers.input_spatial_dimensions().begin();
+ auto kernelSpatialDim = dimensionNumbers.kernel_spatial_dimensions().begin();
+ auto outputSpatialDim = dimensionNumbers.output_spatial_dimensions().begin();
+ // Check spatial dims are ordred correctly.
+ for (int i = 0; i < inputSpatialRank; ++i) {
+ const int dim = i + 1;
+ if ((*inputSpatialDim++).getZExtValue() != dim ||
+ (*outputSpatialDim++).getZExtValue() != dim ||
+ (*kernelSpatialDim++).getZExtValue() != i) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+/// Converts mhlo.conv operation to linalg named op. This only covers normal
+/// convolution cases. The op must have canonical dimension numbers. Depthwise
+/// convolution and pointwise convolution are not handled in the conversion.
+struct NormalConvOpConversion : public OpConversionPattern<mhlo::ConvOp> {
+ using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mhlo::ConvOp op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!hasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
+ if (isDepthwiseConv(op)) return failure();
+
+ mhlo::ConvOp::Adaptor adaptor(args);
+ Location loc = op.getLoc();
+ Value input = adaptor.lhs();
+ Value filter = adaptor.rhs();
+ auto resultType = op.getResult().getType().cast<ShapedType>();
+ int rank = resultType.getRank();
+
+ // Check if padding is zero.
+ DenseIntElementsAttr padding = op.paddingAttr();
+ if (padding &&
+ (!padding.isSplat() || padding.getSplatValue<int64_t>() != 0)) {
+ return rewriter.notifyMatchFailure(op, "expected no padding");
+ }
+
+ // The output shape is N spatial_dims F.
+ SmallVector<Value, 8> dynSizes;
+ for (int i = 0, e = rank - 1; i < e; ++i) {
+ if (!resultType.isDynamicDim(i)) continue;
+ dynSizes.push_back(rewriter.create<DimOp>(loc, input, i));
+ }
+ if (resultType.isDynamicDim(rank - 1)) {
+ dynSizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
+ }
+ Value initTensor = rewriter.create<linalg::InitTensorOp>(
+ loc, dynSizes, resultType.getShape(), resultType.getElementType());
+ auto zeroAttr = rewriter.getZeroAttr(resultType.getElementType());
+ Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
+ Value zeroTensor =
+ rewriter.create<linalg::FillOp>(loc, initTensor, zero).getResult(0);
+ linalg::LinalgOp res;
+ Attribute strides = op.window_stridesAttr();
+ // TODO(ataei): Only support dilated kernel right now. We need to consider
+ // input dilation for deconvolution cases.
+ Attribute dilations = op.rhs_dilationAttr();
+ switch (rank) {
+ case 3: {
+ res = rewriter.create<linalg::ConvInputNWCFilterWCFOp>(
+ loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
+ dilations, strides);
+ break;
+ }
+ case 4: {
+ res = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
+ loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
+ dilations, strides);
+ break;
+ }
+ case 5: {
+ res = rewriter.create<linalg::ConvInputNDHWCFilterDHWCFOp>(
+ loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
+ dilations, strides);
+ break;
+ }
+ default:
+ return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
+ }
+ rewriter.replaceOp(op, res.getOperation()->getResults());
+ return success();
+ }
+};
+} // namespace
+
struct ConvertHLOToLinalgOnTensorsPass
: public PassWrapper<ConvertHLOToLinalgOnTensorsPass, FunctionPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -295,7 +445,8 @@
MLIRContext *context, OwningRewritePatternList &patterns) {
mhlo::populateHLOToLinalgConversionPattern(context, &patterns);
patterns.insert<TorchIndexSelectOpConversion, SliceOpConversion,
- ConstOpConversion, PadOpConversion>(context);
+ ConstOpConversion, PadOpConversion, NormalConvOpConversion>(
+ context);
}
std::unique_ptr<OperationPass<FuncOp>> createHLOToLinalgOnTensorsPass() {
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir b/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir
index 4006f02..008b380 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir
@@ -1,123 +1,161 @@
-// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers %s | IreeFileCheck %s
+// RUN: iree-opt -iree-codegen-hlo-to-linalg-on-tensors -canonicalize %s | IreeFileCheck %s
-module {
- // CHECK-LABEL: func @conv
- func @conv() {
- %c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<3x5x5x3xf32>
- %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x2x3x4xf32>
- // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) {
- // CHECK-SAME: dilations = [1, 2]
- // CHECK-SAME: padding = dense<[
- // CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
- // CHECK-SAME: strides = [2, 1]}
- %2 = "mhlo.convolution"(%1, %0) {
- batch_group_count = 1 : i64,
- dimension_numbers = {
- input_batch_dimension = 0 : i64,
- input_feature_dimension = 3 : i64,
- input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
- kernel_input_feature_dimension = 2 : i64,
- kernel_output_feature_dimension = 3 : i64,
- kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
- output_batch_dimension = 0 : i64,
- output_feature_dimension = 3 : i64,
- output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
- },
- feature_group_count = 1 : i64,
- padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
- rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
- window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
- hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<3x5x5x4xf32>
- return
- }
- hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
- }
+// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
+// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]]
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
+// CHECK: linalg.conv_1d_input_nwc_filter_wcf
+// CHECK-SAME: {dilations = dense<1> : tensor<1xi64>
+// CHECK-SAME: strides = dense<1> : tensor<1xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
+ -> tensor<?x?x?xf32> {
+ %0 = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 2 : i64,
+ input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
+ kernel_input_feature_dimension = 1 : i64,
+ kernel_output_feature_dimension = 2 : i64,
+ kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 2 : i64,
+ output_spatial_dimensions = dense<[1]> : tensor<1xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<[[0], [0]]> : tensor<2x1xi64>,
+ rhs_dilation = dense<1> : tensor<1xi64>,
+ window_strides = dense<1> : tensor<1xi64>
+ } : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
}
-// -----
-
-module {
- func @depthwise_conv() {
- %c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x4x5x2xf32>
- %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x2x2x3xf32>
- %2 = "mhlo.convolution"(%0, %1) {
- batch_group_count = 1 : i64,
- dimension_numbers = {
- input_batch_dimension = 0 : i64,
- input_feature_dimension = 3 : i64,
- input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
- kernel_input_feature_dimension = 2 : i64,
- kernel_output_feature_dimension = 3 : i64,
- kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
- output_batch_dimension = 0 : i64,
- output_feature_dimension = 3 : i64,
- output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
- },
- feature_group_count = 2 : i64,
- padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
- hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<2x3x4x6xf32>
- return
- }
- hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
- }
+// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<?x?x?x?xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
+// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>)
+ -> tensor<?x?x?x?xf32> {
+ %0 = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d4)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d4, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4 * 3 + d3)>
-// CHECK: func @depthwise_conv()
-// CHECK: linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]
-// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
-// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<2x3x4x6xf32>)
-// CHECK: mulf
-// CHECK: addf
-// -----
-
-module {
- func @depthwise_conv_multiplier_1() {
- %c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x113x113x96xf32>
- %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<3x3x1x96xf32>
- %2 = "mhlo.convolution"(%0, %1) {
- batch_group_count = 1 : i64,
- dimension_numbers = {
- input_batch_dimension = 0 : i64,
- input_feature_dimension = 3 : i64,
- input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
- kernel_input_feature_dimension = 2 : i64,
- kernel_output_feature_dimension = 3 : i64,
- kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
- output_batch_dimension = 0 : i64,
- output_feature_dimension = 3 : i64,
- output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
- },
- feature_group_count = 96 : i64,
- padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32>
- hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<1x56x56x96xf32>
- return
- }
- hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
- }
+// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?x?xf32>
+// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?x?xf32>
+// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?x?xf32>
+// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?x?xf32>
+// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<?x?x?x?x?xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]]
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
+// CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf
+// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>
+// CHECK-SAME: strides = dense<1> : tensor<3xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>)
+ -> tensor<?x?x?x?x?xf32> {
+ %0 = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 4 : i64,
+ input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
+ kernel_input_feature_dimension = 3 : i64,
+ kernel_output_feature_dimension = 4 : i64,
+ kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 4 : i64,
+ output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
+ rhs_dilation = dense<1> : tensor<3xi64>,
+ window_strides = dense<1> : tensor<3xi64>
+ } : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-// CHECK: func @depthwise_conv_multiplier_1()
-// CHECK: linalg.fill
-// CHECK: %[[FILTER:.+]] = linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] : memref<3x3x1x96xf32> into memref<3x3x96xf32>
-// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%{{.+}}, %[[FILTER]] : memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%{{.+}} : memref<1x56x56x96xf32>)
+
+// CHECK-LABEL: func @conv2d_1452x2223_dilated_valid
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 2, 4, 3] : tensor<1x2x4x3xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) : tensor<1x2x4x3xf32>, f32 -> tensor<1x2x4x3xf32>
+// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME: {dilations = dense<[2, 1]> : tensor<2xi64>
+// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32>
+func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2x2x2x3xf32>)
+ -> tensor<1x2x4x3xf32> {
+ %0 = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<[2, 1]> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32>
+ return %0 : tensor<1x2x4x3xf32>
+}
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/depthwise_conv.mlir b/iree/compiler/Conversion/HLOToLinalg/test/depthwise_conv.mlir
new file mode 100644
index 0000000..d68f749
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToLinalg/test/depthwise_conv.mlir
@@ -0,0 +1,83 @@
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers %s | IreeFileCheck %s
+
+module {
+ func @depthwise_conv() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x4x5x2xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x2x2x3xf32>
+ %2 = "mhlo.convolution"(%0, %1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 2 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
+ hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<2x3x4x6xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d4, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4 * 3 + d3)>
+// CHECK: func @depthwise_conv()
+// CHECK: linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]
+// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
+// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<2x3x4x6xf32>)
+// CHECK: mulf
+// CHECK: addf
+
+// -----
+
+module {
+ func @depthwise_conv_multiplier_1() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x113x113x96xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<3x3x1x96xf32>
+ %2 = "mhlo.convolution"(%0, %1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 96 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32>
+ hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<1x56x56x96xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK: func @depthwise_conv_multiplier_1()
+// CHECK: linalg.fill
+// CHECK: %[[FILTER:.+]] = linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] : memref<3x3x1x96xf32> into memref<3x3x96xf32>
+// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%{{.+}}, %[[FILTER]] : memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%{{.+}} : memref<1x56x56x96xf32>)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index bf27373..2ec5408 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -771,22 +771,25 @@
OwningRewritePatternList patterns;
- patterns.insert<MapLinalgOpToGlobalInvocationId<linalg::CopyOp>,
- MapLinalgOpToGlobalInvocationId<linalg::FillOp>,
- MapLinalgOpToGlobalInvocationId<linalg::GenericOp>,
- MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
- MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
- MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
- MapLinalgOpToLocalInvocationId<linalg::FillOp>,
- MapLinalgOpToLocalInvocationId<linalg::GenericOp>,
- MapLinalgOpToLocalInvocationId<linalg::IndexedGenericOp>,
- MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
- MapLinalgOpToLocalInvocationId<linalg::BatchMatmulOp>,
- MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
- MapLinalgOpToLocalInvocationId<linalg::PoolingMinOp>,
- MapLinalgOpToLocalInvocationId<linalg::PoolingSumOp>,
- RemoveLinalgRange, SerializeParallelLoopPattern>(
- context, options.useLinalgOnTensors);
+ patterns.insert<
+ MapLinalgOpToGlobalInvocationId<linalg::CopyOp>,
+ MapLinalgOpToGlobalInvocationId<linalg::FillOp>,
+ MapLinalgOpToGlobalInvocationId<linalg::GenericOp>,
+ MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
+ MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
+ MapLinalgOpToLocalInvocationId<linalg::ConvInputNWCFilterWCFOp>,
+ MapLinalgOpToLocalInvocationId<linalg::ConvInputNHWCFilterHWCFOp>,
+ MapLinalgOpToLocalInvocationId<linalg::ConvInputNDHWCFilterDHWCFOp>,
+ MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
+ MapLinalgOpToLocalInvocationId<linalg::FillOp>,
+ MapLinalgOpToLocalInvocationId<linalg::GenericOp>,
+ MapLinalgOpToLocalInvocationId<linalg::IndexedGenericOp>,
+ MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
+ MapLinalgOpToLocalInvocationId<linalg::BatchMatmulOp>,
+ MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
+ MapLinalgOpToLocalInvocationId<linalg::PoolingMinOp>,
+ MapLinalgOpToLocalInvocationId<linalg::PoolingSumOp>, RemoveLinalgRange,
+ SerializeParallelLoopPattern>(context, options.useLinalgOnTensors);
FrozenRewritePatternList frozenPatterns(std::move(patterns));
for (FuncOp funcOp : getOperation().getOps<FuncOp>()) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 1c4bcd5..885dbe3 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -350,13 +350,19 @@
std::array<int64_t, 3> workgroupSize;
};
-static LogicalResult getMaliSpecificConfig(linalg::ConvOp op,
+template <typename ConvOpTy>
+static LogicalResult getMaliSpecificConfig(ConvOpTy op,
TileSizesListType &tileSizes,
LaunchConfigInfo &config) {
- auto inputType = op.getInput(1).getType().cast<MemRefType>();
- auto outputType = op.getOutputBufferTypes()[0].cast<MemRefType>();
+ auto inputType = op.getInput(1).getType().template cast<MemRefType>();
+ auto outputType = op.getOutputBufferTypes()[0].template cast<MemRefType>();
if (!inputType.hasStaticShape() || !outputType.hasStaticShape())
return failure();
+ // Only support NHWC conv.
+ if (!isa<linalg::ConvOp, linalg::ConvInputNHWCFilterHWCFOp>(
+ op.getOperation())) {
+ return failure();
+ }
bool isInputTilable = inputType.getDimSize(3) % 4 == 0;
if (!isInputTilable) return failure();
@@ -406,12 +412,11 @@
return failure();
}
-template <>
-LogicalResult getOpLaunchConfig(linalg::ConvOp op,
- const spirv::TargetEnv &targetEnv,
- const SPIRVCodegenOptions &options,
- TileSizesListType &tileSizes,
- LaunchConfigInfo &config) {
+template <typename T>
+LogicalResult getConvOpLaunchConfig(T op, const spirv::TargetEnv &targetEnv,
+ const SPIRVCodegenOptions &options,
+ TileSizesListType &tileSizes,
+ LaunchConfigInfo &config) {
if (targetEnv.getVendorID() == spirv::Vendor::ARM &&
succeeded(getMaliSpecificConfig(op, tileSizes, config))) {
return success();
@@ -428,6 +433,22 @@
return success();
}
+#define GET_CONV_LAUNCH_CONFIG(opType) \
+ template <> \
+ LogicalResult getOpLaunchConfig( \
+ opType op, const spirv::TargetEnv &targetEnv, \
+ const SPIRVCodegenOptions &options, TileSizesListType &tileSizes, \
+ LaunchConfigInfo &config) { \
+ return getConvOpLaunchConfig(op, targetEnv, options, tileSizes, config); \
+ }
+
+GET_CONV_LAUNCH_CONFIG(linalg::ConvOp)
+GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNWCFilterWCFOp)
+GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNHWCFilterHWCFOp)
+GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNDHWCFilterDHWCFOp)
+
+#undef GET_CONV_LAUNCH_CONFIG
+
static LogicalResult getMaliSpecificConfig(
linalg::DepthwiseConvInputNHWCFilterHWCOp op, TileSizesListType &tileSizes,
LaunchConfigInfo &config) {
@@ -583,6 +604,9 @@
DISPATCH(linalg::BatchMatmulOp)
DISPATCH(linalg::ConvOp)
DISPATCH(linalg::DepthwiseConvInputNHWCFilterHWCOp)
+ DISPATCH(linalg::ConvInputNWCFilterWCFOp)
+ DISPATCH(linalg::ConvInputNHWCFilterHWCFOp)
+ DISPATCH(linalg::ConvInputNDHWCFilterDHWCFOp)
DISPATCH(linalg::MatmulOp)
DISPATCH(linalg::PoolingMaxOp)
DISPATCH(linalg::PoolingMinOp)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index c893944..04b9219 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -159,13 +159,14 @@
// 2) Maybe there are better alternatives for handling filter like using
// different storage classes, since for inference workloads these are model
// constants. This is TBD.
+template <typename ConvOpTy>
struct PromoteConvSubviewsPattern
- : public linalg::LinalgPromotionPattern<linalg::ConvOp> {
+ : public linalg::LinalgPromotionPattern<ConvOpTy> {
PromoteConvSubviewsPattern(MLIRContext *context,
linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
- : linalg::LinalgPromotionPattern<linalg::ConvOp>(
+ : linalg::LinalgPromotionPattern<ConvOpTy>(
context,
options.setOperandsToPromote({1}).setUseFullTileBuffers(
{false, false}),
@@ -175,7 +176,11 @@
static void populatePromotionPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
- patterns.insert<PromoteMatmulSubviewsPattern, PromoteConvSubviewsPattern>(
+ patterns.insert<
+ PromoteMatmulSubviewsPattern, PromoteConvSubviewsPattern<linalg::ConvOp>,
+ PromoteConvSubviewsPattern<linalg::ConvInputNWCFilterWCFOp>,
+ PromoteConvSubviewsPattern<linalg::ConvInputNHWCFilterHWCFOp>,
+ PromoteConvSubviewsPattern<linalg::ConvInputNDHWCFilterDHWCFOp>>(
context,
linalg::LinalgPromotionOptions()
.setAllocationDeallocationFns(allocateWorkgroupMemory,
@@ -313,6 +318,9 @@
patterns.insert<
linalg::LinalgTilingPattern<linalg::ConvOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNWCFilterWCFOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNHWCFilterHWCFOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNDHWCFilterDHWCFOp>,
linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCOp>>(
context, tilingOptions,
getLinalgMatchAndReplaceMarker(
@@ -431,8 +439,12 @@
.setInterchange(loopOrder)
.setTileSizeComputationFunction(getTileSizeFn);
- patterns.insert<linalg::LinalgTilingPattern<linalg::ConvOp>>(
- context, convTilingOptions, marker);
+ patterns
+ .insert<linalg::LinalgTilingPattern<linalg::ConvOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNWCFilterWCFOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNHWCFilterHWCFOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNDHWCFilterDHWCFOp>>(
+ context, convTilingOptions, marker);
}
//====---------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index d314304..9f6c443 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -85,7 +85,11 @@
linalg::LinalgDependenceGraph::DependenceType::RAW)
ADD_FUSABLE_PAIR(linalg::FillOp, linalg::BatchMatmulOp,
linalg::LinalgDependenceGraph::DependenceType::WAW)
- ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvOp,
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvInputNWCFilterWCFOp,
+ linalg::LinalgDependenceGraph::DependenceType::WAW)
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvInputNHWCFilterHWCFOp,
+ linalg::LinalgDependenceGraph::DependenceType::WAW)
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvInputNDHWCFilterDHWCFOp,
linalg::LinalgDependenceGraph::DependenceType::WAW)
ADD_FUSABLE_PAIR(linalg::FillOp, linalg::MatmulOp,
linalg::LinalgDependenceGraph::DependenceType::WAW)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index f937cd6..d201d0e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -354,9 +354,12 @@
%26 = dim %arg2, %c3 : memref<?x?x?x?xf32>
%27 = subview %arg2[%arg3, %arg4, %arg5, 0] [%23, %24, %25, %26] [1, 1, 1, 1]
: memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map5>
- linalg.conv(%arg0, %21, %27)
- {__internal_linalg_transform__ = "workgroup", dilations = [1, 1], strides = [1, 1]}
- : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map5>, memref<?x?x?x?xf32, #map5>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ __internal_linalg_transform__ = "workgroup",
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%21, %arg0 : memref<?x?x?x?xf32, #map5>, memref<?x?x?x?xf32>)
+ outs(%27 : memref<?x?x?x?xf32, #map5>)
scf.yield
}
return
@@ -407,7 +410,7 @@
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
-// CHECK-NOT: linalg.conv
+// CHECK-NOT: linalg.conv_2d_input_nhwc_filter_hwcf
// -----
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index 5278cdd..9e0fa44 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -1,29 +1,71 @@
// RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-codegen-split-dispatch-function -verify-diagnostics %s | IreeFileCheck %s
module {
- // CHECK: func @kernel_fusable_fill_conv_ops
+ // CHECK: func @kernel_fusable_fill_conv1d_ops
// CHECK: linalg.fill
// CHECK-NOT: return
- // CHECK: linalg.conv
+ // CHECK: linalg.conv_1d_input_nwc_filter_wcf
// CHECK: return
- func @kernel_fusable_fill_conv_ops()
+ func @kernel_fusable_fill_conv1d_ops()
attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%dim = hal.interface.load.constant offset = 0 : index
- %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,512]>
+ %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x512xf32>, !shapex.ranked_shape<[?,3,512]>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x512x1xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x512xf32>
+ %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x512xf32>, !shapex.ranked_shape<[?,1,512]>
+ linalg.fill(%ts2, %cst) : memref<?x1x512xf32>, f32
+ linalg.conv_1d_input_nwc_filter_wcf {
+ dilations = dense<1> : tensor<1xi64>,
+ strides = dense<2> : tensor<1xi64>}
+ ins(%ts1, %1 : memref<?x3x512xf32>, memref<3x512x1xf32>)
+ outs(%ts2 : memref<?x1x512xf32>)
+ return
+ }
+ func private @kernel_fusable_fill_conv_ops_num_workgroups__
+ (!shapex.ranked_shape<[?,3,512]>, !shapex.ranked_shape<[3,512,1]>,
+ !shapex.ranked_shape<[?,1,512]>) -> (index, index, index)
+ hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+// -----
+
+module {
+ // CHECK: func @kernel_fusable_fill_conv2d_ops
+ // CHECK: linalg.fill
+ // CHECK-NOT: return
+ // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+ // CHECK: return
+
+ func @kernel_fusable_fill_conv2d_ops()
+ attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
+ %cst = constant 0.000000e+00 : f32
+ %dim = hal.interface.load.constant offset = 0 : index
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,512]>
%shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
- %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,512]>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
%ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
- linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%ts1, %1 : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+ outs(%ts2 : memref<?x1x1x512xf32>)
return
}
func private @kernel_fusable_fill_conv_ops_num_workgroups__
- (!shapex.ranked_shape<[?,2,2,512]>, !shapex.ranked_shape<[3,3,512,1]>,
+ (!shapex.ranked_shape<[?,3,3,512]>, !shapex.ranked_shape<[3,3,512,1]>,
!shapex.ranked_shape<[?,1,1,512]>) -> (index, index, index)
hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -35,6 +77,44 @@
// -----
module {
+ // CHECK: func @kernel_fusable_fill_conv3d_ops
+ // CHECK: linalg.fill
+ // CHECK-NOT: return
+ // CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf
+ // CHECK: return
+
+ func @kernel_fusable_fill_conv3d_ops()
+ attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
+ %cst = constant 0.000000e+00 : f32
+ %dim = hal.interface.load.constant offset = 0 : index
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,3,512]>
+ %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,1,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x3x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,3,512]>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x3x512x1xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x1x512xf32>
+ %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,1,512]>
+ linalg.fill(%ts2, %cst) : memref<?x1x1x1x512xf32>, f32
+ linalg.conv_3d_input_ndhwc_filter_dhwcf {
+ dilations = dense<1> : tensor<3xi64>,
+ strides = dense<2> : tensor<3xi64>}
+ ins(%ts1, %1 : memref<?x3x3x3x512xf32>, memref<3x3x3x512x1xf32>)
+ outs(%ts2 : memref<?x1x1x1x512xf32>)
+ return
+ }
+ func private @kernel_fusable_fill_conv_ops_num_workgroups__
+ (!shapex.ranked_shape<[?,3,3,3,512]>, !shapex.ranked_shape<[3,3,3,512,1]>,
+ !shapex.ranked_shape<[?,1,1,1,512]>) -> (index, index, index)
+ hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+// -----
+
+module {
// CHECK: func @kernel_fusable_fill_matmul_ops
// CHECK: linalg.fill
// CHECK-NOT: return
@@ -118,29 +198,35 @@
// CHECK: %[[DIM:.+]] = hal.interface.load.constant
// CHECK: %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
// CHECK: %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
- // CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+ // CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
// CHECK: %[[TS1:.+]] = shapex.tie_shape %[[IN1]], %[[SHAPE1]]
// CHECK: %[[IN2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
// CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
// CHECK: %[[TS2:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE2]]
- // CHECK: linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
+ // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+ // CHECK-SAME: ins(%[[TS1]], %[[IN2]] : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+ // CHECK-SAME: outs(%[[TS2]] : memref<?x1x1x512xf32>)
// CHECK: return
func @kernel() attributes {hal.num_workgroups_fn = @kernel__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%dim = hal.interface.load.constant offset = 0 : index
- %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,512]>
%shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
- %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,512]>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
%ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
- linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%ts1, %1 : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+ outs(%ts2 : memref<?x1x1x512xf32>)
linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
return
}
- func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
+ func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,3,3,512]>,
!shapex.ranked_shape<[3,3,512,1]>,
!shapex.ranked_shape<[?,1,1,512]>)
-> (index, index, index)
@@ -160,12 +246,14 @@
// CHECK: %[[DIM:.+]] = hal.interface.load.constant
// CHECK: %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
// CHECK: %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
-// CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+// CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
// CHECK: %[[TS1:.+]] = shapex.tie_shape %[[IN1]], %[[SHAPE1]]
// CHECK: %[[IN2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
// CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
// CHECK: %[[TS2:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE2]]
-// CHECK: linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
+// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME: ins(%[[TS1]], %[[IN2]] : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+// CHECK-SAME: outs(%[[TS2]] : memref<?x1x1x512xf32>)
// CHECK: return
// CHECK: func private @[[NUM_WORKGROUPS_FN2]]
@@ -197,10 +285,10 @@
%c0 = constant 0 : index
%c1 = constant 1 : index
%dim = hal.interface.load.constant offset = 0 : index
- %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,512]>
%shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
- %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,512]>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
%ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
@@ -208,10 +296,14 @@
scf.parallel (%iv) = (%c0) to (%c1) step (%c1) {
scf.yield
}
- linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%ts1, %1 : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+ outs(%ts2 : memref<?x1x1x512xf32>)
return
}
- func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
+ func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,3,3,512]>,
!shapex.ranked_shape<[3,3,512,1]>,
!shapex.ranked_shape<[?,1,1,512]>)
-> (index, index, index)
@@ -231,10 +323,14 @@
// CHECK-LABEL: @kernel()
func @kernel() attributes {hal.num_workgroups_fn = @kernel__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x2x2x512xf32>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x3x3x512xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x1x1x512xf32>
- linalg.conv(%1, %0, %2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<1x2x2x512xf32>, memref<1x1x1x512xf32>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%0, %1 : memref<1x3x3x512xf32>, memref<3x3x512x1xf32>)
+ outs(%2 : memref<1x1x1x512xf32>)
return
}
// CHECK-LABEL: @kernel__num_workgroups__
@@ -256,12 +352,16 @@
// expected-error @+1 {{cannot separate Linalg/Parallel ops into multiple kernels}}
func @kernel() {
%cst = constant 0.000000e+00 : f32
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x2x2x512xf32>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x3x3x512xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x1x1x512xf32>
linalg.fill(%2, %cst) : memref<1x1x1x512xf32>, f32
"some_op"() : () -> ()
- linalg.conv(%1, %0, %2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<1x2x2x512xf32>, memref<1x1x1x512xf32>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%0, %1 : memref<1x3x3x512xf32>, memref<3x3x512x1xf32>)
+ outs(%2 : memref<1x1x1x512xf32>)
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
diff --git a/iree/test/e2e/llvm_specific/conv.mlir b/iree/test/e2e/llvm_specific/conv.mlir
index 6ac1719..5e29fa2 100644
--- a/iree/test/e2e/llvm_specific/conv.mlir
+++ b/iree/test/e2e/llvm_specific/conv.mlir
@@ -205,3 +205,53 @@
[105921.0, 107874.0, 109827.0, 111780.0, 113733.0, 115686.0]]]]> : tensor<2x3x3x6xf32>) : tensor<2x3x3x6xf32>
return
}
+
+func @conv_1d() {
+ %inputs = iree.unfoldable_constant dense<2.0> : tensor<3x8x1xf32>
+ %weights = iree.unfoldable_constant dense<2.0> : tensor<3x1x1xf32>
+ %res = "mhlo.convolution"(%inputs, %weights) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 2 : i64,
+ input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
+ kernel_input_feature_dimension = 1 : i64,
+ kernel_output_feature_dimension = 2 : i64,
+ kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 2 : i64,
+ output_spatial_dimensions = dense<[1]> : tensor<1xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<1x2xi64>,
+ rhs_dilation = dense<1> : tensor<1xi64>,
+ window_strides = dense<1> : tensor<1xi64>
+ } : (tensor<3x8x1xf32>, tensor<3x1x1xf32>) -> tensor<3x6x1xf32>
+ check.expect_almost_eq_const(%res, dense<12.0> : tensor<3x6x1xf32>) : tensor<3x6x1xf32>
+ return
+}
+
+func @conv_3d() {
+ %inputs = iree.unfoldable_constant dense<1.0> : tensor<2x8x8x8x3xf32>
+ %weights = iree.unfoldable_constant dense<1.0> : tensor<2x2x2x3x2xf32>
+ %res = "mhlo.convolution"(%inputs, %weights) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 4 : i64,
+ input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
+ kernel_input_feature_dimension = 3 : i64,
+ kernel_output_feature_dimension = 4 : i64,
+ kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 4 : i64,
+ output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<3x2xi64>,
+ rhs_dilation = dense<1> : tensor<3xi64>,
+ window_strides = dense<1> : tensor<3xi64>
+ } : (tensor<2x8x8x8x3xf32>, tensor<2x2x2x3x2xf32>) -> tensor<2x7x7x7x2xf32>
+ check.expect_almost_eq_const(%res, dense<24.0> : tensor<2x7x7x7x2xf32>) : tensor<2x7x7x7x2xf32>
+ return
+}
diff --git a/iree/test/e2e/vulkan_specific/conv.mlir b/iree/test/e2e/vulkan_specific/conv.mlir
index e2dc59d..a1f273b 100644
--- a/iree/test/e2e/vulkan_specific/conv.mlir
+++ b/iree/test/e2e/vulkan_specific/conv.mlir
@@ -51,14 +51,14 @@
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
- input_feature_dimension = 3 : i64,
- input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
- kernel_input_feature_dimension = 2 : i64,
- kernel_output_feature_dimension = 3 : i64,
- kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
- output_batch_dimension = 0 : i64,
- output_feature_dimension = 3 : i64,
- output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
feature_group_count = 1 : i64,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>}
@@ -81,3 +81,53 @@
: tensor<1x3x4x3xf32>) : tensor<1x3x4x3xf32>
return
}
+
+func @conv_1d() {
+ %inputs = iree.unfoldable_constant dense<2.0> : tensor<3x8x1xf32>
+ %weights = iree.unfoldable_constant dense<2.0> : tensor<3x1x1xf32>
+ %res = "mhlo.convolution"(%inputs, %weights) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 2 : i64,
+ input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
+ kernel_input_feature_dimension = 1 : i64,
+ kernel_output_feature_dimension = 2 : i64,
+ kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 2 : i64,
+ output_spatial_dimensions = dense<[1]> : tensor<1xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<1x2xi64>,
+ rhs_dilation = dense<1> : tensor<1xi64>,
+ window_strides = dense<1> : tensor<1xi64>
+ } : (tensor<3x8x1xf32>, tensor<3x1x1xf32>) -> tensor<3x6x1xf32>
+ check.expect_almost_eq_const(%res, dense<12.0> : tensor<3x6x1xf32>) : tensor<3x6x1xf32>
+ return
+}
+
+func @conv_3d() {
+ %inputs = iree.unfoldable_constant dense<1.0> : tensor<2x8x8x8x3xf32>
+ %weights = iree.unfoldable_constant dense<1.0> : tensor<2x2x2x3x2xf32>
+ %res = "mhlo.convolution"(%inputs, %weights) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 4 : i64,
+ input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
+ kernel_input_feature_dimension = 3 : i64,
+ kernel_output_feature_dimension = 4 : i64,
+ kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 4 : i64,
+ output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<3x2xi64>,
+ rhs_dilation = dense<1> : tensor<3xi64>,
+ window_strides = dense<1> : tensor<3xi64>
+ } : (tensor<2x8x8x8x3xf32>, tensor<2x2x2x3x2xf32>) -> tensor<2x7x7x7x2xf32>
+ check.expect_almost_eq_const(%res, dense<24.0> : tensor<2x7x7x7x2xf32>) : tensor<2x7x7x7x2xf32>
+ return
+}
diff --git a/iree/test/e2e/xla_ops/convolution.mlir b/iree/test/e2e/xla_ops/convolution.mlir
index bfdea3c..7fb1008 100644
--- a/iree/test/e2e/xla_ops/convolution.mlir
+++ b/iree/test/e2e/xla_ops/convolution.mlir
@@ -341,3 +341,64 @@
[105921.0, 107874.0, 109827.0, 111780.0, 113733.0, 115686.0]]]]> : tensor<2x3x3x6xf32>) : tensor<2x3x3x6xf32>
return
}
+
+func @conv2d_1452x2223_dilated_valid() {
+ %inputs = iree.unfoldable_constant dense<
+ [[[[0.09762701, 0.43037874],
+ [ 0.20552675, 0.08976637],
+ [-0.1526904, 0.29178822],
+ [-0.12482557, 0.78354603],
+ [ 0.92732555, -0.23311697]],
+ [[ 0.5834501, 0.05778984],
+ [ 0.13608912, 0.85119325],
+ [-0.85792786, -0.8257414 ],
+ [-0.9595632, 0.6652397 ],
+ [ 0.5563135, 0.74002427]],
+ [[ 0.9572367, 0.59831715],
+ [-0.07704128, 0.56105834],
+ [-0.76345116, 0.27984205],
+ [-0.71329343, 0.88933784],
+ [ 0.04369664, -0.17067613]],
+ [[-0.47088876, 0.5484674 ],
+ [-0.08769934, 0.1368679 ],
+ [-0.9624204, 0.23527099],
+ [ 0.22419144, 0.23386799],
+ [ 0.8874962, 0.3636406 ]]]]> : tensor<1x4x5x2xf32>
+ %weights = iree.unfoldable_constant dense<
+ [[[[-0.2809842, -0.12593609, 0.3952624 ],
+ [-0.8795491, 0.33353344, 0.34127575]],
+ [[-0.5792349, -0.7421474, -0.3691433 ],
+ [-0.27257845, 0.14039354, -0.12279698]]],
+ [[[ 0.9767477, -0.79591036, -0.5822465 ],
+ [-0.677381, 0.30621666, -0.4934168 ]],
+ [[-0.06737845, -0.5111488, -0.68206084],
+ [-0.7792497, 0.31265917, -0.7236341 ]]]]> : tensor<2x2x2x3xf32>
+ %res = "mhlo.convolution"(%inputs, %weights) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<[2, 1]> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32>
+ check.expect_almost_eq_const(%res, dense<
+ [[[[-0.45181108, -0.37253797, -1.1074474 ],
+ [-0.74972206, 0.8691965, 0.21864426],
+ [-1.9352274, 1.6551838, 0.13848126],
+ [-2.296763, 0.32046723, -0.02542188]],
+ [[-1.4578199, 0.59465677, 0.0599021 ],
+ [-0.3617443, 1.4647548, 1.2320882 ],
+ [ 0.04506956, 1.4347346, -0.22625303],
+ [-1.122044, -0.41301775, -1.5628793 ]]]]> : tensor<1x2x4x3xf32>) : tensor<1x2x4x3xf32>
+ return
+}