Merging hal.device.switch feature branch. Adding a new `hal.device.switch` op that allows us to nicely generate backend/device-specific logic that can be transformed into efficient code for runtime. There's a lot of folders needed to help here, however this initial structure will let us begin work on #1168 as it exposes the ability to generate device-specific dispatch ops. Future iteration will add more device matchers to allow more fine-grained codegen. For example, matchers that look for supported device-specific extensions, minimum allocation alignment, unified vs. discrete memory, maximum descriptor set count, etc can all be added in a generic way and efficiently lowered to runtime code. Progress towards heterogeneous execution will cause the `!hal.device` type to change to a `hal.device<ordinal>` form, allowing better compile-time evaluation of the `hal.device.switch` blocks when placement happens. For now we still only have a single runtime-selected `hal.ex.shared_device` so this cannot happen. Closes https://github.com/google/iree/pull/1393 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/1393 from google:benvanik-hal-device-switch d0348ef32a6af9940d4f9243a1cce88921bd398e PiperOrigin-RevId: 306259903
diff --git a/iree/base/api.cc b/iree/base/api.cc index 28ced9d..c62fcce 100644 --- a/iree/base/api.cc +++ b/iree/base/api.cc
@@ -199,6 +199,40 @@ return offset; } +static bool MatchPattern(absl::string_view value, absl::string_view pattern) { + size_t next_char_index = pattern.find_first_of("*?"); + if (next_char_index == std::string::npos) { + return value == pattern; + } else if (next_char_index > 0) { + if (value.substr(0, next_char_index) != + pattern.substr(0, next_char_index)) { + return false; + } + value = value.substr(next_char_index); + pattern = pattern.substr(next_char_index); + } + char pattern_char = pattern[0]; + if (value.empty() && pattern.empty()) { + return true; + } else if (pattern_char == '*' && pattern.size() > 1 && value.empty()) { + return false; + } else if (pattern_char == '*' && pattern.size() == 1) { + return true; + } else if (pattern_char == '?' || value[0] == pattern_char) { + return MatchPattern(value.substr(1), pattern.substr(1)); + } else if (pattern_char == '*') { + return MatchPattern(value, pattern.substr(1)) || + MatchPattern(value.substr(1), pattern); + } + return false; +} + +IREE_API_EXPORT bool IREE_API_CALL iree_string_view_match_pattern( + iree_string_view_t value, iree_string_view_t pattern) { + return MatchPattern(absl::string_view(value.data, value.size), + absl::string_view(pattern.data, pattern.size)); +} + //===----------------------------------------------------------------------===// // iree::FileMapping //===----------------------------------------------------------------------===//
diff --git a/iree/base/api.h b/iree/base/api.h index a35bb4d..3d19268 100644 --- a/iree/base/api.h +++ b/iree/base/api.h
@@ -402,6 +402,17 @@ iree_string_view_t value, char split_char, iree_string_view_t* out_lhs, iree_string_view_t* out_rhs); +// Returns true if the given |value| matches |pattern| (normal * and ? rules). +// This accepts wildcards in the form of '*' and '?' for any delimited value. +// '*' will match zero or more of any character and '?' will match exactly one +// of any character. +// +// For example, +// 'foo-*-bar' matches: 'foo-123-bar', 'foo-456-789-bar' +// 'foo-10?' matches: 'foo-101', 'foo-102' +IREE_API_EXPORT bool IREE_API_CALL iree_string_view_match_pattern( + iree_string_view_t value, iree_string_view_t pattern); + #endif // IREE_API_NO_PROTOTYPES //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp index 0347ae3..2bfa8ee 100644 --- a/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp +++ b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp
@@ -111,7 +111,7 @@ } else if (isa<CallIndirectOp>(op)) { // Indirect calls are not supported and must first be devirtualized. return false; - } else if (isa<ReturnOp>(op)) { + } else if (isa<mlir::ReturnOp>(op)) { // TODO(benvanik): widen to all known terminators? sometimes they may // have side-effects. continue;
diff --git a/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp index 4dc0cfe..11da5c9 100644 --- a/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
@@ -126,7 +126,7 @@ return false; } -bool convertReturnOp(ReturnOp *op, OpBuilder &builder, +bool convertReturnOp(mlir::ReturnOp *op, OpBuilder &builder, BlockAndValueMapping *mapping) { llvm::SmallVector<Value, 10> newOperands; if (untupleAndLookupValues(op->getOperands(), &newOperands, builder, @@ -134,7 +134,7 @@ return true; } - builder.create<ReturnOp>(op->getLoc(), newOperands); + builder.create<mlir::ReturnOp>(op->getLoc(), newOperands); return false; } @@ -226,7 +226,7 @@ bool convertOperation(Operation *op, OpBuilder &builder, BlockAndValueMapping *mapping) { - if (auto returnOp = dyn_cast<ReturnOp>(op)) { + if (auto returnOp = dyn_cast<mlir::ReturnOp>(op)) { return convertReturnOp(&returnOp, builder, mapping); } else if (auto callOp = dyn_cast<CallOp>(op)) { return convertCallOp(&callOp, builder, mapping);
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp index 0dab859..4251d12 100644 --- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -39,7 +39,7 @@ namespace { // Replaces |returnOp| with a clone including |newOperands| appended. -LogicalResult appendReturnOperands(ReturnOp returnOp, +LogicalResult appendReturnOperands(IREE::Flow::ReturnOp returnOp, ArrayRef<Value> newOperands) { // Insert prior to the original return. OpBuilder builder(returnOp); @@ -49,7 +49,7 @@ operands.reserve(returnOp.getNumOperands() + newOperands.size()); operands.append(returnOp.operand_begin(), returnOp.operand_end()); operands.append(newOperands.begin(), newOperands.end()); - builder.create<ReturnOp>(returnOp.getLoc(), operands); + builder.create<IREE::Flow::ReturnOp>(returnOp.getLoc(), operands); // Remove original. returnOp.erase(); @@ -100,7 +100,7 @@ DispatchRegionOp removeUnusedResults(DispatchRegionOp regionOp) { // Find return value within the region. auto ®ionBlock = regionOp.body().getBlocks().front(); - auto returnOp = dyn_cast<ReturnOp>(regionBlock.getTerminator()); + auto returnOp = dyn_cast<IREE::Flow::ReturnOp>(regionBlock.getTerminator()); if (!returnOp) { regionBlock.getParent()->getParentOfType<FuncOp>().emitError() << "block does not contain an flow.return op"; @@ -213,7 +213,7 @@ // Find the values used as return values in the lhs. // We'll need to replace the uses in rhs with these. - auto lhsReturnOp = cast<ReturnOp>(lhsBlock.getTerminator()); + auto lhsReturnOp = cast<IREE::Flow::ReturnOp>(lhsBlock.getTerminator()); SmallVector<Value, 8> lhsReturnValues; lhsReturnValues.reserve(lhsReturnOp.getNumOperands()); lhsReturnValues.append(lhsReturnOp.operand_begin(), @@ -221,7 +221,7 @@ // Find the values used as return values in the rhs. // We'll add these to the results of the lhs region. - auto rhsReturnOp = cast<ReturnOp>(rhsBlock.getTerminator()); + auto rhsReturnOp = cast<IREE::Flow::ReturnOp>(rhsBlock.getTerminator()); SmallVector<Value, 8> rhsReturnValues; rhsReturnValues.reserve(rhsReturnOp.getNumOperands()); rhsReturnValues.append(rhsReturnOp.operand_begin(),
diff --git a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp index 592d145..40addc2 100644 --- a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp
@@ -275,7 +275,7 @@ for (auto *op : streamOps) { fragmentBuilder.clone(*op, mapping); } - fragmentBuilder.create<ReturnOp>( + fragmentBuilder.create<IREE::Flow::ReturnOp>( UnknownLoc::get(context), llvm::to_vector<8>(llvm::map_range(fragmentResults, [&](Value value) { return mapping.lookup(value);
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp index d33a735..cc59fb4 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
@@ -49,7 +49,7 @@ // Create const and return ops. auto constValue = rewriter.create<ConstantOp>(loc, immediateElements); - rewriter.create<ReturnOp>(loc, constValue.getResult()); + rewriter.create<mlir::ReturnOp>(loc, constValue.getResult()); return initializerFuncOp; }
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp index 24c4feb..4131dc2 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp
@@ -25,6 +25,8 @@ OwningRewritePatternList &patterns) { patterns.insert<VMImportOpConversion<IREE::HAL::DeviceAllocatorOp>>( context, importSymbols, typeConverter, "hal.device.allocator"); + patterns.insert<VMImportOpConversion<IREE::HAL::DeviceMatchIDOp>>( + context, importSymbols, typeConverter, "hal.device.match.id"); } } // namespace iree_compiler
diff --git a/iree/compiler/Dialect/HAL/IR/HALBase.td b/iree/compiler/Dialect/HAL/IR/HALBase.td index 04fd282..65b2aac 100644 --- a/iree/compiler/Dialect/HAL/IR/HALBase.td +++ b/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -554,6 +554,36 @@ } //===----------------------------------------------------------------------===// +// Expression matching attributes +//===----------------------------------------------------------------------===// + +def HAL_MatchAlwaysAttr : + IREE_StructAttr<"match.always", "MatchAlwaysAttr", HAL_Dialect, []> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_MatchAnyAttr : + IREE_StructAttr<"match.any", "MatchAnyAttr", HAL_Dialect, [ + IREE_StructFieldAttr<"conditions", AnyAttr>, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_MatchAllAttr : + IREE_StructAttr<"match.all", "MatchAllAttr", HAL_Dialect, [ + IREE_StructFieldAttr<"conditions", AnyAttr>, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +def HAL_DeviceMatchIDAttr : + IREE_StructAttr<"device.match.id", "DeviceMatchIDAttr", HAL_Dialect, [ + IREE_StructFieldAttr<"pattern", StrAttr>, + ]> { + let cppNamespace = "mlir::iree_compiler::IREE::HAL"; +} + +//===----------------------------------------------------------------------===// // Base HAL op classes //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp index 8bb4b14..66adfec 100644 --- a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
@@ -20,6 +20,7 @@ #include "iree/compiler/Dialect/HAL/hal.imports.h" #include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h" #include "llvm/Support/SourceMgr.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Parser.h" @@ -65,7 +66,8 @@ : Dialect(getDialectNamespace(), context) { addInterfaces<HALToVMConversionInterface>(); - addAttributes<DescriptorSetLayoutBindingAttr>(); + addAttributes<DescriptorSetLayoutBindingAttr, MatchAlwaysAttr, MatchAnyAttr, + MatchAllAttr, DeviceMatchIDAttr>(); addTypes<AllocatorType, BufferType, BufferViewType, CommandBufferType, DescriptorSetType, DescriptorSetLayoutType, DeviceType, EventType, @@ -88,6 +90,14 @@ if (failed(parser.parseKeyword(&attrKind))) return {}; if (attrKind == DescriptorSetLayoutBindingAttr::getKindName()) { return DescriptorSetLayoutBindingAttr::parse(parser); + } else if (attrKind == MatchAlwaysAttr::getKindName()) { + return MatchAlwaysAttr::parse(parser); + } else if (attrKind == MatchAnyAttr::getKindName()) { + return MatchAnyAttr::parse(parser); + } else if (attrKind == MatchAllAttr::getKindName()) { + return MatchAllAttr::parse(parser); + } else if (attrKind == DeviceMatchIDAttr::getKindName()) { + return DeviceMatchIDAttr::parse(parser); } parser.emitError(parser.getNameLoc()) << "unknown HAL attribute: " << attrKind; @@ -95,11 +105,12 @@ } void HALDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const { - if (auto typedAttr = attr.dyn_cast<DescriptorSetLayoutBindingAttr>()) { - typedAttr.print(p); - } else { - llvm_unreachable("unhandled HAL attribute kind"); - } + TypeSwitch<Attribute>(attr) + .Case<DescriptorSetLayoutBindingAttr, MatchAlwaysAttr, MatchAnyAttr, + MatchAllAttr, DeviceMatchIDAttr>( + [&](auto typedAttr) { typedAttr.print(p); }) + .Default( + [](Attribute) { llvm_unreachable("unhandled HAL attribute kind"); }); } //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index 7477df5..efdd27a 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -255,6 +255,22 @@ results.insert<SkipBufferViewBufferOp>(context); } +//===----------------------------------------------------------------------===// +// hal.device.switch +//===----------------------------------------------------------------------===// + +// TODO(benvanik): fold conditions with the same IR tree. +// TODO(benvanik): remove duplicate conditions. +// TODO(benvanik): fold condition expressions (any(always, ...) -> always, etc). +// TODO(benvanik): completely replace switches with just one always block. +// TODO(benvanik): remove conditions with no side-effects. + +//===----------------------------------------------------------------------===// +// hal.device.match.id +//===----------------------------------------------------------------------===// + +// TODO(benvanik): fold matches that are known true based on device config. + } // namespace HAL } // namespace IREE } // namespace iree_compiler
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 9071527..68d595d 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -962,6 +962,176 @@ } //===----------------------------------------------------------------------===// +// hal.device.switch +//===----------------------------------------------------------------------===// + +void DeviceSwitchOp::build(Builder *builder, OperationState &state, + TypeRange resultTypes, Value device, + ArrayRef<Attribute> conditions, + ArrayRef<ValueRange> conditionArgs, + ArrayRef<NamedAttribute> attributes) { + state.addOperands({device}); + state.addAttribute("conditions", builder->getArrayAttr(conditions)); + for (auto args : conditionArgs) { + state.addOperands(args); + auto *region = state.addRegion(); + auto *entryBlock = OpBuilder(region).createBlock(region); + for (auto arg : args) { + entryBlock->addArgument(arg.getType()); + } + } + state.addTypes(resultTypes); + state.addAttributes(attributes); + state.resizableOperandList = true; +} + +static ParseResult parseDeviceSwitchOp(OpAsmParser &parser, + OperationState *result) { + OpAsmParser::OperandType device; + Type deviceType; + if (failed(parser.parseLParen()) || failed(parser.parseOperand(device)) || + failed(parser.parseColonType(deviceType)) || + failed(parser.resolveOperand(device, deviceType, result->operands)) || + failed(parser.parseRParen()) || + failed(parser.parseOptionalArrowTypeList(result->types))) { + return failure(); + } + + // Parses each switch condition attribute and region, like: + // #hal.device.match.id<"vulkan-v1.?-*">(%c1a = %c1 : i32) { + // hal.return %c1a : i32 + // }, ... + result->setOperandListToResizable(); + SmallVector<Attribute, 4> conditionAttrs; + do { + Attribute conditionAttr; + SmallVector<NamedAttribute, 1> dummyAttrs; + if (failed(parser.parseAttribute(conditionAttr, "condition", dummyAttrs)) || + failed(parser.parseLParen())) { + return failure(); + } + conditionAttrs.push_back(conditionAttr); + SmallVector<OpAsmParser::OperandType, 16> regionArgs; + SmallVector<Type, 16> regionArgTypes; + if (failed(parser.parseOptionalRParen())) { + SmallVector<OpAsmParser::OperandType, 16> regionOperands; + auto argsLoc = parser.getCurrentLocation(); + do { + // Reserve entries in the lists. + regionArgs.emplace_back(); + regionOperands.emplace_back(); + regionArgTypes.emplace_back(); + if (failed(parser.parseRegionArgument(regionArgs.back())) || + failed(parser.parseEqual()) || + failed(parser.parseOperand(regionOperands.back())) || + failed(parser.parseColonType(regionArgTypes.back()))) { + return failure(); + } + } while (succeeded(parser.parseOptionalComma())); + if (failed(parser.parseRParen()) || + failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc, + result->operands))) { + return failure(); + } + } + auto *regionBody = result->addRegion(); + if (failed(parser.parseRegion(*regionBody, regionArgs, regionArgTypes))) { + return failure(); + } + } while (succeeded(parser.parseOptionalComma())); + result->addAttribute("conditions", + ArrayAttr::get(conditionAttrs, result->getContext())); + + if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) { + return failure(); + } + return success(); +} + +static void printDeviceSwitchOp(OpAsmPrinter &p, DeviceSwitchOp op) { + p << op.getOperationName() << "("; + p.printOperand(op.device()); + p << " : "; + p.printType(op.device().getType()); + p << ")"; + p.printOptionalArrowTypeList(op.getResultTypes()); + p << "\n"; + p.getStream().indent(4); + int argOffset = 0; + interleave( + llvm::zip(op.conditions(), op.condition_regions()), + [&](std::tuple<Attribute, Region &> it) { + auto &conditionAttr = std::get<0>(it); + auto &conditionRegion = std::get<1>(it); + p.printAttribute(conditionAttr); + p << "("; + auto regionOperands = conditionRegion.front().getArguments(); + auto regionArgs = op.args().slice(argOffset, regionOperands.size()); + argOffset += regionOperands.size(); + // TODO(benvanik): figure out how to parse with shadowing. + // p.shadowRegionArgs(conditionRegion, regionArgs); + interleaveComma(llvm::zip(regionOperands, regionArgs), p, + [&](std::tuple<BlockArgument, Value> it) { + p << std::get<0>(it) << " = " << std::get<1>(it); + p << " : "; + p << std::get<1>(it).getType(); + }); + p << ")"; + p.printRegion(conditionRegion, + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + }, + [&]() { + p << ",\n"; + p.getStream().indent(4); + }); + p.printOptionalAttrDictWithKeyword(op.getAttrs(), + /*elidedAttrs=*/{"conditions"}); +} + +static LogicalResult verifyDeviceSwitchOp(DeviceSwitchOp op) { + if (op.conditions().size() != op.condition_regions().size()) { + return op.emitOpError() << "requires conditions and regions be matched 1:1"; + } else if (op.condition_regions().empty()) { + return op.emitOpError() << "requires at least one condition"; + } + int argOffset = 0; + for (auto ®ion : op.condition_regions()) { + auto regionOperands = region.front().getArguments(); + auto regionArgs = op.args().slice(argOffset, regionOperands.size()); + argOffset += regionOperands.size(); + + for (auto it : llvm::zip(regionArgs, regionOperands)) { + auto regionArg = std::get<0>(it); + auto regionOperand = std::get<1>(it); + if (regionArg.getType() != regionOperand.getType()) { + return op.emitOpError() << "requires that regions have their arguments " + "represented in the op arg list in order (" + << regionArg.getType() + << " != " << regionOperand.getType() << ")"; + } + } + + for (auto &block : region) { + if (auto returnOp = + dyn_cast_or_null<IREE::HAL::ReturnOp>(block.getTerminator())) { + if (!std::equal(returnOp.getOperandTypes().begin(), + returnOp.getOperandTypes().end(), + op.getResultTypes().begin())) { + return op.emitOpError() + << "requires all regions return the same types"; + } + } + } + } + if (argOffset != op.args().size()) { + return op.emitOpError() << "requires that the total argument list matches " + "the sum of all region operands"; + } + return success(); +} + +//===----------------------------------------------------------------------===// // hal.executable //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td index 26ae91d..2e9f148 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1317,7 +1317,137 @@ ]; } +def HAL_DeviceSwitchOp : HAL_Op<"device.switch", [IsolatedFromAbove]> { + let summary = [{runtime device switch pseudo op}]; + let description = [{ + Switches between multiple regions based on the runtime device type. + The provided regions are matched against the runtime backend of the given + device and executed only when the device matches the conditions. + + Conditions can match on wildcards and be folded to enable conditions that + have similar bodies to be folded. The patterns themselves are only matched + once at startup and then the results are cached; the runtime overhead is + equivalent to a normal switch statement. In cases where the compiler can + statically identify the device type entire cases can be folded away. + + Supported conditions: + * `#hal.match...`: execute the region if the expression matches. + + Supported match expressions: + * `#hal.match.always`: always matches; useful for defaults. + * `#hal.match.any<[...]>`: matches if any of the nested expressions match. + * `#hal.match.all<[...]>`: matches only if all of the nested expressions + match. + * `#hal.device.match.id<"pattern*-?-*">`: matches against the device + identifier. The pattern is evaluated with standard file path wildcards + (`*` for zero or more characters and `?` for one character). + + If more than one condition is satisfied the first listed will be chosen. + More specific conditions should be earlier in the set. If no condition is + matched but there are return values the switch will abort at runtime. It's + strongly recommend that all switches that return values end with a trailing + `#hal.match.always` condition to handle the fallthrough case. + + Upon creation each condition region will have an empty entry block with the + specified operands available as arguments. Each region must be setup to + return the same types. + + ```mlir + %c0 = constant 0 : i32 + %c1 = constant 1 : i32 + %c2 = constant 2 : i32 + %device = ... : !hal.device + %0 = hal.device.switch(%device : !hal.device) -> i32 + #hal.device.match.id<"vulkan-v1.?-*">(%c1a = %c1 : i32) { + hal.return %c1a : i32 + }, + #hal.match.any<[#hal.device.match.id<"vmla">, #hal.device.match.id<"vulkan-*">]>(%c2a = %c2 : i32) { + hal.return %c2a : i32 + }, + #hal.match.always(%c0a = %c0 : i32) { + hal.return %c0a : i32 + } + ``` + }]; + + let arguments = (ins + HAL_Device:$device, + ArrayAttr:$conditions, + Variadic<AnyType>:$args + ); + let results = (outs + Variadic<AnyType>:$results + ); + + let regions = (region VariadicRegion<AnyRegion>:$condition_regions); + + let extraClassDeclaration = [{ + /// Returns the index of the args() operand in the Operation operands list. + unsigned mapArgOperandToOpOperand(unsigned i) { return i + 1; } + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &state, TypeRange resultTypes, + Value device, ArrayRef<Attribute> conditions, + ArrayRef<ValueRange> conditionArgs, + ArrayRef<NamedAttribute> attributes = {} + }]>, + ]; + + let verifier = [{ return verifyDeviceSwitchOp(*this); }]; +} + +def HAL_ReturnOp : HAL_Op<"return", [Terminator]> { + let summary = [{return from a hal.device.switch region}]; + let description = [{ + Returns the given values from the region and back to the host code. + }]; + + let arguments = (ins + Variadic<AnyType>:$operands + ); + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + + let builders = [ + OpBuilder<[{ + Builder *builder, OperationState &result + }], [{ + build(builder, result, llvm::None); + }]>, + ]; +} + // TODO(benvanik): additional factory functions and submission ops. +// TODO(benvanik): %0 = hal.device.query %device, group, property : i32/etc + +def HAL_DeviceMatchIDOp : HAL_PureOp<"device.match.id"> { + let summary = [{returns true if the device ID matches the pattern}]; + let description = [{ + Pattern matches the device ID with the given wildcard pattern. + This can be used to conditionally evaluate device-specific code when the + device is not known at compile-time. + + ```mlir + %is_match = hal.device.match.id %device, pattern = ["vulkan-*"] : (!hal.device) -> i1 + ``` + }]; + + let arguments = (ins + HAL_Device:$device, + StrAttr:$pattern + ); + let results = (outs + I1:$result + ); + + let assemblyFormat = [{ + $device `,` `pattern` `=` `[` $pattern `]` attr-dict + `:` `(` type($device) `)` `->` type($result) + }]; +} //===----------------------------------------------------------------------===// // iree::hal::Executable
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp index 37e00dc..b15fd4b 100644 --- a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -134,6 +134,81 @@ os << ">"; } +// static +Attribute MatchAlwaysAttr::parse(DialectAsmParser &p) { + return get(p.getBuilder().getContext()); +} + +void MatchAlwaysAttr::print(DialectAsmPrinter &p) const { + auto &os = p.getStream(); + os << getKindName(); +} + +static ArrayAttr parseMultiMatchAttrArray(DialectAsmParser &p) { + auto b = p.getBuilder(); + SmallVector<Attribute, 4> conditionAttrs; + if (failed(p.parseLess()) || failed(p.parseLSquare())) { + return {}; + } + do { + Attribute conditionAttr; + if (failed(p.parseAttribute(conditionAttr))) { + return {}; + } + conditionAttrs.push_back(conditionAttr); + } while (succeeded(p.parseOptionalComma())); + if (failed(p.parseRSquare()) || failed(p.parseGreater())) { + return {}; + } + return b.getArrayAttr(conditionAttrs); +} + +static void printMultiMatchAttrList(ArrayAttr conditionAttrs, + DialectAsmPrinter &p) { + auto &os = p.getStream(); + os << "<["; + interleaveComma(conditionAttrs, os, + [&](Attribute condition) { os << condition; }); + os << "]>"; +} + +// static +Attribute MatchAnyAttr::parse(DialectAsmParser &p) { + return get(parseMultiMatchAttrArray(p)); +} + +void MatchAnyAttr::print(DialectAsmPrinter &p) const { + p << getKindName(); + printMultiMatchAttrList(conditions().cast<ArrayAttr>(), p); +} + +// static +Attribute MatchAllAttr::parse(DialectAsmParser &p) { + return get(parseMultiMatchAttrArray(p)); +} + +void MatchAllAttr::print(DialectAsmPrinter &p) const { + p << getKindName(); + printMultiMatchAttrList(conditions().cast<ArrayAttr>(), p); +} + +// static +Attribute DeviceMatchIDAttr::parse(DialectAsmParser &p) { + StringAttr patternAttr; + if (failed(p.parseLess()) || failed(p.parseAttribute(patternAttr)) || + failed(p.parseGreater())) { + return {}; + } + return get(patternAttr); +} + +void DeviceMatchIDAttr::print(DialectAsmPrinter &p) const { + auto &os = p.getStream(); + os << getKindName() << "<\""; + os << pattern(); + os << "\">"; +} + #include "iree/compiler/Dialect/HAL/IR/HALOpInterface.cpp.inc" } // namespace HAL
diff --git a/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir index d40e27f..932f6f7 100644 --- a/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir
@@ -1,6 +1,4 @@ -// Tests printing and parsing of hal.device ops. - -// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s +// RUN: iree-opt -print-ir-after-all -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s // CHECK-LABEL: @device_allocator func @device_allocator() -> !hal.allocator { @@ -9,3 +7,48 @@ %allocator = hal.device.allocator %0 : !hal.allocator return %allocator : !hal.allocator } + +// ----- + +// CHECK-LABEL: @device_switch +func @device_switch() -> i32 { + // CHECK-DAG: %[[C0:.+]] = constant 0 + %c0 = constant 0 : i32 + // CHECK-DAG: %[[C1:.+]] = constant 1 + %c1 = constant 1 : i32 + // CHECK-DAG: %[[C2:.+]] = constant 2 + %c2 = constant 2 : i32 + // CHECK-DAG: %[[DEVICE:.+]] = "test_hal.device" + %device = "test_hal.device"() : () -> !hal.device + // CHECK: = hal.device.switch(%[[DEVICE]] : !hal.device) -> i32 + %0 = hal.device.switch(%device : !hal.device) -> i32 + // CHECK-NEXT: #hal.device.match.id<"vulkan-v1.?-*">(%[[C1A:.+]] = %[[C1]] : i32) { + #hal.device.match.id<"vulkan-v1.?-*">(%c1a = %c1 : i32) { + // CHECK-NEXT: hal.return %[[C1A]] : i32 + hal.return %c1a : i32 + // CHECK-NEXT: }, + }, + // CHECK-NEXT: #hal.match.any<[#hal.device.match.id<"vmla">, #hal.device.match.id<"vulkan-*">]>(%[[C2A:.+]] = %[[C2]] : i32) { + #hal.match.any<[#hal.device.match.id<"vmla">, #hal.device.match.id<"vulkan-*">]>(%c2a = %c2 : i32) { + // CHECK-NEXT: hal.return %[[C2A]] : i32 + hal.return %c2a : i32 + // CHECK-NEXT: }, + }, + // CHECK-NEXT: #hal.match.always(%[[C0A:.+]] = %[[C0]] : i32) { + #hal.match.always(%c0a = %c0 : i32) { + // CHECK-NEXT: hal.return %[[C0A]] : i32 + hal.return %c0a : i32 + // CHECK-NEXT: } + } + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: @device_matchers +// CHECK-SAME: %[[DEVICE:.+]]: !hal.device +func @device_matchers(%device : !hal.device) -> i1 { + // CHECK: = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-*"] : (!hal.device) -> i1 + %0 = hal.device.match.id %device, pattern = ["vulkan-*"] : (!hal.device) -> i1 + return %0 : i1 +}
diff --git a/iree/compiler/Dialect/HAL/Transforms/BUILD b/iree/compiler/Dialect/HAL/Transforms/BUILD index 325aedf..5da3750 100644 --- a/iree/compiler/Dialect/HAL/Transforms/BUILD +++ b/iree/compiler/Dialect/HAL/Transforms/BUILD
@@ -22,6 +22,8 @@ srcs = [ "MaterializeInterfaces.cpp", "MaterializeResourceCaches.cpp", + "MemoizeDeviceQueries.cpp", + "OutlineDeviceSwitches.cpp", "Passes.cpp", "PublicAbiGeneration.cpp", "RewriteLegacyIO.cpp",
diff --git a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt index ddd3cb3..9b91962 100644 --- a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -20,8 +20,10 @@ HDRS "Passes.h" SRCS + "MemoizeDeviceQueries.cpp" "MaterializeInterfaces.cpp" "MaterializeResourceCaches.cpp" + "OutlineDeviceSwitches.cpp" "Passes.cpp" "PublicAbiGeneration.cpp" "RewriteLegacyIO.cpp"
diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp index 4ffb8df..caae5e5 100644 --- a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -134,7 +134,7 @@ auto layoutUsage = IREE::HAL::DescriptorSetLayoutUsageType::PushOnly; auto layoutValue = blockBuilder.createOrFold<DescriptorSetLayoutCreateOp>( loc, layoutType, deviceValue, layoutUsage, bindingsAttr); - blockBuilder.create<ReturnOp>(loc, layoutValue); + blockBuilder.create<mlir::ReturnOp>(loc, layoutValue); return variableOp; } @@ -189,7 +189,7 @@ } auto layoutValue = blockBuilder.createOrFold<ExecutableLayoutCreateOp>( loc, layoutType, deviceValue, setLayoutValues, pushConstantsAttr); - blockBuilder.create<ReturnOp>(loc, layoutValue); + blockBuilder.create<mlir::ReturnOp>(loc, layoutValue); return variableOp; } @@ -250,7 +250,7 @@ executableVariableOp.sym_name()); } - blockBuilder.create<ReturnOp>(loc, executableCacheValue); + blockBuilder.create<mlir::ReturnOp>(loc, executableCacheValue); return variableOp; }
diff --git a/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp new file mode 100644 index 0000000..83b1884 --- /dev/null +++ b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
@@ -0,0 +1,114 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <utility> + +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { + +// NOTE: this implementation is just for a single active device. As we start to +// support multiple devices we'll need to change this to be per-device. +class MemoizeDeviceQueriesPass + : public PassWrapper<MemoizeDeviceQueriesPass, OperationPass<ModuleOp>> { + public: + void runOnOperation() override { + // Find all match ops we want to memoize and group them together. + // This lets us easily replace all usages of a match with a single variable. + DenseMap<Attribute, std::vector<IREE::HAL::DeviceMatchIDOp>> + deviceIDMatchOps; + SmallVector<Attribute, 4> deviceIDMatchKeys; + auto moduleOp = getOperation(); + for (auto funcOp : moduleOp.getOps<FuncOp>()) { + funcOp.walk([&](IREE::HAL::DeviceMatchIDOp matchOp) { + auto key = matchOp.patternAttr().cast<Attribute>(); + auto lookup = deviceIDMatchOps.try_emplace( + key, std::vector<IREE::HAL::DeviceMatchIDOp>{}); + if (lookup.second) { + deviceIDMatchKeys.push_back(key); + } + lookup.first->second.push_back(matchOp); + return WalkResult::advance(); + }); + } + + // Create each match variable and replace the uses with loads. + auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); + for (auto matchKey : llvm::enumerate(deviceIDMatchKeys)) { + auto matchOps = deviceIDMatchOps[matchKey.value()]; + auto pattern = matchOps.front().pattern(); + + // Merge all the locs as we are deduping the original query ops. + auto fusedLoc = FusedLoc::get( + llvm::to_vector<4>(llvm::map_range( + matchOps, [&](Operation *op) { return op->getLoc(); })), + moduleOp.getContext()); + + // The initializer will perform the query once and store it in the + // variable. + std::string variableName = + "_device_match_id_" + std::to_string(matchKey.index()); + auto initializerOp = moduleBuilder.create<FuncOp>( + fusedLoc, variableName + "_initializer", + moduleBuilder.getFunctionType({}, {moduleBuilder.getI1Type()}), + ArrayRef<NamedAttribute>{}); + SymbolTable::setSymbolVisibility(initializerOp, + SymbolTable::Visibility::Private); + auto variableOp = moduleBuilder.create<IREE::HAL::VariableOp>( + fusedLoc, variableName, + /*isMutable=*/false, initializerOp); + SymbolTable::setSymbolVisibility(variableOp, + SymbolTable::Visibility::Private); + + auto funcBuilder = OpBuilder::atBlockBegin(initializerOp.addEntryBlock()); + auto device = + funcBuilder.createOrFold<IREE::HAL::ExSharedDeviceOp>(fusedLoc); + auto matchOp = funcBuilder.create<IREE::HAL::DeviceMatchIDOp>( + fusedLoc, funcBuilder.getI1Type(), device, pattern); + funcBuilder.create<mlir::ReturnOp>(fusedLoc, matchOp.getResult()); + + for (auto matchOp : matchOps) { + OpBuilder replaceBuilder(matchOp); + auto loadOp = replaceBuilder.create<IREE::HAL::VariableLoadOp>( + fusedLoc, matchOp.getResult().getType(), variableOp.getName()); + matchOp.replaceAllUsesWith(loadOp.result()); + matchOp.erase(); + } + } + } +}; + +std::unique_ptr<OperationPass<ModuleOp>> createMemoizeDeviceQueriesPass() { + return std::make_unique<MemoizeDeviceQueriesPass>(); +} + +static PassRegistration<MemoizeDeviceQueriesPass> pass( + "iree-hal-memoize-device-queries", + "Caches hal.device.query results for use across the entire module"); + +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Transforms/OutlineDeviceSwitches.cpp b/iree/compiler/Dialect/HAL/Transforms/OutlineDeviceSwitches.cpp new file mode 100644 index 0000000..c4f4211 --- /dev/null +++ b/iree/compiler/Dialect/HAL/Transforms/OutlineDeviceSwitches.cpp
@@ -0,0 +1,235 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <utility> + +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" +#include "iree/compiler/Dialect/IREE/IR/IREEOps.h" +#include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { +namespace HAL { + +// Given a #hal.match.* expression tree returns a boolean value indicating +// whether the expression evaluates to true. +static Value buildConditionExpression(Location loc, Value device, + Attribute conditionAttr, + OpBuilder funcBuilder) { + if (auto matchAttr = conditionAttr.dyn_cast<IREE::HAL::MatchAlwaysAttr>()) { + // #hal.match.always -> true + return funcBuilder.createOrFold<ConstantIntOp>(loc, 1, 1); + } else if (auto matchAttr = + conditionAttr.dyn_cast<IREE::HAL::MatchAnyAttr>()) { + // #hal.match.any<[a, b, c]> -> or(or(a, b), c) + auto conditionAttrs = matchAttr.conditions().cast<ArrayAttr>(); + auto conditionValues = + llvm::to_vector<4>(llvm::map_range(conditionAttrs, [&](Attribute attr) { + return buildConditionExpression(loc, device, attr, funcBuilder); + })); + Value resultValue = conditionValues[0]; + for (int i = 1; i < conditionValues.size(); ++i) { + resultValue = + funcBuilder.createOrFold<OrOp>(loc, resultValue, conditionValues[i]); + } + return resultValue; + } else if (auto matchAttr = + conditionAttr.dyn_cast<IREE::HAL::MatchAllAttr>()) { + // #hal.match.all<[a, b, c]> -> and(and(a, b), c) + auto conditionAttrs = matchAttr.conditions().cast<ArrayAttr>(); + auto conditionValues = + llvm::to_vector<4>(llvm::map_range(conditionAttrs, [&](Attribute attr) { + return buildConditionExpression(loc, device, attr, funcBuilder); + })); + Value resultValue = conditionValues[0]; + for (int i = 1; i < conditionValues.size(); ++i) { + resultValue = + funcBuilder.createOrFold<AndOp>(loc, resultValue, conditionValues[i]); + } + return resultValue; + } else if (auto matchAttr = + conditionAttr.dyn_cast<IREE::HAL::DeviceMatchIDAttr>()) { + // #hal.device.match.id<"pattern"> -> hal.device.match.id + return funcBuilder.createOrFold<IREE::HAL::DeviceMatchIDOp>( + loc, funcBuilder.getI1Type(), device, matchAttr.patternAttr()); + } + llvm_unreachable("unhandled condition expression attribute"); + return {}; +} + +// Outlines a condition region from a switch op into a standalone function. +static FuncOp outlineConditionRegion(StringRef funcName, + Region &conditionRegion, + ArrayRef<Type> resultTypes, + OpBuilder moduleBuilder) { + auto &entryBlock = conditionRegion.front(); + auto funcType = moduleBuilder.getFunctionType( + llvm::to_vector<4>( + llvm::map_range(entryBlock.getArguments(), + [](BlockArgument arg) { return arg.getType(); })), + resultTypes); + auto funcOp = moduleBuilder.create<FuncOp>( + conditionRegion.getLoc(), funcName, funcType, ArrayRef<NamedAttribute>{}); + funcOp.getBody().takeBody(conditionRegion); + SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private); + + // Replace hal.return statements with normal std.return. This ensures that + // normal matchers/inlining/etc works as we continue transformation. + for (auto &block : funcOp.getBlocks()) { + if (auto returnOp = dyn_cast<IREE::HAL::ReturnOp>(block.back())) { + OpBuilder builder(returnOp); + builder.create<mlir::ReturnOp>( + returnOp.getLoc(), llvm::to_vector<4>(returnOp.getOperands())); + returnOp.erase(); + } + } + + return funcOp; +} + +// Outlines each switch condition region into its own function and replaces the +// switch op with conditioned calls to those functions. +// +// Since switch conditions are evaluated in the order they are defined we can +// trivially turn the switch into a chain of if-else blocks. +// if condition_0_match: +// call outlined_condition_0 +// else +// if condition_1_match: +// call outlined_condition_1 +// else ... +static void buildConditionDispatchTable(IREE::HAL::DeviceSwitchOp switchOp, + StringRef baseFuncName, + OpBuilder moduleBuilder, + OpBuilder funcBuilder) { + // Split the block containing the switch op such that all ops before the + // switch are before and the switch and the following ops are after. + // We'll have all of our outlined regions bounce over to the afterBlock with + // the results of the call and use that to replace the switch op. + auto *beforeBlock = funcBuilder.getBlock(); + auto *afterBlock = beforeBlock->splitBlock(switchOp); + auto finalValues = + llvm::to_vector<4>(afterBlock->addArguments(switchOp.getResultTypes())); + + // Create the blocks we'll use for all our conditions so that we can reference + // them when inserting the branch ops. + SmallVector<Block *, 4> conditionMatchBlocks( + switchOp.condition_regions().size()); + SmallVector<Block *, 4> conditionFallthroughBlocks( + switchOp.condition_regions().size()); + for (int i = 0; i < conditionMatchBlocks.size(); ++i) { + conditionMatchBlocks[i] = funcBuilder.createBlock(afterBlock); + conditionFallthroughBlocks[i] = funcBuilder.createBlock(afterBlock); + } + + funcBuilder.setInsertionPoint(beforeBlock, beforeBlock->end()); + int argOffset = 0; + for (auto condition : llvm::enumerate(llvm::zip( + switchOp.conditions().getValue(), switchOp.condition_regions()))) { + auto conditionAttr = std::get<0>(condition.value()); + auto &conditionRegion = std::get<1>(condition.value()); + + // Get the arguments from the switch that we want to carry along in the + // block arguments. + auto regionOperands = conditionRegion.front().getArguments(); + auto regionArgs = switchOp.args().slice(argOffset, regionOperands.size()); + argOffset += regionOperands.size(); + + // Outline the region into a function. + std::string regionFuncName = + (baseFuncName + "_").str() + std::to_string(condition.index()); + auto regionFuncOp = outlineConditionRegion( + regionFuncName, conditionRegion, + switchOp.getOperation()->getResultTypes(), moduleBuilder); + + // Insert the branch based on the match. We either match and jump to a block + // that will call the function or don't match and need to fall through. + auto isMatch = buildConditionExpression( + switchOp.getLoc(), switchOp.device(), conditionAttr, funcBuilder); + auto *matchBlock = conditionMatchBlocks[condition.index()]; + auto *fallthroughBlock = conditionFallthroughBlocks[condition.index()]; + funcBuilder.create<CondBranchOp>(switchOp.getLoc(), isMatch, matchBlock, + fallthroughBlock); + + // Block that calls the outlined function and then jumps out of the chain. + funcBuilder.setInsertionPointToStart(matchBlock); + auto matchResults = + funcBuilder.create<CallOp>(switchOp.getLoc(), regionFuncOp, regionArgs); + funcBuilder.create<BranchOp>(switchOp.getLoc(), afterBlock, + matchResults.getResults()); + + // Block that we enter to check the next condition. + funcBuilder.setInsertionPointToStart(fallthroughBlock); + if (condition.index() + 1 < conditionFallthroughBlocks.size()) { + // Just continue on - the next loop iteration for the following condition + // will add its IR to the block. + } else { + // Fallthrough of all expressions; die if we expected return values. + if (switchOp.getNumResults() > 0) { + funcBuilder.create<IREE::UnreachableOp>(switchOp.getLoc()); + } else { + funcBuilder.create<BranchOp>(switchOp.getLoc(), afterBlock); + } + } + } + + // Remove the switch op and replace its results with the final joined results. + switchOp.replaceAllUsesWith(finalValues); + switchOp.erase(); +} + +class OutlineDeviceSwitchesPass + : public PassWrapper<OutlineDeviceSwitchesPass, OperationPass<ModuleOp>> { + public: + void runOnOperation() override { + auto moduleOp = getOperation(); + auto funcOps = llvm::to_vector<16>(moduleOp.getOps<FuncOp>()); + for (auto &funcOp : funcOps) { + OpBuilder moduleBuilder(funcOp); + moduleBuilder.setInsertionPointAfter(funcOp); + for (auto &block : funcOp) { + auto switchOps = + llvm::to_vector<4>(block.getOps<IREE::HAL::DeviceSwitchOp>()); + for (auto switchOp : llvm::enumerate(switchOps)) { + std::string baseFuncName = (funcOp.getName() + "_switch_").str() + + std::to_string(switchOp.index()); + OpBuilder funcBuilder(switchOp.value()); + buildConditionDispatchTable(switchOp.value(), baseFuncName, + moduleBuilder, funcBuilder); + } + } + } + } +}; + +std::unique_ptr<OperationPass<ModuleOp>> createOutlineDeviceSwitchesPass() { + return std::make_unique<OutlineDeviceSwitchesPass>(); +} + +static PassRegistration<OutlineDeviceSwitchesPass> pass( + "iree-hal-outline-device-switches", + "Outlines hal.device.switch condition regions"); + +} // namespace HAL +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index d7fc72f..b74558d 100644 --- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -51,6 +51,10 @@ passManager.addNestedPass<FuncOp>(createCanonicalizerPass()); passManager.addNestedPass<FuncOp>(createCSEPass()); + passManager.addPass(createOutlineDeviceSwitchesPass()); + passManager.addPass(createMemoizeDeviceQueriesPass()); + // TODO(benvanik): function deduplication to remove outlined functions. + // TODO(benvanik): run symbol DCE when all symbols have visibility defined. // Right now the global value initializers don't have proper tracking and if // we do this we lose initializers that have side effects we care about.
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.h b/iree/compiler/Dialect/HAL/Transforms/Passes.h index f5dc6c1..9db239d 100644 --- a/iree/compiler/Dialect/HAL/Transforms/Passes.h +++ b/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -48,6 +48,16 @@ ExecutableTargetOptions executableOptions); //===----------------------------------------------------------------------===// +// Device management +//===----------------------------------------------------------------------===// + +// Outlines hal.device.switch conditions into functions and inlines conditions. +std::unique_ptr<OperationPass<ModuleOp>> createOutlineDeviceSwitchesPass(); + +// Finds hal.device.query ops and creates variables initialized on startup. +std::unique_ptr<OperationPass<ModuleOp>> createMemoizeDeviceQueriesPass(); + +//===----------------------------------------------------------------------===// // Executable translation and optimization //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp b/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp index a03fb3e..c3bdb42 100644 --- a/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp
@@ -212,7 +212,7 @@ } // Add the return. - builder.create<ReturnOp>(loc, funcResults); + builder.create<mlir::ReturnOp>(loc, funcResults); return success(); }
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir b/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir new file mode 100644 index 0000000..de7f576 --- /dev/null +++ b/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
@@ -0,0 +1,23 @@ +// RUN: iree-opt -split-input-file -iree-hal-memoize-device-queries %s | IreeFileCheck %s + +// CHECK: func @_device_match_id_0_initializer() -> i1 +// CHECK-NEXT: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device +// CHECK-NEXT: %[[IS_MATCH:.+]] = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1 +// CHECK-NEXT: return %[[IS_MATCH]] : i1 +// CHECK: hal.variable @_device_match_id_0 init(@_device_match_id_0_initializer) : i1 + +// CHECK: hal.variable @_device_match_id_1 +// CHECK: hal.variable @_device_match_id_2 + +// CHECK-LABEL: @device_matchers +func @device_matchers(%device : !hal.device) { + // CHECK-NEXT: = hal.variable.load @_device_match_id_0 : i1 + %0 = hal.device.match.id %device, pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1 + // CHECK-NEXT: = hal.variable.load @_device_match_id_0 : i1 + %1 = hal.device.match.id %device, pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1 + // CHECK-NEXT: = hal.variable.load @_device_match_id_1 : i1 + %2 = hal.device.match.id %device, pattern = ["vulkan-v2.?-*"] : (!hal.device) -> i1 + // CHECK-NEXT: = hal.variable.load @_device_match_id_2 : i1 + %3 = hal.device.match.id %device, pattern = ["vulkan-*"] : (!hal.device) -> i1 + return +}
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/outline_device_switches.mlir b/iree/compiler/Dialect/HAL/Transforms/test/outline_device_switches.mlir new file mode 100644 index 0000000..c2ae10f --- /dev/null +++ b/iree/compiler/Dialect/HAL/Transforms/test/outline_device_switches.mlir
@@ -0,0 +1,98 @@ +// RUN: iree-opt -split-input-file -iree-hal-outline-device-switches %s | IreeFileCheck %s + +// CHECK-LABEL: @simple_constants +// CHECK-SAME: %[[DEVICE:.+]]: !hal.device +func @simple_constants(%device : !hal.device) -> i32 { + // CHECK-DAG: %[[C0:.+]] = constant 0 + %c0 = constant 0 : i32 + // CHECK-DAG: %[[C1:.+]] = constant 1 + %c1 = constant 1 : i32 + // CHECK-DAG: %[[C2:.+]] = constant 2 + %c2 = constant 2 : i32 + %0 = hal.device.switch(%device : !hal.device) -> i32 + // CHECK-NEXT: %[[IS0:.+]] = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1 + // CHECK-NEXT: cond_br %[[IS0]], ^bb1, ^bb2 + // CHECK-NEXT: ^bb1: + // CHECK-NEXT: %[[RES0:.+]] = call @simple_constants_switch_0_0(%[[C1]]) : (i32) -> i32 + // CHECK-NEXT: br ^bb7(%[[RES0]] : i32) + #hal.device.match.id<"vulkan-v1.?-*">(%c1a = %c1 : i32) { + hal.return %c1a : i32 + }, + // CHECK-NEXT: ^bb2: + // CHECK-NEXT: %[[IS1L:.+]] = hal.device.match.id %arg0, pattern = ["vmla"] : (!hal.device) -> i1 + // CHECK-NEXT: %[[IS1R:.+]] = hal.device.match.id %arg0, pattern = ["vulkan-*"] : (!hal.device) -> i1 + // CHECK-NEXT: %[[IS1:.+]] = or %[[IS1L]], %[[IS1R]] : i1 + // CHECK-NEXT: cond_br %[[IS1]], ^bb3, ^bb4 + // CHECK-NEXT: ^bb3: + // CHECK-NEXT: %[[RES1:.+]] = call @simple_constants_switch_0_1(%[[C2]]) : (i32) -> i32 + // CHECK-NEXT: br ^bb7(%[[RES1]] : i32) + #hal.match.any<[#hal.device.match.id<"vmla">, #hal.device.match.id<"vulkan-*">]>(%c2a = %c2 : i32) { + hal.return %c2a : i32 + }, + // CHECK-NEXT: ^bb4: + // CHECK-NEXT: %[[IS2:.+]] = constant 1 : i1 + // CHECK-NEXT: cond_br %[[IS2]], ^bb5, ^bb6 + // CHECK-NEXT: ^bb5: + // CHECK-NEXT: %[[RES2:.+]] = call @simple_constants_switch_0_2(%[[C0]]) : (i32) -> i32 + // CHECK-NEXT: br ^bb7(%[[RES2]] : i32) + #hal.match.always(%c0a = %c0 : i32) { + hal.return %c0a : i32 + } + // CHECK-NEXT: ^bb6: + // CHECK-NEXT: iree.unreachable + // CHECK-NEXT: ^bb7(%[[RES:.+]]: i32): + // CHECK-NEXT: return %[[RES]] : i32 + return %0 : i32 +} + +// CHECK: func @simple_constants_switch_0_0(%arg0: i32) -> i32 +// CHECK-NEXT: return %arg0 : i32 +// CHECK: func @simple_constants_switch_0_1(%arg0: i32) -> i32 +// CHECK-NEXT: return %arg0 : i32 +// CHECK: func @simple_constants_switch_0_2(%arg0: i32) -> i32 +// CHECK-NEXT: return %arg0 : i32 + +// ----- + +// CHECK-LABEL: @no_results +// CHECK-SAME: %[[DEVICE:.+]]: !hal.device +func @no_results(%device : !hal.device) { + hal.device.switch(%device : !hal.device) + // CHECK-NEXT: %[[IS0:.+]] = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1 + // CHECK-NEXT: cond_br %[[IS0]], ^bb1, ^bb2 + // CHECK-NEXT: ^bb1: + // CHECK-NEXT: call @no_results_switch_0_0() : () -> () + // CHECK-NEXT: br ^bb7 + #hal.device.match.id<"vulkan-v1.?-*">() { + hal.return + }, + // CHECK-NEXT: ^bb2: + // CHECK-NEXT: %[[IS1L:.+]] = hal.device.match.id %arg0, pattern = ["vmla"] : (!hal.device) -> i1 + // CHECK-NEXT: %[[IS1R:.+]] = hal.device.match.id %arg0, pattern = ["vulkan-*"] : (!hal.device) -> i1 + // CHECK-NEXT: %[[IS1:.+]] = or %[[IS1L]], %[[IS1R]] : i1 + // CHECK-NEXT: cond_br %[[IS1]], ^bb3, ^bb4 + // CHECK-NEXT: ^bb3: + // CHECK-NEXT: call @no_results_switch_0_1() : () -> () + // CHECK-NEXT: br ^bb7 + #hal.match.any<[#hal.device.match.id<"vmla">, #hal.device.match.id<"vulkan-*">]>() { + hal.return + }, + // CHECK-NEXT: ^bb4: + // CHECK-NEXT: %[[IS2:.+]] = constant 1 : i1 + // CHECK-NEXT: cond_br %[[IS2]], ^bb5, ^bb6 + // CHECK-NEXT: ^bb5: + // CHECK-NEXT: call @no_results_switch_0_2() : () -> () + // CHECK-NEXT: br ^bb7 + #hal.match.always() { + hal.return + } + // CHECK-NEXT: ^bb6: + // CHECK-NEXT: br ^bb7 + // CHECK-NEXT: ^bb7: + // CHECK-NEXT: return + return +} + +// CHECK: func @no_results_switch_0_0() +// CHECK: func @no_results_switch_0_1() +// CHECK: func @no_results_switch_0_2()
diff --git a/iree/compiler/Dialect/HAL/hal.imports.mlir b/iree/compiler/Dialect/HAL/hal.imports.mlir index 559b00f..6ab6153 100644 --- a/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -352,6 +352,13 @@ ) -> !vm.ref<!hal.allocator> attributes {nosideeffects} +// Returns true if the device ID matches the pattern. +vm.import @device.match.id( + %device : !vm.ref<!hal.device>, + %pattern : !vm.ref<!iree.byte_buffer> +) -> i32 +attributes {nosideeffects} + //===----------------------------------------------------------------------===// // iree::hal::ExecutableCache //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/IREE/IR/IREEAttributes.h b/iree/compiler/Dialect/IREE/IR/IREEAttributes.h index 4b2ae13..c4a9043 100644 --- a/iree/compiler/Dialect/IREE/IR/IREEAttributes.h +++ b/iree/compiler/Dialect/IREE/IR/IREEAttributes.h
@@ -32,6 +32,10 @@ namespace AttrKind { enum Kind { DescriptorSetLayoutBindingAttr = IREE::AttrKind::FIRST_HAL_ATTR, + MatchAlwaysAttr, + MatchAllAttr, + MatchAnyAttr, + DeviceMatchIDAttr, }; } // namespace AttrKind } // namespace HAL
diff --git a/iree/compiler/Dialect/IREE/IR/IREEOps.td b/iree/compiler/Dialect/IREE/IR/IREEOps.td index d200a80..f307237 100644 --- a/iree/compiler/Dialect/IREE/IR/IREEOps.td +++ b/iree/compiler/Dialect/IREE/IR/IREEOps.td
@@ -90,4 +90,27 @@ let hasCanonicalizer = 1; } +def IREE_UnreachableOp : IREE_Op<"unreachable", [NoSideEffect, Terminator]> { + let summary = [{unreachable assertion op}]; + let description = [{ + Signals to the compiler that the parent block should not be reachable. + This may be converted into a runtime assertion, though ideally they are + stripped during translation. + + ```mlir + ^bb0: + %true = constant 1 : i1 + cond_br %true, ^bb2, ^bb1 + ^bb1: + // Indicates that this branch should never be taken. + iree.unreachable + ^bb2: + ... + + ``` + }]; + + let assemblyFormat = "attr-dict"; +} + #endif // IREE_DIALECT_IREE_OPS
diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.h b/iree/compiler/Dialect/IREE/IR/IREETypes.h index 9a1629d..e4605cf 100644 --- a/iree/compiler/Dialect/IREE/IR/IREETypes.h +++ b/iree/compiler/Dialect/IREE/IR/IREETypes.h
@@ -81,7 +81,10 @@ namespace Strings { namespace TypeKind { -enum Kind { String = IREE::TypeKind::FIRST_STRING_TYPE, StringTensor }; +enum Kind { + String = IREE::TypeKind::FIRST_STRING_TYPE, + StringTensor, +}; } // namespace TypeKind } // namespace Strings
diff --git a/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp b/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp index 3734ff5..42bf726 100644 --- a/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp +++ b/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp
@@ -91,12 +91,17 @@ }; static void emitStructClass(const StructAttr &structAttr, raw_ostream &os) { + if (!structAttr.getAllFields().empty()) { + os << formatv(R"( +namespace detail { +struct {0}Storage; +} // namespace detail +)", + structAttr.getStructClassName()); + } os << formatv(R"( // {0} -namespace detail { -struct {1}Storage; -} // namespace detail -class {1} : public mlir::Attribute::AttrBase<{1}, mlir::Attribute, detail::{1}Storage> { +class {1} : public mlir::Attribute::AttrBase<{1}, mlir::Attribute, {3}Storage> { public: using Base::Base; @@ -105,7 +110,10 @@ )", structAttr.getDescription(), structAttr.getStructClassName(), - structAttr.getStructKind()); + structAttr.getStructKind(), + structAttr.getAllFields().empty() + ? "Attribute" + : "detail::" + structAttr.getStructClassName()); if (!structAttr.getAllFields().empty()) { os << " static LogicalResult verifyConstructionInvariants(\n"; @@ -134,12 +142,14 @@ os << ");\n\n"; // Attribute return type constructor (APInt, etc). - os << formatv(" static {0} get(\n", structAttr.getStructClassName()); - for (auto field : structAttr.getAllFields()) { - auto type = field.getType(); - os << formatv(" {0} {1},\n", type.getReturnType(), field.getName()); + if (!structAttr.getAllFields().empty()) { + os << formatv(" static {0} get(\n", structAttr.getStructClassName()); + for (auto field : structAttr.getAllFields()) { + auto type = field.getType(); + os << formatv(" {0} {1},\n", type.getReturnType(), field.getName()); + } + os << " mlir::MLIRContext* context);\n"; } - os << " mlir::MLIRContext* context);\n"; os << R"( static Attribute parse(DialectAsmParser &p); @@ -322,11 +332,13 @@ structAttr.getAllFields().front().getName()); } - os << formatv(" return Base::get(context, AttrKind::{0},\n", + os << formatv(" return Base::get(context, AttrKind::{0}", structAttr.getStructClassName()); - os << " "; - interleaveComma(structAttr.getAllFields(), os, - [&](StructFieldAttr field) { os << field.getName(); }); + if (!structAttr.getAllFields().empty()) { + os << "\n, "; + interleaveComma(structAttr.getAllFields(), os, + [&](StructFieldAttr field) { os << field.getName(); }); + } os << ");\n"; os << "}\n\n"; @@ -412,12 +424,16 @@ } os << "\n"; - emitStorageDef(structAttr, os); - emitVerifierDef(structAttr, os); + if (!structAttr.getAllFields().empty()) { + emitStorageDef(structAttr, os); + emitVerifierDef(structAttr, os); + } emitAttrFactoryDef(structAttr, os); - emitTypedFactoryDef(structAttr, os); - for (auto field : structAttr.getAllFields()) { - emitAccessorDefs(structAttr, field, os); + if (!structAttr.getAllFields().empty()) { + emitTypedFactoryDef(structAttr, os); + for (auto field : structAttr.getAllFields()) { + emitAccessorDefs(structAttr, field, os); + } } emitWalkStorageDef(structAttr, os);
diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp index c64d620..841fa31 100644 --- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp +++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
@@ -146,11 +146,11 @@ } }; -class ReturnOpConversion : public OpConversionPattern<ReturnOp> { +class ReturnOpConversion : public OpConversionPattern<mlir::ReturnOp> { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - ReturnOp srcOp, ArrayRef<Value> operands, + mlir::ReturnOp srcOp, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp<IREE::VM::ReturnOp>(srcOp, operands); return success();
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.cpp b/iree/compiler/Dialect/VM/IR/VMOps.cpp index d82ba6d..82fc1be 100644 --- a/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -378,6 +378,7 @@ p.printOptionalAttrDictWithKeyword(op->getAttrs(), /*elidedAttrs=*/{ "sym_name", "is_mutable", + "initializer", "initial_value", "type", });
diff --git a/iree/compiler/Translation/CodegenUtils/CodegenUtils.cpp b/iree/compiler/Translation/CodegenUtils/CodegenUtils.cpp index 8f554b7..ce0716d 100644 --- a/iree/compiler/Translation/CodegenUtils/CodegenUtils.cpp +++ b/iree/compiler/Translation/CodegenUtils/CodegenUtils.cpp
@@ -147,8 +147,9 @@ if (!isDispatchFuncImpl(funcOp)) return op->emitError("expected operation to be within a dispatch function"); MLIRContext *context = op->getContext(); - SmallVector<int32_t, 3> workGroupSizeVec(workGroupSize.begin(), - workGroupSize.end()); + SmallVector<int32_t, 3> workGroupSizeVec(llvm::map_range( + workGroupSize, + [](int64_t value) { return static_cast<int32_t>(value); })); workGroupSizeVec.resize(3, 1); funcOp.setAttr(spirv::getEntryPointABIAttrName(), spirv::getEntryPointABIAttr(workGroupSizeVec, context));
diff --git a/iree/hal/api.cc b/iree/hal/api.cc index e66c622..08e008c 100644 --- a/iree/hal/api.cc +++ b/iree/hal/api.cc
@@ -1023,6 +1023,14 @@ return reinterpret_cast<iree_hal_allocator_t*>(handle->allocator()); } +IREE_API_EXPORT iree_string_view_t IREE_API_CALL +iree_hal_device_id(iree_hal_device_t* device) { + auto* handle = reinterpret_cast<Device*>(device); + if (!handle) return IREE_STRING_VIEW_EMPTY; + const auto& id = handle->info().id(); + return iree_string_view_t{id.data(), id.size()}; +} + //===----------------------------------------------------------------------===// // iree::hal::Driver //===----------------------------------------------------------------------===//
diff --git a/iree/hal/api.h b/iree/hal/api.h index 22e58fa..89a9689 100644 --- a/iree/hal/api.h +++ b/iree/hal/api.h
@@ -960,6 +960,12 @@ IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL iree_hal_device_allocator(iree_hal_device_t* device); +// Returns the device identifier. +// This identifier may vary based on the runtime device type; for example, a +// Vulkan device may return `vulkan-v1.1` or `vulkan-v1.2-spec1`. +IREE_API_EXPORT iree_string_view_t IREE_API_CALL +iree_hal_device_id(iree_hal_device_t* device); + #endif // IREE_API_NO_PROTOTYPES //===----------------------------------------------------------------------===//
diff --git a/iree/hal/dawn/dawn_driver.cc b/iree/hal/dawn/dawn_driver.cc index f4b0b6f..34dd792 100644 --- a/iree/hal/dawn/dawn_driver.cc +++ b/iree/hal/dawn/dawn_driver.cc
@@ -37,9 +37,10 @@ // supported_features |= DeviceFeature::kProfiling; // TODO(scotttodd): more clever/sanitized device naming. + std::string device_id = "dawn"; std::string device_name = absl::StrCat("dawn-", adapter->GetPCIInfo().name); - return DeviceInfo(device_name, supported_features, + return DeviceInfo(device_id, device_name, supported_features, reinterpret_cast<DriverDeviceID>(adapter)); }
diff --git a/iree/hal/device_info.h b/iree/hal/device_info.h index 6d5661f..4fb5ffb 100644 --- a/iree/hal/device_info.h +++ b/iree/hal/device_info.h
@@ -63,12 +63,20 @@ // TODO(benvanik): device info (caps, physical mappings, etc). class DeviceInfo { public: - DeviceInfo(std::string name, DeviceFeatureBitfield supported_features, + DeviceInfo(std::string id, std::string name, + DeviceFeatureBitfield supported_features, DriverDeviceID device_id = 0) - : name_(std::move(name)), + : id_(std::move(id)), + name_(std::move(name)), supported_features_(supported_features), device_id_(device_id) {} + // Machine-friendly device identifier used to match the device against + // compiler-generated patterns. This should be consistent with the device IDs + // emitted by the compiler. For example: `vulkan-v1.1-spec`. + const std::string& id() const { return id_; } + + // Human-friendly device name. const std::string& name() const { return name_; } // Features supported by the device. @@ -82,6 +90,7 @@ DriverDeviceID device_id() const { return device_id_; } private: + const std::string id_; const std::string name_; const DeviceFeatureBitfield supported_features_; DriverDeviceID device_id_;
diff --git a/iree/hal/llvmjit/llvmjit_driver.cc b/iree/hal/llvmjit/llvmjit_driver.cc index 7b22ffd..da4c6b1 100644 --- a/iree/hal/llvmjit/llvmjit_driver.cc +++ b/iree/hal/llvmjit/llvmjit_driver.cc
@@ -30,7 +30,7 @@ // supported_features |= DeviceFeature::kDebugging; // supported_features |= DeviceFeature::kCoverage; // supported_features |= DeviceFeature::kProfiling; - DeviceInfo device_info("llvm", supported_features); + DeviceInfo device_info("llvmjit", "llvm", supported_features); // TODO(benvanik): device info. return device_info; }
diff --git a/iree/hal/vmla/vmla_driver.cc b/iree/hal/vmla/vmla_driver.cc index 59e2ad1..6f21f7f 100644 --- a/iree/hal/vmla/vmla_driver.cc +++ b/iree/hal/vmla/vmla_driver.cc
@@ -35,7 +35,7 @@ // supported_features |= DeviceFeature::kDebugging; // supported_features |= DeviceFeature::kCoverage; // supported_features |= DeviceFeature::kProfiling; - DeviceInfo device_info("vmla", supported_features); + DeviceInfo device_info("vmla", "vmla", supported_features); // TODO(benvanik): device info. return device_info; }
diff --git a/iree/hal/vulkan/vulkan_driver.cc b/iree/hal/vulkan/vulkan_driver.cc index b67a762..f0da43f 100644 --- a/iree/hal/vulkan/vulkan_driver.cc +++ b/iree/hal/vulkan/vulkan_driver.cc
@@ -73,7 +73,7 @@ // supported_features |= DeviceFeature::kDebugging; // supported_features |= DeviceFeature::kCoverage; // supported_features |= DeviceFeature::kProfiling; - return DeviceInfo(std::move(name), supported_features, + return DeviceInfo("vulkan", std::move(name), supported_features, reinterpret_cast<DriverDeviceID>(physical_device)); }
diff --git a/iree/modules/hal/hal_module.cc b/iree/modules/hal/hal_module.cc index ca5685d..19e22fd 100644 --- a/iree/modules/hal/hal_module.cc +++ b/iree/modules/hal/hal_module.cc
@@ -590,9 +590,12 @@ vm::ref<iree_hal_executable_t> executable, int32_t entry_point, uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { IREE_TRACE_SCOPE0("HALModuleState::CommandBufferDispatch"); - return reinterpret_cast<CommandBuffer*>(command_buffer.get()) - ->Dispatch(reinterpret_cast<Executable*>(executable.get()), entry_point, - {workgroup_x, workgroup_y, workgroup_z}); + RETURN_IF_ERROR(FromApiStatus( + iree_hal_command_buffer_dispatch(command_buffer.get(), executable.get(), + entry_point, workgroup_x, workgroup_y, + workgroup_z), + IREE_LOC)); + return OkStatus(); } Status CommandBufferDispatchIndirect( @@ -600,11 +603,12 @@ vm::ref<iree_hal_executable_t> executable, int32_t entry_point, vm::ref<iree_hal_buffer_t> workgroups_buffer, int32_t workgroups_offset) { IREE_TRACE_SCOPE0("HALModuleState::CommandBufferDispatchIndirect"); - return reinterpret_cast<CommandBuffer*>(command_buffer.get()) - ->DispatchIndirect(reinterpret_cast<Executable*>(executable.get()), - entry_point, - reinterpret_cast<Buffer*>(workgroups_buffer.get()), - workgroups_offset); + RETURN_IF_ERROR( + FromApiStatus(iree_hal_command_buffer_dispatch_indirect( + command_buffer.get(), executable.get(), entry_point, + workgroups_buffer.get(), workgroups_offset), + IREE_LOC)); + return OkStatus(); } //===--------------------------------------------------------------------===// @@ -673,6 +677,15 @@ return vm::retain_ref(iree_hal_device_allocator(device.get())); } + StatusOr<int32_t> DeviceMatchID(vm::ref<iree_hal_device_t> device, + absl::string_view pattern) { + iree_string_view_t device_id = iree_hal_device_id(device.get()); + return iree_string_view_match_pattern( + device_id, iree_string_view_t{pattern.data(), pattern.size()}) + ? 1 + : 0; + } + //===--------------------------------------------------------------------===// // iree::hal::ExecutableCache //===--------------------------------------------------------------------===// @@ -831,6 +844,7 @@ vm::MakeNativeFunction("device.allocator", &HALModuleState::DeviceAllocator), + vm::MakeNativeFunction("device.match.id", &HALModuleState::DeviceMatchID), vm::MakeNativeFunction("executable_cache.create", &HALModuleState::ExecutableCacheCreate),