Switching to use tensor.pack/unpack ops for data-tiling (#12247)
benchmarks: comp-stats
diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
index 2d80fbc..affad4d 100644
--- a/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/ConvertToDestinationPassingStylePass.cpp
@@ -471,6 +471,7 @@
/// of tensor.unpack op for more details.
static LogicalResult replaceUnpackEmptyWithAllocTensor(OpBuilder &b,
func::FuncOp funcOp) {
+ // TODO(hanchung): retire IREE::LinalgExt version.
funcOp.walk([&](IREE::LinalgExt::UnPackOp unpackOp) {
if (!unpackOp->hasOneUse() ||
!isa<tensor::ExtractSliceOp>(*(unpackOp->user_begin()))) {
@@ -485,6 +486,22 @@
emptyOp.getLoc(), emptyOp.getType(), emptyOp.getDynamicSizes());
emptyOp.replaceAllUsesWith(allocTensor.getResult());
});
+
+ funcOp.walk([&](tensor::UnPackOp unpackOp) {
+ if (!unpackOp->hasOneUse() ||
+ !isa<tensor::ExtractSliceOp>(*(unpackOp->user_begin()))) {
+ return;
+ }
+ auto emptyOp = unpackOp.getDest().getDefiningOp<tensor::EmptyOp>();
+ if (!emptyOp) return;
+
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointAfter(emptyOp);
+ auto allocTensor = b.create<bufferization::AllocTensorOp>(
+ emptyOp.getLoc(), emptyOp.getType(), emptyOp.getDynamicSizes());
+ emptyOp.replaceAllUsesWith(allocTensor.getResult());
+ });
+
return success();
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
index 4ec5a6e..289f7a5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -56,9 +56,10 @@
auto innerTileSizes = getInnerTileSizesOfr(
builder, loc, boundTensorType, *encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizes)) return failure();
- SmallVector<OpFoldResult> convertedTargetShape = PackOp::getResultShape(
- builder, loc, targetShape, *innerTileSizes, encodingInfo->innerDimsPos,
- encodingInfo->outerDimsPerm);
+ SmallVector<OpFoldResult> convertedTargetShape =
+ tensor::PackOp::getResultShape(builder, loc, targetShape, *innerTileSizes,
+ encodingInfo->innerDimsPos,
+ encodingInfo->outerDimsPerm);
return convertedTargetShape;
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
index 47df80d..a410923 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp
@@ -112,9 +112,9 @@
return success();
}
-/// Get the materialization information from a `iree_linalg_ext.pack` operation.
+/// Get the materialization information from a `tensor.pack` operation.
static FailureOr<IREE::LinalgExt::MaterializeEncodingInfo>
-getMaterializationInfo(IREE::LinalgExt::PackOp packOp) {
+getMaterializationInfo(tensor::PackOp packOp) {
IREE::LinalgExt::MaterializeEncodingInfo encodingInfo;
SmallVector<OpFoldResult> mixedTileSizes = packOp.getMixedTiles();
encodingInfo.innerTileSizes.reserve(mixedTileSizes.size());
@@ -128,7 +128,7 @@
}
encodingInfo.innerDimsPos = llvm::to_vector(packOp.getInnerDimsPos());
encodingInfo.outerDimsPerm = llvm::to_vector(packOp.getOuterDimsPerm());
- encodingInfo.srcRank = packOp.getInputRank();
+ encodingInfo.srcRank = packOp.getSourceRank();
return encodingInfo;
}
@@ -277,11 +277,10 @@
auto innerTileSizes = getInnerTileSizesOfr(rewriter, loc, inputType,
materializeEncodingInfo, {});
if (failed(innerTileSizes)) return failure();
- SmallVector<OpFoldResult> resultShape =
- IREE::LinalgExt::PackOp::getResultShape(
- rewriter, loc, getAsOpFoldResult(workload), *innerTileSizes,
- materializeEncodingInfo.innerDimsPos,
- materializeEncodingInfo.outerDimsPerm);
+ SmallVector<OpFoldResult> resultShape = tensor::PackOp::getResultShape(
+ rewriter, loc, getAsOpFoldResult(workload), *innerTileSizes,
+ materializeEncodingInfo.innerDimsPos,
+ materializeEncodingInfo.outerDimsPerm);
resultShape.resize(materializeEncodingInfo.srcRank);
rewriter
@@ -351,14 +350,13 @@
patterns.insert<LowerDispatchWorkgroupCountForDagRootOp>(
context, tileSizes, staticLoopRanges, interchange,
partitionableLoops);
- if (auto packRootOp =
- dyn_cast_or_null<IREE::LinalgExt::PackOp>(dispatchRootOp)) {
+ if (auto packRootOp = dyn_cast_or_null<tensor::PackOp>(dispatchRootOp)) {
FailureOr<IREE::LinalgExt::MaterializeEncodingInfo> encodingInfo =
getMaterializationInfo(packRootOp);
if (failed(encodingInfo)) {
return signalPassFailure();
}
- auto tensorType = packRootOp.getInputType().cast<RankedTensorType>();
+ auto tensorType = packRootOp.getSourceType();
// The LowerDispatchWorkgroupCountFromSetEncodingOp pattern is going to
// call materializeEncodingValueFn, passing it a tensor type, expecting
// that tensor type to have a TensorEncodingAttr. The problem is that
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
index c78f23b..2ac2eb1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
@@ -821,7 +821,7 @@
%16 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv s0)>(%arg1)[%8#1]
%17 = flow.dispatch.tensor.load %3, offsets = [%15, %16, 0, 0], sizes = [%c1, %c1, %8#0, %8#1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%6, %7, %5#0, %5#1} -> tensor<?x?x?x?xi32>
%18 = tensor.empty(%8#0, %8#1) : tensor<?x?xi32>
- %19 = iree_linalg_ext.unpack %17 inner_dims_pos = [0, 1] inner_tiles = [%8#0, %8#1] into %18 : (tensor<?x?x?x?xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+ %19 = tensor.unpack %17 inner_dims_pos = [0, 1] inner_tiles = [%8#0, %8#1] into %18 : tensor<?x?x?x?xi32> -> tensor<?x?xi32>
%extracted_slice = tensor.extract_slice %19[%13, %14] [1, 1] [1, 1] : tensor<?x?xi32> to tensor<1x1xi32>
%cast = tensor.cast %extracted_slice : tensor<1x1xi32> to tensor<?x?xi32>
flow.dispatch.tensor.store %cast, %4, offsets = [%arg0, %arg1], sizes = [%c1, %c1], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<1x1xi32>>
@@ -831,7 +831,7 @@
}
// CHECK-LABEL: func.func @non_perfect_tiling_unpack
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor
-// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack
+// CHECK: %[[UNPACK:.+]] = tensor.unpack
// CHECK-SAME: into %[[ALLOC]]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
index 68117fd..dee7671 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_to_workgroups.mlir
@@ -1915,9 +1915,9 @@
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [100, 250], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<100x250xf32>> -> tensor<100x250xf32>
%3 = tensor.empty() : tensor<14x64x8x4xf32>
- %4 = iree_linalg_ext.pack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>} %2
- padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %3
- : (tensor<100x250xf32> tensor<14x64x8x4xf32>) -> tensor<14x64x8x4xf32>
+ %4 = tensor.pack %2 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %3
+ {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
+ : tensor<100x250xf32> -> tensor<14x64x8x4xf32>
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [14, 64, 8, 4], strides = [1, 1, 1, 1]
: tensor<14x64x8x4xf32> -> !flow.dispatch.tensor<writeonly:tensor<14x64x8x4xf32>>
return
@@ -1960,9 +1960,9 @@
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [250, 500], strides = [1, 1]
: !flow.dispatch.tensor<readonly:tensor<250x500xf32>> -> tensor<250x500xf32>
%3 = tensor.empty() : tensor<64x64x8x4xf32>
- %4 = iree_linalg_ext.pack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>} %2
- padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4] into %3
- : (tensor<250x500xf32> tensor<64x64x8x4xf32>) -> tensor<64x64x8x4xf32>
+ %4 = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4] into %3
+ {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
+ : tensor<250x500xf32> -> tensor<64x64x8x4xf32>
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [64, 64, 8, 4], strides = [1, 1, 1, 1]
: tensor<64x64x8x4xf32> -> !flow.dispatch.tensor<writeonly:tensor<64x64x8x4xf32>>
return
@@ -2018,9 +2018,9 @@
%15 = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%13]
%16 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%14]
%17 = tensor.empty(%15, %16) : tensor<?x?x8x4xf32>
- %18 = iree_linalg_ext.pack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>} %12
- padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %17
- : (tensor<?x?xf32> tensor<?x?x8x4xf32>) -> tensor<?x?x8x4xf32>
+ %18 = tensor.pack %12 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %17
+ {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
+ : tensor<?x?xf32> -> tensor<?x?x8x4xf32>
%19 = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%6]
%20 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%7]
flow.dispatch.tensor.store %18, %11, offsets = [0, 0, 0, 0], sizes = [%19, %20, 8, 4], strides = [1, 1, 1, 1]
@@ -2070,8 +2070,9 @@
%9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c131072) : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
%10 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0], sizes = [%4, %5, 32, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x32x16xi32>>{%4, %5} -> tensor<?x?x32x16xi32>
%11 = tensor.empty(%6, %7) : tensor<?x?xi32>
- %12 = iree_linalg_ext.unpack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
- %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11 : (tensor<?x?x32x16xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+ %12 = tensor.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11
+ {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
+ : tensor<?x?x32x16xi32> -> tensor<?x?xi32>
flow.dispatch.tensor.store %12, %9, offsets = [0, 0], sizes = [%6, %7], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
return
}
@@ -2081,7 +2082,7 @@
// CHECK-LABEL: func.func @dynamic_unpack
// CHECK: scf.for
// CHECK: scf.for
-// CHECK: iree_linalg_ext.unpack
+// CHECK: tensor.unpack
// -----
@@ -2115,8 +2116,9 @@
%9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c131072) : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
%10 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0], sizes = [%4, %5, %c32, %c16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x?x?xi32>>{%4, %5, %c32, %c16} -> tensor<?x?x?x?xi32>
%11 = tensor.empty(%6, %7) : tensor<?x?xi32>
- %12 = iree_linalg_ext.unpack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
- %10 inner_dims_pos = [0, 1] inner_tiles = [%c32, %c16] into %11 : (tensor<?x?x?x?xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+ %12 = tensor.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [%c32, %c16] into %11
+ {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
+ : tensor<?x?x?x?xi32> -> tensor<?x?xi32>
flow.dispatch.tensor.store %12, %9, offsets = [0, 0], sizes = [%6, %7], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
return
}
@@ -2126,7 +2128,7 @@
// CHECK-LABEL: func.func @dynamic_unpack_dynamic_tile
// CHECK: scf.for
// CHECK: scf.for
-// CHECK: iree_linalg_ext.unpack
+// CHECK: tensor.unpack
// -----
@@ -2149,7 +2151,7 @@
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<128x384xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [16, 48, 8, 8], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x48x8x8xf32>> -> tensor<16x48x8x8xf32>
%3 = tensor.empty() : tensor<128x384xf32>
- %4 = iree_linalg_ext.unpack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>} %2 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %3 : (tensor<16x48x8x8xf32> tensor<128x384xf32>) -> tensor<128x384xf32>
+ %4 = tensor.unpack %2 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %3 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>} : tensor<16x48x8x8xf32> -> tensor<128x384xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<128x384xf32>) outs(%3 : tensor<128x384xf32>) {
^bb0(%in: f32, %out: f32):
%6 = arith.addf %in, %in : f32
@@ -2165,7 +2167,7 @@
// CHECK-LABEL: func.func @unpack_elem
// CHECK: scf.for
// CHECK: scf.for
-// CHECK: iree_linalg_ext.unpack
+// CHECK: tensor.unpack
// CHECK: linalg.generic
// -----
@@ -2205,7 +2207,7 @@
%13 = tensor.empty() : tensor<12544x16xi32>
%14 = tensor.empty() : tensor<12544x16xi32>
%15:2 = vmvx.query_tile_sizes sizes(%c12544, %c16) flags(1245184) -> index, index
- %16 = iree_linalg_ext.unpack {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16, 16]]>} %10 inner_dims_pos = [0, 1] inner_tiles = [%15#0, %15#1] into %14 : (tensor<?x?x?x?xi32> tensor<12544x16xi32>) -> tensor<12544x16xi32>
+ %16 = tensor.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [%15#0, %15#1] into %14 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[16, 16]]>} : tensor<?x?x?x?xi32> -> tensor<12544x16xi32>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst, %16, %11, %12 : tensor<16xi32>, tensor<12544x16xi32>, tensor<12544xi32>, tensor<16xi32>) outs(%13 : tensor<12544x16xi32>) {
^bb0(%in: i32, %in_0: i32, %in_1: i32, %in_2: i32, %out: i32):
%18 = arith.muli %in_1, %c-30_i32 : i32
@@ -2225,7 +2227,7 @@
// CHECK-LABEL: func.func @dynamic_unpack_fusion
// CHECK: scf.for
// CHECK: scf.for
-// CHECK: iree_linalg_ext.unpack
+// CHECK: tensor.unpack
// CHECK: tensor.extract_slice
// CHECK: linalg.generic
@@ -2273,7 +2275,7 @@
linalg.yield %23, %25 : f32, f32
} -> (tensor<384x512xf32>, tensor<384x512xf32>)
%17 = tensor.empty() : tensor<48x512x8x1xf32>
- %18 = iree_linalg_ext.pack {encoding = #iree_linalg_ext.encoding<MATMUL_F32F32F32_LHS>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 64]]>} %16#0 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %17 : (tensor<384x512xf32> tensor<48x512x8x1xf32>) -> tensor<48x512x8x1xf32>
+ %18 = tensor.pack %16#0 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %17 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 64]]>} : tensor<384x512xf32> -> tensor<48x512x8x1xf32>
flow.dispatch.tensor.store %18, %6, offsets = [0, 0, 0, 0], sizes = [48, 512, 8, 1], strides = [1, 1, 1, 1] : tensor<48x512x8x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<48x512x8x1xf32>>
flow.dispatch.tensor.store %16#0, %7, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : tensor<384x512xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x512xf32>>
flow.dispatch.tensor.store %16#1, %8, offsets = [0, 0], sizes = [384, 512], strides = [1, 1] : tensor<384x512xf32> -> !flow.dispatch.tensor<writeonly:tensor<384x512xf32>>
@@ -2286,7 +2288,7 @@
// CHECK: scf.for
// CHECK: scf.for
// CHECK: %[[ELEM:.+]]:2 = linalg.generic
-// CHECK: %[[PACK:.+]] = iree_linalg_ext.pack
+// CHECK: %[[PACK:.+]] = tensor.pack
// CHECK-DAG: flow.dispatch.tensor.store %[[PACK]], {{.*}} sizes = [8, 64, 8, 1]
// CHECK-DAG: flow.dispatch.tensor.store %[[ELEM]]#0, {{.*}} sizes = [64, 64]
// CHECK-DAG: flow.dispatch.tensor.store %[[ELEM]]#1, {{.*}} sizes = [64, 64]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
index 5efb99a..2be2eba 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -102,6 +102,7 @@
"@llvm-project//mlir:SCFToControlFlow",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:TensorDialect",
+ "@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:TosaToArith",
"@llvm-project//mlir:TransformDialect",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 91a49c1..8375bfa 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -80,6 +80,7 @@
MLIRSCFToControlFlow
MLIRSCFTransforms
MLIRTensorDialect
+ MLIRTensorTransforms
MLIRTosaDialect
MLIRTosaToArith
MLIRTransformDialect
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
index 23ca04c..21152ea 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUMaterializeEncodingPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -144,7 +145,7 @@
// dims ops.
{
RewritePatternSet patterns(context);
- populateFoldIntoPackAndUnpackOpsPatterns(patterns);
+ tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(operation, std::move(patterns)))) {
operation.emitOpError("folding patterns failed");
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index b5c41f5..341a56a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -670,9 +670,9 @@
addTileAndDistributePasses(passManager);
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
nestedModulePM.addNestedPass<func::FuncOp>(
- IREE::LinalgExt::createLinalgExtVectorizationPass());
- nestedModulePM.addNestedPass<func::FuncOp>(
createVectorizePackUnPackOpsPass());
+ nestedModulePM.addNestedPass<func::FuncOp>(
+ IREE::LinalgExt::createLinalgExtVectorizationPass());
addBufferizePasses(nestedModulePM);
nestedModulePM.addNestedPass<func::FuncOp>(
createSplitFullPartialTransferPass("linalg-copy"));
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tiling_pipeline.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tiling_pipeline.mlir
index 38812ee..884790a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tiling_pipeline.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tiling_pipeline.mlir
@@ -22,7 +22,7 @@
linalg.yield %7 : f32
} -> tensor<128x384xf32>
%5 = tensor.empty() : tensor<16x384x8x1xf32>
- %6 = iree_linalg_ext.pack %4 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %5 : (tensor<128x384xf32> tensor<16x384x8x1xf32>) -> tensor<16x384x8x1xf32>
+ %6 = tensor.pack %4 inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %5 : tensor<128x384xf32> -> tensor<16x384x8x1xf32>
flow.dispatch.tensor.store %6, %1, offsets = [0, 0, 0, 0], sizes = [16, 384, 8, 1], strides = [1, 1, 1, 1] : tensor<16x384x8x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<16x384x8x1xf32>>
return
}
@@ -32,5 +32,4 @@
// CHECK: func.func @elem_pack
// CHECK: %[[READ:.+]] = vector.transfer_read
// CHECK: %[[ADD:.+]] = arith.addf %[[READ]], %[[READ]]
-// CHECK: %[[BCAST:.+]] = vector.broadcast %[[ADD]]
-// CHECK: vector.transfer_write %[[BCAST]], %{{.+}}
+// CHECK: vector.transfer_write %[[ADD]], %{{.+}}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir
index 0d931c9..29049c4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_aarch64_launch_configuration.mlir
@@ -355,7 +355,7 @@
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4x48x8x1xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [20, 40], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<20x40xf32>> -> tensor<20x40xf32>
%3 = tensor.empty() : tensor<4x48x8x1xf32>
- %4 = iree_linalg_ext.pack %2 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %3 : (tensor<20x40xf32> tensor<4x48x8x1xf32>) -> tensor<4x48x8x1xf32>
+ %4 = tensor.pack %2 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %3 : tensor<20x40xf32> -> tensor<4x48x8x1xf32>
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [4, 48, 8, 1], strides = [1, 1, 1, 1] : tensor<4x48x8x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<4x48x8x1xf32>>
return
}
@@ -366,7 +366,7 @@
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDataTiling>
// CHECK: hal.executable.export public @pack
// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK: iree_linalg_ext.pack
+// CHECK: tensor.pack
// CHECK-SAME: lowering_config = #[[CONFIG]]
// -----
@@ -400,7 +400,7 @@
%9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c131072) : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
%10 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0], sizes = [%4, %5, 32, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x32x16xi32>>{%4, %5} -> tensor<?x?x32x16xi32>
%11 = tensor.empty(%6, %7) : tensor<?x?xi32>
- %12 = iree_linalg_ext.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11 : (tensor<?x?x32x16xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+ %12 = tensor.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11 : tensor<?x?x32x16xi32> -> tensor<?x?xi32>
flow.dispatch.tensor.store %12, %9, offsets = [0, 0], sizes = [%6, %7], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
return
}
@@ -411,5 +411,5 @@
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDataTiling>
// CHECK: hal.executable.export public @unpack_outer_dynamic
// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK: iree_linalg_ext.unpack
+// CHECK: tensor.unpack
// CHECK-SAME: lowering_config = #[[CONFIG]]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
index 0ef9a3f..ac523a8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_encoding.mlir
@@ -41,7 +41,7 @@
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<?x?x8x4xf32>>{%[[TILED_OUTD0]], %[[TILED_OUTD1]]}
// CHECK: %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[TILED_OUTD0]], %[[TILED_OUTD1]])
-// CHECK: %[[PACK:.+]] = iree_linalg_ext.pack
+// CHECK: %[[PACK:.+]] = tensor.pack
// CHECK-SAME: %[[INPUT]] padding_value(%[[CST]] : f32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[EMPTY]]
// CHECK: flow.dispatch.tensor.store %[[PACK]], %[[OUTPUT_BINDING]]
@@ -86,7 +86,7 @@
// CHECK: %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_D0]], %[[TILED_D1]], 8, 4], strides = [1, 1, 1, 1]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[OUTD0]], %[[OUTD1]])
-// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[INPUT]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[INPUT]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[EMPTY]]
// CHECK-DAG: flow.dispatch.tensor.store %[[UNPACK]], %[[OUTPUT_BINDING]]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_vmvx_launch_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_vmvx_launch_configuration.mlir
index 2c4747d..62e9e3f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_vmvx_launch_configuration.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_vmvx_launch_configuration.mlir
@@ -212,7 +212,7 @@
%9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c131072) : !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
%10 = flow.dispatch.tensor.load %8, offsets = [0, 0, 0, 0], sizes = [%4, %5, 32, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x32x16xi32>>{%4, %5} -> tensor<?x?x32x16xi32>
%11 = tensor.empty(%6, %7) : tensor<?x?xi32>
- %12 = iree_linalg_ext.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11 : (tensor<?x?x32x16xi32> tensor<?x?xi32>) -> tensor<?x?xi32>
+ %12 = tensor.unpack %10 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %11 : tensor<?x?x32x16xi32> -> tensor<?x?xi32>
flow.dispatch.tensor.store %12, %9, offsets = [0, 0], sizes = [%6, %7], strides = [1, 1] : tensor<?x?xi32> -> !flow.dispatch.tensor<writeonly:tensor<?x?xi32>>{%6, %7}
return
}
@@ -223,5 +223,5 @@
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<VMVXDefault>
// CHECK: hal.executable.export public @unpack_outer_dynamic
// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK: iree_linalg_ext.unpack
+// CHECK: tensor.unpack
// CHECK-SAME: lowering_config = #[[CONFIG]]
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
index 6f91452..07fae72 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/materialize_x86_64_launch_configuration.mlir
@@ -1450,7 +1450,7 @@
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x48x16x1xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [20, 40], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<20x40xf32>> -> tensor<20x40xf32>
%3 = tensor.empty() : tensor<2x48x16x1xf32>
- %4 = iree_linalg_ext.pack %2 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %3 : (tensor<20x40xf32> tensor<2x48x16x1xf32>) -> tensor<2x48x16x1xf32>
+ %4 = tensor.pack %2 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %3 : tensor<20x40xf32> -> tensor<2x48x16x1xf32>
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0, 0], sizes = [2, 48, 16, 1], strides = [1, 1, 1, 1] : tensor<2x48x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x48x16x1xf32>>
return
}
@@ -1461,7 +1461,7 @@
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDataTiling>
// CHECK: hal.executable.export public @pack
// CHECK-SAME: translation_info = #[[TRANSLATION]]
-// CHECK: iree_linalg_ext.pack
+// CHECK: tensor.pack
// CHECK-SAME: lowering_config = #[[CONFIG]]
// -----
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/BUILD b/compiler/src/iree/compiler/Codegen/VMVX/BUILD
index f585bd1..800c519 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/BUILD
+++ b/compiler/src/iree/compiler/Codegen/VMVX/BUILD
@@ -49,6 +49,7 @@
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:Transforms",
],
)
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt
index 776cabc..b4faa02 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/VMVX/CMakeLists.txt
@@ -34,6 +34,7 @@
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRPass
+ MLIRTensorTransforms
MLIRTransforms
iree::builtins::ukernel::exported_bits
iree::compiler::Codegen::Common
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/VMVXMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/VMVX/VMVXMaterializeEncodingPass.cpp
index 0af5408..2c25454 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/VMVXMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/VMVX/VMVXMaterializeEncodingPass.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -53,9 +54,10 @@
struct VMVXMaterializeEncodingPass
: public VMVXMaterializeEncodingBase<VMVXMaterializeEncodingPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithDialect, AffineDialect, IREE::Flow::FlowDialect,
- IREE::LinalgExt::IREELinalgExtDialect,
- IREE::VMVX::VMVXDialect>();
+ registry
+ .insert<arith::ArithDialect, AffineDialect, tensor::TensorDialect,
+ IREE::Flow::FlowDialect, IREE::LinalgExt::IREELinalgExtDialect,
+ IREE::VMVX::VMVXDialect>();
}
void runOnOperation() override;
};
@@ -101,7 +103,7 @@
// dims ops.
{
RewritePatternSet patterns(context);
- populateFoldIntoPackAndUnpackOpsPatterns(patterns);
+ tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(operation, std::move(patterns)))) {
operation.emitOpError("folding patterns failed");
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 90977cb..37f2a24 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -47,10 +48,10 @@
if (failed(materializeEncodingInfo)) {
return tensorType;
}
- return PackOp::getPackedType(tensorType,
- materializeEncodingInfo->innerTileSizes,
- materializeEncodingInfo->innerDimsPos,
- materializeEncodingInfo->outerDimsPerm)
+ return tensor::PackOp::inferPackedType(
+ tensorType, materializeEncodingInfo->innerTileSizes,
+ materializeEncodingInfo->innerDimsPos,
+ materializeEncodingInfo->outerDimsPerm)
.cast<RankedTensorType>();
}
@@ -116,7 +117,7 @@
/// Utility method to convert from `set_encoding` op to `pack` operation.
/// For now this takes a `paddingValue` as input. The source is also taken
/// as input so that these could be used with `OpConversionPatterns`.
-static FailureOr<PackOp> lowerSetEncodingOpToPackOp(
+static FailureOr<tensor::PackOp> lowerSetEncodingOpToPackOp(
RewriterBase &rewriter, SetEncodingOp encodingOp, Value source,
MaterializeEncodingFn materializeEncodingFn,
MaterializeEncodingValueFn materializeEncodingValueFn) {
@@ -139,14 +140,14 @@
Optional<TensorEncoding> encoding = getEncoding(resultType);
if (!encoding)
return failure();
- SmallVector<OpFoldResult> resultDims =
- PackOp::getResultShape(rewriter, loc, sourceDims, *innerTileSizesOfr,
- materializeEncodingInfo->innerDimsPos,
- materializeEncodingInfo->outerDimsPerm);
+ SmallVector<OpFoldResult> resultDims = tensor::PackOp::getResultShape(
+ rewriter, loc, sourceDims, *innerTileSizesOfr,
+ materializeEncodingInfo->innerDimsPos,
+ materializeEncodingInfo->outerDimsPerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
resultType.getElementType());
Optional<Value> paddingValue = getPaddingValue(source);
- auto packOp = rewriter.create<PackOp>(
+ auto packOp = rewriter.create<tensor::PackOp>(
loc, source, emptyOp, materializeEncodingInfo->innerDimsPos,
*innerTileSizesOfr, paddingValue, materializeEncodingInfo->outerDimsPerm);
// As we rewrite the SetEncoding and its old result tensor, which used to hold
@@ -165,7 +166,7 @@
/// Utility method to convert from `set_encoding` op to `pack` operation.
/// The source is taken as input so that these could be used with
/// `OpConversionPatterns`.
-static FailureOr<UnPackOp> lowerUnsetEncodingToUnpackOp(
+static FailureOr<tensor::UnPackOp> lowerUnsetEncodingToUnpackOp(
RewriterBase &rewriter, UnsetEncodingOp encodingOp, Value packedValue,
MaterializeEncodingFn materializeEncodingFn,
MaterializeEncodingValueFn materializeEncodingValueFn) {
@@ -188,7 +189,7 @@
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
}
- return rewriter.create<UnPackOp>(
+ return rewriter.create<tensor::UnPackOp>(
loc, packedValue, emptyOp, materializeEncodingInfo->innerDimsPos,
*innerTileSizesOfr, materializeEncodingInfo->outerDimsPerm);
}
@@ -303,7 +304,7 @@
if (failed(packOp))
return rewriter.notifyMatchFailure(encodingOp,
"failed to convert to pack op");
- rewriter.replaceOp(encodingOp, packOp->getResults());
+ rewriter.replaceOp(encodingOp, packOp->getResult());
return success();
}
};
@@ -327,7 +328,7 @@
if (failed(unpackOp))
return rewriter.notifyMatchFailure(encodingOp,
"failed to convert to unpack op");
- rewriter.replaceOp(encodingOp, unpackOp->getResults());
+ rewriter.replaceOp(encodingOp, unpackOp->getResult());
return success();
}
};
@@ -402,10 +403,11 @@
return signalPassFailure();
}
- // Add patterns to fold pack/unpack ops with pad/extract_slice ops.
+ // Add patterns to fold tensor.pack/unpack ops with tensor.pad/extract_slice
+ // ops.
{
RewritePatternSet patterns(context);
- populateFoldIntoPackAndUnpackOpsPatterns(patterns);
+ tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
index 0a069d4..9dec445 100644
--- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
+++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/materialize_encoding.mlir
@@ -16,10 +16,10 @@
// CHECK-DAG: %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
// CHECK: %[[PACK_DEST:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor<?x?x8x4xf32>
-// CHECK: %[[PACK:.+]] = iree_linalg_ext.pack
+// CHECK: %[[PACK:.+]] = tensor.pack
// CHECK-SAME: %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[PACK_DEST]]
// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor<?x?xf32>
-// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[PACK]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[UNPACK_DEST]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %[[UNPACK_DEST]]
// CHECK: return %[[UNPACK]]
// -----
@@ -30,9 +30,9 @@
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func @pack_unpack_gemm_rhs(
-// CHECK: linalg_ext.pack
+// CHECK: tensor.pack
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4]
-// CHECK: linalg_ext.unpack %{{.+}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4]
+// CHECK: tensor.unpack %{{.+}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 4]
// -----
@@ -42,9 +42,9 @@
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func @pack_unpack_gemm_result(
-// CHECK: linalg_ext.pack
+// CHECK: tensor.pack
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 8]
-// CHECK: linalg_ext.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 8]
+// CHECK: tensor.unpack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [8, 8]
// -----
@@ -77,21 +77,21 @@
// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.0
// CHECK: %[[INIT_LHS:.+]] = tensor.empty() : tensor<13x63x8x4xf32>
-// CHECK: %[[PACK_LHS:.+]] = iree_linalg_ext.pack
+// CHECK: %[[PACK_LHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG0]] padding_value(%[[CST]] : f32)
// CHECK-SAME: into %[[INIT_LHS]]
// CHECK: %[[INIT_RHS:.+]] = tensor.empty() : tensor<63x63x8x4xf32>
-// CHECK: %[[PACK_RHS:.+]] = iree_linalg_ext.pack
+// CHECK: %[[PACK_RHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG1]] padding_value(%[[CST]] : f32)
// CHECK-SAME: into %[[INIT_RHS]]
// CHECK: %[[INIT_RESULT:.+]] = tensor.empty() : tensor<13x63x8x8xf32>
-// CHECK: %[[PACK_RESULT:.+]] = iree_linalg_ext.pack
+// CHECK: %[[PACK_RESULT:.+]] = tensor.pack
// CHECK-SAME: %[[ARG2]] padding_value(%[[CST]] : f32)
// CHECK-SAME: into %[[INIT_RESULT]]
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
// CHECK-SAME: outs(%[[PACK_RESULT]] :
-// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[MMT4D]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]]
// CHECK: return %[[UNPACK]]
// -----
@@ -111,16 +111,16 @@
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK: %[[PACK_LHS:.+]] = iree_linalg_ext.pack
+// CHECK: %[[PACK_LHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG0]]
-// CHECK: %[[PACK_RHS:.+]] = iree_linalg_ext.pack
+// CHECK: %[[PACK_RHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG1]]
-// CHECK: %[[PACK_RESULT:.+]] = iree_linalg_ext.pack
+// CHECK: %[[PACK_RESULT:.+]] = tensor.pack
// CHECK-SAME: %[[ARG2]]
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
// CHECK-SAME: outs(%[[PACK_RESULT]] :
-// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[MMT4D]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]]
// CHECK: return %[[UNPACK]]
// -----
@@ -152,8 +152,8 @@
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[OUT_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
// CHECK-DAG: %[[OUT_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
-// CHECK-DAG: %[[PACK_LHS:.+]] = iree_linalg_ext.pack {{.*}}%[[ARG0]]
-// CHECK: %[[PACK_RHS:.+]] = iree_linalg_ext.pack
+// CHECK-DAG: %[[PACK_LHS:.+]] = tensor.pack {{.*}}%[[ARG0]]
+// CHECK: %[[PACK_RHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG1]]
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]], %[[OUT_D1]]) : tensor<?x?x8x8xf32>
// CHECK: %[[FILL:.+]] = linalg.fill
@@ -161,5 +161,5 @@
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
// CHECK-SAME: outs(%[[FILL]] :
-// CHECK: %[[UNPACK:.+]] = iree_linalg_ext.unpack %[[MMT4D]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]]
// CHECK: return %[[UNPACK]]