Change VM dialect to generate and use prefixed accessors (#11444)
This is in preparation for upstream deprecation of raw accessor formats.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index 24a1a85..c1ea014 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -4082,7 +4082,7 @@
return setOp.emitError() << "parent func op not found in cache.";
}
bool move =
- vmAnalysis.value().get().isMove(setOp.value(), setOp.getOperation());
+ vmAnalysis.value().get().isMove(setOp.getValue(), setOp.getOperation());
StringRef callee =
move ? "iree_vm_list_set_ref_move" : "iree_vm_list_set_ref_retain";
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td
index ee05bb1..1255449 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td
@@ -16,10 +16,7 @@
def VM_Dialect : Dialect {
let name = "vm";
let cppNamespace = "::mlir::iree_compiler::IREE::VM";
- // TODO(benvanik): change to kEmitAccessorPrefix_Prefixed once the op encoder
- // tablegen goo supports it. Currently it requires accessors with the same
- // name as the ODS fields.
- let emitAccessorPrefix = kEmitAccessorPrefix_Both;
+ let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
let summary = [{
A dialect representing operations against an abstract virtual machine.
@@ -90,8 +87,13 @@
// return success();
// }
-class VM_EncEncodeExpr<code evalExpr> {
+class VM_EncEncodeExpr<code evalExpr, list<string> parameters = []> {
+ // A code snippet that can potentially contain placeholders (e.g., `{0}`).
code expr = evalExpr;
+ // If so, `params` are expected to provide the substitutions for these
+ // placeholders. TableGen backends can process the parameters before
+ // plugging them in though.
+ list<string> params = parameters;
}
class VM_EncConstI8<int value> : VM_EncEncodeExpr<
@@ -101,15 +103,15 @@
VM_OPC opcode = thisOpcode;
}
class VM_EncFuncAttr<string name> : VM_EncEncodeExpr<
- "e.encodeSymbolOrdinal(syms, " # name # "())">;
+ "e.encodeSymbolOrdinal(syms, {0}())", [name]>;
class VM_EncGlobalAttr<string name> : VM_EncEncodeExpr<
- "e.encodeSymbolOrdinal(syms, " # name # "())">;
+ "e.encodeSymbolOrdinal(syms, {0}())", [name]>;
class VM_EncRodataAttr<string name> : VM_EncEncodeExpr<
- "e.encodeSymbolOrdinal(syms, " # name # "())">;
+ "e.encodeSymbolOrdinal(syms, {0}())", [name]>;
class VM_EncType<string expr> : VM_EncEncodeExpr<
"e.encodeType(" # expr # ")">;
class VM_EncTypeOf<string name> : VM_EncEncodeExpr<
- "e.encodeType(" # name # "())">;
+ "e.encodeType({0}())", [name]>;
class VM_EncPrimitiveAttr<string name, int thisBitwidth> : VM_EncEncodeExpr<
"e.encodePrimitiveAttr(getOperation()->getAttrOfType<Attribute>(\"" # name # "\"))"> {
int bitwidth = thisBitwidth;
@@ -121,15 +123,15 @@
class VM_EncStrAttr<string name> : VM_EncEncodeExpr<
"e.encodeStrAttr(getOperation()->getAttrOfType<StringAttr>(\"" # name # "\"))">;
class VM_EncBranch<string blockName, string operandsName, int successorIndex> : VM_EncEncodeExpr<
- "e.encodeBranch(" # blockName # "(), " # operandsName # "(), " # successorIndex # ")">;
+ "e.encodeBranch({0}(), " # operandsName # "(), " # successorIndex # ")", [blockName]>;
class VM_EncOperand<string name, int ordinal> : VM_EncEncodeExpr<
- "e.encodeOperand(" # name # "(), " # ordinal # ")">;
+ "e.encodeOperand({0}(), " # ordinal # ")", [name]>;
class VM_EncVariadicOperands<string name> : VM_EncEncodeExpr<
- "e.encodeOperands(" # name # "())">;
+ "e.encodeOperands({0}())", [name]>;
class VM_EncResult<string name> : VM_EncEncodeExpr<
- "e.encodeResult(" # name # "())">;
+ "e.encodeResult({0}())", [name]>;
class VM_EncVariadicResults<string name> : VM_EncEncodeExpr<
- "e.encodeResults(" # name # "())">;
+ "e.encodeResults({0}())", [name]>;
def VM_SerializableOpInterface : OpInterface<"VMSerializableOp"> {
let description = [{
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 11e4b2f..2c3c950 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -2090,23 +2090,25 @@
// * xor high bits of lhs and rhs
auto zero = rewriter.createOrFold<ConstFOp>(loc, 0);
auto lhsPositive =
- rewriter.createOrFold<CmpGTEFOp>(loc, i32Type, op.lhs(), zero);
+ rewriter.createOrFold<CmpGTEFOp>(loc, i32Type, op.getLhs(), zero);
auto rhsPositive =
- rewriter.createOrFold<CmpGTEFOp>(loc, i32Type, op.rhs(), zero);
+ rewriter.createOrFold<CmpGTEFOp>(loc, i32Type, op.getRhs(), zero);
auto signsNotEqual = rewriter.createOrFold<IREE::VM::CmpNEI32Op>(
loc, i32Type, lhsPositive, rhsPositive);
// If signs differ, perform a direct comparison of `lhs == rhs`.
auto *directComparisonBlock = rewriter.createBlock(continuationBlock);
auto exactEqual =
- rewriter.createOrFold<CmpEQFOp>(loc, i32Type, op.lhs(), op.rhs());
+ rewriter.createOrFold<CmpEQFOp>(loc, i32Type, op.getLhs(), op.getRhs());
rewriter.createOrFold<IREE::VM::BranchOp>(loc, continuationBlock,
exactEqual);
// ...else, perform a full ULP-based comparison.
auto *ulpComparisonBlock = rewriter.createBlock(continuationBlock);
- auto lhsInt = rewriter.createOrFold<BitcastFToIOp>(loc, i32Type, op.lhs());
- auto rhsInt = rewriter.createOrFold<BitcastFToIOp>(loc, i32Type, op.rhs());
+ auto lhsInt =
+ rewriter.createOrFold<BitcastFToIOp>(loc, i32Type, op.getLhs());
+ auto rhsInt =
+ rewriter.createOrFold<BitcastFToIOp>(loc, i32Type, op.getRhs());
auto signedUlpsDiff =
rewriter.createOrFold<SubIOp>(loc, i32Type, lhsInt, rhsInt);
auto absUlpsDiff =
@@ -2133,7 +2135,7 @@
template <typename T>
static OpFoldResult foldCmpEQNearOp(T op, ArrayRef<Attribute> operands) {
- if (op.lhs() == op.rhs()) {
+ if (op.getLhs() == op.getRhs()) {
// x ~ x = true
return oneOfType(op.getType());
}
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 0acbedf..ca46f84 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -405,7 +405,7 @@
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, global);
assert(symbolOp);
auto globalOp = dyn_cast<T>(symbolOp);
- if (globalOp.is_mutable()) {
+ if (globalOp.getIsMutable()) {
effects.emplace_back(MemoryEffects::Read::get());
}
}
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
index 66411ad..d40b2c4 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -1513,7 +1513,7 @@
let encoding = [
VM_EncOpcode<VM_OPC_ListAlloc>,
- VM_EncType<"result().getType().cast<IREE::VM::RefType>().getObjectType().cast<IREE::VM::ListType>().getElementType()">,
+ VM_EncType<"getResult().getType().cast<IREE::VM::RefType>().getObjectType().cast<IREE::VM::ListType>().getElementType()">,
VM_EncOperand<"initial_capacity", 0>,
VM_EncResult<"result">,
];
@@ -3584,8 +3584,8 @@
let encoding = [
VM_EncOpcode<VM_OPC_CondBranch>,
VM_EncOperand<"condition", 0>,
- VM_EncBranch<"getTrueDest", "getTrueOperands", 0>,
- VM_EncBranch<"getFalseDest", "getFalseOperands", 1>,
+ VM_EncBranch<"trueDest", "getTrueOperands", 0>,
+ VM_EncBranch<"falseDest", "getFalseOperands", 1>,
];
let builders = [
diff --git a/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp b/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp
index 9425b26..e16310c 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Tools/VMOpEncoderGen.cpp
@@ -5,10 +5,10 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
-#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
@@ -60,12 +60,25 @@
}
os << " if (";
- interleave(
- encodingExprs, os,
- [&](Record *encodingExpr) {
- os << formatv("failed({0})", encodingExpr->getValueAsString("expr"));
- },
- " ||\n ");
+ auto printOneCondition = [&](Record *encodingExpr) {
+ StringRef expr = encodingExpr->getValueAsString("expr");
+ std::vector<StringRef> params =
+ encodingExpr->getValueAsListOfStrings("params");
+ assert(params.size() <= 1);
+
+ // Note the following relies on the fact that only encoding expressions
+ // involving operands/results have one parameter. It's a bit inflexible,
+ // but it works for now and we can change when the extra flexibility is
+ // really needed.
+ std::string param;
+ if (params.size() == 1) {
+ param = "get" + llvm::convertToCamelFromSnakeCase(params.front(), true);
+ } else {
+ param = expr;
+ }
+ os << formatv("failed({0})", formatv(expr.data(), param));
+ };
+ interleave(encodingExprs, os, printOneCondition, " ||\n ");
os << ") {\n";
os << " return emitOpError() << \"failed to encode (internal)\";\n";
os << " }\n";