[DataTiling] Add supports for materializing elementwise ops. (#15507)
This PR adds a conversion case for element-wise `linalg.generic` ops and
fixes a bug where the element type of `original_type` was used for
conversion in some places.
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
index b231108..3d00bb7 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
@@ -1169,3 +1169,76 @@
// CHECK-SAME: outs(%[[OUTS]] :
// CHECK: flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 16, 16], strides = [1, 1, 1, 1]
+
+// -----
+
+func.func @extend_batch_vecmat(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {
+ hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
+} {
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c128 = arith.constant 128 : index
+ %c11008 = arith.constant 11008 : index
+ %c0_i8 = arith.constant 0 : i8
+ %c0_i32 = arith.constant 0 : i32
+ %0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<32x1x128xi8>
+ %1 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<32x128x11008xi8>
+ %padded = tensor.pad %0 low[0, 0, 0] high[%c0, %c0, %c0] {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index):
+ tensor.yield %c0_i8 : i8
+ } : tensor<32x1x128xi8> to tensor<?x?x?xi8>
+ %4 = iree_linalg_ext.set_encoding %padded : tensor<?x?x?xi8> -> tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>
+ %5 = tensor.empty(%c32, %c1, %c128) : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>
+ %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4 : tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>) outs(%5 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>) {
+ ^bb0(%in: i8, %out: i32):
+ %17 = arith.extsi %in : i8 to i32
+ linalg.yield %17 : i32
+ } -> tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>
+ %padded_0 = tensor.pad %1 low[0, 0, 0] high[%c0, %c0, %c0] {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index):
+ tensor.yield %c0_i8 : i8
+ } : tensor<32x128x11008xi8> to tensor<?x?x?xi8>
+ %7 = iree_linalg_ext.set_encoding %padded_0 : tensor<?x?x?xi8> -> tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>
+ %8 = tensor.empty(%c32, %c128, %c11008) : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>
+ %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7 : tensor<?x?x?xi8, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>) outs(%8 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>) {
+ ^bb0(%in: i8, %out: i32):
+ %17 = arith.extsi %in : i8 to i32
+ linalg.yield %17 : i32
+ } -> tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>
+ %10 = tensor.empty(%c32, %c1, %c11008) : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>>
+ %11 = linalg.fill ins(%c0_i32 : i32) outs(%10 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>>) -> tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>>
+ %12 = linalg.batch_matmul ins(%6, %9 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = LHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x128xi8>>>, tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RHS, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x128x11008xi8>>>) outs(%11 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>>) -> tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>>
+ %13 = iree_linalg_ext.unset_encoding %12 : tensor<?x?x?xi32, #iree_linalg_ext.encoding<user = BATCH_MATMUL, role = RESULT, element_types = [i8, i8, i32], matmul_narrow_M = 1 : index, original_type = tensor<32x1x11008xi32>>> -> tensor<?x?x?xi32>
+ %extracted_slice = tensor.extract_slice %13[0, 0, 0] [32, 1, 11008] [1, 1, 1] : tensor<?x?x?xi32> to tensor<32x1x11008xi32>
+ %16 = hal.tensor.export %extracted_slice "output 0" : tensor<32x1x11008xi32> -> !hal.buffer_view
+ return %16 : !hal.buffer_view
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: func @extend_batch_vecmat(%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -> !hal.buffer_view attributes
+// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK: %[[LHS:.+]] = hal.tensor.import %[[ARG0]] "input 0" : !hal.buffer_view -> tensor<32x1x128xi8>
+// CHECK: %[[RHS:.+]] = hal.tensor.import %[[ARG1]] "input 1" : !hal.buffer_view -> tensor<32x128x11008xi8>
+// CHECK: %[[INIT_LHS_PACK:.+]] = tensor.empty() : tensor<32x1x64x1x2xi8>
+// CHECK: %[[LHS_PACK:.+]] = tensor.pack %[[LHS]] padding_value(%[[C0_I8]] : i8) inner_dims_pos = [1, 2] inner_tiles = [1, 2] into %[[INIT_LHS_PACK]] : tensor<32x1x128xi8> -> tensor<32x1x64x1x2xi8>
+// CHECK: %[[INIT_LHS_EXT:.+]] = tensor.empty() : tensor<32x1x64x1x2xi32>
+// CHECK: %[[LHS_EXT:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[LHS_PACK]] : tensor<32x1x64x1x2xi8>) outs(%[[INIT_LHS_EXT]] : tensor<32x1x64x1x2xi32>) {
+// CHECK-NEXT: ^bb0(%[[LHS_EXT_ARG_IN:.+]]: i8, %[[LHS_EXT_ARG_OUT:.+]]: i32):
+// CHECK-NEXT: %[[LHS_EXT_OP:.+]] = arith.extsi %[[LHS_EXT_ARG_IN]] : i8 to i32
+// CHECK-NEXT: linalg.yield %[[LHS_EXT_OP]] : i32
+// CHECK: %[[INIT_RHS_PACK:.+]] = tensor.empty() : tensor<32x688x64x16x2xi8>
+// CHECK: %[[RHS_PACK:.+]] = tensor.pack %[[RHS]] padding_value(%[[C0_I8]] : i8) outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %[[INIT_RHS_PACK]] : tensor<32x128x11008xi8> -> tensor<32x688x64x16x2xi8>
+// CHECK: %[[INIT_RHS_EXT:.+]] = tensor.empty() : tensor<32x688x64x16x2xi32>
+// CHECK: %[[RHS_EXT:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[RHS_PACK]] : tensor<32x688x64x16x2xi8>) outs(%[[INIT_RHS_EXT]] : tensor<32x688x64x16x2xi32>) {
+// CHECK-NEXT: ^bb0(%[[RHS_EXT_ARG_IN:.+]]: i8, %[[RHS_EXT_ARG_OUT:.+]]: i32):
+// CHECK-NEXT: %[[RHS_EXT_OP:.+]] = arith.extsi %[[RHS_EXT_ARG_IN]] : i8 to i32
+// CHECK-NEXT: linalg.yield %[[RHS_EXT_OP]] : i32
+// CHECK: %[[INIT_FILL:.+]] = tensor.empty() : tensor<32x1x688x1x16xi32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0_I32]] : i32) outs(%[[INIT_FILL]] : tensor<32x1x688x1x16xi32>) -> tensor<32x1x688x1x16xi32>
+// CHECK: %[[MMT4D:.+]] = linalg.batch_mmt4d ins(%[[LHS_EXT]], %[[RHS_EXT]] : tensor<32x1x64x1x2xi32>, tensor<32x688x64x16x2xi32>) outs(%[[FILL]] : tensor<32x1x688x1x16xi32>) -> tensor<32x1x688x1x16xi32>
+// CHECK: %[[INIT_UNPACK:.+]] = tensor.empty() : tensor<32x1x11008xi32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[MMT4D]] inner_dims_pos = [1, 2] inner_tiles = [1, 16] into %[[INIT_UNPACK]] : tensor<32x1x688x1x16xi32> -> tensor<32x1x11008xi32>
+// CHECK: %[[EXPORT:.+]] = hal.tensor.export %[[UNPACK]] "output 0" : tensor<32x1x11008xi32> -> !hal.buffer_view
+// CHECK: return %[[EXPORT]] : !hal.buffer_view
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
index 3ed53a5..a73e2d3 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/MaterializeEncoding.cpp
@@ -66,7 +66,8 @@
return dropEncoding(tensorType);
}
return tensor::PackOp::inferPackedType(
- getOriginalTypeWithEncoding(tensorType),
+ getOriginalTypeWithEncoding(tensorType)
+ .clone(tensorType.getElementType()),
materializeEncodingInfo->innerTileSizes,
materializeEncodingInfo->innerDimsPos,
materializeEncodingInfo->outerDimsPerm)
@@ -361,8 +362,9 @@
ValueRange convertedOperands,
MaterializeEncodingFn materializeEncodingFn,
MaterializeEncodingValueFn materializeEncodingValueFn) {
- auto resultType = getOriginalTypeWithEncoding(
- emptyOp->getResultTypes()[0].cast<RankedTensorType>());
+ auto emptyType = emptyOp->getResultTypes()[0].cast<RankedTensorType>();
+ auto resultType =
+ getOriginalTypeWithEncoding(emptyType).clone(emptyType.getElementType());
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
materializeEncodingFn(resultType);
Location loc = emptyOp.getLoc();
@@ -391,6 +393,41 @@
return newEmptyOp;
}
+/// Utility method to convert from `linalg.generic` on `tensor` type with
+/// encoding to `linalg.generic` on the materialized type
+static FailureOr<Operation *>
+lowerOpWithEncoding(RewriterBase &rewriter, linalg::GenericOp genericOp,
+ ValueRange convertedInputOperands,
+ ValueRange convertedOutputOperands, MaterializeEncodingFn,
+ MaterializeEncodingValueFn) {
+ if (!genericOp.hasTensorSemantics() || !isElementwise(genericOp) ||
+ genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "linalg.generic op is not elementwise "
+ "with single input and single output");
+ }
+ if (!llvm::all_of(genericOp.getIndexingMapsArray(),
+ [](AffineMap m) { return m.isIdentity(); })) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "indexing maps are not all identity maps");
+ }
+ auto convertedResultType =
+ convertedOutputOperands[0].getType().cast<RankedTensorType>();
+ SmallVector<AffineMap> maps(
+ 2, AffineMap::getMultiDimIdentityMap(convertedResultType.getRank(),
+ rewriter.getContext()));
+ SmallVector<utils::IteratorType> iteratorTypes(convertedResultType.getRank(),
+ utils::IteratorType::parallel);
+ auto materializedGenericOp = rewriter.create<linalg::GenericOp>(
+ genericOp.getLoc(), convertedResultType, convertedInputOperands,
+ convertedOutputOperands, maps, iteratorTypes,
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+ rewriter.inlineRegionBefore(genericOp.getRegion(),
+ materializedGenericOp.getRegion(),
+ materializedGenericOp.getRegion().begin());
+ return materializedGenericOp.getOperation();
+}
+
namespace {
//===---------------------------------------------------------------------===//
// Patterns to lower ops with encodings. These are written as
@@ -491,7 +528,7 @@
MaterializeEncodingFn materializeEncodingFn;
};
-/// Generic pattern to convert operaiton that is in Destination Passing Style.
+/// Generic pattern to convert operation that is in Destination Passing Style.
template <typename OpTy>
struct MaterializeDPSOperation : public OpMaterializeEncodingPattern<OpTy> {
using OpMaterializeEncodingPattern<OpTy>::OpMaterializeEncodingPattern;
@@ -633,6 +670,7 @@
patterns.insert<MaterializeDPSOperation<linalg::FillOp>,
MaterializeDPSOperation<linalg::MatmulOp>,
MaterializeDPSOperation<linalg::BatchMatmulOp>,
+ MaterializeDPSOperation<linalg::GenericOp>,
MaterializeOperation<tensor::EmptyOp>,
SetEncodingOpToPackOpConversion,
UnsetEncodingOpToUnPackOpConversion>(