Moves the linalg_ext dialect to iree_linalg_ext under the iree-dialects project. (#7657)

Mostly NFC:

* Pre-factors iree-dialects into a consistent state wrt namespaces and directory layouts.
* Moves linalg_ext to a new dialect under iree-dialects. Does some adaptation to upstream style along the way.
* Redirects everything that was using it to use the new one.
* Enables tests for iree-dialects in the cmake CI (they were not enabled) and fixes some type prefixing that had drifted. (Will follow up with enabling them in the internal builds to guard against further regression).

Non-NFC:

*When tiling, the old pass was directly generating flow.dispatch.workgroup... ops to query the current workgroup id. We had been planning to add those to the input dialect, so I pulled part of that patch forward.
* Since these now newly lack lowerings, I was expecting to hit some test failures that would guide me to where to adapt, but this appears to be dead code (outside of integration tests?). We'll see what the CI says.
* Can keep coding on this patch to adapt whatever is needed.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/CMakeLists.txt
new file mode 100644
index 0000000..31167e6
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/CMakeLists.txt
@@ -0,0 +1,3 @@
+add_subdirectory(IR)
+add_subdirectory(Transforms)
+add_subdirectory(Utils)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/CMakeLists.txt
new file mode 100644
index 0000000..7ca64d8
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_library(IREEPyDMDialect
+  PyDMDialect.cpp
+  PyDMOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${IREE_DIALECTS_SOURCE_DIR}/include
+
+  DEPENDS
+  IREEPyDMIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSideEffectInterfaces
+)
+
+iree_dialects_target_includes(IREEPyDMDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp
new file mode 100644
index 0000000..5db10f7
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMDialect.cpp
@@ -0,0 +1,337 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
+
+#include "iree-dialects/Dialect/PyDM/IR/PyDMInterfaces.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOpInterfaces.cpp.inc"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMTypeInterfaces.cpp.inc"
+#define GET_TYPEDEF_CLASSES
+#include "iree-dialects/Dialect/PyDM/IR/PyDMTypes.cpp.inc"
+
+//------------------------------------------------------------------------------
+// Dialect implementation
+//------------------------------------------------------------------------------
+
+using BuiltinIntegerType = mlir::IntegerType;
+
+using PyBoolType = PYDM::BoolType;
+using PyConstantOp = PYDM::ConstantOp;
+using PyIntegerType = PYDM::IntegerType;
+using PyRealType = PYDM::RealType;
+
+void IREEPyDMDialect::initialize() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "iree-dialects/Dialect/PyDM/IR/PyDMTypes.cpp.inc"
+      >();
+  addOperations<
+#define GET_OP_LIST
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc"
+      >();
+}
+
+Operation *IREEPyDMDialect::materializeConstant(OpBuilder &builder,
+                                                Attribute value, Type type,
+                                                Location loc) {
+  // Since we support materialization of builtin types too, explicitly
+  // allow these.
+  if (type.isa<PyBoolType, BytesType, PyIntegerType, PyRealType, StrType,
+               BuiltinIntegerType>()) {
+    return builder.create<PYDM::ConstantOp>(loc, type, value);
+  }
+
+  if (type.isa<NoneType>()) {
+    return builder.create<PYDM::NoneOp>(loc, type);
+  }
+
+  if (type.isa<ExceptionResultType>() && value.isa<UnitAttr>()) {
+    return builder.create<PYDM::SuccessOp>(loc, type);
+  }
+
+  llvm_unreachable("unhandled iree_pydm constant materialization");
+}
+
+Type IREEPyDMDialect::parseType(DialectAsmParser &parser) const {
+  StringRef typeTag;
+  if (succeeded(parser.parseKeyword(&typeTag))) {
+    Type genType;
+    auto parseResult = generatedTypeParser(parser, typeTag, genType);
+    if (parseResult.hasValue()) {
+      if (*parseResult) {
+        return Type();
+      }
+      return genType;
+    }
+  }
+
+  parser.emitError(parser.getNameLoc(), "unknown dialect type");
+  return Type();
+}
+
+void IREEPyDMDialect::printType(Type type, DialectAsmPrinter &printer) const {
+  (void)generatedTypePrinter(type, printer);
+}
+
+//------------------------------------------------------------------------------
+// Python type implementation
+//------------------------------------------------------------------------------
+
+// BoolType
+BuiltinTypeCode PYDM::BoolType::getTypeCode() const {
+  return static_cast<BuiltinTypeCode>(
+      makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
+}
+
+StringRef PYDM::BoolType::getPythonTypeName() const { return "bool"; }
+
+Optional<NumericCategory> PYDM::BoolType::getNumericCategory() const {
+  return NumericCategory::Bool;
+}
+
+Optional<int> PYDM::BoolType::getNumericSubTypeCode() const { return 0; }
+
+Optional<int> PYDM::BoolType::getNumericPromotionOrder() const {
+  return static_cast<int>(getTypeCode());
+}
+
+// BytesType
+BuiltinTypeCode PYDM::BytesType::getTypeCode() const {
+  return BuiltinTypeCode::Bytes;
+}
+
+StringRef PYDM::BytesType::getPythonTypeName() const { return "bytes"; }
+
+// ExceptionResultType
+BuiltinTypeCode PYDM::ExceptionResultType::getTypeCode() const {
+  return BuiltinTypeCode::ExceptionResult;
+}
+
+StringRef PYDM::ExceptionResultType::getPythonTypeName() const {
+  return "Exception";
+}
+
+// IntegerType
+LogicalResult PYDM::IntegerType::verify(
+    function_ref<InFlightDiagnostic()> emitError, Optional<int> bitWidth) {
+  if (!bitWidth) return success();
+  int w = abs(*bitWidth);
+  if (w == 0 || w == 8 || w == 16 || w == 32 || w == 64) return success();
+  return emitError() << "unsupported python integer bit width: " << w;
+}
+
+BuiltinTypeCode PYDM::IntegerType::getTypeCode() const {
+  return static_cast<BuiltinTypeCode>(
+      makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
+}
+
+StringRef PYDM::IntegerType::getPythonTypeName() const { return "int"; }
+
+Optional<NumericCategory> PYDM::IntegerType::getNumericCategory() const {
+  if (isWeak()) return NumericCategory::WeakInteger;
+  if (getBitWidth() == 0) return NumericCategory::APSigned;
+  if (isSigned()) return NumericCategory::Signed;
+  return NumericCategory::Unsigned;
+}
+
+Optional<int> PYDM::IntegerType::getNumericSubTypeCode() const {
+  if (isWeak()) return 0;
+  IntegerSubTypeCode stc;
+  switch (getBitWidth()) {
+    case 8:
+      stc = IntegerSubTypeCode::Integer8;
+      break;
+    case 16:
+      stc = IntegerSubTypeCode::Integer16;
+      break;
+    case 32:
+      stc = IntegerSubTypeCode::Integer32;
+      break;
+    case 64:
+      stc = IntegerSubTypeCode::Integer64;
+      break;
+    default: {
+      llvm_unreachable("unsupported numeric bitwidth");
+    }
+  }
+  return static_cast<int>(stc);
+}
+
+Optional<int> PYDM::IntegerType::getNumericPromotionOrder() const {
+  return static_cast<int>(getTypeCode());
+}
+
+bool PYDM::IntegerType::isWeak() const { return !getImpl()->bitWidth; }
+
+unsigned PYDM::IntegerType::getBitWidth() const {
+  return abs(*getImpl()->bitWidth);
+}
+
+bool PYDM::IntegerType::isSigned() const { return *getImpl()->bitWidth >= 0; }
+
+BuiltinTypeCode PYDM::ListType::getTypeCode() const {
+  return BuiltinTypeCode::List;
+}
+
+// ListType
+StringRef PYDM::ListType::getPythonTypeName() const { return "list"; }
+
+BuiltinTypeCode PYDM::NoneType::getTypeCode() const {
+  return BuiltinTypeCode::List;
+}
+
+bool PYDM::ListType::isRefinable() const {
+  if (getStorageClass() == CollectionStorageClass::Empty) return false;
+
+  if (!getUniformElementType()) return true;
+
+  if (auto pyType = getUniformElementType().dyn_cast<PythonTypeInterface>())
+    return pyType.isRefinable();
+
+  return false;
+}
+
+Type PYDM::ListType::getElementStorageType() const {
+  switch (getStorageClass()) {
+    case CollectionStorageClass::Boxed:
+    case CollectionStorageClass::Empty:
+      return ObjectType::get(getContext());
+    case CollectionStorageClass::Unboxed:
+      assert(getUniformElementType() &&
+             "unboxed list should have uniform element type");
+      return getUniformElementType();
+    default:
+      llvm_unreachable("unsupported storage class");
+      return {};
+  }
+}
+
+// NoneType
+StringRef PYDM::NoneType::getPythonTypeName() const { return "None"; }
+
+// ObjectType
+BuiltinTypeCode PYDM::ObjectType::getTypeCode() const {
+  return BuiltinTypeCode::Object;
+}
+
+StringRef PYDM::ObjectType::getPythonTypeName() const { return "object"; }
+
+bool PYDM::ObjectType::isRefinable() const {
+  if (!getPrimitiveType()) return true;
+
+  if (auto pyType = getPrimitiveType().dyn_cast<PythonTypeInterface>())
+    return pyType.isRefinable();
+
+  return false;
+}
+
+// RealType
+LogicalResult PYDM::RealType::verify(
+    function_ref<InFlightDiagnostic()> emitError, FloatType floatType) {
+  if (!floatType) return success();
+  if (!floatType.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>()) {
+    return emitError() << "unsupported Python floating point type: "
+                       << floatType;
+  }
+  return success();
+}
+
+BuiltinTypeCode PYDM::RealType::getTypeCode() const {
+  return static_cast<BuiltinTypeCode>(
+      makeNumericTypeCode(*getNumericCategory(), *getNumericSubTypeCode()));
+}
+
+StringRef PYDM::RealType::getPythonTypeName() const { return "float"; }
+
+Optional<NumericCategory> PYDM::RealType::getNumericCategory() const {
+  if (isWeak()) return NumericCategory::WeakReal;
+  return NumericCategory::Real;
+}
+
+Optional<int> PYDM::RealType::getNumericSubTypeCode() const {
+  if (isWeak()) return 0;
+  RealSubTypeCode stc =
+      TypeSwitch<Type, RealSubTypeCode>(getFloatType())
+          .Case([](BFloat16Type t) { return RealSubTypeCode::BF16; })
+          .Case([](Float16Type t) { return RealSubTypeCode::FP16; })
+          .Case([](Float32Type t) { return RealSubTypeCode::FP32; })
+          .Case([](Float64Type t) { return RealSubTypeCode::FP64; })
+          .Default([](Type t) {
+            llvm_unreachable("unsupported float type");
+            return RealSubTypeCode::FP64;
+          });
+  return static_cast<int>(stc);
+}
+
+Optional<int> PYDM::RealType::getNumericPromotionOrder() const {
+  return static_cast<int>(getTypeCode());
+}
+
+bool PYDM::RealType::isWeak() const { return !getImpl()->floatType; }
+
+// StrType
+BuiltinTypeCode PYDM::StrType::getTypeCode() const {
+  return BuiltinTypeCode::Str;
+}
+
+StringRef PYDM::StrType::getPythonTypeName() const { return "str"; }
+
+// TupleType
+BuiltinTypeCode PYDM::TupleType::getTypeCode() const {
+  return BuiltinTypeCode::Tuple;
+}
+
+StringRef PYDM::TupleType::getPythonTypeName() const { return "tuple"; }
+
+// TypeType
+BuiltinTypeCode PYDM::TypeType::getTypeCode() const {
+  return BuiltinTypeCode::Type;
+}
+
+StringRef PYDM::TypeType::getPythonTypeName() const { return "type"; }
+
+Type PYDM::TupleType::getElementStorageType() const {
+  // TODO: When it implements unboxed storage, switch here.
+  return ObjectType::get(getContext());
+}
+
+//------------------------------------------------------------------------------
+// Union type implementation
+//------------------------------------------------------------------------------
+
+LogicalResult PYDM::UnionType::verify(
+    llvm::function_ref<InFlightDiagnostic()> emitError,
+    ArrayRef<Type> alternatives) {
+  int lastTypeCode = 0;
+  for (Type alternative : alternatives) {
+    if (auto pythonType = alternative.dyn_cast<PYDM::PythonTypeInterface>()) {
+      int thisTypeCode = static_cast<int>(pythonType.getTypeCode());
+      // TODO: This doesn't account for parameterized types.
+      if (thisTypeCode <= lastTypeCode) {
+        return emitError() << "expected total order of union to be normative. "
+                              "got out of order: "
+                           << alternative;
+      }
+    } else {
+      return emitError() << "expected a python type in union. got: "
+                         << alternative;
+    }
+  }
+
+  return failure();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp
new file mode 100644
index 0000000..7ef7bfc
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/IR/PyDMOps.cpp
@@ -0,0 +1,810 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/FunctionImplementation.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+using llvm::dbgs;
+
+using PyBoolType = PYDM::BoolType;
+using PyConstantOp = PYDM::ConstantOp;
+using PyIntegerType = PYDM::IntegerType;
+using PyRealType = PYDM::RealType;
+using PyCallOp = PYDM::CallOp;
+using PyFuncOp = PYDM::FuncOp;
+
+static LogicalResult verify(Operation *) { return success(); }
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Generic pattern to unbox any operands that are a specific object
+/// type (i.e. object<integer>).
+struct UnboxOperands : public RewritePattern {
+  UnboxOperands(StringRef rootName, MLIRContext *context,
+                Optional<llvm::SmallSet<int, 4>> operandIndices = None)
+      : RewritePattern(rootName, 1, context), operandIndices(operandIndices) {}
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    bool changed = false;
+    SmallVector<Value> operands(op->getOperands());
+    auto excResultType = rewriter.getType<ExceptionResultType>();
+    for (int operandIndex = 0, e = operands.size(); operandIndex < e;
+         ++operandIndex) {
+      Value &operand = operands[operandIndex];
+      if (operandIndices && !operandIndices->contains(operandIndex)) continue;
+      if (auto objectType = operand.getType().dyn_cast<ObjectType>()) {
+        Type primitiveType = objectType.getPrimitiveType();
+        if (primitiveType) {
+          // Unbox.
+          auto unboxOp = rewriter.create<UnboxOp>(
+              loc, TypeRange{excResultType, primitiveType}, operand);
+          operand = unboxOp.primitive();
+          changed = true;
+        }
+      }
+    }
+
+    if (changed) {
+      rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
+      return success();
+    }
+
+    return failure();
+  }
+  Optional<llvm::SmallSet<int, 4>> operandIndices;
+};
+
+}  // namespace
+
+static Value getNumericZeroConstant(Location loc, Type numericType,
+                                    OpBuilder &builder) {
+  return TypeSwitch<Type, Value>(numericType)
+      .Case([&](PyBoolType t) -> Value {
+        return builder.create<PyConstantOp>(loc, t, builder.getBoolAttr(false));
+      })
+      .Case([&](PyIntegerType t) -> Value {
+        return builder.create<PyConstantOp>(loc, t,
+                                            builder.getI64IntegerAttr(0));
+      })
+      .Case([&](PyRealType t) -> Value {
+        return builder.create<PyConstantOp>(loc, t,
+                                            builder.getF64FloatAttr(0.0));
+      });
+}
+
+static Value getBoolConstant(Location loc, bool pred, OpBuilder &builder) {
+  return builder.create<PyConstantOp>(loc, builder.getType<BoolType>(),
+                                      builder.getBoolAttr(pred));
+}
+
+//===----------------------------------------------------------------------===//
+// Constants
+//===----------------------------------------------------------------------===//
+
+OpFoldResult PyConstantOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.empty() && "constant has no operands");
+  return getValue();
+}
+
+OpFoldResult NoneOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.empty() && "constant has no operands");
+  return UnitAttr::get(getContext());
+}
+
+OpFoldResult SuccessOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.empty() && "constant has no operands");
+  return UnitAttr::get(getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Variables
+//===----------------------------------------------------------------------===//
+
+void AllocFreeVarOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
+  setNameFn(getResult(), name());
+}
+
+//===----------------------------------------------------------------------===//
+// ApplyBinaryOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ApplyBinaryToSequenceClone : public OpRewritePattern<ApplyBinaryOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(ApplyBinaryOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.dunder_name() != "mul") return failure();
+    Value listOperand;
+    Value countOperand;
+    if (isBuiltinSequence(op.left()) && isInteger(op.right())) {
+      listOperand = op.left();
+      countOperand = op.right();
+    } else if (isInteger(op.left()) && isBuiltinSequence(op.right())) {
+      countOperand = op.left();
+      listOperand = op.right();
+    } else {
+      return failure();
+    }
+    Type resultType = op.getResult().getType();
+    rewriter.replaceOpWithNewOp<SequenceCloneOp>(op, resultType, listOperand,
+                                                 countOperand);
+    return success();
+  }
+
+  static bool isBuiltinSequence(Value operand) {
+    return operand.getType().isa<PYDM::ListType, PYDM::TupleType>();
+  }
+  static bool isInteger(Value operand) {
+    return operand.getType().isa<PYDM::IntegerType>();
+  }
+};
+}  // namespace
+
+void ApplyBinaryOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                MLIRContext *context) {
+  patterns.add<UnboxOperands>(getOperationName(), context);
+  patterns.add<ApplyBinaryToSequenceClone>(context);
+}
+
+bool ApplyBinaryOp::refineResultTypes() {
+  auto leftType = left().getType();
+  auto rightType = right().getType();
+  auto applyUpdates = [&](Type newResultType) -> bool {
+    if (newResultType != getResult().getType()) {
+      getResult().setType(newResultType);
+      return true;
+    }
+    return false;
+  };
+
+  // Both numeric types. It is only dynamically legal for statically known
+  // numeric types to be the same, in which case the result type must be the
+  // same as well.
+  auto ptLeft = leftType.dyn_cast<PythonTypeInterface>();
+  auto ptRight = rightType.dyn_cast<PythonTypeInterface>();
+  if (ptLeft && ptRight && ptLeft.getNumericPromotionOrder() &&
+      ptRight.getNumericPromotionOrder()) {
+    if (leftType == rightType) {
+      return applyUpdates(leftType);
+    }
+  }
+
+  // (list, integer) or (integer, list) refine to the list type.
+  if (dunder_name() == "mul") {
+    auto leftList = leftType.dyn_cast<ListType>();
+    auto rightList = rightType.dyn_cast<ListType>();
+    auto leftInteger = leftType.dyn_cast<IntegerType>();
+    auto rightInteger = rightType.dyn_cast<IntegerType>();
+    if (leftList && rightInteger) {
+      return applyUpdates(leftList);
+    } else if (leftInteger && rightList) {
+      return applyUpdates(rightList);
+    }
+  }
+
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// ApplyCompareOp
+//===----------------------------------------------------------------------===//
+
+void ApplyCompareOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<UnboxOperands>(getOperationName(), context);
+}
+
+//===----------------------------------------------------------------------===//
+// AsBoolOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldAsBoolFromBool : public OpRewritePattern<AsBoolOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(AsBoolOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.value().getType().isa<BoolType>()) {
+      rewriter.replaceOp(op, op.value());
+      return success();
+    }
+    return failure();
+  }
+};
+
+struct FoldAsBoolFromNumeric : public OpRewritePattern<AsBoolOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(AsBoolOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ptType = op.value().getType().dyn_cast<PythonTypeInterface>();
+    if (!ptType) return failure();
+    if (!ptType.getNumericPromotionOrder()) return failure();
+
+    auto boolType = rewriter.getType<BoolType>();
+    Value zeroValue =
+        getNumericZeroConstant(loc, op.value().getType(), rewriter);
+    Value trueValue = getBoolConstant(loc, true, rewriter);
+    Value falseValue = getBoolConstant(loc, false, rewriter);
+    Value cmpResult = rewriter.create<ApplyCompareOp>(
+        loc, boolType, rewriter.getStringAttr("eq"), op.value(), zeroValue);
+    rewriter.replaceOpWithNewOp<SelectOp>(op, boolType, cmpResult, falseValue,
+                                          trueValue);
+    return success();
+  }
+};
+
+}  // namespace
+
+void AsBoolOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                           MLIRContext *context) {
+  patterns.add<FoldAsBoolFromBool, FoldAsBoolFromNumeric>(context);
+}
+
+OpFoldResult AsBoolOp::fold(ArrayRef<Attribute> operands) {
+  Builder b(getContext());
+  // Fold NoneType to False.
+  if (value().getType().isa<NoneType>()) {
+    return b.getBoolAttr(false);
+  }
+
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// AssignSubscriptOp
+//===----------------------------------------------------------------------===//
+
+void AssignSubscriptOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                    MLIRContext *context) {
+  llvm::SmallSet<int, 4> unboxIndices;
+  unboxIndices.insert(0);
+  unboxIndices.insert(1);
+  patterns.add<UnboxOperands>(getOperationName(), context, unboxIndices);
+}
+
+//===----------------------------------------------------------------------===//
+// BoolToPredOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult BoolToPredOp::fold(ArrayRef<Attribute> operands) {
+  if (!operands[0]) return {};
+  // Since both BoolType and I1 share the attribute form (an IntegerAttr of I1),
+  // we can just return it.
+  return operands[0];
+}
+
+//===----------------------------------------------------------------------===//
+// BoxOp and UnboxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult BoxOp::canonicalize(BoxOp op, PatternRewriter &rewriter) {
+  // Sometimes boxes are emitted when the input is an object. Just remove.
+  if (op.primitive().getType().isa<ObjectType>()) {
+    rewriter.replaceOp(op, op.primitive());
+    return success();
+  }
+
+  // Box to an appropriate type and static info cast.
+  ObjectType objectType = rewriter.getType<ObjectType>(nullptr);
+  if (op.object().getType() == objectType &&
+      !op.primitive().getType().isa<ObjectType>()) {
+    auto refinedBox = rewriter.create<BoxOp>(
+        op.getLoc(),
+        rewriter.getType<ObjectType>(
+            op.primitive().getType().cast<PrimitiveType>()),
+        op.primitive());
+    rewriter.replaceOpWithNewOp<StaticInfoCastOp>(op, op.object().getType(),
+                                                  refinedBox);
+    return success();
+  }
+
+  return failure();
+}
+
+LogicalResult UnboxOp::canonicalize(UnboxOp unboxOp,
+                                    PatternRewriter &rewriter) {
+  auto loc = unboxOp.getLoc();
+
+  // Handle the case of an immediate BoxOp producer.
+  if (auto boxProducer =
+          dyn_cast_or_null<BoxOp>(unboxOp.object().getDefiningOp())) {
+    // If the producer is boxing to the same type we are unboxing, then
+    // just elide everything.
+    if (boxProducer.primitive().getType() == unboxOp.primitive().getType()) {
+      auto successValue = rewriter.create<SuccessOp>(
+          loc, rewriter.getType<ExceptionResultType>());
+      rewriter.replaceOp(unboxOp, {successValue, boxProducer.primitive()});
+      return success();
+    }
+  }
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// DynamicBinaryPromoteOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Resolves a DynamicBinaryPromote over numeric operands to either elide
+/// or insert specific PromoteNumeric ops.
+struct ResolveNumericDynamicBinaryPromote
+    : public OpRewritePattern<DynamicBinaryPromoteOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto leftType = op.left().getType();
+    auto rightType = op.right().getType();
+    auto leftResultType = op.getResultTypes()[0];
+    auto rightResultType = op.getResultTypes()[1];
+    auto leftPt = leftType.dyn_cast<PythonTypeInterface>();
+    auto rightPt = rightType.dyn_cast<PythonTypeInterface>();
+    if (!leftPt || !rightPt) return failure();
+
+    Optional<int> leftOrder = leftPt.getNumericPromotionOrder();
+    Optional<int> rightOrder = rightPt.getNumericPromotionOrder();
+    Value newLeft = op.left();
+    Value newRight = op.right();
+
+    if (leftOrder && rightOrder) {
+      // Both numeric.
+      if (*leftOrder > *rightOrder) {
+        newRight = rewriter.create<PromoteNumericOp>(loc, leftType, newRight);
+      }
+      if (*rightOrder > *leftOrder) {
+        newLeft = rewriter.create<PromoteNumericOp>(loc, rightType, newLeft);
+      }
+    } else {
+      return failure();
+    }
+
+    // Need to box back to the original type (which will always be a generic
+    // object).
+    newLeft = rewriter.create<BoxOp>(loc, leftResultType, newLeft);
+    newRight = rewriter.create<BoxOp>(loc, rightResultType, newRight);
+
+    rewriter.replaceOp(op, {newLeft, newRight});
+    return success();
+  }
+};
+
+/// If we statically determine one of the arguments to be a concrete, non
+/// numeric type, then the op has no meaning and is elided.
+struct ElideNonNumericDynamicBinaryPromote
+    : public OpRewritePattern<DynamicBinaryPromoteOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(DynamicBinaryPromoteOp op,
+                                PatternRewriter &rewriter) const override {
+    if ((!isConcreteNonNumericType(op.left().getType()) &&
+         !isConcreteNonNumericType(op.right().getType())))
+      return failure();
+
+    // Since DynamicBinaryPromote already returns object, and we only match
+    // non-object operands, box them back.
+    auto loc = op.getLoc();
+    auto leftResultType = op.getResultTypes()[0];
+    auto rightResultType = op.getResultTypes()[1];
+    Value newLeft = rewriter.create<BoxOp>(loc, leftResultType, op.left());
+    Value newRight = rewriter.create<BoxOp>(loc, rightResultType, op.right());
+    rewriter.replaceOp(op, {newLeft, newRight});
+    return success();
+  }
+
+  static bool isConcreteNonNumericType(Type t) {
+    if (t.isa<ObjectType>()) return false;
+    auto pt = t.dyn_cast<PythonTypeInterface>();
+    if (!pt || pt.getNumericPromotionOrder()) return false;
+    return true;
+  }
+};
+
+}  // namespace
+
+void DynamicBinaryPromoteOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<ResolveNumericDynamicBinaryPromote>(context);
+  patterns.add<UnboxOperands>(getOperationName(), context);
+  patterns.add<ElideNonNumericDynamicBinaryPromote>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// FunctionalIfOp
+//===----------------------------------------------------------------------===//
+
+::llvm::StringRef FunctionalIfOp::getDefaultDialect() { return "iree_pydm"; }
+
+static LogicalResult verify(FunctionalIfOp op) {
+  if (op.getNumResults() != 0 && op.elseRegion().empty())
+    return op.emitOpError("must have an else block if defining values");
+
+  return RegionBranchOpInterface::verifyTypes(op);
+}
+
+static ParseResult parseFunctionalIfOp(OpAsmParser &parser,
+                                       OperationState &result) {
+  // Create the regions for 'then'.
+  result.regions.reserve(2);
+  Region *thenRegion = result.addRegion();
+  Region *elseRegion = result.addRegion();
+
+  auto &builder = parser.getBuilder();
+  OpAsmParser::OperandType cond;
+  Type conditionType = builder.getType<PyBoolType>();
+  if (parser.parseOperand(cond) ||
+      parser.resolveOperand(cond, conditionType, result.operands))
+    return failure();
+  // Parse optional results type list.
+  if (parser.parseOptionalArrowTypeList(result.types)) return failure();
+  // Parse the 'then' region.
+  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
+    return failure();
+  // IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
+
+  // If we find an 'else' keyword then parse the 'else' region.
+  if (!parser.parseOptionalKeyword("else")) {
+    if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
+      return failure();
+    // IfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
+    // result.location);
+  }
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+  return success();
+}
+
+static void print(OpAsmPrinter &p, FunctionalIfOp op) {
+  bool printBlockTerminators = false;
+
+  p << " " << op.condition();
+  if (!op.results().empty()) {
+    p << " -> (" << op.getResultTypes() << ")";
+    // Print yield explicitly if the op defines values.
+    printBlockTerminators = true;
+  }
+  p.printRegion(op.thenRegion(),
+                /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/printBlockTerminators);
+
+  // Print the 'else' regions if it exists and has a block.
+  auto &elseRegion = op.elseRegion();
+  if (!elseRegion.empty()) {
+    p << " else";
+    p.printRegion(elseRegion,
+                  /*printEntryBlockArgs=*/false,
+                  /*printBlockTerminators=*/printBlockTerminators);
+  }
+
+  p.printOptionalAttrDict(op->getAttrs());
+}
+
+/// Given the region at `index`, or the parent operation if `index` is None,
+/// return the successor regions. These are the regions that may be selected
+/// during the flow of control. `operands` is a set of optional attributes that
+/// correspond to a constant value for each operand, or null if that operand is
+/// not a constant.
+void FunctionalIfOp::getSuccessorRegions(
+    Optional<unsigned> index, ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  // The `then` and the `else` region branch back to the parent operation.
+  if (index.hasValue()) {
+    regions.push_back(RegionSuccessor(getResults()));
+    return;
+  }
+
+  // Don't consider the else region if it is empty.
+  Region *elseRegion = &this->elseRegion();
+  if (elseRegion->empty()) elseRegion = nullptr;
+
+  // Otherwise, the successor is dependent on the condition.
+  if (auto condAttr = operands.front().dyn_cast_or_null<BoolAttr>()) {
+    bool condition = condAttr.getValue();
+    // Add the successor regions using the condition.
+    regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
+  } else {
+    // If the condition isn't constant, both regions may be executed.
+    regions.push_back(RegionSuccessor(&thenRegion()));
+    // If the else region does not exist, it is not a viable successor.
+    if (elseRegion) regions.push_back(RegionSuccessor(elseRegion));
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+::llvm::StringRef PyFuncOp::getDefaultDialect() { return "iree_pydm"; }
+
+LogicalResult PyFuncOp::verifyType() {
+  // TODO: Enforce arg/result invariants.
+  return success();
+}
+
+static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
+  auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
+                          ArrayRef<Type> results,
+                          function_like_impl::VariadicFlag, std::string &) {
+    return builder.getFunctionType(argTypes, results);
+  };
+
+  return function_like_impl::parseFunctionLikeOp(
+      parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+static void print(PyFuncOp op, OpAsmPrinter &p) {
+  FunctionType fnType = op.getType();
+  function_like_impl::printFunctionLikeOp(
+      p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
+}
+
+static LogicalResult verify(PyFuncOp op) {
+  // TODO: Enforce invariants.
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MakeListOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(MakeListOp op) {
+  auto listType = op.list().getType().cast<ListType>();
+  switch (listType.getStorageClass()) {
+    case CollectionStorageClass::Boxed:
+      for (auto element : op.elements()) {
+        if (!element.getType().isa<ObjectType>()) {
+          return op.emitOpError() << "making a list with boxed storage class "
+                                     "must have object elements. Got: "
+                                  << element.getType();
+        }
+      }
+      break;
+    case CollectionStorageClass::Unboxed:
+      for (auto element : op.elements()) {
+        if (element.getType().isa<ObjectType>()) {
+          return op.emitOpError() << "making a list with unboxed storage class "
+                                     "must not have object elements. Got: "
+                                  << element.getType();
+        }
+      }
+      break;
+    case CollectionStorageClass::Empty:
+      if (!op.elements().empty()) {
+        return op.emitOpError()
+               << "making a list with empty storage class must have zero "
+                  "elements";
+      }
+      break;
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// NegOp
+//===----------------------------------------------------------------------===//
+
+void NegOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                        MLIRContext *context) {
+  patterns.add<UnboxOperands>(getOperationName(), context);
+}
+
+bool NegOp::refineResultTypes() {
+  if (value().getType() != getResult().getType()) {
+    getResult().setType(value().getType());
+    return true;
+  }
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// PatternMatchCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PatternMatchCallOp::verifySymbolUses(
+    SymbolTableCollection &symbolTable) {
+  auto verifySymbols = [&](ArrayAttr symbols) -> LogicalResult {
+    for (auto symbolAttr : symbols) {
+      auto symbol = symbolAttr.cast<FlatSymbolRefAttr>();
+      PyFuncOp fn =
+          symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, symbol);
+      if (!fn)
+        return emitOpError() << "'" << symbol.getValue()
+                             << "' does not reference a valid function";
+    }
+    return success();
+  };
+  auto genericsAttr = (*this)->getAttrOfType<ArrayAttr>("generic_match");
+  if (!genericsAttr)
+    return emitOpError(
+        "requires a 'generic_match' array of symbol reference attributes");
+  if (failed(verifySymbols(genericsAttr))) return failure();
+
+  auto specificsAttr = (*this)->getAttrOfType<ArrayAttr>("specific_match");
+  if (!specificsAttr)
+    return emitOpError(
+        "requires a 'specific_match' array of symbol reference attributes");
+  if (failed(verifySymbols(specificsAttr))) return failure();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// PromoteNumericOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult PromoteNumericOp::fold(ArrayRef<Attribute> operands) {
+  if (!operands[0]) return {};
+
+  Builder b(getContext());
+  Attribute fromAttr = operands[0];
+  return TypeSwitch<Type, OpFoldResult>(getResult().getType())
+      .Case([&](PyIntegerType toType) -> OpFoldResult {
+        return TypeSwitch<Attribute, OpFoldResult>(fromAttr)
+            .Case([&](BoolAttr fromBool) -> OpFoldResult {
+              return b.getI64IntegerAttr(fromBool.getValue() ? 1 : 0);
+            })
+            .Default([](Attribute) -> OpFoldResult { return {}; });
+      })
+      .Case([&](PyRealType toType) -> OpFoldResult {
+        return TypeSwitch<Attribute, OpFoldResult>(fromAttr)
+            .Case([&](BoolAttr fromBool) -> OpFoldResult {
+              return b.getF64FloatAttr(fromBool.getValue() ? 1.0 : 0.0);
+            })
+            .Case([&](IntegerAttr fromInteger) -> OpFoldResult {
+              APInt value = fromInteger.getValue();
+              return b.getF64FloatAttr(value.getSExtValue());
+            })
+            .Default([](Attribute) -> OpFoldResult { return {}; });
+      })
+      .Default([](Type) -> OpFoldResult { return {}; });
+}
+
+LogicalResult PromoteNumericOp::canonicalize(PromoteNumericOp op,
+                                             PatternRewriter &rewriter) {
+  if (op.input().getType() == op.getResult().getType()) {
+    rewriter.replaceOp(op, op.input());
+    return success();
+  }
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// RaiseOnFailureOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PYDM::RaiseOnFailureOp::fold(
+    ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
+  assert(operands.size() == 1 && "expected one fold operand");
+  // Unit exception result is success. Just elide.
+  if (operands[0] && operands[0].isa<UnitAttr>()) {
+    erase();
+    return success();
+  }
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
+  if (!operands[0]) return {};
+
+  BoolAttr boolAttr = operands[0].cast<BoolAttr>();
+  if (boolAttr.getValue())
+    return true_value();
+  else
+    return false_value();
+}
+
+//===----------------------------------------------------------------------===//
+// SequenceCloneOp
+//===----------------------------------------------------------------------===//
+
+bool SequenceCloneOp::refineResultTypes() {
+  if (sequence().getType() != getResult().getType()) {
+    getResult().setType(sequence().getType());
+    return true;
+  }
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// SubscriptOp
+//===----------------------------------------------------------------------===//
+
+void SubscriptOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                              MLIRContext *context) {
+  patterns.add<UnboxOperands>(getOperationName(), context);
+}
+
+//===----------------------------------------------------------------------===//
+// CallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PyCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // Check that the callee attribute was specified.
+  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+  if (!fnAttr)
+    return emitOpError("requires a 'callee' symbol reference attribute");
+  PyFuncOp fn = symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, fnAttr);
+  if (!fn)
+    return emitOpError() << "'" << fnAttr.getValue()
+                         << "' does not reference a valid function";
+
+  // Verify that the operand and result types match the callee.
+  auto fnType = fn.getType();
+  if (fnType.getNumInputs() != getNumOperands())
+    return emitOpError("incorrect number of operands for callee");
+
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+    if (getOperand(i).getType() != fnType.getInput(i)) {
+      return emitOpError("operand type mismatch: expected operand type ")
+             << fnType.getInput(i) << ", but provided "
+             << getOperand(i).getType() << " for operand number " << i;
+    }
+  }
+
+  if (fnType.getNumResults() != getNumResults())
+    return emitOpError("incorrect number of results for callee");
+
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+    if (getResult(i).getType() != fnType.getResult(i)) {
+      auto diag = emitOpError("result type mismatch at index ") << i;
+      diag.attachNote() << "      op result types: " << getResultTypes();
+      diag.attachNote() << "function result types: " << fnType.getResults();
+      return diag;
+    }
+  }
+
+  return success();
+}
+
+FunctionType PyCallOp::getCalleeType() {
+  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
+}
+
+//===----------------------------------------------------------------------===//
+// DynamicCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DynamicCallOp::verifySymbolUses(
+    SymbolTableCollection &symbolTable) {
+  // Check that the callee attribute was specified.
+  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+  if (!fnAttr)
+    return emitOpError("requires a 'callee' symbol reference attribute");
+  Operation *fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr);
+  if (!fn || !isa<PyFuncOp>(fn))
+    return emitOpError() << "'" << fnAttr.getValue()
+                         << "' does not reference a valid function";
+  return success();
+}
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..fd9704a
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_subdirectory(Optimize)
+add_subdirectory(RTL)
+add_subdirectory(ToIREE)
+
+add_mlir_library(IREEPyDMPasses
+  Passes.cpp
+
+  DEPENDS
+  IREEPyDMTransformsPassesIncGen
+
+  LINK_LIBS PUBLIC
+  IREEPyDMOptimizePasses
+  IREEPyDMRTLPasses
+  IREEPyDMToIREEPasses
+  MLIRTransforms
+)
+
+iree_dialects_target_includes(IREEPyDMPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/CMakeLists.txt
new file mode 100644
index 0000000..1f7abb1
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_library(IREEPyDMOptimizePasses
+  FixateWeakNumeric.cpp
+  LocalPropagateTypes.cpp
+  VariablesToSSA.cpp
+
+  DEPENDS
+  IREEPyDMTransformsPassesIncGen
+
+  LINK_LIBS PUBLIC
+  IREEPyDMDialect
+  IREEPyDMUtils
+  MLIRIR
+  MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMOptimizePasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp
new file mode 100644
index 0000000..fc17764
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/FixateWeakNumeric.cpp
@@ -0,0 +1,111 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "../PassDetail.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+namespace {
+
+struct FixateWeakNumericPass
+    : public FixateWeakNumericBase<FixateWeakNumericPass> {
+  void runOnOperation() override {
+    Operation *rootOp = getOperation();
+    rootOp->walk([&](Operation *op) {
+      convertOperation(op);
+      return WalkResult::advance();
+    });
+  }
+
+  void convertOperation(Operation *op) {
+    // Process all regions/blocks to rewrite block arguments.
+    for (auto &region : op->getRegions()) {
+      for (auto &block : region) {
+        for (BlockArgument blockArg : block.getArguments()) {
+          convertValue(blockArg);
+        }
+      }
+    }
+
+    // And all results.
+    for (Value result : op->getResults()) {
+      convertValue(result);
+    }
+
+    // Special cases for operations.
+    if (auto funcOp = llvm::dyn_cast<PYDM::FuncOp>(op)) {
+      FunctionType existingFt = funcOp.getType();
+      FunctionType newFt = convertFunctionType(existingFt);
+      if (newFt != existingFt) {
+        funcOp.setType(newFt);
+      }
+    }
+  }
+
+  void convertValue(Value value) {
+    value.setType(convertType(value.getType()));
+  }
+
+  Type convertType(Type type) {
+    // TODO: The specific types we promote to need to be configured by the
+    // lowering options.
+    if (auto integerType = type.dyn_cast<PYDM::IntegerType>()) {
+      if (integerType.isWeak()) {
+        return PYDM::IntegerType::get(type.getContext(), 32);
+      }
+    } else if (auto realType = type.dyn_cast<PYDM::RealType>()) {
+      if (realType.isWeak()) {
+        return PYDM::RealType::get(type.getContext(),
+                                   mlir::Float32Type::get(type.getContext()));
+      }
+    } else if (auto objectType = type.dyn_cast<PYDM::ObjectType>()) {
+      Type primitiveType = objectType.getPrimitiveType();
+      if (primitiveType) {
+        Type newPrimitiveType = convertType(primitiveType);
+        if (newPrimitiveType != primitiveType) {
+          return PYDM::ObjectType::get(
+              type.getContext(), newPrimitiveType.cast<PYDM::PrimitiveType>());
+        }
+      }
+    }
+
+    return type;
+  }
+
+  FunctionType convertFunctionType(FunctionType ft) {
+    SmallVector<Type> inputs(ft.getInputs().begin(), ft.getInputs().end());
+    SmallVector<Type> results(ft.getResults().begin(), ft.getResults().end());
+    bool modified = false;
+    for (Type &type : inputs) {
+      Type newType = convertType(type);
+      if (type != newType) {
+        type = newType;
+        modified = true;
+      }
+    }
+    for (Type &type : results) {
+      Type newType = convertType(type);
+      if (type != newType) {
+        type = newType;
+        modified = true;
+      }
+    }
+
+    if (!modified) return ft;
+
+    return FunctionType::get(ft.getContext(), inputs, results);
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<>> PYDM::createFixateWeakNumericPass() {
+  return std::make_unique<FixateWeakNumericPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp
new file mode 100644
index 0000000..4f4a73f
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/LocalPropagateTypes.cpp
@@ -0,0 +1,252 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "../PassDetail.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/Utils/TypeInference.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+using llvm::dbgs;
+#define DEBUG_TYPE "pydm_opt"
+
+namespace {
+
+struct LocalPropagateTypesPass
+    : public LocalPropagateTypesBase<LocalPropagateTypesPass> {
+  void runOnOperation() override {
+    // Prepare selected canonicalization patterns.
+    auto *context = &getContext();
+    PermutedTypePropagator propagator(context);
+
+    RewritePatternSet canonicalizePatterns(context);
+    ApplyBinaryOp::getCanonicalizationPatterns(canonicalizePatterns, context);
+    ApplyCompareOp::getCanonicalizationPatterns(canonicalizePatterns, context);
+    AsBoolOp::getCanonicalizationPatterns(canonicalizePatterns, context);
+    BoxOp::getCanonicalizationPatterns(canonicalizePatterns, context);
+    DynamicBinaryPromoteOp::getCanonicalizationPatterns(canonicalizePatterns,
+                                                        context);
+    NegOp::getCanonicalizationPatterns(canonicalizePatterns, context);
+    PromoteNumericOp::getCanonicalizationPatterns(canonicalizePatterns,
+                                                  context);
+    SubscriptOp::getCanonicalizationPatterns(canonicalizePatterns, context);
+    UnboxOp::getCanonicalizationPatterns(canonicalizePatterns, context);
+    FrozenRewritePatternSet frozenCanonicalizePatterns(
+        std::move(canonicalizePatterns));
+    GreedyRewriteConfig rewriterConfig;
+    // During the main fixpoint iteration, we cannot simplify regions because
+    // our propagator is keeping a cache of permuted blocks (we can add blocks
+    // but not remove until iteration is complete).
+    rewriterConfig.enableRegionSimplification = false;
+
+    bool changed = false;
+    for (int i = 0; i < 500; ++i) {
+      LLVM_DEBUG(dbgs() << "--- Local type propagation iteration " << i
+                        << "\n");
+      if (failed(applyPatternsAndFoldGreedily(
+              getOperation(), frozenCanonicalizePatterns, rewriterConfig))) {
+        emitError(getOperation().getLoc())
+            << "failed to converge type propagation canonicalizations";
+        return signalPassFailure();
+      }
+      changed = false;
+      if (sinkStaticInfoCasts()) changed = true;
+      if (refineResultTypes()) changed = true;
+      permuteRefinedBlocks(propagator);
+      if (!changed) break;
+    }
+
+    // Now that iteration is complete and we are no longer using the
+    // propagator, do one final canonicalization with region simplification
+    // enabled. This will prune out all of the excess blocks we created.
+    // Note that because we are still using a subset of dialect-specific
+    // patterns, this is less than a full canonicalization pass will do.
+    rewriterConfig.enableRegionSimplification = true;
+    if (failed(applyPatternsAndFoldGreedily(
+            getOperation(), frozenCanonicalizePatterns, rewriterConfig))) {
+      emitError(getOperation().getLoc())
+          << "failed to converge type propagation canonicalizations";
+      return signalPassFailure();
+    }
+  }
+
+  // Moving things around the CFG often creates unresolved static info casts.
+  // We sink these until they don't go any further (typically eliminating them).
+  // Returns whether any changes were made.
+  bool sinkStaticInfoCasts() {
+    bool changed = false;
+    auto allCasts = getStaticInfoCasts();
+    for (StaticInfoCastOp castOp : allCasts) {
+      Value fromValue = castOp.value();
+      ObjectType fromType = castOp.value().getType().dyn_cast<ObjectType>();
+      ObjectType toType = castOp.value().getType().dyn_cast<ObjectType>();
+      if (!fromType || !toType) {
+        LLVM_DEBUG(dbgs() << "Skipping non-object cast: " << castOp << "\n");
+        continue;
+      }
+      // We only want to sink refinements (where we know more on input).
+      if (fromType.getPrimitiveType() && !toType.getPrimitiveType()) {
+        LLVM_DEBUG(dbgs() << "Skipping non-refinement cast: " << castOp
+                          << "\n");
+        continue;
+      }
+
+      bool eliminatedUses = true;
+      SmallVector<OpOperand *> uses;
+      for (auto &use : castOp.getResult().getUses()) {
+        uses.push_back(&use);
+      }
+      for (auto *use : uses) {
+        // Most of our ops which accept objects are internally tolerant of
+        // receiving a refinement.
+        if (auto refinable =
+                llvm::dyn_cast<TypeRefinableOpInterface>(use->getOwner())) {
+          use->set(fromValue);
+          changed = true;
+          LLVM_DEBUG(dbgs()
+                     << "Sink refined type into: " << *use->getOwner() << "\n");
+        } else if (auto branchOp =
+                       llvm::dyn_cast<BranchOpInterface>(use->getOwner())) {
+          // We just update it directly and rely on the fix-up step after
+          // to smooth it all out.
+          changed = true;
+          use->set(fromValue);
+          LLVM_DEBUG(dbgs()
+                     << "Sink refined type into: " << *use->getOwner() << "\n");
+        } else {
+          eliminatedUses = false;
+        }
+      }
+
+      if (eliminatedUses) {
+        castOp->erase();
+      }
+    }
+    return changed;
+  }
+
+  bool refineResultTypes() {
+    // Process any refinable ops we encountered in the main walk.
+    bool changed = false;
+    LLVM_DEBUG(dbgs() << "-- Refining result types:\n");
+    getOperation()->walk([&](TypeRefinableOpInterface refinable) {
+      Operation *refinableOp = refinable.getOperation();
+      SmallVector<Type> originalResultTypes(refinableOp->getResultTypes());
+      LLVM_DEBUG(dbgs() << "  refineResultTypes: " << *refinableOp << "\n");
+      if (!refinable.refineResultTypes()) return;
+      LLVM_DEBUG(dbgs() << "  refineResultTypes changed results: "
+                        << *refinableOp << "\n");
+      OpBuilder builder(refinableOp);
+      builder.setInsertionPointAfter(refinableOp);
+      for (auto it :
+           llvm::zip(originalResultTypes, refinableOp->getOpResults())) {
+        Type origType = std::get<0>(it);
+        OpResult result = std::get<1>(it);
+        Type newType = result.getType();
+        if (origType == newType) continue;
+        // Insert a static info cast.
+        // In the future, we could further query the use for refinable
+        // support and skip creating the op.
+        LLVM_DEBUG(dbgs() << "    changed result type " << origType << " -> "
+                          << newType << "\n");
+
+        Value newResult = result;
+        Operation *replaceExcept = nullptr;
+        // It is possible to refine from an object (boxed) to an unboxed type.
+        // In order to keep the type algebra safe, we must box back.
+        if (origType.isa<ObjectType>() && newType.isa<PYDM::PrimitiveType>()) {
+          auto boxed = builder.create<BoxOp>(
+              refinableOp->getLoc(),
+              builder.getType<ObjectType>(newType.cast<PYDM::PrimitiveType>()),
+              newResult);
+          replaceExcept = boxed;
+          newResult = boxed;
+        }
+        auto casted = builder.create<StaticInfoCastOp>(refinableOp->getLoc(),
+                                                       origType, newResult);
+        if (!replaceExcept) replaceExcept = casted;
+        result.replaceAllUsesExcept(casted, replaceExcept);
+        changed = true;
+      }
+    });
+    return changed;
+  }
+
+  // We may have done type refinement on branch ops but not updated successors.
+  // We fix these up en-masse by permuting the blocks using the propagator.
+  // This is not merely mechanical: by iterating in this way with a permutation
+  // cache, it is possible to refinements that include type cycles in the CFG.
+  void permuteRefinedBlocks(PermutedTypePropagator &propagator) {
+    SmallVector<Block *> blocks;
+    for (auto &block : getOperation().body()) {
+      blocks.push_back(&block);
+    }
+
+    // This loop adds new blocks so must iterate a snapshot.
+    for (auto *block : blocks) {
+      auto mismatchedPredecessors =
+          propagator.findMismatchedBlockPredecessors(block);
+      if (mismatchedPredecessors.empty()) continue;
+      LLVM_DEBUG(dbgs() << "  ++ Processing block " << block << " ("
+                        << mismatchedPredecessors.size()
+                        << " mismatched predecessors)\n");
+
+      auto *parentInfo = propagator.lookupParentBlock(block);
+      for (auto &mismatch : mismatchedPredecessors) {
+        Location loc = mismatch.terminator.getLoc();
+        Block *permutation =
+            propagator.findBlockPermutation(parentInfo, mismatch.signature);
+        if (!permutation) {
+          LLVM_DEBUG(dbgs() << "  -- Creating new permutation for "
+                            << mismatch.signature << "\n");
+          permutation = propagator.createBlockPermutation(
+              parentInfo, mismatch.signature.getInputs(),
+              [&](Block *newBlock, Block *origBlock,
+                  BlockAndValueMapping &mapping) {
+                OpBuilder builder(newBlock, newBlock->begin());
+                for (auto it : llvm::zip(newBlock->getArguments(),
+                                         origBlock->getArguments())) {
+                  Value newArgument = std::get<0>(it);
+                  Type newType = newArgument.getType();
+                  Value origArgument = std::get<1>(it);
+                  Type origType = origArgument.getType();
+                  if (newType != origType) {
+                    newArgument = builder.create<StaticInfoCastOp>(
+                        loc, origType, newArgument);
+                    LLVM_DEBUG(dbgs() << "  -- Adding cast " << newType
+                                      << " -> " << origType << "\n");
+                  }
+                  mapping.map(origArgument, newArgument);
+                }
+              });
+        } else {
+          LLVM_DEBUG(dbgs() << "  -- Re-using existing permutation for "
+                            << mismatch.signature << "\n");
+        }
+        mismatch.terminator->setSuccessor(permutation, mismatch.successorIndex);
+      }
+    }
+  }
+
+  SmallVector<StaticInfoCastOp> getStaticInfoCasts() {
+    SmallVector<StaticInfoCastOp> results;
+    getOperation()->walk([&](StaticInfoCastOp op) { results.push_back(op); });
+    return results;
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<PYDM::FuncOp>>
+PYDM::createLocalPropagateTypesPass() {
+  return std::make_unique<LocalPropagateTypesPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp
new file mode 100644
index 0000000..0a0a141
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Optimize/VariablesToSSA.cpp
@@ -0,0 +1,214 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "../PassDetail.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+using llvm::dbgs;
+#define DEBUG_TYPE "pydm_opt"
+
+namespace {
+
+struct BlockAccessInfo {
+  // Tracks any variable, value mapping that has been hoisted to the block
+  // arguments.
+  DenseMap<Value, Value> blockArgVariableValueMap;
+
+  // Map of variable alloc value to most terminal value of the variable
+  // within the block.
+  DenseMap<Value, Value> variableValueMap;
+
+  // Set of any loads that are live.
+  DenseSet<Operation *> liveLoads;
+};
+
+struct VariablesToSSAPass : public VariablesToSSABase<VariablesToSSAPass> {
+  void runOnOperation() override {
+    // Verify that the structure we need is valid.
+    for (auto &block : getOperation().getBody().getBlocks()) {
+      if (failed(verifyBlockIsLegal(block))) {
+        return signalPassFailure();
+      }
+    }
+
+    // Canonicalize and accumulate information about per-block accesses.
+    DenseMap<Block *, BlockAccessInfo> blockAccessInfos;
+    bool changed = false;
+    for (int i = 0; i < 100; ++i) {
+      LLVM_DEBUG(dbgs() << "--- Iteration on all blocks\n");
+      changed = false;
+      for (auto &block : getOperation().getBody().getBlocks()) {
+        auto &info = blockAccessInfos[&block];
+        if (canonicalizeBlockVariableAccess(block, info)) changed = true;
+        hoistLoadsFromBlock(block, info);
+
+        // Invalidate internal value map and re-initialize from block arg
+        // carried values.
+        info.variableValueMap.clear();
+        info.variableValueMap = info.blockArgVariableValueMap;
+      }
+
+      if (!changed) break;
+    }
+
+    // We should now have eliminated as many loads as possible, so we can
+    // DCE any free variable stores (since free variables do not escape, we
+    // can just eliminate them with some simple checks).
+    elideDeadFreeVarStores();
+  }
+
+  void elideDeadFreeVarStores() {
+    SmallVector<Operation *> deadOps;
+    getOperation().walk([&](AllocFreeVarOp allocOp) {
+      bool canElide = true;
+      SmallVector<Operation *> storeOps;
+      for (auto &use : allocOp.getResult().getUses()) {
+        if (llvm::isa<LoadVarOp>(use.getOwner())) {
+          canElide = false;
+        } else if (llvm::isa<StoreVarOp>(use.getOwner())) {
+          storeOps.push_back(use.getOwner());
+        } else {
+          canElide = false;
+        }
+      }
+      if (canElide) {
+        deadOps.append(storeOps);
+        deadOps.push_back(allocOp);
+      }
+    });
+
+    // Note that we cannot erase in the walk, even though it is post-order
+    // because in addition to erasing the root op, we also erase uses of it.
+    // If one of these is immediately after the root op, it is an access
+    // violation if erased during walk.
+    for (auto *deadOp : deadOps) {
+      deadOp->erase();
+    }
+  }
+
+  // This pass must operate before any CFG operations have been performed
+  // which may cause variables to be sunk into block arguments.
+  LogicalResult verifyBlockIsLegal(Block &block) {
+    for (BlockArgument arg : block.getArguments()) {
+      if (arg.getType().isa<FreeVarRefType>()) {
+        return emitError(getOperation().getLoc())
+               << "cannot convert variables to SSA on a function which carries "
+                  "variable references across block boundaries";
+      }
+    }
+    Operation *terminator = block.getTerminator();
+    if (!terminator ||
+        !llvm::isa<BranchOp, CondBranchOp, PYDM::ReturnOp>(terminator)) {
+      return emitError(terminator->getLoc())
+             << "unsupported terminator for block";
+    }
+    return success();
+  }
+
+  // Canonicalizes variable accesses within a block such that:
+  //   - Redundant loads are eliminated (there is at most one load of a var).
+  //   - If a store dominates loads, then all loads are eliminated.
+  // Note that this is likely only viable for Free Variables, which can be
+  // treated as registers. Other variable types (once they exist), will have
+  // more involved requirements.
+  // Note that stores are never eliminated at this phase.
+  bool canonicalizeBlockVariableAccess(Block &block, BlockAccessInfo &info) {
+    bool changed = false;
+    SmallVector<Operation *> elidedOps;
+    for (Operation &op : block) {
+      if (auto storeOp = llvm::dyn_cast<StoreVarOp>(op)) {
+        Value &currentValue = info.variableValueMap[storeOp.var()];
+        currentValue = storeOp.value();
+        LLVM_DEBUG(dbgs() << "Initialize store: " << currentValue << "\n");
+      } else if (auto loadOp = llvm::dyn_cast<LoadVarOp>(op)) {
+        Value &currentValue = info.variableValueMap[loadOp.var()];
+        if (currentValue) {
+          LLVM_DEBUG(dbgs() << "Forward load from: " << currentValue << "\n");
+          Value replacementValue = currentValue;
+          if (loadOp.getResult().getType() != currentValue.getType()) {
+            OpBuilder builder(loadOp);
+            replacementValue = builder.create<StaticInfoCastOp>(
+                loadOp.getLoc(), loadOp.getResult().getType(), currentValue);
+          }
+          loadOp.getResult().replaceAllUsesWith(replacementValue);
+          elidedOps.push_back(loadOp);
+          changed = true;
+        } else {
+          LLVM_DEBUG(dbgs() << "Initialize load: " << loadOp << "\n");
+          currentValue = loadOp.getResult();
+          info.liveLoads.insert(loadOp);
+        }
+      }
+    }
+
+    for (auto *op : elidedOps) {
+      op->erase();
+    }
+    return changed;
+  }
+
+  // Lifts any live loads into the block's phi arguments and move the
+  // load up to the predecessors. This assumes that the function is in a
+  // legal form where all allocs are done in the entry block.
+  void hoistLoadsFromBlock(Block &block, BlockAccessInfo &info) {
+    // Entry block: nowhere to hoist.
+    if (block.isEntryBlock()) return;
+
+    SmallVector<std::tuple<Location, Value, Type>> loadVarTypes;
+    // Redirect each live load to a block argument.
+    for (Operation *genericLoadOp : info.liveLoads) {
+      auto loadOp = llvm::cast<LoadVarOp>(genericLoadOp);
+      loadVarTypes.emplace_back(loadOp.getLoc(), loadOp.var(),
+                                loadOp.getResult().getType());
+      Value newArg = block.addArgument(loadOp.getResult().getType());
+      info.blockArgVariableValueMap[loadOp.var()] = newArg;
+      loadOp.getResult().replaceAllUsesWith(newArg);
+      loadOp->erase();
+    }
+
+    // In each predecessor, rematerialize the load.
+    for (Block *pred : block.getPredecessors()) {
+      Operation *terminator = pred->getTerminator();
+      OpBuilder builder(terminator);
+      SmallVector<Value> newLoadValues;
+      for (auto &it : loadVarTypes) {
+        Location loc = std::get<0>(it);
+        Value varValue = std::get<1>(it);
+        Type loadType = std::get<2>(it);
+        newLoadValues.push_back(
+            builder.create<LoadVarOp>(loc, loadType, varValue));
+      }
+
+      if (auto branchOp = llvm::dyn_cast<BranchOp>(terminator)) {
+        branchOp.destOperandsMutable().append(newLoadValues);
+      } else if (auto condBranchOp = llvm::dyn_cast<CondBranchOp>(terminator)) {
+        if (condBranchOp.trueDest() == &block) {
+          condBranchOp.trueDestOperandsMutable().append(newLoadValues);
+        } else if (condBranchOp.falseDest() == &block) {
+          condBranchOp.falseDestOperandsMutable().append(newLoadValues);
+        }
+      }
+    }
+
+    info.liveLoads.clear();
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<PYDM::FuncOp>> PYDM::createVariablesToSSAPass() {
+  return std::make_unique<VariablesToSSAPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h
new file mode 100644
index 0000000..9fbfc52
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/PassDetail.h
@@ -0,0 +1,32 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
+#define IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace iree {
+class IREEDialect;
+}
+
+namespace iree_compiler {
+namespace IREE {
+namespace PYDM {
+
+class FuncOp;
+
+#define GEN_PASS_CLASSES
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h.inc"
+
+}  // namespace PYDM
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
+
+#endif  // IREE_DIALECTS_DIALECT_IREEPYDM_TRANSFORMS_PASSDETAIL_H
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp
new file mode 100644
index 0000000..820338c
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/Passes.cpp
@@ -0,0 +1,67 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+void PYDM::buildPostImportPassPipeline(OpPassManager& passManager) {
+  passManager.addNestedPass<PYDM::FuncOp>(createVariablesToSSAPass());
+  passManager.addNestedPass<PYDM::FuncOp>(createLocalPropagateTypesPass());
+  passManager.addPass(createCanonicalizerPass());
+  passManager.addPass(createCSEPass());
+}
+
+void PYDM::buildLowerToIREEPassPipeline(OpPassManager& passManager,
+                                        const LowerToIREEOptions& options) {
+  // TODO: Needs to be iterative, support optimization passes, etc.
+  passManager.addPass(createLowerIREEPyDMToRTLPass());
+  if (options.linkRtlSource) {
+    passManager.addPass(createLinkIREEPyDMRTLPass(options.linkRtlSource));
+  }
+  // TODO: Optimization passes need to be their own pipeline.
+  passManager.addPass(createFixateWeakNumericPass());
+  passManager.addPass(createCanonicalizerPass());
+
+  // Lowering passes.
+  passManager.addPass(createConvertIREEPyDMToIREEPass());
+
+  // Cleanup.
+  passManager.addPass(createCanonicalizerPass());
+  passManager.addPass(createSymbolDCEPass());
+  passManager.addPass(createCSEPass());
+}
+
+namespace PYDM_generated {
+namespace {
+#define GEN_PASS_REGISTRATION
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h.inc"
+}  // namespace
+}  // namespace PYDM_generated
+
+void PYDM::registerPasses() {
+  PYDM_generated::registerPasses();
+  PassPipelineRegistration<> postImportPassPipeline(
+      "pydm-post-import-pipeline",
+      "Runs passes to cleanup PyDM immediately post-import",
+      [](OpPassManager& passManager) {
+        buildPostImportPassPipeline(passManager);
+      });
+
+  PassPipelineRegistration<> lowerToIREEPipeline(
+      "pydm-lower-to-iree-pipeline",
+      "Runs passes to lower PyDM to IREE's input dialects",
+      [](OpPassManager& passManager) {
+        LowerToIREEOptions options;
+        buildLowerToIREEPassPipeline(passManager, options);
+      });
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/CMakeLists.txt
new file mode 100644
index 0000000..08392f8
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_library(IREEPyDMRTLPasses
+  LinkageAnalysis.cpp
+  LinkRTLPass.cpp
+  LowerToRTLPass.cpp
+
+  DEPENDS
+  IREEPyDMTransformsPassesIncGen
+
+  LINK_LIBS PUBLIC
+  IREEInputDialect
+  IREEPyDMDialect
+  MLIRIR
+  MLIRParser
+  MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMRTLPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp
new file mode 100644
index 0000000..f61214a
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkRTLPass.cpp
@@ -0,0 +1,215 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <memory>
+
+#include "../PassDetail.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Parser.h"
+
+#define DEBUG_TYPE "iree_pydm"
+
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+static StringRef safeModuleName(Operation *op) {
+  if (auto moduleOp = dyn_cast<ModuleOp>(op)) {
+    auto name = moduleOp.getName();
+    return name ? *name : StringRef("");
+  }
+  return "(unknown module)";
+}
+
+namespace {
+
+class LinkIREEPyDMRTLPass : public LinkIREEPyDMRTLBase<LinkIREEPyDMRTLPass> {
+ public:
+  LinkIREEPyDMRTLPass() = default;
+  LinkIREEPyDMRTLPass(Optional<SourceBundle> linkRtlSourceBundle)
+      : linkRtlSourceBundle(std::move(linkRtlSourceBundle)) {}
+
+ private:
+  LogicalResult initialize(MLIRContext *context) override {
+    SourceBundle localSource;
+    if (linkRtlSourceBundle) {
+      localSource = *linkRtlSourceBundle;
+    } else {
+      // Get it from the cli options.
+      localSource.asmFilePath = rtlFile;
+    }
+
+    if (localSource.asmBlob) {
+      // Parse from inline asm.
+      auto owningOp = parseSourceString(*localSource.asmBlob, context);
+      if (!owningOp) return failure();
+      rtlModule = std::make_shared<OwningModuleRef>(std::move(owningOp));
+    } else if (localSource.asmFilePath) {
+      // Parse from a file.
+      auto owningOp = parseSourceFile(*localSource.asmFilePath, context);
+      if (!owningOp) return failure();
+      rtlModule = std::make_shared<OwningModuleRef>(std::move(owningOp));
+    } else {
+      return emitError(UnknownLoc::get(context))
+             << "pass " << getArgument()
+             << "must be initialized with an RTL module (did you mean to "
+                "add an rtl-file option?)";
+    }
+
+    ModuleOp parentModule = rtlModule->get();
+    // Walk the module and build a SymbolTable for each sub-module.
+    parentModule->walk([&](ModuleOp importModule) {
+      if (importModule != parentModule) {
+        LLVM_DEBUG(llvm::dbgs() << "Loaded RTL module "
+                                << safeModuleName(importModule) << "\n");
+        importModules.emplace_back(importModule);
+      }
+      // We don't need to descend into functions so just skip them.
+      return WalkResult::skip();
+    });
+
+    return success();
+  }
+
+  void runOnOperation() override {
+    auto moduleOp = getOperation();
+
+    SymbolTable programSymbolTable(moduleOp);
+    for (int i = 0; i < 1000; i++) {
+      auto analysis = getAnalysis<LinkageAnalysis>();
+      if (!analysis.hasExternFuncs()) {
+        LLVM_DEBUG(llvm::dbgs() << "No extern funcs to link.\n");
+        if (i == 0) {
+          markAllAnalysesPreserved();
+        }
+        return;
+      }
+
+      SetVector<Operation *> externFuncOps(analysis.getExternFuncOps().begin(),
+                                           analysis.getExternFuncOps().end());
+      while (!externFuncOps.empty()) {
+        auto externOp = *externFuncOps.begin();
+        if (failed(linkExtern(programSymbolTable, externOp, externFuncOps))) {
+          return signalPassFailure();
+        }
+      }
+
+      getAnalysisManager().invalidate({});
+    }
+
+    emitError(moduleOp.getLoc()) << "failed to converge when linking RTL";
+    return signalPassFailure();
+  }
+
+  LogicalResult linkExtern(SymbolTable &programSymbolTable, Operation *externOp,
+                           SetVector<Operation *> &externFuncOps) {
+    // First see if we can find a module that defines the symbol.
+    StringAttr probeSymbolName = SymbolTable::getSymbolName(externOp);
+    for (SymbolTable &importModule : importModules) {
+      Operation *probeImport = importModule.lookup(probeSymbolName);
+      if (probeImport) {
+        LLVM_DEBUG(llvm::dbgs()
+                       << "Resolving extern " << probeSymbolName << " from "
+                       << safeModuleName(importModule.getOp()) << "\n";);
+
+        if (failed(inlineImportModule(programSymbolTable, importModule,
+                                      externFuncOps)))
+          return failure();
+        return success();
+      }
+    }
+
+    externOp->emitError() << "could not resolve extern " << probeSymbolName;
+    return failure();
+  }
+
+  // Inlines an import module into a program module. This is a relatively
+  // brute force mechanism and it requires that symbols do not collide (i.e.
+  // if the program defined the same name as an RTL export, that would be an
+  // error). It is possible to make something smarter but not clear it is
+  // necessary, given the limited scope of "linking".
+  LogicalResult inlineImportModule(SymbolTable &programSymbolTable,
+                                   SymbolTable &importModule,
+                                   SetVector<Operation *> &externFuncOps) {
+    LLVM_DEBUG(llvm::dbgs() << "+++ Inlining module\n";);
+    auto result = importModule.getOp()->walk<WalkOrder::PreOrder>(
+        [&](Operation *importOp) -> WalkResult {
+          if (importOp == importModule.getOp()) return WalkResult::advance();
+          if (auto symbolImportOp = dyn_cast<SymbolOpInterface>(importOp)) {
+            StringAttr name = symbolImportOp.getNameAttr();
+            Operation *existing = programSymbolTable.lookup(name);
+            if (existing) {
+              if (failed(verifyCanImport(existing, importOp)))
+                return WalkResult::interrupt();
+
+              LLVM_DEBUG(llvm::dbgs() << "*** Erasing existing import " << name
+                                      << "\n";);
+              // Bookkeeping.
+              externFuncOps.remove(existing);
+              programSymbolTable.erase(existing);
+            }
+            // Clone and insert.
+            Operation *importedOp = importOp->clone();
+            SymbolTable::setSymbolVisibility(importedOp,
+                                             SymbolTable::Visibility::Private);
+            programSymbolTable.insert(importedOp);
+          } else {
+            importOp->emitWarning()
+                << "RTL module has non-importable operation ("
+                << importOp->getName() << "). Skipping.";
+          }
+          return WalkResult::skip();
+        });
+    if (result.wasInterrupted()) return failure();
+    LLVM_DEBUG(llvm::dbgs() << "--- Inlining complete\n";);
+    return success();
+  }
+
+  LogicalResult verifyCanImport(Operation *existing, Operation *importOp) {
+    // Must be the same type of operation.
+    if (existing->getName() != importOp->getName()) {
+      existing->emitError()
+          << "attempt to import RTL operation of different type ("
+          << importOp->getName() << " into " << existing->getName() << ")";
+      return failure();
+    }
+
+    // If a FuncOp, must be an import.
+    if (auto symbolOp = dyn_cast<SymbolOpInterface>(existing)) {
+      if (!symbolOp.isDeclaration()) {
+        return existing->emitError()
+               << "cannot import a symbol that is already defined";
+      }
+    } else {
+      return existing->emitError() << "cannot import a non-symbol";
+    }
+
+    return success();
+  }
+
+  // Really, this is the best option for this kind of thing.
+  std::shared_ptr<OwningModuleRef> rtlModule;
+
+  // A SymbolTable for each sub module.
+  SmallVector<SymbolTable> importModules;
+
+  // ASM source of RTL modules to link (otherwise will use pass options).
+  Optional<SourceBundle> linkRtlSourceBundle;
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> PYDM::createLinkIREEPyDMRTLPass(
+    Optional<SourceBundle> linkRtlSourceBundle) {
+  return std::make_unique<LinkIREEPyDMRTLPass>(std::move(linkRtlSourceBundle));
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp
new file mode 100644
index 0000000..aee2bf2
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.cpp
@@ -0,0 +1,22 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/PyDM/Transforms/RTL/LinkageAnalysis.h"
+
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "mlir/IR/SymbolTable.h"
+
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+LinkageAnalysis::LinkageAnalysis(Operation *moduleOp) {
+  moduleOp->walk<WalkOrder::PreOrder>([&](PYDM::FuncOp f) {
+    if (f.empty()) externFuncOps.push_back(f);
+    // We don't need to descend into functions so just skip them.
+    return WalkResult::skip();
+  });
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp
new file mode 100644
index 0000000..67957a6
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/RTL/LowerToRTLPass.cpp
@@ -0,0 +1,251 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "../PassDetail.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+namespace {
+
+class RtlFunc {
+ protected:
+  FunctionType makeRaisingSignature(Builder b, ArrayRef<Type> inputs,
+                                    Type output) {
+    return b.getType<FunctionType>(
+        inputs, TypeRange{b.getType<PYDM::ExceptionResultType>(), output});
+  }
+};
+
+template <typename RtlFuncTy>
+Operation *importRtlFunc(SymbolTable &symbolTable, RtlFuncTy rtlFunc) {
+  OpBuilder builder(symbolTable.getOp()->getContext());
+  auto name = builder.getStringAttr(rtlFunc.getRtlName());
+  auto *existing = symbolTable.lookup(name);
+  if (existing) return existing;
+
+  // Does not exist - create detached and insert.
+  FunctionType signature = rtlFunc.getRtlSignature(builder);
+  OperationState state(symbolTable.getOp()->getLoc(),
+                       PYDM::FuncOp::getOperationName());
+  PYDM::FuncOp::build(builder, state, name, signature);
+  auto funcOp = Operation::create(state);
+  SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
+  symbolTable.insert(funcOp);
+  return funcOp;
+}
+
+struct ObjectAsBoolFunc : public RtlFunc {
+  StringRef getRtlName() { return "pydmrtl$object_as_bool"; }
+  FunctionType getRtlSignature(Builder b) {
+    return makeRaisingSignature(b, {b.getType<PYDM::ObjectType>(nullptr)},
+                                b.getType<PYDM::BoolType>());
+  }
+};
+
+struct DynamicBinaryPromoteFunc : public RtlFunc {
+  StringRef getRtlName() { return "pydmrtl$dynamic_binary_promote"; }
+  FunctionType getRtlSignature(Builder b) {
+    return makeRaisingSignature(b,
+                                {b.getType<PYDM::ObjectType>(nullptr),
+                                 b.getType<PYDM::ObjectType>(nullptr)},
+                                b.getType<PYDM::TupleType>());
+  }
+};
+
+/// pydmrtl$apply_binary_${dunderName} RTL func.
+class ApplyBinaryFunc : public RtlFunc {
+ public:
+  ApplyBinaryFunc(StringRef dunderName) : rtlName("pydmrtl$apply_binary_") {
+    rtlName.append(dunderName.begin(), dunderName.end());
+  }
+  StringRef getRtlName() { return rtlName; }
+  FunctionType getRtlSignature(Builder b) {
+    Type objectType = b.getType<PYDM::ObjectType>(nullptr);
+    return makeRaisingSignature(b, {objectType, objectType}, objectType);
+  }
+
+ private:
+  std::string rtlName;
+};
+
+/// pydmrtl$apply_compare_${dunderName} RTL func.
+class ApplyCompareFunc : public RtlFunc {
+ public:
+  ApplyCompareFunc(StringRef dunderName) : rtlName("pydmrtl$apply_compare_") {
+    rtlName.append(dunderName.begin(), dunderName.end());
+  }
+  StringRef getRtlName() { return rtlName; }
+  FunctionType getRtlSignature(Builder b) {
+    Type objectType = b.getType<PYDM::ObjectType>(nullptr);
+    Type boolType = b.getType<PYDM::BoolType>();
+    return makeRaisingSignature(b, {objectType, objectType}, boolType);
+  }
+
+ private:
+  std::string rtlName;
+};
+
+template <typename RtlFuncTy, typename OpTy>
+class EmitImportCallBase : public OpRewritePattern<OpTy> {
+ public:
+  EmitImportCallBase(SymbolTable &symbolTable, PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>::OpRewritePattern(
+            symbolTable.getOp()->getContext(), benefit),
+        symbolTable(symbolTable) {}
+
+ protected:
+  Value emitImportCall(Location loc, ValueRange inputs, RtlFuncTy rtlFunc,
+                       PatternRewriter &rewriter) const {
+    auto rtlName = rtlFunc.getRtlName();
+    importRtlFunc<RtlFuncTy>(symbolTable, rtlFunc);
+    FunctionType signature = rtlFunc.getRtlSignature(rewriter);
+    auto symbolRef = rewriter.getType<FlatSymbolRefAttr>(rtlName);
+    // Perform simple conversions on inputs.
+    SmallVector<Value> convertedInputs;
+    for (auto it : zip(inputs, signature.getInputs())) {
+      Value input = std::get<0>(it);
+      Type expectedType = std::get<1>(it);
+      // Detect boxing.
+      if (expectedType.isa<PYDM::ObjectType>() &&
+          !input.getType().isa<ObjectType>()) {
+        input = rewriter.create<PYDM::BoxOp>(loc, expectedType, input);
+      }
+      convertedInputs.push_back(input);
+    }
+
+    auto callOp = rewriter.create<PYDM::CallOp>(loc, signature.getResults(),
+                                                symbolRef, convertedInputs);
+    rewriter.create<PYDM::RaiseOnFailureOp>(loc, callOp.exc_result());
+    return callOp.result();
+  }
+
+  void replaceOpWithCall(Operation *op, ValueRange inputs, RtlFuncTy rtlFunc,
+                         PatternRewriter &rewriter) const {
+    Value callResult = emitImportCall(op->getLoc(), inputs, rtlFunc, rewriter);
+    assert(op->getNumResults() != 0 && "expected op with results");
+    if (op->getNumResults() == 1) {
+      // No unpack.
+      rewriter.replaceOp(op, {callResult});
+    } else {
+      // Unpack 1 -> N.
+      SmallVector<Type> unpackTypes = {
+          rewriter.getType<PYDM::ExceptionResultType>()};
+      unpackTypes.append(op->getResultTypes().begin(),
+                         op->getResultTypes().end());
+      auto unpackOp = rewriter.create<PYDM::DynamicUnpackOp>(
+          op->getLoc(), unpackTypes, callResult);
+      rewriter.create<PYDM::RaiseOnFailureOp>(op->getLoc(),
+                                              unpackOp.exc_result());
+      rewriter.replaceOp(op, unpackOp.slots());
+    }
+  }
+
+ private:
+  SymbolTable &symbolTable;
+};
+
+struct ApplyBinaryPattern
+    : public EmitImportCallBase<ApplyBinaryFunc, PYDM::ApplyBinaryOp> {
+  using EmitImportCallBase::EmitImportCallBase;
+
+  LogicalResult matchAndRewrite(PYDM::ApplyBinaryOp srcOp,
+                                PatternRewriter &rewriter) const override {
+    // Only match object-object binary apply.
+    if (!srcOp.left().getType().isa<PYDM::ObjectType>() ||
+        !srcOp.right().getType().isa<PYDM::ObjectType>())
+      return rewriter.notifyMatchFailure(srcOp, "not (object, object) variant");
+
+    ApplyBinaryFunc f(srcOp.dunder_name());
+    replaceOpWithCall(srcOp, {srcOp.left(), srcOp.right()}, std::move(f),
+                      rewriter);
+    return success();
+  }
+};
+
+struct ApplyComparePattern
+    : public EmitImportCallBase<ApplyCompareFunc, PYDM::ApplyCompareOp> {
+  using EmitImportCallBase::EmitImportCallBase;
+
+  LogicalResult matchAndRewrite(PYDM::ApplyCompareOp srcOp,
+                                PatternRewriter &rewriter) const override {
+    // Only match object-object binary apply.
+    if (!srcOp.left().getType().isa<PYDM::ObjectType>() ||
+        !srcOp.right().getType().isa<PYDM::ObjectType>())
+      return rewriter.notifyMatchFailure(srcOp, "not (object, object) variant");
+
+    ApplyCompareFunc f(srcOp.dunder_name());
+    replaceOpWithCall(srcOp, {srcOp.left(), srcOp.right()}, std::move(f),
+                      rewriter);
+    return success();
+  }
+};
+
+struct DynamicBinaryPromotePattern
+    : public EmitImportCallBase<DynamicBinaryPromoteFunc,
+                                PYDM::DynamicBinaryPromoteOp> {
+  using EmitImportCallBase::EmitImportCallBase;
+
+  LogicalResult matchAndRewrite(PYDM::DynamicBinaryPromoteOp srcOp,
+                                PatternRewriter &rewriter) const override {
+    replaceOpWithCall(srcOp, {srcOp.left(), srcOp.right()}, {}, rewriter);
+    return success();
+  }
+};
+
+struct ObjectAsBoolPattern
+    : public EmitImportCallBase<ObjectAsBoolFunc, PYDM::AsBoolOp> {
+  using EmitImportCallBase::EmitImportCallBase;
+
+  LogicalResult matchAndRewrite(PYDM::AsBoolOp srcOp,
+                                PatternRewriter &rewriter) const override {
+    auto valueType = srcOp.value().getType().dyn_cast<PYDM::ObjectType>();
+    if (!valueType)
+      return rewriter.notifyMatchFailure(srcOp, "not an !object<>");
+    replaceOpWithCall(srcOp, {srcOp.value()}, {}, rewriter);
+    return success();
+  }
+};
+
+struct LowerIREEPyDMToRTLPass
+    : public LowerIREEPyDMToRTLBase<LowerIREEPyDMToRTLPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<mlir::iree_compiler::IREE::Input::IREEInputDialect,
+                    BuiltinDialect, StandardOpsDialect>();
+  }
+
+  void runOnOperation() override {
+    auto *context = &getContext();
+    auto moduleOp = getOperation();
+    SymbolTable symbolTable(moduleOp);
+    RewritePatternSet patterns(context);
+    patterns.insert<ApplyBinaryPattern, ApplyComparePattern,
+                    DynamicBinaryPromotePattern, ObjectAsBoolPattern>(
+        symbolTable);
+
+    GreedyRewriteConfig config;
+    if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns),
+                                            config))) {
+      emitError(getOperation().getLoc())
+          << "did not converge while lowering to rtl";
+      return signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> PYDM::createLowerIREEPyDMToRTLPass() {
+  return std::make_unique<LowerIREEPyDMToRTLPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/CMakeLists.txt
new file mode 100644
index 0000000..4c8a7b4
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_library(IREEPyDMToIREEPasses
+  ConversionPass.cpp
+  LoweringPatterns.cpp
+  TypeConverter.cpp
+
+  DEPENDS
+  IREEPyDMTransformsPassesIncGen
+
+  LINK_LIBS PUBLIC
+  IREEInputDialect
+  IREEPyDMDialect
+  MLIRArithmetic
+  MLIRIR
+  MLIRStandard
+  MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMToIREEPasses)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp
new file mode 100644
index 0000000..f49d0b1
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/ConversionPass.cpp
@@ -0,0 +1,75 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "../PassDetail.h"
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/Passes.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinDialect.h"
+
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+namespace {
+
+struct ConvertIREEPyDMToIREEPass
+    : public ConvertIREEPyDMToIREEBase<ConvertIREEPyDMToIREEPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<mlir::iree_compiler::IREE::Input::IREEInputDialect,
+                    BuiltinDialect, StandardOpsDialect, math::MathDialect>();
+  }
+
+  void runOnOperation() override {
+    auto *context = &getContext();
+    auto moduleOp = getOperation();
+    LoweringTypeConverter typeConverter;
+    RewritePatternSet patterns(context);
+    populatePyDMToIREELoweringPatterns(context, typeConverter, patterns);
+
+    ConversionTarget target(*context);
+    target.addIllegalDialect<IREEPyDMDialect>();
+    target.addLegalDialect<BuiltinDialect>();
+    target
+        .addLegalDialect<mlir::iree_compiler::IREE::Input::IREEInputDialect>();
+    target.addLegalDialect<mlir::arith::ArithmeticDialect>();
+    target.addLegalDialect<mlir::math::MathDialect>();
+    target.addLegalDialect<mlir::StandardOpsDialect>();
+
+    // Some CFG ops can be present in the original pydm program. Need to
+    // verify legality based on types.
+    target.addDynamicallyLegalOp<BranchOp>([&](mlir::BranchOp op) -> bool {
+      return typeConverter.areTypesLegal(op.getOperandTypes());
+    });
+    target.addDynamicallyLegalOp<CondBranchOp>(
+        [&](mlir::CondBranchOp op) -> bool {
+          return typeConverter.areTypesLegal(op.getOperandTypes());
+        });
+
+    // Standard select can be emitted as part of CFG canonicalization.
+    target.addDynamicallyLegalOp<mlir::SelectOp>(
+        [&](mlir::SelectOp op) -> bool {
+          return typeConverter.areTypesLegal(op.getOperandTypes());
+        });
+
+    if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>>
+PYDM::createConvertIREEPyDMToIREEPass() {
+  return std::make_unique<ConvertIREEPyDMToIREEPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp
new file mode 100644
index 0000000..0d5317d
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/LoweringPatterns.cpp
@@ -0,0 +1,1288 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMOps.h"
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/Patterns.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using llvm::enumerate;
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+namespace Input = mlir::iree_compiler::IREE::Input;
+using namespace PYDM;
+
+namespace {
+
+enum class ExceptionCode : int {
+  Success = 0,
+  StopIteration = -1,
+  StopAsyncIteration = -2,
+  RuntimeError = -3,
+  ValueError = -4,
+  NotImplementedError = -5,
+  KeyError = -6,
+  IndexError = -7,
+  AttributeError = -8,
+  TypeError = -9,
+  UnboundLocalError = -10,
+};
+
+}  // namespace
+
+static Type getVariantListType(Builder &builder) {
+  return builder.getType<Input::ListType>(
+      builder.getType<Input::VariantType>());
+}
+
+static Value getNullValue(Location loc, OpBuilder &builder, Type t) {
+  return TypeSwitch<Type, Value>(t)
+      .Case<Input::ListType>([&](auto t) -> Value {
+        // TODO: If it becomes important to optimize this, come up with a way
+        // to return an empty list without creating one.
+        return builder.create<Input::ListCreateOp>(
+            loc, getVariantListType(builder), /*capacity=*/nullptr);
+      })
+      .Default([&](Type t) -> Value {
+        auto attr = builder.getZeroAttr(t);
+        assert(attr && "could not get zero attr for builtin type");
+        return builder.create<arith::ConstantOp>(loc, t, attr);
+      });
+}
+
+/// Creates a slow path block at the end of the function. The current block
+/// will always dominate.
+static Block *createSlowPathBlock(OpBuilder &builder) {
+  Region *parentRegion = builder.getInsertionBlock()->getParent();
+  return builder.createBlock(parentRegion, parentRegion->end());
+}
+
+static Value getSuccessStatusValue(Location loc, OpBuilder &builder) {
+  return builder.create<arith::ConstantOp>(loc, builder.getI32IntegerAttr(0));
+}
+
+static Value getFailureStatusValue(Location loc, OpBuilder &builder,
+                                   ExceptionCode code) {
+  return builder.create<arith::ConstantOp>(
+      loc, builder.getI32IntegerAttr(static_cast<int>(code)));
+}
+
+static Value createUndefObjectList(Location loc, OpBuilder &builder) {
+  return builder.create<Input::ListCreateOp>(loc, getVariantListType(builder),
+                                             /*capacity=*/nullptr);
+}
+
+void resetObjectList(Location loc, OpBuilder &builder, Value list, int typeCode,
+                     Value data) {
+  // Note: The list can record optional runtime state at positions > 1, so
+  // to truly reset, we have to resize. Low level optimizations should be able
+  // to elide this if it turns out to be unnecessary.
+  auto size = builder.create<arith::ConstantIndexOp>(loc, 2);
+  builder.create<Input::ListResizeOp>(loc, list, size);
+  auto index0 = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value typeCodeValue = builder.create<arith::ConstantOp>(
+      loc, builder.getI32IntegerAttr(typeCode));
+  builder.create<Input::ListSetOp>(loc, list, index0, typeCodeValue);
+  auto index1 = builder.create<arith::ConstantIndexOp>(loc, 1);
+  builder.create<Input::ListSetOp>(loc, list, index1, data);
+}
+
+static Value createObjectList(Location loc, OpBuilder &builder, int typeCode,
+                              Value data) {
+  auto list = createUndefObjectList(loc, builder);
+  resetObjectList(loc, builder, list, typeCode, data);
+  return list;
+}
+
+static Value castIntegerValue(Location loc, Value input,
+                              mlir::IntegerType resultType,
+                              OpBuilder &builder) {
+  mlir::IntegerType inputType = input.getType().cast<mlir::IntegerType>();
+  if (inputType.getWidth() == resultType.getWidth()) {
+    return input;
+  } else if (inputType.getWidth() < resultType.getWidth()) {
+    return builder.create<arith::ExtSIOp>(loc, resultType, input);
+  } else {
+    return builder.create<arith::TruncIOp>(loc, resultType, input);
+  }
+}
+
+static Optional<arith::CmpIPredicate> convertIntegerComparePredicate(
+    StringAttr dunderName, bool isSigned, Builder &builder) {
+  StringRef v = dunderName.getValue();
+  if (v == "lt") {
+    return isSigned ? arith::CmpIPredicate::slt : arith::CmpIPredicate::ult;
+  } else if (v == "le") {
+    return isSigned ? arith::CmpIPredicate::sle : arith::CmpIPredicate::ule;
+  } else if (v == "eq" || v == "is") {
+    return arith::CmpIPredicate::eq;
+  } else if (v == "ne" || v == "isnot") {
+    return arith::CmpIPredicate::ne;
+  } else if (v == "gt") {
+    return isSigned ? arith::CmpIPredicate::sgt : arith::CmpIPredicate::ugt;
+  } else if (v == "ge") {
+    return isSigned ? arith::CmpIPredicate::sge : arith::CmpIPredicate::uge;
+  }
+
+  return {};
+}
+
+static Optional<arith::CmpFPredicate> convertFpComparePredicate(
+    StringAttr dunderName, Builder &builder) {
+  StringRef v = dunderName.getValue();
+  if (v == "lt") {
+    return arith::CmpFPredicate::OLT;
+  } else if (v == "le") {
+    return arith::CmpFPredicate::OLE;
+  } else if (v == "eq" || v == "is") {
+    return arith::CmpFPredicate::OEQ;
+  } else if (v == "ne" || v == "isnot") {
+    return arith::CmpFPredicate::ONE;
+  } else if (v == "gt") {
+    return arith::CmpFPredicate::OGT;
+  } else if (v == "ge") {
+    return arith::CmpFPredicate::OGE;
+  }
+
+  return {};
+}
+
+/// Does a low-level boxing operation on the given `convertedValue`, which
+/// has already been subject to type conversion. This is based on the original
+/// `pythonType` which must implement PythonTypeInterface.
+/// If `pythonType` is already boxed, then this does nothing.
+/// Returns nullptr for unsupported cases, not emitting diagnostics.
+static Value boxConvertedValue(Location loc, Type pythonType,
+                               Value convertedValue, OpBuilder &builder) {
+  if (pythonType.isa<PYDM::ObjectType>()) return convertedValue;
+  auto ptiType = pythonType.dyn_cast<PYDM::PythonTypeInterface>();
+  if (!ptiType) return {};
+  auto typeCode = ptiType.getTypeCode();
+  auto list = createObjectList(loc, builder, static_cast<int>(typeCode),
+                               convertedValue);
+  return list;
+}
+
+namespace {
+
+class AllocFreeVarOpConversion
+    : public OpConversionPattern<PYDM::AllocFreeVarOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::AllocFreeVarOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    // TODO: We may want to initialize the list structurally in some way.
+    // This will fail either way on read from unassigned variable, but we need
+    // to see what works better for good runtime error messages.
+    auto loc = srcOp.getLoc();
+    Value list = createUndefObjectList(loc, rewriter);
+    rewriter.replaceOp(srcOp, list);
+    return success();
+  }
+};
+
+class ApplyBinaryNumericConversion
+    : public OpConversionPattern<PYDM::ApplyBinaryOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::ApplyBinaryOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    Type pyLeftType = srcOp.left().getType();
+    Type pyRightType = srcOp.right().getType();
+    Type leftType = adaptor.left().getType();
+    Type rightType = adaptor.right().getType();
+    Type resultType = typeConverter->convertType(srcOp.result().getType());
+    if (!resultType || pyLeftType != pyRightType || leftType != rightType ||
+        leftType != resultType) {
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "not same type operands/results");
+    }
+    if (auto pyIntegerType = pyLeftType.dyn_cast<PYDM::IntegerType>()) {
+      bool isSigned = pyIntegerType.isSigned();
+      Value converted =
+          convertIntegerOp(srcOp.getLoc(), adaptor.dunder_name().getValue(),
+                           adaptor.left(), adaptor.right(), isSigned, rewriter);
+      if (!converted)
+        return rewriter.notifyMatchFailure(srcOp, "unsupported operation");
+      rewriter.replaceOp(srcOp, converted);
+      return success();
+    } else if (leftType.isa<mlir::FloatType>()) {
+      Value converted =
+          convertFloatOp(srcOp.getLoc(), adaptor.dunder_name().getValue(),
+                         adaptor.left(), adaptor.right(), rewriter);
+      if (!converted)
+        return rewriter.notifyMatchFailure(srcOp, "unsupported operation");
+      rewriter.replaceOp(srcOp, converted);
+      return success();
+    }
+
+    return rewriter.notifyMatchFailure(srcOp, "non numeric type");
+  }
+
+  Value convertIntegerOp(Location loc, StringRef dunderName, Value left,
+                         Value right, bool isSigned,
+                         ConversionPatternRewriter &rewriter) const {
+    // TODO: matmul, truediv, floordiv, mod, divmod, pow
+    if (dunderName == "add") {
+      return rewriter.create<arith::AddIOp>(loc, left, right);
+    } else if (dunderName == "and") {
+      return rewriter.create<arith::AndIOp>(loc, left, right);
+    } else if (dunderName == "mul") {
+      return rewriter.create<arith::MulIOp>(loc, left, right);
+    } else if (dunderName == "lshift") {
+      return rewriter.create<arith::ShLIOp>(loc, left, right);
+    } else if (dunderName == "or") {
+      return rewriter.create<arith::OrIOp>(loc, left, right);
+    } else if (dunderName == "rshift") {
+      if (isSigned)
+        return rewriter.create<arith::ShRSIOp>(loc, left, right);
+      else
+        return rewriter.create<arith::ShRUIOp>(loc, left, right);
+    } else if (dunderName == "sub") {
+      return rewriter.create<arith::SubIOp>(loc, left, right);
+    } else if (dunderName == "xor") {
+      return rewriter.create<arith::XOrIOp>(loc, left, right);
+    }
+    return nullptr;
+  }
+
+  Value convertFloatOp(Location loc, StringRef dunderName, Value left,
+                       Value right, ConversionPatternRewriter &rewriter) const {
+    // TODO: matmul, truediv, floordiv, mod, divmod, pow
+    if (dunderName == "add") {
+      return rewriter.create<arith::AddFOp>(loc, left, right);
+    } else if (dunderName == "mul") {
+      return rewriter.create<arith::MulFOp>(loc, left, right);
+    } else if (dunderName == "sub") {
+      return rewriter.create<arith::SubFOp>(loc, left, right);
+    }
+    return nullptr;
+  }
+};
+
+class ApplyCompareNumericConversion
+    : public OpConversionPattern<PYDM::ApplyCompareOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::ApplyCompareOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    Type leftType = adaptor.left().getType();
+    Type rightType = adaptor.right().getType();
+    if (leftType != rightType) {
+      return rewriter.notifyMatchFailure(srcOp, "not same type operands");
+    }
+    if (leftType.isa<mlir::IntegerType>()) {
+      bool isSigned = true;  // TODO: Unsigned.
+      auto predicate = convertIntegerComparePredicate(adaptor.dunder_name(),
+                                                      isSigned, rewriter);
+      if (!predicate)
+        return rewriter.notifyMatchFailure(srcOp, "unsupported predicate");
+      rewriter.replaceOpWithNewOp<arith::CmpIOp>(
+          srcOp, *predicate, adaptor.left(), adaptor.right());
+      return success();
+    } else if (leftType.isa<mlir::FloatType>()) {
+      auto predicate =
+          convertFpComparePredicate(adaptor.dunder_name(), rewriter);
+      if (!predicate)
+        return rewriter.notifyMatchFailure(srcOp, "unsupported predicate");
+      rewriter.replaceOpWithNewOp<arith::CmpFOp>(
+          srcOp, *predicate, adaptor.left(), adaptor.right());
+      return success();
+    }
+
+    return rewriter.notifyMatchFailure(srcOp, "non numeric type");
+  }
+};
+
+class AssignSubscriptListConversion
+    : public OpConversionPattern<PYDM::AssignSubscriptOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      PYDM::AssignSubscriptOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto pySequence = srcOp.lhs();
+    if (!pySequence.getType().isa<PYDM::ListType>())
+      return rewriter.notifyMatchFailure(srcOp, "not builtin sequence");
+    auto pySlice = srcOp.slice();
+    if (!pySlice.getType().isa<PYDM::IntegerType>())
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "slice is not static integer type");
+
+    auto loc = srcOp.getLoc();
+    auto sequence = adaptor.lhs();
+    auto slice = adaptor.slice();
+    auto indexType = rewriter.getType<IndexType>();
+    Type statusType =
+        getTypeConverter()->convertType(srcOp.exc_result().getType());
+    Value valueToSet =
+        boxIfNecessary(loc, pySequence.getType().cast<PYDM::ListType>(),
+                       srcOp.rhs().getType(), adaptor.rhs(), rewriter);
+    if (!valueToSet) {
+      return rewriter.notifyMatchFailure(
+          srcOp, "unsupported list assignment boxing mode");
+    }
+
+    Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, slice.getType());
+    Value listSizeIndex =
+        rewriter.create<Input::ListSizeOp>(loc, indexType, sequence);
+    Value listSizeInteger = rewriter.create<arith::IndexCastOp>(
+        loc, slice.getType(), listSizeIndex);
+    Block *entryBlock = rewriter.getInsertionBlock();
+    Block *continuationBlock = rewriter.splitBlock(
+        rewriter.getInsertionBlock(), rewriter.getInsertionPoint());
+    Block *indexLtZeroBlock = rewriter.createBlock(continuationBlock);
+    Block *indexCheckBlock = rewriter.createBlock(continuationBlock);
+    indexCheckBlock->addArgument(indexType);
+    Block *setElementBlock = rewriter.createBlock(continuationBlock);
+    setElementBlock->addArgument(indexType);
+    Block *failureBlock = createSlowPathBlock(rewriter);
+    continuationBlock->addArgument(statusType);
+    rewriter.replaceOp(srcOp, continuationBlock->getArguments());
+
+    // Comparison index < 0.
+    {
+      rewriter.setInsertionPointToEnd(entryBlock);
+      Value ltZero = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::slt, slice, zero);
+      auto sliceIndex =
+          rewriter.create<arith::IndexCastOp>(loc, indexType, slice);
+      rewriter.create<mlir::CondBranchOp>(loc, ltZero, indexLtZeroBlock,
+                                          indexCheckBlock,
+                                          ValueRange{sliceIndex});
+    }
+
+    // Handle index < 0.
+    {
+      rewriter.setInsertionPointToEnd(indexLtZeroBlock);
+      Value positiveSlice =
+          rewriter.create<arith::AddIOp>(loc, slice, listSizeInteger);
+      Value positiveSliceIndex =
+          rewriter.create<arith::IndexCastOp>(loc, indexType, positiveSlice);
+      rewriter.create<mlir::BranchOp>(loc, ValueRange{positiveSliceIndex},
+                                      indexCheckBlock);
+    }
+
+    // Index check.
+    {
+      rewriter.setInsertionPointToEnd(indexCheckBlock);
+      Value sliceIndex = indexCheckBlock->getArgument(0);
+      Value ltSize = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::ult, sliceIndex, listSizeIndex);
+      rewriter.create<mlir::CondBranchOp>(loc, ltSize, setElementBlock,
+                                          ValueRange{sliceIndex}, failureBlock,
+                                          ValueRange{});
+    }
+
+    // Set element.
+    {
+      rewriter.setInsertionPointToEnd(setElementBlock);
+      Value successResult = getSuccessStatusValue(loc, rewriter);
+      rewriter.create<Input::ListSetOp>(
+          loc, sequence, setElementBlock->getArgument(0), valueToSet);
+      rewriter.create<mlir::BranchOp>(loc, continuationBlock,
+                                      ValueRange{successResult});
+    }
+
+    // Failure.
+    {
+      rewriter.setInsertionPointToEnd(failureBlock);
+      Value failureResult =
+          getFailureStatusValue(loc, rewriter, ExceptionCode::IndexError);
+      rewriter.create<mlir::BranchOp>(loc, continuationBlock,
+                                      ValueRange{failureResult});
+    }
+
+    return success();
+  }
+
+  Value boxIfNecessary(Location loc, PYDM::ListType listType, Type origRhsType,
+                       Value rhs, ConversionPatternRewriter &rewriter) const {
+    switch (listType.getStorageClass()) {
+      case CollectionStorageClass::Boxed:
+      case CollectionStorageClass::Empty: {
+        return boxConvertedValue(loc, origRhsType, rhs, rewriter);
+        break;
+      }
+      case CollectionStorageClass::Unboxed:
+        // TODO: Implement.
+        return nullptr;
+    }
+  }
+};
+
+class BoolToPredConversion : public OpConversionPattern<PYDM::BoolToPredOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::BoolToPredOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOp(srcOp, adaptor.value());
+    return success();
+  }
+};
+
+class BoxOpConversion : public OpConversionPattern<PYDM::BoxOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::BoxOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+    Value boxedValue = boxConvertedValue(loc, srcOp.primitive().getType(),
+                                         adaptor.primitive(), rewriter);
+    if (!boxedValue)
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "not a supported type for boxing");
+    rewriter.replaceOp(srcOp, boxedValue);
+    return success();
+  }
+};
+
+class CallOpConversion : public OpConversionPattern<PYDM::CallOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::CallOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Type> resultTypes;
+    if (failed(getTypeConverter()->convertTypes(srcOp.getResultTypes(),
+                                                resultTypes))) {
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "result types could not be converted");
+    }
+    rewriter.replaceOpWithNewOp<mlir::CallOp>(srcOp, srcOp.callee(),
+                                              resultTypes, adaptor.operands());
+    return success();
+  }
+};
+
+class ConstantOpConversion : public OpConversionPattern<PYDM::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::ConstantOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+    Type resultType = typeConverter->convertType(srcOp.getResult().getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(
+          srcOp, "constant type could not be converted");
+    Attribute newValue = adaptor.value();
+    // Fixup widths of integer types that may be wider/narrower than the
+    // stored attribute (which tends to be stored in high precision in pydm
+    // constants).
+    TypeSwitch<Type>(resultType)
+        .Case([&](mlir::IntegerType t) {
+          APInt intValue =
+              newValue.cast<IntegerAttr>().getValue().sextOrTrunc(t.getWidth());
+          newValue = rewriter.getIntegerAttr(t, intValue);
+        })
+        .Case([&](mlir::FloatType t) {
+          APFloat fpValue = newValue.cast<FloatAttr>().getValue();
+          if (APFloat::SemanticsToEnum(fpValue.getSemantics()) !=
+              APFloat::SemanticsToEnum(t.getFloatSemantics())) {
+            // Convert.
+            APFloat newFpValue = fpValue;
+            bool losesInfo;
+            newFpValue.convert(t.getFloatSemantics(),
+                               APFloat::rmNearestTiesToEven, &losesInfo);
+            if (losesInfo) {
+              emitWarning(loc) << "conversion of " << newValue << " to " << t
+                               << " loses information";
+            }
+            newValue = rewriter.getFloatAttr(t, newFpValue);
+          }
+        });
+
+    if (!newValue)
+      return rewriter.notifyMatchFailure(
+          srcOp, "constant cannot be represented as a standard constant");
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(srcOp, resultType, newValue);
+    return success();
+  }
+};
+
+/// Expands dynamic unpacking of a tuple or list by taking advantage that they
+/// both are just variant lists. A size check is emitted, with a branch to
+/// a failure block. The success block will just get each element.
+class DynamicUnpackOpConversion
+    : public OpConversionPattern<PYDM::DynamicUnpackOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::DynamicUnpackOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+    // Convert types.
+    Type excResultType =
+        getTypeConverter()->convertType(srcOp.exc_result().getType());
+    if (!excResultType)
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "could not convert exc_result type");
+    int arity = srcOp.slots().size();
+    SmallVector<Type> slotTypes;
+    slotTypes.reserve(arity);
+    for (auto slot : srcOp.slots()) {
+      Type slotType = getTypeConverter()->convertType(slot.getType());
+      if (!slotType)
+        return rewriter.notifyMatchFailure(
+            srcOp, "could not convert result slot type");
+      slotTypes.push_back(slotType);
+    }
+
+    // Split the entry block.
+    Block *entryBlock = rewriter.getInsertionBlock();
+    Block *continuationBlock = rewriter.splitBlock(
+        rewriter.getInsertionBlock(), rewriter.getInsertionPoint());
+    Block *arityMatchBlock = rewriter.createBlock(continuationBlock);
+    Block *errorBlock = createSlowPathBlock(rewriter);
+    continuationBlock->addArguments(excResultType);
+    continuationBlock->addArguments(slotTypes);
+    rewriter.replaceOp(srcOp, continuationBlock->getArguments());
+
+    // Entry block - check arity.
+    {
+      rewriter.setInsertionPointToEnd(entryBlock);
+      auto arityValue =
+          rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(arity));
+      Value listSize = rewriter.create<Input::ListSizeOp>(
+          loc, rewriter.getIndexType(), adaptor.sequence());
+      Value arityMatch = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::eq, arityValue, listSize);
+      rewriter.create<mlir::CondBranchOp>(loc, arityMatch, arityMatchBlock,
+                                          errorBlock);
+    }
+
+    // Arity match.
+    {
+      rewriter.setInsertionPointToEnd(arityMatchBlock);
+      SmallVector<Value> branchArgs;
+      branchArgs.push_back(getSuccessStatusValue(loc, rewriter));
+      for (auto it : enumerate(slotTypes)) {
+        Value index = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIndexAttr(it.index()));
+        Value slotValue = rewriter.create<Input::ListGetOp>(
+            loc, it.value(), adaptor.sequence(), index);
+        branchArgs.push_back(slotValue);
+      }
+      rewriter.create<mlir::BranchOp>(loc, continuationBlock, branchArgs);
+    }
+
+    // Error block.
+    {
+      rewriter.setInsertionPointToEnd(errorBlock);
+      SmallVector<Value> branchArgs;
+      branchArgs.push_back(
+          getFailureStatusValue(loc, rewriter, ExceptionCode::ValueError));
+      for (Type slotType : slotTypes) {
+        branchArgs.push_back(getNullValue(loc, rewriter, slotType));
+      }
+      rewriter.create<mlir::BranchOp>(loc, continuationBlock, branchArgs);
+    }
+
+    return success();
+  }
+};
+
+/// If at this phase, there is nothing to do with a static info cast.
+/// Just drop it.
+class ElideStaticInfoCast : public OpConversionPattern<PYDM::StaticInfoCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::StaticInfoCastOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOp(srcOp, srcOp.value());
+    return success();
+  }
+};
+
+/// Generates a failure exception code.
+/// This is just temporary to allow some libraries to signal exceptions.
+class FailureOpConversion : public OpConversionPattern<PYDM::FailureOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::FailureOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    Type i32 = rewriter.getI32Type();
+    // '-3' == RuntimeError
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+        srcOp, i32, rewriter.getIntegerAttr(i32, -3));
+    return success();
+  }
+};
+
+class FuncOpConversion : public OpConversionPattern<PYDM::FuncOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::FuncOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    FunctionType srcFuncType = srcOp.getType();
+    TypeConverter::SignatureConversion signatureConversion(
+        srcOp.getNumArguments());
+
+    // Convert function arguments.
+    for (unsigned i = 0, e = srcFuncType.getNumInputs(); i < e; ++i) {
+      if (failed(getTypeConverter()->convertSignatureArg(
+              i, srcFuncType.getInput(i), signatureConversion))) {
+        return rewriter.notifyMatchFailure(srcOp, "argument failed to convert");
+      }
+    }
+
+    // Convert function results.
+    SmallVector<Type, 1> convertedResultTypes;
+    if (failed(getTypeConverter()->convertTypes(srcFuncType.getResults(),
+                                                convertedResultTypes))) {
+      return rewriter.notifyMatchFailure(srcOp, "results failed to convert");
+    }
+
+    // Create new function with converted argument and result types.
+    // Note that attributes are dropped. Consider preserving some if needed.
+    auto newFuncType = mlir::FunctionType::get(
+        srcOp.getContext(), signatureConversion.getConvertedTypes(),
+        convertedResultTypes);
+    auto newFuncOp = rewriter.create<mlir::FuncOp>(
+        srcOp.getLoc(), srcOp.getName(), newFuncType);
+    newFuncOp.setVisibility(srcOp.getVisibility());
+    rewriter.inlineRegionBefore(srcOp.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+
+    // Tell the rewriter to convert the region signature.
+    TypeConverter &typeConverter = *getTypeConverter();
+    if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+                                           &signatureConversion))) {
+      return failure();
+    }
+
+    rewriter.replaceOp(srcOp, llvm::None);
+    return success();
+  }
+};
+
+class GetTypeCodeConversion : public OpConversionPattern<PYDM::GetTypeCodeOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::GetTypeCodeOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+    // Gets the 0'th element of the object list, optionally casting it to the
+    // converted integer type.
+    Type resultType = typeConverter->convertType(srcOp.getResult().getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "result type could not be converted");
+    Type i32Type = rewriter.getIntegerType(32);
+    Value index0 =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
+    Value typeCode = rewriter.create<Input::ListGetOp>(loc, i32Type,
+                                                       adaptor.value(), index0);
+    rewriter.replaceOp(
+        srcOp,
+        castIntegerValue(loc, typeCode, resultType.cast<mlir::IntegerType>(),
+                         rewriter));
+    return success();
+  }
+};
+
+class LoadVarOpConversion : public OpConversionPattern<PYDM::LoadVarOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::LoadVarOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+    auto resultType =
+        getTypeConverter()->convertType(srcOp.getResult().getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(
+          srcOp, "could not convert load_var result type");
+    auto list = adaptor.getOperands()[0];
+    auto index1 =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
+    rewriter.replaceOpWithNewOp<Input::ListGetOp>(srcOp, resultType, list,
+                                                  index1);
+    return success();
+  }
+};
+
+class MakeListOpBoxedConversion : public OpConversionPattern<PYDM::MakeListOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::MakeListOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+    auto listType = srcOp.list().getType().cast<PYDM::ListType>();
+    if (listType.getStorageClass() != CollectionStorageClass::Boxed ||
+        listType.getStorageClass() == CollectionStorageClass::Empty)
+      return rewriter.notifyMatchFailure(srcOp, "unboxed list");
+    auto resultType = getTypeConverter()->convertType(listType);
+    if (!resultType)
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "could not convert result type");
+
+    auto size = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIndexAttr(adaptor.elements().size()));
+    auto list =
+        rewriter.create<Input::ListCreateOp>(loc, getVariantListType(rewriter),
+                                             /*capacity=*/size);
+    rewriter.create<Input::ListResizeOp>(loc, list, size);
+    for (auto it : enumerate(adaptor.elements())) {
+      auto index = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIndexAttr(it.index()));
+      rewriter.create<Input::ListSetOp>(loc, list, index, it.value());
+    }
+
+    rewriter.replaceOp(srcOp, ValueRange{list});
+    return success();
+  }
+};
+
+class MakeTupleOpConversion : public OpConversionPattern<PYDM::MakeTupleOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::MakeTupleOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+    auto resultType = getTypeConverter()->convertType(srcOp.tuple().getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "could not convert result type");
+
+    auto size = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIndexAttr(adaptor.slots().size()));
+    auto list =
+        rewriter.create<Input::ListCreateOp>(loc, getVariantListType(rewriter),
+                                             /*capacity=*/size);
+    rewriter.create<Input::ListResizeOp>(loc, list, size);
+    for (auto it : enumerate(adaptor.slots())) {
+      auto index = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIndexAttr(it.index()));
+      rewriter.create<Input::ListSetOp>(loc, list, index, it.value());
+    }
+
+    rewriter.replaceOp(srcOp, ValueRange{list});
+    return success();
+  }
+};
+
+/// Converts the `neg` op on integer operand/result to a corresponding sub.
+class NegIntegerOpConversion : public OpConversionPattern<PYDM::NegOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      PYDM::NegOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    Type valueType = adaptor.value().getType();
+    Type resultType = getTypeConverter()->convertType(srcOp.result().getType());
+    if (!valueType.isa<mlir::IntegerType>() || valueType != resultType)
+      return rewriter.notifyMatchFailure(srcOp, "not an integer neg");
+    Location loc = srcOp.getLoc();
+    Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, resultType);
+    rewriter.replaceOpWithNewOp<arith::SubIOp>(srcOp, zero, adaptor.value());
+    return success();
+  }
+};
+
+/// Converts a `none` operation to a `constant 0 : i32`.
+/// See also the type conversion rule for `NoneType` which must align.
+/// TODO: What we are really reaching for is a zero width type.
+class NoneOpConversion : public OpConversionPattern<PYDM::NoneOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::NoneOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    Type i32 = rewriter.getI32Type();
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+        srcOp, i32, rewriter.getIntegerAttr(i32, 0));
+    return success();
+  }
+};
+
+/// Raises an excpetion (failing status) on failure.
+/// This pattern matches raise_on_failure ops at function scope. Those nested
+/// within exception blocks are different.
+class RaiseOnFailureOpConversion
+    : public OpConversionPattern<PYDM::RaiseOnFailureOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::RaiseOnFailureOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+
+    Value status = adaptor.getOperands()[0];
+    // Get the containing function return type so that we can create a
+    // suitable null return value.
+    auto parentFunc = srcOp->getParentOfType<mlir::FuncOp>();
+    if (!parentFunc)
+      return rewriter.notifyMatchFailure(srcOp, "not contained by a func");
+    Type convertedReturnType = parentFunc.getType().getResult(1);
+
+    // Split the entry block.
+    Block *entryBlock = rewriter.getInsertionBlock();
+    Block *continuationBlock = rewriter.splitBlock(
+        rewriter.getInsertionBlock(), rewriter.getInsertionPoint());
+    Block *raiseAndReturnBlock = createSlowPathBlock(rewriter);
+
+    // Branch on success conditional.
+    rewriter.setInsertionPointToEnd(entryBlock);
+    Value successValue = getSuccessStatusValue(loc, rewriter);
+    Value isSuccess = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::eq, successValue, status);
+    rewriter.create<mlir::CondBranchOp>(loc, isSuccess, continuationBlock,
+                                        raiseAndReturnBlock);
+    rewriter.eraseOp(srcOp);
+
+    // Raise and return block.
+    rewriter.setInsertionPointToEnd(raiseAndReturnBlock);
+    auto nullReturnValue = getNullValue(loc, rewriter, convertedReturnType);
+    rewriter.create<mlir::ReturnOp>(loc, ValueRange{status, nullReturnValue});
+    return success();
+  }
+};
+
+/// Converts to a successful return (0 exception result and actual value).
+class ReturnOpConversion : public OpConversionPattern<PYDM::ReturnOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::ReturnOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+    auto zeroResult =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
+    rewriter.replaceOpWithNewOp<mlir::ReturnOp>(
+        srcOp, ValueRange{zeroResult, adaptor.getOperands()[0]});
+    return success();
+  }
+};
+
+/// Implements sequence duplication over built-in list, tuple types.
+class SequenceCloneBuiltinConversion
+    : public OpConversionPattern<PYDM::SequenceCloneOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::SequenceCloneOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    Type origListType = srcOp.sequence().getType();
+    if (!isSupportedList(origListType)) return failure();
+    if (origListType != srcOp.getResult().getType()) return failure();
+    Type resultType = typeConverter->convertType(srcOp.getResult().getType());
+    if (!resultType) {
+      return rewriter.notifyMatchFailure(srcOp, "cannot convert result type");
+    }
+    Type listElementType =
+        typeConverter->convertType(getElementAccessType(origListType));
+    if (!listElementType) {
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "cannot convert list element type");
+    }
+
+    Value listOperand = adaptor.sequence();
+    Value countOperand = adaptor.count();
+    auto loc = srcOp.getLoc();
+    // Compute the new size, clamping count to >= 0 and construct list.
+    Type indexType = rewriter.getType<IndexType>();
+    Type listType = listOperand.getType();
+    Value subListSize =
+        rewriter.create<Input::ListSizeOp>(loc, indexType, listOperand);
+    Value countInteger = countOperand;
+    Value countIndex =
+        rewriter.create<arith::IndexCastOp>(loc, indexType, countOperand);
+    Value zeroIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    Value oneIndex = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value zeroInteger =
+        rewriter.create<arith::ConstantIntOp>(loc, 0, countInteger.getType());
+    Value countClampsToZero = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sle, countInteger, zeroInteger);
+    Value clampedCountIndex = rewriter.create<mlir::SelectOp>(
+        loc, countClampsToZero, zeroIndex, countIndex);
+    Value newListSize =
+        rewriter.create<arith::MulIOp>(loc, subListSize, clampedCountIndex);
+    Value newList =
+        rewriter.create<Input::ListCreateOp>(loc, listType, clampedCountIndex);
+    rewriter.create<Input::ListResizeOp>(loc, newList, newListSize);
+
+    // Split blocks to loop.
+    // TODO: Use a new list.copy op instead of an inner loop.
+    // OuterCond: (newListIt : index)
+    // InnerCond: (newListIt : index, subListIt: index)
+    // InnerBody: (newListIt : index, subListIt: index)
+    Block *entryBlock = rewriter.getInsertionBlock();
+    Block *continuationBlock = rewriter.splitBlock(
+        rewriter.getInsertionBlock(), rewriter.getInsertionPoint());
+    Block *outerCond = rewriter.createBlock(continuationBlock);
+    outerCond->addArgument(indexType);
+    Block *innerCond = rewriter.createBlock(continuationBlock);
+    innerCond->addArguments({indexType, indexType});
+    Block *innerBody = rewriter.createBlock(continuationBlock);
+    innerBody->addArguments({indexType, indexType});
+
+    // Entry block.
+    {
+      rewriter.setInsertionPointToEnd(entryBlock);
+      rewriter.create<BranchOp>(loc, outerCond, ValueRange{zeroIndex});
+    }
+
+    // Outer cond.
+    {
+      rewriter.setInsertionPointToEnd(outerCond);
+      Value newListIt = outerCond->getArgument(0);
+      Value inBounds = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::ult, newListIt, newListSize);
+      rewriter.create<CondBranchOp>(loc, inBounds, innerCond,
+                                    ValueRange{newListIt, zeroIndex},
+                                    continuationBlock, ValueRange{});
+    }
+
+    // Inner cond.
+    {
+      rewriter.setInsertionPointToEnd(innerCond);
+      Value newListIt = innerCond->getArgument(0);
+      Value subListIt = innerCond->getArgument(1);
+      Value inBounds = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::ult, subListIt, subListSize);
+      rewriter.create<CondBranchOp>(loc, inBounds, innerBody,
+                                    ValueRange{newListIt, subListIt}, outerCond,
+                                    ValueRange{newListIt});
+    }
+
+    // Inner body.
+    {
+      rewriter.setInsertionPointToEnd(innerBody);
+      Value newListIt = innerBody->getArgument(0);
+      Value subListIt = innerBody->getArgument(1);
+
+      Value elementValue = rewriter.create<Input::ListGetOp>(
+          loc, listElementType, listOperand, subListIt);
+      rewriter.create<Input::ListSetOp>(loc, newList, newListIt, elementValue);
+
+      newListIt = rewriter.create<arith::AddIOp>(loc, newListIt, oneIndex);
+      subListIt = rewriter.create<arith::AddIOp>(loc, subListIt, oneIndex);
+      rewriter.create<BranchOp>(loc, innerCond,
+                                ValueRange{newListIt, subListIt});
+    }
+
+    // Continuation.
+    {
+      rewriter.setInsertionPointToEnd(continuationBlock);
+      rewriter.replaceOp(srcOp, {newList});
+    }
+
+    return success();
+  }
+
+  bool isSupportedList(Type t) const {
+    // Both lists and tuples have the same physical representation and can
+    // be supported interchangeably here.
+    return t.isa<PYDM::ListType>() || t.isa<PYDM::TupleType>();
+  }
+
+  Type getElementAccessType(Type t) const {
+    if (auto listType = t.dyn_cast<PYDM::ListType>())
+      return listType.getElementStorageType();
+    else if (auto tupleType = t.dyn_cast<PYDM::TupleType>())
+      return tupleType.getElementStorageType();
+
+    llvm_unreachable("unsupported list type");
+  }
+};
+
+class StoreVarOpConversion : public OpConversionPattern<PYDM::StoreVarOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::StoreVarOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+
+    auto origStoreType =
+        srcOp.value().getType().dyn_cast<PYDM::PythonTypeInterface>();
+    if (!origStoreType)
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "not a python type for value()");
+    int typeCode = static_cast<int>(origStoreType.getTypeCode());
+
+    auto list = adaptor.getOperands()[0];
+    auto newValue = adaptor.getOperands()[1];
+    resetObjectList(loc, rewriter, list, typeCode, newValue);
+    rewriter.eraseOp(srcOp);
+    return success();
+  }
+};
+
+/// Lowers the subscript operator on builtin sequence types (list, tuple)
+/// based on a statically determined scalar slice (which can be positive or
+/// negative).
+class SubscriptOpBuiltinSequenceConversion
+    : public OpConversionPattern<PYDM::SubscriptOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::SubscriptOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto pySequence = srcOp.value();
+    if (!pySequence.getType().isa<PYDM::ListType, PYDM::TupleType>())
+      return rewriter.notifyMatchFailure(srcOp, "not builtin sequence");
+    auto pySlice = srcOp.slice();
+    if (!pySlice.getType().isa<PYDM::IntegerType>())
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "slice is not static integer type");
+    Type resultType = getTypeConverter()->convertType(srcOp.result().getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(srcOp,
+                                         "could not convert result type");
+
+    auto loc = srcOp.getLoc();
+    auto slice = adaptor.slice();
+    auto indexType = rewriter.getType<IndexType>();
+    Type statusType =
+        getTypeConverter()->convertType(srcOp.exc_result().getType());
+
+    Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, slice.getType());
+    Value listSizeIndex =
+        rewriter.create<Input::ListSizeOp>(loc, indexType, adaptor.value());
+    Value listSizeInteger = rewriter.create<arith::IndexCastOp>(
+        loc, slice.getType(), listSizeIndex);
+
+    // Split blocks:
+    //   indexLtZeroBlock
+    //   indexCheckBlock(sliceIndex : IndexType)
+    //   getElementBlock(sliceIndex : IndexType)
+    //   continuationBlock(exc_result, result)
+    //   failureBlock
+    Block *entryBlock = rewriter.getInsertionBlock();
+    Block *continuationBlock = rewriter.splitBlock(
+        rewriter.getInsertionBlock(), rewriter.getInsertionPoint());
+    Block *indexLtZeroBlock = rewriter.createBlock(continuationBlock);
+    Block *indexCheckBlock = rewriter.createBlock(continuationBlock);
+    indexCheckBlock->addArgument(indexType);
+    Block *getElementBlock = rewriter.createBlock(continuationBlock);
+    getElementBlock->addArgument(indexType);
+    Block *failureBlock = createSlowPathBlock(rewriter);
+    continuationBlock->addArguments({statusType, resultType});
+    rewriter.replaceOp(srcOp, continuationBlock->getArguments());
+
+    // Comparison index < 0.
+    {
+      rewriter.setInsertionPointToEnd(entryBlock);
+      Value ltZero = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::slt, slice, zero);
+      auto sliceIndex =
+          rewriter.create<arith::IndexCastOp>(loc, indexType, slice);
+      rewriter.create<mlir::CondBranchOp>(loc, ltZero, indexLtZeroBlock,
+                                          indexCheckBlock,
+                                          ValueRange{sliceIndex});
+    }
+
+    // Handle index < 0.
+    {
+      rewriter.setInsertionPointToEnd(indexLtZeroBlock);
+      Value positiveSlice =
+          rewriter.create<arith::AddIOp>(loc, slice, listSizeInteger);
+      Value positiveSliceIndex =
+          rewriter.create<arith::IndexCastOp>(loc, indexType, positiveSlice);
+      rewriter.create<mlir::BranchOp>(loc, ValueRange{positiveSliceIndex},
+                                      indexCheckBlock);
+    }
+
+    // Index check.
+    {
+      rewriter.setInsertionPointToEnd(indexCheckBlock);
+      Value sliceIndex = indexCheckBlock->getArgument(0);
+      Value ltSize = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::ult, sliceIndex, listSizeIndex);
+      rewriter.create<mlir::CondBranchOp>(loc, ltSize, getElementBlock,
+                                          ValueRange{sliceIndex}, failureBlock,
+                                          ValueRange{});
+    }
+
+    // Get element.
+    {
+      rewriter.setInsertionPointToEnd(getElementBlock);
+      Value successResult = getSuccessStatusValue(loc, rewriter);
+      Value resultValue = rewriter.create<Input::ListGetOp>(
+          loc, resultType, adaptor.value(), getElementBlock->getArgument(0));
+      rewriter.create<mlir::BranchOp>(loc, continuationBlock,
+                                      ValueRange{successResult, resultValue});
+    }
+
+    // Failure.
+    {
+      rewriter.setInsertionPointToEnd(failureBlock);
+      Value failureResult =
+          getFailureStatusValue(loc, rewriter, ExceptionCode::IndexError);
+      Value nullResult = getNullValue(loc, rewriter, resultType);
+      rewriter.create<mlir::BranchOp>(loc, continuationBlock,
+                                      ValueRange{failureResult, nullResult});
+    }
+
+    return success();
+  }
+};
+
+class UnboxOpConversion : public OpConversionPattern<PYDM::UnboxOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      PYDM::UnboxOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    auto loc = srcOp.getLoc();
+    auto list = adaptor.getOperands()[0];
+
+    // Target exception result type.
+    Type statusType =
+        getTypeConverter()->convertType(srcOp.exc_result().getType());
+    // Target unboxed type.
+    Type targetUnboxedType =
+        getTypeConverter()->convertType(srcOp.primitive().getType());
+    if (!targetUnboxedType || !statusType)
+      return rewriter.notifyMatchFailure(
+          srcOp, "could not convert unbox result types");
+
+    // Compute the target type code.
+    auto origUnboxedType =
+        srcOp.primitive().getType().dyn_cast<PYDM::PythonTypeInterface>();
+    if (!origUnboxedType)
+      return rewriter.notifyMatchFailure(
+          srcOp, "not a python type for primitive() unboxed result");
+    int typeCode = static_cast<int>(origUnboxedType.getTypeCode());
+
+    // Split the entry block.
+    Block *entryBlock = rewriter.getInsertionBlock();
+    Block *continuationBlock = rewriter.splitBlock(
+        rewriter.getInsertionBlock(), rewriter.getInsertionPoint());
+    Block *typesMatchBlock = rewriter.createBlock(continuationBlock);
+    Block *slowPathMismatchBlock = createSlowPathBlock(rewriter);
+    continuationBlock->addArguments({statusType, targetUnboxedType});
+    rewriter.replaceOp(srcOp, continuationBlock->getArguments());
+
+    // Type code extraction and comparison.
+    {
+      rewriter.setInsertionPointToEnd(entryBlock);
+      auto index0 =
+          rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
+      Value requiredTypeCodeValue = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(typeCode));
+      Value actualTypeCodeValue = rewriter.create<Input::ListGetOp>(
+          loc, rewriter.getI32Type(), list, index0);
+      Value typeCodeEqual = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::eq, requiredTypeCodeValue,
+          actualTypeCodeValue);
+      rewriter.create<mlir::CondBranchOp>(loc, typeCodeEqual, typesMatchBlock,
+                                          slowPathMismatchBlock);
+    }
+
+    // Fast path types match block.
+    {
+      rewriter.setInsertionPointToEnd(typesMatchBlock);
+      auto index1 =
+          rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
+      Value successResult = getSuccessStatusValue(loc, rewriter);
+      Value unboxedValue = rewriter.create<Input::ListGetOp>(
+          loc, targetUnboxedType, list, index1);
+      rewriter.create<mlir::BranchOp>(loc, continuationBlock,
+                                      ValueRange{successResult, unboxedValue});
+    }
+
+    // Slow path coercion on mismatch.
+    // TODO: Currently just fails - should emit a runtime call.
+    {
+      rewriter.setInsertionPointToEnd(slowPathMismatchBlock);
+      Value failureResult =
+          getFailureStatusValue(loc, rewriter, ExceptionCode::ValueError);
+      Value nullResult = getNullValue(loc, rewriter, targetUnboxedType);
+      rewriter.create<mlir::BranchOp>(loc, continuationBlock,
+                                      ValueRange{failureResult, nullResult});
+    }
+
+    return success();
+  }
+};
+
+//------------------------------------------------------------------------------
+// Outside pydm op conversions
+// These are largely identity conversions for CFG related standard ops, and
+// those that can be emitted as part of canonicalizations.
+//------------------------------------------------------------------------------
+
+class BuiltinBranchConversion : public OpConversionPattern<mlir::BranchOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      mlir::BranchOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<mlir::BranchOp>(srcOp, srcOp.dest(),
+                                                adaptor.destOperands());
+    return success();
+  }
+};
+
+class BuiltinCondBranchConversion
+    : public OpConversionPattern<mlir::CondBranchOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      mlir::CondBranchOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<mlir::CondBranchOp>(
+        srcOp, adaptor.condition(), srcOp.trueDest(),
+        adaptor.trueDestOperands(), srcOp.falseDest(),
+        adaptor.falseDestOperands());
+    return success();
+  }
+};
+
+class BuiltinSelectConversion : public OpConversionPattern<mlir::SelectOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      mlir::SelectOp srcOp, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<mlir::SelectOp>(srcOp, adaptor.condition(),
+                                                adaptor.true_value(),
+                                                adaptor.false_value());
+    return success();
+  }
+};
+
+}  // namespace
+
+void PYDM::populatePyDMToIREELoweringPatterns(MLIRContext *context,
+                                              TypeConverter &typeConverter,
+                                              RewritePatternSet &patterns) {
+  // PyDM conversions.
+  patterns.insert<
+      AllocFreeVarOpConversion, ApplyBinaryNumericConversion,
+      ApplyCompareNumericConversion, AssignSubscriptListConversion,
+      BoolToPredConversion, BoxOpConversion, MakeListOpBoxedConversion,
+      CallOpConversion, ConstantOpConversion, DynamicUnpackOpConversion,
+      ElideStaticInfoCast, FailureOpConversion, FuncOpConversion,
+      GetTypeCodeConversion, LoadVarOpConversion, MakeTupleOpConversion,
+      NegIntegerOpConversion, RaiseOnFailureOpConversion, ReturnOpConversion,
+      SequenceCloneBuiltinConversion, StoreVarOpConversion,
+      SubscriptOpBuiltinSequenceConversion, UnboxOpConversion>(typeConverter,
+                                                               context);
+
+  // External CFG ops.
+  patterns.insert<BuiltinBranchConversion, BuiltinCondBranchConversion,
+                  BuiltinSelectConversion>(typeConverter, context);
+
+  // Constants and constructors.
+  patterns.insert<NoneOpConversion>(typeConverter, context);
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp
new file mode 100644
index 0000000..a16fe16
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Transforms/ToIREE/TypeConverter.cpp
@@ -0,0 +1,116 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/PyDM/Transforms/ToIREE/TypeConverter.h"
+
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/PyDM/IR/PyDMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+static Type getVariantListType(Builder &builder) {
+  return builder.getType<IREE::Input::ListType>(
+      builder.getType<IREE::Input::VariantType>());
+}
+
+LoweringTypeConverter::LoweringTypeConverter() {
+  addConversion([](PYDM::NoneType t) -> Optional<Type> {
+    // TODO: This should really be a zero-width opaque value in the VM. Just
+    // making it an integer now.
+    return mlir::IntegerType::get(t.getContext(), 32);
+  });
+  addConversion([](PYDM::ExceptionResultType t) -> Optional<Type> {
+    return mlir::IntegerType::get(t.getContext(), 32);
+  });
+  addConversion([](PYDM::ObjectType t) -> Optional<Type> {
+    Builder b(t.getContext());
+    return getVariantListType(b);
+  });
+
+  // Bool.
+  addConversion([&](PYDM::BoolType t) -> Optional<Type> {
+    return mlir::IntegerType::get(t.getContext(), 1);
+  });
+
+  // Integer type hierarchy.
+  addConversion([&](PYDM::IntegerType t) -> Optional<Type> {
+    Builder b(t.getContext());
+    if (t.isWeak()) {
+      return getWeakIntegerType(b);
+    }
+    return b.getIntegerType(t.getBitWidth());
+  });
+
+  // Real type hierarchy.
+  addConversion([&](PYDM::RealType t) -> Optional<Type> {
+    Builder b(t.getContext());
+    if (t.isWeak()) {
+      return getWeakFloatType(b);
+    }
+    return t.getFloatType();
+  });
+
+  // Tuple, List.
+  // TODO: Fork these based on CollectionStorageClass as they can avoid
+  // using variant lists.
+  addConversion([&](PYDM::ListType t) -> Optional<Type> {
+    Builder b(t.getContext());
+    return getVariantListType(b);
+  });
+  addConversion([&](PYDM::TupleType t) -> Optional<Type> {
+    Builder b(t.getContext());
+    return getVariantListType(b);
+  });
+
+  // Variable references.
+  addConversion([](PYDM::FreeVarRefType t) -> Optional<Type> {
+    // Just an object record.
+    Builder b(t.getContext());
+    return getVariantListType(b);
+  });
+
+  // Explicit conversions for allowed built-in types (avoids default conversion
+  // which can mask issues).
+  addConversion([](mlir::IndexType t) -> Optional<Type> { return t; });
+  addConversion([](mlir::IntegerType t) -> Optional<Type> { return t; });
+  addConversion([](mlir::FloatType t) -> Optional<Type> { return t; });
+  addConversion([](mlir::IndexType t) -> Optional<Type> { return t; });
+  addConversion([](IREE::Input::ListType t) -> Optional<Type> { return t; });
+}
+
+Type LoweringTypeConverter::getBoolType(Builder b) const {
+  return b.getIntegerType(boolBits);
+}
+
+Type LoweringTypeConverter::getWeakIntegerType(Builder b) const {
+  return b.getIntegerType(weakIntegerBits);
+}
+
+Type LoweringTypeConverter::getWeakFloatType(Builder b) const {
+  switch (weakFloatType) {
+    case WeakFloatType::F32:
+      return b.getF32Type();
+    case WeakFloatType::F64:
+      return b.getF64Type();
+  }
+}
+
+bool LoweringTypeConverter::isTypeLegal(Type t) const {
+  return t.isa<mlir::IntegerType, mlir::FloatType, mlir::IndexType,
+               IREE::Input::ListType>();
+}
+
+bool LoweringTypeConverter::areTypesLegal(TypeRange types) const {
+  for (Type t : types) {
+    if (!isTypeLegal(t)) return false;
+  }
+  return true;
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/CMakeLists.txt
new file mode 100644
index 0000000..2417731
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_library(IREEPyDMUtils
+  TypeInference.cpp
+
+  LINK_LIBS PUBLIC
+  IREEPyDMDialect
+  MLIRIR
+  MLIRStandard
+  MLIRTransformUtils
+)
+
+iree_dialects_target_includes(IREEPyDMUtils)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp
new file mode 100644
index 0000000..ed09f96
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/PyDM/Utils/TypeInference.cpp
@@ -0,0 +1,109 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/PyDM/Utils/TypeInference.h"
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/TypeRange.h"
+
+using namespace mlir;
+namespace PYDM = mlir::iree_compiler::IREE::PYDM;
+using namespace PYDM;
+
+PermutedTypePropagator::PermutedBlockInfo *
+PermutedTypePropagator::addPermutedBlockToParent(ParentBlockInfo *parentInfo,
+                                                 Block *block) {
+  auto *permutedInfo = new (allocator) PermutedBlockInfo();
+  permutedInfo->permutedBlock = block;
+  permutedInfo->parentInfo = parentInfo;
+  permutedInfo->signature =
+      FunctionType::get(context, block->getArgumentTypes(), TypeRange{});
+  permutedInfo->next = parentInfo->permutationHead;
+  parentInfo->permutationHead = permutedInfo;
+  parentInfo->size += 1;
+  return permutedInfo;
+}
+
+PermutedTypePropagator::ParentBlockInfo *
+PermutedTypePropagator::lookupParentBlock(Block *forBlock) {
+  auto it = permutedBlocks.find(forBlock);
+  if (it == permutedBlocks.end()) {
+    // Unaccounted for blocks are assumed to be parents.
+    auto *parentInfo = allocator.Allocate<ParentBlockInfo>();
+    new (parentInfo) ParentBlockInfo();
+
+    parentInfo->parentBlock = forBlock;
+    // The parent is also considered a permutation.
+    auto *permutedInfo = addPermutedBlockToParent(parentInfo, forBlock);
+    permutedBlocks.insert(std::make_pair(forBlock, permutedInfo));
+    return parentInfo;
+  }
+
+  return it->second->parentInfo;
+}
+
+Block *PermutedTypePropagator::findBlockPermutation(ParentBlockInfo *parentInfo,
+                                                    FunctionType signature) {
+  for (PermutedBlockInfo *info = parentInfo->permutationHead; info;
+       info = info->next) {
+    if (info->signature == signature) return info->permutedBlock;
+  }
+  return nullptr;
+}
+
+static bool checkAllBlockArgsMapped(Block *block,
+                                    BlockAndValueMapping &mapping) {
+  for (Value arg : block->getArguments()) {
+    if (!mapping.contains(arg)) return false;
+  }
+  return true;
+}
+
+Block *PermutedTypePropagator::createBlockPermutation(
+    ParentBlockInfo *parentInfo, TypeRange newArgumentTypes,
+    BlockPermuteCallback initializeCallback) {
+  Block *parentBlock = parentInfo->parentBlock;
+  Block *newBlock = new Block();
+  newBlock->addArguments(newArgumentTypes);
+  newBlock->insertBefore(parentBlock);
+
+  BlockAndValueMapping mapping;
+  mapping.map(parentBlock, newBlock);
+  initializeCallback(newBlock, parentBlock, mapping);
+  assert(checkAllBlockArgsMapped(parentBlock, mapping) &&
+         "permuted block initializer did not map all block arguments");
+
+  // Inline.
+  for (auto &op : *parentBlock) {
+    newBlock->push_back(op.clone(mapping));
+  }
+
+  addPermutedBlockToParent(parentInfo, newBlock);
+  return newBlock;
+}
+
+SmallVector<PermutedTypePropagator::BlockPredecessor>
+PermutedTypePropagator::findMismatchedBlockPredecessors(Block *block) {
+  SmallVector<BlockPredecessor> results;
+  for (Block *predecessor : block->getPredecessors()) {
+    Operation *terminator = predecessor->getTerminator();
+    auto branchOp = llvm::cast<BranchOpInterface>(terminator);
+    unsigned successorIndex = 0;
+    for (Block *successor : terminator->getSuccessors()) {
+      if (successor == block) break;
+      successorIndex += 1;
+    }
+    auto successorOperands = branchOp.getSuccessorOperands(successorIndex);
+    assert(successorOperands && "expected branch with explicit operands");
+    TypeRange operandTypes(*successorOperands);
+    if (block->getArgumentTypes() != operandTypes) {
+      results.push_back(BlockPredecessor{
+          branchOp, successorIndex,
+          FunctionType::get(context, operandTypes, TypeRange{})});
+    }
+  }
+  return results;
+}