Extend TD strategy to support batched matmul (#14292)

Extend the transform dialect strategy for matmul to also support batched
matmul. This is guarded by a different flag and is *disabled by
default*.

The extension maps the batch dimension to blocks/threads along the Z
axis. Good default values will come separately after an end-to-end
experimentation.

Co-authored-by: Nicolas Vasilache <ntv@google.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
index 692e632..e75c01c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
@@ -19,6 +19,7 @@
     srcs = enforce_glob(
         [
             "affinemin_canonicalization.mlir",
+            "batch_matmuls.mlir",
             "bubble_up_ordinal_ops.mlir",
             "bufferize_copy_only_dispatches.mlir",
             "canonicalize_interface_load_store.mlir",
@@ -64,6 +65,7 @@
         ],
         include = ["*.mlir"],
         exclude = [
+            "batch_matmul_match_spec.mlir",
             "convolution_match_spec.mlir",
             "reductions_codegen_spec.mlir",
             "reductions_match_spec.mlir",
@@ -73,6 +75,7 @@
     # transform dialect spec files are MLIR files that specify a transformation,
     # they need to be included as data.
     data = [
+        "batch_matmul_match_spec.mlir",
         "convolution_match_spec.mlir",
         "reductions_codegen_spec.mlir",
         "reductions_match_spec.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
index 87e8c56..b76b7ae 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt
@@ -15,6 +15,7 @@
     lit
   SRCS
     "affinemin_canonicalization.mlir"
+    "batch_matmuls.mlir"
     "bubble_up_ordinal_ops.mlir"
     "bufferize_copy_only_dispatches.mlir"
     "canonicalize_interface_load_store.mlir"
@@ -61,6 +62,7 @@
     FileCheck
     iree-opt
   DATA
+    batch_matmul_match_spec.mlir
     convolution_match_spec.mlir
     reductions_codegen_spec.mlir
     reductions_match_spec.mlir
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/batch_matmul_match_spec.mlir b/compiler/src/iree/compiler/Codegen/Common/test/batch_matmul_match_spec.mlir
new file mode 100644
index 0000000..302aabd
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/batch_matmul_match_spec.mlir
@@ -0,0 +1,9 @@
+// RUN: iree-opt %s
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  transform.iree.register_match_callbacks
+  %0:2 = transform.iree.match_callback failures(propagate) "batch_matmul"(%arg0) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+  transform.iree.emit_remark "fill" at %0#0 : !transform.any_op
+  transform.iree.emit_remark "batch matmul" at %0#1 : !transform.any_op
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/batch_matmuls.mlir b/compiler/src/iree/compiler/Codegen/Common/test/batch_matmuls.mlir
new file mode 100644
index 0000000..bfeb9a6
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/test/batch_matmuls.mlir
@@ -0,0 +1,71 @@
+// RUN: iree-opt %s --iree-transform-dialect-interpreter='transform-file-name=%p/batch_matmul_match_spec.mlir' --split-input-file --verify-diagnostics
+
+!lhs = tensor<128x80x32xf32>
+!rhs = tensor<128x32x320xf32>
+!res = tensor<128x80x320xf32>
+
+func.func @batch_matmul(%arg0: !lhs, %arg1: !rhs, %arg2: !res) -> !res {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : !res
+  // expected-remark @below {{fill}}
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : !res) -> !res
+  // expected-remark @below {{batch matmul}}
+  %2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%arg0, %arg1 : !lhs, !rhs) outs(%1 : !res) {
+  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+    %3 = arith.mulf %arg3, %arg4 : f32
+    %4 = arith.addf %arg5, %3 : f32
+    linalg.yield %4 : f32
+  } -> !res
+  return %2 : !res
+}
+
+// -----
+
+!lhs = tensor<128x80x32xf32>
+!rhs = tensor<128x32x320xf32>
+!res = tensor<128x80x320xf32>
+
+func.func @batch_matmul(%arg0: !lhs, %arg1: !rhs, %arg2: !res) -> !res {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : !res
+  // expected-remark @below {{fill}}
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : !res) -> !res
+  // expected-remark @below {{batch matmul}}
+  %2 = linalg.batch_matmul ins(%arg0, %arg1 : !lhs, !rhs) outs(%1 : !res) -> !res
+  return %2 : !res
+}
+
+// -----
+
+!lhs = tensor<80x128x32xf32>
+!rhs = tensor<128x32x320xf32>
+!res = tensor<80x320x128xf32>
+
+func.func @batch_matmul(%arg0: !lhs, %arg1: !rhs, %arg2: !res) -> !res {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : !res
+  // expected-remark @below {{fill}}
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : !res) -> !res
+  // expected-remark @below {{batch matmul}}
+  %2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  } ins(%arg0, %arg1 : !lhs, !rhs) outs(%1 : !res) {
+  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+    %3 = arith.mulf %arg3, %arg4 : f32
+    %4 = arith.addf %arg5, %3 : f32
+    linalg.yield %4 : f32
+  } -> !res
+  return %2 : !res
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index ec16caa..91f70d3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -35,6 +35,7 @@
             "reduction_pipeline.mlir",
             "rocdl_pipeline_test.mlir",
             "set_transform_strategy.mlir",
+            "set_transform_strategy_batch_matmul.mlir",
             "set_transform_strategy_pad.mlir",
             "illegal_configuration.mlir",
             "layout_analysis_and_distribution.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 17c7ef6..731adcc 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -37,6 +37,7 @@
     "reduction_pipeline_transform.mlir"
     "rocdl_pipeline_test.mlir"
     "set_transform_strategy.mlir"
+    "set_transform_strategy_batch_matmul.mlir"
     "set_transform_strategy_pad.mlir"
     "tensor_pad.mlir"
     "tensorcore_vectorization.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir
new file mode 100644
index 0000000..38fe1ba
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy_batch_matmul.mlir
@@ -0,0 +1,231 @@
+// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target{test-lowering-configuration})))" \
+// RUN:     --iree-codegen-llvmgpu-enable-transform-dialect-jit=1 --iree-codegen-llvmgpu-enable-transform-dialect-batch-matmul-strategy |\
+// RUN:   FileCheck %s --check-prefixes=CHECK,DEFAULT
+
+// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target{test-lowering-configuration})))" \
+// RUN:     --iree-codegen-llvmgpu-enable-transform-dialect-jit=1 --iree-codegen-llvmgpu-enable-transform-dialect-batch-matmul-strategy \
+// RUN: -td-matmul-strategy-blk-sizes=128,64,32,2 \
+// RUN: -td-matmul-strategy-reduc-size=8 \
+// RUN: -td-matmul-strategy-num-threads=32,4,1 \
+// RUN: -td-matmul-strategy-num-warps=1,4,1 \
+// RUN: -td-matmul-strategy-use-async-copies=true \
+// RUN: -td-matmul-strategy-pipeline-depth=3 \
+// RUN: -td-matmul-strategy-use-mma-sync=false \
+// RUN: -td-matmul-strategy-use-fma=true \
+// RUN:   | FileCheck %s --check-prefixes=CHECK,OPTIONS
+
+#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>
+#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#executable_target_cuda_nvptx_fb], legacy_sync}>
+module attributes {hal.device.targets = [#device_target_cuda]} {
+  hal.executable private @batch_matmul_dispatch_0 {
+    hal.executable.variant public @cuda_nvptx_fb, target = #executable_target_cuda_nvptx_fb {
+      hal.executable.export public @batch_matmul_dispatch_0_generic_128x80x320x32_f32 ordinal(0) layout(#pipeline_layout) {
+      ^bb0(%arg0: !hal.device):
+        %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
+        hal.return %x, %y, %z : index, index, index
+      }
+      builtin.module {
+        func.func @batch_matmul_dispatch_0_generic_128x80x320x32_f32() {
+          %c0 = arith.constant 0 : index
+          %cst = arith.constant 0.000000e+00 : f32
+          %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x80x32xf32>>
+          %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128x32x320xf32>>
+          %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<128x80x320xf32>>
+          %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [128, 80, 32], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<128x80x32xf32>> -> tensor<128x80x32xf32>
+          %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [128, 32, 320], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<128x32x320xf32>> -> tensor<128x32x320xf32>
+          %5 = tensor.empty() : tensor<128x80x320xf32>
+          %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128x80x320xf32>) -> tensor<128x80x320xf32>
+          %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<128x80x32xf32>, tensor<128x32x320xf32>) outs(%6 : tensor<128x80x320xf32>) {
+          ^bb0(%in: f32, %in_0: f32, %out: f32):
+            %8 = arith.mulf %in, %in_0 : f32
+            %9 = arith.addf %out, %8 : f32
+            linalg.yield %9 : f32
+          } -> tensor<128x80x320xf32>
+          flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [128, 80, 320], strides = [1, 1, 1] : tensor<128x80x320xf32> -> !flow.dispatch.tensor<writeonly:tensor<128x80x320xf32>>
+          return
+        }
+      }
+    }
+  }
+  func.func @batch_matmul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
+    %c1310720 = arith.constant 1310720 : index
+    %c5242880 = arith.constant 5242880 : index
+    %c13107200 = arith.constant 13107200 : index
+    %c0 = arith.constant 0 : index
+    %c320 = arith.constant 320 : index
+    %c553648160_i32 = arith.constant 553648160 : i32
+    %c1_i32 = arith.constant 1 : i32
+    %c128 = arith.constant 128 : index
+    %c80 = arith.constant 80 : index
+    %c32 = arith.constant 32 : index
+    hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input 0") shape([%c128, %c80, %c32]) type(%c553648160_i32) encoding(%c1_i32)
+    %0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<128x80x32xf32> in !stream.resource<external>{%c1310720}
+    hal.buffer_view.assert<%arg1 : !hal.buffer_view> message("input 1") shape([%c128, %c32, %c320]) type(%c553648160_i32) encoding(%c1_i32)
+    %1 = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<128x32x320xf32> in !stream.resource<external>{%c5242880}
+    %2 = stream.resource.alloc uninitialized : !stream.resource<external>{%c13107200}
+    %3 = stream.cmd.execute with(%0 as %arg3: !stream.resource<external>{%c1310720}, %1 as %arg4: !stream.resource<external>{%c5242880}, %2 as %arg5: !stream.resource<external>{%c13107200}) {
+      stream.cmd.dispatch @batch_matmul_dispatch_0::@cuda_nvptx_fb::@batch_matmul_dispatch_0_generic_128x80x320x32_f32 {
+        ro %arg3[%c0 for %c1310720] : !stream.resource<external>{%c1310720},
+        ro %arg4[%c0 for %c5242880] : !stream.resource<external>{%c5242880},
+        wo %arg5[%c0 for %c13107200] : !stream.resource<external>{%c13107200}
+      } attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]}
+    } => !stream.timepoint
+    %4 = stream.timepoint.await %3 => %2 : !stream.resource<external>{%c13107200}
+    %5 = stream.tensor.export %4 : tensor<128x80x320xf32> in !stream.resource<external>{%c13107200} -> !hal.buffer_view
+    return %5 : !hal.buffer_view
+  }
+}
+
+
+// CHECK: transform.sequence  failures(propagate) {
+// CHECK:   transform.iree.register_match_callbacks
+// CHECK:   %[[MATCH:.+]]:2 = transform.iree.match_callback failures(propagate) "batch_matmul"
+// CHECK:   %[[FORALL:.+]], %[[TILED:.+]] = transform.structured.tile_to_forall_op %[[MATCH]]#1
+// DEFAULT:   num_threads [] tile_sizes [64, 64, 1](mapping = [#gpu.block<z>, #gpu.block<y>, #gpu.block<x>])
+// OPTIONS:   num_threads [] tile_sizes [128, 64, 32](mapping = [#gpu.block<z>, #gpu.block<y>, #gpu.block<x>])
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   %[[FUSED:.+]], %[[CONTAINING:.+]] = transform.structured.fuse_into_containing_op %[[MATCH]]#0 into %[[FORALL]]
+// CHECK:   transform.iree.populate_workgroup_count_region_using_num_threads_slice %[[FORALL]]
+// CHECK:   %[[TILED_LINALG:.+]], %[[LOOPS:.+]] = transform.structured.tile %tiled_op
+// DEFAULT:   [0, 0, 0, 16]
+// OPTIONS:   [0, 0, 0, 8]
+// CHECK:   %[[PADDED:.+]] = transform.structured.pad %tiled_linalg_op 
+// CHECK:     pack_paddings = [1, 1, 1, 1], pad_to_multiple_of = [1, 1, 1, 1], padding_dimensions = [0, 1, 2, 3]
+// CHECK:     padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]}
+// CHECK:   %[[V3:.+]] = get_producer_of_operand %[[PADDED]][2]
+// CHECK:   transform.structured.hoist_pad %{{.*}} by 1 loops
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   %[[FILL:.+]] = transform.structured.match ops{["linalg.fill"]}
+// CHECK:   apply_patterns 
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   transform.structured.match ops{["tensor.parallel_insert_slice"]}
+// CHECK:   transform.structured.insert_slice_to_copy
+// CHECK:   %[[LHS:.+]] = get_producer_of_operand %[[PADDED]][0]
+// CHECK:   %[[RHS:.+]] = get_producer_of_operand %[[PADDED]][1]
+// CHECK:   %[[RHS_DPS:.+]] = transform.structured.rewrite_in_destination_passing_style %[[RHS]]
+
+// CHECK:   transform.structured.tile_to_forall_op %[[LHS]] 
+// DEFAULT:  num_threads [1, 32, 4] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// OPTIONS:  num_threads [1, 64, 2] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   transform.structured.match ops{["scf.if"]}
+// CHECK:   transform.scf.take_assumed_branch %{{.*}} take_else_branch
+
+// CHECK:   transform.structured.tile_to_forall_op %[[RHS_DPS]]  
+// DEFAULT:  num_threads [8, 16, 1] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// OPTIONS:  num_threads [2, 8, 8] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// CHECK:   apply_patterns 
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+
+// CHECK:   transform.structured.tile_to_forall_op
+// DEFAULT:  num_threads [2, 64, 1] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// OPTIONS:  num_threads [1, 16, 8] tile_sizes [](mapping = [#gpu.linear<z>, #gpu.linear<y>, #gpu.linear<x>])
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+
+// CHECK:   transform.structured.tile_to_forall_op
+// DEFAULT:  num_threads [1, 2, 64] tile_sizes [](mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>])
+// OPTIONS:  num_threads [1, 4, 32] tile_sizes [](mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>])
+// CHECK:   apply_patterns 
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+
+// CHECK:   %forall_op_8, %tiled_op_9 = transform.structured.tile_to_forall_op %[[FILL]]
+// DEFAULT:   num_threads [1, 2, 64] tile_sizes [](mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>])
+// OPTIONS:   num_threads [1, 4, 32] tile_sizes [](mapping = [#gpu.thread<z>, #gpu.thread<y>, #gpu.thread<x>])
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+
+// CHECK:   transform.structured.masked_vectorize
+// DEFAULT:   vector_sizes [64, 2, 4]
+// OPTIONS:   vector_sizes [128, 1, 4]
+// CHECK:   transform.structured.masked_vectorize
+// DEFAULT:   vector_sizes [32, 1, 1]
+// OPTIONS:   vector_sizes [128, 4, 4]
+// CHECK:   apply_patterns
+// CHECK:     transform.apply_patterns.vector.lower_masked_transfers
+// CHECK:   transform.structured.vectorize
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   apply_patterns
+// CHECK:     transform.apply_patterns.canonicalization
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   transform.iree.eliminate_empty_tensors
+
+// CHECK:   transform.iree.bufferize {target_gpu}
+// CHECK:   transform.iree.erase_hal_descriptor_type_from_memref
+// CHECK:   transform.iree.apply_buffer_optimizations
+// CHECK:   transform.iree.forall_to_workgroup
+// CHECK:   transform.iree.map_nested_forall_to_gpu_threads
+// DEFAULT:  workgroup_dims = [64, 2, 1] warp_dims = [2, 2, 1]
+// OPTIONS:  workgroup_dims = [32, 4, 1] warp_dims = [1, 4, 1]
+// CHECK:   transform.iree.eliminate_gpu_barriers
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   transform.iree.hoist_static_alloc
+// CHECK:   apply_patterns
+// CHECK:     transform.apply_patterns.memref.fold_memref_alias_ops
+// CHECK:   apply_patterns
+// CHECK:     transform.apply_patterns.memref.extract_address_computations
+// CHECK:   apply_patterns
+// CHECK:     transform.apply_patterns.linalg.tiling_canonicalization
+// CHECK:     transform.apply_patterns.iree.fold_fill_into_pad
+// CHECK:     transform.apply_patterns.scf.for_loop_canonicalization
+// CHECK:     transform.apply_patterns.canonicalization
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   transform.iree.synchronize_loop
+// CHECK:   transform.structured.hoist_redundant_vector_transfers
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   transform.iree.apply_buffer_optimizations
+// CHECK:   %30 = transform.iree.eliminate_gpu_barriers
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   apply_patterns
+// CHECK:     transform.apply_patterns.memref.fold_memref_alias_ops
+
+// CHECK:   transform.memref.multibuffer
+// DEFAULT:   factor = 2
+// OPTIONS:   factor = 3
+// CHECK:   apply_patterns
+// CHECK:     transform.apply_patterns.vector.transfer_to_scf   max_transfer_rank = 1 full_unroll = true
+// CHECK:   apply_patterns 
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   transform.iree.create_async_groups
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
+// CHECK:   transform.iree.pipeline_shared_memory_copies
+// DEFAULT:   depth = 2
+// OPTIONS:   depth = 3
+// CHECK:   apply_patterns
+// CHECK:     transform.apply_patterns.vector.lower_masks
+// CHECK:   apply_patterns
+// CHECK:     transform.apply_patterns.vector.materialize_masks
+// CHECK:   apply_patterns
+// CHECK:   transform.iree.apply_licm
+// CHECK:   transform.iree.apply_cse
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp
index 364e1ec..c3a957b 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/Codegen/TransformStrategies/GPU/CopyMapping.h"
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -104,7 +105,8 @@
     MLIRContext *ctx, int totalNumThreads, int64_t alignment,
     ArrayRef<int64_t> copySizes, bool favorPredication,
     int64_t elementalBitWidth) {
-  assert(copySizes.size() == 2 && "only 2-D copy supported for now");
+  assert(!copySizes.empty() && copySizes.size() <= 3 &&
+         "only 1,2,3-D copies are supported for now");
   FailureOr<CopyMapping> maybeCopyMapping =
       CopyMapping::numThreadsForCopy(totalNumThreads, alignment, copySizes,
                                      favorPredication, elementalBitWidth);
@@ -117,18 +119,24 @@
         elementalBitWidth);
   }
   assert(succeeded(maybeCopyMapping) && "failed to compute copy mapping");
-  assert(maybeCopyMapping->numThreads.size() == 2 &&
-         "compute copy mapping expected size-2");
-  int64_t numThreadsY = maybeCopyMapping->numThreads[0];
-  int64_t numThreadsX = maybeCopyMapping->numThreads[1];
-  int64_t sizeY = copySizes[0];
-  int64_t sizeX = copySizes[1];
-  MappingInfo res{
-      /*numThreads=*/{numThreadsY, numThreadsX},
-      /*tilecopySizes=*/
-      {mlir::ceilDiv(sizeY, numThreadsY), mlir::ceilDiv(sizeX, numThreadsX)},
-      /*threadMapping=*/{linearIdY(ctx), linearIdX(ctx)},
-      /*vectorSize=*/maybeCopyMapping->vectorSize};
+  assert(maybeCopyMapping->numThreads.size() == copySizes.size() &&
+         "compute copy mapping expected same number of threads and copy sizes");
+
+  SmallVector<int64_t> tileSizes = llvm::to_vector(llvm::map_range(
+      llvm::zip(copySizes, maybeCopyMapping->numThreads), [](auto &&pair) {
+        int64_t size, numThreads;
+        std::tie(size, numThreads) = pair;
+        return mlir::ceilDiv(size, numThreads);
+      }));
+  SmallVector<Attribute> allThreadMappings{linearIdZ(ctx), linearIdY(ctx),
+                                           linearIdX(ctx)};
+  auto threadMapping =
+      llvm::to_vector(ArrayRef(allThreadMappings).take_back(tileSizes.size()));
+
+  MappingInfo res{/*numThreads=*/maybeCopyMapping->numThreads,
+                  /*tilecopySizes=*/tileSizes,
+                  /*threadMapping=*/threadMapping,
+                  /*vectorSize=*/maybeCopyMapping->vectorSize};
   LLVM_DEBUG(res.print(DBGS()); llvm::dbgs() << "\n");
   return res;
 }
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
index 59956ba..f92f8f2 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
@@ -36,6 +37,7 @@
 using iree_compiler::buildTileFuseDistToForallWithNumThreads;
 using iree_compiler::buildTileFuseDistToForallWithTileSizes;
 using iree_compiler::TileToForallAndFuseAndDistributeResult;
+using iree_compiler::gpu::BatchMatmulStrategy;
 using iree_compiler::gpu::buildBufferize;
 using iree_compiler::gpu::buildConvertToAsyncCopies;
 using iree_compiler::gpu::buildConvertToTensorCoreOp;
@@ -97,6 +99,46 @@
   return success();
 }
 
+LogicalResult BatchMatmulStrategy::validate(const GPUModel &gpuModel) const {
+  if (failed(MatmulStrategy::validate(gpuModel))) {
+    return failure();
+  }
+
+  if (batch() < blockTileBatch()) {
+    return emitError(UnknownLoc::get(ctx))
+           << "batch( " << batch() << ") <  blockTileBatch(" << blockTileBatch()
+           << ") this is at risk of not vectorizing and is NYI";
+  }
+
+  // Only single outermost batch dimension is currently supported.
+  if (captures.batches().size() != 1 || captures.batches().back() != 0) {
+    LDBG("--Couldn't find single outermost batch dimension\n");
+    return failure();
+  }
+
+  if (blockTileSizes.size() < 3) {
+    LDBG("--Not enough block tile sizes\n");
+    return failure();
+  }
+
+  if (numWarps.size() < 3) {
+    LDBG("--Not enough num warps\n");
+    return failure();
+  }
+
+  if (numThreads.size() < 3) {
+    LDBG("--Not enough num threads\n");
+    return failure();
+  }
+
+  if (!useFma) {
+    LDBG("--Only FMA is supported for batch matmul atm\n");
+    return failure();
+  }
+
+  return success();
+}
+
 static std::tuple<Value, Value, Value, Value>
 buildMatmulStrategyBlockDistribution(ImplicitLocOpBuilder &b, Value variantH,
                                      const MatmulStrategy &strategy) {
@@ -130,22 +172,26 @@
                          tileResult.tiledOpH, Value(), tileResult.forallH);
 }
 
-void iree_compiler::gpu::buildMatmulTensorCoreStrategy(
-    ImplicitLocOpBuilder &b, Value variantH, const MatmulStrategy &strategy) {
-  LLVM_DEBUG(strategy.print(DBGS()));
+/// Builds the common part of the schedule for matmuls and batched matmuls.
+static void
+buildCommonMatmulLikeThreadSchedule(ImplicitLocOpBuilder &b, Value variantH,
+                                    Value fillH, Value matmulH,
+                                    const MatmulStrategy &strategy) {
+  using mlir::iree_compiler::buildLowerVectorMasksAndCleanup;
+  using mlir::iree_compiler::buildTileFuseToScfFor;
+  using namespace mlir::iree_compiler::gpu;
 
-  // Step 1. Apply block-level part of the strategy, keeps everything fused.
-  auto [fillH, matmulH, maybeTiledTrailingHBlock, forall] =
-      buildMatmulStrategyBlockDistribution(b, variantH, strategy);
-  // Tile reduction loop.
-  SmallVector<int64_t> tileSizes{0, 0, strategy.reductionTileSize};
+  // Tile the reduction loop (last in the list).
+  SmallVector<int64_t> tileSizes(strategy.captures.matmulOpSizes.size() - 1, 0);
+  tileSizes.push_back(strategy.reductionTileSize);
+
   // Avoid canonicalizing before the pad to avoid folding away the extract_slice
   // on the output needed to hoist the output pad.
   auto tileReductionResult = buildTileFuseToScfFor(
       b, variantH, matmulH, {}, getAsOpFoldResult(b.getI64ArrayAttr(tileSizes)),
       /*canonicalize=*/false);
 
-  // Step 2. Pad the matmul op.
+  // Step 2. Pad the (batch) matmul op.
   auto paddedMatmulOpH =
       buildPad(b, tileReductionResult.tiledOpH,
                strategy.getZeroPadAttrFromElementalTypes(b).getValue(),
@@ -219,3 +265,53 @@
   // Step 13. Late lowerings and cleanups.
   buildLowerVectorMasksAndCleanup(b, funcH);
 }
+
+void iree_compiler::gpu::buildMatmulTensorCoreStrategy(
+    ImplicitLocOpBuilder &b, Value variantH, const MatmulStrategy &strategy) {
+  LLVM_DEBUG(strategy.print(DBGS()));
+
+  // Step 1. Apply block-level part of the strategy, keeps everything fused.
+  auto [fillH, matmulH, maybeTiledTrailingHBlock, forall] =
+      buildMatmulStrategyBlockDistribution(b, variantH, strategy);
+  buildCommonMatmulLikeThreadSchedule(b, variantH, fillH, matmulH, strategy);
+}
+
+/// Builds the transform dialect operations distributing batch matmul across
+/// blocks according to the given strategy.
+static std::tuple<Value, Value, Value>
+buildBatchMatmulStrategyBlockDistribution(ImplicitLocOpBuilder &b,
+                                          Value variantH,
+                                          const BatchMatmulStrategy &strategy) {
+  b.create<RegisterMatchCallbacksOp>();
+  auto [fillH, bmmH] = unpackRegisteredMatchCallback<2>(
+      b, "batch_matmul", transform::FailurePropagationMode::Propagate,
+      variantH);
+
+  MappingInfo blockMapping = strategy.getBlockMapping();
+  TileToForallAndFuseAndDistributeResult tileResult =
+      buildTileFuseDistToForallWithTileSizes(
+          /*builder=*/b,
+          /*variantH=*/variantH,
+          /*rootH=*/bmmH,
+          /*opsToFuseH=*/fillH,
+          /*tileSizes=*/
+          getAsOpFoldResult(b.getI64ArrayAttr(blockMapping.tileSizes)),
+          /*threadDimMapping=*/
+          b.getArrayAttr(blockMapping.threadMapping));
+
+  // Handle the workgroup count region.
+  b.create<IREEPopulateWorkgroupCountRegionUsingNumThreadsSliceOp>(
+      tileResult.forallH);
+  return std::make_tuple(tileResult.resultingFusedOpsHandles.front(),
+                         tileResult.tiledOpH, tileResult.forallH);
+}
+
+void iree_compiler::gpu::buildBatchMatmulStrategy(
+    ImplicitLocOpBuilder &b, Value variantH,
+    const BatchMatmulStrategy &strategy) {
+  LLVM_DEBUG(strategy.print(DBGS()));
+
+  auto [fillH, matmulH, forallH] =
+      buildBatchMatmulStrategyBlockDistribution(b, variantH, strategy);
+  buildCommonMatmulLikeThreadSchedule(b, variantH, fillH, matmulH, strategy);
+}
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h
index 8f2bb09..2021b18 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/MatmulTensorCoreStrategy.h
@@ -172,8 +172,133 @@
                        /*vectorSize=*/std::nullopt};
   }
 
-  void print(llvm::raw_ostream &os) const;
-  LLVM_DUMP_METHOD void dump() const;
+  void print(llvm::raw_ostream &os) const override;
+  LLVM_DUMP_METHOD void dump() const override;
+};
+
+/// An extension of the matmul strategy to batched matrix multiplications.
+class BatchMatmulStrategy : public MatmulStrategy {
+public:
+  /// Construct the default strategy, pulling options from the command-line
+  /// arguments if provided and using the defaults otherwise.
+  BatchMatmulStrategy(MLIRContext *context, const GPUModel &gpuModel,
+                      const transform_ext::MatchedMatmulCaptures &captures)
+      : MatmulStrategy(context, captures, gpuModel) {
+    initDefaultValues(gpuModel);
+  }
+
+  /// Initialize the default values of the strategy.
+  void initDefaultValues(const GPUModel &gpuModel) override {
+    // First, initialize as if this was a simple matmul.
+    MatmulStrategy::initDefaultValues(gpuModel);
+
+    // Make sure we pad along all dimensions.
+    paddingDimensions = {0, 1, 2, 3};
+    packingDimensions = {1, 1, 1, 1};
+  }
+
+  /// Check that the strategy is valid for the captures and the model.
+  LogicalResult validate(const GPUModel &gpuModel) const override;
+
+  /// Named accessors to shapes.
+  int64_t batch() const { return captures.matmulOpSizes[0]; }
+  int64_t m() const override { return captures.matmulOpSizes[1]; }
+  int64_t n() const override { return captures.matmulOpSizes[2]; }
+  int64_t k() const override { return captures.matmulOpSizes[3]; }
+
+  /// Named accessors to block tile sizes associated with shapes.
+  int64_t blockTileBatch() const { return blockTileSizes[0]; }
+  int64_t blockTileM() const override { return blockTileSizes[1]; }
+  int64_t blockTileN() const override { return blockTileSizes[2]; }
+
+  /// Number of threads to use.
+  int64_t numThreadsX() const { return numThreads[0]; }
+  int64_t numThreadsY() const { return numThreads[1]; }
+  int64_t numThreadsZ() const { return numThreads[2]; }
+
+  /// Number of warps to use.
+  int64_t numWarpsX() const override { return numWarps[0]; }
+  int64_t numWarpsY() const override { return numWarps[1]; }
+  int64_t numWarpsZ() const { return numWarps[2]; }
+
+  MappingInfo getBlockMapping() const override {
+    return MappingInfo{
+        /*numThreads=*/
+        {},
+        /*tileSizes=*/{blockTileBatch(), blockTileM(), blockTileN()},
+        /*threadMapping=*/{blockZ(ctx), blockY(ctx), blockX(ctx)},
+        /*vectorSize=*/std::nullopt};
+  }
+
+  // LHS copy is batch x M x K.
+  MappingInfo lhsCopyMapping() const override {
+    // TODO: generalize to transpositions, here and below.
+    return CopyMapping::getMappingInfo(
+        ctx, totalNumThreads(), k(),
+        {blockTileBatch(), blockTileM(), reductionTileSize},
+        /*favorPredication=*/false,
+        captures.lhsElementType.getIntOrFloatBitWidth());
+  }
+
+  // RHS copy is batch x K x N.
+  MappingInfo rhsCopyMapping() const override {
+    return CopyMapping::getMappingInfo(
+        ctx, totalNumThreads(), n(),
+        {blockTileBatch(), reductionTileSize, blockTileN()},
+        /*favorPredication=*/false,
+        captures.rhsElementType.getIntOrFloatBitWidth());
+  }
+
+  // RES copy is batch x M x N.
+  MappingInfo resCopyMapping() const override {
+    return CopyMapping::getMappingInfo(
+        ctx, totalNumThreads(), n(),
+        {blockTileBatch(), blockTileM(), blockTileN()},
+        /*favorPredication=*/false,
+        captures.outputElementType.getIntOrFloatBitWidth());
+  }
+
+  /// Validates the mapping for one of the lhs, rhs or res copies.
+  LogicalResult validateCopyMapping(const MappingInfo &mapping,
+                                    StringRef name) const {
+    int64_t threadsUsed =
+        std::accumulate(mapping.numThreads.begin(), mapping.numThreads.end(), 1,
+                        std::multiplies<int64_t>());
+    if (totalNumThreads() < threadsUsed) {
+      InFlightDiagnostic diag = emitError(UnknownLoc::get(ctx))
+                                << "too many threads used for transferring "
+                                << name;
+
+      std::string str;
+      llvm::raw_string_ostream os(str);
+      llvm::interleave(mapping.numThreads, os, " * ");
+      os << " >= " << totalNumThreads();
+      diag.attachNote() << os.str();
+      return diag;
+    }
+
+    return success();
+  }
+
+  /// Check that the mapping computed for a copy is valid.
+  LogicalResult validateLhsCopyMapping() const override {
+    return validateCopyMapping(lhsCopyMapping(), "lhs");
+  }
+  LogicalResult validateRhsCopyMapping() const override {
+    return validateCopyMapping(rhsCopyMapping(), "rhs");
+  }
+  LogicalResult validateResCopyMapping() const override {
+    return validateCopyMapping(resCopyMapping(), "result");
+  }
+
+  // Compute is of the size batch x M x N.
+  MappingInfo computeMapping() const override {
+    assert(useFma && "only fma is currently supported");
+    return MappingInfo{{numThreadsZ(), numThreadsY(), numThreadsX()},
+                       {},
+                       {threadZ(ctx), threadY(ctx), threadX(ctx)},
+                       std::nullopt};
+  }
 };
 
 } // namespace gpu
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp
index 4b2dfff..2b403b6 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.cpp
@@ -63,9 +63,15 @@
 llvm::cl::opt<bool> clGPUEnableTransformDialectPadStrategy(
     "iree-codegen-llvmgpu-enable-transform-dialect-pad-strategy",
     llvm::cl::desc("activate the pad strategy"), llvm::cl::init(false));
+llvm::cl::opt<bool> clGPUEnableTransformDialectBatchMatmulStrategy(
+    "iree-codegen-llvmgpu-enable-transform-dialect-batch-matmul-strategy",
+    llvm::cl::desc("activate the batch matmul strategy, additional "
+                   "configuration flags are shared with matmul"),
+    llvm::cl::init(false));
 
 // TODO: significantly better namespacing.
 using iree_compiler::gpu::AbstractGemmLikeStrategy;
+using iree_compiler::gpu::BatchMatmulStrategy;
 using iree_compiler::gpu::GPUModel;
 using iree_compiler::gpu::kCudaMaxVectorLoadBitWidth;
 using iree_compiler::gpu::MatmulStrategy;
@@ -337,6 +343,91 @@
   return strategy;
 }
 
+/// Update the strategy to make sure it can be consumed by the codegen. In
+/// particular, make sure that tile sizes are smaller than the problem sizes to
+/// actually trigger tiling and mapping to blocks and threads.
+static void failSafeOverrides(BatchMatmulStrategy &strategy,
+                              const GPUModel &gpuModel) {
+  // Configure the strategy as if for a matmul.
+  failSafeOverrides(static_cast<MatmulStrategy &>(strategy), gpuModel);
+
+  // Failsafe for blockTileBatch to avoid tiling by > size (i.e. no tiling).
+  int64_t blockTileBatch = selectLargestFailsafeValueIfNeeded(
+      /*value=*/strategy.blockTileBatch(),
+      /*limit=*/strategy.batch(),
+      /*thresholds=*/{2, 4, 8, 16, 32, 64, 128},
+      /*failSafeValues=*/{1, 2, 4, 8, 16, 32, 64});
+
+  // Override the matmul configuration to be suitable for batch matmul.
+  // Specifically, prepend the tile size for the batch dimension and force FMA.
+  strategy.blockTileSizes.insert(strategy.blockTileSizes.begin(),
+                                 blockTileBatch);
+
+  strategy.useMmaSync = false;
+  strategy.useWmma = false;
+  strategy.useFma = true;
+}
+
+/// Produce a strategy for the batch matmul characterized by the given capture
+/// list (shapes and types).
+static BatchMatmulStrategy getBatchMatmulConfig(MLIRContext *context,
+                                                MatchedMatmulCaptures &captures,
+                                                const GPUModel &gpuModel) {
+  // Command-line arguments trump everything.
+  BatchMatmulStrategy strategy(context, gpuModel, captures);
+  if (strategy.cliOptionsSpecified)
+    return strategy;
+
+  // TODO: fixed strategies and decision tree/heuristic.
+
+  failSafeOverrides(strategy, gpuModel);
+  return strategy;
+}
+
+/// Match the supported batch matmuls and set the transform dialect strategy for
+/// them.
+static LogicalResult matchAndSetBatchMatmulStrategy(func::FuncOp entryPoint,
+                                                    linalg::LinalgOp op,
+                                                    const GPUModel &gpuModel) {
+  if (!clGPUEnableTransformDialectBatchMatmulStrategy) {
+    LDBG("--Batch matmul strategy flag turned off\n");
+    return failure();
+  }
+
+  StructuredOpMatcher *fill;
+  StructuredOpMatcher *bmm;
+  transform_ext::MatchedMatmulCaptures captures;
+  transform_ext::MatcherContext matcherContext;
+  transform_ext::makeBatchMatmulMatcher(matcherContext, bmm, fill, captures,
+                                        /*mustMatchEntireFunc=*/true);
+  if (!matchPattern(op, *bmm)) {
+    LDBG("--Batch matmul strategy failed to match\n");
+    return failure();
+  }
+
+  if (captures.contractionDims.batch.size() != 1 ||
+      captures.contractionDims.m.size() != 1 ||
+      captures.contractionDims.n.size() != 1 ||
+      captures.contractionDims.k.size() != 1 || captures.batches()[0] != 0 ||
+      captures.m() != 1 || captures.n() != 2 || captures.k() != 3) {
+    LDBG("--Only support batch matmul with b, m, n, k iterator order atm\n");
+    return failure();
+  }
+
+  BatchMatmulStrategy strategy =
+      getBatchMatmulConfig(entryPoint->getContext(), captures, gpuModel);
+  if (failed(strategy.validate(gpuModel))) {
+    LDBG("--Batch matmul strategy failed to validate\n");
+    return failure();
+  }
+
+  iree_compiler::createTransformRegion(entryPoint, [&](ImplicitLocOpBuilder &b,
+                                                       Value variantH) {
+    return iree_compiler::gpu::buildBatchMatmulStrategy(b, variantH, strategy);
+  });
+  return success();
+}
+
 static LogicalResult matchAndSetMatmulStrategy(func::FuncOp entryPoint,
                                                linalg::LinalgOp op,
                                                const GPUModel &gpuModel) {
@@ -551,6 +642,11 @@
     LDBG("Activate matmul\n");
     return success();
   }
+  if (succeeded(
+          matchAndSetBatchMatmulStrategy(entryPoint, linalgOp, gpuModel))) {
+    LDBG("Activate batch matmul\n");
+    return success();
+  }
   // TODO: Add more transform dialect strategy for other kind of dispatch
   // regions.
   LDBG("No suitable strategy found\n");
diff --git a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h
index cc75416..852d5f1 100644
--- a/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h
+++ b/compiler/src/iree/compiler/Codegen/TransformStrategies/GPU/Strategies.h
@@ -17,6 +17,7 @@
 namespace gpu {
 
 /// Forward declarations of all supported strategies.
+struct BatchMatmulStrategy;
 struct MatmulStrategy;
 class PadStrategy;
 class SmallReductionStrategy;
@@ -65,6 +66,14 @@
                                    const MatmulStrategy &strategy);
 
 //===--------------------------------------------------------------------===//
+// Batch matmul strategies.
+//===--------------------------------------------------------------------===//
+/// Entry point to build the transform IR corresponding to an FMA-based strategy
+/// for linalg.fill + linalg.batch_matmul.
+void buildBatchMatmulStrategy(ImplicitLocOpBuilder &b, Value variantH,
+                              const BatchMatmulStrategy &strategy);
+
+//===--------------------------------------------------------------------===//
 // Pad strategies.
 //===--------------------------------------------------------------------===//
 /// Entry point to build the transform IR corresponding to a simple pad
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
index 8f98e30..c5d48c3 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Transforms/TransformMatchers.h
@@ -72,6 +72,17 @@
   using Base::Base;
 };
 
+/// Captures the contraction dimensions of the target operation.
+struct CaptureIndexingMaps : public CaptureStaticValue<SmallVector<AffineMap>> {
+  using Base::Base;
+};
+
+/// Captures the contraction dimensions of the target operation.
+struct CaptureContractionDims
+    : public CaptureStaticValue<mlir::linalg::ContractionDimensions> {
+  using Base::Base;
+};
+
 /// Captures the convolution dimensions of the target operation.
 struct CaptureConvDims
     : public CaptureStaticValue<mlir::linalg::detail::ConvolutionDimensions> {
@@ -627,6 +638,7 @@
   /// constant value.
   StructuredOpMatcher &rank(NumGreaterEqualTo minRank);
   StructuredOpMatcher &rank(NumLowerEqualTo maxRank);
+  StructuredOpMatcher &rank(NumEqualsTo exactRank);
 
   /// Adds a predicate checking that the given iteration space dimension is
   /// static/dynamic. The dimension index may be negative, in which case
@@ -667,6 +679,8 @@
   StructuredOpMatcher &rank(CaptureRank capture);
   StructuredOpMatcher &dim(int64_t dimension, CaptureDim capture);
   StructuredOpMatcher &dim(AllDims tag, CaptureDims captures);
+  StructuredOpMatcher &indexingMaps(CaptureIndexingMaps indexingMaps);
+  StructuredOpMatcher &contractionDims(CaptureContractionDims contractionDims);
   StructuredOpMatcher &convolutionDims(CaptureConvDims convDims);
 
   //===-------------------------------------------------------------------===//
@@ -863,6 +877,16 @@
   /// using block arguments in order.
   StructuredOpMatcher &passThroughOp();
 
+  /// Check if the body of the linalg op implements a contraction of the kind
+  ///   result <ReductionOpTy>= input1 <ElemOpTy> input2
+  template <typename ElemOpTy, typename ReductionOpTy>
+  StructuredOpMatcher &hasContractionBody() {
+    return hasContractionBody(
+        [](Operation *op) { return isa<ElemOpTy>(op); },
+        [](Operation *op) { return isa<ReductionOpTy>(op); },
+        ElemOpTy::getOperationName(), ReductionOpTy::getOperationName());
+  }
+
 private:
   /// Non-template implementations of nested predicate builders for inputs,
   /// outputs and results. Should not be called directly.
@@ -881,6 +905,13 @@
   // Common util for constant matcher.
   StructuredOpMatcher &input(int64_t position,
                              std::function<bool(llvm::APFloat)> floatValueFn);
+
+  /// Non-template implementation of hasContractionBody. Takes callbacks for
+  /// checking operation kinds and names for error reporting.
+  StructuredOpMatcher &
+  hasContractionBody(function_ref<bool(Operation *)> isaElemOpTy,
+                     function_ref<bool(Operation *)> isaReductionOpTy,
+                     StringRef elemOpName, StringRef reductionOpName);
 };
 
 /// Creates a matcher of an arbitrary structured op.
@@ -1009,8 +1040,36 @@
 };
 
 struct MatchedMatmulCaptures {
+  linalg::ContractionDimensions contractionDims = {};
   Type lhsElementType, rhsElementType, outputElementType;
   SmallVector<int64_t> matmulOpSizes = {};
+  SmallVector<AffineMap> indexingMaps;
+
+  /// Helper functions.
+  int64_t rank() const { return matmulOpSizes.size(); }
+  /// Return all batches.
+  ArrayRef<unsigned> batches() const { return contractionDims.batch; }
+  /// Return the most minor candidate dimension for `m`.
+  int64_t m() const { return contractionDims.m.back(); }
+  /// Return the most minor candidate dimension for `n`.
+  int64_t n() const { return contractionDims.n.back(); }
+  /// Return the most minor candidate dimension for `k`.
+  int64_t k() const { return contractionDims.k.back(); }
+  /// AffineMap for indexing into the LHS.
+  AffineMap lhsIndexing() const {
+    assert(indexingMaps.size() == 3 && "expected 3 indexing maps");
+    return indexingMaps[0];
+  }
+  /// AffineMap for indexing into the RHS.
+  AffineMap rhsIndexing() const {
+    assert(indexingMaps.size() == 3 && "expected 3 indexing maps");
+    return indexingMaps[1];
+  }
+  /// AffineMap for indexing into the RES.
+  AffineMap resIndexing() const {
+    assert(indexingMaps.size() == 3 && "expected 3 indexing maps");
+    return indexingMaps[2];
+  }
 };
 
 /// Creates a group of matchers for:
@@ -1046,6 +1105,19 @@
                        MatchedMatmulCaptures &captures,
                        bool mustMatchEntireFunc);
 
+/// Create a group of matchers of batch mamtul with a fill:
+///
+///  batch_matmul(*, *, fill())
+///
+/// and capture various useful quantities. If `mustMatchEntireFunc` is set, the
+/// matcher additionally checks if all tileable operations in the functions are
+/// captured.
+void makeBatchMatmulMatcher(transform_ext::MatcherContext &matcherContext,
+                            transform_ext::StructuredOpMatcher *&bmmCapture,
+                            transform_ext::StructuredOpMatcher *&fillCapture,
+                            transform_ext::MatchedMatmulCaptures &captures,
+                            bool mustMatchEntireFunc);
+
 /// Create a group of matchers for a different code sequence of operations
 /// matching exactly a softmax operation.
 ///
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
index 7e2963b..ed3dfd6 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp
@@ -856,6 +856,58 @@
   return emitSilenceableFailure(loc) << "failed to match";
 }
 
+/// Match callback for linalg.batch_matmul and its linalg.generic equivalent fed
+/// by a linalg.fill.
+///
+/// Input handles:
+///
+///   - the container op, must be associated with one operation.
+///
+/// Output handles:
+///
+///   - the fill op initializing the output;
+///   - the main compute op.
+static DiagnosedSilenceableFailure
+batchMatmulCallback(transform_ext::MatchCallbackResult &res, Location loc,
+                    const mlir::transform::TransformState &state,
+                    ValueRange handles) {
+  if (handles.size() != 1 ||
+      !llvm::hasSingleElement(state.getPayloadOps(handles[0]))) {
+    return emitSilenceableFailure(loc)
+           << "expected one handle to one operation";
+  }
+
+  transform_ext::StructuredOpMatcher *pattern, *fill;
+  transform_ext::MatchedMatmulCaptures ignore;
+  transform_ext::MatcherContext matcherContext;
+  transform_ext::makeBatchMatmulMatcher(matcherContext, pattern, fill, ignore,
+                                        /*mustMatchEntireFunc*/ true);
+
+  // TODO: need a mechanism for this to go around the entire IR,
+  // potentially with list matches for each group.
+  Operation *root = *state.getPayloadOps(handles[0]).begin();
+
+  WalkResult walkResult = root->walk([&](Operation *op) {
+    pattern->resetCapture();
+    if (!matchPattern(op, *pattern))
+      return WalkResult::advance();
+
+    // TODO: notify properly
+    LLVM_DEBUG({
+      DBGS() << "fill:" << fill->getCaptured() << "\n";
+      DBGS() << "pattern: " << pattern->getCaptured() << "\n";
+    });
+
+    res.addPayloadGroup({fill->getCaptured()});
+    res.addPayloadGroup({pattern->getCaptured()});
+    return WalkResult::interrupt();
+  });
+
+  if (walkResult.wasInterrupted())
+    return DiagnosedSilenceableFailure::success();
+  return emitSilenceableFailure(loc) << "failed to match batch matmul";
+}
+
 /// Match callback for a tensor.pad. Matches *the first* occurrence of such pad
 /// within an op associated with the given handle.
 ///
@@ -922,6 +974,7 @@
                             testShapedValueMatcherCallback);
   registry.registerCallback("convolution", convolutionCallback);
   registry.registerCallback("matmul", matmulCallback);
+  registry.registerCallback("batch_matmul", batchMatmulCallback);
   registry.registerCallback("pad", wrapAsEntireFuncMatch(padCallback));
   registry.registerCallback("reduction",
                             wrapAsEntireFuncMatch(reductionCallback));
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index 59f085e..0d2293d 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -424,6 +424,14 @@
   });
 }
 
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::rank(NumEqualsTo exactRank) {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+    LLVM_DEBUG(DBGS() << "rank == " << exactRank.value);
+    return linalgOp.getNumLoops() == exactRank.value;
+  });
+}
+
 StringRef stringifyShapeKind(transform_ext::ShapeKind kind) {
   switch (kind) {
   case transform_ext::ShapeKind::Static:
@@ -576,16 +584,40 @@
 }
 
 transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::indexingMaps(
+    CaptureIndexingMaps indexingMaps) {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+    LLVM_DEBUG(DBGS() << "capture indexing maps");
+    indexingMaps.value = linalgOp.getIndexingMapsArray();
+    return true;
+  });
+}
+
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::contractionDims(
+    CaptureContractionDims contractionDims) {
+  return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
+    LLVM_DEBUG(DBGS() << "capture contraction dimensions");
+    StringRef convMessage = linalg::detail::getMatchContractionMessage(
+        mlir::linalg::detail::isContractionInterfaceImpl(
+            linalgOp, &contractionDims.value));
+    if (convMessage.empty())
+      return true;
+    LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")");
+    return false;
+  });
+}
+
+transform_ext::StructuredOpMatcher &
 transform_ext::StructuredOpMatcher::convolutionDims(CaptureConvDims convDims) {
   return addPredicate([=](linalg::LinalgOp linalgOp) -> bool {
-    LLVM_DEBUG(DBGS() << "capture convolution dimensions\n");
+    LLVM_DEBUG(DBGS() << "capture convolution dimensions");
     StringRef convMessage = linalg::detail::getMatchConvolutionMessage(
         mlir::linalg::detail::isConvolutionInterfaceImpl(linalgOp,
                                                          &convDims.value));
     if (convMessage.empty())
       return true;
-    LLVM_DEBUG(DBGS() << "capture convolution dimensions failed: "
-                      << convMessage << "\n");
+    LLVM_DEBUG(llvm::dbgs() << " (" << convMessage << ")");
     return false;
   });
 }
@@ -1149,6 +1181,79 @@
   });
 }
 
+transform_ext::StructuredOpMatcher &
+transform_ext::StructuredOpMatcher::hasContractionBody(
+    function_ref<bool(Operation *)> isaElemOpTy,
+    function_ref<bool(Operation *)> isaReductionOpTy, StringRef elemOpName,
+    StringRef reductionOpName) {
+  return addPredicate([=](linalg::LinalgOp linalgOp) {
+    LLVM_DEBUG(DBGS() << "op region is a " << elemOpName << "/"
+                      << reductionOpName << " contraction (");
+    auto scopeExitPrinter = llvm::make_scope_exit(
+        [] { LLVM_DEBUG(llvm::dbgs() << " check failed)"); });
+
+    Block *body = linalgOp.getBlock();
+    if (!llvm::hasNItems(*body, 3)) {
+      LLVM_DEBUG(llvm::dbgs() << "three-operation body");
+      return false;
+    }
+    if (body->getNumArguments() != 3) {
+      LLVM_DEBUG(llvm::dbgs() << "three-argument block");
+      return false;
+    }
+
+    Operation *elemOp = &(*linalgOp.getBlock()->getOperations().begin());
+    Operation *reductionOp = elemOp->getNextNode();
+    Operation *yieldOp = reductionOp->getNextNode();
+    if (!isaElemOpTy(elemOp)) {
+      LLVM_DEBUG(llvm::dbgs() << "first operation is a " << elemOpName);
+      return false;
+    }
+    if (!isaReductionOpTy(reductionOp)) {
+      LLVM_DEBUG(llvm::dbgs() << "second operation is a " << reductionOpName);
+      return false;
+    }
+    if (yieldOp->getNumOperands() != 1) {
+      LLVM_DEBUG(llvm::dbgs() << "one value yielded");
+      return false;
+    }
+    if (yieldOp->getOperand(0).getDefiningOp() != reductionOp) {
+      LLVM_DEBUG(llvm::dbgs() << "yielded value produced by the second op");
+      return false;
+    }
+    if (elemOp->getNumOperands() != 2 || elemOp->getNumResults() != 1) {
+      LLVM_DEBUG(llvm::dbgs() << "first op has two operands and one result");
+      return false;
+    }
+    if (reductionOp->getNumOperands() != 2 ||
+        reductionOp->getNumResults() != 1) {
+      LLVM_DEBUG(llvm::dbgs() << "second op has two operands and one result");
+      return false;
+    }
+
+    SmallVector<Value, 2> expectedReductionOperands = {body->getArgument(2),
+                                                       elemOp->getResult(0)};
+    if (!llvm::equal(expectedReductionOperands, reductionOp->getOperands()) &&
+        !llvm::equal(llvm::reverse(expectedReductionOperands),
+                     reductionOp->getOperands())) {
+      LLVM_DEBUG(llvm::dbgs() << "operands of the second op");
+      return false;
+    }
+
+    ValueRange expectedElemOperands = body->getArguments().take_front(2);
+    if (!llvm::equal(expectedElemOperands, elemOp->getOperands()) &&
+        !llvm::equal(llvm::reverse(expectedElemOperands),
+                     elemOp->getOperands())) {
+      LLVM_DEBUG(llvm::dbgs() << "operands of the first op");
+      return false;
+    }
+
+    scopeExitPrinter.release();
+    LLVM_DEBUG(llvm::dbgs() << "success)");
+    return true;
+  });
+}
+
 void transform_ext::detail::debugOutputForConcreteOpMatcherConstructor(
     StringRef name) {
   LLVM_DEBUG(DBGS() << "op is a " << name << "'");
@@ -1393,6 +1498,34 @@
   trailingCapture = &trailing;
 }
 
+void transform_ext::makeBatchMatmulMatcher(
+    transform_ext::MatcherContext &matcherContext,
+    transform_ext::StructuredOpMatcher *&bmmCapture,
+    transform_ext::StructuredOpMatcher *&fillCapture,
+    transform_ext::MatchedMatmulCaptures &captures, bool mustMatchEntireFunc) {
+  auto &bmm =
+      transform_ext::m_StructuredOp<linalg::BatchMatmulOp, linalg::GenericOp>(
+          matcherContext)
+          .hasContractionBody<arith::MulFOp, arith::AddFOp>()
+          .rank(NumEqualsTo(4))
+          .dim(AllDims(), CaptureDims(captures.matmulOpSizes))
+          .dim(AllDimsExcept({-1}), utils::IteratorType::parallel)
+          .dim(-1, utils::IteratorType::reduction)
+          .contractionDims(CaptureContractionDims(captures.contractionDims))
+          .input(NumEqualsTo(2))
+          .input(0, CaptureElementType(captures.lhsElementType))
+          .input(1, CaptureElementType(captures.rhsElementType))
+          .output(0, CaptureElementType(captures.outputElementType));
+  bmmCapture = &bmm;
+
+  auto &fill = transform_ext::m_StructuredOp<linalg::FillOp>(matcherContext);
+  bmm = bmm.output(0, fill);
+  fillCapture = &fill;
+
+  if (mustMatchEntireFunc)
+    bmm = bmm.allTilableOpsCaptured<func::FuncOp>();
+}
+
 /// Match sum(%src, broadcast(%reduction))
 static void
 matchSubBroadcast(transform_ext::MatcherContext &matcherContext,