Simplify the conv workgroup calculation by using the output shape rather than
recomputing from input shapes.

PiperOrigin-RevId: 284052397
diff --git a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
index 0773784..6434866 100644
--- a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
+++ b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
@@ -28,44 +28,6 @@
 namespace IREE {
 namespace Flow {
 
-std::array<int32_t, 3> convWorkload(xla_hlo::ConvOp conv) {
-  std::array<int32_t, 3> workload = {1, 1, 1};
-  auto lhs = conv.lhs()->getType().cast<ShapedType>();
-  auto rhs = conv.rhs()->getType().cast<ShapedType>();
-  std::array<int32_t, 3> lhs_hw = {1, 1};
-  int i = 0;
-  for (const auto &spatial :
-       conv.dimension_numbers().input_spatial_dimensions()) {
-    if (i > 1) {
-      break;
-    }
-    lhs_hw[i++] = lhs.getDimSize(spatial.getSExtValue());
-  }
-  std::array<int32_t, 2> rhs_hw = {1, 1};
-  i = 0;
-  for (const auto &spatial :
-       conv.dimension_numbers().kernel_spatial_dimensions()) {
-    if (i > 1) {
-      break;
-    }
-    rhs_hw[i++] = rhs.getDimSize(spatial.getSExtValue());
-  }
-  std::array<int32_t, 2> padding = {0, 0};
-  i = 0;
-  for (const auto &pad : conv.padding().getValue().getIntValues()) {
-    if (i > 3) {
-      break;
-    }
-    padding[i++ / 2] += pad.getSExtValue();
-  }
-  // TODO(namiller): Generalize for other ranks and strides once supported.
-  workload[2] =
-      lhs.getDimSize(conv.dimension_numbers().input_batch_dimension().getInt());
-  workload[1] = lhs_hw[0] - rhs_hw[0] + padding[0] + 1;
-  workload[0] = lhs_hw[1] - rhs_hw[1] + padding[1] + 1;
-  return workload;
-}
-
 Value *calculateWorkload(Operation *op, ShapedType baseOperandType) {
   OpBuilder builder(op);
 
@@ -76,10 +38,19 @@
     op->emitOpError() << "Dynamic shapes not yet supported";
     return nullptr;
   }
-  if (auto conv = llvm::dyn_cast_or_null<xla_hlo::ConvOp>(op)) {
-    workload = convWorkload(conv);
+  auto shape = baseOperandType.getShape();
+  if (auto conv = dyn_cast_or_null<xla_hlo::ConvOp>(op)) {
+    workload[2] =
+        shape[conv.dimension_numbers().output_batch_dimension().getInt()];
+    int i = 0;
+    for (const auto &dim :
+         conv.dimension_numbers().output_spatial_dimensions().getIntValues()) {
+      if (i > 1) {
+        break;
+      }
+      workload[1 - i++] = shape[dim.getSExtValue()];
+    }
   } else {
-    auto shape = baseOperandType.getShape();
     // Drop the trailing ones from the shape.
     while (shape.size() > 1 && shape.back() == 1) {
       shape = shape.drop_back();
diff --git a/iree/compiler/Utils/DispatchUtils.cpp b/iree/compiler/Utils/DispatchUtils.cpp
index 436015c..9cf84df 100644
--- a/iree/compiler/Utils/DispatchUtils.cpp
+++ b/iree/compiler/Utils/DispatchUtils.cpp
@@ -37,44 +37,6 @@
 namespace mlir {
 namespace iree_compiler {
 
-std::array<int32_t, 3> convWorkload(xla_hlo::ConvOp conv) {
-  std::array<int32_t, 3> workload = {1, 1, 1};
-  auto lhs = conv.lhs()->getType().cast<ShapedType>();
-  auto rhs = conv.rhs()->getType().cast<ShapedType>();
-  std::array<int32_t, 3> lhs_hw = {1, 1};
-  int i = 0;
-  for (const auto &spatial :
-       conv.dimension_numbers().input_spatial_dimensions()) {
-    if (i > 1) {
-      break;
-    }
-    lhs_hw[i++] = lhs.getDimSize(spatial.getSExtValue());
-  }
-  std::array<int32_t, 2> rhs_hw = {1, 1};
-  i = 0;
-  for (const auto &spatial :
-       conv.dimension_numbers().kernel_spatial_dimensions()) {
-    if (i > 1) {
-      break;
-    }
-    rhs_hw[i++] = rhs.getDimSize(spatial.getSExtValue());
-  }
-  std::array<int32_t, 2> padding = {0, 0};
-  i = 0;
-  for (const auto &pad : conv.padding().getValue().getIntValues()) {
-    if (i > 3) {
-      break;
-    }
-    padding[i++ / 2] += pad.getSExtValue();
-  }
-  // TODO(namiller): Generalize for other ranks and strides once supported.
-  workload[2] =
-      lhs.getDimSize(conv.dimension_numbers().input_batch_dimension().getInt());
-  workload[1] = lhs_hw[0] - rhs_hw[0] + padding[0] + 1;
-  workload[0] = lhs_hw[1] - rhs_hw[1] + padding[1] + 1;
-  return workload;
-}
-
 Value *calculateWorkload(Operation *op, Value *baseOperand) {
   OpBuilder builder(op);
 
@@ -82,31 +44,44 @@
 
   // TODO(b/139353314): lookup/calculate based on type/etc.
   auto resultType = baseOperand->getType();
-  if (auto conv = dyn_cast_or_null<xla_hlo::ConvOp>(op)) {
-    workload = convWorkload(conv);
-  } else if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
+  if (auto shapedType = resultType.dyn_cast<ShapedType>()) {
     if (!shapedType.hasStaticShape()) {
       op->emitOpError() << "Dynamic shapes not yet supported";
       return nullptr;
     }
     auto shape = shapedType.getShape();
-    // Drop the trailing ones from the shape.
-    while (shape.size() > 1 && shape.back() == 1) {
-      shape = shape.drop_back();
-    }
-    if (shape.size() <= 3) {
-      // Maps to XYZ (possibly with 1's for unused dimensions).
-      for (auto dim : enumerate(shape)) {
-        workload[shape.size() - 1 - dim.index()] = dim.value();
+    if (auto conv = dyn_cast_or_null<xla_hlo::ConvOp>(op)) {
+      workload[2] =
+          shape[conv.dimension_numbers().output_batch_dimension().getInt()];
+      int i = 0;
+      for (const auto &dim : conv.dimension_numbers()
+                                 .output_spatial_dimensions()
+                                 .getIntValues()) {
+        if (i > 1) {
+          break;
+        }
+        workload[1 - i++] = shape[dim.getSExtValue()];
       }
     } else {
-      // Need to flatten the shape to fit XYZ. For now we just squash from LHS.
-      workload[2] = 1;
-      for (int i = 0; i < shape.size() - 2; ++i) {
-        workload[2] *= shape[i];
+      // Drop the trailing ones from the shape.
+      while (shape.size() > 1 && shape.back() == 1) {
+        shape = shape.drop_back();
       }
-      workload[1] = shape[shape.size() - 2];
-      workload[0] = shape.back();
+      if (shape.size() <= 3) {
+        // Maps to XYZ (possibly with 1's for unused dimensions).
+        for (auto dim : enumerate(shape)) {
+          workload[shape.size() - 1 - dim.index()] = dim.value();
+        }
+      } else {
+        // Need to flatten the shape to fit XYZ. For now we just squash from
+        // LHS.
+        workload[2] = 1;
+        for (int i = 0; i < shape.size() - 2; ++i) {
+          workload[2] *= shape[i];
+        }
+        workload[1] = shape[shape.size() - 2];
+        workload[0] = shape.back();
+      }
     }
   }