Merging hal.device.switch feature branch.

Adding a new `hal.device.switch` op that allows us to nicely generate backend/device-specific logic that can be transformed into efficient code for runtime. There's a lot of folders needed to help here, however this initial structure will let us begin work on #1168 as it exposes the ability to generate device-specific dispatch ops.

Future iteration will add more device matchers to allow more fine-grained codegen. For example, matchers that look for supported device-specific extensions, minimum allocation alignment, unified vs. discrete memory, maximum descriptor set count, etc can all be added in a generic way and efficiently lowered to runtime code.

Progress towards heterogeneous execution will cause the `!hal.device` type to change to a `hal.device<ordinal>` form, allowing better compile-time evaluation of the `hal.device.switch` blocks when placement happens. For now we still only have a single runtime-selected `hal.ex.shared_device` so this cannot happen.

Closes https://github.com/google/iree/pull/1393

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/1393 from google:benvanik-hal-device-switch d0348ef32a6af9940d4f9243a1cce88921bd398e
PiperOrigin-RevId: 306259903
diff --git a/iree/base/api.cc b/iree/base/api.cc
index 28ced9d..c62fcce 100644
--- a/iree/base/api.cc
+++ b/iree/base/api.cc
@@ -199,6 +199,40 @@
   return offset;
 }
 
+static bool MatchPattern(absl::string_view value, absl::string_view pattern) {
+  size_t next_char_index = pattern.find_first_of("*?");
+  if (next_char_index == std::string::npos) {
+    return value == pattern;
+  } else if (next_char_index > 0) {
+    if (value.substr(0, next_char_index) !=
+        pattern.substr(0, next_char_index)) {
+      return false;
+    }
+    value = value.substr(next_char_index);
+    pattern = pattern.substr(next_char_index);
+  }
+  char pattern_char = pattern[0];
+  if (value.empty() && pattern.empty()) {
+    return true;
+  } else if (pattern_char == '*' && pattern.size() > 1 && value.empty()) {
+    return false;
+  } else if (pattern_char == '*' && pattern.size() == 1) {
+    return true;
+  } else if (pattern_char == '?' || value[0] == pattern_char) {
+    return MatchPattern(value.substr(1), pattern.substr(1));
+  } else if (pattern_char == '*') {
+    return MatchPattern(value, pattern.substr(1)) ||
+           MatchPattern(value.substr(1), pattern);
+  }
+  return false;
+}
+
+IREE_API_EXPORT bool IREE_API_CALL iree_string_view_match_pattern(
+    iree_string_view_t value, iree_string_view_t pattern) {
+  return MatchPattern(absl::string_view(value.data, value.size),
+                      absl::string_view(pattern.data, pattern.size));
+}
+
 //===----------------------------------------------------------------------===//
 // iree::FileMapping
 //===----------------------------------------------------------------------===//
diff --git a/iree/base/api.h b/iree/base/api.h
index a35bb4d..3d19268 100644
--- a/iree/base/api.h
+++ b/iree/base/api.h
@@ -402,6 +402,17 @@
     iree_string_view_t value, char split_char, iree_string_view_t* out_lhs,
     iree_string_view_t* out_rhs);
 
+// Returns true if the given |value| matches |pattern| (normal * and ? rules).
+// This accepts wildcards in the form of '*' and '?' for any delimited value.
+// '*' will match zero or more of any character and '?' will match exactly one
+// of any character.
+//
+// For example,
+// 'foo-*-bar' matches: 'foo-123-bar', 'foo-456-789-bar'
+// 'foo-10?' matches: 'foo-101', 'foo-102'
+IREE_API_EXPORT bool IREE_API_CALL iree_string_view_match_pattern(
+    iree_string_view_t value, iree_string_view_t pattern);
+
 #endif  // IREE_API_NO_PROTOTYPES
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp
index 0347ae3..2bfa8ee 100644
--- a/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp
+++ b/iree/compiler/Dialect/Flow/Analysis/Dispatchability.cpp
@@ -111,7 +111,7 @@
       } else if (isa<CallIndirectOp>(op)) {
         // Indirect calls are not supported and must first be devirtualized.
         return false;
-      } else if (isa<ReturnOp>(op)) {
+      } else if (isa<mlir::ReturnOp>(op)) {
         // TODO(benvanik): widen to all known terminators? sometimes they may
         // have side-effects.
         continue;
diff --git a/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
index 4dc0cfe..11da5c9 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FlattenTuplesInCFG.cpp
@@ -126,7 +126,7 @@
   return false;
 }
 
-bool convertReturnOp(ReturnOp *op, OpBuilder &builder,
+bool convertReturnOp(mlir::ReturnOp *op, OpBuilder &builder,
                      BlockAndValueMapping *mapping) {
   llvm::SmallVector<Value, 10> newOperands;
   if (untupleAndLookupValues(op->getOperands(), &newOperands, builder,
@@ -134,7 +134,7 @@
     return true;
   }
 
-  builder.create<ReturnOp>(op->getLoc(), newOperands);
+  builder.create<mlir::ReturnOp>(op->getLoc(), newOperands);
   return false;
 }
 
@@ -226,7 +226,7 @@
 
 bool convertOperation(Operation *op, OpBuilder &builder,
                       BlockAndValueMapping *mapping) {
-  if (auto returnOp = dyn_cast<ReturnOp>(op)) {
+  if (auto returnOp = dyn_cast<mlir::ReturnOp>(op)) {
     return convertReturnOp(&returnOp, builder, mapping);
   } else if (auto callOp = dyn_cast<CallOp>(op)) {
     return convertCallOp(&callOp, builder, mapping);
diff --git a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
index 0dab859..4251d12 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FoldCompatibleDispatchRegions.cpp
@@ -39,7 +39,7 @@
 namespace {
 
 // Replaces |returnOp| with a clone including |newOperands| appended.
-LogicalResult appendReturnOperands(ReturnOp returnOp,
+LogicalResult appendReturnOperands(IREE::Flow::ReturnOp returnOp,
                                    ArrayRef<Value> newOperands) {
   // Insert prior to the original return.
   OpBuilder builder(returnOp);
@@ -49,7 +49,7 @@
   operands.reserve(returnOp.getNumOperands() + newOperands.size());
   operands.append(returnOp.operand_begin(), returnOp.operand_end());
   operands.append(newOperands.begin(), newOperands.end());
-  builder.create<ReturnOp>(returnOp.getLoc(), operands);
+  builder.create<IREE::Flow::ReturnOp>(returnOp.getLoc(), operands);
 
   // Remove original.
   returnOp.erase();
@@ -100,7 +100,7 @@
 DispatchRegionOp removeUnusedResults(DispatchRegionOp regionOp) {
   // Find return value within the region.
   auto &regionBlock = regionOp.body().getBlocks().front();
-  auto returnOp = dyn_cast<ReturnOp>(regionBlock.getTerminator());
+  auto returnOp = dyn_cast<IREE::Flow::ReturnOp>(regionBlock.getTerminator());
   if (!returnOp) {
     regionBlock.getParent()->getParentOfType<FuncOp>().emitError()
         << "block does not contain an flow.return op";
@@ -213,7 +213,7 @@
 
   // Find the values used as return values in the lhs.
   // We'll need to replace the uses in rhs with these.
-  auto lhsReturnOp = cast<ReturnOp>(lhsBlock.getTerminator());
+  auto lhsReturnOp = cast<IREE::Flow::ReturnOp>(lhsBlock.getTerminator());
   SmallVector<Value, 8> lhsReturnValues;
   lhsReturnValues.reserve(lhsReturnOp.getNumOperands());
   lhsReturnValues.append(lhsReturnOp.operand_begin(),
@@ -221,7 +221,7 @@
 
   // Find the values used as return values in the rhs.
   // We'll add these to the results of the lhs region.
-  auto rhsReturnOp = cast<ReturnOp>(rhsBlock.getTerminator());
+  auto rhsReturnOp = cast<IREE::Flow::ReturnOp>(rhsBlock.getTerminator());
   SmallVector<Value, 8> rhsReturnValues;
   rhsReturnValues.reserve(rhsReturnOp.getNumOperands());
   rhsReturnValues.append(rhsReturnOp.operand_begin(),
diff --git a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp
index 592d145..40addc2 100644
--- a/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/FormStreams.cpp
@@ -275,7 +275,7 @@
     for (auto *op : streamOps) {
       fragmentBuilder.clone(*op, mapping);
     }
-    fragmentBuilder.create<ReturnOp>(
+    fragmentBuilder.create<IREE::Flow::ReturnOp>(
         UnknownLoc::get(context),
         llvm::to_vector<8>(llvm::map_range(fragmentResults, [&](Value value) {
           return mapping.lookup(value);
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
index d33a735..cc59fb4 100644
--- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertVariableOps.cpp
@@ -49,7 +49,7 @@
 
   // Create const and return ops.
   auto constValue = rewriter.create<ConstantOp>(loc, immediateElements);
-  rewriter.create<ReturnOp>(loc, constValue.getResult());
+  rewriter.create<mlir::ReturnOp>(loc, constValue.getResult());
   return initializerFuncOp;
 }
 
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp
index 24c4feb..4131dc2 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp
@@ -25,6 +25,8 @@
                                    OwningRewritePatternList &patterns) {
   patterns.insert<VMImportOpConversion<IREE::HAL::DeviceAllocatorOp>>(
       context, importSymbols, typeConverter, "hal.device.allocator");
+  patterns.insert<VMImportOpConversion<IREE::HAL::DeviceMatchIDOp>>(
+      context, importSymbols, typeConverter, "hal.device.match.id");
 }
 
 }  // namespace iree_compiler
diff --git a/iree/compiler/Dialect/HAL/IR/HALBase.td b/iree/compiler/Dialect/HAL/IR/HALBase.td
index 04fd282..65b2aac 100644
--- a/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -554,6 +554,36 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Expression matching attributes
+//===----------------------------------------------------------------------===//
+
+def HAL_MatchAlwaysAttr :
+  IREE_StructAttr<"match.always", "MatchAlwaysAttr", HAL_Dialect, []> {
+  let cppNamespace = "mlir::iree_compiler::IREE::HAL";
+}
+
+def HAL_MatchAnyAttr :
+  IREE_StructAttr<"match.any", "MatchAnyAttr", HAL_Dialect, [
+    IREE_StructFieldAttr<"conditions", AnyAttr>,
+  ]> {
+  let cppNamespace = "mlir::iree_compiler::IREE::HAL";
+}
+
+def HAL_MatchAllAttr :
+  IREE_StructAttr<"match.all", "MatchAllAttr", HAL_Dialect, [
+    IREE_StructFieldAttr<"conditions", AnyAttr>,
+  ]> {
+  let cppNamespace = "mlir::iree_compiler::IREE::HAL";
+}
+
+def HAL_DeviceMatchIDAttr :
+  IREE_StructAttr<"device.match.id", "DeviceMatchIDAttr", HAL_Dialect, [
+    IREE_StructFieldAttr<"pattern", StrAttr>,
+  ]> {
+  let cppNamespace = "mlir::iree_compiler::IREE::HAL";
+}
+
+//===----------------------------------------------------------------------===//
 // Base HAL op classes
 //===----------------------------------------------------------------------===//
 
diff --git a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
index 8bb4b14..66adfec 100644
--- a/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
@@ -20,6 +20,7 @@
 #include "iree/compiler/Dialect/HAL/hal.imports.h"
 #include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
 #include "llvm/Support/SourceMgr.h"
+#include "mlir/ADT/TypeSwitch.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Parser.h"
@@ -65,7 +66,8 @@
     : Dialect(getDialectNamespace(), context) {
   addInterfaces<HALToVMConversionInterface>();
 
-  addAttributes<DescriptorSetLayoutBindingAttr>();
+  addAttributes<DescriptorSetLayoutBindingAttr, MatchAlwaysAttr, MatchAnyAttr,
+                MatchAllAttr, DeviceMatchIDAttr>();
 
   addTypes<AllocatorType, BufferType, BufferViewType, CommandBufferType,
            DescriptorSetType, DescriptorSetLayoutType, DeviceType, EventType,
@@ -88,6 +90,14 @@
   if (failed(parser.parseKeyword(&attrKind))) return {};
   if (attrKind == DescriptorSetLayoutBindingAttr::getKindName()) {
     return DescriptorSetLayoutBindingAttr::parse(parser);
+  } else if (attrKind == MatchAlwaysAttr::getKindName()) {
+    return MatchAlwaysAttr::parse(parser);
+  } else if (attrKind == MatchAnyAttr::getKindName()) {
+    return MatchAnyAttr::parse(parser);
+  } else if (attrKind == MatchAllAttr::getKindName()) {
+    return MatchAllAttr::parse(parser);
+  } else if (attrKind == DeviceMatchIDAttr::getKindName()) {
+    return DeviceMatchIDAttr::parse(parser);
   }
   parser.emitError(parser.getNameLoc())
       << "unknown HAL attribute: " << attrKind;
@@ -95,11 +105,12 @@
 }
 
 void HALDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const {
-  if (auto typedAttr = attr.dyn_cast<DescriptorSetLayoutBindingAttr>()) {
-    typedAttr.print(p);
-  } else {
-    llvm_unreachable("unhandled HAL attribute kind");
-  }
+  TypeSwitch<Attribute>(attr)
+      .Case<DescriptorSetLayoutBindingAttr, MatchAlwaysAttr, MatchAnyAttr,
+            MatchAllAttr, DeviceMatchIDAttr>(
+          [&](auto typedAttr) { typedAttr.print(p); })
+      .Default(
+          [](Attribute) { llvm_unreachable("unhandled HAL attribute kind"); });
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index 7477df5..efdd27a 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -255,6 +255,22 @@
   results.insert<SkipBufferViewBufferOp>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// hal.device.switch
+//===----------------------------------------------------------------------===//
+
+// TODO(benvanik): fold conditions with the same IR tree.
+// TODO(benvanik): remove duplicate conditions.
+// TODO(benvanik): fold condition expressions (any(always, ...) -> always, etc).
+// TODO(benvanik): completely replace switches with just one always block.
+// TODO(benvanik): remove conditions with no side-effects.
+
+//===----------------------------------------------------------------------===//
+// hal.device.match.id
+//===----------------------------------------------------------------------===//
+
+// TODO(benvanik): fold matches that are known true based on device config.
+
 }  // namespace HAL
 }  // namespace IREE
 }  // namespace iree_compiler
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 9071527..68d595d 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -962,6 +962,176 @@
 }
 
 //===----------------------------------------------------------------------===//
+// hal.device.switch
+//===----------------------------------------------------------------------===//
+
+void DeviceSwitchOp::build(Builder *builder, OperationState &state,
+                           TypeRange resultTypes, Value device,
+                           ArrayRef<Attribute> conditions,
+                           ArrayRef<ValueRange> conditionArgs,
+                           ArrayRef<NamedAttribute> attributes) {
+  state.addOperands({device});
+  state.addAttribute("conditions", builder->getArrayAttr(conditions));
+  for (auto args : conditionArgs) {
+    state.addOperands(args);
+    auto *region = state.addRegion();
+    auto *entryBlock = OpBuilder(region).createBlock(region);
+    for (auto arg : args) {
+      entryBlock->addArgument(arg.getType());
+    }
+  }
+  state.addTypes(resultTypes);
+  state.addAttributes(attributes);
+  state.resizableOperandList = true;
+}
+
+static ParseResult parseDeviceSwitchOp(OpAsmParser &parser,
+                                       OperationState *result) {
+  OpAsmParser::OperandType device;
+  Type deviceType;
+  if (failed(parser.parseLParen()) || failed(parser.parseOperand(device)) ||
+      failed(parser.parseColonType(deviceType)) ||
+      failed(parser.resolveOperand(device, deviceType, result->operands)) ||
+      failed(parser.parseRParen()) ||
+      failed(parser.parseOptionalArrowTypeList(result->types))) {
+    return failure();
+  }
+
+  // Parses each switch condition attribute and region, like:
+  // #hal.device.match.id<"vulkan-v1.?-*">(%c1a = %c1 : i32) {
+  //   hal.return %c1a : i32
+  // }, ...
+  result->setOperandListToResizable();
+  SmallVector<Attribute, 4> conditionAttrs;
+  do {
+    Attribute conditionAttr;
+    SmallVector<NamedAttribute, 1> dummyAttrs;
+    if (failed(parser.parseAttribute(conditionAttr, "condition", dummyAttrs)) ||
+        failed(parser.parseLParen())) {
+      return failure();
+    }
+    conditionAttrs.push_back(conditionAttr);
+    SmallVector<OpAsmParser::OperandType, 16> regionArgs;
+    SmallVector<Type, 16> regionArgTypes;
+    if (failed(parser.parseOptionalRParen())) {
+      SmallVector<OpAsmParser::OperandType, 16> regionOperands;
+      auto argsLoc = parser.getCurrentLocation();
+      do {
+        // Reserve entries in the lists.
+        regionArgs.emplace_back();
+        regionOperands.emplace_back();
+        regionArgTypes.emplace_back();
+        if (failed(parser.parseRegionArgument(regionArgs.back())) ||
+            failed(parser.parseEqual()) ||
+            failed(parser.parseOperand(regionOperands.back())) ||
+            failed(parser.parseColonType(regionArgTypes.back()))) {
+          return failure();
+        }
+      } while (succeeded(parser.parseOptionalComma()));
+      if (failed(parser.parseRParen()) ||
+          failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc,
+                                        result->operands))) {
+        return failure();
+      }
+    }
+    auto *regionBody = result->addRegion();
+    if (failed(parser.parseRegion(*regionBody, regionArgs, regionArgTypes))) {
+      return failure();
+    }
+  } while (succeeded(parser.parseOptionalComma()));
+  result->addAttribute("conditions",
+                       ArrayAttr::get(conditionAttrs, result->getContext()));
+
+  if (failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) {
+    return failure();
+  }
+  return success();
+}
+
+static void printDeviceSwitchOp(OpAsmPrinter &p, DeviceSwitchOp op) {
+  p << op.getOperationName() << "(";
+  p.printOperand(op.device());
+  p << " : ";
+  p.printType(op.device().getType());
+  p << ")";
+  p.printOptionalArrowTypeList(op.getResultTypes());
+  p << "\n";
+  p.getStream().indent(4);
+  int argOffset = 0;
+  interleave(
+      llvm::zip(op.conditions(), op.condition_regions()),
+      [&](std::tuple<Attribute, Region &> it) {
+        auto &conditionAttr = std::get<0>(it);
+        auto &conditionRegion = std::get<1>(it);
+        p.printAttribute(conditionAttr);
+        p << "(";
+        auto regionOperands = conditionRegion.front().getArguments();
+        auto regionArgs = op.args().slice(argOffset, regionOperands.size());
+        argOffset += regionOperands.size();
+        // TODO(benvanik): figure out how to parse with shadowing.
+        // p.shadowRegionArgs(conditionRegion, regionArgs);
+        interleaveComma(llvm::zip(regionOperands, regionArgs), p,
+                        [&](std::tuple<BlockArgument, Value> it) {
+                          p << std::get<0>(it) << " = " << std::get<1>(it);
+                          p << " : ";
+                          p << std::get<1>(it).getType();
+                        });
+        p << ")";
+        p.printRegion(conditionRegion,
+                      /*printEntryBlockArgs=*/false,
+                      /*printBlockTerminators=*/true);
+      },
+      [&]() {
+        p << ",\n";
+        p.getStream().indent(4);
+      });
+  p.printOptionalAttrDictWithKeyword(op.getAttrs(),
+                                     /*elidedAttrs=*/{"conditions"});
+}
+
+static LogicalResult verifyDeviceSwitchOp(DeviceSwitchOp op) {
+  if (op.conditions().size() != op.condition_regions().size()) {
+    return op.emitOpError() << "requires conditions and regions be matched 1:1";
+  } else if (op.condition_regions().empty()) {
+    return op.emitOpError() << "requires at least one condition";
+  }
+  int argOffset = 0;
+  for (auto &region : op.condition_regions()) {
+    auto regionOperands = region.front().getArguments();
+    auto regionArgs = op.args().slice(argOffset, regionOperands.size());
+    argOffset += regionOperands.size();
+
+    for (auto it : llvm::zip(regionArgs, regionOperands)) {
+      auto regionArg = std::get<0>(it);
+      auto regionOperand = std::get<1>(it);
+      if (regionArg.getType() != regionOperand.getType()) {
+        return op.emitOpError() << "requires that regions have their arguments "
+                                   "represented in the op arg list in order ("
+                                << regionArg.getType()
+                                << " != " << regionOperand.getType() << ")";
+      }
+    }
+
+    for (auto &block : region) {
+      if (auto returnOp =
+              dyn_cast_or_null<IREE::HAL::ReturnOp>(block.getTerminator())) {
+        if (!std::equal(returnOp.getOperandTypes().begin(),
+                        returnOp.getOperandTypes().end(),
+                        op.getResultTypes().begin())) {
+          return op.emitOpError()
+                 << "requires all regions return the same types";
+        }
+      }
+    }
+  }
+  if (argOffset != op.args().size()) {
+    return op.emitOpError() << "requires that the total argument list matches "
+                               "the sum of all region operands";
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // hal.executable
 //===----------------------------------------------------------------------===//
 
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 26ae91d..2e9f148 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1317,7 +1317,137 @@
   ];
 }
 
+def HAL_DeviceSwitchOp : HAL_Op<"device.switch", [IsolatedFromAbove]> {
+  let summary = [{runtime device switch pseudo op}];
+  let description = [{
+    Switches between multiple regions based on the runtime device type.
+    The provided regions are matched against the runtime backend of the given
+    device and executed only when the device matches the conditions.
+
+    Conditions can match on wildcards and be folded to enable conditions that
+    have similar bodies to be folded. The patterns themselves are only matched
+    once at startup and then the results are cached; the runtime overhead is
+    equivalent to a normal switch statement. In cases where the compiler can
+    statically identify the device type entire cases can be folded away.
+
+    Supported conditions:
+    * `#hal.match...`: execute the region if the expression matches.
+
+    Supported match expressions:
+    * `#hal.match.always`: always matches; useful for defaults.
+    * `#hal.match.any<[...]>`: matches if any of the nested expressions match.
+    * `#hal.match.all<[...]>`: matches only if all of the nested expressions
+      match.
+    * `#hal.device.match.id<"pattern*-?-*">`: matches against the device
+      identifier. The pattern is evaluated with standard file path wildcards
+      (`*` for zero or more characters and `?` for one character).
+
+    If more than one condition is satisfied the first listed will be chosen.
+    More specific conditions should be earlier in the set. If no condition is
+    matched but there are return values the switch will abort at runtime. It's
+    strongly recommend that all switches that return values end with a trailing
+    `#hal.match.always` condition to handle the fallthrough case.
+
+    Upon creation each condition region will have an empty entry block with the
+    specified operands available as arguments. Each region must be setup to
+    return the same types.
+
+    ```mlir
+    %c0 = constant 0 : i32
+    %c1 = constant 1 : i32
+    %c2 = constant 2 : i32
+    %device = ... : !hal.device
+    %0 = hal.device.switch(%device : !hal.device) -> i32
+      #hal.device.match.id<"vulkan-v1.?-*">(%c1a = %c1 : i32) {
+        hal.return %c1a : i32
+      },
+      #hal.match.any<[#hal.device.match.id<"vmla">, #hal.device.match.id<"vulkan-*">]>(%c2a = %c2 : i32) {
+        hal.return %c2a : i32
+      },
+      #hal.match.always(%c0a = %c0 : i32) {
+        hal.return %c0a : i32
+      }
+    ```
+  }];
+
+  let arguments = (ins
+    HAL_Device:$device,
+    ArrayAttr:$conditions,
+    Variadic<AnyType>:$args
+  );
+  let results = (outs
+    Variadic<AnyType>:$results
+  );
+
+  let regions = (region VariadicRegion<AnyRegion>:$condition_regions);
+
+  let extraClassDeclaration = [{
+    /// Returns the index of the args() operand in the Operation operands list.
+    unsigned mapArgOperandToOpOperand(unsigned i) { return i + 1; }
+  }];
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<[{
+      Builder *builder, OperationState &state, TypeRange resultTypes,
+      Value device, ArrayRef<Attribute> conditions,
+      ArrayRef<ValueRange> conditionArgs,
+      ArrayRef<NamedAttribute> attributes = {}
+    }]>,
+  ];
+
+  let verifier = [{ return verifyDeviceSwitchOp(*this); }];
+}
+
+def HAL_ReturnOp : HAL_Op<"return", [Terminator]> {
+  let summary = [{return from a hal.device.switch region}];
+  let description = [{
+    Returns the given values from the region and back to the host code.
+  }];
+
+  let arguments = (ins
+    Variadic<AnyType>:$operands
+  );
+
+  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+
+  let builders = [
+    OpBuilder<[{
+      Builder *builder, OperationState &result
+    }], [{
+      build(builder, result, llvm::None);
+    }]>,
+  ];
+}
+
 // TODO(benvanik): additional factory functions and submission ops.
+// TODO(benvanik): %0 = hal.device.query %device, group, property : i32/etc
+
+def HAL_DeviceMatchIDOp : HAL_PureOp<"device.match.id"> {
+  let summary = [{returns true if the device ID matches the pattern}];
+  let description = [{
+    Pattern matches the device ID with the given wildcard pattern.
+    This can be used to conditionally evaluate device-specific code when the
+    device is not known at compile-time.
+
+    ```mlir
+    %is_match = hal.device.match.id %device, pattern = ["vulkan-*"] : (!hal.device) -> i1
+    ```
+  }];
+
+  let arguments = (ins
+    HAL_Device:$device,
+    StrAttr:$pattern
+  );
+  let results = (outs
+    I1:$result
+  );
+
+  let assemblyFormat = [{
+    $device `,` `pattern` `=` `[` $pattern `]` attr-dict
+    `:` `(` type($device) `)` `->` type($result)
+  }];
+}
 
 //===----------------------------------------------------------------------===//
 // iree::hal::Executable
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
index 37e00dc..b15fd4b 100644
--- a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -134,6 +134,81 @@
   os << ">";
 }
 
+// static
+Attribute MatchAlwaysAttr::parse(DialectAsmParser &p) {
+  return get(p.getBuilder().getContext());
+}
+
+void MatchAlwaysAttr::print(DialectAsmPrinter &p) const {
+  auto &os = p.getStream();
+  os << getKindName();
+}
+
+static ArrayAttr parseMultiMatchAttrArray(DialectAsmParser &p) {
+  auto b = p.getBuilder();
+  SmallVector<Attribute, 4> conditionAttrs;
+  if (failed(p.parseLess()) || failed(p.parseLSquare())) {
+    return {};
+  }
+  do {
+    Attribute conditionAttr;
+    if (failed(p.parseAttribute(conditionAttr))) {
+      return {};
+    }
+    conditionAttrs.push_back(conditionAttr);
+  } while (succeeded(p.parseOptionalComma()));
+  if (failed(p.parseRSquare()) || failed(p.parseGreater())) {
+    return {};
+  }
+  return b.getArrayAttr(conditionAttrs);
+}
+
+static void printMultiMatchAttrList(ArrayAttr conditionAttrs,
+                                    DialectAsmPrinter &p) {
+  auto &os = p.getStream();
+  os << "<[";
+  interleaveComma(conditionAttrs, os,
+                  [&](Attribute condition) { os << condition; });
+  os << "]>";
+}
+
+// static
+Attribute MatchAnyAttr::parse(DialectAsmParser &p) {
+  return get(parseMultiMatchAttrArray(p));
+}
+
+void MatchAnyAttr::print(DialectAsmPrinter &p) const {
+  p << getKindName();
+  printMultiMatchAttrList(conditions().cast<ArrayAttr>(), p);
+}
+
+// static
+Attribute MatchAllAttr::parse(DialectAsmParser &p) {
+  return get(parseMultiMatchAttrArray(p));
+}
+
+void MatchAllAttr::print(DialectAsmPrinter &p) const {
+  p << getKindName();
+  printMultiMatchAttrList(conditions().cast<ArrayAttr>(), p);
+}
+
+// static
+Attribute DeviceMatchIDAttr::parse(DialectAsmParser &p) {
+  StringAttr patternAttr;
+  if (failed(p.parseLess()) || failed(p.parseAttribute(patternAttr)) ||
+      failed(p.parseGreater())) {
+    return {};
+  }
+  return get(patternAttr);
+}
+
+void DeviceMatchIDAttr::print(DialectAsmPrinter &p) const {
+  auto &os = p.getStream();
+  os << getKindName() << "<\"";
+  os << pattern();
+  os << "\">";
+}
+
 #include "iree/compiler/Dialect/HAL/IR/HALOpInterface.cpp.inc"
 
 }  // namespace HAL
diff --git a/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir
index d40e27f..932f6f7 100644
--- a/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir
@@ -1,6 +1,4 @@
-// Tests printing and parsing of hal.device ops.
-
-// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
+// RUN: iree-opt -print-ir-after-all -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
 
 // CHECK-LABEL: @device_allocator
 func @device_allocator() -> !hal.allocator {
@@ -9,3 +7,48 @@
   %allocator = hal.device.allocator %0 : !hal.allocator
   return %allocator : !hal.allocator
 }
+
+// -----
+
+// CHECK-LABEL: @device_switch
+func @device_switch() -> i32 {
+  // CHECK-DAG: %[[C0:.+]] = constant 0
+  %c0 = constant 0 : i32
+  // CHECK-DAG: %[[C1:.+]] = constant 1
+  %c1 = constant 1 : i32
+  // CHECK-DAG: %[[C2:.+]] = constant 2
+  %c2 = constant 2 : i32
+  // CHECK-DAG: %[[DEVICE:.+]] = "test_hal.device"
+  %device = "test_hal.device"() : () -> !hal.device
+  // CHECK: = hal.device.switch(%[[DEVICE]] : !hal.device) -> i32
+  %0 = hal.device.switch(%device : !hal.device) -> i32
+    // CHECK-NEXT: #hal.device.match.id<"vulkan-v1.?-*">(%[[C1A:.+]] = %[[C1]] : i32) {
+    #hal.device.match.id<"vulkan-v1.?-*">(%c1a = %c1 : i32) {
+      // CHECK-NEXT: hal.return %[[C1A]] : i32
+      hal.return %c1a : i32
+      // CHECK-NEXT: },
+    },
+    // CHECK-NEXT: #hal.match.any<[#hal.device.match.id<"vmla">, #hal.device.match.id<"vulkan-*">]>(%[[C2A:.+]] = %[[C2]] : i32) {
+    #hal.match.any<[#hal.device.match.id<"vmla">, #hal.device.match.id<"vulkan-*">]>(%c2a = %c2 : i32) {
+      // CHECK-NEXT: hal.return %[[C2A]] : i32
+      hal.return %c2a : i32
+      // CHECK-NEXT: },
+    },
+    // CHECK-NEXT: #hal.match.always(%[[C0A:.+]] = %[[C0]] : i32) {
+    #hal.match.always(%c0a = %c0 : i32) {
+      // CHECK-NEXT: hal.return %[[C0A]] : i32
+      hal.return %c0a : i32
+      // CHECK-NEXT: }
+    }
+  return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @device_matchers
+// CHECK-SAME: %[[DEVICE:.+]]: !hal.device
+func @device_matchers(%device : !hal.device) -> i1 {
+  // CHECK: = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-*"] : (!hal.device) -> i1
+  %0 = hal.device.match.id %device, pattern = ["vulkan-*"] : (!hal.device) -> i1
+  return %0 : i1
+}
diff --git a/iree/compiler/Dialect/HAL/Transforms/BUILD b/iree/compiler/Dialect/HAL/Transforms/BUILD
index 325aedf..5da3750 100644
--- a/iree/compiler/Dialect/HAL/Transforms/BUILD
+++ b/iree/compiler/Dialect/HAL/Transforms/BUILD
@@ -22,6 +22,8 @@
     srcs = [
         "MaterializeInterfaces.cpp",
         "MaterializeResourceCaches.cpp",
+        "MemoizeDeviceQueries.cpp",
+        "OutlineDeviceSwitches.cpp",
         "Passes.cpp",
         "PublicAbiGeneration.cpp",
         "RewriteLegacyIO.cpp",
diff --git a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index ddd3cb3..9b91962 100644
--- a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -20,8 +20,10 @@
   HDRS
     "Passes.h"
   SRCS
+    "MemoizeDeviceQueries.cpp"
     "MaterializeInterfaces.cpp"
     "MaterializeResourceCaches.cpp"
+    "OutlineDeviceSwitches.cpp"
     "Passes.cpp"
     "PublicAbiGeneration.cpp"
     "RewriteLegacyIO.cpp"
diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index 4ffb8df..caae5e5 100644
--- a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -134,7 +134,7 @@
     auto layoutUsage = IREE::HAL::DescriptorSetLayoutUsageType::PushOnly;
     auto layoutValue = blockBuilder.createOrFold<DescriptorSetLayoutCreateOp>(
         loc, layoutType, deviceValue, layoutUsage, bindingsAttr);
-    blockBuilder.create<ReturnOp>(loc, layoutValue);
+    blockBuilder.create<mlir::ReturnOp>(loc, layoutValue);
 
     return variableOp;
   }
@@ -189,7 +189,7 @@
     }
     auto layoutValue = blockBuilder.createOrFold<ExecutableLayoutCreateOp>(
         loc, layoutType, deviceValue, setLayoutValues, pushConstantsAttr);
-    blockBuilder.create<ReturnOp>(loc, layoutValue);
+    blockBuilder.create<mlir::ReturnOp>(loc, layoutValue);
 
     return variableOp;
   }
@@ -250,7 +250,7 @@
                                            executableVariableOp.sym_name());
     }
 
-    blockBuilder.create<ReturnOp>(loc, executableCacheValue);
+    blockBuilder.create<mlir::ReturnOp>(loc, executableCacheValue);
 
     return variableOp;
   }
diff --git a/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
new file mode 100644
index 0000000..83b1884
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
@@ -0,0 +1,114 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "llvm/ADT/StringSet.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// NOTE: this implementation is just for a single active device. As we start to
+// support multiple devices we'll need to change this to be per-device.
+class MemoizeDeviceQueriesPass
+    : public PassWrapper<MemoizeDeviceQueriesPass, OperationPass<ModuleOp>> {
+ public:
+  void runOnOperation() override {
+    // Find all match ops we want to memoize and group them together.
+    // This lets us easily replace all usages of a match with a single variable.
+    DenseMap<Attribute, std::vector<IREE::HAL::DeviceMatchIDOp>>
+        deviceIDMatchOps;
+    SmallVector<Attribute, 4> deviceIDMatchKeys;
+    auto moduleOp = getOperation();
+    for (auto funcOp : moduleOp.getOps<FuncOp>()) {
+      funcOp.walk([&](IREE::HAL::DeviceMatchIDOp matchOp) {
+        auto key = matchOp.patternAttr().cast<Attribute>();
+        auto lookup = deviceIDMatchOps.try_emplace(
+            key, std::vector<IREE::HAL::DeviceMatchIDOp>{});
+        if (lookup.second) {
+          deviceIDMatchKeys.push_back(key);
+        }
+        lookup.first->second.push_back(matchOp);
+        return WalkResult::advance();
+      });
+    }
+
+    // Create each match variable and replace the uses with loads.
+    auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
+    for (auto matchKey : llvm::enumerate(deviceIDMatchKeys)) {
+      auto matchOps = deviceIDMatchOps[matchKey.value()];
+      auto pattern = matchOps.front().pattern();
+
+      // Merge all the locs as we are deduping the original query ops.
+      auto fusedLoc = FusedLoc::get(
+          llvm::to_vector<4>(llvm::map_range(
+              matchOps, [&](Operation *op) { return op->getLoc(); })),
+          moduleOp.getContext());
+
+      // The initializer will perform the query once and store it in the
+      // variable.
+      std::string variableName =
+          "_device_match_id_" + std::to_string(matchKey.index());
+      auto initializerOp = moduleBuilder.create<FuncOp>(
+          fusedLoc, variableName + "_initializer",
+          moduleBuilder.getFunctionType({}, {moduleBuilder.getI1Type()}),
+          ArrayRef<NamedAttribute>{});
+      SymbolTable::setSymbolVisibility(initializerOp,
+                                       SymbolTable::Visibility::Private);
+      auto variableOp = moduleBuilder.create<IREE::HAL::VariableOp>(
+          fusedLoc, variableName,
+          /*isMutable=*/false, initializerOp);
+      SymbolTable::setSymbolVisibility(variableOp,
+                                       SymbolTable::Visibility::Private);
+
+      auto funcBuilder = OpBuilder::atBlockBegin(initializerOp.addEntryBlock());
+      auto device =
+          funcBuilder.createOrFold<IREE::HAL::ExSharedDeviceOp>(fusedLoc);
+      auto matchOp = funcBuilder.create<IREE::HAL::DeviceMatchIDOp>(
+          fusedLoc, funcBuilder.getI1Type(), device, pattern);
+      funcBuilder.create<mlir::ReturnOp>(fusedLoc, matchOp.getResult());
+
+      for (auto matchOp : matchOps) {
+        OpBuilder replaceBuilder(matchOp);
+        auto loadOp = replaceBuilder.create<IREE::HAL::VariableLoadOp>(
+            fusedLoc, matchOp.getResult().getType(), variableOp.getName());
+        matchOp.replaceAllUsesWith(loadOp.result());
+        matchOp.erase();
+      }
+    }
+  }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createMemoizeDeviceQueriesPass() {
+  return std::make_unique<MemoizeDeviceQueriesPass>();
+}
+
+static PassRegistration<MemoizeDeviceQueriesPass> pass(
+    "iree-hal-memoize-device-queries",
+    "Caches hal.device.query results for use across the entire module");
+
+}  // namespace HAL
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Transforms/OutlineDeviceSwitches.cpp b/iree/compiler/Dialect/HAL/Transforms/OutlineDeviceSwitches.cpp
new file mode 100644
index 0000000..c4f4211
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/OutlineDeviceSwitches.cpp
@@ -0,0 +1,235 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "llvm/ADT/StringSet.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+
+// Given a #hal.match.* expression tree returns a boolean value indicating
+// whether the expression evaluates to true.
+static Value buildConditionExpression(Location loc, Value device,
+                                      Attribute conditionAttr,
+                                      OpBuilder funcBuilder) {
+  if (auto matchAttr = conditionAttr.dyn_cast<IREE::HAL::MatchAlwaysAttr>()) {
+    // #hal.match.always -> true
+    return funcBuilder.createOrFold<ConstantIntOp>(loc, 1, 1);
+  } else if (auto matchAttr =
+                 conditionAttr.dyn_cast<IREE::HAL::MatchAnyAttr>()) {
+    // #hal.match.any<[a, b, c]> -> or(or(a, b), c)
+    auto conditionAttrs = matchAttr.conditions().cast<ArrayAttr>();
+    auto conditionValues =
+        llvm::to_vector<4>(llvm::map_range(conditionAttrs, [&](Attribute attr) {
+          return buildConditionExpression(loc, device, attr, funcBuilder);
+        }));
+    Value resultValue = conditionValues[0];
+    for (int i = 1; i < conditionValues.size(); ++i) {
+      resultValue =
+          funcBuilder.createOrFold<OrOp>(loc, resultValue, conditionValues[i]);
+    }
+    return resultValue;
+  } else if (auto matchAttr =
+                 conditionAttr.dyn_cast<IREE::HAL::MatchAllAttr>()) {
+    // #hal.match.all<[a, b, c]> -> and(and(a, b), c)
+    auto conditionAttrs = matchAttr.conditions().cast<ArrayAttr>();
+    auto conditionValues =
+        llvm::to_vector<4>(llvm::map_range(conditionAttrs, [&](Attribute attr) {
+          return buildConditionExpression(loc, device, attr, funcBuilder);
+        }));
+    Value resultValue = conditionValues[0];
+    for (int i = 1; i < conditionValues.size(); ++i) {
+      resultValue =
+          funcBuilder.createOrFold<AndOp>(loc, resultValue, conditionValues[i]);
+    }
+    return resultValue;
+  } else if (auto matchAttr =
+                 conditionAttr.dyn_cast<IREE::HAL::DeviceMatchIDAttr>()) {
+    // #hal.device.match.id<"pattern"> -> hal.device.match.id
+    return funcBuilder.createOrFold<IREE::HAL::DeviceMatchIDOp>(
+        loc, funcBuilder.getI1Type(), device, matchAttr.patternAttr());
+  }
+  llvm_unreachable("unhandled condition expression attribute");
+  return {};
+}
+
+// Outlines a condition region from a switch op into a standalone function.
+static FuncOp outlineConditionRegion(StringRef funcName,
+                                     Region &conditionRegion,
+                                     ArrayRef<Type> resultTypes,
+                                     OpBuilder moduleBuilder) {
+  auto &entryBlock = conditionRegion.front();
+  auto funcType = moduleBuilder.getFunctionType(
+      llvm::to_vector<4>(
+          llvm::map_range(entryBlock.getArguments(),
+                          [](BlockArgument arg) { return arg.getType(); })),
+      resultTypes);
+  auto funcOp = moduleBuilder.create<FuncOp>(
+      conditionRegion.getLoc(), funcName, funcType, ArrayRef<NamedAttribute>{});
+  funcOp.getBody().takeBody(conditionRegion);
+  SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
+
+  // Replace hal.return statements with normal std.return. This ensures that
+  // normal matchers/inlining/etc works as we continue transformation.
+  for (auto &block : funcOp.getBlocks()) {
+    if (auto returnOp = dyn_cast<IREE::HAL::ReturnOp>(block.back())) {
+      OpBuilder builder(returnOp);
+      builder.create<mlir::ReturnOp>(
+          returnOp.getLoc(), llvm::to_vector<4>(returnOp.getOperands()));
+      returnOp.erase();
+    }
+  }
+
+  return funcOp;
+}
+
+// Outlines each switch condition region into its own function and replaces the
+// switch op with conditioned calls to those functions.
+//
+// Since switch conditions are evaluated in the order they are defined we can
+// trivially turn the switch into a chain of if-else blocks.
+//   if condition_0_match:
+//     call outlined_condition_0
+//   else
+//     if condition_1_match:
+//       call outlined_condition_1
+//     else ...
+static void buildConditionDispatchTable(IREE::HAL::DeviceSwitchOp switchOp,
+                                        StringRef baseFuncName,
+                                        OpBuilder moduleBuilder,
+                                        OpBuilder funcBuilder) {
+  // Split the block containing the switch op such that all ops before the
+  // switch are before and the switch and the following ops are after.
+  // We'll have all of our outlined regions bounce over to the afterBlock with
+  // the results of the call and use that to replace the switch op.
+  auto *beforeBlock = funcBuilder.getBlock();
+  auto *afterBlock = beforeBlock->splitBlock(switchOp);
+  auto finalValues =
+      llvm::to_vector<4>(afterBlock->addArguments(switchOp.getResultTypes()));
+
+  // Create the blocks we'll use for all our conditions so that we can reference
+  // them when inserting the branch ops.
+  SmallVector<Block *, 4> conditionMatchBlocks(
+      switchOp.condition_regions().size());
+  SmallVector<Block *, 4> conditionFallthroughBlocks(
+      switchOp.condition_regions().size());
+  for (int i = 0; i < conditionMatchBlocks.size(); ++i) {
+    conditionMatchBlocks[i] = funcBuilder.createBlock(afterBlock);
+    conditionFallthroughBlocks[i] = funcBuilder.createBlock(afterBlock);
+  }
+
+  funcBuilder.setInsertionPoint(beforeBlock, beforeBlock->end());
+  int argOffset = 0;
+  for (auto condition : llvm::enumerate(llvm::zip(
+           switchOp.conditions().getValue(), switchOp.condition_regions()))) {
+    auto conditionAttr = std::get<0>(condition.value());
+    auto &conditionRegion = std::get<1>(condition.value());
+
+    // Get the arguments from the switch that we want to carry along in the
+    // block arguments.
+    auto regionOperands = conditionRegion.front().getArguments();
+    auto regionArgs = switchOp.args().slice(argOffset, regionOperands.size());
+    argOffset += regionOperands.size();
+
+    // Outline the region into a function.
+    std::string regionFuncName =
+        (baseFuncName + "_").str() + std::to_string(condition.index());
+    auto regionFuncOp = outlineConditionRegion(
+        regionFuncName, conditionRegion,
+        switchOp.getOperation()->getResultTypes(), moduleBuilder);
+
+    // Insert the branch based on the match. We either match and jump to a block
+    // that will call the function or don't match and need to fall through.
+    auto isMatch = buildConditionExpression(
+        switchOp.getLoc(), switchOp.device(), conditionAttr, funcBuilder);
+    auto *matchBlock = conditionMatchBlocks[condition.index()];
+    auto *fallthroughBlock = conditionFallthroughBlocks[condition.index()];
+    funcBuilder.create<CondBranchOp>(switchOp.getLoc(), isMatch, matchBlock,
+                                     fallthroughBlock);
+
+    // Block that calls the outlined function and then jumps out of the chain.
+    funcBuilder.setInsertionPointToStart(matchBlock);
+    auto matchResults =
+        funcBuilder.create<CallOp>(switchOp.getLoc(), regionFuncOp, regionArgs);
+    funcBuilder.create<BranchOp>(switchOp.getLoc(), afterBlock,
+                                 matchResults.getResults());
+
+    // Block that we enter to check the next condition.
+    funcBuilder.setInsertionPointToStart(fallthroughBlock);
+    if (condition.index() + 1 < conditionFallthroughBlocks.size()) {
+      // Just continue on - the next loop iteration for the following condition
+      // will add its IR to the block.
+    } else {
+      // Fallthrough of all expressions; die if we expected return values.
+      if (switchOp.getNumResults() > 0) {
+        funcBuilder.create<IREE::UnreachableOp>(switchOp.getLoc());
+      } else {
+        funcBuilder.create<BranchOp>(switchOp.getLoc(), afterBlock);
+      }
+    }
+  }
+
+  // Remove the switch op and replace its results with the final joined results.
+  switchOp.replaceAllUsesWith(finalValues);
+  switchOp.erase();
+}
+
+class OutlineDeviceSwitchesPass
+    : public PassWrapper<OutlineDeviceSwitchesPass, OperationPass<ModuleOp>> {
+ public:
+  void runOnOperation() override {
+    auto moduleOp = getOperation();
+    auto funcOps = llvm::to_vector<16>(moduleOp.getOps<FuncOp>());
+    for (auto &funcOp : funcOps) {
+      OpBuilder moduleBuilder(funcOp);
+      moduleBuilder.setInsertionPointAfter(funcOp);
+      for (auto &block : funcOp) {
+        auto switchOps =
+            llvm::to_vector<4>(block.getOps<IREE::HAL::DeviceSwitchOp>());
+        for (auto switchOp : llvm::enumerate(switchOps)) {
+          std::string baseFuncName = (funcOp.getName() + "_switch_").str() +
+                                     std::to_string(switchOp.index());
+          OpBuilder funcBuilder(switchOp.value());
+          buildConditionDispatchTable(switchOp.value(), baseFuncName,
+                                      moduleBuilder, funcBuilder);
+        }
+      }
+    }
+  }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createOutlineDeviceSwitchesPass() {
+  return std::make_unique<OutlineDeviceSwitchesPass>();
+}
+
+static PassRegistration<OutlineDeviceSwitchesPass> pass(
+    "iree-hal-outline-device-switches",
+    "Outlines hal.device.switch condition regions");
+
+}  // namespace HAL
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index d7fc72f..b74558d 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -51,6 +51,10 @@
   passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
   passManager.addNestedPass<FuncOp>(createCSEPass());
 
+  passManager.addPass(createOutlineDeviceSwitchesPass());
+  passManager.addPass(createMemoizeDeviceQueriesPass());
+  // TODO(benvanik): function deduplication to remove outlined functions.
+
   // TODO(benvanik): run symbol DCE when all symbols have visibility defined.
   // Right now the global value initializers don't have proper tracking and if
   // we do this we lose initializers that have side effects we care about.
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.h b/iree/compiler/Dialect/HAL/Transforms/Passes.h
index f5dc6c1..9db239d 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.h
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -48,6 +48,16 @@
                                    ExecutableTargetOptions executableOptions);
 
 //===----------------------------------------------------------------------===//
+// Device management
+//===----------------------------------------------------------------------===//
+
+// Outlines hal.device.switch conditions into functions and inlines conditions.
+std::unique_ptr<OperationPass<ModuleOp>> createOutlineDeviceSwitchesPass();
+
+// Finds hal.device.query ops and creates variables initialized on startup.
+std::unique_ptr<OperationPass<ModuleOp>> createMemoizeDeviceQueriesPass();
+
+//===----------------------------------------------------------------------===//
 // Executable translation and optimization
 //===----------------------------------------------------------------------===//
 
diff --git a/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp b/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp
index a03fb3e..c3bdb42 100644
--- a/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp
@@ -212,7 +212,7 @@
   }
 
   // Add the return.
-  builder.create<ReturnOp>(loc, funcResults);
+  builder.create<mlir::ReturnOp>(loc, funcResults);
   return success();
 }
 
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir b/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
new file mode 100644
index 0000000..de7f576
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir
@@ -0,0 +1,23 @@
+// RUN: iree-opt -split-input-file -iree-hal-memoize-device-queries %s | IreeFileCheck %s
+
+//      CHECK: func @_device_match_id_0_initializer() -> i1
+// CHECK-NEXT:   %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device
+// CHECK-NEXT:   %[[IS_MATCH:.+]] = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1
+// CHECK-NEXT:   return %[[IS_MATCH]] : i1
+//      CHECK: hal.variable @_device_match_id_0 init(@_device_match_id_0_initializer) : i1
+
+// CHECK: hal.variable @_device_match_id_1
+// CHECK: hal.variable @_device_match_id_2
+
+// CHECK-LABEL: @device_matchers
+func @device_matchers(%device : !hal.device) {
+  // CHECK-NEXT: = hal.variable.load @_device_match_id_0 : i1
+  %0 = hal.device.match.id %device, pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1
+  // CHECK-NEXT: = hal.variable.load @_device_match_id_0 : i1
+  %1 = hal.device.match.id %device, pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1
+  // CHECK-NEXT: = hal.variable.load @_device_match_id_1 : i1
+  %2 = hal.device.match.id %device, pattern = ["vulkan-v2.?-*"] : (!hal.device) -> i1
+  // CHECK-NEXT: = hal.variable.load @_device_match_id_2 : i1
+  %3 = hal.device.match.id %device, pattern = ["vulkan-*"] : (!hal.device) -> i1
+  return
+}
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/outline_device_switches.mlir b/iree/compiler/Dialect/HAL/Transforms/test/outline_device_switches.mlir
new file mode 100644
index 0000000..c2ae10f
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/test/outline_device_switches.mlir
@@ -0,0 +1,98 @@
+// RUN: iree-opt -split-input-file -iree-hal-outline-device-switches %s | IreeFileCheck %s
+
+// CHECK-LABEL: @simple_constants
+// CHECK-SAME: %[[DEVICE:.+]]: !hal.device
+func @simple_constants(%device : !hal.device) -> i32 {
+  // CHECK-DAG: %[[C0:.+]] = constant 0
+  %c0 = constant 0 : i32
+  // CHECK-DAG: %[[C1:.+]] = constant 1
+  %c1 = constant 1 : i32
+  // CHECK-DAG: %[[C2:.+]] = constant 2
+  %c2 = constant 2 : i32
+  %0 = hal.device.switch(%device : !hal.device) -> i32
+    // CHECK-NEXT: %[[IS0:.+]] = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1
+    // CHECK-NEXT: cond_br %[[IS0]], ^bb1, ^bb2
+    // CHECK-NEXT: ^bb1:
+    // CHECK-NEXT: %[[RES0:.+]] = call @simple_constants_switch_0_0(%[[C1]]) : (i32) -> i32
+    // CHECK-NEXT: br ^bb7(%[[RES0]] : i32)
+    #hal.device.match.id<"vulkan-v1.?-*">(%c1a = %c1 : i32) {
+      hal.return %c1a : i32
+    },
+    // CHECK-NEXT: ^bb2:
+    // CHECK-NEXT: %[[IS1L:.+]] = hal.device.match.id %arg0, pattern = ["vmla"] : (!hal.device) -> i1
+    // CHECK-NEXT: %[[IS1R:.+]] = hal.device.match.id %arg0, pattern = ["vulkan-*"] : (!hal.device) -> i1
+    // CHECK-NEXT: %[[IS1:.+]] = or %[[IS1L]], %[[IS1R]] : i1
+    // CHECK-NEXT: cond_br %[[IS1]], ^bb3, ^bb4
+    // CHECK-NEXT: ^bb3:
+    // CHECK-NEXT: %[[RES1:.+]] = call @simple_constants_switch_0_1(%[[C2]]) : (i32) -> i32
+    // CHECK-NEXT: br ^bb7(%[[RES1]] : i32)
+    #hal.match.any<[#hal.device.match.id<"vmla">, #hal.device.match.id<"vulkan-*">]>(%c2a = %c2 : i32) {
+      hal.return %c2a : i32
+    },
+    // CHECK-NEXT: ^bb4:
+    // CHECK-NEXT: %[[IS2:.+]] = constant 1 : i1
+    // CHECK-NEXT: cond_br %[[IS2]], ^bb5, ^bb6
+    // CHECK-NEXT: ^bb5:
+    // CHECK-NEXT: %[[RES2:.+]] = call @simple_constants_switch_0_2(%[[C0]]) : (i32) -> i32
+    // CHECK-NEXT: br ^bb7(%[[RES2]] : i32)
+    #hal.match.always(%c0a = %c0 : i32) {
+      hal.return %c0a : i32
+    }
+    // CHECK-NEXT: ^bb6:
+    // CHECK-NEXT: iree.unreachable
+  // CHECK-NEXT: ^bb7(%[[RES:.+]]: i32):
+  // CHECK-NEXT: return %[[RES]] : i32
+  return %0 : i32
+}
+
+// CHECK: func @simple_constants_switch_0_0(%arg0: i32) -> i32
+// CHECK-NEXT: return %arg0 : i32
+// CHECK: func @simple_constants_switch_0_1(%arg0: i32) -> i32
+// CHECK-NEXT: return %arg0 : i32
+// CHECK: func @simple_constants_switch_0_2(%arg0: i32) -> i32
+// CHECK-NEXT: return %arg0 : i32
+
+// -----
+
+// CHECK-LABEL: @no_results
+// CHECK-SAME: %[[DEVICE:.+]]: !hal.device
+func @no_results(%device : !hal.device) {
+  hal.device.switch(%device : !hal.device)
+    // CHECK-NEXT: %[[IS0:.+]] = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1
+    // CHECK-NEXT: cond_br %[[IS0]], ^bb1, ^bb2
+    // CHECK-NEXT: ^bb1:
+    // CHECK-NEXT: call @no_results_switch_0_0() : () -> ()
+    // CHECK-NEXT: br ^bb7
+    #hal.device.match.id<"vulkan-v1.?-*">() {
+      hal.return
+    },
+    // CHECK-NEXT: ^bb2:
+    // CHECK-NEXT: %[[IS1L:.+]] = hal.device.match.id %arg0, pattern = ["vmla"] : (!hal.device) -> i1
+    // CHECK-NEXT: %[[IS1R:.+]] = hal.device.match.id %arg0, pattern = ["vulkan-*"] : (!hal.device) -> i1
+    // CHECK-NEXT: %[[IS1:.+]] = or %[[IS1L]], %[[IS1R]] : i1
+    // CHECK-NEXT: cond_br %[[IS1]], ^bb3, ^bb4
+    // CHECK-NEXT: ^bb3:
+    // CHECK-NEXT: call @no_results_switch_0_1() : () -> ()
+    // CHECK-NEXT: br ^bb7
+    #hal.match.any<[#hal.device.match.id<"vmla">, #hal.device.match.id<"vulkan-*">]>() {
+      hal.return
+    },
+    // CHECK-NEXT: ^bb4:
+    // CHECK-NEXT: %[[IS2:.+]] = constant 1 : i1
+    // CHECK-NEXT: cond_br %[[IS2]], ^bb5, ^bb6
+    // CHECK-NEXT: ^bb5:
+    // CHECK-NEXT: call @no_results_switch_0_2() : () -> ()
+    // CHECK-NEXT: br ^bb7
+    #hal.match.always() {
+      hal.return
+    }
+    // CHECK-NEXT: ^bb6:
+    // CHECK-NEXT: br ^bb7
+  // CHECK-NEXT: ^bb7:
+  // CHECK-NEXT: return
+  return
+}
+
+// CHECK: func @no_results_switch_0_0()
+// CHECK: func @no_results_switch_0_1()
+// CHECK: func @no_results_switch_0_2()
diff --git a/iree/compiler/Dialect/HAL/hal.imports.mlir b/iree/compiler/Dialect/HAL/hal.imports.mlir
index 559b00f..6ab6153 100644
--- a/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -352,6 +352,13 @@
 ) -> !vm.ref<!hal.allocator>
 attributes {nosideeffects}
 
+// Returns true if the device ID matches the pattern.
+vm.import @device.match.id(
+  %device : !vm.ref<!hal.device>,
+  %pattern : !vm.ref<!iree.byte_buffer>
+) -> i32
+attributes {nosideeffects}
+
 //===----------------------------------------------------------------------===//
 // iree::hal::ExecutableCache
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/IREE/IR/IREEAttributes.h b/iree/compiler/Dialect/IREE/IR/IREEAttributes.h
index 4b2ae13..c4a9043 100644
--- a/iree/compiler/Dialect/IREE/IR/IREEAttributes.h
+++ b/iree/compiler/Dialect/IREE/IR/IREEAttributes.h
@@ -32,6 +32,10 @@
 namespace AttrKind {
 enum Kind {
   DescriptorSetLayoutBindingAttr = IREE::AttrKind::FIRST_HAL_ATTR,
+  MatchAlwaysAttr,
+  MatchAllAttr,
+  MatchAnyAttr,
+  DeviceMatchIDAttr,
 };
 }  // namespace AttrKind
 }  // namespace HAL
diff --git a/iree/compiler/Dialect/IREE/IR/IREEOps.td b/iree/compiler/Dialect/IREE/IR/IREEOps.td
index d200a80..f307237 100644
--- a/iree/compiler/Dialect/IREE/IR/IREEOps.td
+++ b/iree/compiler/Dialect/IREE/IR/IREEOps.td
@@ -90,4 +90,27 @@
   let hasCanonicalizer = 1;
 }
 
+def IREE_UnreachableOp : IREE_Op<"unreachable", [NoSideEffect, Terminator]> {
+  let summary = [{unreachable assertion op}];
+  let description = [{
+    Signals to the compiler that the parent block should not be reachable.
+    This may be converted into a runtime assertion, though ideally they are
+    stripped during translation.
+
+    ```mlir
+    ^bb0:
+      %true = constant 1 : i1
+      cond_br %true, ^bb2, ^bb1
+    ^bb1:
+      // Indicates that this branch should never be taken.
+      iree.unreachable
+    ^bb2:
+      ...
+
+    ```
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif  // IREE_DIALECT_IREE_OPS
diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.h b/iree/compiler/Dialect/IREE/IR/IREETypes.h
index 9a1629d..e4605cf 100644
--- a/iree/compiler/Dialect/IREE/IR/IREETypes.h
+++ b/iree/compiler/Dialect/IREE/IR/IREETypes.h
@@ -81,7 +81,10 @@
 
 namespace Strings {
 namespace TypeKind {
-enum Kind { String = IREE::TypeKind::FIRST_STRING_TYPE, StringTensor };
+enum Kind {
+  String = IREE::TypeKind::FIRST_STRING_TYPE,
+  StringTensor,
+};
 }  // namespace TypeKind
 }  // namespace Strings
 
diff --git a/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp b/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp
index 3734ff5..42bf726 100644
--- a/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp
+++ b/iree/compiler/Dialect/IREE/Tools/StructAttrGen.cpp
@@ -91,12 +91,17 @@
 };
 
 static void emitStructClass(const StructAttr &structAttr, raw_ostream &os) {
+  if (!structAttr.getAllFields().empty()) {
+    os << formatv(R"(
+namespace detail {
+struct {0}Storage;
+}  // namespace detail
+)",
+                  structAttr.getStructClassName());
+  }
   os << formatv(R"(
 // {0}
-namespace detail {
-struct {1}Storage;
-}  // namespace detail
-class {1} : public mlir::Attribute::AttrBase<{1}, mlir::Attribute, detail::{1}Storage> {
+class {1} : public mlir::Attribute::AttrBase<{1}, mlir::Attribute, {3}Storage> {
  public:
   using Base::Base;
 
@@ -105,7 +110,10 @@
 
 )",
                 structAttr.getDescription(), structAttr.getStructClassName(),
-                structAttr.getStructKind());
+                structAttr.getStructKind(),
+                structAttr.getAllFields().empty()
+                    ? "Attribute"
+                    : "detail::" + structAttr.getStructClassName());
 
   if (!structAttr.getAllFields().empty()) {
     os << "  static LogicalResult verifyConstructionInvariants(\n";
@@ -134,12 +142,14 @@
   os << ");\n\n";
 
   // Attribute return type constructor (APInt, etc).
-  os << formatv("  static {0} get(\n", structAttr.getStructClassName());
-  for (auto field : structAttr.getAllFields()) {
-    auto type = field.getType();
-    os << formatv("      {0} {1},\n", type.getReturnType(), field.getName());
+  if (!structAttr.getAllFields().empty()) {
+    os << formatv("  static {0} get(\n", structAttr.getStructClassName());
+    for (auto field : structAttr.getAllFields()) {
+      auto type = field.getType();
+      os << formatv("      {0} {1},\n", type.getReturnType(), field.getName());
+    }
+    os << "      mlir::MLIRContext* context);\n";
   }
-  os << "      mlir::MLIRContext* context);\n";
 
   os << R"(
   static Attribute parse(DialectAsmParser &p);
@@ -322,11 +332,13 @@
                   structAttr.getAllFields().front().getName());
   }
 
-  os << formatv("  return Base::get(context, AttrKind::{0},\n",
+  os << formatv("  return Base::get(context, AttrKind::{0}",
                 structAttr.getStructClassName());
-  os << "                   ";
-  interleaveComma(structAttr.getAllFields(), os,
-                  [&](StructFieldAttr field) { os << field.getName(); });
+  if (!structAttr.getAllFields().empty()) {
+    os << "\n,                   ";
+    interleaveComma(structAttr.getAllFields(), os,
+                    [&](StructFieldAttr field) { os << field.getName(); });
+  }
   os << ");\n";
 
   os << "}\n\n";
@@ -412,12 +424,16 @@
   }
   os << "\n";
 
-  emitStorageDef(structAttr, os);
-  emitVerifierDef(structAttr, os);
+  if (!structAttr.getAllFields().empty()) {
+    emitStorageDef(structAttr, os);
+    emitVerifierDef(structAttr, os);
+  }
   emitAttrFactoryDef(structAttr, os);
-  emitTypedFactoryDef(structAttr, os);
-  for (auto field : structAttr.getAllFields()) {
-    emitAccessorDefs(structAttr, field, os);
+  if (!structAttr.getAllFields().empty()) {
+    emitTypedFactoryDef(structAttr, os);
+    for (auto field : structAttr.getAllFields()) {
+      emitAccessorDefs(structAttr, field, os);
+    }
   }
   emitWalkStorageDef(structAttr, os);
 
diff --git a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
index c64d620..841fa31 100644
--- a/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
+++ b/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
@@ -146,11 +146,11 @@
   }
 };
 
-class ReturnOpConversion : public OpConversionPattern<ReturnOp> {
+class ReturnOpConversion : public OpConversionPattern<mlir::ReturnOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult matchAndRewrite(
-      ReturnOp srcOp, ArrayRef<Value> operands,
+      mlir::ReturnOp srcOp, ArrayRef<Value> operands,
       ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<IREE::VM::ReturnOp>(srcOp, operands);
     return success();
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.cpp b/iree/compiler/Dialect/VM/IR/VMOps.cpp
index d82ba6d..82fc1be 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -378,6 +378,7 @@
   p.printOptionalAttrDictWithKeyword(op->getAttrs(), /*elidedAttrs=*/{
                                          "sym_name",
                                          "is_mutable",
+                                         "initializer",
                                          "initial_value",
                                          "type",
                                      });
diff --git a/iree/compiler/Translation/CodegenUtils/CodegenUtils.cpp b/iree/compiler/Translation/CodegenUtils/CodegenUtils.cpp
index 8f554b7..ce0716d 100644
--- a/iree/compiler/Translation/CodegenUtils/CodegenUtils.cpp
+++ b/iree/compiler/Translation/CodegenUtils/CodegenUtils.cpp
@@ -147,8 +147,9 @@
   if (!isDispatchFuncImpl(funcOp))
     return op->emitError("expected operation to be within a dispatch function");
   MLIRContext *context = op->getContext();
-  SmallVector<int32_t, 3> workGroupSizeVec(workGroupSize.begin(),
-                                           workGroupSize.end());
+  SmallVector<int32_t, 3> workGroupSizeVec(llvm::map_range(
+      workGroupSize,
+      [](int64_t value) { return static_cast<int32_t>(value); }));
   workGroupSizeVec.resize(3, 1);
   funcOp.setAttr(spirv::getEntryPointABIAttrName(),
                  spirv::getEntryPointABIAttr(workGroupSizeVec, context));
diff --git a/iree/hal/api.cc b/iree/hal/api.cc
index e66c622..08e008c 100644
--- a/iree/hal/api.cc
+++ b/iree/hal/api.cc
@@ -1023,6 +1023,14 @@
   return reinterpret_cast<iree_hal_allocator_t*>(handle->allocator());
 }
 
+IREE_API_EXPORT iree_string_view_t IREE_API_CALL
+iree_hal_device_id(iree_hal_device_t* device) {
+  auto* handle = reinterpret_cast<Device*>(device);
+  if (!handle) return IREE_STRING_VIEW_EMPTY;
+  const auto& id = handle->info().id();
+  return iree_string_view_t{id.data(), id.size()};
+}
+
 //===----------------------------------------------------------------------===//
 // iree::hal::Driver
 //===----------------------------------------------------------------------===//
diff --git a/iree/hal/api.h b/iree/hal/api.h
index 22e58fa..89a9689 100644
--- a/iree/hal/api.h
+++ b/iree/hal/api.h
@@ -960,6 +960,12 @@
 IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL
 iree_hal_device_allocator(iree_hal_device_t* device);
 
+// Returns the device identifier.
+// This identifier may vary based on the runtime device type; for example, a
+// Vulkan device may return `vulkan-v1.1` or `vulkan-v1.2-spec1`.
+IREE_API_EXPORT iree_string_view_t IREE_API_CALL
+iree_hal_device_id(iree_hal_device_t* device);
+
 #endif  // IREE_API_NO_PROTOTYPES
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/hal/dawn/dawn_driver.cc b/iree/hal/dawn/dawn_driver.cc
index f4b0b6f..34dd792 100644
--- a/iree/hal/dawn/dawn_driver.cc
+++ b/iree/hal/dawn/dawn_driver.cc
@@ -37,9 +37,10 @@
   // supported_features |= DeviceFeature::kProfiling;
 
   // TODO(scotttodd): more clever/sanitized device naming.
+  std::string device_id = "dawn";
   std::string device_name = absl::StrCat("dawn-", adapter->GetPCIInfo().name);
 
-  return DeviceInfo(device_name, supported_features,
+  return DeviceInfo(device_id, device_name, supported_features,
                     reinterpret_cast<DriverDeviceID>(adapter));
 }
 
diff --git a/iree/hal/device_info.h b/iree/hal/device_info.h
index 6d5661f..4fb5ffb 100644
--- a/iree/hal/device_info.h
+++ b/iree/hal/device_info.h
@@ -63,12 +63,20 @@
 // TODO(benvanik): device info (caps, physical mappings, etc).
 class DeviceInfo {
  public:
-  DeviceInfo(std::string name, DeviceFeatureBitfield supported_features,
+  DeviceInfo(std::string id, std::string name,
+             DeviceFeatureBitfield supported_features,
              DriverDeviceID device_id = 0)
-      : name_(std::move(name)),
+      : id_(std::move(id)),
+        name_(std::move(name)),
         supported_features_(supported_features),
         device_id_(device_id) {}
 
+  // Machine-friendly device identifier used to match the device against
+  // compiler-generated patterns. This should be consistent with the device IDs
+  // emitted by the compiler. For example: `vulkan-v1.1-spec`.
+  const std::string& id() const { return id_; }
+
+  // Human-friendly device name.
   const std::string& name() const { return name_; }
 
   // Features supported by the device.
@@ -82,6 +90,7 @@
   DriverDeviceID device_id() const { return device_id_; }
 
  private:
+  const std::string id_;
   const std::string name_;
   const DeviceFeatureBitfield supported_features_;
   DriverDeviceID device_id_;
diff --git a/iree/hal/llvmjit/llvmjit_driver.cc b/iree/hal/llvmjit/llvmjit_driver.cc
index 7b22ffd..da4c6b1 100644
--- a/iree/hal/llvmjit/llvmjit_driver.cc
+++ b/iree/hal/llvmjit/llvmjit_driver.cc
@@ -30,7 +30,7 @@
   // supported_features |= DeviceFeature::kDebugging;
   // supported_features |= DeviceFeature::kCoverage;
   // supported_features |= DeviceFeature::kProfiling;
-  DeviceInfo device_info("llvm", supported_features);
+  DeviceInfo device_info("llvmjit", "llvm", supported_features);
   // TODO(benvanik): device info.
   return device_info;
 }
diff --git a/iree/hal/vmla/vmla_driver.cc b/iree/hal/vmla/vmla_driver.cc
index 59e2ad1..6f21f7f 100644
--- a/iree/hal/vmla/vmla_driver.cc
+++ b/iree/hal/vmla/vmla_driver.cc
@@ -35,7 +35,7 @@
   // supported_features |= DeviceFeature::kDebugging;
   // supported_features |= DeviceFeature::kCoverage;
   // supported_features |= DeviceFeature::kProfiling;
-  DeviceInfo device_info("vmla", supported_features);
+  DeviceInfo device_info("vmla", "vmla", supported_features);
   // TODO(benvanik): device info.
   return device_info;
 }
diff --git a/iree/hal/vulkan/vulkan_driver.cc b/iree/hal/vulkan/vulkan_driver.cc
index b67a762..f0da43f 100644
--- a/iree/hal/vulkan/vulkan_driver.cc
+++ b/iree/hal/vulkan/vulkan_driver.cc
@@ -73,7 +73,7 @@
   // supported_features |= DeviceFeature::kDebugging;
   // supported_features |= DeviceFeature::kCoverage;
   // supported_features |= DeviceFeature::kProfiling;
-  return DeviceInfo(std::move(name), supported_features,
+  return DeviceInfo("vulkan", std::move(name), supported_features,
                     reinterpret_cast<DriverDeviceID>(physical_device));
 }
 
diff --git a/iree/modules/hal/hal_module.cc b/iree/modules/hal/hal_module.cc
index ca5685d..19e22fd 100644
--- a/iree/modules/hal/hal_module.cc
+++ b/iree/modules/hal/hal_module.cc
@@ -590,9 +590,12 @@
       vm::ref<iree_hal_executable_t> executable, int32_t entry_point,
       uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
     IREE_TRACE_SCOPE0("HALModuleState::CommandBufferDispatch");
-    return reinterpret_cast<CommandBuffer*>(command_buffer.get())
-        ->Dispatch(reinterpret_cast<Executable*>(executable.get()), entry_point,
-                   {workgroup_x, workgroup_y, workgroup_z});
+    RETURN_IF_ERROR(FromApiStatus(
+        iree_hal_command_buffer_dispatch(command_buffer.get(), executable.get(),
+                                         entry_point, workgroup_x, workgroup_y,
+                                         workgroup_z),
+        IREE_LOC));
+    return OkStatus();
   }
 
   Status CommandBufferDispatchIndirect(
@@ -600,11 +603,12 @@
       vm::ref<iree_hal_executable_t> executable, int32_t entry_point,
       vm::ref<iree_hal_buffer_t> workgroups_buffer, int32_t workgroups_offset) {
     IREE_TRACE_SCOPE0("HALModuleState::CommandBufferDispatchIndirect");
-    return reinterpret_cast<CommandBuffer*>(command_buffer.get())
-        ->DispatchIndirect(reinterpret_cast<Executable*>(executable.get()),
-                           entry_point,
-                           reinterpret_cast<Buffer*>(workgroups_buffer.get()),
-                           workgroups_offset);
+    RETURN_IF_ERROR(
+        FromApiStatus(iree_hal_command_buffer_dispatch_indirect(
+                          command_buffer.get(), executable.get(), entry_point,
+                          workgroups_buffer.get(), workgroups_offset),
+                      IREE_LOC));
+    return OkStatus();
   }
 
   //===--------------------------------------------------------------------===//
@@ -673,6 +677,15 @@
     return vm::retain_ref(iree_hal_device_allocator(device.get()));
   }
 
+  StatusOr<int32_t> DeviceMatchID(vm::ref<iree_hal_device_t> device,
+                                  absl::string_view pattern) {
+    iree_string_view_t device_id = iree_hal_device_id(device.get());
+    return iree_string_view_match_pattern(
+               device_id, iree_string_view_t{pattern.data(), pattern.size()})
+               ? 1
+               : 0;
+  }
+
   //===--------------------------------------------------------------------===//
   // iree::hal::ExecutableCache
   //===--------------------------------------------------------------------===//
@@ -831,6 +844,7 @@
 
     vm::MakeNativeFunction("device.allocator",
                            &HALModuleState::DeviceAllocator),
+    vm::MakeNativeFunction("device.match.id", &HALModuleState::DeviceMatchID),
 
     vm::MakeNativeFunction("executable_cache.create",
                            &HALModuleState::ExecutableCacheCreate),