[LLVMGPU] Enable IGEMM for convolutions by default (#19006)

This flips the `iree-codegen-llvmgpu-use-igemm` flag from false to true.
Convolutions that are aligned with MMA intrinsic shapes will now go down
the IGEMM path, using the TileAndFuse pipeline. Some tests are also
updated, since the default behavior has changed.

### Benchmarks ###
***Note**: These are simply some local comparative benchmarks, intended
to show the difference when flipping the flag. Do not read too much into
the exact numbers, since the benchmarking environment was probably not
ideal.

- Local benchmarks show improvement (~0.5ms) for int8 punet, and a
regression for int8 VAE (~1ms on batch size 1, ~20ms on batch size 8).
- Broader microbenchmark data of more convolution shapes shows general
improvement on average with IGEMM over the default path.
- The current model benchmarks have no tuning, and the regressions in
VAE should be recoverable with tuning for the IGEMM path.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 6271fe4..45d592f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -98,7 +98,7 @@
 static llvm::cl::opt<bool>
     clLLVMGPUUseIgemm("iree-codegen-llvmgpu-use-igemm",
                       llvm::cl::desc("Enable implicit gemm for convolutions."),
-                      llvm::cl::init(false));
+                      llvm::cl::init(true));
 namespace {
 
 using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline;
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 4e5776f..3a0d7d3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -23,7 +23,6 @@
             "amdgpu_set_anchor_layouts.mlir",
             "assign_constant_ordinals.mlir",
             "conv_pipeline_test_cuda.mlir",
-            "conv_pipeline_test_rocm.mlir",
             "convert_to_nvvm.mlir",
             "convert_to_rocdl.mlir",
             "create_async_groups.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index bc3935f..f628010 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -25,7 +25,6 @@
     "config_winograd.mlir"
     "configure_tensor_layout.mlir"
     "conv_pipeline_test_cuda.mlir"
-    "conv_pipeline_test_rocm.mlir"
     "convert_to_nvvm.mlir"
     "convert_to_rocdl.mlir"
     "create_async_groups.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index 6bef11e..80d4691 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -1,5 +1,6 @@
 // RUN: iree-opt --mlir-print-local-scope --split-input-file --iree-gpu-test-target=gfx940 \
 // RUN: --iree-codegen-llvmgpu-test-tile-and-fuse-matmul=true --iree-codegen-llvmgpu-test-tile-and-fuse-vectorize=true \
+// RUN: --iree-codegen-llvmgpu-use-igemm=false \
 // RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s
 
 // TODO: This test is still using the legacy LLVMGPU kernel config. This needs
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx940.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx940.mlir
index 46b8292..d848c03 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx940.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx940.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx940 --iree-codegen-llvmgpu-use-vector-distribution \
-// RUN:   --iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution \
+// RUN:   --iree-codegen-llvmgpu-use-unaligned-gemm-vector-distribution --iree-codegen-llvmgpu-use-igemm=false \
 // RUN:   --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s
 
 // TODO: This test is still using the legacy LLVMGPU kernel config. This needs
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test_rocm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test_rocm.mlir
deleted file mode 100644
index b33502e..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/conv_pipeline_test_rocm.mlir
+++ /dev/null
@@ -1,53 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 \
-// RUN:   --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target,canonicalize)))))' \
-// RUN:   %s | FileCheck %s
-
-#pipeline_layout = #hal.pipeline.layout<bindings = [
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>,
-  #hal.pipeline.binding<storage_buffer>
-]>
-hal.executable private @conv_nchw_dispatch_1 {
-  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
-    hal.executable.export public @conv_2d_nchw_fchw_2x320x64x64x320x3x3_f16 ordinal(0) layout(#pipeline_layout) attributes {
-      translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [16, 2, 1]>
-    } {
-    ^bb0(%arg0: !hal.device):
-      %x, %y, %z = flow.dispatch.workgroup_count_from_slice
-      hal.return %x, %y, %z : index, index, index
-    }
-    builtin.module {
-      func.func @conv_2d_nchw_fchw_2x320x64x64x320x3x3_f16() {
-        %cst = arith.constant 0.000000e+00 : f16
-        %c0 = arith.constant 0 : index
-        %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x320x130x130xf16>>
-        %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<320x320x3x3xf16>>
-        %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<320xf16>>
-        %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x320x64x64xf16>>
-        %4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 320, 130, 130], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x320x130x130xf16>> -> tensor<2x320x130x130xf16>
-        %5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [320, 320, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<320x320x3x3xf16>> -> tensor<320x320x3x3xf16>
-        %6 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [320], strides = [1] : !flow.dispatch.tensor<readonly:tensor<320xf16>> -> tensor<320xf16>
-        %7 = tensor.empty() : tensor<2x320x64x64xf16>
-        %8 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 8, 64, 4, 1, 1], [0, 0, 1, 0]]>} ins(%cst : f16) outs(%7 : tensor<2x320x64x64xf16>) -> tensor<2x320x64x64xf16>
-        %9 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 8, 64, 4, 1, 1], [0, 0, 1, 0]]>, strides = dense<2> : vector<2xi64>} ins(%4, %5 : tensor<2x320x130x130xf16>, tensor<320x320x3x3xf16>) outs(%8 : tensor<2x320x64x64xf16>) -> tensor<2x320x64x64xf16>
-        %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9, %6 : tensor<2x320x64x64xf16>, tensor<320xf16>) outs(%7 : tensor<2x320x64x64xf16>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 8, 64, 4, 1, 1], [0, 0, 1, 0]]>} {
-        ^bb0(%in: f16, %in_0: f16, %out: f16):
-          %11 = arith.addf %in, %in_0 : f16
-          linalg.yield %11 : f16
-        } -> tensor<2x320x64x64xf16>
-        flow.dispatch.tensor.store %10, %3, offsets = [0, 0, 0, 0], sizes = [2, 320, 64, 64], strides = [1, 1, 1, 1] : tensor<2x320x64x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x320x64x64xf16>>
-        return
-      }
-    }
-  }
-}
-
-// TODO: This test reflects a bug related to how the convolution is bufferized
-// for the LLVMGPUVectorize pipeline, meaning these local memory allocations are
-// not desired. This test should be dropped once the extra buffers have been
-// eliminated.
-
-//   CHECK-LABEL:  func @conv_2d_nchw_fchw_2x320x64x64x320x3x3_f16
-// CHECK-COUNT-3:    memref.alloca() : memref<1x1x1x4xf16, #gpu.address_space<private>>
-// CHECK-COUNT-3:    memref.copy %{{.*}}, %{{.*}} : memref<1x1x1x4xf16, #gpu.address_space<private>> to memref<{{.*}} #hal.descriptor_type<storage_buffer>>