Update and re-enable bert test

Fixed the dot_dimension_numbers for the new custom attribute syntax

PiperOrigin-RevId: 398770021
diff --git a/iree/test/e2e/models/BUILD b/iree/test/e2e/models/BUILD
index 021058d..1732ac0 100644
--- a/iree/test/e2e/models/BUILD
+++ b/iree/test/e2e/models/BUILD
@@ -16,12 +16,8 @@
     licenses = ["notice"],  # Apache 2.0
 )
 
-# TODO(b/200955828): this test needs an update.
-DISABLED_TESTS = [
-    "bert_encoder_unrolled_fake_weights.mlir",
-]
-
 CHECK_FRAMEWORK_TESTS = [
+    "bert_encoder_unrolled_fake_weights.mlir",
     "mobilenetv3_fake_weights.mlir",
 ]
 
@@ -40,7 +36,7 @@
         ],
         include =
             ["*.mlir"],
-        exclude = CHECK_FRAMEWORK_TESTS + DISABLED_TESTS,
+        exclude = CHECK_FRAMEWORK_TESTS,
     ),
     data = [
         "//iree/tools:IreeFileCheck",
diff --git a/iree/test/e2e/models/CMakeLists.txt b/iree/test/e2e/models/CMakeLists.txt
index d761db8..3975477 100644
--- a/iree/test/e2e/models/CMakeLists.txt
+++ b/iree/test/e2e/models/CMakeLists.txt
@@ -33,6 +33,7 @@
   NAME
     check_linalg_on_tensors_dylib-llvm-aot_dylib
   SRCS
+    "bert_encoder_unrolled_fake_weights.mlir"
     "mobilenetv3_fake_weights.mlir"
   TARGET_BACKEND
     "dylib-llvm-aot"
@@ -46,6 +47,7 @@
   NAME
     check_linalg_on_tensors_vulkan-spirv_vulkan
   SRCS
+    "bert_encoder_unrolled_fake_weights.mlir"
     "mobilenetv3_fake_weights.mlir"
   TARGET_BACKEND
     "vulkan-spirv"
@@ -59,6 +61,7 @@
   NAME
     check_linalg_on_tensors_cuda_cuda
   SRCS
+    "bert_encoder_unrolled_fake_weights.mlir"
     "mobilenetv3_fake_weights.mlir"
   TARGET_BACKEND
     "cuda"
diff --git a/iree/test/e2e/models/bert_encoder_unrolled_fake_weights.mlir b/iree/test/e2e/models/bert_encoder_unrolled_fake_weights.mlir
index 7d01f2f..6860dc0 100644
--- a/iree/test/e2e/models/bert_encoder_unrolled_fake_weights.mlir
+++ b/iree/test/e2e/models/bert_encoder_unrolled_fake_weights.mlir
@@ -3405,7 +3405,7 @@
     %2284 = mhlo.add %2282, %2283 : tensor<384x128xf32>
     %2285 = "mhlo.reshape"(%2284) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %2286 = "mhlo.transpose"(%2285) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %2287 = "mhlo.dot_general"(%2286, %2281) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %2287 = "mhlo.dot_general"(%2286, %2281) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %2288 = mhlo.multiply %2287, %1114 : tensor<1x4x384x384xf32>
     %2289 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %2290 = mhlo.add %2288, %2289 : tensor<1x4x384x384xf32>
@@ -3424,7 +3424,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %2296 = "mhlo.broadcast_in_dim"(%2295) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %2297 = mhlo.divide %2294, %2296 : tensor<1x4x384x384xf32>
-    %2298 = "mhlo.dot_general"(%2297, %2267) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %2298 = "mhlo.dot_general"(%2297, %2267) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %2299 = "mhlo.transpose"(%2298) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %2300 = "mhlo.reshape"(%2299) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %2301 = "mhlo.dot"(%2300, %1132) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -3543,7 +3543,7 @@
     %2414 = mhlo.add %2412, %2413 : tensor<384x128xf32>
     %2415 = "mhlo.reshape"(%2414) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %2416 = "mhlo.transpose"(%2415) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %2417 = "mhlo.dot_general"(%2416, %2411) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %2417 = "mhlo.dot_general"(%2416, %2411) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %2418 = mhlo.multiply %2417, %1114 : tensor<1x4x384x384xf32>
     %2419 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %2420 = mhlo.add %2418, %2419 : tensor<1x4x384x384xf32>
@@ -3562,7 +3562,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %2426 = "mhlo.broadcast_in_dim"(%2425) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %2427 = mhlo.divide %2424, %2426 : tensor<1x4x384x384xf32>
-    %2428 = "mhlo.dot_general"(%2427, %2397) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %2428 = "mhlo.dot_general"(%2427, %2397) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %2429 = "mhlo.transpose"(%2428) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %2430 = "mhlo.reshape"(%2429) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %2431 = "mhlo.dot"(%2430, %1178) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -3681,7 +3681,7 @@
     %2544 = mhlo.add %2542, %2543 : tensor<384x128xf32>
     %2545 = "mhlo.reshape"(%2544) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %2546 = "mhlo.transpose"(%2545) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %2547 = "mhlo.dot_general"(%2546, %2541) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %2547 = "mhlo.dot_general"(%2546, %2541) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %2548 = mhlo.multiply %2547, %1114 : tensor<1x4x384x384xf32>
     %2549 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %2550 = mhlo.add %2548, %2549 : tensor<1x4x384x384xf32>
@@ -3700,7 +3700,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %2556 = "mhlo.broadcast_in_dim"(%2555) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %2557 = mhlo.divide %2554, %2556 : tensor<1x4x384x384xf32>
-    %2558 = "mhlo.dot_general"(%2557, %2527) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %2558 = "mhlo.dot_general"(%2557, %2527) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %2559 = "mhlo.transpose"(%2558) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %2560 = "mhlo.reshape"(%2559) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %2561 = "mhlo.dot"(%2560, %1684) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -3819,7 +3819,7 @@
     %2674 = mhlo.add %2672, %2673 : tensor<384x128xf32>
     %2675 = "mhlo.reshape"(%2674) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %2676 = "mhlo.transpose"(%2675) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %2677 = "mhlo.dot_general"(%2676, %2671) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %2677 = "mhlo.dot_general"(%2676, %2671) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %2678 = mhlo.multiply %2677, %1114 : tensor<1x4x384x384xf32>
     %2679 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %2680 = mhlo.add %2678, %2679 : tensor<1x4x384x384xf32>
@@ -3838,7 +3838,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %2686 = "mhlo.broadcast_in_dim"(%2685) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %2687 = mhlo.divide %2684, %2686 : tensor<1x4x384x384xf32>
-    %2688 = "mhlo.dot_general"(%2687, %2657) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %2688 = "mhlo.dot_general"(%2687, %2657) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %2689 = "mhlo.transpose"(%2688) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %2690 = "mhlo.reshape"(%2689) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %2691 = "mhlo.dot"(%2690, %1914) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -3957,7 +3957,7 @@
     %2804 = mhlo.add %2802, %2803 : tensor<384x128xf32>
     %2805 = "mhlo.reshape"(%2804) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %2806 = "mhlo.transpose"(%2805) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %2807 = "mhlo.dot_general"(%2806, %2801) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %2807 = "mhlo.dot_general"(%2806, %2801) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %2808 = mhlo.multiply %2807, %1114 : tensor<1x4x384x384xf32>
     %2809 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %2810 = mhlo.add %2808, %2809 : tensor<1x4x384x384xf32>
@@ -3976,7 +3976,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %2816 = "mhlo.broadcast_in_dim"(%2815) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %2817 = mhlo.divide %2814, %2816 : tensor<1x4x384x384xf32>
-    %2818 = "mhlo.dot_general"(%2817, %2787) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %2818 = "mhlo.dot_general"(%2817, %2787) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %2819 = "mhlo.transpose"(%2818) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %2820 = "mhlo.reshape"(%2819) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %2821 = "mhlo.dot"(%2820, %1960) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -4095,7 +4095,7 @@
     %2934 = mhlo.add %2932, %2933 : tensor<384x128xf32>
     %2935 = "mhlo.reshape"(%2934) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %2936 = "mhlo.transpose"(%2935) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %2937 = "mhlo.dot_general"(%2936, %2931) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %2937 = "mhlo.dot_general"(%2936, %2931) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %2938 = mhlo.multiply %2937, %1114 : tensor<1x4x384x384xf32>
     %2939 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %2940 = mhlo.add %2938, %2939 : tensor<1x4x384x384xf32>
@@ -4114,7 +4114,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %2946 = "mhlo.broadcast_in_dim"(%2945) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %2947 = mhlo.divide %2944, %2946 : tensor<1x4x384x384xf32>
-    %2948 = "mhlo.dot_general"(%2947, %2917) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %2948 = "mhlo.dot_general"(%2947, %2917) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %2949 = "mhlo.transpose"(%2948) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %2950 = "mhlo.reshape"(%2949) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %2951 = "mhlo.dot"(%2950, %2006) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -4233,7 +4233,7 @@
     %3064 = mhlo.add %3062, %3063 : tensor<384x128xf32>
     %3065 = "mhlo.reshape"(%3064) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %3066 = "mhlo.transpose"(%3065) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %3067 = "mhlo.dot_general"(%3066, %3061) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %3067 = "mhlo.dot_general"(%3066, %3061) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %3068 = mhlo.multiply %3067, %1114 : tensor<1x4x384x384xf32>
     %3069 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %3070 = mhlo.add %3068, %3069 : tensor<1x4x384x384xf32>
@@ -4252,7 +4252,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %3076 = "mhlo.broadcast_in_dim"(%3075) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %3077 = mhlo.divide %3074, %3076 : tensor<1x4x384x384xf32>
-    %3078 = "mhlo.dot_general"(%3077, %3047) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %3078 = "mhlo.dot_general"(%3077, %3047) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %3079 = "mhlo.transpose"(%3078) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %3080 = "mhlo.reshape"(%3079) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %3081 = "mhlo.dot"(%3080, %2052) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -4371,7 +4371,7 @@
     %3194 = mhlo.add %3192, %3193 : tensor<384x128xf32>
     %3195 = "mhlo.reshape"(%3194) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %3196 = "mhlo.transpose"(%3195) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %3197 = "mhlo.dot_general"(%3196, %3191) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %3197 = "mhlo.dot_general"(%3196, %3191) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %3198 = mhlo.multiply %3197, %1114 : tensor<1x4x384x384xf32>
     %3199 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %3200 = mhlo.add %3198, %3199 : tensor<1x4x384x384xf32>
@@ -4390,7 +4390,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %3206 = "mhlo.broadcast_in_dim"(%3205) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %3207 = mhlo.divide %3204, %3206 : tensor<1x4x384x384xf32>
-    %3208 = "mhlo.dot_general"(%3207, %3177) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %3208 = "mhlo.dot_general"(%3207, %3177) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %3209 = "mhlo.transpose"(%3208) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %3210 = "mhlo.reshape"(%3209) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %3211 = "mhlo.dot"(%3210, %2098) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -4509,7 +4509,7 @@
     %3324 = mhlo.add %3322, %3323 : tensor<384x128xf32>
     %3325 = "mhlo.reshape"(%3324) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %3326 = "mhlo.transpose"(%3325) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %3327 = "mhlo.dot_general"(%3326, %3321) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %3327 = "mhlo.dot_general"(%3326, %3321) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %3328 = mhlo.multiply %3327, %1114 : tensor<1x4x384x384xf32>
     %3329 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %3330 = mhlo.add %3328, %3329 : tensor<1x4x384x384xf32>
@@ -4528,7 +4528,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %3336 = "mhlo.broadcast_in_dim"(%3335) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %3337 = mhlo.divide %3334, %3336 : tensor<1x4x384x384xf32>
-    %3338 = "mhlo.dot_general"(%3337, %3307) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %3338 = "mhlo.dot_general"(%3337, %3307) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %3339 = "mhlo.transpose"(%3338) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %3340 = "mhlo.reshape"(%3339) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %3341 = "mhlo.dot"(%3340, %2144) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -4647,7 +4647,7 @@
     %3454 = mhlo.add %3452, %3453 : tensor<384x128xf32>
     %3455 = "mhlo.reshape"(%3454) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %3456 = "mhlo.transpose"(%3455) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %3457 = "mhlo.dot_general"(%3456, %3451) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %3457 = "mhlo.dot_general"(%3456, %3451) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %3458 = mhlo.multiply %3457, %1114 : tensor<1x4x384x384xf32>
     %3459 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %3460 = mhlo.add %3458, %3459 : tensor<1x4x384x384xf32>
@@ -4666,7 +4666,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %3466 = "mhlo.broadcast_in_dim"(%3465) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %3467 = mhlo.divide %3464, %3466 : tensor<1x4x384x384xf32>
-    %3468 = "mhlo.dot_general"(%3467, %3437) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %3468 = "mhlo.dot_general"(%3467, %3437) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %3469 = "mhlo.transpose"(%3468) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %3470 = "mhlo.reshape"(%3469) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %3471 = "mhlo.dot"(%3470, %2190) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -4785,7 +4785,7 @@
     %3584 = mhlo.add %3582, %3583 : tensor<384x128xf32>
     %3585 = "mhlo.reshape"(%3584) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %3586 = "mhlo.transpose"(%3585) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %3587 = "mhlo.dot_general"(%3586, %3581) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %3587 = "mhlo.dot_general"(%3586, %3581) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %3588 = mhlo.multiply %3587, %1114 : tensor<1x4x384x384xf32>
     %3589 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %3590 = mhlo.add %3588, %3589 : tensor<1x4x384x384xf32>
@@ -4804,7 +4804,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %3596 = "mhlo.broadcast_in_dim"(%3595) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %3597 = mhlo.divide %3594, %3596 : tensor<1x4x384x384xf32>
-    %3598 = "mhlo.dot_general"(%3597, %3567) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %3598 = "mhlo.dot_general"(%3597, %3567) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %3599 = "mhlo.transpose"(%3598) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %3600 = "mhlo.reshape"(%3599) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %3601 = "mhlo.dot"(%3600, %1224) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -4923,7 +4923,7 @@
     %3714 = mhlo.add %3712, %3713 : tensor<384x128xf32>
     %3715 = "mhlo.reshape"(%3714) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %3716 = "mhlo.transpose"(%3715) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %3717 = "mhlo.dot_general"(%3716, %3711) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %3717 = "mhlo.dot_general"(%3716, %3711) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %3718 = mhlo.multiply %3717, %1114 : tensor<1x4x384x384xf32>
     %3719 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %3720 = mhlo.add %3718, %3719 : tensor<1x4x384x384xf32>
@@ -4942,7 +4942,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %3726 = "mhlo.broadcast_in_dim"(%3725) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %3727 = mhlo.divide %3724, %3726 : tensor<1x4x384x384xf32>
-    %3728 = "mhlo.dot_general"(%3727, %3697) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %3728 = "mhlo.dot_general"(%3727, %3697) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %3729 = "mhlo.transpose"(%3728) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %3730 = "mhlo.reshape"(%3729) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %3731 = "mhlo.dot"(%3730, %1270) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -5061,7 +5061,7 @@
     %3844 = mhlo.add %3842, %3843 : tensor<384x128xf32>
     %3845 = "mhlo.reshape"(%3844) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %3846 = "mhlo.transpose"(%3845) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %3847 = "mhlo.dot_general"(%3846, %3841) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %3847 = "mhlo.dot_general"(%3846, %3841) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %3848 = mhlo.multiply %3847, %1114 : tensor<1x4x384x384xf32>
     %3849 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %3850 = mhlo.add %3848, %3849 : tensor<1x4x384x384xf32>
@@ -5080,7 +5080,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %3856 = "mhlo.broadcast_in_dim"(%3855) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %3857 = mhlo.divide %3854, %3856 : tensor<1x4x384x384xf32>
-    %3858 = "mhlo.dot_general"(%3857, %3827) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %3858 = "mhlo.dot_general"(%3857, %3827) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %3859 = "mhlo.transpose"(%3858) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %3860 = "mhlo.reshape"(%3859) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %3861 = "mhlo.dot"(%3860, %1316) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -5199,7 +5199,7 @@
     %3974 = mhlo.add %3972, %3973 : tensor<384x128xf32>
     %3975 = "mhlo.reshape"(%3974) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %3976 = "mhlo.transpose"(%3975) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %3977 = "mhlo.dot_general"(%3976, %3971) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %3977 = "mhlo.dot_general"(%3976, %3971) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %3978 = mhlo.multiply %3977, %1114 : tensor<1x4x384x384xf32>
     %3979 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %3980 = mhlo.add %3978, %3979 : tensor<1x4x384x384xf32>
@@ -5218,7 +5218,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %3986 = "mhlo.broadcast_in_dim"(%3985) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %3987 = mhlo.divide %3984, %3986 : tensor<1x4x384x384xf32>
-    %3988 = "mhlo.dot_general"(%3987, %3957) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %3988 = "mhlo.dot_general"(%3987, %3957) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %3989 = "mhlo.transpose"(%3988) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %3990 = "mhlo.reshape"(%3989) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %3991 = "mhlo.dot"(%3990, %1362) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -5337,7 +5337,7 @@
     %4104 = mhlo.add %4102, %4103 : tensor<384x128xf32>
     %4105 = "mhlo.reshape"(%4104) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %4106 = "mhlo.transpose"(%4105) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %4107 = "mhlo.dot_general"(%4106, %4101) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %4107 = "mhlo.dot_general"(%4106, %4101) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %4108 = mhlo.multiply %4107, %1114 : tensor<1x4x384x384xf32>
     %4109 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %4110 = mhlo.add %4108, %4109 : tensor<1x4x384x384xf32>
@@ -5356,7 +5356,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %4116 = "mhlo.broadcast_in_dim"(%4115) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %4117 = mhlo.divide %4114, %4116 : tensor<1x4x384x384xf32>
-    %4118 = "mhlo.dot_general"(%4117, %4087) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %4118 = "mhlo.dot_general"(%4117, %4087) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %4119 = "mhlo.transpose"(%4118) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %4120 = "mhlo.reshape"(%4119) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %4121 = "mhlo.dot"(%4120, %1408) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -5475,7 +5475,7 @@
     %4234 = mhlo.add %4232, %4233 : tensor<384x128xf32>
     %4235 = "mhlo.reshape"(%4234) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %4236 = "mhlo.transpose"(%4235) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %4237 = "mhlo.dot_general"(%4236, %4231) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %4237 = "mhlo.dot_general"(%4236, %4231) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %4238 = mhlo.multiply %4237, %1114 : tensor<1x4x384x384xf32>
     %4239 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %4240 = mhlo.add %4238, %4239 : tensor<1x4x384x384xf32>
@@ -5494,7 +5494,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %4246 = "mhlo.broadcast_in_dim"(%4245) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %4247 = mhlo.divide %4244, %4246 : tensor<1x4x384x384xf32>
-    %4248 = "mhlo.dot_general"(%4247, %4217) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %4248 = "mhlo.dot_general"(%4247, %4217) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %4249 = "mhlo.transpose"(%4248) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %4250 = "mhlo.reshape"(%4249) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %4251 = "mhlo.dot"(%4250, %1454) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -5613,7 +5613,7 @@
     %4364 = mhlo.add %4362, %4363 : tensor<384x128xf32>
     %4365 = "mhlo.reshape"(%4364) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %4366 = "mhlo.transpose"(%4365) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %4367 = "mhlo.dot_general"(%4366, %4361) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %4367 = "mhlo.dot_general"(%4366, %4361) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %4368 = mhlo.multiply %4367, %1114 : tensor<1x4x384x384xf32>
     %4369 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %4370 = mhlo.add %4368, %4369 : tensor<1x4x384x384xf32>
@@ -5632,7 +5632,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %4376 = "mhlo.broadcast_in_dim"(%4375) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %4377 = mhlo.divide %4374, %4376 : tensor<1x4x384x384xf32>
-    %4378 = "mhlo.dot_general"(%4377, %4347) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %4378 = "mhlo.dot_general"(%4377, %4347) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %4379 = "mhlo.transpose"(%4378) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %4380 = "mhlo.reshape"(%4379) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %4381 = "mhlo.dot"(%4380, %1500) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -5751,7 +5751,7 @@
     %4494 = mhlo.add %4492, %4493 : tensor<384x128xf32>
     %4495 = "mhlo.reshape"(%4494) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %4496 = "mhlo.transpose"(%4495) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %4497 = "mhlo.dot_general"(%4496, %4491) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %4497 = "mhlo.dot_general"(%4496, %4491) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %4498 = mhlo.multiply %4497, %1114 : tensor<1x4x384x384xf32>
     %4499 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %4500 = mhlo.add %4498, %4499 : tensor<1x4x384x384xf32>
@@ -5770,7 +5770,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %4506 = "mhlo.broadcast_in_dim"(%4505) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %4507 = mhlo.divide %4504, %4506 : tensor<1x4x384x384xf32>
-    %4508 = "mhlo.dot_general"(%4507, %4477) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %4508 = "mhlo.dot_general"(%4507, %4477) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %4509 = "mhlo.transpose"(%4508) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %4510 = "mhlo.reshape"(%4509) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %4511 = "mhlo.dot"(%4510, %1546) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -5889,7 +5889,7 @@
     %4624 = mhlo.add %4622, %4623 : tensor<384x128xf32>
     %4625 = "mhlo.reshape"(%4624) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %4626 = "mhlo.transpose"(%4625) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %4627 = "mhlo.dot_general"(%4626, %4621) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %4627 = "mhlo.dot_general"(%4626, %4621) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %4628 = mhlo.multiply %4627, %1114 : tensor<1x4x384x384xf32>
     %4629 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %4630 = mhlo.add %4628, %4629 : tensor<1x4x384x384xf32>
@@ -5908,7 +5908,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %4636 = "mhlo.broadcast_in_dim"(%4635) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %4637 = mhlo.divide %4634, %4636 : tensor<1x4x384x384xf32>
-    %4638 = "mhlo.dot_general"(%4637, %4607) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %4638 = "mhlo.dot_general"(%4637, %4607) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %4639 = "mhlo.transpose"(%4638) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %4640 = "mhlo.reshape"(%4639) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %4641 = "mhlo.dot"(%4640, %1592) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -6027,7 +6027,7 @@
     %4754 = mhlo.add %4752, %4753 : tensor<384x128xf32>
     %4755 = "mhlo.reshape"(%4754) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %4756 = "mhlo.transpose"(%4755) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %4757 = "mhlo.dot_general"(%4756, %4751) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %4757 = "mhlo.dot_general"(%4756, %4751) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %4758 = mhlo.multiply %4757, %1114 : tensor<1x4x384x384xf32>
     %4759 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %4760 = mhlo.add %4758, %4759 : tensor<1x4x384x384xf32>
@@ -6046,7 +6046,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %4766 = "mhlo.broadcast_in_dim"(%4765) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %4767 = mhlo.divide %4764, %4766 : tensor<1x4x384x384xf32>
-    %4768 = "mhlo.dot_general"(%4767, %4737) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %4768 = "mhlo.dot_general"(%4767, %4737) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %4769 = "mhlo.transpose"(%4768) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %4770 = "mhlo.reshape"(%4769) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %4771 = "mhlo.dot"(%4770, %1638) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -6165,7 +6165,7 @@
     %4884 = mhlo.add %4882, %4883 : tensor<384x128xf32>
     %4885 = "mhlo.reshape"(%4884) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %4886 = "mhlo.transpose"(%4885) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %4887 = "mhlo.dot_general"(%4886, %4881) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %4887 = "mhlo.dot_general"(%4886, %4881) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %4888 = mhlo.multiply %4887, %1114 : tensor<1x4x384x384xf32>
     %4889 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %4890 = mhlo.add %4888, %4889 : tensor<1x4x384x384xf32>
@@ -6184,7 +6184,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %4896 = "mhlo.broadcast_in_dim"(%4895) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %4897 = mhlo.divide %4894, %4896 : tensor<1x4x384x384xf32>
-    %4898 = "mhlo.dot_general"(%4897, %4867) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %4898 = "mhlo.dot_general"(%4897, %4867) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %4899 = "mhlo.transpose"(%4898) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %4900 = "mhlo.reshape"(%4899) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %4901 = "mhlo.dot"(%4900, %1730) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -6303,7 +6303,7 @@
     %5014 = mhlo.add %5012, %5013 : tensor<384x128xf32>
     %5015 = "mhlo.reshape"(%5014) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %5016 = "mhlo.transpose"(%5015) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %5017 = "mhlo.dot_general"(%5016, %5011) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %5017 = "mhlo.dot_general"(%5016, %5011) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %5018 = mhlo.multiply %5017, %1114 : tensor<1x4x384x384xf32>
     %5019 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %5020 = mhlo.add %5018, %5019 : tensor<1x4x384x384xf32>
@@ -6322,7 +6322,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %5026 = "mhlo.broadcast_in_dim"(%5025) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %5027 = mhlo.divide %5024, %5026 : tensor<1x4x384x384xf32>
-    %5028 = "mhlo.dot_general"(%5027, %4997) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %5028 = "mhlo.dot_general"(%5027, %4997) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %5029 = "mhlo.transpose"(%5028) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %5030 = "mhlo.reshape"(%5029) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %5031 = "mhlo.dot"(%5030, %1776) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -6441,7 +6441,7 @@
     %5144 = mhlo.add %5142, %5143 : tensor<384x128xf32>
     %5145 = "mhlo.reshape"(%5144) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %5146 = "mhlo.transpose"(%5145) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %5147 = "mhlo.dot_general"(%5146, %5141) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %5147 = "mhlo.dot_general"(%5146, %5141) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %5148 = mhlo.multiply %5147, %1114 : tensor<1x4x384x384xf32>
     %5149 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %5150 = mhlo.add %5148, %5149 : tensor<1x4x384x384xf32>
@@ -6460,7 +6460,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %5156 = "mhlo.broadcast_in_dim"(%5155) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %5157 = mhlo.divide %5154, %5156 : tensor<1x4x384x384xf32>
-    %5158 = "mhlo.dot_general"(%5157, %5127) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %5158 = "mhlo.dot_general"(%5157, %5127) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %5159 = "mhlo.transpose"(%5158) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %5160 = "mhlo.reshape"(%5159) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %5161 = "mhlo.dot"(%5160, %1822) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>
@@ -6579,7 +6579,7 @@
     %5274 = mhlo.add %5272, %5273 : tensor<384x128xf32>
     %5275 = "mhlo.reshape"(%5274) : (tensor<384x128xf32>) -> tensor<1x384x4x32xf32>
     %5276 = "mhlo.transpose"(%5275) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x384x4x32xf32>) -> tensor<1x4x384x32xf32>
-    %5277 = "mhlo.dot_general"(%5276, %5271) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<3> : tensor<1xi64>}} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
+    %5277 = "mhlo.dot_general"(%5276, %5271) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [3]>} : (tensor<1x4x384x32xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x384xf32>
     %5278 = mhlo.multiply %5277, %1114 : tensor<1x4x384x384xf32>
     %5279 = "mhlo.broadcast_in_dim"(%2254) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x384x384xf32>) -> tensor<1x4x384x384xf32>
     %5280 = mhlo.add %5278, %5279 : tensor<1x4x384x384xf32>
@@ -6598,7 +6598,7 @@
     }) {dimensions = dense<3> : tensor<1xi64>} : (tensor<1x4x384x384xf32>, tensor<f32>) -> tensor<1x4x384xf32>
     %5286 = "mhlo.broadcast_in_dim"(%5285) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x384xf32>) -> tensor<1x4x384x384xf32>
     %5287 = mhlo.divide %5284, %5286 : tensor<1x4x384x384xf32>
-    %5288 = "mhlo.dot_general"(%5287, %5257) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<3> : tensor<1xi64>, rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
+    %5288 = "mhlo.dot_general"(%5287, %5257) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_batching_dimensions = [0, 1], rhs_contracting_dimensions = [2]>} : (tensor<1x4x384x384xf32>, tensor<1x4x384x32xf32>) -> tensor<1x4x384x32xf32>
     %5289 = "mhlo.transpose"(%5288) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>} : (tensor<1x4x384x32xf32>) -> tensor<1x384x4x32xf32>
     %5290 = "mhlo.reshape"(%5289) : (tensor<1x384x4x32xf32>) -> tensor<384x128xf32>
     %5291 = "mhlo.dot"(%5290, %1868) : (tensor<384x128xf32>, tensor<128x128xf32>) -> tensor<384x128xf32>