Merge pull request #4791 from google/benvanik-llvm-abi-new
This switches the runtime and generated code to using the new executable
library signature while still routing all the outputs through the flatbuffers.
Future changes will start generating the library metadata structures.
Progress on #3580.
diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
index d54aa67..e1030c4 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
+++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
@@ -56,118 +56,164 @@
// versions in the same compiled output.
class HALDispatchABI {
public:
- static constexpr int kIndexPackedBuffer = 0;
- static constexpr int kIndexPushConstant = 1;
- static constexpr int kIndexWorkGroupId = 2;
- static constexpr int kIndexWorkGroupCount = 3;
- static constexpr int kIndexWorkGroupSize = 4;
+ // Returns a Type representing iree_hal_vec3_t.
+ static Type getVec3Type(MLIRContext *context) {
+ auto uint32Type = IntegerType::get(context, 32);
+ return LLVM::LLVMArrayType::get(uint32Type, 3);
+ }
+
+ // Matches the field order in iree_hal_executable_dispatch_state_v0_t.
+ enum class Field {
+ workgroup_count = 0,
+ workgroup_size = 1,
+ push_constant_count = 2,
+ push_constants = 3,
+ binding_count = 4,
+ binding_ptrs = 5,
+ binding_lengths = 6,
+ };
+
+ // Returns a Type representing iree_hal_executable_dispatch_state_v0_t.
+ static LLVM::LLVMStructType getDispatchStateType(
+ MLIRContext *context, LLVMTypeConverter *typeConverter) {
+ auto structType = LLVM::LLVMStructType::getIdentified(
+ context, "iree_hal_executable_dispatch_state_v0_t");
+ if (structType.isInitialized()) return structType;
+
+ auto indexType = typeConverter->convertType(IndexType::get(context));
+ auto int8Type = IntegerType::get(context, 8);
+ auto uint32Type = IntegerType::get(context, 32);
+ auto vec3Type = getVec3Type(context);
+ SmallVector<Type, 4> fieldTypes;
+
+ // iree_hal_vec3_t workgroup_count;
+ // iree_hal_vec3_t workgroup_size;
+ fieldTypes.push_back(vec3Type);
+ fieldTypes.push_back(vec3Type);
+
+ // size_t push_constant_count;
+ // const uint32_t * push_constants;
+ fieldTypes.push_back(indexType);
+ fieldTypes.push_back(LLVM::LLVMPointerType::get(uint32Type));
+
+ // size_t binding_count;
+ // void *const * binding_ptrs;
+ // const size_t * binding_lengths;
+ fieldTypes.push_back(indexType);
+ fieldTypes.push_back(
+ LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(int8Type)));
+ fieldTypes.push_back(LLVM::LLVMPointerType::get(indexType));
+
+ LogicalResult bodySet = structType.setBody(fieldTypes, /*isPacked=*/false);
+ assert(succeeded(bodySet) &&
+ "could not set the body of an identified struct");
+ (void)bodySet;
+
+ return structType;
+ }
// Returns the types of the LLVM function inputs for the ABI.
// This matches the signature of `iree_hal_executable_dispatch_v0_t` in
// `iree/hal/local/executable_library.h`.
- static SmallVector<Type, 5> getInputTypes(MLIRContext *context) {
- // func foo(%packed_buffer_args: !llvm.ptr<!llvm.ptr<i8>>,
- // %push_constant: !llvm.ptr<i32>,
- // workgroup_id[3]: !llvm.ptr<!llvm.array<i32, 3>>,
- // workgroup_count[3]: !llvm.ptr<!llvm.array<i32, 3>>,
- // workgroup_size[3]: !llvm.ptr<!llvm.array<i32, 3>>)
- auto indexTy = IntegerType::get(context, 32);
+ static SmallVector<Type, 5> getInputTypes(MLIRContext *context,
+ LLVMTypeConverter *typeConverter) {
return SmallVector<Type, 5>{
- // %packed_buffer_args: !llvm.ptr<!llvm.ptr<i8>>
+ // const iree_hal_executable_dispatch_state_v0_t* IREE_RESTRICT
+ // dispatch_state
LLVM::LLVMPointerType::get(
- LLVM::LLVMPointerType::get(IntegerType::get(context, 8))),
- // %push_constant: !llvm.ptr<i32>
- LLVM::LLVMPointerType::get(indexTy),
- // %workgroup_id[3]: !llvm.ptr<!llvm.array<i32, 3>>
- LLVM::LLVMPointerType::get(LLVM::LLVMArrayType::get(indexTy, 3)),
- // %workgroup_count[3]: !llvm.ptr<!llvm.array<i32, 3>>
- LLVM::LLVMPointerType::get(LLVM::LLVMArrayType::get(indexTy, 3)),
- // %workgroup_size[3]: !llvm.ptr<!llvm.array<i32, 3>>
- LLVM::LLVMPointerType::get(LLVM::LLVMArrayType::get(indexTy, 3))};
+ getDispatchStateType(context, typeConverter)),
+ // const iree_hal_vec3_t* IREE_RESTRICT workgroup_id
+ LLVM::LLVMPointerType::get(getVec3Type(context)),
+ };
}
explicit HALDispatchABI(LLVM::LLVMFuncOp &funcOp,
LLVMTypeConverter *typeConverter)
- : funcOp(funcOp), typeConverter(typeConverter) {}
+ : funcOp(funcOp),
+ typeConverter(typeConverter),
+ dispatchStateType(
+ getDispatchStateType(funcOp.getContext(), typeConverter)) {}
// Loads the workgroup_id[dim] value (XYZ) and casts it to |resultType|.
Value loadWorkgroupID(Location loc, int32_t dim, Type resultType,
OpBuilder &builder) {
- auto xyzArrayPtr = funcOp.getArgument(kIndexWorkGroupId);
- auto xyzArrayValue = builder.createOrFold<LLVM::LoadOp>(loc, xyzArrayPtr);
+ auto workgroupIdPtrValue = funcOp.getArgument(1);
+ auto workgroupIdValue =
+ builder.createOrFold<LLVM::LoadOp>(loc, workgroupIdPtrValue);
auto dimValue = builder.createOrFold<LLVM::ExtractValueOp>(
- loc, builder.getIntegerType(32), xyzArrayValue,
- builder.getI32ArrayAttr({dim}));
+ loc, builder.getIntegerType(32), workgroupIdValue,
+ builder.getI64ArrayAttr({dim}));
return castValueToType(loc, dimValue, resultType, builder);
}
// Loads the workgroup_count[dim] value (XYZ) and casts it to |resultType|.
Value loadWorkgroupCount(Location loc, int32_t dim, Type resultType,
OpBuilder &builder) {
- auto xyzArrayPtr = funcOp.getArgument(kIndexWorkGroupCount);
- auto xyzArrayValue = builder.createOrFold<LLVM::LoadOp>(loc, xyzArrayPtr);
+ auto workgroupCountValue =
+ loadFieldValue(loc, Field::workgroup_count, builder);
auto dimValue = builder.createOrFold<LLVM::ExtractValueOp>(
- loc, builder.getIntegerType(32), xyzArrayValue,
- builder.getI32ArrayAttr({dim}));
+ loc, builder.getIntegerType(32), workgroupCountValue,
+ builder.getI64ArrayAttr(dim));
return castValueToType(loc, dimValue, resultType, builder);
}
// Loads the workgroup_size[dim] value (XYZ) and casts it to |resultType|.
Value loadWorkgroupSize(Location loc, int32_t dim, Type resultType,
OpBuilder &builder) {
- auto xyzArrayPtr = funcOp.getArgument(kIndexWorkGroupSize);
- auto xyzArrayValue = builder.createOrFold<LLVM::LoadOp>(loc, xyzArrayPtr);
+ auto workgroupSizeValue =
+ loadFieldValue(loc, Field::workgroup_size, builder);
auto dimValue = builder.createOrFold<LLVM::ExtractValueOp>(
- loc, builder.getIntegerType(32), xyzArrayValue,
- builder.getI32ArrayAttr({dim}));
+ loc, builder.getIntegerType(32), workgroupSizeValue,
+ builder.getI64ArrayAttr(dim));
return castValueToType(loc, dimValue, resultType, builder);
}
// Returns the total push constant count as an index-converted type.
Value loadPushConstantCount(Location loc, OpBuilder &builder) {
- // TODO(#3580): switch to the executable_library ABI.
- assert(false && "not yet implemented");
- return {};
+ auto value = loadFieldValue(loc, Field::push_constant_count, builder);
+ return castValueToType(loc, value,
+ typeConverter->convertType(builder.getIndexType()),
+ builder);
}
// Loads a push constant at |offset| and casts it to |resultType|.
Value loadPushConstant(Location loc, int64_t offset, Type resultType,
OpBuilder &builder) {
- Value offsetValue = builder.create<LLVM::DialectCastOp>(
- loc, typeConverter->convertType(builder.getIndexType()),
- builder.create<ConstantIndexOp>(loc, offset));
- Value pushConstantPtrValue = builder.create<LLVM::GEPOp>(
- loc, funcOp.getArgument(kIndexPushConstant).getType(),
- funcOp.getArgument(kIndexPushConstant), offsetValue);
- Value pushConstantValue =
- builder.create<LLVM::LoadOp>(loc, pushConstantPtrValue);
- return castValueToType(loc, pushConstantValue, resultType, builder);
+ auto constantsPtrValue =
+ loadFieldValue(loc, Field::push_constants, builder);
+ auto offsetValue = getIndexValue(loc, offset, builder);
+ Value constantPtrValue = builder.create<LLVM::GEPOp>(
+ loc, constantsPtrValue.getType(), constantsPtrValue, offsetValue);
+ Value constantValue = builder.create<LLVM::LoadOp>(loc, constantPtrValue);
+ return castValueToType(loc, constantValue, resultType, builder);
}
// Returns the total binding count as an index-converted type.
Value loadBindingCount(Location loc, OpBuilder &builder) {
- // TODO(#3580): switch to the executable_library ABI.
- assert(false && "not yet implemented");
- return {};
+ auto value = loadFieldValue(loc, Field::binding_count, builder);
+ return castValueToType(loc, value,
+ typeConverter->convertType(builder.getIndexType()),
+ builder);
}
// Loads the base pointer of the binding |ordinal| as an `i8**`.
// Equivalent to:
// int8_t** base_ptr = &state->binding_ptrs[ordinal];
Value loadBindingPtr(Location loc, int64_t ordinal, OpBuilder &builder) {
- Value ordinalValue = builder.createOrFold<LLVM::DialectCastOp>(
- loc, typeConverter->convertType(builder.getIndexType()),
- builder.create<ConstantIndexOp>(loc, ordinal));
- return builder.createOrFold<LLVM::GEPOp>(
- loc, funcOp.getArgument(kIndexPackedBuffer).getType(),
- funcOp.getArgument(kIndexPackedBuffer), ordinalValue);
+ auto ptrsPtrValue = loadFieldValue(loc, Field::binding_ptrs, builder);
+ auto ordinalValue = getIndexValue(loc, ordinal, builder);
+ auto elementPtrValue = builder.createOrFold<LLVM::GEPOp>(
+ loc, ptrsPtrValue.getType(), ptrsPtrValue, ordinalValue);
+ return builder.createOrFold<LLVM::LoadOp>(loc, elementPtrValue);
}
// Loads the byte length of the binding |ordinal| as an index-converted type.
Value loadBindingLength(Location loc, int64_t ordinal, OpBuilder &builder) {
- // TODO(#3580): switch to the executable_library ABI.
- assert(false && "not yet implemented");
- return {};
+ auto lengthsPtrValue = loadFieldValue(loc, Field::binding_lengths, builder);
+ auto ordinalValue = getIndexValue(loc, ordinal, builder);
+ auto elementPtrValue = builder.createOrFold<LLVM::GEPOp>(
+ loc, lengthsPtrValue.getType(), lengthsPtrValue, ordinalValue);
+ return builder.createOrFold<LLVM::LoadOp>(loc, elementPtrValue);
}
// Loads a binding as a constructed MemRefDescriptor.
@@ -176,8 +222,7 @@
Value baseOffsetValue, MemRefType memRefType,
OpBuilder &builder) {
// Load the base buffer pointer in the appropriate type (f32*, etc).
- Value opaqueBasePtrValue = loadBindingPtr(loc, ordinal, builder);
- Value basePtrValue = builder.create<LLVM::LoadOp>(loc, opaqueBasePtrValue);
+ Value basePtrValue = loadBindingPtr(loc, ordinal, builder);
// Adjust by baseOffset (if needed).
if (baseOffsetValue) {
@@ -211,6 +256,20 @@
}
private:
+ Value loadFieldValue(Location loc, Field field, OpBuilder &builder) {
+ auto statePtrValue = funcOp.getArgument(0);
+ auto stateValue = builder.createOrFold<LLVM::LoadOp>(loc, statePtrValue);
+ auto fieldType = dispatchStateType.getBody()[(int)field];
+ return builder.createOrFold<LLVM::ExtractValueOp>(
+ loc, fieldType, stateValue, builder.getI64ArrayAttr((int)field));
+ }
+
+ Value getIndexValue(Location loc, int64_t value, OpBuilder &builder) {
+ return builder.createOrFold<LLVM::DialectCastOp>(
+ loc, typeConverter->convertType(builder.getIndexType()),
+ builder.createOrFold<ConstantIndexOp>(loc, value));
+ }
+
Value castValueToType(Location loc, Value value, Type resultType,
OpBuilder &builder) {
// NOTE: we should handle more cases here (and proper sign extension).
@@ -220,6 +279,7 @@
LLVM::LLVMFuncOp funcOp;
LLVMTypeConverter *typeConverter;
+ LLVM::LLVMStructType dispatchStateType;
};
/// Converts Standard MLIR FuncOps to LLVMFuncOps matching the IREE HAL ABI.
@@ -268,7 +328,8 @@
// Convert the function signature to take the HAL ABI LLVM pointers.
TypeConverter::SignatureConversion signatureConverter(/*numOrigInputs=*/0);
MLIRContext *context = rewriter.getContext();
- auto abiInputTypes = HALDispatchABI::getInputTypes(context);
+ auto abiInputTypes =
+ HALDispatchABI::getInputTypes(context, getTypeConverter());
signatureConverter.addInputs(abiInputTypes);
// Copy all attributes onto the LLVM function except the ones handled by
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
deleted file mode 100644
index beb1d63..0000000
--- a/iree/compiler/Conversion/LinalgToLLVM/test/convert_to_llvm.mlir
+++ /dev/null
@@ -1,158 +0,0 @@
-// RUN: iree-opt -iree-codegen-convert-to-llvm -cse -split-input-file %s | IreeFileCheck %s
-
-// CHECK_LABEL: @convert_dynamic_shape
-func @convert_dynamic_shape() {
- %c0 = constant 0 : index
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x?xf32>
- %1 = hal.interface.load.constant offset = 0 : index
- %2 = hal.interface.load.constant offset = 1 : index
- %3 = shapex.make_ranked_shape %1, %2 : (index, index) -> !shapex.ranked_shape<[?,?]>
- %6 = shapex.tie_shape %0, %3 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
- %7 = load %6[%c0, %c0] : memref<?x?xf32>
- %8 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x?xf32>
- %9 = shapex.tie_shape %8, %3 : memref<?x?xf32>, !shapex.ranked_shape<[?,?]>
- store %7, %8[%c0, %c0] : memref<?x?xf32>
- return
-}
-hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
-}
-// CHECK: llvm.func @convert_dynamic_shape(%[[ARG0:.+]]: !llvm.ptr<ptr<i8>>, %[[ARG1:.+]]: !llvm.ptr<i32>, %[[WORKGROUP_ID:.+]]: !llvm.ptr<array<3 x i32>>, %[[WORKGROUP_COUNT:.+]]: !llvm.ptr<array<3 x i32>>, %[[WORKGROUP_SIZE:.+]]: !llvm.ptr<array<3 x i32>>) {
-// CHECK: %[[CONST0:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[MEMREF0_PTR_PTR:.+]] = llvm.getelementptr %[[ARG0]][%[[CONST0]]] : (!llvm.ptr<ptr<i8>>, i64) -> !llvm.ptr<ptr<i8>>
-// CHECK: %[[MEMREF0_PTR:.+]] = llvm.load %[[MEMREF0_PTR_PTR]] : !llvm.ptr<ptr<i8>>
-// CHECK: %[[MEMREF0_BASE_PTR:.+]] = llvm.bitcast %[[MEMREF0_PTR]] : !llvm.ptr<i8> to !llvm.ptr<f32>
-// CHECK: %[[MEMREF_DESC:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_1:.+]] = llvm.insertvalue %[[MEMREF0_BASE_PTR]], %[[MEMREF_DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_2:.+]] = llvm.insertvalue %[[MEMREF0_BASE_PTR]], %[[MEMREF0_1]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[DIM0_PTR:.+]] = llvm.getelementptr %[[ARG1]][%[[CONST0]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
-// CHECK: %[[DIM0:.+]] = llvm.load %[[DIM0_PTR]] : !llvm.ptr<i32>
-// CHECK: %[[DIM0_i64:.+]] = llvm.zext %[[DIM0]] : i32 to i64
-// CHECK: %[[CONST1:.+]] = llvm.mlir.constant(1 : index) : i64
-// CHECK: %[[DIM1_PTR:.+]] = llvm.getelementptr %[[ARG1]][%[[CONST1]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
-// CHECK: %[[DIM1:.+]] = llvm.load %[[DIM1_PTR]] : !llvm.ptr<i32>
-// CHECK: %[[DIM1_i64:.+]] = llvm.zext %[[DIM1]] : i32 to i64
-// CHECK: %[[MEMREF0_3:.+]] = llvm.insertvalue %[[DIM0_i64]], %[[MEMREF0_2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_4:.+]] = llvm.insertvalue %[[DIM1_i64]], %[[MEMREF0_3]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_5:.+]] = llvm.insertvalue %[[CONST1]], %[[MEMREF0_4]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_STRIDE_1:.+]] = llvm.extractvalue %[[MEMREF0_5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_SIZE_1:.+]] = llvm.extractvalue %[[MEMREF0_5]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_STRIDE_0:.+]] = llvm.mul %[[MEMREF0_STRIDE_1]], %[[MEMREF0_SIZE_1]] : i64
-// CHECK: %[[MEMREF0_6:.+]] = llvm.insertvalue %[[MEMREF0_STRIDE_0]], %[[MEMREF0_5]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_BASE_PTR_1:.+]] = llvm.extractvalue %[[MEMREF0_6]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_STRIDE0_0:.+]] = llvm.extractvalue %[[MEMREF0_6]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_00_BASE:.+]] = llvm.mul %[[CONST0]], %[[MEMREF0_STRIDE0_0]] : i64
-// CHECK: %[[MEMREF0_00_OFFSET:.+]] = llvm.add %[[MEMREF0_00_BASE]], %[[CONST0]] : i64
-// CHECK: %[[MEMREF0_00_PTR:.+]] = llvm.getelementptr %[[MEMREF0_BASE_PTR_1]][%[[MEMREF0_00_OFFSET]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-// CHECK: %[[MEMREF0_00:.+]] = llvm.load %[[MEMREF0_00_PTR]] : !llvm.ptr<f32>
-// CHECK: %[[MEMREF1_PTR_PTR:.+]] = llvm.getelementptr %[[ARG0]][%[[CONST1]]] : (!llvm.ptr<ptr<i8>>, i64) -> !llvm.ptr<ptr<i8>>
-// CHECK: %[[MEMREF1_PTR:.+]] = llvm.load %[[MEMREF1_PTR_PTR]] : !llvm.ptr<ptr<i8>>
-// CHECK: %[[MEMREF1_BASE_PTR:.+]] = llvm.bitcast %[[MEMREF1_PTR]] : !llvm.ptr<i8> to !llvm.ptr<f32>
-// CHECK: %[[MEMREF1_0:.+]] = llvm.insertvalue %[[MEMREF1_BASE_PTR]], %[[MEMREF_DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_1:.+]] = llvm.insertvalue %[[MEMREF1_BASE_PTR]], %[[MEMREF1_0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_BASE:.+]] = llvm.extractvalue %[[MEMREF1_1]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_STRIDE_0:.+]] = llvm.extractvalue %[[MEMREF1_1]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_00_OFFSET:.+]] = llvm.mul %[[CONST0]], %[[MEMREF1_STRIDE_0]] : i64
-// CHECK: %[[MEMREF1_00_ADDRS:.+]] = llvm.add %[[MEMREF1_00_OFFSET]], %[[CONST0]] : i64
-// CHECK: %[[MEMREF1_00_PTR:.+]] = llvm.getelementptr %[[MEMREF1_BASE]][%[[MEMREF1_00_ADDRS]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-// CHECK: llvm.store %[[MEMREF0_00]], %[[MEMREF1_00_PTR]] : !llvm.ptr<f32>
-// CHECK: llvm.return
-
-// -----
-
-// CHECK_LABEL: @convert_dynamic_shape2
-func @convert_dynamic_shape2() {
- %c0 = constant 0 : index
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io2::@arg0} : memref<2x?xf32>
- %1 = hal.interface.load.constant offset = 0 : index
- %2 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[2,?]>
- %3 = shapex.tie_shape %0, %2 : memref<2x?xf32>, !shapex.ranked_shape<[2,?]>
- %4 = load %3[%c0, %c0] : memref<2x?xf32>
- %5 = iree.placeholder for "interface buffer" {binding = @legacy_io2::@ret0} : memref<2x?xf32>
- %9 = shapex.tie_shape %5, %2 : memref<2x?xf32>, !shapex.ranked_shape<[2,?]>
- store %4, %9[%c0, %c0] : memref<2x?xf32>
- return
-}
-hal.interface @legacy_io2 attributes {push_constants = 1 : i32, sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
-}
-// CHECK: llvm.func @convert_dynamic_shape2(%[[ARG0:.+]]: !llvm.ptr<ptr<i8>>, %[[ARG1:.+]]: !llvm.ptr<i32>, %[[WORKGROUP_ID:.+]]: !llvm.ptr<array<3 x i32>>, %[[WORKGROUP_COUNT:.+]]: !llvm.ptr<array<3 x i32>>, %[[WORKGROUP_SIZE:.+]]: !llvm.ptr<array<3 x i32>>) {
-// CHECK: %[[CONST0:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[MEMREF0_PTR_PTR:.+]] = llvm.getelementptr %[[ARG0]][%[[CONST0]]] : (!llvm.ptr<ptr<i8>>, i64) -> !llvm.ptr<ptr<i8>>
-// CHECK: %[[MEMREF0_PTR:.+]] = llvm.load %[[MEMREF0_PTR_PTR]] : !llvm.ptr<ptr<i8>>
-// CHECK: %[[MEMREF0_BASE_PTR:.+]] = llvm.bitcast %[[MEMREF0_PTR]] : !llvm.ptr<i8> to !llvm.ptr<f32>
-// CHECK: %[[MEMREF_DESC:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_1:.+]] = llvm.insertvalue %[[MEMREF0_BASE_PTR]], %[[MEMREF_DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_2:.+]] = llvm.insertvalue %[[MEMREF0_BASE_PTR]], %[[MEMREF0_1]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[DIM1_PTR:.+]] = llvm.getelementptr %[[ARG1]][%[[CONST0]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
-// CHECK: %[[DIM1:.+]] = llvm.load %[[DIM1_PTR]] : !llvm.ptr<i32>
-// CHECK: %[[DIM1_i64:.+]] = llvm.zext %[[DIM1]] : i32 to i64
-// CHECK: %[[CONST2:.+]] = llvm.mlir.constant(2 : index) : i64
-// CHECK: %[[MEMREF0_3:.+]] = llvm.insertvalue %[[CONST2]], %[[MEMREF0_2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_4:.+]] = llvm.insertvalue %[[DIM1_i64]], %[[MEMREF0_3]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[CONST1:.+]] = llvm.mlir.constant(1 : index) : i64
-// CHECK: %[[MEMREF0_5:.+]] = llvm.insertvalue %[[CONST1]], %[[MEMREF0_4]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_STRIDE_1:.+]] = llvm.extractvalue %[[MEMREF0_5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_SIZE_1:.+]] = llvm.extractvalue %[[MEMREF0_5]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_STRIDE_0:.+]] = llvm.mul %[[MEMREF0_STRIDE_1]], %[[MEMREF0_SIZE_1]] : i64
-// CHECK: %[[MEMREF0_6:.+]] = llvm.insertvalue %[[MEMREF0_STRIDE_0]], %[[MEMREF0_5]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_BASE_PTR_1:.+]] = llvm.extractvalue %[[MEMREF0_6]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_STRIDE_0_0:.+]] = llvm.extractvalue %[[MEMREF0_6]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF0_00_ADDRS:.+]] = llvm.mul %[[CONST0]], %[[MEMREF0_STRIDE_0_0]] : i64
-// CHECK: %[[MEMREF0_00_INDEX:.+]] = llvm.add %[[MEMREF0_00_ADDRS]], %[[CONST0]] : i64
-// CHECK: %[[MEMREF0_00_PTR:.+]] = llvm.getelementptr %19[%[[MEMREF0_00_INDEX]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-// CHECK: %[[MEMREF0_00:.+]] = llvm.load %[[MEMREF0_00_PTR]] : !llvm.ptr<f32>
-// CHECK: %[[MEMREF1_PTR_PTR:.+]] = llvm.getelementptr %[[ARG0]][%[[CONST1]]] : (!llvm.ptr<ptr<i8>>, i64) -> !llvm.ptr<ptr<i8>>
-// CHECK: %[[MEMREF1_PTR:.+]] = llvm.load %[[MEMREF1_PTR_PTR]] : !llvm.ptr<ptr<i8>>
-// CHECK: %[[MEMREF1_BASE_PTR:.+]] = llvm.bitcast %[[MEMREF1_PTR]] : !llvm.ptr<i8> to !llvm.ptr<f32>
-// CHECK: %[[MEMREF1_1:.+]] = llvm.insertvalue %[[MEMREF1_BASE_PTR]], %[[MEMREF_DESC]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_2:.+]] = llvm.insertvalue %[[MEMREF1_BASE_PTR]], %[[MEMREF1_1]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_3:.+]] = llvm.insertvalue %[[CONST2]], %[[MEMREF1_2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_4:.+]] = llvm.insertvalue %[[DIM1_i64]], %[[MEMREF1_3]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_5:.+]] = llvm.insertvalue %[[CONST1]], %[[MEMREF1_4]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_STRIDE1:.+]] = llvm.extractvalue %[[MEMREF1_5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_SIZE1:.+]] = llvm.extractvalue %[[MEMREF1_5]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_STRIDE0:.+]] = llvm.mul %[[MEMREF1_STRIDE1]], %[[MEMREF1_SIZE1]] : i64
-// CHECK: %[[MEMREF1_6:.+]] = llvm.insertvalue %[[MEMREF1_STRIDE0]], %[[MEMREF1_5]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_BASE_PTR:.+]] = llvm.extractvalue %[[MEMREF1_6]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_STRIDE0_0:.+]] = llvm.extractvalue %[[MEMREF1_6]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[MEMREF1_00_OFFSET:.+]] = llvm.mul %[[CONST0]], %[[MEMREF1_STRIDE0_0]] : i64
-// CHECK: %[[MEMREF1_00_INDEX:.+]] = llvm.add %[[MEMREF1_00_OFFSET]], %[[CONST0]] : i64
-// CHECK: %[[MEMREF1_00_PTR:.+]] = llvm.getelementptr %[[MEMREF1_BASE_PTR]][%[[MEMREF1_00_INDEX]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-// CHECK: llvm.store %[[MEMREF0_00]], %[[MEMREF1_00_PTR]] : !llvm.ptr<f32>
-// CHECK: llvm.return
-
-// -----
-
-// CHECK_LABEL: @distribute_lookup
-func @distribute_lookup() {
- %0 = iree.placeholder for "interface buffer" {binding = @legacy_io3::@arg0} : memref<2x2x2xf32>
- %1 = hal.interface.workgroup.id[0] : index
- %2 = hal.interface.workgroup.id[1] : index
- %3 = hal.interface.workgroup.id[2] : index
- %4 = load %0[%1, %2, %3] : memref<2x2x2xf32>
- %5 = iree.placeholder for "interface buffer" {binding = @legacy_io3::@ret0} : memref<2x2x2xf32>
- store %4, %5[%1, %2, %3] : memref<2x2x2xf32>
- return
-}
-hal.interface @legacy_io3 attributes {push_constants = 1 : i32, sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
-}
-// CHECK: llvm.func @distribute_lookup(%[[ARG0:.+]]: !llvm.ptr<ptr<i8>>, %[[ARG1:.+]]: !llvm.ptr<i32>, %[[WORKGROUP_ID:.+]]: !llvm.ptr<array<3 x i32>>, %[[WORKGROUP_COUNT:.+]]: !llvm.ptr<array<3 x i32>>, %[[WORKGROUP_SIZE:.+]]: !llvm.ptr<array<3 x i32>>) {
-// CHECK: %[[CONST0:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[CONST2:.+]] = llvm.mlir.constant(2 : index) : i64
-// CHECK: %[[CONST4:.+]] = llvm.mlir.constant(4 : index) : i64
-// CHECK: %[[WORKGROUP_ID_DATA_0:.+]] = llvm.load %[[WORKGROUP_ID]] : !llvm.ptr<array<3 x i32>>
-// CHECK: %[[WORKGROUP_ID_Z:.+]] = llvm.extractvalue %[[WORKGROUP_ID_DATA_0]][0 : i32] : !llvm.array<3 x i32>
-// CHECK: %[[WORKGROUP_ID_Z_i64:.+]] = llvm.zext %[[WORKGROUP_ID_Z]] : i32 to i64
-// CHECK: %[[WORKGROUP_ID_DATA_1:.+]] = llvm.load %[[WORKGROUP_ID]] : !llvm.ptr<array<3 x i32>>
-// CHECK: %[[WORKGROUP_ID_Y:.+]] = llvm.extractvalue %[[WORKGROUP_ID_DATA_1]][1 : i32] : !llvm.array<3 x i32>
-// CHECK: %[[WORKGROUP_ID_Y_i64:.+]] = llvm.zext %[[WORKGROUP_ID_Y]] : i32 to i64
-// CHECK: %[[WORKGROUP_ID_DATA_2:.+]] = llvm.load %[[WORKGROUP_ID]] : !llvm.ptr<array<3 x i32>>
-// CHECK: %[[WORKGROUP_ID_X:.+]] = llvm.extractvalue %[[WORKGROUP_ID_DATA_2]][2 : i32] : !llvm.array<3 x i32>
-// CHECK: %[[WORKGROUP_ID_X_i64:.+]] = llvm.zext %[[WORKGROUP_ID_X]] : i32 to i64
-// CHECK: %[[LOAD_STRIDE_Z:.+]] = llvm.mul %[[WORKGROUP_ID_Z_i64]], %[[CONST4]]
-// CHECK: %[[LOAD_STRIDE_Y:.+]] llvm.mul %[[WORKGROUP_ID_Y_i64]], %[[CONST2]]
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_bindings.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_bindings.mlir
index 2838662..b37e652 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_bindings.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_bindings.mlir
@@ -4,8 +4,10 @@
func @binding_ptrs() {
// CHECK-DAG: %[[C72:.+]] = llvm.mlir.constant(72 : index) : i64
%c72 = constant 72 : index
- // CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(1 : index) : i64
- // CHECK: %[[ARRAY_PTR:.+]] = llvm.getelementptr %arg0[%[[C1]]] : (!llvm.ptr<ptr<i8>>, i64) -> !llvm.ptr<ptr<i8>>
+ // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t", (array<3 x i32>, array<3 x i32>, i64, ptr<i32>, i64, ptr<ptr<i8>>, ptr<i64>)>>
+ // CHECK: %[[BINDING_PTRS:.+]] = llvm.extractvalue %[[STATE]][5]
+ // CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: %[[ARRAY_PTR:.+]] = llvm.getelementptr %[[BINDING_PTRS]][%[[C1]]] : (!llvm.ptr<ptr<i8>>, i64) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[BASE_PTR_I8:.+]] = llvm.load %[[ARRAY_PTR]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[BUFFER_I8:.+]] = llvm.getelementptr %[[BASE_PTR_I8]][%[[C72]]] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
// CHECK: %[[BUFFER_F32:.+]] = llvm.bitcast %[[BUFFER_I8]] : !llvm.ptr<i8> to !llvm.ptr<f32>
@@ -21,3 +23,37 @@
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @tie_shape
+func @tie_shape() {
+ %c72 = constant 72 : index
+ // ...
+ // CHECK: %[[DYN_MEMREF_T1:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DYN_MEMREF_T2:.+]] = llvm.insertvalue %{{.+}}, %[[DYN_MEMREF_T1]][0]
+ // CHECK: %[[DYN_MEMREF:.+]] = llvm.insertvalue %{{.+}}, %[[DYN_MEMREF_T2]][1]
+ %memref = hal.interface.binding.subspan @io::@ret0[%c72] : memref<?x2xf32>
+ // ...
+ // CHECK: %[[CDIM0_I32:.+]] = llvm.load %14 : !llvm.ptr<i32>
+ // CHECK: %[[CDIM0:.+]] = llvm.zext %[[CDIM0_I32]] : i32 to i64
+ %dim = hal.interface.load.constant offset = 0 : index
+ %shape = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2]>
+ // CHECK: %[[MEMREF_T0:.+]] = llvm.insertvalue %[[CDIM0]], %[[DYN_MEMREF]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[C2:.+]] = llvm.mlir.constant(2 : index) : i64
+ // CHECK: %[[MEMREF_T1:.+]] = llvm.insertvalue %[[C2]], %[[MEMREF_T0]][3, 1]
+ // CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: %[[TIED_MEMREF:.+]] = llvm.insertvalue %[[C1]], %[[MEMREF_T1]][4, 1]
+ // CHECK: %[[STRIDE1:.+]] = llvm.extractvalue %[[TIED_MEMREF]][4, 1]
+ // CHECK: %[[DIM1:.+]] = llvm.extractvalue %[[TIED_MEMREF]][3, 1]
+ // CHECK: %[[STRIDE0:.+]] = llvm.mul %[[STRIDE1]], %[[DIM1]] : i64
+ // CHECK: %[[FINAL_MEMREF:.+]] = llvm.insertvalue %[[STRIDE0]], %[[TIED_MEMREF]][4, 0]
+ %tied_memref = shapex.tie_shape %memref, %shape : memref<?x2xf32>, !shapex.ranked_shape<[?,2]>
+ // CHECK-NEXT: "test.sink"(%[[FINAL_MEMREF]])
+ "test.sink"(%tied_memref) : (memref<?x2xf32>) -> ()
+ return
+}
+hal.interface @io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
+}
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_constants.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_constants.mlir
index c0ee985..5956f65 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_constants.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_constants.mlir
@@ -2,8 +2,10 @@
// CHECK-LABEL: llvm.func @constant_values
func @constant_values() {
+ // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t"
+ // CHECK: %[[PTR_BASE:.+]] = llvm.extractvalue %[[STATE]][3]
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : index) : i64
- // CHECK: %[[VPTR:.+]] = llvm.getelementptr %arg1[%[[C1]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
+ // CHECK: %[[VPTR:.+]] = llvm.getelementptr %[[PTR_BASE]][%[[C1]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
// CHECK: %[[V32:.+]] = llvm.load %[[VPTR]] : !llvm.ptr<i32>
// CHECK: %[[V64:.+]] = llvm.zext %[[V32]] : i32 to i64
%v1 = hal.interface.load.constant offset = 1 : index
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_workgroup_info.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_workgroup_info.mlir
index 239a791..60430e0 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_workgroup_info.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_workgroup_info.mlir
@@ -2,8 +2,8 @@
// CHECK-LABEL: llvm.func @workgroup_id
func @workgroup_id() {
- // CHECK: %[[PTR:.+]] = llvm.load %arg2 : !llvm.ptr<array<3 x i32>>
- // CHECK: %[[Z32:.+]] = llvm.extractvalue %[[PTR]][2 : i32] : !llvm.array<3 x i32>
+ // CHECK: %[[PTR:.+]] = llvm.load %arg1 : !llvm.ptr<array<3 x i32>>
+ // CHECK: %[[Z32:.+]] = llvm.extractvalue %[[PTR]][2] : !llvm.array<3 x i32>
// CHECK: %[[Z64:.+]] = llvm.zext %[[Z32]] : i32 to i64
%workgroup_id_z = hal.interface.workgroup.id[2] : index
// CHECK-NEXT: "test.sink"(%[[Z64]])
@@ -15,8 +15,9 @@
// CHECK-LABEL: llvm.func @workgroup_size
func @workgroup_size() {
- // CHECK: %[[PTR:.+]] = llvm.load %arg4 : !llvm.ptr<array<3 x i32>>
- // CHECK: %[[Z32:.+]] = llvm.extractvalue %[[PTR]][2 : i32] : !llvm.array<3 x i32>
+ // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t"
+ // CHECK: %[[SIZE_PTR:.+]] = llvm.extractvalue %[[STATE]][1] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t"
+ // CHECK: %[[Z32:.+]] = llvm.extractvalue %[[SIZE_PTR]][2] : !llvm.array<3 x i32>
// CHECK: %[[Z64:.+]] = llvm.zext %[[Z32]] : i32 to i64
%workgroup_size_z = hal.interface.workgroup.size[2] : index
// CHECK-NEXT: "test.sink"(%[[Z64]])
@@ -28,8 +29,9 @@
// CHECK-LABEL: llvm.func @workgroup_count
func @workgroup_count() {
- // CHECK: %[[PTR:.+]] = llvm.load %arg3 : !llvm.ptr<array<3 x i32>>
- // CHECK: %[[Z32:.+]] = llvm.extractvalue %[[PTR]][2 : i32] : !llvm.array<3 x i32>
+ // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t"
+ // CHECK: %[[COUNT_PTR:.+]] = llvm.extractvalue %[[STATE]][0] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t"
+ // CHECK: %[[Z32:.+]] = llvm.extractvalue %[[COUNT_PTR]][2] : !llvm.array<3 x i32>
// CHECK: %[[Z64:.+]] = llvm.zext %[[Z32]] : i32 to i64
%workgroup_count_z = hal.interface.workgroup.count[2] : index
// CHECK-NEXT: "test.sink"(%[[Z64]])
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/plan-conv-loop-order.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/plan_conv_loop_order.mlir
similarity index 100%
rename from iree/compiler/Conversion/LinalgToLLVM/test/plan-conv-loop-order.mlir
rename to iree/compiler/Conversion/LinalgToLLVM/test/plan_conv_loop_order.mlir
diff --git a/iree/hal/local/executable_library.h b/iree/hal/local/executable_library.h
index a8ec8b2..44dafad 100644
--- a/iree/hal/local/executable_library.h
+++ b/iree/hal/local/executable_library.h
@@ -25,6 +25,57 @@
#include <stdint.h>
//===----------------------------------------------------------------------===//
+// Common utilities included to reduce dependencies
+//===----------------------------------------------------------------------===//
+
+// `restrict` keyword, not supported by some older compilers.
+// We define our own macro in case dependencies use `restrict` differently.
+#if defined(_MSC_VER) && _MSC_VER >= 1900
+#define IREE_RESTRICT __restrict
+#elif defined(_MSC_VER)
+#define IREE_RESTRICT
+#else
+#define IREE_RESTRICT restrict
+#endif // _MSC_VER
+
+//===----------------------------------------------------------------------===//
+// Runtime feature support metadata
+//===----------------------------------------------------------------------===//
+
+// Defines a bitfield of features that the library requires or supports.
+enum iree_hal_executable_library_feature_e {
+ IREE_HAL_EXECUTABLE_LIBRARY_FEATURE_NONE = 0u,
+ // TODO(benvanik): declare features for debugging/coverage/printf/etc.
+ // These will control which symbols are injected into the library at runtime.
+};
+typedef uint32_t iree_hal_executable_library_features_t;
+
+// Defines a set of supported sanitizers that libraries may be compiled with.
+// Loaders can use this declaration to check as to whether the library is
+// compatible with the hosting environment for cases where the sanitizer
+// requires host support.
+enum iree_hal_executable_library_sanitizer_kind_e {
+ IREE_HAL_EXECUTABLE_LIBRARY_SANITIZER_NONE = 0u,
+ // Indicates the library is compiled to use AddressSanitizer:
+ // https://clang.llvm.org/docs/AddressSanitizer.html
+ // Equivalent compiler flag: -fsanitize=address
+ IREE_HAL_EXECUTABLE_LIBRARY_SANITIZER_ADDRESS = 1u,
+ // Indicates the library is compiled to use MemorySanitizer:
+ // https://clang.llvm.org/docs/MemorySanitizer.html
+ // Equivalent compiler flag: -fsanitize=memory
+ IREE_HAL_EXECUTABLE_LIBRARY_SANITIZER_MEMORY = 2u,
+ // Indicates the library is compiled to use ThreadSanitizer:
+ // https://clang.llvm.org/docs/ThreadSanitizer.html
+ // Equivalent compiler flag: -fsanitize=thread
+ IREE_HAL_EXECUTABLE_LIBRARY_SANITIZER_THREAD = 3u,
+ // Indicates the library is compiled to use UndefinedBehaviorSanitizer:
+ // https://clang.llvm.org/docs/UndefinedBehaviorSanitizer.html
+ // Equivalent compiler flag: -fsanitize=undefined
+ IREE_HAL_EXECUTABLE_LIBRARY_SANITIZER_UNDEFINED = 4u,
+};
+typedef uint32_t iree_hal_executable_library_sanitizer_kind_t;
+
+//===----------------------------------------------------------------------===//
// Versioning and interface querying
//===----------------------------------------------------------------------===//
@@ -47,9 +98,14 @@
// Version of the API this library was built with, which was likely the value
// of IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION.
iree_hal_executable_library_version_t version;
-
// Name used for logging/diagnostics.
const char* name;
+ // Bitfield of features required/supported by this executable.
+ iree_hal_executable_library_features_t features;
+ // Which sanitizer the library is compiled to use, if any.
+ // Libraries meant for use with a particular sanitizer will are only usable
+ // with hosting code that is using the same sanitizer.
+ iree_hal_executable_library_sanitizer_kind_t sanitizer;
} iree_hal_executable_library_header_t;
// Exported function from dynamic libraries for querying library information.
@@ -68,10 +124,11 @@
// IREE_HAL_EXECUTABLE_LIBRARY_VERSION_0
//===----------------------------------------------------------------------===//
-// Read-only per-dispatch state passed to each tile in a dispatch.
+// TBD: do not use this yet.
typedef struct {
- uint32_t reserved;
-} iree_hal_executable_dispatch_state_v0_t;
+ size_t import_count;
+ void* import_fns;
+} iree_hal_executable_import_table_v0_t;
typedef union {
struct {
@@ -82,33 +139,47 @@
uint32_t value[3];
} iree_hal_vec3_t;
-#if defined(_MSC_VER)
-typedef __declspec(
- align(16)) const uint32_t* iree_hal_executable_push_constants_ptr_t;
-#else
-typedef const uint32_t* iree_hal_executable_push_constants_ptr_t
- __attribute__((align_value(16)));
-#endif // MSVC
+// Read-only per-dispatch state passed to each workgroup in a dispatch.
+typedef struct {
+ // Total workgroup count for the dispatch. This is sourced from either the
+ // original dispatch call (for iree_hal_command_buffer_dispatch) or the
+ // indirection buffer (for iree_hal_command_buffer_dispatch_indirect).
+ iree_hal_vec3_t workgroup_count;
+ // Workgroup size chosen for the dispatch. For compilation modes where the
+ // workgroup size is constant this may be ignored.
+ iree_hal_vec3_t workgroup_size;
-typedef void* iree_hal_executable_binding_ptr_t;
+ // Total number of available 4 byte push constant values in |push_constants|.
+ size_t push_constant_count;
+ // |push_constant_count| values.
+ const uint32_t* push_constants;
+
+ // Total number of binding base pointers in |binding_ptrs| and
+ // |binding_lengths|. The set is packed densely based on which binidngs are
+ // used (known at compile-time).
+ size_t binding_count;
+ // Base pointers to each binding buffer.
+ void* const* binding_ptrs;
+ // The length of each binding in bytes, 1:1 with |binding_ptrs|.
+ const size_t* binding_lengths;
+
+ // Optional imported functions available for use within the executable.
+ const iree_hal_executable_import_table_v0_t* imports;
+} iree_hal_executable_dispatch_state_v0_t;
// Function signature of exported executable entry points.
-// The same |state| is passed to all tiles in a dispatch, with other arguments
-// such as |workgroup_id| varying per-tile (counting to the |workgroup_count|).
-// Each tile represents |workgroup_size| local invocations in the global
-// |workgroup_count| grid.
+// The same |dispatch_state| is passed to all workgroups in a dispatch while
+// |workgroup_id| will vary for each workgroup.
//
-// 0 or more push constants are available at |push_constants| with the count
-// being determined by the sidechannel information provided by the compiler.
-//
-// The |bindings| list is a dense set of pointers to I/O data with the count and
-// ordering determined by the compiler.
-typedef void (*iree_hal_executable_dispatch_v0_t)(
- const iree_hal_executable_dispatch_state_v0_t* state,
- const iree_hal_vec3_t* workgroup_id, const iree_hal_vec3_t* workgroup_size,
- const iree_hal_vec3_t* workgroup_count,
- const iree_hal_executable_push_constants_ptr_t push_constants,
- const iree_hal_executable_binding_ptr_t* bindings);
+// Returns 0 on success and non-zero on failure. Failures will cause device loss
+// and should only be used to communicate serious issues that should abort all
+// execution within the current device. Buffer overflows are a good example of
+// a useful failure though the HAL does not mandate that all overflows are
+// caught and only that they are not harmful - clamping byte ranges and never
+// returning a failure is sufficient.
+typedef int (*iree_hal_executable_dispatch_v0_t)(
+ const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
+ const iree_hal_vec3_t* workgroup_id);
// Structure used for v0 library interfaces.
// The entire structure is designed to be read-only and able to live embedded in
@@ -126,22 +197,21 @@
// The total number of entry points available in the library. Bounds all of
// the tables below.
uint32_t entry_point_count;
-
// Table of export function entry points matching the ordinals defined during
// library generation. The runtime will use this table to map the ordinals to
// function pointers for execution.
const iree_hal_executable_dispatch_v0_t* entry_points;
-
// Optional table of export function entry point names 1:1 with entry_points.
// These names are only used for tracing/debugging and can be omitted to save
// binary size.
const char** entry_point_names;
-
// Optional table of entry point tags that describe the entry point in a
// human-readable format useful for verbose logging. The string values, when
// present, may be attached to tracing/debugging events related to the entry
// point.
const char** entry_point_tags;
+
+ // TODO(benvanik): optional import declarations.
} iree_hal_executable_library_v0_t;
#endif // IREE_HAL_LOCAL_EXECUTABLE_LIBRARY_H_
diff --git a/iree/hal/local/loaders/legacy_library_loader.cc b/iree/hal/local/loaders/legacy_library_loader.cc
index bc3256b..7e8dcb3 100644
--- a/iree/hal/local/loaders/legacy_library_loader.cc
+++ b/iree/hal/local/loaders/legacy_library_loader.cc
@@ -88,12 +88,6 @@
// iree_hal_legacy_executable_t
//===----------------------------------------------------------------------===//
-typedef void (*iree_hal_legacy_executable_fn_ptr_t)(void* const*,
- const uint32_t*,
- const uint32_t*,
- const uint32_t*,
- const uint32_t*);
-
typedef struct {
iree_hal_local_executable_t base;
@@ -110,7 +104,7 @@
// Resolved entry points from the dynamic library.
iree_host_size_t entry_fn_count;
- iree_hal_legacy_executable_fn_ptr_t entry_fns[];
+ iree_hal_executable_dispatch_v0_t entry_fns[];
} iree_hal_legacy_executable_t;
extern const iree_hal_local_executable_vtable_t
@@ -204,7 +198,7 @@
"symbol %s not exported by the dynamic library, check visibility",
entry_point_str);
}
- executable->entry_fns[i] = (iree_hal_legacy_executable_fn_ptr_t)symbol;
+ executable->entry_fns[i] = (iree_hal_executable_dispatch_v0_t)symbol;
}
return iree_ok_status();
}
@@ -306,7 +300,8 @@
static iree_status_t iree_hal_legacy_executable_issue_call(
iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal,
- const iree_hal_local_executable_call_t* call) {
+ const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
+ const iree_hal_vec3_t* workgroup_id) {
iree_hal_legacy_executable_t* executable =
(iree_hal_legacy_executable_t*)base_executable;
@@ -327,10 +322,7 @@
entry_point_name.size);
#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- executable->entry_fns[ordinal](call->bindings, call->push_constants,
- (const uint32_t*)&call->workgroup_id,
- (const uint32_t*)&call->workgroup_count,
- (const uint32_t*)&call->workgroup_size);
+ executable->entry_fns[ordinal](dispatch_state, workgroup_id);
IREE_TRACE_ZONE_END(z0);
diff --git a/iree/hal/local/loaders/system_library_loader.c b/iree/hal/local/loaders/system_library_loader.c
index 00a0005..c707891 100644
--- a/iree/hal/local/loaders/system_library_loader.c
+++ b/iree/hal/local/loaders/system_library_loader.c
@@ -88,7 +88,8 @@
static iree_status_t iree_hal_system_executable_issue_call(
iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal,
- const iree_hal_local_executable_call_t* call) {
+ const iree_hal_executable_dispatch_state_v0_t* IREE_RESTRICT dispatch_state,
+ const iree_hal_vec3_t* IREE_RESTRICT workgroup_id) {
iree_hal_system_executable_t* executable =
(iree_hal_system_executable_t*)base_executable;
@@ -108,13 +109,16 @@
entry_point_name.size);
#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- executable->library.v0->entry_points[ordinal](
- call->state, &call->workgroup_id, &call->workgroup_size,
- &call->workgroup_count, call->push_constants, call->bindings);
+ int ret = executable->library.v0->entry_points[ordinal](dispatch_state,
+ workgroup_id);
IREE_TRACE_ZONE_END(z0);
- return iree_ok_status();
+ return ret == 0 ? iree_ok_status()
+ : iree_make_status(
+ IREE_STATUS_INTERNAL,
+ "executable entry point returned catastrophic error %d",
+ ret);
}
static const iree_hal_local_executable_vtable_t
diff --git a/iree/hal/local/loaders/vmla_module_loader.cc b/iree/hal/local/loaders/vmla_module_loader.cc
index 83eca6e..5cc4479 100644
--- a/iree/hal/local/loaders/vmla_module_loader.cc
+++ b/iree/hal/local/loaders/vmla_module_loader.cc
@@ -160,7 +160,8 @@
static iree_status_t iree_hal_vmla_executable_issue_call(
iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal,
- const iree_hal_local_executable_call_t* call) {
+ const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
+ const iree_hal_vec3_t* workgroup_id) {
iree_hal_vmla_executable_t* executable =
(iree_hal_vmla_executable_t*)base_executable;
@@ -195,17 +196,19 @@
/*element_type=*/NULL,
/*interface*/ 1 + /*workgroup_xyz[3]*/ 3, &input_list));
iree_vm_list_push_ref_retain(input_list, &interface_ref);
- iree_vm_value_t workgroup_id_x = iree_vm_value_make_i32(call->workgroup_id.x);
- iree_vm_value_t workgroup_id_y = iree_vm_value_make_i32(call->workgroup_id.y);
- iree_vm_value_t workgroup_id_z = iree_vm_value_make_i32(call->workgroup_id.z);
+ iree_vm_value_t workgroup_id_x = iree_vm_value_make_i32(workgroup_id->x);
+ iree_vm_value_t workgroup_id_y = iree_vm_value_make_i32(workgroup_id->y);
+ iree_vm_value_t workgroup_id_z = iree_vm_value_make_i32(workgroup_id->z);
iree_vm_list_push_value(input_list, &workgroup_id_x);
iree_vm_list_push_value(input_list, &workgroup_id_y);
iree_vm_list_push_value(input_list, &workgroup_id_z);
iree_hal_local_executable_layout_t* local_layout =
executable->base.executable_layouts[ordinal];
- IREE_CHECK_OK(interface.SetConstants(
- absl::MakeConstSpan(call->push_constants, local_layout->push_constants)));
+ IREE_CHECK_EQ(local_layout->push_constants,
+ dispatch_state->push_constant_count);
+ IREE_CHECK_OK(interface.SetConstants(absl::MakeConstSpan(
+ dispatch_state->push_constants, dispatch_state->push_constant_count)));
for (iree_host_size_t set_ordinal = 0;
set_ordinal < local_layout->set_layout_count; ++set_ordinal) {
@@ -214,7 +217,8 @@
local_layout->set_layouts[set_ordinal]);
for (iree_host_size_t i = 0; i < local_set_layout->binding_count; ++i) {
auto buffer_or = iree::hal::vmla::Buffer::WrapMutable(
- call->bindings[i], call->binding_lengths[i], iree_allocator_null());
+ dispatch_state->binding_ptrs[i], dispatch_state->binding_lengths[i],
+ iree_allocator_null());
if (!buffer_or.ok()) {
IREE_CHECK_OK(std::move(buffer_or).status());
}
diff --git a/iree/hal/local/local_executable.c b/iree/hal/local/local_executable.c
index 60c1846..9dcddb0 100644
--- a/iree/hal/local/local_executable.c
+++ b/iree/hal/local/local_executable.c
@@ -49,10 +49,12 @@
iree_status_t iree_hal_local_executable_issue_call(
iree_hal_local_executable_t* executable, iree_host_size_t ordinal,
- const iree_hal_local_executable_call_t* call) {
+ const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
+ const iree_hal_vec3_t* workgroup_id) {
IREE_ASSERT_ARGUMENT(executable);
- IREE_ASSERT_ARGUMENT(call);
+ IREE_ASSERT_ARGUMENT(dispatch_state);
+ IREE_ASSERT_ARGUMENT(workgroup_id);
return ((const iree_hal_local_executable_vtable_t*)
executable->resource.vtable)
- ->issue_call(executable, ordinal, call);
+ ->issue_call(executable, ordinal, dispatch_state, workgroup_id);
}
diff --git a/iree/hal/local/local_executable.h b/iree/hal/local/local_executable.h
index 7503bc3..a86d26c 100644
--- a/iree/hal/local/local_executable.h
+++ b/iree/hal/local/local_executable.h
@@ -25,16 +25,6 @@
#endif // __cplusplus
typedef struct {
- const iree_hal_executable_dispatch_state_v0_t* state;
- iree_hal_vec3_t workgroup_id;
- iree_hal_vec3_t workgroup_size;
- iree_hal_vec3_t workgroup_count;
- iree_hal_executable_push_constants_ptr_t push_constants;
- const iree_hal_executable_binding_ptr_t* bindings;
- const iree_device_size_t* binding_lengths;
-} iree_hal_local_executable_call_t;
-
-typedef struct {
iree_hal_resource_t resource;
iree_allocator_t host_allocator;
iree_host_size_t executable_layout_count;
@@ -46,7 +36,8 @@
iree_status_t(IREE_API_PTR* issue_call)(
iree_hal_local_executable_t* executable, iree_host_size_t ordinal,
- const iree_hal_local_executable_call_t* call);
+ const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
+ const iree_hal_vec3_t* workgroup_id);
} iree_hal_local_executable_vtable_t;
// Callers must allocate memory for |target_executable_layouts| with at least
@@ -67,7 +58,8 @@
iree_status_t iree_hal_local_executable_issue_call(
iree_hal_local_executable_t* executable, iree_host_size_t ordinal,
- const iree_hal_local_executable_call_t* call);
+ const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
+ const iree_hal_vec3_t* workgroup_id);
#ifdef __cplusplus
} // extern "C"
diff --git a/iree/hal/local/task_command_buffer.c b/iree/hal/local/task_command_buffer.c
index 498fc11..9d84fb3 100644
--- a/iree/hal/local/task_command_buffer.c
+++ b/iree/hal/local/task_command_buffer.c
@@ -82,9 +82,8 @@
// represent the fully-translated binding data pointer.
// TODO(benvanik): support proper mapping semantics and track the
// iree_hal_buffer_mapping_t and map/unmap where appropriate.
- iree_hal_executable_binding_ptr_t
- bindings[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT *
- IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT];
+ void* bindings[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT *
+ IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT];
iree_device_size_t
binding_lengths[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT *
IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT];
@@ -706,9 +705,7 @@
iree_task_dispatch_t task;
iree_hal_local_executable_t* executable;
iree_host_size_t ordinal;
- iree_hal_executable_binding_ptr_t* IREE_RESTRICT bindings;
- iree_device_size_t* IREE_RESTRICT binding_lengths;
- uint32_t* IREE_RESTRICT push_constants;
+ iree_hal_executable_dispatch_state_v0_t state;
} iree_hal_cmd_dispatch_t;
static iree_status_t iree_hal_cmd_dispatch_tile(
@@ -718,24 +715,9 @@
(const iree_hal_cmd_dispatch_t*)user_context;
IREE_TRACE_ZONE_BEGIN(z0);
- iree_hal_executable_dispatch_state_v0_t state;
- // TODO(benvanik): wire up device state (imports, etc) and cache on the
- // command buffer for reuse across all tiles.
-
- iree_hal_local_executable_call_t call = {
- .state = &state,
- .push_constants = cmd->push_constants,
- .bindings = cmd->bindings,
- .binding_lengths = cmd->binding_lengths,
- };
- memcpy(call.workgroup_id.value, tile_context->workgroup_xyz,
- sizeof(iree_hal_vec3_t));
- memcpy(call.workgroup_size.value, tile_context->workgroup_size,
- sizeof(iree_hal_vec3_t));
- memcpy(call.workgroup_count.value, tile_context->workgroup_count,
- sizeof(iree_hal_vec3_t));
iree_status_t status = iree_hal_local_executable_issue_call(
- cmd->executable, cmd->ordinal, &call);
+ cmd->executable, cmd->ordinal, &cmd->state,
+ (const iree_hal_vec3_t*)tile_context->workgroup_xyz);
IREE_TRACE_ZONE_END(z0);
return status;
@@ -761,7 +743,7 @@
iree_hal_cmd_dispatch_t* cmd = NULL;
iree_host_size_t total_cmd_size =
sizeof(*cmd) + push_constant_count * sizeof(uint32_t) +
- used_binding_count * sizeof(iree_hal_executable_binding_ptr_t) +
+ used_binding_count * sizeof(void*) +
used_binding_count * sizeof(iree_device_size_t);
IREE_RETURN_IF_ERROR(iree_arena_allocate(&command_buffer->arena,
total_cmd_size, (void**)&cmd));
@@ -769,20 +751,26 @@
cmd->executable = local_executable;
cmd->ordinal = entry_point;
- uint32_t workgroup_count[3] = {workgroup_x, workgroup_y, workgroup_z};
+ const uint32_t workgroup_count[3] = {workgroup_x, workgroup_y, workgroup_z};
// TODO(benvanik): expose on API or keep fixed on executable.
- uint32_t workgroup_size[3] = {1, 1, 1};
+ const uint32_t workgroup_size[3] = {1, 1, 1};
iree_task_dispatch_initialize(command_buffer->scope,
iree_task_make_dispatch_closure(
iree_hal_cmd_dispatch_tile, (uintptr_t)cmd),
workgroup_size, workgroup_count, &cmd->task);
+ iree_hal_executable_dispatch_state_v0_t* state = &cmd->state;
+ memcpy(&state->workgroup_size, workgroup_size, sizeof(iree_hal_vec3_t));
+ memcpy(&state->workgroup_count, workgroup_count, sizeof(iree_hal_vec3_t));
+
// Copy only the push constant range used by the executable.
uint8_t* cmd_ptr = (uint8_t*)cmd + sizeof(*cmd);
- cmd->push_constants = (uint32_t*)cmd_ptr;
- memcpy(cmd->push_constants, command_buffer->state.push_constants,
- push_constant_count * sizeof(*cmd->push_constants));
- cmd_ptr += push_constant_count * sizeof(*cmd->push_constants);
+ uint32_t* push_constants = (uint32_t*)cmd_ptr;
+ memcpy(push_constants, command_buffer->state.push_constants,
+ push_constant_count * sizeof(*push_constants));
+ cmd_ptr += push_constant_count * sizeof(*push_constants);
+ state->push_constant_count = push_constant_count;
+ state->push_constants = push_constants;
// Produce the dense binding list based on the declared bindings used.
// This allows us to change the descriptor sets and bindings counts supported
@@ -792,24 +780,26 @@
// Note that we are just directly setting the binding data pointers here with
// no ownership/retaining/etc - it's part of the HAL contract that buffers are
// kept valid for the duration they may be in use.
- cmd->bindings = (iree_hal_executable_binding_ptr_t*)cmd_ptr;
- cmd_ptr += used_binding_count * sizeof(*cmd->bindings);
- cmd->binding_lengths = (iree_device_size_t*)cmd_ptr;
- cmd_ptr += used_binding_count * sizeof(*cmd->binding_lengths);
+ state->binding_count = used_binding_count;
+ void** binding_ptrs = (void**)cmd_ptr;
+ cmd_ptr += used_binding_count * sizeof(*binding_ptrs);
+ size_t* binding_lengths = (size_t*)cmd_ptr;
+ cmd_ptr += used_binding_count * sizeof(*binding_lengths);
iree_host_size_t binding_base = 0;
for (iree_host_size_t i = 0; i < used_binding_count; ++i) {
int mask_offset = iree_math_count_trailing_zeros_u64(used_binding_mask);
int binding_ordinal = binding_base + mask_offset;
binding_base += mask_offset + 1;
used_binding_mask = iree_shr(used_binding_mask, mask_offset + 1);
- cmd->bindings[i] = command_buffer->state.bindings[binding_ordinal];
- cmd->binding_lengths[i] =
- command_buffer->state.binding_lengths[binding_ordinal];
- if (!cmd->bindings[i]) {
+ binding_ptrs[i] = command_buffer->state.bindings[binding_ordinal];
+ binding_lengths[i] = command_buffer->state.binding_lengths[binding_ordinal];
+ if (!binding_ptrs[i]) {
return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
"(flat) binding %d is NULL", binding_ordinal);
}
}
+ state->binding_ptrs = binding_ptrs;
+ state->binding_lengths = binding_lengths;
*out_cmd = cmd;
return iree_hal_task_command_buffer_emit_execution_task(command_buffer,