[EmitC] Use signature conversions to drop block arguments with ref type (#18508)

Block arguments are emitted as additional local variables in the EmitC
emitter and branch operands get transferred to these with assignments.
Thus block arguments with `VM::RefType` need to be specifically handled
to correctly update their ref count and therefore need to be deleted
from the blocks.

Before this change these block arguments would survive the dialect
conversion and were manually removed in a cleanup walk over the IR.

Now these arguments get directly dropped during the dialect conversion,
which saves a walk over the IR and more importantly removes state from
the type converter.

---------

Signed-off-by: Simon Camphausen <simon.camphausen@iml.fraunhofer.de>
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index 1255bb9..71eaea6 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -46,8 +46,7 @@
 };
 
 LogicalResult convertFuncOp(IREE::VM::FuncOp funcOp,
-                            const IREE::VM::EmitCTypeConverter &typeConverter,
-                            SmallVector<BlockArgument> &blockArgsToRemove) {
+                            const IREE::VM::EmitCTypeConverter &typeConverter) {
   auto ctx = funcOp.getContext();
   auto loc = funcOp.getLoc();
 
@@ -131,15 +130,6 @@
     funcAnalysis.cacheLocalRef(i + numRefArgs, refPtr);
   }
 
-  for (Block &block : llvm::drop_begin(newFuncOp.getBlocks(), 1)) {
-    for (BlockArgument blockArg : block.getArguments()) {
-      if (!llvm::isa<IREE::VM::RefType>(blockArg.getType())) {
-        continue;
-      }
-      blockArgsToRemove.push_back(blockArg);
-    }
-  }
-
   if (failed(
           funcOp.replaceAllSymbolUses(builder.getStringAttr(name), moduleOp)))
     return funcOp.emitError() << "unable to update symbol name in module";
@@ -147,21 +137,6 @@
   return success();
 }
 
-/// Remove block arguments
-LogicalResult
-removeBlockArguments(IREE::VM::ModuleOp moduleOp,
-                     SmallVector<BlockArgument> &blockArgsToRemove) {
-  for (auto &blockArg : blockArgsToRemove) {
-    assert(isa<IREE::VM::RefType>(blockArg.getType()));
-    assert(blockArg.use_empty());
-    Block *block = blockArg.getOwner();
-
-    block->eraseArgument(blockArg.getArgNumber());
-  }
-
-  return success();
-}
-
 std::optional<std::string> buildFunctionName(IREE::VM::ModuleOp &moduleOp,
                                              IREE::VM::ImportOp &importOp) {
   auto callingConvention = makeImportCallingConventionString(importOp);
@@ -1557,23 +1532,61 @@
   LogicalResult
   matchAndRewrite(mlir::emitc::FuncOp funcOp, Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    TypeConverter::SignatureConversion signatureConverter(
-        funcOp.getFunctionType().getNumInputs());
-    for (const auto &arg : llvm::enumerate(funcOp.getArguments())) {
-      Type convertedType =
-          getTypeConverter()->convertType(arg.value().getType());
-      signatureConverter.addInputs(arg.index(), convertedType);
+    // Entry block arguments, i.e. function arguments get converted 1:1.
+    // VM::RefType arguments get replaced by iree_vm_ref_t*.
+    {
+      Block &block = funcOp.getBlocks().front();
+      TypeConverter::SignatureConversion signatureConversion(
+          block.getNumArguments());
+
+      for (const auto &[index, arg] : llvm::enumerate(block.getArguments())) {
+        Type convertedType = getTypeConverter()->convertType(arg.getType());
+        signatureConversion.addInputs(index, convertedType);
+      }
+
+      rewriter.applySignatureConversion(&block, signatureConversion);
+
+      rewriter.modifyOpInPlace(funcOp, [&] {
+        funcOp.setType(
+            rewriter.getFunctionType(signatureConversion.getConvertedTypes(),
+                                     funcOp.getFunctionType().getResults()));
+      });
     }
 
-    rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(),
-                                      signatureConverter);
+    // Non-entry block arguments are handled differently between numeric types
+    // and VM::RefType.
+    {
+      for (Block &block : llvm::make_early_inc_range(
+               llvm::drop_begin(funcOp.getBlocks(), 1))) {
+        TypeConverter::SignatureConversion signatureConversion(
+            block.getNumArguments());
 
-    // Creates a new function with the updated signature.
-    rewriter.modifyOpInPlace(funcOp, [&] {
-      funcOp.setType(
-          rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
-                                   funcOp.getFunctionType().getResults()));
-    });
+        for (const auto &[index, arg] : llvm::enumerate(block.getArguments())) {
+          if (isa<IREE::VM::RefType>(arg.getType())) {
+            // VM::RefType arguments are dropped and their uses are replaced.
+            // The replacement values are determined by the register allocation
+            // pass.
+            Value ref = getModuleAnalysis().lookupRef(arg);
+            signatureConversion.remapInput(index, ref);
+          } else {
+            // Numerically typed arguments are kept as block arguments. These
+            // are automatically handled later in the emitter.
+            signatureConversion.addInputs(index, arg.getType());
+          }
+        }
+
+        Block *newBlock =
+            rewriter.applySignatureConversion(&block, signatureConversion);
+
+        // The signatureConversion stores a mapping from the original block
+        // argument index to the replacement value. This information is needed
+        // in the conversion of branch ops to correctly map from branch operands
+        // to the replacement values.
+        getModuleAnalysis().lookupFunction(funcOp).cacheBlockConversion(
+            newBlock, signatureConversion);
+      }
+    }
+
     return success();
   }
 };
@@ -2084,8 +2097,8 @@
     }
 
     builder.setInsertionPointToEnd(condBlock);
-    builder.create<IREE::VM::CondBranchOp>(location, conditionI1, failureBlock,
-                                           continuationBlock);
+    builder.create<cf::CondBranchOp>(location, conditionI1, failureBlock,
+                                     continuationBlock);
 
     builder.setInsertionPointToStart(continuationBlock);
   }
@@ -3142,21 +3155,27 @@
 
     Block *destDispatch;
     {
+      auto funcOp =
+          op.getOperation()->template getParentOfType<mlir::emitc::FuncOp>();
+      auto &funcAnalysis = getModuleAnalysis().lookupFunction(funcOp);
+      auto &signatureConversion = funcAnalysis.lookupBlockConversion(dest);
+
       OpBuilder::InsertionGuard guard(rewriter);
       destDispatch = rewriter.createBlock(dest);
 
       IRMapping refMapping;
-      for (auto [operand, blockArg] :
-           llvm::zip_equal(op.getOperands(), dest->getArguments())) {
+      for (auto [index, operand] : llvm::enumerate(op.getOperands())) {
         if (isNotRefOperand(operand)) {
           continue;
         }
 
+        Value blockArgRef =
+            signatureConversion.getInputMapping(index)->replacementValue;
+
         assert(isa<IREE::VM::RefType>(operand.getType()));
-        assert(isa<IREE::VM::RefType>(blockArg.getType()));
+        assert(isa<emitc::PointerType>(blockArgRef.getType()));
 
         Value operandRef = getModuleAnalysis().lookupRef(operand);
-        Value blockArgRef = getModuleAnalysis().lookupRef(blockArg);
 
         refMapping.map(operandRef, blockArgRef);
       }
@@ -4303,8 +4322,8 @@
     }
 
     rewriter.setInsertionPointToEnd(condBlock);
-    rewriter.create<IREE::VM::CondBranchOp>(loc, invalidType, failureBlock,
-                                            continuationBlock);
+    rewriter.create<cf::CondBranchOp>(loc, invalidType, failureBlock,
+                                      continuationBlock);
 
     rewriter.replaceOp(getOp, ref);
 
@@ -4729,9 +4748,8 @@
     // reference emitc.func ops with the correct calling convention during the
     // conversion.
     SmallVector<IREE::VM::FuncOp> funcsToRemove;
-    SmallVector<BlockArgument> blockArgsToRemove;
     for (auto funcOp : module.getOps<IREE::VM::FuncOp>()) {
-      if (failed(convertFuncOp(funcOp, typeConverter, blockArgsToRemove))) {
+      if (failed(convertFuncOp(funcOp, typeConverter))) {
         return signalPassFailure();
       }
       funcsToRemove.push_back(funcOp);
@@ -4761,7 +4779,8 @@
 
     target.addDynamicallyLegalOp<mlir::emitc::FuncOp>(
         [&](mlir::emitc::FuncOp op) {
-          return typeConverter.isSignatureLegal(op.getFunctionType());
+          return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+                 typeConverter.isLegal(&op.getFunctionBody());
         });
 
     // Structural ops
@@ -4776,26 +4795,6 @@
       return signalPassFailure();
     }
 
-    // Remove unused block arguments from refs
-    if (failed(removeBlockArguments(module, blockArgsToRemove))) {
-      return signalPassFailure();
-    }
-
-    SetVector<Operation *> &materializations =
-        typeConverter.sourceMaterializations;
-
-    module.walk([&materializations](Operation *op) {
-      // Remove dead basic block arguments
-      if (materializations.contains(op)) {
-        assert(isa<emitc::VariableOp>(op));
-        assert(op->use_empty());
-
-        materializations.remove(op);
-        op->erase();
-        return;
-      }
-    });
-
     if (failed(createModuleStructure(module, typeConverter))) {
       return signalPassFailure();
     }
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.cpp
index 608bbe1..195feba 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.cpp
@@ -28,30 +28,6 @@
     auto input = cast<TypedValue<IREE::VM::RefType>>(inputs[0]);
     return analysis.lookupRef(input);
   });
-
-  // We need a source materialization for refs because after running
-  // `applyFullConversion` there would be references to the original
-  // IREE::VM::Ref values in unused basic block arguments. As these are unused
-  // anyway we create dummy ops which get deleted after the conversion has
-  // finished.
-  addSourceMaterialization([this](OpBuilder &builder, IREE::VM::RefType type,
-                                  ValueRange inputs, Location loc) -> Value {
-    assert(inputs.size() == 1);
-    assert(isa<emitc::PointerType>(inputs[0].getType()));
-
-    Type objectType = IREE::VM::OpaqueType::get(builder.getContext());
-    Type refType = IREE::VM::RefType::get(objectType);
-
-    auto ctx = builder.getContext();
-    auto op = builder.create<emitc::VariableOp>(
-        /*location=*/loc,
-        /*resultType=*/refType,
-        /*value=*/emitc::OpaqueAttr::get(ctx, ""));
-
-    sourceMaterializations.insert(op.getOperation());
-
-    return op.getResult();
-  });
 }
 
 Type EmitCTypeConverter::convertTypeAsNonPointer(Type type) const {
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h
index 9ac4a1d..1c5343f 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/EmitCTypeConverter.h
@@ -25,7 +25,6 @@
   emitc::PointerType convertTypeAsPointer(Type type) const;
   emitc::OpaqueType convertTypeAsCType(Type type) const;
 
-  SetVector<Operation *> sourceMaterializations;
   mutable ModuleAnalysis analysis;
 };
 
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h
index cbfcc92..a99b991 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/VMAnalysis.h
@@ -16,6 +16,7 @@
 #include "iree/compiler/Dialect/VM/Utils/CallingConvention.h"
 #include "iree/compiler/Dialect/VM/Utils/TypeTable.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir::iree_compiler::IREE::VM {
 
@@ -34,6 +35,7 @@
     originalFunctionType = funcOp.getFunctionType();
     callingConvention = makeCallingConventionString(funcOp).value();
     refs = DenseMap<int64_t, Value>{};
+    blockMap = DenseMap<Block *, TypeConverter::SignatureConversion>{};
   }
   FuncAnalysis(mlir::emitc::FuncOp funcOp) {
     originalFunctionType = funcOp.getFunctionType();
@@ -107,12 +109,28 @@
     refs.value()[ordinal] = ref;
   }
 
+  void cacheBlockConversion(Block *block,
+                            TypeConverter::SignatureConversion conversion) {
+    assert(blockMap.has_value());
+    assert(!blockMap.value().count(block) &&
+           "block conversion was already cached");
+    blockMap.value().try_emplace(block, conversion);
+  }
+
   Value lookupLocalRef(int64_t ordinal) {
     assert(refs.has_value());
     assert(refs.value().count(ordinal) && "ref not found in cache");
     return refs.value()[ordinal];
   }
 
+  const TypeConverter::SignatureConversion &
+  lookupBlockConversion(Block *block) const {
+    assert(blockMap.has_value());
+    assert(blockMap.value().count(block) &&
+           "block conversion not found in cache");
+    return blockMap.value().at(block);
+  }
+
   bool hasLocalRefs() { return refs.has_value(); }
 
   DenseMap<int64_t, Value> &localRefs() {
@@ -128,6 +146,7 @@
   std::optional<std::string> callingConvention;
   std::optional<std::string> exportName;
   std::optional<bool> emitAtEnd;
+  std::optional<DenseMap<Block *, TypeConverter::SignatureConversion>> blockMap;
 };
 
 struct ModuleAnalysis {