Splitting per-workgroup params from per-dispatch params. (#8507)
This allows for sharing the dispatch state across all workgroups.
The states are now packed such that they fit in single cache lines.
diff --git a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index f5f270d..dc74bea 100644
--- a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -95,10 +95,10 @@
// Matches the field order in iree_hal_executable_environment_v0_t.
enum class EnvironmentField {
- constants = 0,
- import_thunk = 1,
- imports = 2,
- processor = 3,
+ constants,
+ import_thunk,
+ imports,
+ processor,
};
// Returns a Type representing iree_hal_executable_environment_v0_t.
@@ -138,63 +138,117 @@
return structType;
}
- // Returns a Type representing iree_hal_vec3_t.
- static Type getVec3Type(MLIRContext *context) {
- auto uint32Type = IntegerType::get(context, 32);
- return LLVM::LLVMArrayType::get(uint32Type, 3);
- }
-
// Matches the field order in iree_hal_executable_dispatch_state_v0_t.
- enum class StateField {
- workgroup_count = 0,
- workgroup_size = 1,
- push_constant_count = 2,
- push_constants = 3,
- binding_count = 4,
- binding_ptrs = 5,
- binding_lengths = 6,
- processor_id = 7,
- environment = 8,
+ enum class DispatchStateField {
+ /*uint32_t*/ workgroup_size_x,
+ /*uint32_t*/ workgroup_size_y,
+ /*uint16_t*/ workgroup_size_z,
+ /*uint16_t*/ push_constant_count,
+ /*uint32_t*/ workgroup_count_x,
+ /*uint32_t*/ workgroup_count_y,
+ /*uint16_t*/ workgroup_count_z,
+ /*uint16_t*/ binding_count,
+ /*intptr_t*/ push_constants,
+ /*intptr_t*/ binding_ptrs,
+ /*intptr_t*/ binding_lengths,
};
+ friend DispatchStateField operator+(DispatchStateField lhs, int32_t rhs) {
+ return static_cast<DispatchStateField>(static_cast<int32_t>(lhs) + rhs);
+ }
// Returns a Type representing iree_hal_executable_dispatch_state_v0_t.
static LLVM::LLVMStructType getDispatchStateType(
- MLIRContext *context, LLVMTypeConverter *typeConverter,
- LLVM::LLVMStructType environmentType) {
+ MLIRContext *context, LLVMTypeConverter *typeConverter) {
auto structType = LLVM::LLVMStructType::getIdentified(
context, "iree_hal_executable_dispatch_state_v0_t");
if (structType.isInitialized()) return structType;
auto indexType = typeConverter->convertType(IndexType::get(context));
auto int8Type = IntegerType::get(context, 8);
+ auto uint16Type = IntegerType::get(context, 16);
auto uint32Type = IntegerType::get(context, 32);
auto int8PtrType = LLVM::LLVMPointerType::get(int8Type);
auto uint32PtrType = LLVM::LLVMPointerType::get(uint32Type);
- auto vec3Type = getVec3Type(context);
SmallVector<Type, 4> fieldTypes;
- // iree_hal_vec3_t workgroup_count;
- // iree_hal_vec3_t workgroup_size;
- fieldTypes.push_back(vec3Type);
- fieldTypes.push_back(vec3Type);
+ // uint32_t workgroup_size_x;
+ // uint32_t workgroup_size_y;
+ // uint16_t workgroup_size_z;
+ fieldTypes.push_back(uint32Type);
+ fieldTypes.push_back(uint32Type);
+ fieldTypes.push_back(uint16Type);
- // size_t push_constant_count;
+ // uint16_t push_constant_count;
+ fieldTypes.push_back(uint16Type);
+
+ // uint32_t workgroup_count_x;
+ // uint32_t workgroup_count_y;
+ // uint16_t workgroup_count_z;
+ fieldTypes.push_back(uint32Type);
+ fieldTypes.push_back(uint32Type);
+ fieldTypes.push_back(uint16Type);
+
+ // uint16_t binding_count;
+ fieldTypes.push_back(uint16Type);
+
// const uint32_t * push_constants;
- fieldTypes.push_back(indexType);
fieldTypes.push_back(uint32PtrType);
-
- // size_t binding_count;
// void *const * binding_ptrs;
// const size_t * binding_lengths;
- fieldTypes.push_back(indexType);
fieldTypes.push_back(LLVM::LLVMPointerType::get(int8PtrType));
fieldTypes.push_back(LLVM::LLVMPointerType::get(indexType));
+ LogicalResult bodySet = structType.setBody(fieldTypes, /*isPacked=*/false);
+ assert(succeeded(bodySet) &&
+ "could not set the body of an identified struct");
+ (void)bodySet;
+
+ return structType;
+ }
+
+ enum class WorkgroupStateField {
+ /*uint32_t*/ workgroup_id_x = 0,
+ /*uint32_t*/ workgroup_id_y,
+ /*uint16_t*/ workgroup_id_z,
+ /*uint16_t*/ reserved,
+ /*uint32_t*/ processor_id,
+ /*intptr_t*/ local_memory,
+ /*uint32_t*/ local_memory_size,
+ };
+ friend WorkgroupStateField operator+(WorkgroupStateField lhs, int32_t rhs) {
+ return static_cast<WorkgroupStateField>(static_cast<int32_t>(lhs) + rhs);
+ }
+
+ // Returns a Type representing iree_hal_executable_workgroup_state_v0_t.
+ static LLVM::LLVMStructType getWorkgroupStateType(
+ MLIRContext *context, LLVMTypeConverter *typeConverter) {
+ auto structType = LLVM::LLVMStructType::getIdentified(
+ context, "iree_hal_executable_workgroup_state_v0_t");
+ if (structType.isInitialized()) return structType;
+
+ auto int8Type = IntegerType::get(context, 8);
+ auto uint16Type = IntegerType::get(context, 16);
+ auto uint32Type = IntegerType::get(context, 32);
+ auto int8PtrType = LLVM::LLVMPointerType::get(int8Type);
+ SmallVector<Type, 4> fieldTypes;
+
+ // uint32_t workgroup_id_x;
+ // uint32_t workgroup_id_y;
+ // uint16_t workgroup_id_z;
+ fieldTypes.push_back(uint32Type);
+ fieldTypes.push_back(uint32Type);
+ fieldTypes.push_back(uint16Type);
+
+ // uint16_t reserved;
+ fieldTypes.push_back(uint16Type);
+
// uint32_t processor_id;
fieldTypes.push_back(uint32Type);
- // const iree_hal_executable_environment_v0_t* environment;
- fieldTypes.push_back(LLVM::LLVMPointerType::get(environmentType));
+ // void* local_memory;
+ // uint32_t local_memory_size;
+ fieldTypes.push_back(LLVM::LLVMPointerType::get(int8PtrType));
+ fieldTypes.push_back(uint32Type);
LogicalResult bodySet = structType.setBody(fieldTypes, /*isPacked=*/false);
assert(succeeded(bodySet) &&
@@ -209,18 +263,28 @@
// `iree/hal/local/executable_library.h`.
static SmallVector<Type, 5> getInputTypes(MLIRContext *context,
LLVMTypeConverter *typeConverter) {
+ auto environmentType = LLVM::LLVMStructType::getIdentified(
+ context, "iree_hal_executable_environment_v0_t");
+ assert(environmentType &&
+ "environment type must be defined by ConvertToLLVM");
auto dispatchStateType = LLVM::LLVMStructType::getIdentified(
context, "iree_hal_executable_dispatch_state_v0_t");
assert(dispatchStateType &&
"dispatch state type must be defined by ConvertToLLVM");
+ auto workgroupStateType = LLVM::LLVMStructType::getIdentified(
+ context, "iree_hal_executable_workgroup_state_v0_t");
+ assert(workgroupStateType &&
+ "workgroup state type must be defined by ConvertToLLVM");
return SmallVector<Type, 5>{
+ // const iree_hal_executable_environment_v0_t* IREE_RESTRICT
+ // environment
+ LLVM::LLVMPointerType::get(environmentType),
// const iree_hal_executable_dispatch_state_v0_t* IREE_RESTRICT
// dispatch_state
LLVM::LLVMPointerType::get(dispatchStateType),
- // const iree_hal_vec3_t* IREE_RESTRICT workgroup_id
- LLVM::LLVMPointerType::get(getVec3Type(context)),
- // void* IREE_RESTRICT local_memory
- LLVM::LLVMPointerType::get(IntegerType::get(context, 8)),
+ // const iree_hal_executable_workgroup_state_v0_t* IREE_RESTRICT
+ // workgroup_state
+ LLVM::LLVMPointerType::get(workgroupStateType),
};
}
@@ -231,54 +295,57 @@
processorType(getProcessorType(funcOp.getContext(), typeConverter)),
environmentType(getEnvironmentType(funcOp.getContext(), typeConverter,
processorType)),
- dispatchStateType(getDispatchStateType(
- funcOp.getContext(), typeConverter, environmentType)) {}
+ dispatchStateType(
+ getDispatchStateType(funcOp.getContext(), typeConverter)),
+ workgroupStateType(
+ getWorkgroupStateType(funcOp.getContext(), typeConverter)) {}
LLVM::LLVMFuncOp getFuncOp() { return funcOp; }
// Loads the workgroup_id[dim] value (XYZ) and casts it to |resultType|.
Value loadWorkgroupID(Location loc, int32_t dim, Type resultType,
OpBuilder &builder) {
- auto workgroupIdPtrValue = funcOp.getArgument(1);
- auto workgroupIdValue =
- builder.createOrFold<LLVM::LoadOp>(loc, workgroupIdPtrValue);
- auto dimValue = builder.createOrFold<LLVM::ExtractValueOp>(
- loc, builder.getIntegerType(32), workgroupIdValue,
- builder.getI64ArrayAttr({dim}));
+ auto dimValue =
+ loadFieldValue(loc, WorkgroupStateField::workgroup_id_x + dim, builder);
return castValueToType(loc, dimValue, resultType, builder);
}
// Loads the workgroup_count[dim] value (XYZ) and casts it to |resultType|.
Value loadWorkgroupCount(Location loc, int32_t dim, Type resultType,
OpBuilder &builder) {
- auto workgroupCountValue =
- loadFieldValue(loc, StateField::workgroup_count, builder);
- auto dimValue = builder.createOrFold<LLVM::ExtractValueOp>(
- loc, builder.getIntegerType(32), workgroupCountValue,
- builder.getI64ArrayAttr(dim));
+ auto dimValue = loadFieldValue(
+ loc, DispatchStateField::workgroup_count_x + dim, builder);
return castValueToType(loc, dimValue, resultType, builder);
}
// Loads the workgroup_size[dim] value (XYZ) and casts it to |resultType|.
Value loadWorkgroupSize(Location loc, int32_t dim, Type resultType,
OpBuilder &builder) {
- auto workgroupSizeValue =
- loadFieldValue(loc, StateField::workgroup_size, builder);
- auto dimValue = builder.createOrFold<LLVM::ExtractValueOp>(
- loc, builder.getIntegerType(32), workgroupSizeValue,
- builder.getI64ArrayAttr(dim));
+ auto dimValue = loadFieldValue(
+ loc, DispatchStateField::workgroup_size_x + dim, builder);
return castValueToType(loc, dimValue, resultType, builder);
}
+ // Returns the total number of bytes available in workgroup local memory.
+ // This may be larger than the requested size.
+ Value loadWorkgroupLocalMemorySize(Location loc, OpBuilder &builder) {
+ auto value =
+ loadFieldValue(loc, WorkgroupStateField::local_memory_size, builder);
+ return castValueToType(loc, value,
+ typeConverter->convertType(builder.getIndexType()),
+ builder);
+ }
+
// Loads the base pointer of the workgroup local memory.
// Note that this may be NULL if no workgroup local memory was requested.
Value loadWorkgroupLocalMemoryPtr(Location loc, OpBuilder &builder) {
- return funcOp.getArgument(2);
+ return loadFieldValue(loc, WorkgroupStateField::local_memory, builder);
}
// Returns the total push constant count as an index-converted type.
Value loadPushConstantCount(Location loc, OpBuilder &builder) {
- auto value = loadFieldValue(loc, StateField::push_constant_count, builder);
+ auto value =
+ loadFieldValue(loc, DispatchStateField::push_constant_count, builder);
return castValueToType(loc, value,
typeConverter->convertType(builder.getIndexType()),
builder);
@@ -288,7 +355,7 @@
Value loadPushConstant(Location loc, int64_t offset, Type resultType,
OpBuilder &builder) {
auto constantsPtrValue =
- loadFieldValue(loc, StateField::push_constants, builder);
+ loadFieldValue(loc, DispatchStateField::push_constants, builder);
auto offsetValue = getIndexValue(loc, offset, builder);
Value constantPtrValue = builder.create<LLVM::GEPOp>(
loc, constantsPtrValue.getType(), constantsPtrValue, offsetValue);
@@ -298,7 +365,8 @@
// Returns the total binding count as an index-converted type.
Value loadBindingCount(Location loc, OpBuilder &builder) {
- auto value = loadFieldValue(loc, StateField::binding_count, builder);
+ auto value =
+ loadFieldValue(loc, DispatchStateField::binding_count, builder);
return castValueToType(loc, value,
typeConverter->convertType(builder.getIndexType()),
builder);
@@ -308,7 +376,8 @@
// Equivalent to:
// int8_t** base_ptr = &state->binding_ptrs[ordinal];
Value loadBindingPtr(Location loc, int64_t ordinal, OpBuilder &builder) {
- auto ptrsPtrValue = loadFieldValue(loc, StateField::binding_ptrs, builder);
+ auto ptrsPtrValue =
+ loadFieldValue(loc, DispatchStateField::binding_ptrs, builder);
auto ordinalValue = getIndexValue(loc, ordinal, builder);
auto elementPtrValue = builder.createOrFold<LLVM::GEPOp>(
loc, ptrsPtrValue.getType(), ptrsPtrValue, ordinalValue);
@@ -318,7 +387,7 @@
// Loads the byte length of the binding |ordinal| as an index-converted type.
Value loadBindingLength(Location loc, int64_t ordinal, OpBuilder &builder) {
auto lengthsPtrValue =
- loadFieldValue(loc, StateField::binding_lengths, builder);
+ loadFieldValue(loc, DispatchStateField::binding_lengths, builder);
auto ordinalValue = getIndexValue(loc, ordinal, builder);
auto elementPtrValue = builder.createOrFold<LLVM::GEPOp>(
loc, lengthsPtrValue.getType(), lengthsPtrValue, ordinalValue);
@@ -397,7 +466,7 @@
// Equivalent to:
// uint32_t processor_id = state->processor_id;
Value loadProcessorID(Location loc, OpBuilder &builder) {
- return loadFieldValue(loc, StateField::processor_id, builder);
+ return loadFieldValue(loc, WorkgroupStateField::processor_id, builder);
}
// Loads a processor information data field at the given index.
@@ -469,18 +538,22 @@
}
private:
- Value loadFieldValue(Location loc, StateField field, OpBuilder &builder) {
- Value statePtrValue = funcOp.getArgument(0);
- Value stateValue = builder.createOrFold<LLVM::LoadOp>(loc, statePtrValue);
- Type fieldType = dispatchStateType.getBody()[(int)field];
- return builder.createOrFold<LLVM::ExtractValueOp>(
- loc, fieldType, stateValue, builder.getI64ArrayAttr((int)field));
+ Value getIndexValue(Location loc, int64_t value, OpBuilder &builder) {
+ return builder.createOrFold<LLVM::ConstantOp>(
+ loc, typeConverter->convertType(builder.getIndexType()),
+ builder.getI64IntegerAttr(value));
+ }
+
+ Value castValueToType(Location loc, Value value, Type resultType,
+ OpBuilder &builder) {
+ // NOTE: we should handle more cases here (and proper sign extension).
+ if (value.getType() == resultType) return value;
+ return builder.createOrFold<LLVM::ZExtOp>(loc, resultType, value);
}
Value loadFieldValue(Location loc, EnvironmentField field,
OpBuilder &builder) {
- Value environmentPtrValue =
- loadFieldValue(loc, StateField::environment, builder);
+ auto environmentPtrValue = funcOp.getArgument(0);
Value environmentValue =
builder.create<LLVM::LoadOp>(loc, environmentPtrValue);
Type fieldType = environmentType.getBody()[(int)field];
@@ -496,17 +569,22 @@
loc, fieldType, processorValue, builder.getI64ArrayAttr((int)field));
}
- Value getIndexValue(Location loc, int64_t value, OpBuilder &builder) {
- return builder.createOrFold<LLVM::ConstantOp>(
- loc, typeConverter->convertType(builder.getIndexType()),
- builder.getI64IntegerAttr(value));
+ Value loadFieldValue(Location loc, DispatchStateField field,
+ OpBuilder &builder) {
+ Value statePtrValue = funcOp.getArgument(1);
+ Value stateValue = builder.createOrFold<LLVM::LoadOp>(loc, statePtrValue);
+ Type fieldType = dispatchStateType.getBody()[(int)field];
+ return builder.createOrFold<LLVM::ExtractValueOp>(
+ loc, fieldType, stateValue, builder.getI64ArrayAttr((int)field));
}
- Value castValueToType(Location loc, Value value, Type resultType,
- OpBuilder &builder) {
- // NOTE: we should handle more cases here (and proper sign extension).
- if (value.getType() == resultType) return value;
- return builder.createOrFold<LLVM::ZExtOp>(loc, resultType, value);
+ Value loadFieldValue(Location loc, WorkgroupStateField field,
+ OpBuilder &builder) {
+ Value statePtrValue = funcOp.getArgument(2);
+ Value stateValue = builder.createOrFold<LLVM::LoadOp>(loc, statePtrValue);
+ Type fieldType = dispatchStateType.getBody()[(int)field];
+ return builder.createOrFold<LLVM::ExtractValueOp>(
+ loc, fieldType, stateValue, builder.getI64ArrayAttr((int)field));
}
LLVM::LLVMFuncOp funcOp;
@@ -514,6 +592,7 @@
LLVM::LLVMStructType processorType;
LLVM::LLVMStructType environmentType;
LLVM::LLVMStructType dispatchStateType;
+ LLVM::LLVMStructType workgroupStateType;
};
/// Converts Standard MLIR FuncOps to LLVMFuncOps matching the IREE HAL ABI.
@@ -590,6 +669,17 @@
return failure();
}
+ // Tag all arguments so LLVM can reason about our exports it otherwise
+ // cannot analyze. We do this early on so that MLIR-based LLVM transforms
+ // can use the attributes.
+ // (%arg0: environment, %arg1: dispatch_state, %arg2: workgroup_state)
+ for (unsigned i = 0; i <= 2; ++i) {
+ llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getNoAliasAttrName(),
+ rewriter.getUnitAttr());
+ llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getAlignAttrName(),
+ rewriter.getI64IntegerAttr(16));
+ }
+
// Add default zero return value.
// TODO(ataei): do something meaningful with the return value; non-zero will
// have the runtime bail out with an error.
diff --git a/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir b/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir
index 29e62f9..d5faa03 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/hal_interface_bindings.mlir
@@ -7,16 +7,16 @@
// CHECK-DAG: %[[C72:.+]] = llvm.mlir.constant(72 : index) : i64
%c72 = arith.constant 72 : index
- // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<[[DISPATCH_STATE_TYPE:.+]]>>
- // CHECK: %[[PC:.+]] = llvm.extractvalue %[[STATE]][3] : !llvm.struct<[[DISPATCH_STATE_TYPE]]>
+ // CHECK: %[[STATE:.+]] = llvm.load %arg1 : !llvm.ptr<struct<[[DISPATCH_STATE_TYPE:.+]]>>
+ // CHECK: %[[PC:.+]] = llvm.extractvalue %[[STATE]][8]
// CHECK: %[[C2:.+]] = llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[DIM_PTR:.+]] = llvm.getelementptr %[[PC]][%[[C2]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
// CHECK: %[[DIM_I32:.+]] = llvm.load %[[DIM_PTR]] : !llvm.ptr<i32>
// CHECK: %[[DIM:.+]] = llvm.zext %[[DIM_I32]] : i32 to i64
%dim = hal.interface.constant.load[2] : index
- // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<[[DISPATCH_STATE_TYPE]]>>
- // CHECK: %[[BINDING_PTRS:.+]] = llvm.extractvalue %[[STATE]][5] : !llvm.struct<[[DISPATCH_STATE_TYPE]]>
+ // CHECK: %[[STATE:.+]] = llvm.load %arg1 : !llvm.ptr<struct<[[DISPATCH_STATE_TYPE]]>>
+ // CHECK: %[[BINDING_PTRS:.+]] = llvm.extractvalue %[[STATE]][9]
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ARRAY_PTR:.+]] = llvm.getelementptr %[[BINDING_PTRS]][%[[C1]]] : (!llvm.ptr<ptr<i8>>, i64) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[BASE_PTR_I8:.+]] = llvm.load %[[ARRAY_PTR]] : !llvm.ptr<ptr<i8>>
diff --git a/iree/compiler/Codegen/LLVMCPU/test/hal_interface_constants.mlir b/iree/compiler/Codegen/LLVMCPU/test/hal_interface_constants.mlir
index 0017396..72acff1 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/hal_interface_constants.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/hal_interface_constants.mlir
@@ -4,8 +4,8 @@
// CHECK-LABEL: llvm.func internal @constant_values
func @constant_values() {
- // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t"
- // CHECK: %[[PTR_BASE:.+]] = llvm.extractvalue %[[STATE]][3]
+ // CHECK: %[[STATE:.+]] = llvm.load %arg1 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t"
+ // CHECK: %[[PTR_BASE:.+]] = llvm.extractvalue %[[STATE]][8]
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1
// CHECK: %[[VPTR:.+]] = llvm.getelementptr %[[PTR_BASE]][%[[C1]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
// CHECK: %[[V32:.+]] = llvm.load %[[VPTR]] : !llvm.ptr<i32>
diff --git a/iree/compiler/Codegen/LLVMCPU/test/hal_interface_workgroup_info.mlir b/iree/compiler/Codegen/LLVMCPU/test/hal_interface_workgroup_info.mlir
index 8fef088..19f1fb2 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/hal_interface_workgroup_info.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/hal_interface_workgroup_info.mlir
@@ -4,9 +4,9 @@
// CHECK-LABEL: llvm.func internal @workgroup_id
func @workgroup_id() {
- // CHECK: %[[PTR:.+]] = llvm.load %arg1 : !llvm.ptr<array<3 x i32>>
- // CHECK: %[[Z32:.+]] = llvm.extractvalue %[[PTR]][2] : !llvm.array<3 x i32>
- // CHECK: %[[Z64:.+]] = llvm.zext %[[Z32]] : i32 to i64
+ // CHECK: %[[STATE:.+]] = llvm.load %arg2 : !llvm.ptr<struct<"iree_hal_executable_workgroup_state_v0_t"
+ // CHECK: %[[Z16:.+]] = llvm.extractvalue %[[STATE]][2]
+ // CHECK: %[[Z64:.+]] = llvm.zext %[[Z16]] : i16 to i64
%workgroup_id_z = hal.interface.workgroup.id[2] : index
// CHECK-NEXT: llvm.call @sink(%[[Z64]])
%val = arith.index_cast %workgroup_id_z : index to i64
@@ -20,10 +20,9 @@
// CHECK-LABEL: llvm.func internal @workgroup_size
func @workgroup_size() {
- // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t"
- // CHECK: %[[SIZE_PTR:.+]] = llvm.extractvalue %[[STATE]][1] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t"
- // CHECK: %[[Z32:.+]] = llvm.extractvalue %[[SIZE_PTR]][2] : !llvm.array<3 x i32>
- // CHECK: %[[Z64:.+]] = llvm.zext %[[Z32]] : i32 to i64
+ // CHECK: %[[STATE:.+]] = llvm.load %arg1 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t"
+ // CHECK: %[[Z16:.+]] = llvm.extractvalue %[[STATE]][2]
+ // CHECK: %[[Z64:.+]] = llvm.zext %[[Z16]] : i16 to i64
%workgroup_size_z = hal.interface.workgroup.size[2] : index
// CHECK-NEXT: llvm.call @sink(%[[Z64]])
%val = arith.index_cast %workgroup_size_z : index to i64
@@ -37,10 +36,9 @@
// CHECK-LABEL: llvm.func internal @workgroup_count
func @workgroup_count() {
- // CHECK: %[[STATE:.+]] = llvm.load %arg0 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t"
- // CHECK: %[[COUNT_PTR:.+]] = llvm.extractvalue %[[STATE]][0] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t"
- // CHECK: %[[Z32:.+]] = llvm.extractvalue %[[COUNT_PTR]][2] : !llvm.array<3 x i32>
- // CHECK: %[[Z64:.+]] = llvm.zext %[[Z32]] : i32 to i64
+ // CHECK: %[[STATE:.+]] = llvm.load %arg1 : !llvm.ptr<struct<"iree_hal_executable_dispatch_state_v0_t"
+ // CHECK: %[[Z16:.+]] = llvm.extractvalue %[[STATE]][6]
+ // CHECK: %[[Z64:.+]] = llvm.zext %[[Z16]] : i16 to i64
%workgroup_count_z = hal.interface.workgroup.count[2] : index
// CHECK-NEXT: llvm.call @sink(%[[Z64]])
%val = arith.index_cast %workgroup_count_z : index to i64
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
index d5fced2..953ed56 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -275,6 +275,7 @@
}
} break;
}
+ auto align16 = llvm::Attribute::getWithAlignment(context, llvm::Align(16));
for (auto entryPointOp :
variantOp.getBlock().getOps<ExecutableEntryPointOp>()) {
// Find the matching function in the LLVM module.
@@ -282,6 +283,19 @@
llvmFunc->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
llvmFunc->setDSOLocal(true);
+ // Tag the function parameters in case they got removed during conversion.
+ // (%arg0: environment, %arg1: dispatch_state, %arg2: workgroup_state)
+ for (unsigned i = 0; i <= 2; ++i) {
+ llvmFunc->addParamAttr(
+ i, llvm::Attribute::getWithByRefType(
+ context, llvmFunc->getArg(i)
+ ->getType()
+ ->getNonOpaquePointerElementType()));
+ llvmFunc->addParamAttr(i, llvm::Attribute::NonNull);
+ llvmFunc->addParamAttr(i, llvm::Attribute::NoAlias);
+ llvmFunc->addParamAttr(i, align16);
+ }
+
// Optionally entry points may specify that they require workgroup local
// memory. We fetch that value here and plumb it through so the runtime
// knows how much memory to reserve and pass in.
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp
index c41e8ee..14749a8 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LibraryBuilder.cpp
@@ -65,25 +65,6 @@
return type;
}
-// %struct.anon = type { i32, i32, i32 }
-// %union.iree_hal_vec3_t = type { %struct.anon }
-static llvm::StructType *makeVec3Type(llvm::LLVMContext &context) {
- if (auto *existingType =
- llvm::StructType::getTypeByName(context, "iree_hal_vec3_t")) {
- return existingType;
- }
- auto *i32Type = llvm::IntegerType::getInt32Ty(context);
- auto *type = llvm::StructType::create(context,
- {
- i32Type,
- i32Type,
- i32Type,
- },
- "iree_hal_vec3_t",
- /*isPacked=*/false);
- return type;
-}
-
// %struct.iree_hal_executable_dispatch_state_v0_t = type {
// ...
// }
@@ -94,20 +75,30 @@
return type;
}
-// i32 (%struct.iree_hal_executable_dispatch_state_v0_t*,
-// %union.iree_hal_vec3_t*,
+// %struct.iree_hal_executable_workgroup_state_v0_t = type {
+// ...
+// }
+static llvm::StructType *makeWorkgroupStateType(llvm::LLVMContext &context) {
+ auto *type = llvm::StructType::getTypeByName(
+ context, "iree_hal_executable_workgroup_state_v0_t");
+ assert(type && "state type must be defined by ConvertToLLVM");
+ return type;
+}
+
+// i32 (%struct.iree_hal_executable_environment_v0_t*,
+// %struct.iree_hal_executable_dispatch_state_v0_t*,
// i8*)
static llvm::FunctionType *makeDispatchFunctionType(
llvm::LLVMContext &context) {
+ auto *environmentType = makeEnvironmentType(context);
auto *dispatchStateType = makeDispatchStateType(context);
- auto *i8Type = llvm::IntegerType::getInt8Ty(context);
+ auto *workgroupStateType = makeWorkgroupStateType(context);
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
- auto *vec3Type = llvm::ArrayType::get(i32Type, 3);
return llvm::FunctionType::get(i32Type,
{
+ environmentType->getPointerTo(),
dispatchStateType->getPointerTo(),
- vec3Type->getPointerTo(),
- i8Type->getPointerTo(),
+ workgroupStateType->getPointerTo(),
},
/*isVarArg=*/false);
}
diff --git a/iree/compiler/Dialect/Modules/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp b/iree/compiler/Dialect/Modules/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp
index 734801a..c3fd120 100644
--- a/iree/compiler/Dialect/Modules/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp
+++ b/iree/compiler/Dialect/Modules/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp
@@ -53,9 +53,9 @@
/// %local_memory: !vmvx.buffer,
/// %constants: !vmvx.buffer,
/// %bindings: !util.list<!vmvx.buffer>,
-/// %workgroup_x: index,
-/// %workgroup_y: index,
-/// %workgroup_z: index,
+/// %workgroup_id_x: index,
+/// %workgroup_id_y: index,
+/// %workgroup_id_z: index,
/// %workgroup_size_x: index,
/// %workgroup_size_y: index,
/// %workgroup_size_z: index,
@@ -81,9 +81,9 @@
/*local_memory=*/memRefI8Type,
/*constants=*/memRefI32Type,
/*bindings=*/bindingsType,
- /*workgroup_x=*/indexType,
- /*workgroup_y=*/indexType,
- /*workgroup_z=*/indexType,
+ /*workgroup_id_x=*/indexType,
+ /*workgroup_id_y=*/indexType,
+ /*workgroup_id_z=*/indexType,
/*workgroup_size_x=*/indexType,
/*workgroup_size_y=*/indexType,
/*workgroup_size_z=*/indexType,
diff --git a/iree/hal/local/elf/elf_module_test_main.c b/iree/hal/local/elf/elf_module_test_main.c
index a6a2c10..b9d515d 100644
--- a/iree/hal/local/elf/elf_module_test_main.c
+++ b/iree/hal/local/elf/elf_module_test_main.c
@@ -113,22 +113,26 @@
arg1,
ret0,
};
- iree_hal_vec3_t workgroup_count = {{1, 1, 1}};
- iree_hal_vec3_t workgroup_size = {{1, 1, 1}};
- iree_hal_executable_dispatch_state_v0_t dispatch_state;
- memset(&dispatch_state, 0, sizeof(dispatch_state));
- dispatch_state.workgroup_count = workgroup_count;
- dispatch_state.workgroup_size = workgroup_size;
- dispatch_state.binding_count = 1;
- dispatch_state.binding_lengths = binding_lengths;
- dispatch_state.binding_ptrs = binding_ptrs;
- dispatch_state.processor_id = iree_cpu_query_processor_id();
- dispatch_state.environment = &environment;
- iree_hal_vec3_t workgroup_id = {{0, 0, 0}};
- void* local_memory = NULL;
+ const iree_hal_executable_dispatch_state_v0_t dispatch_state = {
+ .workgroup_size_x = 1,
+ .workgroup_size_y = 1,
+ .workgroup_size_z = 1,
+ .workgroup_count_x = 1,
+ .workgroup_count_y = 1,
+ .workgroup_count_z = 1,
+ .binding_count = 1,
+ .binding_lengths = binding_lengths,
+ .binding_ptrs = binding_ptrs,
+ };
+ const iree_hal_executable_workgroup_state_v0_t workgroup_state = {
+ .workgroup_id_x = 0,
+ .workgroup_id_y = 0,
+ .workgroup_id_z = 0,
+ .processor_id = iree_cpu_query_processor_id(),
+ };
int ret = iree_elf_call_i_ppp((const void*)library.v0->exports.ptrs[0],
- (void*)&dispatch_state, (void*)&workgroup_id,
- local_memory);
+ (void*)&environment, (void*)&dispatch_state,
+ (void*)&workgroup_state);
if (ret != 0) {
return iree_make_status(IREE_STATUS_INTERNAL,
"dispatch function returned failure: %d", ret);
diff --git a/iree/hal/local/elf/testdata/elementwise_mul_arm_32.so b/iree/hal/local/elf/testdata/elementwise_mul_arm_32.so
index ac55988..44258d8 100644
--- a/iree/hal/local/elf/testdata/elementwise_mul_arm_32.so
+++ b/iree/hal/local/elf/testdata/elementwise_mul_arm_32.so
Binary files differ
diff --git a/iree/hal/local/elf/testdata/elementwise_mul_arm_64.so b/iree/hal/local/elf/testdata/elementwise_mul_arm_64.so
index b5bcd5c..4cb65a9 100644
--- a/iree/hal/local/elf/testdata/elementwise_mul_arm_64.so
+++ b/iree/hal/local/elf/testdata/elementwise_mul_arm_64.so
Binary files differ
diff --git a/iree/hal/local/elf/testdata/elementwise_mul_riscv_32.so b/iree/hal/local/elf/testdata/elementwise_mul_riscv_32.so
index 962e60b..66e66de 100644
--- a/iree/hal/local/elf/testdata/elementwise_mul_riscv_32.so
+++ b/iree/hal/local/elf/testdata/elementwise_mul_riscv_32.so
Binary files differ
diff --git a/iree/hal/local/elf/testdata/elementwise_mul_riscv_64.so b/iree/hal/local/elf/testdata/elementwise_mul_riscv_64.so
index af976d2..543a68c 100644
--- a/iree/hal/local/elf/testdata/elementwise_mul_riscv_64.so
+++ b/iree/hal/local/elf/testdata/elementwise_mul_riscv_64.so
Binary files differ
diff --git a/iree/hal/local/elf/testdata/elementwise_mul_x86_32.so b/iree/hal/local/elf/testdata/elementwise_mul_x86_32.so
index 41a2ca3..d2a7408 100644
--- a/iree/hal/local/elf/testdata/elementwise_mul_x86_32.so
+++ b/iree/hal/local/elf/testdata/elementwise_mul_x86_32.so
Binary files differ
diff --git a/iree/hal/local/elf/testdata/elementwise_mul_x86_64.so b/iree/hal/local/elf/testdata/elementwise_mul_x86_64.so
index cebbb85..173a7a1 100644
--- a/iree/hal/local/elf/testdata/elementwise_mul_x86_64.so
+++ b/iree/hal/local/elf/testdata/elementwise_mul_x86_64.so
Binary files differ
diff --git a/iree/hal/local/elf/testdata/generate.sh b/iree/hal/local/elf/testdata/generate.sh
old mode 100644
new mode 100755
diff --git a/iree/hal/local/executable_library.h b/iree/hal/local/executable_library.h
index 5f72737..962a03b 100644
--- a/iree/hal/local/executable_library.h
+++ b/iree/hal/local/executable_library.h
@@ -258,56 +258,96 @@
iree_hal_processor_v0_t processor;
} iree_hal_executable_environment_v0_t;
-typedef union iree_hal_vec3_t {
- struct {
- uint32_t x;
- uint32_t y;
- uint32_t z;
- };
- uint32_t value[3];
-} iree_hal_vec3_t;
-
// Read-only per-dispatch state passed to each workgroup in a dispatch.
+//
+// We layout to try to fit everything commonly used into the first cache line
+// (on archs with 64-bit pointers; 32-bit fits in a single line).
+//
+// For workgroup dimensions we allow the full 32-bit range on X and Y as those
+// are the primary distribution dimensions. Z is the coarsest control and is
+// usually in the 1-16 range; any higher and it can pessimize scheduling. Almost
+// all GPUs also have this limitation (max Z of 65K) for the same reason.
typedef struct iree_hal_executable_dispatch_state_v0_t {
+ // Workgroup size chosen for the dispatch. For compilation modes where the
+ // workgroup size is constant this may be ignored.
+ uint32_t workgroup_size_x;
+ uint32_t workgroup_size_y;
+ uint16_t workgroup_size_z;
+
+ // Total number of available 4 byte push constant values in |push_constants|.
+ uint16_t push_constant_count;
+
// Total workgroup count for the dispatch. This is sourced from either the
// original dispatch call (for iree_hal_command_buffer_dispatch) or the
// indirection buffer (for iree_hal_command_buffer_dispatch_indirect).
- iree_hal_vec3_t workgroup_count;
- // Workgroup size chosen for the dispatch. For compilation modes where the
- // workgroup size is constant this may be ignored.
- iree_hal_vec3_t workgroup_size;
-
- // Total number of available 4 byte push constant values in |push_constants|.
- size_t push_constant_count;
- // |push_constant_count| values.
- const uint32_t* push_constants;
+ uint32_t workgroup_count_x;
+ uint32_t workgroup_count_y;
+ uint16_t workgroup_count_z;
// Total number of binding base pointers in |binding_ptrs| and
// |binding_lengths|. The set is packed densely based on which bindings are
// used (known at compile-time).
- size_t binding_count;
+ uint16_t binding_count;
+
+ // |push_constant_count| values.
+ const uint32_t* push_constants;
// Base pointers to each binding buffer.
void* const* binding_ptrs;
// The length of each binding in bytes, 1:1 with |binding_ptrs|.
const size_t* binding_lengths;
+ // NOTE: the above fields are frequently accessed and should be kept together
+ // to ensure cache-friendly behavior. The first instructions every dispatch
+ // executes are loads from the fields and we want to avoid a cascade of
+ // cache misses. Less-frequently used fields can follow.
+} iree_hal_executable_dispatch_state_v0_t;
+static_assert(sizeof(iree_hal_executable_dispatch_state_v0_t) <= 64,
+ "try keeping dispatch state small enough to fit in a cache line");
+
+// Read-only per-workgroup state passed to each workgroup in a dispatch.
+//
+// We layout to try to fit everything commonly used into the first cache line
+// (on archs with 64-bit pointers; 32-bit fits in a single line).
+typedef struct iree_hal_executable_workgroup_state_v0_t {
+ // Workgroup ID of the currently executing workgroup.
+ // This is in the range of 0-workgroup_count and each unique workgroup is to
+ // perform workgroup_size invocations.
+ uint32_t workgroup_id_x;
+ uint32_t workgroup_id_y;
+ uint16_t workgroup_id_z;
+
+ // Reserved for future use.
+ uint16_t reserved;
+
// Logical processor identifier used to index into processor info fields.
// Depending on the implementation this may be an ordinal, a bitfield, or an
// opaque unique identifier.
+ //
+ // NOTE: we could steal bits from the |processor_id| if needed; today the ID
+ // is the global ID but it really only needs to be within the current node
+ // (8-bits, or 16-bit for single-node thousand-core future proofing).
uint32_t processor_id;
- // Optional executable environment information.
- const iree_hal_executable_environment_v0_t* environment;
-} iree_hal_executable_dispatch_state_v0_t;
+
+ // Scratch memory available for use by the workgroup.
+ // Requires a non-zero value to be specified for |local_memory_pages|; at
+ // least the size specified will be available. This memory is transient and
+ // exclusive to the workgroup. The provided pointer may be NULL if no
+ // workgroup local memory was requested.
+ void* local_memory;
+ // Total number of bytes available in |local_memory|. This may be larger than
+ // the requested amount.
+ uint32_t local_memory_size;
+
+ // +4 trailing bytes of free space
+} iree_hal_executable_workgroup_state_v0_t;
+static_assert(
+ sizeof(iree_hal_executable_workgroup_state_v0_t) <= 64,
+ "try keeping workgroup state small enough to fit in a cache line");
// Function signature of exported executable entry points.
-// The same |dispatch_state| is passed to all workgroups in a dispatch while
-// |workgroup_id| and |local_memory| will vary for each workgroup.
-//
-// If a non-zero value was specified for |local_memory_page| then scratch memory
-// will be available for use by the invocation of at least the size specified.
-// This memory is transient and exclusive to the workgroup. The provided pointer
-// may be NULL if no workgroup local memory was requested and otherwise will
-// point to memory of the size specified.
+// The same |environment| is passed to all dispatches.
+// The same |dispatch_state| is passed to all workgroups within a dispatch.
+// A unique |workgroup_state| is passed to every workgroup within a dispatch.
//
// Returns 0 on success and non-zero on failure. Failures will cause device loss
// and should only be used to communicate serious issues that should abort all
@@ -316,8 +356,9 @@
// caught and only that they are not harmful - clamping byte ranges and never
// returning a failure is sufficient.
typedef int (*iree_hal_executable_dispatch_v0_t)(
+ const iree_hal_executable_environment_v0_t* environment,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- const iree_hal_vec3_t* workgroup_id, void* local_memory);
+ const iree_hal_executable_workgroup_state_v0_t* workgroup_state);
// Bytes per page of workgroup local memory.
// This is chosen to match the common page size of devices.
diff --git a/iree/hal/local/executable_library_benchmark.c b/iree/hal/local/executable_library_benchmark.c
index c91f5be..593b2b7 100644
--- a/iree/hal/local/executable_library_benchmark.c
+++ b/iree/hal/local/executable_library_benchmark.c
@@ -233,23 +233,18 @@
}
// Setup dispatch state.
- iree_hal_executable_dispatch_state_v0_t dispatch_state = {
- .workgroup_count = {{
- .x = FLAG_workgroup_count_x,
- .y = FLAG_workgroup_count_y,
- .z = FLAG_workgroup_count_z,
- }},
- .workgroup_size = {{
- .x = FLAG_workgroup_size_x,
- .y = FLAG_workgroup_size_y,
- .z = FLAG_workgroup_size_z,
- }},
+ const iree_hal_executable_dispatch_state_v0_t dispatch_state = {
+ .workgroup_count_x = FLAG_workgroup_count_x,
+ .workgroup_count_y = FLAG_workgroup_count_y,
+ .workgroup_count_z = FLAG_workgroup_count_z,
+ .workgroup_size_x = FLAG_workgroup_size_x,
+ .workgroup_size_y = FLAG_workgroup_size_y,
+ .workgroup_size_z = FLAG_workgroup_size_z,
.push_constant_count = dispatch_params.push_constant_count,
.push_constants = &dispatch_params.push_constants[0].ui32,
.binding_count = dispatch_params.binding_count,
.binding_ptrs = binding_ptrs,
.binding_lengths = binding_lengths,
- .environment = &local_executable->environment,
};
// Execute benchmark the workgroup invocation.
@@ -260,7 +255,7 @@
int64_t dispatch_count = 0;
while (iree_benchmark_keep_running(benchmark_state, /*batch_count=*/1)) {
IREE_RETURN_IF_ERROR(iree_hal_local_executable_issue_dispatch_inline(
- local_executable, FLAG_entry_point, &dispatch_state, local_memory));
+ local_executable, FLAG_entry_point, &dispatch_state, 0, local_memory));
++dispatch_count;
}
@@ -268,8 +263,8 @@
// invocations dispatched. That gives us both total dispatch and single
// invocation times in the reporter output.
int64_t total_invocations =
- dispatch_count * dispatch_state.workgroup_count.x *
- dispatch_state.workgroup_count.y * dispatch_state.workgroup_count.z;
+ dispatch_count * dispatch_state.workgroup_count_x *
+ dispatch_state.workgroup_count_y * dispatch_state.workgroup_count_z;
iree_benchmark_set_items_processed(benchmark_state, total_invocations);
// Deallocate buffers.
diff --git a/iree/hal/local/executable_library_demo.c b/iree/hal/local/executable_library_demo.c
index 038a1ef..af18875 100644
--- a/iree/hal/local/executable_library_demo.c
+++ b/iree/hal/local/executable_library_demo.c
@@ -24,20 +24,23 @@
// This is a simple scalar addition:
// binding[1] = binding[0] + push_constant[0]
static int dispatch_tile_a(
+ const iree_hal_executable_environment_v0_t* environment,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- const iree_hal_vec3_t* workgroup_id, void* local_memory) {
+ const iree_hal_executable_workgroup_state_v0_t* workgroup_state) {
const dispatch_tile_a_push_constants_t* push_constants =
(const dispatch_tile_a_push_constants_t*)dispatch_state->push_constants;
const float* src = ((const float*)dispatch_state->binding_ptrs[0]);
float* dst = ((float*)dispatch_state->binding_ptrs[1]);
- dst[workgroup_id->x] = src[workgroup_id->x] + push_constants->f0;
+ const uint32_t x = workgroup_state->workgroup_id_x;
+ dst[x] = src[x] + push_constants->f0;
return 0;
}
// Just another entry point.
static int dispatch_tile_b(
+ const iree_hal_executable_environment_v0_t* environment,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- const iree_hal_vec3_t* workgroup_id, void* local_memory) {
+ const iree_hal_executable_workgroup_state_v0_t* workgroup_state) {
return 0;
}
diff --git a/iree/hal/local/executable_library_test.c b/iree/hal/local/executable_library_test.c
index 69b7c39..9bcff7b 100644
--- a/iree/hal/local/executable_library_test.c
+++ b/iree/hal/local/executable_library_test.c
@@ -81,24 +81,30 @@
library.v0->exports.ptrs[0];
// Dispatch each workgroup with the same state.
- iree_hal_executable_dispatch_state_v0_t dispatch_state = {
- .workgroup_count = {{4, 1, 1}},
- .workgroup_size = {{1, 1, 1}},
+ const iree_hal_executable_dispatch_state_v0_t dispatch_state = {
+ .workgroup_count_x = 4,
+ .workgroup_count_y = 1,
+ .workgroup_count_z = 1,
+ .workgroup_size_x = 1,
+ .workgroup_size_y = 1,
+ .workgroup_size_z = 1,
.push_constant_count = IREE_ARRAYSIZE(push_constants.values),
.push_constants = push_constants.values,
.binding_count = IREE_ARRAYSIZE(binding_ptrs),
.binding_ptrs = binding_ptrs,
.binding_lengths = binding_lengths,
- .processor_id = iree_cpu_query_processor_id(),
- .environment = &environment,
};
- for (uint32_t z = 0; z < dispatch_state.workgroup_count.z; ++z) {
- for (uint32_t y = 0; y < dispatch_state.workgroup_count.y; ++y) {
- for (uint32_t x = 0; x < dispatch_state.workgroup_count.x; ++x) {
+ iree_hal_executable_workgroup_state_v0_t workgroup_state = {
+ .processor_id = iree_cpu_query_processor_id(),
+ };
+ for (uint32_t z = 0; z < dispatch_state.workgroup_count_z; ++z) {
+ workgroup_state.workgroup_id_z = z;
+ for (uint32_t y = 0; y < dispatch_state.workgroup_count_y; ++y) {
+ workgroup_state.workgroup_id_y = y;
+ for (uint32_t x = 0; x < dispatch_state.workgroup_count_x; ++x) {
+ workgroup_state.workgroup_id_x = x;
// Invoke the workgroup (x, y, z).
- iree_hal_vec3_t workgroup_id = {{x, y, z}};
- int ret = entry_fn_ptr(&dispatch_state, &workgroup_id,
- /*local_memory=*/NULL);
+ int ret = entry_fn_ptr(&environment, &dispatch_state, &workgroup_state);
IREE_ASSERT_EQ(
ret, 0,
"if we have bounds checking enabled the executable will signal "
diff --git a/iree/hal/local/inline_command_buffer.c b/iree/hal/local/inline_command_buffer.c
index a64ef51..adc426e 100644
--- a/iree/hal/local/inline_command_buffer.c
+++ b/iree/hal/local/inline_command_buffer.c
@@ -56,10 +56,12 @@
// Cached and initialized dispatch state reused for all dispatches.
// Individual dispatches must populate the dynamically changing fields like
// push_constant_count and binding_count.
- iree_hal_executable_dispatch_state_v0_t dispatch_state;
+ iree_alignas(64) iree_hal_executable_dispatch_state_v0_t dispatch_state;
// An opaque tag used to reduce the cost of processor ID queries.
iree_cpu_processor_tag_t processor_tag;
+ // Guess at the current processor ID.
+ iree_cpu_processor_id_t processor_id;
} state;
} iree_hal_inline_command_buffer_t;
@@ -160,9 +162,8 @@
// Updates the cached processor ID field in the command buffer.
static void iree_hal_inline_command_buffer_update_processor_id(
iree_hal_inline_command_buffer_t* command_buffer) {
- iree_cpu_requery_processor_id(
- &command_buffer->state.processor_tag,
- &command_buffer->state.dispatch_state.processor_id);
+ iree_cpu_requery_processor_id(&command_buffer->state.processor_tag,
+ &command_buffer->state.processor_id);
}
static iree_status_t iree_hal_inline_command_buffer_begin(
@@ -420,15 +421,14 @@
iree_hal_executable_dispatch_state_v0_t* dispatch_state =
&command_buffer->state.dispatch_state;
- dispatch_state->environment = &local_executable->environment;
// TODO(benvanik): expose on API or keep fixed on executable.
- dispatch_state->workgroup_size.x = 1;
- dispatch_state->workgroup_size.y = 1;
- dispatch_state->workgroup_size.z = 1;
- dispatch_state->workgroup_count.x = workgroup_x;
- dispatch_state->workgroup_count.y = workgroup_y;
- dispatch_state->workgroup_count.z = workgroup_z;
+ dispatch_state->workgroup_size_x = 1;
+ dispatch_state->workgroup_size_y = 1;
+ dispatch_state->workgroup_size_z = 1;
+ dispatch_state->workgroup_count_x = workgroup_x;
+ dispatch_state->workgroup_count_y = workgroup_y;
+ dispatch_state->workgroup_count_z = workgroup_z;
// Push constants are pulled directly from the command buffer state, but we
// only allow the dispatch to read what we know is initialized based on the
@@ -483,7 +483,8 @@
iree_fpu_state_t fpu_state =
iree_fpu_state_push(IREE_FPU_STATE_FLAG_FLUSH_DENORMALS_TO_ZERO);
iree_status_t status = iree_hal_local_executable_issue_dispatch_inline(
- local_executable, entry_point, dispatch_state, local_memory);
+ local_executable, entry_point, dispatch_state,
+ command_buffer->state.processor_id, local_memory);
iree_fpu_state_pop(fpu_state);
if (local_memory.data) {
@@ -492,6 +493,15 @@
return status;
}
+typedef union iree_hal_vec3_t {
+ struct {
+ uint32_t x;
+ uint32_t y;
+ uint32_t z;
+ };
+ uint32_t value[3];
+} iree_hal_vec3_t;
+
static iree_status_t iree_hal_inline_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
diff --git a/iree/hal/local/loaders/embedded_library_loader.c b/iree/hal/local/loaders/embedded_library_loader.c
index d6786ef..017579e 100644
--- a/iree/hal/local/loaders/embedded_library_loader.c
+++ b/iree/hal/local/loaders/embedded_library_loader.c
@@ -246,7 +246,7 @@
static iree_status_t iree_hal_elf_executable_issue_call(
iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- const iree_hal_vec3_t* workgroup_id, iree_byte_span_t local_memory) {
+ const iree_hal_executable_workgroup_state_v0_t* workgroup_state) {
iree_hal_elf_executable_t* executable =
(iree_hal_elf_executable_t*)base_executable;
const iree_hal_executable_library_v0_t* library = executable->library.v0;
@@ -275,9 +275,9 @@
}
#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- int ret =
- iree_elf_call_i_ppp(library->exports.ptrs[ordinal], (void*)dispatch_state,
- (void*)workgroup_id, (void*)local_memory.data);
+ int ret = iree_elf_call_i_ppp(library->exports.ptrs[ordinal],
+ (void*)&base_executable->environment,
+ (void*)dispatch_state, (void*)workgroup_state);
IREE_TRACE_ZONE_END(z0);
diff --git a/iree/hal/local/loaders/static_library_loader.c b/iree/hal/local/loaders/static_library_loader.c
index ea7f49f..e123938 100644
--- a/iree/hal/local/loaders/static_library_loader.c
+++ b/iree/hal/local/loaders/static_library_loader.c
@@ -119,7 +119,7 @@
static iree_status_t iree_hal_static_executable_issue_call(
iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- const iree_hal_vec3_t* workgroup_id, iree_byte_span_t local_memory) {
+ const iree_hal_executable_workgroup_state_v0_t* workgroup_state) {
iree_hal_static_executable_t* executable =
(iree_hal_static_executable_t*)base_executable;
const iree_hal_executable_library_v0_t* library = executable->library.v0;
@@ -148,8 +148,8 @@
}
#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- int ret = library->exports.ptrs[ordinal](dispatch_state, workgroup_id,
- local_memory.data);
+ int ret = library->exports.ptrs[ordinal](&base_executable->environment,
+ dispatch_state, workgroup_state);
IREE_TRACE_ZONE_END(z0);
diff --git a/iree/hal/local/loaders/system_library_loader.c b/iree/hal/local/loaders/system_library_loader.c
index a0aaf79..b18acc6 100644
--- a/iree/hal/local/loaders/system_library_loader.c
+++ b/iree/hal/local/loaders/system_library_loader.c
@@ -354,7 +354,7 @@
static iree_status_t iree_hal_system_executable_issue_call(
iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- const iree_hal_vec3_t* workgroup_id, iree_byte_span_t local_memory) {
+ const iree_hal_executable_workgroup_state_v0_t* workgroup_state) {
iree_hal_system_executable_t* executable =
(iree_hal_system_executable_t*)base_executable;
const iree_hal_executable_library_v0_t* library = executable->library.v0;
@@ -383,8 +383,8 @@
}
#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
- int ret = library->exports.ptrs[ordinal](dispatch_state, workgroup_id,
- local_memory.data);
+ int ret = library->exports.ptrs[ordinal](&base_executable->environment,
+ dispatch_state, workgroup_state);
IREE_TRACE_ZONE_END(z0);
diff --git a/iree/hal/local/loaders/vmvx_module_loader.c b/iree/hal/local/loaders/vmvx_module_loader.c
index 89a0770..04ce5be 100644
--- a/iree/hal/local/loaders/vmvx_module_loader.c
+++ b/iree/hal/local/loaders/vmvx_module_loader.c
@@ -249,7 +249,7 @@
static iree_status_t iree_hal_vmvx_executable_issue_call(
iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- const iree_hal_vec3_t* workgroup_id, iree_byte_span_t local_memory) {
+ const iree_hal_executable_workgroup_state_v0_t* workgroup_state) {
iree_hal_vmvx_executable_t* executable =
(iree_hal_vmvx_executable_t*)base_executable;
@@ -313,7 +313,9 @@
iree_vm_buffer_t local_memory_buffer;
iree_vm_buffer_initialize(
IREE_VM_BUFFER_ACCESS_MUTABLE | IREE_VM_BUFFER_ACCESS_ORIGIN_HOST,
- local_memory, iree_allocator_null(), &local_memory_buffer);
+ iree_make_byte_span(workgroup_state->local_memory,
+ workgroup_state->local_memory_size),
+ iree_allocator_null(), &local_memory_buffer);
iree_vm_buffer_retain(&local_memory_buffer); // for call
// Map the push constant memory directly from the dispatch state.
@@ -333,9 +335,9 @@
// %local_memory: !vmvx.buffer,
// %constants: !vmvx.buffer,
// %bindings: !util.list<!vmvx.buffer>,
- // %workgroup_x: index,
- // %workgroup_y: index,
- // %workgroup_z: index,
+ // %workgroup_id_x: index,
+ // %workgroup_id_y: index,
+ // %workgroup_id_z: index,
// %workgroup_size_x: index,
// %workgroup_size_y: index,
// %workgroup_size_z: index,
@@ -350,9 +352,9 @@
iree_vm_ref_t local_memory;
iree_vm_ref_t constants;
iree_vm_ref_t bindings;
- uint32_t workgroup_x;
- uint32_t workgroup_y;
- uint32_t workgroup_z;
+ uint32_t workgroup_id_x;
+ uint32_t workgroup_id_y;
+ uint32_t workgroup_id_z;
uint32_t workgroup_size_x;
uint32_t workgroup_size_y;
uint32_t workgroup_size_z;
@@ -378,15 +380,15 @@
.ptr = binding_list,
.offsetof_counter = 0,
},
- .workgroup_x = workgroup_id->x,
- .workgroup_y = workgroup_id->y,
- .workgroup_z = workgroup_id->z,
- .workgroup_size_x = dispatch_state->workgroup_size.x,
- .workgroup_size_y = dispatch_state->workgroup_size.y,
- .workgroup_size_z = dispatch_state->workgroup_size.z,
- .workgroup_count_x = dispatch_state->workgroup_count.x,
- .workgroup_count_y = dispatch_state->workgroup_count.y,
- .workgroup_count_z = dispatch_state->workgroup_count.z,
+ .workgroup_id_x = workgroup_state->workgroup_id_x,
+ .workgroup_id_y = workgroup_state->workgroup_id_y,
+ .workgroup_id_z = workgroup_state->workgroup_id_z,
+ .workgroup_size_x = dispatch_state->workgroup_size_x,
+ .workgroup_size_y = dispatch_state->workgroup_size_y,
+ .workgroup_size_z = dispatch_state->workgroup_size_z,
+ .workgroup_count_x = dispatch_state->workgroup_count_x,
+ .workgroup_count_y = dispatch_state->workgroup_count_y,
+ .workgroup_count_z = dispatch_state->workgroup_count_z,
};
// On-stack stack. We really do abuse the stack too much here.
diff --git a/iree/hal/local/local_executable.c b/iree/hal/local/local_executable.c
index cdcf828..1fd92ec 100644
--- a/iree/hal/local/local_executable.c
+++ b/iree/hal/local/local_executable.c
@@ -52,44 +52,52 @@
iree_status_t iree_hal_local_executable_issue_call(
iree_hal_local_executable_t* executable, iree_host_size_t ordinal,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- const iree_hal_vec3_t* workgroup_id, iree_byte_span_t local_memory) {
+ const iree_hal_executable_workgroup_state_v0_t* workgroup_state) {
IREE_ASSERT_ARGUMENT(executable);
IREE_ASSERT_ARGUMENT(dispatch_state);
- IREE_ASSERT_ARGUMENT(workgroup_id);
+ IREE_ASSERT_ARGUMENT(workgroup_state);
return ((const iree_hal_local_executable_vtable_t*)
executable->resource.vtable)
- ->issue_call(executable, ordinal, dispatch_state, workgroup_id,
- local_memory);
+ ->issue_call(executable, ordinal, dispatch_state, workgroup_state);
}
iree_status_t iree_hal_local_executable_issue_dispatch_inline(
iree_hal_local_executable_t* executable, iree_host_size_t ordinal,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- iree_byte_span_t local_memory) {
+ uint32_t processor_id, iree_byte_span_t local_memory) {
IREE_TRACE_ZONE_BEGIN(z0);
// TODO(benvanik): annotate with executable name to calculate total time.
- const iree_hal_vec3_t workgroup_count = dispatch_state->workgroup_count;
+ const uint32_t workgroup_count_x = dispatch_state->workgroup_count_x;
+ const uint32_t workgroup_count_y = dispatch_state->workgroup_count_y;
+ const uint32_t workgroup_count_z = dispatch_state->workgroup_count_z;
#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
char xyz_string[32];
int xyz_string_length =
snprintf(xyz_string, IREE_ARRAYSIZE(xyz_string), "%ux%ux%u",
- workgroup_count.x, workgroup_count.y, workgroup_count.z);
+ workgroup_count_x, workgroup_count_y, workgroup_count_z);
IREE_TRACE_ZONE_APPEND_TEXT_STRING_VIEW(z0, xyz_string, xyz_string_length);
#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION
iree_status_t status = iree_ok_status();
- iree_hal_vec3_t workgroup_id;
- for (workgroup_id.z = 0; workgroup_id.z < workgroup_count.z;
- ++workgroup_id.z) {
- for (workgroup_id.y = 0; workgroup_id.y < workgroup_count.y;
- ++workgroup_id.y) {
- for (workgroup_id.x = 0; workgroup_id.x < workgroup_count.x;
- ++workgroup_id.x) {
+ iree_alignas(64) iree_hal_executable_workgroup_state_v0_t workgroup_state = {
+ .workgroup_id_x = 0,
+ .workgroup_id_y = 0,
+ .workgroup_id_z = 0,
+ .processor_id = processor_id,
+ .local_memory = local_memory.data,
+ .local_memory_size = (size_t)local_memory.data_length,
+ };
+ for (uint32_t z = 0; z < workgroup_count_z; ++z) {
+ workgroup_state.workgroup_id_z = z;
+ for (uint32_t y = 0; y < workgroup_count_y; ++y) {
+ workgroup_state.workgroup_id_y = y;
+ for (uint32_t x = 0; x < workgroup_count_x; ++x) {
+ workgroup_state.workgroup_id_x = x;
status = iree_hal_local_executable_issue_call(
- executable, ordinal, dispatch_state, &workgroup_id, local_memory);
+ executable, ordinal, dispatch_state, &workgroup_state);
if (!iree_status_is_ok(status)) break;
}
}
diff --git a/iree/hal/local/local_executable.h b/iree/hal/local/local_executable.h
index 7993295..d9a42e4 100644
--- a/iree/hal/local/local_executable.h
+++ b/iree/hal/local/local_executable.h
@@ -38,7 +38,7 @@
iree_status_t(IREE_API_PTR* issue_call)(
iree_hal_local_executable_t* executable, iree_host_size_t ordinal,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- const iree_hal_vec3_t* workgroup_id, iree_byte_span_t local_memory);
+ const iree_hal_executable_workgroup_state_v0_t* workgroup_state);
} iree_hal_local_executable_vtable_t;
// Initializes the local executable base type.
@@ -62,12 +62,12 @@
iree_status_t iree_hal_local_executable_issue_call(
iree_hal_local_executable_t* executable, iree_host_size_t ordinal,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- const iree_hal_vec3_t* workgroup_id, iree_byte_span_t local_memory);
+ const iree_hal_executable_workgroup_state_v0_t* workgroup_state);
iree_status_t iree_hal_local_executable_issue_dispatch_inline(
iree_hal_local_executable_t* executable, iree_host_size_t ordinal,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
- iree_byte_span_t local_memory);
+ uint32_t processor_id, iree_byte_span_t local_memory);
#ifdef __cplusplus
} // extern "C"
diff --git a/iree/hal/local/task_command_buffer.c b/iree/hal/local/task_command_buffer.c
index 63efb51..9412063 100644
--- a/iree/hal/local/task_command_buffer.c
+++ b/iree/hal/local/task_command_buffer.c
@@ -820,32 +820,41 @@
(const iree_hal_cmd_dispatch_t*)user_context;
IREE_TRACE_ZONE_BEGIN(z0);
- iree_hal_executable_dispatch_state_v0_t state;
- memset(&state, 0, sizeof(state));
- memcpy(state.workgroup_count.value, tile_context->workgroup_count,
- sizeof(state.workgroup_count));
- memcpy(state.workgroup_size.value, tile_context->workgroup_size,
- sizeof(state.workgroup_size));
-
+ // We could share this across all workgroups in a dispatch and reduce cache
+ // pressure as all cores would be hitting the same hot read-only cache line.
+ // It'd grow the size of iree_hal_cmd_dispatch_t by a few dozen bytes, though,
+ // and so we'd need some profiling to see if it's worth it (fixed command
+ // buffer cost vs potential for saving a cache miss or two).
+ iree_alignas(64) iree_hal_executable_dispatch_state_v0_t dispatch_state = {
+ .workgroup_size_x = tile_context->workgroup_size[0],
+ .workgroup_size_y = tile_context->workgroup_size[1],
+ .workgroup_size_z = tile_context->workgroup_size[2],
+ .push_constant_count = cmd->push_constant_count,
+ .workgroup_count_x = tile_context->workgroup_count[0],
+ .workgroup_count_y = tile_context->workgroup_count[1],
+ .workgroup_count_z = tile_context->workgroup_count[2],
+ .binding_count = cmd->binding_count,
+ };
uint8_t* cmd_ptr = (uint8_t*)cmd + sizeof(*cmd);
+ dispatch_state.push_constants = (uint32_t*)cmd_ptr;
+ cmd_ptr += cmd->push_constant_count * sizeof(*dispatch_state.push_constants);
+ dispatch_state.binding_ptrs = (void**)cmd_ptr;
+ cmd_ptr += cmd->binding_count * sizeof(*dispatch_state.binding_ptrs);
+ dispatch_state.binding_lengths = (size_t*)cmd_ptr;
+ cmd_ptr += cmd->binding_count * sizeof(*dispatch_state.binding_lengths);
- state.push_constant_count = cmd->push_constant_count;
- state.push_constants = (uint32_t*)cmd_ptr;
- cmd_ptr += cmd->push_constant_count * sizeof(*state.push_constants);
-
- state.binding_count = cmd->binding_count;
- state.binding_ptrs = (void**)cmd_ptr;
- cmd_ptr += cmd->binding_count * sizeof(*state.binding_ptrs);
- state.binding_lengths = (size_t*)cmd_ptr;
- cmd_ptr += cmd->binding_count * sizeof(*state.binding_lengths);
-
- state.processor_id = tile_context->processor_id;
- state.environment = &cmd->executable->environment;
-
+ const iree_alignas(64)
+ iree_hal_executable_workgroup_state_v0_t workgroup_state = {
+ .workgroup_id_x = tile_context->workgroup_xyz[0],
+ .workgroup_id_y = tile_context->workgroup_xyz[1],
+ .workgroup_id_z = tile_context->workgroup_xyz[2],
+ .reserved = 0,
+ .processor_id = tile_context->processor_id,
+ .local_memory = tile_context->local_memory.data,
+ .local_memory_size = (size_t)tile_context->local_memory.data_length,
+ };
iree_status_t status = iree_hal_local_executable_issue_call(
- cmd->executable, cmd->ordinal, &state,
- (const iree_hal_vec3_t*)tile_context->workgroup_xyz,
- tile_context->local_memory);
+ cmd->executable, cmd->ordinal, &dispatch_state, &workgroup_state);
IREE_TRACE_ZONE_END(z0);
return status;
diff --git a/iree/task/task.h b/iree/task/task.h
index 9ae2290..aeef180 100644
--- a/iree/task/task.h
+++ b/iree/task/task.h
@@ -532,6 +532,10 @@
// TODO(benvanik): workgroup index to amortize calculating linear offsets.
// (like gl_GlobalInvocationID)
+ // Opaque ID of the processor executing the tile.
+ // May be slightly out of date or 0 if the processor could not be queried.
+ iree_cpu_processor_id_t processor_id;
+
// Tile-local memory that is pinned to each worker ensuring no cache
// thrashing. Aligned to at least the natural pointer size of the machine.
// Contents are (today) undefined upon entry.
@@ -539,10 +543,6 @@
// Shared statistics counters for the dispatch shard.
iree_task_dispatch_statistics_t* statistics;
-
- // Opaque ID of the processor executing the tile.
- // May be slightly out of date or 0 if the processor could not be queried.
- iree_cpu_processor_id_t processor_id;
} iree_task_tile_context_t;
typedef struct iree_task_dispatch_t iree_task_dispatch_t;