|  | # Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | # See https://llvm.org/LICENSE.txt for license information. | 
|  | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  |  | 
|  | import argparse | 
|  | import jaxlib.mlir.ir as mlir_ir | 
|  | import jax._src.interpreters.mlir as mlir | 
|  | import multiprocessing | 
|  | import os | 
|  | import re | 
|  |  | 
|  | parser = argparse.ArgumentParser( | 
|  | prog="triage_jaxtest.py", description="Triage the jax tests" | 
|  | ) | 
|  | parser.add_argument("-l", "--logdir", default="/tmp/jaxtest") | 
|  | parser.add_argument("-d", "--delete", default=False) | 
|  | parser.add_argument("-j", "--jobs", default=None) | 
|  | args = parser.parse_args() | 
|  |  | 
|  | tests = set(os.listdir(args.logdir)) | 
|  |  | 
|  |  | 
|  | def filter_to_failures(tests): | 
|  | failures = list() | 
|  | for test in tests: | 
|  | files = os.listdir(f"{args.logdir}/{test}") | 
|  | if "error.txt" in files or "CRASH_MARKER" in files: | 
|  | failures.append(test) | 
|  | failures = sorted(failures) | 
|  | return failures | 
|  |  | 
|  |  | 
|  | def check_custom_call(errortxt, mlirbc, __): | 
|  | return "stablehlo.custom_call" in errortxt or "stablehlo.custom_call" in mlirbc | 
|  |  | 
|  |  | 
|  | def check_load_ui(errortxt, _, __): | 
|  | return ( | 
|  | "'flow.tensor.load' op result #0 must be index or signless integer or floating-point or complex-type or vector of any type values, but got 'ui32'" | 
|  | in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_splat_ui(errortxt, _, __): | 
|  | return ( | 
|  | "'flow.tensor.splat' op failed to verify that value type matches element type of result" | 
|  | in errortxt | 
|  | and "xui" in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_degenerate_tensor(_, mlirbc, __): | 
|  | return "tensor<0x" in mlirbc or "x0x" in mlirbc | 
|  |  | 
|  |  | 
|  | def check_topk_bf16(_, mlirbc, __): | 
|  | return "bf16" in mlirbc and "hlo.top_k" in mlirbc | 
|  |  | 
|  |  | 
|  | def check_cross_replica(errortxt, _, __): | 
|  | return "cross-replica" in errortxt | 
|  |  | 
|  |  | 
|  | def check_collective(errortxt, _, __): | 
|  | return ( | 
|  | "stablehlo.collective" in errortxt | 
|  | or "UNIMPLEMENTED; collectives not implemented" in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_sort_shape(errortxt, _, __): | 
|  | return ( | 
|  | "'iree_linalg_ext.sort' op expected operand 1 to have same shape as other operands" | 
|  | in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_reverse_i1(_, mlirbc, __): | 
|  | for line in mlirbc.split("\n"): | 
|  | if "stablehlo.reverse" in line and "xui" in line: | 
|  | return True | 
|  | return False | 
|  |  | 
|  |  | 
|  | def check_complex(errortxt, mlirbc, __): | 
|  | return "complex<" in errortxt or "complex<" in mlirbc | 
|  |  | 
|  |  | 
|  | def check_timeout(errortxt, _, __): | 
|  | return "jaxlib.xla_extension.XlaRuntimeError: ABORTED: ABORTED" in errortxt | 
|  |  | 
|  |  | 
|  | def check_rng_bit_i8(_, mlirbc, __): | 
|  | lines = mlirbc.split("\n") | 
|  | for line in lines: | 
|  | if "stablehlo.rng_bit_generator" in line and "i8" in line: | 
|  | return True | 
|  | return False | 
|  |  | 
|  |  | 
|  | def check_min_max_f16(errortxt, _, __): | 
|  | if "undefined symbol: fminf" in errortxt: | 
|  | return True | 
|  | lines = errortxt.split("\n") | 
|  | for line in lines: | 
|  | has_fmax = "llvm.intr.vector.reduce.fmax" in line | 
|  | has_fmin = "llvm.intr.vector.reduce.fmin" in line | 
|  | has_f16 = "f16" in line | 
|  | if (has_fmax or has_fmin) and has_f16: | 
|  | return True | 
|  | return False | 
|  |  | 
|  |  | 
|  | def check_scatter_ui(errortxt, _, __): | 
|  | lines = errortxt.split("\n") | 
|  | for line in lines: | 
|  | has_scatter = "iree_linalg_ext.scatter" in line | 
|  | has_operand = "expected type of `outs` operand #0" in line | 
|  | has_type = "xui" in line | 
|  | if has_scatter and has_operand and has_type: | 
|  | return True | 
|  | return False | 
|  |  | 
|  |  | 
|  | def check_bitcast_bf16(errortxt, _, __): | 
|  | return "bf16" in errortxt and "`arith.bitcast` op operand type" in errortxt | 
|  |  | 
|  |  | 
|  | def check_constant_bf16(errortxt, _, __): | 
|  | return "FloatAttr does not match expected type of the constant" in errortxt | 
|  |  | 
|  |  | 
|  | def check_triangular_solve(errortxt, _, __): | 
|  | return "stablehlo.triangular_solve" in errortxt | 
|  |  | 
|  |  | 
|  | def check_cholesky(errortxt, _, __): | 
|  | return "stablehlo.cholesky" in errortxt | 
|  |  | 
|  |  | 
|  | def check_fft(_, mlirbc, __): | 
|  | return "stablehlo.fft" in mlirbc | 
|  |  | 
|  |  | 
|  | def check_schedule_allocation(errortxt, _, __): | 
|  | return "Pipeline failed while executing [`ScheduleAllocation`" in errortxt | 
|  |  | 
|  |  | 
|  | def check_dot_i1(_, mlirbc, __): | 
|  | for line in mlirbc.split("\n"): | 
|  | has_i1 = re.search("tensor<([0-9]*x)*i1>", line) | 
|  | has_dot = re.search("stablehlo.dot", line) | 
|  | if has_i1 and has_dot: | 
|  | return True | 
|  | return False | 
|  |  | 
|  |  | 
|  | def check_vectorize(errortxt, _, __): | 
|  | return ( | 
|  | "arith.truncf' op operand #0 must be floating-point-like, but got 'vector<f32>" | 
|  | in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_roundeven(errortxt, _, __): | 
|  | return "roundeven" in errortxt | 
|  |  | 
|  |  | 
|  | def check_numerical(errortxt, _, __): | 
|  | return "Mismatched elements" in errortxt | 
|  |  | 
|  |  | 
|  | def check_compilation(errortxt, _, __): | 
|  | return "iree/integrations/pjrt/common/api_impl.cc:1085" in errortxt | 
|  |  | 
|  |  | 
|  | def check_scatter(errortxt, _, __): | 
|  | return ( | 
|  | "'iree_linalg_ext.scatter' op mismatch in shape of indices and update value at dim#0" | 
|  | in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_shape_cast(errortxt, _, __): | 
|  | return ( | 
|  | "'vector.shape_cast' op source/result number of elements must match" in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_scatter_crash(_, mlirbc, runtime_crash): | 
|  | return "stablehlo.scatter" in mlirbc and runtime_crash | 
|  |  | 
|  |  | 
|  | def check_eigen_decomposition(errortxt, _, __): | 
|  | return ( | 
|  | "Nonsymmetric eigendecomposition is only implemented on the CPU backend" | 
|  | in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_jax_unimplemented(errortxt, _, __): | 
|  | return "NotImplementedError: MLIR translation rule for primitive" in errortxt | 
|  |  | 
|  |  | 
|  | def check_serialize_exe(errortxt, _, __): | 
|  | return "UNIMPLEMENTED; PJRT_Executable_Serialize" in errortxt | 
|  |  | 
|  |  | 
|  | def check_optimized_prgrm(errortxt, _, __): | 
|  | return "UNIMPLEMENTED; PJRT_Executable_OptimizedProgram" in errortxt | 
|  |  | 
|  |  | 
|  | def check_optimized_program(errortxt, _, __): | 
|  | return "UNIMPLEMENTED; PJRT_Executable_OptimizedProgram" in errortxt | 
|  |  | 
|  |  | 
|  | def check_donation(errortxt, _, __): | 
|  | return "Donation is not implemented for iree_cpu" in errortxt | 
|  |  | 
|  |  | 
|  | def check_semaphore_overload(errortxt, _, __): | 
|  | return ( | 
|  | "OUT_OF_RANGE; semaphore values must be monotonically increasing;" in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_python_callback(errortxt, _, __): | 
|  | return "ValueError: `EmitPythonCallback` not supported" in errortxt | 
|  |  | 
|  |  | 
|  | def check_complex_convolution(errortxt, mlirbc, __): | 
|  | if "failed to legalize operation 'complex.constant'" in errortxt: | 
|  | return True | 
|  |  | 
|  | for line in mlirbc.split("\n"): | 
|  | has_i1 = re.search("tensor<([0-9]*x)*complex<f[0-9]*>>", line) | 
|  | has_conv = re.search("stablehlo.convolution", line) | 
|  | has_dot = re.search("stablehlo.dot", line) | 
|  | if has_i1 and (has_conv or has_dot): | 
|  | return True | 
|  | return False | 
|  |  | 
|  |  | 
|  | def check_subspan(errortxt, _, __): | 
|  | return ( | 
|  | "failed to legalize operation 'hal.interface.binding.subspan' that was explicitly marked illegal" | 
|  | in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_from_tensor(errortxt, _, __): | 
|  | return "error: 'tensor.from_elements' op unhandled tensor operation" in errortxt | 
|  |  | 
|  |  | 
|  | def check_unknown_backend(errortxt, _, __): | 
|  | return "RuntimeError: Unknown backend" in errortxt | 
|  |  | 
|  |  | 
|  | def check_unsigned_topk(_, mlirbc, __): | 
|  | for line in mlirbc.split("\n"): | 
|  | if "xui" in line and "chlo.top_k" in line: | 
|  | return True | 
|  | return False | 
|  |  | 
|  |  | 
|  | def check_runtime_crash(__, _, runtime_crash): | 
|  | return runtime_crash | 
|  |  | 
|  |  | 
|  | def check_aborted(errortxt, _, __): | 
|  | return "ABORTED" in errortxt | 
|  |  | 
|  |  | 
|  | def check_bounds_indexing(errortxt, _, __): | 
|  | return "out-of-bounds indexing for array of shape" in errortxt | 
|  |  | 
|  |  | 
|  | def check_nan_correctness(errortxt, _, __): | 
|  | return "nan location mismatch" in errortxt | 
|  |  | 
|  |  | 
|  | def check_pointer_mismatch(errortxt, _, __): | 
|  | return "unsafe_buffer_pointer()" in errortxt | 
|  |  | 
|  |  | 
|  | def check_select_and_scatter(errortxt, _, __): | 
|  | return "failed to legalize operation 'stablehlo.select_and_scatter'" in errortxt | 
|  |  | 
|  |  | 
|  | def check_degenerate_scatter(errortxt, _, __): | 
|  | return ( | 
|  | "'iree_linalg_ext.scatter' op operand #2 must be ranked tensor or memref of any type values" | 
|  | in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_cost_analysis(errortxt, _, __): | 
|  | return "cost_analysis()" in errortxt | 
|  |  | 
|  |  | 
|  | def check_invalid_option(errortxt, _, __): | 
|  | return "No such compile option: 'invalid_key'" in errortxt | 
|  |  | 
|  |  | 
|  | def check_inf_mismatch(errortxt, _, __): | 
|  | return "inf location mismatch" in errortxt | 
|  |  | 
|  |  | 
|  | def check_shape_assertion(errortxt, _, __): | 
|  | for line in errortxt.split("\n"): | 
|  | if "assertEqual" in line and ".shape" in line: | 
|  | return True | 
|  | return False | 
|  |  | 
|  |  | 
|  | def check_vector_contract(errortxt, _, __): | 
|  | return ( | 
|  | "'vector.contract' op failed to verify that lhs and rhs have same element type" | 
|  | in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_subbyte_read(errortxt, _, __): | 
|  | return "opaque and sub-byte aligned element types cannot be indexed" in errortxt | 
|  |  | 
|  |  | 
|  | def check_buffer_usage(errortxt, _, __): | 
|  | return ( | 
|  | "requested buffer usage is not supported" in errortxt | 
|  | or "tensor requested usage was not specified when the buffer" in errortxt | 
|  | or "PERMISSION_DENIED; requested usage was not specified when the buffer was allocated; buffer allows DISPATCH_INDIRECT_PARAMS" | 
|  | in errortxt | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_subbyte_singleton(errortxt, _, __): | 
|  | return "does not have integral number of total bytes" in errortxt | 
|  |  | 
|  |  | 
|  | def check_max_arg(errortxt, _, __): | 
|  | return "max() arg is an empty sequence" in errortxt | 
|  |  | 
|  |  | 
|  | def check_double_support(errortxt, _, __): | 
|  | return "expected f32 (21000020) but have f64 (21000040)" in errortxt | 
|  |  | 
|  |  | 
|  | def check_stablehlo_degenerate(_, mlirbc, __): | 
|  | for line in mlirbc.split("\n"): | 
|  | if "stablehlo" in line and ("x0x" in line or "<0x" in line): | 
|  | return True | 
|  | return False | 
|  |  | 
|  |  | 
|  | def check_stablehlo_allreduce(errortxt, _, __): | 
|  | return "failed to legalize operation 'stablehlo.all_reduce'" in errortxt | 
|  |  | 
|  |  | 
|  | def check_dot_shape(errortxt, _, __): | 
|  | for line in errortxt.split("\n"): | 
|  | if ( | 
|  | "error: inferred shape" in line | 
|  | and "is incompatible with return type of operation " in line | 
|  | ): | 
|  | return True | 
|  | return False | 
|  |  | 
|  |  | 
|  | KnownChecks = { | 
|  | "https://github.com/iree-org/iree/issues/14255 (detensoring)": check_from_tensor, | 
|  | "https://github.com/iree-org/iree/issues/????? (unknown)": check_jax_unimplemented, | 
|  | "https://github.com/iree-org/iree/issues/13726 (collective)": check_collective, | 
|  | "https://github.com/iree-org/iree/issues/12410 (custom call)": check_custom_call, | 
|  | "https://github.com/iree-org/iree/issues/11018 (triangle)": check_triangular_solve, | 
|  | "https://github.com/iree-org/iree/issues/12263 (fft)": check_fft, | 
|  | "https://github.com/iree-org/iree/issues/14072 (complex convolution)": check_complex_convolution, | 
|  | "https://github.com/iree-org/iree/issues/10816 (cholesky)": check_cholesky, | 
|  | "https://github.com/iree-org/iree/issues/11761 (rng bit gen i8)": check_rng_bit_i8, | 
|  | "https://github.com/iree-org/iree/issues/????? (eigen decomp)": check_eigen_decomposition, | 
|  | "https://github.com/iree-org/iree/issues/13579 (scatter ui)": check_scatter_ui, | 
|  | "https://github.com/iree-org/iree/issues/13725 (cross repl)": check_cross_replica, | 
|  | "https://github.com/iree-org/iree/issues/13493 (dot i1)": check_dot_i1, | 
|  | "https://github.com/iree-org/iree/issues/13522 (roundeven)": check_roundeven, | 
|  | "https://github.com/iree-org/iree/issues/13577 (max/min f16)": check_min_max_f16, | 
|  | "https://github.com/iree-org/iree/issues/13523 (scatter)": check_scatter, | 
|  | "https://github.com/iree-org/iree/issues/13580 (scatter crash)": check_scatter_crash, | 
|  | "https://github.com/iree-org/iree/issues/14079 (shape_cast)": check_shape_cast, | 
|  | "https://github.com/iree-org/iree/issues/????? (optimized prgrm)": check_optimized_program, | 
|  | "https://github.com/iree-org/iree/issues/????? (donation)": check_donation, | 
|  | "https://github.com/iree-org/iree/issues/????? (python callback)": check_python_callback, | 
|  | "https://github.com/iree-org/iree/issues/????? (subspan)": check_subspan, | 
|  | "https://github.com/iree-org/iree/issues/14098 (unsigned topk)": check_unsigned_topk, | 
|  | "https://github.com/iree-org/iree/issues/????? (bounds indexing)": check_bounds_indexing, | 
|  | "https://github.com/iree-org/iree/issues/????? (nan correctness)": check_nan_correctness, | 
|  | "https://github.com/iree-org/iree/issues/????? (pointer mismatch)": check_pointer_mismatch, | 
|  | "https://github.com/iree-org/iree/issues/10841 (select and scatter)": check_select_and_scatter, | 
|  | "https://github.com/iree-org/iree/issues/????? (degenerate scatter)": check_degenerate_scatter, | 
|  | "https://github.com/iree-org/iree/issues/????? (cost analysis)": check_cost_analysis, | 
|  | "https://github.com/iree-org/iree/issues/????? (invalid option)": check_invalid_option, | 
|  | "https://github.com/iree-org/iree/issues/????? (inf mismatch)": check_inf_mismatch, | 
|  | "https://github.com/iree-org/iree/issues/????? (shape assertion)": check_shape_assertion, | 
|  | "https://github.com/iree-org/iree/issues/????? (vector contract)": check_vector_contract, | 
|  | "https://github.com/iree-org/iree/issues/????? (subbyte indexed)": check_subbyte_read, | 
|  | "https://github.com/iree-org/iree/issues/????? (buffer usage)": check_buffer_usage, | 
|  | "https://github.com/iree-org/iree/issues/????? (subbyte singleton)": check_subbyte_singleton, | 
|  | "https://github.com/iree-org/iree/issues/????? (max arg)": check_max_arg, | 
|  | "https://github.com/iree-org/iree/issues/????? (double support)": check_double_support, | 
|  | "https://github.com/iree-org/iree/issues/????? (zero extent)": check_stablehlo_degenerate, | 
|  | "https://github.com/iree-org/iree/issues/????? (all reduce)": check_stablehlo_allreduce, | 
|  | "https://github.com/iree-org/iree/issues/????? (stablehlo dot_general)": check_dot_shape, | 
|  | "(unknown backend)": check_unknown_backend, | 
|  | "(semaphore)": check_semaphore_overload, | 
|  | "Aborted (possible timeout)": check_aborted, | 
|  | "Runtime Crash": check_runtime_crash, | 
|  | "Compilation Failure": check_compilation, | 
|  | "Numerical Failures": check_numerical, | 
|  | "Untriaged": lambda _, __, ___: True, | 
|  | } | 
|  |  | 
|  |  | 
|  | def triage_test(test): | 
|  | files = sorted(os.listdir(f"{args.logdir}/{test}")) | 
|  | # Load the error.txt if it is available. | 
|  | error = "" | 
|  | if "error.txt" in files: | 
|  | with open(f"{args.logdir}/{test}/error.txt") as f: | 
|  | error = "".join(f.readlines()) | 
|  |  | 
|  | # Load the last bytecode file that was attempted to be compiled: | 
|  | mlirbc_count = len([f for f in files if "mlirbc" in f]) | 
|  | mlirbc_name = f"{mlirbc_count - 1}-program.mlirbc" | 
|  | vmfb_name = f"{mlirbc_count - 1}-program.vmfb" | 
|  |  | 
|  | runtime_crash = "CRASH_MARKER" in files | 
|  |  | 
|  | mlirbc = "" | 
|  | if mlirbc_count > 0: | 
|  | with mlir.make_ir_context() as ctx: | 
|  | with open(f"{args.logdir}/{test}/{mlirbc_name}", "rb") as f: | 
|  | mlirbc = f.read() | 
|  | mlirbc = str(mlir_ir.Module.parse(mlirbc)) | 
|  |  | 
|  | for checkname in KnownChecks: | 
|  | if KnownChecks[checkname](error, mlirbc, runtime_crash): | 
|  | return checkname | 
|  |  | 
|  | return "unknown error" | 
|  |  | 
|  |  | 
|  | def filter_error_mapping(tests): | 
|  | error_mapping = {} | 
|  | with multiprocessing.Pool(int(args.jobs) if args.jobs else args.jobs) as p: | 
|  | results = p.map(triage_test, tests) | 
|  |  | 
|  | for test, result in zip(tests, results): | 
|  | error_mapping[test] = result | 
|  | return error_mapping | 
|  |  | 
|  |  | 
|  | def generate_summary(mapping): | 
|  | summary = {} | 
|  | for err in KnownChecks.keys(): | 
|  | summary[err] = [] | 
|  | for test in mapping: | 
|  | summary[mapping[test]].append(test) | 
|  | return summary | 
|  |  | 
|  |  | 
|  | def print_summary(summary): | 
|  | maxlen = 0 | 
|  | for error in summary: | 
|  | maxlen = max(len(error), maxlen) | 
|  | for error in summary: | 
|  | print(f"{error:<{maxlen}} : {len(summary[error])}") | 
|  |  | 
|  | passstr = "Passing" | 
|  | failstr = "Failing" | 
|  | print(f"{passstr:<{maxlen}} : {len(tests) - len(failing)}") | 
|  | print(f"{failstr:<{maxlen}} : {len(failing)}") | 
|  |  | 
|  |  | 
|  | failing = filter_to_failures(tests) | 
|  | mapping = filter_error_mapping(failing) | 
|  | summary = generate_summary(mapping) | 
|  | print_summary(summary) |