| // RUN: iree-opt -split-input-file -iree-index-computation -simplify-spirv-affine-exprs=false -convert-iree-to-spirv -verify-diagnostics -o - %s | IreeFileCheck %s |
| |
| module { |
| // CHECK:spv.module Logical GLSL450 |
| // CHECK-DAG: spv.globalVariable [[GLOBALIDVAR:@.*]] built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input> |
| // CHECK: spv.func [[FN:@broadcast_2D_3D]] |
| // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer> |
| // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<1512 x i32 [4]> [0]>, StorageBuffer> |
| |
| func @broadcast_2D_3D(%arg0: memref<12x42xi32>, %arg1: memref<3x12x42xi32>) |
| attributes {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} { |
| // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]] |
| // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]] |
| %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32> |
| %1 = "xla_hlo.broadcast"(%0) {broadcast_sizes = dense<[3]> : tensor<1xi64>} : (tensor<12x42xi32>) -> tensor<3x12x42xi32> |
| // CHECK: [[ARG1STOREPTR:%.*]] = spv.AccessChain [[ARG1]] |
| // CHECK: spv.Store "StorageBuffer" [[ARG1STOREPTR]], [[VAL]] |
| iree.store_output(%1 : tensor<3x12x42xi32>, %arg1 : memref<3x12x42xi32>) |
| return |
| } |
| } |
| |
| // ----- |
| |
| module { |
| // CHECK:spv.module Logical GLSL450 |
| // CHECK-DAG: spv.globalVariable [[GLOBALIDVAR:@.*]] built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input> |
| // CHECK: spv.func [[FN:@broadcast_scalar_3D]] |
| // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer> |
| // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<1512 x i32 [4]> [0]>, StorageBuffer> |
| func @broadcast_scalar_3D(%arg0: memref<i32>, %arg1: memref<3x12x42xi32>) |
| attributes {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} { |
| // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]] |
| // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]] |
| %0 = iree.load_input(%arg0 : memref<i32>) : tensor<i32> |
| %1 = "xla_hlo.broadcast"(%0) {broadcast_sizes = dense<[3, 12, 42]>: tensor<3xi64>} : (tensor<i32>) -> tensor<3x12x42xi32> |
| // CHECK: [[ARG1STOREPTR:%.*]] = spv.AccessChain [[ARG1]] |
| // CHECK: spv.Store "StorageBuffer" [[ARG1STOREPTR]], [[VAL]] |
| iree.store_output(%1 : tensor<3x12x42xi32>, %arg1 : memref<3x12x42xi32>) |
| return |
| } |
| } |