blob: 88c25a50268762b0a36c02387c207884080b2534 [file] [log] [blame]
module @jit__lambda_ {
func.func public @main(%arg0: tensor<8xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<8xf32> {jax.result_info = ""}) {
%0 = call @"<lambda>"(%arg0) : (tensor<8xf32>) -> tensor<8xf32>
return %0 : tensor<8xf32>
}
func.func private @"<lambda>"(%arg0: tensor<8xf32>) -> tensor<8xf32> {
%0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.sharding = "{replicated}"} : (tensor<8xf32>) -> tensor<8xf32>
%1 = stablehlo.constant dense<[5.000000e-01, 1.000000e+00, 5.000000e-01]> : tensor<3xf32>
%2 = call @convolve(%0, %1) : (tensor<8xf32>, tensor<3xf32>) -> tensor<8xf32>
%3 = stablehlo.custom_call @Sharding(%2) {mhlo.sharding = "{maximal device=0}"} : (tensor<8xf32>) -> tensor<8xf32>
return %3 : tensor<8xf32>
}
func.func private @convolve(%arg0: tensor<8xf32>, %arg1: tensor<3xf32>) -> tensor<8xf32> {
%0 = call @_conv(%arg0, %arg1) : (tensor<8xf32>, tensor<3xf32>) -> tensor<8xf32>
return %0 : tensor<8xf32>
}
func.func private @_conv(%arg0: tensor<8xf32>, %arg1: tensor<3xf32>) -> tensor<8xf32> {
%0 = call @_flip(%arg1) : (tensor<3xf32>) -> tensor<3xf32>
%1 = stablehlo.broadcast_in_dim %arg0, dims = [2] : (tensor<8xf32>) -> tensor<1x1x8xf32>
%2 = stablehlo.broadcast_in_dim %0, dims = [2] : (tensor<3xf32>) -> tensor<1x1x3xf32>
%3 = stablehlo.convolution(%1, %2) dim_numbers = [b, f, 0]x[o, i, 0]->[b, f, 0], window = {stride = [1], pad = [[1, 1]], lhs_dilate = [1], rhs_dilate = [1], reverse = [0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<1x1x8xf32>, tensor<1x1x3xf32>) -> tensor<1x1x8xf32>
%4 = stablehlo.constant dense<0> : tensor<i32>
%5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<i32>) -> tensor<1xi32>
%6 = stablehlo.constant dense<0> : tensor<i32>
%7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor<i32>) -> tensor<1xi32>
%8 = stablehlo.concatenate %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
%9 = "stablehlo.gather"(%3, %8) {dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<[1, 1, 8]> : tensor<3xi64>} : (tensor<1x1x8xf32>, tensor<2xi32>) -> tensor<8xf32>
return %9 : tensor<8xf32>
}
func.func private @_flip(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.reverse %arg0, dims = [0] : tensor<3xf32>
return %0 : tensor<3xf32>
}
}