[Codegen] Add `DeviceMappingAttr` that maps to workgroup IDs. (#18264)

This adds a `DeviceMappingAttr` that maps to
`hal.interface.workgroup.id`s, but for use as mapping attribute with
`scf.forall`. The existing attributes upstream dont seem to match the
behavior one would expect for mapping to GPU dimensions. Overtime this
should be upstreamed to MLIR (and deprecate what exists today).

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel
index 21fe438..49db376 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/BUILD.bazel
@@ -38,6 +38,7 @@
         "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles",
         "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
         "@llvm-project//mlir:OpBaseTdFiles",
+        "@llvm-project//mlir:SCFTdFiles",
         "@llvm-project//mlir:SideEffectInterfacesTdFiles",
         "@llvm-project//mlir:ViewLikeInterfaceTdFiles",
     ],
@@ -94,6 +95,7 @@
         "@llvm-project//mlir:InferTypeOpInterface",
         "@llvm-project//mlir:MemRefDialect",
         "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformDialect",
         "@llvm-project//mlir:TransformDialectTransforms",
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt
index 11a3e0c..bc40433 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/CMakeLists.txt
@@ -56,6 +56,7 @@
     MLIRInferTypeOpInterface
     MLIRMemRefDialect
     MLIRParser
+    MLIRSCFDialect
     MLIRSupport
     MLIRTransformDialect
     MLIRTransformDialectTransforms
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
index f803103..b818ac8 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp
@@ -364,6 +364,90 @@
 }
 
 //===----------------------------------------------------------------------===//
+// iree_codegen.workgroup_mapping
+//===----------------------------------------------------------------------===//
+
+WorkgroupMappingAttr WorkgroupMappingAttr::get(MLIRContext *context,
+                                               WorkgroupId id) {
+  return WorkgroupMappingAttr::get(context, id, /*delinearizedDim=*/0);
+}
+
+LogicalResult
+WorkgroupMappingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                             WorkgroupId id, int64_t delinearizedDim) {
+  if (delinearizedDim > 0 && id != WorkgroupId::IdZ) {
+    return emitError() << "illegal to use `delinearizationDim` for "
+                       << stringifyWorkgroupId(id);
+  }
+  return success();
+}
+
+LogicalResult WorkgroupMappingAttr::verifyAttrList(
+    MLIRContext *context, function_ref<InFlightDiagnostic()> emitError,
+    ArrayRef<Attribute> attrs) {
+  if (attrs.empty()) {
+    return success();
+  }
+  SmallVector<IREE::Codegen::WorkgroupMappingAttr> mappingAttrs;
+  llvm::SmallDenseSet<IREE::Codegen::WorkgroupMappingAttr, 4> attrSet;
+  for (auto attr : attrs) {
+    auto typedAttr =
+        ::mlir::dyn_cast_or_null<IREE::Codegen::WorkgroupMappingAttr>(attr);
+    if (!attr) {
+      return emitError() << "expected all the mapping attribute to be of "
+                            "`WorkgroupMappingAttr` type";
+    }
+    if (attrSet.contains(typedAttr)) {
+      return emitError() << "illegal to repeat mapping specification";
+    }
+    attrSet.insert(typedAttr);
+    mappingAttrs.push_back(typedAttr);
+  }
+
+  llvm::sort(mappingAttrs, [](const IREE::Codegen::WorkgroupMappingAttr &lhs,
+                              const IREE::Codegen::WorkgroupMappingAttr &rhs) {
+    if (lhs.getId() != rhs.getId()) {
+      return lhs.getId() < rhs.getId();
+    }
+    assert(lhs.getId() == IREE::Codegen::WorkgroupId::IdZ);
+    return lhs.getDelinearizedDim() < rhs.getDelinearizedDim();
+  });
+  // First element has to be `workgroup_mapping<x>`.
+  if (mappingAttrs.front().getId() != IREE::Codegen::WorkgroupId::IdX) {
+    return emitError() << "missing `workgroup_mapping<x>`";
+  }
+  if (mappingAttrs.size() == 1) {
+    return success();
+  }
+
+  if (mappingAttrs[1].getId() != IREE::Codegen::WorkgroupId::IdY) {
+    return emitError() << "missing `workgroup_mapping<y>`";
+  }
+  if (mappingAttrs.size() == 2) {
+    return success();
+  }
+
+  auto mappingAttrsRef = ArrayRef<IREE::Codegen::WorkgroupMappingAttr>(
+      mappingAttrs.begin(), mappingAttrs.end());
+  for (auto [index, attr] : llvm::enumerate(mappingAttrsRef.drop_front(2))) {
+    if (attr.getDelinearizedDim() != index) {
+      return emitError() << "missing `workgroup_mapping<z:" << index;
+    }
+  }
+  return success();
+}
+
+int64_t WorkgroupMappingAttr::getMappingId() const {
+  return llvm::to_underlying(getId()) + getDelinearizedDim();
+}
+
+bool WorkgroupMappingAttr::isLinearMapping() const { return false; }
+
+int64_t WorkgroupMappingAttr::getRelativeIndex() const {
+  return getMappingId();
+}
+
+//===----------------------------------------------------------------------===//
 // Initialize attributes
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
index 71d54e8..ebf7e3c 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h
@@ -12,6 +12,7 @@
 
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
index cef0735..95ea6e5 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td
@@ -9,8 +9,13 @@
 
 include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.td"
 include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td"
+include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/IR/EnumAttr.td"
 
+//===---------------------------------------------------------------------===//
+// Pass pipelines
+//===---------------------------------------------------------------------===//
+
 // List of pre-existing pipelines for translating executables.
 def CPU_Default
     : I32EnumAttrCase<"CPUDefault", 0>;
@@ -126,6 +131,73 @@
   let assemblyFormat = "``$value";
 }
 
+//===---------------------------------------------------------------------===//
+// IREE Codegen workgroup mapping attributes
+//===---------------------------------------------------------------------===//
+
+def IdX : I64EnumAttrCase<"IdX", 0, "x">;
+def IdY : I64EnumAttrCase<"IdY", 1, "y">;
+def IdZ : I64EnumAttrCase<"IdZ", 2, "z">;
+
+def WorkgroupIdEnum :
+    I64EnumAttr<"WorkgroupId", "Attribute that map to hal.workgrpoup.ids", [
+      IdX, IdY, IdZ]> {
+  let cppNamespace = "::mlir::iree_compiler::IREE::Codegen";
+}
+
+def WorkgroupMappingAttr :
+    AttrDef<IREECodegen_Dialect, "WorkgroupMapping", [
+      DeclareAttrInterfaceMethods<DeviceMappingAttrInterface>]> {
+  let mnemonic = "workgroup_mapping";
+
+  let parameters = (ins
+    EnumParameter<WorkgroupIdEnum>:$id,
+    DefaultValuedParameter<"int64_t", "0">:$delinearizedDim
+  );
+  let assemblyFormat = "`<` $id (`` `:` `` $delinearizedDim^)? `>`";
+
+  let description = [{
+    Attribute that eventually will be used to map distributed loop iterations
+    to `hal.workgroup.id`s.
+
+    The `x`,`y` and `z` values for `id` map to `hal.workgroup.id[0]`,
+    `hal.workgroup.id[1]` and `hal.workgroup.id[2]` respectively.
+
+    In addition it is possible to specify if the `z` dimension is to be
+    delinearized on mapping. For example if the list of mapping attributes is
+    `[workgroup_mapping<z:1>, workgroup_mapping<z:0>]`, then the `z` dimension
+    is delinearized to map to `workgroup_mapping<z:1>` and
+    `workgroup_mapping<z:0>`. In other words if the number of logical parallel
+    workers along the `z:0` dimension is `W`, then
+    ```
+    workgroup_mapping<z:0> = hal.workgroup.id[1] mod W,
+    worgrkoup_mapping<z:1> = hal.workgroup.id[1] div W
+    ```
+
+    Note: It is expected that this attribute is always used in a list of mapping
+    attributes (with a single element being a list of size 1). It is illegal for
+    a list to have `workgroup_mapping<z:a>` without `workgroup_mapping<z:b>`
+    if `a > b`. In the same way it is illegal to for the list to
+    - have `workgroup_mapping<y>` but not `workgroup_mapping<x>`
+    - have `workgroup_mapping<z:*>`  but not have `workgroup_mapping<x>`
+      and `workgroup_mapping<y>`
+  }];
+  let builders = [
+    AttrBuilder<(ins "WorkgroupId":$id)>
+  ];
+  let extraClassDeclaration = [{
+    // Checks that a list of attributes is well-defined.
+    static LogicalResult verifyAttrList(::mlir::MLIRContext *context,
+      ::llvm::function_ref<::mlir::InFlightDiagnostic ()> emitError,
+      ArrayRef<Attribute> attrs);
+  }];
+
+  let genVerifyDecl = 1;
+}
+
+//===---------------------------------------------------------------------===//
+// iree_codegen.translation_info
+//===---------------------------------------------------------------------===//
 
 def IREECodegen_TranslationInfoAttr :
     AttrDef<IREECodegen_Dialect, "TranslationInfo", []> {
@@ -183,6 +255,10 @@
   let genVerifyDecl = 1;
 }
 
+//===---------------------------------------------------------------------===//
+// iree_codegen.lowering_config
+//===---------------------------------------------------------------------===//
+
 def IREECodegen_LoweringConfigTilingLevelAttr :
   AttrDef<IREECodegen_Dialect, "LoweringConfigTilingLevel", []>
 {
@@ -278,6 +354,10 @@
   let genVerifyDecl = 1;
 }
 
+//===---------------------------------------------------------------------===//
+// iree_codegen.compilation_info
+//===---------------------------------------------------------------------===//
+
 def IREECodegen_CompilationInfoAttr :
     AttrDef<IREECodegen_Dialect, "CompilationInfo", []> {
   let mnemonic = "compilation_info";
@@ -309,6 +389,10 @@
   let genVerifyDecl = 1;
 }
 
+//===---------------------------------------------------------------------===//
+// iree_codegen.export_config
+//===---------------------------------------------------------------------===//
+
 def IREECodegen_ExportConfig : AttrDef<IREECodegen_Dialect, "ExportConfig", []> {
   let mnemonic = "export_config";
   let summary = "User defined workgroup size specification";
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel
index 1f335c8..fa66184 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/BUILD.bazel
@@ -22,6 +22,7 @@
             "lowering_config_attr.mlir",
             "ukernel_ops.mlir",
             "ukernel_ops_cse.mlir",
+            "workgroup_mapping_attrs.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/CMakeLists.txt
index 917dc6e..feb809a 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/CMakeLists.txt
@@ -18,6 +18,7 @@
     "lowering_config_attr.mlir"
     "ukernel_ops.mlir"
     "ukernel_ops_cse.mlir"
+    "workgroup_mapping_attrs.mlir"
   TOOLS
     FileCheck
     iree-opt
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/workgroup_mapping_attrs.mlir b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/workgroup_mapping_attrs.mlir
new file mode 100644
index 0000000..ef96a73
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/workgroup_mapping_attrs.mlir
@@ -0,0 +1,37 @@
+// RUN: iree-opt --split-input-file --verify-diagnostics --allow-unregistered-dialect %s | FileCheck %s
+
+func.func @roundtrip() {
+  "dummy.op"() {
+    workgroup_mapping = [
+      #iree_codegen.workgroup_mapping<x>,
+      #iree_codegen.workgroup_mapping<y>,
+      #iree_codegen.workgroup_mapping<z:0>,
+      #iree_codegen.workgroup_mapping<z:1>]
+  } : () -> ()
+  return
+}
+// CHECK-LABEL: func @roundtrip()
+//       CHECK:   #iree_codegen.workgroup_mapping<x>
+//  CHECK-SAME:   #iree_codegen.workgroup_mapping<y>
+//  CHECK-SAME:   #iree_codegen.workgroup_mapping<z>
+//  CHECK-SAME:   #iree_codegen.workgroup_mapping<z:1>
+
+// -----
+
+func.func @illegal_x_linearized_dim() {
+  "dummy.op"() {
+    // expected-error @+1 {{illegal to use `delinearizationDim` for x}}
+    workgroup_mapping = #iree_codegen.workgroup_mapping<x:1>
+  } : () -> ()
+  return
+}
+
+// -----
+
+func.func @illegal_y_linearized_dim() {
+  "dummy.op"() {
+    // expected-error @+1 {{illegal to use `delinearizationDim` for y}}
+    workgroup_mapping = #iree_codegen.workgroup_mapping<y:1>
+  } : () -> ()
+  return
+}