[Stream] Improve the function signature verifier for CmdDispatchOp (#15886)

This adds additional verification to the exported function within
a dispatch that the binding counts and uniform operand types are
consistent with the dispatch op.
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index 0d006b1..ff321ba 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -2858,7 +2858,45 @@
       return failure();
     }
 
-    // TODO(benvanik): verify that the target function has matching operands.
+    // Verify that the exported function, if present, matches the signature of
+    // the dispatch.
+    auto funcOp = exportOp.lookupFunctionRef();
+    if (!funcOp) {
+      return success();
+    }
+
+    TypeRange uniformTypes = getUniformOperands().getTypes();
+    int64_t numResources = getResources().size();
+
+    auto entryPointType = funcOp.getFunctionType();
+    SmallVector<Type> uniformEntryPointTypes;
+    int64_t bindingCounts = 0;
+    for (auto entryPointArg : entryPointType.getInputs()) {
+      if (isa<IREE::Stream::BindingType>(entryPointArg)) {
+        bindingCounts++;
+      } else {
+        uniformEntryPointTypes.push_back(entryPointArg);
+      }
+    }
+    if (uniformTypes.size() != uniformEntryPointTypes.size()) {
+      return emitOpError("function type mismatch; expected ")
+             << uniformTypes.size()
+             << " uniform arguments on exported function, but has "
+             << uniformEntryPointTypes.size();
+    }
+    if (numResources != bindingCounts) {
+      return emitOpError("function type mismatch; expected ")
+             << numResources
+             << " binding arguments on exported function, but has "
+             << bindingCounts;
+    }
+    for (auto [expectedType, actualType] :
+         llvm::zip_equal(uniformTypes, uniformEntryPointTypes)) {
+      if (expectedType != actualType) {
+        return emitOpError("uniform dispatch argument type mismatch: expected ")
+               << expectedType << " but got " << actualType;
+      }
+    }
   }
   return success();
 }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/executable_ops.mlir
index 87abff0..985e3de 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/executable_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/executable_ops.mlir
@@ -67,3 +67,34 @@
     }
   }
 }
+
+// -----
+
+stream.executable private @executable {
+  stream.executable.export public @dispatch
+  builtin.module {
+    func.func @dispatch(%arg0: !stream.binding, %arg1: index) {
+      %c0 = arith.constant 0 : index
+      %0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !flow.dispatch.tensor<readwrite:tensor<?x5x64xf32>>{%arg1}
+      return
+    }
+  }
+}
+
+func.func @cmdDispatchExecutableSignatureMismatch(%arg0: !stream.resource<transient>,
+                                                  %arg1: index,
+                                                  %arg2: !stream.resource<external>,
+                                                  %arg3: index) -> !stream.timepoint {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c128 = arith.constant 128 : index
+  %0 = stream.cmd.execute with(%arg0 as %arg4: !stream.resource<transient>{%arg1}, %arg2 as %arg5: !stream.resource<external>{%arg3}) {
+    // expected-error @+1 {{function type mismatch; expected 2 binding arguments on exported function, but has 1}}
+    stream.cmd.dispatch {@executable::@dispatch}[%c1](%c2 : index) {
+      ro %arg4[%c0 for %c128] : !stream.resource<transient>{%arg1},
+      wo %arg5[%c0 for %c128] : !stream.resource<external>{%arg3}
+    }
+  } => !stream.timepoint
+  return %0 : !stream.timepoint
+}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_dispatch_arguments.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_dispatch_arguments.mlir
index c43d679..feb6913 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_dispatch_arguments.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_dispatch_arguments.mlir
@@ -36,8 +36,8 @@
     // CHECK-SAME: %arg0: i32,
     // CHECK-SAME: %arg1: index {stream.alignment = 4 : index, stream.values = [20 : index, 40 : index]},
     // CHECK-SAME: %arg2: i1 {stream.values = [false, true]},
-    // CHECK-SAME: %arg3: f32)
-    func.func @dispatch(%arg0: i32, %arg1: index, %arg2: i1, %arg3: f32) {
+    // CHECK-SAME: %arg3: f32
+    func.func @dispatch(%arg0: i32, %arg1: index, %arg2: i1, %arg3: f32, %binding: !stream.binding) {
       return
     }
   }
@@ -83,7 +83,7 @@
     // CHECK-SAME: %arg2: index {stream.values = [4096 : index, 4097 : index]},
     // CHECK-SAME: %arg3: index {stream.alignment = 16 : index, stream.values = [1200 : index, 5232 : index]}
     // CHECK-SAME: %arg4: index {stream.alignment = 1024 : index, stream.values = [1024 : index, 2048 : index]}
-    func.func @dispatch(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
+    func.func @dispatch(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %binding: !stream.binding) {
       return
     }
   }
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_dispatch_operands.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_dispatch_operands.mlir
index 5cb2c3b..4f053e9 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_dispatch_operands.mlir
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/pack_dispatch_operands.mlir
@@ -231,36 +231,41 @@
 
 // -----
 
-stream.executable private @ex5 {
-  // CHECK-LABEL: @device_complex_f32_bitcast
-  stream.executable.export public @device_complex_f32_bitcast
+stream.executable private @ex6 {
+  // CHECK-LABEL: @device_complex_f64_bitcast
+  stream.executable.export public @device_complex_f64_bitcast
   builtin.module {
-    // CHECK-LABEL: func.func @device_complex_f32
-    // CHECK-SAME: (%[[DEV_REAL_I32:.+]]: i32, %[[DEV_IMAG_I32:.+]]: i32, %arg2: !stream.binding)
-    func.func @device_complex_f32_bitcast(%arg0: complex<f32>, %arg1: !stream.binding) {
-      // CHECK-DAG: %[[DEV_REAL_F32:.+]] = arith.bitcast %[[DEV_REAL_I32]] : i32 to f32
-      // CHECK-DAG: %[[DEV_IMAG_F32:.+]] = arith.bitcast %[[DEV_IMAG_I32]] : i32 to f32
-      // CHECK-DAG: %[[DEV_COMPLEX:.+]] = complex.create %[[DEV_REAL_F32]], %[[DEV_IMAG_F32]]
+    // CHECK-LABEL: func.func @device_complex_f64
+    // CHECK-SAME: (%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32, %arg4: !stream.binding)
+    func.func @device_complex_f64_bitcast(%arg0: complex<f64>, %arg1: !stream.binding) {
+      // CHECK-COUNT-2: arith.bitcast {{.*}} : i64 to f64
+      // CHECK: %[[DEV_COMPLEX:.+]] = complex.create
       // CHECK-NEXT: util.optimization_barrier %[[DEV_COMPLEX]]
-      util.optimization_barrier %arg0 : complex<f32>
+      util.optimization_barrier %arg0 : complex<f64>
       return
     }
   }
 }
 // CHECK-LABEL: func.func @host_complex_bitcast
-func.func @host_complex_bitcast(%arg0: complex<f32>) -> !stream.timepoint {
+func.func @host_complex_bitcast(%arg0: complex<f64>) -> !stream.timepoint {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c128 = arith.constant 128 : index
   %0 = stream.resource.alloc uninitialized : !stream.resource<external>{%c128}
   // CHECK-DAG: %[[HOST_REAL_F32:.+]] = complex.re %arg0
   // CHECK-DAG: %[[HOST_IMAG_F32:.+]] = complex.im %arg0
-  // CHECK-DAG: %[[HOST_REAL_I32:.+]] = arith.bitcast %[[HOST_REAL_F32]] : f32 to i32
-  // CHECK-DAG: %[[HOST_IMAG_I32:.+]] = arith.bitcast %[[HOST_IMAG_F32]] : f32 to i32
-  %1 = complex.bitcast %arg0 : complex<f32> to i64
+  // CHECK-DAG: %[[HOST_REAL_I32:.+]] = arith.bitcast %[[HOST_REAL_F32]] : f64 to i64
+  // CHECK-DAG: %[[H_REAL_LOWER_I32:.+]] = arith.trunci %[[HOST_REAL_I32]] : i64 to i32
+  // CHECK-DAG: %[[H_REAL_UPPER_SHF:.+]] = arith.shrui %[[HOST_REAL_I32]], {{.*}} : i64
+  // CHECK-DAG: %[[H_REAL_UPPER_I32:.+]] = arith.trunci %[[H_REAL_UPPER_SHF]] : i64 to i32
+  // CHECK-DAG: %[[HOST_IMAG_I32:.+]] = arith.bitcast %[[HOST_IMAG_F32]] : f64 to i64
+  // CHECK-DAG: %[[H_IMAG_LOWER_I32:.+]] = arith.trunci %[[HOST_IMAG_I32]] : i64 to i32
+  // CHECK-DAG: %[[H_IMAG_UPPER_SHF:.+]] = arith.shrui %[[HOST_IMAG_I32]], {{.*}} : i64
+  // CHECK-DAG: %[[H_IMAG_UPPER_I32:.+]] = arith.trunci %[[H_IMAG_UPPER_SHF]] : i64 to i32
   %2 = stream.cmd.execute with(%0 as %arg1: !stream.resource<external>{%c128}) {
-    // CHECK: stream.cmd.dispatch {{.+}}(%[[HOST_REAL_I32]], %[[HOST_IMAG_I32]] : i32, i32)
-    stream.cmd.dispatch @ex5::@device_complex_f32_bitcast[%c1, %c1, %c1](%1 : i64) {
+    // CHECK: stream.cmd.dispatch {{.+}}(%[[H_REAL_LOWER_I32]], %[[H_REAL_UPPER_I32]],
+    // CHECK-SAME: %[[H_IMAG_LOWER_I32]], %[[H_IMAG_UPPER_I32]] : i32, i32, i32, i32)
+    stream.cmd.dispatch @ex6::@device_complex_f64_bitcast[%c1, %c1, %c1](%arg0 : complex<f64>) {
       wo %arg1[%c0 for %c128] : !stream.resource<external>{%c128}
     }
   } => !stream.timepoint