Begin implementation of the 'raw' function ABI.
* Adds signature support for it.
* Adds documentation.
* This is all very draft-y and open to further discussion.
PiperOrigin-RevId: 285480154
diff --git a/docs/function_abi.md b/docs/function_abi.md
index 1c0c8b5..0296a6a 100644
--- a/docs/function_abi.md
+++ b/docs/function_abi.md
@@ -24,10 +24,170 @@
## ABIs
+### Raw Function ABI
+
+All exported functions implement the raw function ABI, which defines the
+metadata and calling convention for marshalling inputs and results to their
+underlying implementations.
+
+*Attributes:*
+
+* `fv` = 1 (current version of the raw function ABI)
+* `f` = encoded raw function signature (see below)
+* `fbr` = result buffer allocation function name (optional)
+
+The reflection metadata documented here augments the underlying type system such
+that host language bindings can interop as needed. This additional metadata is
+needed in most dynamic cases because the compiled assets operate on fundamental
+types with most characteristics type erased away (think: `void*` level things vs
+high-level `ShapedBuffer` level things).
+
+#### Grammar
+
+The signature is implemented in terms of the SignatureBuilder, using tagged
+Integer and Spans.
+
+```text
+signature ::= 'I' length-prefixed(type-sequence)
+ 'R' length-prefixed(type-sequence)
+
+type-sequence ::= (arg-result-type)*
+arg-result-type ::= buffer-type | ref-object-type
+buffer-type ::= 'B' length-prefixed(scalar-type? dim*)
+scalar-type ::= 't' (
+ '0' # IEEE float32 (default if not specified)
+ | '1' # IEEE float16
+ | '2' # IEEE float64
+ | '3' # Google bfloat16
+ | '4' # Signed int8
+ | '5' # Signed int16
+ | '6' # Signed int32
+ | '7' # Signed int64
+ | '8' # Unsigned int8
+ | '9' # Unsigned int16
+ | '10' # Unsigned int32
+ | '11' # Unsigned int64
+ )
+dim :: = 'd' integer # -1 indicates a dynamic dim
+ref-object-type ::= 'O' length-prefixed() # Details TBD
+
+
+# Lexical primitives
+integer ::= -?[0-9]+
+length ::= [0-9]+
+# The `length` encodes the length in bytes of `production`, plus 1 for the '!'.
+length-prefixed(production) ::= length '!' production
+any-byte-sequence ::= <any byte sequence>
+```
+
+#### Interpretation and Rationale
+
+##### Memory layout
+
+The astute reader will note that the above metadata is insufficient to determine
+the memory layout of a buffer. The reason is that any more specific details than
+this (contiguity, strides, alignment, etc) can actually only be known once the
+actual compute devices have been enumerated and the resulting matrix of
+conversions is more dynamic than can be expressed in something as static as a
+function signature. The above formulation is an input to an additional runtime
+oracle which produces appropriate full buffer descriptions.
+
+While the exact implementation is host-language specific, consider the following
+more detailed set of declarations that may exist in such a binding layer:
+
+```c++
+// Inspired heavily by the Py_buffer type.
+// See: https://docs.python.org/3/c-api/buffer.html
+struct BufferDescription {
+ ScalarType element_type;
+ // For contiguous arrays, this is is the length of the underlying memory.
+ // For non-contiguous, this is the size of the buffer if it were copied
+ // to a contiguous representation.
+ size_t len;
+ // Number of dims and strides.
+ size_t ndim;
+ int* shape;
+ int* strides;
+};
+
+// Mirrors the 'buffer-type' production in the above grammar.
+struct SignatureBufferType;
+
+// Oracle which combines signature metadata with a user-provided, materialized
+// BufferDescription to derive a BufferDescription that is compatible for
+// invocation. Returns an updated buffer description if the original is
+// not compatible or fully specified.
+// This can be used in a couple of ways:
+// a) On function invocation to determine whether a provided buffer can be
+// used as-is or needs to be converted (copied).
+// b) To provide a factory function to the host language to create a
+// compatible buffer.
+optional<BufferDescription> BufferDescriptionOracle(
+ DeviceContext*, SignatureBufferType, BufferDescription)
+ throws UnsupportedBufferException;
+```
+
+The above scheme should allow host-language and device coordination with respect
+to buffer layout. For the moment, the responsibility to convert the buffer to a
+compatible memory layout is on the host-language binding. However, often it is
+the most efficient to schedule this for execution on a device. In the future, it
+is anticipated that there will be a built-in pathway for scheduling such a
+conversion (which would allow pipelinining and offload of buffer conversions).
+
+##### Deferred result allocation
+
+In general, exported functions accept pre-allocated results that should be
+mutated. For the simplest cases, such results can be `null` and retrieved upon
+completion of the function. This, however, puts severe limitations on the
+ability to pipeline. For fully specified signatures (no dynamic shapes), the
+`BufferDescriptionOracle` and the signature is sufficient to pre-allocate
+appropriate results, which allows chains of result-producing invocations to be
+pipelined.
+
+If, however, a `buffer-type` is not fully specified, the compiler may emit a
+special *result allocator* function, which will be referenced in the `fbr`
+attribute. Such a function would have a signature like this:
+
+```c++
+tuple<buffer> __allocate_results(tuple<int> dynamic_dims);
+```
+
+Such a function takes a tuple of all dynamic buffer dims in the function input
+signature and returns a tuple of allocated buffers for each dynamic result. Note
+that it may not be possible to fully allocate results in this fashion (i.e. if
+the result layout is data dependent), in which case a null buffer is returned
+for that slot (and the host library would need to await on the invocation to get
+the fully populated result).
+
+A similar mechanism will need to be created at some future point for
+under-specified results of other (non-buffer) types.
+
+##### Contiguity hinting
+
+Commonly in some kinds of dataflows, the compiler needs to be free to internally
+toggle buffer continuity (i.e. C/row-major, Fortran/col-major, etc). In many
+cases, such toggling does not naturally escape through the exported function
+boundaries, in which case, there is no ABI impact. However, it is anticipated
+that there is benefit to letting the toggle propagate through the exported ABI
+boundary, in which case, the `buffer-type` will likely be extended with a
+contiguity hint indicating the preference. When combined with the buffer
+description oracle and in-pipeline conversion features described above, this
+could yield a powerful mechanism for dynamically and efficiently managing such
+transitions.
+
+Such an enhancement would almost certainly necessitate a major version bump in
+the ABI and would be logical to implement once the advanced features above are
+functional.
+
### Structured Index Path ABI
-* `abi` = `sip`
-* Current `abiv` version = 1
+Functions may support the SIP ABI if their input and result tuples logically map
+onto "structures" (nested sequence/dicts).
+
+*Attributes:*
+
+* `sipv` = 1 (current SIP ABI version)
+* `sip` = encoded SIP signature (see below)
This ABI maps a raw, linear sequence of inputs and results onto an input and
result "structure" -- which in this context refers to a nested assembly of
diff --git a/iree/base/signature_mangle.cc b/iree/base/signature_mangle.cc
index 8ed9e2c..6aa041f 100644
--- a/iree/base/signature_mangle.cc
+++ b/iree/base/signature_mangle.cc
@@ -21,6 +21,30 @@
namespace iree {
// -----------------------------------------------------------------------------
+// AbiConstants
+// -----------------------------------------------------------------------------
+
+const std::array<size_t, 12> AbiConstants::kScalarTypeSize = {
+ 4, // kIeeeFloat32 = 0,
+ 2, // kIeeeFloat16 = 1,
+ 8, // kIeeeFloat64 = 2,
+ 2, // kGoogleBfloat16 = 3,
+ 1, // kSint8 = 4,
+ 2, // kSint16 = 5,
+ 4, // kSint32 = 6,
+ 8, // kSint64 = 7,
+ 1, // kUint8 = 8,
+ 2, // kUint16 = 9,
+ 4, // kUint32 = 10,
+ 8, // kUint64 = 11,
+};
+
+const std::array<const char*, 12> AbiConstants::kScalarTypeNames = {
+ "float32", "float16", "float64", "bfloat16", "sint8", "sint16",
+ "sint32", "sint64", "uint8", "uint16", "uint32", "uint64",
+};
+
+// -----------------------------------------------------------------------------
// SignatureBuilder and SignatureParser
// -----------------------------------------------------------------------------
@@ -114,6 +138,98 @@
}
// -----------------------------------------------------------------------------
+// RawSignatureMangler
+// -----------------------------------------------------------------------------
+
+SignatureBuilder RawSignatureMangler::ToFunctionSignature(
+ RawSignatureMangler& inputs, RawSignatureMangler& results) {
+ SignatureBuilder func_builder;
+ inputs.builder_.AppendTo(func_builder, 'I');
+ results.builder_.AppendTo(func_builder, 'R');
+ return func_builder;
+}
+
+void RawSignatureMangler::AddAnyReference() {
+ // A more constrained ref object would have a non empty span.
+ builder_.Span(absl::string_view(), 'O');
+}
+
+void RawSignatureMangler::AddShapedNDBuffer(
+ AbiConstants::ScalarType element_type, absl::Span<int> shape) {
+ SignatureBuilder item_builder;
+ // Fields:
+ // 't': scalar type code
+ // 'd': shape dimension
+ if (static_cast<unsigned>(element_type) != 0) {
+ item_builder.Integer(static_cast<unsigned>(element_type), 't');
+ }
+ for (int d : shape) {
+ item_builder.Integer(d, 'd');
+ }
+ item_builder.AppendTo(builder_, 'B');
+}
+
+// -----------------------------------------------------------------------------
+// RawSignatureParser
+// -----------------------------------------------------------------------------
+
+void RawSignatureParser::Description::ToString(std::string& s) const {
+ switch (type) {
+ case Type::kBuffer: {
+ const char* scalar_type_name = "!BADTYPE!";
+ unsigned scalar_type_u = static_cast<unsigned>(buffer.scalar_type);
+ if (scalar_type_u >= 0 &&
+ scalar_type_u <= AbiConstants::kScalarTypeNames.size()) {
+ scalar_type_name = AbiConstants::kScalarTypeNames[static_cast<unsigned>(
+ scalar_type_u)];
+ }
+ absl::StrAppend(&s, "Buffer<", scalar_type_name, "[");
+ for (size_t i = 0; i < dims.size(); ++i) {
+ if (i > 0) s.push_back('x');
+ if (dims[i] >= 0) {
+ absl::StrAppend(&s, dims[i]);
+ } else {
+ s.push_back('?');
+ }
+ }
+ absl::StrAppend(&s, "]>");
+ break;
+ }
+ case Type::kRefObject:
+ absl::StrAppend(&s, "RefObject<?>");
+ break;
+ default:
+ absl::StrAppend(&s, "!UNKNOWN!");
+ }
+}
+
+absl::optional<std::string> RawSignatureParser::FunctionSignatureToString(
+ absl::string_view signature) {
+ std::string s;
+
+ bool print_sep = false;
+ auto visitor = [&print_sep, &s](const Description& d) {
+ if (print_sep) {
+ s.append(", ");
+ }
+ d.ToString(s);
+ print_sep = true;
+ };
+ s.push_back('(');
+ VisitInputs(signature, visitor);
+ s.append(") -> (");
+ print_sep = false;
+ VisitResults(signature, visitor);
+ s.push_back(')');
+
+ if (!GetError()) {
+ return s;
+ } else {
+ return absl::nullopt;
+ }
+}
+
+// -----------------------------------------------------------------------------
// SipSignatureMangler
// -----------------------------------------------------------------------------
diff --git a/iree/base/signature_mangle.h b/iree/base/signature_mangle.h
index b3c5c11..a8088a9 100644
--- a/iree/base/signature_mangle.h
+++ b/iree/base/signature_mangle.h
@@ -15,6 +15,7 @@
#ifndef IREE_BASE_SIGNATURE_MANGLE_H_
#define IREE_BASE_SIGNATURE_MANGLE_H_
+#include <array>
#include <cassert>
#include <map>
#include <string>
@@ -27,6 +28,38 @@
// Name mangling/demangling for function and type signatures.
namespace iree {
+namespace AbiConstants {
+
+// Canonical integer mappings are maintained for core scalar type codes
+// since they change infrequently and are used everywhere.
+// Generally, favor adding a custom type vs extending this arbitrarily.
+enum class ScalarType : unsigned {
+ kIeeeFloat32 = 0,
+ kIeeeFloat16 = 1,
+ kIeeeFloat64 = 2,
+ kGoogleBfloat16 = 3,
+ kSint8 = 4,
+ kSint16 = 5,
+ kSint32 = 6,
+ kSint64 = 7,
+ kUint8 = 8,
+ kUint16 = 9,
+ kUint32 = 10,
+ kUint64 = 11,
+ kMaxScalarType = 11,
+};
+
+// Array that maps ScalarType codes to the size in bytes.
+extern const std::array<size_t,
+ static_cast<unsigned>(ScalarType::kMaxScalarType) + 1>
+ kScalarTypeSize;
+
+extern const std::array<const char*,
+ static_cast<unsigned>(ScalarType::kMaxScalarType) + 1>
+ kScalarTypeNames;
+
+} // namespace AbiConstants
+
// Builds up a signature string from components.
// The signature syntax is a sequence of Integer or Span fields:
// integer_tag ::= '_' | [a-z]
@@ -117,6 +150,167 @@
char next_tag_;
};
+// -----------------------------------------------------------------------------
+// Raw signatures
+// -----------------------------------------------------------------------------
+
+// Mangles raw function signatures.
+// See function_abi.md.
+class RawSignatureMangler {
+ public:
+ // Combines mangled input and result signatures into a function signature.
+ static SignatureBuilder ToFunctionSignature(RawSignatureMangler& inputs,
+ RawSignatureMangler& results);
+
+ // Adds an unconstrained reference-type object.
+ void AddAnyReference();
+
+ // Adds a shaped nd buffer operand with the given element type and shape.
+ // Unknown dims should be -1.
+ // This is the common case for external interfacing and requires a fully
+ // ranked shape.
+ void AddShapedNDBuffer(AbiConstants::ScalarType element_type,
+ absl::Span<int> shape);
+
+ const SignatureBuilder& builder() const { return builder_; }
+
+ private:
+ SignatureBuilder builder_;
+};
+
+// Parses function signatures generated by RawSignatureMangler.
+class RawSignatureParser {
+ public:
+ enum class Type {
+ kBuffer = 0,
+ kRefObject = 1,
+ };
+
+ struct Description {
+ Type type;
+
+ // For shaped types, this is the corresponding dims.
+ absl::InlinedVector<int, 7> dims;
+
+ union {
+ // Further details for Type == kBuffer.
+ struct {
+ AbiConstants::ScalarType scalar_type;
+ } buffer;
+ };
+
+ // Human readable description.
+ void ToString(std::string& s) const;
+ };
+
+ using Visitor = std::function<void(const Description&)>;
+
+ void VisitInputs(absl::string_view signature, Visitor visitor) {
+ SignatureParser sp(signature);
+ if (!sp.SeekTag('I')) {
+ SetError("Inputs span not found");
+ return;
+ }
+ auto nested = sp.nested();
+ return Visit(visitor, nested);
+ }
+
+ void VisitResults(absl::string_view signature, Visitor visitor) {
+ SignatureParser sp(signature);
+ if (!sp.SeekTag('R')) {
+ SetError("Results span not found");
+ return;
+ }
+ auto nested = sp.nested();
+ return Visit(visitor, nested);
+ }
+
+ // Produces a human readable function signature from the encoded form.
+ // Does not return a value on error.
+ absl::optional<std::string> FunctionSignatureToString(
+ absl::string_view signature);
+
+ // If the parser is in an error state, accesses the error.
+ const absl::optional<std::string>& GetError() { return error_; }
+ void SetError(std::string error) {
+ if (!error_) error_ = std::move(error);
+ }
+
+ private:
+ void Visit(Visitor& v, SignatureParser& item_parser) {
+ Description d;
+ while (!item_parser.end_or_error() && !error_) {
+ // Reset shared fields.
+ d.dims.clear();
+
+ switch (item_parser.tag()) {
+ case 'B':
+ if (!FillBuffer(d, SignatureParser(item_parser.nested()))) {
+ return;
+ }
+ break;
+ case 'O':
+ if (!FillRefObject(d, SignatureParser(item_parser.nested()))) {
+ return;
+ }
+ break;
+ default:
+ SetError("Unrecognized raw tag");
+ return;
+ }
+
+ v(d);
+ item_parser.Next();
+ }
+ }
+
+ bool FillBuffer(Description& d, SignatureParser p) {
+ d.type = Type::kBuffer;
+ d.buffer.scalar_type = AbiConstants::ScalarType::kIeeeFloat32; // Default
+ while (!p.end_or_error()) {
+ switch (p.tag()) {
+ case 't':
+ if (p.ival() < 0 ||
+ p.ival() >
+ static_cast<int>(AbiConstants::ScalarType::kMaxScalarType)) {
+ SetError("Illegal ScalarType code");
+ return false;
+ }
+ d.buffer.scalar_type =
+ static_cast<AbiConstants::ScalarType>(p.ival());
+ break;
+ case 'd':
+ d.dims.push_back(p.ival());
+ break;
+ default:
+ SetError("Unrecognized buffer field tag");
+ return false;
+ }
+ p.Next();
+ }
+ return true;
+ }
+
+ bool FillRefObject(Description& d, SignatureParser p) {
+ d.type = Type::kRefObject;
+ while (!p.end_or_error()) {
+ switch (p.tag()) {
+ default:
+ SetError("Unrecognized ref object field tag");
+ return false;
+ }
+ p.Next();
+ }
+ return true;
+ }
+
+ absl::optional<std::string> error_;
+};
+
+// -----------------------------------------------------------------------------
+// Sip signatures
+// -----------------------------------------------------------------------------
+
// Mangles function signatures according to the Sip (Structured Index Path) V1
// scheme.
//
diff --git a/iree/base/signature_mangle_test.cc b/iree/base/signature_mangle_test.cc
index d132a3a..1524695 100644
--- a/iree/base/signature_mangle_test.cc
+++ b/iree/base/signature_mangle_test.cc
@@ -209,6 +209,68 @@
EXPECT_EQ(SignatureParser::Type::kEnd, sp1.Next());
}
+// -----------------------------------------------------------------------------
+// Raw signatures
+// -----------------------------------------------------------------------------
+
+TEST(RawSignatureManglerTest, DefaultBuffer) {
+ RawSignatureMangler sm;
+ sm.AddShapedNDBuffer(AbiConstants::ScalarType::kIeeeFloat32, {});
+ EXPECT_EQ("B1!", sm.builder().encoded());
+}
+
+TEST(RawSignatureManglerTest, FullBuffer) {
+ RawSignatureMangler sm;
+ std::vector<int> dims = {-1, 128, 64};
+ sm.AddShapedNDBuffer(AbiConstants::ScalarType::kIeeeFloat64,
+ absl::MakeSpan(dims));
+ EXPECT_EQ("B13!t2d-1d128d64", sm.builder().encoded());
+}
+
+TEST(RawSignatureManglerTest, AnyRef) {
+ RawSignatureMangler sm;
+ sm.AddAnyReference();
+ EXPECT_EQ("O1!", sm.builder().encoded());
+}
+
+TEST(RawSignatureParserTest, EmptySignature) {
+ RawSignatureMangler inputs;
+ RawSignatureMangler results;
+
+ auto sig = RawSignatureMangler::ToFunctionSignature(inputs, results);
+ RawSignatureParser p;
+ auto s = p.FunctionSignatureToString(sig.encoded());
+ ASSERT_TRUE(s) << *p.GetError();
+ EXPECT_EQ("() -> ()", *s);
+}
+
+TEST(RawSignatureParserTest, AllTypes) {
+ RawSignatureMangler inputs;
+ inputs.AddAnyReference();
+ std::vector<int> dims = {-1, 128, 64};
+ inputs.AddShapedNDBuffer(AbiConstants::ScalarType::kIeeeFloat32,
+ absl::MakeSpan(dims));
+ RawSignatureMangler results;
+ std::vector<int> dims2 = {32, -1, 64};
+ results.AddShapedNDBuffer(AbiConstants::ScalarType::kUint64,
+ absl::MakeSpan(dims2));
+
+ auto sig = RawSignatureMangler::ToFunctionSignature(inputs, results);
+ EXPECT_EQ("I18!O1!B11!d-1d128d64R17!B13!t11d32d-1d64", sig.encoded());
+
+ RawSignatureParser p;
+ auto s = p.FunctionSignatureToString(sig.encoded());
+ ASSERT_TRUE(s) << *p.GetError();
+ EXPECT_EQ(
+ "(RefObject<?>, Buffer<float32[?x128x64]>) -> "
+ "(Buffer<uint64[32x?x64]>)",
+ *s);
+}
+
+// -----------------------------------------------------------------------------
+// Sip signatures
+// -----------------------------------------------------------------------------
+
TEST_F(SipSignatureTest, NoInputsResults) {
const char kExpectedInputs[] = R"()";
const char kExpectedResults[] = R"()";