Integrate LLVM at llvm/llvm-project@ce211c505b82
Updates LLVM usage to match
[ce211c505b82](https://github.com/llvm/llvm-project/commit/ce211c505b82)
PiperOrigin-RevId: 382707665
diff --git a/SUBMODULE_VERSIONS.txt b/SUBMODULE_VERSIONS.txt
index 6f8700c..6bf302e 100644
--- a/SUBMODULE_VERSIONS.txt
+++ b/SUBMODULE_VERSIONS.txt
@@ -4,7 +4,7 @@
aa533abfd4232b01f9e57041d70114d5a77e6de0 third_party/googletest
88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
acd6f6f014c25e46363e718381e0b35205df2d83 third_party/libyaml
-5b8ddd2ccceb8de04bd020f286bc3ca38638ecb1 third_party/llvm-project
+ce211c505b82e5bbb68b936968d9b54608285416 third_party/llvm-project
1a4dea1387e34538ba159a56204a6982e728e337 third_party/mlir-emitc
a41d23745eb902d7093ba6eaf4902c6ec8bf12b2 third_party/mlir-hlo
4c7697dbe973ed01ae6fbec37d186ebd05982e1f third_party/pybind11
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp b/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp
index 3537399..b2b0a03 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp
@@ -94,7 +94,7 @@
target.addLegalDialect<tensor::TensorDialect>();
target.addLegalOp<mlir::CallOp>();
target.addLegalOp<mlir::tensor::CastOp>();
- target.addLegalOp<mlir::memref::DimOp>();
+ target.addLegalOp<mlir::tensor::DimOp>();
// TODO(suderman): Enable logicistic op for lowering once the op is
// supported in IREE. Also, remove the numerically unstable ConvertSigmoidOp
diff --git a/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
index 436927f..bcf4343 100644
--- a/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
+++ b/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
@@ -12,9 +12,9 @@
// CHECK-NEXT: %[[ARG1_DIM0:.+]] = hal.buffer_view.dim %[[ARG1]], 0 : index
// CHECK-NEXT: %[[ARG1_TENSOR:.+]] = hal.tensor.cast %[[ARG1]] : !hal.buffer_view -> tensor<?x8x8x3xf32>{%[[ARG1_DIM0]]}
// CHECK-NEXT: %[[RET_TENSOR:.+]]:2 = call @_dynamicEntry(%[[ARG0_TENSOR]], %[[ARG1_TENSOR]])
-// CHECK: %[[RET0_DIM0:.+]] = memref.dim %[[RET_TENSOR]]#0, %c0{{.*}} : tensor<?x8x8x3xf32>
+// CHECK: %[[RET0_DIM0:.+]] = tensor.dim %[[RET_TENSOR]]#0, %c0{{.*}} : tensor<?x8x8x3xf32>
// CHECK-NEXT: %[[RET0_VIEW:.+]] = hal.tensor.cast %[[RET_TENSOR]]#0 : tensor<?x8x8x3xf32>{%[[RET0_DIM0]]} -> !hal.buffer_view
-// CHECK: %[[RET1_DIM0:.+]] = memref.dim %[[RET_TENSOR]]#1, %c0{{.*}} : tensor<?x8x8x3xf32>
+// CHECK: %[[RET1_DIM0:.+]] = tensor.dim %[[RET_TENSOR]]#1, %c0{{.*}} : tensor<?x8x8x3xf32>
// CHECK-NEXT: %[[RET1_VIEW:.+]] = hal.tensor.cast %[[RET_TENSOR]]#1 : tensor<?x8x8x3xf32>{%[[RET1_DIM0]]} -> !hal.buffer_view
// CHECK-NEXT: return %[[RET0_VIEW]], %[[RET1_VIEW]] : !hal.buffer_view, !hal.buffer_view
// CHECK-NEXT: }
diff --git a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
index ad130d0..403e436 100644
--- a/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
+++ b/iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
@@ -427,7 +427,7 @@
static bool hasSingleRealUse(Value value) {
int numUsers = 0;
for (OpOperand &use : value.getUses()) {
- if (!isa<memref::DimOp>(use.getOwner())) {
+ if (!isa<memref::DimOp, tensor::DimOp>(use.getOwner())) {
numUsers++;
}
}
@@ -549,7 +549,8 @@
[&](tensor::InsertSliceOp subTensorInsertOp) {
return subTensorInsertOp.dest() == arg;
})
- .Case<memref::DimOp, scf::YieldOp>([&](auto op) { return true; })
+ .Case<memref::DimOp, scf::YieldOp, tensor::DimOp>(
+ [&](auto op) { return true; })
.Default([&](Operation *op) { return false; });
};
if (llvm::all_of(arg.getUses(), isDestructiveUpdateUses)) {
@@ -1616,7 +1617,7 @@
.Case<tensor::ExtractOp>([&](tensor::ExtractOp op) {
return convertTensorExtractOp(b, op, bvm);
})
- .Case<memref::DimOp, vector::TransferReadOp>([&](auto op) {
+ .Case<vector::TransferReadOp>([&](auto op) {
for (unsigned i : llvm::seq<unsigned>(0, op->getNumOperands())) {
Value operand = op->getOperand(i);
if (operand.getType().isa<RankedTensorType>()) {
@@ -1626,6 +1627,14 @@
}
return success();
})
+ .Case<tensor::DimOp>([&](tensor::DimOp dimOp) {
+ Value operand = dimOp.source();
+ Value remappedVal = bvm.lookupOrNull(operand);
+ Value newDimOp = b.create<memref::DimOp>(dimOp.getLoc(), remappedVal,
+ dimOp.index());
+ dimOp.replaceAllUsesWith(newDimOp);
+ return success();
+ })
.Case<scf::ForOp>([&](scf::ForOp forOp) {
// To canonicalize the `scf.for` tensor result/operand/yield value
// away, forward the init argument to the yeild of the loop.
diff --git a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
index 07257b6..8bcfab8 100644
--- a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
@@ -970,8 +970,8 @@
%2 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xi32>
%3 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
%4 = flow.dispatch.tensor.load %1, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xi32> -> tensor<?x?xi32>
- %5 = memref.dim %3, %c0 : tensor<?x?xi32>
- %6 = memref.dim %3, %c1 : tensor<?x?xi32>
+ %5 = tensor.dim %3, %c0 : tensor<?x?xi32>
+ %6 = tensor.dim %3, %c1 : tensor<?x?xi32>
%7 = tensor.insert_slice %3 into %4[3, 4] [%5, %6] [1, 1] : tensor<?x?xi32> into tensor<?x?xi32>
flow.dispatch.tensor.store %7, %2, offsets = [], sizes = [], strides = [] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:?x?xi32>
return
@@ -1118,8 +1118,8 @@
%2 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<writeonly:?x?xf32>
%4 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = []: !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
%5 = flow.dispatch.tensor.load %1, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?xi32> -> tensor<?xi32>
- %d0 = memref.dim %5, %c0 : tensor<?xi32>
- %d1 = memref.dim %4, %c1 : tensor<?x?xf32>
+ %d0 = tensor.dim %5, %c0 : tensor<?xi32>
+ %d1 = tensor.dim %4, %c1 : tensor<?x?xf32>
%3 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<?xi32>) outs(%3 : tensor<?x?xf32>) {
^bb0( %arg2: i32, %arg3: f32): // no predecessors
@@ -1203,8 +1203,8 @@
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
%6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
- %dim0 = memref.dim %2, %c0 : tensor<?x?xf32>
- %dim1 = memref.dim %2, %c1 : tensor<?x?xf32>
+ %dim0 = tensor.dim %2, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %2, %c1 : tensor<?x?xf32>
scf.for %arg0 = %5 to %dim0 step %6 {
%7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
%8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
@@ -1263,7 +1263,7 @@
%2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
%3 = linalg.tensor_collapse_shape %2 [[0, 1]]
: tensor<?x?xf32> into tensor<?xf32>
- %4 = memref.dim %3, %c0 : tensor<?xf32>
+ %4 = tensor.dim %3, %c0 : tensor<?xf32>
%5 = linalg.init_tensor [%4] : tensor<?xf32>
%6 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
@@ -1496,8 +1496,8 @@
%1 = hal.interface.binding.subspan @io::@ret0[%c0] : !flow.dispatch.tensor<readwrite:?x?x?xf32>
%2 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?x?xf32> -> tensor<?x?xf32>
%3 = flow.dispatch.tensor.load %1, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readwrite:?x?x?xf32> -> tensor<?x?x?xf32>
- %4 = memref.dim %3, %c1 : tensor<?x?x?xf32>
- %5 = memref.dim %3, %c2 : tensor<?x?x?xf32>
+ %4 = tensor.dim %3, %c1 : tensor<?x?x?xf32>
+ %5 = tensor.dim %3, %c2 : tensor<?x?x?xf32>
%6 = tensor.insert_slice %2 into %3[0, 0, 0] [1, %4, %5] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
flow.dispatch.tensor.store %6, %1, offsets = [], sizes = [], strides = [] : tensor<?x?x?xf32> -> !flow.dispatch.tensor<readwrite:?x?x?xf32>
return
@@ -2035,9 +2035,9 @@
%18 = tensor.extract_slice %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
%19 = affine.min #map5(%arg4)
%20 = tensor.extract_slice %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
- %21 = memref.dim %arg7, %c0 : tensor<?x?xf32>
+ %21 = tensor.dim %arg7, %c0 : tensor<?x?xf32>
%22 = affine.min #map6(%21, %arg2)
- %23 = memref.dim %arg7, %c1 : tensor<?x?xf32>
+ %23 = tensor.dim %arg7, %c1 : tensor<?x?xf32>
%24 = affine.min #map6(%23, %arg4)
%25 = tensor.extract_slice %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%26 = linalg.matmul {__internal_linalg_transform__ = "workgroup_l1_tile", lowering.config = #config1} ins(%18, %20 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%25 : tensor<?x?xf32>) -> tensor<?x?xf32>
@@ -2140,9 +2140,9 @@
%18 = tensor.extract_slice %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
%19 = affine.min #map5(%arg4)
%20 = tensor.extract_slice %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
- %21 = memref.dim %arg7, %c0 : tensor<?x?xf32>
+ %21 = tensor.dim %arg7, %c0 : tensor<?x?xf32>
%22 = affine.min #map6(%21, %arg2)
- %23 = memref.dim %arg7, %c1 : tensor<?x?xf32>
+ %23 = tensor.dim %arg7, %c1 : tensor<?x?xf32>
%24 = affine.min #map6(%23, %arg4)
%25 = tensor.extract_slice %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%26 = linalg.matmul {__internal_linalg_transform__ = "workgroup_l1_tile", lowering.config = #config1} ins(%18, %20 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%25 : tensor<?x?xf32>) -> tensor<?x?xf32>
@@ -2241,9 +2241,9 @@
%18 = tensor.extract_slice %8[%arg2, %arg6] [%16, %17] [1, 1] : tensor<?x144xf32> to tensor<?x?xf32>
%19 = affine.min #map5(%arg4)
%20 = tensor.extract_slice %10[%arg6, %arg4] [%17, %19] [1, 1] : tensor<144x?xf32> to tensor<?x?xf32>
- %21 = memref.dim %arg7, %c0 : tensor<?x?xf32>
+ %21 = tensor.dim %arg7, %c0 : tensor<?x?xf32>
%22 = affine.min #map6(%21, %arg2)
- %23 = memref.dim %arg7, %c1 : tensor<?x?xf32>
+ %23 = tensor.dim %arg7, %c1 : tensor<?x?xf32>
%24 = affine.min #map6(%23, %arg4)
%25 = tensor.extract_slice %arg7[%arg2, %arg4] [%22, %24] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%26 = linalg.matmul {__internal_linalg_transform__ = "workgroup_l1_tile", lowering.config = #config1} ins(%18, %20 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%25 : tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 125528a..72ac6d4 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -402,12 +402,12 @@
// shapex.ranked_dim(flow.dispatch.shape(%x), %const)
// ``
struct ConvertDimOfDispatchInputLoadToDispatchShape
- : public OpRewritePattern<memref::DimOp> {
+ : public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(memref::DimOp op,
+ LogicalResult matchAndRewrite(tensor::DimOp op,
PatternRewriter &rewriter) const override {
- auto loadOp = op.memrefOrTensor().getDefiningOp<DispatchTensorLoadOp>();
+ auto loadOp = op.source().getDefiningOp<DispatchTensorLoadOp>();
if (!loadOp) return failure();
Optional<int64_t> constantIndex = op.getConstantIndex();
diff --git a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir
index 5f8c37b..d2999fe 100644
--- a/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/dispatch_workgroups_folding.mlir
@@ -66,7 +66,7 @@
// CHECK-NEXT: "test.sink"(%[[DIM]]) : (index) -> ()
%tensor = flow.dispatch.tensor.load %arg0, offsets=[], sizes=[], strides=[] : !flow.dispatch.tensor<readonly:?xf32> -> tensor<?xf32>
%c0 = constant 0 : index
- %dim = memref.dim %tensor, %c0 : tensor<?xf32>
+ %dim = tensor.dim %tensor, %c0 : tensor<?xf32>
"test.sink"(%dim) : (index) -> ()
return
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
index 24bf115..b1f3687 100644
--- a/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/ConvertToFlowTensorOps.cpp
@@ -122,7 +122,7 @@
SmallVector<Value, 4> dynamicDims;
for (auto dim : llvm::enumerate(v.getType().cast<ShapedType>().getShape())) {
if (dim.value() != ShapedType::kDynamicSize) continue;
- dynamicDims.push_back(b.createOrFold<memref::DimOp>(loc, v, dim.index()));
+ dynamicDims.push_back(b.createOrFold<tensor::DimOp>(loc, v, dim.index()));
}
return dynamicDims;
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
index a90fbea..dc6e9d0 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DestructiveUpdateUtils.cpp
@@ -69,7 +69,7 @@
writes.push_back(subTensorInsertOp);
continue;
}
- if (auto dimOp = dyn_cast<memref::DimOp>(u.getOwner())) {
+ if (auto dimOp = dyn_cast<tensor::DimOp>(u.getOwner())) {
continue;
}
LLVM_DEBUG(llvm::dbgs() << "found non-destructive update pattern use: "
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 5aaad87..92fa3e1 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -17,7 +17,6 @@
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -104,7 +103,7 @@
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
- memref::MemRefDialect, scf::SCFDialect, ShapeDialect>();
+ scf::SCFDialect, ShapeDialect, tensor::TensorDialect>();
}
DispatchLinalgOnTensorsPass() = default;
DispatchLinalgOnTensorsPass(const DispatchLinalgOnTensorsPass &pass) {}
@@ -665,7 +664,7 @@
if (auto rt = operand.getType().dyn_cast<RankedTensorType>()) {
for (unsigned i = 0; i < rt.getRank(); ++i) {
if (!rt.isDynamicDim(i)) continue;
- auto dim = builder.createOrFold<memref::DimOp>(dispatchOp.getLoc(),
+ auto dim = builder.createOrFold<tensor::DimOp>(dispatchOp.getLoc(),
operand, i);
operandDynamicDims.push_back(dim);
}
@@ -715,7 +714,7 @@
static bool hasOnlyDimUses(Operation *op) {
return llvm::all_of(op->getUsers(), [&](Operation *user) {
- return isa<memref::DimOp>(user);
+ return isa<tensor::DimOp>(user);
});
}
@@ -808,7 +807,7 @@
rewriter.replaceOpWithIf(op, dispatchOp.getResults(),
[&](OpOperand &operand) {
- return !isa<memref::DimOp>(operand.getOwner());
+ return !isa<tensor::DimOp>(operand.getOwner());
});
return success();
}
@@ -865,7 +864,7 @@
SmallVector<Value> shape;
for (auto dim :
llvm::seq<int64_t>(0, v.getType().cast<ShapedType>().getRank())) {
- shape.push_back(rewriter.createOrFold<memref::DimOp>(loc, v, dim));
+ shape.push_back(rewriter.createOrFold<tensor::DimOp>(loc, v, dim));
}
return shape;
};
@@ -898,7 +897,7 @@
llvm::all_of(op->getUsers(), [](Operation *user) {
return isDispatchableOp(user) ||
user->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>() ||
- isa<IREE::Flow::DispatchWorkgroupsOp, memref::DimOp>(user);
+ isa<IREE::Flow::DispatchWorkgroupsOp, tensor::DimOp>(user);
})) {
return failure();
}
@@ -966,7 +965,7 @@
rewriter.replaceOpWithIf(op, dispatchOp.getOperation()->getResults(),
[&](OpOperand &operand) {
Operation *user = operand.getOwner();
- return !isa<memref::DimOp>(user);
+ return !isa<tensor::DimOp>(user);
});
return success();
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp b/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp
index 2bf45af..d6c5ec7 100644
--- a/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/PadTensorToSubTensorInsert.cpp
@@ -61,7 +61,7 @@
SmallVector<OpFoldResult> outputShape;
for (int64_t dim : llvm::seq<int64_t>(0, rank)) {
SmallVector<Value> mapValues;
- Value sourceDim = rewriter.createOrFold<memref::DimOp>(loc, source, dim);
+ Value sourceDim = rewriter.createOrFold<tensor::DimOp>(loc, source, dim);
mapValues.push_back(sourceDim);
sourceShape.push_back(sourceDim);
AffineExpr expr = rewriter.getAffineDimExpr(0);
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
index 53c4a24..01ada4d 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/convert_to_flow_tensor_ops.mlir
@@ -88,7 +88,7 @@
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C48:.+]] = constant 48 : index
-// CHECK-DAG: %[[DIM:.+]] = memref.dim %[[ARG0]], %[[C1]] : tensor<5x?x48xf32>
+// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<5x?x48xf32>
// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]][%[[C2]], %[[C3]], %[[C0]] for %[[C1]], %[[C2]], %[[C48]]]
// CHECK: %[[RESULT:.+]] = flow.tensor.reshape %[[SLICE]]
// CHECK: return %[[RESULT]]
@@ -105,7 +105,7 @@
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK-DAG: %[[C513:.+]] = constant 513 : index
-// CHECK: %[[DIM:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[SLICE:.+]] = flow.tensor.slice %[[ARG0]]
// CHECK-SAME: [%[[C4]], %[[C0]] for %[[C1]], %[[C513]]]
// CHECK-SAME: : tensor<?x513xi32>{%[[DIM]]} -> tensor<1x513xi32>
@@ -146,7 +146,7 @@
// CHECK-DAG: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[C2:.+]] = constant 2
// CHECK-DAG: %[[C4:.+]] = constant 4
-// CHECK-DAG: %[[DIM0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[UPDATE:.+]] = flow.tensor.update %[[ARG1]], %[[ARG0]][%[[C4]], %[[C2]], %[[C0]]]
// CHECK-SAME: : tensor<1x4x48xf32> -> tensor<?x24x48xf32>{%[[DIM0]]}
@@ -167,7 +167,7 @@
// CHECK-DAG: %[[C2:.+]] = constant 2
// CHECK-DAG: %[[C4:.+]] = constant 4
// CHECK-DAG: %[[RESHAPE:.+]] = flow.tensor.reshape %[[ARG1]] : tensor<4x48xf32> -> tensor<1x4x48xf32>
-// CHECK-DAG: %[[DIM:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[UPDATE:.+]] = flow.tensor.update %[[RESHAPE]], %[[ARG0]][%[[C4]], %[[C2]], %[[C0]]]
// CHECK-SAME: : tensor<1x4x48xf32> -> tensor<?x24x48xf32>{%[[DIM]]}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index e70a224..837dc6a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -48,8 +48,8 @@
func @tile_generic_op_alone(%A: tensor<?x?xf32>, %B: tensor<?xf32>) -> tensor<?x?xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
- %d0 = memref.dim %A, %c0 : tensor<?x?xf32>
- %d1 = memref.dim %A, %c1 : tensor<?x?xf32>
+ %d0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %A, %c1 : tensor<?x?xf32>
%0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -69,8 +69,8 @@
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: flow.dispatch.workgroups
// CHECK-SAME: [%[[D1]], %[[D0]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[D0]], %[[D1]])
// CHECK-NEXT: %[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?x?xf32>
@@ -92,8 +92,8 @@
%zero = constant 0.0 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
- %M = memref.dim %A, %c0 : tensor<?x?xf32>
- %N = memref.dim %B, %c1 : tensor<?x?xf32>
+ %M = tensor.dim %A, %c0 : tensor<?x?xf32>
+ %N = tensor.dim %B, %c1 : tensor<?x?xf32>
%0 = linalg.init_tensor [%M, %N] : tensor<?x?xf32>
%1 = linalg.fill(%zero, %0) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%2 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
@@ -105,8 +105,8 @@
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: flow.dispatch.workgroups[%[[N]], %[[M]], %[[C1]]]
// CHECK-SAME: (%[[M]], %[[N]], %[[ARG0]], %[[ARG1]])
// CHECK-NEXT: (%[[ARG2:[a-zA-Z0-9_]+]]: index
@@ -135,9 +135,9 @@
%one = constant 1.0 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
- %M = memref.dim %A, %c0 : tensor<?x?xf32>
- %N = memref.dim %B, %c1 : tensor<?x?xf32>
- %K = memref.dim %A, %c1 : tensor<?x?xf32>
+ %M = tensor.dim %A, %c0 : tensor<?x?xf32>
+ %N = tensor.dim %B, %c1 : tensor<?x?xf32>
+ %K = tensor.dim %A, %c1 : tensor<?x?xf32>
%0 = linalg.init_tensor [%M, %N] : tensor<?x?xf32>
%1 = linalg.fill(%zero, %0) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%2 = linalg.init_tensor [%M, %K] : tensor<?x?xf32>
@@ -159,9 +159,9 @@
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[RESULT1:.+]] = flow.dispatch.workgroups[%[[K]], %[[M]], %[[C1]]]
// CHECK-SAME: (%[[ARG0]], %[[M]], %[[K]])
// CHECK-NEXT: (%[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?x?xf32>
@@ -220,8 +220,8 @@
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[WORKLOAD:.+]] = affine.apply #[[MAP0]]()[%[[D0]], %[[D1]]]
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
// CHECK-SAME: [%[[WORKLOAD]], %[[C1]], %[[C1]]](%[[ARG0]])
@@ -239,10 +239,10 @@
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
- %d0 = memref.dim %A, %c0 : tensor<?x?x?x?xf32>
- %d1 = memref.dim %A, %c1 : tensor<?x?x?x?xf32>
- %d2 = memref.dim %A, %c2 : tensor<?x?x?x?xf32>
- %d3 = memref.dim %A, %c3 : tensor<?x?x?x?xf32>
+ %d0 = tensor.dim %A, %c0 : tensor<?x?x?x?xf32>
+ %d1 = tensor.dim %A, %c1 : tensor<?x?x?x?xf32>
+ %d2 = tensor.dim %A, %c2 : tensor<?x?x?x?xf32>
+ %d3 = tensor.dim %A, %c3 : tensor<?x?x?x?xf32>
%0 = linalg.init_tensor [%d0, %d1, %d2, %d3] : tensor<?x?x?x?xf32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
@@ -266,10 +266,10 @@
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[D2:.+]] = memref.dim %[[ARG0]], %[[C2]]
-// CHECK-DAG: %[[D3:.+]] = memref.dim %[[ARG0]], %[[C3]]
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[D3:.+]] = tensor.dim %[[ARG0]], %[[C3]]
// CHECK-DAG: %[[WG_SISE_2:.+]] = flow.dispatch.workgroup.size[2] : index
// CHECK-DAG: %[[WG_ID_2:.+]] = flow.dispatch.workgroup.id[2] : index
// CHECK-DAG: flow.dispatch.workgroups[%[[D3]], %[[D2]], %[[D1]]]
@@ -286,14 +286,14 @@
%c1 = constant 1 : index
%0 = linalg.tensor_expand_shape %lhs [[0, 1]]
: tensor<?xf32> into tensor<?x4xf32>
- %m = memref.dim %0, %c0 : tensor<?x4xf32>
- %n1 = memref.dim %rhs1, %c1 : tensor<4x?xf32>
+ %m = tensor.dim %0, %c0 : tensor<?x4xf32>
+ %n1 = tensor.dim %rhs1, %c1 : tensor<4x?xf32>
%init1 = linalg.init_tensor [%m, %n1] : tensor<?x?xf32>
%fill1 = linalg.fill(%cst, %init1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%1 = linalg.matmul
ins(%0, %rhs1 : tensor<?x4xf32>, tensor<4x?xf32>)
outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %n2 = memref.dim %rhs2, %c1 : tensor<4x?xf32>
+ %n2 = tensor.dim %rhs2, %c1 : tensor<4x?xf32>
%init2 = linalg.init_tensor [%m, %n2] : tensor<?x?xf32>
%fill2 = linalg.fill(%cst, %init2) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%2= linalg.matmul
@@ -309,12 +309,12 @@
// CHECK-SAME: %[[RHS2:[a-zA-Z0-9_]+]]: tensor<4x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[M:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK-DAG: %[[N1:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: %[[RESULT1:.+]] = flow.dispatch.workgroups[%[[N1]], %[[M]], %[[C1]]]
// CHECK-SAME: (%[[M]], %[[N1]], %[[ARG0]], %[[RHS1]])
-// CHECK: %[[N2:.+]] = memref.dim %[[RHS2]], %[[C1]]
+// CHECK: %[[N2:.+]] = tensor.dim %[[RHS2]], %[[C1]]
// CHECK: %[[RESULT2:.+]] = flow.dispatch.workgroups[%[[N2]], %[[M]], %[[C1]]]
// CHECK-SAME: (%[[M]], %[[N2]], %[[ARG0]], %[[RHS2]])
// CHECK: return %[[RESULT1]], %[[RESULT2]]
@@ -326,8 +326,8 @@
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = tensor.extract %arg1[] : tensor<f32>
- %1 = memref.dim %arg0, %c0 : tensor<?x?xf32>
- %2 = memref.dim %arg0, %c1 : tensor<?x?xf32>
+ %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%3 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%1)[%arg2, %arg4]
%4 = affine.apply affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)>(%2)[%arg3, %arg5]
%5 = linalg.init_tensor [%3, %4] : tensor<?x?xf32>
@@ -346,8 +346,8 @@
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[RD0:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[D0]]]
// CHECK-DAG: %[[RD1:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG5]], %[[D1]]]
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
@@ -409,8 +409,8 @@
%f12 = constant 12.0 : f32
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
- // CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG2]], %[[C0]]
- // CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG2]], %[[C1]]
+ // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG2]], %[[C0]]
+ // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
// CHECK: %[[origCC:.+]] = flow.dispatch.workgroups[%[[D1]], %[[D0]], %[[C1]]](%[[ARG2]])
// CHECK-NEXT: %[[ARG3:.+]]: !flow.dispatch.tensor<readwrite:?x?xf32>
// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG3]], {{.*}}
@@ -760,7 +760,7 @@
%cmin = constant -2147483648 : i32
%c0_i32 = constant 0 : i32
%c0 = constant 0 : index
- %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
+ %0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
%1 = linalg.init_tensor [%0] : tensor<?xi32>
%2 = linalg.fill(%cmin, %1) : i32, tensor<?xi32> -> tensor<?xi32>
%3 = linalg.fill(%c0_i32, %1) : i32, tensor<?xi32> -> tensor<?xi32>
@@ -800,7 +800,7 @@
-> (tensor<?x10xi32>, tensor<?x10xi32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
- %0 = memref.dim %arg0, %c0 : tensor<?x10xi32>
+ %0 = tensor.dim %arg0, %c0 : tensor<?x10xi32>
%1 = linalg.init_tensor [%0, 10] : tensor<?x10xi32>
%2:2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, 10-d1)>,
@@ -876,13 +876,13 @@
%cst = constant 0.000000e+00 : f32
%0 = iree.dynamic_shape_constant dense<[[1.500000e+01, 1.400000e+01, 1.300000e+01], [1.200000e+01, 1.100000e+01, 1.000000e+01], [9.000000e+00, 8.000000e+00, 7.000000e+00], [6.000000e+00, 5.000000e+00, 4.000000e+00], [3.000000e+00, 2.000000e+00, 1.000000e+00]]> : tensor<5x3xf32> -> tensor<?x?xf32>
%1 = iree.dynamic_shape_constant dense<[[1.500000e+01, 1.400000e+01, 1.300000e+01, 1.200000e+01, 1.100000e+01], [1.000000e+01, 9.000000e+00, 8.000000e+00, 7.000000e+00, 6.000000e+00], [5.000000e+00, 4.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]]> : tensor<3x5xf32> -> tensor<?x?xf32>
- %2 = memref.dim %0, %c0 : tensor<?x?xf32>
- %3 = memref.dim %1, %c1 : tensor<?x?xf32>
+ %2 = tensor.dim %0, %c0 : tensor<?x?xf32>
+ %3 = tensor.dim %1, %c1 : tensor<?x?xf32>
%4 = linalg.init_tensor [%2, %3] : tensor<?x?xf32>
%5 = linalg.fill(%cst, %4) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%6 = linalg.matmul ins(%0, %1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %7 = memref.dim %6, %c0 : tensor<?x?xf32>
- %8 = memref.dim %6, %c1 : tensor<?x?xf32>
+ %7 = tensor.dim %6, %c0 : tensor<?x?xf32>
+ %8 = tensor.dim %6, %c1 : tensor<?x?xf32>
%9 = hal.tensor.cast %6 : tensor<?x?xf32>{%7, %8} -> !hal.buffer_view
return %9 : !hal.buffer_view
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_elementwise.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_elementwise.mlir
index 14ce173..620f019 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_elementwise.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors_elementwise.mlir
@@ -3,8 +3,8 @@
func @tile_generic_op_alone(%A: tensor<?x?xf32>, %B: tensor<?xf32>) -> tensor<?x?xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
- %d0 = memref.dim %A, %c0 : tensor<?x?xf32>
- %d1 = memref.dim %A, %c1 : tensor<?x?xf32>
+ %d0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %A, %c1 : tensor<?x?xf32>
%0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -25,8 +25,8 @@
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: flow.dispatch.workgroups
// CHECK-SAME: [%[[D1]], %[[D0]], %[[C1]]](%[[ARG0]], %[[ARG1]], %[[D0]], %[[D1]])
// CHECK-NEXT: %[[ARG2:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:?x?xf32>
@@ -67,10 +67,10 @@
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
- %d0 = memref.dim %A, %c0 : tensor<?x?x?x?xf32>
- %d1 = memref.dim %A, %c1 : tensor<?x?x?x?xf32>
- %d2 = memref.dim %A, %c2 : tensor<?x?x?x?xf32>
- %d3 = memref.dim %A, %c3 : tensor<?x?x?x?xf32>
+ %d0 = tensor.dim %A, %c0 : tensor<?x?x?x?xf32>
+ %d1 = tensor.dim %A, %c1 : tensor<?x?x?x?xf32>
+ %d2 = tensor.dim %A, %c2 : tensor<?x?x?x?xf32>
+ %d3 = tensor.dim %A, %c3 : tensor<?x?x?x?xf32>
%0 = linalg.init_tensor [%d0, %d1, %d2, %d3] : tensor<?x?x?x?xf32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
@@ -94,10 +94,10 @@
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[D2:.+]] = memref.dim %[[ARG0]], %[[C2]]
-// CHECK-DAG: %[[D3:.+]] = memref.dim %[[ARG0]], %[[C3]]
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[D3:.+]] = tensor.dim %[[ARG0]], %[[C3]]
// CHECK: flow.dispatch.workgroups[%[[D3]], %[[D2]], %[[D1]]]
// -----
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir
index ccd4718..af834cb 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/form_streams.mlir
@@ -347,7 +347,7 @@
func @metadata_only(%t: tensor<?xf32>) -> (tensor<?xf32>, !shapex.ranked_shape<[?]>) {
// CHECK-NOT: flow.ex.stream.fragment
%c0 = constant 0 : index
- %4 = memref.dim %t, %c0 : tensor<?xf32>
+ %4 = tensor.dim %t, %c0 : tensor<?xf32>
%5 = shapex.make_ranked_shape %4 : (index) -> !shapex.ranked_shape<[?]>
%6 = shapex.tie_shape %t, %5 : tensor<?xf32>, !shapex.ranked_shape<[?]>
return %6, %5 : tensor<?xf32>, !shapex.ranked_shape<[?]>
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
index 909c7fa..e9d0d7e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir
@@ -156,10 +156,10 @@
func @dynamicShapeDispatch(%arg0 : tensor<7x?x24x?xf32>) -> tensor<?x?x1024xf32> {
%c1 = constant 1 : index
%c3 = constant 3 : index
- // CHECK-DAG: %[[ARG0_DIM1:.+]] = memref.dim %[[ARG0]], %c1
- %dim1 = memref.dim %arg0, %c1 : tensor<7x?x24x?xf32>
- // CHECK-DAG: %[[ARG0_DIM3:.+]] = memref.dim %[[ARG0]], %c3
- %dim3 = memref.dim %arg0, %c3 : tensor<7x?x24x?xf32>
+ // CHECK-DAG: %[[ARG0_DIM1:.+]] = tensor.dim %[[ARG0]], %c1
+ %dim1 = tensor.dim %arg0, %c1 : tensor<7x?x24x?xf32>
+ // CHECK-DAG: %[[ARG0_DIM3:.+]] = tensor.dim %[[ARG0]], %c3
+ %dim3 = tensor.dim %arg0, %c3 : tensor<7x?x24x?xf32>
// CHECK-DAG: %[[X:.+]] = constant 1024
%x = constant 1024 : index
// CHECK-DAG: %[[Y:.+]] = constant 512
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/pad_tensor_to_tensor.mlir b/iree/compiler/Dialect/Flow/Transforms/test/pad_tensor_to_tensor.mlir
index 7fa66b6..e788b79 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/pad_tensor_to_tensor.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/pad_tensor_to_tensor.mlir
@@ -23,8 +23,8 @@
// CHECK-DAG: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[C1:.+]] = constant 1
// CHECK-DAG: %[[VAL:.+]] = tensor.extract %[[ARG1]]
-// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[RD0:.+]] = affine.apply #[[MAP0]]()[%[[ARG3]], %[[D0]]]
// CHECK-DAG: %[[RD1:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[D1]]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[RD0]], %[[RD1]]]
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD
index 3fa193d..e982fd7 100644
--- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD
@@ -34,6 +34,7 @@
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertShapeOps.cpp b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertShapeOps.cpp
index 6f9d52f..61f1e7e 100644
--- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertShapeOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertShapeOps.cpp
@@ -9,8 +9,8 @@
#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/ArrayRef.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -37,22 +37,20 @@
// Lowers dim operations against values that were originally tensors but have
// been converted to HAL buffer types.
class BackingBufferBufferViewDimPattern
- : public OpConversionPattern<memref::DimOp> {
+ : public OpConversionPattern<tensor::DimOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
- memref::DimOp dimOp, llvm::ArrayRef<Value> rawOperands,
+ tensor::DimOp dimOp, llvm::ArrayRef<Value> rawOperands,
ConversionPatternRewriter &rewriter) const override {
- memref::DimOp::Adaptor operands(rawOperands);
- if (!dimOp.memrefOrTensor().getType().isa<TensorType>() ||
- !IREE::HAL::TensorRewriteAdaptor::isValidNewType(
- operands.memrefOrTensor().getType())) {
+ tensor::DimOp::Adaptor operands(rawOperands);
+ if (!IREE::HAL::TensorRewriteAdaptor::isValidNewType(
+ operands.source().getType())) {
return failure();
}
auto adaptor = IREE::HAL::TensorRewriteAdaptor::get(
- dimOp.getLoc(), dimOp.memrefOrTensor(), operands.memrefOrTensor(),
- rewriter);
+ dimOp.getLoc(), dimOp.source(), operands.source(), rewriter);
Optional<int64_t> index = dimOp.getConstantIndex();
assert(index.hasValue() && "expect constant index in `std.dim` operation");
diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp
index 34fa2f9..58d7601 100644
--- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -117,6 +118,7 @@
// have any types they are valid to be used on after this conversion.
conversionTarget.addIllegalOp<memref::DimOp>();
conversionTarget.addIllegalOp<mlir::RankOp>();
+ conversionTarget.addIllegalOp<tensor::DimOp>();
// We must convert away any of our casts from higher level dialects.
conversionTarget.addIllegalOp<IREE::HAL::TensorCastOp>();
diff --git a/iree/compiler/Dialect/Shape/IR/Builders.cpp b/iree/compiler/Dialect/Shape/IR/Builders.cpp
index c399e62..85afad3 100644
--- a/iree/compiler/Dialect/Shape/IR/Builders.cpp
+++ b/iree/compiler/Dialect/Shape/IR/Builders.cpp
@@ -261,7 +261,7 @@
// hopefully converted to ranked shape types.
for (unsigned i = 0; i < valueSt.getRank(); ++i) {
if (valueSt.isDynamicDim(i)) {
- result.push_back(builder.createOrFold<memref::DimOp>(loc, value, i));
+ result.push_back(builder.createOrFold<tensor::DimOp>(loc, value, i));
}
}
}
diff --git a/iree/compiler/Dialect/Shape/IR/Folders.cpp b/iree/compiler/Dialect/Shape/IR/Folders.cpp
index e808043..87570fd 100644
--- a/iree/compiler/Dialect/Shape/IR/Folders.cpp
+++ b/iree/compiler/Dialect/Shape/IR/Folders.cpp
@@ -242,7 +242,7 @@
if (auto carryingOp = dyn_cast<ShapeCarryingInterface>(use.getOwner())) {
carryingOp->setOperand(use.getOperandNumber(), operands.operand());
didAnything = true;
- } else if (auto dimOp = dyn_cast<memref::DimOp>(use.getOwner())) {
+ } else if (auto dimOp = dyn_cast<tensor::DimOp>(use.getOwner())) {
auto index = dimOp.getConstantIndex();
if (index.hasValue()) {
rewriter.replaceOpWithNewOp<RankedDimOp>(dimOp, op.shape(),
diff --git a/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp b/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp
index aaba8dc..2f88c85 100644
--- a/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp
+++ b/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp
@@ -293,7 +293,7 @@
for (int i = 0; i < t.getRank(); ++i) {
if (t.isDynamicDim(i)) {
// Emit a dim op.
- Value dim = builder.create<memref::DimOp>(loc, v, i);
+ Value dim = builder.create<tensor::DimOp>(loc, v, i);
extents.push_back(dim);
} else {
// Static dim.
diff --git a/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp b/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
index c72c654..28d6e03 100644
--- a/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
+++ b/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp
@@ -73,12 +73,12 @@
SmallVector<Value, 3> offsets, sizes, strides;
for (int i = 0; i < rank; ++i) {
offsets.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
- sizes.push_back(rewriter.create<memref::DimOp>(loc, args[0], i));
+ sizes.push_back(rewriter.create<tensor::DimOp>(loc, args[0], i));
strides.push_back(rewriter.create<ConstantIndexOp>(loc, 1));
}
Value resultDimSize = rewriter.create<ConstantIndexOp>(loc, 0);
for (auto arg : args) {
- auto size = rewriter.create<memref::DimOp>(loc, arg, dim);
+ auto size = rewriter.create<tensor::DimOp>(loc, arg, dim);
resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
}
sizes[dim] = resultDimSize;
@@ -92,7 +92,7 @@
Value accBound = rewriter.create<ConstantIndexOp>(loc, 0);
for (auto arg : args) {
offsets[dim] = accBound;
- sizes[dim] = rewriter.create<memref::DimOp>(loc, arg, dim);
+ sizes[dim] = rewriter.create<tensor::DimOp>(loc, arg, dim);
result = rewriter.create<tensor::InsertSliceOp>(loc, arg, result, offsets,
sizes, strides);
accBound = rewriter.create<AddIOp>(loc, accBound, sizes[dim]);
diff --git a/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir b/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
index 7954f8f..c23d9d8 100644
--- a/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir
@@ -24,11 +24,11 @@
// Should broadcast %arg0 -> %arg1 and assert on dynamic expansion.
// CHECK: %[[C0_0:.*]] = constant 0 : index
- // CHECK: %[[ARG0_D0:.*]] = memref.dim %arg0, %[[C0_0]]
+ // CHECK: %[[ARG0_D0:.*]] = tensor.dim %arg0, %[[C0_0]]
// CHECK: %[[C0_1:.*]] = constant 0 : index
- // CHECK: %[[ARG1_D0:.*]] = memref.dim %arg1, %[[C0_1]] : tensor<?x?xf32>
+ // CHECK: %[[ARG1_D0:.*]] = tensor.dim %arg1, %[[C0_1]] : tensor<?x?xf32>
// CHECK: %[[C1_0:.*]] = constant 1 : index
- // CHECK: %[[ARG1_D1:.*]] = memref.dim %arg1, %[[C1_0]] : tensor<?x?xf32>
+ // CHECK: %[[ARG1_D1:.*]] = tensor.dim %arg1, %[[C1_0]] : tensor<?x?xf32>
// CHECK: %[[EQ:.*]] = cmpi eq, %[[ARG0_D0]], %[[ARG1_D1]] : index
// CHECK: assert %[[EQ]], "mismatched dynamic broadcast extents"
@@ -175,7 +175,7 @@
// CHECK-LABEL: func @selectv2_broadcast_dyn_pred
func @selectv2_broadcast_dyn_pred(%arg0: tensor<?x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<?x8x8xi32> {
// CHECK: %[[C0_0:.*]] = constant 0 : index
- // CHECK: %[[DIM_PRED_0:.*]] = memref.dim %arg0, %[[C0_0]]
+ // CHECK: %[[DIM_PRED_0:.*]] = tensor.dim %arg0, %[[C0_0]]
// CHECK: %[[INIT_PRED:.*]] = linalg.init_tensor [%[[DIM_PRED_0]], 8, 8]
// CHECK: %[[BCAST_PRED:.*]] = linalg.generic
// CHECK-SAME: indexing_maps = [#map0, #map1]
@@ -189,7 +189,7 @@
// CHECK-SAME: indexing_maps = [#map3, #map1]
// CHECK-SAME: ins(%arg2 : tensor<1x1x8xi32>) outs(%[[INIT_ELSE]] : tensor<?x8x8xi32>)
// CHECK: %[[C0_1:.*]] = constant 0 : index
- // CHECK: %[[DIM_BCAST_PRED_0:.*]] = memref.dim %[[BCAST_PRED]], %[[C0_1]]
+ // CHECK: %[[DIM_BCAST_PRED_0:.*]] = tensor.dim %[[BCAST_PRED]], %[[C0_1]]
// CHECK: %[[INIT_RESULT:.*]] = linalg.init_tensor [%[[DIM_BCAST_PRED_0]], 8, 8]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[BCAST_PRED]], %[[BCAST_THEN]], %[[BCAST_ELSE]] : tensor<?x8x8xi1>, tensor<?x8x8xi32>, tensor<?x8x8xi32>) outs(%[[INIT_RESULT]] : tensor<?x8x8xi32>)
@@ -201,7 +201,7 @@
// CHECK-LABEL: func @selectv2_broadcast_dyn_then
func @selectv2_broadcast_dyn_then(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x?x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x?x8xi32> {
// CHECK: %[[C1_0:.*]] = constant 1 : index
- // CHECK: %[[DIM_THEN_1:.*]] = memref.dim %arg1, %[[C1_0]]
+ // CHECK: %[[DIM_THEN_1:.*]] = tensor.dim %arg1, %[[C1_0]]
// CHECK: %[[INIT_PRED:.*]] = linalg.init_tensor [8, %[[DIM_THEN_1]], 8]
// CHECK: %[[BCAST_PRED:.*]] = linalg.generic
// CHECK-SAME: indexing_maps = [#map0, #map1]
@@ -215,7 +215,7 @@
// CHECK-SAME: indexing_maps = [#map3, #map1]
// CHECK-SAME: ins(%arg2 : tensor<1x1x8xi32>) outs(%[[INIT_ELSE]] : tensor<8x?x8xi32>)
// CHECK: %[[C1_1:.*]] = constant 1 : index
- // CHECK: %[[DIM_BCAST_PRED_1:.*]] = memref.dim %[[BCAST_PRED]], %[[C1_1]]
+ // CHECK: %[[DIM_BCAST_PRED_1:.*]] = tensor.dim %[[BCAST_PRED]], %[[C1_1]]
// CHECK: %[[INIT_RESULT:.*]] = linalg.init_tensor [8, %[[DIM_BCAST_PRED_1]], 8]
// CHECK: linalg.generic
// CHECK-SAME: ins(%2, %4, %6 : tensor<8x?x8xi1>, tensor<8x?x8xi32>, tensor<8x?x8xi32>) outs(%8 : tensor<8x?x8xi32>)
@@ -227,7 +227,7 @@
// CHECK-LABEL: func @selectv2_broadcast_dyn_else
func @selectv2_broadcast_dyn_else(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x?xi32>) -> tensor<8x8x?xi32> {
// CHECK: %[[C2_0:.*]] = constant 2 : index
- // CHECK: %[[DIM_ELSE_2:.*]] = memref.dim %arg2, %[[C2_0]]
+ // CHECK: %[[DIM_ELSE_2:.*]] = tensor.dim %arg2, %[[C2_0]]
// CHECK: %[[INIT_PRED:.*]] = linalg.init_tensor [8, 8, %[[DIM_ELSE_2]]]
// CHECK: %[[BCAST_PRED:.*]] = linalg.generic
// CHECK-SAME: indexing_maps = [#map0, #map1]
@@ -242,7 +242,7 @@
// CHECK-SAME: indexing_maps = [#map3, #map1]
// CHECK-SAME: ins(%arg2 : tensor<1x1x?xi32>) outs(%[[INIT_ELSE]] : tensor<8x8x?xi32>)
// CHECK: %[[C2_1:.*]] = constant 2 : index
- // CHECK: %[[DIM_BCAST_PRED_1:.*]] = memref.dim %[[BCAST_PRED]], %[[C2_1]]
+ // CHECK: %[[DIM_BCAST_PRED_1:.*]] = tensor.dim %[[BCAST_PRED]], %[[C2_1]]
// CHECK: %[[INIT_RESULT:.*]] = linalg.init_tensor [8, 8, %[[DIM_BCAST_PRED_1]]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%2, %4, %6 : tensor<8x8x?xi1>, tensor<8x8x?xi32>, tensor<8x8x?xi32>) outs(%8 : tensor<8x8x?xi32>)
@@ -254,13 +254,13 @@
// CHECK-LABEL: func @selectv2_broadcast_dyn_all
func @selectv2_broadcast_dyn_all(%arg0: tensor<?x1x1xi1>, %arg1: tensor<?x8x1xi32>, %arg2: tensor<?x1x?xi32>) -> tensor<?x8x?xi32> {
// CHECK: %[[C0:.*]] = constant 0 : index
- // CHECK: %[[PRED_D0:.*]] = memref.dim %arg0, %[[C0]] : tensor<?x1x1xi1>
+ // CHECK: %[[PRED_D0:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x1x1xi1>
// CHECK: %[[C0_0:.*]] = constant 0 : index
- // CHECK: %[[THEN_D0:.*]] = memref.dim %arg1, %[[C0_0]] : tensor<?x8x1xi32>
+ // CHECK: %[[THEN_D0:.*]] = tensor.dim %arg1, %[[C0_0]] : tensor<?x8x1xi32>
// CHECK: %[[C0_1:.*]] = constant 0 : index
- // CHECK: %[[ELSE_D0:.*]] = memref.dim %arg2, %[[C0_1]] : tensor<?x1x?xi32>
+ // CHECK: %[[ELSE_D0:.*]] = tensor.dim %arg2, %[[C0_1]] : tensor<?x1x?xi32>
// CHECK: %[[C2:.*]] = constant 2 : index
- // CHECK: %[[ELSE_D2:.*]] = memref.dim %arg2, %[[C2]] : tensor<?x1x?xi32>
+ // CHECK: %[[ELSE_D2:.*]] = tensor.dim %arg2, %[[C2]] : tensor<?x1x?xi32>
// CHECK: %[[CMP_0:.*]] = cmpi eq, %[[PRED_D0]], %[[THEN_D0]] : index
// CHECK: assert %[[CMP_0]], "mismatched dynamic broadcast extents"
// CHECK: %[[CMP_1:.*]] = cmpi eq, %[[PRED_D0]], %[[ELSE_D0]] : index
@@ -416,9 +416,9 @@
// CHECK-DAG: %[[C4:.*]] = constant 4 : index
// CHECK-DAG: %[[RESULT_D4:.*]] = tensor.extract %arg1[%[[C4]]] : tensor<5xindex>
// CHECK-DAG: %[[INDEX1:.*]] = constant 1 : index
- // CHECK-DAG: %[[ARG_D1:.*]] = memref.dim %[[INPUT]], %[[INDEX1]] : tensor<4x?x3x?xi32>
+ // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %[[INPUT]], %[[INDEX1]] : tensor<4x?x3x?xi32>
// CHECK-DAG: %[[INDEX3:.*]] = constant 3 : index
- // CHECK-DAG: %[[ARG_D3:.*]] = memref.dim %[[INPUT]], %[[INDEX3]] : tensor<4x?x3x?xi32>
+ // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %[[INPUT]], %[[INDEX3]] : tensor<4x?x3x?xi32>
// CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %[[INPUT]] : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]}
%0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<4x?x3x?xui32>, tensor<5xindex>) -> tensor<12x?x?x1x?xui32>
// CHECK: %[[UNCONVERTED_RESULT:.*]] = unrealized_conversion_cast %[[RESULT]] : tensor<12x?x?x1x?xi32> to tensor<12x?x?x1x?xui32>
diff --git a/iree/compiler/InputConversion/MHLO/test/dynamic_shape.mlir b/iree/compiler/InputConversion/MHLO/test/dynamic_shape.mlir
index 73c7bda..ec32c98 100644
--- a/iree/compiler/InputConversion/MHLO/test/dynamic_shape.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/dynamic_shape.mlir
@@ -10,9 +10,9 @@
// CHECK: func @dynamic_shape
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
// CHECK: %[[C0:.+]] = constant 0 : index
-// CHECK: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[T0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[C1:.+]] = constant 1 : index
-// CHECK: %[[T1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[T1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[T2:.+]] = linalg.init_tensor [%[[T0]], %[[T1]]]
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]]
diff --git a/iree/samples/custom_modules/dialect/test/conversion.mlir b/iree/samples/custom_modules/dialect/test/conversion.mlir
index a23aa72..07d2892 100644
--- a/iree/samples/custom_modules/dialect/test/conversion.mlir
+++ b/iree/samples/custom_modules/dialect/test/conversion.mlir
@@ -79,7 +79,7 @@
func @messageToTensorReturnDim(%arg0 : !custom.message) -> index {
%0 = "custom.message_to_tensor"(%arg0) : (!custom.message) -> tensor<?x4xf32>
%c0 = constant 0 : index
- %1 = memref.dim %0, %c0 : tensor<?x4xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?x4xf32>
// CHECK: [[VIEW:%.+]] = vm.call @custom.message_to_buffer(%arg0) : (!vm.ref<!custom.message>) -> !vm.ref<!hal.buffer_view>
// CHECK: [[BUFFER:%.+]] = vm.call @hal.buffer_view.buffer([[VIEW]]) : (!vm.ref<!hal.buffer_view>) -> !vm.ref<!hal.buffer>
// CHECK: %{{.*}} = vm.const.i32.zero
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 5b8ddd2..ce211c5 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 5b8ddd2ccceb8de04bd020f286bc3ca38638ecb1
+Subproject commit ce211c505b82e5bbb68b936968d9b54608285416