Data tiling: transpose narrow-N into narrow-M (#17446)

(This is a rebasing PR for https://github.com/iree-org/iree/pull/16890 )

This is a generic idea in the design of matrix multiplication
implementations: the M and N dimensions play symmetrical roles, so there
is this opportunity to halve the problem space by transposition. The
immediate motivation is ukernels: we have chosen to implement narrow
ukernels only for the narrow-M cases, not narrow-N, in preparation for
this. This is the reason why this PR is a 5%-9% e2e speedup on multiple
ML models with ukernels (and a > 2x speedup on matvec microbenchmarks).

The idea should be beneficial beyond ukernels though:
* With codegen (outside of ukernels), inner unit dimensions have often
caused codegen to fall off of good vectorization cases. This
transposition moves unit or generally smaller static dimensions to the
outer dimensions, which will help with that.
* When we get to serious distribution tiling (#16410), the reduction of
generality will again greatly help.

This transposition is made easier by (and was all along part of the idea
in) the RHS-transposition in `mmt4d` (the `t` in `mmt4d`), as generally
with matrix multiplication

```
B * Transpose(A) == Transpose( A * Transpose(B) )
```

so in `mmt4d` terms

```
mmt4d(B, A) == Transpose(mmt4d(A, B))
```

As `pack` and `unpack` already have enough generality to perform these
transpositions, we just directly generate the right transposing `pack`
and `unpack` ops. An earlier plan was to generate `linalg.transpose` and
rely on a later folding pattern, but it turned out to just be simpler to
directly generate the already-transposed `pack`, `unpack`.

A legitimate question was: should this transposition be implemented at
`SetEncoding` instead of at `MaterializeEncoding`? That would have been
simpler in some ways, but:
* The benefit of the transposition depends on the backend, so it doesn't
belong in Flow.
* SetEncoding must be reversible in case the back-end doesn't want to do
data-tiling. The transposition would be difficult to revert, and
generally confusing in settings where it's not wanted.
* The above mmt4d-specific trait simplifying the transposition only
helps since at MaterializeEncoding we know we are generating a mmt4d. We
couldn't so easily rely on that in SetEncoding.
* Before MaterializeEncoding we would have to handle `linalg.generic`,
not just named matmul ops.

Co-authored-by: Benoit Jacob <jacob.benoit.1@gmail.com>

benchmark-extra: x86_64-dt-only, android-cpu-dt-only

Signed-off-by: Alan Li <me@alanli.org>
Co-authored-by: Benoit Jacob <jacob.benoit.1@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
index 8597360..c1948b9 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp
@@ -498,20 +498,8 @@
   if (enumeratedTileMxNxK.empty()) {
     return failure();
   }
-  // Check if the encoding specifies static narrow sizes for the M/N dimensions.
-  // This can be used to choose a correspondingly narrow tile shape.
-  // With microkernels, we keep this logic in sync with the set of actual
-  // optimized microkernel tile functions to avoid a tile shape specialization
-  // causing a fallback to a slow generic tile function. At the moment,
-  // microkernel tile functions are only specialize for narrow M, not for narrow
-  // N. Accordingly, we leave matmulNarrowN as 0 (default) when microkernels are
-  // used. Generally it would be best to deal with narrow-N cases by transposing
-  // the whole matmul and swapping LHS<->RHS, reducing the narrow-N case to
-  // narrow-M.
   int64_t matmulNarrowM = getIntOrZero(encoding.getMatmulNarrow_M());
-  int64_t matmulNarrowN = hasUkernel(targetAttr, "mmt4d")
-                              ? 0
-                              : getIntOrZero(encoding.getMatmulNarrow_N());
+  int64_t matmulNarrowN = getIntOrZero(encoding.getMatmulNarrow_N());
   // Choose a final matmul TileMxNxK from the above-enumarated tile shapes,
   // taking narrow dimensions into account.
   TileMxNxK chosenTileMxNxK =
diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
index eaa1661..64a87b0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/CPU/test/llvmcpu_materialize_encoding.mlir
@@ -257,6 +257,45 @@
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matvec_shaped_matmul_lowering_f32f32f32_aarch64(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
+} {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<16x16xf32>
+  %1 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<16x1xf32>
+  %2 = hal.tensor.import %arg2 "input2" : !hal.buffer_view -> tensor<16x1xf32>
+  %padded = tensor.pad %0 low[0, 0] high[%c0, %c0] {
+  ^bb0(%arg3: index, %arg4: index):
+    tensor.yield %cst : f32
+  } : tensor<16x16xf32> to tensor<?x?xf32>
+  %3 = iree_encoding.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<role =  LHS, element_types = [f32, f32, f32], original_type = tensor<16x16xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
+  %padded_0 = tensor.pad %1 low[0, 0] high[%c0, %c0] {
+  ^bb0(%arg3: index, %arg4: index):
+    tensor.yield %cst : f32
+  } : tensor<16x1xf32> to tensor<?x?xf32>
+  %4 = iree_encoding.set_encoding %padded_0 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<role =  RHS, element_types = [f32, f32, f32], original_type = tensor<16x1xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
+  %padded_1 = tensor.pad %2 low[0, 0] high[%c0, %c0] {
+  ^bb0(%arg3: index, %arg4: index):
+    tensor.yield %cst : f32
+  } : tensor<16x1xf32> to tensor<?x?xf32>
+  %5 = iree_encoding.set_encoding %padded_1 : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<16x1xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
+  %6 = linalg.matmul ins(%3, %4 : tensor<?x?xf32, #iree_encoding.encoding<role =  LHS, element_types = [f32, f32, f32], original_type = tensor<16x16xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>, tensor<?x?xf32, #iree_encoding.encoding<role =  RHS, element_types = [f32, f32, f32], original_type = tensor<16x1xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) outs(%5 : tensor<?x?xf32, #iree_encoding.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<16x1xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) -> tensor<?x?xf32, #iree_encoding.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<16x1xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
+  %7 = iree_encoding.unset_encoding %6 : tensor<?x?xf32, #iree_encoding.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<16x1xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>> -> tensor<?x?xf32>
+  %extracted_slice = tensor.extract_slice %7[0, 0] [16, 1] [1, 1] : tensor<?x?xf32> to tensor<16x1xf32>
+  %8 = hal.tensor.export %extracted_slice "output0" : tensor<16x1xf32> -> !hal.buffer_view
+  func.return %8 : !hal.buffer_view
+}
+// CHECK-LABEL: func @matvec_shaped_matmul_lowering_f32f32f32_aarch64(
+//       CHECK:   %[[MMT4D:.+]] = linalg.mmt4d
+//  CHECK-SAME:       ins({{.*}} : tensor<1x16x1x1xf32>, tensor<2x16x8x1xf32>)
+//  CHECK-SAME:        outs({{.*}} : tensor<1x2x1x8xf32>) -> tensor<1x2x1x8xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
 func.func @matmul_lowering_f32f32f32_aarch64() attributes {
   hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz", ukernels = "all"}>
 } {
@@ -320,6 +359,45 @@
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matvec_lowering_f32f32f32_aarch64(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {
+  hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
+} {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<16x16xf32>
+  %1 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<16xf32>
+  %2 = hal.tensor.import %arg2 "input2" : !hal.buffer_view -> tensor<16xf32>
+  %padded = tensor.pad %0 low[0, 0] high[%c0, %c0] {
+  ^bb0(%arg3: index, %arg4: index):
+    tensor.yield %cst : f32
+  } : tensor<16x16xf32> to tensor<?x?xf32>
+  %3 = iree_encoding.set_encoding %padded : tensor<?x?xf32> -> tensor<?x?xf32, #iree_encoding.encoding<role =  LHS, element_types = [f32, f32, f32], original_type = tensor<16x16xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>]>>
+  %padded_0 = tensor.pad %1 low[0] high[%c0] {
+  ^bb0(%arg3: index):
+    tensor.yield %cst : f32
+  } : tensor<16xf32> to tensor<?xf32>
+  %4 = iree_encoding.set_encoding %padded_0 : tensor<?xf32> -> tensor<?xf32, #iree_encoding.encoding<role =  RHS, element_types = [f32, f32, f32], original_type = tensor<16xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>]>>
+  %padded_1 = tensor.pad %2 low[0] high[%c0] {
+  ^bb0(%arg3: index):
+    tensor.yield %cst : f32
+  } : tensor<16xf32> to tensor<?xf32>
+  %5 = iree_encoding.set_encoding %padded_1 : tensor<?xf32> -> tensor<?xf32, #iree_encoding.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<16xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>]>>
+  %6 = linalg.matvec ins(%3, %4 : tensor<?x?xf32, #iree_encoding.encoding<role =  LHS, element_types = [f32, f32, f32], original_type = tensor<16x16xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>]>>, tensor<?xf32, #iree_encoding.encoding<role =  RHS, element_types = [f32, f32, f32], original_type = tensor<16xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>]>>) outs(%5 : tensor<?xf32, #iree_encoding.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<16xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>]>>) -> tensor<?xf32, #iree_encoding.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<16xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>]>>
+  %7 = iree_encoding.unset_encoding %6 : tensor<?xf32, #iree_encoding.encoding<role =  RESULT, element_types = [f32, f32, f32], original_type = tensor<16xf32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>]>> -> tensor<?xf32>
+  %extracted_slice = tensor.extract_slice %7[0] [16] [1] : tensor<?xf32> to tensor<16xf32>
+  %8 = hal.tensor.export %extracted_slice "output0" : tensor<16xf32> -> !hal.buffer_view
+  func.return %8 : !hal.buffer_view
+}
+// CHECK-LABEL: func @matvec_lowering_f32f32f32_aarch64(
+//       CHECK:   %[[MMT4D:.+]] = linalg.mmt4d
+//  CHECK-SAME:       ins({{.*}} : tensor<1x16x1x1xf32>, tensor<2x16x8x1xf32>)
+//  CHECK-SAME:        outs({{.*}} : tensor<1x2x1x8xf32>) -> tensor<1x2x1x8xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
 func.func @matvec_lowering_f32f32f32_aarch64() attributes {
   hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="aarch64-xyz-xyz"}>
 } {
@@ -356,18 +434,18 @@
 //       CHECK:   %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
 //  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<1x16x1x1xf32>>
 //       CHECK:   %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
-//  CHECK-SAME:       !flow.dispatch.tensor<readwrite:tensor<2x1x8x1xf32>>
+//  CHECK-SAME:       !flow.dispatch.tensor<readwrite:tensor<1x2x1x8xf32>>
 //       CHECK:   %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
 //  CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [2, 16, 8, 1], strides = [1, 1, 1, 1]
 //       CHECK:   %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
 //  CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [1, 16, 1, 1], strides = [1, 1, 1, 1]
 //       CHECK:   %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
-//  CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [2, 1, 8, 1], strides = [1, 1, 1, 1]
+//  CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [1, 2, 1, 8], strides = [1, 1, 1, 1]
 //       CHECK:   %[[MMT4D:.+]] = linalg.mmt4d
-//  CHECK-SAME:       ins(%[[LHS]], %[[RHS]] :
+//  CHECK-SAME:       ins(%[[RHS]], %[[LHS]] :
 //  CHECK-SAME:       outs(%[[OUTS]] :
 //       CHECK:   flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
-//  CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [2, 1, 8, 1], strides = [1, 1, 1, 1]
+//  CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [1, 2, 1, 8], strides = [1, 1, 1, 1]
 
 // -----
 
@@ -2168,10 +2246,10 @@
 //  CHECK-NEXT:       linalg.yield %[[RHS_EXT_OP]] : i32
 //       CHECK:   %[[INIT_FILL:.+]] = tensor.empty() : tensor<688x16xi32>
 //       CHECK:   %[[EXPAND_RHS:.+]] = tensor.expand_shape %[[RHS_EXT]] {{\[}}[0, 1], [2, 3]] output_shape [1, 64, 1, 2] : tensor<64x2xi32> into tensor<1x64x1x2xi32>
-//       CHECK:   %[[EXPAND_INIT:.+]] = tensor.expand_shape %[[INIT_FILL:.+]] {{\[}}[0, 1], [2, 3]] output_shape [688, 1, 16, 1] : tensor<688x16xi32> into tensor<688x1x16x1xi32>
-//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[C0_I32]] : i32) outs(%[[EXPAND_INIT]] : tensor<688x1x16x1xi32>) -> tensor<688x1x16x1xi32>
-//       CHECK:   %[[MMT4D:.+]] = linalg.mmt4d ins(%[[LHS_EXT]], %[[EXPAND_RHS]]  : tensor<688x64x16x2xi32>, tensor<1x64x1x2xi32>) outs(%[[FILL]] : tensor<688x1x16x1xi32>) -> tensor<688x1x16x1xi32>
-//       CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMT4D]] {{\[}}[0, 1], [2, 3]] : tensor<688x1x16x1xi32> into tensor<688x16xi32>
+//       CHECK:   %[[EXPAND_INIT:.+]] = tensor.expand_shape %[[INIT_FILL:.+]] {{\[}}[0, 1], [2, 3]] output_shape [1, 688, 1, 16] : tensor<688x16xi32> into tensor<1x688x1x16xi32>
+//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[C0_I32]] : i32) outs(%[[EXPAND_INIT]] : tensor<1x688x1x16xi32>) -> tensor<1x688x1x16xi32>
+//       CHECK:   %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXPAND_RHS]], %[[LHS_EXT]]  : tensor<1x64x1x2xi32>, tensor<688x64x16x2xi32>) outs(%[[FILL]] : tensor<1x688x1x16xi32>) -> tensor<1x688x1x16xi32>
+//       CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMT4D]] {{\[}}[0, 1], [2, 3]] : tensor<1x688x1x16xi32> into tensor<688x16xi32>
 //       CHECK:   %[[INIT_UNPACK:.+]] = tensor.empty() : tensor<11008xi32>
 //       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[COLLAPSED]] outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [16] into %[[INIT_UNPACK]] : tensor<688x16xi32> -> tensor<11008xi32>
 //       CHECK:   return %[[UNPACK]]
@@ -2242,10 +2320,10 @@
 //  CHECK-NEXT:       linalg.yield %[[RHS_EXT_OP]] : i32
 //       CHECK:   %[[INIT_FILL:.+]] = tensor.empty() : tensor<1x16xi32>
 //       CHECK:   %[[EXPAND_RHS:.+]] = tensor.expand_shape %[[RHS_EXT]] {{\[}}[0, 1], [2, 3]] output_shape [1, 64, 1, 2] : tensor<64x2xi32> into tensor<1x64x1x2xi32>
-//       CHECK:   %[[EXPAND_INIT:.+]] = tensor.expand_shape %[[INIT_FILL:.+]] {{\[}}[0, 1], [2, 3]] output_shape [1, 1, 16, 1] : tensor<1x16xi32> into tensor<1x1x16x1xi32>
-//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[C0_I32]] : i32) outs(%[[EXPAND_INIT]] : tensor<1x1x16x1xi32>) -> tensor<1x1x16x1xi32>
-//       CHECK:   %[[MMT4D:.+]] = linalg.mmt4d ins(%[[LHS_EXT]], %[[EXPAND_RHS]]  : tensor<1x64x16x2xi32>, tensor<1x64x1x2xi32>) outs(%[[FILL]] : tensor<1x1x16x1xi32>) -> tensor<1x1x16x1xi32>
-//       CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMT4D]] {{\[}}[0, 1], [2, 3]] : tensor<1x1x16x1xi32> into tensor<1x16xi32>
+//       CHECK:   %[[EXPAND_INIT:.+]] = tensor.expand_shape %[[INIT_FILL:.+]] {{\[}}[0, 1], [2, 3]] output_shape [1, 1, 1, 16] : tensor<1x16xi32> into tensor<1x1x1x16xi32>
+//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[C0_I32]] : i32) outs(%[[EXPAND_INIT]] : tensor<1x1x1x16xi32>) -> tensor<1x1x1x16xi32>
+//       CHECK:   %[[MMT4D:.+]] = linalg.mmt4d ins(%[[EXPAND_RHS]], %[[LHS_EXT]]  : tensor<1x64x1x2xi32>, tensor<1x64x16x2xi32>) outs(%[[FILL]] : tensor<1x1x1x16xi32>) -> tensor<1x1x1x16xi32>
+//       CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMT4D]] {{\[}}[0, 1], [2, 3]] : tensor<1x1x1x16xi32> into tensor<1x16xi32>
 //       CHECK:   %[[INIT_UNPACK:.+]] = tensor.empty() : tensor<15xi32>
 //       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[COLLAPSED]] outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [16] into %[[INIT_UNPACK]] : tensor<1x16xi32> -> tensor<15xi32>
 //       CHECK:   return %[[UNPACK]]
@@ -2326,77 +2404,38 @@
 
 // -----
 
-#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @batch_matvec(%arg0: tensor<32x11008x128xi8>, %arg1: tensor<32x128xi8>) -> tensor<32x11008xi32> attributes {
+func.func @batch_matvec(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {
   hal.executable.target = #hal.executable.target<"xyz", "xyz", {target_triple="x86_64-xyz-xyz", cpu_features="+avx512vnni"}>
 } {
-  %c32 = arith.constant 32 : index
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c128 = arith.constant 128 : index
-  %c11008 = arith.constant 11008 : index
-  %c0_i8 = arith.constant 0 : i8
   %c0_i32 = arith.constant 0 : i32
-  %padded = tensor.pad %arg0 low[0, 0, 0] high[%c0, %c0, %c0] {
-  ^bb0(%arg2: index, %arg3: index, %arg4: index):
+  %c0_i8 = arith.constant 0 : i8
+  %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<32x11008x128xi8>
+  %1 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<32x128xi8>
+  %2 = hal.tensor.import %arg2 "input2" : !hal.buffer_view -> tensor<32x11008xi32>
+  %padded = tensor.pad %0 low[0, 0, 0] high[%c0, %c0, %c0] {
+  ^bb0(%arg3: index, %arg4: index, %arg5: index):
     tensor.yield %c0_i8 : i8
   } : tensor<32x11008x128xi8> to tensor<?x?x?xi8>
-  %4 = iree_encoding.set_encoding %padded : tensor<?x?x?xi8> -> tensor<?x?x?xi8, #iree_encoding.encoding<role = LHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>
-  %5 = tensor.empty(%c32, %c11008, %c128) : tensor<?x?x?xi32, #iree_encoding.encoding<role = LHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>
-  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4 : tensor<?x?x?xi8, #iree_encoding.encoding<role = LHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>) outs(%5 : tensor<?x?x?xi32, #iree_encoding.encoding<role = LHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>) {
-  ^bb0(%in: i8, %out: i32):
-    %17 = arith.extsi %in : i8 to i32
-    linalg.yield %17 : i32
-  } -> tensor<?x?x?xi32, #iree_encoding.encoding<role = LHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>
-  %padded_0 = tensor.pad %arg1 low[0, 0] high[%c0, %c0] {
-  ^bb0(%arg2: index, %arg3: index):
+  %3 = iree_encoding.set_encoding %padded : tensor<?x?x?xi8> -> tensor<?x?x?xi8, #iree_encoding.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<32x11008x128xi8>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
+  %padded_0 = tensor.pad %1 low[0, 0] high[%c0, %c0] {
+  ^bb0(%arg3: index, %arg4: index):
     tensor.yield %c0_i8 : i8
   } : tensor<32x128xi8> to tensor<?x?xi8>
-  %7 = iree_encoding.set_encoding %padded_0 : tensor<?x?xi8> -> tensor<?x?xi8, #iree_encoding.encoding<role = RHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>
-  %8 = tensor.empty(%c32, %c128) : tensor<?x?xi32, #iree_encoding.encoding<role = RHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>
-  %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%7 : tensor<?x?xi8, #iree_encoding.encoding<role = RHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>) outs(%8 : tensor<?x?xi32, #iree_encoding.encoding<role = RHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>) {
-  ^bb0(%in: i8, %out: i32):
-    %17 = arith.extsi %in : i8 to i32
-    linalg.yield %17 : i32
-  } -> tensor<?x?xi32, #iree_encoding.encoding<role = RHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>
-  %10 = tensor.empty(%c32, %c11008) : tensor<?x?xi32, #iree_encoding.encoding<role = RESULT, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008xi32>, user_indexing_maps = [#map, #map1, #map2]>>
-  %11 = linalg.fill ins(%c0_i32 : i32) outs(%10 : tensor<?x?xi32, #iree_encoding.encoding<role = RESULT, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008xi32>, user_indexing_maps = [#map, #map1, #map2]>>) -> tensor<?x?xi32, #iree_encoding.encoding<role = RESULT, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008xi32>, user_indexing_maps = [#map, #map1, #map2]>>
-  %12 = linalg.batch_matvec ins(%6, %9 : tensor<?x?x?xi32, #iree_encoding.encoding<role = LHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>, tensor<?x?xi32, #iree_encoding.encoding<role = RHS, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x128xi8>, user_indexing_maps = [#map, #map1, #map2]>>) outs(%11 : tensor<?x?xi32, #iree_encoding.encoding<role = RESULT, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008xi32>, user_indexing_maps = [#map, #map1, #map2]>>) -> tensor<?x?xi32, #iree_encoding.encoding<role = RESULT, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008xi32>, user_indexing_maps = [#map, #map1, #map2]>>
-  %13 = iree_encoding.unset_encoding %12 : tensor<?x?xi32, #iree_encoding.encoding<role = RESULT, element_types = [i8, i8, i32], matmul_narrow_N = 1 : index, original_type = tensor<32x11008xi32>, user_indexing_maps = [#map, #map1, #map2]>> -> tensor<?x?xi32>
-  %extracted_slice = tensor.extract_slice %13[0, 0] [32, 11008] [1, 1] : tensor<?x?xi32> to tensor<32x11008xi32>
-  return %extracted_slice : tensor<32x11008xi32>
+  %4 = iree_encoding.set_encoding %padded_0 : tensor<?x?xi8> -> tensor<?x?xi8, #iree_encoding.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<32x128xi8>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
+  %padded_1 = tensor.pad %2 low[0, 0] high[%c0, %c0] {
+  ^bb0(%arg3: index, %arg4: index):
+    tensor.yield %c0_i32 : i32
+  } : tensor<32x11008xi32> to tensor<?x?xi32>
+  %5 = iree_encoding.set_encoding %padded_1 : tensor<?x?xi32> -> tensor<?x?xi32, #iree_encoding.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<32x11008xi32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
+  %6 = linalg.batch_matvec ins(%3, %4 : tensor<?x?x?xi8, #iree_encoding.encoding<role =  LHS, element_types = [i8, i8, i32], original_type = tensor<32x11008x128xi8>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>, tensor<?x?xi8, #iree_encoding.encoding<role =  RHS, element_types = [i8, i8, i32], original_type = tensor<32x128xi8>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) outs(%5 : tensor<?x?xi32, #iree_encoding.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<32x11008xi32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>) -> tensor<?x?xi32, #iree_encoding.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<32x11008xi32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>>
+  %7 = iree_encoding.unset_encoding %6 : tensor<?x?xi32, #iree_encoding.encoding<role =  RESULT, element_types = [i8, i8, i32], original_type = tensor<32x11008xi32>, matmul_narrow_N = 1 : index, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>]>> -> tensor<?x?xi32>
+  %extracted_slice = tensor.extract_slice %7[0, 0] [32, 11008] [1, 1] : tensor<?x?xi32> to tensor<32x11008xi32>
+  %8 = hal.tensor.export %extracted_slice "output0" : tensor<32x11008xi32> -> !hal.buffer_view
+  func.return %8 : !hal.buffer_view
 }
 
-//   CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-//   CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-LABEL: func.func @batch_matvec(
-//  CHECK-SAME:   %[[LHS:.+]]: tensor<32x11008x128xi8>, %[[RHS:.+]]: tensor<32x128xi8>) -> tensor<32x11008xi32>
-//   CHECK-DAG:   %[[C0_I32:.+]] = arith.constant 0 : i32
-//   CHECK-DAG:   %[[INIT_LHS_PACK:.+]] = tensor.empty() : tensor<32x688x64x16x2xi8>
-//   CHECK-DAG:   %[[LHS_PACK:.+]] = tensor.pack %[[LHS]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 2] into %[[INIT_LHS_PACK]] : tensor<32x11008x128xi8> -> tensor<32x688x64x16x2xi8>
-//   CHECK-DAG:   %[[INIT_LHS_EXT:.+]] = tensor.empty() : tensor<32x688x64x16x2xi32>
-//       CHECK:   %[[LHS_EXT:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[LHS_PACK]] : tensor<32x688x64x16x2xi8>) outs(%[[INIT_LHS_EXT]] : tensor<32x688x64x16x2xi32>) {
-//  CHECK-NEXT:       ^bb0(%[[LHS_EXT_ARG_IN:.+]]: i8, %[[LHS_EXT_ARG_OUT:.+]]: i32):
-//  CHECK-NEXT:       %[[LHS_EXT_OP:.+]] = arith.extsi %[[LHS_EXT_ARG_IN]] : i8 to i32
-//  CHECK-NEXT:       linalg.yield %[[LHS_EXT_OP]] : i32
-//   CHECK-DAG:   %[[INIT_RHS_PACK:.+]] = tensor.empty() : tensor<32x64x2xi8>
-//   CHECK-DAG:   %[[RHS_PACK:.+]] = tensor.pack %[[RHS]] outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [2] into %[[INIT_RHS_PACK]] : tensor<32x128xi8> -> tensor<32x64x2xi8>
-//   CHECK-DAG:   %[[INIT_RHS_EXT:.+]] = tensor.empty() : tensor<32x64x2xi32>
-//       CHECK:   %[[RHS_EXT:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[RHS_PACK]] : tensor<32x64x2xi8>) outs(%[[INIT_RHS_EXT]] : tensor<32x64x2xi32>) {
-//  CHECK-NEXT:       ^bb0(%[[RHS_EXT_ARG_IN:.+]]: i8, %[[RHS_EXT_ARG_OUT:.+]]: i32):
-//  CHECK-NEXT:       %[[RHS_EXT_OP:.+]] = arith.extsi %[[RHS_EXT_ARG_IN]] : i8 to i32
-//  CHECK-NEXT:       linalg.yield %[[RHS_EXT_OP]] : i32
-//       CHECK:   %[[INIT_FILL:.+]] = tensor.empty() : tensor<32x688x16xi32>
-//       CHECK:   %[[EXPAND_RHS:.+]] = tensor.expand_shape %[[RHS_EXT]] {{\[}}[0], [1, 2], [3, 4]] output_shape [32, 1, 64, 1, 2] : tensor<32x64x2xi32> into tensor<32x1x64x1x2xi32>
-//       CHECK:   %[[EXPAND_INIT:.+]] = tensor.expand_shape %[[INIT_FILL:.+]] {{\[}}[0], [1, 2], [3, 4]] output_shape [32, 688, 1, 16, 1] : tensor<32x688x16xi32> into tensor<32x688x1x16x1xi32>
-//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[C0_I32]] : i32) outs(%[[EXPAND_INIT]] : tensor<32x688x1x16x1xi32>) -> tensor<32x688x1x16x1xi32>
-//       CHECK:   %[[MMT4D:.+]] = linalg.batch_mmt4d ins(%[[LHS_EXT]], %[[EXPAND_RHS]] : tensor<32x688x64x16x2xi32>, tensor<32x1x64x1x2xi32>) outs(%[[FILL]] : tensor<32x688x1x16x1xi32>) -> tensor<32x688x1x16x1xi32>
-//       CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[MMT4D]] {{\[}}[0], [1, 2], [3, 4]] : tensor<32x688x1x16x1xi32> into tensor<32x688x16xi32>
-//       CHECK:   %[[INIT_UNPACK:.+]] = tensor.empty() : tensor<32x11008xi32>
-//       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[COLLAPSED]] outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [16] into %[[INIT_UNPACK]] : tensor<32x688x16xi32> -> tensor<32x11008xi32>
-//       CHECK:   return %[[UNPACK]]
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
index 6cea37a..8387431 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
@@ -9,6 +9,8 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 
+#include <numeric>
+
 namespace mlir::iree_compiler {
 
 using IREE::Encoding::EncodingAttr;
@@ -16,23 +18,98 @@
 using IREE::Encoding::getEncodingAttr;
 using IREE::Encoding::getEncodingContractionDims;
 
+// If tensorType has the encoding of a matmul RESULT with narrow N, returns
+// the transposed type. Otherwise, just returns tensorType.
+static RankedTensorType transposeIfNarrowNResult(RankedTensorType tensorType) {
+  auto encoding =
+      llvm::dyn_cast_or_null<EncodingAttr>(tensorType.getEncoding());
+  if (!encoding) {
+    return tensorType;
+  }
+  if (!isNarrowNResult(encoding)) {
+    return tensorType;
+  }
+  auto newRole = encoding.getRole().getValue();
+  TypeAttr originalTypeAttr = encoding.getOriginalType();
+  RankedTensorType originalType = tensorType;
+  if (originalTypeAttr) {
+    originalType =
+        llvm::dyn_cast<RankedTensorType>(originalTypeAttr.getValue());
+  }
+  SmallVector<int64_t> newOriginalShape(originalType.getShape());
+  auto userIndexingMaps = encoding.getUserIndexingMaps();
+  SmallVector<AffineMap> maps;
+  for (auto a : userIndexingMaps) {
+    maps.push_back(cast<AffineMapAttr>(a).getAffineMap());
+  }
+  auto cDims = linalg::inferContractionDims(maps);
+  SmallVector<int64_t> newShape(tensorType.getShape());
+  SmallVector<int64_t> permIndices(maps[0].getNumDims());
+  std::iota(std::begin(permIndices), std::end(permIndices), 0);
+  // Matrix case: there are both M and N dimensions. Transposing means swapping
+  // them.
+  if (cDims->m.size() == 1 && cDims->n.size() == 1) {
+    int m = cDims->m[0];
+    int n = cDims->n[0];
+    std::swap(permIndices[m], permIndices[n]);
+    int mDim = encoding.mapDimToRoleIndex(m);
+    int nDim = encoding.mapDimToRoleIndex(n);
+    std::swap(newShape[mDim], newShape[nDim]);
+    std::swap(newOriginalShape[mDim], newOriginalShape[nDim]);
+  }
+  // Vector case: there is no N dimension to swap the M dimension with. We
+  // swap the maps themselves.
+  if (cDims->n.empty()) {
+    std::swap(maps[0], maps[1]);
+  }
+
+  // auto newRoundDimsTo = encoding.getRoundDimsToArray();
+  SmallVector<int64_t> newRoundDimsTo(encoding.getRoundDimsToArray());
+  assert(newRoundDimsTo.size() == 0 || newRoundDimsTo.size() == 3);
+  if (newRoundDimsTo.size() != 0)
+    std::swap(newRoundDimsTo[0], newRoundDimsTo[1]);
+
+  auto context = tensorType.getContext();
+  AffineMap permutation = AffineMap::getPermutationMap(permIndices, context);
+  for (auto &map : maps) {
+    map = map.compose(permutation);
+  }
+  SmallVector<Attribute> newMaps;
+  for (auto map : maps) {
+    newMaps.push_back(AffineMapAttr::get(map));
+  }
+  ArrayAttr newIndexingMaps = ArrayAttr::get(context, newMaps);
+  auto elemType = tensorType.getElementType();
+  OpBuilder builder(context);
+
+  auto newEncoding = IREE::Encoding::EncodingAttr::get(
+      context, IREE::Encoding::EncodingRoleAttr::get(context, newRole),
+      encoding.getElementTypes(),
+      TypeAttr::get(RankedTensorType::get(newOriginalShape, elemType)),
+      encoding.getMatmulNarrow_N(), encoding.getMatmulNarrow_M(),
+      newIndexingMaps, DenseI64ArrayAttr::get(context, newRoundDimsTo));
+  return RankedTensorType::get(newShape, elemType, newEncoding);
+}
+
 /// For a given tensor type with an encoding, return the materialized
 /// type to use for it. If no encoding is set, then return the tensor type
 /// itself.
 static RankedTensorType
 getMaterializedType(RankedTensorType tensorType,
                     MaterializeEncodingFn materializeEncodingFn) {
+  RankedTensorType maybeTransposedTensorType =
+      transposeIfNarrowNResult(tensorType);
   FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
-      materializeEncodingFn(tensorType);
+      materializeEncodingFn(maybeTransposedTensorType);
   if (failed(materializeEncodingInfo)) {
     return dropEncoding(tensorType);
   }
-  return cast<RankedTensorType>(
-      tensor::PackOp::inferPackedType(getOriginalTypeWithEncoding(tensorType)
-                                          .clone(tensorType.getElementType()),
-                                      materializeEncodingInfo->innerTileSizes,
-                                      materializeEncodingInfo->innerDimsPos,
-                                      materializeEncodingInfo->outerDimsPerm));
+  return cast<RankedTensorType>(tensor::PackOp::inferPackedType(
+      getOriginalTypeWithEncoding(maybeTransposedTensorType)
+          .clone(tensorType.getElementType()),
+      materializeEncodingInfo->innerTileSizes,
+      materializeEncodingInfo->innerDimsPos,
+      materializeEncodingInfo->outerDimsPerm));
 }
 
 MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
@@ -42,10 +119,9 @@
   addConversion([](IndexType indexType) { return indexType; });
   addConversion([](FloatType floatType) { return floatType; });
   addConversion([](MemRefType memrefType) { return memrefType; });
-  addConversion(
-      [materializeEncodingFn](RankedTensorType t) -> RankedTensorType {
-        return getMaterializedType(t, materializeEncodingFn);
-      });
+  addConversion([=](RankedTensorType t) -> RankedTensorType {
+    return getMaterializedType(t, materializeEncodingFn);
+  });
 }
 
 MaterializeEncodingConversionTarget::MaterializeEncodingConversionTarget(
@@ -127,4 +203,13 @@
   return encodingInfo;
 }
 
+bool isNarrowNResult(EncodingAttr encoding) {
+  if (encoding.getRole().getValue() != EncodingRole::RESULT) {
+    return false;
+  }
+  IntegerAttr narrowM = encoding.getMatmulNarrow_M();
+  IntegerAttr narrowN = encoding.getMatmulNarrow_N();
+  return narrowN && (!narrowM || narrowM.getInt() > narrowN.getInt());
+}
+
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
index cf9e9a6..42b4438 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
@@ -99,6 +99,10 @@
 void populateMaterializeUpperBoundTileSizePatterns(
     RewritePatternSet &patterns, MaterializeEncodingFn materializeEncodingFn);
 
+// Returns true if `encoding` represents a narrow-N matmul RESULT, e.g. the
+// result of a matvec.
+bool isNarrowNResult(IREE::Encoding::EncodingAttr encoding);
+
 } // namespace mlir::iree_compiler
 
 #endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGUTILS_H_
diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
index b6443d3..689e88c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp
@@ -18,6 +18,7 @@
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -141,11 +142,12 @@
 /// the canonical mmt4d input shape. If the input element type is unsigned,
 /// create a producer Linalg::GenericOp on the input that unsigned extends the
 /// input to the output element type. This extension is required to keep the
-/// unsignedness information on the input for ukernels.
-Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp,
-                      RewriterBase &rewriter,
-                      SmallVectorImpl<ReassociationIndices> &ri,
-                      ArrayRef<Type> elemTypes, int operandIdx) {
+/// unsignedness information on the input for ukernels. If `transpose` is true,
+/// the `linalgOp`'s indexing maps are transposed.
+static Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp,
+                             bool transpose, RewriterBase &rewriter,
+                             SmallVectorImpl<ReassociationIndices> &ri,
+                             ArrayRef<Type> elemTypes, int operandIdx) {
   assert(linalgOp.getNumDpsInputs() == 2);
   assert(linalgOp.getNumDpsInits() == 1);
   auto cDims = linalg::inferContractionDims(linalgOp);
@@ -158,7 +160,7 @@
     auto type = cast<RankedTensorType>(value.getType());
     RankedTensorType newType = getExpandedType(
         type, /*isBatched=*/!cDims->batch.empty(),
-        /*isTransposed=*/operandIdx == 2 && cDims->n.empty(), ri);
+        /*isTransposed=*/operandIdx == 2 && (transpose ^ cDims->n.empty()), ri);
     expandedValue =
         rewriter.create<tensor::ExpandShapeOp>(loc, newType, value, ri);
   }
@@ -169,6 +171,22 @@
   return expandedValue;
 }
 
+static void transposeInPlace(MaterializeEncodingInfo &info) {
+  // Vector cases: nothing to do.
+  if (info.innerTileSizes.size() < 2) {
+    return;
+  }
+  // Not a vector case, so all three arrays in `info` have size at least 2,
+  // outerDimsPerm may have size 3 if there is a batch dimension, but in all
+  // cases, the last 2 entries of each array are M and N, not batch.
+  auto transpose = [](SmallVector<int64_t> &a) {
+    std::swap(a[a.size() - 2], a[a.size() - 1]);
+  };
+  transpose(info.innerDimsPos);
+  transpose(info.innerTileSizes);
+  transpose(info.outerDimsPerm);
+}
+
 //===---------------------------------------------------------------------===//
 // Methods to convert `set_encoding` and `unset_encoding` operations
 // to `pack` and `unpack` operations respectively.
@@ -200,11 +218,18 @@
     MaterializeEncodingFn materializeEncodingFn,
     MaterializeEncodingValueFn materializeEncodingValueFn) {
   RankedTensorType resultType = encodingOp.getResultType();
+  auto encoding = getEncodingAttr(resultType);
+  if (!encoding) {
+    return failure();
+  }
   FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
       materializeEncodingFn(resultType);
   if (failed(materializeEncodingInfo)) {
     return rewriter.notifyMatchFailure(encodingOp, "unhandled result encoding");
   }
+  if (isNarrowNResult(encoding)) {
+    transposeInPlace(*materializeEncodingInfo);
+  }
   // Create `tensor.empty` operation for the result of the pack operation.
   Location loc = encodingOp.getLoc();
   FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
@@ -214,7 +239,6 @@
     return rewriter.notifyMatchFailure(
         encodingOp, "failed to generate runtime tile size query");
   }
-  auto encoding = getEncodingAttr(resultType);
   if (!encoding) {
     return failure();
   }
@@ -251,6 +275,9 @@
   if (failed(materializeEncodingInfo)) {
     return rewriter.notifyMatchFailure(encodingOp, "unhandled source encoding");
   }
+  if (isNarrowNResult(getEncodingAttr(sourceType))) {
+    transposeInPlace(*materializeEncodingInfo);
+  }
   // Create an `tensor.empty` for the result of the unpack operation.
   Location loc = encodingOp.getLoc();
   SmallVector<OpFoldResult> resultDims =
@@ -339,22 +366,22 @@
                                     operands.take_front(inputs.size()),
                                     operands.drop_front(inputs.size()));
   } else {
+    bool transpose = isNarrowNResult(resultEncoding);
     auto elemTypes = llvm::map_to_vector(
         lhsEncoding.getElementTypes().getValue(),
         [](Attribute a) { return cast<TypeAttr>(a).getValue(); });
     SmallVector<ReassociationIndices> ri;
-    Value newLhs =
-        getMmt4dOperand(operands[0], linalgOp, rewriter, ri, elemTypes,
-                        /*operandIdx=*/0);
-    Value newRhs =
-        getMmt4dOperand(operands[1], linalgOp, rewriter, ri, elemTypes,
-                        /*operandIdx=*/1);
+    Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, rewriter,
+                                   ri, elemTypes, /*operandIdx=*/0);
+    Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, rewriter,
+                                   ri, elemTypes, /*operandIdx=*/1);
     Value newResult =
-        getMmt4dOperand(operands[2], linalgOp, rewriter, ri, elemTypes,
-                        /*operandIdx=*/2);
-
+        getMmt4dOperand(operands[2], linalgOp, transpose, rewriter, ri,
+                        elemTypes, /*operandIdx=*/2);
+    if (transpose) {
+      std::swap(newLhs, newRhs);
+    }
     Type newResultType = newResult.getType();
-
     auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding);
     if (cDims->batch.empty()) {
       result = rewriter.create<linalg::Mmt4DOp>(
@@ -391,7 +418,9 @@
         loc, emptyOp.getMixedSizes(), resultType.getElementType());
     return newEmptyOp;
   }
-
+  if (isNarrowNResult(getEncodingAttr(emptyType))) {
+    transposeInPlace(*materializeEncodingInfo);
+  }
   FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr =
       getInnerTileSizesOfr(rewriter, loc, resultType, *materializeEncodingInfo,
                            materializeEncodingValueFn);
@@ -407,7 +436,6 @@
       materializeEncodingInfo->outerDimsPerm);
   Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
       loc, newShape, resultType.getElementType());
-
   return newEmptyOp;
 }
 
@@ -499,6 +527,9 @@
   if (failed(encodingInfo)) {
     return failure();
   }
+  if (isNarrowNResult(getEncodingAttr(boundTensorType))) {
+    transposeInPlace(*encodingInfo);
+  }
 
   SmallVector<OpFoldResult> targetShape =
       getMixedValues(originalTensorType.getShape(), dynamicDims, builder);
@@ -710,10 +741,10 @@
   LogicalResult
   matchAndRewrite(SetEncodingOp encodingOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
+        getTypeConverter());
     MaterializeEncodingFn materializeEncodingFn =
-        static_cast<const MaterializeEncodingTypeConverter *>(
-            getTypeConverter())
-            ->getMaterializeEncodingFn();
+        converter->getMaterializeEncodingFn();
     auto packOp = lowerSetEncodingOpToPackOp(
         rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn,
         this->materializeEncodingValueFn);
@@ -742,10 +773,10 @@
   LogicalResult
   matchAndRewrite(UnsetEncodingOp encodingOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
+        this->getTypeConverter());
     MaterializeEncodingFn materializeEncodingFn =
-        static_cast<const MaterializeEncodingTypeConverter *>(
-            this->getTypeConverter())
-            ->getMaterializeEncodingFn();
+        converter->getMaterializeEncodingFn();
     auto unpackOp = lowerUnsetEncodingToUnpackOp(
         rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn,
         this->materializeEncodingValueFn);
@@ -802,10 +833,10 @@
   LogicalResult
   matchAndRewrite(OpTy dpsOp, typename OpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
+        this->getTypeConverter());
     MaterializeEncodingFn materializeEncodingFn =
-        static_cast<const MaterializeEncodingTypeConverter *>(
-            this->getTypeConverter())
-            ->getMaterializeEncodingFn();
+        converter->getMaterializeEncodingFn();
     FailureOr<Operation *> convertedOp = lowerOpWithEncoding(
         rewriter, dpsOp, adaptor.getInputs(), adaptor.getOutputs(),
         materializeEncodingFn, this->materializeEncodingValueFn);
@@ -825,10 +856,10 @@
   LogicalResult
   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
+        this->getTypeConverter());
     MaterializeEncodingFn materializeEncodingFn =
-        static_cast<const MaterializeEncodingTypeConverter *>(
-            this->getTypeConverter())
-            ->getMaterializeEncodingFn();
+        converter->getMaterializeEncodingFn();
     FailureOr<Operation *> convertedOp = lowerOpWithEncoding(
         rewriter, op, adaptor.getOperands(), materializeEncodingFn,
         this->materializeEncodingValueFn);
@@ -868,10 +899,10 @@
   matchAndRewrite(mlir::linalg::ContractionOpInterface op,
                   ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
+    auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
+        this->getTypeConverter());
     MaterializeEncodingFn materializeEncodingFn =
-        static_cast<const MaterializeEncodingTypeConverter *>(
-            this->getTypeConverter())
-            ->getMaterializeEncodingFn();
+        converter->getMaterializeEncodingFn();
     auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
     if (!linalgOp || operands.size() != 3) {
       return failure();
diff --git a/tests/e2e/linalg/BUILD.bazel b/tests/e2e/linalg/BUILD.bazel
index 73645d2..791fd28 100644
--- a/tests/e2e/linalg/BUILD.bazel
+++ b/tests/e2e/linalg/BUILD.bazel
@@ -23,6 +23,7 @@
     [
         "conv2d.mlir",
         "fp_to_subbyte.mlir",
+        "narrow_n_matmuls.mlir",
         "subbyte_to_fp.mlir",
     ],
     include = ["*.mlir"],
@@ -45,6 +46,7 @@
 VMVX_SRCS = enforce_glob(
     [
         "conv2d.mlir",
+        "narrow_n_matmuls.mlir",
     ],
     include = ["*.mlir"],
     exclude = [
@@ -65,6 +67,7 @@
     [
         "conv2d.mlir",
         "subbyte_to_fp.mlir",
+        "narrow_n_matmuls.mlir",
     ],
     include = ["*.mlir"],
     exclude = [
@@ -112,6 +115,7 @@
         "subbyte_to_fp.mlir",
         # currently only enabled on cuda as it can be slow on other backends.
         "large_linalg_matmul.mlir",
+        "narrow_n_matmuls.mlir",
     ],
     include = ["*.mlir"],
     exclude = [
diff --git a/tests/e2e/linalg/CMakeLists.txt b/tests/e2e/linalg/CMakeLists.txt
index fdd9c04..9794387 100644
--- a/tests/e2e/linalg/CMakeLists.txt
+++ b/tests/e2e/linalg/CMakeLists.txt
@@ -16,6 +16,7 @@
   SRCS
     "conv2d.mlir"
     "fp_to_subbyte.mlir"
+    "narrow_n_matmuls.mlir"
     "subbyte_to_fp.mlir"
   TARGET_BACKEND
     "llvm-cpu"
@@ -30,6 +31,7 @@
     check_vmvx_local-task
   SRCS
     "conv2d.mlir"
+    "narrow_n_matmuls.mlir"
   TARGET_BACKEND
     "vmvx"
   DRIVER
@@ -41,6 +43,7 @@
     check_vulkan-spirv_vulkan
   SRCS
     "conv2d.mlir"
+    "narrow_n_matmuls.mlir"
     "subbyte_to_fp.mlir"
   TARGET_BACKEND
     "vulkan-spirv"
@@ -81,6 +84,7 @@
     "conv2d.mlir"
     "fp_to_subbyte.mlir"
     "large_linalg_matmul.mlir"
+    "narrow_n_matmuls.mlir"
     "subbyte_to_fp.mlir"
   TARGET_BACKEND
     "cuda"
diff --git a/tests/e2e/linalg/narrow_n_matmuls.mlir b/tests/e2e/linalg/narrow_n_matmuls.mlir
new file mode 100644
index 0000000..578d7f7
--- /dev/null
+++ b/tests/e2e/linalg/narrow_n_matmuls.mlir
@@ -0,0 +1,126 @@
+// Test various forms of matmuls with narrow N, in particual matvec/batch_matvec
+// (implicitly N=1) and matmuls with N=1 and N=2.
+//
+// The reason why this needs extensive e2e testing is the transposition of
+// narrow N to narrow M in data tiling (around CPUMaterializeEncodingPass).
+// It doesn't hurt to enable this case on all backends though.
+
+func.func @matvec() {
+  %lhs = util.unfoldable_constant dense<[
+     [1, 2, 0, 5],
+     [3, 4, -1, -3],
+     [5, 6, -7, 0]
+  ]> : tensor<3x4xi8>
+  %rhs = util.unfoldable_constant dense<[-2, 3, 4, -1]> : tensor<4xi8>
+  %acc = util.unfoldable_constant dense<[1, 2, 3]> : tensor<3xi32>
+  %result = linalg.matvec ins(%lhs, %rhs : tensor<3x4xi8>, tensor<4xi8>) outs(%acc : tensor<3xi32>) -> tensor<3xi32>
+  check.expect_eq_const(%result, dense<
+    [0, 7, -17]
+  > : tensor<3xi32>) : tensor<3xi32>
+  return
+}
+
+func.func @batch_matvec() {
+  %lhs = util.unfoldable_constant dense<[[
+     [1, 2, 0, 5],
+     [3, 4, -1, -3],
+     [5, 6, -7, 0]
+  ], [
+     [-3, 1, 4, 2],
+     [-1, 0, 6, -1],
+     [1, -2, 3, -4]
+  ]]> : tensor<2x3x4xi8>
+  %rhs = util.unfoldable_constant dense<[
+    [-2, 3, 4, -1],
+    [1, 2, -5, 3]
+  ]> : tensor<2x4xi8>
+  %acc = util.unfoldable_constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+  %result = linalg.batch_matvec ins(%lhs, %rhs : tensor<2x3x4xi8>, tensor<2x4xi8>) outs(%acc : tensor<2x3xi32>) -> tensor<2x3xi32>
+  check.expect_eq_const(%result, dense<[
+    [0, 7, -17],
+    [-11, -29, -24]
+  ]> : tensor<2x3xi32>) : tensor<2x3xi32>
+  return
+}
+
+func.func @matmul_narrow_n_1() {
+  %lhs = util.unfoldable_constant dense<[
+     [1, 2, 0, 5],
+     [3, 4, -1, -3],
+     [5, 6, -7, 0]
+  ]> : tensor<3x4xi8>
+  %rhs = util.unfoldable_constant dense<[[-2], [3], [4], [-1]]> : tensor<4x1xi8>
+  %acc = util.unfoldable_constant dense<[[1], [2], [3]]> : tensor<3x1xi32>
+  %result = linalg.matmul ins(%lhs, %rhs : tensor<3x4xi8>, tensor<4x1xi8>) outs(%acc : tensor<3x1xi32>) -> tensor<3x1xi32>
+  check.expect_eq_const(%result, dense<
+    [[0], [7], [-17]]
+  > : tensor<3x1xi32>) : tensor<3x1xi32>
+  return
+}
+
+func.func @batch_matmul_narrow_n_1() {
+  %lhs = util.unfoldable_constant dense<[[
+     [1, 2, 0, 5],
+     [3, 4, -1, -3],
+     [5, 6, -7, 0]
+  ], [
+     [-3, 1, 4, 2],
+     [-1, 0, 6, -1],
+     [1, -2, 3, -4]
+  ]]> : tensor<2x3x4xi8>
+  %rhs = util.unfoldable_constant dense<[
+    [[-2], [3], [4], [-1]],
+    [[1], [2], [-5], [3]]
+  ]> : tensor<2x4x1xi8>
+  %acc = util.unfoldable_constant dense<[
+    [[1], [2], [3]],
+    [[4], [5], [6]]
+  ]> : tensor<2x3x1xi32>
+  %result = linalg.batch_matmul ins(%lhs, %rhs : tensor<2x3x4xi8>, tensor<2x4x1xi8>) outs(%acc : tensor<2x3x1xi32>) -> tensor<2x3x1xi32>
+  check.expect_eq_const(%result, dense<[
+    [[0], [7], [-17]],
+    [[-11], [-29], [-24]]
+  ]> : tensor<2x3x1xi32>) : tensor<2x3x1xi32>
+  return
+}
+
+func.func @matmul_narrow_n_2() {
+  %lhs = util.unfoldable_constant dense<[
+     [1, 2, 0, 5],
+     [3, 4, -1, -3],
+     [5, 6, -7, 0]
+  ]> : tensor<3x4xi8>
+  %rhs = util.unfoldable_constant dense<[[-2, 1], [3, -1], [4, 0], [-1, 2]]> : tensor<4x2xi8>
+  %acc = util.unfoldable_constant dense<[[1, -1], [2, 0], [3, 1]]> : tensor<3x2xi32>
+  %result = linalg.matmul ins(%lhs, %rhs : tensor<3x4xi8>, tensor<4x2xi8>) outs(%acc : tensor<3x2xi32>) -> tensor<3x2xi32>
+  check.expect_eq_const(%result, dense<
+    [[0, 8], [7, -7], [-17, 0]]
+  > : tensor<3x2xi32>) : tensor<3x2xi32>
+  return
+}
+
+func.func @batch_matmul_narrow_n_2() {
+  %lhs = util.unfoldable_constant dense<[[
+     [1, 2, 0, 5],
+     [3, 4, -1, -3],
+     [5, 6, -7, 0]
+  ], [
+     [-3, 1, 4, 2],
+     [-1, 0, 6, -1],
+     [1, -2, 3, -4]
+  ]]> : tensor<2x3x4xi8>
+  %rhs = util.unfoldable_constant dense<[
+    [[-2, 0], [3, 1], [4, -1], [-1, 2]],
+    [[1, -2], [2, 3], [-5, -3], [3, 0]]
+  ]> : tensor<2x4x2xi8>
+  %acc = util.unfoldable_constant dense<[
+    [[1, -1], [2, 0], [3, 1]],
+    [[4, 2], [5, 1], [6, -1]]
+  ]> : tensor<2x3x2xi32>
+  %result = linalg.batch_matmul ins(%lhs, %rhs : tensor<2x3x4xi8>, tensor<2x4x2xi8>) outs(%acc : tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
+  check.expect_eq_const(%result, dense<[
+    [[0, 11], [7, -1], [-17, 14]],
+    [[-11, -1], [-29, -15], [-24, -18]]
+  ]> : tensor<2x3x2xi32>) : tensor<2x3x2xi32>
+  return
+}