Merge pull request #4885 from KoolJBlack:main-to-google
PiperOrigin-RevId: 358275381
diff --git a/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml b/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml
index 33e7833..d36bc46 100644
--- a/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml
+++ b/build_tools/buildkite/cmake/android/arm64-v8a/benchmark.yml
@@ -29,8 +29,6 @@
- wait
- label: "benchmark on snapdragon-855 (adreno-640) (Pixel 4)"
- # TODO(#4861): Re-enable when phone is fixed
- skip: "Phone is borked. See https://github.com/google/iree/issues/4861"
commands:
- "buildkite-agent artifact download --step build model-artifacts.tgz ./"
- "tar xzvf model-artifacts.tgz"
diff --git a/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml b/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml
index 150008d..8bd3864 100644
--- a/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml
+++ b/build_tools/buildkite/cmake/android/arm64-v8a/pipeline.yml
@@ -26,6 +26,7 @@
- wait
- label: "test on exynos-990 (mali-g77)"
+ skip: true
commands:
- "buildkite-agent artifact download --step build build-artifacts.tgz ./"
- "tar xzf build-artifacts.tgz"
@@ -38,7 +39,6 @@
env:
IREE_DOCKER_WORKDIR: "/usr/src/github/iree"
timeout_in_minutes: "15"
- skip: true
- label: "test on exynos-9820 (mali-g76)"
commands:
@@ -55,6 +55,7 @@
timeout_in_minutes: "15"
- label: "test on snapdragon-835 (adreno-540)"
+ skip: true
commands:
- "buildkite-agent artifact download --step build build-artifacts.tgz ./"
- "tar xzf build-artifacts.tgz"
@@ -69,9 +70,10 @@
timeout_in_minutes: "15"
soft_fail:
- exit_status: "*"
- skip: true
- label: "test on snapdragon-855 (adreno-640)"
+ # TODO(#4861): Re-enable when phone is fixed
+ skip: "Phone is borked. See https://github.com/google/iree/issues/4861"
commands:
- "buildkite-agent artifact download --step build build-artifacts.tgz ./"
- "tar xzf build-artifacts.tgz"
@@ -87,6 +89,7 @@
timeout_in_minutes: "15"
- label: "test on snapdragon-855 (adreno-640) (Android 11)"
+ skip: true
commands:
- "buildkite-agent artifact download --step build build-artifacts.tgz ./"
- "tar xzf build-artifacts.tgz"
@@ -102,7 +105,6 @@
branches: "main"
timeout_in_minutes: "20"
soft_fail: true
- skip: true
- label: "test on snapdragon-865 (adreno-650)"
commands:
diff --git a/iree/compiler/Conversion/HLOToHLO/BUILD b/iree/compiler/Conversion/HLOToHLO/BUILD
index e874421..fb1661c 100644
--- a/iree/compiler/Conversion/HLOToHLO/BUILD
+++ b/iree/compiler/Conversion/HLOToHLO/BUILD
@@ -21,6 +21,7 @@
cc_library(
name = "HLOToHLO",
srcs = [
+ "Convert1x1ConvToDot.cpp",
"DecomposeHLOClamp.cpp",
"DemoteF32ToF16.cpp",
],
diff --git a/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
index 2b6f0ae..8c0601d 100644
--- a/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
+++ b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
@@ -20,6 +20,7 @@
HDRS
"Passes.h"
SRCS
+ "Convert1x1ConvToDot.cpp"
"DecomposeHLOClamp.cpp"
"DemoteF32ToF16.cpp"
DEPS
diff --git a/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp b/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp
new file mode 100644
index 0000000..4bfe8ec
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp
@@ -0,0 +1,149 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+// Rewrites an n-d (n, d1, d2, d3, ..., ci) * (1, 1, 1, ..., ci, co)
+// as (n * d1 * d2 * d3, ..., ci) . (ci, co)
+// TODO(#4876): this pattern should be replaced by a pattern that converts
+// linalg.conv to linalg.matmul.
+class Convert1x1ConvolutionToDotOp : public OpRewritePattern<mhlo::ConvOp> {
+ public:
+ using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mhlo::ConvOp op,
+ PatternRewriter &rewriter) const override {
+ // Only 1x1 convolution no groups will match.
+ if (op.feature_group_count() != 1) return failure();
+
+ Value input = op.lhs();
+ Value filter = op.rhs();
+ Value output = op.getResult();
+ auto inputShapeType = input.getType().dyn_cast_or_null<RankedTensorType>();
+ auto filterShapeType =
+ filter.getType().dyn_cast_or_null<RankedTensorType>();
+ auto outputShapeType =
+ output.getType().dyn_cast_or_null<RankedTensorType>();
+
+ if (!inputShapeType || !filterShapeType || !outputShapeType) {
+ return failure();
+ }
+
+ auto inputShape = inputShapeType.getShape();
+ auto filterShape = filterShapeType.getShape();
+
+ auto inputBatchDim =
+ op.dimension_numbers().input_batch_dimension().getInt();
+ auto inputFeatureDim =
+ op.dimension_numbers().input_feature_dimension().getInt();
+ auto kernelInputFeatureDim =
+ op.dimension_numbers().kernel_input_feature_dimension().getInt();
+ auto kernelOutputFeatureDim =
+ op.dimension_numbers().kernel_output_feature_dimension().getInt();
+
+ // Match input (n, d1, d2, ..., ci) format
+ if (inputFeatureDim != (inputShape.size() - 1) || inputBatchDim != 0) {
+ return failure();
+ }
+
+ // Match filter (k1, k2, ..., ci, co) format
+ if (kernelInputFeatureDim != (filterShape.size() - 2) ||
+ kernelOutputFeatureDim != (filterShape.size() - 1)) {
+ return failure();
+ }
+
+ // Check 1x1x... kernel spatial size.
+ for (auto dim : op.dimension_numbers().kernel_spatial_dimensions()) {
+ if (filterShape[dim.getZExtValue()] != 1) return failure();
+ }
+
+ // Check dilation & strides are ones.
+ if (op.window_strides()) {
+ for (auto stride : op.window_strides()->getValues<int64_t>()) {
+ if (stride != 1) return failure();
+ }
+ }
+ if (op.rhs_dilation()) {
+ for (auto dilation : op.rhs_dilation()->getValues<int64_t>()) {
+ if (dilation != 1) return failure();
+ }
+ }
+
+ int64_t spatialSize = inputShape[0];
+ for (auto dim : op.dimension_numbers().input_spatial_dimensions()) {
+ spatialSize *= inputShape[dim.getZExtValue()];
+ }
+
+ Type reshapedInputType =
+ RankedTensorType::get({spatialSize, inputShape[inputFeatureDim]},
+ inputShapeType.getElementType());
+ Type reshapedFilterTYpe =
+ RankedTensorType::get({filterShape[kernelInputFeatureDim],
+ filterShape[kernelOutputFeatureDim]},
+ filterShapeType.getElementType());
+ Type dotResultType = RankedTensorType::get(
+ {spatialSize, filterShape[kernelOutputFeatureDim]},
+ outputShapeType.getElementType());
+
+ Value reshapedInput =
+ rewriter.create<mhlo::ReshapeOp>(op.getLoc(), reshapedInputType, input);
+ Value reshapedFilter = rewriter.create<mhlo::ReshapeOp>(
+ op.getLoc(), reshapedFilterTYpe, filter);
+
+ Value dotResult = rewriter.create<mhlo::DotOp>(
+ op.getLoc(), dotResultType, reshapedInput, reshapedFilter,
+ rewriter.getStrArrayAttr({"HIGHEST", "HIGHEST"}));
+
+ Value reshapedResult = rewriter.create<mhlo::ReshapeOp>(
+ op.getLoc(), outputShapeType, dotResult);
+
+ rewriter.replaceOp(op, reshapedResult);
+
+ return success();
+ }
+};
+
+struct Convert1x1ConvToDotPass
+ : public PassWrapper<Convert1x1ConvToDotPass, FunctionPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mhlo::MhloDialect>();
+ }
+
+ void runOnFunction() override {
+ MLIRContext *context = &getContext();
+ OwningRewritePatternList patterns;
+ patterns.insert<Convert1x1ConvolutionToDotOp>(context);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createConvert1x1ConvToDotPass() {
+ return std::make_unique<Convert1x1ConvToDotPass>();
+}
+
+static PassRegistration<Convert1x1ConvToDotPass> pass(
+ "iree-codegen-convert-1x1-conv-to-dot",
+ "Convert mhlo.convolution ops with 1x1 kernels into mhlo.dot ops");
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/HLOToHLO/Passes.h b/iree/compiler/Conversion/HLOToHLO/Passes.h
index 5edbcb2..0822ec7 100644
--- a/iree/compiler/Conversion/HLOToHLO/Passes.h
+++ b/iree/compiler/Conversion/HLOToHLO/Passes.h
@@ -30,6 +30,10 @@
namespace mlir {
namespace iree_compiler {
+/// Creates a pass to convert mhlo.convolution ops with 1x1 kernels into
+/// mhlo.dot ops.
+std::unique_ptr<OperationPass<FuncOp>> createConvert1x1ConvToDotPass();
+
/// Creates a pass to decompose XLA-HLO clamp ops into primitive ops.
std::unique_ptr<OperationPass<FuncOp>> createDecomposeHLOClampPass();
diff --git a/iree/compiler/Conversion/HLOToHLO/test/conv1x12dot.mlir b/iree/compiler/Conversion/HLOToHLO/test/conv1x12dot.mlir
new file mode 100644
index 0000000..002fe3e
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToHLO/test/conv1x12dot.mlir
@@ -0,0 +1,28 @@
+// RUN: iree-opt -split-input-file -iree-codegen-convert-1x1-conv-to-dot %s | IreeFileCheck %s
+
+// CHECK: @conv_1x1(%[[INPUT:.+]]: tensor<2x4x5x2xf32>, %[[FILTER:.+]]: tensor<1x1x2x7xf32>) -> tensor<2x4x5x7xf32>
+func @conv_1x1(%arg0: tensor<2x4x5x2xf32>, %arg1: tensor<1x1x2x7xf32>) -> tensor<2x4x5x7xf32> {
+ // CHECK: %[[RESHAPED_INPUT:.+]] = "mhlo.reshape"(%[[INPUT]]) : (tensor<2x4x5x2xf32>) -> tensor<40x2xf32>
+ // CHECK: %[[RESHAPED_FILTER:.+]] = "mhlo.reshape"(%[[FILTER]]) : (tensor<1x1x2x7xf32>) -> tensor<2x7xf32>
+ // CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot"(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]]) {precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<40x2xf32>, tensor<2x7xf32>) -> tensor<40x7xf32>
+ // CEHCK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<40x7xf32>) -> tensor<2x4x5x7xf32>
+ %0 = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<1x1x2x7xf32>) -> tensor<2x4x5x7xf32>
+ return %0 : tensor<2x4x5x7xf32>
+}
+
+
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index c707431..d972351 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -48,6 +48,8 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#define DEBUG_TYPE "iree-hlo-to-linalg-on-buffers"
+
namespace mlir {
namespace iree_compiler {
@@ -57,6 +59,11 @@
// Utility functions.
// -----------------------------------------------------------------------------
+/// Returns true if the given `attr` is a splat of the given `value`.
+static bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
+ return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
+}
+
/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
/// are "parallel" except the last `nReduction` elements, where are "reduction"
/// attributes.
@@ -225,68 +232,11 @@
} // namespace
//===----------------------------------------------------------------------===//
-// mhlo.dot_general conversion patterns.
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Converts mhlo.dot_general operation to linalg.batchmatmul op
-struct DotGeneralOpConversion
- : public ConvertToLinalgBufferOp<DotGeneralOpConversion,
- mhlo::DotGeneralOp> {
- using ConvertToLinalgBufferOp<DotGeneralOpConversion,
- mhlo::DotGeneralOp>::ConvertToLinalgBufferOp;
- LogicalResult apply(mhlo::DotGeneralOp op, ArrayRef<Value> inputBuffers,
- ArrayRef<Value> resultBuffers,
- ConversionPatternRewriter &rewriter) const {
- auto extract1DVector = [](DenseIntElementsAttr elements) {
- SmallVector<int64_t, 6> ret;
- for (const APInt &element : elements) {
- ret.push_back(element.getLimitedValue());
- }
- return ret;
- };
- mhlo::DotDimensionNumbers dimNumbers = op.dot_dimension_numbers();
- auto lhsBatchingDims =
- extract1DVector(dimNumbers.lhs_batching_dimensions());
- auto rhsBatchingDims =
- extract1DVector(dimNumbers.rhs_batching_dimensions());
- auto lhsContractingDims =
- extract1DVector(dimNumbers.lhs_contracting_dimensions());
- auto rhsContractingDims =
- extract1DVector(dimNumbers.rhs_contracting_dimensions());
- if (lhsBatchingDims.size() != 1 || lhsBatchingDims[0] != 0) {
- return rewriter.notifyMatchFailure(
- op, "expected lhs batching dimensions exactly {0}");
- }
- if (rhsBatchingDims.size() != 1 || rhsBatchingDims[0] != 0) {
- return rewriter.notifyMatchFailure(
- op, "expected rhs batching dimensions exactly {0}");
- }
- if (lhsContractingDims.size() != 1 || lhsContractingDims[0] != 2) {
- return rewriter.notifyMatchFailure(
- op, "expected lhs contracting dimensions exactly {2}");
- }
- if (rhsContractingDims.size() != 1 || rhsContractingDims[0] != 1) {
- return rewriter.notifyMatchFailure(
- op, "expected rhs contracting dimensions exactly {1}");
- }
- if (failed(zeroFillBuffer(op.getLoc(), resultBuffers[0], rewriter))) {
- return rewriter.notifyMatchFailure(op,
- "failed to zero fill result buffer");
- }
- rewriter.create<linalg::BatchMatmulOp>(op.getLoc(), inputBuffers,
- resultBuffers);
- return success();
- }
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
// mhlo.convolution conversion patterns and utility functions.
//===----------------------------------------------------------------------===//
namespace {
-/// Converts mhlo.convolution operation to linalg.conv op.
+/// Converts mhlo.convolution operation to linalg.conv or linalg.generic op.
struct ConvOpConversion
: public ConvertToLinalgBufferOp<ConvOpConversion, mhlo::ConvOp> {
using ConvertToLinalgBufferOp<ConvOpConversion,
@@ -295,70 +245,85 @@
ArrayRef<Value> resultBuffers,
ConversionPatternRewriter &rewriter) const;
};
+
+/// Converts mhlo.convolution operation to linalg.depthwise_conv_nhwc op.
+struct DepthwiseConvOpConversion
+ : public ConvertToLinalgBufferOp<DepthwiseConvOpConversion, mhlo::ConvOp> {
+ using ConvertToLinalgBufferOp<DepthwiseConvOpConversion,
+ mhlo::ConvOp>::ConvertToLinalgBufferOp;
+ LogicalResult apply(mhlo::ConvOp op, ArrayRef<Value> inputBuffers,
+ ArrayRef<Value> resultBuffers,
+ ConversionPatternRewriter &rewriter) const;
+};
} // namespace
+/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
+/// follows a canonical form:
+///
+/// * Input dimensions have order: (batch_count, spatial_dims,
+/// input_channel_count).
+/// * Filter dimensions have order: (spatial_dims, input_channel_count,
+/// output_channel_count).
+/// * Output dimensions have order: (batch_count, spatial_dims,
+/// output_channel_count).
+static bool hasCanonicalDimensionNumbers(
+ const mhlo::ConvDimensionNumbers &dimensionNumbers) {
+ const int inputSpatialRank =
+ llvm::size(dimensionNumbers.input_spatial_dimensions());
+ // The dimensions for input should follow the order of
+ // batch_count, spatial_dims..., input_feature_count.
+ if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
+ dimensionNumbers.input_feature_dimension().getInt() !=
+ (inputSpatialRank + 1)) {
+ return false;
+ }
+
+ const int kernelSpatialRank =
+ llvm::size(dimensionNumbers.kernel_spatial_dimensions());
+ // The dimensions for filter should follow the order of
+ // spatial_dims..., input_feature_count, num_output_feature_count.
+ if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
+ kernelSpatialRank ||
+ dimensionNumbers.kernel_output_feature_dimension().getInt() !=
+ (kernelSpatialRank + 1)) {
+ return false;
+ }
+
+ const int outputSpatialRank =
+ llvm::size(dimensionNumbers.output_spatial_dimensions());
+ // The dimensions for output should follow the order of
+ // batch_count, spatial_dims.., output_feature_count.
+ if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
+ dimensionNumbers.output_feature_dimension().getInt() !=
+ (outputSpatialRank + 1)) {
+ return false;
+ }
+
+ if (inputSpatialRank != outputSpatialRank ||
+ inputSpatialRank != kernelSpatialRank) {
+ return false;
+ }
+
+ auto inputSpatialDim = dimensionNumbers.input_spatial_dimensions().begin();
+ auto kernelSpatialDim = dimensionNumbers.kernel_spatial_dimensions().begin();
+ auto outputSpatialDim = dimensionNumbers.output_spatial_dimensions().begin();
+ // Check spatial dims are ordred correctly.
+ for (int i = 0; i < inputSpatialRank; ++i) {
+ const int dim = i + 1;
+ if ((*inputSpatialDim++).getZExtValue() != dim ||
+ (*outputSpatialDim++).getZExtValue() != dim ||
+ (*kernelSpatialDim++).getZExtValue() != i) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
LogicalResult ConvOpConversion::apply(
mhlo::ConvOp op, ArrayRef<Value> inputBuffers,
ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
- if (const auto dimensionNumbers = op.dimension_numbers()) {
- const int inputSpatialRank =
- llvm::size(dimensionNumbers.input_spatial_dimensions());
- // The dimensions for input should follow the order of
- // batch_count, spatial_dims..., input_feature_count.
- if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
- dimensionNumbers.input_feature_dimension().getInt() !=
- (inputSpatialRank + 1)) {
- return failure();
- }
-
- const int kernelSpatialRank =
- llvm::size(dimensionNumbers.kernel_spatial_dimensions());
- // The dimensions for filter should follow the order of
- // spatial_dims..., input_feature_count, num_output_feature_count.
- if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
- kernelSpatialRank ||
- dimensionNumbers.kernel_output_feature_dimension().getInt() !=
- (kernelSpatialRank + 1)) {
- return failure();
- }
-
- const int outputSpatialRank =
- llvm::size(dimensionNumbers.output_spatial_dimensions());
- // The dimensions for output should follow the order of
- // batch_count, spatial_dims.., output_feature_count.
- if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
- dimensionNumbers.output_feature_dimension().getInt() !=
- (outputSpatialRank + 1)) {
- return failure();
- }
-
- if (inputSpatialRank != outputSpatialRank ||
- inputSpatialRank != kernelSpatialRank) {
- return failure();
- }
-
- auto inputSpatialDim = dimensionNumbers.input_spatial_dimensions().begin();
- auto kernelSpatialDim =
- dimensionNumbers.kernel_spatial_dimensions().begin();
- auto outputSpatialDim =
- dimensionNumbers.output_spatial_dimensions().begin();
- // Check spatial dims are ordred correctly.
- for (int i = 0; i < inputSpatialRank; ++i) {
- const int dim = i + 1;
- if ((*inputSpatialDim++).getZExtValue() != dim ||
- (*outputSpatialDim++).getZExtValue() != dim ||
- (*kernelSpatialDim++).getZExtValue() != i) {
- return failure();
- }
- }
- }
-
- llvm::SmallVector<Attribute, 4> strides;
- if (auto windowStrides = op.window_strides()) {
- auto range = windowStrides->getAttributeValues();
- strides.append(range.begin(), range.end());
- }
- auto stridesArg = ArrayAttr::get(op.getContext(), strides);
+ if (!hasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
// TODO(ataei): Only support dilated convolution for now. We need to consider
// LHS dilation for deconvolution cases.
@@ -367,7 +332,6 @@
auto range = rhsDilation->getAttributeValues();
dilation.append(range.begin(), range.end());
}
- auto dilationArg = ArrayAttr::get(op.getContext(), dilation);
// Set padding only if it is non-zero.
DenseIntElementsAttr padding = op.paddingAttr();
@@ -389,92 +353,171 @@
shape[op.dimension_numbers().kernel_input_feature_dimension().getInt()];
auto groupSize =
shape[op.dimension_numbers().kernel_output_feature_dimension().getInt()];
- // Depthwise conv path...
- if (op.feature_group_count() > 1u && op.feature_group_count() == numGroups) {
- // Lowering depthwise convolution to linalg.generic op. The idea is to use
- // the group convolution formulation to perform the separable depthwise
- // convolution as the following, given an n-dimensional input x and filter w
- // the direct convolution operation can be written as:
- // y[n, d1, d2, ....dn, ci * groupSize + co] = sum(k1, k2, ....kn,
- // x[n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn]
- // * w[k1, k2, ...kn, ci, co])
-
- // TODO(ataei): Support dilation.
- if (llvm::any_of(dilation, [](Attribute attr) {
- return (attr.dyn_cast<IntegerAttr>().getInt() != 1);
- })) {
- return failure();
- }
-
- SmallVector<AffineExpr, 4> inputExprs;
- SmallVector<AffineExpr, 4> filterExprs;
- SmallVector<AffineExpr, 4> outputExprs;
-
- const auto spatialDims =
- llvm::size(op.dimension_numbers().input_spatial_dimensions());
- const int d1Index = 1;
- const int coIndex = d1Index + spatialDims;
- const int ciIndex = coIndex + 1;
- const int k1Index = ciIndex + 1;
- // n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn
- inputExprs.push_back(rewriter.getAffineDimExpr(0));
- for (int i = 0; i < spatialDims; ++i) {
- if (op.window_stridesAttr()) {
- auto stride = op.window_stridesAttr().getValue<APInt>(i);
- inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) *
- stride.getZExtValue() +
- rewriter.getAffineDimExpr(k1Index + i));
- } else {
- inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) +
- rewriter.getAffineDimExpr(k1Index + i));
- }
- }
- inputExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
-
- // k1, k2, ...kn, ci, co
- for (int i = 0; i < spatialDims; ++i) {
- filterExprs.push_back(rewriter.getAffineDimExpr(k1Index + i));
- }
- filterExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
- filterExprs.push_back(rewriter.getAffineDimExpr(coIndex));
-
- // n, d1, d2, ....dn, ci * groupSize + co
- outputExprs.push_back(rewriter.getAffineDimExpr(0));
- for (int i = 0; i < spatialDims; ++i) {
- outputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i));
- }
- outputExprs.push_back(rewriter.getAffineDimExpr(ciIndex) * groupSize +
- rewriter.getAffineDimExpr(coIndex));
-
- // nloops = |d| + |k| + |{n, ci, co}|
- int nloops = spatialDims * 2 + 3;
- SmallVector<AffineMap, 4> indexingMaps;
- indexingMaps.emplace_back(AffineMap::get(
- nloops, /*symbolCount=*/0, inputExprs, rewriter.getContext()));
- indexingMaps.emplace_back(AffineMap::get(
- nloops, /*symbolCount=*/0, filterExprs, rewriter.getContext()));
- indexingMaps.emplace_back(AffineMap::get(
- nloops, /*symbolCount=*/0, outputExprs, rewriter.getContext()));
-
- Location loc = op.getLoc();
-
- SmallVector<StringRef, 3> loopAttributeTypes(spatialDims + 3, "parallel");
- loopAttributeTypes.append(spatialDims, "reduction");
- rewriter.create<linalg::GenericOp>(
- loc,
- /*resultTensorTypes=*/ArrayRef<Type>{},
- /*inputs=*/inputBuffers,
- /*outputs=*/resultBuffers, indexingMaps, loopAttributeTypes,
- [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- Value mul = nestedBuilder.create<MulFOp>(nestedLoc, args[0], args[1]);
- Value add = nestedBuilder.create<AddFOp>(nestedLoc, mul, args[2]);
- nestedBuilder.create<linalg::YieldOp>(loc, add);
- });
- } else {
- rewriter.create<linalg::ConvOp>(op.getLoc(), inputBuffers[1],
- inputBuffers[0], resultBuffers[0],
- stridesArg, dilationArg, padding);
+ if (op.feature_group_count() <= 1u || op.feature_group_count() != numGroups) {
+ return failure();
}
+ // Lowering depthwise convolution to linalg.generic op. The idea is to use
+ // the group convolution formulation to perform the separable depthwise
+ // convolution as the following, given an n-dimensional input x and filter w
+ // the direct convolution operation can be written as:
+ // y[n, d1, d2, ....dn, ci * groupSize + co] = sum(k1, k2, ....kn,
+ // x[n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn]
+ // * w[k1, k2, ...kn, ci, co])
+
+ // TODO(ataei): Support dilation.
+ if (llvm::any_of(dilation, [](Attribute attr) {
+ return (attr.dyn_cast<IntegerAttr>().getInt() != 1);
+ })) {
+ return failure();
+ }
+
+ SmallVector<AffineExpr, 4> inputExprs;
+ SmallVector<AffineExpr, 4> filterExprs;
+ SmallVector<AffineExpr, 4> outputExprs;
+
+ const auto spatialDims =
+ llvm::size(op.dimension_numbers().input_spatial_dimensions());
+ const int d1Index = 1;
+ const int coIndex = d1Index + spatialDims;
+ const int ciIndex = coIndex + 1;
+ const int k1Index = ciIndex + 1;
+ // n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn
+ inputExprs.push_back(rewriter.getAffineDimExpr(0));
+ for (int i = 0; i < spatialDims; ++i) {
+ if (op.window_stridesAttr()) {
+ auto stride = op.window_stridesAttr().getValue<APInt>(i);
+ inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) *
+ stride.getZExtValue() +
+ rewriter.getAffineDimExpr(k1Index + i));
+ } else {
+ inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) +
+ rewriter.getAffineDimExpr(k1Index + i));
+ }
+ }
+ inputExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
+
+ // k1, k2, ...kn, ci, co
+ for (int i = 0; i < spatialDims; ++i) {
+ filterExprs.push_back(rewriter.getAffineDimExpr(k1Index + i));
+ }
+ filterExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
+ filterExprs.push_back(rewriter.getAffineDimExpr(coIndex));
+
+ // n, d1, d2, ....dn, ci * groupSize + co
+ outputExprs.push_back(rewriter.getAffineDimExpr(0));
+ for (int i = 0; i < spatialDims; ++i) {
+ outputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i));
+ }
+ outputExprs.push_back(rewriter.getAffineDimExpr(ciIndex) * groupSize +
+ rewriter.getAffineDimExpr(coIndex));
+
+ // nloops = |d| + |k| + |{n, ci, co}|
+ int nloops = spatialDims * 2 + 3;
+ SmallVector<AffineMap, 4> indexingMaps;
+ indexingMaps.emplace_back(AffineMap::get(nloops, /*symbolCount=*/0,
+ inputExprs, rewriter.getContext()));
+ indexingMaps.emplace_back(AffineMap::get(nloops, /*symbolCount=*/0,
+ filterExprs, rewriter.getContext()));
+ indexingMaps.emplace_back(AffineMap::get(nloops, /*symbolCount=*/0,
+ outputExprs, rewriter.getContext()));
+
+ Location loc = op.getLoc();
+
+ SmallVector<StringRef, 3> loopAttributeTypes(spatialDims + 3, "parallel");
+ loopAttributeTypes.append(spatialDims, "reduction");
+ rewriter.create<linalg::GenericOp>(
+ loc,
+ /*resultTensorTypes=*/ArrayRef<Type>{},
+ /*inputs=*/inputBuffers,
+ /*outputs=*/resultBuffers, indexingMaps, loopAttributeTypes,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value mul = nestedBuilder.create<MulFOp>(nestedLoc, args[0], args[1]);
+ Value add = nestedBuilder.create<AddFOp>(nestedLoc, mul, args[2]);
+ nestedBuilder.create<linalg::YieldOp>(loc, add);
+ });
+ return success();
+}
+
+LogicalResult DepthwiseConvOpConversion::apply(
+ mhlo::ConvOp op, ArrayRef<Value> inputBuffers,
+ ArrayRef<Value> resultBuffers, ConversionPatternRewriter &rewriter) const {
+ if (op.batch_group_count() != 1) return failure();
+
+ if (op.padding() && !isSplatValue(*op.padding(), 0)) {
+ return rewriter.notifyMatchFailure(op, "non-zero padding unsupported yet");
+ }
+
+ if ((op.lhs_dilation() && !isSplatValue(*op.lhs_dilation(), 1)) ||
+ (op.rhs_dilation() && !isSplatValue(*op.rhs_dilation(), 1))) {
+ return rewriter.notifyMatchFailure(op, "non-one dialation unsupported yet");
+ }
+
+ if (const mhlo::ConvDimensionNumbers &dimension_numbers =
+ op.dimension_numbers()) {
+ // Make sure that this is 2-D convolution.
+ const int spatialRank =
+ llvm::size(dimension_numbers.input_spatial_dimensions());
+ if (spatialRank != 2) {
+ return rewriter.notifyMatchFailure(op, "only support 2-D cases for now");
+ }
+
+ // Make sure that this is depthwise convolution.
+ int64_t inputFeatureDim =
+ dimension_numbers.input_feature_dimension().getInt();
+ int64_t inputFeatureCount =
+ op.lhs().getType().cast<ShapedType>().getDimSize(inputFeatureDim);
+ if (op.feature_group_count() != inputFeatureCount) {
+ return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
+ }
+
+ // Make sure that this convolution has a canonical form.
+ if (!hasCanonicalDimensionNumbers(dimension_numbers)) {
+ return rewriter.notifyMatchFailure(op, "does not have canonical form");
+ }
+ }
+
+ DenseIntElementsAttr windowStrides;
+ if (op.window_strides()) {
+ windowStrides = op.window_strides().getValue();
+ } else {
+ windowStrides = rewriter.getI64VectorAttr({1, 1});
+ }
+
+ if (failed(zeroFillBuffer(op.getLoc(), resultBuffers[0], rewriter))) {
+ return rewriter.notifyMatchFailure(op, "failed to zero fill result buffer");
+ }
+
+ // Create a Linalg reshape op that converts the filter from 4 dimensions
+ // into 3 dimensions (by droping the unit dimension). This is needed because
+ // linalg.depthwise_conv_2d_nhwc expects 3 dimensions for the filter.
+
+ auto filterDims =
+ llvm::to_vector<4>(op.rhs().getType().cast<ShapedType>().getShape());
+ if (filterDims[2] * filterDims[3] != op.feature_group_count()) {
+ return rewriter.notifyMatchFailure(
+ op, "non-one channel multiplier unsupported yet");
+ }
+ filterDims[2] = op.feature_group_count();
+ filterDims.pop_back();
+
+ MemRefType filterShape = MemRefType::get(
+ filterDims, op.getType().getElementType(), ArrayRef<AffineMap>(),
+ resultBuffers[0].getType().cast<MemRefType>().getMemorySpace());
+
+ auto getIndicesVector = [](int start, int end) {
+ return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
+ };
+
+ SmallVector<linalg::ReassociationIndices, 4> collapsedDimList = {
+ getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
+
+ Value filterBuffer = rewriter.create<linalg::ReshapeOp>(
+ op.getLoc(), filterShape, inputBuffers[1], collapsedDimList);
+
+ rewriter.create<linalg::DepthwiseConvInputNHWCFilterHWCOp>(
+ op.getLoc(), TypeRange(), ValueRange{inputBuffers[0], filterBuffer},
+ resultBuffers, windowStrides);
+
return success();
}
@@ -783,20 +826,26 @@
}
};
-/// Converts linalg.matmul-ish on tensors to linalg.matmul-ish on buffers.
+/// Converts a linalg named op on tensors to linalg named op on buffers.
template <typename LinalgOpTy>
-struct MatmulOnTensorConversion
- : public ConvertToLinalgBufferOp<MatmulOnTensorConversion<LinalgOpTy>,
+struct NamedOpConversion
+ : public ConvertToLinalgBufferOp<NamedOpConversion<LinalgOpTy>,
LinalgOpTy> {
- using ConvertToLinalgBufferOp<MatmulOnTensorConversion<LinalgOpTy>,
+ using ConvertToLinalgBufferOp<NamedOpConversion<LinalgOpTy>,
LinalgOpTy>::ConvertToLinalgBufferOp;
LogicalResult apply(LinalgOpTy op, ArrayRef<Value> inputBuffers,
ArrayRef<Value> resultBuffers,
ConversionPatternRewriter &rewriter) const {
if (!op.hasTensorSemantics()) return failure();
- // The last one is a init tensor.
- rewriter.create<LinalgOpTy>(
- op.getLoc(), inputBuffers.drop_back(op.getNumResults()), resultBuffers);
+ auto linalgOp = cast<linalg::LinalgOp>(op.getOperation());
+ SmallVector<Value, 8> newOperands;
+ newOperands.append(inputBuffers.begin(),
+ inputBuffers.end() - op.getNumResults());
+ newOperands.append(resultBuffers.begin(), resultBuffers.end());
+ auto otherOperands = linalgOp.getAssumedNonShapedOperands();
+ newOperands.append(otherOperands.begin(), otherOperands.end());
+ Location loc = op.getLoc();
+ linalgOp.clone(rewriter, loc, /*resultTypes=*/TypeRange{}, newOperands);
return success();
}
};
@@ -1189,15 +1238,23 @@
void populateHLOToLinalgOnBuffersConversionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
TensorToBufferMap const &resultTensorToBufferMap) {
- patterns.insert<ConvOpConversion, ConcatenateOpConversion,
- FillOpOnTensorConversion, InitTensorOpConversion,
+ patterns.insert<ConvOpConversion, DepthwiseConvOpConversion,
+ ConcatenateOpConversion, FillOpOnTensorConversion,
+ InitTensorOpConversion,
LinalgOpOnTensorConversion<linalg::GenericOp>,
LinalgOpOnTensorConversion<linalg::IndexedGenericOp>,
- MatmulOnTensorConversion<linalg::MatmulOp>,
- MatmulOnTensorConversion<linalg::BatchMatmulOp>,
+ NamedOpConversion<linalg::ConvInputNWCFilterWCFOp>,
+ NamedOpConversion<linalg::ConvInputNHWCFilterHWCFOp>,
+ NamedOpConversion<linalg::ConvInputNDHWCFilterDHWCFOp>,
+ NamedOpConversion<linalg::MatmulOp>,
+ NamedOpConversion<linalg::BatchMatmulOp>,
PadTensorOpConversion, ReduceWindowOpConversion,
SubTensorOpConversion, TensorReshapeOpConversion>(
context, resultTensorToBufferMap);
+
+ // Prefer lowering to named Linalg dpethwise convolution when possible.
+ patterns.insert<DepthwiseConvOpConversion>(context, resultTensorToBufferMap,
+ /*benefit=*/2);
}
void ConvertHLOToLinalgOnBuffersPass::runOnFunction() {
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
index 585570f..e734f26 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
@@ -252,6 +252,156 @@
};
} // namespace
+//===----------------------------------------------------------------------===//
+// mhlo.conv conversion patterns.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+static bool isDepthwiseConv(mhlo::ConvOp op) {
+ auto shape = op.rhs().getType().cast<ShapedType>().getShape();
+ auto numGroups =
+ shape[op.dimension_numbers().kernel_input_feature_dimension().getInt()];
+ return op.feature_group_count() > 1u && op.feature_group_count() == numGroups;
+}
+
+/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
+/// follows a canonical form:
+///
+/// * Input dimensions have order: (batch_count, spatial_dims,
+/// input_channel_count).
+/// * Filter dimensions have order: (spatial_dims, input_channel_count,
+/// output_channel_count).
+/// * Output dimensions have order: (batch_count, spatial_dims,
+/// output_channel_count).
+static bool hasCanonicalDimensionNumbers(
+ const mhlo::ConvDimensionNumbers &dimensionNumbers) {
+ const int inputSpatialRank =
+ llvm::size(dimensionNumbers.input_spatial_dimensions());
+ // The dimensions for input should follow the order of
+ // batch_count, spatial_dims..., input_feature_count.
+ if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
+ dimensionNumbers.input_feature_dimension().getInt() !=
+ (inputSpatialRank + 1)) {
+ return false;
+ }
+
+ const int kernelSpatialRank =
+ llvm::size(dimensionNumbers.kernel_spatial_dimensions());
+ // The dimensions for filter should follow the order of
+ // spatial_dims..., input_feature_count, num_output_feature_count.
+ if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
+ kernelSpatialRank ||
+ dimensionNumbers.kernel_output_feature_dimension().getInt() !=
+ (kernelSpatialRank + 1)) {
+ return false;
+ }
+
+ const int outputSpatialRank =
+ llvm::size(dimensionNumbers.output_spatial_dimensions());
+ // The dimensions for output should follow the order of
+ // batch_count, spatial_dims.., output_feature_count.
+ if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
+ dimensionNumbers.output_feature_dimension().getInt() !=
+ (outputSpatialRank + 1)) {
+ return false;
+ }
+
+ if (inputSpatialRank != outputSpatialRank ||
+ inputSpatialRank != kernelSpatialRank) {
+ return false;
+ }
+
+ auto inputSpatialDim = dimensionNumbers.input_spatial_dimensions().begin();
+ auto kernelSpatialDim = dimensionNumbers.kernel_spatial_dimensions().begin();
+ auto outputSpatialDim = dimensionNumbers.output_spatial_dimensions().begin();
+ // Check spatial dims are ordred correctly.
+ for (int i = 0; i < inputSpatialRank; ++i) {
+ const int dim = i + 1;
+ if ((*inputSpatialDim++).getZExtValue() != dim ||
+ (*outputSpatialDim++).getZExtValue() != dim ||
+ (*kernelSpatialDim++).getZExtValue() != i) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+/// Converts mhlo.conv operation to linalg named op. This only covers normal
+/// convolution cases. The op must have canonical dimension numbers. Depthwise
+/// convolution and pointwise convolution are not handled in the conversion.
+struct NormalConvOpConversion : public OpConversionPattern<mhlo::ConvOp> {
+ using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ mhlo::ConvOp op, ArrayRef<Value> args,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!hasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
+ if (isDepthwiseConv(op)) return failure();
+
+ mhlo::ConvOp::Adaptor adaptor(args);
+ Location loc = op.getLoc();
+ Value input = adaptor.lhs();
+ Value filter = adaptor.rhs();
+ auto resultType = op.getResult().getType().cast<ShapedType>();
+ int rank = resultType.getRank();
+
+ // Check if padding is zero.
+ DenseIntElementsAttr padding = op.paddingAttr();
+ if (padding &&
+ (!padding.isSplat() || padding.getSplatValue<int64_t>() != 0)) {
+ return rewriter.notifyMatchFailure(op, "expected no padding");
+ }
+
+ // The output shape is N spatial_dims F.
+ SmallVector<Value, 8> dynSizes;
+ for (int i = 0, e = rank - 1; i < e; ++i) {
+ if (!resultType.isDynamicDim(i)) continue;
+ dynSizes.push_back(rewriter.create<DimOp>(loc, input, i));
+ }
+ if (resultType.isDynamicDim(rank - 1)) {
+ dynSizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
+ }
+ Value initTensor = rewriter.create<linalg::InitTensorOp>(
+ loc, dynSizes, resultType.getShape(), resultType.getElementType());
+ auto zeroAttr = rewriter.getZeroAttr(resultType.getElementType());
+ Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
+ Value zeroTensor =
+ rewriter.create<linalg::FillOp>(loc, initTensor, zero).getResult(0);
+ linalg::LinalgOp res;
+ Attribute strides = op.window_stridesAttr();
+ // TODO(ataei): Only support dilated kernel right now. We need to consider
+ // input dilation for deconvolution cases.
+ Attribute dilations = op.rhs_dilationAttr();
+ switch (rank) {
+ case 3: {
+ res = rewriter.create<linalg::ConvInputNWCFilterWCFOp>(
+ loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
+ dilations, strides);
+ break;
+ }
+ case 4: {
+ res = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
+ loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
+ dilations, strides);
+ break;
+ }
+ case 5: {
+ res = rewriter.create<linalg::ConvInputNDHWCFilterDHWCFOp>(
+ loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
+ dilations, strides);
+ break;
+ }
+ default:
+ return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
+ }
+ rewriter.replaceOp(op, res.getOperation()->getResults());
+ return success();
+ }
+};
+} // namespace
+
struct ConvertHLOToLinalgOnTensorsPass
: public PassWrapper<ConvertHLOToLinalgOnTensorsPass, FunctionPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -297,7 +447,8 @@
MLIRContext *context, OwningRewritePatternList &patterns) {
mhlo::populateHLOToLinalgConversionPattern(context, &patterns);
patterns.insert<TorchIndexSelectOpConversion, SliceOpConversion,
- ConstOpConversion, PadOpConversion>(context);
+ ConstOpConversion, PadOpConversion, NormalConvOpConversion>(
+ context);
}
std::unique_ptr<OperationPass<FuncOp>> createHLOToLinalgOnTensorsPass() {
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir b/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir
index 5ee4a70..008b380 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/conv.mlir
@@ -1,81 +1,161 @@
-// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers %s | IreeFileCheck %s
+// RUN: iree-opt -iree-codegen-hlo-to-linalg-on-tensors -canonicalize %s | IreeFileCheck %s
-module {
- // CHECK: func @conv
- func @conv() {
- %c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<3x5x5x3xf32>
- %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x2x3x4xf32>
- // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) {
- // CHECK-SAME: dilations = [1, 2]
- // CHECK-SAME: padding = dense<[
- // CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
- // CHECK-SAME: strides = [2, 1]}
- %2 = "mhlo.convolution"(%1, %0) {
- batch_group_count = 1 : i64,
- dimension_numbers = {
- input_batch_dimension = 0 : i64,
- input_feature_dimension = 3 : i64,
- input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
- kernel_input_feature_dimension = 2 : i64,
- kernel_output_feature_dimension = 3 : i64,
- kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
- output_batch_dimension = 0 : i64,
- output_feature_dimension = 3 : i64,
- output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
- },
- feature_group_count = 1 : i64,
- padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
- rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
- window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
- hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<3x5x5x4xf32>
- return
- }
- hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
- }
+// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
+// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]]
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
+// CHECK: linalg.conv_1d_input_nwc_filter_wcf
+// CHECK-SAME: {dilations = dense<1> : tensor<1xi64>
+// CHECK-SAME: strides = dense<1> : tensor<1xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
+ -> tensor<?x?x?xf32> {
+ %0 = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 2 : i64,
+ input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
+ kernel_input_feature_dimension = 1 : i64,
+ kernel_output_feature_dimension = 2 : i64,
+ kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 2 : i64,
+ output_spatial_dimensions = dense<[1]> : tensor<1xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<[[0], [0]]> : tensor<2x1xi64>,
+ rhs_dilation = dense<1> : tensor<1xi64>,
+ window_strides = dense<1> : tensor<1xi64>
+ } : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
}
-// -----
-
-module {
- func @depthwise_conv() {
- %c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x4x5x2xf32>
- %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x2x2x3xf32>
- %2 = "mhlo.convolution"(%0, %1) {
- batch_group_count = 1 : i64,
- dimension_numbers = {
- input_batch_dimension = 0 : i64,
- input_feature_dimension = 3 : i64,
- input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
- kernel_input_feature_dimension = 2 : i64,
- kernel_output_feature_dimension = 3 : i64,
- kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
- output_batch_dimension = 0 : i64,
- output_feature_dimension = 3 : i64,
- output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
- },
- feature_group_count = 2 : i64,
- padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
- hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<2x3x4x6xf32>
- return
- }
- hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
- }
+// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<?x?x?x?xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
+// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>)
+ -> tensor<?x?x?x?xf32> {
+ %0 = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d4)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d4, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4 * 3 + d3)>
-// CHECK: linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]
-// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
-// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<2x3x4x6xf32>)
-// CHECK: mulf
-// CHECK: addf
+
+// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?x?xf32>
+// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?x?xf32>
+// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?x?xf32>
+// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?x?xf32>
+// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<?x?x?x?x?xf32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]]
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
+// CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf
+// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>
+// CHECK-SAME: strides = dense<1> : tensor<3xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>)
+ -> tensor<?x?x?x?x?xf32> {
+ %0 = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 4 : i64,
+ input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
+ kernel_input_feature_dimension = 3 : i64,
+ kernel_output_feature_dimension = 4 : i64,
+ kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 4 : i64,
+ output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
+ rhs_dilation = dense<1> : tensor<3xi64>,
+ window_strides = dense<1> : tensor<3xi64>
+ } : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// CHECK-LABEL: func @conv2d_1452x2223_dilated_valid
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 2, 4, 3] : tensor<1x2x4x3xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) : tensor<1x2x4x3xf32>, f32 -> tensor<1x2x4x3xf32>
+// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME: {dilations = dense<[2, 1]> : tensor<2xi64>
+// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32>
+func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2x2x2x3xf32>)
+ -> tensor<1x2x4x3xf32> {
+ %0 = "mhlo.convolution"(%arg0, %arg1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<[2, 1]> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32>
+ return %0 : tensor<1x2x4x3xf32>
+}
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/depthwise_conv.mlir b/iree/compiler/Conversion/HLOToLinalg/test/depthwise_conv.mlir
new file mode 100644
index 0000000..d68f749
--- /dev/null
+++ b/iree/compiler/Conversion/HLOToLinalg/test/depthwise_conv.mlir
@@ -0,0 +1,83 @@
+// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers %s | IreeFileCheck %s
+
+module {
+ func @depthwise_conv() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x4x5x2xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x2x2x3xf32>
+ %2 = "mhlo.convolution"(%0, %1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 2 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
+ hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<2x3x4x6xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d4, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4 * 3 + d3)>
+// CHECK: func @depthwise_conv()
+// CHECK: linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]
+// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
+// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<2x3x4x6xf32>)
+// CHECK: mulf
+// CHECK: addf
+
+// -----
+
+module {
+ func @depthwise_conv_multiplier_1() {
+ %c0 = constant 0 : index
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x113x113x96xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<3x3x1x96xf32>
+ %2 = "mhlo.convolution"(%0, %1) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 96 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<1> : tensor<2xi64>,
+ window_strides = dense<2> : tensor<2xi64>} : (tensor<1x113x113x96xf32>, tensor<3x3x1x96xf32>) -> tensor<1x56x56x96xf32>
+ hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<1x56x56x96xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK: func @depthwise_conv_multiplier_1()
+// CHECK: linalg.fill
+// CHECK: %[[FILTER:.+]] = linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] : memref<3x3x1x96xf32> into memref<3x3x96xf32>
+// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : tensor<2xi64>} ins(%{{.+}}, %[[FILTER]] : memref<1x113x113x96xf32>, memref<3x3x96xf32>) outs(%{{.+}} : memref<1x56x56x96xf32>)
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/dot_general.mlir b/iree/compiler/Conversion/HLOToLinalg/test/dot_general.mlir
deleted file mode 100644
index 9ab9c3e..0000000
--- a/iree/compiler/Conversion/HLOToLinalg/test/dot_general.mlir
+++ /dev/null
@@ -1,27 +0,0 @@
-// RUN: iree-opt -iree-codegen-hlo-to-linalg-pipeline %s | IreeFileCheck %s
-
-module {
- // CHECK: func @dot_general
- func @dot_general() {
- %c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x2x3xf32>
- %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x3x4xf32>
- // CHECK: linalg.batch_matmul ins(%{{.+}}, %{{.+}} : memref<2x2x3xf32>, memref<2x3x4xf32>) outs(%{{.+}} : memref<2x2x4xf32>)
- %result ="mhlo.dot_general"(%0, %1) {
- dot_dimension_numbers = {
- lhs_batching_dimensions = dense<0> : tensor<1xi64>,
- lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
- rhs_batching_dimensions = dense<0> : tensor<1xi64>,
- rhs_contracting_dimensions = dense<1> : tensor<1xi64>
- },
- precision_config = ["DEFAULT", "DEFAULT"]
- } : (tensor<2x2x3xf32>, tensor<2x3x4xf32>) -> tensor<2x2x4xf32>
- hal.interface.store.tensor %result, @legacy_io::@ret0, offset = %c0 : tensor<2x2x4xf32>
- return
- }
- hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
- }
-}
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
index 9c259d4..de16c47 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
@@ -569,3 +569,34 @@
// CHECK-SAME: ) outs(%[[RET0]]
// CHECK-SAME: )
// CHECK: return
+
+// -----
+
+module {
+ func @dot_general() {
+ %c0 = constant 0 : index
+ %cst = constant 0.000000e+00 : f32
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x2x3xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x3x4xf32>
+ %2 = linalg.init_tensor [2, 2, 4] : tensor<2x2x4xf32>
+ %3 = linalg.fill(%2, %cst) : tensor<2x2x4xf32>, f32 -> tensor<2x2x4xf32>
+ %4 = linalg.batch_matmul ins(%0, %1 : tensor<2x2x3xf32>, tensor<2x3x4xf32>)
+ outs(%3 : tensor<2x2x4xf32>) -> tensor<2x2x4xf32>
+ hal.interface.store.tensor %4, @legacy_io::@ret0, offset = %c0 : tensor<2x2x4xf32>
+ return
+ }
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+// CHECK-LABEL: func @dot_general
+// CHECK-DAG: %[[RET:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<2x2x4xf32>
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<2x2x3xf32>
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<2x3x4xf32>
+// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK: linalg.fill(%[[RET]], %[[ZERO]]) : memref<2x2x4xf32>, f32
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref<2x2x3xf32>, memref<2x3x4xf32>)
+// CHECK-SAME: outs(%[[RET]] : memref<2x2x4xf32>)
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
index cf332c7..e16e3bc 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
@@ -117,6 +117,7 @@
Shape::createMaterializeShapeCalculationsPass());
nestedModulePM.addNestedPass<FuncOp>(
Shape::createHoistShapeCalculationsPass());
+ nestedModulePM.addNestedPass<FuncOp>(createConvert1x1ConvToDotPass());
nestedModulePM.addNestedPass<FuncOp>(createDecomposeHLOClampPass());
addHLOToLinalgOnBuffersPasses(nestedModulePM);
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index bf27373..2ec5408 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -771,22 +771,25 @@
OwningRewritePatternList patterns;
- patterns.insert<MapLinalgOpToGlobalInvocationId<linalg::CopyOp>,
- MapLinalgOpToGlobalInvocationId<linalg::FillOp>,
- MapLinalgOpToGlobalInvocationId<linalg::GenericOp>,
- MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
- MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
- MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
- MapLinalgOpToLocalInvocationId<linalg::FillOp>,
- MapLinalgOpToLocalInvocationId<linalg::GenericOp>,
- MapLinalgOpToLocalInvocationId<linalg::IndexedGenericOp>,
- MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
- MapLinalgOpToLocalInvocationId<linalg::BatchMatmulOp>,
- MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
- MapLinalgOpToLocalInvocationId<linalg::PoolingMinOp>,
- MapLinalgOpToLocalInvocationId<linalg::PoolingSumOp>,
- RemoveLinalgRange, SerializeParallelLoopPattern>(
- context, options.useLinalgOnTensors);
+ patterns.insert<
+ MapLinalgOpToGlobalInvocationId<linalg::CopyOp>,
+ MapLinalgOpToGlobalInvocationId<linalg::FillOp>,
+ MapLinalgOpToGlobalInvocationId<linalg::GenericOp>,
+ MapLinalgOpToGlobalInvocationId<linalg::IndexedGenericOp>,
+ MapLinalgOpToLocalInvocationId<linalg::ConvOp>,
+ MapLinalgOpToLocalInvocationId<linalg::ConvInputNWCFilterWCFOp>,
+ MapLinalgOpToLocalInvocationId<linalg::ConvInputNHWCFilterHWCFOp>,
+ MapLinalgOpToLocalInvocationId<linalg::ConvInputNDHWCFilterDHWCFOp>,
+ MapLinalgOpToLocalInvocationId<linalg::CopyOp>,
+ MapLinalgOpToLocalInvocationId<linalg::FillOp>,
+ MapLinalgOpToLocalInvocationId<linalg::GenericOp>,
+ MapLinalgOpToLocalInvocationId<linalg::IndexedGenericOp>,
+ MapLinalgOpToLocalInvocationId<linalg::MatmulOp>,
+ MapLinalgOpToLocalInvocationId<linalg::BatchMatmulOp>,
+ MapLinalgOpToLocalInvocationId<linalg::PoolingMaxOp>,
+ MapLinalgOpToLocalInvocationId<linalg::PoolingMinOp>,
+ MapLinalgOpToLocalInvocationId<linalg::PoolingSumOp>, RemoveLinalgRange,
+ SerializeParallelLoopPattern>(context, options.useLinalgOnTensors);
FrozenRewritePatternList frozenPatterns(std::move(patterns));
for (FuncOp funcOp : getOperation().getOps<FuncOp>()) {
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index 7772e82..7b0be45 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -350,13 +350,19 @@
std::array<int64_t, 3> workgroupSize;
};
-static LogicalResult getMaliSpecificConfig(linalg::ConvOp op,
+template <typename ConvOpTy>
+static LogicalResult getMaliSpecificConfig(ConvOpTy op,
TileSizesListType &tileSizes,
LaunchConfigInfo &config) {
- auto inputType = op.getInput(1).getType().cast<MemRefType>();
- auto outputType = op.getOutputBufferTypes()[0].cast<MemRefType>();
+ auto inputType = op.getInput(1).getType().template cast<MemRefType>();
+ auto outputType = op.getOutputBufferTypes()[0].template cast<MemRefType>();
if (!inputType.hasStaticShape() || !outputType.hasStaticShape())
return failure();
+ // Only support NHWC conv.
+ if (!isa<linalg::ConvOp, linalg::ConvInputNHWCFilterHWCFOp>(
+ op.getOperation())) {
+ return failure();
+ }
bool isInputTilable = inputType.getDimSize(3) % 4 == 0;
if (!isInputTilable) return failure();
@@ -406,12 +412,11 @@
return failure();
}
-template <>
-LogicalResult getOpLaunchConfig(linalg::ConvOp op,
- const spirv::TargetEnv &targetEnv,
- const SPIRVCodegenOptions &options,
- TileSizesListType &tileSizes,
- LaunchConfigInfo &config) {
+template <typename T>
+LogicalResult getConvOpLaunchConfig(T op, const spirv::TargetEnv &targetEnv,
+ const SPIRVCodegenOptions &options,
+ TileSizesListType &tileSizes,
+ LaunchConfigInfo &config) {
if (targetEnv.getVendorID() == spirv::Vendor::ARM &&
succeeded(getMaliSpecificConfig(op, tileSizes, config))) {
return success();
@@ -428,6 +433,22 @@
return success();
}
+#define GET_CONV_LAUNCH_CONFIG(opType) \
+ template <> \
+ LogicalResult getOpLaunchConfig( \
+ opType op, const spirv::TargetEnv &targetEnv, \
+ const SPIRVCodegenOptions &options, TileSizesListType &tileSizes, \
+ LaunchConfigInfo &config) { \
+ return getConvOpLaunchConfig(op, targetEnv, options, tileSizes, config); \
+ }
+
+GET_CONV_LAUNCH_CONFIG(linalg::ConvOp)
+GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNWCFilterWCFOp)
+GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNHWCFilterHWCFOp)
+GET_CONV_LAUNCH_CONFIG(linalg::ConvInputNDHWCFilterDHWCFOp)
+
+#undef GET_CONV_LAUNCH_CONFIG
+
static LogicalResult getMaliSpecificConfig(
linalg::DepthwiseConvInputNHWCFilterHWCOp op, TileSizesListType &tileSizes,
LaunchConfigInfo &config) {
@@ -583,6 +604,9 @@
DISPATCH(linalg::BatchMatmulOp)
DISPATCH(linalg::ConvOp)
DISPATCH(linalg::DepthwiseConvInputNHWCFilterHWCOp)
+ DISPATCH(linalg::ConvInputNWCFilterWCFOp)
+ DISPATCH(linalg::ConvInputNHWCFilterHWCFOp)
+ DISPATCH(linalg::ConvInputNDHWCFilterDHWCFOp)
DISPATCH(linalg::MatmulOp)
DISPATCH(linalg::PoolingMaxOp)
DISPATCH(linalg::PoolingMinOp)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 38f9c8a..f48fd7c 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -59,16 +59,9 @@
// Utility functions
//===----------------------------------------------------------------------===//
-/// Returns a Linalg marker that replaces existing markers.
-linalg::LinalgTransformationFilter getLinalgReplaceMarker(
- StringRef maker, MLIRContext *context) {
- return linalg::LinalgTransformationFilter(ArrayRef<Identifier>(),
- Identifier::get(maker, context));
-}
-
/// Returns a Linalg marker that matches any of the `matchMarkers` and replaces
/// it with `replaceMarker`.
-linalg::LinalgTransformationFilter getLinalgMatchAndReplaceMarker(
+static linalg::LinalgTransformationFilter getLinalgMatchAndReplaceMarker(
ArrayRef<StringRef> matchMarkers, StringRef replaceMarker,
MLIRContext *context) {
SmallVector<Identifier, 2> markers;
@@ -166,13 +159,14 @@
// 2) Maybe there are better alternatives for handling filter like using
// different storage classes, since for inference workloads these are model
// constants. This is TBD.
+template <typename ConvOpTy>
struct PromoteConvSubviewsPattern
- : public linalg::LinalgPromotionPattern<linalg::ConvOp> {
+ : public linalg::LinalgPromotionPattern<ConvOpTy> {
PromoteConvSubviewsPattern(MLIRContext *context,
linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
- : linalg::LinalgPromotionPattern<linalg::ConvOp>(
+ : linalg::LinalgPromotionPattern<ConvOpTy>(
context,
options.setOperandsToPromote({1}).setUseFullTileBuffers(
{false, false}),
@@ -182,7 +176,11 @@
static void populatePromotionPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
- patterns.insert<PromoteMatmulSubviewsPattern, PromoteConvSubviewsPattern>(
+ patterns.insert<
+ PromoteMatmulSubviewsPattern, PromoteConvSubviewsPattern<linalg::ConvOp>,
+ PromoteConvSubviewsPattern<linalg::ConvInputNWCFilterWCFOp>,
+ PromoteConvSubviewsPattern<linalg::ConvInputNHWCFilterHWCFOp>,
+ PromoteConvSubviewsPattern<linalg::ConvInputNDHWCFilterDHWCFOp>>(
context,
linalg::LinalgPromotionOptions()
.setAllocationDeallocationFns(allocateWorkgroupMemory,
@@ -320,6 +318,9 @@
patterns.insert<
linalg::LinalgTilingPattern<linalg::ConvOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNWCFilterWCFOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNHWCFilterHWCFOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNDHWCFilterDHWCFOp>,
linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCOp>>(
context, tilingOptions,
getLinalgMatchAndReplaceMarker(
@@ -440,8 +441,12 @@
.setInterchange(loopOrder)
.setTileSizeComputationFunction(getTileSizeFn);
- patterns.insert<linalg::LinalgTilingPattern<linalg::ConvOp>>(
- context, convTilingOptions, marker);
+ patterns
+ .insert<linalg::LinalgTilingPattern<linalg::ConvOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNWCFilterWCFOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNHWCFilterHWCFOp>,
+ linalg::LinalgTilingPattern<linalg::ConvInputNDHWCFilterDHWCFOp>>(
+ context, convTilingOptions, marker);
}
//====---------------------------------------------------------------------===//
@@ -641,6 +646,34 @@
applyVectorTransformation(funcOp);
}
+ // Invoke patterns to generalize linalg.depthwise_conv_2d_nhwc ops to Linalg
+ // generic ops. This can handle those cases that failed tiling and
+ // vectorization in the above.
+ // TODO(antiagainst): remove this once we have depthwise convolution
+ // vectorization applicable everywhere.
+ {
+ // Carry over the Linalg marker because it is load-bearing and affects
+ // later passes.
+ linalg::LinalgTransformationFilter marker =
+ getLinalgMatchAndReplaceMarker({getWorkgroupMarker()},
+ getWorkgroupMarker(), context);
+ marker.addFilter([](Operation *op) -> LogicalResult {
+ return success(isa<linalg::DepthwiseConvInputNHWCFilterHWCOp>(op));
+ });
+
+ OwningRewritePatternList patterns;
+ linalg::populateLinalgNamedOpsGeneralizationPatterns(context, patterns,
+ marker);
+
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "--- After generalization ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+ }
+
launchConfig.finalize(funcOp);
}
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index 7252b38..2432432 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -260,6 +260,7 @@
// - All XLA HLO ops are converted.
// - All Linalg ops are operating on buffers.
//===--------------------------------------------------------------------===//
+ pm.nest<ModuleOp>().addNestedPass<FuncOp>(createConvert1x1ConvToDotPass());
pm.nest<ModuleOp>().addNestedPass<FuncOp>(createDecomposeHLOClampPass());
addHLOToLinalgOnBuffersPasses(pm.nest<ModuleOp>());
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index d314304..9f6c443 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -85,7 +85,11 @@
linalg::LinalgDependenceGraph::DependenceType::RAW)
ADD_FUSABLE_PAIR(linalg::FillOp, linalg::BatchMatmulOp,
linalg::LinalgDependenceGraph::DependenceType::WAW)
- ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvOp,
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvInputNWCFilterWCFOp,
+ linalg::LinalgDependenceGraph::DependenceType::WAW)
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvInputNHWCFilterHWCFOp,
+ linalg::LinalgDependenceGraph::DependenceType::WAW)
+ ADD_FUSABLE_PAIR(linalg::FillOp, linalg::ConvInputNDHWCFilterDHWCFOp,
linalg::LinalgDependenceGraph::DependenceType::WAW)
ADD_FUSABLE_PAIR(linalg::FillOp, linalg::MatmulOp,
linalg::LinalgDependenceGraph::DependenceType::WAW)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index f937cd6..d201d0e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -354,9 +354,12 @@
%26 = dim %arg2, %c3 : memref<?x?x?x?xf32>
%27 = subview %arg2[%arg3, %arg4, %arg5, 0] [%23, %24, %25, %26] [1, 1, 1, 1]
: memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map5>
- linalg.conv(%arg0, %21, %27)
- {__internal_linalg_transform__ = "workgroup", dilations = [1, 1], strides = [1, 1]}
- : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map5>, memref<?x?x?x?xf32, #map5>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ __internal_linalg_transform__ = "workgroup",
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%21, %arg0 : memref<?x?x?x?xf32, #map5>, memref<?x?x?x?xf32>)
+ outs(%27 : memref<?x?x?x?xf32, #map5>)
scf.yield
}
return
@@ -407,7 +410,7 @@
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
-// CHECK-NOT: linalg.conv
+// CHECK-NOT: linalg.conv_2d_input_nhwc_filter_hwcf
// -----
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index b81a602..ed295f7 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -1,29 +1,71 @@
// RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-codegen-split-dispatch-function -verify-diagnostics %s | IreeFileCheck %s
module {
- // CHECK: func @kernel_fusable_fill_conv_ops
+ // CHECK: func @kernel_fusable_fill_conv1d_ops
// CHECK: linalg.fill
// CHECK-NOT: return
- // CHECK: linalg.conv
+ // CHECK: linalg.conv_1d_input_nwc_filter_wcf
// CHECK: return
- func @kernel_fusable_fill_conv_ops()
+ func @kernel_fusable_fill_conv1d_ops()
attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%dim = hal.interface.load.constant offset = 0 : index
- %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,512]>
+ %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x512xf32>, !shapex.ranked_shape<[?,3,512]>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x512x1xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x512xf32>
+ %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x512xf32>, !shapex.ranked_shape<[?,1,512]>
+ linalg.fill(%ts2, %cst) : memref<?x1x512xf32>, f32
+ linalg.conv_1d_input_nwc_filter_wcf {
+ dilations = dense<1> : tensor<1xi64>,
+ strides = dense<2> : tensor<1xi64>}
+ ins(%ts1, %1 : memref<?x3x512xf32>, memref<3x512x1xf32>)
+ outs(%ts2 : memref<?x1x512xf32>)
+ return
+ }
+ func private @kernel_fusable_fill_conv_ops_num_workgroups__
+ (!shapex.ranked_shape<[?,3,512]>, !shapex.ranked_shape<[3,512,1]>,
+ !shapex.ranked_shape<[?,1,512]>) -> (index, index, index)
+ hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+// -----
+
+module {
+ // CHECK: func @kernel_fusable_fill_conv2d_ops
+ // CHECK: linalg.fill
+ // CHECK-NOT: return
+ // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+ // CHECK: return
+
+ func @kernel_fusable_fill_conv2d_ops()
+ attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
+ %cst = constant 0.000000e+00 : f32
+ %dim = hal.interface.load.constant offset = 0 : index
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,512]>
%shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
- %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,512]>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
%ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
- linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%ts1, %1 : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+ outs(%ts2 : memref<?x1x1x512xf32>)
return
}
func private @kernel_fusable_fill_conv_ops_num_workgroups__
- (!shapex.ranked_shape<[?,2,2,512]>, !shapex.ranked_shape<[3,3,512,1]>,
+ (!shapex.ranked_shape<[?,3,3,512]>, !shapex.ranked_shape<[3,3,512,1]>,
!shapex.ranked_shape<[?,1,1,512]>) -> (index, index, index)
hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
@@ -35,6 +77,44 @@
// -----
module {
+ // CHECK: func @kernel_fusable_fill_conv3d_ops
+ // CHECK: linalg.fill
+ // CHECK-NOT: return
+ // CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf
+ // CHECK: return
+
+ func @kernel_fusable_fill_conv3d_ops()
+ attributes {hal.num_workgroups_fn = @kernel_fusable_fill_conv_ops_num_workgroups__} {
+ %cst = constant 0.000000e+00 : f32
+ %dim = hal.interface.load.constant offset = 0 : index
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,3,512]>
+ %shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,1,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x3x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,3,512]>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x3x512x1xf32>
+ %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x1x512xf32>
+ %ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,1,512]>
+ linalg.fill(%ts2, %cst) : memref<?x1x1x1x512xf32>, f32
+ linalg.conv_3d_input_ndhwc_filter_dhwcf {
+ dilations = dense<1> : tensor<3xi64>,
+ strides = dense<2> : tensor<3xi64>}
+ ins(%ts1, %1 : memref<?x3x3x3x512xf32>, memref<3x3x3x512x1xf32>)
+ outs(%ts2 : memref<?x1x1x1x512xf32>)
+ return
+ }
+ func private @kernel_fusable_fill_conv_ops_num_workgroups__
+ (!shapex.ranked_shape<[?,3,3,3,512]>, !shapex.ranked_shape<[3,3,3,512,1]>,
+ !shapex.ranked_shape<[?,1,1,1,512]>) -> (index, index, index)
+ hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+}
+
+// -----
+
+module {
// CHECK: func @kernel_fusable_fill_matmul_ops
// CHECK: linalg.fill
// CHECK-NOT: return
@@ -118,29 +198,35 @@
// CHECK: %[[DIM:.+]] = hal.interface.load.constant
// CHECK: %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
// CHECK: %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
- // CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+ // CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
// CHECK: %[[TS1:.+]] = shapex.tie_shape %[[IN1]], %[[SHAPE1]]
// CHECK: %[[IN2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
// CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
// CHECK: %[[TS2:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE2]]
- // CHECK: linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
+ // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+ // CHECK-SAME: ins(%[[TS1]], %[[IN2]] : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+ // CHECK-SAME: outs(%[[TS2]] : memref<?x1x1x512xf32>)
// CHECK: return
func @kernel() attributes {hal.num_workgroups_fn = @kernel__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%dim = hal.interface.load.constant offset = 0 : index
- %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,512]>
%shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
- %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,512]>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
%ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
- linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%ts1, %1 : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+ outs(%ts2 : memref<?x1x1x512xf32>)
linalg.fill(%ts2, %cst) : memref<?x1x1x512xf32>, f32
return
}
- func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
+ func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,3,3,512]>,
!shapex.ranked_shape<[3,3,512,1]>,
!shapex.ranked_shape<[?,1,1,512]>)
-> (index, index, index)
@@ -160,12 +246,14 @@
// CHECK: %[[DIM:.+]] = hal.interface.load.constant
// CHECK: %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
// CHECK: %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
-// CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+// CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
// CHECK: %[[TS1:.+]] = shapex.tie_shape %[[IN1]], %[[SHAPE1]]
// CHECK: %[[IN2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
// CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
// CHECK: %[[TS2:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE2]]
-// CHECK: linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
+// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME: ins(%[[TS1]], %[[IN2]] : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+// CHECK-SAME: outs(%[[TS2]] : memref<?x1x1x512xf32>)
// CHECK: return
// CHECK: func private @[[NUM_WORKGROUPS_FN2]]
@@ -197,10 +285,10 @@
%c0 = constant 0 : index
%c1 = constant 1 : index
%dim = hal.interface.load.constant offset = 0 : index
- %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
+ %shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,3,3,512]>
%shape2 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,1,1,512]>
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
- %ts1 = shapex.tie_shape %0, %shape1 : memref<?x2x2x512xf32>, !shapex.ranked_shape<[?,2,2,512]>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x3x3x512xf32>
+ %ts1 = shapex.tie_shape %0, %shape1 : memref<?x3x3x512xf32>, !shapex.ranked_shape<[?,3,3,512]>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
%ts2 = shapex.tie_shape %2, %shape2 : memref<?x1x1x512xf32>, !shapex.ranked_shape<[?,1,1,512]>
@@ -208,10 +296,14 @@
scf.parallel (%iv) = (%c0) to (%c1) step (%c1) {
scf.yield
}
- linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%ts1, %1 : memref<?x3x3x512xf32>, memref<3x3x512x1xf32>)
+ outs(%ts2 : memref<?x1x1x512xf32>)
return
}
- func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
+ func private @kernel__num_workgroups__(!shapex.ranked_shape<[?,3,3,512]>,
!shapex.ranked_shape<[3,3,512,1]>,
!shapex.ranked_shape<[?,1,1,512]>)
-> (index, index, index)
@@ -231,10 +323,14 @@
// CHECK-LABEL: @kernel()
func @kernel() attributes {hal.num_workgroups_fn = @kernel__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x2x2x512xf32>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x3x3x512xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x1x1x512xf32>
- linalg.conv(%1, %0, %2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<1x2x2x512xf32>, memref<1x1x1x512xf32>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%0, %1 : memref<1x3x3x512xf32>, memref<3x3x512x1xf32>)
+ outs(%2 : memref<1x1x1x512xf32>)
return
}
// CHECK-LABEL: @kernel__num_workgroups__
@@ -256,12 +352,16 @@
// expected-error @+1 {{cannot separate Linalg/Parallel ops into multiple kernels}}
func @kernel() {
%cst = constant 0.000000e+00 : f32
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x2x2x512xf32>
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x3x3x512xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x1x1x512xf32>
linalg.fill(%2, %cst) : memref<1x1x1x512xf32>, f32
"some_op"() : () -> ()
- linalg.conv(%1, %0, %2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<1x2x2x512xf32>, memref<1x1x1x512xf32>
+ linalg.conv_2d_input_nhwc_filter_hwcf {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins(%0, %1 : memref<1x3x3x512xf32>, memref<3x3x512x1xf32>)
+ outs(%2 : memref<1x1x1x512xf32>)
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index 422c887..6379dd4 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -47,11 +47,6 @@
llvm::cl::desc("Guarantees input/output features ordered for conv kernel"),
llvm::cl::init(true));
-static llvm::cl::opt<bool> conv1x1toDot(
- "iree-flow-1x1-conv-to-dot",
- llvm::cl::desc("Rewrites mhlo.conv with 1x1 filter into mhlo.dot"),
- llvm::cl::init(true));
-
static bool isAllZero(DenseIntElementsAttr attr) {
if (!attr.isSplat()) return false;
return attr.getSplatValue<IntegerAttr>().getInt() == 0;
@@ -442,104 +437,6 @@
}
};
-// Rewrites an n-d (n, d1, d2, d3, ..., ci) * (1, 1, 1, ..., ci, co)
-// as (n * d1 * d2 * d3, ..., ci) . (ci, co)
-class Lower1x1ConvolutionToDotOp : public OpRewritePattern<mhlo::ConvOp> {
- public:
- using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(mhlo::ConvOp op,
- PatternRewriter &rewriter) const override {
- // Only 1x1 convolution no groups will match.
- if (op.feature_group_count() != 1) return failure();
-
- Value input = op.lhs();
- Value filter = op.rhs();
- Value output = op.getResult();
- auto inputShapeType = input.getType().dyn_cast_or_null<RankedTensorType>();
- auto filterShapeType =
- filter.getType().dyn_cast_or_null<RankedTensorType>();
- auto outputShapeType =
- output.getType().dyn_cast_or_null<RankedTensorType>();
-
- if (!inputShapeType || !filterShapeType || !outputShapeType) {
- return failure();
- }
-
- auto inputShape = inputShapeType.getShape();
- auto filterShape = filterShapeType.getShape();
-
- auto inputBatchDim =
- op.dimension_numbers().input_batch_dimension().getInt();
- auto inputFeatureDim =
- op.dimension_numbers().input_feature_dimension().getInt();
- auto kernelInputFeatureDim =
- op.dimension_numbers().kernel_input_feature_dimension().getInt();
- auto kernelOutputFeatureDim =
- op.dimension_numbers().kernel_output_feature_dimension().getInt();
-
- // Match input (n, d1, d2, ..., ci) format
- if (inputFeatureDim != (inputShape.size() - 1) || inputBatchDim != 0) {
- return failure();
- }
-
- // Match filter (k1, k2, ..., ci, co) format
- if (kernelInputFeatureDim != (filterShape.size() - 2) ||
- kernelOutputFeatureDim != (filterShape.size() - 1)) {
- return failure();
- }
-
- // Check 1x1x... kernel spatial size.
- for (auto dim : op.dimension_numbers().kernel_spatial_dimensions()) {
- if (filterShape[dim.getZExtValue()] != 1) return failure();
- }
-
- // Check dilation & strides are ones.
- if (op.window_strides()) {
- for (auto stride : op.window_strides()->getValues<int64_t>()) {
- if (stride != 1) return failure();
- }
- }
- if (op.rhs_dilation()) {
- for (auto dilation : op.rhs_dilation()->getValues<int64_t>()) {
- if (dilation != 1) return failure();
- }
- }
-
- int64_t spatialSize = inputShape[0];
- for (auto dim : op.dimension_numbers().input_spatial_dimensions()) {
- spatialSize *= inputShape[dim.getZExtValue()];
- }
-
- Type reshapedInputType =
- RankedTensorType::get({spatialSize, inputShape[inputFeatureDim]},
- inputShapeType.getElementType());
- Type reshapedFilterTYpe =
- RankedTensorType::get({filterShape[kernelInputFeatureDim],
- filterShape[kernelOutputFeatureDim]},
- filterShapeType.getElementType());
- Type dotResultType = RankedTensorType::get(
- {spatialSize, filterShape[kernelOutputFeatureDim]},
- outputShapeType.getElementType());
-
- Value reshapedInput =
- rewriter.create<mhlo::ReshapeOp>(op.getLoc(), reshapedInputType, input);
- Value reshapedFilter = rewriter.create<mhlo::ReshapeOp>(
- op.getLoc(), reshapedFilterTYpe, filter);
-
- Value dotResult = rewriter.create<mhlo::DotOp>(
- op.getLoc(), dotResultType, reshapedInput, reshapedFilter,
- rewriter.getStrArrayAttr({"HIGHEST", "HIGHEST"}));
-
- Value reshapedResult = rewriter.create<mhlo::ReshapeOp>(
- op.getLoc(), outputShapeType, dotResult);
-
- rewriter.replaceOp(op, reshapedResult);
-
- return success();
- }
-};
-
// Adjust the shape of depthwise_conv filter where is applied by mhlo.
class AdjustDepthwiseFilterShape : public OpRewritePattern<mhlo::ConvOp> {
public:
@@ -977,9 +874,6 @@
patterns.insert<ReorderConvOpKernelDimensions>(context);
patterns.insert<ReorderConvOpOutputDimensions>(context);
}
- if (conv1x1toDot) {
- patterns.insert<Lower1x1ConvolutionToDotOp>(context);
- }
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing.mlir b/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing.mlir
index df94f47..39f5419 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -verify-diagnostics -iree-flow-hlo-to-hlo-preprocessing --iree-flow-1x1-conv-to-dot %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-flow-hlo-to-hlo-preprocessing %s | IreeFileCheck %s
// CHECK-LABEL: @batch_norm_inference
// CHECK-SAME: %[[X:[^:[:space:]]+]]
@@ -97,33 +97,6 @@
// -----
-// CHECK: @conv_1x1(%[[INPUT:.+]]: tensor<2x4x5x2xf32>, %[[FILTER:.+]]: tensor<1x1x2x7xf32>) -> tensor<2x4x5x7xf32>
-func @conv_1x1(%arg0: tensor<2x4x5x2xf32>, %arg1: tensor<1x1x2x7xf32>) -> tensor<2x4x5x7xf32> {
- // CHECK: %[[RESHAPED_INPUT:.+]] = "mhlo.reshape"(%[[INPUT]]) : (tensor<2x4x5x2xf32>) -> tensor<40x2xf32>
- // CHECK: %[[RESHAPED_FILTER:.+]] = "mhlo.reshape"(%[[FILTER]]) : (tensor<1x1x2x7xf32>) -> tensor<2x7xf32>
- // CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot"(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]]) {precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<40x2xf32>, tensor<2x7xf32>) -> tensor<40x7xf32>
- // CEHCK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<40x7xf32>) -> tensor<2x4x5x7xf32>
- %0 = "mhlo.convolution"(%arg0, %arg1) {
- batch_group_count = 1 : i64,
- dimension_numbers = {
- input_batch_dimension = 0 : i64,
- input_feature_dimension = 3 : i64,
- input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
- kernel_input_feature_dimension = 2 : i64,
- kernel_output_feature_dimension = 3 : i64,
- kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
- output_batch_dimension = 0 : i64,
- output_feature_dimension = 3 : i64,
- output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
- feature_group_count = 1 : i64,
- padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<1x1x2x7xf32>) -> tensor<2x4x5x7xf32>
- return %0 : tensor<2x4x5x7xf32>
-}
-
-// -----
-
// CHECK: @reorder_broadcast_in_dim_scalar_binary(%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i32>)
func @reorder_broadcast_in_dim_scalar_binary(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>) {
// CHECK: %[[ADD:.*]] = mhlo.add %[[ARG0]], %[[ARG1]] : tensor<f32>
diff --git a/iree/compiler/Dialect/VM/IR/BUILD b/iree/compiler/Dialect/VM/IR/BUILD
index 78488ef..5cc6162 100644
--- a/iree/compiler/Dialect/VM/IR/BUILD
+++ b/iree/compiler/Dialect/VM/IR/BUILD
@@ -38,6 +38,7 @@
"VMOpInterface.cpp.inc",
"VMOps.cpp",
"VMOps.cpp.inc",
+ "VMStructs.cpp.inc",
"VMTypes.cpp",
],
hdrs = [
@@ -47,6 +48,7 @@
"VMOpInterface.h.inc",
"VMOps.h",
"VMOps.h.inc",
+ "VMStructs.h.inc",
"VMTraits.h",
"VMTypes.h",
],
@@ -55,6 +57,7 @@
":VMOpEncoderGen",
":VMOpInterfaceGen",
":VMOpsGen",
+ ":VMStructsGen",
"//iree/compiler/Dialect/IREE/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowInterfaces",
@@ -135,6 +138,21 @@
],
)
+gentbl(
+ name = "VMStructsGen",
+ tbl_outs = [
+ ("-gen-iree-struct-attr-decls", "VMStructs.h.inc"),
+ ("-gen-iree-struct-attr-defs", "VMStructs.cpp.inc"),
+ ],
+ tblgen = "//iree/tools:iree-tblgen",
+ td_file = "VMBase.td",
+ td_srcs = [
+ ":td_files",
+ "//iree/compiler/Dialect/IREE/IR:td_files",
+ "@llvm-project//mlir:OpBaseTdFiles",
+ ],
+)
+
iree_tablegen_doc(
name = "VMDialectDocGen",
tbl_outs = [
diff --git a/iree/compiler/Dialect/VM/IR/CMakeLists.txt b/iree/compiler/Dialect/VM/IR/CMakeLists.txt
index bfa76e9..e180228 100644
--- a/iree/compiler/Dialect/VM/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/IR/CMakeLists.txt
@@ -25,6 +25,7 @@
"VMOpInterface.h.inc"
"VMOps.h"
"VMOps.h.inc"
+ "VMStructs.h.inc"
"VMTraits.h"
"VMTypes.h"
SRCS
@@ -35,6 +36,7 @@
"VMOpInterface.cpp.inc"
"VMOps.cpp"
"VMOps.cpp.inc"
+ "VMStructs.cpp.inc"
"VMTypes.cpp"
DEPS
LLVMSupport
@@ -90,6 +92,18 @@
-gen-op-interface-defs VMOpInterface.cpp.inc
)
+iree_tablegen_library(
+ NAME
+ VMStructsGen
+ TD_FILE
+ "VMBase.td"
+ OUTS
+ -gen-iree-struct-attr-decls VMStructs.h.inc
+ -gen-iree-struct-attr-defs VMStructs.cpp.inc
+ TBLGEN
+ IREE
+)
+
iree_tablegen_doc(
NAME
VMDialectDocGen
diff --git a/iree/compiler/Dialect/VM/IR/VMBase.td b/iree/compiler/Dialect/VM/IR/VMBase.td
index dd7b575..3f404a5 100644
--- a/iree/compiler/Dialect/VM/IR/VMBase.td
+++ b/iree/compiler/Dialect/VM/IR/VMBase.td
@@ -643,4 +643,23 @@
let constBuilderCall = "$0";
}
+//===----------------------------------------------------------------------===//
+// VM structs
+//===----------------------------------------------------------------------===//
+
+def VM_OrdinalCountsAttr :
+ IREE_StructAttr<"ordinal_counts",
+ "OrdinalCountsAttr",
+ VM_Dialect, [
+ IREE_StructFieldAttr<"import_funcs", I32Attr>,
+ IREE_StructFieldAttr<"export_funcs", I32Attr>,
+ IREE_StructFieldAttr<"internal_funcs", I32Attr>,
+ IREE_StructFieldAttr<"global_bytes", I32Attr>,
+ IREE_StructFieldAttr<"global_refs", I32Attr>,
+ IREE_StructFieldAttr<"rodatas", I32Attr>,
+ IREE_StructFieldAttr<"rwdatas", I32Attr>,
+ ]> {
+ let cppNamespace = "mlir::iree_compiler::IREE::VM";
+}
+
#endif // IREE_DIALECT_VM_BASE
diff --git a/iree/compiler/Dialect/VM/IR/VMDialect.cpp b/iree/compiler/Dialect/VM/IR/VMDialect.cpp
index 2f466fc..c467c12 100644
--- a/iree/compiler/Dialect/VM/IR/VMDialect.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMDialect.cpp
@@ -16,6 +16,7 @@
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "iree/compiler/Dialect/VM/IR/VMTypes.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h"
@@ -191,6 +192,7 @@
VMDialect::VMDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<VMDialect>()) {
+ addAttributes<IREE::VM::OrdinalCountsAttr>();
addTypes<IREE::VM::ListType, IREE::VM::OpaqueType, IREE::VM::RefType>();
addInterfaces<VMInlinerInterface, VMOpAsmInterface, VMFolderInterface>();
@@ -201,6 +203,27 @@
}
//===----------------------------------------------------------------------===//
+// Attribute printing and parsing
+//===----------------------------------------------------------------------===//
+
+Attribute VMDialect::parseAttribute(DialectAsmParser &parser, Type type) const {
+ StringRef attrKind;
+ if (failed(parser.parseKeyword(&attrKind))) return {};
+ if (attrKind == OrdinalCountsAttr::getKindName()) {
+ return OrdinalCountsAttr::parse(parser);
+ }
+ parser.emitError(parser.getNameLoc()) << "unknown VM attribute: " << attrKind;
+ return {};
+}
+
+void VMDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const {
+ TypeSwitch<Attribute>(attr)
+ .Case<OrdinalCountsAttr>([&](auto typedAttr) { typedAttr.print(p); })
+ .Default(
+ [](Attribute) { llvm_unreachable("unhandled VM attribute kind"); });
+}
+
+//===----------------------------------------------------------------------===//
// Type printing and parsing
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/VM/IR/VMDialect.h b/iree/compiler/Dialect/VM/IR/VMDialect.h
index 16fc6bc..9941752 100644
--- a/iree/compiler/Dialect/VM/IR/VMDialect.h
+++ b/iree/compiler/Dialect/VM/IR/VMDialect.h
@@ -32,6 +32,12 @@
explicit VMDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "vm"; }
+ /// Parses an attribute registered to this dialect.
+ Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
+
+ /// Prints an attribute registered to this dialect.
+ void printAttribute(Attribute attr, DialectAsmPrinter &p) const override;
+
/// Parses a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td
index 3f07efc..be73ef9 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -42,7 +42,7 @@
let arguments = (ins
StrAttr:$sym_name,
// TODO(benvanik): add compatibility and versioning attributes.
- OptionalAttr<DictionaryAttr>:$ordinal_counts
+ OptionalAttr<VM_OrdinalCountsAttr>:$ordinal_counts
);
let regions = (region SizedRegion<1>:$body);
diff --git a/iree/compiler/Dialect/VM/IR/VMTypes.cpp b/iree/compiler/Dialect/VM/IR/VMTypes.cpp
index 393ae65..b1cc7bd 100644
--- a/iree/compiler/Dialect/VM/IR/VMTypes.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMTypes.cpp
@@ -14,12 +14,15 @@
#include "iree/compiler/Dialect/VM/IR/VMTypes.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/TypeSupport.h"
// Order matters:
#include "iree/compiler/Dialect/VM/IR/VMEnums.cpp.inc"
+#include "iree/compiler/Dialect/VM/IR/VMStructs.cpp.inc"
namespace mlir {
namespace iree_compiler {
@@ -120,6 +123,57 @@
Type RefType::getObjectType() { return getImpl()->objectType; }
+//===----------------------------------------------------------------------===//
+// Attribute printing and parsing
+//===----------------------------------------------------------------------===//
+
+Attribute OrdinalCountsAttr::parse(DialectAsmParser &p) {
+ Type i32 = p.getBuilder().getIntegerType(32);
+ IntegerAttr importFuncsAttr;
+ IntegerAttr exportFuncsAttr;
+ IntegerAttr internalFuncsAttr;
+ IntegerAttr globalBytesAttr;
+ IntegerAttr globalRefsAttr;
+ IntegerAttr rodatasAttr;
+ IntegerAttr rwdatasAttr;
+ if (failed(p.parseLess()) || failed(p.parseKeyword("import_funcs")) ||
+ failed(p.parseEqual()) ||
+ failed(p.parseAttribute(importFuncsAttr, i32)) ||
+ failed(p.parseComma()) || failed(p.parseKeyword("export_funcs")) ||
+ failed(p.parseEqual()) ||
+ failed(p.parseAttribute(exportFuncsAttr, i32)) ||
+ failed(p.parseComma()) || failed(p.parseKeyword("internal_funcs")) ||
+ failed(p.parseEqual()) ||
+ failed(p.parseAttribute(internalFuncsAttr, i32)) ||
+ failed(p.parseComma()) || failed(p.parseKeyword("global_bytes")) ||
+ failed(p.parseEqual()) ||
+ failed(p.parseAttribute(globalBytesAttr, i32)) ||
+ failed(p.parseComma()) || failed(p.parseKeyword("global_refs")) ||
+ failed(p.parseEqual()) || failed(p.parseAttribute(globalRefsAttr, i32)) ||
+ failed(p.parseComma()) || failed(p.parseKeyword("rodatas")) ||
+ failed(p.parseEqual()) || failed(p.parseAttribute(rodatasAttr, i32)) ||
+ failed(p.parseComma()) || failed(p.parseKeyword("rwdatas")) ||
+ failed(p.parseEqual()) || failed(p.parseAttribute(rwdatasAttr, i32)) ||
+ failed(p.parseGreater())) {
+ return {};
+ }
+ return get(importFuncsAttr, exportFuncsAttr, internalFuncsAttr,
+ globalBytesAttr, globalRefsAttr, rodatasAttr, rwdatasAttr);
+}
+
+void OrdinalCountsAttr::print(DialectAsmPrinter &p) const {
+ auto &os = p.getStream();
+ os << getKindName() << "<";
+ os << "import_funcs = " << import_funcs() << ", ";
+ os << "export_funcs = " << export_funcs() << ", ";
+ os << "internal_funcs = " << internal_funcs() << ", ";
+ os << "global_bytes = " << global_bytes() << ", ";
+ os << "global_refs = " << global_refs() << ", ";
+ os << "rodatas = " << rodatas() << ", ";
+ os << "rwdatas = " << rwdatas();
+ os << ">";
+}
+
} // namespace VM
} // namespace IREE
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/VM/IR/VMTypes.h b/iree/compiler/Dialect/VM/IR/VMTypes.h
index 6a1586d..266c3f3 100644
--- a/iree/compiler/Dialect/VM/IR/VMTypes.h
+++ b/iree/compiler/Dialect/VM/IR/VMTypes.h
@@ -19,12 +19,14 @@
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
// Order matters.
#include "iree/compiler/Dialect/VM/IR/VMEnums.h.inc"
+#include "iree/compiler/Dialect/VM/IR/VMStructs.h.inc"
namespace mlir {
namespace iree_compiler {
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
index 12abc70..0d4205d 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
@@ -279,29 +279,17 @@
return moduleOp.emitError() << "ordinal_counts attribute not found. The "
"OrdinalAllocationPass must be run before.";
}
- DictionaryAttr ordinalCounts = moduleOp.ordinal_counts().getValue();
+ OrdinalCountsAttr ordinalCounts = moduleOp.ordinal_counts().getValue();
// Find all structural ops in the module.
std::vector<IREE::VM::ImportOp> importFuncOps;
std::vector<IREE::VM::ExportOp> exportFuncOps;
std::vector<IREE::VM::FuncOp> internalFuncOps;
std::vector<IREE::VM::RodataOp> rodataOps;
- importFuncOps.resize(ordinalCounts.get("import_funcs")
- .dyn_cast<IntegerAttr>()
- .getValue()
- .getLimitedValue());
- exportFuncOps.resize(ordinalCounts.get("export_funcs")
- .dyn_cast<IntegerAttr>()
- .getValue()
- .getLimitedValue());
- internalFuncOps.resize(ordinalCounts.get("internal_funcs")
- .dyn_cast<IntegerAttr>()
- .getValue()
- .getLimitedValue());
- rodataOps.resize(ordinalCounts.get("rodatas")
- .dyn_cast<IntegerAttr>()
- .getValue()
- .getLimitedValue());
+ importFuncOps.resize(ordinalCounts.import_funcs());
+ exportFuncOps.resize(ordinalCounts.export_funcs());
+ internalFuncOps.resize(ordinalCounts.internal_funcs());
+ rodataOps.resize(ordinalCounts.rodatas());
for (auto &op : moduleOp.getBlock().getOperations()) {
if (auto funcOp = dyn_cast<IREE::VM::FuncOp>(op)) {
@@ -458,14 +446,8 @@
auto importFuncsRef = fbb.createOffsetVecDestructive(importFuncRefs);
auto typesRef = fbb.createOffsetVecDestructive(typeRefs);
- int32_t globalRefs = ordinalCounts.get("global_refs")
- .dyn_cast<IntegerAttr>()
- .getValue()
- .getLimitedValue();
- int32_t globalBytes = ordinalCounts.get("global_bytes")
- .dyn_cast<IntegerAttr>()
- .getValue()
- .getLimitedValue();
+ int32_t globalRefs = ordinalCounts.global_refs();
+ int32_t globalBytes = ordinalCounts.global_bytes();
iree_vm_ModuleStateDef_ref_t moduleStateDef = 0;
if (globalBytes || globalRefs) {
diff --git a/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp b/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp
index 3fa29fd..dca926a 100644
--- a/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/OrdinalAllocation.cpp
@@ -78,7 +78,7 @@
// Assign byte offset values to primitive globals, ensuring that we meet
// natural alignment requirements on each size type.
int nextGlobalBytesOrdinal = 0;
- size_t globalBytes = 0;
+ int globalBytes = 0;
for (auto sizeGlobalOps : llvm::enumerate(primitiveGlobalOps)) {
size_t storageSize = sizeGlobalOps.index();
if (sizeGlobalOps.value().empty()) continue;
@@ -88,26 +88,14 @@
globalOp->setAttr("ordinal",
builder.getI32IntegerAttr(nextGlobalBytesOrdinal));
nextGlobalBytesOrdinal += storageSize;
- globalBytes =
- std::max(globalBytes, nextGlobalBytesOrdinal + storageSize);
+ globalBytes = std::max(globalBytes, nextGlobalBytesOrdinal);
}
}
// Assign ordinal counts to module op.
- getOperation().ordinal_countsAttr(builder.getDictionaryAttr(
- {{builder.getIdentifier("import_funcs"),
- builder.getI32IntegerAttr(nextImportOrdinal)},
- {builder.getIdentifier("export_funcs"),
- builder.getI32IntegerAttr(nextExportOrdinal)},
- {builder.getIdentifier("internal_funcs"),
- builder.getI32IntegerAttr(nextFuncOrdinal)},
- {builder.getIdentifier("global_bytes"),
- builder.getI32IntegerAttr(globalBytes)},
- {builder.getIdentifier("global_refs"),
- builder.getI32IntegerAttr(nextGlobalRefOrdinal)},
- {builder.getIdentifier("rodatas"),
- builder.getI32IntegerAttr(nextRodataOrdinal)},
- {builder.getIdentifier("rwdatas"), builder.getI32IntegerAttr(0)}}));
+ getOperation().ordinal_countsAttr(OrdinalCountsAttr::get(
+ nextImportOrdinal, nextExportOrdinal, nextFuncOrdinal, globalBytes,
+ nextGlobalRefOrdinal, nextRodataOrdinal, 0, &getContext()));
SymbolTable symbolTable(getOperation());
diff --git a/iree/compiler/Dialect/VM/Transforms/test/ordinal_allocation.mlir b/iree/compiler/Dialect/VM/Transforms/test/ordinal_allocation.mlir
index 9d977b7..e919749 100644
--- a/iree/compiler/Dialect/VM/Transforms/test/ordinal_allocation.mlir
+++ b/iree/compiler/Dialect/VM/Transforms/test/ordinal_allocation.mlir
@@ -1,6 +1,17 @@
// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation)' %s | IreeFileCheck %s
+// check the parser for vm.module.ordinal_counts
+// RUN: iree-opt -split-input-file -pass-pipeline='vm.module(iree-vm-ordinal-allocation)' %s | iree-opt | IreeFileCheck %s
// CHECK-LABEL: @global_address_propagation
+ // CHECK-SAME: attributes {ordinal_counts = #vm.ordinal_counts<
+ // CHECK-SAME: import_funcs = 0,
+ // CHECK-SAME: export_funcs = 0,
+ // CHECK-SAME: internal_funcs = 1,
+ // CHECK-SAME: global_bytes = 8,
+ // CHECK-SAME: global_refs = 0,
+ // CHECK-SAME: rodatas = 0,
+ // CHECK-SAME: rwdatas = 0
+ // CHECK-SAME: >}
vm.module @global_address_propagation {
// CHECK-DAG: vm.global.i32 @g0 mutable : i32 attributes {ordinal = 0 : i32}
vm.global.i32 @g0 mutable : i32
diff --git a/iree/compiler/Dialect/VMLA/Transforms/BUILD b/iree/compiler/Dialect/VMLA/Transforms/BUILD
index 74918d9..7435c5e 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/BUILD
+++ b/iree/compiler/Dialect/VMLA/Transforms/BUILD
@@ -30,6 +30,7 @@
"Passes.h",
],
deps = [
+ "//iree/compiler/Conversion/HLOToHLO",
"//iree/compiler/Dialect/HAL/IR:HALDialect",
"//iree/compiler/Dialect/IREE/Transforms",
"//iree/compiler/Dialect/Shape/IR",
diff --git a/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
index edcaac8..e0ebf15 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
@@ -31,6 +31,7 @@
MLIRStandard
MLIRSupport
MLIRTransforms
+ iree::compiler::Conversion::HLOToHLO
iree::compiler::Dialect::HAL::IR::HALDialect
iree::compiler::Dialect::IREE::Transforms
iree::compiler::Dialect::Shape::IR
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
index 8061fca..384d050 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
@@ -16,6 +16,7 @@
#include <memory>
+#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
#include "iree/compiler/Dialect/IREE/Transforms/Passes.h"
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
@@ -58,6 +59,9 @@
// Unroll multi-dimensional reductions to one reduction per dimension.
passManager.addNestedPass<FuncOp>(createUnrollReductionsPass());
+ // Converts mhlo.convolution ops with 1x1 kernels into mhlo.dot ops.
+ passManager.addNestedPass<FuncOp>(createConvert1x1ConvToDotPass());
+
// Tensor-level pattern-based lowerings. Thrown into one pass for simplicity.
passManager.addNestedPass<FuncOp>(createPreConversionLoweringPass());
diff --git a/iree/test/e2e/llvm_specific/conv.mlir b/iree/test/e2e/llvm_specific/conv.mlir
index 6ac1719..5e29fa2 100644
--- a/iree/test/e2e/llvm_specific/conv.mlir
+++ b/iree/test/e2e/llvm_specific/conv.mlir
@@ -205,3 +205,53 @@
[105921.0, 107874.0, 109827.0, 111780.0, 113733.0, 115686.0]]]]> : tensor<2x3x3x6xf32>) : tensor<2x3x3x6xf32>
return
}
+
+func @conv_1d() {
+ %inputs = iree.unfoldable_constant dense<2.0> : tensor<3x8x1xf32>
+ %weights = iree.unfoldable_constant dense<2.0> : tensor<3x1x1xf32>
+ %res = "mhlo.convolution"(%inputs, %weights) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 2 : i64,
+ input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
+ kernel_input_feature_dimension = 1 : i64,
+ kernel_output_feature_dimension = 2 : i64,
+ kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 2 : i64,
+ output_spatial_dimensions = dense<[1]> : tensor<1xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<1x2xi64>,
+ rhs_dilation = dense<1> : tensor<1xi64>,
+ window_strides = dense<1> : tensor<1xi64>
+ } : (tensor<3x8x1xf32>, tensor<3x1x1xf32>) -> tensor<3x6x1xf32>
+ check.expect_almost_eq_const(%res, dense<12.0> : tensor<3x6x1xf32>) : tensor<3x6x1xf32>
+ return
+}
+
+func @conv_3d() {
+ %inputs = iree.unfoldable_constant dense<1.0> : tensor<2x8x8x8x3xf32>
+ %weights = iree.unfoldable_constant dense<1.0> : tensor<2x2x2x3x2xf32>
+ %res = "mhlo.convolution"(%inputs, %weights) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 4 : i64,
+ input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
+ kernel_input_feature_dimension = 3 : i64,
+ kernel_output_feature_dimension = 4 : i64,
+ kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 4 : i64,
+ output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<3x2xi64>,
+ rhs_dilation = dense<1> : tensor<3xi64>,
+ window_strides = dense<1> : tensor<3xi64>
+ } : (tensor<2x8x8x8x3xf32>, tensor<2x2x2x3x2xf32>) -> tensor<2x7x7x7x2xf32>
+ check.expect_almost_eq_const(%res, dense<24.0> : tensor<2x7x7x7x2xf32>) : tensor<2x7x7x7x2xf32>
+ return
+}
diff --git a/iree/test/e2e/vulkan_specific/conv.mlir b/iree/test/e2e/vulkan_specific/conv.mlir
index e2dc59d..a1f273b 100644
--- a/iree/test/e2e/vulkan_specific/conv.mlir
+++ b/iree/test/e2e/vulkan_specific/conv.mlir
@@ -51,14 +51,14 @@
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
- input_feature_dimension = 3 : i64,
- input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
- kernel_input_feature_dimension = 2 : i64,
- kernel_output_feature_dimension = 3 : i64,
- kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
- output_batch_dimension = 0 : i64,
- output_feature_dimension = 3 : i64,
- output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
feature_group_count = 1 : i64,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>}
@@ -81,3 +81,53 @@
: tensor<1x3x4x3xf32>) : tensor<1x3x4x3xf32>
return
}
+
+func @conv_1d() {
+ %inputs = iree.unfoldable_constant dense<2.0> : tensor<3x8x1xf32>
+ %weights = iree.unfoldable_constant dense<2.0> : tensor<3x1x1xf32>
+ %res = "mhlo.convolution"(%inputs, %weights) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 2 : i64,
+ input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
+ kernel_input_feature_dimension = 1 : i64,
+ kernel_output_feature_dimension = 2 : i64,
+ kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 2 : i64,
+ output_spatial_dimensions = dense<[1]> : tensor<1xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<1x2xi64>,
+ rhs_dilation = dense<1> : tensor<1xi64>,
+ window_strides = dense<1> : tensor<1xi64>
+ } : (tensor<3x8x1xf32>, tensor<3x1x1xf32>) -> tensor<3x6x1xf32>
+ check.expect_almost_eq_const(%res, dense<12.0> : tensor<3x6x1xf32>) : tensor<3x6x1xf32>
+ return
+}
+
+func @conv_3d() {
+ %inputs = iree.unfoldable_constant dense<1.0> : tensor<2x8x8x8x3xf32>
+ %weights = iree.unfoldable_constant dense<1.0> : tensor<2x2x2x3x2xf32>
+ %res = "mhlo.convolution"(%inputs, %weights) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 4 : i64,
+ input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
+ kernel_input_feature_dimension = 3 : i64,
+ kernel_output_feature_dimension = 4 : i64,
+ kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 4 : i64,
+ output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<3x2xi64>,
+ rhs_dilation = dense<1> : tensor<3xi64>,
+ window_strides = dense<1> : tensor<3xi64>
+ } : (tensor<2x8x8x8x3xf32>, tensor<2x2x2x3x2xf32>) -> tensor<2x7x7x7x2xf32>
+ check.expect_almost_eq_const(%res, dense<24.0> : tensor<2x7x7x7x2xf32>) : tensor<2x7x7x7x2xf32>
+ return
+}
diff --git a/iree/test/e2e/xla_ops/convolution.mlir b/iree/test/e2e/xla_ops/convolution.mlir
index bfdea3c..7fb1008 100644
--- a/iree/test/e2e/xla_ops/convolution.mlir
+++ b/iree/test/e2e/xla_ops/convolution.mlir
@@ -341,3 +341,64 @@
[105921.0, 107874.0, 109827.0, 111780.0, 113733.0, 115686.0]]]]> : tensor<2x3x3x6xf32>) : tensor<2x3x3x6xf32>
return
}
+
+func @conv2d_1452x2223_dilated_valid() {
+ %inputs = iree.unfoldable_constant dense<
+ [[[[0.09762701, 0.43037874],
+ [ 0.20552675, 0.08976637],
+ [-0.1526904, 0.29178822],
+ [-0.12482557, 0.78354603],
+ [ 0.92732555, -0.23311697]],
+ [[ 0.5834501, 0.05778984],
+ [ 0.13608912, 0.85119325],
+ [-0.85792786, -0.8257414 ],
+ [-0.9595632, 0.6652397 ],
+ [ 0.5563135, 0.74002427]],
+ [[ 0.9572367, 0.59831715],
+ [-0.07704128, 0.56105834],
+ [-0.76345116, 0.27984205],
+ [-0.71329343, 0.88933784],
+ [ 0.04369664, -0.17067613]],
+ [[-0.47088876, 0.5484674 ],
+ [-0.08769934, 0.1368679 ],
+ [-0.9624204, 0.23527099],
+ [ 0.22419144, 0.23386799],
+ [ 0.8874962, 0.3636406 ]]]]> : tensor<1x4x5x2xf32>
+ %weights = iree.unfoldable_constant dense<
+ [[[[-0.2809842, -0.12593609, 0.3952624 ],
+ [-0.8795491, 0.33353344, 0.34127575]],
+ [[-0.5792349, -0.7421474, -0.3691433 ],
+ [-0.27257845, 0.14039354, -0.12279698]]],
+ [[[ 0.9767477, -0.79591036, -0.5822465 ],
+ [-0.677381, 0.30621666, -0.4934168 ]],
+ [[-0.06737845, -0.5111488, -0.68206084],
+ [-0.7792497, 0.31265917, -0.7236341 ]]]]> : tensor<2x2x2x3xf32>
+ %res = "mhlo.convolution"(%inputs, %weights) {
+ batch_group_count = 1 : i64,
+ dimension_numbers = {
+ input_batch_dimension = 0 : i64,
+ input_feature_dimension = 3 : i64,
+ input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
+ kernel_input_feature_dimension = 2 : i64,
+ kernel_output_feature_dimension = 3 : i64,
+ kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ output_batch_dimension = 0 : i64,
+ output_feature_dimension = 3 : i64,
+ output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
+ },
+ feature_group_count = 1 : i64,
+ padding = dense<0> : tensor<2x2xi64>,
+ rhs_dilation = dense<[2, 1]> : tensor<2xi64>,
+ window_strides = dense<1> : tensor<2xi64>
+ } : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32>
+ check.expect_almost_eq_const(%res, dense<
+ [[[[-0.45181108, -0.37253797, -1.1074474 ],
+ [-0.74972206, 0.8691965, 0.21864426],
+ [-1.9352274, 1.6551838, 0.13848126],
+ [-2.296763, 0.32046723, -0.02542188]],
+ [[-1.4578199, 0.59465677, 0.0599021 ],
+ [-0.3617443, 1.4647548, 1.2320882 ],
+ [ 0.04506956, 1.4347346, -0.22625303],
+ [-1.122044, -0.41301775, -1.5628793 ]]]]> : tensor<1x2x4x3xf32>) : tensor<1x2x4x3xf32>
+ return
+}
diff --git a/scripts/update_op_coverage.py b/scripts/update_op_coverage.py
index 642bf4d..711fb90 100755
--- a/scripts/update_op_coverage.py
+++ b/scripts/update_op_coverage.py
@@ -91,7 +91,13 @@
for t in tests:
if not t.endswith('.mlir'):
continue
- backend, op = get_backend_op_pair(t)
+ try:
+ backend, op = get_backend_op_pair(t)
+ except LookupError:
+ # Linalg on tensors are WIP; explicitly ignore them for now.
+ if "linalg_on_tensors" in t:
+ continue
+ raise
res[backend].append(op)
return res