Remove use of iree.executable.workload attribute from SPIR-V translation.

With on-going work on dynamic shapes, use of attributes to specify the
workload is not feasible. Instead treat the shape of the outputs as
the launch size.  Also move helper functions to get launch size,
workgroup size, etc. into translation folder, deprecating
IREECodegenUtils.h/.cpp

PiperOrigin-RevId: 295999426
diff --git a/iree/compiler/Translation/CMakeLists.txt b/iree/compiler/Translation/CMakeLists.txt
index 9dc0dc7..e08c48c 100644
--- a/iree/compiler/Translation/CMakeLists.txt
+++ b/iree/compiler/Translation/CMakeLists.txt
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+add_subdirectory(CodegenUtils)
 add_subdirectory(Interpreter)
 add_subdirectory(SPIRV)
 add_subdirectory(XLAToLinalg)
diff --git a/iree/compiler/Translation/CodegenUtils/BUILD b/iree/compiler/Translation/CodegenUtils/BUILD
new file mode 100644
index 0000000..5c1dc39
--- /dev/null
+++ b/iree/compiler/Translation/CodegenUtils/BUILD
@@ -0,0 +1,38 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Utilities for working with IREE MLIR types.
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+cc_library(
+    name = "CodegenUtils",
+    srcs = [
+        "CodegenUtils.cpp",
+    ],
+    hdrs = [
+        "CodegenUtils.h",
+    ],
+    deps = [
+        "//iree/compiler/Dialect/Flow/Utils",
+        "//iree/compiler/Dialect/IREE/IR",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:Support",
+    ],
+    alwayslink = 1,
+)
diff --git a/iree/compiler/Translation/CodegenUtils/CMakeLists.txt b/iree/compiler/Translation/CodegenUtils/CMakeLists.txt
new file mode 100644
index 0000000..125089a
--- /dev/null
+++ b/iree/compiler/Translation/CodegenUtils/CMakeLists.txt
@@ -0,0 +1,30 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+iree_cc_library(
+  NAME
+    CodegenUtils
+  HDRS
+    "CodegenUtils.h"
+  SRCS
+    "CodegenUtils.cpp"
+  DEPS
+    MLIRIR
+    MLIRStandardOps
+    MLIRSupport
+    iree::compiler::Dialect::Flow::Utils
+    iree::compiler::Dialect::IREE::IR
+  ALWAYSLINK
+  PUBLIC
+)
diff --git a/iree/compiler/Translation/CodegenUtils/CodegenUtils.cpp b/iree/compiler/Translation/CodegenUtils/CodegenUtils.cpp
new file mode 100644
index 0000000..2b92b93
--- /dev/null
+++ b/iree/compiler/Translation/CodegenUtils/CodegenUtils.cpp
@@ -0,0 +1,130 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Translation/CodegenUtils/CodegenUtils.h"
+
+#include "iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+ArrayRef<int64_t> dropTrailingOnes(ArrayRef<int64_t> vector) {
+  if (vector.empty()) return vector;
+  auto numTrailingOnes = 0;
+  for (unsigned i = vector.size() - 1; i > 0; --i) {
+    if (vector[i] != 1) {
+      break;
+    }
+    numTrailingOnes++;
+  }
+  return vector.drop_back(numTrailingOnes);
+}
+
+bool isDispatchFunction(FuncOp funcOp) {
+  return funcOp.getAttr("iree.executable.export") != nullptr;
+}
+
+/// Helper function to check shapes are equal.
+static bool areShapesEqual(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs) {
+  return lhs == rhs;
+}
+
+/// Get the shape to use for a type. For now this is returning shapes as static
+/// value.
+// TODO(ravishankarm) : Modify this to return the Values to use for the extent.
+static LogicalResult getExtentFromStoreOpSrc(Operation *storeOp,
+                                             SmallVectorImpl<int64_t> &extent) {
+  extent.clear();
+  extent.resize(3, 1);
+  Value srcVal = storeOp->getOperand(0);
+  if (srcVal.getType().isIntOrFloat()) return success();
+  auto workload = dyn_cast_or_null<ConstantOp>(
+      IREE::Flow::calculateWorkload(srcVal.getDefiningOp(), srcVal)
+          .getDefiningOp());
+  if (!workload) return failure();
+  auto extentVal = workload.getValue().cast<DenseIntElementsAttr>();
+  for (auto val : enumerate(extentVal.getValues<APInt>()))
+    extent[val.index()] = val.value().getSExtValue();
+  workload.erase();
+  return success();
+}
+
+LogicalResult getLaunchSize(FuncOp funcOp,
+                            SmallVectorImpl<int64_t> &launchSize) {
+  auto &body = funcOp.getBody();
+  if (!mlir::has_single_element(body)) {
+    return funcOp.emitError(
+        "unhandled multiple blocks within dispatch function");
+  }
+  SmallVector<Operation *, 1> storeOperations;
+  auto storeOps = body.front().getOps<IREE::StoreOutputOp>();
+  storeOperations.assign(storeOps.begin(), storeOps.end());
+  auto storeReduceOps = body.front().getOps<IREE::StoreReduceOp>();
+  storeOperations.append(storeReduceOps.begin(), storeReduceOps.end());
+  if (storeOperations.begin() == storeOperations.end()) {
+    return funcOp.emitError(
+        "expected dispatch function to have at least one iree.store_output "
+        "instruction");
+  }
+
+  Operation *firstStoreOp = *storeOperations.begin();
+  if (failed(getExtentFromStoreOpSrc(firstStoreOp, launchSize))) {
+    return firstStoreOp->emitError("unhandled type of the output tensor");
+  }
+  for (auto it = std::next(storeOperations.begin()), ie = storeOperations.end();
+       it != ie; ++it) {
+    SmallVector<int64_t, 3> checkShape;
+    Operation *storeOp = *it;
+    if (failed(getExtentFromStoreOpSrc(storeOp, checkShape))) {
+      return storeOp->emitError("unhandled type of the output tensor");
+    }
+    if (!areShapesEqual(launchSize, checkShape)) {
+      return storeOp->emitError("mismatch in shapes of the output tensors");
+    }
+  }
+  return success();
+}
+
+/// Gets the workgroup size.
+template <typename intType>
+LogicalResult getWorkGroupSize(FuncOp funcOp,
+                               SmallVectorImpl<intType> &workGroupSize) {
+  if (!funcOp.getAttr("iree.executable.export")) {
+    return funcOp.emitError(
+        "expected operation to be in dispatch function to get launch size");
+  }
+  auto workGroupSizeAttr =
+      funcOp.getAttrOfType<DenseElementsAttr>("iree.executable.workgroup_size");
+  if (!workGroupSizeAttr) {
+    return funcOp.emitError(
+        "unable to find workload size, missing attribute "
+        "iree.executable.workload in dispatch function");
+  }
+  workGroupSize.clear();
+  for (auto value : workGroupSizeAttr.getValues<APInt>()) {
+    workGroupSize.push_back(value.getSExtValue());
+  }
+  return success();
+}
+
+template LogicalResult getWorkGroupSize<int32_t>(
+    FuncOp funcOp, SmallVectorImpl<int32_t> &workGroupSize);
+template LogicalResult getWorkGroupSize<int64_t>(
+    FuncOp funcOp, SmallVectorImpl<int64_t> &workGroupSize);
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Translation/CodegenUtils/CodegenUtils.h b/iree/compiler/Translation/CodegenUtils/CodegenUtils.h
new file mode 100644
index 0000000..bc5be2a
--- /dev/null
+++ b/iree/compiler/Translation/CodegenUtils/CodegenUtils.h
@@ -0,0 +1,43 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_COMPILER_TRANSLATION_CODEGENUTILS_CODEGENUTILS_H
+#define IREE_COMPILER_TRANSLATION_CODEGENUTILS_CODEGENUTILS_H
+
+#include "mlir/IR/Function.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Drop trailing ones.
+ArrayRef<int64_t> dropTrailingOnes(ArrayRef<int64_t> vector);
+
+/// Checks that a given function is a dispatch function.
+bool isDispatchFunction(FuncOp funcOp);
+
+/// The launch size is the size of the outputs of the kernel. For now all
+/// outputs have to be the same shape and static shaped.
+LogicalResult getLaunchSize(FuncOp funcOp,
+                            SmallVectorImpl<int64_t> &launchSize);
+
+/// Gets the workgroup size. Has to be a static constant.
+template <typename intType>
+LogicalResult getWorkGroupSize(FuncOp funcOp,
+                               SmallVectorImpl<intType> &workGroupSize);
+
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_COMPILER_TRANSLATION_CODEGENUTILS_CODEGENUTILS_H
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/BUILD b/iree/compiler/Translation/SPIRV/IndexComputation/BUILD
index 25a9d61..da35230 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/BUILD
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/BUILD
@@ -60,8 +60,8 @@
     ],
     deps = [
         "//iree/compiler/Dialect/IREE/IR",
+        "//iree/compiler/Translation/CodegenUtils",
         "//iree/compiler/Translation/SPIRV/IndexComputation:IndexComputationAttrGen",
-        "//iree/compiler/Utils",
         "@llvm-project//llvm:support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/CMakeLists.txt b/iree/compiler/Translation/SPIRV/IndexComputation/CMakeLists.txt
index a94c898..760a1d2 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/CMakeLists.txt
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/CMakeLists.txt
@@ -48,8 +48,8 @@
     MLIRSPIRV
     MLIRStandardOps
     iree::compiler::Dialect::IREE::IR
+    iree::compiler::Translation::CodegenUtils
     iree::compiler::Translation::SPIRV::IndexComputation::IndexComputationAttrGen
-    iree::compiler::Utils
     tensorflow::mlir_xla
   ALWAYSLINK
   PUBLIC
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/IREEIndexComputation.cpp b/iree/compiler/Translation/SPIRV/IndexComputation/IREEIndexComputation.cpp
index 9199416..3d7b72b 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/IREEIndexComputation.cpp
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/IREEIndexComputation.cpp
@@ -49,7 +49,7 @@
   }
 
   SmallVector<int64_t, 3> launchSize;
-  if (failed(getLegacyLaunchSize(funcOp, launchSize))) {
+  if (failed(getLaunchSize(funcOp, launchSize))) {
     return failure();
   }
 
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/IREEIndexComputation.h b/iree/compiler/Translation/SPIRV/IndexComputation/IREEIndexComputation.h
index 3a1ef6d..0919ea6 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/IREEIndexComputation.h
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/IREEIndexComputation.h
@@ -20,8 +20,9 @@
 #ifndef IREE_COMPILER_TRANSLATION_SPIRV_INDEXCOMPUTATION_IREEINDEXCOMP_H
 #define IREE_COMPILER_TRANSLATION_SPIRV_INDEXCOMPUTATION_IREEINDEXCOMP_H
 
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "iree/compiler/Translation/CodegenUtils/CodegenUtils.h"
 #include "iree/compiler/Translation/SPIRV/IndexComputation/XLAIndexPropagation.h"
-#include "iree/compiler/Utils/IREECodegenUtils.h"
 #include "mlir/IR/Function.h"
 
 namespace mlir {
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/IndexComputation.h b/iree/compiler/Translation/SPIRV/IndexComputation/IndexComputation.h
index 236d84b..dbbcf66 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/IndexComputation.h
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/IndexComputation.h
@@ -21,8 +21,8 @@
 #ifndef IREE_COMPILER_TRANSLATION_SPIRV_INDEXCOMPUTATION_INDEXCOMPUTATION_H
 #define IREE_COMPILER_TRANSLATION_SPIRV_INDEXCOMPUTATION_INDEXCOMPUTATION_H
 
+#include "iree/compiler/Translation/CodegenUtils/CodegenUtils.h"
 #include "iree/compiler/Translation/SPIRV/IndexComputation/IndexComputationAttribute.h"
-#include "iree/compiler/Utils/IREECodegenUtils.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
@@ -298,7 +298,7 @@
     // Set the attribute for the number of launch dims.
     auto funcOp = region.getParentOfType<FuncOp>();
     SmallVector<int64_t, 3> launchSize;
-    if (failed(getLegacyLaunchSize(funcOp, launchSize))) {
+    if (failed(getLaunchSize(funcOp, launchSize))) {
       return emitError(region.getLoc(),
                        "expected region of index propagation to be in dispatch "
                        "function to get launch size");
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/test/broadcast.mlir b/iree/compiler/Translation/SPIRV/IndexComputation/test/broadcast.mlir
index 3017f54..5d26e92 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/test/broadcast.mlir
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/test/broadcast.mlir
@@ -9,7 +9,7 @@
   // CHECK-SAME: result_index
   // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
   func @broadcast_2D_3D(%arg0: memref<12x42xi32>, %arg1: memref<3x12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 3]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
     // CHECK: xla_hlo.broadcast
     // CHECK-SAME: iree.index_computation_info
@@ -34,7 +34,7 @@
   // CHECK-SAME: result_index
   // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0)>]
   func @broadcast_scalar_3D(%arg0: memref<i32>, %arg1: memref<3x12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 3]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<i32>) : tensor<i32>
     // CHECK: xla_hlo.broadcast
     // CHECK-SAME: iree.index_computation_info
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/test/broadcast_in_dim.mlir b/iree/compiler/Translation/SPIRV/IndexComputation/test/broadcast_in_dim.mlir
index 0d21047..a6b01c2 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/test/broadcast_in_dim.mlir
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/test/broadcast_in_dim.mlir
@@ -9,7 +9,7 @@
   // CHECK-SAME: result_index
   // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
   func @broadcast_in_dim_2D_3D(%arg0: memref<12x42xi32>, %arg1: memref<3x12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 3]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
     // CHECK: xla_hlo.broadcast_in_dim
     // CHECK-SAME: iree.index_computation_info
@@ -34,7 +34,7 @@
   // CHECK-SAME: result_index
   // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0)>]
   func @broadcast_in_dim_scalar_3D(%arg0: memref<i32>, %arg1: memref<3x12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 3]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<i32>) : tensor<i32>
     // CHECK: xla_hlo.broadcast_in_dim
     // CHECK-SAME: iree.index_computation_info
@@ -52,20 +52,20 @@
 
 module {
   func @const_float_splat(%arg0: memref<12x42xf32>)
-    attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+    attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: constant
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
     // CHECK-SAME: []
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1)>]
     %0 = constant dense<1.0> : tensor<12xf32>
     // CHECK: xla_hlo.broadcast_in_dim
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %1 = "xla_hlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : (tensor<12xf32>) -> tensor<12x42xf32>
     iree.store_output(%1 : tensor<12x42xf32>, %arg0 : memref<12x42xf32>)
     iree.return
@@ -76,20 +76,20 @@
 
 module {
   func @const_int_splat(%arg0: memref<12x42xi32>)
-    attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+    attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: constant
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
     // CHECK-SAME: []
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1)>]
     %0 = constant dense<42> : tensor<12xi32>
     // CHECK: xla_hlo.broadcast_in_dim
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %1 = "xla_hlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : (tensor<12xi32>) -> tensor<12x42xi32>
     iree.store_output(%1 : tensor<12x42xi32>, %arg0 : memref<12x42xi32>)
     iree.return
@@ -100,7 +100,7 @@
 
 module {
   func @const_int_nonsplat(%arg0: memref<2x12x42xi32>)
-    attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 2]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+    attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: constant
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
@@ -124,7 +124,7 @@
 
 module {
   func @zero_element_1dtensor(%arg0 : memref<f32>, %arg1 : memref<4xf32>)
-    attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+    attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<f32>) : tensor<f32>
     // CHECK: xla_hlo.broadcast_in_dim
     // CHECK-SAME: iree.index_computation_info
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/test/concatenate.mlir b/iree/compiler/Translation/SPIRV/IndexComputation/test/concatenate.mlir
index 0dd53db..7d14b70 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/test/concatenate.mlir
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/test/concatenate.mlir
@@ -7,24 +7,24 @@
   // CHECK-SAME: operand_indices
   // CHECK-SAME: []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (0, d1)>]
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0, d0)>]
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: memref<1x10xf32>
   // CHECK-SAME: iree.index_computation_info
   // CHECK-SAME: operand_indices
   // CHECK-SAME: []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (0, d1 - 64)>]
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0, d0 - 64)>]
   func @concatenate(%arg0: memref<1x64xf32>, %arg1 : memref<1x10xf32>, %arg2 : memref<1x74xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[1, 74]> : tensor<2xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<1x64xf32>) : tensor<1x64xf32>
     %1 = iree.load_input(%arg1 : memref<1x10xf32>) : tensor<1x10xf32>
     // CHECK: xla_hlo.concatenate
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (0, d1)>]
-    // CHECK-SAME: [affine_map<(d0, d1) -> (0, d1 - 64)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0, d0 - 64)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (0, d1)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0, d0)>]
     %2 = "xla_hlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<1x64xf32>, tensor<1x10xf32>) -> tensor<1x74xf32>
     iree.store_output(%2 : tensor<1x74xf32>, %arg2 : memref<1x74xf32>)
     iree.return
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/test/copy.mlir b/iree/compiler/Translation/SPIRV/IndexComputation/test/copy.mlir
index 93bb5f6..961c0da 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/test/copy.mlir
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/test/copy.mlir
@@ -7,22 +7,22 @@
    // CHECK-SAME: operand_indices
    // CHECK-SAME: []
    // CHECK-SAME: result_index
-   // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+   // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
   func @simple_load_store(%arg0: memref<12x42xi32>, %arg1: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: iree.load_input
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
     // CHECK: xla_hlo.copy
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %1 = "xla_hlo.copy"(%0) : (tensor<12x42xi32>) -> tensor<12x42xi32>
     iree.store_output(%1 : tensor<12x42xi32>, %arg1 : memref<12x42xi32>)
     iree.return
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/test/gather.mlir b/iree/compiler/Translation/SPIRV/IndexComputation/test/gather.mlir
index 66860da..98269d5 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/test/gather.mlir
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/test/gather.mlir
@@ -17,7 +17,7 @@
   // CHECK-SAME: iree.symbol_number_info
   // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0)>, 0 : i32]
   func @foo(%arg0: memref<5x1x10xf32>, %arg1: memref<i64>, %arg2: memref<1x10xf32>)
-  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.executable.workload = dense<[10, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<5x1x10xf32>) : tensor<5x1x10xf32>
     %1 = iree.load_input(%arg1 : memref<i64>) : tensor<i64>
     %2 = "xla_hlo.gather"(%0, %1) {
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/test/pad.mlir b/iree/compiler/Translation/SPIRV/IndexComputation/test/pad.mlir
index fde7bdf..248575e 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/test/pad.mlir
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/test/pad.mlir
@@ -7,18 +7,18 @@
   // CHECK-SAME: operand_indices
   // CHECK-SAME: []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (d1 - 4, d0 - 5)>]
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1 - 4, d0 - 5)>]
   func @pad_zero_interior(%arg0 : memref<12x4xf32>, %arg1 : memref<18x12xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12, 18, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x4xf32>) : tensor<12x4xf32>
     %1 = constant dense<0.0> : tensor<f32>
     // CHECK: xla_hlo.pad
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1 - 4, d0 - 5)>]
-    // CHECK-SAME: [affine_map<(d0, d1) -> (0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1 - 4, d0 - 5)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %2 = "xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 3]> : tensor<2xi64>, edge_padding_low = dense<[4, 5]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<12x4xf32>, tensor<f32>) -> tensor<18x12xf32>
     iree.store_output(%2 : tensor<18x12xf32>, %arg1 : memref<18x12xf32>)
     iree.return
@@ -33,18 +33,18 @@
   // CHECK-SAME: iree.index_computation_info
   // CHECK-SAME: operand_indices = []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
   func @pad_no_op(%arg0 : memref<12x4xf32>, %arg1 : memref<12x4xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[4, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x4xf32>) : tensor<12x4xf32>
     %1 = constant dense<0.0> : tensor<f32>
     // CHECK: xla_hlo.pad
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
-    // CHECK-SAME: [affine_map<(d0, d1) -> (0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %2 = "xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[0, 0]> : tensor<2xi64>, edge_padding_low = dense<[0, 0]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<12x4xf32>, tensor<f32>) -> tensor<12x4xf32>
     iree.store_output(%2 : tensor<12x4xf32>, %arg1 : memref<12x4xf32>)
     iree.return
@@ -60,18 +60,18 @@
   // CHECK-SAME: operand_indices
   // CHECK-SAME: []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (d1 floordiv 2 - 2, (d0 - 5) floordiv 3)>]
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1 floordiv 2 - 2, (d0 - 5) floordiv 3)>]
   func @pad_with_stride(%arg0 : memref<12x4xf32>, %arg1 : memref<29x18xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[18, 29, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x4xf32>) : tensor<12x4xf32>
     %1 = constant dense<0.0> : tensor<f32>
     // CHECK: xla_hlo.pad
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1 floordiv 2 - 2, (d0 - 5) floordiv 3)>]
-    // CHECK-SAME: [affine_map<(d0, d1) -> (0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1 floordiv 2 - 2, (d0 - 5) floordiv 3)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %2 = "xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 3]> : tensor<2xi64>, edge_padding_low = dense<[4, 5]> : tensor<2xi64>, interior_padding = dense<[1, 2]> : tensor<2xi64>} : (tensor<12x4xf32>, tensor<f32>) -> tensor<29x18xf32>
     iree.store_output(%2 : tensor<29x18xf32>, %arg1 : memref<29x18xf32>)
     iree.return
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/test/reverse.mlir b/iree/compiler/Translation/SPIRV/IndexComputation/test/reverse.mlir
index db34482..2c5cce7 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/test/reverse.mlir
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/test/reverse.mlir
@@ -7,16 +7,16 @@
   // CHECK-SAME: operand_indices
   // CHECK-SAME: []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (-d1 + 11, -d0 + 11)>]
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (-d1 + 11, -d0 + 11)>]
   func @reverse_2d(%arg0: memref<12x12xf32>, %arg1 : memref<12x12xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12, 12]> : tensor<2xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x12xf32>) : tensor<12x12xf32>
     // CHECK: xla_hlo.reverse
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (-d1 + 11, -d0 + 11)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (-d1 + 11, -d0 + 11)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %1 = "xla_hlo.reverse"(%0) {dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<12x12xf32>) -> tensor<12x12xf32>
     iree.store_output(%1 : tensor<12x12xf32>, %arg1 : memref<12x12xf32>)
     iree.return
@@ -34,7 +34,7 @@
   // CHECK-SAME: result_index
   // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d2, -d1 + 2, d0)>]
   func @reverse_3d(%arg0: memref<3x3x3xf32>, %arg1 : memref<3x3x3xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[3, 3, 3]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<3x3x3xf32>) : tensor<3x3x3xf32>
     // CHECK: xla_hlo.reverse
     // CHECK-SAME: iree.index_computation_info
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/test/slice.mlir b/iree/compiler/Translation/SPIRV/IndexComputation/test/slice.mlir
index ea5c501..e5cd2d1 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/test/slice.mlir
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/test/slice.mlir
@@ -7,16 +7,16 @@
   // CHECK-SAME: operand_indices
   // CHECK-SAME: []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (d0 floordiv 3 + 2, d0 mod 3 + 1)>]
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1 + 2, d0 + 1)>]
   func @slice_unit_stride(%arg0: memref<6x6xf32>, %arg1: memref<2x3xf32>)
-  attributes {iree.executable.export, iree.executable.workload = dense<[6, 1]> : tensor<2xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<6x6xf32>) : tensor<6x6xf32>
     // CHECK: xla_hlo.slice
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d0 floordiv 3 + 2, d0 mod 3 + 1)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1 + 2, d0 + 1)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d0 floordiv 3, d0 mod 3)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %1 = "xla_hlo.slice"(%0) {start_indices = dense<[2, 1]> : tensor<2xi64>, limit_indices = dense<[4, 4]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64>} : (tensor<6x6xf32>) -> tensor<2x3xf32>
     iree.store_output(%1 : tensor<2x3xf32>, %arg1 : memref<2x3xf32>)
     iree.return
@@ -32,16 +32,16 @@
   // CHECK-SAME: operand_indices
   // CHECK-SAME: []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (d0 floordiv 3 + 2, (d0 mod 3) * 2 + 1)>]
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1 + 2,  d0 * 2 + 1)>]
   func @slice_non_unit_stride(%arg0: memref<6x6xf32>, %arg1: memref<2x3xf32>)
-  attributes {iree.executable.export, iree.executable.workload = dense<[6, 1]> : tensor<2xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<6x6xf32>) : tensor<6x6xf32>
     // CHECK: xla_hlo.slice
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d0 floordiv 3 + 2, (d0 mod 3) * 2 + 1)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1 + 2, d0 * 2 + 1)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d0 floordiv 3, d0 mod 3)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %1 = "xla_hlo.slice"(%0) {start_indices = dense<[2, 1]> : tensor<2xi64>, limit_indices = dense<[4, 6]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<6x6xf32>) -> tensor<2x3xf32>
     iree.store_output(%1 : tensor<2x3xf32>, %arg1 : memref<2x3xf32>)
     iree.return
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/test/store_reduce.mlir b/iree/compiler/Translation/SPIRV/IndexComputation/test/store_reduce.mlir
index d9c0501..cbe345c 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/test/store_reduce.mlir
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/test/store_reduce.mlir
@@ -14,7 +14,7 @@
   // CHECK-SAME: []
   // CHECK-SAME: result_index
   // CHECK-SAME: [affine_map<(d0, d1, d2) -> (0)>]
-  func @reduction_entry(%arg0: memref<5xi32>, %arg1: memref<i32>, %arg2: memref<i32>) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 0 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.executable.workload = dense<[5, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  func @reduction_entry(%arg0: memref<5xi32>, %arg1: memref<i32>, %arg2: memref<i32>) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 0 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<5xi32>)  : tensor<5xi32>
     iree.store_reduce(%0 : tensor<5xi32>, %arg2 : memref<i32>, @reduction_apply)
     iree.return
@@ -30,14 +30,14 @@
   // CHECK-SAME: operand_indices
   // CHECK-SAME: []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
   // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: memref<4xi32>
   // CHECK-SAME: iree.index_computation_info
   // CHECK-SAME: operand_indices
   // CHECK-SAME: []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (d0)>]
-  func @reduction_2D_dim0_entry(%arg0: memref<5x4xi32>, %arg1: memref<i32>, %arg2: memref<4xi32>) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 0 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.executable.workload = dense<[4, 5, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d0)>]
+  func @reduction_2D_dim0_entry(%arg0: memref<5x4xi32>, %arg1: memref<i32>, %arg2: memref<4xi32>) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 0 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<5x4xi32>)  : tensor<5x4xi32>
     iree.store_reduce(%0 : tensor<5x4xi32>, %arg2 : memref<4xi32>, @reduction_apply)
     iree.return
@@ -53,14 +53,14 @@
   // CHECK-SAME: operand_indices
   // CHECK-SAME: []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
   // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: memref<5xi32>
   // CHECK-SAME: iree.index_computation_info
   // CHECK-SAME: operand_indices
   // CHECK-SAME: []
   // CHECK-SAME: result_index
-  // CHECK-SAME: [affine_map<(d0, d1) -> (d1)>]
-  func @reduction_2D_dim1_entry(%arg0: memref<5x4xi32>, %arg1: memref<i32>, %arg2: memref<5xi32>) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 1 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.executable.workload = dense<[4, 5, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1)>]
+  func @reduction_2D_dim1_entry(%arg0: memref<5x4xi32>, %arg1: memref<i32>, %arg2: memref<5xi32>) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 1 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<5x4xi32>)  : tensor<5x4xi32>
     iree.store_reduce(%0 : tensor<5x4xi32>, %arg2 : memref<5xi32>, @reduction_apply)
     iree.return
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/test/transpose_add.mlir b/iree/compiler/Translation/SPIRV/IndexComputation/test/transpose_add.mlir
index 60c2e25..853a474 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/test/transpose_add.mlir
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/test/transpose_add.mlir
@@ -7,38 +7,38 @@
  // CHECK-SAME: operand_indices
  // CHECK-SAME: []
  // CHECK-SAME: result_index
- // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+ // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
  // CHECK-SAME: operand_indices
  // CHECK-SAME: []
  // CHECK-SAME: result_index
- // CHECK-SAME: [affine_map<(d0, d1) -> (d0, d1)>]
+ // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d0, d1)>]
  func @transpose_add(%arg0: memref<12x12xf32>, %arg1: memref<12x12xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: iree.load_input
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d0, d1)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d0, d1)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d0, d1)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d0, d1)>]
     %0 = iree.load_input(%arg0 : memref<12x12xf32>) : tensor<12x12xf32>
     // CHECK: xla_hlo.transpose
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d0, d1)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d0, d1)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %1 = "xla_hlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<12x12xf32>) -> tensor<12x12xf32>
     // CHECK: xla_hlo.add
     // CHECK-SAME: iree.index_computation_info
     // CHECK-SAME: operand_indices
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     // CHECK-SAME: result_index
-    // CHECK-SAME: [affine_map<(d0, d1) -> (d1, d0)>]
+    // CHECK-SAME: [affine_map<(d0, d1, d2) -> (d1, d0)>]
     %2 = xla_hlo.add %0, %1 : tensor<12x12xf32>
     iree.store_output(%2 : tensor<12x12xf32>, %arg1 : memref<12x12xf32>)
     iree.return
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
index 9677f61..cf03821 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/BUILD
@@ -29,8 +29,8 @@
     ],
     deps = [
         "//iree/compiler/Dialect/IREE/IR",
+        "//iree/compiler/Translation/CodegenUtils",
         "//iree/compiler/Translation/XLAToLinalg:IREELinalgTensorToBuffer",
-        "//iree/compiler/Utils",
         "@llvm-project//llvm:support",
         "@llvm-project//mlir:AffineOps",
         "@llvm-project//mlir:AffineToStandardTransforms",
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
index a18b6fa..83f9713 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/CMakeLists.txt
@@ -43,8 +43,8 @@
     MLIRSupport
     MLIRTransforms
     iree::compiler::Dialect::IREE::IR
+    iree::compiler::Translation::CodegenUtils
     iree::compiler::Translation::XLAToLinalg::IREELinalgTensorToBuffer
-    iree::compiler::Utils
     tensorflow::mlir_xla
   ALWAYSLINK
   PUBLIC
diff --git a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
index dfe6090..128a9e9 100644
--- a/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
+++ b/iree/compiler/Translation/SPIRV/LinalgToSPIRV/LowerToSPIRV.cpp
@@ -18,9 +18,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "iree/compiler/Translation/CodegenUtils/CodegenUtils.h"
 #include "iree/compiler/Translation/SPIRV/LinalgToSPIRV/Passes.h"
 #include "iree/compiler/Translation/XLAToLinalg/LinalgTensorToBuffer.h"
-#include "iree/compiler/Utils/IREECodegenUtils.h"
 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
 #include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
@@ -49,18 +49,6 @@
 namespace mlir {
 namespace iree_compiler {
 
-static ArrayRef<int64_t> dropTrailingOnes(ArrayRef<int64_t> vector) {
-  if (vector.empty()) return vector;
-  auto numTrailingOnes = 0;
-  for (unsigned i = vector.size() - 1; i > 0; --i) {
-    if (vector[i] != 1) {
-      break;
-    }
-    numTrailingOnes++;
-  }
-  return vector.drop_back(numTrailingOnes);
-}
-
 namespace {
 /// These options are only for testing purposes. For actual execution with IREE,
 /// these are computed by IREE/Backends automatically.
@@ -109,7 +97,7 @@
     FuncOp funcOp = getFunction();
     SmallVector<int64_t, 3> workGroupSizeVec;
     workGroupSizeVec.reserve(3);
-    if (failed(getLegacyWorkGroupSize(funcOp, workGroupSizeVec))) return;
+    if (failed(getWorkGroupSize(funcOp, workGroupSizeVec))) return;
     ArrayRef<int64_t> workGroupSize = dropTrailingOnes(workGroupSizeVec);
 
     OpBuilder builder(funcOp);
@@ -162,29 +150,16 @@
     FuncOp funcOp = getFunction();
     SmallVector<int64_t, 3> workGroupSizeVec;
     workGroupSizeVec.reserve(3);
-    if (failed(getLegacyWorkGroupSize(funcOp, workGroupSizeVec))) return;
+    if (failed(getWorkGroupSize(funcOp, workGroupSizeVec))) return;
     ArrayRef<int64_t> workGroupSize = dropTrailingOnes(workGroupSizeVec);
 
-    // While we can use any valid input for numWorkGroups, there might be
-    // canonicalizations that can be used if the workgroup size is passed
-    // accurately. For now compute the workgroup size based on the workload and
-    // workgroup size.
-    // TODO(ravishankarm): This assumes this is static for now. To handle
-    // dynamic cases, generate the IR that corresponds to the operations here.
-    SmallVector<int64_t, 3> workLoad;
-    workLoad.reserve(3);
-    if (failed(getLegacyLaunchSize(funcOp, workLoad))) {
-      funcOp.emitError("unable to retrieve workload size in dispatch function");
-      return signalPassFailure();
-    }
-    workLoad.resize(workGroupSize.size());
-
+    // For now just use number of workgroups to be [1, 1, 1]. The loop to GPU
+    // lowering doesnt use the value of number of workgroups in the codegen
+    // itself, but rather only uses this in the gpu.launch op which is
+    // irrelevant for IREE.
+    // TODO(ravishankarm): Fix the GPU lowering to allow not using gpu.launch at
+    // all.
     SmallVector<int64_t, 3> numWorkGroups(workGroupSize.size(), 1);
-    for (auto index : llvm::seq<unsigned>(0, workGroupSize.size())) {
-      numWorkGroups[index] = workLoad[index] / workGroupSize[index];
-      numWorkGroups[index] +=
-          static_cast<bool>(workLoad[index] % workGroupSize[index]);
-    }
 
     SmallVector<Value, 3> numWorkGroupsVal, workGroupSizeVal;
     numWorkGroupsVal.reserve(3);
@@ -210,7 +185,7 @@
     ModuleOp moduleOp = getModule();
     FuncOp funcOp = nullptr;
     auto walkResult = moduleOp.walk([&funcOp](FuncOp fOp) -> WalkResult {
-      if (fOp.getAttr("iree.executable.export")) {
+      if (isDispatchFunction(fOp)) {
         if (funcOp) return WalkResult::interrupt();
         funcOp = fOp;
       }
@@ -258,7 +233,7 @@
     SPIRVTypeConverter typeConverter;
     OwningRewritePatternList patterns;
     SmallVector<int32_t, 3> workGroupSize;
-    if (failed(getLegacyWorkGroupSize(funcOp, workGroupSize))) return;
+    if (failed(getWorkGroupSize(funcOp, workGroupSize))) return;
 
     // Set spv.entry_point_abi on each kernel functions to drive SPIR-V CodeGen.
     // This is required because SPIR-V CodeGen's contract.
@@ -292,7 +267,7 @@
       : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
   void runOnFunction() {
     FuncOp funcOp = getFunction();
-    if (!funcOp.getAttr("iree.executable.export")) return;
+    if (!isDispatchFunction(funcOp)) return;
 
     if (workGroupSize.empty()) {
       // By default look at the number of "parallel" loops in the generic op.
diff --git a/iree/compiler/Translation/SPIRV/ReductionCodegen/test/simple.mlir b/iree/compiler/Translation/SPIRV/ReductionCodegen/test/simple.mlir
index 985171d..903485a 100644
--- a/iree/compiler/Translation/SPIRV/ReductionCodegen/test/simple.mlir
+++ b/iree/compiler/Translation/SPIRV/ReductionCodegen/test/simple.mlir
@@ -6,7 +6,7 @@
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: memref<i32>
   // CHECK-SAME: [[ARG2:%[a-zA-Z0-9]*]]: memref<4xi32> {iree.executable.reduction.output}
   // CHECK-SAME: iree.executable.reduction.apply = [[APPLYFN:@[a-zA-Z0-9_]*]]
-  func @reduction_entry(memref<5x4xi32>, memref<i32>, memref<4xi32>) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 0 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.executable.workload = dense<[4, 5, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32}
+  func @reduction_entry(memref<5x4xi32>, memref<i32>, memref<4xi32>) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 0 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.ordinal = 0 : i32}
   // CHECK: [[TENSOR:%.*]] = iree.load_input([[ARG0]] : memref<5x4xi32>)  : tensor<5x4xi32>
   // CHECK: iree.store_reduce([[TENSOR]] : tensor<5x4xi32>, [[ARG2]] : memref<4xi32>, [[APPLYFN]])
 
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/BUILD b/iree/compiler/Translation/SPIRV/XLAToSPIRV/BUILD
index b11906a..200b440 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/BUILD
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/BUILD
@@ -34,10 +34,10 @@
     ],
     deps = [
         "//iree/compiler/Dialect/IREE/IR",
+        "//iree/compiler/Translation/CodegenUtils",
         "//iree/compiler/Translation/SPIRV/IndexComputation",
         "//iree/compiler/Translation/SPIRV/Passes",
         "//iree/compiler/Translation/SPIRV/ReductionCodegen",
-        "//iree/compiler/Utils",
         "@llvm-project//llvm:support",
         "@llvm-project//mlir:AffineOps",
         "@llvm-project//mlir:AffineToStandardTransforms",
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/CMakeLists.txt b/iree/compiler/Translation/SPIRV/XLAToSPIRV/CMakeLists.txt
index 5199a5a..109844c 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/CMakeLists.txt
@@ -41,10 +41,10 @@
     MLIRSupport
     MLIRTransforms
     iree::compiler::Dialect::IREE::IR
+    iree::compiler::Translation::CodegenUtils
     iree::compiler::Translation::SPIRV::IndexComputation
     iree::compiler::Translation::SPIRV::Passes
     iree::compiler::Translation::SPIRV::ReductionCodegen
-    iree::compiler::Utils
     tensorflow::mlir_xla
   ALWAYSLINK
   PUBLIC
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.cpp b/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.cpp
index ca7bb66..692618e 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.cpp
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.cpp
@@ -139,7 +139,7 @@
   entryFn.addEntryBlock();
 
   SmallVector<int32_t, 3> workGroupSize;
-  if (failed(getLegacyWorkGroupSize(fn, workGroupSize))) {
+  if (failed(getWorkGroupSize(fn, workGroupSize))) {
     return failure();
   }
   auto entryFnAttr =
@@ -188,7 +188,7 @@
                                                   FuncOp fn) {
   // First check that the global invocation id is in bounds.
   SmallVector<int64_t, 3> launchSize;
-  if (failed(getLegacyLaunchSize(fn, launchSize))) {
+  if (failed(getLaunchSize(fn, launchSize))) {
     return failure();
   }
   auto loc = fn.getLoc();
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.h b/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.h
index 6a35a39..2d2ed7f 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.h
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/SPIRVLowering.h
@@ -20,9 +20,9 @@
 #ifndef IREE_COMPILER_TRANSLATION_SPIRV_XLATOSPIRV_SPIRVLOWERING_H
 #define IREE_COMPILER_TRANSLATION_SPIRV_XLATOSPIRV_SPIRVLOWERING_H
 
+#include "iree/compiler/Translation/CodegenUtils/CodegenUtils.h"
 #include "iree/compiler/Translation/SPIRV/IndexComputation/IndexComputationAttribute.h"
 #include "iree/compiler/Translation/SPIRV/XLAToSPIRV/TensorIndexToScalarValueMap.h"
-#include "iree/compiler/Utils/IREECodegenUtils.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/arithmetic_ops.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/arithmetic_ops.mlir
index 30dd16f..c02f4de 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/arithmetic_ops.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/arithmetic_ops.mlir
@@ -6,7 +6,7 @@
 // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<4 x f32 [4]> [0]>, StorageBuffer>
 // CHECK-SAME: [[ARG2:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<4 x f32 [4]> [0]>, StorageBuffer>
 func @mul_1D(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   // CHECK: [[GLOBALIDPTR:%.*]] = spv._address_of [[GLOBALIDVAR]]
   // CHECK: [[GLOBALID:%.*]] = spv.Load "Input" [[GLOBALIDPTR]]
   // CHECK: [[GLOBALIDX:%.*]] = spv.CompositeExtract [[GLOBALID]][0 : i32]
@@ -31,7 +31,7 @@
 // -----
 
 func @frem(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xf32>) : tensor<4xf32>
   %1 = iree.load_input(%arg1 : memref<4xf32>) : tensor<4xf32>
   // CHECK: spv.FRem
@@ -43,7 +43,7 @@
 // -----
 
 func @srem(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2: memref<4xi32>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xi32>) : tensor<4xi32>
   %1 = iree.load_input(%arg1 : memref<4xi32>) : tensor<4xi32>
   // CHECK: spv.SRem
@@ -55,7 +55,7 @@
 // -----
 
 func @srem(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2: memref<4xi32>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xi32>) : tensor<4xi32>
   %1 = iree.load_input(%arg1 : memref<4xi32>) : tensor<4xi32>
   // CHECK: spv.SRem
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast.mlir
index d8f1979..1d7360a 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast.mlir
@@ -8,7 +8,7 @@
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<1512 x i32 [4]> [0]>, StorageBuffer>
 
   func @broadcast_2D_3D(%arg0: memref<12x42xi32>, %arg1: memref<3x12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 3]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
@@ -29,7 +29,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<1512 x i32 [4]> [0]>, StorageBuffer>
   func @broadcast_scalar_3D(%arg0: memref<i32>, %arg1: memref<3x12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 3]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     %0 = iree.load_input(%arg0 : memref<i32>) : tensor<i32>
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast_in_dim.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast_in_dim.mlir
index 80378e9..00666f0 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast_in_dim.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/broadcast_in_dim.mlir
@@ -7,7 +7,7 @@
   // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1512 x i32 [4]> [0]>, StorageBuffer>
   func @broadcast_in_dim_2D_3D(%arg0: memref<12x42xi32>, %arg1: memref<3x12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 3]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
     %1 = "xla_hlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<12x42xi32>) -> tensor<3x12x42xi32>
     iree.store_output(%1 : tensor<3x12x42xi32>, %arg1 : memref<3x12x42xi32>)
@@ -24,7 +24,7 @@
   // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1512 x i32 [4]> [0]>, StorageBuffer>
   func @broadcast_in_dim_scalar_3D(%arg0: memref<i32>, %arg1: memref<3x12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 3]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     %0 = iree.load_input(%arg0 : memref<i32>) : tensor<i32>
@@ -40,7 +40,7 @@
 
 module {
   func @const_float_splat(%arg0: memref<12x42xf32>)
-    attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+    attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: spv.constant 1.000000e+00 : f32
     %0 = constant dense<1.0> : tensor<12xf32>
     %1 = "xla_hlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : (tensor<12xf32>) -> tensor<12x42xf32>
@@ -53,7 +53,7 @@
 
 module {
   func @const_int_splat(%arg0: memref<12x42xi32>)
-    attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+    attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: spv.constant 42 : i32
     %0 = constant dense<42> : tensor<12xi32>
     %1 = "xla_hlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[0]> : tensor<1xi64>} : (tensor<12xi32>) -> tensor<12x42xi32>
@@ -68,7 +68,7 @@
   // CHECK: spv.func @const_int_nonsplat
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<1008 x i32 [4]> [0]>, StorageBuffer>
   func @const_int_nonsplat(%arg0: memref<2x12x42xi32>)
-    attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 2]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+    attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[CST:%.*]] = spv.constant dense<[42, 21]>
     // CHECK: [[VAR:%.*]] = spv.Variable init([[CST]]) : !spv.ptr<!spv.array<2 x i32 [4]>, Function>
     // CHECK: [[LOADPTR:%.*]] = spv.AccessChain [[VAR]]
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/compare.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/compare.mlir
index 73f842c..8f64a4b 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/compare.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/compare.mlir
@@ -1,7 +1,7 @@
 // RUN: iree-opt -split-input-file -iree-index-computation -simplify-spirv-affine-exprs=false -convert-iree-to-spirv -verify-diagnostics -o - %s | IreeFileCheck %s
 
 func @ieq(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xi32>) : tensor<4xi32>
   %1 = iree.load_input(%arg1 : memref<4xi32>) : tensor<4xi32>
   // CHECK: spv.IEqual
@@ -13,7 +13,7 @@
 // -----
 
 func @ineq(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xi32>) : tensor<4xi32>
   %1 = iree.load_input(%arg1 : memref<4xi32>) : tensor<4xi32>
   // CHECK: spv.INotEqual
@@ -25,7 +25,7 @@
 // -----
 
 func @islt(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xi32>) : tensor<4xi32>
   %1 = iree.load_input(%arg1 : memref<4xi32>) : tensor<4xi32>
   // CHECK: spv.SLessThan
@@ -37,7 +37,7 @@
 // -----
 
 func @isle(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xi32>) : tensor<4xi32>
   %1 = iree.load_input(%arg1 : memref<4xi32>) : tensor<4xi32>
   // CHECK: spv.SLessThanEqual
@@ -49,7 +49,7 @@
 // -----
 
 func @isgt(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xi32>) : tensor<4xi32>
   %1 = iree.load_input(%arg1 : memref<4xi32>) : tensor<4xi32>
   // CHECK: spv.SGreaterThan
@@ -61,7 +61,7 @@
 // -----
 
 func @isge(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xi32>) : tensor<4xi32>
   %1 = iree.load_input(%arg1 : memref<4xi32>) : tensor<4xi32>
   // CHECK: spv.SGreaterThanEqual
@@ -73,7 +73,7 @@
 // -----
 
 func @oeq(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xf32>) : tensor<4xf32>
   %1 = iree.load_input(%arg1 : memref<4xf32>) : tensor<4xf32>
   // CHECK: spv.FOrdEqual
@@ -85,7 +85,7 @@
 // -----
 
 func @oge(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xf32>) : tensor<4xf32>
   %1 = iree.load_input(%arg1 : memref<4xf32>) : tensor<4xf32>
   // CHECK: spv.FOrdGreaterThanEqual
@@ -97,7 +97,7 @@
 // -----
 
 func @ogt(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xf32>) : tensor<4xf32>
   %1 = iree.load_input(%arg1 : memref<4xf32>) : tensor<4xf32>
   // CHECK: spv.FOrdGreaterThan
@@ -109,7 +109,7 @@
 // -----
 
 func @ole(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xf32>) : tensor<4xf32>
   %1 = iree.load_input(%arg1 : memref<4xf32>) : tensor<4xf32>
   // CHECK: spv.FOrdLessThanEqual
@@ -121,7 +121,7 @@
 // -----
 
 func @olt(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xf32>) : tensor<4xf32>
   %1 = iree.load_input(%arg1 : memref<4xf32>) : tensor<4xf32>
   // CHECK: spv.FOrdLessThan
@@ -133,7 +133,7 @@
 // -----
 
 func @ueq(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xf32>) : tensor<4xf32>
   %1 = iree.load_input(%arg1 : memref<4xf32>) : tensor<4xf32>
   // CHECK: spv.FUnordEqual
@@ -145,7 +145,7 @@
 // -----
 
 func @uge(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xf32>) : tensor<4xf32>
   %1 = iree.load_input(%arg1 : memref<4xf32>) : tensor<4xf32>
   // CHECK: spv.FUnordGreaterThanEqual
@@ -157,7 +157,7 @@
 // -----
 
 func @ugt(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xf32>) : tensor<4xf32>
   %1 = iree.load_input(%arg1 : memref<4xf32>) : tensor<4xf32>
   // CHECK: spv.FUnordGreaterThan
@@ -169,7 +169,7 @@
 // -----
 
 func @ule(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xf32>) : tensor<4xf32>
   %1 = iree.load_input(%arg1 : memref<4xf32>) : tensor<4xf32>
   // CHECK: spv.FUnordLessThanEqual
@@ -181,7 +181,7 @@
 // -----
 
 func @ult(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xf32>) : tensor<4xf32>
   %1 = iree.load_input(%arg1 : memref<4xf32>) : tensor<4xf32>
   // CHECK: spv.FUnordLessThan
@@ -193,7 +193,7 @@
 // -----
 
 func @beq(%arg0: memref<4xi1>, %arg1: memref<4xi1>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xi1>) : tensor<4xi1>
   %1 = iree.load_input(%arg1 : memref<4xi1>) : tensor<4xi1>
   // CHECK: spv.LogicalEqual
@@ -205,7 +205,7 @@
 // -----
 
 func @bneq(%arg0: memref<4xi1>, %arg1: memref<4xi1>, %arg2: memref<4xi1>)
-attributes  {iree.executable.export, iree.executable.workload = dense<[4, 1, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
   %0 = iree.load_input(%arg0 : memref<4xi1>) : tensor<4xi1>
   %1 = iree.load_input(%arg1 : memref<4xi1>) : tensor<4xi1>
   // CHECK: spv.LogicalNotEqual
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/concatenate.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/concatenate.mlir
index 79510a0..6870aa3 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/concatenate.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/concatenate.mlir
@@ -6,17 +6,17 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<64 x f32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
   func @concatenate(%arg0: memref<1x64xf32>, %arg1 : memref<1x10xf32>, %arg2 : memref<1x74xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[1, 74]> : tensor<2xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[GLOBALIDPTR:%.*]] = spv._address_of [[GLOBALIDVAR]]
     // CHECK: [[GLOBALID:%.*]] = spv.Load "Input" [[GLOBALIDPTR]]
-    // CHECK: [[GLOBALIDY:%.*]] = spv.CompositeExtract [[GLOBALID]][1 : i32]
+    // CHECK: [[GLOBALIDX:%.*]] = spv.CompositeExtract [[GLOBALID]][0 : i32]
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[INPUTVAL0:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]] : f32
     // CHECK: [[ARG1LOADPTR:%.*]] = spv.AccessChain [[ARG1]]
     // CHECK: [[INPUTVAL1:%.*]] = spv.Load "StorageBuffer" [[ARG1LOADPTR]] : f32
     // CHECK: [[TRUE:%.*]] = spv.constant true
     // CHECK: [[SIXTY_FOUR:%.*]] = spv.constant 64 : i32
-    // CHECK: [[CHECK:%.*]] = spv.SGreaterThanEqual [[GLOBALIDY]], [[SIXTY_FOUR]] : i32
+    // CHECK: [[CHECK:%.*]] = spv.SGreaterThanEqual [[GLOBALIDX]], [[SIXTY_FOUR]] : i32
     // CHECK: [[COND:%.*]] = spv.LogicalAnd [[TRUE]], [[CHECK]] : i1
     // CHECK: [[RESULT:%.*]] = spv.Select [[COND]], [[INPUTVAL1]], [[INPUTVAL0]] : i1, f32
     %0 = iree.load_input(%arg0 : memref<1x64xf32>) : tensor<1x64xf32>
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/constant.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/constant.mlir
index a94c6ce..9328f00 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/constant.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/constant.mlir
@@ -2,7 +2,7 @@
 
 module {
   func @const_f32(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[3, 2]> : tensor<2xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[CONST:%.*]] = spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32 [4]>
     // CHECK: [[VAR:%.*]] = spv.Variable init([[CONST]])
     // CHECK: [[NUMPTR:%.*]] = spv.AccessChain [[VAR]]
@@ -19,7 +19,7 @@
 
 module {
   func @splat_const_f32(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[3, 2]> : tensor<2xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: spv.constant 1.000000e+00 : f32
     %0 = iree.load_input(%arg0 : memref<2x3xf32>) : tensor<2x3xf32>
     %1 = "xla_hlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
@@ -33,7 +33,7 @@
 
 module {
   func @const_i32(%arg0: memref<2x3xi32>, %arg1: memref<2x3xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[3, 2]> : tensor<2xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[CONST:%.*]] = spv.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
     // CHECK: [[VAR:%.*]] = spv.Variable init([[CONST]])
     // CHECK: [[NUMPTR:%.*]] = spv.AccessChain [[VAR]]
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/convert.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/convert.mlir
index b7c12bb..9969849 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/convert.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/convert.mlir
@@ -2,7 +2,7 @@
 
 module {
   func @convert_f2f_nop(%arg0: memref<12xf32>, %arg1 : memref<12xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12]> : tensor<1xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12xf32>) : tensor<12xf32>
     // CHECK-NOT: spv.FConvert
     %1 = "xla_hlo.convert"(%0) : (tensor<12xf32>) -> tensor<12xf32>
@@ -15,7 +15,7 @@
 
 module {
   func @convert_f2f(%arg0: memref<12xf32>, %arg1 : memref<12xf16>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12]> : tensor<1xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12xf32>) : tensor<12xf32>
     // CHECK: spv.FConvert {{%.*}} f32 to f16
     %1 = "xla_hlo.convert"(%0) : (tensor<12xf32>) -> tensor<12xf16>
@@ -28,7 +28,7 @@
 
 module {
   func @convert_i2i_nop(%arg0: memref<12xi32>, %arg1 : memref<12xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12]> : tensor<1xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12xi32>) : tensor<12xi32>
     // CHECK-NOT: spv.SConvert
     %1 = "xla_hlo.convert"(%0) : (tensor<12xi32>) -> tensor<12xi32>
@@ -41,7 +41,7 @@
 
 module {
   func @convert_i2i(%arg0: memref<12xi32>, %arg1 : memref<12xi16>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12]> : tensor<1xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12xi32>) : tensor<12xi32>
     // CHECK: spv.SConvert {{%.*}} i32 to i16
     %1 = "xla_hlo.convert"(%0) : (tensor<12xi32>) -> tensor<12xi16>
@@ -54,7 +54,7 @@
 
 module {
   func @convert_i2f(%arg0: memref<12xi32>, %arg1 : memref<12xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12]> : tensor<1xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12xi32>) : tensor<12xi32>
     // CHECK: spv.ConvertSToF
     %1 = "xla_hlo.convert"(%0) : (tensor<12xi32>) -> tensor<12xf32>
@@ -67,7 +67,7 @@
 
 module {
   func @convert_f2i_nop(%arg0: memref<12xf32>, %arg1 : memref<12xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12]> : tensor<1xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12xf32>) : tensor<12xf32>
     // CHECK: spv.ConvertFToS
     %1 = "xla_hlo.convert"(%0) : (tensor<12xf32>) -> tensor<12xi32>
@@ -80,7 +80,7 @@
 
 module {
   func @convert_b2i(%arg0: memref<12xi1>, %arg1 : memref<12xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12]> : tensor<1xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[VAL0:%.*]] = spv.Load "StorageBuffer" %{{.*}} : i1
     // CHECK: [[ZERO:%.*]] = spv.constant 0 : i32
     // CHECK: [[ONE:%.*]] = spv.constant 1 : i32
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/copy.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/copy.mlir
index 5012f3e..425ca20 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/copy.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/copy.mlir
@@ -8,7 +8,7 @@
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: {{spirv|spv}}.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}
   func @simple_load_store(%arg0: memref<12x42xi32>, %arg1: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[GLOBALIDPTR:%.*]] = spv._address_of [[GLOBALIDVAR]]
     // CHECK: [[GLOBALID:%.*]] = spv.Load "Input" [[GLOBALIDPTR]]
     // CHECK: [[GLOBALIDX:%.*]] = spv.CompositeExtract [[GLOBALID]][0 : i32]
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/exp_test.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/exp_test.mlir
index 8c32f38..ea51f14 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/exp_test.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/exp_test.mlir
@@ -2,7 +2,7 @@
 
 module {
   func @scalar_rgn_dispatch_0(%arg0: memref<f32>)
-    attributes  {iree.executable.export, iree.executable.workload = dense<1> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+    attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %cst = constant dense<1.000000e+00> : tensor<f32>
     //CHECK: {{%.*}} = spv.GLSL.Exp {{%.*}} : f32
     %0 = "xla_hlo.exp"(%cst) : (tensor<f32>) -> tensor<f32>
@@ -15,7 +15,7 @@
 
 module {
   func @exp(%arg0: memref<12x42xf32>, %arg2 : memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: {{%.*}} = spv.GLSL.Exp {{%.*}} : f32
     %2 = "xla_hlo.exp"(%0) : (tensor<12x42xf32>) -> tensor<12x42xf32>
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/extract_element.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/extract_element.mlir
index 23f0ec9..8ca5a78 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/extract_element.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/extract_element.mlir
@@ -5,7 +5,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<i1 [0]>, StorageBuffer>
   func @extract_element(%arg0: memref<i1>, %arg1: memref<i1>)
-    attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.executable.workload = dense<1> : tensor<3xi32>, iree.num_dims = 3 : i32, iree.ordinal = 0 : i32} {
+    attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.num_dims = 3 : i32, iree.ordinal = 0 : i32} {
     %0 = "iree.load_input"(%arg0) : (memref<i1>) -> tensor<i1>
     // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
     // CHECK: {{%.*}} = spv.AccessChain [[ARG0]]{{\[}}[[ZERO1]]{{\]}}
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/gather.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/gather.mlir
index 528c852..90ccb00 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/gather.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/gather.mlir
@@ -5,7 +5,7 @@
   // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<50 x f32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<i64 [0]>, StorageBuffer>
   func @foo(%arg0: memref<5x1x10xf32>, %arg1: memref<i64>, %arg2: memref<1x10xf32>)
-  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.executable.workload = dense<[10, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.ordinal = 0 : i32} {
     // CHECK: [[ZERO1:%.*]] = spv.constant 0
     // CHECK: [[LOAD_ADDRESS_ARG1:%.*]] = spv.AccessChain [[ARG1]]{{\[}}[[ZERO1]]{{\]}}
     // CHECK: [[INDEXI64:%.*]] = spv.Load {{".*"}} [[LOAD_ADDRESS_ARG1]]
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/max.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/max.mlir
index cd9c8b2..f715e2d 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/max.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/max.mlir
@@ -2,7 +2,7 @@
 
 module {
   func @maxf(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2 : memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.GLSL.FMax [[VAL1:%.*]], [[VAL2:%.*]] : f32
@@ -16,7 +16,7 @@
 
 module {
   func @maxi(%arg0: memref<12x42xi32>, %arg1: memref<12x42xi32>, %arg2 : memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
     %1 = iree.load_input(%arg1 : memref<12x42xi32>) : tensor<12x42xi32>
     //CHECK: [[COMPARE:%.*]] = spv.GLSL.SMax [[VAL1:%.*]], [[VAL2:%.*]] : i32
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/pad.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/pad.mlir
index c8c9909..ab02c05 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/pad.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/pad.mlir
@@ -5,7 +5,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<216 x f32 [4]> [0]>, StorageBuffer>
   func @pad_zero_interior(%arg0 : memref<12x4xf32>, %arg1 : memref<18x12xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12, 18, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[INPUTVAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]] : f32
     %0 = iree.load_input(%arg0 : memref<12x4xf32>) : tensor<12x4xf32>
@@ -30,7 +30,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer>
   func @pad_no_op(%arg0 : memref<12x4xf32>, %arg1 : memref<12x4xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[4, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
 
   // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[INPUTVAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]] : f32
@@ -54,7 +54,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<522 x f32 [4]> [0]>, StorageBuffer>
   func @pad_zero_interior(%arg0 : memref<12x4xf32>, %arg1 : memref<29x18xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[18, 29, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[INPUTVAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]] : f32
     %0 = iree.load_input(%arg0 : memref<12x4xf32>) : tensor<12x4xf32>
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape.mlir
index 666c153..d94987d 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape.mlir
@@ -5,7 +5,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   func @reshape_2D_2D(%arg0: memref<24x21xi32>, %arg1: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     %0 = iree.load_input(%arg0 : memref<24x21xi32>) : tensor<24x21xi32>
@@ -24,7 +24,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   func @reshape_3D_2D(%arg0: memref<4x6x21xi32>, %arg1: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     %0 = iree.load_input(%arg0 : memref<4x6x21xi32>) : tensor<4x6x21xi32>
@@ -43,7 +43,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   func @reshape_2D_3D(%arg0: memref<24x21xi32>, %arg1: memref<12x6x7xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     %0 = iree.load_input(%arg0 : memref<24x21xi32>) : tensor<24x21xi32>
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape_dropdims.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape_dropdims.mlir
index 49f4935..a934a78 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape_dropdims.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reshape_dropdims.mlir
@@ -6,7 +6,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   func @reshape_4D_3D(%arg0: memref<12x42x1xi32>, %arg1: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     %0 = iree.load_input(%arg0 : memref<12x42x1xi32>) : tensor<12x42x1xi32>
@@ -26,7 +26,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     %0 = iree.load_input(%arg0 : memref<12x42x1x1xi32>) : tensor<12x42x1x1xi32>
@@ -46,7 +46,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1: memref<12x42x1x1xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
@@ -66,7 +66,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1: memref<12x1x1x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
@@ -86,7 +86,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<504 x i32 [4]> [0]>, StorageBuffer>
   func @reshape_2D_4D(%arg0: memref<12x1x1x42xi32>, %arg1: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     %0 = iree.load_input(%arg0 : memref<12x1x1x42xi32>) : tensor<12x1x1x42xi32>
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reverse.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reverse.mlir
index d5013c4..4721e58 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reverse.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/reverse.mlir
@@ -5,7 +5,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<144 x f32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<144 x f32 [4]> [0]>, StorageBuffer>
   func @reverse_2d(%arg0: memref<12x12xf32>, %arg1 : memref<12x12xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12, 12]> : tensor<2xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]]  = spv.Load "StorageBuffer" [[ARG0LOADPTR]] : f32
     %0 = iree.load_input(%arg0 : memref<12x12xf32>) : tensor<12x12xf32>
@@ -24,7 +24,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<27 x f32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<27 x f32 [4]> [0]>, StorageBuffer>
   func @reverse_3d(%arg0: memref<3x3x3xf32>, %arg1 : memref<3x3x3xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[3, 3, 3]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL:%.*]]  = spv.Load "StorageBuffer" [[ARG0LOADPTR]] : f32
     %0 = iree.load_input(%arg0 : memref<3x3x3xf32>) : tensor<3x3x3xf32>
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/select.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/select.mlir
index 8ac5038..cd612c9 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/select.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/select.mlir
@@ -2,7 +2,7 @@
 
 module {
   func @select_ford_ge(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FOrdGreaterThanEqual {{%.*}}, {{%.*}}
@@ -18,7 +18,7 @@
 
 module {
   func @select_ford_eq(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FOrdEqual {{%.*}}, {{%.*}}
@@ -33,7 +33,7 @@
 
 module {
   func @select_ford_gt(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FOrdGreaterThan {{%.*}}, {{%.*}}
@@ -48,7 +48,7 @@
 
 module {
   func @select_ford_lt(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FOrdLessThan {{%.*}}, {{%.*}}
@@ -63,7 +63,7 @@
 
 module {
   func @select_ford_le(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FOrdLessThanEqual {{%.*}}, {{%.*}}
@@ -78,7 +78,7 @@
 
 module {
   func @select_ford_ne(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FOrdNotEqual {{%.*}}, {{%.*}}
@@ -93,7 +93,7 @@
 
 module {
   func @select_funord_eq(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FUnordEqual {{%.*}}, {{%.*}}
@@ -108,7 +108,7 @@
 
 module {
   func @select_funord_ge(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FUnordGreaterThanEqual {{%.*}}, {{%.*}}
@@ -123,7 +123,7 @@
 
 module {
   func @select_funord_gt(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FUnordGreaterThan {{%.*}}, {{%.*}}
@@ -138,7 +138,7 @@
 
 module {
   func @select_funord_lt(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FUnordLessThan {{%.*}}, {{%.*}}
@@ -153,7 +153,7 @@
 
 module {
   func @select_funord_le(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FUnordLessThanEqual {{%.*}}, {{%.*}}
@@ -168,7 +168,7 @@
 
 module {
   func @select_funord_ne(%arg0: memref<12x42xf32>, %arg1: memref<12x42xf32>, %arg2: memref<12x42xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xf32>) : tensor<12x42xf32>
     %1 = iree.load_input(%arg1 : memref<12x42xf32>) : tensor<12x42xf32>
     //CHECK: [[COMPARE:%.*]] = spv.FUnordNotEqual {{%.*}}, {{%.*}}
@@ -183,7 +183,7 @@
 
 module {
   func @select_int_eq(%arg0: memref<12x42xi32>, %arg1: memref<12x42xi32>, %arg2: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
     %1 = iree.load_input(%arg1 : memref<12x42xi32>) : tensor<12x42xi32>
     //CHECK: [[COMPARE:%.*]] = spv.IEqual {{%.*}}, {{%.*}}
@@ -198,7 +198,7 @@
 
 module {
   func @select_int_ne(%arg0: memref<12x42xi32>, %arg1: memref<12x42xi32>, %arg2: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
     %1 = iree.load_input(%arg1 : memref<12x42xi32>) : tensor<12x42xi32>
     //CHECK: [[COMPARE:%.*]] = spv.INotEqual {{%.*}}, {{%.*}}
@@ -213,7 +213,7 @@
 
 module {
   func @select_int_lt(%arg0: memref<12x42xi32>, %arg1: memref<12x42xi32>, %arg2: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
     %1 = iree.load_input(%arg1 : memref<12x42xi32>) : tensor<12x42xi32>
     //CHECK: [[COMPARE:%.*]] = spv.SLessThan {{%.*}}, {{%.*}}
@@ -228,7 +228,7 @@
 
 module {
   func @select_int_le(%arg0: memref<12x42xi32>, %arg1: memref<12x42xi32>, %arg2: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
     %1 = iree.load_input(%arg1 : memref<12x42xi32>) : tensor<12x42xi32>
     //CHECK: [[COMPARE:%.*]] = spv.SLessThanEqual {{%.*}}, {{%.*}}
@@ -243,7 +243,7 @@
 
 module {
   func @select_int_ge(%arg0: memref<12x42xi32>, %arg1: memref<12x42xi32>, %arg2: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
     %1 = iree.load_input(%arg1 : memref<12x42xi32>) : tensor<12x42xi32>
     //CHECK: [[COMPARE:%.*]] = spv.SGreaterThanEqual {{%.*}}, {{%.*}}
@@ -259,7 +259,7 @@
 
 module {
   func @select_int_gt(%arg0: memref<12x42xi32>, %arg1: memref<12x42xi32>, %arg2: memref<12x42xi32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[42, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<12x42xi32>) : tensor<12x42xi32>
     %1 = iree.load_input(%arg1 : memref<12x42xi32>) : tensor<12x42xi32>
     //CHECK: [[COMPARE:%.*]] = spv.SGreaterThan {{%.*}}, {{%.*}}
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/slice.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/slice.mlir
index a556c92..07b3b90 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/slice.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/slice.mlir
@@ -5,7 +5,7 @@
   // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<36 x f32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<6 x f32 [4]> [0]>, StorageBuffer>
   func @slice_unit_stride(%arg0: memref<6x6xf32>, %arg1: memref<2x3xf32>)
-  attributes {iree.executable.export, iree.executable.workload = dense<[6, 1]> : tensor<2xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL0:%.*]]  = spv.Load "StorageBuffer" [[ARG0LOADPTR]] : f32
     // CHECK: [[ARG1STOREPTR:%.*]] = spv.AccessChain [[ARG1]]
@@ -24,7 +24,7 @@
   // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<36 x f32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<6 x f32 [4]> [0]>, StorageBuffer>
   func @slice_non_unit_stride(%arg0: memref<6x6xf32>, %arg1: memref<2x3xf32>)
-  attributes {iree.executable.export, iree.executable.workload = dense<[6, 1]> : tensor<2xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL0:%.*]]  = spv.Load "StorageBuffer" [[ARG0LOADPTR]] : f32
     // CHECK: [[ARG1STOREPTR:%.*]] = spv.AccessChain [[ARG1]]
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/store_reduce.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/store_reduce.mlir
index 5311af3..9654bb1 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/store_reduce.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/store_reduce.mlir
@@ -5,7 +5,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<5 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
-  func @reduction_entry(%arg0: memref<5xi32>, %arg1: memref<i32>, %arg2: memref<i32> {iree.executable.reduction.output}) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 0 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.executable.workload = dense<[5, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  func @reduction_entry(%arg0: memref<5xi32>, %arg1: memref<i32>, %arg2: memref<i32> {iree.executable.reduction.output}) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 0 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<5xi32>)  : tensor<5xi32>
     // CHECK: [[LOADPTR:%[a-zA-Z0-9_]*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[LOADVAL:%[a-zA-Z0-9_]*]]  = spv.Load "StorageBuffer" [[LOADPTR]]
@@ -27,7 +27,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<20 x i32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<4 x i32 [4]> [0]>, StorageBuffer>
-  func @reduction_2D_dim0_entry(%arg0: memref<5x4xi32>, %arg1: memref<i32>, %arg2: memref<4xi32> {iree.executable.reduction.output}) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 0 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.executable.workload = dense<[4, 5, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  func @reduction_2D_dim0_entry(%arg0: memref<5x4xi32>, %arg1: memref<i32>, %arg2: memref<4xi32> {iree.executable.reduction.output}) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 0 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.ordinal = 0 : i32} {
     %0 = iree.load_input(%arg0 : memref<5x4xi32>)  : tensor<5x4xi32>
     // CHECK: [[LOADPTR:%[a-zA-Z0-9_]*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[LOADVAL:%[a-zA-Z0-9_]*]]  = spv.Load "StorageBuffer" [[LOADPTR]]
@@ -47,7 +47,7 @@
 module {
   // CHECK: spv.func @reduction_2D_dim1_entry
   // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]: !spv.ptr<!spv.struct<!spv.array<5 x i32 [4]> [0]>, StorageBuffer>
-  func @reduction_2D_dim1_entry(%arg0: memref<5x4xi32>, %arg1: memref<i32>, %arg2: memref<5xi32>) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 1 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.executable.workload = dense<[4, 5, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  func @reduction_2D_dim1_entry(%arg0: memref<5x4xi32>, %arg1: memref<i32>, %arg2: memref<5xi32>) attributes {iree.executable.export, iree.executable.reduction, iree.executable.reduction.apply = @reduction_apply, iree.executable.reduction.dimension = 1 : i32, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi64>, iree.ordinal = 0 : i32} {
     // CHECK: [[GLOBALIDPTR:%[a-zA-Z0-9_]*]] = spv._address_of @globalInvocationID
     // CHECK: [[GLOBALID:%[a-zA-Z0-9_]*]] = spv.Load "Input" [[GLOBALIDPTR]] : vector<3xi32>
     // CHECK: [[GLOBALIDY:%[a-zA-Z0-9_]*]] = spv.CompositeExtract [[GLOBALID]]{{\[}}1 : i32{{\]}}
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/transpose_add.mlir b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/transpose_add.mlir
index a6cf7af..331490e 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/transpose_add.mlir
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/test/transpose_add.mlir
@@ -5,7 +5,7 @@
   // CHECK-SAME: [[ARG0:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<144 x f32 [4]> [0]>, StorageBuffer>
   // CHECK-SAME: [[ARG1:%[a-zA-Z0-9]*]]: !spv.ptr<!spv.struct<!spv.array<144 x f32 [4]> [0]>, StorageBuffer>
   func @transpose_add(%arg0: memref<12x12xf32>, %arg1: memref<12x12xf32>)
-  attributes  {iree.executable.export, iree.executable.workload = dense<[12, 12, 1]> : tensor<3xi32>, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
+  attributes  {iree.executable.export, iree.executable.workgroup_size = dense<[32, 1, 1]> : tensor<3xi32>, iree.ordinal = 0 : i32} {
     // CHECK: [[ARG0LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
     // CHECK: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[ARG0LOADPTR]]
     // CHECK: [[ARG1LOADPTR:%.*]] = spv.AccessChain [[ARG0]]
diff --git a/iree/compiler/Utils/BUILD b/iree/compiler/Utils/BUILD
index a76ef85..dca2b51 100644
--- a/iree/compiler/Utils/BUILD
+++ b/iree/compiler/Utils/BUILD
@@ -23,12 +23,10 @@
     name = "Utils",
     srcs = [
         "GraphUtils.cpp",
-        "IREECodegenUtils.cpp",
         "TypeConversionUtils.cpp",
     ],
     hdrs = [
         "GraphUtils.h",
-        "IREECodegenUtils.h",
         "TypeConversionUtils.h",
     ],
     deps = [
diff --git a/iree/compiler/Utils/CMakeLists.txt b/iree/compiler/Utils/CMakeLists.txt
index 3094d47..aa87771 100644
--- a/iree/compiler/Utils/CMakeLists.txt
+++ b/iree/compiler/Utils/CMakeLists.txt
@@ -17,11 +17,9 @@
     Utils
   HDRS
     "GraphUtils.h"
-    "IREECodegenUtils.h"
     "TypeConversionUtils.h"
   SRCS
     "GraphUtils.cpp"
-    "IREECodegenUtils.cpp"
     "TypeConversionUtils.cpp"
   DEPS
     LLVMSupport
diff --git a/iree/compiler/Utils/IREECodegenUtils.cpp b/iree/compiler/Utils/IREECodegenUtils.cpp
deleted file mode 100644
index 1b65cb9..0000000
--- a/iree/compiler/Utils/IREECodegenUtils.cpp
+++ /dev/null
@@ -1,78 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Utils/IREECodegenUtils.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-/// Gets the launch size associated with the dispatch function.
-LogicalResult getLegacyLaunchSize(Operation *funcOp,
-                                  SmallVectorImpl<int64_t> &launchSize) {
-  if (!funcOp->getAttr("iree.executable.export")) {
-    return funcOp->emitError(
-        "expected operation to be in dispatch function to get launch size");
-  }
-  auto workloadAttr =
-      funcOp->getAttrOfType<DenseElementsAttr>("iree.executable.workload");
-  if (!workloadAttr) {
-    return funcOp->emitError(
-        "unable to find workload size, missing attribute "
-        "iree.executable.workload in dispatch function");
-  }
-  launchSize.clear();
-  for (auto value : workloadAttr.getValues<APInt>()) {
-    launchSize.push_back(value.getSExtValue());
-  }
-  // Drop trailing ones.
-  auto dropFrom = launchSize.size() - 1;
-  while (dropFrom > 0 && launchSize[dropFrom] == 1) {
-    --dropFrom;
-  }
-  if (dropFrom > 0) {
-    launchSize.erase(std::next(launchSize.begin(), dropFrom + 1),
-                     launchSize.end());
-  }
-  return success();
-}
-
-/// Gets the workgroup size.
-template <typename intType>
-LogicalResult getLegacyWorkGroupSize(Operation *funcOp,
-                                     SmallVectorImpl<intType> &workGroupSize) {
-  if (!funcOp->getAttr("iree.executable.export")) {
-    return funcOp->emitError(
-        "expected operation to be in dispatch function to get launch size");
-  }
-  auto workGroupSizeAttr = funcOp->getAttrOfType<DenseElementsAttr>(
-      "iree.executable.workgroup_size");
-  if (!workGroupSizeAttr) {
-    return funcOp->emitError(
-        "unable to find workload size, missing attribute "
-        "iree.executable.workload in dispatch function");
-  }
-  workGroupSize.clear();
-  for (auto value : workGroupSizeAttr.getValues<APInt>()) {
-    workGroupSize.push_back(value.getSExtValue());
-  }
-  return success();
-}
-
-template LogicalResult getLegacyWorkGroupSize<int32_t>(
-    Operation *funcOp, SmallVectorImpl<int32_t> &workGroupSize);
-template LogicalResult getLegacyWorkGroupSize<int64_t>(
-    Operation *funcOp, SmallVectorImpl<int64_t> &workGroupSize);
-
-}  // namespace iree_compiler
-}  // namespace mlir
diff --git a/iree/compiler/Utils/IREECodegenUtils.h b/iree/compiler/Utils/IREECodegenUtils.h
deleted file mode 100644
index 3f82e60..0000000
--- a/iree/compiler/Utils/IREECodegenUtils.h
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_UTILS_IREECODEGENUTILS_H
-#define IREE_COMPILER_UTILS_IREECODEGENUTILS_H
-
-#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
-#include "mlir/IR/Function.h"
-#include "mlir/Support/LogicalResult.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// WARNING: this file is deprecated and will be removed soon. Do not use.
-
-// TODO(ravishankarm): remove this; it does not work with dynamic shapes.
-/// Gets the launch size associated with the dispatch function.
-LogicalResult getLegacyLaunchSize(Operation *funcOp,
-                                  SmallVectorImpl<int64_t> &launchSize);
-
-/// Gets the workgroup size. Has to be a static constant.
-template <typename intType>
-LogicalResult getLegacyWorkGroupSize(Operation *funcOp,
-                                     SmallVectorImpl<intType> &workGroupSize);
-
-}  // namespace iree_compiler
-}  // namespace mlir
-
-#endif  // IREE_COMPILER_UTILS_IREECODEGENUTILS_H