Merge pull request #7265 from MaheshRavishankar:main-to-google

PiperOrigin-RevId: 401101859
diff --git a/bindings/python/iree/runtime/vm.cc b/bindings/python/iree/runtime/vm.cc
index ae39185..d49cf01 100644
--- a/bindings/python/iree/runtime/vm.cc
+++ b/bindings/python/iree/runtime/vm.cc
@@ -690,6 +690,30 @@
       .def_property_readonly(
           "stashed_flatbuffer_blob",
           [](VmModule& self) { return self.get_stashed_flatbuffer_blob(); })
+      .def_property_readonly(
+          "function_names",
+          [](VmModule& self) {
+            py::list names;
+            iree_vm_module_signature_t sig =
+                iree_vm_module_signature(self.raw_ptr());
+            for (size_t ordinal = 0; ordinal < sig.export_function_count;
+                 ++ordinal) {
+              iree_vm_function_t f;
+              iree_string_view_t linkage_name;
+              auto status = iree_vm_module_lookup_function_by_ordinal(
+                  self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &f,
+                  &linkage_name);
+              if (iree_status_is_not_found(status)) {
+                iree_status_ignore(status);
+                break;
+              }
+              CheckApiStatus(status, "Error enumerating module");
+              iree_string_view_t fname = iree_vm_function_name(&f);
+              py::str name(fname.data, fname.size);
+              names.append(name);
+            }
+            return names;
+          })
       .def("__repr__", [](VmModule& self) {
         std::string repr("<VmModule ");
         iree_string_view_t name = iree_vm_module_name(self.raw_ptr());
diff --git a/build_tools/benchmarks/run_benchmarks_on_android.py b/build_tools/benchmarks/run_benchmarks_on_android.py
index fa51dae..1c13a9c 100755
--- a/build_tools/benchmarks/run_benchmarks_on_android.py
+++ b/build_tools/benchmarks/run_benchmarks_on_android.py
@@ -216,6 +216,7 @@
 def filter_benchmarks_for_category(benchmark_category_dir: str,
                                    cpu_target_arch: str,
                                    gpu_target_arch: str,
+                                   driver_filter: Optional[str],
                                    verbose: bool = False) -> Sequence[str]:
   """Filters benchmarks in a specific category for the given device.
 
@@ -223,6 +224,7 @@
   - benchmark_category_dir: the directory to a specific benchmark category.
   - cpu_target_arch: CPU target architecture.
   - gpu_target_arch: GPU target architecture.
+  - driver_filter: only run benchmarks for the given driver if not None.
 
   Returns:
   - A list containing all matched benchmark cases' directories.
@@ -242,10 +244,16 @@
       continue
 
     iree_driver, target_arch, bench_mode = segments
+    iree_driver = iree_driver[len("iree-"):].lower()
     target_arch = target_arch.lower()
-    # We can choose this benchmark if it matches the CPU/GPU architecture.
-    should_choose = (target_arch == cpu_target_arch or
-                     target_arch == gpu_target_arch)
+
+    # We can choose this benchmark if it matches the driver and CPU/GPU
+    # architecture.
+    matched_driver = (
+        driver_filter is None or iree_driver == driver_filter.lower())
+    matched_arch = (
+        target_arch == cpu_target_arch or target_arch == gpu_target_arch)
+    should_choose = matched_driver and matched_arch
     if should_choose:
       matched_benchmarks.append(root)
 
@@ -373,6 +381,7 @@
 def filter_and_run_benchmarks(
     device_info: AndroidDeviceInfo,
     root_build_dir: str,
+    driver_filter: Optional[str],
     normal_benchmark_tool: str,
     traced_benchmark_tool: Optional[str],
     trace_capture_tool: Optional[str],
@@ -382,6 +391,7 @@
   Args:
   - device_info: an AndroidDeviceInfo object.
   - root_build_dir: the root build directory.
+  - driver_filter: only run benchmarks for the given driver if not None.
   - normal_benchmark_tool: the path to the normal benchmark tool.
   - traced_benchmark_tool: the path to the tracing-enabled benchmark tool.
   - trace_capture_tool: the path to the tool for collecting captured traces.
@@ -400,6 +410,7 @@
         benchmark_category_dir=benchmark_category_dir,
         cpu_target_arch=cpu_target_arch,
         gpu_target_arch=gpu_target_arch,
+        driver_filter=driver_filter,
         verbose=verbose)
     run_results = run_benchmarks_for_category(
         device_info=device_info,
@@ -454,6 +465,11 @@
                       type=check_exe_path,
                       default=None,
                       help="Path to the tool for collecting captured traces")
+  parser.add_argument(
+      "--driver",
+      type=str,
+      default=None,
+      help="Only run benchmarks for a specific driver, e.g., 'vulkan'")
   parser.add_argument("-o",
                       dest="output",
                       default=None,
@@ -464,9 +480,8 @@
   parser.add_argument(
       "--no-clean",
       action="store_true",
-      help=
-      "Do not clean up the temporary directory used for benchmarking on the Android device"
-  )
+      help="Do not clean up the temporary directory used for "
+      "benchmarking on the Android device")
   parser.add_argument("--verbose",
                       action="store_true",
                       help="Print internal information during execution")
@@ -499,10 +514,17 @@
           (args.trace_capture_tool is not None):
     execute_cmd_and_get_output(["adb", "forward", "tcp:8086", "tcp:8086"])
 
+    args.traced_benchmark_tool = os.path.realpath(args.traced_benchmark_tool)
+    args.trace_capture_tool = os.path.realpath(args.trace_capture_tool)
+
   results, captures = filter_and_run_benchmarks(
-      device_info, args.build_dir, os.path.realpath(args.normal_benchmark_tool),
-      os.path.realpath(args.traced_benchmark_tool),
-      os.path.realpath(args.trace_capture_tool), args.verbose)
+      device_info=device_info,
+      root_build_dir=args.build_dir,
+      driver_filter=args.driver,
+      normal_benchmark_tool=os.path.realpath(args.normal_benchmark_tool),
+      traced_benchmark_tool=args.traced_benchmark_tool,
+      trace_capture_tool=args.trace_capture_tool,
+      verbose=args.verbose)
 
   if args.output is not None:
     with open(args.output, "w") as f:
diff --git a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index 8297bdd..40f9e05 100644
--- a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -543,7 +543,11 @@
         cast<IREE::HAL::InterfaceBindingSubspanOp>(op).queryBindingOp();
     IREE::HAL::InterfaceBindingSubspanOpAdaptor newOperands(
         operands, op->getAttrDictionary());
-    MemRefType memRefType = op->getResult(0).getType().cast<MemRefType>();
+    MemRefType memRefType = op->getResult(0).getType().dyn_cast<MemRefType>();
+    if (!memRefType)
+      return rewriter.notifyMatchFailure(
+          op,
+          "failed to convert interface.binding.subspan result to memref type");
     auto memRefDesc = abi.loadBinding(
         op->getLoc(), interfaceBindingOp.binding().getZExtValue(),
         newOperands.byte_offset(), memRefType, newOperands.dynamic_dims(),
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp
index a95d79c..aa1a3c6 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndVectorizeLinalgTensorOps.cpp
@@ -158,13 +158,15 @@
     }
   }
 
-  // TODO: This should be a folding of Add into Contract in core but while
-  // they live in different dialects, it is not possible without unnatural
-  // dependencies.
-  funcOp.walk([&](Operation *op) {
-    if (auto contract = canonicalizeContractionAdd(op))
-      op->replaceAllUsesWith(contract);
-  });
+  {
+    // Fold consumer add ops into the contraction op itself.
+    RewritePatternSet canonicalizationPatterns(context);
+    vector::ContractionOp::getCanonicalizationPatterns(canonicalizationPatterns,
+                                                       context);
+    (void)applyPatternsAndFoldGreedily(funcOp,
+                                       std::move(canonicalizationPatterns));
+  }
+
   // Apply vector specific operation lowering.
   {
     vector::VectorTransformsOptions vectorTransformsOptions =
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp
index b9b2a97..e0df89c 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorization.cpp
@@ -238,13 +238,14 @@
                                        std::move(vectorizationPatterns));
   }
 
-  // TODO: This should be a folding of Add into Contract in core but while they
-  // live in different dialects, it is not possible without unnatural
-  // dependencies.
-  funcOp.walk([&](Operation *op) {
-    if (auto contract = canonicalizeContractionAdd(op))
-      op->replaceAllUsesWith(contract);
-  });
+  {
+    // Fold consumer add ops into the contraction op itself.
+    RewritePatternSet canonicalizationPatterns(context);
+    vector::ContractionOp::getCanonicalizationPatterns(canonicalizationPatterns,
+                                                       context);
+    (void)applyPatternsAndFoldGreedily(funcOp,
+                                       std::move(canonicalizationPatterns));
+  }
 
   if (enableVectorContractToAarch64Asm) {
     RewritePatternSet vectorToAArch64AsmPatterns(context);
diff --git a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
index 23b8590..ab5b9e5 100644
--- a/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorization.cpp
@@ -86,13 +86,13 @@
       populateVectorizationPatterns(vectorizationPatterns);
       (void)applyPatternsAndFoldGreedily(funcOp,
                                          std::move(vectorizationPatterns));
-      // TODO: This should be a folding of Add into Contract in core but while
-      // they live in different dialects, it is not possible without unnatural
-      // dependencies.
-      funcOp.walk([&](Operation *op) {
-        if (auto contract = canonicalizeContractionAdd(op))
-          op->replaceAllUsesWith(contract);
-      });
+
+      // Fold consumer add ops into the contraction op itself.
+      RewritePatternSet canonicalizationPatterns(context);
+      vector::ContractionOp::getCanonicalizationPatterns(
+          canonicalizationPatterns, context);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(canonicalizationPatterns));
 
       RewritePatternSet vectorUnrollPatterns(context);
       populateVectorUnrollPatterns(vectorUnrollPatterns);
diff --git a/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 99243bf..514d916 100644
--- a/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -107,11 +107,13 @@
   pm.addPass(createTensorConstantBufferizePass());
   pm.addPass(createFoldTensorExtractOpPass());
 
+  pm.addNestedPass<FuncOp>(createLLVMGPUVectorLoweringPass());
+
   // SCF -> STD
   pm.addNestedPass<FuncOp>(createLowerToCFGPass());
   pm.addNestedPass<FuncOp>(createCanonicalizerPass());
   pm.addNestedPass<FuncOp>(createCSEPass());
-  pm.addNestedPass<FuncOp>(createLLVMGPUVectorLoweringPass());
+
   pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
   pm.addPass(createLowerAffinePass());
 
diff --git a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
index 525b0cc..8557c11 100644
--- a/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
+++ b/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
@@ -226,6 +226,13 @@
                                                   vectorizationPatterns);
       (void)applyPatternsAndFoldGreedily(funcOp,
                                          std::move(vectorizationPatterns));
+
+      // Fold consumer add ops into the contraction op itself.
+      RewritePatternSet canonicalizationPatterns(context);
+      vector::ContractionOp::getCanonicalizationPatterns(
+          canonicalizationPatterns, context);
+      (void)applyPatternsAndFoldGreedily(funcOp,
+                                         std::move(canonicalizationPatterns));
     }
 
     LLVM_DEBUG({
@@ -234,13 +241,6 @@
       llvm::dbgs() << "\n\n";
     });
 
-    // TODO: This should be a folding of Add into Contract in core but while
-    // they live in different dialects, it is not possible without unnatural
-    // dependencies.
-    funcOp.walk([&](Operation *op) {
-      if (auto contract = canonicalizeContractionAdd(op))
-        op->replaceAllUsesWith(contract);
-    });
 
     {
       RewritePatternSet vectorUnrollPatterns(funcOp.getContext());
diff --git a/iree/compiler/Codegen/Transforms/Transforms.cpp b/iree/compiler/Codegen/Transforms/Transforms.cpp
index 5a4778f..b4b6056 100644
--- a/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -69,50 +69,5 @@
   return success();
 }
 
-/// Return a fused vector::ContractionOp which represents a patterns such as:
-///
-/// ```mlir
-///    %c0 = vector.constant 0: ...
-///    %c = vector.contract %a, %b, %c0: ...
-///    %e = add %c, %d: ...
-/// ```
-///
-/// by:
-///
-/// ```mlir
-///    %e = vector.contract %a, %b, %d: ...
-/// ```
-///
-/// Return null if the canonicalization does not apply.
-// TODO: This should be a folding of Add into Contract in core but while they
-// live in different dialects, it is not possible without unnatural
-// dependencies.
-vector::ContractionOp canonicalizeContractionAdd(Operation *op) {
-  if (!isa<AddIOp, AddFOp>(op)) return nullptr;
-
-  OpBuilder builder(op);
-  auto canonicalize = [](OpBuilder &b, Value maybeContraction,
-                         Value otherOperand) -> vector::ContractionOp {
-    vector::ContractionOp contractionOp =
-        dyn_cast_or_null<vector::ContractionOp>(
-            maybeContraction.getDefiningOp());
-    if (!contractionOp) return nullptr;
-    if (auto maybeZero =
-            dyn_cast_or_null<ConstantOp>(contractionOp.acc().getDefiningOp())) {
-      if (maybeZero.value() == b.getZeroAttr(contractionOp.acc().getType())) {
-        BlockAndValueMapping bvm;
-        bvm.map(contractionOp.acc(), otherOperand);
-        return cast<vector::ContractionOp>(b.clone(*contractionOp, bvm));
-      }
-    }
-    return nullptr;
-  };
-
-  Value a = op->getOperand(0), b = op->getOperand(1);
-  vector::ContractionOp contract = canonicalize(builder, a, b);
-  contract = contract ? contract : canonicalize(builder, b, a);
-  return contract;
-}
-
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Codegen/Transforms/Transforms.h b/iree/compiler/Codegen/Transforms/Transforms.h
index 1b30ff2..e8270de 100644
--- a/iree/compiler/Codegen/Transforms/Transforms.h
+++ b/iree/compiler/Codegen/Transforms/Transforms.h
@@ -32,26 +32,6 @@
     OpBuilder &builder, FuncOp funcOp,
     WorkgroupCountRegionBuilder regionBuilder);
 
-/// Return a fused vector::ContractionOp which represents a patterns such as:
-///
-/// ```mlir
-///    %c0 = vector.constant 0: ...
-///    %c = vector.contract %a, %b, %c0: ...
-///    %e = add %c, %d: ...
-/// ```
-///
-/// by:
-///
-/// ```mlir
-///    %e = vector.contract %a, %b, %d: ...
-/// ```
-///
-/// Return null if the canonicalization does not apply.
-// TODO: This should be a folding of Add into Contract in core but while they
-// live in different dialects, it is not possible without unnatural
-// dependencies.
-vector::ContractionOp canonicalizeContractionAdd(Operation *op);
-
 /// Insert patterns to perform folding of AffineMinOp by matching the pattern
 /// generated by tile and distribute. Try to fold a affine.min op by matching
 /// the following form:
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 30c7932..b375029 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -239,11 +239,31 @@
 static std::pair<IREE::Flow::DispatchWorkgroupsOp, Operation *>
 buildOperandLessFlowDispatchWorkgroupOp(PatternRewriter &rewriter, Location loc,
                                         ArrayRef<Value> count, Operation *op) {
+  SmallVector<Value> operands, operandDims;
+  SmallVector<int64_t> tiedOperands;
+  if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
+    // Handle tensor.insert_slice in a special manner. This op is actually two
+    // steps:
+    // 1) Copy over the dest tensor to the result,
+    // 2) Update the overwritten part of the result with the destination.
+    // To actually make this work, the dispatch region needs the `dest` and
+    // result to be tied operands. This is somehow special. It might fall out
+    // naturally, but not sure how. For now, just do it by construction.
+    operands.push_back(insertSliceOp.dest());
+    ReifiedRankedShapedTypeDims resultShapes;
+    (void)insertSliceOp.reifyResultShapes(rewriter, resultShapes);
+    auto destType = insertSliceOp.dest().getType().cast<ShapedType>();
+    for (auto shape : enumerate(destType.getShape())) {
+      if (shape.value() != ShapedType::kDynamicSize) {
+        continue;
+      }
+      operandDims.push_back(resultShapes[0][shape.index()]);
+    }
+    tiedOperands.push_back(0);
+  }
   auto dispatchOp = rewriter.create<IREE::Flow::DispatchWorkgroupsOp>(
-      loc, count, op->getResultTypes(), /*result_dims=*/ValueRange{},
-      /*operands=*/ValueRange{},
-      /*operand_dims=*/ValueRange{},
-      /*tied_operands=*/ArrayRef<int64_t>{});
+      loc, count, op->getResultTypes(), /*result_dims=*/ValueRange{}, operands,
+      operandDims, tiedOperands);
   Region &region = dispatchOp.body();
   Block *block = &region.front();
   Operation *clonedOp;
@@ -423,15 +443,13 @@
   return orderedOps;
 }
 
-/// Computes the values that will be eventually be used within the dispatch
+/// Computes the values that will eventually be used within the dispatch
 /// workgroup op but defined outside the op after all clonable operations are
-/// cloned into the region. Returns (by reference) the clonable operations too,
-/// in order in which they can be cloned within the region to satisfy use-def
-/// relationships between them.
+/// cloned into the region.
 static void getUsedValuesDefinedAboveAfterCloningOps(
-    IREE::Flow::DispatchWorkgroupsOp dispatchOp,
-    llvm::SetVector<Value> &valuesDefinedAbove,
-    llvm::SmallVector<Operation *> &clonedOps) {
+    OpBuilder &builder, IREE::Flow::DispatchWorkgroupsOp dispatchOp,
+    llvm::SetVector<Value> &valuesDefinedAbove) {
+  llvm::SmallVector<Operation *> clonedOps;
   llvm::SetVector<Value> visited;
   SmallVector<Value, 4> worklist;
   worklist.assign(valuesDefinedAbove.begin(), valuesDefinedAbove.end());
@@ -452,6 +470,21 @@
   // The cloned operations form a DAG. Return the cloned operations so the
   // leaves come first, and can be cloned in-order into the dispatch region.
   clonedOps = orderOperations(clonedOps);
+
+  for (auto clonedOp : reverse(clonedOps)) {
+    Operation *clone = builder.clone(*clonedOp);
+    for (auto result : llvm::enumerate(clonedOp->getResults())) {
+      result.value().replaceUsesWithIf(
+          clone->getResult(result.index()), [&](OpOperand &use) {
+            return use.getOwner()
+                       ->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>() ==
+                   dispatchOp;
+          });
+      valuesDefinedAbove.remove(result.value());
+    }
+    builder.setInsertionPoint(clone);
+  }
+
   // Reverse the values. This is not for correctness, but more for readability
   // of the IR.
   llvm::SetVector<Value> reversedValues;
@@ -459,86 +492,115 @@
   std::swap(reversedValues, valuesDefinedAbove);
 }
 
+/// Returns the tied operand for the given `resultArg`. Returns nullptr if error
+/// or not found.
+static BlockArgument getTiedOperandBlockArgument(BlockArgument resultArg) {
+  auto resultArgType =
+      resultArg.getType().dyn_cast<IREE::Flow::DispatchTensorType>();
+  if (!resultArgType ||
+      resultArgType.getAccess() != IREE::Flow::TensorAccess::WriteOnly) {
+    return nullptr;
+  }
+  // Each output block argument should just have one use.
+  if (!resultArg.hasOneUse()) return nullptr;
+
+  // And that's a flow.dispatch.output.store op.
+  auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(
+      (*resultArg.getUses().begin()).getOwner());
+  if (!storeOp) return nullptr;
+
+  Operation *tieOp = storeOp.value().getDefiningOp();
+  if (!tieOp) return nullptr;
+
+  // TODO(antiagainst): use TiedOpInterface here instead of hardcoding ops
+  // when it's available in MLIR core in some form.
+  BlockArgument tiedArg =
+      TypeSwitch<Operation *, BlockArgument>(tieOp)
+          .Case<tensor::InsertSliceOp>([&](tensor::InsertSliceOp insertOp)
+                                           -> BlockArgument {
+            auto loadOp =
+                insertOp.dest()
+                    .template getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+            if (!loadOp) return nullptr;
+            return loadOp.source().dyn_cast<BlockArgument>();
+          })
+          .Case<IREE::Flow::DispatchTensorLoadOp>(
+              [&](auto loadOp) -> BlockArgument {
+                // Check that there is a single use and that the source is
+                // block argument. Single use can potentially be relaxed.
+                auto loadArg =
+                    loadOp.source().template dyn_cast<BlockArgument>();
+                if (!loadArg || !loadArg.hasOneUse()) {
+                  return nullptr;
+                }
+                return loadArg;
+              })
+          .Case<linalg::LinalgOp,
+                linalg_ext::LinalgExtOp>([&](auto linalgLikeOp)
+                                             -> BlockArgument {
+            unsigned resultIndex =
+                storeOp.value().cast<OpResult>().getResultNumber();
+            auto loadOp =
+                linalgLikeOp.getOutputTensorOperands()[resultIndex]
+                    ->get()
+                    .template getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+            if (!loadOp) return nullptr;
+            return loadOp.source().template dyn_cast<BlockArgument>();
+          })
+          .Default([&](Operation *) -> BlockArgument { return nullptr; });
+
+  if (!tiedArg) {
+    return nullptr;
+  }
+
+  // CHeck that the type of the tied argument candidate and type of the output
+  // match and that the tied argument is readonly.
+  auto type = tiedArg.getType().dyn_cast<IREE::Flow::DispatchTensorType>();
+  if (!type || type.getAccess() != IREE::Flow::TensorAccess::ReadOnly ||
+      type.getElementType() != resultArgType.getElementType() ||
+      llvm::any_of(llvm::zip(type.getShape(), resultArgType.getShape()),
+                   [](std::tuple<int64_t, int64_t> sizes) {
+                     return std::get<0>(sizes) !=
+                                IREE::Flow::DispatchTensorType::kDynamicSize &&
+                            std::get<1>(sizes) !=
+                                IREE::Flow::DispatchTensorType::kDynamicSize &&
+                            std::get<0>(sizes) != std::get<1>(sizes);
+                   })) {
+    return nullptr;
+  }
+  return tiedArg;
+}
+
 /// Modifies `dispatchOp` to attach operand-result tie information when
 /// possible.
 static void tryToTieOperandsAndResults(
     IREE::Flow::DispatchWorkgroupsOp dispatchOp) {
   Block *block = dispatchOp.getBody(0);
-  unsigned numResults = dispatchOp.getNumResults();
-  auto inputs = block->getArguments().drop_back(numResults);
-  auto outputs = block->getArguments().take_back(numResults);
+  unsigned numOperands = dispatchOp.getODSOperandIndexAndLength(1).second;
 
-  // Returns the tied operand for the given `resultArg`. Returns nullptr
-  // if error or not found.
-  auto getTiedOperandBlockArgument =
-      [](BlockArgument resultArg) -> BlockArgument {
-    // Each output block argument should just have one use.
-    if (!llvm::hasSingleElement(resultArg.getUses())) return nullptr;
-
-    // And that's a flow.dispatch.output.store op.
-    auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(
-        (*resultArg.getUses().begin()).getOwner());
-    if (!storeOp) return nullptr;
-
-    Operation *tieOp = storeOp.value().getDefiningOp();
-    if (!tieOp) return nullptr;
-
-    // TODO(antiagainst): use TiedOpInterface here instead of hardcoding ops
-    // when it's available in MLIR core in some form.
-    BlockArgument tiedArg =
-        TypeSwitch<Operation *, BlockArgument>(tieOp)
-            .Case<tensor::InsertSliceOp>(
-                [&](tensor::InsertSliceOp insertOp) -> BlockArgument {
-                  auto loadOp = insertOp.dest()
-                                    .template getDefiningOp<
-                                        IREE::Flow::DispatchTensorLoadOp>();
-                  if (!loadOp) return nullptr;
-                  return loadOp.source().dyn_cast<BlockArgument>();
-                })
-            .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp>(
-                [&](auto linalgLikeOp) -> BlockArgument {
-                  unsigned resultIndex =
-                      storeOp.value().cast<OpResult>().getResultNumber();
-                  auto loadOp =
-                      linalgLikeOp.getOutputTensorOperands()[resultIndex]
-                          ->get()
-                          .template getDefiningOp<
-                              IREE::Flow::DispatchTensorLoadOp>();
-                  if (!loadOp) return nullptr;
-                  return loadOp.source().template dyn_cast<BlockArgument>();
-                })
-            .Default([&](Operation *) -> BlockArgument { return nullptr; });
-
-    return tiedArg;
-  };
-
-  SmallVector<BlockArgument, 4> tiedOperands;
-  tiedOperands.reserve(numResults);
-
-  // Collect all result argument's tied operand arguments.
-  for (BlockArgument &arg : outputs) {
-    tiedOperands.push_back(getTiedOperandBlockArgument(arg));
-  }
-
+  SmallVector<unsigned> eraseArguments;
   // Go over each result to tie operand when possible, by:
   // 1. Update the tied operand argument to take readwrite tensors.
   // 2. Erase the result argument.
   // 3. Attach the tie information to the DispatchWorkgroupsOp.
-  for (int i = outputs.size() - 1; i >= 0; --i) {
-    BlockArgument inputArg = tiedOperands[i];
-    if (!inputArg) continue;
-
-    auto oldType = inputArg.getType().cast<IREE::Flow::DispatchTensorType>();
-    inputArg.setType(IREE::Flow::DispatchTensorType::get(
+  for (auto result : llvm::enumerate(dispatchOp.getResults())) {
+    if (dispatchOp.getTiedResultOperand(result.value())) continue;
+    BlockArgument outputArgument =
+        block->getArgument(numOperands + result.index());
+    BlockArgument tiedOperandArgument =
+        getTiedOperandBlockArgument(outputArgument);
+    if (!tiedOperandArgument) continue;
+    auto oldType =
+        tiedOperandArgument.getType().cast<IREE::Flow::DispatchTensorType>();
+    tiedOperandArgument.setType(IREE::Flow::DispatchTensorType::get(
         IREE::Flow::TensorAccess::ReadWrite, oldType.getShape(),
         oldType.getElementType()));
-
-    BlockArgument outputArg = block->getArgument(inputs.size() + i);
-    outputArg.replaceAllUsesWith(inputArg);
-    block->eraseArgument(inputs.size() + i);
-
-    dispatchOp.setTiedResultOperandIndex(i, inputArg.getArgNumber());
+    outputArgument.replaceAllUsesWith(tiedOperandArgument);
+    eraseArguments.push_back(outputArgument.getArgNumber());
+    dispatchOp.setTiedResultOperandIndex(result.index(),
+                                         tiedOperandArgument.getArgNumber());
   }
+  block->eraseArguments(eraseArguments);
 }
 
 static void replaceAllUsesWithinDispatchOp(
@@ -563,78 +625,92 @@
   Location loc = dispatchOp.getLoc();
   Region &region = dispatchOp.body();
   Block &block = region.front();
-  unsigned numOldBBArgs = block.getNumArguments();
   OpBuilder b = OpBuilder::atBlockBegin(&block);
 
   llvm::SetVector<Value> valuesDefinedAbove;
-  llvm::SmallVector<Operation *> clonedOps;
   mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
   if (valuesDefinedAbove.empty()) return success();
 
-  getUsedValuesDefinedAboveAfterCloningOps(dispatchOp, valuesDefinedAbove,
-                                           clonedOps);
+  getUsedValuesDefinedAboveAfterCloningOps(b, dispatchOp, valuesDefinedAbove);
+  b.setInsertionPointToStart(&block);
 
-  BlockAndValueMapping map;
-  SmallVector<Value> toReplaceWithinRegion;
-  // Replace valuesDefinedAbove by new BB args (including the op's operands).
-  for (Value operand : valuesDefinedAbove) {
-    if (auto rt = operand.getType().dyn_cast<RankedTensorType>()) {
-      block.addArgument(IREE::Flow::DispatchTensorType::get(
-          TensorAccess::ReadOnly, rt.getShape(), rt.getElementType()));
-    } else {
-      block.addArgument(operand.getType());
+  // Build a map from current operands to arguments.
+  std::pair<unsigned, unsigned> operandsIndexAndLength =
+      dispatchOp.getODSOperandIndexAndLength(1);
+  std::pair<unsigned, unsigned> operandDimsIndexAndLength =
+      dispatchOp.getODSOperandIndexAndLength(2);
+  llvm::DenseMap<Value, BlockArgument> operandToBBArg;
+  for (auto operand : llvm::enumerate(dispatchOp.operands())) {
+    operandToBBArg[operand.value()] = block.getArgument(operand.index());
+  }
+
+  // Of the values defined above and used in the region, add values that are not
+  // operands to the region already.
+  unsigned numOperands = operandsIndexAndLength.second;
+  unsigned numOperandDims = operandDimsIndexAndLength.second;
+  for (auto value : valuesDefinedAbove) {
+    BlockArgument bbArg = operandToBBArg.lookup(value);
+    auto tensorType = value.getType().dyn_cast<RankedTensorType>();
+    if (!bbArg) {
+      // Create a new basic block argument for this value.
+      Type bbArgType = value.getType();
+      if (tensorType) {
+        bbArgType = IREE::Flow::DispatchTensorType::get(
+            TensorAccess::ReadOnly, tensorType.getShape(),
+            tensorType.getElementType());
+      }
+      bbArg = block.insertArgument(numOperands, bbArgType, value.getLoc());
     }
 
-    Value bbArg = block.getArguments().back();
     Value repl = bbArg;
     if (bbArg.getType().isa<IREE::Flow::DispatchTensorType>()) {
+      // For arguments of type flow.dispatch.tensor, create a
+      // flow.dispatch.tensor.load to get the replacement values.
       repl = b.create<IREE::Flow::DispatchTensorLoadOp>(
-          loc, operand.getType().cast<RankedTensorType>(), bbArg);
+          loc, value.getType().cast<RankedTensorType>(), bbArg);
     }
-    map.map(operand, repl);
-    toReplaceWithinRegion.push_back(operand);
-  }
 
-  // The only existing arguments are for the outputs. Just need to add a new
-  // argument for the outputs and remap the value to use the new argument.
-  for (unsigned argNum : llvm::seq<unsigned>(0, numOldBBArgs)) {
-    BlockArgument arg = block.getArgument(argNum);
-    assert(arg.getType().isa<IREE::Flow::DispatchTensorType>());
-    arg.replaceAllUsesWith(block.addArgument(arg.getType()));
-  }
-  // Drop old BB args.
-  block.eraseArguments(
-      llvm::to_vector<4>(llvm::seq<unsigned>(0, numOldBBArgs)));
+    value.replaceUsesWithIf(repl, [&](OpOperand &use) {
+      return use.getOwner()
+                 ->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>() ==
+             dispatchOp;
+    });
 
-  // Clone the marked operations.
-  for (Operation *op : clonedOps) {
-    b.clone(*op, map);
-    toReplaceWithinRegion.append(op->result_begin(), op->result_end());
-  }
-
-  // Make the region isolated from above.
-  for (auto value : toReplaceWithinRegion) {
-    replaceAllUsesWithinDispatchOp(dispatchOp, value, map.lookup(value));
-  }
-
-  // Gather the dynamic dimensions for all operands.
-  SmallVector<Value, 4> operandDynamicDims;
-  OpBuilder builder(dispatchOp);
-  for (Value operand : valuesDefinedAbove) {
-    if (auto rt = operand.getType().dyn_cast<RankedTensorType>()) {
-      for (unsigned i = 0; i < rt.getRank(); ++i) {
-        if (!rt.isDynamicDim(i)) continue;
-        auto dim = builder.createOrFold<tensor::DimOp>(dispatchOp.getLoc(),
-                                                       operand, i);
-        operandDynamicDims.push_back(dim);
+    // Insert the operand if this is not already one. Also need to account for
+    // dynamic dim values for the operands.
+    if (!operandToBBArg.count(value)) {
+      dispatchOp->insertOperands(operandsIndexAndLength.first + numOperands,
+                                 {value});
+      numOperands++;
+      if (tensorType) {
+        // This dims for this operand does not exist. Add those.
+        OpBuilder::InsertionGuard g(b);
+        b.setInsertionPoint(dispatchOp);
+        SmallVector<Value> dynamicDims;
+        for (auto dim : llvm::enumerate(tensorType.getShape())) {
+          if (dim.value() != ShapedType::kDynamicSize) continue;
+          dynamicDims.push_back(b.createOrFold<tensor::DimOp>(
+              dispatchOp.getLoc(), value, dim.index()));
+        }
+        dispatchOp->insertOperands(
+            operandsIndexAndLength.first + numOperands + numOperandDims,
+            dynamicDims);
+        numOperandDims += dynamicDims.size();
       }
     }
   }
 
-  // Set the values captured from above as the new operands.
-  dispatchOp.operandsMutable().assign(llvm::to_vector<4>(valuesDefinedAbove));
-  dispatchOp.operand_dimsMutable().assign(operandDynamicDims);
-
+  // Update the `operand_segment_sizes`.
+  auto operandSegmentSizes = dispatchOp->getAttrOfType<DenseIntElementsAttr>(
+      dispatchOp.operand_segment_sizesAttrName());
+  auto newValues = llvm::to_vector<4>(llvm::map_range(
+      operandSegmentSizes.getValues<APInt>(),
+      [&](APInt val) -> int32_t { return val.getSExtValue(); }));
+  newValues[1] = numOperands;
+  newValues[2] = numOperandDims;
+  auto newAttr =
+      DenseIntElementsAttr::get(operandSegmentSizes.getType(), newValues);
+  dispatchOp->setAttr(dispatchOp.operand_segment_sizesAttrName(), newAttr);
   return success();
 }
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 11a1a55..732502e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -457,30 +457,59 @@
 
 // -----
 
-// A subsequent pass is expected to convert linalg.fill into DMA ops.
-func @subtensor_insert(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x225x225x3xf32> {
-  %cst = constant 0.000000e+00 : f32
-  %0 = linalg.init_tensor [1, 225, 225, 3] : tensor<1x225x225x3xf32>
-  %1 = linalg.fill(%cst, %0) : f32, tensor<1x225x225x3xf32> -> tensor<1x225x225x3xf32>
-  %2 = tensor.insert_slice %arg0 into %1[0, 0, 0, 0] [1, 224, 224, 3] [1, 1, 1, 1] : tensor<1x224x224x3xf32> into tensor<1x225x225x3xf32>
-  return %2 : tensor<1x225x225x3xf32>
+func @subtensor_insert(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index) -> tensor<?x?xf32> {
+  %0 = tensor.insert_slice %arg0 into
+      %arg1[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
 }
 
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
 //      CHECK: func @subtensor_insert
-// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x224x224x3xf32>)
-//
-//  CHECK-NOT:   flow.dispatch.workgroups
-//      CHECK:   %[[FILL:.+]] =  linalg.fill
-//
-//      CHECK:   %[[PAD:.+]] = flow.dispatch.workgroups[{{.+}}](%[[INPUT]], %[[FILL]]) : (tensor<1x224x224x3xf32>, tensor<1x225x225x3xf32>) -> %[[FILL]] =
-// CHECK-NEXT:       (%[[SRC:.+]]: !flow.dispatch.tensor<readonly:1x224x224x3xf32>, %[[DST:.+]]: !flow.dispatch.tensor<readwrite:1x225x225x3xf32>) {
-// CHECK-NEXT:     %[[SRC_TENSOR:.+]] = flow.dispatch.tensor.load %[[SRC]], {{.*}} : !flow.dispatch.tensor<readonly:1x224x224x3xf32> -> tensor<1x224x224x3xf32>
-// CHECK-NEXT:     %[[DST_TENSOR:.+]] = flow.dispatch.tensor.load %[[DST]], {{.*}} : !flow.dispatch.tensor<readwrite:1x225x225x3xf32> -> tensor<1x225x225x3xf32>
-// CHECK-NEXT:     %[[INSERT:.+]] = tensor.insert_slice %[[SRC_TENSOR]] into %[[DST_TENSOR]][0, 0, 0, 0] [1, 224, 224, 3] [1, 1, 1, 1]
-// CHECK-NEXT:     flow.dispatch.tensor.store %[[INSERT]], %[[DST]], {{.*}} : tensor<1x225x225x3xf32> -> !flow.dispatch.tensor<readwrite:1x225x225x3xf32>
-// CHECK-NEXT:     flow.return
-//
-//      CHECK:   return %[[PAD]] : tensor<1x225x225x3xf32>
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C0]]
+//  CHECK-DAG:   %[[D3:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//      CHECK:   %[[RESULT:.+]] = flow.dispatch.workgroups[%[[D1]], %[[D0]], %[[C1]]]
+// CHECK-SAME:       (%[[ARG1]], %[[ARG0]], %[[ARG2]], %[[ARG3]])
+// CHECK-SAME:       tensor<?x?xf32>{%[[D2]], %[[D3]]}
+// CHECK-SAME:       tensor<?x?xf32>{%[[D0]], %[[D1]]}
+// CHECK-SAME:       -> %[[ARG1]]{%[[D2]], %[[D3]]}
+// CHECK-NEXT:     %[[ARG6:.+]]: !flow.dispatch.tensor<readwrite:?x?xf32>
+// CHECK-SAME:     %[[ARG7:.+]]: !flow.dispatch.tensor<readonly:?x?xf32>
+// CHECK-SAME:     %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG9:[a-zA-Z0-9]+]]: index
+//  CHECK-DAG:     %[[WGSIZE_X:.+]] = flow.dispatch.workgroup.size[0]
+//  CHECK-DAG:     %[[WGSIZE_Y:.+]] = flow.dispatch.workgroup.size[1]
+//  CHECK-DAG:     %[[SHAPE:.+]] = flow.dispatch.shape %[[ARG7]]
+//      CHECK:     %[[UB_Y:.+]] = shapex.ranked_dim %[[SHAPE]][0]
+//      CHECK:     %[[UB_X:.+]] = shapex.ranked_dim %[[SHAPE]][1]
+//  CHECK-DAG:     %[[WGID_X:.+]] = flow.dispatch.workgroup.id[0]
+//  CHECK-DAG:     %[[WGCOUNT_X:.+]] = flow.dispatch.workgroup.count[0]
+//  CHECK-DAG:     %[[WGID_Y:.+]] = flow.dispatch.workgroup.id[1]
+//  CHECK-DAG:     %[[WGCOUNT_Y:.+]] = flow.dispatch.workgroup.count[1]
+//  CHECK-DAG:     %[[OFFSET_Y:.+]] = affine.apply #[[MAP0]](%[[WGID_Y]])[%[[WGSIZE_Y]]]
+//  CHECK-DAG:     %[[STEP_Y:.+]] = affine.apply #[[MAP0]](%[[WGCOUNT_Y]])[%[[WGSIZE_Y]]]
+//      CHECK:     scf.for %[[ARG10:.+]] = %[[OFFSET_Y]] to %[[UB_Y]] step %[[STEP_Y]]
+//  CHECK-DAG:       %[[TILESIZE_Y:.+]] = affine.min #[[MAP1]](%[[ARG10]])[%[[WGSIZE_Y]], %[[UB_Y]]]
+//  CHECK-DAG:       %[[OFFSET_X:.+]] = affine.apply #[[MAP0]](%[[WGID_X]])[%[[WGSIZE_X]]]
+//  CHECK-DAG:       %[[STEP_X:.+]] = affine.apply #[[MAP0]](%[[WGCOUNT_X]])[%[[WGSIZE_X]]]
+//      CHECK:       scf.for %[[ARG11:.+]] = %[[OFFSET_X]] to %[[UB_X]] step %[[STEP_X]]
+//      CHECK:         %[[TILESIZE_X:.+]] = affine.min #[[MAP1]](%[[ARG11]])[%[[WGSIZE_X]], %[[UB_X]]]
+//      CHECK:         %[[LOAD_TILE:.+]] = flow.dispatch.tensor.load %[[ARG7]]
+// CHECK-SAME:             offsets = [%[[ARG10]], %[[ARG11]]], sizes = [%[[TILESIZE_Y]], %[[TILESIZE_X]]]
+//  CHECK-DAG:         %[[STORE_OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[ARG10]])[%[[ARG8]]]
+//  CHECK-DAG:         %[[STORE_OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[ARG11]])[%[[ARG9]]]
+//      CHECK:         flow.dispatch.tensor.store %[[LOAD_TILE]], %[[ARG6]]
+//      CHECK:   return %[[RESULT]]
 
 // -----
 
@@ -1096,3 +1125,40 @@
 //       CHECK:     %[[FILL:.+]] = linalg.fill_rng_2d
 //       CHECK:     linalg.matmul
 //  CHECK-SAME:       outs(%[[FILL]] : tensor<?x?xf32>)
+
+// -----
+
+func @dynamic_slice(%arg0 : i32, %arg1 : i32, %arg2 : tensor<?xi32>,
+    %arg3 : tensor<?x?xi32>) -> tensor<?x?xi32>{
+  %c0 = constant 0 : index
+  %c0_i32 = constant 0 : i32
+  %c2_i32 = constant 2 : i32
+  %5 = cmpi slt, %arg0, %c2_i32 : i32
+  %6 = select %5, %arg0, %c2_i32 : i32
+  %7 = cmpi sgt, %6, %c0_i32 : i32
+  %8 = select %7, %6, %c0_i32 : i32
+  %9 = index_cast %8 : i32 to index
+  %11 = cmpi slt, %arg1, %c0_i32 : i32
+  %12 = select %11, %arg1, %c0_i32 : i32
+  %13 = cmpi sgt, %12, %c0_i32 : i32
+  %14 = select %13, %12, %c0_i32 : i32
+  %15 = index_cast %14 : i32 to index
+  %d0 = tensor.dim %arg2, %c0 : tensor<?xi32>
+  %17 = tensor.insert_slice %arg2 into
+      %arg3[%9, %15] [1, %d0] [1, 1] : tensor<?xi32> into tensor<?x?xi32>
+  return %17 : tensor<?x?xi32>
+}
+// CHECK-LABEL: func @dynamic_slice
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: i32
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: i32
+//  CHECK-SAME:     %[[ARG2:.+]]: tensor<?xi32>
+//  CHECK-SAME:     %[[ARG3:.+]]: tensor<?x?xi32>
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG2]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG3]], %[[C0]]
+//   CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG3]], %[[C1]]
+//       CHECK:   flow.dispatch.workgroups[%[[D0]], %[[C1]], %[[C1]]]
+//  CHECK-SAME:       tensor<?x?xi32>{%[[D1]], %[[D2]]}, tensor<?xi32>{%[[D0]]}
+//  CHECK-NEXT:     %[[ARG4:.+]]: !flow.dispatch.tensor<readwrite:?x?xi32>
+//  CHECK-SAME:     %[[ARG5:.+]]: !flow.dispatch.tensor<readonly:?xi32>
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
index fd8c3a5..2cd8a14 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
@@ -31,16 +31,117 @@
   return valueOrAttr.get<Value>();
 }
 
+//===----------------------------------------------------------------------===//
+// Interface implementations for external operations.
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct InsertSliceTiledOpInterface
+    : public TiledOpInterface::ExternalModel<InsertSliceTiledOpInterface,
+                                             tensor::InsertSliceOp> {
+  SmallVector<Value> getDestinationOperands(Operation *op) const {
+    SmallVector<Value> dest;
+    dest.push_back(cast<tensor::InsertSliceOp>(op).dest());
+    return dest;
+  }
+
+  SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+    return SmallVector<StringRef>(insertSliceOp.getSourceType().getRank(),
+                                  getParallelIteratorTypeName());
+  }
+
+  SmallVector<Range> getLoopBounds(Operation *op, OpBuilder &b) const {
+    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+    Value source = insertSliceOp.source();
+    RankedTensorType sourceType = insertSliceOp.getSourceType();
+    Location loc = op->getLoc();
+    Value zero = b.create<ConstantIndexOp>(loc, 0);
+    Value one = b.create<ConstantIndexOp>(loc, 1);
+    SmallVector<Range> loopBounds(sourceType.getRank(),
+                                  Range{zero, nullptr, one});
+    for (auto dim :
+         llvm::seq<int64_t>(0, insertSliceOp.getSourceType().getRank())) {
+      loopBounds[dim].size = b.create<tensor::DimOp>(loc, source, dim);
+    }
+    return loopBounds;
+  }
+
+  Operation *getTiledImplementation(Operation *op, OpBuilder &b,
+                                    ValueRange outputs,
+                                    ArrayRef<OpFoldResult> offsets,
+                                    ArrayRef<OpFoldResult> sizes,
+                                    SmallVectorImpl<Value> &results) const {
+    auto insertOp = cast<tensor::InsertSliceOp>(op);
+    // Compute a subtensor of the source based on the offsets.
+    auto opStrides = insertOp.getMixedStrides();
+    if (!llvm::all_of(opStrides, [&](OpFoldResult valueOrAttr) {
+          Optional<int64_t> intVal = getConstantIntValue(valueOrAttr);
+          return intVal && *intVal == 1;
+        })) {
+      op->emitOpError("unable to tile operation with non-unit stride");
+      return nullptr;
+    }
+    Location loc = insertOp.getLoc();
+    auto oneAttr = b.getI64IntegerAttr(1);
+    SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
+    auto extractSliceOp = b.create<tensor::ExtractSliceOp>(
+        loc, insertOp.source(), offsets, sizes, strides);
+
+    // The offsets for the insert is based on the op offsets plus the offsets of
+    // the loops passed in.
+    auto opOffsets = insertOp.getMixedOffsets();
+    auto opSizes = insertOp.getMixedSizes();
+    unsigned offsetIndex = 0;
+    ArrayRef<int64_t> sourceShape = insertOp.getSourceType().getShape();
+    int64_t destRank = insertOp.getType().getRank();
+    SmallVector<OpFoldResult> resultOffsets(destRank);
+    SmallVector<OpFoldResult> resultSizes(destRank);
+    for (auto opOffset : llvm::enumerate(opOffsets)) {
+      // Check for rank-reducing by checking that
+      // 1) The corresponding opSize value is 1
+      // 2) The current rank of the source is not 1.
+      // Then the opOffset is for the rank-reduced dimension. Skip.
+      unsigned opOffsetIndex = opOffset.index();
+      OpFoldResult opOffsetVal = opOffset.value();
+      Optional<int64_t> opSizeVal = getConstantIntValue(opSizes[opOffsetIndex]);
+      if (offsetIndex >= sourceShape.size() ||
+          (opSizeVal && *opSizeVal == 1 && sourceShape[offsetIndex] != 1)) {
+        resultOffsets[opOffsetIndex] = opOffsetVal;
+        resultSizes[opOffsetIndex] = oneAttr;
+        continue;
+      }
+      OpFoldResult offset = offsets[offsetIndex];
+      if (opOffsetVal.is<Attribute>() && offset.is<Attribute>()) {
+        resultOffsets[opOffsetIndex] = b.getI64IntegerAttr(
+            *getConstantIntValue(opOffsetVal) + *getConstantIntValue(offset));
+      } else {
+        AffineMap map = AffineMap::get(
+            1, 1, {b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0)});
+        resultOffsets[opOffsetIndex] =
+            b.create<AffineApplyOp>(loc, map,
+                                    ValueRange{getValue(b, loc, offset),
+                                               getValue(b, loc, opOffsetVal)})
+                .getResult();
+      }
+      resultSizes[opOffsetIndex] = sizes[offsetIndex];
+      offsetIndex++;
+    }
+    SmallVector<OpFoldResult> resultStrides(destRank, oneAttr);
+    auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
+        loc, extractSliceOp.result(), outputs[0], resultOffsets, resultSizes,
+        resultStrides);
+    results.push_back(tiledInsertOp.result());
+    return extractSliceOp;
+  }
+};
+}  // namespace
+
 void registerTiledOpInterfaceExternalModels(DialectRegistry &registry) {
   LLVM_DEBUG({
     llvm::dbgs() << "Adding tiled op interface for tensor.insert_slice\n";
   });
-  // TODO(ravishankarm): For now this is commented out since there are a lot of
-  // upstream bugs exposed by this. Leaving the restructuring in place, but
-  // avoiding the interface hook till those are addressed.
-  //
-  // registry.addOpInterface<tensor::InsertSliceOp,
-  // InsertSliceTiledOpInterface>();
+  registry.addOpInterface<tensor::InsertSliceOp, InsertSliceTiledOpInterface>();
 }
 
 }  // namespace linalg_ext
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
index 16aca74..ae62abe 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -235,112 +235,6 @@
                              loopBounds, 0, offsets, distributionInfo);
 }
 
-//===----------------------------------------------------------------------===//
-// Interface implementations for external operations.
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct InsertSliceTiledOpInterface
-    : public TiledOpInterface::ExternalModel<InsertSliceTiledOpInterface,
-                                             tensor::InsertSliceOp> {
-  SmallVector<Value> getDestinationOperands(Operation *op) const {
-    SmallVector<Value> dest;
-    dest.push_back(cast<tensor::InsertSliceOp>(op).dest());
-    return dest;
-  }
-
-  SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
-    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
-    return SmallVector<StringRef>(insertSliceOp.getSourceType().getRank(),
-                                  getParallelIteratorTypeName());
-  }
-
-  SmallVector<Range> getLoopBounds(Operation *op, OpBuilder &b) const {
-    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
-    Value source = insertSliceOp.source();
-    RankedTensorType sourceType = insertSliceOp.getSourceType();
-    Location loc = op->getLoc();
-    Value zero = b.create<ConstantIndexOp>(loc, 0);
-    Value one = b.create<ConstantIndexOp>(loc, 1);
-    SmallVector<Range> loopBounds(sourceType.getRank(),
-                                  Range{zero, nullptr, one});
-    for (auto dim :
-         llvm::seq<int64_t>(0, insertSliceOp.getSourceType().getRank())) {
-      loopBounds[dim].size = b.create<tensor::DimOp>(loc, source, dim);
-    }
-    return loopBounds;
-  }
-
-  Operation *getTiledImplementation(Operation *op, OpBuilder &b,
-                                    ValueRange outputs,
-                                    ArrayRef<OpFoldResult> offsets,
-                                    ArrayRef<OpFoldResult> sizes,
-                                    SmallVectorImpl<Value> &results) const {
-    auto insertOp = cast<tensor::InsertSliceOp>(op);
-    // Compute a subtensor of the source based on the offsets.
-    auto opStrides = insertOp.getMixedStrides();
-    if (!llvm::all_of(opStrides, [&](OpFoldResult valueOrAttr) {
-          Optional<int64_t> intVal = getConstantIntValue(valueOrAttr);
-          return intVal && *intVal == 1;
-        })) {
-      op->emitOpError("unable to tile operation with non-unit stride");
-      return nullptr;
-    }
-    Location loc = insertOp.getLoc();
-    auto oneAttr = b.getI64IntegerAttr(1);
-    SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
-    auto extractSliceOp = b.create<tensor::ExtractSliceOp>(
-        loc, insertOp.source(), offsets, sizes, strides);
-
-    // The offsets for the insert is based on the op offsets plus the offsets of
-    // the loops passed in.
-    auto opOffsets = insertOp.getMixedOffsets();
-    auto opSizes = insertOp.getMixedSizes();
-    unsigned offsetIndex = 0;
-    ArrayRef<int64_t> sourceShape = insertOp.getSourceType().getShape();
-    int64_t destRank = insertOp.getType().getRank();
-    SmallVector<OpFoldResult> resultOffsets(destRank);
-    SmallVector<OpFoldResult> resultSizes(destRank);
-    auto zeroAttr = b.getI64IntegerAttr(0);
-    for (auto opOffset : llvm::enumerate(opOffsets)) {
-      // Check for rank-reducing by checking that
-      // 1) The corresponding opSize value is 1
-      // 2) The current rank of the source is not 1.
-      // Then the opOffset is for the rank-reduced dimension. Skip.
-      unsigned opOffsetIndex = opOffset.index();
-      Optional<int64_t> opSizeVal = getConstantIntValue(opSizes[opOffsetIndex]);
-      if (opSizeVal && *opSizeVal == 1 && sourceShape[offsetIndex] != 1) {
-        resultOffsets[opOffsetIndex] = zeroAttr;
-        resultSizes[opOffsetIndex] = oneAttr;
-        continue;
-      }
-      OpFoldResult opOffsetVal = opOffset.value();
-      OpFoldResult offset = offsets[offsetIndex];
-      if (opOffsetVal.is<Attribute>() && offset.is<Attribute>()) {
-        resultOffsets[opOffsetIndex] = b.getI64IntegerAttr(
-            *getConstantIntValue(opOffsetVal) + *getConstantIntValue(offset));
-      } else {
-        AffineMap map = AffineMap::get(
-            1, 1, {b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0)});
-        resultOffsets[opOffsetIndex] =
-            b.create<AffineApplyOp>(loc, map,
-                                    ValueRange{getValue(b, loc, offset),
-                                               getValue(b, loc, opOffsetVal)})
-                .getResult();
-      }
-      resultSizes[opOffsetIndex] = sizes[offsetIndex];
-      offsetIndex++;
-    }
-    SmallVector<OpFoldResult> resultStrides(destRank, oneAttr);
-    auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
-        loc, extractSliceOp.result(), outputs[0], resultOffsets, resultSizes,
-        resultStrides);
-    results.push_back(tiledInsertOp.result());
-    return extractSliceOp;
-  }
-};
-}  // namespace
-
 LogicalResult TiledOpInterfaceBaseTilingPattern::matchAndRewriteBase(
     TiledOpInterface tilableOp, PatternRewriter &rewriter,
     TiledOp &result) const {
@@ -373,7 +267,6 @@
                 linalg_ext::LinalgExtDialect, memref::MemRefDialect,
                 StandardOpsDialect, tensor::TensorDialect, scf::SCFDialect>();
   }
-  LogicalResult initialize(MLIRContext *context) override;
   void runOnOperation() override;
 };
 }  // namespace
@@ -383,13 +276,6 @@
   return b.template create<OpTy>(b.getInsertionPoint()->getLoc(), dim);
 }
 
-LogicalResult TiledOpInterfaceTilingPass::initialize(MLIRContext *context) {
-  // TODO(ravishankarm): When the interface is added during registration, remove
-  // this initialization.
-  tensor::InsertSliceOp::attachInterface<InsertSliceTiledOpInterface>(*context);
-  return success();
-}
-
 void TiledOpInterfaceTilingPass::runOnOperation() {
   FuncOp funcOp = getOperation();
   MLIRContext *context = funcOp.getContext();
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index 21c3596..f632cf6 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -683,3 +683,65 @@
 // CHECK:            scf.yield %[[RES3]]
 // CHECK:          scf.yield %[[RES2]]
 // CHECK:        return %[[RES]]
+
+// -----
+
+func @dynamic_insert_slice(%arg0 : tensor<?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
+  %c0 = constant 0 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+  %0 = tensor.insert_slice %arg0 into %arg1[%arg2, %arg3] [1, %d0] [1, 1]
+      {__internal_linalg_transform__ = "tiling_input"} : tensor<?xf32> into tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+//      CHECK: func @dynamic_insert_slice(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:  %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:  %[[C10:.+]] = constant 10 : index
+//      CHECK:  %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?xf32>
+//      CHECK:  %[[RESULT:.+]] = scf.for %[[ARG4:.+]] = %[[C0]] to %[[D0]]
+// CHECK-SAME:      step %[[C10]] iter_args(%[[ARG5:.+]] = %[[ARG1]])
+//      CHECK:    %[[TILESIZE:.+]] = affine.min #[[MAP0]](%[[ARG4]])[%[[C10]], %[[D0]]]
+//      CHECK:    %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG4]]] [%[[TILESIZE]]]
+//      CHECK:    %[[OFFSET:.+]] = affine.apply #[[MAP1]](%[[ARG4]])[%[[ARG3]]]
+//      CHECK:    %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT]] into %[[ARG5]]
+// CHECK-SAME:        [%[[ARG2]], %[[OFFSET]]] [1, %[[TILESIZE]]]
+//      CHECK:    scf.yield %[[INSERT]]
+//      CHECK:  return %[[RESULT]]
+
+
+// -----
+
+func @insert_slice_rank_reduced_inner(%arg0 : tensor<?xf32>,
+    %arg1 : tensor<?x?x?xf32>, %arg2: index, %arg3 : index, %arg4 : index) -> tensor<?x?x?xf32> {
+  %c0 = constant 0 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+  %0 = tensor.insert_slice %arg0 into %arg1[%arg2, %arg3, %arg4] [1, %d0, 1] [1, 1, 1]
+      {__internal_linalg_transform__ = "tiling_input"} : tensor<?xf32> into tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+//      CHECK: func @insert_slice_rank_reduced_inner(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[LB:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[STEP:.+]] = constant 10 : index
+//      CHECK:   %[[UB:.+]] = tensor.dim %[[ARG0]]
+//      CHECK:   %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] = %[[LB]]
+// CHECK-SAME:       to %[[D0]] step %[[STEP]] iter_args(%[[ARG6:.+]] = %[[ARG1]])
+//      CHECK:     %[[TILESIZE:.+]] = affine.min #[[MAP0]](%[[ARG5]])[%[[STEP]], %[[UB]]]
+//      CHECK:     %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]]] [%[[TILESIZE]]]
+//      CHECK:     %[[APPLY:.+]] = affine.apply #[[MAP1]](%[[IV0]])[%[[ARG3]]]
+//      CHECK:     %[[YIELD:.+]] = tensor.insert_slice %[[SLICE]] into %[[ARG6]]
+// CHECK-SAME:         [%[[ARG2]], %[[APPLY]], %[[ARG4]]] [1, %[[TILESIZE]], 1]
+//      CHECK:     scf.yield %[[YIELD]]
+//      CHECK:   return %[[RESULT]]
diff --git a/iree/samples/static_library/static_library_demo.c b/iree/samples/static_library/static_library_demo.c
index ed751db..7873498 100644
--- a/iree/samples/static_library/static_library_demo.c
+++ b/iree/samples/static_library/static_library_demo.c
@@ -82,11 +82,6 @@
   if (iree_status_is_ok(status)) {
     status = create_device_with_static_loader(&device);
   }
-  iree_vm_module_t* hal_module = NULL;
-  if (iree_status_is_ok(status)) {
-    status =
-        iree_hal_module_create(device, iree_allocator_system(), &hal_module);
-  }
 
   // Session configuration (one per loaded module to hold module state).
   iree_runtime_session_options_t session_options;
diff --git a/iree/test/e2e/xla_ops/gather.mlir b/iree/test/e2e/xla_ops/gather.mlir
index c0417c6..465c58c 100644
--- a/iree/test/e2e/xla_ops/gather.mlir
+++ b/iree/test/e2e/xla_ops/gather.mlir
@@ -18,3 +18,152 @@
   check.expect_eq_const(%res, dense<[[11, 12, 13, 14, 15]]> : tensor<1x5xi32>) : tensor<1x5xi32>
   return
 }
+
+func @via_torch_index_select() {
+  %input = util.unfoldable_constant dense<[
+    [[01, 02, 03, 04, 05]],
+    [[06, 07, 08, 09, 10]],
+    [[11, 12, 13, 14, 15]],
+    [[16, 17, 18, 19, 20]],
+    [[21, 22, 23, 24, 25]]]> : tensor<5x1x5xi32>
+  %start_indices = util.unfoldable_constant dense<2> : tensor<i64>
+  %res = "mhlo.gather"(%input, %start_indices) {
+    dimension_numbers = #mhlo.gather<
+      collapsed_slice_dims = [0],
+      index_vector_dim = 0,
+      offset_dims = [0, 1],
+      start_index_map = [0],
+    >,
+    slice_sizes = dense<[1, 1, 5]> : tensor<3xi64>
+  } : (tensor<5x1x5xi32>, tensor<i64>) -> tensor<1x5xi32>
+  check.expect_eq_const(%res, dense<[[11, 12, 13, 14, 15]]> : tensor<1x5xi32>) : tensor<1x5xi32>
+  return
+}
+
+
+func @general_but_just_index_select() {
+  %operand = util.unfoldable_constant dense<[[
+    [ 0,  1,  2,  3,  4,  5,  6,  7],
+    [ 8,  9, 10, 11, 12, 13, 14, 15],
+    [16, 17, 18, 19, 20, 21, 22, 23],
+    [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x4x8xi32>
+  %start_indices = util.unfoldable_constant dense<[[
+      [0, 1],
+      [0, 2],
+      [0, 3],
+      [0, 0],
+      [0, 0],
+      [0, 1],
+      [0, 2],
+      [0, 3]]]> : tensor<1x8x2xi32>
+  %result = "mhlo.gather"(%operand, %start_indices) {
+    dimension_numbers = #mhlo.gather<
+      collapsed_slice_dims = [0, 1],
+      index_vector_dim = 2,
+      offset_dims = [2],
+      start_index_map = [0, 1]
+    >,
+    indices_are_sorted = false,
+    slice_sizes = dense<[1, 1, 8]> : tensor<3xi64>
+  } : (tensor<1x4x8xi32>, tensor<1x8x2xi32>) -> tensor<1x8x8xi32>
+  check.expect_eq_const(%result, dense<[[
+         [ 8,  9, 10, 11, 12, 13, 14, 15],
+         [16, 17, 18, 19, 20, 21, 22, 23],
+         [24, 25, 26, 27, 28, 29, 30, 31],
+         [ 0,  1,  2,  3,  4,  5,  6,  7],
+         [ 0,  1,  2,  3,  4,  5,  6,  7],
+         [ 8,  9, 10, 11, 12, 13, 14, 15],
+         [16, 17, 18, 19, 20, 21, 22, 23],
+         [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x8x8xi32>) : tensor<1x8x8xi32>
+  return
+}
+
+func @small_slices() {
+  %operand = util.unfoldable_constant dense<[[
+    [ 0,  1,  2,  3,  4,  5,  6,  7],
+    [ 8,  9, 10, 11, 12, 13, 14, 15],
+    [16, 17, 18, 19, 20, 21, 22, 23],
+    [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x4x8xi32>
+  %start_indices = util.unfoldable_constant dense<[[
+    [0, 1],
+    [0, 2],
+    [0, 3],
+    [0, 0]]]> : tensor<1x4x2xi32>
+  %result = "mhlo.gather"(%operand, %start_indices) {
+    dimension_numbers = #mhlo.gather<
+      collapsed_slice_dims = [0, 1],
+      index_vector_dim = 2,
+      offset_dims = [2],
+      start_index_map = [0, 1]
+    >,
+    indices_are_sorted = false,
+    slice_sizes = dense<[1, 1, 3]> : tensor<3xi64>
+  } : (tensor<1x4x8xi32>, tensor<1x4x2xi32>) -> tensor<1x4x3xi32>
+  check.expect_eq_const(%result, dense<[[
+        [ 8,  9, 10],
+        [16, 17, 18],
+        [24, 25, 26],
+        [ 0,  1,  2]]]> : tensor<1x4x3xi32>) : tensor<1x4x3xi32>
+  return
+}
+
+func @nonstandard_offset_dims() {
+  %operand = util.unfoldable_constant dense<[[
+    [ 0,  1,  2,  3,  4,  5,  6,  7],
+    [ 8,  9, 10, 11, 12, 13, 14, 15],
+    [16, 17, 18, 19, 20, 21, 22, 23],
+    [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x4x8xi32>
+  %start_indices = util.unfoldable_constant dense<[[
+    [0, 1],
+    [0, 2],
+    [0, 2],
+    [0, 0]]]> : tensor<1x4x2xi32>
+  %result = "mhlo.gather"(%operand, %start_indices) {
+    dimension_numbers = #mhlo.gather<
+      collapsed_slice_dims = [0],
+      index_vector_dim = 2,
+      offset_dims = [1, 2],
+      start_index_map = [0, 1]
+    >,
+    indices_are_sorted = false,
+    slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>
+  } : (tensor<1x4x8xi32>, tensor<1x4x2xi32>) -> tensor<1x2x3x4xi32>
+  check.expect_eq_const(%result, dense<[[
+      [[ 8, 16, 16,  0],
+       [ 9, 17, 17,  1],
+       [10, 18, 18,  2]],
+      [[16, 24, 24,  8],
+       [17, 25, 25,  9],
+       [18, 26, 26, 10]]]]> : tensor<1x2x3x4xi32>) : tensor<1x2x3x4xi32>
+  return
+}
+
+func @reordered_start_index() {
+  %operand = util.unfoldable_constant dense<[[
+    [[ 0,  1,  2,  3],
+     [ 4,  5,  6,  7]],
+    [[ 8,  9, 10, 11],
+     [12, 13, 14, 15]],
+    [[16, 17, 18, 19],
+     [20, 21, 22, 23]]]]> : tensor<1x3x2x4xi32>
+  %start_indices = util.unfoldable_constant dense<[
+    [0, 1, 0, 0],
+    [1, 0, 0, 0]]> : tensor<2x4xi32>
+ %result = "mhlo.gather"(%operand, %start_indices) {
+    dimension_numbers = #mhlo.gather<
+      collapsed_slice_dims = [0, 2],
+      index_vector_dim = 1,
+      offset_dims = [1, 2],
+      start_index_map = [3, 2, 0, 1]
+    >,
+    indices_are_sorted = false,
+    slice_sizes = dense<[1, 2, 1, 3]> : tensor<4xi64>
+  } : (tensor<1x3x2x4xi32>, tensor<2x4xi32>) -> tensor<2x2x3xi32>
+
+  check.expect_eq_const(%result, dense<[
+    [[ 4,  5,  6],
+     [12, 13, 14]],
+    [[ 1,  2,  3],
+     [ 9, 10, 11]]]> : tensor<2x2x3xi32>) : tensor<2x2x3xi32>
+  return
+}
diff --git a/iree/tools/utils/trace_replay.c b/iree/tools/utils/trace_replay.c
index f7e95c8..cba6e5f 100644
--- a/iree/tools/utils/trace_replay.c
+++ b/iree/tools/utils/trace_replay.c
@@ -427,6 +427,9 @@
     yaml_node_t* encoding_type_node,
     iree_hal_encoding_type_t* out_encoding_type) {
   *out_encoding_type = IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
+  if (!encoding_type_node) {
+    return iree_ok_status();
+  }
 
   if (!encoding_type_node) return iree_ok_status();
   iree_string_view_t encoding_type_str =
diff --git a/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h b/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h
index 6e946e8..2be688c 100644
--- a/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h
+++ b/llvm-external-projects/iree-compiler-api/include/iree-compiler-c/Compiler.h
@@ -21,6 +21,7 @@
   typedef struct name name
 
 DEFINE_C_API_STRUCT(IreeCompilerOptions, void);
+#undef DEFINE_C_API_STRUCT
 
 //===----------------------------------------------------------------------===//
 // Registration.
diff --git a/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp b/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
index 3b6ae8d..d13028d 100644
--- a/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
+++ b/llvm-external-projects/iree-compiler-api/lib/CAPI/Compiler.cpp
@@ -44,7 +44,10 @@
 void ireeCompilerRegisterTargetBackends() { registerHALTargetBackends(); }
 
 IreeCompilerOptions ireeCompilerOptionsCreate() {
-  return wrap(new CompilerOptions);
+  auto options = new CompilerOptions;
+  // TODO: Make configurable.
+  options->vmTargetOptions.f32Extension = true;
+  return wrap(options);
 }
 
 void ireeCompilerOptionsDestroy(IreeCompilerOptions options) {
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index df10347..90340be 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -275,6 +275,7 @@
 cc_library(
     name = "IREEPyDMTransforms",
     srcs = glob([
+        "lib/Dialect/IREEPyDM/Transforms/*.cpp",
         "lib/Dialect/IREEPyDM/Transforms/RTL/*.cpp",
         "lib/Dialect/IREEPyDM/Transforms/ToIREE/*.cpp",
     ]),
@@ -291,6 +292,7 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:MathDialect",
         "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
@@ -312,5 +314,6 @@
         ":IREEPyDMTransforms",
         "@llvm-project//mlir:CAPIIR",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
     ],
 )
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
index 613df95..da8701a 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
@@ -27,6 +27,28 @@
 
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IREEPyDM, iree_pydm);
 
+#define DEFINE_C_API_STRUCT(name, storage) \
+  struct name {                            \
+    storage *ptr;                          \
+  };                                       \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(IREEPyDMSourceBundle, void);
+DEFINE_C_API_STRUCT(IREEPyDMLoweringOptions, void);
+#undef DEFINE_C_API_STRUCT
+
+/// Creates a PyDM source bundle from an ASM string.
+MLIR_CAPI_EXPORTED IREEPyDMSourceBundle
+ireePyDMSourceBundleCreateAsm(MlirStringRef asmString);
+
+/// Creates a PyDM source bundle from a file path.
+MLIR_CAPI_EXPORTED IREEPyDMSourceBundle
+ireePyDMSourceBundleCreateFile(MlirStringRef filePath);
+
+/// Destroys a created source bundle.
+MLIR_CAPI_EXPORTED void ireePyDMSourceBundleDestroy(
+    IREEPyDMSourceBundle bundle);
+
 MLIR_CAPI_EXPORTED bool mlirTypeIsAIREEPyDMPrimitiveType(MlirType type);
 
 #define IREEPYDM_DECLARE_NULLARY_TYPE(Name)                         \
@@ -51,9 +73,20 @@
 MLIR_CAPI_EXPORTED MlirType mlirIREEPyDMObjectTypeGet(MlirContext context,
                                                       MlirType primitive);
 
+/// Creates a lowering options struct.
+MLIR_CAPI_EXPORTED IREEPyDMLoweringOptions ireePyDMLoweringOptionsCreate();
+
+/// Sets the RTL link source bundle to the lowering options.
+MLIR_CAPI_EXPORTED void ireePyDMLoweringOptionsLinkRtl(
+    IREEPyDMLoweringOptions options, IREEPyDMSourceBundle source);
+
+/// Destroys a created lowering options struct.
+MLIR_CAPI_EXPORTED void ireePyDMLoweringOptionsDestroy(
+    IREEPyDMLoweringOptions options);
+
 /// Builds a pass pipeline which lowers the iree_pydm dialect to IREE.
 MLIR_CAPI_EXPORTED void mlirIREEPyDMBuildLowerToIREEPassPipeline(
-    MlirOpPassManager passManager);
+    MlirOpPassManager passManager, IREEPyDMLoweringOptions options);
 
 #ifdef __cplusplus
 }
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h
index c3670e0..a169051 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h
@@ -7,8 +7,11 @@
 #ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
 #define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSES_H
 
+#include <memory>
+
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
 
 namespace mlir {
 
@@ -18,9 +21,25 @@
 
 namespace iree_pydm {
 
+/// References sources, either passed literally or by reference to a file.
+/// One of `asmBlob` or `asmFilePath` should be populated.
+struct SourceBundle {
+  std::shared_ptr<std::string> asmBlob;
+  Optional<std::string> asmFilePath;
+};
+
+/// Options for lowering to IREE.
+struct LowerToIREEOptions {
+  Optional<SourceBundle> linkRtlSource;
+};
+
 std::unique_ptr<OperationPass<ModuleOp>> createConvertIREEPyDMToIREEPass();
 std::unique_ptr<OperationPass<ModuleOp>> createLowerIREEPyDMToRTLPass();
-std::unique_ptr<OperationPass<ModuleOp>> createLinkIREEPyDMRTLPass();
+std::unique_ptr<OperationPass<ModuleOp>> createLinkIREEPyDMRTLPass(
+    Optional<SourceBundle> linkRtlSourceBundle = None);
+
+void buildLowerToIREEPassPipeline(OpPassManager& passManager,
+                                  const LowerToIREEOptions& options);
 
 #define GEN_PASS_REGISTRATION
 #include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h.inc"
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h
index 69dd0e1..daef405 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.h
@@ -26,6 +26,10 @@
   Type getWeakIntegerType(Builder b) const;
   Type getWeakFloatType(Builder b) const;
 
+  // Whether the given type is a valid lowered type.
+  bool isTypeLegal(Type t) const;
+  bool areTypesLegal(TypeRange types) const;
+
  private:
   bool boolBits = 32;
   int weakIntegerBits = 32;
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
index 66682e5..75be3ec 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
@@ -5,7 +5,7 @@
   MLIRIR
   IREEDialectsIREEDialect
   IREEDialectsIREEPyDMDialect
-  IREEDialectsIREEPyDMToIREEPasses
+  IREEDialectsIREEPyDMPasses
 )
 
 iree_dialects_target_includes(IREEDialectsCAPI)
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
index ad5e546..d316fbe 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
@@ -15,6 +15,9 @@
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Utils.h"
 #include "mlir/CAPI/Wrap.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
 
 //===----------------------------------------------------------------------===//
 // IREEDialect
@@ -29,6 +32,10 @@
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(IREEPyDM, iree_pydm,
                                       mlir::iree_pydm::IREEPyDMDialect)
 
+DEFINE_C_API_PTR_METHODS(IREEPyDMSourceBundle, mlir::iree_pydm::SourceBundle)
+DEFINE_C_API_PTR_METHODS(IREEPyDMLoweringOptions,
+                         mlir::iree_pydm::LowerToIREEOptions)
+
 bool mlirTypeIsAIREEPyDMPrimitiveType(MlirType type) {
   return unwrap(type).isa<mlir::iree_pydm::PrimitiveType>();
 }
@@ -66,8 +73,40 @@
   return wrap(mlir::iree_pydm::ObjectType::get(unwrap(ctx), cppType));
 }
 
-void mlirIREEPyDMBuildLowerToIREEPassPipeline(MlirOpPassManager passManager) {
+void mlirIREEPyDMBuildLowerToIREEPassPipeline(MlirOpPassManager passManager,
+                                              IREEPyDMLoweringOptions options) {
   auto *passManagerCpp = unwrap(passManager);
-  // TODO: Should be a pass pipeline, not loose passes in the C impl.
-  passManagerCpp->addPass(mlir::iree_pydm::createConvertIREEPyDMToIREEPass());
+  mlir::iree_pydm::buildLowerToIREEPassPipeline(*passManagerCpp,
+                                                *unwrap(options));
+}
+
+// SourceBundle
+IREEPyDMSourceBundle ireePyDMSourceBundleCreateAsm(MlirStringRef asmString) {
+  auto bundle = std::make_unique<mlir::iree_pydm::SourceBundle>();
+  bundle->asmBlob = std::make_shared<std::string>(unwrap(asmString));
+  return wrap(bundle.release());
+}
+
+IREEPyDMSourceBundle ireePyDMSourceBundleCreateFile(MlirStringRef filePath) {
+  auto bundle = std::make_unique<mlir::iree_pydm::SourceBundle>();
+  bundle->asmFilePath = std::string(unwrap(filePath));
+  return wrap(bundle.release());
+}
+
+void ireePyDMSourceBundleDestroy(IREEPyDMSourceBundle bundle) {
+  delete unwrap(bundle);
+}
+
+// LoweringOptions
+IREEPyDMLoweringOptions ireePyDMLoweringOptionsCreate() {
+  return wrap(new mlir::iree_pydm::LowerToIREEOptions);
+}
+
+void ireePyDMLoweringOptionsLinkRtl(IREEPyDMLoweringOptions options,
+                                    IREEPyDMSourceBundle source) {
+  unwrap(options)->linkRtlSource = *unwrap(source);
+}
+
+void ireePyDMLoweringOptionsDestroy(IREEPyDMLoweringOptions options) {
+  delete unwrap(options);
 }
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
index 37db9a5..44fe3b2 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/CMakeLists.txt
@@ -1,2 +1,15 @@
 add_subdirectory(RTL)
 add_subdirectory(ToIREE)
+
+add_mlir_library(IREEDialectsIREEPyDMPasses
+  Passes.cpp
+
+  DEPENDS
+  MLIRIREEPyDMTransformsPassesIncGen
+
+  LINK_LIBS PUBLIC
+  IREEDialectsIREEPyDMRTLPasses
+  IREEDialectsIREEPyDMToIREEPasses
+)
+
+iree_dialects_target_includes(IREEDialectsIREEPyDMPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
new file mode 100644
index 0000000..caaaf14
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/Passes.cpp
@@ -0,0 +1,22 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/IREEPyDM/Transforms/Passes.h"
+
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+using namespace mlir::iree_pydm;
+
+void mlir::iree_pydm::buildLowerToIREEPassPipeline(
+    OpPassManager& passManager, const LowerToIREEOptions& options) {
+  // TODO: Needs to be iterative, support optimization passes, etc.
+  passManager.addPass(createLowerIREEPyDMToRTLPass());
+  if (options.linkRtlSource) {
+    passManager.addPass(createLinkIREEPyDMRTLPass(options.linkRtlSource));
+  }
+  passManager.addPass(createConvertIREEPyDMToIREEPass());
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp
index 9608edf..7803d65 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/RTL/LinkRTLPass.cpp
@@ -35,22 +35,38 @@
 namespace {
 
 class LinkIREEPyDMRTLPass : public LinkIREEPyDMRTLBase<LinkIREEPyDMRTLPass> {
+ public:
+  LinkIREEPyDMRTLPass() = default;
+  LinkIREEPyDMRTLPass(Optional<SourceBundle> linkRtlSourceBundle)
+      : linkRtlSourceBundle(std::move(linkRtlSourceBundle)) {}
+
+ private:
   LogicalResult initialize(MLIRContext *context) override {
-    // Already initialized in some way.
-    if (!rtlModule && rtlFile.empty()) {
+    SourceBundle localSource;
+    if (linkRtlSourceBundle) {
+      localSource = *linkRtlSourceBundle;
+    } else {
+      // Get it from the cli options.
+      localSource.asmFilePath = rtlFile;
+    }
+
+    if (localSource.asmBlob) {
+      // Parse from inline asm.
+      auto owningOp = parseSourceString(*localSource.asmBlob, context);
+      if (!owningOp) return failure();
+      rtlModule = std::make_shared<OwningModuleRef>(std::move(owningOp));
+    } else if (localSource.asmFilePath) {
+      // Parse from a file.
+      auto owningOp = parseSourceFile(*localSource.asmFilePath, context);
+      if (!owningOp) return failure();
+      rtlModule = std::make_shared<OwningModuleRef>(std::move(owningOp));
+    } else {
       return emitError(UnknownLoc::get(context))
              << "pass " << getArgument()
              << "must be initialized with an RTL module (did you mean to "
                 "add an rtl-file option?)";
     }
 
-    if (!rtlFile.empty()) {
-      // Parse from a file.
-      auto owningOp = parseSourceFile(rtlFile, context);
-      if (!owningOp) return failure();
-      rtlModule = std::make_shared<OwningModuleRef>(std::move(owningOp));
-    }
-
     ModuleOp parentModule = rtlModule->get();
     // Walk the module and build a SymbolTable for each sub-module.
     parentModule->walk([&](ModuleOp importModule) {
@@ -187,11 +203,15 @@
 
   // A SymbolTable for each sub module.
   SmallVector<SymbolTable> importModules;
+
+  // ASM source of RTL modules to link (otherwise will use pass options).
+  Optional<SourceBundle> linkRtlSourceBundle;
 };
 
 }  // namespace
 
 std::unique_ptr<OperationPass<ModuleOp>>
-mlir::iree_pydm::createLinkIREEPyDMRTLPass() {
-  return std::make_unique<LinkIREEPyDMRTLPass>();
+mlir::iree_pydm::createLinkIREEPyDMRTLPass(
+    Optional<SourceBundle> linkRtlSourceBundle) {
+  return std::make_unique<LinkIREEPyDMRTLPass>(std::move(linkRtlSourceBundle));
 }
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
index 661a237..2d4d88d 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/ConversionPass.cpp
@@ -40,6 +40,16 @@
     target.addLegalDialect<mlir::iree::IREEDialect>();
     target.addLegalDialect<mlir::math::MathDialect>();
     target.addLegalDialect<StandardOpsDialect>();
+
+    // Some CFG ops can be present in the original pydm program. Need to
+    // verify legality based on types.
+    target.addDynamicallyLegalOp<BranchOp>([&](BranchOp op) -> bool {
+      return typeConverter.areTypesLegal(op.getOperandTypes());
+    });
+    target.addDynamicallyLegalOp<CondBranchOp>([&](CondBranchOp op) -> bool {
+      return typeConverter.areTypesLegal(op.getOperandTypes());
+    });
+
     if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
       return signalPassFailure();
     }
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
index 6214709..9b3e1fd 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/LoweringPatterns.cpp
@@ -17,6 +17,7 @@
 using namespace mlir;
 using namespace mlir::iree_pydm;
 
+namespace arith_d = mlir;
 namespace iree_d = mlir::iree;
 namespace builtin_d = mlir;
 namespace std_d = mlir;
@@ -101,6 +102,60 @@
   return list;
 }
 
+static Value castIntegerValue(Location loc, Value input,
+                              builtin_d::IntegerType resultType,
+                              OpBuilder &builder) {
+  builtin_d::IntegerType inputType =
+      input.getType().cast<builtin_d::IntegerType>();
+  if (inputType.getWidth() == resultType.getWidth()) {
+    return input;
+  } else if (inputType.getWidth() < resultType.getWidth()) {
+    return builder.create<arith_d::SignExtendIOp>(loc, resultType, input);
+  } else {
+    return builder.create<arith_d::TruncateIOp>(loc, resultType, input);
+  }
+}
+
+static Optional<arith_d::CmpIPredicate> convertIntegerComparePredicate(
+    StringAttr dunderName, bool isSigned, Builder &builder) {
+  StringRef v = dunderName.getValue();
+  if (v == "lt") {
+    return isSigned ? arith_d::CmpIPredicate::slt : arith_d::CmpIPredicate::ult;
+  } else if (v == "le") {
+    return isSigned ? arith_d::CmpIPredicate::sle : arith_d::CmpIPredicate::ule;
+  } else if (v == "eq" || v == "is") {
+    return arith_d::CmpIPredicate::eq;
+  } else if (v == "ne" || v == "isnot") {
+    return arith_d::CmpIPredicate::ne;
+  } else if (v == "gt") {
+    return isSigned ? arith_d::CmpIPredicate::sgt : arith_d::CmpIPredicate::ugt;
+  } else if (v == "ge") {
+    return isSigned ? arith_d::CmpIPredicate::sge : arith_d::CmpIPredicate::uge;
+  }
+
+  return {};
+}
+
+static Optional<arith_d::CmpFPredicate> convertFpComparePredicate(
+    StringAttr dunderName, Builder &builder) {
+  StringRef v = dunderName.getValue();
+  if (v == "lt") {
+    return arith_d::CmpFPredicate::OLT;
+  } else if (v == "le") {
+    return arith_d::CmpFPredicate::OLE;
+  } else if (v == "eq" || v == "is") {
+    return arith_d::CmpFPredicate::OEQ;
+  } else if (v == "ne" || v == "isnot") {
+    return arith_d::CmpFPredicate::ONE;
+  } else if (v == "gt") {
+    return arith_d::CmpFPredicate::OGT;
+  } else if (v == "ge") {
+    return arith_d::CmpFPredicate::OGE;
+  }
+
+  return {};
+}
+
 namespace {
 
 class AllocFreeVarOpConversion
@@ -120,6 +175,53 @@
   }
 };
 
+class ApplyCompareNumericConversion
+    : public OpConversionPattern<pydm_d::ApplyCompareOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      pydm_d::ApplyCompareOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    Type leftType = adaptor.left().getType();
+    Type rightType = adaptor.right().getType();
+    if (leftType != rightType) {
+      return rewriter.notifyMatchFailure(srcOp, "not same type operands");
+    }
+    if (leftType.isa<builtin_d::IntegerType>()) {
+      bool isSigned = true;  // TODO: Unsigned.
+      auto predicate = convertIntegerComparePredicate(adaptor.dunder_name(),
+                                                      isSigned, rewriter);
+      if (!predicate)
+        return rewriter.notifyMatchFailure(srcOp, "unsupported predicate");
+      rewriter.replaceOpWithNewOp<arith_d::CmpIOp>(
+          srcOp, *predicate, adaptor.left(), adaptor.right());
+      return success();
+    } else if (leftType.isa<builtin_d::FloatType>()) {
+      auto predicate =
+          convertFpComparePredicate(adaptor.dunder_name(), rewriter);
+      if (!predicate)
+        return rewriter.notifyMatchFailure(srcOp, "unsupported predicate");
+      rewriter.replaceOpWithNewOp<arith_d::CmpFOp>(
+          srcOp, *predicate, adaptor.left(), adaptor.right());
+      return success();
+    }
+
+    return rewriter.notifyMatchFailure(srcOp, "non numeric type");
+  }
+};
+
+class BoolToPredConversion : public OpConversionPattern<pydm_d::BoolToPredOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      pydm_d::BoolToPredOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    pydm_d::BoolToPredOp::Adaptor adaptor(operands);
+    rewriter.replaceOp(srcOp, adaptor.value());
+    return success();
+  }
+};
+
 class BoxOpConversion : public OpConversionPattern<pydm_d::BoxOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -140,6 +242,95 @@
   }
 };
 
+class BranchConversion : public OpConversionPattern<std_d::BranchOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      std_d::BranchOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<std_d::BranchOp>(srcOp, srcOp.dest(),
+                                                 adaptor.destOperands());
+    return success();
+  }
+};
+
+class CallOpConversion : public OpConversionPattern<pydm_d::CallOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      pydm_d::CallOp srcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    pydm_d::CallOp::Adaptor adaptor(operands);
+    SmallVector<Type> resultTypes;
+    if (failed(getTypeConverter()->convertTypes(srcOp.getResultTypes(),
+                                                resultTypes))) {
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "result types could not be converted");
+    }
+    rewriter.replaceOpWithNewOp<std_d::CallOp>(srcOp, srcOp.callee(),
+                                               resultTypes, adaptor.operands());
+    return success();
+  }
+};
+
+class CondBranchConversion : public OpConversionPattern<std_d::CondBranchOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      std_d::CondBranchOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<std_d::CondBranchOp>(
+        srcOp, adaptor.condition(), srcOp.trueDest(),
+        adaptor.trueDestOperands(), srcOp.falseDest(),
+        adaptor.falseDestOperands());
+    return success();
+  }
+};
+
+class ConstantOpConversion : public OpConversionPattern<pydm_d::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      pydm_d::ConstantOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+    Type resultType = typeConverter->convertType(srcOp.getResult().getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(
+          srcOp, "constant type could not be converted");
+    Attribute newValue = adaptor.value();
+    // Fixup widths of integer types that may be wider/narrower than the
+    // stored attribute (which tends to be stored in high precision in pydm
+    // constants).
+    TypeSwitch<Type>(resultType)
+        .Case([&](builtin_d::IntegerType t) {
+          APInt intValue =
+              newValue.cast<IntegerAttr>().getValue().sextOrTrunc(t.getWidth());
+          newValue = rewriter.getIntegerAttr(t, intValue);
+        })
+        .Case([&](builtin_d::FloatType t) {
+          APFloat fpValue = newValue.cast<FloatAttr>().getValue();
+          if (APFloat::SemanticsToEnum(fpValue.getSemantics()) !=
+              APFloat::SemanticsToEnum(t.getFloatSemantics())) {
+            // Convert.
+            APFloat newFpValue = fpValue;
+            bool losesInfo;
+            newFpValue.convert(t.getFloatSemantics(),
+                               APFloat::rmNearestTiesToEven, &losesInfo);
+            if (losesInfo) {
+              emitWarning(loc) << "conversion of " << newValue << " to " << t
+                               << " loses information";
+            }
+            newValue = rewriter.getFloatAttr(t, newFpValue);
+          }
+        });
+
+    if (!newValue)
+      return rewriter.notifyMatchFailure(
+          srcOp, "constant cannot be represented as a standard constant");
+    rewriter.replaceOpWithNewOp<std_d::ConstantOp>(srcOp, resultType, newValue);
+    return success();
+  }
+};
+
 class FuncOpConversion : public OpConversionPattern<pydm_d::FuncOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -172,6 +363,7 @@
         convertedResultTypes);
     auto newFuncOp = rewriter.create<mlir::FuncOp>(
         srcOp.getLoc(), srcOp.getName(), newFuncType);
+    newFuncOp.setVisibility(srcOp.getVisibility());
     rewriter.inlineRegionBefore(srcOp.getBody(), newFuncOp.getBody(),
                                 newFuncOp.end());
 
@@ -187,6 +379,33 @@
   }
 };
 
+class GetTypeCodeConversion
+    : public OpConversionPattern<pydm_d::GetTypeCodeOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      pydm_d::GetTypeCodeOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+    // Gets the 0'th element of the object list, optionally casting it to the
+    // converted integer type.
+    Type resultType = typeConverter->convertType(srcOp.getResult().getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "result type could not be converted");
+    Type i32Type = rewriter.getIntegerType(32);
+    Value index0 =
+        rewriter.create<std_d::ConstantOp>(loc, rewriter.getIndexAttr(0));
+    Value typeCode = rewriter.create<iree_d::ListGetOp>(
+        loc, i32Type, adaptor.value(), index0);
+    rewriter.replaceOp(
+        srcOp,
+        castIntegerValue(loc, typeCode,
+                         resultType.cast<builtin_d::IntegerType>(), rewriter));
+    return success();
+  }
+};
+
 class LoadVarOpConversion : public OpConversionPattern<pydm_d::LoadVarOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -237,8 +456,8 @@
     auto loc = srcOp.getLoc();
 
     Value status = operands[0];
-    // Get the containing function return type so that we can create a suitable
-    // null return value.
+    // Get the containing function return type so that we can create a
+    // suitable null return value.
     auto parentFunc = srcOp->getParentOfType<builtin_d::FuncOp>();
     if (!parentFunc)
       return rewriter.notifyMatchFailure(srcOp, "not contained by a func");
@@ -390,10 +609,13 @@
     MLIRContext *context, TypeConverter &typeConverter,
     RewritePatternSet &patterns) {
   // Structural.
-  patterns.insert<AllocFreeVarOpConversion, BoxOpConversion, FuncOpConversion,
-                  LoadVarOpConversion, RaiseOnFailureOpConversion,
-                  ReturnOpConversion, StoreVarOpConversion, UnboxOpConversion>(
-      typeConverter, context);
+  patterns.insert<AllocFreeVarOpConversion, ApplyCompareNumericConversion,
+                  BoolToPredConversion, BoxOpConversion, BranchConversion,
+                  CallOpConversion, CondBranchConversion, ConstantOpConversion,
+                  FuncOpConversion, GetTypeCodeConversion, LoadVarOpConversion,
+                  RaiseOnFailureOpConversion, ReturnOpConversion,
+                  StoreVarOpConversion, UnboxOpConversion>(typeConverter,
+                                                           context);
 
   // Constants and constructors.
   patterns.insert<NoneOpConversion>(typeConverter, context);
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
index 9416b2a..0ee4fb7 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/Transforms/ToIREE/TypeConverter.cpp
@@ -37,12 +37,23 @@
     return getVariantListType(b);
   });
 
+  // Bool.
+  addConversion([&](pydm_d::BoolType t) -> Optional<Type> {
+    return builtin_d::IntegerType::get(t.getContext(), 1);
+  });
+
   // Integer type hierarchy.
   addConversion([&](pydm_d::IntegerType t) -> Optional<Type> {
     Builder b(t.getContext());
     return getWeakIntegerType(b);
   });
 
+  // Real type hierarchy.
+  addConversion([&](pydm_d::RealType t) -> Optional<Type> {
+    Builder b(t.getContext());
+    return getWeakFloatType(b);
+  });
+
   // Variable references.
   addConversion([](pydm_d::FreeVarRefType t) -> Optional<Type> {
     // Just an object record.
@@ -73,3 +84,15 @@
       return b.getF64Type();
   }
 }
+
+bool LoweringTypeConverter::isTypeLegal(Type t) const {
+  return t.isa<builtin_d::IntegerType, builtin_d::FloatType,
+               builtin_d::IndexType, iree_d::ListType>();
+}
+
+bool LoweringTypeConverter::areTypesLegal(TypeRange types) const {
+  for (Type t : types) {
+    if (!isTypeLegal(t)) return false;
+  }
+  return true;
+}
diff --git a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
index 0b235ee..067c565 100644
--- a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
+++ b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
@@ -15,6 +15,36 @@
 namespace py = pybind11;
 using namespace mlir::python::adaptors;
 
+namespace {
+
+struct PyIREEPyDMSourceBundle {
+  PyIREEPyDMSourceBundle(IREEPyDMSourceBundle wrapped) : wrapped(wrapped) {}
+  PyIREEPyDMSourceBundle(PyIREEPyDMSourceBundle &&other)
+      : wrapped(other.wrapped) {
+    other.wrapped.ptr = nullptr;
+  }
+  PyIREEPyDMSourceBundle(const PyIREEPyDMSourceBundle &) = delete;
+  ~PyIREEPyDMSourceBundle() {
+    if (wrapped.ptr) ireePyDMSourceBundleDestroy(wrapped);
+  }
+  IREEPyDMSourceBundle wrapped;
+};
+
+struct PyIREEPyDMLoweringOptions {
+  PyIREEPyDMLoweringOptions() : wrapped(ireePyDMLoweringOptionsCreate()) {}
+  PyIREEPyDMLoweringOptions(PyIREEPyDMLoweringOptions &&other)
+      : wrapped(other.wrapped) {
+    other.wrapped.ptr = nullptr;
+  }
+  PyIREEPyDMLoweringOptions(const PyIREEPyDMLoweringOptions &) = delete;
+  ~PyIREEPyDMLoweringOptions() {
+    if (wrapped.ptr) ireePyDMLoweringOptionsDestroy(wrapped);
+  }
+  IREEPyDMLoweringOptions wrapped;
+};
+
+}  // namespace
+
 PYBIND11_MODULE(_ireeDialects, m) {
   m.doc() = "iree-dialects main python extension";
 
@@ -63,6 +93,37 @@
   //===--------------------------------------------------------------------===//
   auto iree_pydm_m = m.def_submodule("iree_pydm");
 
+  py::class_<PyIREEPyDMSourceBundle>(
+      iree_pydm_m, "SourceBundle", py::module_local(),
+      "Contains raw assembly source or a reference to a file")
+      .def_static(
+          "from_asm",
+          [](std::string asmBlob) {
+            return PyIREEPyDMSourceBundle(ireePyDMSourceBundleCreateAsm(
+                {asmBlob.data(), asmBlob.size()}));
+          },
+          py::arg("asm_blob"),
+          "Creates a SourceBundle from an ASM blob (string or bytes)")
+      .def_static(
+          "from_file",
+          [](std::string asmFile) {
+            return PyIREEPyDMSourceBundle(ireePyDMSourceBundleCreateFile(
+                {asmFile.data(), asmFile.size()}));
+          },
+          py::arg("asm_file"),
+          "Creates a SourceBundle from a file containing ASM");
+  py::class_<PyIREEPyDMLoweringOptions>(iree_pydm_m, "LoweringOptions",
+                                        py::module_local(),
+                                        "Lowering options to compile to IREE")
+      .def(py::init<>())
+      .def(
+          "link_rtl",
+          [](PyIREEPyDMLoweringOptions &self,
+             PyIREEPyDMSourceBundle &sourceBundle) {
+            ireePyDMLoweringOptionsLinkRtl(self.wrapped, sourceBundle.wrapped);
+          },
+          "Enables linking against a runtime-library module");
+
   iree_pydm_m.def(
       "register_dialect",
       [](MlirContext context, bool load) {
@@ -76,12 +137,13 @@
 
   iree_pydm_m.def(
       "build_lower_to_iree_pass_pipeline",
-      [](MlirPassManager passManager) {
+      [](MlirPassManager passManager, PyIREEPyDMLoweringOptions &options) {
         MlirOpPassManager opPassManager =
             mlirPassManagerGetAsOpPassManager(passManager);
-        mlirIREEPyDMBuildLowerToIREEPassPipeline(opPassManager);
+        mlirIREEPyDMBuildLowerToIREEPassPipeline(opPassManager,
+                                                 options.wrapped);
       },
-      py::arg("pass_manager"));
+      py::arg("pass_manager"), py::arg("link_rtl_asm") = py::none());
 
 #define DEFINE_IREEPYDM_NULLARY_TYPE(Name)                                 \
   mlir_type_subclass(iree_pydm_m, #Name "Type", mlirTypeIsAIREEPyDM##Name, \
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py
index 8eb41b1..d8bfd1c 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/importer/util.py
@@ -361,8 +361,10 @@
              f"implemented for this compiler")
 
 
-def create_context() -> ir.Context:
+def create_context(*, debug: bool = False) -> ir.Context:
   context = ir.Context()
+  if debug:
+    context.enable_multithreading(False)
   d.register_dialect(context)
   return context
 
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/__init__.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/__init__.py
index 7bea5c5..1255fc0 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/__init__.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/__init__.py
@@ -3,3 +3,22 @@
 # Licensed under the Apache License v2.0 with LLVM Exceptions.
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import functools as _functools
+
+from typing import Sequence
+from .base import RtlBuilder, RtlModule
+
+
+def _get_std_rtl_modules() -> Sequence[RtlModule]:
+  from .modules import (
+      booleans,
+      numerics,
+  )
+  return [m.RTL_MODULE for m in (booleans, numerics)]
+
+
+STD_RTL_MODULES = _get_std_rtl_modules()
+
+# Source bundle for the standard RTL.
+get_std_rtl_source_bundle = RtlBuilder.lazy_build_source_bundle(STD_RTL_MODULES)
diff --git a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/base.py b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/base.py
index 240317a..ae23206 100644
--- a/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/base.py
+++ b/llvm-external-projects/iree-dialects/python/iree/compiler/dialects/iree_pydm/rtl/base.py
@@ -5,9 +5,10 @@
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 """Base helpers for the RTL DSL."""
 
-from typing import List, Optional
+from typing import Callable, List, Optional, Sequence
 
 import functools
+import threading
 
 from ..importer import (
     create_context,
@@ -18,8 +19,15 @@
     ImportStage,
 )
 
-from ... import builtin as builtin_d
-from .... import (ir, passmanager, transforms as unused_transforms)
+from ... import (
+    builtin as builtin_d,
+    iree_pydm as pydm_d,
+)
+from .... import (
+    ir,
+    passmanager,
+    transforms as unused_transforms,
+)
 
 
 class RtlModule:
@@ -88,6 +96,36 @@
         ir.Location.unknown(context=self.context))
     self.module_op = self.root_module.operation
 
+  @staticmethod
+  def build_modules(rtl_modules: Sequence[RtlModule]) -> bytes:
+    """One shot build modules and return assembly."""
+    b = RtlBuilder()
+    b.emit_modules(rtl_modules)
+    b.optimize()
+    return b.root_module.operation.get_asm(binary=True, enable_debug_info=True)
+
+  @staticmethod
+  def lazy_build_source_bundle(
+      rtl_modules: Sequence[RtlModule]) -> Callable[[], pydm_d.SourceBundle]:
+    """Returns a function to lazily build RTL modules.
+
+    Modules will only be built once and cached for the life of the function.
+    Since RTL asm is typically passed unsafely to compiler passes, caching
+    forever is important.
+    """
+    rtl_modules = tuple(rtl_modules)
+    cache = []
+    lock = threading.Lock()
+
+    def get() -> pydm_d.SourceBundle:
+      with lock:
+        if not cache:
+          asm_blob = RtlBuilder.build_modules(rtl_modules)
+          cache.append(pydm_d.SourceBundle.from_asm(asm_blob))
+        return cache[0]
+
+    return get
+
   def emit_module(self, rtl_module: RtlModule):
     root_body = self.module_op.regions[0].blocks[0]
     with ir.InsertionPoint(root_body), ir.Location.unknown():
@@ -101,6 +139,10 @@
       # Getting the symbol implies exporting it into the module.
       f.get_or_create_provided_func_symbol(stage)
 
+  def emit_modules(self, rtl_modules: Sequence[RtlModule]):
+    for rtl_module in rtl_modules:
+      self.emit_module(rtl_module)
+
   def optimize(self):
     """Optimizes the RTL modules by running through stage 1 compilation."""
     with self.context:
diff --git a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/constants.mlir b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/constants.mlir
new file mode 100644
index 0000000..d8fcb91
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/constants.mlir
@@ -0,0 +1,24 @@
+// RUN: iree-dialects-opt -split-input-file -convert-iree-pydm-to-iree %s | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+// CHECK-LABEL: @none_constant
+iree_pydm.func @none_constant() -> (!iree_pydm.exception_result, !iree_pydm.none) {
+  // CHECK: %[[CST0:.*]] = constant 0 : i32
+  // CHECK: %[[CST1:.*]] = constant 0 : i32
+  // CHECK: return %[[CST1]], %[[CST0]]
+  %0 = none
+  return %0 : !iree_pydm.none
+}
+
+// CHECK-LABEL: @constant_integer_trunc
+iree_pydm.func @constant_integer_trunc() -> (!iree_pydm.exception_result, !iree_pydm.integer) {
+  // CHECK: constant -10 : i32
+  %0 = constant -10 : i64 -> !iree_pydm.integer
+  return %0 : !iree_pydm.integer
+}
+
+// CHECK-LABEL: @constant_real_trunc
+iree_pydm.func @constant_real_trunc() -> (!iree_pydm.exception_result, !iree_pydm.real) {
+  // CHECK: constant -2.000000e+00 : f32
+  %0 = constant -2.0 : f64 -> !iree_pydm.real
+  return %0 : !iree_pydm.real
+}
diff --git a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/integer_compare.mlir b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/integer_compare.mlir
new file mode 100644
index 0000000..dfd387c
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/integer_compare.mlir
@@ -0,0 +1,65 @@
+// RUN: iree-dialects-opt -convert-iree-pydm-to-iree %s | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+// CHECK-LABEL: @lt
+iree_pydm.func @lt(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpi slt, %arg0, %arg1 : i32
+  %0 = apply_compare "lt", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @le
+iree_pydm.func @le(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpi sle, %arg0, %arg1 : i32
+  %0 = apply_compare "le", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @eq
+iree_pydm.func @eq(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpi eq, %arg0, %arg1 : i32
+  %0 = apply_compare "eq", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @is
+iree_pydm.func @is(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpi eq, %arg0, %arg1 : i32
+  %0 = apply_compare "is", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @ne
+iree_pydm.func @ne(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpi ne, %arg0, %arg1 : i32
+  %0 = apply_compare "ne", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @isnot
+iree_pydm.func @isnot(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpi ne, %arg0, %arg1 : i32
+  %0 = apply_compare "isnot", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @gt
+iree_pydm.func @gt(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpi sgt, %arg0, %arg1 : i32
+  %0 = apply_compare "gt", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @ge
+iree_pydm.func @ge(%arg0 : !iree_pydm.integer, %arg1 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpi sge, %arg0, %arg1 : i32
+  %0 = apply_compare "ge", %arg0, %arg1 : !iree_pydm.integer, !iree_pydm.integer
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
diff --git a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/real_compare.mlir b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/real_compare.mlir
new file mode 100644
index 0000000..83601c1
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/real_compare.mlir
@@ -0,0 +1,65 @@
+// RUN: iree-dialects-opt -convert-iree-pydm-to-iree %s | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+// CHECK-LABEL: @lt
+iree_pydm.func @lt(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpf olt, %arg0, %arg1 : f32
+  %0 = apply_compare "lt", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @le
+iree_pydm.func @le(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpf ole, %arg0, %arg1 : f32
+  %0 = apply_compare "le", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @eq
+iree_pydm.func @eq(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpf oeq, %arg0, %arg1 : f32
+  %0 = apply_compare "eq", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @is
+iree_pydm.func @is(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpf oeq, %arg0, %arg1 : f32
+  %0 = apply_compare "is", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @ne
+iree_pydm.func @ne(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpf one, %arg0, %arg1 : f32
+  %0 = apply_compare "ne", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @isnot
+iree_pydm.func @isnot(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpf one, %arg0, %arg1 : f32
+  %0 = apply_compare "isnot", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @gt
+iree_pydm.func @gt(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpf ogt, %arg0, %arg1 : f32
+  %0 = apply_compare "gt", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
+
+// CHECK-LABEL: @ge
+iree_pydm.func @ge(%arg0 : !iree_pydm.real, %arg1 : !iree_pydm.real) -> (!iree_pydm.exception_result, !iree_pydm.bool) {
+  // CHECK: %[[R:.*]] = cmpf oge, %arg0, %arg1 : f32
+  %0 = apply_compare "ge", %arg0, %arg1 : !iree_pydm.real, !iree_pydm.real
+  // CHECK: return {{.*}}, %[[R]]
+  return %0 : !iree_pydm.bool
+}
diff --git a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
index 35566e8..46cafc7 100644
--- a/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
+++ b/llvm-external-projects/iree-dialects/test/iree_pydm/to_iree/structural.mlir
@@ -1,12 +1,27 @@
 // RUN: iree-dialects-opt -split-input-file -convert-iree-pydm-to-iree %s | FileCheck --enable-var-scope --dump-input-filter=all %s
 
-// CHECK-LABEL: @none_constant
-iree_pydm.func @none_constant() -> (!iree_pydm.exception_result, !iree_pydm.none) {
-  // CHECK: %[[CST0:.*]] = constant 0 : i32
-  // CHECK: %[[CST1:.*]] = constant 0 : i32
-  // CHECK: return %[[CST1]], %[[CST0]]
+// CHECK-LABEL: @bool_to_pred
+// NOTE: Also tests cond_br conversion.
+iree_pydm.func @bool_to_pred(%arg0 : !iree_pydm.bool) -> (!iree_pydm.exception_result, !iree_pydm.none) {
+  %0 = bool_to_pred %arg0
+  %1 = none
+  // CHECK: cond_br %arg0
+  cond_br %0, ^bb1, ^bb2
+^bb1:
+  return %1 : !iree_pydm.none
+^bb2:
+  return %1 : !iree_pydm.none
+}
+
+// -----
+// CHECK-LABEL: @br
+iree_pydm.func @br() -> (!iree_pydm.exception_result, !iree_pydm.none) {
   %0 = none
-  return %0 : !iree_pydm.none
+  // CHECK: br ^bb1({{.*}} : i32)
+  br ^bb1(%0 : !iree_pydm.none)
+  // CHECK: ^bb1(%0: i32):
+^bb1(%1 : !iree_pydm.none):
+  return %1 : !iree_pydm.none
 }
 
 // -----
@@ -109,3 +124,25 @@
   raise_on_failure %arg0 : !iree_pydm.exception_result
   return %arg1 : !iree_pydm.integer
 }
+
+// -----
+// CHECK-LABEL: @call_and_visibility
+iree_pydm.func @call_and_visibility(%arg0 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.integer) {
+  // CHECK: %[[R:.*]]:2 = call @callee(%arg0) : (i32) -> (i32, i32)
+  %0:2 = call @callee(%arg0) : (!iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.integer)
+  return %0#1 : !iree_pydm.integer
+}
+
+// CHECK: func private @callee
+iree_pydm.func private @callee(%arg0 : !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.integer) {
+  return %arg0 : !iree_pydm.integer
+}
+
+// -----
+// CHECK-LABEL: @get_type_code
+iree_pydm.func @get_type_code(%arg0 : !iree_pydm.object) -> (!iree_pydm.exception_result, !iree_pydm.integer) {
+  // CHECK: %[[c0:.*]] = constant 0 : index
+  // CHECK: %[[R:.*]] = iree.list.get %arg0[%[[c0]]] : !iree.list<!iree.variant> -> i32
+  %0 = get_type_code %arg0 : !iree_pydm.object
+  return %0 : !iree_pydm.integer
+}
diff --git a/llvm-external-projects/iree-dialects/test/python/iree_pydm/rtl.py b/llvm-external-projects/iree-dialects/test/python/iree_pydm/rtl.py
new file mode 100644
index 0000000..9337da9
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/python/iree_pydm/rtl.py
@@ -0,0 +1,6 @@
+# RUN: %PYTHON %s
+
+from iree.compiler.dialects.iree_pydm import rtl
+
+# Ensures that we can compile the standard library to a SourceModule.
+print(rtl.get_std_rtl_source_bundle())