Adapt to function op change in SPIR-V dialect
SPIR-V dialect changes to use spv.func to represent functions.
PiperOrigin-RevId: 294947689
diff --git a/iree/compiler/Translation/SPIRV/ReductionCodegen/BUILD b/iree/compiler/Translation/SPIRV/ReductionCodegen/BUILD
index 5566307..5fb973b 100644
--- a/iree/compiler/Translation/SPIRV/ReductionCodegen/BUILD
+++ b/iree/compiler/Translation/SPIRV/ReductionCodegen/BUILD
@@ -31,6 +31,7 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SPIRVDialect",
+ "@llvm-project//mlir:SPIRVLowering",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
"@org_tensorflow//tensorflow/compiler/mlir/xla:hlo",
diff --git a/iree/compiler/Translation/SPIRV/ReductionCodegen/ReductionFnLowering.cpp b/iree/compiler/Translation/SPIRV/ReductionCodegen/ReductionFnLowering.cpp
index 77b6a53..e10d587 100644
--- a/iree/compiler/Translation/SPIRV/ReductionCodegen/ReductionFnLowering.cpp
+++ b/iree/compiler/Translation/SPIRV/ReductionCodegen/ReductionFnLowering.cpp
@@ -20,6 +20,7 @@
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Function.h"
@@ -33,7 +34,7 @@
namespace {
/// Type converter for legalization of reduction apply function.
-class SPIRVReductionTypeConverter : public TypeConverter {
+class SPIRVReductionTypeConverter : public SPIRVTypeConverter {
public:
Type convertType(Type t) override;
};
@@ -177,8 +178,8 @@
//===----------------------------------------------------------------------===//
LogicalResult lowerReductionApplyFunction(MLIRContext *context,
ArrayRef<Operation *> fns) {
- OwningRewritePatternList patterns;
SPIRVReductionTypeConverter typeConverter;
+ OwningRewritePatternList patterns;
patterns
.insert<ReductionApplyFnConversion,
ReductionOpConversion<xla_hlo::MinOp, spirv::AtomicSMinOp>,
@@ -186,10 +187,10 @@
ReductionOpConversion<AddIOp, spirv::AtomicIAddOp>,
ReturnOpConversion<IREE::ReturnOp>, ReturnOpConversion<ReturnOp>>(
context, typeConverter);
+ populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+
ConversionTarget target(*context);
target.addLegalDialect<spirv::SPIRVDialect>();
- target.addDynamicallyLegalOp<FuncOp>(
- [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
if (failed(applyPartialConversion(fns, target, patterns))) {
return failure();
}
diff --git a/iree/compiler/Translation/SPIRV/ReductionCodegen/test/ops.mlir b/iree/compiler/Translation/SPIRV/ReductionCodegen/test/ops.mlir
index c32d81c..909f6ad 100644
--- a/iree/compiler/Translation/SPIRV/ReductionCodegen/test/ops.mlir
+++ b/iree/compiler/Translation/SPIRV/ReductionCodegen/test/ops.mlir
@@ -1,6 +1,6 @@
// RUN: iree-opt -iree-spirv-reduction-fn-lowering -o - %s | IreeFileCheck %s
-// CHECK-LABEL: func @reduction_max_apply
+// CHECK-LABEL: spv.func @reduction_max_apply
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: i32
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<i32, StorageBuffer>
func @reduction_max_apply(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
@@ -9,7 +9,7 @@
iree.return %0 : tensor<i32>
}
-// CHECK-LABEL: func @reduction_min_apply
+// CHECK-LABEL: spv.func @reduction_min_apply
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: i32
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<i32, StorageBuffer>
func @reduction_min_apply(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
@@ -18,7 +18,7 @@
iree.return %0 : tensor<i32>
}
-// CHECK-LABEL: func @reduction_iadd_apply
+// CHECK-LABEL: spv.func @reduction_iadd_apply
func @reduction_iadd_apply(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
// CHECK: spv.AtomicIAdd
%0 = std.addi %arg0, %arg1 : tensor<i32>
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.cpp b/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.cpp
index 915b687..ca7bb66 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.cpp
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.cpp
@@ -135,8 +135,7 @@
}
auto entryFnType = builder.getFunctionType(entryFnArgTypes, ArrayRef<Type>());
- auto entryFn = builder.create<FuncOp>(loc, fn.getName(), entryFnType,
- ArrayRef<NamedAttribute>());
+ auto entryFn = builder.create<spirv::FuncOp>(loc, fn.getName(), entryFnType);
entryFn.addEntryBlock();
SmallVector<int32_t, 3> workGroupSize;
@@ -301,7 +300,7 @@
}
LogicalResult SPIRVCodegenImpl::lowerFunction(
- OpBuilder &builder, FuncOp fn, FuncOp entryFn,
+ OpBuilder &builder, FuncOp fn, spirv::FuncOp entryFn,
TensorIndexToScalarValueMap &valueCache) {
if (failed(createLaunchGuard(builder, fn))) {
return failure();
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.h b/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.h
index 22d45bf..6a35a39 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.h
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.h
@@ -277,7 +277,8 @@
Value origArg, AffineMap indexMap);
/// Lowers the body of the function in the original dialect to SPIR-V dialect.
- LogicalResult lowerFunction(OpBuilder &builder, FuncOp fn, FuncOp entryFn,
+ LogicalResult lowerFunction(OpBuilder &builder, FuncOp fn,
+ spirv::FuncOp entryFn,
TensorIndexToScalarValueMap &valueCache);
/// Method to lower the operations within the dispatch function.
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/adjust_integer_width.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/adjust_integer_width.mlir
index 23a692f..ebb3f77 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/adjust_integer_width.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/adjust_integer_width.mlir
@@ -7,7 +7,7 @@
// CHECK: spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
- func @foo_i64(%arg0 : i64, %arg1 : i64) -> () {
+ spv.func @foo_i64(%arg0 : i64, %arg1 : i64) -> () "None" {
// CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
@@ -32,7 +32,7 @@
// CHECK: spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i16 [0]>, StorageBuffer>
spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i16 [0]>, StorageBuffer>
- func @foo_i16(%arg0 : i16, %arg1 : i16) -> () {
+ spv.func @foo_i16(%arg0 : i16, %arg1 : i16) -> () "None" {
// CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
@@ -56,7 +56,7 @@
// CHECK: spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i8 [0]>, StorageBuffer>
- func @foo_i8(%arg0 : i8, %arg1 : i8) -> () {
+ spv.func @foo_i8(%arg0 : i8, %arg1 : i8) -> () "None" {
// CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
@@ -80,7 +80,7 @@
// CHECK: spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
spv.globalVariable @constant_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
spv.globalVariable @constant_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
- func @foo_i1(%arg0 : i1, %arg1 : i1) -> () {
+ spv.func @foo_i1(%arg0 : i1, %arg1 : i1) -> () "None" {
// CHECK: spv._address_of {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK: spv.AccessChain {{.*}} : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
@@ -105,7 +105,7 @@
spv.globalVariable @globalInvocationID built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
spv.globalVariable @arg_0 bind(0, 0) : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
spv.globalVariable @arg_1 bind(0, 1) : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
- func @add(%arg0: i64, %arg1: i64) -> () {
+ spv.func @add(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -118,7 +118,7 @@
spv.Return
}
- func @sub(%arg0: i64, %arg1: i64) -> () {
+ spv.func @sub(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -132,7 +132,7 @@
spv.Return
}
- func @mul(%arg0: i64, %arg1: i64) -> () {
+ spv.func @mul(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -145,7 +145,7 @@
spv.Return
}
- func @sdiv(%arg0: i64, %arg1: i64) -> () {
+ spv.func @sdiv(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -158,7 +158,7 @@
spv.Return
}
- func @smod(%arg0: i64, %arg1: i64) -> () {
+ spv.func @smod(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -171,7 +171,7 @@
spv.Return
}
- func @srem(%arg0: i64, %arg1: i64) -> () {
+ spv.func @srem(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -184,7 +184,7 @@
spv.Return
}
- func @udiv(%arg0: i64, %arg1: i64) -> () {
+ spv.func @udiv(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -197,7 +197,7 @@
spv.Return
}
- func @umod(%arg0: i64, %arg1: i64) -> () {
+ spv.func @umod(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -210,7 +210,7 @@
spv.Return
}
- func @abs(%arg0: i64, %arg1: i64) -> () {
+ spv.func @abs(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -223,7 +223,7 @@
spv.Return
}
- func @smax(%arg0: i64, %arg1: i64) -> () {
+ spv.func @smax(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -236,7 +236,7 @@
spv.Return
}
- func @smin(%arg0: i64, %arg1: i64) -> () {
+ spv.func @smin(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -249,7 +249,7 @@
spv.Return
}
- func @sign(%arg0: i64, %arg1: i64) -> () {
+ spv.func @sign(%arg0: i64, %arg1: i64) -> () "None" {
%0 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
@@ -262,7 +262,7 @@
spv.Return
}
- func @constant_i64(%arg1: i64) -> () {
+ spv.func @constant_i64(%arg1: i64) -> () "None" {
// CHECK: spv.constant 1337 : i32
%0 = spv.constant 1337 : i64
%1 = spv.constant 0 : i32
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/arithmetic_ops.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/arithmetic_ops.mlir
index 26bfedd..30dd16f 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/arithmetic_ops.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/arithmetic_ops.mlir
@@ -1,7 +1,7 @@
// RUN: iree-opt -split-input-file -iree-index-computation -simplify-spirv-affine-exprs=false -convert-iree-to-spirv -verify-diagnostics -o - %s | IreeFileCheck %s
// CHECK-DAG: spv.globalVariable [[GLOBALIDVAR:@.*]] built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
-// CHECK: func @mul_1D
+// CHECK: spv.func @mul_1D
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<4 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<4 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG2:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<4 x f32 [4]> [0]>, StorageBuffer>
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast.mlir
index 672bd76..d8f1979 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast.mlir
@@ -3,7 +3,7 @@
module {
// CHECK:spv.module "Logical" "GLSL450"
// CHECK-DAG: spv.globalVariable [[GLOBALIDVAR:@.*]] built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
- // CHECK: func [[FN:@broadcast_2D_3D]]
+ // CHECK: spv.func [[FN:@broadcast_2D_3D]]
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<1512 x i32 [4]> [0]>, StorageBuffer>
@@ -25,7 +25,7 @@
module {
// CHECK:spv.module "Logical" "GLSL450"
// CHECK-DAG: spv.globalVariable [[GLOBALIDVAR:@.*]] built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
- // CHECK: func [[FN:@broadcast_scalar_3D]]
+ // CHECK: spv.func [[FN:@broadcast_scalar_3D]]
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<1512 x i32 [4]> [0]>, StorageBuffer>
func @broadcast_scalar_3D(%arg0: memref<i32>, %arg1: memref<3x12x42xi32>)
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast_in_dim.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast_in_dim.mlir
index 85c4c2a..80378e9 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast_in_dim.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast_in_dim.mlir
@@ -3,7 +3,7 @@
module {
// CHECK:spv.module "Logical" "GLSL450"
// CHECK-DAG: spv.globalVariable [[GLOBALIDVAR:@.*]] built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
- // CHECK: func [[FN:@broadcast_in_dim_2D_3D]]
+ // CHECK: spv.func [[FN:@broadcast_in_dim_2D_3D]]
// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1512 x i32 [4]> [0]>, StorageBuffer>
func @broadcast_in_dim_2D_3D(%arg0: memref<12x42xi32>, %arg1: memref<3x12x42xi32>)
@@ -20,7 +20,7 @@
module {
// CHECK:spv.module "Logical" "GLSL450"
// CHECK-DAG: spv.globalVariable [[GLOBALIDVAR:@.*]] built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
- // CHECK: func [[FN:@broadcast_in_dim_scalar_3D]]
+ // CHECK: spv.func [[FN:@broadcast_in_dim_scalar_3D]]
// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1512 x i32 [4]> [0]>, StorageBuffer>
func @broadcast_in_dim_scalar_3D(%arg0: memref<i32>, %arg1: memref<3x12x42xi32>)
@@ -65,7 +65,7 @@
// -----
module {
- // CHECK: func @const_int_nonsplat
+ // CHECK: spv.func @const_int_nonsplat
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<1008 x i32 [4]> [0]>, StorageBuffer>
func @const_int_nonsplat(%arg0: memref<2x12x42xi32>)
attributes {iree.executable.export, iree.executable.workload = dense<[42, 12, 2]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/concatenate.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/concatenate.mlir
index d5c0e62..79510a0 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/concatenate.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/concatenate.mlir
@@ -2,7 +2,7 @@
module {
// CHECK-DAG: spv.globalVariable [[GLOBALIDVAR:@.*]] built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
- // CHECK: func @concatenate
+ // CHECK: spv.func @concatenate
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<64 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
func @concatenate(%arg0: memref<1x64xf32>, %arg1 : memref<1x10xf32>, %arg2 : memref<1x74xf32>)
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/copy.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/copy.mlir
index 85a284e..5012f3e 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/copy.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/copy.mlir
@@ -3,7 +3,7 @@
module {
// CHECK:spv.module "Logical" "GLSL450"
// CHECK-DAG: spv.globalVariable [[GLOBALIDVAR:@.*]] built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
- // CHECK: func [[FN:@simple_load_store]]
+ // CHECK: spv.func [[FN:@simple_load_store]]
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: {{spirv|spv}}.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/extract_element.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/extract_element.mlir
index 80581a1..23f0ec9 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/extract_element.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/extract_element.mlir
@@ -1,7 +1,7 @@
// RUN: iree-opt -split-input-file -iree-index-computation -simplify-spirv-affine-exprs=false -convert-iree-to-spirv -verify-diagnostics -o - %s | IreeFileCheck %s
module {
- // CHECK-LABEL: func @extract_element
+ // CHECK-LABEL: spv.func @extract_element
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
func @extract_element(%arg0: memref<i1>, %arg1: memref<i1>)
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/gather.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/gather.mlir
index 67ae2f2..528c852 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/gather.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/gather.mlir
@@ -1,7 +1,7 @@
// RUN: iree-opt -iree-index-computation -simplify-spirv-affine-exprs=false -convert-iree-to-spirv %s | IreeFileCheck %s
module {
- // CHECK-LABEL: func @foo
+ // CHECK-LABEL: spv.func @foo
// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<50 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
func @foo(%arg0: memref<5x1x10xf32>, %arg1: memref<i64>, %arg2: memref<1x10xf32>)
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/pad.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/pad.mlir
index a5f9ece..c8c9909 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/pad.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/pad.mlir
@@ -1,7 +1,7 @@
// RUN: iree-opt -split-input-file -iree-index-computation -simplify-spirv-affine-exprs=false -convert-iree-to-spirv -verify-diagnostics -o - %s | IreeFileCheck %s
module {
- // CHECK: func @pad_zero_interior
+ // CHECK: spv.func @pad_zero_interior
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<216 x f32 [4]> [0]>, StorageBuffer>
func @pad_zero_interior(%arg0 : memref<12x4xf32>, %arg1 : memref<18x12xf32>)
@@ -26,7 +26,7 @@
// -----
module {
- // CHECK: func @pad_no_op
+ // CHECK: spv.func @pad_no_op
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer>
func @pad_no_op(%arg0 : memref<12x4xf32>, %arg1 : memref<12x4xf32>)
@@ -50,7 +50,7 @@
// -----
module {
- // CHECK: func @pad_zero_interior
+ // CHECK: spv.func @pad_zero_interior
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<522 x f32 [4]> [0]>, StorageBuffer>
func @pad_zero_interior(%arg0 : memref<12x4xf32>, %arg1 : memref<29x18xf32>)
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape.mlir
index bbed0dd..666c153 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape.mlir
@@ -1,7 +1,7 @@
// RUN: iree-opt -split-input-file -iree-index-computation -simplify-spirv-affine-exprs=false -convert-iree-to-spirv -verify-diagnostics -o - %s | IreeFileCheck %s
module {
- // CHECK: func @reshape_2D_2D
+ // CHECK: spv.func @reshape_2D_2D
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
func @reshape_2D_2D(%arg0: memref<24x21xi32>, %arg1: memref<12x42xi32>)
@@ -20,7 +20,7 @@
// -----
module {
- // CHECK: func @reshape_3D_2D
+ // CHECK: spv.func @reshape_3D_2D
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
func @reshape_3D_2D(%arg0: memref<4x6x21xi32>, %arg1: memref<12x42xi32>)
@@ -39,7 +39,7 @@
// -----
module {
- // CHECK: func @reshape_2D_3D
+ // CHECK: spv.func @reshape_2D_3D
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
func @reshape_2D_3D(%arg0: memref<24x21xi32>, %arg1: memref<12x6x7xi32>)
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape_dropdims.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape_dropdims.mlir
index f7c03f6..49f4935 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape_dropdims.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape_dropdims.mlir
@@ -2,7 +2,7 @@
module {
// CHECK-LABEL: spv.module
- // CHECK: func @reshape_4D_3D
+ // CHECK: spv.func @reshape_4D_3D
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
func @reshape_4D_3D(%arg0: memref<12x42x1xi32>, %arg1: memref<12x42xi32>)
@@ -22,7 +22,7 @@
module {
// CHECK-LABEL: spv.module
- // CHECK: func @reshape_4D_2D
+ // CHECK: spv.func @reshape_4D_2D
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1: memref<12x42xi32>)
@@ -42,7 +42,7 @@
module {
// CHECK-LABEL: spv.module
- // CHECK: func @reshape_2D_4D
+ // CHECK: spv.func @reshape_2D_4D
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1: memref<12x42x1x1xi32>)
@@ -62,7 +62,7 @@
module {
// CHECK-LABEL: spv.module
- // CHECK: func @reshape_2D_4D
+ // CHECK: spv.func @reshape_2D_4D
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1: memref<12x1x1x42xi32>)
@@ -82,7 +82,7 @@
module {
// CHECK-LABEL: spv.module
- // CHECK: func @reshape_2D_4D
+ // CHECK: spv.func @reshape_2D_4D
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
func @reshape_2D_4D(%arg0: memref<12x1x1x42xi32>, %arg1: memref<12x42xi32>)
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reverse.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reverse.mlir
index 4ce0160..d5013c4 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reverse.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reverse.mlir
@@ -1,7 +1,7 @@
// RUN: iree-opt -split-input-file -iree-index-computation -simplify-spirv-affine-exprs=false -convert-iree-to-spirv -verify-diagnostics -o - %s | IreeFileCheck %s
module {
- // CHECK: func @reverse_2d
+ // CHECK: spv.func @reverse_2d
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<144 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<144 x f32 [4]> [0]>, StorageBuffer>
func @reverse_2d(%arg0: memref<12x12xf32>, %arg1 : memref<12x12xf32>)
@@ -20,7 +20,7 @@
// -----
module {
- // CHECK: func @reverse_3d
+ // CHECK: spv.func @reverse_3d
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<27 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<27 x f32 [4]> [0]>, StorageBuffer>
func @reverse_3d(%arg0: memref<3x3x3xf32>, %arg1 : memref<3x3x3xf32>)
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/slice.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/slice.mlir
index b3f5c64..a556c92 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/slice.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/slice.mlir
@@ -1,7 +1,7 @@
// RUN: iree-opt -split-input-file -iree-index-computation -simplify-spirv-affine-exprs=false -convert-iree-to-spirv -verify-diagnostics -o - %s | IreeFileCheck %s
module {
- // CHECK: func @slice_unit_stride
+ // CHECK: spv.func @slice_unit_stride
// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<36 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<6 x f32 [4]> [0]>, StorageBuffer>
func @slice_unit_stride(%arg0: memref<6x6xf32>, %arg1: memref<2x3xf32>)
@@ -20,7 +20,7 @@
// -----
module {
- // CHECK: func @slice_non_unit_stride
+ // CHECK: spv.func @slice_non_unit_stride
// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<36 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<6 x f32 [4]> [0]>, StorageBuffer>
func @slice_non_unit_stride(%arg0: memref<6x6xf32>, %arg1: memref<2x3xf32>)
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/store_reduce.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/store_reduce.mlir
index a851057..5311af3 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/store_reduce.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/store_reduce.mlir
@@ -1,7 +1,7 @@
// RUN: iree-opt -split-input-file -iree-index-computation -convert-iree-to-spirv -verify-diagnostics -o - %s | IreeFileCheck %s
module {
- // CHECK: func @reduction_entry
+ // CHECK: spv.func @reduction_entry
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<5 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
@@ -23,7 +23,7 @@
// -----
module {
- // CHECK: func @reduction_2D_dim0_entry
+ // CHECK: spv.func @reduction_2D_dim0_entry
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<20 x i32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
// CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<4 x i32 [4]> [0]>, StorageBuffer>
@@ -45,7 +45,7 @@
// -----
module {
- // CHECK: func @reduction_2D_dim1_entry
+ // CHECK: spv.func @reduction_2D_dim1_entry
// CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<5 x i32 [4]> [0]>, StorageBuffer>
func @reduction_2D_dim1_entry(%arg0: memref<5x4xi32>, %arg1: memref<i32>, %arg2: memref<5xi32>) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 1 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.executable.workload = dense<[4, 5, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
// CHECK: [[GLOBALIDPTR:%[a-zA-Z0-9_]*]] = spv._address_of @globalInvocationID
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/transpose_add.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/transpose_add.mlir
index aef1ac3..a6cf7af 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/transpose_add.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/transpose_add.mlir
@@ -1,7 +1,7 @@
// RUN: iree-opt -split-input-file -iree-index-computation -simplify-spirv-affine-exprs=false -convert-iree-to-spirv -verify-diagnostics -o - %s | IreeFileCheck %s
module {
- // CHECK: func @transpose_add
+ // CHECK: spv.func @transpose_add
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<144 x f32 [4]> [0]>, StorageBuffer>
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<144 x f32 [4]> [0]>, StorageBuffer>
func @transpose_add(%arg0: memref<12x12xf32>, %arg1: memref<12x12xf32>)
diff --git a/iree/compiler/Utils/IREECodegenUtils.cpp b/iree/compiler/Utils/IREECodegenUtils.cpp
index b0b1f3b..1b65cb9 100644
--- a/iree/compiler/Utils/IREECodegenUtils.cpp
+++ b/iree/compiler/Utils/IREECodegenUtils.cpp
@@ -18,16 +18,16 @@
namespace iree_compiler {
/// Gets the launch size associated with the dispatch function.
-LogicalResult getLegacyLaunchSize(FuncOp funcOp,
+LogicalResult getLegacyLaunchSize(Operation *funcOp,
SmallVectorImpl<int64_t> &launchSize) {
- if (!funcOp.getAttr("iree.executable.export")) {
- return funcOp.emitError(
+ if (!funcOp->getAttr("iree.executable.export")) {
+ return funcOp->emitError(
"expected operation to be in dispatch function to get launch size");
}
auto workloadAttr =
- funcOp.getAttrOfType<DenseElementsAttr>("iree.executable.workload");
+ funcOp->getAttrOfType<DenseElementsAttr>("iree.executable.workload");
if (!workloadAttr) {
- return funcOp.emitError(
+ return funcOp->emitError(
"unable to find workload size, missing attribute "
"iree.executable.workload in dispatch function");
}
@@ -49,16 +49,16 @@
/// Gets the workgroup size.
template <typename intType>
-LogicalResult getLegacyWorkGroupSize(FuncOp funcOp,
+LogicalResult getLegacyWorkGroupSize(Operation *funcOp,
SmallVectorImpl<intType> &workGroupSize) {
- if (!funcOp.getAttr("iree.executable.export")) {
- return funcOp.emitError(
+ if (!funcOp->getAttr("iree.executable.export")) {
+ return funcOp->emitError(
"expected operation to be in dispatch function to get launch size");
}
- auto workGroupSizeAttr =
- funcOp.getAttrOfType<DenseElementsAttr>("iree.executable.workgroup_size");
+ auto workGroupSizeAttr = funcOp->getAttrOfType<DenseElementsAttr>(
+ "iree.executable.workgroup_size");
if (!workGroupSizeAttr) {
- return funcOp.emitError(
+ return funcOp->emitError(
"unable to find workload size, missing attribute "
"iree.executable.workload in dispatch function");
}
@@ -70,9 +70,9 @@
}
template LogicalResult getLegacyWorkGroupSize<int32_t>(
- FuncOp funcOp, SmallVectorImpl<int32_t> &workGroupSize);
+ Operation *funcOp, SmallVectorImpl<int32_t> &workGroupSize);
template LogicalResult getLegacyWorkGroupSize<int64_t>(
- FuncOp funcOp, SmallVectorImpl<int64_t> &workGroupSize);
+ Operation *funcOp, SmallVectorImpl<int64_t> &workGroupSize);
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Utils/IREECodegenUtils.h b/iree/compiler/Utils/IREECodegenUtils.h
index 908de04..3f82e60 100644
--- a/iree/compiler/Utils/IREECodegenUtils.h
+++ b/iree/compiler/Utils/IREECodegenUtils.h
@@ -26,12 +26,12 @@
// TODO(ravishankarm): remove this; it does not work with dynamic shapes.
/// Gets the launch size associated with the dispatch function.
-LogicalResult getLegacyLaunchSize(FuncOp funcOp,
+LogicalResult getLegacyLaunchSize(Operation *funcOp,
SmallVectorImpl<int64_t> &launchSize);
/// Gets the workgroup size. Has to be a static constant.
template <typename intType>
-LogicalResult getLegacyWorkGroupSize(FuncOp funcOp,
+LogicalResult getLegacyWorkGroupSize(Operation *funcOp,
SmallVectorImpl<intType> &workGroupSize);
} // namespace iree_compiler