Adding `vm.br_table` op. (#15286)
This implements a 0-based dense branch table that can be lowered from
`cf.switch` which is a weird op but what we have and what comes from
`scf.index_switch`. The exact encoding may change in the future to make
interpreting faster but as structured it should be good for basic JITs.
The initial usage of this will come from entirely dense (or mostly
dense) `scf.index_switch` ops so only the most basic optimization for
offseting the base of the table is implemented. There's no
canonicalization/folding as `cf.switch` has most of that but we should
probably add some of the same handlers in the future.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
index ad9c441..1a60825 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.cpp
@@ -520,15 +520,21 @@
 
 SmallVector<std::pair<Register, Register>, 8>
 RegisterAllocation::remapSuccessorRegisters(Operation *op, int successorIndex) {
+  auto *targetBlock = op->getSuccessor(successorIndex);
+  auto targetOperands = cast<BranchOpInterface>(op)
+                            .getSuccessorOperands(successorIndex)
+                            .getForwardedOperands();
+  return remapSuccessorRegisters(op->getLoc(), targetBlock, targetOperands);
+}
+
+SmallVector<std::pair<Register, Register>, 8>
+RegisterAllocation::remapSuccessorRegisters(Location loc, Block *targetBlock,
+                                            OperandRange targetOperands) {
   // Compute the initial directed graph of register movements.
   // This may contain cycles ([reg 0->1], [reg 1->0], ...) that would not be
   // possible to evaluate as a direct remapping.
   SmallVector<std::pair<Register, Register>, 8> srcDstRegs;
-  auto *targetBlock = op->getSuccessor(successorIndex);
-  auto operands = cast<BranchOpInterface>(op)
-                      .getSuccessorOperands(successorIndex)
-                      .getForwardedOperands();
-  for (auto it : llvm::enumerate(operands)) {
+  for (auto it : llvm::enumerate(targetOperands)) {
     auto srcReg = mapToRegister(it.value());
     BlockArgument targetArg = targetBlock->getArgument(it.index());
     auto dstReg = mapToRegister(targetArg);
@@ -580,7 +586,7 @@
     assert(getMaxI32RegisterOrdinal() <= Register::kInt32RegisterCount &&
            "spilling i32 regs");
     if (getMaxI32RegisterOrdinal() > Register::kInt32RegisterCount) {
-      op->emitOpError() << "spilling entire i32 register address space";
+      mlir::emitError(loc) << "spilling entire i32 register address space";
     }
   }
   if (localScratchRefRegCount > 0) {
@@ -589,7 +595,7 @@
     assert(getMaxRefRegisterOrdinal() <= Register::kRefRegisterCount &&
            "spilling ref regs");
     if (getMaxRefRegisterOrdinal() > Register::kRefRegisterCount) {
-      op->emitOpError() << "spilling entire ref register address space";
+      mlir::emitError(loc) << "spilling entire ref register address space";
     }
   }
 
diff --git a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h
index 8dab791..2e345a9 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Analysis/RegisterAllocation.h
@@ -200,6 +200,9 @@
   // may have their move bit set.
   SmallVector<std::pair<Register, Register>, 8>
   remapSuccessorRegisters(Operation *op, int successorIndex);
+  SmallVector<std::pair<Register, Register>, 8>
+  remapSuccessorRegisters(Location loc, Block *targetBlock,
+                          OperandRange targetOperands);
 
 private:
   int maxI32RegisterOrdinal_ = -1;
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
index f025d2a..dbafa45 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
@@ -1085,6 +1085,80 @@
   }
 };
 
+class SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(cf::SwitchOp srcOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Special handling for default only ops: just jump to default.
+    if (srcOp.getCaseDestinations().empty()) {
+      rewriter.replaceOpWithNewOp<IREE::VM::BranchOp>(
+          srcOp, srcOp.getDefaultDestination(), adaptor.getDefaultOperands());
+      return success();
+    }
+
+    // NOTE: cf.switch can have sparse indices but we cannot; instead we fill
+    // any gaps with branches to the default block. This is wasteful but keeps
+    // the runtime super simple - if we have offset or really sparse tables
+    // (default + case 10000 + case 400000) we can optimize those in the
+    // compiler by using multiple branch tables, inverse lookups via a lookup
+    // op, etc.
+    //
+    // To make this simple here we get all cases, sort them, and then walk in
+    // order while filling gaps as need.
+    SmallVector<std::pair<int, int64_t>> caseValues;
+    for (auto [i, value] : llvm::enumerate(srcOp.getCaseValues().value())) {
+      caseValues.push_back(std::make_pair(i, value.getSExtValue()));
+    }
+    llvm::stable_sort(caseValues,
+                      [](std::pair<int, int64_t> a, std::pair<int, int64_t> b) {
+                        return a.second < b.second;
+                      });
+
+    // Sanity check negative values, which are tricky.
+    int64_t minValue = caseValues.front().second;
+    if (minValue < 0) {
+      return rewriter.notifyMatchFailure(
+          srcOp, "negative case indices are not supported by the VM (today); "
+                 "needs positive offsetting");
+    }
+
+    // If the first branch is offset from 0 then we can subtract that out to
+    // avoid holes at the start of the table.
+    Value index = adaptor.getFlag();
+    if (minValue > 0) {
+      index = rewriter.create<IREE::VM::SubI32Op>(
+          srcOp.getLoc(), rewriter.getI32Type(), index,
+          rewriter.create<IREE::VM::ConstI32Op>(
+              srcOp.getLoc(), static_cast<int32_t>(minValue)));
+      for (auto &[i, value] : caseValues) {
+        value -= minValue;
+      }
+    }
+
+    // Emit each dense case, filling interior holes as needed.
+    SmallVector<Block *> caseDestinations;
+    SmallVector<ValueRange> caseOperands;
+    int64_t lastValue = 0;
+    for (auto [i, value] : caseValues) {
+      while (value != lastValue && value - lastValue != 1) {
+        // Hole to fill.
+        caseDestinations.push_back(srcOp.getDefaultDestination());
+        caseOperands.push_back(adaptor.getDefaultOperands());
+        ++lastValue;
+      }
+      caseDestinations.push_back(srcOp.getCaseDestinations()[i]);
+      caseOperands.push_back(srcOp.getCaseOperands(i));
+      lastValue = value;
+    }
+
+    rewriter.replaceOpWithNewOp<IREE::VM::BranchTableOp>(
+        srcOp, index, adaptor.getDefaultOperands(), caseOperands,
+        srcOp.getDefaultDestination(), caseDestinations);
+    return success();
+  }
+};
+
 } // namespace
 
 void populateStandardToVMPatterns(MLIRContext *context,
@@ -1093,9 +1167,9 @@
   patterns
       .insert<AssertOpConversion, BranchOpConversion, CallOpConversion,
               CmpI32OpConversion, CmpI64OpConversion, CmpF32OpConversion,
-              CondBranchOpConversion, ModuleOpConversion, FuncOpConversion,
-              ExternalFuncOpConversion, ReturnOpConversion, SelectOpConversion>(
-          typeConverter, context);
+              CondBranchOpConversion, SwitchOpConversion, ModuleOpConversion,
+              FuncOpConversion, ExternalFuncOpConversion, ReturnOpConversion,
+              SelectOpConversion>(typeConverter, context);
 
   // TODO(#2878): figure out how to pass the type converter in a supported way.
   // Right now if we pass the type converter as the first argument - triggering
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/control_flow_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/control_flow_ops.mlir
index fa46d3b..82354fa 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/control_flow_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/test/control_flow_ops.mlir
@@ -74,6 +74,43 @@
 }
 
 // -----
+// CHECK-LABEL: @t005_br_table
+module @t005_br_table {
+
+module {
+  // CHECK: vm.func private @my_fn
+  // CHECK-SAME: %[[FLAG:[a-zA-Z0-9$._-]+]]
+  // CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]
+  // CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]
+  // CHECK-SAME: %[[ARG3:[a-zA-Z0-9$._-]+]]
+  func.func @my_fn(%flag: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 {
+    // CHECK: %[[INDEX:.+]] = vm.sub.i32 %[[FLAG]], %c100
+    //      CHECK: vm.br_table %[[INDEX]] {
+    // CHECK-NEXT:   default: ^bb1(%[[ARG1]] : i32),
+    // CHECK-NEXT:   0: ^bb1(%[[ARG2]] : i32),
+    // CHECK-NEXT:   1: ^bb1(%[[ARG1]] : i32),
+    // CHECK-NEXT:   2: ^bb1(%[[ARG1]] : i32),
+    // CHECK-NEXT:   3: ^bb1(%[[ARG1]] : i32),
+    // CHECK-NEXT:   4: ^bb1(%[[ARG3]] : i32),
+    // CHECK-NEXT:   5: ^bb1(%[[ARG1]] : i32),
+    // CHECK-NEXT:   6: ^bb2
+    // CHECK-NEXT: }
+    cf.switch %flag : i32, [
+      default: ^bb1(%arg1 : i32),
+      104: ^bb1(%arg3 : i32),
+      100: ^bb1(%arg2 : i32),
+      106: ^bb2
+    ]
+  ^bb1(%0 : i32):
+    return %0 : i32
+  ^bb2:
+    return %arg1 : i32
+  }
+}
+
+}
+
+// -----
 // CHECK-LABEL: @t006_assert
 module @t006_assert {
 
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
index 1584742..113e12e 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp
@@ -3064,6 +3064,43 @@
   }
 };
 
+// EmitC does not support cf.switch so we turn the branch table into a long
+// sequence of conditional branches.
+class BranchTableOpConversion
+    : public EmitCConversionPattern<IREE::VM::BranchTableOp> {
+  using Adaptor = IREE::VM::BranchTableOp::Adaptor;
+  using EmitCConversionPattern<IREE::VM::BranchTableOp>::EmitCConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(IREE::VM::BranchTableOp op, Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto caseDestinations = op.getCaseDestinations();
+    SmallVector<Block *> caseBlocks;
+    {
+      OpBuilder::InsertionGuard guard(rewriter);
+      auto *nextBlock = rewriter.getInsertionBlock()->getNextNode();
+      for (size_t i = 0; i < caseDestinations.size(); ++i)
+        caseBlocks.push_back(rewriter.createBlock(nextBlock));
+      caseBlocks.push_back(rewriter.createBlock(nextBlock)); // default
+    }
+    rewriter.create<IREE::VM::BranchOp>(op.getLoc(), caseBlocks.front());
+    for (size_t i = 0; i < caseDestinations.size(); ++i) {
+      rewriter.setInsertionPointToStart(caseBlocks[i]);
+      Value cmp = rewriter.create<IREE::VM::CmpEQI32Op>(
+          op.getLoc(), rewriter.getI1Type(), adaptor.getIndex(),
+          rewriter.create<IREE::VM::ConstI32Op>(op.getLoc(), i));
+      auto caseOperands = adaptor.getCaseOperands();
+      rewriter.create<IREE::VM::CondBranchOp>(
+          op.getLoc(), cmp, caseDestinations[i], caseOperands[i],
+          caseBlocks[i + 1], ValueRange{});
+    }
+    rewriter.setInsertionPointToStart(caseBlocks.back());
+    rewriter.replaceOpWithNewOp<IREE::VM::BranchOp>(
+        op, op.getDefaultDestination(), adaptor.getDefaultOperands());
+    return success();
+  }
+};
+
 class ReturnOpConversion : public EmitCConversionPattern<IREE::VM::ReturnOp> {
   using Adaptor = IREE::VM::ReturnOp::Adaptor;
   using EmitCConversionPattern<IREE::VM::ReturnOp>::EmitCConversionPattern;
@@ -4306,6 +4343,7 @@
     CallOpConversion<IREE::VM::CallVariadicOp>,
     CompareRefNotZeroOpConversion,
     CondBranchOpConversion,
+    BranchTableOpConversion,
     ConstOpConversion<IREE::VM::ConstF32Op>,
     ConstOpConversion<IREE::VM::ConstF64Op>,
     ConstOpConversion<IREE::VM::ConstI32Op>,
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir
index 5ddc33b..3a2e9ec 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/test/control_flow_ops.mlir
@@ -279,6 +279,56 @@
 
 // -----
 
+// CHECK-LABEL: @my_module_br_table_empty
+vm.module @my_module {
+  vm.func @br_table_empty(%arg0: i32, %arg1: i32) -> i32 {
+    //  CHECK-NOT: vm.br_table
+    //      CHECK:  cf.br ^bb1
+    // CHECK-NEXT: ^bb1:
+    // CHECK-NEXT:  cf.br ^bb2(%arg4 : i32)
+    // CHECK-NEXT: ^bb2(%0: i32):
+    //      CHECK:  return
+    vm.br_table %arg0 {
+      default: ^bb1(%arg1 : i32)
+    }
+  ^bb1(%0 : i32):
+    // CHECK: return
+    // CHECK-NOT: vm.return
+    vm.return %0 : i32
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @my_module_br_table
+vm.module @my_module {
+  vm.func @br_table(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+    //  CHECK-NOT: vm.br_table
+    //      CHECK:  cf.br ^bb1
+    // CHECK-NEXT: ^bb1:
+    //      CHECK:  emitc.call "vm_cmp_eq_i32"
+    //      CHECK:  emitc.call "vm_cmp_nz_i32"
+    //      CHECK:  cf.cond_br %{{.+}}, ^bb5(%arg4 : i32), ^bb2
+    //      CHECK: ^bb2:
+    //      CHECK:  emitc.call "vm_cmp_eq_i32"
+    //      CHECK:  emitc.call "vm_cmp_nz_i32"
+    //      CHECK:  cf.cond_br %{{.+}}, ^bb5(%arg5 : i32), ^bb3
+    //      CHECK: ^bb3:
+    //      CHECK:  cf.br ^bb4(%arg3 : i32)
+    vm.br_table %arg0 {
+      default: ^bb1(%arg0 : i32),
+      0: ^bb2(%arg1 : i32),
+      1: ^bb2(%arg2 : i32)
+    }
+  ^bb1(%0 : i32):
+    vm.return %0 : i32
+  ^bb2(%1 : i32):
+    vm.return %1 : i32
+  }
+}
+
+// -----
+
 vm.module @my_module {
   // CHECK-LABEL: @my_module_fail
   vm.func @fail(%arg0 : i32) {
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td
index 93c507a..b7cc789 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMBase.td
@@ -126,6 +126,8 @@
     "e.encodeStrAttr(getOperation()->getAttrOfType<StringAttr>(\"" # name # "\"))">;
 class VM_EncBranch<string blockName, string operandsName, int successorIndex> : VM_EncEncodeExpr<
     "e.encodeBranch({0}(), " # operandsName # "(), " # successorIndex # ")", [blockName]>;
+class VM_EncBranchTable<string caseBlockNames, string caseOperandsName, int baseSuccessorIndex> : VM_EncEncodeExpr<
+    "e.encodeBranchTable(" # caseBlockNames # "(), " # caseOperandsName # "(), " # baseSuccessorIndex # ")">;
 class VM_EncOperand<string name, int ordinal> : VM_EncEncodeExpr<
     "e.encodeOperand({0}(), " # ordinal # ")", [name]>;
 class VM_EncVariadicOperands<string name> : VM_EncEncodeExpr<
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMFuncEncoder.h b/compiler/src/iree/compiler/Dialect/VM/IR/VMFuncEncoder.h
index 9c61414..536eb12 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMFuncEncoder.h
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMFuncEncoder.h
@@ -61,6 +61,11 @@
                                      Operation::operand_range operands,
                                      int successorIndex) = 0;
 
+  // Encodes a branch table.
+  virtual LogicalResult encodeBranchTable(SuccessorRange caseSuccessors,
+                                          OperandRangeRange caseOperands,
+                                          int baseSuccessorIndex) = 0;
+
   // Encodes an operand value (by reference).
   virtual LogicalResult encodeOperand(Value value, int ordinal) = 0;
 
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesCore.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesCore.td
index 6667afb..1757cbd 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesCore.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesCore.td
@@ -47,7 +47,7 @@
   string opcodeEnumTag = enumTag;
 }
 
-// Next available opcode: 0x83
+// Next available opcode: 0x84
 
 // Globals:
 def VM_OPC_GlobalLoadI32         : VM_OPC<0x00, "GlobalLoadI32">;
@@ -184,6 +184,7 @@
 def VM_OPC_Block                 : VM_OPC<0x79, "Block">;
 def VM_OPC_Branch                : VM_OPC<0x56, "Branch">;
 def VM_OPC_CondBranch            : VM_OPC<0x57, "CondBranch">;
+def VM_OPC_BranchTable           : VM_OPC<0x83, "BranchTable">;
 def VM_OPC_Call                  : VM_OPC<0x58, "Call">;
 def VM_OPC_CallVariadic          : VM_OPC<0x59, "CallVariadic">;
 def VM_OPC_Return                : VM_OPC<0x5A, "Return">;
@@ -347,6 +348,7 @@
 
     VM_OPC_Branch,
     VM_OPC_CondBranch,
+    VM_OPC_BranchTable,
     VM_OPC_Call,
     VM_OPC_CallVariadic,
     VM_OPC_Return,
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 28eb548..c7147a8 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -1265,6 +1265,87 @@
                             : SuccessorOperands(getFalseDestOperandsMutable());
 }
 
+static ParseResult parseBranchTableCases(
+    OpAsmParser &parser, Block *&defaultDestination,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
+    SmallVectorImpl<Type> &defaultOperandTypes,
+    SmallVectorImpl<Block *> &caseDestinations,
+    SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
+    SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
+  if (parser.parseKeyword("default") || parser.parseColon() ||
+      parser.parseSuccessor(defaultDestination))
+    return failure();
+  if (succeeded(parser.parseOptionalLParen())) {
+    if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
+                                /*allowResultNumber=*/false) ||
+        parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
+      return failure();
+  }
+  while (succeeded(parser.parseOptionalComma())) {
+    int64_t index = 0;
+    if (failed(parser.parseInteger(index)))
+      return failure();
+    if (index != caseDestinations.size())
+      return failure();
+    Block *destination;
+    SmallVector<OpAsmParser::UnresolvedOperand> operands;
+    SmallVector<Type> operandTypes;
+    if (failed(parser.parseColon()) ||
+        failed(parser.parseSuccessor(destination)))
+      return failure();
+    if (succeeded(parser.parseOptionalLParen())) {
+      if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
+                                         /*allowResultNumber=*/false)) ||
+          failed(parser.parseColonTypeList(operandTypes)) ||
+          failed(parser.parseRParen()))
+        return failure();
+    }
+    caseDestinations.push_back(destination);
+    caseOperands.emplace_back(operands);
+    caseOperandTypes.emplace_back(operandTypes);
+  }
+  return success();
+}
+
+static void printBranchTableCases(OpAsmPrinter &p, Operation *op,
+                                  Block *defaultDestination,
+                                  OperandRange defaultOperands,
+                                  TypeRange defaultOperandTypes,
+                                  SuccessorRange caseDestinations,
+                                  OperandRangeRange caseOperands,
+                                  const TypeRangeRange &caseOperandTypes) {
+  p.increaseIndent();
+  p << "  default: ";
+  p.printSuccessorAndUseList(defaultDestination, defaultOperands);
+  int index = 0;
+  for (auto [caseDestination, caseOperands, caseOperandTypes] :
+       llvm::zip_equal(caseDestinations, caseOperands, caseOperandTypes)) {
+    p << ',';
+    p.printNewline();
+    p << (index++) << ": ";
+    p.printSuccessorAndUseList(caseDestination, caseOperands);
+  }
+  p.decreaseIndent();
+  p.printNewline();
+}
+
+SuccessorOperands BranchTableOp::getSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
+                                      : getCaseOperandsMutable(index - 1));
+}
+
+Block *BranchTableOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
+  SuccessorRange caseDestinations = getCaseDestinations();
+  if (auto valueAttr = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
+    int64_t value = valueAttr.getValue().getSExtValue();
+    if (value < 0 || value >= caseDestinations.size())
+      return getDefaultDestination();
+    return caseDestinations[value];
+  }
+  return nullptr;
+}
+
 LogicalResult verifyFailOp(Operation *op, Value statusVal) {
   APInt status;
   if (matchPattern(statusVal, m_ConstantInt(&status))) {
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.h b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.h
index 8b6cb1a..768e778 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.h
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.h
@@ -8,6 +8,7 @@
 #define IREE_COMPILER_DIALECT_VM_IR_VMOPS_H_
 
 #include <cstdint>
+#include <numeric>
 
 #include "iree/compiler/Dialect/Util/IR/UtilTraits.h"
 #include "iree/compiler/Dialect/VM/IR/VMDialect.h"
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
index 8340c19..c1f03e6 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -3941,9 +3941,9 @@
       build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
             falseDest);
     }]>,
-  OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
-    "Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands),
-  [{
+    OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
+      "Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands),
+    [{
       build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
             falseOperands);
     }]>
@@ -4005,6 +4005,69 @@
   let hasCanonicalizer = 1;
 }
 
+def VM_BranchTableOp : VM_PureOp<"br_table", [
+    AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+    DeclareOpInterfaceMethods<VM_SerializableOpInterface>,
+    Terminator,
+  ]> {
+  let summary = [{branch table operation}];
+  let description = [{
+    Represents a branch table instructing execution to branch to the block with
+    the specified index. If the index is out of bounds then execution will
+    branch to the default block.
+
+    ```
+    vm.br_table %index {
+      default: ^bb1(%a : i64),
+      0: ^bb2,
+      1: ^bb3(%c : i64)
+    }
+   ```
+  }];
+
+  let arguments = (ins
+    I32:$index,
+    Variadic<VM_AnyType>:$defaultOperands,
+    VariadicOfVariadic<VM_AnyType, "case_operand_segments">:$caseOperands,
+    DenseI32ArrayAttr:$case_operand_segments
+  );
+
+  let successors = (successor
+    AnySuccessor:$defaultDestination,
+    VariadicSuccessor<AnySuccessor>:$caseDestinations
+  );
+
+  let assemblyFormat = [{
+    $index ` ` `{` `\n`
+    custom<BranchTableCases>(
+        $defaultDestination, $defaultOperands, type($defaultOperands),
+        $caseDestinations, $caseOperands, type($caseOperands))
+    `}`
+    attr-dict
+  }];
+
+  let encoding = [
+    VM_EncOpcode<VM_OPC_BranchTable>,
+    VM_EncOperand<"index", 0>,
+    VM_EncBranch<"defaultDestination", "getDefaultOperands", 0>,
+    VM_EncBranchTable<"getCaseDestinations", "getCaseOperands", 0>,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return the operands for the case destination block at the given index.
+    OperandRange getCaseOperands(unsigned index) {
+      return getCaseOperands()[index];
+    }
+
+    /// Return a mutable range of operands for the case destination block at the
+    /// given index.
+    MutableOperandRange getCaseOperandsMutable(unsigned index) {
+      return getCaseOperandsMutable()[index];
+    }
+  }];
+}
+
 class VM_CallBaseOp<string mnemonic, list<Trait> traits = []> :
     VM_Op<mnemonic, !listconcat(traits, [
       DeclareOpInterfaceMethods<VM_SerializableOpInterface>,
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_ops.mlir
index d834099..083862a 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_ops.mlir
@@ -52,6 +52,44 @@
 
 // -----
 
+// CHECK-LABEL: @br_table_empty
+vm.module @my_module {
+  vm.func @br_table_empty(%arg0: i32, %arg1: i32) -> i32 {
+    //      CHECK: vm.br_table %arg0 {
+    // CHECK-NEXT:   default: ^bb1(%arg1 : i32)
+    // CHECK-NEXT: }
+    vm.br_table %arg0 {
+      default: ^bb1(%arg1 : i32)
+    }
+  ^bb1(%0 : i32):
+    vm.return %0 : i32
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @br_table
+vm.module @my_module {
+  vm.func @br_table(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+    //      CHECK: vm.br_table %arg0 {
+    // CHECK-NEXT:   default: ^bb1(%arg0 : i32),
+    // CHECK-NEXT:   0: ^bb2(%arg1 : i32),
+    // CHECK-NEXT:   1: ^bb2(%arg2 : i32)
+    // CHECK-NEXT: }
+    vm.br_table %arg0 {
+      default: ^bb1(%arg0 : i32),
+      0: ^bb2(%arg1 : i32),
+      1: ^bb2(%arg2 : i32)
+    }
+  ^bb1(%0 : i32):
+    vm.return %0 : i32
+  ^bb2(%1 : i32):
+    vm.return %1 : i32
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @call_fn
 vm.module @my_module {
   vm.import private @import_fn(%arg0 : i32) -> i32
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
index 2ffe492..e648f81 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
@@ -216,6 +216,19 @@
     return success();
   }
 
+  LogicalResult encodeBranchTable(SuccessorRange caseSuccessors,
+                                  OperandRangeRange caseOperands,
+                                  int baseSuccessorIndex) override {
+    if (failed(writeUint16(caseSuccessors.size())))
+      return failure();
+    for (auto [successor, operands] :
+         llvm::zip_equal(caseSuccessors, caseOperands)) {
+      if (failed(encodeBranch(successor, operands, ++baseSuccessorIndex)))
+        return failure();
+    }
+    return success();
+  }
+
   LogicalResult encodeOperand(Value value, int ordinal) override {
     uint16_t reg =
         registerAllocation_->mapUseToRegister(value, currentOp_, ordinal)
diff --git a/runtime/src/iree/vm/bytecode/disassembler.c b/runtime/src/iree/vm/bytecode/disassembler.c
index babdbc5..039e037 100644
--- a/runtime/src/iree/vm/bytecode/disassembler.c
+++ b/runtime/src/iree/vm/bytecode/disassembler.c
@@ -31,6 +31,9 @@
 #define VM_ParseConstI8(name) \
   OP_I8(0);                   \
   ++pc;
+#define VM_ParseConstI16(name) \
+  OP_I16(0);                   \
+  pc += 2;
 #define VM_ParseConstI32(name) \
   OP_I32(0);                   \
   pc += 4;
@@ -1588,6 +1591,31 @@
       break;
     }
 
+    DISASM_OP(CORE, BranchTable) {
+      uint16_t index_reg = VM_ParseOperandRegI32("index");
+      IREE_RETURN_IF_ERROR(
+          iree_string_builder_append_cstring(b, "vm.br_table "));
+      EMIT_I32_REG_NAME(index_reg);
+      EMIT_OPTIONAL_VALUE_I32(regs->i32[index_reg]);
+      int32_t default_block_pc = VM_ParseBranchTarget("default_dest");
+      const iree_vm_register_remap_list_t* default_remap_list =
+          VM_ParseBranchOperands("default_operands");
+      IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
+          b, " { default: ^%08X(", default_block_pc));
+      EMIT_REMAP_LIST(default_remap_list);
+      uint16_t table_size = VM_ParseConstI16(table_size);
+      for (uint16_t i = 0; i < table_size; ++i) {
+        int32_t case_block_pc = VM_ParseBranchTarget("case_dest");
+        const iree_vm_register_remap_list_t* case_remap_list =
+            VM_ParseBranchOperands("case_operands");
+        IREE_RETURN_IF_ERROR(iree_string_builder_append_format(
+            b, "), %u: ^%08X(", i, case_block_pc));
+        EMIT_REMAP_LIST(case_remap_list);
+      }
+      IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(b, ") }"));
+      break;
+    }
+
     DISASM_OP(CORE, Call) {
       int32_t function_ordinal = VM_ParseFuncAttr("callee");
       const iree_vm_register_list_t* src_reg_list =
diff --git a/runtime/src/iree/vm/bytecode/dispatch.c b/runtime/src/iree/vm/bytecode/dispatch.c
index 81b773b..84fad05 100644
--- a/runtime/src/iree/vm/bytecode/dispatch.c
+++ b/runtime/src/iree/vm/bytecode/dispatch.c
@@ -1628,6 +1628,37 @@
       }
     });
 
+    DISPATCH_OP(CORE, BranchTable, {
+      int32_t index = VM_DecOperandRegI32("index");
+      int32_t default_block_pc = VM_DecBranchTarget("default_dest");
+      const iree_vm_register_remap_list_t* default_remap_list =
+          VM_DecBranchOperands("default_operands");
+      uint16_t table_size = VM_DecConstI16("table_size");
+      if (index < 0 || index >= table_size) {
+        // Out-of-bounds index; jump to default block.
+        pc = default_block_pc + IREE_VM_BLOCK_MARKER_SIZE;  // skip block marker
+        if (IREE_UNLIKELY(default_remap_list->size > 0)) {
+          iree_vm_bytecode_dispatch_remap_branch_registers(regs_i32, regs_ref,
+                                                           default_remap_list);
+        }
+      } else {
+        // In-bounds index; decode until we hit the case and branch as the
+        // cases are all variable length and we can't directly index.
+        for (uint16_t i = 0; i < index; ++i) {
+          VM_DecBranchTarget("case_dest");
+          VM_DecBranchOperands("case_operands");
+        }
+        int32_t case_block_pc = VM_DecBranchTarget("case_dest");
+        const iree_vm_register_remap_list_t* case_remap_list =
+            VM_DecBranchOperands("case_operands");
+        pc = case_block_pc + IREE_VM_BLOCK_MARKER_SIZE;  // skip block marker
+        if (IREE_UNLIKELY(case_remap_list->size > 0)) {
+          iree_vm_bytecode_dispatch_remap_branch_registers(regs_i32, regs_ref,
+                                                           case_remap_list);
+        }
+      }
+    });
+
     DISPATCH_OP(CORE, Call, {
       int32_t function_ordinal = VM_DecFuncAttr("callee");
       const iree_vm_register_list_t* src_reg_list =
diff --git a/runtime/src/iree/vm/bytecode/dispatch_util.h b/runtime/src/iree/vm/bytecode/dispatch_util.h
index d6d377b..cf5a092 100644
--- a/runtime/src/iree/vm/bytecode/dispatch_util.h
+++ b/runtime/src/iree/vm/bytecode/dispatch_util.h
@@ -135,6 +135,9 @@
 #define VM_DecConstI8(name) \
   OP_I8(0);                 \
   ++pc;
+#define VM_DecConstI16(name) \
+  OP_I16(0);                 \
+  pc += 2;
 #define VM_DecConstI32(name) \
   OP_I32(0);                 \
   pc += 4;
diff --git a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h
index 7684b2e..5359b22 100644
--- a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h
+++ b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h
@@ -138,7 +138,7 @@
   IREE_VM_OP_CORE_MaxI64S = 0x80,
   IREE_VM_OP_CORE_MaxI64U = 0x81,
   IREE_VM_OP_CORE_CastAnyRef = 0x82,
-  IREE_VM_OP_CORE_RSV_0x83,
+  IREE_VM_OP_CORE_BranchTable = 0x83,
   IREE_VM_OP_CORE_RSV_0x84,
   IREE_VM_OP_CORE_RSV_0x85,
   IREE_VM_OP_CORE_RSV_0x86,
@@ -397,7 +397,7 @@
     OPC(0x80, MaxI64S) \
     OPC(0x81, MaxI64U) \
     OPC(0x82, CastAnyRef) \
-    RSV(0x83) \
+    OPC(0x83, BranchTable) \
     RSV(0x84) \
     RSV(0x85) \
     RSV(0x86) \
diff --git a/runtime/src/iree/vm/bytecode/verifier.c b/runtime/src/iree/vm/bytecode/verifier.c
index 2b6055d..3122a70 100644
--- a/runtime/src/iree/vm/bytecode/verifier.c
+++ b/runtime/src/iree/vm/bytecode/verifier.c
@@ -516,6 +516,11 @@
   uint8_t name = OP_I8(0);                 \
   (void)(name);                            \
   ++pc;
+#define VM_VerifyConstI16(name)            \
+  IREE_VM_VERIFY_PC_RANGE(pc + 2, max_pc); \
+  uint32_t name = OP_I16(0);               \
+  (void)(name);                            \
+  pc += 2;
 #define VM_VerifyConstI32(name)            \
   IREE_VM_VERIFY_PC_RANGE(pc + 4, max_pc); \
   uint32_t name = OP_I32(0);               \
@@ -1548,6 +1553,18 @@
       verify_state->in_block = 0;  // terminator
     });
 
+    VERIFY_OP(CORE, BranchTable, {
+      VM_VerifyOperandRegI32(index);
+      VM_VerifyBranchTarget(default_dest_pc);
+      VM_VerifyBranchOperands(default_operands);
+      VM_VerifyConstI16(table_size);
+      for (uint16_t i = 0; i < table_size; ++i) {
+        VM_VerifyBranchTarget(case_dest_pc);
+        VM_VerifyBranchOperands(case_operands);
+      }
+      verify_state->in_block = 0;  // terminator
+    });
+
     VERIFY_OP(CORE, Call, {
       VM_VerifyFuncAttr(callee_ordinal);
       VM_VerifyVariadicOperandsAny(operands);
diff --git a/runtime/src/iree/vm/test/control_flow_ops.mlir b/runtime/src/iree/vm/test/control_flow_ops.mlir
index 3b857be..409c445 100644
--- a/runtime/src/iree/vm/test/control_flow_ops.mlir
+++ b/runtime/src/iree/vm/test/control_flow_ops.mlir
@@ -124,6 +124,43 @@
     vm.return
   }
 
+  vm.export @test_br_table_inbounds
+  vm.func private @test_br_table_inbounds() {
+    %c0 = vm.const.i32 0
+    %c1 = vm.const.i32 1
+    %c2 = vm.const.i32 2
+    %c1dno = util.optimization_barrier %c1 : i32
+    vm.br_table %c1dno {
+      default: ^bb1(%c2 : i32),
+      0: ^bb2(%c0 : i32),
+      1: ^bb2(%c1 : i32)
+    }
+  ^bb1(%arg0: i32):
+    vm.fail %arg0, "unreachable"
+  ^bb2(%arg1: i32):
+    vm.check.eq %arg1, %c1, "expected table[1] branch" : i32
+    vm.return
+  }
+
+  vm.export @test_br_table_outofbounds
+  vm.func private @test_br_table_outofbounds() {
+    %c0 = vm.const.i32 0
+    %c1 = vm.const.i32 1
+    %c2 = vm.const.i32 2
+    %c-1 = vm.const.i32 -1
+    %c-1dno = util.optimization_barrier %c-1 : i32
+    vm.br_table %c-1dno {
+      default: ^bb1(%c0 : i32),
+      0: ^bb2(%c1 : i32),
+      1: ^bb2(%c2 : i32)
+    }
+  ^bb1(%arg0: i32):
+    vm.check.eq %arg0, %c0, "expected default branch" : i32
+    vm.return
+  ^bb2(%arg1: i32):
+    vm.fail %arg1, "unreachable"
+  }
+
   vm.rodata private @buffer_a dense<[1]> : tensor<1xi8>
   vm.rodata private @buffer_b dense<[2]> : tensor<1xi8>
   vm.rodata private @buffer_c dense<[3]> : tensor<1xi8>