[Codegen][CPU] Add MaterializeEncoding conversions for parallel generic ops (#18071)

This PR is needed after https://github.com/iree-org/iree/pull/18063,
because broadcasting ops can now have encoded tensors. This PR extends
the materialization pattern for linalg.generic ops to work for any fully
parallel linalg.generic op (compared to only unary identity generic ops
before). The PR also extends the type converter to handle `bcast_map` by
using the actual encoded tensor type instead of the original_type.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
index e9c07d2..6dd1983 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
@@ -2789,3 +2789,226 @@
 //       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[BATCH_MMT4D]] {{.*}} : tensor<32x1x128x1x32xi32> into tensor<32x128x32xi32>
 //   CHECK-DAG:   %[[UNPACK_DEST:.+]] = tensor.empty() : tensor<4096x32xi32>
 //       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[COLLAPSE]] outer_dims_perm = [1, 0] inner_dims_pos = [0] inner_tiles = [32] into %[[UNPACK_DEST]] : tensor<32x128x32xi32> -> tensor<4096x32xi32>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>,
+    #hal.descriptor_set.binding<2, storage_buffer>,
+    #hal.descriptor_set.binding<3, storage_buffer>
+  ]>
+]>
+#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+#encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+func.func @dequantization() attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x128x64xi8, #encoding>>
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
+  %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
+  %6 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+  %7 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x128x64xi8, #encoding>> -> tensor<2x128x64xi8, #encoding>
+  %8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
+  %9 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
+  %13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
+  %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%7, %8, %9 : tensor<2x128x64xi8, #encoding>, tensor<2x64xf32, #encoding_bcast>, tensor<2x64xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
+  ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
+    %21 = arith.extui %in : i8 to i32
+    %22 = arith.uitofp %21 : i32 to f32
+    %23 = arith.subf %22, %in_1 : f32
+    %24 = arith.mulf %23, %in_0 : f32
+    linalg.yield %24 : f32
+  } -> tensor<2x128x64xf32, #encoding>
+  flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//   CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CHECK-LABEL: func.func @dequantization()
+//   CHECK-DAG:   %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x128x16x1xi8>>
+//   CHECK-DAG:   %[[LHS_SCALES_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>>
+//   CHECK-DAG:   %[[LHS_ZPS_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(2) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>>
+//   CHECK-DAG:   %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(3) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+//   CHECK-DAG:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x128x16x1xi8>> -> tensor<2x4x128x16x1xi8>
+//   CHECK-DAG:   %[[LHS_SCALES:.+]] = flow.dispatch.tensor.load %[[LHS_SCALES_BINDING]], offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> -> tensor<2x4x16xf32>
+//   CHECK-DAG:   %[[LHS_ZPS:.+]] = flow.dispatch.tensor.load %[[LHS_ZPS_BINDING]], offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> -> tensor<2x4x16xf32>
+//   CHECK-DAG:   %[[EMPTY_LHS:.+]] = tensor.empty() : tensor<2x4x128x16x1xf32>
+//   CHECK-DAG:   %[[LHS_DEQUANT:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP1]], #[[$MAP]]]
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+//  CHECK-SAME:       ins(%[[LHS]], %[[LHS_SCALES]], %[[LHS_ZPS]] : tensor<2x4x128x16x1xi8>, tensor<2x4x16xf32>, tensor<2x4x16xf32>)
+//  CHECK-SAME:       outs(%[[EMPTY_LHS]] : tensor<2x4x128x16x1xf32>)
+//       CHECK:     arith.extui
+//       CHECK:     arith.uitofp
+//       CHECK:     arith.subf
+//       CHECK:     arith.mulf
+//       CHECK:   flow.dispatch.tensor.store %[[LHS_DEQUANT]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : tensor<2x4x128x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+#encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+func.func @broadcast_batch() attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x64xf32, #encoding_bcast>>
+  %6 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+  %8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [128, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128x64xf32, #encoding_bcast>> -> tensor<128x64xf32, #encoding_bcast>
+  %13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
+  %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8 : tensor<128x64xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x128x64xf32, #encoding>
+  flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>
+//   CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @broadcast_batch()
+//   CHECK-DAG:   %[[INPUT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<4x128x16x1xf32>>
+//   CHECK-DAG:   %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+//   CHECK-DAG:   %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]], offsets = [0, 0, 0, 0], sizes = [4, 128, 16, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4x128x16x1xf32>> -> tensor<4x128x16x1xf32>
+//   CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<2x4x128x16x1xf32>
+//   CHECK-DAG:   %[[BROADCAST:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[$MAP]], #[[$MAP1]]]
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+//  CHECK-SAME:       ins(%[[INPUT]] : tensor<4x128x16x1xf32>)
+//  CHECK-SAME:       outs(%[[EMPTY]] : tensor<2x4x128x16x1xf32>)
+//       CHECK:   flow.dispatch.tensor.store %[[BROADCAST]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : tensor<2x4x128x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+#encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1)>, round_dims_to = array<i64: 16, 16, 16>>
+func.func @broadcast_M() attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x128xf32, #encoding_bcast>>
+  %6 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+  %8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x128xf32, #encoding_bcast>> -> tensor<2x128xf32, #encoding_bcast>
+  %13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
+  %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8 : tensor<2x128xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x128x64xf32, #encoding>
+  flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>
+//   CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @broadcast_M()
+//   CHECK-DAG:   %[[INPUT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x128x1xf32>>
+//   CHECK-DAG:   %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+//   CHECK-DAG:   %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]], offsets = [0, 0, 0], sizes = [2, 128, 1], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x128x1xf32>> -> tensor<2x128x1xf32>
+//   CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<2x4x128x16x1xf32>
+//   CHECK-DAG:   %[[BROADCAST:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[$MAP]], #[[$MAP1]]]
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+//  CHECK-SAME:       ins(%[[INPUT]] : tensor<2x128x1xf32>)
+//  CHECK-SAME:       outs(%[[EMPTY]] : tensor<2x4x128x16x1xf32>)
+//       CHECK:   flow.dispatch.tensor.store %[[BROADCAST]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : tensor<2x4x128x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+#encoding = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+#encoding_bcast = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+func.func @broadcast_N() attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
+  %6 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+  %8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
+  %13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
+  %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8 : tensor<2x64xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x128x64xf32, #encoding>
+  flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>
+//   CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @broadcast_N()
+//   CHECK-DAG:   %[[INPUT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x64x1xf32>>
+//   CHECK-DAG:   %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x8x64x16x1xf32>>
+//   CHECK-DAG:   %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]], offsets = [0, 0, 0], sizes = [2, 64, 1], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64x1xf32>> -> tensor<2x64x1xf32>
+//   CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x64x16x1xf32>
+//   CHECK-DAG:   %[[BROADCAST:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[$MAP]], #[[$MAP1]]]
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+//  CHECK-SAME:       ins(%[[INPUT]] : tensor<2x64x1xf32>)
+//  CHECK-SAME:       outs(%[[EMPTY]] : tensor<2x8x64x16x1xf32>)
+//       CHECK:   flow.dispatch.tensor.store %[[BROADCAST]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 8, 64, 16, 1], strides = [1, 1, 1, 1, 1] : tensor<2x8x64x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x8x64x16x1xf32>>
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+  #hal.descriptor_set.layout<0, bindings = [
+    #hal.descriptor_set.binding<0, storage_buffer>,
+    #hal.descriptor_set.binding<1, storage_buffer>
+  ]>
+]>
+#encoding = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+#encoding_bcast = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<2x128x64xf32>, user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], bcast_map = affine_map<(d0, d1, d2) -> (d0, d2)>, round_dims_to = array<i64: 16, 16, 16>>
+func.func @broadcast_K() attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512f"}>
+} {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>>
+  %6 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+  %8 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [2, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x64xf32, #encoding_bcast>> -> tensor<2x64xf32, #encoding_bcast>
+  %13 = tensor.empty() : tensor<2x128x64xf32, #encoding>
+  %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8 : tensor<2x64xf32, #encoding_bcast>) outs(%13 : tensor<2x128x64xf32, #encoding>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<2x128x64xf32, #encoding>
+  flow.dispatch.tensor.store %14, %6, offsets = [0, 0, 0], sizes = [2, 128, 64], strides = [1, 1, 1] : tensor<2x128x64xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<2x128x64xf32, #encoding>>
+  return
+}
+
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+//   CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @broadcast_K()
+//   CHECK-DAG:   %[[INPUT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(0) {{.*}} : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>>
+//   CHECK-DAG:   %[[RESULT_BINDING:.+]] = hal.interface.binding.subspan {{.*}} binding(1) {{.*}} : !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
+//   CHECK-DAG:   %[[INPUT:.+]] = flow.dispatch.tensor.load %[[INPUT_BINDING]], offsets = [0, 0, 0], sizes = [2, 4, 16], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4x16xf32>> -> tensor<2x4x16xf32>
+//   CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty() : tensor<2x4x128x16x1xf32>
+//   CHECK-DAG:   %[[BROADCAST:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[$MAP]], #[[$MAP1]]]
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+//  CHECK-SAME:       ins(%[[INPUT]] : tensor<2x4x16xf32>)
+//  CHECK-SAME:       outs(%[[EMPTY]] : tensor<2x4x128x16x1xf32>)
+//       CHECK:   flow.dispatch.tensor.store %[[BROADCAST]], %[[RESULT_BINDING]], offsets = [0, 0, 0, 0, 0], sizes = [2, 4, 128, 16, 1], strides = [1, 1, 1, 1, 1] : tensor<2x4x128x16x1xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x4x128x16x1xf32>>
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
index 9d60646..ce42de1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
@@ -113,9 +113,7 @@
     return dropEncoding(tensorType);
   }
   return cast<RankedTensorType>(tensor::PackOp::inferPackedType(
-      getOriginalTypeWithEncoding(maybeTransposedTensorType)
-          .clone(tensorType.getElementType()),
-      materializeEncodingInfo->innerTileSizes,
+      maybeTransposedTensorType, materializeEncodingInfo->innerTileSizes,
       materializeEncodingInfo->innerDimsPos,
       materializeEncodingInfo->outerDimsPerm));
 }
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
index c3fec88..99fb0d3 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -22,9 +22,9 @@
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
-#include "mlir/Support/LLVM.h"
 
 namespace mlir::iree_compiler {
 
@@ -294,8 +294,8 @@
   }
 
   FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
-      materializeEncodingFn(getOriginalTypeWithEncoding(
-          cast<RankedTensorType>(linalgOp->getResultTypes()[0])));
+      materializeEncodingFn(
+          cast<RankedTensorType>(linalgOp->getResultTypes()[0]));
 
   Operation *result;
   if (failed(materializeEncodingInfo)) {
@@ -345,21 +345,19 @@
                     MaterializeEncodingFn materializeEncodingFn,
                     MaterializeEncodingValueFn materializeEncodingValueFn) {
   auto emptyType = cast<RankedTensorType>(emptyOp->getResultTypes()[0]);
-  auto resultType =
-      getOriginalTypeWithEncoding(emptyType).clone(emptyType.getElementType());
   FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
-      materializeEncodingFn(resultType);
+      materializeEncodingFn(emptyType);
   Location loc = emptyOp.getLoc();
   if (failed(materializeEncodingInfo)) {
     Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
-        loc, emptyOp.getMixedSizes(), resultType.getElementType());
+        loc, emptyOp.getMixedSizes(), emptyType.getElementType());
     return newEmptyOp;
   }
   if (isNarrowNResult(IREE::Encoding::getEncodingAttr(emptyType))) {
     transposeInPlace(*materializeEncodingInfo);
   }
   FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
-      getInnerTileSizesOfr(rewriter, loc, resultType, *materializeEncodingInfo,
+      getInnerTileSizesOfr(rewriter, loc, emptyType, *materializeEncodingInfo,
                            materializeEncodingValueFn);
   if (failed(innerTileSizesOfr)) {
     return rewriter.notifyMatchFailure(
@@ -372,16 +370,110 @@
       materializeEncodingInfo->innerDimsPos,
       materializeEncodingInfo->outerDimsPerm);
   Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
-      loc, newShape, resultType.getElementType());
+      loc, newShape, emptyType.getElementType());
   return newEmptyOp;
 }
 
+/// Converts a linalg::GenericOp with encoded inputs into the packed domain.
+/// The `genericOp` must have all parallel iterator types and a single output
+/// with an identity indexing map.
+static FailureOr<Operation *>
+lowerGenericOpWithEncoding(RewriterBase &rewriter, linalg::GenericOp genericOp,
+                           ValueRange convertedInputOperands,
+                           ValueRange convertedOutputOperands,
+                           MaterializeEncodingFn materializeEncodingFn) {
+  if (!genericOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+  if (genericOp.getNumReductionLoops() != 0) {
+    return rewriter.notifyMatchFailure(genericOp, "Loops are not all parallel");
+  }
+  if (genericOp.getNumDpsInits() != 1) {
+    return rewriter.notifyMatchFailure(genericOp, "Not only 1 init operand");
+  }
+  OpOperand *outputOperand = genericOp.getDpsInitOperand(0);
+  AffineMap outputMap = genericOp.getMatchingIndexingMap(outputOperand);
+  if (!outputMap.isIdentity()) {
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "Output indexing map is not identity");
+  }
+  FailureOr<MaterializeEncodingInfo> outMaterializeEncodingInfo =
+      materializeEncodingFn(
+          cast<RankedTensorType>(outputOperand->get().getType()));
+  if (failed(outMaterializeEncodingInfo)) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "MaterializeEncodingInfo failed for output");
+  }
+
+  auto convertedResultType =
+      cast<RankedTensorType>(convertedOutputOperands[0].getType());
+  SmallVector<utils::IteratorType> iteratorTypes(convertedResultType.getRank(),
+                                                 utils::IteratorType::parallel);
+  // Compute the new indexing maps for the packed layout. This assumes that
+  // the output map is identity, and that all iterator types are parallel.
+  SmallVector<int64_t> outInnerDimsPos =
+      outMaterializeEncodingInfo->innerDimsPos;
+  SmallVector<int64_t> outInverseOuterDimsPerm =
+      invertPermutationVector(outMaterializeEncodingInfo->outerDimsPerm);
+  SmallVector<AffineMap> packedIndexingMaps;
+  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
+    FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
+        materializeEncodingFn(
+            cast<RankedTensorType>(inputOperand->get().getType()));
+    if (failed(materializeEncodingInfo)) {
+      return rewriter.notifyMatchFailure(
+          genericOp, "MaterializeEncodingInfo failed for input");
+    }
+    SmallVector<int64_t> innerDimsPos = materializeEncodingInfo->innerDimsPos;
+    SmallVector<int64_t> outerDimsPerm = materializeEncodingInfo->outerDimsPerm;
+    AffineMap inputMap = genericOp.getMatchingIndexingMap(inputOperand);
+    // Permute result dims to the input packed domain, and map dims to the
+    // output packed domain.
+    SmallVector<int64_t> packedResultDims = llvm::map_to_vector(
+        applyPermutation(inputMap.getResults(), outerDimsPerm),
+        [&](AffineExpr expr) {
+          auto dimExpr = cast<AffineDimExpr>(expr);
+          return outInverseOuterDimsPerm[dimExpr.getPosition()];
+        });
+    // Add new dims for the inner tiles, taking the dim position from the
+    // corresponding inner tile of the init operand.
+    for (auto [idx, pos] : llvm::enumerate(innerDimsPos)) {
+      auto dimPos = cast<AffineDimExpr>(inputMap.getResult(pos)).getPosition();
+      for (auto [tileIdx, outDim] : llvm::enumerate(outInnerDimsPos)) {
+        if (dimPos == outDim) {
+          packedResultDims.push_back(outputMap.getNumDims() + tileIdx);
+        }
+      }
+    }
+    // Create the packed indexing map.
+    SmallVector<AffineExpr> packedResultExprs =
+        llvm::map_to_vector(packedResultDims, [&](int64_t dim) {
+          return rewriter.getAffineDimExpr(dim);
+        });
+    auto packedInputMap = AffineMap::get(
+        /*dimCount=*/iteratorTypes.size(), /*symbolCount=*/0, packedResultExprs,
+        rewriter.getContext());
+    packedIndexingMaps.push_back(packedInputMap);
+  }
+  // Create the new packed identity map for the output.
+  packedIndexingMaps.push_back(
+      rewriter.getMultiDimIdentityMap(convertedResultType.getRank()));
+  auto materializedGenericOp = rewriter.create<linalg::GenericOp>(
+      genericOp.getLoc(), convertedResultType, convertedInputOperands,
+      convertedOutputOperands, packedIndexingMaps, iteratorTypes,
+      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+  rewriter.inlineRegionBefore(genericOp.getRegion(),
+                              materializedGenericOp.getRegion(),
+                              materializedGenericOp.getRegion().begin());
+  return materializedGenericOp.getOperation();
+}
+
 /// Utility method to convert from a linalg::LinalgOp on `tensor` types with
 /// encodings to a linalg::LinalgOp on the materialized type. The current
 /// supported op types are:
 ///  - linalg::LinalgOp that `isaContractionOpInterface`
 ///  - linalg::FillOp
-///  - element-wise linalg::GenericOp with single input and output
+///  - linalg::GenericOp with parallel iterators and a single output
 static FailureOr<Operation *> lowerOpWithEncoding(
     RewriterBase &rewriter, linalg::LinalgOp linalgOp,
     ValueRange convertedInputOperands, ValueRange convertedOutputOperands,
@@ -406,36 +498,12 @@
                 convertedInputOperands, convertedOutputOperands);
             return materializedFillOp;
           })
-      .Case<linalg::GenericOp>([&](linalg::GenericOp genericOp)
-                                   -> FailureOr<Operation *> {
-        if (!genericOp.hasPureTensorSemantics() || !isElementwise(genericOp) ||
-            genericOp.getNumDpsInputs() != 1 ||
-            genericOp.getNumDpsInits() != 1) {
-          return rewriter.notifyMatchFailure(
-              genericOp, "linalg.generic op is not elementwise "
-                         "with single input and single output");
-        }
-        if (!llvm::all_of(genericOp.getIndexingMapsArray(),
-                          [](AffineMap m) { return m.isIdentity(); })) {
-          return rewriter.notifyMatchFailure(
-              genericOp, "indexing maps are not all identity maps");
-        }
-        auto convertedResultType =
-            cast<RankedTensorType>(convertedOutputOperands[0].getType());
-        SmallVector<AffineMap> maps(
-            2, AffineMap::getMultiDimIdentityMap(convertedResultType.getRank(),
-                                                 rewriter.getContext()));
-        SmallVector<utils::IteratorType> iteratorTypes(
-            convertedResultType.getRank(), utils::IteratorType::parallel);
-        auto materializedGenericOp = rewriter.create<linalg::GenericOp>(
-            genericOp.getLoc(), convertedResultType, convertedInputOperands,
-            convertedOutputOperands, maps, iteratorTypes,
-            /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
-        rewriter.inlineRegionBefore(genericOp.getRegion(),
-                                    materializedGenericOp.getRegion(),
-                                    materializedGenericOp.getRegion().begin());
-        return materializedGenericOp.getOperation();
-      })
+      .Case<linalg::GenericOp>(
+          [&](linalg::GenericOp genericOp) -> FailureOr<Operation *> {
+            return lowerGenericOpWithEncoding(
+                rewriter, genericOp, convertedInputOperands,
+                convertedOutputOperands, materializeEncodingFn);
+          })
       .Default([](Operation *op) { return failure(); });
 }
 
@@ -454,9 +522,6 @@
     return failure();
   }
 
-  RankedTensorType originalTensorType =
-      getOriginalTypeWithEncoding(boundTensorType);
-
   MaterializeEncodingFn materializeEncodingFn =
       typeConverter.getMaterializeEncodingFn();
   FailureOr<MaterializeEncodingInfo> encodingInfo =
@@ -469,10 +534,9 @@
   }
 
   SmallVector<OpFoldResult> targetShape =
-      getMixedValues(originalTensorType.getShape(), dynamicDims, builder);
-  auto innerTileSizes =
-      getInnerTileSizesOfr(builder, loc, originalTensorType, *encodingInfo,
-                           materializeEncodingValueFn);
+      getMixedValues(boundTensorType.getShape(), dynamicDims, builder);
+  auto innerTileSizes = getInnerTileSizesOfr(
+      builder, loc, boundTensorType, *encodingInfo, materializeEncodingValueFn);
   if (failed(innerTileSizes)) {
     return failure();
   }