Merge pull request #6114 from rsuderman:main-to-google
PiperOrigin-RevId: 377579548
diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index 0920faf..0e47f9b 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -4,15 +4,15 @@
4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest
88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
-528af9c01dc4080331184ece3ced3be6beaad47f third_party/llvm-bazel
-c89dff5855bb32d47751cce087537c2b12a90f1b third_party/llvm-project
+12a96e87d2eb377b7884aae56773bca50f12e7c7 third_party/llvm-bazel
+b109172d993edacd9853a8bbb8128a94da014399 third_party/llvm-project
108a78da82049553b41c7a0f5987c67d5006af8d third_party/mlir-emitc
-fe42a08fc93830f06150b93bcb764e2386ca3110 third_party/mlir-hlo
+8b3a75ea25ceca9070938ced4b5909042765ea0c third_party/mlir-hlo
d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
b42009b3b9d4ca35bc703f5310eedc74f584be58 third_party/stblib
-8be2640d44c8b4012e7863c93eea3ccb14f75151 third_party/tensorflow
+1d497e6419ca142acd188632c228c11d7c074a34 third_party/tensorflow
f03b677ffa0fd96fcf859c32e79b740fac7dd59e third_party/tracy
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
3528e2aed3e8808f33e1e7d63eeb1560456a605a third_party/vulkan_memory_allocator
diff --git a/integrations/tensorflow/e2e/keras/applications/BUILD b/integrations/tensorflow/e2e/keras/applications/BUILD
index 9fce0f0..cd360b9 100644
--- a/integrations/tensorflow/e2e/keras/applications/BUILD
+++ b/integrations/tensorflow/e2e/keras/applications/BUILD
@@ -103,8 +103,15 @@
failing_configurations = [
# Frequently OOMs
{
- "target_backends": "tflite",
- "model": "VGG19",
+ "target_backends": [
+ "tflite",
+ "iree_llvmaot",
+ "iree_vulkan",
+ ],
+ "model": [
+ "VGG16",
+ "VGG19",
+ ],
},
],
matrix = {
diff --git a/integrations/tensorflow/e2e/keras/applications/CMakeLists.txt b/integrations/tensorflow/e2e/keras/applications/CMakeLists.txt
index 7760534..448e0e4 100644
--- a/integrations/tensorflow/e2e/keras/applications/CMakeLists.txt
+++ b/integrations/tensorflow/e2e/keras/applications/CMakeLists.txt
@@ -27,7 +27,12 @@
"MobileNet;MobileNetV2;ResNet50;VGG16;VGG19"
"tf;tflite;iree_llvmaot;iree_vulkan"
FAILING_CONFIGURATIONS
+ ",,,VGG16,tflite"
",,,VGG19,tflite"
+ ",,,VGG16,iree_llvmaot"
+ ",,,VGG19,iree_llvmaot"
+ ",,,VGG16,iree_vulkan"
+ ",,,VGG19,iree_vulkan"
LABELS
"manual"
)
diff --git a/integrations/tensorflow/iree_tf_compiler/BUILD b/integrations/tensorflow/iree_tf_compiler/BUILD
index 03f6647..793693e 100644
--- a/integrations/tensorflow/iree_tf_compiler/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/BUILD
@@ -103,6 +103,7 @@
"//iree_tf_compiler/MHLO",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
diff --git a/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD b/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
index b9bf77a..986b490 100644
--- a/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/MHLO/BUILD
@@ -37,9 +37,12 @@
"@iree//iree/compiler/Dialect/Flow/IR",
"@iree//iree/compiler/Dialect/Flow/Transforms",
"@iree//iree/compiler/Dialect/IREE/IR",
+ "@iree//iree/compiler/Dialect/Shape/Conversion",
+ "@iree//iree/compiler/Dialect/Shape/Transforms",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFToStandard",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeTransforms",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
index a745d9d..fc63547 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
@@ -46,6 +46,8 @@
"@iree//iree/compiler/Dialect/Shape/Transforms",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeTransforms",
diff --git a/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp b/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp
index 6615a66..f5e75a5 100644
--- a/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp
+++ b/iree/compiler/Conversion/Common/BufferAllocViewCleanUpPass.cpp
@@ -44,14 +44,15 @@
/// !flow.dispatch.tensor<readonly:864xf32>
/// %0 = flow.dispatch.tensor.load %subspan :
/// !flow.dispatch.tensor<readonly:864xf32> -> tensor<864xf32>
-struct FoldReshapeIntoInterfaceTensorLoad
- : OpRewritePattern<linalg::TensorReshapeOp> {
- using OpRewritePattern::OpRewritePattern;
+template <typename TensorReshapeOp>
+struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern<TensorReshapeOp> {
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(linalg::TensorReshapeOp reshapeOp,
+ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
auto loadOp =
- reshapeOp.src().getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+ reshapeOp.src()
+ .template getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
if (!loadOp) return failure();
// Make sure we are loading the full incoming subspan. Otherwise we cannot
@@ -61,11 +62,14 @@
return failure();
auto subspanOp =
- loadOp.source().getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+ loadOp.source()
+ .template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
if (!subspanOp) return failure();
auto newSubspanType = IREE::Flow::DispatchTensorType::get(
- subspanOp.getType().cast<IREE::Flow::DispatchTensorType>().getAccess(),
+ subspanOp.getType()
+ .template cast<IREE::Flow::DispatchTensorType>()
+ .getAccess(),
reshapeOp.getResultType());
Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
@@ -101,8 +105,10 @@
: public PassWrapper<BufferAllocViewCleanUpPass, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns(&getContext());
- patterns.insert<FoldReshapeIntoInterfaceTensorLoad, RemoveDeadMemAllocs>(
- &getContext());
+ patterns.insert<
+ FoldReshapeIntoInterfaceTensorLoad<linalg::TensorCollapseShapeOp>,
+ FoldReshapeIntoInterfaceTensorLoad<linalg::TensorExpandShapeOp>,
+ RemoveDeadMemAllocs>(&getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
diff --git a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
index d0766db..ffbd37d 100644
--- a/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
@@ -85,9 +85,14 @@
if (!definingOp) return false;
return TypeSwitch<Operation *, bool>(definingOp)
.Case<ConstantOp>([&](ConstantOp constantOp) { return true; })
- .Case<linalg::TensorReshapeOp>([&](linalg::TensorReshapeOp reshapeOp) {
- return isFromReadOnlyTensor(reshapeOp.src());
- })
+ .Case<linalg::TensorExpandShapeOp>(
+ [&](linalg::TensorExpandShapeOp reshapeOp) {
+ return isFromReadOnlyTensor(reshapeOp.src());
+ })
+ .Case<linalg::TensorCollapseShapeOp>(
+ [&](linalg::TensorCollapseShapeOp reshapeOp) {
+ return isFromReadOnlyTensor(reshapeOp.src());
+ })
.Case<SubTensorOp>([&](SubTensorOp subTensorOp) {
return isFromReadOnlyTensor(subTensorOp.source());
})
@@ -108,7 +113,10 @@
// TODO(ravishankarm): Maybe this is too aggressive, might have to switch this
// to have a white-list instead of blacklist.
for (Operation *user : op->getUsers()) {
- if (isa<IREE::Flow::DispatchTensorStoreOp, linalg::TensorReshapeOp>(user))
+ if (isa<IREE::Flow::DispatchTensorStoreOp, linalg::TensorCollapseShapeOp>(
+ user) ||
+ isa<IREE::Flow::DispatchTensorStoreOp, linalg::TensorExpandShapeOp>(
+ user))
return false;
}
return true;
@@ -249,7 +257,8 @@
storeOp.getMixedStrides().empty())) {
SmallVector<Value> mappedTensors = plan.getTensorsMappedToSameSet(value);
for (auto v : mappedTensors) {
- if (v.getDefiningOp<linalg::TensorReshapeOp>()) return false;
+ if (v.getDefiningOp<linalg::TensorCollapseShapeOp>()) return false;
+ if (v.getDefiningOp<linalg::TensorExpandShapeOp>()) return false;
}
}
@@ -482,8 +491,13 @@
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
return analyseLinalgOps(linalgOp, plan);
})
- .Case<linalg::TensorReshapeOp>(
- [&](linalg::TensorReshapeOp tensorReshapeOp) {
+ .Case<linalg::TensorCollapseShapeOp>(
+ [&](linalg::TensorCollapseShapeOp tensorReshapeOp) {
+ return analyseSingleOperandResultOp(
+ tensorReshapeOp.src(), tensorReshapeOp.result(), plan);
+ })
+ .Case<linalg::TensorExpandShapeOp>(
+ [&](linalg::TensorExpandShapeOp tensorReshapeOp) {
return analyseSingleOperandResultOp(
tensorReshapeOp.src(), tensorReshapeOp.result(), plan);
})
@@ -681,16 +695,31 @@
return subview;
}
-/// Gets the reverse of a `linalg.tensor_reshape` op to get a memref type that
-/// can be used for in-place computation of the result of a disaptch region.
-static Value getReverseOfReshapeOp(OpBuilder &b,
- linalg::TensorReshapeOp reshapeOp,
- Value resultBuffer) {
+/// Gets the reverse of a `linalg.tensor_expand_shape` op to get a memref type
+/// that can be used for in-place computation of the result of a dispatch
+/// region.
+static Value getReverseOfExpandShapeOp(OpBuilder &b,
+ linalg::TensorExpandShapeOp expandOp,
+ Value resultBuffer) {
auto memrefType = getMemrefTypeForTensor(
- reshapeOp.getSrcType(), {},
+ expandOp.getSrcType(), {},
resultBuffer.getType().cast<MemRefType>().getMemorySpaceAsInt());
- return b.create<linalg::ReshapeOp>(reshapeOp.getLoc(), memrefType,
- resultBuffer, reshapeOp.reassociation());
+ return b.create<linalg::CollapseShapeOp>(
+ expandOp.getLoc(), memrefType, resultBuffer, expandOp.reassociation());
+}
+
+/// Gets the reverse of a `linalg.tensor_collapse_shape` op to get a memref type
+/// that can be used for in-place computation of the result of a dispatch
+/// region.
+static Value getReverseOfCollapseShapeOp(
+ OpBuilder &b, linalg::TensorCollapseShapeOp collapseOp,
+ Value resultBuffer) {
+ auto memrefType = getMemrefTypeForTensor(
+ collapseOp.getSrcType(), {},
+ resultBuffer.getType().cast<MemRefType>().getMemorySpaceAsInt());
+ return b.create<linalg::ExpandShapeOp>(collapseOp.getLoc(), memrefType,
+ resultBuffer,
+ collapseOp.reassociation());
}
/// Gets the reverse of a `tensor.cast` op to get a memref type that
@@ -763,9 +792,14 @@
.Case<scf::ForOp, linalg::LinalgOp, SubTensorInsertOp,
vector::TransferWriteOp>(
[&](auto op) { return resultBuffer; })
- .Case<linalg::TensorReshapeOp>(
- [&](linalg::TensorReshapeOp reshapeOp) {
- return getReverseOfReshapeOp(b, reshapeOp, resultBuffer);
+ .Case<linalg::TensorExpandShapeOp>(
+ [&](linalg::TensorExpandShapeOp expandOp) {
+ return getReverseOfExpandShapeOp(b, expandOp, resultBuffer);
+ })
+ .Case<linalg::TensorCollapseShapeOp>(
+ [&](linalg::TensorCollapseShapeOp collapseOp) {
+ return getReverseOfCollapseShapeOp(b, collapseOp,
+ resultBuffer);
})
.Case<tensor::CastOp>([&](tensor::CastOp castOp) {
return getReverseOfCastOp(b, castOp, resultBuffer);
@@ -807,10 +841,10 @@
loadOp.getMixedSizes(), loadOp.getMixedStrides());
}
-/// Converts a `linalg.tensor_reshape` operation to a `linalg.reshape`
+/// Converts a `linalg.tensor_expand_shape` operation to a `linalg.expand_shape`
/// operation with the result aliasing the buffer for the operand.
static Value getAliasingBufferForResult(OpBuilder &b,
- linalg::TensorReshapeOp op,
+ linalg::TensorExpandShapeOp op,
BlockAndValueMapping &bvm) {
Location loc = op.getLoc();
Value srcTensor = op.src();
@@ -821,7 +855,27 @@
MemRefType inputBufferType = inputBuffer.getType().cast<MemRefType>();
auto reshapeResultType = getMemrefTypeForTensor(
resultTensorType, {}, inputBufferType.getMemorySpaceAsInt());
- Value bufferReshape = b.create<linalg::ReshapeOp>(
+ Value bufferReshape = b.create<linalg::ExpandShapeOp>(
+ loc, reshapeResultType, inputBuffer, op.reassociation());
+ return bufferReshape;
+}
+
+/// Converts a `linalg.tensor_collapse_shape` operation to a
+/// `linalg.collapse_shape` operation with the result aliasing the buffer for
+/// the operand.
+static Value getAliasingBufferForResult(OpBuilder &b,
+ linalg::TensorCollapseShapeOp op,
+ BlockAndValueMapping &bvm) {
+ Location loc = op.getLoc();
+ Value srcTensor = op.src();
+ RankedTensorType resultTensorType = op.getResultType();
+ Value inputBuffer = bvm.lookup(srcTensor);
+
+ // Create the reshape op.
+ MemRefType inputBufferType = inputBuffer.getType().cast<MemRefType>();
+ auto reshapeResultType = getMemrefTypeForTensor(
+ resultTensorType, {}, inputBufferType.getMemorySpaceAsInt());
+ Value bufferReshape = b.create<linalg::CollapseShapeOp>(
loc, reshapeResultType, inputBuffer, op.reassociation());
return bufferReshape;
}
@@ -866,7 +920,12 @@
static SmallVector<Value, 4> getAliasingBuffersForResults(
OpBuilder &b, Operation *op, BlockAndValueMapping &bvm) {
return TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
- .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorReshapeOp,
+ .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorCollapseShapeOp,
+ SubTensorOp, tensor::CastOp>(
+ [&](auto singleResultOp) -> SmallVector<Value, 4> {
+ return {getAliasingBufferForResult(b, singleResultOp, bvm)};
+ })
+ .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorExpandShapeOp,
SubTensorOp, tensor::CastOp>(
[&](auto singleResultOp) -> SmallVector<Value, 4> {
return {getAliasingBufferForResult(b, singleResultOp, bvm)};
@@ -1301,7 +1360,21 @@
}
return convertScfForOp(b, forOp, bvm, plan);
})
- .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorReshapeOp,
+ .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorCollapseShapeOp,
+ SubTensorOp, tensor::CastOp>([&](auto aliasingOp) {
+ auto aliasingBuffers =
+ getAliasingBuffersForResults(b, aliasingOp, bvm);
+ if (failed(getOrAllocateResultBuffers(
+ b, aliasingOp, aliasingOp->getOperand(0), aliasingBuffers,
+ bvm, plan, allocationFn))) {
+ return failure();
+ }
+ copyFromAliasingBufferToResultBuffer(
+ b, aliasingOp->getLoc(), aliasingOp->getOperand(0),
+ aliasingOp->getResult(0), aliasingBuffers, bvm, plan);
+ return success();
+ })
+ .Case<IREE::Flow::DispatchTensorLoadOp, linalg::TensorExpandShapeOp,
SubTensorOp, tensor::CastOp>([&](auto aliasingOp) {
auto aliasingBuffers =
getAliasingBuffersForResults(b, aliasingOp, bvm);
diff --git a/iree/compiler/Conversion/Common/test/canonicalize_interface_load_store.mlir b/iree/compiler/Conversion/Common/test/canonicalize_interface_load_store.mlir
index ff8df97..2870012 100644
--- a/iree/compiler/Conversion/Common/test/canonicalize_interface_load_store.mlir
+++ b/iree/compiler/Conversion/Common/test/canonicalize_interface_load_store.mlir
@@ -10,8 +10,8 @@
%2 = hal.interface.binding.subspan @interface_io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x3x96xf32>
// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG]], {{.*}} : !flow.dispatch.tensor<readonly:3x3x96xf32> -> tensor<3x3x96xf32>
%3 = flow.dispatch.tensor.load %1, offsets=[], sizes =[], strides=[] : !flow.dispatch.tensor<readonly:3x3x1x96xf32> -> tensor<3x3x1x96xf32>
- %4 = linalg.tensor_reshape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
- %5 = linalg.tensor_reshape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
+ %4 = linalg.tensor_collapse_shape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
+ %5 = linalg.tensor_expand_shape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
// CHECK: flow.dispatch.tensor.store %[[LOAD]], {{.*}}
flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, %c1, %c1], strides = [%c1, %c1, %c1] : tensor<3x3x96xf32> -> !flow.dispatch.tensor<writeonly:3x3x96xf32>
return
@@ -34,9 +34,10 @@
%1 = hal.interface.binding.subspan @interface_io::@arg0[%c0] : !flow.dispatch.tensor<readonly:6x3x1x96xf32>
%2 = hal.interface.binding.subspan @interface_io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x3x96xf32>
%3 = flow.dispatch.tensor.load %1, offsets = [%c3, %c0, %c0, %c0], sizes = [%c3, %c3, %c1, %c96], strides = [%c1, %c1, %c1, %c1] : !flow.dispatch.tensor<readonly:6x3x1x96xf32> -> tensor<3x3x1x96xf32>
- // CHECK-COUNT-2: linalg.tensor_reshape
- %4 = linalg.tensor_reshape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
- %5 = linalg.tensor_reshape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
+ // CHECK: linalg.tensor_collapse_shape
+ // CHECK: linalg.tensor_expand_shape
+ %4 = linalg.tensor_collapse_shape %3 [[0, 1, 2, 3]] : tensor<3x3x1x96xf32> into tensor<864xf32>
+ %5 = linalg.tensor_expand_shape %4 [[0, 1, 2]] : tensor<864xf32> into tensor<3x3x96xf32>
flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, %c1, %c1], strides = [%c1, %c1, %c1] : tensor<3x3x96xf32> -> !flow.dispatch.tensor<writeonly:3x3x96xf32>
return
}
diff --git a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
index 0d88745..0ba2b4f 100644
--- a/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Conversion/Common/test/linalg_bufferize.mlir
@@ -644,7 +644,7 @@
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:12xi32>
%1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
%2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
- %3 = linalg.tensor_reshape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+ %3 = linalg.tensor_expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
flow.dispatch.tensor.store %3, %1, offsets = [], sizes = [], strides = [] : tensor<3x4xi32> -> !flow.dispatch.tensor<writeonly:3x4xi32>
return
}
@@ -655,7 +655,7 @@
// CHECK: func @reshape_simple()
// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0
-// CHECK: %[[RESHAPE:.+]] = linalg.reshape %[[ARG0]] {{\[}}[0, 1]]
+// CHECK: %[[RESHAPE:.+]] = linalg.expand_shape %[[ARG0]] {{\[}}[0, 1]]
// CHECK: linalg.copy(%[[RESHAPE]], %[[RET0]])
// -----
@@ -669,7 +669,7 @@
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:12xi32>
%1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
%2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
- %3 = linalg.tensor_reshape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+ %3 = linalg.tensor_expand_shape %2 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
%4 = linalg.init_tensor [3, 4] : tensor<3x4xi32>
%5 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -690,7 +690,7 @@
// CHECK: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]] : memref<12xi32>
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]] : memref<3x4xi32>
-// CHECK: %[[RESHAPE:.+]] = linalg.reshape %[[ARG0]] {{\[}}[0, 1]]
+// CHECK: %[[RESHAPE:.+]] = linalg.expand_shape %[[ARG0]] {{\[}}[0, 1]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]] : memref<3x4xi32>)
// CHECK-SAME: outs(%[[RET0]] : memref<3x4xi32>)
@@ -707,7 +707,7 @@
%1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
%2 = hal.interface.binding.subspan @io::@ret1[%c0] : !flow.dispatch.tensor<writeonly:3x4xi32>
%3 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:12xi32> -> tensor<12xi32>
- %4 = linalg.tensor_reshape %3 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
+ %4 = linalg.tensor_expand_shape %3 [[0, 1]] : tensor<12xi32> into tensor<3x4xi32>
%5 = linalg.init_tensor [3, 4] : tensor<3x4xi32>
%6 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -731,7 +731,7 @@
// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]] : memref<12xi32>
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]] : memref<3x4xi32>
// CHECK-DAG: %[[RET1:.+]] = hal.interface.binding.subspan @io::@ret1[%[[C0]]] : memref<3x4xi32>
-// CHECK: %[[RESHAPE:.+]] = linalg.reshape %[[ARG0]] {{\[}}[0, 1]]
+// CHECK: %[[RESHAPE:.+]] = linalg.expand_shape %[[ARG0]] {{\[}}[0, 1]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]] : memref<3x4xi32>)
// CHECK-SAME: outs(%[[RET0]] : memref<3x4xi32>)
@@ -757,7 +757,7 @@
%5 = addi %arg0, %arg0 : i32
linalg.yield %5 : i32
} -> tensor<3x4xi32>
- %5 = linalg.tensor_reshape %4 [[0, 1]] : tensor<3x4xi32> into tensor<12xi32>
+ %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<3x4xi32> into tensor<12xi32>
flow.dispatch.tensor.store %5, %1, offsets = [], sizes = [], strides = [] : tensor<12xi32> -> !flow.dispatch.tensor<writeonly:12xi32>
return
}
@@ -769,7 +769,7 @@
// CHECK: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]] : memref<3x4xi32>
// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]] : memref<12xi32>
-// CHECK: %[[RESHAPE:.+]] = linalg.reshape %[[RET0]] {{\[}}[0, 1]]
+// CHECK: %[[RESHAPE:.+]] = linalg.expand_shape %[[RET0]] {{\[}}[0, 1]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : memref<3x4xi32>)
// CHECK-SAME: outs(%[[RESHAPE]] : memref<3x4xi32>)
@@ -786,7 +786,7 @@
%1 = hal.interface.binding.subspan @io::@arg1[%c0] : !flow.dispatch.tensor<readonly:2x3xf32>
%2 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:1x3xf32>
%3 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:1x1x2xf32> -> tensor<1x1x2xf32>
- %4 = linalg.tensor_reshape %3 [[0, 1], [2]] : tensor<1x1x2xf32> into tensor<1x2xf32>
+ %4 = linalg.tensor_collapse_shape %3 [[0, 1], [2]] : tensor<1x1x2xf32> into tensor<1x2xf32>
%workgroup_size_x = hal.interface.workgroup.size[0] : index
%workgroup_size_y = hal.interface.workgroup.size[1] : index
%workgroup_id_x = hal.interface.workgroup.id[0] : index
@@ -819,7 +819,7 @@
// CHECK-LABEL: func @dot_general_lowering()
// CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan @io::@arg0
// CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan @io::@arg1
-// CHECK-DAG: %[[RESHAPE_LHS:.+]] = linalg.reshape %[[LHS]]
+// CHECK-DAG: %[[RESHAPE_LHS:.+]] = linalg.collapse_shape %[[LHS]]
// CHECK-DAG: %[[RETURN:.+]] = hal.interface.binding.subspan @io::@ret0
// CHECK: scf.for %[[IV0:.+]] = {{.+}} {
// CHECK: scf.for %[[IV1:.+]] = {{.+}} {
@@ -1063,7 +1063,7 @@
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : !flow.dispatch.tensor<readonly:1x5x3x1xf32>
%1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:5x5xf32>
%2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:1x5x3x1xf32> -> tensor<1x5x3x1xf32>
- %3 = linalg.tensor_reshape %2 [[0, 1], [2, 3]] : tensor<1x5x3x1xf32> into tensor<5x3xf32>
+ %3 = linalg.tensor_collapse_shape %2 [[0, 1], [2, 3]] : tensor<1x5x3x1xf32> into tensor<5x3xf32>
%workgroup_size_x = hal.interface.workgroup.size[0] : index
%workgroup_size_y = hal.interface.workgroup.size[1] : index
%workgroup_id_x = hal.interface.workgroup.id[0] : index
@@ -1097,7 +1097,7 @@
// CHECK-DAG: %[[RHS:.+]] = memref.buffer_cast %[[CONSTANT]]
// CHECK-DAG: %[[LHS_INPUT:.+]] = hal.interface.binding.subspan @io::@arg0[%{{.+}}] : memref<1x5x3x1xf32>
// CHECK-DAG: %[[RETURN:.+]] = hal.interface.binding.subspan @io::@ret0[%{{.+}}] : memref<5x5xf32>
-// CHECK: %[[LHS:.+]] = linalg.reshape %[[LHS_INPUT]]
+// CHECK: %[[LHS:.+]] = linalg.collapse_shape %[[LHS_INPUT]]
// CHECK: scf.for %[[IV0:.+]] =
// CHECK: scf.for %[[IV1:.+]] =
// CHECK-DAG: %[[LHS_SUBVIEW:.+]] = memref.subview %[[LHS]][%[[IV0]], 0]
@@ -1260,7 +1260,7 @@
%0 = hal.interface.binding.subspan @io::@ro0[%c0] : !flow.dispatch.tensor<readonly:?x?xf32>
%1 = hal.interface.binding.subspan @io::@wo0[%c0] : !flow.dispatch.tensor<writeonly:?xf32>
%2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
- %3 = linalg.tensor_reshape %2 [[0, 1]]
+ %3 = linalg.tensor_collapse_shape %2 [[0, 1]]
: tensor<?x?xf32> into tensor<?xf32>
%4 = memref.dim %3, %c0 : tensor<?xf32>
%5 = linalg.init_tensor [%4] : tensor<?xf32>
@@ -1278,7 +1278,7 @@
// CHECK-LABEL: func @reshape_read_only
// CHECK-DAG: %[[INPUT:.+]] = hal.interface.binding.subspan @io::@ro0
// CHECK-DAG: %[[OUTPUT:.+]] = hal.interface.binding.subspan @io::@wo0
-// CHECK: %[[RESHAPE:.+]] = linalg.reshape %[[INPUT]]
+// CHECK: %[[RESHAPE:.+]] = linalg.collapse_shape %[[INPUT]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]] : memref<?xf32>)
// CHECK-SAME: outs(%[[OUTPUT]] : memref<?xf32>)
@@ -1800,7 +1800,7 @@
linalg.yield %cst : f32
} : tensor<2x?xf32> to tensor<4x4xf32>
%15 = linalg.init_tensor [4, 4] : tensor<4x4xf32>
- %16 = linalg.fill(%15, %cst) : tensor<4x4xf32>, f32 -> tensor<4x4xf32>
+ %16 = linalg.fill(%15, %cst) : tensor<4x4xf32>, f32 -> tensor<4x4xf32>
%17 = linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%13, %14 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%16 : tensor<4x4xf32>) -> tensor<4x4xf32>
%18 = subtensor %17[0, 0] [%7, %9] [1, 1] : tensor<4x4xf32> to tensor<?x?xf32>
flow.dispatch.tensor.store %18, %2, offsets = [%arg0, %arg1], sizes = [%11, %12], strides = [1, 1] : tensor<?x?xf32> -> !flow.dispatch.tensor<writeonly:?x?xf32>
@@ -1877,17 +1877,17 @@
%13 = flow.dispatch.tensor.load %0, offsets = [0, %9, %11, 0], sizes = [1, %10, %12, 8], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x225x225x8xf32> -> tensor<1x?x?x8xf32>
%14 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 8, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x8x32xf32> -> tensor<3x3x8x4xf32>
%15 = linalg.init_tensor [1, 16, 16, 4] : tensor<1x16x16x4xf32>
- %16 = linalg.fill(%15, %cst) {__internal_linalg_transform__ = "workgroup"} : tensor<1x16x16x4xf32>, f32 -> tensor<1x16x16x4xf32>
+ %16 = linalg.fill(%15, %cst) {__internal_linalg_transform__ = "workgroup"} : tensor<1x16x16x4xf32>, f32 -> tensor<1x16x16x4xf32>
%17 = linalg.init_tensor [1, 16, 16, 3, 3, 8] : tensor<1x16x16x3x3x8xf32>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 2 + d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<1x?x?x8xf32>) outs(%17 : tensor<1x16x16x3x3x8xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
linalg.yield %arg3 : f32
} -> tensor<1x16x16x3x3x8xf32>
- %19 = linalg.tensor_reshape %18 [[0, 1, 2], [3, 4, 5]] : tensor<1x16x16x3x3x8xf32> into tensor<256x72xf32>
- %20 = linalg.tensor_reshape %14 [[0, 1, 2], [3]] : tensor<3x3x8x4xf32> into tensor<72x4xf32>
- %21 = linalg.tensor_reshape %16 [[0, 1, 2], [3]] : tensor<1x16x16x4xf32> into tensor<256x4xf32>
+ %19 = linalg.tensor_collapse_shape %18 [[0, 1, 2], [3, 4, 5]] : tensor<1x16x16x3x3x8xf32> into tensor<256x72xf32>
+ %20 = linalg.tensor_collapse_shape %14 [[0, 1, 2], [3]] : tensor<3x3x8x4xf32> into tensor<72x4xf32>
+ %21 = linalg.tensor_collapse_shape %16 [[0, 1, 2], [3]] : tensor<1x16x16x4xf32> into tensor<256x4xf32>
%22 = linalg.matmul ins(%19, %20 : tensor<256x72xf32>, tensor<72x4xf32>) outs(%21 : tensor<256x4xf32>) -> tensor<256x4xf32>
- %23 = linalg.tensor_reshape %22 [[0, 1, 2], [3]] : tensor<256x4xf32> into tensor<1x16x16x4xf32>
+ %23 = linalg.tensor_expand_shape %22 [[0, 1, 2], [3]] : tensor<256x4xf32> into tensor<1x16x16x4xf32>
%24 = tensor.cast %23 : tensor<1x16x16x4xf32> to tensor<1x?x?x?xf32>
flow.dispatch.tensor.store %24, %2, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %c16, %c16, %c4], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
}
@@ -1917,9 +1917,9 @@
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[ARG0_SV]]
// CHECK-SAME: outs(%[[ALLOC_ARG0]]
-// CHECK-DAG: %[[ALLOC_ARG0_RESHAPE:.+]] = linalg.reshape %[[ALLOC_ARG0]]
-// CHECK-DAG: %[[ALLOC_ARG1_RESHAPE:.+]] = linalg.reshape %[[ALLOC_ARG1]]
-// CHECK-DAG: %[[ALLOC_RET0_RESHAPE:.+]] = linalg.reshape %[[ALLOC_RET0]]
+// CHECK-DAG: %[[ALLOC_ARG0_RESHAPE:.+]] = linalg.collapse_shape %[[ALLOC_ARG0]]
+// CHECK-DAG: %[[ALLOC_ARG1_RESHAPE:.+]] = linalg.collapse_shape %[[ALLOC_ARG1]]
+// CHECK-DAG: %[[ALLOC_RET0_RESHAPE:.+]] = linalg.collapse_shape %[[ALLOC_RET0]]
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[ALLOC_ARG0_RESHAPE]], %[[ALLOC_ARG1_RESHAPE]]
// CHECK-SAME: outs(%[[ALLOC_RET0_RESHAPE]]
@@ -1948,8 +1948,8 @@
%7 = flow.dispatch.tensor.load %0, offsets = [0, %arg0], sizes = [%d0, %6], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
%9 = flow.dispatch.tensor.load %1, offsets = [0, %arg0], sizes = [%d0, %6], strides = [1, 1] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
%13 = linalg.init_tensor [%6] : tensor<?xi32>
- %14 = linalg.fill(%13, %c-2147483648_i32) {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[128]]}} : tensor<?xi32>, i32 -> tensor<?xi32>
- %17 = linalg.fill(%13, %c0_i32) {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[128]]}} : tensor<?xi32>, i32 -> tensor<?xi32>
+ %14 = linalg.fill(%13, %c-2147483648_i32) {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[128]]}} : tensor<?xi32>, i32 -> tensor<?xi32>
+ %17 = linalg.fill(%13, %c0_i32) {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[128]]}} : tensor<?xi32>, i32 -> tensor<?xi32>
%18:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%7, %9 : tensor<?x?xi32>, tensor<?x?xi32>) outs(%14, %17 : tensor<?xi32>, tensor<?xi32>) attrs = {__internal_linalg_transform__ = "workgroup", lowering.config = {tileSizes = [[128]]}} {
^bb0(%arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32): // no predecessors
%19 = cmpi sge, %arg1, %arg3 : i32
@@ -2028,7 +2028,7 @@
%9 = affine.min #map2(%arg1)
%10 = flow.dispatch.tensor.load %1, offsets = [0, %arg1], sizes = [144, %9], strides = [1, 1] : !flow.dispatch.tensor<readonly:144x370xf32> -> tensor<144x?xf32>
%11 = linalg.init_tensor [%7, %9] : tensor<?x?xf32>
- %12 = linalg.fill(%11, %cst) {__internal_linalg_transform__ = "workgroup", lowering.config = #config0} : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+ %12 = linalg.fill(%11, %cst) {__internal_linalg_transform__ = "workgroup", lowering.config = #config0} : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
%13 = scf.for %arg2 = %c0 to %c250 step %c32 iter_args(%arg3 = %12) -> (tensor<?x?xf32>) {
%14 = scf.for %arg4 = %c0 to %c370 step %c32 iter_args(%arg5 = %arg3) -> (tensor<?x?xf32>) {
%15 = scf.for %arg6 = %c0 to %c144 step %c24 iter_args(%arg7 = %arg5) -> (tensor<?x?xf32>) {
diff --git a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
index d907171..e7a97ec 100644
--- a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
@@ -28,9 +28,10 @@
namespace mlir {
namespace iree_compiler {
-
namespace {
+using linalg::LinalgOp;
+
/// Pass to fuse linalg on tensor operations as well as fusion of hal.interface*
/// operations with linalg.tensor_reshape operation.
struct FusionOfTensorOpsPass
@@ -58,7 +59,7 @@
if (!clEnableFusionWithReductionOps) {
auto consumerOp = consumer.getOwner();
if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(consumerOp) &&
- dyn_cast<linalg::LinalgOp>(consumerOp).getNumReductionLoops()) {
+ dyn_cast<LinalgOp>(consumerOp).getNumReductionLoops()) {
return false;
}
}
@@ -77,8 +78,14 @@
// to the consumer linalg op.
linalg::ControlElementwiseOpsFusionFn foldReshapeBetweenLinalgFn =
[](const OpResult &producer, const OpOperand &consumer) {
- auto reshapeOp = producer.getDefiningOp<linalg::TensorReshapeOp>();
- return reshapeOp.src().getDefiningOp<linalg::LinalgOp>() != nullptr;
+ auto collapseOp =
+ producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
+ if (collapseOp)
+ return collapseOp.src().getDefiningOp<LinalgOp>() != nullptr;
+ auto expandOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
+ if (expandOp)
+ return expandOp.src().getDefiningOp<LinalgOp>() != nullptr;
+ return false;
};
linalg::populateElementwiseOpsFusionPatterns(
fusionPatterns,
@@ -92,7 +99,9 @@
OwningRewritePatternList reshapeCanonicalizations(&getContext());
linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
reshapeCanonicalizations);
- linalg::TensorReshapeOp::getCanonicalizationPatterns(
+ linalg::TensorCollapseShapeOp::getCanonicalizationPatterns(
+ reshapeCanonicalizations, context);
+ linalg::TensorExpandShapeOp::getCanonicalizationPatterns(
reshapeCanonicalizations, context);
(void)applyPatternsAndFoldGreedily(op->getRegions(),
std::move(reshapeCanonicalizations));
@@ -100,8 +109,10 @@
// Push the remaining reshapes down the graphs.
OwningRewritePatternList pushReshapePatterns(&getContext());
linalg::populatePushReshapeOpsPatterns(pushReshapePatterns);
- linalg::TensorReshapeOp::getCanonicalizationPatterns(pushReshapePatterns,
- context);
+ linalg::TensorCollapseShapeOp::getCanonicalizationPatterns(
+ pushReshapePatterns, context);
+ linalg::TensorExpandShapeOp::getCanonicalizationPatterns(
+ pushReshapePatterns, context);
(void)applyPatternsAndFoldGreedily(op->getRegions(),
std::move(pushReshapePatterns));
}
diff --git a/iree/compiler/Conversion/LinalgToLinalg/Conv2D1x1ToMatmul.cpp b/iree/compiler/Conversion/LinalgToLinalg/Conv2D1x1ToMatmul.cpp
index a2cdc1a..bb50042 100644
--- a/iree/compiler/Conversion/LinalgToLinalg/Conv2D1x1ToMatmul.cpp
+++ b/iree/compiler/Conversion/LinalgToLinalg/Conv2D1x1ToMatmul.cpp
@@ -62,18 +62,18 @@
Value output = convOp.getOutput(0);
auto loc = convOp.getLoc();
- Value reshapedInput = rewriter.create<linalg::TensorReshapeOp>(
+ Value reshapedInput = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, reshapedInputType, input, reassociationIndices);
- Value reshapedFilter = rewriter.create<linalg::TensorReshapeOp>(
+ Value reshapedFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, reshapedFilterType, filter, reassociationIndices);
- Value reshapedOutput = rewriter.create<linalg::TensorReshapeOp>(
+ Value reshapedOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, reshapedOutputType, output, reassociationIndices);
auto matmulResult = rewriter.create<linalg::MatmulOp>(
loc, reshapedOutputType, ArrayRef<Value>{reshapedInput, reshapedFilter},
ArrayRef<Value>{reshapedOutput});
- auto reshapedResult = rewriter.create<linalg::TensorReshapeOp>(
+ auto reshapedResult = rewriter.create<linalg::TensorExpandShapeOp>(
loc, outputShapeType, matmulResult.getResults()[0],
reassociationIndices);
diff --git a/iree/compiler/Conversion/LinalgToLinalg/Conv2DToImg2Col.cpp b/iree/compiler/Conversion/LinalgToLinalg/Conv2DToImg2Col.cpp
index 3ee6078..4862acc 100644
--- a/iree/compiler/Conversion/LinalgToLinalg/Conv2DToImg2Col.cpp
+++ b/iree/compiler/Conversion/LinalgToLinalg/Conv2DToImg2Col.cpp
@@ -131,14 +131,15 @@
RankedTensorType::get({outputShape[1] * outputShape[2], outputShape[3]},
outputShapeType.getElementType());
- Value reshapedImg2ColTensor = rewriter.create<linalg::TensorReshapeOp>(
- loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
- img2ColTensorReassociationIndices);
+ Value reshapedImg2ColTensor =
+ rewriter.create<linalg::TensorCollapseShapeOp>(
+ loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
+ img2ColTensorReassociationIndices);
- Value reshapedFilter = rewriter.create<linalg::TensorReshapeOp>(
+ Value reshapedFilter = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, reshapedFilterType, filter, filterAndOutputReassociationIndices);
- Value reshapedOutput = rewriter.create<linalg::TensorReshapeOp>(
+ Value reshapedOutput = rewriter.create<linalg::TensorCollapseShapeOp>(
loc, reshapedOutputType, output, filterAndOutputReassociationIndices);
auto matmulResult = rewriter.create<linalg::MatmulOp>(
@@ -146,7 +147,7 @@
ArrayRef<Value>{reshapedImg2ColTensor, reshapedFilter},
ArrayRef<Value>{reshapedOutput});
- auto reshapedResult = rewriter.create<linalg::TensorReshapeOp>(
+ auto reshapedResult = rewriter.create<linalg::TensorExpandShapeOp>(
loc, outputShapeType, matmulResult.getResults()[0],
filterAndOutputReassociationIndices);
diff --git a/iree/compiler/Conversion/LinalgToLinalg/test/conv1x1_to_matmul.mlir b/iree/compiler/Conversion/LinalgToLinalg/test/conv1x1_to_matmul.mlir
index 067c66d..dfdd0c1 100644
--- a/iree/compiler/Conversion/LinalgToLinalg/test/conv1x1_to_matmul.mlir
+++ b/iree/compiler/Conversion/LinalgToLinalg/test/conv1x1_to_matmul.mlir
@@ -12,9 +12,9 @@
// CHECK: %[[INPUT:.+]]: tensor<1x4x5x2xf32>
// CHECK: %[[FILTER:.+]]: tensor<1x1x2x7xf32>
// CHECK: %[[OTUPUT:.+]] = linalg.init_tensor [1, 4, 5, 7] : tensor<1x4x5x7xf32>
-// CHECK: %[[RESHAPED_INPUT:.+]] = linalg.tensor_reshape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x2xf32> into tensor<20x2xf32>
-// CHECK: %[[RESHAPED_FILTER:.+]] = linalg.tensor_reshape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32>
-// CHECK: %[[RESHAPED_OUTPUT:.+]] = linalg.tensor_reshape %[[OTUPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32>
+// CHECK: %[[RESHAPED_INPUT:.+]] = linalg.tensor_collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x2xf32> into tensor<20x2xf32>
+// CHECK: %[[RESHAPED_FILTER:.+]] = linalg.tensor_collapse_shape %[[FILTER]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x2x7xf32> into tensor<2x7xf32>
+// CHECK: %[[RESHAPED_OUTPUT:.+]] = linalg.tensor_collapse_shape %[[OTUPUT]] {{\[}}[0, 1, 2], [3]] : tensor<1x4x5x7xf32> into tensor<20x7xf32>
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]] : tensor<20x2xf32>, tensor<2x7xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<20x7xf32>)
-// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<20x7xf32> into tensor<1x4x5x7xf32>
+// CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<20x7xf32> into tensor<1x4x5x7xf32>
// CHECK: return %[[RESULT]]
diff --git a/iree/compiler/Conversion/LinalgToLinalg/test/conv2d_to_img2col.mlir b/iree/compiler/Conversion/LinalgToLinalg/test/conv2d_to_img2col.mlir
index 7100ec7..e406535 100644
--- a/iree/compiler/Conversion/LinalgToLinalg/test/conv2d_to_img2col.mlir
+++ b/iree/compiler/Conversion/LinalgToLinalg/test/conv2d_to_img2col.mlir
@@ -20,14 +20,14 @@
// CHECK-SAME: #[[MAP1]]
// CHECK: ^bb0(%[[IN_DATA:.+]]: f32, %[[OUT_DATA:.+]]: f32)
// CHECK: linalg.yield %[[IN_DATA]] : f32
-// CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = linalg.tensor_reshape %[[COL_TENSOR]]
+// CHECK-DAG: %[[RESHAPED_INIT_COL_TENSOR:.+]] = linalg.tensor_collapse_shape %[[COL_TENSOR]]
// CHECK-SAME: [0, 1, 2], [3, 4, 5]
// CHECK-SAME: tensor<1x14x14x3x3x4xf32> into tensor<196x36xf32>
-// CHECK-DAG: %[[RESHAPED_FILTER:.+]] = linalg.tensor_reshape %[[FILTER]]
+// CHECK-DAG: %[[RESHAPED_FILTER:.+]] = linalg.tensor_collapse_shape %[[FILTER]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK-SAME: tensor<3x3x4x16xf32> into tensor<36x16xf32>
-// CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = linalg.tensor_reshape %[[OUTPUT]]
+// CHECK-DAG: %[[RESHAPED_OUTPUT:.+]] = linalg.tensor_collapse_shape %[[OUTPUT]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK: %[[MATMUL_RESULT:.+]] = linalg.matmul ins(%[[RESHAPED_INIT_COL_TENSOR]], %[[RESHAPED_FILTER]] : tensor<196x36xf32>, tensor<36x16xf32>) outs(%[[RESHAPED_OUTPUT]] : tensor<196x16xf32>)
-// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[MATMUL_RESULT]] {{\[}}[0, 1, 2], [3]] : tensor<196x16xf32> into tensor<1x14x14x16xf32>
// CHECK: return %[[RESULT]]
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index d5d1837..36c496a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -653,7 +653,7 @@
// Reshape ops are treated legal since they just change the way the underlying
// buffer is viewed. These are legalized downstream. They become no ops when
// lowering to SPIR-V since the SPIR-V code uses linearized arrays.
- target.addLegalOp<linalg::ReshapeOp>();
+ target.addLegalOp<linalg::CollapseShapeOp, linalg::ExpandShapeOp>();
// Let the rest fall through.
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
index b351349..130e1c6 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp
@@ -313,9 +313,10 @@
/// - tensor_to_memref can become a no-op since tensors are lowered to
/// !spv.array.
/// - unrealized_conversion_cast with the same source and target type.
- patterns
- .insert<FoldAsNoOp<linalg::ReshapeOp>, FoldAsNoOp<memref::BufferCastOp>,
- RemoveIdentityConversionCast>(typeConverter, context);
+ patterns.insert<
+ FoldAsNoOp<linalg::CollapseShapeOp>, FoldAsNoOp<linalg::ExpandShapeOp>,
+ FoldAsNoOp<memref::BufferCastOp>, RemoveIdentityConversionCast>(
+ typeConverter, context);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir
index 6d21f9c..d660072 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir
@@ -109,8 +109,8 @@
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<1x113x113x96xf32>
%1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<3x3x1x96xf32>
%2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<1x56x56x96xf32>
- %3 = linalg.reshape %1 [[0, 1, 2, 3]] : memref<3x3x1x96xf32> into memref<864xf32>
- %4 = linalg.reshape %3 [[0, 1, 2]] : memref<864xf32> into memref<3x3x96xf32>
+ %3 = linalg.collapse_shape %1 [[0, 1, 2, 3]] : memref<3x3x1x96xf32> into memref<864xf32>
+ %4 = linalg.expand_shape %3 [[0, 1, 2]] : memref<864xf32> into memref<3x3x96xf32>
%workgroup_size_x = hal.interface.workgroup.size[0] : index
%workgroup_size_y = hal.interface.workgroup.size[1] : index
%workgroup_size_z = hal.interface.workgroup.size[2] : index
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
index 3eb7143..6654559 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
@@ -130,13 +130,14 @@
/// Converts linalg.tensor_reshape operations into flow.tensor.reshape
/// operations.
+template <typename TensorReshapeOp>
struct LinalgTensorReshapeToFlowTensorReshape
- : public OpRewritePattern<linalg::TensorReshapeOp> {
- using OpRewritePattern<linalg::TensorReshapeOp>::OpRewritePattern;
+ : public OpRewritePattern<TensorReshapeOp> {
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(linalg::TensorReshapeOp reshapeOp,
+ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
- if (reshapeOp->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
+ if (reshapeOp->template getParentOfType<Flow::DispatchWorkgroupsOp>()) {
return failure();
}
SmallVector<SmallVector<Value>> outputShape;
@@ -263,9 +264,10 @@
MLIRContext *context = funcOp->getContext();
context->allowUnregisteredDialects(true);
RewritePatternSet patterns(&getContext());
- patterns.insert<LinalgTensorReshapeToFlowTensorReshape,
- SubTensorInsertToTensorUpdate, SubTensorToTensorSlice>(
- context);
+ patterns.insert<
+ LinalgTensorReshapeToFlowTensorReshape<linalg::TensorCollapseShapeOp>,
+ LinalgTensorReshapeToFlowTensorReshape<linalg::TensorExpandShapeOp>,
+ SubTensorInsertToTensorUpdate, SubTensorToTensorSlice>(context);
IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index dea0b8b..f1a8875 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -261,7 +261,9 @@
}
static bool isAlwaysFusedIntoDispatchOp(Operation *op) {
- return isDispatchableOp(op) && isa<linalg::TensorReshapeOp, SubTensorOp>(op);
+ return isDispatchableOp(op) &&
+ (isa<linalg::TensorCollapseShapeOp, SubTensorOp>(op) ||
+ isa<linalg::TensorExpandShapeOp, SubTensorOp>(op));
}
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
index 03b5ec5..3360b10 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
@@ -117,9 +117,9 @@
func @tensor_reshape(%arg0 : tensor<?x4x?x5x?x6xf32>, %arg1 : tensor<20x?x40xf32>)
-> (tensor<?x5x?xf32>, tensor<5x4x?x4x2x4x5xf32>)
{
- %0 = linalg.tensor_reshape %arg0 [[0, 1, 2], [3], [4, 5]]
+ %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3], [4, 5]]
: tensor<?x4x?x5x?x6xf32> into tensor<?x5x?xf32>
- %1 = linalg.tensor_reshape %arg1 [[0, 1], [2, 3], [4, 5, 6]]
+ %1 = linalg.tensor_expand_shape %arg1 [[0, 1], [2, 3], [4, 5, 6]]
: tensor<20x?x40xf32> into tensor<5x4x?x4x2x4x5xf32>
return %0, %1 : tensor<?x5x?xf32>, tensor<5x4x?x4x2x4x5xf32>
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 6266c07..ee0cdcd 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -212,7 +212,7 @@
func @fuse_reshape_op(%arg0: tensor<?x?xf32>) -> tensor<?xf32>
{
- %0 = linalg.tensor_reshape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
+ %0 = linalg.tensor_collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
@@ -228,7 +228,7 @@
// CHECK-NEXT: %[[ARG1:.+]]: !flow.dispatch.tensor<readonly:?x?xf32>
// CHECK-SAME: %[[ARG2:.+]]: !flow.dispatch.tensor<writeonly:?xf32>
// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG1]], {{.*}}
-// CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[LOAD]] {{\[}}[0, 1]]
+// CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[LOAD]] {{\[}}[0, 1]]
// CHECK: flow.dispatch.tensor.store %[[RESHAPE]], %[[ARG2]], {{.*}}
// -----
@@ -284,7 +284,7 @@
%cst = constant 0.0 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
- %0 = linalg.tensor_reshape %lhs [[0, 1]]
+ %0 = linalg.tensor_expand_shape %lhs [[0, 1]]
: tensor<?xf32> into tensor<?x4xf32>
%m = memref.dim %0, %c0 : tensor<?x4xf32>
%n1 = memref.dim %rhs1, %c1 : tensor<4x?xf32>
@@ -554,14 +554,14 @@
func @inline_dag_1(
%arg0: tensor<?xf32>, %arg1 : tensor<1x?xf32>, %arg2 : tensor<i32>,
%arg3 : index) -> tensor<?xf32> {
- %0 = linalg.tensor_reshape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+ %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
%1 = subtensor %0[0, 20] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
- %2 = linalg.tensor_reshape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
- %3 = linalg.tensor_reshape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+ %2 = linalg.tensor_collapse_shape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+ %3 = linalg.tensor_collapse_shape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
%4 = subtensor %0[0, 10] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
- %5 = linalg.tensor_reshape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+ %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
%6 = subtensor %0[0, 0] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
- %7 = linalg.tensor_reshape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+ %7 = linalg.tensor_collapse_shape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
%8 = linalg.init_tensor [%arg3] : tensor<?xf32>
%9 = linalg.generic {
indexing_maps = [#map0, #map1, #map1, #map1, #map1, #map1],
@@ -591,14 +591,14 @@
// CHECK: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], {{.*}}
// CHECK: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
// CHECK: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
-// CHECK: %[[OP1:.+]] = linalg.tensor_reshape %[[LEAF2]]
-// CHECK: %[[OP2:.+]] = linalg.tensor_reshape %[[LEAF1]]
+// CHECK: %[[OP1:.+]] = linalg.tensor_expand_shape %[[LEAF2]]
+// CHECK: %[[OP2:.+]] = linalg.tensor_collapse_shape %[[LEAF1]]
// CHECK: %[[OP3:.+]] = subtensor %[[OP1]][0, 0]
// CHECK: %[[OP4:.+]] = subtensor %[[OP1]][0, 10]
// CHECK: %[[OP5:.+]] = subtensor %[[OP1]][0, 20]
-// CHECK: %[[OP6:.+]] = linalg.tensor_reshape %[[OP3]]
-// CHECK: %[[OP7:.+]] = linalg.tensor_reshape %[[OP4]]
-// CHECK: %[[OP8:.+]] = linalg.tensor_reshape %[[OP5]]
+// CHECK: %[[OP6:.+]] = linalg.tensor_collapse_shape %[[OP3]]
+// CHECK: %[[OP7:.+]] = linalg.tensor_collapse_shape %[[OP4]]
+// CHECK: %[[OP8:.+]] = linalg.tensor_collapse_shape %[[OP5]]
// -----
@@ -607,16 +607,16 @@
func @inline_dag_2(
%arg0: tensor<?xf32>, %arg1 : tensor<1x?xf32>, %arg2 : tensor<i32>,
%arg3 : index) -> tensor<?xf32> {
- %0 = linalg.tensor_reshape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+ %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
%1 = subtensor %0[0, 20] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
- %2 = linalg.tensor_reshape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+ %2 = linalg.tensor_collapse_shape %arg1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
br ^bb1
^bb1:
- %3 = linalg.tensor_reshape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+ %3 = linalg.tensor_collapse_shape %1 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
%4 = subtensor %0[0, 10] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
- %5 = linalg.tensor_reshape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+ %5 = linalg.tensor_collapse_shape %4 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
%6 = subtensor %0[0, 0] [1, %arg3] [1, 1] : tensor<1x?xf32> to tensor<1x?xf32>
- %7 = linalg.tensor_reshape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
+ %7 = linalg.tensor_collapse_shape %6 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
%8 = linalg.init_tensor [%arg3] : tensor<?xf32>
%9 = linalg.generic {
indexing_maps = [#map0, #map1, #map1, #map1, #map1, #map1],
@@ -646,14 +646,14 @@
// CHECK: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], {{.*}}
// CHECK: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
// CHECK: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
-// CHECK: %[[OP1:.+]] = linalg.tensor_reshape %[[LEAF2]]
-// CHECK: %[[OP2:.+]] = linalg.tensor_reshape %[[LEAF1]]
+// CHECK: %[[OP1:.+]] = linalg.tensor_expand_shape %[[LEAF2]]
+// CHECK: %[[OP2:.+]] = linalg.tensor_collapse_shape %[[LEAF1]]
// CHECK: %[[OP3:.+]] = subtensor %[[OP1]][0, 0]
// CHECK: %[[OP4:.+]] = subtensor %[[OP1]][0, 10]
// CHECK: %[[OP5:.+]] = subtensor %[[OP1]][0, 20]
-// CHECK: %[[OP6:.+]] = linalg.tensor_reshape %[[OP3]]
-// CHECK: %[[OP7:.+]] = linalg.tensor_reshape %[[OP4]]
-// CHECK: %[[OP8:.+]] = linalg.tensor_reshape %[[OP5]]
+// CHECK: %[[OP6:.+]] = linalg.tensor_collapse_shape %[[OP3]]
+// CHECK: %[[OP7:.+]] = linalg.tensor_collapse_shape %[[OP4]]
+// CHECK: %[[OP8:.+]] = linalg.tensor_collapse_shape %[[OP5]]
// -----
diff --git a/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp b/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp
index 5ec9871..96e8fe1 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp
+++ b/iree/compiler/Dialect/Modules/VMVX/Conversion/StandardToVMVX/ConvertStandardToVMVX.cpp
@@ -68,7 +68,8 @@
OwningRewritePatternList &patterns,
TypeConverter &typeConverter) {
// We type/shape erase memrefs as we lower so there is no need for reshapes.
- patterns.insert<FoldAsNoOp<linalg::ReshapeOp>>(typeConverter, context);
+ patterns.insert<FoldAsNoOp<linalg::CollapseShapeOp>>(typeConverter, context);
+ patterns.insert<FoldAsNoOp<linalg::ExpandShapeOp>>(typeConverter, context);
patterns.insert<RemoveIdentityConversionCast>(typeConverter, context);
}
diff --git a/third_party/llvm-bazel b/third_party/llvm-bazel
index 528af9c..12a96e8 160000
--- a/third_party/llvm-bazel
+++ b/third_party/llvm-bazel
@@ -1 +1 @@
-Subproject commit 528af9c01dc4080331184ece3ced3be6beaad47f
+Subproject commit 12a96e87d2eb377b7884aae56773bca50f12e7c7
diff --git a/third_party/llvm-project b/third_party/llvm-project
index c89dff5..b109172 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit c89dff5855bb32d47751cce087537c2b12a90f1b
+Subproject commit b109172d993edacd9853a8bbb8128a94da014399
diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo
index fe42a08..8b3a75e 160000
--- a/third_party/mlir-hlo
+++ b/third_party/mlir-hlo
@@ -1 +1 @@
-Subproject commit fe42a08fc93830f06150b93bcb764e2386ca3110
+Subproject commit 8b3a75ea25ceca9070938ced4b5909042765ea0c
diff --git a/third_party/tensorflow b/third_party/tensorflow
index 8be2640..1d497e6 160000
--- a/third_party/tensorflow
+++ b/third_party/tensorflow
@@ -1 +1 @@
-Subproject commit 8be2640d44c8b4012e7863c93eea3ccb14f75151
+Subproject commit 1d497e6419ca142acd188632c228c11d7c074a34