Use dataclasses and enums in generate_e2e_matmul_tests.py (#7441)

Just some improvement to how we use python in this file. Should be a no-op change.

Fixes #7431 .
diff --git a/iree/test/e2e/regression/generate_e2e_matmul_tests.py b/iree/test/e2e/regression/generate_e2e_matmul_tests.py
index 9fec3e6..bc983c3 100644
--- a/iree/test/e2e/regression/generate_e2e_matmul_tests.py
+++ b/iree/test/e2e/regression/generate_e2e_matmul_tests.py
@@ -4,116 +4,206 @@
 # Licensed under the Apache License v2.0 with LLVM Exceptions.
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-"""iree_generated_check_test generator for end-to-end matrix multiplication.
+"""iree_generated_trace_runner_test generator for e2e matmul tests.
 """
 
 import argparse
-import random
 import os
 import yaml
 import re
 
-
-# Returns lists of shapes as (M, K, N) tuples.
-# For example (M, K, 1) is a matrix*vector product, and (M, 1, N) is an outer
-# product.
-def get_test_shapes():
-  return {
-      "small": [  # Small sizes, square matrices
-          (x, x, x) for x in range(1, 40)
-      ] + [
-          # Small sizes, slightly rectangular matrices
-          (2, 3, 4),
-          (8, 7, 6),
-          (15, 16, 17),
-          (14, 19, 23),
-          (31, 33, 32),
-          (25, 41, 35),
-          # Small sizes, involving vectors (i.e. most rectangular cases)
-          (10, 1, 1),
-          (1, 10, 1),
-          (1, 1, 10),
-          (1, 10, 10),
-          (10, 1, 10),
-          (10, 10, 1),
-          # Small sizes, involving other very small dimensions just above 1
-          (13, 14, 2),
-          (3, 17, 12),
-          (21, 4, 18),
-          # Medium sizes, square matrices
-          (100, 100, 100),
-          # Medium sizes, slightly rectangular matrices
-          (101, 102, 103),
-          # Medium sizes, involving vectors (i.e. most rectangular cases)
-          (10000, 1, 1),
-          (1, 10000, 1),
-          (1, 1, 10000),
-          (1, 1000, 1000),
-          (1000, 1, 1000),
-          (1000, 1000, 1),
-          # Medium sizes, involving other very small dimensions just above 1
-          (1300, 1300, 2),
-          (1300, 1300, 3),
-          (1300, 1300, 4),
-      ],
-      "large": [
-          # Large sizes, powers of two
-          (256, 256, 512),
-          (512, 512, 128),
-          (1024, 512, 512),
-          (512, 1024, 512),
-          # Large sizes, powers of two minus one
-          (127, 63, 511),
-          # Large sizes, powers of two plus one
-          (129, 65, 513),
-          # Large sizes, misc.
-          (200, 300, 400),
-          (123, 456, 789),
-          (500, 500, 50),
-          # Be conservative in adding larger shapes. They can result in
-          # high latency tests. If you have to, consider splitting them
-          # out in a way that constrains the latency impact, e.g. by
-          # running on fewer backends/drivers or with fewer generators
-          # (see get_test_generators).
-      ]
-  }
+import enum
+import dataclasses
 
 
-# Returns lists of 'generators', which are tuples of the form
-# (lhs_generator, rhs_generator, acc_generator, dynamicity)
-# The first 3 entries specify how to generate test input data.
-# The dynamicity entry chooses between static, dynamic or mixed shapes.
-#
-# TODO (Issue #7431): turn into enum and dataclass.
-def get_test_generators():
-  return {
-      "small": [
-          # Generators using simple matrices for ease of numerical debugging.
-          # They don't add significant test coverage (all bugs are hit by
-          # tests using random matrices anyway). They are only here to make
-          # the bulk of our debugging easier.
-          ("identity", "identity", "zero", "dynamic"),
-          ("random", "identity", "zero", "dynamic"),
-          ("identity", "random", "zero", "dynamic"),
-          ("identity", "identity", "random", "dynamic"),
-          # Generators using general random matrices
-          ("random", "random", "random", "dynamic"),
-          ("random", "random", "random", "static"),
-          ("random", "random", "random", "mixed"),
-      ],
-      "large": [
-          # Fewer generators are used for large shapes, to limit the
-          # latency impact. Most bugs are going to be caught on small
-          # shapes anyway.
-          ("random", "random", "random", "dynamic"),
-          ("random", "random", "random", "static"),
-      ]
-  }
+# Data type of matrix entries. The string values must match MLIR data types.
+# This is a superset of the values accepted for the --lhs_rhs_types= flag,
+# as this also includes accumulator-specific types like i32.
+@enum.unique
+class MatrixElemTypeId(enum.Enum):
+  I8 = "i8"
+  I32 = "i32"
+  F32 = "f32"
+
+
+# Enumerates of the collections of shapes that we can generate tests for.
+# The values are the accepted values for the --shapes= flag.
+@enum.unique
+class ShapesId(enum.Enum):
+  SMALL = "small"
+  LARGE = "large"
+
+
+# Enumerates ways to construct MLIR tensor types.
+@enum.unique
+class Dynamicity(enum.Enum):
+  DYNAMIC = "dynamic"  # Use '?' everywhere. Example: tensor<?x?xf32>.
+  STATIC = "static"  # Use fixed values everywhere. Example: tensor<4x6xf32>.
+  MIXED = "mixed"  # Randomly mix '?' and values. Example: tensor<?x4xf32>.
+
+
+# Enumerates ways to initialize matrix buffer contents.
+@enum.unique
+class MatrixGenerator(enum.Enum):
+  ZERO = "zero"  # Fill with zeros
+  IDENTITY = "identity"  # Make an identity matrix (generalized to any shape).
+  RANDOM = "random"  # Fill with (deterministic) pseudorandom values.
+
+
+# Describes the shape of a matrix multiplication in the usual convention:
+# the LHS is {m}x{k}, the RHS is {k}x{n}, the accumulator/result is {m}x{n}.
+@dataclasses.dataclass
+class TestShape:
+  m: int
+  k: int
+  n: int
+
+
+# Describes how to construct MLIR tensor types and how to initialize buffer
+# contents for a test case (for an already given TestShape, and already given
+# matrix element data types).
+@dataclasses.dataclass
+class TestGenerator:
+  lhs: MatrixGenerator
+  rhs: MatrixGenerator
+  acc: MatrixGenerator
+  dynamicity: Dynamicity
+
+
+# Returns the list of TestShape's to use for the collection of shapes
+# identified by shapes_id.
+def get_test_shapes(shapes_id: ShapesId):
+  if shapes_id == ShapesId.SMALL:
+    return [  # Small sizes, square matrices
+        TestShape(m=x, k=x, n=x) for x in range(1, 40)
+    ] + [
+        # Small sizes, slightly rectangular matrices
+        TestShape(m=2, k=3, n=4),
+        TestShape(m=8, k=7, n=6),
+        TestShape(m=15, k=16, n=17),
+        TestShape(m=14, k=19, n=23),
+        TestShape(m=31, k=33, n=32),
+        TestShape(m=25, k=41, n=35),
+        # Small sizes, involving vectors (i.e. most rectangular cases)
+        TestShape(m=10, k=1, n=1),
+        TestShape(m=1, k=10, n=1),
+        TestShape(m=1, k=1, n=10),
+        TestShape(m=1, k=10, n=10),
+        TestShape(m=10, k=1, n=10),
+        TestShape(m=10, k=10, n=1),
+        # Small sizes, involving other very small dimensions just above 1
+        TestShape(m=13, k=14, n=2),
+        TestShape(m=3, k=17, n=12),
+        TestShape(m=21, k=4, n=18),
+        # Medium sizes, square matrices
+        TestShape(m=100, k=100, n=100),
+        # Medium sizes, slightly rectangular matrices
+        TestShape(m=101, k=102, n=103),
+        # Medium sizes, involving vectors (i.e. most rectangular cases)
+        TestShape(m=10000, k=1, n=1),
+        TestShape(m=1, k=10000, n=1),
+        TestShape(m=1, k=1, n=10000),
+        TestShape(m=1, k=1000, n=1000),
+        TestShape(m=1000, k=1, n=1000),
+        TestShape(m=1000, k=1000, n=1),
+        # Medium sizes, involving other very small dimensions just above 1
+        TestShape(m=1300, k=1300, n=2),
+        TestShape(m=1300, k=1300, n=3),
+        TestShape(m=1300, k=1300, n=4),
+    ]
+  if shapes_id == ShapesId.LARGE:
+    return [
+        # Large sizes, powers of two
+        TestShape(m=256, k=256, n=512),
+        TestShape(m=512, k=512, n=128),
+        TestShape(m=1024, k=512, n=512),
+        TestShape(m=512, k=1024, n=512),
+        # Large sizes, powers of two minus one
+        TestShape(m=127, k=63, n=511),
+        # Large sizes, powers of two plus one
+        TestShape(m=129, k=65, n=513),
+        # Large sizes, misc.
+        TestShape(m=200, k=300, n=400),
+        TestShape(m=123, k=456, n=789),
+        TestShape(m=500, k=500, n=50),
+        # Be conservative in adding larger shapes. They can result in
+        # high latency tests. If you have to, consider splitting them
+        # out in a way that constrains the latency impact, e.g. by
+        # running on fewer backends/drivers or with fewer generators
+        # (see get_test_generators).
+    ]
+  raise ValueError(shapes_id)
+
+
+# Returns the list of TestGenerator's to use for the collection of shapes
+# identified by shapes_id.
+def get_test_generators(shapes_id: ShapesId):
+  if shapes_id == ShapesId.SMALL:
+    return [
+        # Generators using simple matrices for ease of numerical debugging.
+        # They don't add significant test coverage (all bugs are hit by
+        # tests using random matrices anyway). They are only here to make
+        # the bulk of our debugging easier.
+        TestGenerator(lhs=MatrixGenerator.IDENTITY,
+                      rhs=MatrixGenerator.IDENTITY,
+                      acc=MatrixGenerator.ZERO,
+                      dynamicity=Dynamicity.DYNAMIC),
+        TestGenerator(lhs=MatrixGenerator.RANDOM,
+                      rhs=MatrixGenerator.IDENTITY,
+                      acc=MatrixGenerator.ZERO,
+                      dynamicity=Dynamicity.DYNAMIC),
+        TestGenerator(lhs=MatrixGenerator.IDENTITY,
+                      rhs=MatrixGenerator.RANDOM,
+                      acc=MatrixGenerator.ZERO,
+                      dynamicity=Dynamicity.DYNAMIC),
+        TestGenerator(lhs=MatrixGenerator.IDENTITY,
+                      rhs=MatrixGenerator.IDENTITY,
+                      acc=MatrixGenerator.RANDOM,
+                      dynamicity=Dynamicity.DYNAMIC),
+        # Generators using general random matrices
+        TestGenerator(lhs=MatrixGenerator.RANDOM,
+                      rhs=MatrixGenerator.RANDOM,
+                      acc=MatrixGenerator.RANDOM,
+                      dynamicity=Dynamicity.DYNAMIC),
+        TestGenerator(lhs=MatrixGenerator.RANDOM,
+                      rhs=MatrixGenerator.RANDOM,
+                      acc=MatrixGenerator.RANDOM,
+                      dynamicity=Dynamicity.STATIC),
+        TestGenerator(lhs=MatrixGenerator.RANDOM,
+                      rhs=MatrixGenerator.RANDOM,
+                      acc=MatrixGenerator.RANDOM,
+                      dynamicity=Dynamicity.MIXED),
+    ]
+  if shapes_id == ShapesId.LARGE:
+    return [
+        # Fewer generators are used for large shapes, to limit the
+        # latency impact. Most bugs are going to be caught on small
+        # shapes anyway.
+        TestGenerator(lhs=MatrixGenerator.RANDOM,
+                      rhs=MatrixGenerator.RANDOM,
+                      acc=MatrixGenerator.RANDOM,
+                      dynamicity=Dynamicity.DYNAMIC),
+        TestGenerator(lhs=MatrixGenerator.RANDOM,
+                      rhs=MatrixGenerator.RANDOM,
+                      acc=MatrixGenerator.RANDOM,
+                      dynamicity=Dynamicity.STATIC),
+    ]
+  raise ValueError(shapes_id)
 
 
 # Generates a name for a test function in the generated MLIR code.
-def function_name(lhs_rhs_type, accum_type, shape, gen):
-  return f"{lhs_rhs_type}_{gen[3]}_{gen[0]}_{shape[0]}x{shape[1]}_times_{gen[1]}_{shape[1]}x{shape[2]}_plus_{gen[2]}_{accum_type}"
+def function_name(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId,
+                  shape: TestShape, gen: TestGenerator):
+  dyn = gen.dynamicity.value
+  lhs_g = gen.lhs.value
+  rhs_g = gen.rhs.value
+  acc_g = gen.acc.value
+  input_t = lhs_rhs_type.value
+  acc_t = acc_type.value
+  m = shape.m
+  k = shape.k
+  n = shape.n
+  return f"{input_t}_{dyn}_{lhs_g}_{m}x{k}_times_{rhs_g}_{k}x{n}_plus_{acc_g}_{acc_t}"
 
 
 # Intentionally fixed seed! We want full reproducibility here, both across runs
@@ -125,12 +215,12 @@
 
 # Generates a compile-time MLIR size value, i.e. either a fixed positive integer
 # or a '?' depending on dynamicity.
-def static_size(x, dynamicity):
-  if dynamicity == "dynamic":
+def static_size(x: int, dynamicity: Dynamicity):
+  if dynamicity == Dynamicity.DYNAMIC:
     return "?"
-  elif dynamicity == "static":
+  elif dynamicity == Dynamicity.STATIC:
     return x
-  elif dynamicity == "mixed":
+  elif dynamicity == Dynamicity.MIXED:
     global local_pseudorandom_state
     # Same as C++ std::minstd_rand.
     # Using a local pseudorandom generator implementation ensures that it's
@@ -144,17 +234,18 @@
 # Generates a test function in the generated MLIR code.
 # The generated function will take the same arguments as linalg.matmul and
 # will just call linalg.matmul with them, returning its result.
-def generate_function(func_name, lhs_rhs_type, accum_type, shape, gen):
-  (m, k, n) = shape
-  lhs_m = static_size(m, gen[3])
-  lhs_k = static_size(k, gen[3])
-  rhs_k = static_size(k, gen[3])
-  rhs_n = static_size(n, gen[3])
-  acc_m = static_size(m, gen[3])
-  acc_n = static_size(n, gen[3])
-  lhs_tensor_type = f"tensor<{lhs_m}x{lhs_k}x{lhs_rhs_type}>"
-  rhs_tensor_type = f"tensor<{rhs_k}x{rhs_n}x{lhs_rhs_type}>"
-  acc_tensor_type = f"tensor<{acc_m}x{acc_n}x{accum_type}>"
+def generate_function(func_name: str, lhs_rhs_type: MatrixElemTypeId,
+                      acc_type: MatrixElemTypeId, shape: TestShape,
+                      gen: TestGenerator):
+  lhs_m = static_size(shape.m, gen.dynamicity)
+  lhs_k = static_size(shape.k, gen.dynamicity)
+  rhs_k = static_size(shape.k, gen.dynamicity)
+  rhs_n = static_size(shape.n, gen.dynamicity)
+  acc_m = static_size(shape.m, gen.dynamicity)
+  acc_n = static_size(shape.n, gen.dynamicity)
+  lhs_tensor_type = f"tensor<{lhs_m}x{lhs_k}x{lhs_rhs_type.value}>"
+  rhs_tensor_type = f"tensor<{rhs_k}x{rhs_n}x{lhs_rhs_type.value}>"
+  acc_tensor_type = f"tensor<{acc_m}x{acc_n}x{acc_type.value}>"
   return (
       f"func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
       f"  %result = linalg.matmul ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
@@ -170,12 +261,12 @@
 
 
 # Generates a contents_generator tag to use in the output trace.
-def contents_generator_tag(generator):
-  if generator == "zero":
+def contents_generator_tag(generator: MatrixGenerator):
+  if generator == MatrixGenerator.ZERO:
     return ""
-  elif generator == "identity":
+  elif generator == MatrixGenerator.IDENTITY:
     return "!tag:iree:identity_matrix"
-  elif generator == "random":
+  elif generator == MatrixGenerator.RANDOM:
     global pseudorandom_generator_seed
     pseudorandom_generator_seed = pseudorandom_generator_seed + 1
     return f"!tag:iree:fully_specified_pseudorandom {pseudorandom_generator_seed}"
@@ -185,11 +276,13 @@
 
 # Generate a matrix function argument in the output trace, as a dictionary
 # to be passed to yaml.dump.
-def generate_trace_matrix_arg(matrix_shape, element_type, generator):
+def generate_trace_matrix_arg(matrix_shape: list,
+                              element_type: MatrixElemTypeId,
+                              generator: MatrixGenerator):
   result = {
       "type": "hal.buffer_view",
       "shape": matrix_shape,
-      "element_type": element_type,
+      "element_type": element_type.value,
   }
   generator_tag = contents_generator_tag(generator)
   if generator_tag:
@@ -199,12 +292,14 @@
 
 # Generates the output trace for a testcase i.e. a single test function call,
 # as a dictionary to be passed to yaml.dump.
-def generate_trace(func_name, lhs_rhs_type, acc_type, shape, gen):
-  (m, k, n) = shape
-  lhs_arg = generate_trace_matrix_arg([m, k], lhs_rhs_type, gen[0])
-  rhs_arg = generate_trace_matrix_arg([k, n], lhs_rhs_type, gen[1])
-  acc_arg = generate_trace_matrix_arg([m, n], acc_type, gen[2])
-  result_arg = generate_trace_matrix_arg([m, n], acc_type, "zero")
+def generate_trace(func_name: str, lhs_rhs_type: MatrixElemTypeId,
+                   acc_type: MatrixElemTypeId, shape: TestShape,
+                   gen: TestGenerator):
+  lhs_arg = generate_trace_matrix_arg([shape.m, shape.k], lhs_rhs_type, gen.lhs)
+  rhs_arg = generate_trace_matrix_arg([shape.k, shape.n], lhs_rhs_type, gen.rhs)
+  acc_arg = generate_trace_matrix_arg([shape.m, shape.n], acc_type, gen.acc)
+  result_arg = generate_trace_matrix_arg([shape.m, shape.n], acc_type,
+                                         MatrixGenerator.ZERO)
   return {
       "type": "call",
       "function": "module." + func_name,
@@ -218,14 +313,13 @@
 
 
 # Generates all output files' contents as strings.
-def generate(args):
+def generate(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId,
+             shapes_id: ShapesId):
   functions = {}
   traces = []
-  lhs_rhs_type = args.lhs_rhs_type
-  accum_type = 'i32' if lhs_rhs_type == 'i8' else lhs_rhs_type
-  for shape in get_test_shapes()[args.shapes]:
-    for gen in get_test_generators()[args.shapes]:
-      func_name = function_name(lhs_rhs_type, accum_type, shape, gen)
+  for shape in get_test_shapes(shapes_id):
+    for gen in get_test_generators(shapes_id):
+      func_name = function_name(lhs_rhs_type, acc_type, shape, gen)
       # Different testcases may differ only by runtime parameters but
       # share the same code. For example, dynamic-shapes testcases
       # share the same code involing tensor<?x?xf32> even though the runtime
@@ -233,9 +327,9 @@
       # generate_function conditionally, and generate_trace unconditionally.
       if func_name not in functions:
         functions[func_name] = generate_function(func_name, lhs_rhs_type,
-                                                 accum_type, shape, gen)
+                                                 acc_type, shape, gen)
       traces.append(
-          generate_trace(func_name, lhs_rhs_type, accum_type, shape, gen))
+          generate_trace(func_name, lhs_rhs_type, acc_type, shape, gen))
   return (functions, traces)
 
 
@@ -256,7 +350,7 @@
                       required=True)
   parser.add_argument("--shapes",
                       type=str,
-                      choices=["small", "large"],
+                      choices=[s.value for s in ShapesId],
                       help="Collection of matrix shapes to test",
                       required=True)
   parser.add_argument(
@@ -308,8 +402,22 @@
     file.write(processed_yaml)
 
 
+# For now, the accumulator type can always be inferred from the input LHS/RHS
+# type, so we do that. That is temporary: eventually there will be cases
+# where the same input types are used with different accumulator types, e.g.
+# f16 inputs with both f16 and f32 accumulator.
+def infer_acc_type(lhs_rhs_type: MatrixElemTypeId):
+  if lhs_rhs_type == MatrixElemTypeId.I8:
+    return MatrixElemTypeId.I32
+  else:
+    return lhs_rhs_type
+
+
 def main(args):
-  (functions, traces) = generate(args)
+  lhs_rhs_type = MatrixElemTypeId(args.lhs_rhs_type)
+  acc_type = infer_acc_type(lhs_rhs_type)
+  shapes_id = ShapesId(args.shapes)
+  (functions, traces) = generate(lhs_rhs_type, acc_type, shapes_id)
   write_code_file(functions, args.output_code)
   write_trace_file(traces, args.output_trace, args.module_path)