[Codegen][CPU] Add MaterializeEncoding conversions for parallel generic ops (#18071)
This PR is needed after https://github.com/iree-org/iree/pull/18063,
because broadcasting ops can now have encoded tensors. This PR extends
the materialization pattern for linalg.generic ops to work for any fully
parallel linalg.generic op (compared to only unary identity generic ops
before). The PR also extends the type converter to handle `bcast_map` by
using the actual encoded tensor type instead of the original_type.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
index e9c07d2..6dd1983 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
@@ -2789,3 +2789,226 @@
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[BATCH_MMT4D]] {{.*}} : tensor<32x1x128x1x32xi32> into tensor<32x128x32xi32>
// CHECK-DAG: %[[UNPACK_DEST:.+]] = tensor.empty() : tensor<4096x32xi32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[COLLAPSE]] outer_dims_perm = [1, 0] inner_dims_pos = [0] inner_tiles = [32] into %[[UNPACK_DEST]] : tensor<32x128x32xi32> -> tensor<4096x32xi32>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>,
+ #hal.descriptor_set.binding<2, storage_buffer>,
+ #hal.descriptor_set.binding<3, storage_buffer>
+ ]>
+]>
+#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+#encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+func.func @dequantization() attributes {
+ hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x128x64xi8, #encoding>>
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
+ %6 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+ %7 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x128x64xi8, #encoding>> -> tensor<2x128x64xi8, #encoding>
+ %8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
+ %9 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
+ %13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
+ %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %8, %9 : tensor<2x128x64xi8, #encoding>, tensor<2x64xf32, #encoding_bcast>, tensor<2x64xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
+ ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
+ %21 = arith.extui %in : i8 to i32
+ %22 = arith.uitofp %21 : i32 to f32
+ %23 = arith.subf %22, %in_1 : f32
+ %24 = arith.mulf %23, %in_0 : f32
+ linalg.yield %24 : f32
+ } -> tensor<2x128x64xf32, #encoding>
+ flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CHECK-LABEL: func.func @dequantization()
+// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x128x16x1xi8>>
+// CHECK-DAG: %[[LHS_SCALES_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>>
+// CHECK-DAG: %[[LHS_ZPS_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(2) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>>
+// CHECK-DAG: %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(3) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x128x16x1xi8>> -> tensor<2x4x128x16x1xi8>
+// CHECK-DAG: %[[LHS_SCALES:.+]] = flow.dispatch.tensor.load %[[LHS_SCALES_BINDING]], offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> -> tensor<2x4x16xf32>
+// CHECK-DAG: %[[LHS_ZPS:.+]] = flow.dispatch.tensor.load %[[LHS_ZPS_BINDING]], offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> -> tensor<2x4x16xf32>
+// CHECK-DAG: %[[EMPTY_LHS:.+]] = tensor.empty() : tensor<2x4x128x16x1xf32>
+// CHECK-DAG: %[[LHS_DEQUANT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP1]], #[[$MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[LHS]], %[[LHS_SCALES]], %[[LHS_ZPS]] : tensor<2x4x128x16x1xi8>, tensor<2x4x16xf32>, tensor<2x4x16xf32>)
+// CHECK-SAME: outs(%[[EMPTY_LHS]] : tensor<2x4x128x16x1xf32>)
+// CHECK: arith.extui
+// CHECK: arith.uitofp
+// CHECK: arith.subf
+// CHECK: arith.mulf
+// CHECK: flow.dispatch.tensor.store %[[LHS_DEQUANT]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : tensor<2x4x128x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+#encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+func.func @broadcast_batch() attributes {
+ hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x64xf32, #encoding_bcast>>
+ %6 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+ %8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x64xf32, #encoding_bcast>> -> tensor<128x64xf32, #encoding_bcast>
+ %13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
+ %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8 : tensor<128x64xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<2x128x64xf32, #encoding>
+ flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @broadcast_batch()
+// CHECK-DAG: %[[INPUT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<4x128x16x1xf32>>
+// CHECK-DAG: %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+// CHECK-DAG: %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]], offsets = [0, 0, 0, 0], sizes = [4, 128, 16, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x128x16x1xf32>> -> tensor<4x128x16x1xf32>
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4x128x16x1xf32>
+// CHECK-DAG: %[[BROADCAST:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<4x128x16x1xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x4x128x16x1xf32>)
+// CHECK: flow.dispatch.tensor.store %[[BROADCAST]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : tensor<2x4x128x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+#encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1)>, round_dims_to = array<i64: 16, 16, 16>>
+func.func @broadcast_M() attributes {
+ hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x128xf32, #encoding_bcast>>
+ %6 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+ %8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x128xf32, #encoding_bcast>> -> tensor<2x128xf32, #encoding_bcast>
+ %13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
+ %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8 : tensor<2x128xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<2x128x64xf32, #encoding>
+ flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @broadcast_M()
+// CHECK-DAG: %[[INPUT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x128x1xf32>>
+// CHECK-DAG: %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+// CHECK-DAG: %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]], offsets = [0, 0, 0], sizes = [2, 128, 1], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x128x1xf32>> -> tensor<2x128x1xf32>
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4x128x16x1xf32>
+// CHECK-DAG: %[[BROADCAST:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<2x128x1xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x4x128x16x1xf32>)
+// CHECK: flow.dispatch.tensor.store %[[BROADCAST]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : tensor<2x4x128x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+#encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+#encoding_bcast = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+func.func @broadcast_N() attributes {
+ hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
+ %6 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+ %8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
+ %13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
+ %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8 : tensor<2x64xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<2x128x64xf32, #encoding>
+ flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @broadcast_N()
+// CHECK-DAG: %[[INPUT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x64x1xf32>>
+// CHECK-DAG: %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x8x64x16x1xf32>>
+// CHECK-DAG: %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]], offsets = [0, 0, 0], sizes = [2, 64, 1], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64x1xf32>> -> tensor<2x64x1xf32>
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x64x16x1xf32>
+// CHECK-DAG: %[[BROADCAST:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<2x64x1xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x8x64x16x1xf32>)
+// CHECK: flow.dispatch.tensor.store %[[BROADCAST]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 8, 64, 16, 1], strides = [1, 1, 1, 1, 1] : tensor<2x8x64x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x8x64x16x1xf32>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+]>
+#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+#encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+func.func @broadcast_K() attributes {
+ hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
+ %6 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+ %8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
+ %13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
+ %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8 : tensor<2x64xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<2x128x64xf32, #encoding>
+ flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+ return
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @broadcast_K()
+// CHECK-DAG: %[[INPUT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>>
+// CHECK-DAG: %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+// CHECK-DAG: %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]], offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> -> tensor<2x4x16xf32>
+// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4x128x16x1xf32>
+// CHECK-DAG: %[[BROADCAST:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[INPUT]] : tensor<2x4x16xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x4x128x16x1xf32>)
+// CHECK: flow.dispatch.tensor.store %[[BROADCAST]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : tensor<2x4x128x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
index 9d60646..ce42de1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
@@ -113,9 +113,7 @@
return dropEncoding(tensorType);
}
return cast<RankedTensorType>(tensor::PackOp::inferPackedType(
- getOriginalTypeWithEncoding(maybeTransposedTensorType)
- .clone(tensorType.getElementType()),
- materializeEncodingInfo->innerTileSizes,
+ maybeTransposedTensorType, materializeEncodingInfo->innerTileSizes,
materializeEncodingInfo->innerDimsPos,
materializeEncodingInfo->outerDimsPerm));
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
index c3fec88..99fb0d3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -22,9 +22,9 @@
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
-#include "mlir/Support/LLVM.h"
namespace mlir::iree_compiler {
@@ -294,8 +294,8 @@
}
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
- materializeEncodingFn(getOriginalTypeWithEncoding(
- cast<RankedTensorType>(linalgOp->getResultTypes()[0])));
+ materializeEncodingFn(
+ cast<RankedTensorType>(linalgOp->getResultTypes()[0]));
Operation *result;
if (failed(materializeEncodingInfo)) {
@@ -345,21 +345,19 @@
MaterializeEncodingFn materializeEncodingFn,
MaterializeEncodingValueFn materializeEncodingValueFn) {
auto emptyType = cast<RankedTensorType>(emptyOp->getResultTypes()[0]);
- auto resultType =
- getOriginalTypeWithEncoding(emptyType).clone(emptyType.getElementType());
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
- materializeEncodingFn(resultType);
+ materializeEncodingFn(emptyType);
Location loc = emptyOp.getLoc();
if (failed(materializeEncodingInfo)) {
Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
- loc, emptyOp.getMixedSizes(), resultType.getElementType());
+ loc, emptyOp.getMixedSizes(), emptyType.getElementType());
return newEmptyOp;
}
if (isNarrowNResult(IREE::Encoding::getEncodingAttr(emptyType))) {
transposeInPlace(*materializeEncodingInfo);
}
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
- getInnerTileSizesOfr(rewriter, loc, resultType, *materializeEncodingInfo,
+ getInnerTileSizesOfr(rewriter, loc, emptyType, *materializeEncodingInfo,
materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
@@ -372,16 +370,110 @@
materializeEncodingInfo->innerDimsPos,
materializeEncodingInfo->outerDimsPerm);
Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
- loc, newShape, resultType.getElementType());
+ loc, newShape, emptyType.getElementType());
return newEmptyOp;
}
+/// Converts a linalg::GenericOp with encoded inputs into the packed domain.
+/// The `genericOp` must have all parallel iterator types and a single output
+/// with an identity indexing map.
+static FailureOr<Operation *>
+lowerGenericOpWithEncoding(RewriterBase &rewriter, linalg::GenericOp genericOp,
+ ValueRange convertedInputOperands,
+ ValueRange convertedOutputOperands,
+ MaterializeEncodingFn materializeEncodingFn) {
+ if (!genericOp.hasPureTensorSemantics()) {
+ return failure();
+ }
+ if (genericOp.getNumReductionLoops() != 0) {
+ return rewriter.notifyMatchFailure(genericOp, "Loops are not all parallel");
+ }
+ if (genericOp.getNumDpsInits() != 1) {
+ return rewriter.notifyMatchFailure(genericOp, "Not only 1 init operand");
+ }
+ OpOperand *outputOperand = genericOp.getDpsInitOperand(0);
+ AffineMap outputMap = genericOp.getMatchingIndexingMap(outputOperand);
+ if (!outputMap.isIdentity()) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "Output indexing map is not identity");
+ }
+ FailureOr<MaterializeEncodingInfo> outMaterializeEncodingInfo =
+ materializeEncodingFn(
+ cast<RankedTensorType>(outputOperand->get().getType()));
+ if (failed(outMaterializeEncodingInfo)) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "MaterializeEncodingInfo failed for output");
+ }
+
+ auto convertedResultType =
+ cast<RankedTensorType>(convertedOutputOperands[0].getType());
+ SmallVector<utils::IteratorType> iteratorTypes(convertedResultType.getRank(),
+ utils::IteratorType::parallel);
+ // Compute the new indexing maps for the packed layout. This assumes that
+ // the output map is identity, and that all iterator types are parallel.
+ SmallVector<int64_t> outInnerDimsPos =
+ outMaterializeEncodingInfo->innerDimsPos;
+ SmallVector<int64_t> outInverseOuterDimsPerm =
+ invertPermutationVector(outMaterializeEncodingInfo->outerDimsPerm);
+ SmallVector<AffineMap> packedIndexingMaps;
+ for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
+ FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
+ materializeEncodingFn(
+ cast<RankedTensorType>(inputOperand->get().getType()));
+ if (failed(materializeEncodingInfo)) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "MaterializeEncodingInfo failed for input");
+ }
+ SmallVector<int64_t> innerDimsPos = materializeEncodingInfo->innerDimsPos;
+ SmallVector<int64_t> outerDimsPerm = materializeEncodingInfo->outerDimsPerm;
+ AffineMap inputMap = genericOp.getMatchingIndexingMap(inputOperand);
+ // Permute result dims to the input packed domain, and map dims to the
+ // output packed domain.
+ SmallVector<int64_t> packedResultDims = llvm::map_to_vector(
+ applyPermutation(inputMap.getResults(), outerDimsPerm),
+ [&](AffineExpr expr) {
+ auto dimExpr = cast<AffineDimExpr>(expr);
+ return outInverseOuterDimsPerm[dimExpr.getPosition()];
+ });
+ // Add new dims for the inner tiles, taking the dim position from the
+ // corresponding inner tile of the init operand.
+ for (auto [idx, pos] : llvm::enumerate(innerDimsPos)) {
+ auto dimPos = cast<AffineDimExpr>(inputMap.getResult(pos)).getPosition();
+ for (auto [tileIdx, outDim] : llvm::enumerate(outInnerDimsPos)) {
+ if (dimPos == outDim) {
+ packedResultDims.push_back(outputMap.getNumDims() + tileIdx);
+ }
+ }
+ }
+ // Create the packed indexing map.
+ SmallVector<AffineExpr> packedResultExprs =
+ llvm::map_to_vector(packedResultDims, [&](int64_t dim) {
+ return rewriter.getAffineDimExpr(dim);
+ });
+ auto packedInputMap = AffineMap::get(
+ /*dimCount=*/iteratorTypes.size(), /*symbolCount=*/0, packedResultExprs,
+ rewriter.getContext());
+ packedIndexingMaps.push_back(packedInputMap);
+ }
+ // Create the new packed identity map for the output.
+ packedIndexingMaps.push_back(
+ rewriter.getMultiDimIdentityMap(convertedResultType.getRank()));
+ auto materializedGenericOp = rewriter.create<linalg::GenericOp>(
+ genericOp.getLoc(), convertedResultType, convertedInputOperands,
+ convertedOutputOperands, packedIndexingMaps, iteratorTypes,
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+ rewriter.inlineRegionBefore(genericOp.getRegion(),
+ materializedGenericOp.getRegion(),
+ materializedGenericOp.getRegion().begin());
+ return materializedGenericOp.getOperation();
+}
+
/// Utility method to convert from a linalg::LinalgOp on `tensor` types with
/// encodings to a linalg::LinalgOp on the materialized type. The current
/// supported op types are:
/// - linalg::LinalgOp that `isaContractionOpInterface`
/// - linalg::FillOp
-/// - element-wise linalg::GenericOp with single input and output
+/// - linalg::GenericOp with parallel iterators and a single output
static FailureOr<Operation *> lowerOpWithEncoding(
RewriterBase &rewriter, linalg::LinalgOp linalgOp,
ValueRange convertedInputOperands, ValueRange convertedOutputOperands,
@@ -406,36 +498,12 @@
convertedInputOperands, convertedOutputOperands);
return materializedFillOp;
})
- .Case<linalg::GenericOp>([&](linalg::GenericOp genericOp)
- -> FailureOr<Operation *> {
- if (!genericOp.hasPureTensorSemantics() || !isElementwise(genericOp) ||
- genericOp.getNumDpsInputs() != 1 ||
- genericOp.getNumDpsInits() != 1) {
- return rewriter.notifyMatchFailure(
- genericOp, "linalg.generic op is not elementwise "
- "with single input and single output");
- }
- if (!llvm::all_of(genericOp.getIndexingMapsArray(),
- [](AffineMap m) { return m.isIdentity(); })) {
- return rewriter.notifyMatchFailure(
- genericOp, "indexing maps are not all identity maps");
- }
- auto convertedResultType =
- cast<RankedTensorType>(convertedOutputOperands[0].getType());
- SmallVector<AffineMap> maps(
- 2, AffineMap::getMultiDimIdentityMap(convertedResultType.getRank(),
- rewriter.getContext()));
- SmallVector<utils::IteratorType> iteratorTypes(
- convertedResultType.getRank(), utils::IteratorType::parallel);
- auto materializedGenericOp = rewriter.create<linalg::GenericOp>(
- genericOp.getLoc(), convertedResultType, convertedInputOperands,
- convertedOutputOperands, maps, iteratorTypes,
- /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
- rewriter.inlineRegionBefore(genericOp.getRegion(),
- materializedGenericOp.getRegion(),
- materializedGenericOp.getRegion().begin());
- return materializedGenericOp.getOperation();
- })
+ .Case<linalg::GenericOp>(
+ [&](linalg::GenericOp genericOp) -> FailureOr<Operation *> {
+ return lowerGenericOpWithEncoding(
+ rewriter, genericOp, convertedInputOperands,
+ convertedOutputOperands, materializeEncodingFn);
+ })
.Default([](Operation *op) { return failure(); });
}
@@ -454,9 +522,6 @@
return failure();
}
- RankedTensorType originalTensorType =
- getOriginalTypeWithEncoding(boundTensorType);
-
MaterializeEncodingFn materializeEncodingFn =
typeConverter.getMaterializeEncodingFn();
FailureOr<MaterializeEncodingInfo> encodingInfo =
@@ -469,10 +534,9 @@
}
SmallVector<OpFoldResult> targetShape =
- getMixedValues(originalTensorType.getShape(), dynamicDims, builder);
- auto innerTileSizes =
- getInnerTileSizesOfr(builder, loc, originalTensorType, *encodingInfo,
- materializeEncodingValueFn);
+ getMixedValues(boundTensorType.getShape(), dynamicDims, builder);
+ auto innerTileSizes = getInnerTileSizesOfr(
+ builder, loc, boundTensorType, *encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizes)) {
return failure();
}