Simplify the conv workgroup calculation by using the output shape rather than recomputing from input shapes. PiperOrigin-RevId: 284052397
diff --git a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp index 0773784..6434866 100644 --- a/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp +++ b/iree/compiler/Dialect/Flow/Utils/WorkloadUtils.cpp
@@ -28,44 +28,6 @@ namespace IREE { namespace Flow { -std::array<int32_t, 3> convWorkload(xla_hlo::ConvOp conv) { - std::array<int32_t, 3> workload = {1, 1, 1}; - auto lhs = conv.lhs()->getType().cast<ShapedType>(); - auto rhs = conv.rhs()->getType().cast<ShapedType>(); - std::array<int32_t, 3> lhs_hw = {1, 1}; - int i = 0; - for (const auto &spatial : - conv.dimension_numbers().input_spatial_dimensions()) { - if (i > 1) { - break; - } - lhs_hw[i++] = lhs.getDimSize(spatial.getSExtValue()); - } - std::array<int32_t, 2> rhs_hw = {1, 1}; - i = 0; - for (const auto &spatial : - conv.dimension_numbers().kernel_spatial_dimensions()) { - if (i > 1) { - break; - } - rhs_hw[i++] = rhs.getDimSize(spatial.getSExtValue()); - } - std::array<int32_t, 2> padding = {0, 0}; - i = 0; - for (const auto &pad : conv.padding().getValue().getIntValues()) { - if (i > 3) { - break; - } - padding[i++ / 2] += pad.getSExtValue(); - } - // TODO(namiller): Generalize for other ranks and strides once supported. - workload[2] = - lhs.getDimSize(conv.dimension_numbers().input_batch_dimension().getInt()); - workload[1] = lhs_hw[0] - rhs_hw[0] + padding[0] + 1; - workload[0] = lhs_hw[1] - rhs_hw[1] + padding[1] + 1; - return workload; -} - Value *calculateWorkload(Operation *op, ShapedType baseOperandType) { OpBuilder builder(op); @@ -76,10 +38,19 @@ op->emitOpError() << "Dynamic shapes not yet supported"; return nullptr; } - if (auto conv = llvm::dyn_cast_or_null<xla_hlo::ConvOp>(op)) { - workload = convWorkload(conv); + auto shape = baseOperandType.getShape(); + if (auto conv = dyn_cast_or_null<xla_hlo::ConvOp>(op)) { + workload[2] = + shape[conv.dimension_numbers().output_batch_dimension().getInt()]; + int i = 0; + for (const auto &dim : + conv.dimension_numbers().output_spatial_dimensions().getIntValues()) { + if (i > 1) { + break; + } + workload[1 - i++] = shape[dim.getSExtValue()]; + } } else { - auto shape = baseOperandType.getShape(); // Drop the trailing ones from the shape. while (shape.size() > 1 && shape.back() == 1) { shape = shape.drop_back();
diff --git a/iree/compiler/Utils/DispatchUtils.cpp b/iree/compiler/Utils/DispatchUtils.cpp index 436015c..9cf84df 100644 --- a/iree/compiler/Utils/DispatchUtils.cpp +++ b/iree/compiler/Utils/DispatchUtils.cpp
@@ -37,44 +37,6 @@ namespace mlir { namespace iree_compiler { -std::array<int32_t, 3> convWorkload(xla_hlo::ConvOp conv) { - std::array<int32_t, 3> workload = {1, 1, 1}; - auto lhs = conv.lhs()->getType().cast<ShapedType>(); - auto rhs = conv.rhs()->getType().cast<ShapedType>(); - std::array<int32_t, 3> lhs_hw = {1, 1}; - int i = 0; - for (const auto &spatial : - conv.dimension_numbers().input_spatial_dimensions()) { - if (i > 1) { - break; - } - lhs_hw[i++] = lhs.getDimSize(spatial.getSExtValue()); - } - std::array<int32_t, 2> rhs_hw = {1, 1}; - i = 0; - for (const auto &spatial : - conv.dimension_numbers().kernel_spatial_dimensions()) { - if (i > 1) { - break; - } - rhs_hw[i++] = rhs.getDimSize(spatial.getSExtValue()); - } - std::array<int32_t, 2> padding = {0, 0}; - i = 0; - for (const auto &pad : conv.padding().getValue().getIntValues()) { - if (i > 3) { - break; - } - padding[i++ / 2] += pad.getSExtValue(); - } - // TODO(namiller): Generalize for other ranks and strides once supported. - workload[2] = - lhs.getDimSize(conv.dimension_numbers().input_batch_dimension().getInt()); - workload[1] = lhs_hw[0] - rhs_hw[0] + padding[0] + 1; - workload[0] = lhs_hw[1] - rhs_hw[1] + padding[1] + 1; - return workload; -} - Value *calculateWorkload(Operation *op, Value *baseOperand) { OpBuilder builder(op); @@ -82,31 +44,44 @@ // TODO(b/139353314): lookup/calculate based on type/etc. auto resultType = baseOperand->getType(); - if (auto conv = dyn_cast_or_null<xla_hlo::ConvOp>(op)) { - workload = convWorkload(conv); - } else if (auto shapedType = resultType.dyn_cast<ShapedType>()) { + if (auto shapedType = resultType.dyn_cast<ShapedType>()) { if (!shapedType.hasStaticShape()) { op->emitOpError() << "Dynamic shapes not yet supported"; return nullptr; } auto shape = shapedType.getShape(); - // Drop the trailing ones from the shape. - while (shape.size() > 1 && shape.back() == 1) { - shape = shape.drop_back(); - } - if (shape.size() <= 3) { - // Maps to XYZ (possibly with 1's for unused dimensions). - for (auto dim : enumerate(shape)) { - workload[shape.size() - 1 - dim.index()] = dim.value(); + if (auto conv = dyn_cast_or_null<xla_hlo::ConvOp>(op)) { + workload[2] = + shape[conv.dimension_numbers().output_batch_dimension().getInt()]; + int i = 0; + for (const auto &dim : conv.dimension_numbers() + .output_spatial_dimensions() + .getIntValues()) { + if (i > 1) { + break; + } + workload[1 - i++] = shape[dim.getSExtValue()]; } } else { - // Need to flatten the shape to fit XYZ. For now we just squash from LHS. - workload[2] = 1; - for (int i = 0; i < shape.size() - 2; ++i) { - workload[2] *= shape[i]; + // Drop the trailing ones from the shape. + while (shape.size() > 1 && shape.back() == 1) { + shape = shape.drop_back(); } - workload[1] = shape[shape.size() - 2]; - workload[0] = shape.back(); + if (shape.size() <= 3) { + // Maps to XYZ (possibly with 1's for unused dimensions). + for (auto dim : enumerate(shape)) { + workload[shape.size() - 1 - dim.index()] = dim.value(); + } + } else { + // Need to flatten the shape to fit XYZ. For now we just squash from + // LHS. + workload[2] = 1; + for (int i = 0; i < shape.size() - 2; ++i) { + workload[2] *= shape[i]; + } + workload[1] = shape[shape.size() - 2]; + workload[0] = shape.back(); + } } }