Re-land moving tensor->flow passes. (#6651)
Reverts https://github.com/google/iree/pull/6648, rolling-forward https://github.com/google/iree/pull/6586.
First commit is the pure revert, second is a set of proposed fixes for the performance regression. The single pass (with a true/false option) is now split further:
* `createConvertTensorOpsPass`
* `createConvertLinalgTensorOpsPass(true)`
* `createConvertLinalgTensorOpsPass(false)`
---
This gets the original change closer to an NFC (but it is still not an NFC).
Before this change:
| stage | patterns |
| --- | --- |
| `buildCommonInputConversionPassPipeline` | `tensor::CastOp`, `tensor::FromElementsOp` |
| Flow, before DispatchRegionFormation | `tensor::InsertSliceOp`, `tensor::ExtractSliceOp`, <br> `linalg::TensorCollapseShapeOp`, `linalg::TensorExpandShapeOp` |
| Flow, after DispatchRegionFormation | `linalg::FillOp` |
After this change:
| stage | patterns |
| --- | --- |
| `buildCommonInputConversionPassPipeline` | (deleted) |
| Flow, before DispatchRegionFormation | `tensor::CastOp`, `tensor::FromElementsOp`, <br>`tensor::InsertSliceOp`, `tensor::ExtractSliceOp`, <br>`linalg::TensorCollapseShapeOp`, `linalg::TensorExpandShapeOp` |
| Flow, after DispatchRegionFormation | `linalg::FillOp` |
Passes on `linalg` ops are the same, but two `tensor` ops are now converted after linalg fusion, rather than at input conversion time. The first version of this change had those `linalg` ops converted during the "common input conversion pipeline".
diff --git a/iree/compiler/Dialect/Flow/Conversion/BUILD b/iree/compiler/Dialect/Flow/Conversion/BUILD
new file mode 100644
index 0000000..f27d209
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/BUILD
@@ -0,0 +1,11 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
diff --git a/iree/compiler/Dialect/Flow/Conversion/CMakeLists.txt b/iree/compiler/Dialect/Flow/Conversion/CMakeLists.txt
new file mode 100644
index 0000000..544fe79
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/CMakeLists.txt
@@ -0,0 +1,13 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Flow/Conversion/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/BUILD b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/BUILD
new file mode 100644
index 0000000..22002c0
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/BUILD
@@ -0,0 +1,27 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "TensorToFlow",
+ srcs = [
+ "ConvertTensorToFlow.cpp",
+ ],
+ hdrs = [
+ "ConvertTensorToFlow.h",
+ ],
+ deps = [
+ "//iree/compiler/Dialect/Flow/IR",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:TensorDialect",
+ ],
+)
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/CMakeLists.txt b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/CMakeLists.txt
new file mode 100644
index 0000000..8b427cc
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/CMakeLists.txt
@@ -0,0 +1,28 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Flow/Conversion/TensorToFlow/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_cc_library(
+ NAME
+ TensorToFlow
+ HDRS
+ "ConvertTensorToFlow.h"
+ SRCS
+ "ConvertTensorToFlow.cpp"
+ DEPS
+ MLIRIR
+ MLIRStandard
+ MLIRTensor
+ iree::compiler::Dialect::Flow::IR
+ PUBLIC
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp
similarity index 68%
rename from iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
rename to iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp
index 07118d9..eecfa67 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.cpp
@@ -4,28 +4,20 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h"
+
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
-#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define DEBUG_TYPE "iree-flow-convert-to-flow-tensor-ops"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
+namespace {
+
/// An operation that uses `offsets`, `sizes` and `strides` (i.e. implements the
/// `OffsetSizeAndStrideInterface`) can be mapped to flow operations that
/// eventually map to DMA operations if the offsets/sizes/strides represent a
@@ -127,39 +119,8 @@
return dynamicDims;
}
-namespace {
-
-/// Converts linalg.tensor_reshape operations into flow.tensor.reshape
-/// operations.
-template <typename TensorReshapeOp>
-struct LinalgTensorReshapeToFlowTensorReshape
- : public OpRewritePattern<TensorReshapeOp> {
- using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
- PatternRewriter &rewriter) const override {
- if (reshapeOp->template getParentOfType<Flow::DispatchWorkgroupsOp>()) {
- return failure();
- }
- SmallVector<SmallVector<Value>> outputShape;
- if (failed(reshapeOp.reifyResultShapes(rewriter, outputShape))) {
- return failure();
- }
- SmallVector<Value> outputDynamicShapes;
- for (auto shape :
- llvm::zip(reshapeOp.getResultType().getShape(), outputShape[0])) {
- if (std::get<0>(shape) != ShapedType::kDynamicSize) continue;
- outputDynamicShapes.push_back(std::get<1>(shape));
- }
- rewriter.replaceOpWithNewOp<IREE::Flow::TensorReshapeOp>(
- reshapeOp, reshapeOp.getResultType(), reshapeOp.src(),
- outputDynamicShapes);
- return success();
- }
-};
-
-/// Convert subtensor insert operation flow.tensor.update where possible.
-struct SubTensorInsertToTensorUpdate
+/// Convert tensor.insert_slice ops into flow.tensor.update ops where possible.
+struct ConvertTensorInsertSlicePattern
: public OpRewritePattern<tensor::InsertSliceOp> {
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
@@ -168,6 +129,7 @@
if (insertOp->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
return failure();
}
+
SmallVector<OpFoldResult, 4> offsets = insertOp.getMixedOffsets();
SmallVector<OpFoldResult, 4> sizes = insertOp.getMixedSizes();
SmallVector<OpFoldResult, 4> strides = insertOp.getMixedStrides();
@@ -176,6 +138,7 @@
dstShape)) {
return failure();
}
+
Location loc = insertOp.getLoc();
auto sourceDynamicDims = getDynamicValues(sizes);
Value source = insertOp.source();
@@ -192,7 +155,7 @@
loc, sourceType, source, sourceDynamicDims, sourceDynamicDims);
}
- auto offsetVals = getAsValues(rewriter, loc, offsets);
+ auto offsetVals = getAsValues(rewriter, loc, insertOp.getMixedOffsets());
Value dest = insertOp.dest();
auto destDynamicDims = getDynamicDimValues(rewriter, loc, dest);
rewriter.replaceOpWithNewOp<TensorUpdateOp>(
@@ -202,8 +165,8 @@
}
};
-/// Convert subtensor operation to flow.tensor.slice where possible.
-struct SubTensorToTensorSlice
+/// Convert tensor.extract_slice ops into flow.tensor.slice ops where possible.
+struct ConvertTensorExtractSlicePattern
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
@@ -212,6 +175,7 @@
if (sliceOp->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
return failure();
}
+
SmallVector<OpFoldResult, 4> offsets = sliceOp.getMixedOffsets();
SmallVector<OpFoldResult, 4> sizes = sliceOp.getMixedSizes();
SmallVector<OpFoldResult, 4> strides = sliceOp.getMixedStrides();
@@ -220,6 +184,7 @@
srcShape)) {
return failure();
}
+
Location loc = sliceOp.getLoc();
ShapedType sourceType = sliceOp.getSourceType();
@@ -251,68 +216,103 @@
}
};
-/// Converts linalg.fill ops into flow.tensor.splat ops.
-///
-/// This is expected to improve performance because we can use DMA
-/// functionalities for the fill, instead of dispatching kernels.
-struct LinalgFillToFlowTensorSplat final
- : public OpRewritePattern<linalg::FillOp> {
- using OpRewritePattern::OpRewritePattern;
+struct ConvertTensorCastPattern : public OpRewritePattern<tensor::CastOp> {
+ using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(linalg::FillOp fillOp,
+ LogicalResult matchAndRewrite(tensor::CastOp op,
PatternRewriter &rewriter) const override {
- if (fillOp->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
- // Don't convert linalg.fill ops that were fused together with other ops.
+ if (op->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
return failure();
}
- SmallVector<Value, 4> dynamicDims =
- getDynamicDimValues(rewriter, fillOp.getLoc(), fillOp.output());
- rewriter.replaceOpWithNewOp<TensorSplatOp>(
- fillOp, fillOp.output().getType(), fillOp.value(), dynamicDims);
+ auto loc = op.getLoc();
+ Value input = op.getOperand();
+ ShapedType inputType = input.getType().dyn_cast<ShapedType>();
+ ShapedType resultType =
+ op.getResult().getType().dyn_cast_or_null<ShapedType>();
+ if (!inputType || !resultType || !inputType.hasRank() ||
+ !resultType.hasRank()) {
+ return rewriter.notifyMatchFailure(op, "not ranked shaped types");
+ }
+ // This should not happen, except in the context of type conversion.
+ if (inputType.getRank() != resultType.getRank()) {
+ return rewriter.notifyMatchFailure(op, "mismatched rank");
+ }
+
+ // Resolve dims to the most specific value.
+ int rank = inputType.getRank();
+ SmallVector<Value> dimSizes(rank);
+ auto resolveDimSize = [&](int position) -> Value {
+ if (!dimSizes[position]) {
+ // Find the most specific.
+ if (!inputType.isDynamicDim(position) ||
+ !resultType.isDynamicDim(position)) {
+ // Static dim.
+ int64_t dimSize = !inputType.isDynamicDim(position)
+ ? inputType.getDimSize(position)
+ : resultType.getDimSize(position);
+ dimSizes[position] = rewriter.create<ConstantIndexOp>(loc, dimSize);
+ } else {
+ // Dynamic dim.
+ dimSizes[position] =
+ rewriter.create<tensor::DimOp>(loc, input, position);
+ }
+ }
+
+ return dimSizes[position];
+ };
+
+ SmallVector<Value> sourceDynamicDims;
+ SmallVector<Value> targetDynamicDims;
+ for (int i = 0; i < rank; i++) {
+ if (inputType.isDynamicDim(i)) {
+ sourceDynamicDims.push_back(resolveDimSize(i));
+ }
+ if (resultType.isDynamicDim(i)) {
+ targetDynamicDims.push_back(resolveDimSize(i));
+ }
+ }
+
+ // TODO: Decide if this needs to be replaced with a flow.tensor.cast
+ // See https://github.com/google/iree/issues/6418
+ rewriter.replaceOpWithNewOp<IREE::Flow::TensorReshapeOp>(
+ op, resultType, input, sourceDynamicDims, targetDynamicDims);
+
return success();
}
};
-/// Converts operations that can map to flow.tensor.* operations.
-struct ConvertToFlowTensorOpsPass
- : public ConvertToFlowTensorOpsBase<ConvertToFlowTensorOpsPass> {
- ConvertToFlowTensorOpsPass(bool runBefore) {
- runBeforeDispatchRegionFormation = runBefore;
- }
- ConvertToFlowTensorOpsPass(const ConvertToFlowTensorOpsPass &that) {
- runBeforeDispatchRegionFormation = that.runBeforeDispatchRegionFormation;
- }
+struct ConvertTensorFromElementsPattern
+ : public OpRewritePattern<tensor::FromElementsOp> {
+ using OpRewritePattern<tensor::FromElementsOp>::OpRewritePattern;
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<IREE::Flow::FlowDialect, memref::MemRefDialect,
- mlir::StandardOpsDialect>();
- }
- void runOnOperation() override {
- FuncOp funcOp = getOperation();
- MLIRContext *context = funcOp->getContext();
- context->allowUnregisteredDialects(true);
- RewritePatternSet patterns(&getContext());
- if (runBeforeDispatchRegionFormation) {
- patterns.insert<
- LinalgTensorReshapeToFlowTensorReshape<linalg::TensorCollapseShapeOp>,
- LinalgTensorReshapeToFlowTensorReshape<linalg::TensorExpandShapeOp>,
- SubTensorInsertToTensorUpdate, SubTensorToTensorSlice>(context);
- } else {
- patterns.insert<LinalgFillToFlowTensorSplat>(context);
+ LogicalResult matchAndRewrite(tensor::FromElementsOp op,
+ PatternRewriter &rewriter) const override {
+ // TODO: This pattern was mainly added to iron out some kinks specific to
+ // detensoring (see: https://github.com/google/iree/issues/1159). Do we need
+ // to expand this check for other uses?
+ if (op->getParentOfType<Flow::DispatchWorkgroupsOp>() ||
+ op.getType().getDimSize(0) != 1) {
+ return failure();
}
- IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
- return signalPassFailure();
- }
+
+ auto loc = op.getLoc();
+ SmallVector<Value> dimSizes(1);
+ dimSizes[0] = rewriter.create<ConstantIndexOp>(loc, 1);
+ rewriter.replaceOpWithNewOp<IREE::Flow::TensorSplatOp>(
+ op, op.getType(), op.getOperand(0), dimSizes);
+ return success();
}
};
+
} // namespace
-std::unique_ptr<OperationPass<FuncOp>> createConvertToFlowTensorOpsPass(
- bool runBeforeDispatchRegionFormation) {
- return std::make_unique<ConvertToFlowTensorOpsPass>(
- runBeforeDispatchRegionFormation);
+void populateTensorToFlowPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns) {
+ patterns
+ .insert<ConvertTensorInsertSlicePattern, ConvertTensorExtractSlicePattern,
+ ConvertTensorCastPattern, ConvertTensorFromElementsPattern>(
+ context);
}
} // namespace Flow
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h
new file mode 100644
index 0000000..e5b4505
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h
@@ -0,0 +1,27 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_FLOW_CONVERSION_TENSORTOFLOW_CONVERTTENSORTOFLOW_H_
+#define IREE_COMPILER_DIALECT_FLOW_CONVERSION_TENSORTOFLOW_CONVERTTENSORTOFLOW_H_
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+// Populates rewrite patterns for Tensor->Flow.
+void populateTensorToFlowPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns);
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_FLOW_CONVERSION_TENSORTOFLOW_CONVERTTENSORTOFLOW_H_
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD
new file mode 100644
index 0000000..1d1e9da
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD
@@ -0,0 +1,31 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = enforce_glob(
+ [
+ "cast.mlir",
+ "extract_slice.mlir",
+ "from_elements.mlir",
+ "insert_slice.mlir",
+ ],
+ include = ["*.mlir"],
+ ),
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-opt",
+ ],
+)
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/CMakeLists.txt
new file mode 100644
index 0000000..d515c1e
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/CMakeLists.txt
@@ -0,0 +1,26 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "cast.mlir"
+ "extract_slice.mlir"
+ "from_elements.mlir"
+ "insert_slice.mlir"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-opt
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/cast.mlir b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/cast.mlir
new file mode 100644
index 0000000..5167626
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/cast.mlir
@@ -0,0 +1,42 @@
+// RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-flow-convert-tensor-ops-pass %s | IreeFileCheck %s
+
+func @static_tensor_cast_to_dynamic(%arg0: tensor<4x4xf32>) -> tensor<?x?xf32> {
+ // CHECK-DAG: %[[C4:.*]] = constant 4 : index
+ // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<4x4xf32> -> tensor<?x?xf32>{%[[C4]], %[[C4]]}
+ // CHECK: return %[[RESULT]]
+ %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+func @dynamic_tensor_cast_to_static(%arg0: tensor<?xf32>) -> tensor<4xf32> {
+ // CHECK: %[[C4:.*]] = constant 4 : index
+ // CHECK: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<?xf32>{%[[C4]]} -> tensor<4xf32>
+ // CHECK: return %[[RESULT]]
+ %0 = tensor.cast %arg0 : tensor<?xf32> to tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+// -----
+func @dynamic_tensor_cast_to_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?x3xf32> {
+ // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+ // CHECK-DAG: %[[D0:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
+ // CHECK-DAG: %[[C3:.*]] = constant 3 : index
+ // CHECK: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<?x?xf32>{%[[D0]], %[[C3]]} -> tensor<?x3xf32>{%[[D0]]}
+ // CHECK: return %[[RESULT]]
+ %0 = tensor.cast %arg0 : tensor<?x?xf32> to tensor<?x3xf32>
+ return %0 : tensor<?x3xf32>
+}
+
+// -----
+func @tensor_cast_within_dispatch_workgroups_not_converted() -> tensor<f32> {
+ %x = constant 100 : index
+ %0 = flow.dispatch.workgroups[%x]() : () -> (tensor<f32>) = () {
+ // CHECK: = tensor.cast %[[source:.+]] : tensor<4x4xf32> to tensor<?x?xf32>
+ %1 = "test.source"() : () -> (tensor<4x4xf32>)
+ %2 = tensor.cast %1 : tensor<4x4xf32> to tensor<?x?xf32>
+ "test.sink"(%2) : (tensor<?x?xf32>) -> ()
+ flow.return
+ }
+ return %0 : tensor<f32>
+}
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/extract_slice.mlir b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/extract_slice.mlir
new file mode 100644
index 0000000..ce1bff7
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/extract_slice.mlir
@@ -0,0 +1,143 @@
+// RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-flow-convert-tensor-ops-pass %s | IreeFileCheck %s
+
+func @extract_slice1(%arg0 : tensor<5x24x48xf32>) -> tensor<4xf32> {
+ %0 = tensor.extract_slice %arg0[2, 3, 4] [1, 1, 4] [1, 1, 1]
+ : tensor<5x24x48xf32> to tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+// CHECK-LABEL: func @extract_slice1(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x24x48xf32>)
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]][%[[C2]], %[[C3]], %[[C4]] for %[[C1]], %[[C1]], %[[C4]]]
+// CHECK: %[[RESULT:.+]] = flow.tensor.reshape %[[SLICE]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice2(%arg0 : tensor<5x24x48xf32>) -> tensor<2x48xf32> {
+ %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 48] [1, 1, 1]
+ : tensor<5x24x48xf32> to tensor<2x48xf32>
+ return %0 : tensor<2x48xf32>
+}
+// CHECK-LABEL: func @extract_slice2
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x24x48xf32>)
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C48:.+]] = constant 48 : index
+// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]][%[[C2]], %[[C3]], %[[C0]] for %[[C1]], %[[C2]], %[[C48]]]
+// CHECK: %[[RESULT:.+]] = flow.tensor.reshape %[[SLICE]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice3(%arg0 : tensor<5x24x48xf32>) -> tensor<2x24xf32> {
+ %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 24] [1, 1, 1]
+ : tensor<5x24x48xf32> to tensor<2x24xf32>
+ return %0 : tensor<2x24xf32>
+}
+// CHECK-LABEL: func @extract_slice3
+// CHECK: tensor.extract_slice
+
+// -----
+
+func @extract_slice4(%arg0 : tensor<5x24x48xf32>, %arg1 : index) -> tensor<2x24xf32> {
+ %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 24] [1, %arg1, 1]
+ : tensor<5x24x48xf32> to tensor<2x24xf32>
+ return %0 : tensor<2x24xf32>
+}
+// CHECK-LABEL: func @extract_slice4
+// CHECK: tensor.extract_slice
+
+// -----
+
+func @extract_slice5(%arg0 : tensor<5x24x48xf32>, %arg1 : index) -> tensor<2x48xf32> {
+ %0 = tensor.extract_slice %arg0[2, %arg1, 0] [1, 2, 48] [1, 1, 1]
+ : tensor<5x24x48xf32> to tensor<2x48xf32>
+ return %0 : tensor<2x48xf32>
+}
+// CHECK-LABEL: func @extract_slice5
+// CHECK: tensor.extract_slice
+
+// -----
+
+func @extract_slice6(%arg0 : tensor<5x24x48xf32>, %arg1 : index) -> tensor<?x48xf32> {
+ %0 = tensor.extract_slice %arg0[2, 3, 0] [1, %arg1, 48] [1, 1, 1]
+ : tensor<5x24x48xf32> to tensor<?x48xf32>
+ return %0 : tensor<?x48xf32>
+}
+// CHECK-LABEL: func @extract_slice6
+// CHECK: tensor.extract_slice
+
+// -----
+
+func @extract_slice7(%arg0 : tensor<5x?x48xf32>, %arg1 : index) -> tensor<2x48xf32> {
+ %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 48] [1, 1, 1]
+ : tensor<5x?x48xf32> to tensor<2x48xf32>
+ return %0 : tensor<2x48xf32>
+}
+// CHECK-LABEL: func @extract_slice7(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x48xf32>
+// CHECK-SAME: %[[ARG1:.+]]: index)
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C48:.+]] = constant 48 : index
+// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<5x?x48xf32>
+// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]][%[[C2]], %[[C3]], %[[C0]] for %[[C1]], %[[C2]], %[[C48]]]
+// CHECK: %[[RESULT:.+]] = flow.tensor.reshape %[[SLICE]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @rank_reducing_extract_slice(%arg0: tensor<?x513xi32>) -> tensor<513xi32> {
+ %0 = tensor.extract_slice %arg0[4, 0] [1, 513] [1, 1] : tensor<?x513xi32> to tensor<513xi32>
+ return %0 : tensor<513xi32>
+}
+// CHECK-LABEL: func @rank_reducing_extract_slice
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C513:.+]] = constant 513 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]]
+// CHECK-SAME: [%[[C4]], %[[C0]] for %[[C1]], %[[C513]]]
+// CHECK-SAME: : tensor<?x513xi32>{%[[DIM]]} -> tensor<1x513xi32>
+// CHECK: %[[RESHAPE:.+]] = flow.tensor.reshape %[[SLICE]] : tensor<1x513xi32> -> tensor<513xi32>
+// CHECK: return %[[RESHAPE]] : tensor<513xi32>
+
+// -----
+
+func @rank_reducing_extract_slice_trailing_unit_dims
+ (%arg0 : tensor<1x50x20x1xf32>) -> tensor<49x20xf32> {
+ %0 = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 49, 20, 1] [1, 1, 1, 1] : tensor<1x50x20x1xf32> to tensor<49x20xf32>
+ return %0 : tensor<49x20xf32>
+}
+// CHECK-LABEL: func @rank_reducing_extract_slice_trailing_unit_dims
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C49:.+]] = constant 49 : index
+// CHECK-DAG: %[[C20:.+]] = constant 20 : index
+// CHECK: %[[extract_slice:.+]] = flow.tensor.slice %{{.+}}[%[[C0]], %[[C1]], %[[C0]], %[[C0]] for %[[C1]], %[[C49]], %[[C20]], %[[C1]]] : tensor<1x50x20x1xf32> -> tensor<1x49x20x1xf32>
+// CHECK: flow.tensor.reshape %[[extract_slice]] : tensor<1x49x20x1xf32> -> tensor<49x20xf32>
+
+// -----
+
+func @extract_slice_within_dispatch_workgroups_not_converted() -> tensor<f32> {
+ %x = constant 100 : index
+ %0 = flow.dispatch.workgroups[%x]() : () -> (tensor<f32>) = () {
+ // CHECK: = tensor.extract_slice %[[source:.+]][2, 3, 4] [1, 1, 4] [1, 1, 1] : tensor<5x24x48xf32> to tensor<4xf32>
+ %1 = "test.source"() : () -> (tensor<5x24x48xf32>)
+ %2 = tensor.extract_slice %1[2, 3, 4] [1, 1, 4] [1, 1, 1]
+ : tensor<5x24x48xf32> to tensor<4xf32>
+ "test.sink"(%2) : (tensor<4xf32>) -> ()
+ flow.return
+ }
+ return %0 : tensor<f32>
+}
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir
new file mode 100644
index 0000000..d380abe
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/from_elements.mlir
@@ -0,0 +1,35 @@
+// RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-flow-convert-tensor-ops-pass %s | IreeFileCheck %s
+
+// CHECK: func @tensor.from_elements__to__flow.tensor.splat(%[[arg0:.*]]: i8)
+func @tensor.from_elements__to__flow.tensor.splat(%arg0: i8) -> (i8) {
+ // CHECK: %[[splat_res:.*]] = flow.tensor.splat %[[arg0]]
+ %0 = tensor.from_elements %arg0 : tensor<1xi8>
+ // CHECK: flow.tensor.load %[[splat_res]]
+ %1 = flow.tensor.load %0 : tensor<1xi8>
+ return %1 : i8
+}
+
+// -----
+// CHECK: func @tensor.from_elements__not_convertible(%[[arg0:.*]]: i8)
+func @tensor.from_elements__not_convertible(%arg0: i8) -> (i8) {
+ // CHECK: %[[c0:.*]] = constant 0
+ %c0 = constant 0 : index
+ // CHECK: %[[res:.*]] = tensor.from_elements %[[arg0]], %[[arg0]] : tensor<2xi8>
+ %0 = tensor.from_elements %arg0, %arg0 : tensor<2xi8>
+ // CHECK: flow.tensor.load %[[res]][%[[c0]]]
+ %1 = flow.tensor.load %0[%c0] : tensor<2xi8>
+ return %1 : i8
+}
+
+// -----
+func @tensor.from_elements__within_dispatch_workgroups_not_converted() -> tensor<f32> {
+ %x = constant 100 : index
+ %0 = flow.dispatch.workgroups[%x]() : () -> (tensor<f32>) = () {
+ // CHECK: = tensor.from_elements %[[source:.+]] : tensor<1xi8>
+ %1 = "test.source"() : () -> (i8)
+ %2 = tensor.from_elements %1 : tensor<1xi8>
+ "test.sink"(%2) : (tensor<1xi8>) -> ()
+ flow.return
+ }
+ return %0 : tensor<f32>
+}
diff --git a/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/insert_slice.mlir b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/insert_slice.mlir
new file mode 100644
index 0000000..383624e
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/insert_slice.mlir
@@ -0,0 +1,70 @@
+// RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-flow-convert-tensor-ops-pass %s | IreeFileCheck %s
+
+func @insert_slice_convert
+ (%arg0 : tensor<?x24x48xf32>, %arg1 : tensor<1x4x48xf32>) ->
+ tensor<?x24x48xf32> {
+ %c0 = constant 0 : index
+ %0 = tensor.insert_slice %arg1 into %arg0[4, 2, 0] [1, 4, 48] [1, 1, 1] :
+ tensor<1x4x48xf32> into tensor<?x24x48xf32>
+ return %0 : tensor<?x24x48xf32>
+}
+// CHECK-LABEL: func @insert_slice_convert
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = constant 0
+// CHECK-DAG: %[[C2:.+]] = constant 2
+// CHECK-DAG: %[[C4:.+]] = constant 4
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[UPDATE:.+]] = flow.tensor.update %[[ARG1]], %[[ARG0]][%[[C4]], %[[C2]], %[[C0]]]
+// CHECK-SAME: : tensor<1x4x48xf32> -> tensor<?x24x48xf32>{%[[DIM0]]}
+
+// -----
+
+func @insert_slice_convert_rank_reducing
+ (%arg0 : tensor<?x24x48xf32>, %arg1 : tensor<4x48xf32>) ->
+ tensor<?x24x48xf32> {
+ %c0 = constant 0 : index
+ %0 = tensor.insert_slice %arg1 into %arg0[4, 2, 0] [1, 4, 48] [1, 1, 1] :
+ tensor<4x48xf32> into tensor<?x24x48xf32>
+ return %0 : tensor<?x24x48xf32>
+}
+// CHECK-LABEL: func @insert_slice_convert_rank_reducing
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = constant 0
+// CHECK-DAG: %[[C2:.+]] = constant 2
+// CHECK-DAG: %[[C4:.+]] = constant 4
+// CHECK-DAG: %[[RESHAPE:.+]] = flow.tensor.reshape %[[ARG1]] : tensor<4x48xf32> -> tensor<1x4x48xf32>
+// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[UPDATE:.+]] = flow.tensor.update %[[RESHAPE]], %[[ARG0]][%[[C4]], %[[C2]], %[[C0]]]
+// CHECK-SAME: : tensor<1x4x48xf32> -> tensor<?x24x48xf32>{%[[DIM]]}
+
+// -----
+
+func @rank_reducing_insert_slice_trailing_unit_dims
+ (%arg0 : tensor<49x20xf32>, %arg1 : tensor<1x50x20x1xf32>) -> tensor<1x50x20x1xf32> {
+ %0 = tensor.insert_slice %arg0 into %arg1[0, 1, 0, 0] [1, 49, 20, 1] [1, 1, 1, 1] : tensor<49x20xf32> into tensor<1x50x20x1xf32>
+ return %0 : tensor<1x50x20x1xf32>
+}
+// CHECK-LABEL: func @rank_reducing_insert_slice_trailing_unit_dims
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK: %[[RESHAPE:.+]] = flow.tensor.reshape %{{.+}} : tensor<49x20xf32> -> tensor<1x49x20x1xf32>
+// CHECK: flow.tensor.update %[[RESHAPE]], %{{.+}}[%[[C0]], %[[C1]], %[[C0]], %[[C0]]] : tensor<1x49x20x1xf32> -> tensor<1x50x20x1xf32>
+
+
+// -----
+
+func @insert_slice_within_dispatch_workgroups_not_converted() -> tensor<f32> {
+ %x = constant 100 : index
+ %0 = flow.dispatch.workgroups[%x]() : () -> (tensor<f32>) = () {
+ // CHECK: = tensor.insert_slice %[[source2:.+]] into %[[source1:.+]][4, 2, 0] [1, 4, 48] [1, 1, 1] : tensor<1x4x48xf32> into tensor<?x24x48xf32>
+ %1 = "test.source1"() : () -> (tensor<?x24x48xf32>)
+ %2 = "test.source2"() : () -> (tensor<1x4x48xf32>)
+ %3 = tensor.insert_slice %2 into %1[4, 2, 0] [1, 4, 48] [1, 1, 1] :
+ tensor<1x4x48xf32> into tensor<?x24x48xf32>
+ "test.sink"(%3) : (tensor<?x24x48xf32>) -> ()
+ flow.return
+ }
+ return %0 : tensor<f32>
+}
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index 8b26a4b..12ff92c 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -33,7 +33,8 @@
srcs = [
"ConvertConv2D1x1ToMatmulPass.cpp",
"ConvertConv2DToImg2ColPass.cpp",
- "ConvertToFlowTensorOps.cpp",
+ "ConvertLinalgTensorOps.cpp",
+ "ConvertTensorOps.cpp",
"DeduplicateExecutables.cpp",
"DestructiveUpdateUtils.cpp",
"DispatchLinalgOnTensors.cpp",
@@ -65,6 +66,7 @@
],
deps = [
":PassesIncGen",
+ "//iree/compiler/Dialect/Flow/Conversion/TensorToFlow",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/IR",
"//iree/compiler/Dialect/IREE/IR",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index e8c00f3..df5ebdb 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -30,7 +30,8 @@
SRCS
"ConvertConv2D1x1ToMatmulPass.cpp"
"ConvertConv2DToImg2ColPass.cpp"
- "ConvertToFlowTensorOps.cpp"
+ "ConvertLinalgTensorOps.cpp"
+ "ConvertTensorOps.cpp"
"DeduplicateExecutables.cpp"
"DestructiveUpdateUtils.cpp"
"DispatchLinalgOnTensors.cpp"
@@ -69,6 +70,7 @@
MLIRTensor
MLIRTransformUtils
MLIRTransforms
+ iree::compiler::Dialect::Flow::Conversion::TensorToFlow
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::IREE::IR
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgTensorOps.cpp
new file mode 100644
index 0000000..b89682c
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertLinalgTensorOps.cpp
@@ -0,0 +1,137 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-flow-convert-linalg-tensor-ops"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+namespace {
+
+/// Generates `tensor.dim` operations to get the dynamic sizes of a value `v`.
+static SmallVector<Value, 4> getDynamicDimValues(OpBuilder &b, Location loc,
+ Value v) {
+ SmallVector<Value, 4> dynamicDims;
+ for (auto dim : llvm::enumerate(v.getType().cast<ShapedType>().getShape())) {
+ if (dim.value() != ShapedType::kDynamicSize) continue;
+ dynamicDims.push_back(b.createOrFold<tensor::DimOp>(loc, v, dim.index()));
+ }
+ return dynamicDims;
+}
+
+/// Converts linalg.tensor_reshape operations into flow.tensor.reshape
+/// operations.
+template <typename TensorReshapeOp>
+struct LinalgTensorReshapeToFlowTensorReshape
+ : public OpRewritePattern<TensorReshapeOp> {
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ if (reshapeOp->template getParentOfType<Flow::DispatchWorkgroupsOp>()) {
+ return failure();
+ }
+ SmallVector<SmallVector<Value>> outputShape;
+ if (failed(reshapeOp.reifyResultShapes(rewriter, outputShape))) {
+ return failure();
+ }
+ SmallVector<Value> outputDynamicShapes;
+ for (auto shape :
+ llvm::zip(reshapeOp.getResultType().getShape(), outputShape[0])) {
+ if (std::get<0>(shape) != ShapedType::kDynamicSize) continue;
+ outputDynamicShapes.push_back(std::get<1>(shape));
+ }
+ rewriter.replaceOpWithNewOp<IREE::Flow::TensorReshapeOp>(
+ reshapeOp, reshapeOp.getResultType(), reshapeOp.src(),
+ outputDynamicShapes);
+ return success();
+ }
+};
+
+/// Converts linalg.fill ops into flow.tensor.splat ops.
+///
+/// This is expected to improve performance because we can use DMA
+/// functionalities for the fill, instead of dispatching kernels.
+struct LinalgFillToFlowTensorSplat final
+ : public OpRewritePattern<linalg::FillOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::FillOp fillOp,
+ PatternRewriter &rewriter) const override {
+ if (fillOp->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
+ // Don't convert linalg.fill ops that were fused together with other ops.
+ return failure();
+ }
+
+ SmallVector<Value, 4> dynamicDims =
+ getDynamicDimValues(rewriter, fillOp.getLoc(), fillOp.output());
+ rewriter.replaceOpWithNewOp<TensorSplatOp>(
+ fillOp, fillOp.output().getType(), fillOp.value(), dynamicDims);
+ return success();
+ }
+};
+
+/// Converts linalg operations that can map to flow.tensor.* operations.
+struct ConvertLinalgTensorOpsPass
+ : public ConvertLinalgTensorOpsBase<ConvertLinalgTensorOpsPass> {
+ ConvertLinalgTensorOpsPass(bool runBefore) {
+ runBeforeDispatchRegionFormation = runBefore;
+ }
+ ConvertLinalgTensorOpsPass(const ConvertLinalgTensorOpsPass &that) {
+ runBeforeDispatchRegionFormation = that.runBeforeDispatchRegionFormation;
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Flow::FlowDialect, tensor::TensorDialect,
+ linalg::LinalgDialect, mlir::StandardOpsDialect>();
+ }
+ void runOnOperation() override {
+ FuncOp funcOp = getOperation();
+ MLIRContext *context = funcOp->getContext();
+ context->allowUnregisteredDialects(true);
+ RewritePatternSet patterns(&getContext());
+ if (runBeforeDispatchRegionFormation) {
+ patterns.insert<
+ LinalgTensorReshapeToFlowTensorReshape<linalg::TensorCollapseShapeOp>,
+ LinalgTensorReshapeToFlowTensorReshape<linalg::TensorExpandShapeOp>>(
+ context);
+ } else {
+ patterns.insert<LinalgFillToFlowTensorSplat>(context);
+ }
+ IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgTensorOpsPass(
+ bool runBeforeDispatchRegionFormation) {
+ return std::make_unique<ConvertLinalgTensorOpsPass>(
+ runBeforeDispatchRegionFormation);
+}
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertTensorOps.cpp
new file mode 100644
index 0000000..f0cf6de
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertTensorOps.cpp
@@ -0,0 +1,55 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/ConvertTensorToFlow.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-flow-convert-tensor-ops"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+namespace {
+
+/// Converts operations that can map to flow.tensor.* operations.
+struct ConvertTensorOpsPass
+ : public ConvertTensorOpsBase<ConvertTensorOpsPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Flow::FlowDialect, mlir::StandardOpsDialect,
+ tensor::TensorDialect>();
+ }
+ void runOnOperation() override {
+ FuncOp funcOp = getOperation();
+ MLIRContext *context = funcOp->getContext();
+ context->allowUnregisteredDialects(true);
+ RewritePatternSet patterns(&getContext());
+ populateTensorToFlowPatterns(&getContext(), patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createConvertTensorOpsPass() {
+ return std::make_unique<ConvertTensorOpsPass>();
+}
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 86696f6..18c51b9 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -125,15 +125,16 @@
passManager.addNestedPass<FuncOp>(mlir::createLinalgDetensorizePass());
}
passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
+ passManager.addNestedPass<FuncOp>(IREE::Flow::createConvertTensorOpsPass());
passManager.addNestedPass<FuncOp>(
- IREE::Flow::createConvertToFlowTensorOpsPass(
+ IREE::Flow::createConvertLinalgTensorOpsPass(
/*runBeforeDispatchRegionFormation=*/true));
passManager.addNestedPass<FuncOp>(mlir::createCanonicalizerPass());
passManager.addNestedPass<FuncOp>(
IREE::Flow::createDispatchLinalgOnTensorsPass());
passManager.addPass(memref::createResolveShapedTypeResultDimsPass());
passManager.addNestedPass<FuncOp>(
- IREE::Flow::createConvertToFlowTensorOpsPass(
+ IREE::Flow::createConvertLinalgTensorOpsPass(
/*runBeforeDispatchRegionFormation=*/false));
// NOTE: required because the current dispatch-linalg-on-tensors pass
// creates a lot of dead IR that needs to be cleaned up.
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 445190e..130c13f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -61,11 +61,17 @@
/// the most inner loops.
std::unique_ptr<OperationPass<FuncOp>> createInterchangeGenericOpsPass();
-// Convert operations to equivalent flow.tensor.* ops.
+// Convert tensor operations to equivalent flow.tensor.* operations
// `runBeforeDispatchRegionFormation` controls whether to run before dispatch
// region creation. If run after, it will catch operations that were left
// outside of dispatch regions and could be represented as flow.tensor.* ops.
-std::unique_ptr<OperationPass<FuncOp>> createConvertToFlowTensorOpsPass(
+std::unique_ptr<OperationPass<FuncOp>> createConvertTensorOpsPass();
+
+// Convert linalg.tensor operations to equivalent flow.tensor.* ops.
+// `runBeforeDispatchRegionFormation` controls whether to run before dispatch
+// region creation. If run after, it will catch operations that were left
+// outside of dispatch regions and could be represented as flow.tensor.* ops.
+std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgTensorOpsPass(
bool runBeforeDispatchRegionFormation = true);
// Promote I1 tensor constants to I8 tensors to match later operations.
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.td b/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 4dfc814..60f708a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -21,10 +21,16 @@
let constructor = "mlir::iree_compiler::IREE::Flow::createConvertConv2DToImg2ColPass()";
}
-def ConvertToFlowTensorOps :
- Pass<"iree-flow-convert-to-flow-tensor-ops-pass", "FuncOp"> {
- let summary = "Convert operations to equivalent flow.tensor.* operations";
- let constructor = "mlir::iree_compiler::IREE::Flow::createConvertToFlowTensorOpsPass()";
+def ConvertTensorOps :
+ Pass<"iree-flow-convert-tensor-ops-pass", "FuncOp"> {
+ let summary = "Convert tensor operations to equivalent flow.tensor.* operations";
+ let constructor = "mlir::iree_compiler::IREE::Flow::createConvertTensorOpsPass()";
+}
+
+def ConvertLinalgTensorOps :
+ Pass<"iree-flow-convert-linalg-tensor-ops-pass", "FuncOp"> {
+ let summary = "Convert linalg operations to equivalent flow.tensor.* operations";
+ let constructor = "mlir::iree_compiler::IREE::Flow::createConvertLinalgTensorOpsPass()";
let options = [
Option<"runBeforeDispatchRegionFormation", "run-before-dispatch-region-formation",
"bool", /*default=*/"true", "Run the pass before dispatch region formation">
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/BUILD b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
index b16ee25..12c0f37 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/test/BUILD
@@ -19,8 +19,8 @@
[
"conv1x1_to_matmul.mlir",
"conv2d_to_img2col.mlir",
- "convert_to_flow_tensor_ops_after.mlir",
- "convert_to_flow_tensor_ops_before.mlir",
+ "convert_linalg_tensor_ops_after.mlir",
+ "convert_linalg_tensor_ops_before.mlir",
"deduplicate_executables.mlir",
"dispatch_linalg_on_tensors.mlir",
"dispatch_linalg_on_tensors_elementwise.mlir",
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
index 01a4294..83b449a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/test/CMakeLists.txt
@@ -16,8 +16,8 @@
SRCS
"conv1x1_to_matmul.mlir"
"conv2d_to_img2col.mlir"
- "convert_to_flow_tensor_ops_after.mlir"
- "convert_to_flow_tensor_ops_before.mlir"
+ "convert_linalg_tensor_ops_after.mlir"
+ "convert_linalg_tensor_ops_before.mlir"
"deduplicate_executables.mlir"
"dispatch_linalg_on_tensors.mlir"
"dispatch_linalg_on_tensors_elementwise.mlir"
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops_after.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
similarity index 91%
rename from iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops_after.mlir
rename to iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
index f126e54..e3e964e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops_after.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_after.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -iree-flow-convert-to-flow-tensor-ops-pass='run-before-dispatch-region-formation=false' -canonicalize -cse -split-input-file %s | IreeFileCheck %s
+// RUN: iree-opt -iree-flow-convert-linalg-tensor-ops-pass='run-before-dispatch-region-formation=false' -canonicalize -cse -split-input-file %s | IreeFileCheck %s
func @turn_fill_into_splat(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> tensor<?x?xf32> {
%c0 = constant 0 : index
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_before.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_before.mlir
new file mode 100644
index 0000000..621912c
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_linalg_tensor_ops_before.mlir
@@ -0,0 +1,17 @@
+// RUN: iree-opt -iree-flow-convert-linalg-tensor-ops-pass -canonicalize -cse -split-input-file %s | IreeFileCheck %s
+
+func @tensor_reshape(%arg0 : tensor<?x4x?x5x?x6xf32>, %arg1 : tensor<20x?x40xf32>)
+ -> (tensor<?x5x?xf32>, tensor<5x4x?x4x2x4x5xf32>)
+{
+ %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3], [4, 5]]
+ : tensor<?x4x?x5x?x6xf32> into tensor<?x5x?xf32>
+ %1 = linalg.tensor_expand_shape %arg1 [[0, 1], [2, 3], [4, 5, 6]]
+ : tensor<20x?x40xf32> into tensor<5x4x?x4x2x4x5xf32>
+ return %0, %1 : tensor<?x5x?xf32>, tensor<5x4x?x4x2x4x5xf32>
+}
+// CHECK-LABEL: func @tensor_reshape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x4x?x5x?x6xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<20x?x40xf32>
+// CHECK-DAG: %[[R0:.+]] = flow.tensor.reshape %[[ARG0]]
+// CHECK-DAG: %[[R1:.+]] = flow.tensor.reshape %[[ARG1]]
+// CHECK: return %[[R0]], %[[R1]]
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops_before.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops_before.mlir
deleted file mode 100644
index 01ada4d..0000000
--- a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops_before.mlir
+++ /dev/null
@@ -1,200 +0,0 @@
-// RUN: iree-opt -iree-flow-convert-to-flow-tensor-ops-pass -canonicalize -cse -split-input-file %s | IreeFileCheck %s
-
-func @subtensor1(%arg0 : tensor<5x24x48xf32>) -> tensor<4xf32> {
- %0 = tensor.extract_slice %arg0[2, 3, 4] [1, 1, 4] [1, 1, 1]
- : tensor<5x24x48xf32> to tensor<4xf32>
- return %0 : tensor<4xf32>
-}
-// CHECK-LABEL: func @subtensor1(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<5x24x48xf32>)
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]][%[[C2]], %[[C3]], %[[C4]] for %[[C1]], %[[C1]], %[[C4]]]
-// CHECK: %[[RESULT:.+]] = flow.tensor.reshape %[[SLICE]]
-// CHECK: return %[[RESULT]]
-
-// -----
-
-func @subtensor2(%arg0 : tensor<5x24x48xf32>) -> tensor<2x48xf32> {
- %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 48] [1, 1, 1]
- : tensor<5x24x48xf32> to tensor<2x48xf32>
- return %0 : tensor<2x48xf32>
-}
-// CHECK-LABEL: func @subtensor2
-// CHECK-SAME: %[[ARG0:.+]]: tensor<5x24x48xf32>)
-// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C48:.+]] = constant 48 : index
-// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]][%[[C2]], %[[C3]], %[[C0]] for %[[C1]], %[[C2]], %[[C48]]]
-// CHECK: %[[RESULT:.+]] = flow.tensor.reshape %[[SLICE]]
-// CHECK: return %[[RESULT]]
-
-// -----
-
-func @subtensor3(%arg0 : tensor<5x24x48xf32>) -> tensor<2x24xf32> {
- %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 24] [1, 1, 1]
- : tensor<5x24x48xf32> to tensor<2x24xf32>
- return %0 : tensor<2x24xf32>
-}
-// CHECK-LABEL: func @subtensor3
-// CHECK: tensor.extract_slice
-
-// -----
-
-func @subtensor4(%arg0 : tensor<5x24x48xf32>, %arg1 : index) -> tensor<2x24xf32> {
- %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 24] [1, %arg1, 1]
- : tensor<5x24x48xf32> to tensor<2x24xf32>
- return %0 : tensor<2x24xf32>
-}
-// CHECK-LABEL: func @subtensor4
-// CHECK: tensor.extract_slice
-
-// -----
-
-func @subtensor5(%arg0 : tensor<5x24x48xf32>, %arg1 : index) -> tensor<2x48xf32> {
- %0 = tensor.extract_slice %arg0[2, %arg1, 0] [1, 2, 48] [1, 1, 1]
- : tensor<5x24x48xf32> to tensor<2x48xf32>
- return %0 : tensor<2x48xf32>
-}
-// CHECK-LABEL: func @subtensor5
-// CHECK: tensor.extract_slice
-
-// -----
-
-func @subtensor6(%arg0 : tensor<5x24x48xf32>, %arg1 : index) -> tensor<?x48xf32> {
- %0 = tensor.extract_slice %arg0[2, 3, 0] [1, %arg1, 48] [1, 1, 1]
- : tensor<5x24x48xf32> to tensor<?x48xf32>
- return %0 : tensor<?x48xf32>
-}
-// CHECK-LABEL: func @subtensor6
-// CHECK: tensor.extract_slice
-
-// -----
-
-func @subtensor7(%arg0 : tensor<5x?x48xf32>, %arg1 : index) -> tensor<2x48xf32> {
- %0 = tensor.extract_slice %arg0[2, 3, 0] [1, 2, 48] [1, 1, 1]
- : tensor<5x?x48xf32> to tensor<2x48xf32>
- return %0 : tensor<2x48xf32>
-}
-// CHECK-LABEL: func @subtensor7(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x48xf32>
-// CHECK-SAME: %[[ARG1:.+]]: index)
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C48:.+]] = constant 48 : index
-// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<5x?x48xf32>
-// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]][%[[C2]], %[[C3]], %[[C0]] for %[[C1]], %[[C2]], %[[C48]]]
-// CHECK: %[[RESULT:.+]] = flow.tensor.reshape %[[SLICE]]
-// CHECK: return %[[RESULT]]
-
-// -----
-
-func @rank_reducing_subtensor(%arg0: tensor<?x513xi32>) -> tensor<513xi32> {
- %0 = tensor.extract_slice %arg0[4, 0] [1, 513] [1, 1] : tensor<?x513xi32> to tensor<513xi32>
- return %0 : tensor<513xi32>
-}
-// CHECK-LABEL: func @rank_reducing_subtensor
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK-DAG: %[[C513:.+]] = constant 513 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]]
-// CHECK-SAME: [%[[C4]], %[[C0]] for %[[C1]], %[[C513]]]
-// CHECK-SAME: : tensor<?x513xi32>{%[[DIM]]} -> tensor<1x513xi32>
-// CHECK: %[[RESHAPE:.+]] = flow.tensor.reshape %[[SLICE]] : tensor<1x513xi32> -> tensor<513xi32>
-// CHECK: return %[[RESHAPE]] : tensor<513xi32>
-
-// -----
-
-func @tensor_reshape(%arg0 : tensor<?x4x?x5x?x6xf32>, %arg1 : tensor<20x?x40xf32>)
- -> (tensor<?x5x?xf32>, tensor<5x4x?x4x2x4x5xf32>)
-{
- %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3], [4, 5]]
- : tensor<?x4x?x5x?x6xf32> into tensor<?x5x?xf32>
- %1 = linalg.tensor_expand_shape %arg1 [[0, 1], [2, 3], [4, 5, 6]]
- : tensor<20x?x40xf32> into tensor<5x4x?x4x2x4x5xf32>
- return %0, %1 : tensor<?x5x?xf32>, tensor<5x4x?x4x2x4x5xf32>
-}
-// CHECK-LABEL: func @tensor_reshape
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x4x?x5x?x6xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<20x?x40xf32>
-// CHECK-DAG: %[[R0:.+]] = flow.tensor.reshape %[[ARG0]]
-// CHECK-DAG: %[[R1:.+]] = flow.tensor.reshape %[[ARG1]]
-// CHECK: return %[[R0]], %[[R1]]
-
-// -----
-
-func @subtensor_insert_convert
- (%arg0 : tensor<?x24x48xf32>, %arg1 : tensor<1x4x48xf32>) ->
- tensor<?x24x48xf32> {
- %c0 = constant 0 : index
- %0 = tensor.insert_slice %arg1 into %arg0[4, 2, 0] [1, 4, 48] [1, 1, 1] :
- tensor<1x4x48xf32> into tensor<?x24x48xf32>
- return %0 : tensor<?x24x48xf32>
-}
-// CHECK-LABEL: func @subtensor_insert_convert
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
-// CHECK-DAG: %[[C0:.+]] = constant 0
-// CHECK-DAG: %[[C2:.+]] = constant 2
-// CHECK-DAG: %[[C4:.+]] = constant 4
-// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK: %[[UPDATE:.+]] = flow.tensor.update %[[ARG1]], %[[ARG0]][%[[C4]], %[[C2]], %[[C0]]]
-// CHECK-SAME: : tensor<1x4x48xf32> -> tensor<?x24x48xf32>{%[[DIM0]]}
-
-// -----
-
-func @subtensor_insert_convert_rank_reducing
- (%arg0 : tensor<?x24x48xf32>, %arg1 : tensor<4x48xf32>) ->
- tensor<?x24x48xf32> {
- %c0 = constant 0 : index
- %0 = tensor.insert_slice %arg1 into %arg0[4, 2, 0] [1, 4, 48] [1, 1, 1] :
- tensor<4x48xf32> into tensor<?x24x48xf32>
- return %0 : tensor<?x24x48xf32>
-}
-// CHECK-LABEL: func @subtensor_insert_convert_rank_reducing
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
-// CHECK-DAG: %[[C0:.+]] = constant 0
-// CHECK-DAG: %[[C2:.+]] = constant 2
-// CHECK-DAG: %[[C4:.+]] = constant 4
-// CHECK-DAG: %[[RESHAPE:.+]] = flow.tensor.reshape %[[ARG1]] : tensor<4x48xf32> -> tensor<1x4x48xf32>
-// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK: %[[UPDATE:.+]] = flow.tensor.update %[[RESHAPE]], %[[ARG0]][%[[C4]], %[[C2]], %[[C0]]]
-// CHECK-SAME: : tensor<1x4x48xf32> -> tensor<?x24x48xf32>{%[[DIM]]}
-
-// -----
-
-func @rank_reducing_subtensor_insert_trailing_unit_dims
- (%arg0 : tensor<49x20xf32>, %arg1 : tensor<1x50x20x1xf32>) -> tensor<1x50x20x1xf32> {
- %0 = tensor.insert_slice %arg0 into %arg1[0, 1, 0, 0] [1, 49, 20, 1] [1, 1, 1, 1] : tensor<49x20xf32> into tensor<1x50x20x1xf32>
- return %0 : tensor<1x50x20x1xf32>
-}
-// CHECK-LABEL: func @rank_reducing_subtensor_insert_trailing_unit_dims
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK: %[[RESHAPE:.+]] = flow.tensor.reshape %{{.+}} : tensor<49x20xf32> -> tensor<1x49x20x1xf32>
-// CHECK: flow.tensor.update %[[RESHAPE]], %{{.+}}[%[[C0]], %[[C1]], %[[C0]], %[[C0]]] : tensor<1x49x20x1xf32> -> tensor<1x50x20x1xf32>
-
-// -----
-
-func @rank_reducing_subtensor_trailing_unit_dims
- (%arg0 : tensor<1x50x20x1xf32>) -> tensor<49x20xf32> {
- %0 = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 49, 20, 1] [1, 1, 1, 1] : tensor<1x50x20x1xf32> to tensor<49x20xf32>
- return %0 : tensor<49x20xf32>
-}
-// CHECK-LABEL: func @rank_reducing_subtensor_trailing_unit_dims
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C49:.+]] = constant 49 : index
-// CHECK-DAG: %[[C20:.+]] = constant 20 : index
-// CHECK: %[[SUBTENSOR:.+]] = flow.tensor.slice %{{.+}}[%[[C0]], %[[C1]], %[[C0]], %[[C0]] for %[[C1]], %[[C49]], %[[C20]], %[[C1]]] : tensor<1x50x20x1xf32> -> tensor<1x49x20x1xf32>
-// CHECK: flow.tensor.reshape %[[SUBTENSOR]] : tensor<1x49x20x1xf32> -> tensor<49x20xf32>
diff --git a/iree/compiler/InputConversion/Common/BUILD b/iree/compiler/InputConversion/Common/BUILD
index 4677590..7e22c00 100644
--- a/iree/compiler/InputConversion/Common/BUILD
+++ b/iree/compiler/InputConversion/Common/BUILD
@@ -44,7 +44,6 @@
cc_library(
name = "Common",
srcs = [
- "ConvertUpstreamToIREE.cpp",
"Passes.cpp",
"TopLevelSCFToCFG.cpp",
],
@@ -54,13 +53,11 @@
deps = [
":PassHeaders",
":PassesIncGen",
- "//iree/compiler/Dialect/Flow/IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFToStandard",
"@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
diff --git a/iree/compiler/InputConversion/Common/CMakeLists.txt b/iree/compiler/InputConversion/Common/CMakeLists.txt
index d35c59e..7c2e3a8 100644
--- a/iree/compiler/InputConversion/Common/CMakeLists.txt
+++ b/iree/compiler/InputConversion/Common/CMakeLists.txt
@@ -39,7 +39,6 @@
HDRS
"Passes.h"
SRCS
- "ConvertUpstreamToIREE.cpp"
"Passes.cpp"
"TopLevelSCFToCFG.cpp"
DEPS
@@ -50,9 +49,7 @@
MLIRSCF
MLIRSCFToStandard
MLIRStandard
- MLIRTensor
MLIRTransforms
- iree::compiler::Dialect::Flow::IR
PUBLIC
)
diff --git a/iree/compiler/InputConversion/Common/ConvertUpstreamToIREE.cpp b/iree/compiler/InputConversion/Common/ConvertUpstreamToIREE.cpp
deleted file mode 100644
index 5b7b76b..0000000
--- a/iree/compiler/InputConversion/Common/ConvertUpstreamToIREE.cpp
+++ /dev/null
@@ -1,164 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/InputConversion/Common/PassDetail.h"
-#include "iree/compiler/InputConversion/Common/Passes.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-struct ConvertTensorCastPattern : public OpConversionPattern<tensor::CastOp> {
- using OpConversionPattern<tensor::CastOp>::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- tensor::CastOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
- Value input = operands.front();
- ShapedType inputType = input.getType().dyn_cast<ShapedType>();
- ShapedType resultType =
- typeConverter->convertType(op.getType()).dyn_cast_or_null<ShapedType>();
- if (!inputType || !resultType || !inputType.hasRank() ||
- !resultType.hasRank()) {
- return rewriter.notifyMatchFailure(op, "not ranked shaped types");
- }
- // This should not happen, except in the context of type conversion.
- if (inputType.getRank() != resultType.getRank()) {
- return rewriter.notifyMatchFailure(op, "mismatched rank");
- }
-
- // Resolve dims to the most specific value.
- int rank = inputType.getRank();
- SmallVector<Value> dimSizes(rank);
- auto resolveDimSize = [&](int position) -> Value {
- if (!dimSizes[position]) {
- // Find the most specific.
- if (!inputType.isDynamicDim(position) ||
- !resultType.isDynamicDim(position)) {
- // Static dim.
- int64_t dimSize = !inputType.isDynamicDim(position)
- ? inputType.getDimSize(position)
- : resultType.getDimSize(position);
- dimSizes[position] = rewriter.create<ConstantIndexOp>(loc, dimSize);
- } else {
- // Dynamic dim.
- dimSizes[position] =
- rewriter.create<tensor::DimOp>(loc, input, position);
- }
- }
-
- return dimSizes[position];
- };
-
- SmallVector<Value> sourceDynamicDims;
- SmallVector<Value> targetDynamicDims;
- for (int i = 0; i < rank; i++) {
- if (inputType.isDynamicDim(i)) {
- sourceDynamicDims.push_back(resolveDimSize(i));
- }
- if (resultType.isDynamicDim(i)) {
- targetDynamicDims.push_back(resolveDimSize(i));
- }
- }
-
- // TODO: Decide if this needs to be replaced with a flow.tensor.cast
- // See https://github.com/google/iree/issues/6418
- rewriter.replaceOpWithNewOp<IREE::Flow::TensorReshapeOp>(
- op, resultType, input, sourceDynamicDims, targetDynamicDims);
-
- return success();
- }
-};
-
-struct ConvertTensorFromElementsPattern
- : public OpConversionPattern<tensor::FromElementsOp> {
- using OpConversionPattern<tensor::FromElementsOp>::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- tensor::FromElementsOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
-
- if (shouldBeConverted(op)) {
- SmallVector<Value> dimSizes(1);
- dimSizes[0] = rewriter.create<ConstantIndexOp>(loc, 1);
- rewriter.replaceOpWithNewOp<IREE::Flow::TensorSplatOp>(
- op, op.getType(), operands.front(), dimSizes);
- }
-
- return success();
- }
-
- // TODO: This pattern was mainly added to iron out some kinks specific to
- // detensoring (see: https://github.com/google/iree/issues/1159). Do we need
- // to expand this check for other uses?
- static bool shouldBeConverted(tensor::FromElementsOp op) {
- return op.getType().getDimSize(0) == 1;
- }
-};
-} // namespace
-
-void populateConvertUpstreamToIREEPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
- patterns.add<ConvertTensorCastPattern>(typeConverter, context);
- patterns.add<ConvertTensorFromElementsPattern>(typeConverter, context);
-}
-
-namespace {
-
-struct ConvertUpstreamToIREEPass
- : public ConvertUpstreamToIREEBase<ConvertUpstreamToIREEPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<StandardOpsDialect, tensor::TensorDialect,
- IREE::Flow::FlowDialect>();
- }
-
- void runOnOperation() override;
-};
-
-} // namespace
-
-void ConvertUpstreamToIREEPass::runOnOperation() {
- OwningRewritePatternList patterns(&getContext());
- MLIRContext *context = &getContext();
- TypeConverter typeConverter;
- typeConverter.addConversion([](Type t) { return t; });
- populateConvertUpstreamToIREEPatterns(&getContext(), typeConverter, patterns);
-
- ConversionTarget target(*context);
- target.addIllegalOp<tensor::CastOp>();
- target.addDynamicallyLegalOp<tensor::FromElementsOp>(
- [](tensor::FromElementsOp op) {
- return !ConvertTensorFromElementsPattern::shouldBeConverted(op);
- });
-
- target.addLegalDialect<StandardOpsDialect>();
- target.addLegalDialect<tensor::TensorDialect>();
- target.addLegalDialect<IREE::Flow::FlowDialect>();
-
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns)))) {
- return signalPassFailure();
- }
-}
-
-std::unique_ptr<OperationPass<FuncOp>> createConvertUpstreamToIREE() {
- return std::make_unique<ConvertUpstreamToIREEPass>();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/InputConversion/Common/Passes.cpp b/iree/compiler/InputConversion/Common/Passes.cpp
index 7e4e3d6..84332fd 100644
--- a/iree/compiler/InputConversion/Common/Passes.cpp
+++ b/iree/compiler/InputConversion/Common/Passes.cpp
@@ -6,39 +6,17 @@
#include "iree/compiler/InputConversion/Common/Passes.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Pass/PassOptions.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/Passes.h"
-
namespace mlir {
namespace iree_compiler {
-void registerCommonConversionPassPipelines() {
- PassPipelineRegistration<> common(
- "iree-common-input-transformation-pipeline",
- "Runs the common input transformation pipeline",
- [](OpPassManager &passManager) {
- buildCommonInputConversionPassPipeline(passManager);
- });
-}
-
-// Common transformations to prepare input dialects for IREE.
-void buildCommonInputConversionPassPipeline(OpPassManager &passManager) {
- passManager.addNestedPass<FuncOp>(createConvertUpstreamToIREE());
-}
-
namespace {
#define GEN_PASS_REGISTRATION
#include "iree/compiler/InputConversion/Common/Passes.h.inc" // IWYU pragma: export
} // namespace
void registerCommonInputConversionPasses() {
- // Generated.
+ // Generated passes.
registerPasses();
-
- // Pipelines.
- registerCommonConversionPassPipelines();
}
} // namespace iree_compiler
diff --git a/iree/compiler/InputConversion/Common/Passes.h b/iree/compiler/InputConversion/Common/Passes.h
index 1a6f703..f67587a 100644
--- a/iree/compiler/InputConversion/Common/Passes.h
+++ b/iree/compiler/InputConversion/Common/Passes.h
@@ -14,28 +14,10 @@
namespace iree_compiler {
//===----------------------------------------------------------------------===//
-// Pipelines
-//===----------------------------------------------------------------------===//
-
-// Performs input legalization for specific combination of input dialects.
-void buildCommonInputConversionPassPipeline(OpPassManager &passManager);
-
-void registerCommonConversionPassPipelines();
-
-//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
std::unique_ptr<OperationPass<FuncOp>> createTopLevelSCFToCFGPass();
-std::unique_ptr<OperationPass<FuncOp>> createConvertUpstreamToIREE();
-
-//===----------------------------------------------------------------------===//
-// Patterns
-//===----------------------------------------------------------------------===//
-
-void populateConvertUpstreamToIREEPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
//===----------------------------------------------------------------------===//
// Register all Passes
diff --git a/iree/compiler/InputConversion/Common/Passes.td b/iree/compiler/InputConversion/Common/Passes.td
index 3eca76c..7c225c6 100644
--- a/iree/compiler/InputConversion/Common/Passes.td
+++ b/iree/compiler/InputConversion/Common/Passes.td
@@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#ifndef IREE_COMPILER_INPUTCONVERSION_TOSA_PASSES
-#define IREE_COMPILER_INPUTCONVERSION_TOSA_PASSES
+#ifndef IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES
+#define IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES
include "mlir/Pass/PassBase.td"
@@ -15,10 +15,4 @@
let constructor = "mlir::iree_compiler::createTopLevelSCFToCFGPass()";
}
-def ConvertUpstreamToIREE :
- Pass<"iree-convert-upstream-to-iree", "FuncOp"> {
- let summary = "Catch-all pass to convert upstream MLIR ops that (for whatever reason) we prefer to be represented differently in IREE";
- let constructor = "mlir::iree_compiler::createConvertUpstreamToIREE()";
-}
-
-#endif // IREE_COMPILER_INPUTCONVERSION_TOSA_PASSES
+#endif // IREE_COMPILER_INPUTCONVERSION_COMMON_PASSES
diff --git a/iree/compiler/InputConversion/Common/test/BUILD b/iree/compiler/InputConversion/Common/test/BUILD
index d21c9fa..d9e87bd 100644
--- a/iree/compiler/InputConversion/Common/test/BUILD
+++ b/iree/compiler/InputConversion/Common/test/BUILD
@@ -19,7 +19,6 @@
name = "lit",
srcs = enforce_glob(
[
- "convert_upstream_to_iree.mlir",
"top_level_scf_to_cfg.mlir",
],
include = ["*.mlir"],
diff --git a/iree/compiler/InputConversion/Common/test/CMakeLists.txt b/iree/compiler/InputConversion/Common/test/CMakeLists.txt
index c3e22eb..ab43294 100644
--- a/iree/compiler/InputConversion/Common/test/CMakeLists.txt
+++ b/iree/compiler/InputConversion/Common/test/CMakeLists.txt
@@ -14,7 +14,6 @@
NAME
lit
SRCS
- "convert_upstream_to_iree.mlir"
"top_level_scf_to_cfg.mlir"
DATA
iree::tools::IreeFileCheck
diff --git a/iree/compiler/InputConversion/Common/test/convert_upstream_to_iree.mlir b/iree/compiler/InputConversion/Common/test/convert_upstream_to_iree.mlir
deleted file mode 100644
index 865183a..0000000
--- a/iree/compiler/InputConversion/Common/test/convert_upstream_to_iree.mlir
+++ /dev/null
@@ -1,52 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-convert-upstream-to-iree %s | IreeFileCheck %s
-
-func @static_tensor_cast_to_dynamic(%arg0: tensor<4x4xf32>) -> tensor<?x?xf32> {
- // CHECK-DAG: %[[C4_0:.*]] = constant 4 : index
- // CHECK-DAG: %[[C4_1:.*]] = constant 4 : index
- // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<4x4xf32> -> tensor<?x?xf32>{%[[C4_0]], %[[C4_1]]}
- // CHECK: return %[[RESULT]]
- %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
- return %0 : tensor<?x?xf32>
-}
-
-// -----
-func @dynamic_tensor_cast_to_static(%arg0: tensor<?xf32>) -> tensor<4xf32> {
- // CHECK: %[[C4_0:.*]] = constant 4 : index
- // CHECK: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<?xf32>{%[[C4_0]]} -> tensor<4xf32>
- // CHECK: return %[[RESULT]]
- %0 = tensor.cast %arg0 : tensor<?xf32> to tensor<4xf32>
- return %0 : tensor<4xf32>
-}
-
-// -----
-func @dynamic_tensor_cast_to_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?x3xf32> {
- // CHECK-DAG: %[[C0:.*]] = constant 0 : index
- // CHECK-DAG: %[[D0:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
- // CHECK-DAG: %[[C3:.*]] = constant 3 : index
- // CHECK: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<?x?xf32>{%[[D0]], %[[C3]]} -> tensor<?x3xf32>{%[[D0]]}
- // CHECK: return %[[RESULT]]
- %0 = tensor.cast %arg0 : tensor<?x?xf32> to tensor<?x3xf32>
- return %0 : tensor<?x3xf32>
-}
-
-// -----
-// CHECK: func @tensor.from_elements__to__flow.tensor.splat(%[[arg0:.*]]: i8)
-func @tensor.from_elements__to__flow.tensor.splat(%arg0: i8) -> (i8) {
- // CHECK: %[[splat_res:.*]] = flow.tensor.splat %[[arg0]]
- %0 = tensor.from_elements %arg0 : tensor<1xi8>
- // CHECK: flow.tensor.load %[[splat_res]]
- %1 = flow.tensor.load %0 : tensor<1xi8>
- return %1 : i8
-}
-
-// -----
-// CHECK: func @tensor.from_elements__not_convertible(%[[arg0:.*]]: i8)
-func @tensor.from_elements__not_convertible(%arg0: i8) -> (i8) {
- // CHECK: %[[c0:.*]] = constant 0
- %c0 = constant 0 : index
- // CHECK: %[[res:.*]] = tensor.from_elements %[[arg0]], %[[arg0]] : tensor<2xi8>
- %0 = tensor.from_elements %arg0, %arg0 : tensor<2xi8>
- // CHECK: flow.tensor.load %[[res]][%[[c0]]]
- %1 = flow.tensor.load %0[%c0] : tensor<2xi8>
- return %1 : i8
-}
diff --git a/iree/compiler/Translation/BUILD b/iree/compiler/Translation/BUILD
index 044807c..1e2150d 100644
--- a/iree/compiler/Translation/BUILD
+++ b/iree/compiler/Translation/BUILD
@@ -29,7 +29,6 @@
"//iree/compiler/Dialect/VM/Conversion/StandardToVM",
"//iree/compiler/Dialect/VM/Target/Bytecode",
"//iree/compiler/Dialect/VM/Transforms",
- "//iree/compiler/InputConversion/Common",
"//iree/compiler/InputConversion/MHLO",
"//iree/compiler/InputConversion/TOSA",
"//iree/compiler/Utils",
diff --git a/iree/compiler/Translation/IREEVM.cpp b/iree/compiler/Translation/IREEVM.cpp
index 6b89909..4024b00 100644
--- a/iree/compiler/Translation/IREEVM.cpp
+++ b/iree/compiler/Translation/IREEVM.cpp
@@ -13,7 +13,6 @@
#include "iree/compiler/Dialect/IREE/Transforms/Passes.h"
#include "iree/compiler/Dialect/VM/Target/Bytecode/TranslationFlags.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
-#include "iree/compiler/InputConversion/Common/Passes.h"
#include "iree/compiler/InputConversion/MHLO/Passes.h"
#include "iree/compiler/InputConversion/TOSA/Passes.h"
#include "iree/compiler/Utils/TracingUtils.h"
@@ -180,7 +179,6 @@
break;
}
- buildCommonInputConversionPassPipeline(passManager);
IREE::Flow::buildFlowTransformPassPipeline(passManager);
IREE::HAL::buildHALTransformPassPipeline(passManager, executableOptions);
IREE::VM::buildVMTransformPassPipeline(passManager, targetOptions);
diff --git a/iree/test/e2e/regression/BUILD b/iree/test/e2e/regression/BUILD
index 2dd099d..0946c8e 100644
--- a/iree/test/e2e/regression/BUILD
+++ b/iree/test/e2e/regression/BUILD
@@ -35,6 +35,7 @@
"dynamic_torch_index_select_vector.mlir",
"globals.mlir",
"scalar.mlir",
+ "tensor_cast.mlir",
"trace_dispatch_tensors.mlir",
"unused_args.mlir",
],
diff --git a/iree/test/e2e/regression/CMakeLists.txt b/iree/test/e2e/regression/CMakeLists.txt
index b15bafa..3e6bc37 100644
--- a/iree/test/e2e/regression/CMakeLists.txt
+++ b/iree/test/e2e/regression/CMakeLists.txt
@@ -23,6 +23,7 @@
"dynamic_torch_index_select_vector.mlir"
"globals.mlir"
"scalar.mlir"
+ "tensor_cast.mlir"
"trace_dispatch_tensors.mlir"
"unused_args.mlir"
DATA
diff --git a/iree/test/e2e/regression/tensor_cast.mlir b/iree/test/e2e/regression/tensor_cast.mlir
new file mode 100644
index 0000000..65845eb
--- /dev/null
+++ b/iree/test/e2e/regression/tensor_cast.mlir
@@ -0,0 +1,10 @@
+// RUN: iree-run-mlir -iree-hal-target-backends=vmvx %s | IreeFileCheck %s
+// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=dylib-llvm-aot %s | IreeFileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv %s | IreeFileCheck %s)
+
+func @tensor_cast() -> tensor<2x?xf32> {
+ %input = iree.unfoldable_constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
+ %result = tensor.cast %input : tensor<2x3xf32> to tensor<2x?xf32>
+ return %result : tensor<2x?xf32>
+}
+// CHECK: 2x3xf32=[1 2 3][4 5 6]