[Stream] Improve the function signature verifier for CmdDispatchOp (#15886)
This adds additional verification to the exported function within
a dispatch that the binding counts and uniform operand types are
consistent with the dispatch op.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index 0d006b1..ff321ba 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -2858,7 +2858,45 @@
return failure();
}
- // TODO(benvanik): verify that the target function has matching operands.
+ // Verify that the exported function, if present, matches the signature of
+ // the dispatch.
+ auto funcOp = exportOp.lookupFunctionRef();
+ if (!funcOp) {
+ return success();
+ }
+
+ TypeRange uniformTypes = getUniformOperands().getTypes();
+ int64_t numResources = getResources().size();
+
+ auto entryPointType = funcOp.getFunctionType();
+ SmallVector<Type> uniformEntryPointTypes;
+ int64_t bindingCounts = 0;
+ for (auto entryPointArg : entryPointType.getInputs()) {
+ if (isa<IREE::Stream::BindingType>(entryPointArg)) {
+ bindingCounts++;
+ } else {
+ uniformEntryPointTypes.push_back(entryPointArg);
+ }
+ }
+ if (uniformTypes.size() != uniformEntryPointTypes.size()) {
+ return emitOpError("function type mismatch; expected ")
+ << uniformTypes.size()
+ << " uniform arguments on exported function, but has "
+ << uniformEntryPointTypes.size();
+ }
+ if (numResources != bindingCounts) {
+ return emitOpError("function type mismatch; expected ")
+ << numResources
+ << " binding arguments on exported function, but has "
+ << bindingCounts;
+ }
+ for (auto [expectedType, actualType] :
+ llvm::zip_equal(uniformTypes, uniformEntryPointTypes)) {
+ if (expectedType != actualType) {
+ return emitOpError("uniform dispatch argument type mismatch: expected ")
+ << expectedType << " but got " << actualType;
+ }
+ }
}
return success();
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/executable_ops.mlir
index 87abff0..985e3de 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/executable_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/executable_ops.mlir
@@ -67,3 +67,34 @@
}
}
}
+
+// -----
+
+stream.executable private @executable {
+ stream.executable.export public @dispatch
+ builtin.module {
+ func.func @dispatch(%arg0: !stream.binding, %arg1: index) {
+ %c0 = arith.constant 0 : index
+ %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:tensor<?x5x64xf32>>{%arg1}
+ return
+ }
+ }
+}
+
+func.func @cmdDispatchExecutableSignatureMismatch(%arg0: !stream.resource<transient>,
+ %arg1: index,
+ %arg2: !stream.resource<external>,
+ %arg3: index) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c128 = arith.constant 128 : index
+ %0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<external>{%arg3}) {
+ // expected-error @+1 {{function type mismatch; expected 2 binding arguments on exported function, but has 1}}
+ stream.cmd.dispatch {@executable::@dispatch}[%c1](%c2 : index) {
+ ro %arg4[%c0 for %c128] : !stream.resource<transient>{%arg1},
+ wo %arg5[%c0 for %c128] : !stream.resource<external>{%arg3}
+ }
+ } => !stream.timepoint
+ return %0 : !stream.timepoint
+}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_dispatch_arguments.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_dispatch_arguments.mlir
index c43d679..feb6913 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_dispatch_arguments.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_dispatch_arguments.mlir
@@ -36,8 +36,8 @@
// CHECK-SAME: %arg0: i32,
// CHECK-SAME: %arg1: index {stream.alignment = 4 : index, stream.values = [20 : index, 40 : index]},
// CHECK-SAME: %arg2: i1 {stream.values = [false, true]},
- // CHECK-SAME: %arg3: f32)
- func.func @dispatch(%arg0: i32, %arg1: index, %arg2: i1, %arg3: f32) {
+ // CHECK-SAME: %arg3: f32
+ func.func @dispatch(%arg0: i32, %arg1: index, %arg2: i1, %arg3: f32, %binding: !stream.binding) {
return
}
}
@@ -83,7 +83,7 @@
// CHECK-SAME: %arg2: index {stream.values = [4096 : index, 4097 : index]},
// CHECK-SAME: %arg3: index {stream.alignment = 16 : index, stream.values = [1200 : index, 5232 : index]}
// CHECK-SAME: %arg4: index {stream.alignment = 1024 : index, stream.values = [1024 : index, 2048 : index]}
- func.func @dispatch(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
+ func.func @dispatch(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %binding: !stream.binding) {
return
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_dispatch_operands.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_dispatch_operands.mlir
index 5cb2c3b..4f053e9 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_dispatch_operands.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_dispatch_operands.mlir
@@ -231,36 +231,41 @@
// -----
-stream.executable private @ex5 {
- // CHECK-LABEL: @device_complex_f32_bitcast
- stream.executable.export public @device_complex_f32_bitcast
+stream.executable private @ex6 {
+ // CHECK-LABEL: @device_complex_f64_bitcast
+ stream.executable.export public @device_complex_f64_bitcast
builtin.module {
- // CHECK-LABEL: func.func @device_complex_f32
- // CHECK-SAME: (%[[DEV_REAL_I32:.+]]: i32, %[[DEV_IMAG_I32:.+]]: i32, %arg2: !stream.binding)
- func.func @device_complex_f32_bitcast(%arg0: complex<f32>, %arg1: !stream.binding) {
- // CHECK-DAG: %[[DEV_REAL_F32:.+]] = arith.bitcast %[[DEV_REAL_I32]] : i32 to f32
- // CHECK-DAG: %[[DEV_IMAG_F32:.+]] = arith.bitcast %[[DEV_IMAG_I32]] : i32 to f32
- // CHECK-DAG: %[[DEV_COMPLEX:.+]] = complex.create %[[DEV_REAL_F32]], %[[DEV_IMAG_F32]]
+ // CHECK-LABEL: func.func @device_complex_f64
+ // CHECK-SAME: (%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %arg4: !stream.binding)
+ func.func @device_complex_f64_bitcast(%arg0: complex<f64>, %arg1: !stream.binding) {
+ // CHECK-COUNT-2: arith.bitcast {{.*}} : i64 to f64
+ // CHECK: %[[DEV_COMPLEX:.+]] = complex.create
// CHECK-NEXT: util.optimization_barrier %[[DEV_COMPLEX]]
- util.optimization_barrier %arg0 : complex<f32>
+ util.optimization_barrier %arg0 : complex<f64>
return
}
}
}
// CHECK-LABEL: func.func @host_complex_bitcast
-func.func @host_complex_bitcast(%arg0: complex<f32>) -> !stream.timepoint {
+func.func @host_complex_bitcast(%arg0: complex<f64>) -> !stream.timepoint {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c128 = arith.constant 128 : index
%0 = stream.resource.alloc uninitialized : !stream.resource<external>{%c128}
// CHECK-DAG: %[[HOST_REAL_F32:.+]] = complex.re %arg0
// CHECK-DAG: %[[HOST_IMAG_F32:.+]] = complex.im %arg0
- // CHECK-DAG: %[[HOST_REAL_I32:.+]] = arith.bitcast %[[HOST_REAL_F32]] : f32 to i32
- // CHECK-DAG: %[[HOST_IMAG_I32:.+]] = arith.bitcast %[[HOST_IMAG_F32]] : f32 to i32
- %1 = complex.bitcast %arg0 : complex<f32> to i64
+ // CHECK-DAG: %[[HOST_REAL_I32:.+]] = arith.bitcast %[[HOST_REAL_F32]] : f64 to i64
+ // CHECK-DAG: %[[H_REAL_LOWER_I32:.+]] = arith.trunci %[[HOST_REAL_I32]] : i64 to i32
+ // CHECK-DAG: %[[H_REAL_UPPER_SHF:.+]] = arith.shrui %[[HOST_REAL_I32]], {{.*}} : i64
+ // CHECK-DAG: %[[H_REAL_UPPER_I32:.+]] = arith.trunci %[[H_REAL_UPPER_SHF]] : i64 to i32
+ // CHECK-DAG: %[[HOST_IMAG_I32:.+]] = arith.bitcast %[[HOST_IMAG_F32]] : f64 to i64
+ // CHECK-DAG: %[[H_IMAG_LOWER_I32:.+]] = arith.trunci %[[HOST_IMAG_I32]] : i64 to i32
+ // CHECK-DAG: %[[H_IMAG_UPPER_SHF:.+]] = arith.shrui %[[HOST_IMAG_I32]], {{.*}} : i64
+ // CHECK-DAG: %[[H_IMAG_UPPER_I32:.+]] = arith.trunci %[[H_IMAG_UPPER_SHF]] : i64 to i32
%2 = stream.cmd.execute with(%0 as %arg1: !stream.resource<external>{%c128}) {
- // CHECK: stream.cmd.dispatch {{.+}}(%[[HOST_REAL_I32]], %[[HOST_IMAG_I32]] : i32, i32)
- stream.cmd.dispatch @ex5::@device_complex_f32_bitcast[%c1, %c1, %c1](%1 : i64) {
+ // CHECK: stream.cmd.dispatch {{.+}}(%[[H_REAL_LOWER_I32]], %[[H_REAL_UPPER_I32]],
+ // CHECK-SAME: %[[H_IMAG_LOWER_I32]], %[[H_IMAG_UPPER_I32]] : i32, i32, i32, i32)
+ stream.cmd.dispatch @ex6::@device_complex_f64_bitcast[%c1, %c1, %c1](%arg0 : complex<f64>) {
wo %arg1[%c0 for %c128] : !stream.resource<external>{%c128}
}
} => !stream.timepoint