Initial commit: Add new iree_pydm dialect and supporting code to compile to it. (#6907)

* Initial commit: Add new iree_pydm dialect and supporting code to compile to it.

* Short for "python data-model", which is the closest thing to an authoritative spec as exists.
* Attempts to be a faithful modeling of Python (currently incomplete) which also has been engineered to be realized on the IREE VM.
* Aim is for Python program extraction which can preserve full generality while also benefiting (greatly) from any present type information.
* Trying to be a counter-balance to previous approaches which tended to make "not quite Python" uncanny valleys: it will be a while before this is complete but it attempts to model the dynamism of the language.
* Inspired by my previous work on npcomp but a complete rewrite with lessons learned incorporated. Conceptually builds on the npcomp version in several key ways:
  * Models the Python type system with both generality and specialization in mind (npcomp required specialization just making a "not quite Python").
  * Leverages knowledge of the types and constraints of the IREE VM, paying attention to having an IR which should lower without much fuss.
  * Aims to be self-contained, not mixing and matching unrelated MLIR dialects and types (currently just incorporates parts of SCF and CFG operations).
  * More general Python importer bridge, with an interop model based on intrinsics and supporting a runtime library largely authored in the language itself.
  * Implements general flow control.
* I am landing this here, as an IREE public dialect because:
  * It is already setup for other LLVM projects to take a source dep on it (allowing others to compile their Python DSLs to IREE). NPComp already does this.
  * This is the only place that also allows the IREE compiler to lower it to run on its VM.
  * It is already plumbed through to IREE's (new) public C and Python compiler API distributables.
* There are three forks of work from here:
  1. Finish modeling the dialect and importer (needs for/while support, type/op support for containers and iterators).
  2. Implement first pass source dialect optimizations, namely type propagation, and simplification away from dynamic forms where possible.
  3. Lowerings to IREE VM. The types and ops are laid out so that this can hopefully be a mostly 1:1, relatively simple lowering.
* I expect getting everything to connect is roughly a week of coding on each category.
* Requires upstream: https://reviews.llvm.org/D108898
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 99fb60b..9d0b7178 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -121,6 +121,7 @@
         "//iree/compiler/InputConversion/TOSA",
         "//iree/compiler/Translation:IREEVM",
         "//llvm-external-projects/iree-dialects:IREEDialect",
+        "//llvm-external-projects/iree-dialects:IREEPyDMDialect",
         "@llvm-project//mlir:IR",
     ],
 )
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 956752c..144986f 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -196,6 +196,7 @@
       "init_iree_dialects.h"
       "init_iree_passes.h"
     DEPS
+      IREEDialectsIREEPyDMDialect
       IREEDialectsIREEDialect
       MLIRIR
       iree::compiler::Bindings::Native::Transforms
diff --git a/iree/tools/init_iree_dialects.h b/iree/tools/init_iree_dialects.h
index 85ad810..c5d1f6d 100644
--- a/iree/tools/init_iree_dialects.h
+++ b/iree/tools/init_iree_dialects.h
@@ -13,6 +13,7 @@
 #define IREE_TOOLS_INIT_IREE_DIALECTS_H_
 
 #include "iree-dialects/Dialect/IREE/IREEDialect.h"
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
 #include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
@@ -37,7 +38,8 @@
                   IREE::VM::VMDialect,
                   IREE::VMVX::VMVXDialect,
                   IREE::Vulkan::VulkanDialect,
-                  mlir::iree::IREEDialect>();
+                  mlir::iree::IREEDialect,
+                  mlir::iree_pydm::IREEPyDMDialect>();
   // clang-format on
 }
 
diff --git a/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt b/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt
index dc01e86..f0253a1 100644
--- a/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt
+++ b/llvm-external-projects/iree-compiler-api/python/CMakeLists.txt
@@ -1,5 +1,9 @@
 include(AddMLIRPython)
 
+# Specifies that all MLIR packages are co-located under npcomp.
+# TODO: Add an upstream cmake param for this vs having a global here.
+add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=iree.compiler.")
+
 ################################################################################
 # Sources
 ################################################################################
diff --git a/llvm-external-projects/iree-dialects/BUILD b/llvm-external-projects/iree-dialects/BUILD
index 429aab7..1531434 100644
--- a/llvm-external-projects/iree-dialects/BUILD
+++ b/llvm-external-projects/iree-dialects/BUILD
@@ -10,12 +10,18 @@
 
 filegroup(
     name = "TdFilegroup",
-    srcs = glob(["include/iree-dialects/Dialect/IREE/*.td"]),
+    srcs = glob([
+        "include/iree-dialects/Dialect/IREE/*.td",
+        "include/iree-dialects/Dialect/IREEPyDM/IR/*.td",
+    ]),
 )
 
 td_library(
     name = "TdFiles",
-    srcs = glob(["include/iree-dialects/Dialect/IREE/*.td"]),
+    srcs = glob([
+        "include/iree-dialects/Dialect/IREE/*.td",
+        "include/iree-dialects/Dialect/IREEPyDM/IR/*.td",
+    ]),
     includes = ["include"],
     deps = [
         "@llvm-project//mlir:OpBaseTdFiles",
@@ -23,6 +29,10 @@
     ],
 )
 
+################################################################################
+# IREE dialect
+################################################################################
+
 gentbl_cc_library(
     name = "IREEOpsIncGen",
     strip_include_prefix = "include",
@@ -72,6 +82,69 @@
     ],
 )
 
+################################################################################
+# IREEPyDM Dialect
+################################################################################
+
+gentbl_cc_library(
+    name = "IREEPyDMOpsIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["-gen-dialect-decls"],
+            "include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsDialect.h.inc",
+        ),
+        (
+            ["-gen-dialect-defs"],
+            "include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsDialect.cpp.inc",
+        ),
+        (
+            ["-gen-op-decls"],
+            "include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.h.inc",
+        ),
+        (
+            ["-gen-op-defs"],
+            "include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.cpp.inc",
+        ),
+        (
+            ["-gen-typedef-decls"],
+            "include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsTypes.h.inc",
+        ),
+        (
+            ["-gen-typedef-defs"],
+            "include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsTypes.cpp.inc",
+        ),
+    ],
+    tblgen = "@llvm-project//mlir:mlir-tblgen",
+    td_file = "include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.td",
+    deps = [
+        ":TdFiles",
+        "@llvm-project//mlir:CallInterfacesTdFiles",
+        "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
+    ],
+)
+
+cc_library(
+    name = "IREEPyDMDialect",
+    srcs = glob([
+        "lib/Dialect/IREEPyDM/IR/*.cpp",
+    ]),
+    hdrs = glob(["include/iree-dialects/Dialect/IREEPyDM/IR/*.h"]),
+    includes = ["include"],
+    deps = [
+        ":IREEPyDMOpsIncGen",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:CallOpInterfaces",
+        "@llvm-project//mlir:ControlFlowInterfaces",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Support",
+    ],
+)
+
+################################################################################
+# CAPI
+################################################################################
+
 cc_library(
     name = "CAPI",
     srcs = [
@@ -82,6 +155,7 @@
     ],
     deps = [
         ":IREEDialect",
+        ":IREEPyDMDialect",
         "@llvm-project//mlir:CAPIIR",
     ],
 )
diff --git a/llvm-external-projects/iree-dialects/CMakeLists.txt b/llvm-external-projects/iree-dialects/CMakeLists.txt
index c24e553..3f3da9c 100644
--- a/llvm-external-projects/iree-dialects/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/CMakeLists.txt
@@ -64,6 +64,7 @@
 add_subdirectory(include)
 add_subdirectory(lib)
 add_subdirectory(test)
+add_subdirectory(tools)
 
 if(MLIR_ENABLE_BINDINGS_PYTHON)
   add_subdirectory(python)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
index 4738007..ac17a04 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Dialects.h
@@ -7,14 +7,48 @@
 #ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_DIALECTS_H
 #define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_DIALECTS_H
 
+#include "mlir-c/IR.h"
 #include "mlir-c/Registration.h"
 
 #ifdef __cplusplus
 extern "C" {
 #endif
 
+//===----------------------------------------------------------------------===//
+// IREEDialect
+//===----------------------------------------------------------------------===//
+
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IREE, iree);
 
+//===----------------------------------------------------------------------===//
+// IREEPyDMDialect
+//===----------------------------------------------------------------------===//
+
+MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IREEPyDM, iree_pydm);
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAIREEPyDMPrimitiveType(MlirType type);
+
+#define IREEPYDM_DECLARE_NULLARY_TYPE(Name)                         \
+  MLIR_CAPI_EXPORTED bool mlirTypeIsAIREEPyDM##Name(MlirType type); \
+  MLIR_CAPI_EXPORTED MlirType mlirIREEPyDM##Name##TypeGet(MlirContext ctx);
+
+IREEPYDM_DECLARE_NULLARY_TYPE(Bool)
+IREEPYDM_DECLARE_NULLARY_TYPE(Bytes)
+IREEPYDM_DECLARE_NULLARY_TYPE(Integer)
+IREEPYDM_DECLARE_NULLARY_TYPE(ExceptionResult)
+IREEPYDM_DECLARE_NULLARY_TYPE(List)
+IREEPYDM_DECLARE_NULLARY_TYPE(None)
+IREEPYDM_DECLARE_NULLARY_TYPE(Real)
+IREEPYDM_DECLARE_NULLARY_TYPE(Str)
+IREEPYDM_DECLARE_NULLARY_TYPE(Tuple)
+IREEPYDM_DECLARE_NULLARY_TYPE(Type)
+
+#undef IREEPYDM_DECLARE_NULLARY_TYPE
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAIREEPyDMObject(MlirType type);
+MLIR_CAPI_EXPORTED MlirType mlirIREEPyDMObjectTypeGet(MlirContext context,
+                                                      MlirType primitive);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects-c/Utils.h b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Utils.h
new file mode 100644
index 0000000..696f6ad
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects-c/Utils.h
@@ -0,0 +1,26 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_UTILS_H
+#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_UTILS_H
+
+#include "mlir-c/IR.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// TODO: Upstream C/Python APIs for symbol table.
+// Looks up the referrent operation with the given flat symbol, starting from
+// a specific op.
+MLIR_CAPI_EXPORTED MlirOperation
+ireeLookupNearestSymbolFrom(MlirOperation fromOp, MlirAttribute symbolRefAttr);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif  // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_C_UTILS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
index 952be9f..61df04e 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(IREE)
+add_subdirectory(IREEPyDM)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/CMakeLists.txt
new file mode 100644
index 0000000..f33061b
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/CMakeLists.txt
new file mode 100644
index 0000000..5d94ff9
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/CMakeLists.txt
@@ -0,0 +1,3 @@
+add_mlir_dialect(IREEPyDMOps iree_pydm)
+add_mlir_doc(IREEPyDMDialect IREEPyDMDialect IREEPyDM/ -gen-dialect-doc)
+add_mlir_doc(IREEPyDMOps IREEPyDMOps IREEPyDM/ -gen-op-doc)
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMBase.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMBase.td
new file mode 100644
index 0000000..40071d7
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMBase.td
@@ -0,0 +1,41 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_BASE_TD
+#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_BASE_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def IREEPyDM_Dialect : Dialect {
+  let name = "iree_pydm";
+  let summary = "Python data model as expressible for compilation to IREE";
+  let description = [{
+    Provides an implementation of the Python Data Model
+    (https://docs.python.org/3/reference/datamodel.html) as adapted to run
+    on the IREE VM.
+
+    This dialect aims for correctness of the subset of the Python Data Model
+    supported by IREE, with future work focused on completeness. Goals:
+      - Provide a suitable set of types and ops for trivially parsing Python
+        ASTs to this dialect, in a similar fashion as the CPython interpreter
+        parses to bytecode.
+      - Focused on embedded Python program extraction, where subsets of
+        programs are compiled from a running Python instance for later
+        hermetic execution.
+      - Makes IR decisions conducive to progress typeification, enabling
+        optimization benefits compared to fully untyped programs.
+  }];
+  let cppNamespace = "::mlir::iree_pydm";
+}
+
+class IREEPyDM_Op<string mnemonic, list<OpTrait> traits = []> :
+    Op<IREEPyDM_Dialect, mnemonic, traits>;
+class IREEPyDM_PureOp<string mnemonic, list<OpTrait> traits = []> :
+    Op<IREEPyDM_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])>;
+class IREEPyDM_TypeDef<string name> : TypeDef<IREEPyDM_Dialect, name>;
+
+#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_BASE_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.h
new file mode 100644
index 0000000..fb714cb
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.h
@@ -0,0 +1,45 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_H
+#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_H
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace iree_pydm {
+
+/// Base class for all unboxed primitive types.
+class PrimitiveType : public mlir::Type {
+ public:
+  using mlir::Type::Type;
+  static bool classof(Type type);
+};
+
+}  // namespace iree_pydm
+}  // namespace mlir
+
+// Include generated dialect code (this comment blocks clang-format from
+// clobbering order).
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsDialect.h.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsTypes.h.inc"
+
+namespace mlir {
+namespace iree_pydm {
+
+inline bool PrimitiveType::classof(Type type) {
+  // Must corresponds with each subclass.
+  return type.isa<BoolType, BytesType, IntegerType, ExceptionResultType,
+                  ListType, NoneType, RealType, StrType, TupleType, TypeType>();
+}
+
+}  // namespace iree_pydm
+}  // namespace mlir
+
+#endif  // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.td
new file mode 100644
index 0000000..3bc6167
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.td
@@ -0,0 +1,225 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_TD
+#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_TD
+
+include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMBase.td"
+
+//===----------------------------------------------------------------------===//
+// Unboxed Primitive Types
+//===----------------------------------------------------------------------===//
+
+// Declare a new primitive type.
+// When adding a new one, update the PrimitiveType::classof method in
+// IREEPyDMDialect.h.
+class IREEPyDM_PrimitiveTypeDef<string name> :
+    TypeDef<IREEPyDM_Dialect, name, /*traits=*/[],
+    /*baseCppClass=*/"::mlir::iree_pydm::PrimitiveType"> {
+}
+
+def IREEPyDM_AnyPrimitiveType : Type<
+    CPred<"$_self.isa<::mlir::iree_pydm::PrimitiveType>()">,
+    "unboxed primitive type",
+    "::mlir::iree_pydm::PrimitiveType">;
+
+def IREEPyDM_BoolType : IREEPyDM_PrimitiveTypeDef<"Bool"> {
+  let mnemonic = "bool";
+
+  let summary = "Type of bool values";
+
+  let description = [{
+    Represents boolean types in the data model, with values 'True' and 'False'.
+    Note that the data model considers the bool type to be a subtype of
+    integer, which is important during numeric promotion.
+  }];
+}
+
+def IREEPyDM_BytesType : IREEPyDM_PrimitiveTypeDef<"Bytes"> {
+  let mnemonic = "bytes";
+
+  let summary = "Type representing byte strings";
+
+  let description = [{
+    Represent Pythong 'bytes'.
+  }];
+}
+
+def IREEPyDM_ExceptionResultType : IREEPyDM_PrimitiveTypeDef<"ExceptionResult"> {
+  let mnemonic = "exception_result";
+
+  let summary = "Either successful or exceptional result";
+
+  let description = [{
+    The exception result connotes a logical success/failure state and can
+    also carry additional user-level exception data. It is used as a return
+    value from functions and many failable operations. Boxing a successful
+    exception result produces a None object. Boxing a failed result produces
+    an exception object.
+  }];
+}
+
+def IREEPyDM_IntegerType : IREEPyDM_PrimitiveTypeDef<"Integer"> {
+  let mnemonic = "integer";
+
+  let summary = "Type of integer values";
+
+  let description = [{
+    Represents the `numbers.Integral` type in the data model. At this abstract
+    level, it should be considered conformant with the data model (i.e.
+    unbounded). However, compiler flags will generally be used to interpret
+    this type in a more bounded fashion (i32, i64, etc).
+  }];
+}
+
+def IREEPyDM_ListType : IREEPyDM_PrimitiveTypeDef<"List"> {
+  let mnemonic = "list";
+
+  let summary = "Mutable sequence of boxed values";
+
+  let description = [{
+    Corresponds to the "Lists" type in the data model.
+  }];
+}
+
+def IREEPyDM_NoneType : IREEPyDM_PrimitiveTypeDef<"None"> {
+  let mnemonic = "none";
+
+  let summary = "Type of the single 'None' value";
+
+  let description = [{
+    Represents the 'None' type in the standard type hierarchy.
+  }];
+}
+
+def IREEPyDM_RealType : IREEPyDM_PrimitiveTypeDef<"Real"> {
+  let mnemonic = "real";
+
+  let summary = "Type of floating point values";
+
+  let description = [{
+    Represents the `numbers.Real` type in the data model. At this abstract
+    level, it should be considered conformant with the data model (i.e.
+    double precision). However, compiler flags will generally be used to
+    interpret this type in a more bounded fashion (f32).
+  }];
+}
+
+def IREEPyDM_StrType : IREEPyDM_PrimitiveTypeDef<"Str"> {
+  let mnemonic = "str";
+
+  let summary = "Type representing unicode strings";
+
+  let description = [{
+    Corresponds to the "Strings" type in the data model.
+  }];
+}
+
+def IREEPyDM_TupleType : IREEPyDM_PrimitiveTypeDef<"Tuple"> {
+  let mnemonic = "tuple";
+
+  let summary = "Immutable sequence of boxed values";
+
+  let description = [{
+    Corresponds to the "Tuples" type in the data model.
+  }];
+}
+
+def IREEPyDM_TypeType : IREEPyDM_PrimitiveTypeDef<"Type"> {
+  let mnemonic = "type";
+
+  let summary = "Type associated with a value";
+
+  let description = [{
+    Corresponds to the Python `Type` class. It is considered a primitive because
+    the data model cannot be represented without it.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Boxed objects
+//===----------------------------------------------------------------------===//
+
+def IREEPyDM_ObjectType : IREEPyDM_TypeDef<"Object"> {
+  let mnemonic = "object";
+
+  let summary = "Core data type having an identity, type and value";
+
+  let description = [{
+    In terms of a typical Python runtime, objects are the primary data type.
+    An object can represent every primitive and user defined type and value
+    in the system. The act of converting a primitive to an object is called
+    boxing, and doing so gives it an identity. Objects can be unboxed to
+    specific primitive types.
+
+    The system will function dynamically if specified completely in
+    terms of untyped object types. Objects can be parameterized with a specific
+    primitive type to support progressive typeification.
+  }];
+
+  let parameters = (ins
+    "::mlir::iree_pydm::PrimitiveType":$primitiveType
+  );
+
+  let printer = [{
+    $_printer << getMnemonic();
+    if (getImpl()->primitiveType)
+      $_printer << "<" << getImpl()->primitiveType << ">";
+  }];
+
+  let parser = [{
+    if (parser.parseOptionalLess())
+      return get($_ctxt, nullptr);
+
+    Type t;
+    if ($_parser.parseType(t))
+      return Type();
+    if ($_parser.parseGreater())
+      return Type();
+    if (auto primitiveType = t.dyn_cast<PrimitiveType>())
+      return get($_ctxt, primitiveType);
+    else {
+      $_parser.emitError(
+          $_parser.getNameLoc(), "expected a primitive type");
+      return Type();
+    }
+  }];
+
+  let extraClassDeclaration = [{
+    static bool isGenericObjectType(Type t) {
+      if (auto objectType = t.dyn_cast_or_null<ObjectType>())
+        return !objectType.getPrimitiveType();
+      return false;
+    }
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Predicates and aggregate definitions
+//===----------------------------------------------------------------------===//
+
+def IREEPyDM_PrimitiveType : Type<CPred<
+  "$_self.isa<::mlir::iree_pydm::PrimitiveType>()">,
+  "Python unboxed primitive type">;
+
+def IREEPyDM_AnyValueType : AnyTypeOf<[
+  IREEPyDM_ObjectType,
+  IREEPyDM_PrimitiveType,
+], "Python boxed or unboxed value">;
+
+def IREEPyDM_GenericObjectType : Type<
+    CPred<"::mlir::iree_pydm::ObjectType::isGenericObjectType($_self)">,
+    "generic object",
+    "::mlir::iree_pydm::ObjectType">,
+    BuildableType<"$_builder.getType<::mlir::iree_pydm::ObjectType>(nullptr)">;
+
+// TODO: Upstream this. Missing from OpBase.td.
+def IREEPyDM_FlatSymbolRefArrayAttr :
+    TypedArrayAttrBase<FlatSymbolRefAttr, "flat symbol ref array attribute"> {
+  let constBuilderCall = ?;
+}
+
+#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_DIALECT_TD
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.h
new file mode 100644
index 0000000..d08b0d2
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.h
@@ -0,0 +1,23 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_IREEPYDM_IR_OPS_H
+#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_IREEPYDM_IR_OPS_H
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.h.inc"
+
+#endif  // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_IREEPYDM_IR_OPS_H
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.td
new file mode 100644
index 0000000..2a887f9
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.td
@@ -0,0 +1,525 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_OPS_TD
+#define IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_OPS_TD
+
+include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/IR/SymbolInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Functions
+//===----------------------------------------------------------------------===//
+
+// TODO: Move arg attributes to arguments, and generally rework the free_vars,
+// cell_vars attributes (with better modeling, they may not be needed at all).
+def IREEPyDM_FuncOp : IREEPyDM_Op<"func", [
+    IsolatedFromAbove,
+    FunctionLike,
+    CallableOpInterface,
+    Symbol]> {
+  let summary = "Python func";
+  let description = [{
+    Python functions map arguments to results and have the following additional
+    characteristics:
+      - Have arguments that are either !object or primitive types. Typical user
+        funcs will just be !object based unless if refined.
+      - Returns a (LogicalResult, value). The LogicalResult will be expanded in
+        the future to be a full exception record, and the value is an
+        object/primitive (which may be a tuple/sequence if unpacking).
+      - Have an explicit list of free and cell variable names which can be
+        accessed from within the function.
+      - ... other things as needed.
+      - Entry block with arguments matching the function arguments.
+
+    These functions are CFG based, functioning as a Block with a single implicit
+    exception handler which matches all exceptions and exits with a failing
+    status on failure.
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       TypeAttr:$type,
+                       StrArrayAttr:$arg_names,
+                       StrArrayAttr:$free_vars,
+                       StrArrayAttr:$cell_vars,
+                       OptionalAttr<StrAttr>:$sym_visibility);
+  let regions = (region AnyRegion:$body);
+
+  let extraClassDeclaration = [{
+    /// Add an entry block to an empty function and set up the block arguments
+    /// to match the signature of the function.
+    Block *addEntryBlock();
+
+    Block *addBlock() {
+      assert(!empty() && "function should at least have an entry block");
+      push_back(new Block());
+      return &back();
+    }
+
+    /// Returns the type of this function.
+    FunctionType getType() {
+      return getOperation()->getAttrOfType<TypeAttr>(getTypeAttrName())
+          .getValue()
+          .cast<FunctionType>();
+    }
+
+    /// Hook for OpTrait::FunctionLike, returns the number of function
+    /// arguments. Depends on the type attribute being correct as checked by
+    /// verifyType.
+    unsigned getNumFuncArguments() { return getType().getInputs().size(); }
+
+    /// Hook for OpTrait::FunctionLike, returns the number of function results.
+    /// Depends on the type attribute being correct as checked by verifyType.
+    unsigned getNumFuncResults() { return getType().getResults().size(); }
+
+    /// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
+    /// attribute is present. This can check for preconditions of the
+    /// getNumArguments hook not failing.
+    LogicalResult verifyType();
+
+    Region *getCallableRegion() { return &body(); }
+    ArrayRef<Type> getCallableResults() {
+      return getType().getResults();
+    }
+  }];
+
+  let parser = [{ return ::parseFuncOp(parser, result); }];
+  let printer = [{ return ::print(*this, p); }];
+  let verifier = [{ return ::verify(*this); }];
+}
+
+def IREEPyDM_ReturnOp : IREEPyDM_Op<"return", [
+    NoSideEffect,
+    HasParent<"FuncOp">,
+    ReturnLike,
+    Terminator]> {
+  let summary = "Successful return from a Python function";
+  let description = [{
+    Returns a value from a Python function.
+  }];
+
+  let arguments = (ins IREEPyDM_AnyValueType:$value);
+  let assemblyFormat = [{
+    $value `:` type($value) attr-dict
+  }];
+}
+
+def IREEPyDM_RaiseOnFailureOp : IREEPyDM_Op<"raise_on_failure", [
+    HasParent<"FuncOp">]> {
+  let summary = "Performs a non-local exit on failure of an ExceptionResult";
+  let description = [{
+    This op handles the vast majority of cases where a failure simply needs
+    to be propagated to the next most frame (typically returning it from
+    a function).
+
+    If the `exc_result` represents a failure, control will not proceed
+    past this operation.
+  }];
+
+  let arguments = (ins IREEPyDM_ExceptionResultType:$exc_result);
+  let assemblyFormat = [{
+    $exc_result `:` type($exc_result) attr-dict
+  }];
+}
+
+def IREEPyDM_CallOp : IREEPyDM_Op<"call", [
+    CallOpInterface, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Call a `func` op";
+  let description = [{
+    This is the most primitive call operation that all other static calls
+    decay into.
+
+  }];
+
+  let arguments = (ins FlatSymbolRefAttr:$callee,
+                       Variadic<IREEPyDM_AnyValueType>:$operands);
+  let results = (outs
+      IREEPyDM_ExceptionResultType:$exc_result,
+      IREEPyDM_AnyValueType:$result);
+
+  let extraClassDeclaration = [{
+    StringRef getCallee() { return callee(); }
+    FunctionType getCalleeType();
+
+    /// Get the argument operands to the called function.
+    operand_range getArgOperands() {
+      return {arg_operand_begin(), arg_operand_end()};
+    }
+
+    operand_iterator arg_operand_begin() { return operand_begin(); }
+    operand_iterator arg_operand_end() { return operand_end(); }
+
+    /// Return the callee of this operation.
+    CallInterfaceCallable getCallableForCallee() {
+      return (*this)->getAttrOfType<SymbolRefAttr>("callee");
+    }
+  }];
+
+  let assemblyFormat = [{
+    $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
+  }];
+}
+
+def IREEPyDM_PatternMatchCallOp : IREEPyDM_Op<"pattern_match_call", [
+    DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Combines multiple functions together for generic dispatch";
+  let description = [{
+    This op enables specialized arity and type based dispatch via simple
+    pattern matching. It is generally used by the implementation to provide
+    both fully generic, runtime based implementations with the option to
+    select a specialization.
+
+    Alternatives are split into two groups:
+      - 'generic': Provides a maximally generic fallback implementation.
+        Generally an alternative will be matched purely based on arity and
+        structure of arguments. First match wins.
+      - 'specific': Provides a specific, strongly typed implementation. Matching
+        is by full type signature. First match wins.
+
+    Generally, during compilation, such calls will decay into regular calls
+    to one of the referenced, backing functions. How this is
+    done depends on phase:
+      - During canonicalization: A match to a specific function will be taken
+        immediately.
+      - After type inference: Any leftovers must be matched to a generic (by
+        an explicit pass) for runtime resolution.
+  }];
+
+  let arguments = (ins IREEPyDM_FlatSymbolRefArrayAttr:$generic_match,
+                       IREEPyDM_FlatSymbolRefArrayAttr:$specific_match,
+                       Variadic<IREEPyDM_AnyValueType>:$operands);
+  let results = (outs
+      IREEPyDM_ExceptionResultType:$exc_result,
+      IREEPyDM_AnyValueType:$result);
+
+  let assemblyFormat = [{
+    `(` $operands `)` `:` functional-type($operands, results)
+    `matching` `generic` $generic_match `specific` $specific_match
+    attr-dict
+  }];
+}
+
+def IREEPyDM_DynamicCallOp : IREEPyDM_PureOp<"dynamic_call", [
+    DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Dynamic dispatch to a `func`.";
+  let description = [{
+    Performs full dynamic dispatch to a function. Most imports start in this
+    form and are progressively refined to something more specific as more
+    information is known. In general, for well-formed programs, there should
+    not be any `dynamic_call`s left at the lowest levels of the compiler.
+  }];
+
+  let arguments = (ins FlatSymbolRefAttr:$callee,
+                       Variadic<IREEPyDM_AnyValueType>:$operands);
+  let results = (outs
+      IREEPyDM_ExceptionResultType:$exc_result,
+      IREEPyDM_ObjectType:$result);
+
+  let assemblyFormat = [{
+    $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Boxing/unboxing
+//===----------------------------------------------------------------------===//
+
+def IREEPyDM_BoxOp : IREEPyDM_PureOp<"box"> {
+  let summary = "Boxes a primitive into an object";
+  let description = [{
+    Given a PrimitiveType, boxes it into an appropriate !object, establishing
+    identity. For value typed primitives, it is always safe to box, but for
+    reference primitives, the providence must be tracked and the original boxed
+    value used (vs boxing a new one). Failure to do so will result in aliased
+    objects.
+  }];
+
+  let arguments = (ins
+    IREEPyDM_AnyPrimitiveType:$primitive
+  );
+  let results = (outs
+    IREEPyDM_ObjectType:$object
+  );
+
+  let assemblyFormat = [{
+    $primitive `:` type($primitive)  `->` type($object) attr-dict
+  }];
+}
+
+def IREEPyDM_UnboxOp : IREEPyDM_PureOp<"unbox"> {
+  let summary = "Unboxes an object to a specific primitive type";
+  let description = [{
+    Unboxes to a primitive, returning a failure result and a default initialized
+    primitive value on failure to unbox.
+  }];
+
+  let arguments = (ins
+    IREEPyDM_ObjectType:$object
+  );
+  let results = (outs
+    IREEPyDM_ExceptionResultType:$status,
+    IREEPyDM_AnyPrimitiveType:$primitive
+  );
+
+  let assemblyFormat = [{
+    $object `:` type($object) `->` type($primitive) attr-dict
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Name access
+// While the CPython nomenclature is a bit inconsistent, we follow it where
+// it makes sense. Names are split into a few categories:
+//   - Free Variables: Local variables that are not visible outside of the
+//     containing (actual, not lexical) function. CPython bytecode refers
+//     to operations on these as "FAST" (i.e. LOAD_FAST, STORE_FAST, DEL_FAST).
+//   - Cell Variables: Variables that exist as part of the function closure
+//     and are resolved through a level of indirection (cell).
+//   - Global Variables: Variables that are resolved via the function's
+//     globals().
+// At this level, all operations are done by name (at the CPython instruction
+// level, it is by ordinal).
+//===----------------------------------------------------------------------===//
+
+// TODO: Consider adding a dedicated free_var op to bind free variables and
+// move the string attribute there.
+def IREEPyDM_LoadFreeVarOp : IREEPyDM_PureOp<"load_free_var"> {
+  let summary = "Loads a boxed free variable";
+  let description = [{
+    Loads the boxed object for a free variable. This is not failable at
+    runtime, and if the slot is not initialized, it will contain a special
+    NotInitialized primitive which cannot convert to anything else.
+
+    When importing Python programs, they must use boxed accessors to free
+    variable storage. Further analysis and transformation can promote these
+    to unboxed access.
+  }];
+
+  let arguments = (ins StrAttr:$name);
+  let results = (outs IREEPyDM_ObjectType:$value);
+
+  let assemblyFormat = [{
+    $name `->` type($value) attr-dict
+  }];
+}
+
+def IREEPyDM_StoreFreeVarOp : IREEPyDM_Op<"store_free_var"> {
+  let summary = "Stores a boxed free variable";
+  let description = [{
+    Stores a boxed value to a free variable slot.
+  }];
+
+  let arguments = (ins StrAttr:$name, IREEPyDM_ObjectType:$value);
+
+  let assemblyFormat = [{
+    $name `,` $value `:` type($value) attr-dict
+  }];
+}
+
+def IREEPyDM_LoadFreeVarUnboxedOp : IREEPyDM_PureOp<"load_free_var_unboxed"> {
+  let summary = "Loads an unboxed free variable";
+  let description = [{
+    Loads an unboxed value from a free variable slot. This will not be
+    emitted in regular programs but can be used in intrinsics and by
+    optimizations to remove boxing when types are guaranteed. It is a
+    program error to load a mismatched type and will produce undefined
+    behavior.
+  }];
+
+  let arguments = (ins StrAttr:$name);
+  let results = (outs IREEPyDM_PrimitiveType:$value);
+
+  let assemblyFormat = [{
+    $name `->` type($value) attr-dict
+  }];
+}
+
+def IREEPyDM_StoreFreeVarUnboxedOp : IREEPyDM_Op<"store_free_var_unboxed"> {
+  let summary = "Stores an unboxed free variable";
+  let description = [{
+    Stores an unboxed value to a free variable slot.
+  }];
+
+  let arguments = (ins StrAttr:$name, IREEPyDM_PrimitiveType:$value);
+
+  let assemblyFormat = [{
+    $name `,` $value `:` type($value) attr-dict
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Value constructors
+//===----------------------------------------------------------------------===//
+
+def IREEPyDM_ConstantOp : IREEPyDM_PureOp<"constant", [ConstantLike]> {
+  let summary = "Constants for value types";
+  let description = [{
+    This op supports immutable value types that have direct coding as MLIR
+    attributes:
+      IntType -> IntegerAttr<i64>
+      FloatType -> FloatAttr<double>
+      StrType -> StringAttr
+      BytesType -> BytesAttr
+      BoolType -> IntegerAttr<i1>
+  }];
+
+  let arguments = (ins AnyAttr:$value);
+  let results = (outs AnyType);
+
+  let assemblyFormat = [{
+    $value `->` type(results) attr-dict
+  }];
+
+  let extraClassDeclaration = [{
+    Attribute getValue() { return (*this)->getAttr("value"); }
+  }];
+}
+
+def IREEPyDM_NoneOp : IREEPyDM_PureOp<"none"> {
+  let summary = "Gets the singleton NoneType primitive value";
+  let results = (outs IREEPyDM_NoneType:$value);
+  let assemblyFormat = [{
+    `->` type($value) attr-dict
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Type conversion
+//===----------------------------------------------------------------------===//
+
+def IREEPyDM_AsBoolOp : IREEPyDM_PureOp<"as_bool"> {
+  let summary = "Evaluates an arbitrary value for its truthiness";
+  let arguments = (ins IREEPyDM_AnyValueType:$value);
+  let results = (outs IREEPyDM_BoolType);
+  let assemblyFormat = [{
+    $value `:` type($value) `->` type(results) attr-dict
+  }];
+}
+
+def IREEPyDM_BoolToPredOp : IREEPyDM_PureOp<"bool_to_pred"> {
+  let summary = "Extracts a pred (i1) value from a BoolType";
+  let description = [{
+    This dialect does not use native MLIR IntegerTypes. This is used for
+    bridging to other dialects.
+  }];
+  let arguments = (ins IREEPyDM_BoolType:$value);
+  let results = (outs I1);
+  let assemblyFormat = [{
+    $value `:` type($value) `->` type(results) attr-dict
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Control flow
+//===----------------------------------------------------------------------===//
+
+def IREEPyDM_SelectOp : IREEPyDM_PureOp<"select", [
+  AllTypesMatch<["true_value", "false_value", "result"]>
+    ]> {
+  let summary = "Select a true or false value based on condition";
+  let arguments = (ins
+      IREEPyDM_BoolType:$condition,
+      IREEPyDM_AnyValueType:$true_value,
+      IREEPyDM_AnyValueType:$false_value);
+  let results = (outs IREEPyDM_AnyValueType:$result);
+  let assemblyFormat = [{
+    $condition `,` $true_value `,` $false_value `:` type($result) attr-dict
+  }];
+}
+
+def IREEPyDM_ExprStatementDiscardOp : IREEPyDM_Op<"expr_statement_discard"> {
+  let summary = "Anchors an expression evaluated as a statement";
+  let description = [{
+    This op is used early during import to keep a statement-evaluated expression
+    live until more effect information is available to anchor it properly.
+  }];
+  let arguments = (ins IREEPyDM_AnyValueType:$value);
+  let assemblyFormat = [{
+    $value `:` type($value) attr-dict
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Computation
+//===----------------------------------------------------------------------===//
+
+def IREEPyDM_DynamicBinaryPromoteOp : IREEPyDM_PureOp<"dynamic_binary_promote"> {
+  let summary = "Promotes two arguments that may be of numeric types";
+  let description = [{
+    Takes two values of arbitrary type which may be input to a following
+    binary arithmetic operation. If they are both numeric, returns the two
+    arguments, promoted to the bounding, common numeric type. Otherwise,
+    returns the inputs as-is.
+
+    This op will always produce boxed, type erased values since that is always
+    legal for runtime evaluation. However, for known input types, it will
+    canonicalize to various identities or fixed numerics, making the trailing
+    boxing trivial to type propagate past.
+
+    Note that this operation could be implemented with language-level pattern
+    matching, but it is both very frequently used and complicated from a
+    type system perspective. As such, we make it a black box with custom
+    canonicalization and lowering.
+  }];
+
+  let arguments = (ins
+    IREEPyDM_AnyValueType:$left,
+    IREEPyDM_AnyValueType:$right);
+  let results = (outs
+    IREEPyDM_GenericObjectType:$left_prime,
+    IREEPyDM_GenericObjectType:$right_prime);
+  let assemblyFormat = [{
+    $left `,` $right `:` type($left) `,` type($right) attr-dict
+  }];
+}
+
+def IREEPyDM_ApplyBinaryOp : IREEPyDM_PureOp<"apply_binary"> {
+  let summary = "Applies one of Python's binary operations";
+  let description = [{
+    The operation to apply is specified as per the data model:
+      https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
+    Here it is presented with leading and trailing double underscores (i.e.
+    "add", "sub", etc).
+
+    Numeric types must be promoted to a common type prior to application.
+  }];
+  let arguments = (ins
+    StrAttr:$dunder_name,
+    IREEPyDM_AnyValueType:$left,
+    IREEPyDM_AnyValueType:$right);
+  let results = (outs IREEPyDM_AnyValueType:$result);
+  let assemblyFormat = [{
+    $dunder_name `,` $left `,` $right `:` type(operands) `->` type(results) attr-dict
+  }];
+}
+
+def IREEPyDM_ApplyCompareOp : IREEPyDM_PureOp<"apply_compare"> {
+  let summary = "Performs a binary comparison";
+  let description = [{
+    Performs a comparison between two operands.
+
+    Op name is based on the dunder name of the rich comparison ops in the
+    data model:
+      "lt", "le", "eq", "ne", "gt", "ge"
+    With extensions for those that do not have a dunder name:
+      "is", "isnot", "in", "notin"
+
+    Numeric types must be promoted to a common type prior to application.
+  }];
+  let arguments = (ins
+    StrAttr:$dunder_name,
+    IREEPyDM_AnyValueType:$left,
+    IREEPyDM_AnyValueType:$right);
+  let results = (outs IREEPyDM_BoolType:$result);
+  let assemblyFormat = [{
+    $dunder_name `,` $left `,` $right `:` type(operands) attr-dict
+  }];
+}
+
+#endif // IREE_LLVM_EXTERNAL_PROJECTS_IREE_DIALECTS_DIALECT_IREEPYDM_IR_OPS_TD
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
index 928d6f7..fc3317e 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/CMakeLists.txt
@@ -1,7 +1,10 @@
 add_mlir_public_c_api_library(IREEDialectsCAPI
   Dialects.cpp
+  Utils.cpp
   LINK_LIBS PUBLIC
+  MLIRIR
   IREEDialectsIREEDialect
+  IREEDialectsIREEPyDMDialect
 )
 
 iree_dialects_target_includes(IREEDialectsCAPI)
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
index 924fbbc..650a2c3 100644
--- a/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/Dialects.cpp
@@ -7,6 +7,50 @@
 #include "iree-dialects-c/Dialects.h"
 
 #include "iree-dialects/Dialect/IREE/IREEDialect.h"
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.h"
 #include "mlir/CAPI/Registration.h"
 
+//===----------------------------------------------------------------------===//
+// IREEDialect
+//===----------------------------------------------------------------------===//
+
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(IREE, iree, mlir::iree::IREEDialect)
+
+//===----------------------------------------------------------------------===//
+// IREEPyDMDialect
+//===----------------------------------------------------------------------===//
+
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(IREEPyDM, iree_pydm,
+                                      mlir::iree_pydm::IREEPyDMDialect)
+
+bool mlirTypeIsAIREEPyDMPrimitiveType(MlirType type) {
+  return unwrap(type).isa<mlir::iree_pydm::PrimitiveType>();
+}
+
+#define IREEPYDM_DEFINE_NULLARY_TYPE(Name)                      \
+  bool mlirTypeIsAIREEPyDM##Name(MlirType type) {               \
+    return unwrap(type).isa<mlir::iree_pydm::Name##Type>();     \
+  }                                                             \
+  MlirType mlirIREEPyDM##Name##TypeGet(MlirContext ctx) {       \
+    return wrap(mlir::iree_pydm::Name##Type::get(unwrap(ctx))); \
+  }
+
+IREEPYDM_DEFINE_NULLARY_TYPE(Bool)
+IREEPYDM_DEFINE_NULLARY_TYPE(Bytes)
+IREEPYDM_DEFINE_NULLARY_TYPE(Integer)
+IREEPYDM_DEFINE_NULLARY_TYPE(ExceptionResult)
+IREEPYDM_DEFINE_NULLARY_TYPE(List)
+IREEPYDM_DEFINE_NULLARY_TYPE(None)
+IREEPYDM_DEFINE_NULLARY_TYPE(Real)
+IREEPYDM_DEFINE_NULLARY_TYPE(Str)
+IREEPYDM_DEFINE_NULLARY_TYPE(Tuple)
+IREEPYDM_DEFINE_NULLARY_TYPE(Type)
+
+bool mlirTypeIsAIREEPyDMObject(MlirType type) {
+  return unwrap(type).isa<mlir::iree_pydm::ObjectType>();
+}
+
+MlirType mlirIREEPyDMObjectTypeGet(MlirContext ctx, MlirType primitive) {
+  auto cppType = unwrap(primitive).cast<mlir::iree_pydm::PrimitiveType>();
+  return wrap(mlir::iree_pydm::ObjectType::get(unwrap(ctx), cppType));
+}
diff --git a/llvm-external-projects/iree-dialects/lib/CAPI/Utils.cpp b/llvm-external-projects/iree-dialects/lib/CAPI/Utils.cpp
new file mode 100644
index 0000000..d704f2b
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/CAPI/Utils.cpp
@@ -0,0 +1,20 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects-c/Utils.h"
+
+#include "mlir/CAPI/IR.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/SymbolTable.h"
+
+using namespace mlir;
+
+MlirOperation ireeLookupNearestSymbolFrom(MlirOperation fromOp,
+                                          MlirAttribute symbolRefAttr) {
+  auto symbolRefAttrCpp = unwrap(symbolRefAttr).cast<SymbolRefAttr>();
+  return wrap(
+      SymbolTable::lookupNearestSymbolFrom(unwrap(fromOp), symbolRefAttrCpp));
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
index 952be9f..61df04e 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(IREE)
+add_subdirectory(IREEPyDM)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/CMakeLists.txt
new file mode 100644
index 0000000..f33061b
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt
new file mode 100644
index 0000000..cb991b8
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_library(IREEDialectsIREEPyDMDialect
+  IREEPyDMDialect.cpp
+  IREEPyDMOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${IREE_DIALECTS_SOURCE_DIR}/include
+
+  DEPENDS
+  MLIRIREEPyDMOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSideEffectInterfaces
+)
+
+iree_dialects_target_includes(IREEDialectsIREEPyDMDialect)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMDialect.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMDialect.cpp
new file mode 100644
index 0000000..50043ae
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMDialect.cpp
@@ -0,0 +1,43 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.h"
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::iree_pydm;
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsDialect.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsTypes.cpp.inc"
+
+void IREEPyDMDialect::initialize() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOpsTypes.cpp.inc"
+      >();
+  addOperations<
+#define GET_OP_LIST
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.cpp.inc"
+      >();
+}
+
+Type IREEPyDMDialect::parseType(DialectAsmParser &parser) const {
+  StringRef typeTag;
+  Type genType;
+  if (succeeded(parser.parseKeyword(&typeTag)))
+    generatedTypeParser(getContext(), parser, typeTag, genType);
+  return genType;
+}
+
+void IREEPyDMDialect::printType(Type type, DialectAsmPrinter &printer) const {
+  (void)generatedTypePrinter(type, printer);
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMOps.cpp
new file mode 100644
index 0000000..1bb1fa4
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/IREEPyDM/IR/IREEPyDMOps.cpp
@@ -0,0 +1,149 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.h"
+
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/FunctionImplementation.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+using namespace mlir::iree_pydm;
+
+using PyCallOp = mlir::iree_pydm::CallOp;
+using PyFuncOp = mlir::iree_pydm::FuncOp;
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PyFuncOp::verifyType() {
+  // TODO: Enforce arg/result invariants.
+  return success();
+}
+
+static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
+  auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
+                          ArrayRef<Type> results,
+                          function_like_impl::VariadicFlag, std::string &) {
+    return builder.getFunctionType(argTypes, results);
+  };
+
+  return function_like_impl::parseFunctionLikeOp(
+      parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+static void print(PyFuncOp op, OpAsmPrinter &p) {
+  FunctionType fnType = op.getType();
+  function_like_impl::printFunctionLikeOp(
+      p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
+}
+
+static LogicalResult verify(PyFuncOp op) {
+  // TODO: Enforce invariants.
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// PatternMatchCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PatternMatchCallOp::verifySymbolUses(
+    SymbolTableCollection &symbolTable) {
+  auto verifySymbols = [&](ArrayAttr symbols) -> LogicalResult {
+    for (auto symbolAttr : symbols) {
+      auto symbol = symbolAttr.cast<FlatSymbolRefAttr>();
+      PyFuncOp fn =
+          symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, symbol);
+      if (!fn)
+        return emitOpError() << "'" << symbol.getValue()
+                             << "' does not reference a valid function";
+    }
+    return success();
+  };
+  auto genericsAttr = (*this)->getAttrOfType<ArrayAttr>("generic_match");
+  if (!genericsAttr)
+    return emitOpError(
+        "requires a 'generic_match' array of symbol reference attributes");
+  if (failed(verifySymbols(genericsAttr))) return failure();
+
+  auto specificsAttr = (*this)->getAttrOfType<ArrayAttr>("specific_match");
+  if (!specificsAttr)
+    return emitOpError(
+        "requires a 'specific_match' array of symbol reference attributes");
+  if (failed(verifySymbols(specificsAttr))) return failure();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PyCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // Check that the callee attribute was specified.
+  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+  if (!fnAttr)
+    return emitOpError("requires a 'callee' symbol reference attribute");
+  PyFuncOp fn = symbolTable.lookupNearestSymbolFrom<PyFuncOp>(*this, fnAttr);
+  if (!fn)
+    return emitOpError() << "'" << fnAttr.getValue()
+                         << "' does not reference a valid function";
+
+  // Verify that the operand and result types match the callee.
+  auto fnType = fn.getType();
+  if (fnType.getNumInputs() != getNumOperands())
+    return emitOpError("incorrect number of operands for callee");
+
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+    if (getOperand(i).getType() != fnType.getInput(i)) {
+      return emitOpError("operand type mismatch: expected operand type ")
+             << fnType.getInput(i) << ", but provided "
+             << getOperand(i).getType() << " for operand number " << i;
+    }
+  }
+
+  if (fnType.getNumResults() != getNumResults())
+    return emitOpError("incorrect number of results for callee");
+
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+    if (getResult(i).getType() != fnType.getResult(i)) {
+      auto diag = emitOpError("result type mismatch at index ") << i;
+      diag.attachNote() << "      op result types: " << getResultTypes();
+      diag.attachNote() << "function result types: " << fnType.getResults();
+      return diag;
+    }
+  }
+
+  return success();
+}
+
+FunctionType PyCallOp::getCalleeType() {
+  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
+}
+
+//===----------------------------------------------------------------------===//
+// DynamicCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DynamicCallOp::verifySymbolUses(
+    SymbolTableCollection &symbolTable) {
+  // Check that the callee attribute was specified.
+  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+  if (!fnAttr)
+    return emitOpError("requires a 'callee' symbol reference attribute");
+  Operation *fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr);
+  if (!fn || !isa<PyFuncOp>(fn))
+    return emitOpError() << "'" << fnAttr.getValue()
+                         << "' does not reference a valid function";
+  return success();
+}
+
+#define GET_OP_CLASSES
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.cpp.inc"
diff --git a/llvm-external-projects/iree-dialects/python/CMakeLists.txt b/llvm-external-projects/iree-dialects/python/CMakeLists.txt
index bfc56ef..d91b274 100644
--- a/llvm-external-projects/iree-dialects/python/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/python/CMakeLists.txt
@@ -19,6 +19,18 @@
   DIALECT_NAME iree
 )
 
+declare_mlir_dialect_python_bindings(
+  ADD_TO_PARENT IREEDialectsPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/IreePyDmBinding.td
+  SOURCES
+    dialects/_iree_pydm_ops_ext.py
+    dialects/iree_pydm/__init__.py
+  SOURCES_GLOB
+    dialects/iree_pydm/importer/*.py
+  DIALECT_NAME iree_pydm
+)
+
 ################################################################################
 # Extensions
 ################################################################################
@@ -44,6 +56,7 @@
   # TODO: Core is now implicitly building/registering all dialects, increasing
   # build burden by ~5x. Make it stop.
   MLIRPythonSources.Core
+  MLIRPythonSources.Dialects.std
   IREEDialectsPythonSources
   IREEDialectsPythonExtensions
 )
diff --git a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
index 743605c..8bf5864 100644
--- a/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
+++ b/llvm-external-projects/iree-dialects/python/IREEDialectsModule.cpp
@@ -5,17 +5,50 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree-dialects-c/Dialects.h"
+#include "iree-dialects-c/Utils.h"
 #include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/Diagnostics.h"
 #include "mlir-c/Registration.h"
 #include "mlir/Bindings/Python/PybindAdaptors.h"
 
 namespace py = pybind11;
+using namespace mlir::python::adaptors;
 
 PYBIND11_MODULE(_ireeDialects, m) {
   m.doc() = "iree-dialects main python extension";
 
+  auto irModule = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"));
+  auto typeClass = irModule.attr("Type");
+
+  //===--------------------------------------------------------------------===//
+  // Utils
+  //===--------------------------------------------------------------------===//
+
   m.def(
-      "register_iree_dialect",
+      "lookup_nearest_symbol_from",
+      [](MlirOperation fromOp, MlirAttribute symbol) {
+        if (!mlirAttributeIsASymbolRef(symbol)) {
+          throw std::invalid_argument("expected a SymbolRefAttr");
+        }
+        return ireeLookupNearestSymbolFrom(fromOp, symbol);
+      },
+      py::arg("fromOp"), py::arg("symbol"));
+
+  // TODO: Upstream this into the main Python bindings.
+  m.def(
+      "emit_error",
+      [](MlirLocation loc, std::string message) {
+        mlirEmitError(loc, message.c_str());
+      },
+      py::arg("loc"), py::arg("message"));
+
+  //===--------------------------------------------------------------------===//
+  // IREEDialect
+  //===--------------------------------------------------------------------===//
+  auto iree_m = m.def_submodule("iree");
+  iree_m.def(
+      "register_dialect",
       [](MlirContext context, bool load) {
         MlirDialectHandle handle = mlirGetDialectHandle__iree__();
         mlirDialectHandleRegisterDialect(handle, context);
@@ -23,5 +56,62 @@
           mlirDialectHandleLoadDialect(handle, context);
         }
       },
-      py::arg("context"), py::arg("load") = true);
+      py::arg("context") = py::none(), py::arg("load") = true);
+
+  //===--------------------------------------------------------------------===//
+  // IREEPyDMDialect
+  //===--------------------------------------------------------------------===//
+  auto iree_pydm_m = m.def_submodule("iree_pydm");
+
+  iree_pydm_m.def(
+      "register_dialect",
+      [](MlirContext context, bool load) {
+        MlirDialectHandle handle = mlirGetDialectHandle__iree_pydm__();
+        mlirDialectHandleRegisterDialect(handle, context);
+        if (load) {
+          mlirDialectHandleLoadDialect(handle, context);
+        }
+      },
+      py::arg("context") = py::none(), py::arg("load") = true);
+
+#define DEFINE_IREEPYDM_NULLARY_TYPE(Name)                                 \
+  mlir_type_subclass(iree_pydm_m, #Name "Type", mlirTypeIsAIREEPyDM##Name, \
+                     typeClass)                                            \
+      .def_classmethod(                                                    \
+          "get",                                                           \
+          [](py::object cls, MlirContext context) {                        \
+            return cls(mlirIREEPyDM##Name##TypeGet(context));              \
+          },                                                               \
+          py::arg("cls"), py::arg("context") = py::none());
+
+  DEFINE_IREEPYDM_NULLARY_TYPE(Bool)
+  DEFINE_IREEPYDM_NULLARY_TYPE(Bytes)
+  DEFINE_IREEPYDM_NULLARY_TYPE(ExceptionResult)
+  DEFINE_IREEPYDM_NULLARY_TYPE(Integer)
+  DEFINE_IREEPYDM_NULLARY_TYPE(List)
+  DEFINE_IREEPYDM_NULLARY_TYPE(None)
+  DEFINE_IREEPYDM_NULLARY_TYPE(Real)
+  DEFINE_IREEPYDM_NULLARY_TYPE(Str)
+  DEFINE_IREEPYDM_NULLARY_TYPE(Tuple)
+  DEFINE_IREEPYDM_NULLARY_TYPE(Type)
+
+  mlir_type_subclass(iree_pydm_m, "ObjectType", mlirTypeIsAIREEPyDMObject,
+                     typeClass)
+      .def_classmethod(
+          "get",
+          [](py::object cls, MlirContext context) {
+            return cls(mlirIREEPyDMObjectTypeGet(context, {nullptr}));
+          },
+          py::arg("cls"), py::arg("context") = py::none())
+      .def_classmethod(
+          "get_typed",
+          [](py::object cls, MlirType type) {
+            if (!mlirTypeIsAIREEPyDMPrimitiveType(type)) {
+              throw std::invalid_argument(
+                  "expected a primitive type when constructing object");
+            }
+            MlirContext context = mlirTypeGetContext(type);
+            return cls(mlirIREEPyDMObjectTypeGet(context, type));
+          },
+          py::arg("cls"), py::arg("type"));
 }
diff --git a/llvm-external-projects/iree-dialects/python/mlir/dialects/IreePyDmBinding.td b/llvm-external-projects/iree-dialects/python/mlir/dialects/IreePyDmBinding.td
new file mode 100644
index 0000000..900c7d1
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/mlir/dialects/IreePyDmBinding.td
@@ -0,0 +1,13 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef PYTHON_BINDINGS_IREE_PYDM_OPS
+#define PYTHON_BINDINGS_IREE_PYDM_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMOps.td"
+
+#endif // PYTHON_BINDINGS_IREE_PYDM_OPS
diff --git a/llvm-external-projects/iree-dialects/python/mlir/dialects/_iree_pydm_ops_ext.py b/llvm-external-projects/iree-dialects/python/mlir/dialects/_iree_pydm_ops_ext.py
new file mode 100644
index 0000000..5155972
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/mlir/dialects/_iree_pydm_ops_ext.py
@@ -0,0 +1,49 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# pytype: disable=attribute-error
+
+from .. import ir
+
+
+class FuncOp:
+  """Specialization for the func op class."""
+
+  @property
+  def body(self):
+    return self.regions[0]
+
+  @property
+  def type(self):
+    return ir.FunctionType(ir.TypeAttr(self.attributes["type"]).value)
+
+  @property
+  def py_return_type(self) -> ir.Type:
+    return self.type.results[1]
+
+  @property
+  def entry_block(self):
+    return self.regions[0].blocks[0]
+
+  # TODO: Why aren't these getters being auto-generated?
+  @property
+  def arg_names(self) -> ir.ArrayAttr:
+    return ir.ArrayAttr(self.attributes["arg_names"])
+
+  @property
+  def free_vars(self) -> ir.ArrayAttr:
+    return ir.ArrayAttr(self.attributes["free_vars"])
+
+  @property
+  def cell_vars(self) -> ir.ArrayAttr:
+    return ir.ArrayAttr(self.attributes["cell_vars"])
+
+  def add_entry_block(self):
+    """Add an entry block to the function body using the function signature to
+    infer block arguments. Returns the newly created block.
+    """
+    self.body.blocks.append(*self.type.inputs)
+    return self.body.blocks[0]
diff --git a/llvm-external-projects/iree-dialects/python/mlir/dialects/iree.py b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree.py
index 8e51281..c70532b 100644
--- a/llvm-external-projects/iree-dialects/python/mlir/dialects/iree.py
+++ b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree.py
@@ -5,4 +5,4 @@
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._iree_ops_gen import *
-from .._mlir_libs._ireeDialects import register_iree_dialect
+from .._mlir_libs._ireeDialects.iree import *
diff --git a/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/__init__.py b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/__init__.py
new file mode 100644
index 0000000..8c00c0d
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/__init__.py
@@ -0,0 +1,8 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._iree_pydm_ops_gen import *
+from ..._mlir_libs._ireeDialects.iree_pydm import *
diff --git a/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/__init__.py b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/__init__.py
new file mode 100644
index 0000000..8faf46e
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/__init__.py
@@ -0,0 +1,9 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .util import ImportContext, ImportHooks, ImportStage
+from .importer import Importer
+from .intrinsic_def import def_ir_macro_intrinsic, def_pattern_call_intrinsic, def_pyfunc_intrinsic
diff --git a/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/builtins_intrinsics.py b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/builtins_intrinsics.py
new file mode 100644
index 0000000..be5d0ac
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/builtins_intrinsics.py
@@ -0,0 +1,19 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional
+
+from .util import ImportContext, Intrinsic
+from ... import iree_pydm as d
+from .... import ir
+
+
+@Intrinsic.make_singleton
+class print(Intrinsic):
+
+  # TODO: Obviously not right.
+  def emit_immediate(self, ic: ImportContext) -> ir.Value:
+    return d.NoneOp(d.NoneType.get(ic.context), ip=ic.ip, loc=ic.loc).result
diff --git a/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/importer.py b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/importer.py
new file mode 100644
index 0000000..9cae4d6
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/importer.py
@@ -0,0 +1,702 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional, Sequence, Tuple, List, Union
+
+import ast
+import inspect
+import logging
+import sys
+import textwrap
+
+from .util import DefaultImportHooks, ImportContext, ImportHooks, ImportStage, Intrinsic
+
+from ... import iree_pydm as d
+from ... import std as std_d
+from .... import ir
+
+
+class Importer:
+  """Imports a Python construct into IR."""
+  __slots__ = [
+      "ic",
+      "hooks",
+  ]
+
+  def __init__(self, ic: ImportContext, hooks: Optional[ImportHooks] = None):
+    self.ic = ic
+    self.hooks = hooks or DefaultImportHooks()
+
+  def import_global_function(self,
+                             f,
+                             *,
+                             symbol: Optional[str] = None,
+                             visibility: Optional[str] = None) -> d.FuncOp:
+    """Imports a live Python global function.
+
+    This is just a placeholder of the simplest possible thing until a proper,
+    general mechanism is created.
+    """
+    ic = self.ic
+    filename, root_node = _get_function_ast(f)
+    fd_node = root_node.body[0]  # pytype: disable=attribute-error
+    self.ic.set_file_line_col(filename, fd_node.lineno, fd_node.col_offset)
+
+    if not symbol:
+      symbol = fd_node.name
+
+    # Define the function.
+    # TODO: Much more needs to be done here (arg/result mapping, etc)
+    logging.debug(":::::::")
+    logging.debug("::: Importing global function %s:\n%s", symbol,
+                  ast.dump(fd_node, include_attributes=False))
+
+    # Main import uses a FunctionContext but we aren't ready to create it yet.
+    dummy_stage = ImportStage(ic, self.hooks)
+
+    # Things we are short-cutting by inspecting the live function:
+    #   - freevars
+    #   - cellvars
+    #   - globals
+    #   - arg definitions
+    #   - annotation parsing
+    # Also, since this is just a toy right now, sticking to pos params.
+    code_object = f.__code__
+    with ic.scoped_ip(ir.InsertionPoint(ic.module.body)) as ip, ip, ic.loc:
+      f_signature = inspect.signature(f)
+      f_params = f_signature.parameters
+      arg_names = list(f_params.keys())
+      var_names = list(code_object.co_varnames)
+      f_input_types = [
+          self.hooks.resolve_annotation_to_type(dummy_stage, p.annotation)
+          for p in f_params.values()
+      ]
+      f_arg_names = ir.ArrayAttr.get(
+          [ir.StringAttr.get(name) for name in arg_names])
+      f_var_names = ir.ArrayAttr.get(
+          [ir.StringAttr.get(name) for name in var_names])
+      f_return_type = self.hooks.resolve_annotation_to_type(
+          dummy_stage, f_signature.return_annotation)
+      ir_f_type = ir.FunctionType.get(
+          f_input_types, [d.ExceptionResultType.get(), f_return_type],
+          context=ic.context)
+      f_op = d.FuncOp(
+          ir.StringAttr.get(symbol),
+          type=ir.TypeAttr.get(ir_f_type),
+          arg_names=f_arg_names,
+          free_vars=f_var_names,
+          cell_vars=ir.ArrayAttr.get([]),
+          sym_visibility=ir.StringAttr.get(visibility) if visibility else None)
+      entry_block = f_op.add_entry_block()
+
+    fctx = FunctionContext(self.ic,
+                           self.hooks,
+                           f_op,
+                           filename=filename,
+                           host_closure_vars=inspect.getclosurevars(f))
+    body_importer = FunctionDefBodyImporter(fctx)
+    with ic.scoped_ip(ir.InsertionPoint(entry_block)):
+      body_importer.import_body(fd_node)
+    return f_op
+
+
+class FunctionContext(ImportStage):
+  """Represents a function import in progress.
+
+  Note that construction of the outer FuncOp is performed externally. This
+  allows for multiple modes of operation:
+    - Bootstrapping a func from a live Python callable (via inspection)
+    - Parsing a function declaration purely from AST
+  """
+  __slots__ = [
+      "f_op",
+      "filename",
+      "arg_names",
+      "free_vars",
+      "cell_vars",
+      "host_closure_vars",
+  ]
+
+  def __init__(self,
+               ic: ImportContext,
+               hooks: ImportHooks,
+               f_op: d.FuncOp,
+               *,
+               host_closure_vars: Optional[inspect.ClosureVars] = None,
+               filename: str = "<anonymous>"):
+    super().__init__(ic, hooks)
+    self.f_op = f_op
+    self.host_closure_vars = host_closure_vars
+    self.filename = filename
+
+    # Keep sets of free and cell var names so that we know what kinds of
+    # loads to issue.
+    self.arg_names = set(
+        [ir.StringAttr(attr).value for attr in self.f_op.arg_names])
+    self.free_vars = set(
+        [ir.StringAttr(attr).value for attr in self.f_op.free_vars])
+    self.cell_vars = set(
+        [ir.StringAttr(attr).value for attr in self.f_op.cell_vars])
+
+  def update_loc(self, ast_node):
+    self.ic.set_file_line_col(self.filename, ast_node.lineno,
+                              ast_node.col_offset)
+
+  def cast_to_return_type(self, value: ir.Value) -> ir.Value:
+    """Casts an arbitrary value to the declared function return type."""
+    ic = self.ic
+    input_type = value.type
+    return_type = self.f_op.py_return_type
+    if input_type == return_type:
+      return value
+    if d.ObjectType.isinstance(return_type):
+      # Function returns a boxed value.
+      if d.ObjectType.isinstance(input_type):
+        # Already an object type but annotated differently. Something has
+        # gone wrong.
+        ic.abort(f"function declared return type {return_type} "
+                 f"is incompatible with actual return type {input_type}")
+      return ic.box(value)
+    else:
+      # Function returns a primitive value.
+      return ic.unbox(return_type, value)
+
+
+class BaseNodeVisitor(ast.NodeVisitor):
+  """Base class of a node visitor that aborts on unhandled nodes."""
+  IMPORTER_TYPE = "<unknown>"
+  __slots__ = [
+      "fctx",
+  ]
+
+  def __init__(self, fctx: FunctionContext):
+    super().__init__()
+    self.fctx = fctx
+
+  def visit(self, node):
+    self.fctx.update_loc(node)
+    return super().visit(node)
+
+  def generic_visit(self, ast_node: ast.AST):
+    logging.debug("UNHANDLED NODE: %s", ast.dump(ast_node))
+    self.fctx.ic.abort(f"unhandled python {self.IMPORTER_TYPE} "
+                       f"AST node {ast_node.__class__.__name__}: {ast_node}")
+
+
+class FunctionDefBodyImporter(BaseNodeVisitor):
+  """AST visitor for importing a function's statements.
+  Handles nodes that are direct children of a FunctionDef.
+  """
+  IMPORTER_TYPE = "statement"
+  __slots__ = [
+      "successor_block",
+      "_last_was_return",
+  ]
+
+  def __init__(self,
+               fctx: FunctionContext,
+               *,
+               successor_block: Optional[ir.Block] = None):
+    super().__init__(fctx)
+    self.successor_block = successor_block
+    self._last_was_return = False
+
+  def import_body(self, ast_fd: ast.FunctionDef):
+    ic = self.fctx.ic
+    # Function prologue: Initialize arguments.
+    for arg_index, arg_name in enumerate(
+        [ir.StringAttr(attr).value for attr in self.fctx.f_op.arg_names]):
+      self.initialize_argument(arg_index, arg_name)
+    # Import statements.
+    self.import_block(ast_fd.body)
+
+  def import_block(self, stmts: Sequence[ast.AST]):
+    ic = self.fctx.ic
+    for ast_stmt in stmts:
+      self._last_was_return = False
+      logging.debug("STMT: %s", ast.dump(ast_stmt, include_attributes=True))
+      self.visit(ast_stmt)
+    if not self._last_was_return:
+      with ic.ip, ic.loc:
+        # Add a default terminator.
+        if self.successor_block:
+          # Branch to the successor.
+          std_d.BranchOp([], dest=self.successor_block)
+        else:
+          # Return from function.
+          none_value = d.NoneOp(d.NoneType.get()).result
+          d.ReturnOp(none_value)
+
+  def initialize_argument(self, index, name):
+    fctx = self.fctx
+    ic = fctx.ic
+    entry_block = fctx.f_op.entry_block
+    arg_value = entry_block.arguments[index]
+    arg_value = ic.box(arg_value)
+    with ic.loc, ic.ip:
+      d.StoreFreeVarOp(ir.StringAttr.get(name), arg_value)
+
+  def visit_Pass(self, ast_node):
+    pass
+
+  def visit_Return(self, ast_node):
+    ic = self.fctx.ic
+    with ic.loc, ic.ip:
+      expr = ExpressionImporter(self.fctx)
+      expr.visit(ast_node.value)
+      d.ReturnOp(self.fctx.cast_to_return_type(expr.get_immediate()))
+      self._last_was_return = True
+
+  def visit_Assign(self, node: ast.Assign):
+    fctx = self.fctx
+    ic = fctx.ic
+    expr = ExpressionImporter(fctx)
+    expr.visit(node.value)
+    for target in node.targets:
+      fctx.update_loc(target)
+      target_ctx = target.ctx  # pytype: disable=attribute-error
+      if not isinstance(target_ctx, ast.Store):
+        # TODO: Del, AugStore, etc
+        ic.abort(
+            f"unsupported assignment context type {target_ctx.__class__.__name__}"
+        )
+
+      # TODO: Support assignment to non-free slots.
+      boxed = ic.box(expr.get_immediate())
+      with ic.loc, ic.ip:
+        target_id = target.id  # pytype: disable=attribute-error
+        d.StoreFreeVarOp(ir.StringAttr.get(target_id), boxed)
+
+  def visit_Expr(self, node: ast.Expr):
+    fctx = self.fctx
+    ic = fctx.ic
+
+    expr = ExpressionImporter(fctx)
+    expr.visit(node.value)
+    with ic.loc, ic.ip:
+      d.ExprStatementDiscardOp(expr.get_immediate())
+
+  def visit_If(self, node: ast.If):
+    fctx = self.fctx
+    ic = fctx.ic
+    # Emit the test.
+    test_expr = ExpressionImporter(fctx)
+    test_expr.visit(node.test)
+
+    # We create a successor block that a non terminating block will branch to.
+    predecessor_block = ic.ip.block
+    successor_block = predecessor_block.create_after()
+
+    with ic.ip, ic.loc:
+      test_bool = d.AsBoolOp(d.BoolType.get(), test_expr.get_immediate()).result
+      test_pred = d.BoolToPredOp(ir.IntegerType.get_signless(1),
+                                 test_bool).result
+
+    # Emit the false block
+    if not node.orelse:
+      # Else just jumps to the successor.
+      false_block = successor_block
+    else:
+      # Emit the false body.
+      false_block = predecessor_block.create_after()
+      with ic.scoped_ip(ir.InsertionPoint(false_block)):
+        else_importer = FunctionDefBodyImporter(fctx,
+                                                successor_block=successor_block)
+        else_importer.import_block(node.orelse)
+    # Emit the true body.
+    true_block = predecessor_block.create_after()
+    with ic.scoped_ip(ir.InsertionPoint(true_block)):
+      body_importer = FunctionDefBodyImporter(fctx,
+                                              successor_block=successor_block)
+      body_importer.import_block(node.body)
+
+    # Now that we have true/false blocks, emit the cond_br in the original
+    # block.
+    fctx.update_loc(node)
+    with ic.ip, ic.loc:
+      std_d.CondBranchOp(condition=test_pred,
+                         trueDestOperands=[],
+                         falseDestOperands=[],
+                         trueDest=true_block,
+                         falseDest=false_block)
+
+    # And emission continues here.
+    ic.reset_ip(ir.InsertionPoint(successor_block))
+
+
+ExpressionResult = Union[Intrinsic, ir.Value]
+
+
+class ExpressionImporter(BaseNodeVisitor):
+  """Imports expression nodes.
+  Visitor methods must either raise an exception or call _set_result.
+  """
+  IMPORTER_TYPE = "expression"
+  __slots__ = [
+      "_result",
+  ]
+
+  def __init__(self, fctx: FunctionContext):
+    super().__init__(fctx)
+    self._result: Optional[ExpressionResult] = None
+
+  def visit(self, node):
+    super().visit(node)
+    assert self._result is not None, (
+        f"ExpressionImporter did not assign a value ({ast.dump(node)})")
+
+  def get_immediate(self) -> ir.Value:
+    """Gets the expression result by emitting it as an immediate value."""
+    if isinstance(self._result, ir.Value):
+      return self._result
+    else:
+      # Intrinsic.
+      return self._result.emit_immediate(self.fctx.ic)
+
+  def get_call_result(self, args: Sequence[ir.Value]) -> ir.Value:
+    """Perfoms a call against the expression result, returning the value."""
+    if isinstance(self._result, ir.Value):
+      return self.fctx.ic.abort(
+          f"TODO: User defined function call not supported")
+    else:
+      # Intrinsic.
+      return self._result.emit_call(self.fctx, args=args, keywords=[])
+
+  def get_static_attribute(self, attr_name: str) -> ExpressionResult:
+    fctx = self.fctx
+    ic = fctx.ic
+    if isinstance(self._result, ir.Value):
+      # Immediate.
+      ic.abort(f"TODO: Runtime attribute resolution NYI")
+    else:
+      # Intrinsic.
+      resolved = self._result.resolve_static_getattr(ic, attr_name)
+      if resolved is Intrinsic.UNBOUND_VALUE:
+        ic.abort(f"attribute {attr_name} not found for compile time intrinsic "
+                 f"{self._result}")
+      return resolved
+
+  def _set_result(self, result: ExpressionResult):
+    assert not self._result
+    assert isinstance(result, (Intrinsic, ir.Value)), (
+        f"Not an ExpressionResult: is a {result.__class__} ({result})")
+    self._result = result
+
+  def visit_Name(self, node: ast.Name):
+    fctx = self.fctx
+    ic = fctx.ic
+    if not isinstance(node.ctx, ast.Load):
+      # Note that the other context types (Store, Del, Star) cannot appear
+      # in expressions.
+      fctx.abort(f"Unsupported expression name context type %s")
+
+    # Handle free variables (also includes args).
+    with ic.loc:
+      if node.id in self.fctx.free_vars:
+        self._set_result(
+            d.LoadFreeVarOp(d.ObjectType.get(),
+                            ir.StringAttr.get(node.id),
+                            ip=ic.ip).result)
+        return
+
+    # Fall-back to global resolution.
+    resolved = fctx.hooks.resolve_global(
+        fctx, node.id, host_closure_vars=fctx.host_closure_vars)
+    if resolved == Intrinsic.UNBOUND_VALUE:
+      ic.abort(f"could not resolve global {node.id}")
+    self._set_result(resolved)
+
+  def visit_Attribute(self, node: ast.Attribute):
+    sub_eval = ExpressionImporter(self.fctx)
+    sub_eval.visit(node.value)
+    self._set_result(sub_eval.get_static_attribute(node.attr))
+
+  def visit_BinOp(self, node: ast.BinOp):
+    fctx = self.fctx
+    ic = fctx.ic
+    op = node.op
+    for op_type, dunder_name in _AST_BINOP_TYPE_TO_DUNDER:
+      if isinstance(op, op_type):
+        break
+    else:
+      ic.abort(f"unsupported binary operation {op}")
+
+    left = ExpressionImporter(fctx)
+    left.visit(node.left)
+    right = ExpressionImporter(fctx)
+    right.visit(node.right)
+    fctx.update_loc(node)
+
+    with ic.loc, ic.ip:
+      object_type = d.ObjectType.get()
+      # TODO: There are some exceptions to blanket binary promotion:
+      #   - Truediv has its own promotion rules
+      #   - Shl, Shr are different
+      left_prime, right_prime = d.DynamicBinaryPromoteOp(
+          object_type, object_type, left.get_immediate(),
+          right.get_immediate()).results
+      result = d.ApplyBinaryOp(object_type, ir.StringAttr.get(dunder_name),
+                               left_prime, right_prime).result
+      self._set_result(result)
+
+  def visit_BoolOp(self, node):
+    fctx = self.fctx
+    ic = fctx.ic
+    if isinstance(node.op, ast.And):
+      return_first_true = False
+    elif isinstance(node.op, ast.Or):
+      return_first_true = True
+    else:
+      ic.abort(f"unknown bool op {ast.dump(node.op)}")
+
+    def emit_next(next_nodes):
+      next_node = next_nodes[0]
+      next_nodes = next_nodes[1:]
+
+      # Evaluate sub-expression.
+      sub_expression = ExpressionImporter(fctx)
+      sub_expression.visit(next_node)
+      next_value = sub_expression.get_immediate()
+      if not next_nodes:
+        return next_value
+
+      bool_value = d.AsBoolOp(d.BoolType.get(), next_value, ip=ic.ip).result
+      condition_value = d.BoolToPredOp(ir.IntegerType.get_signless(1),
+                                       bool_value,
+                                       ip=ic.ip).result
+      # TODO: See if we can re-organize this to not force boxing through the
+      # if.
+      if_op, then_ip, else_ip = ic.scf_IfOp([d.ObjectType.get()],
+                                            condition_value, True)
+      # Short-circuit return case.
+      with ic.scoped_ip(then_ip if return_first_true else else_ip):
+        next_value_casted = ic.box(next_value)
+        ic.scf_YieldOp([next_value_casted])
+
+      # Nested evaluate next case.
+      with ic.scoped_ip(else_ip if return_first_true else then_ip):
+        nested_value = emit_next(next_nodes)
+        nested_value_casted = next_value_casted = ic.box(nested_value)
+        ic.scf_YieldOp([nested_value_casted])
+
+      return if_op.result
+
+    with ic.loc:
+      self._set_result(emit_next(node.values))
+
+  def visit_Compare(self, node: ast.Compare):
+    # Short-circuit comparison (degenerates to binary comparison when just
+    # two operands).
+    fctx = self.fctx
+    ic = fctx.ic
+    false_value = ic.emit_constant(False)
+
+    def emit_next(left_value, comparisons):
+      operation, right_node = comparisons[0]
+      comparisons = comparisons[1:]
+
+      # Determine operation type.
+      for (ast_type, op_name, reflective_op_name,
+           needs_promotion) in _AST_COMPAREOP_TYPE_TO_INFO:
+        if isinstance(operation, ast_type):
+          break
+      else:
+        ic.abort(f"unsupported comparison op: {operation}")
+
+      # Lazy evaluate the right.
+      right_expr = ExpressionImporter(fctx)
+      right_expr.visit(right_node)
+      right_value = right_expr.get_immediate()
+      with ic.ip, ic.loc:
+        object_type = d.ObjectType.get()
+        # Promote if needed.
+        if needs_promotion:
+          left_prime, right_prime = d.DynamicBinaryPromoteOp(
+              object_type, object_type, left_value, right_value).results
+        else:
+          left_prime = left_value
+          right_prime = right_expr.get_immediate()
+
+        # Apply comparison.
+        compare_result = d.ApplyCompareOp(d.BoolType.get(),
+                                          ir.StringAttr.get(op_name),
+                                          left_prime, right_prime).result
+      # Terminate by yielding the final compare result.
+      if not comparisons:
+        return compare_result
+
+      # Emit if for short circuit eval.
+      # Since this is an 'and', all else clauses yield a false value.
+      with ic.ip, ic.loc:
+        compare_result_i1 = d.BoolToPredOp(ir.IntegerType.get_signless(1),
+                                           compare_result).result
+        if_op, then_ip, else_ip = ic.scf_IfOp([d.BoolType.get()],
+                                              compare_result_i1, True)
+      # Build the else clause.
+      with ic.scoped_ip(else_ip):
+        ic.scf_YieldOp([false_value])
+
+      # Build the then clause.
+      with ic.scoped_ip(then_ip):
+        nested_result = emit_next(right_value, comparisons)
+        ic.scf_YieldOp([nested_result])
+
+      return if_op.result
+
+    # Compute left and recurse for lazy evaluation.
+    left_expr = ExpressionImporter(fctx)
+    left_expr.visit(node.left)
+    self._set_result(
+        emit_next(left_expr.get_immediate(),
+                  list(zip(node.ops, node.comparators))))
+
+  def visit_Call(self, node: ast.Call):
+    fctx = self.fctx
+    ic = fctx.ic
+    func_expr = ExpressionImporter(fctx)
+    func_expr.visit(node.func)
+
+    args = []
+    for ast_arg in node.args:
+      arg_expr = ExpressionImporter(fctx)
+      arg_expr.visit(ast_arg)
+      args.append(arg_expr.get_immediate())
+
+    if node.keywords:
+      ic.abort(f"TODO: keyword calls are not yet supported")
+    fctx.update_loc(node)
+    self._set_result(func_expr.get_call_result(args=args))
+
+  def visit_UnaryOp(self, node: ast.UnaryOp):
+    fctx = self.fctx
+    ic = fctx.ic
+    with ic.ip, ic.loc:
+      op = node.op
+
+      # Evaluate sub-expression.
+      sub_expression = ExpressionImporter(fctx)
+      sub_expression.visit(node.operand)
+      fctx.update_loc(node)
+      operand_value = sub_expression.get_immediate()
+
+      if isinstance(op, ast.Not):
+        # Special handling for logical-not.
+        bool_value = d.AsBoolOp(d.BoolType.get(), operand_value).result
+        true_value = ic.emit_constant(True)
+        false_value = ic.emit_constant(False)
+        self._set_result(
+            d.SelectOp(d.BoolType.get(), bool_value, false_value,
+                       true_value).result)
+      else:
+        ic.abort(f"Unknown unary op {ast.dump(op)}")
+
+  def visit_IfExp(self, node: ast.IfExp):
+    fctx = self.fctx
+    ic = fctx.ic
+
+    # Evaluate test sub-expression.
+    sub_expression = ExpressionImporter(fctx)
+    sub_expression.visit(node.test)
+    fctx.update_loc(node)
+    test_value = sub_expression.get_immediate()
+
+    # Interpret as bool.
+    test_bool = d.AsBoolOp(d.BoolType.get(), test_value, ip=ic.ip,
+                           loc=ic.loc).result
+    test_pred = d.BoolToPredOp(ir.IntegerType.get_signless(1),
+                               test_bool,
+                               ip=ic.ip,
+                               loc=ic.loc).result
+
+    # TODO: There is a hazard here if then and else refine to different
+    # boxed types. Needs a derefine cast. Also we are boxing to type erased
+    # types to satisfy scf.if verifier. Do something better.
+    if_op, then_ip, else_ip = ic.scf_IfOp([d.ObjectType.get(ic.context)],
+                                          test_pred, True)
+    # Build the then clause
+    with ic.scoped_ip(then_ip):
+      # Evaluate the true clause within the if body.
+      sub_expression = ExpressionImporter(fctx)
+      sub_expression.visit(node.body)
+      then_result = sub_expression.get_immediate()
+      ic.scf_YieldOp([ic.box(then_result, to_typed=False)])
+
+    # Build the then clause.
+    with ic.scoped_ip(else_ip):
+      sub_expression = ExpressionImporter(fctx)
+      sub_expression.visit(node.orelse)
+      orelse_result = sub_expression.get_immediate()
+      ic.scf_YieldOp([ic.box(orelse_result, to_typed=False)])
+
+    self._set_result(if_op.result)
+
+  if sys.version_info < (3, 8, 0):
+    # <3.8 breaks these out into separate AST classes.
+    def visit_Num(self, ast_node):
+      self._set_result(self.fctx.ic.emit_constant(ast_node.n))
+
+    def visit_Str(self, ast_node):
+      self._set_result(self.fctx.ic.emit_constant(ast_node.s))
+
+    def visit_Bytes(self, ast_node):
+      self._set_result(self.fctx.ic.emit_constant(ast_node.s))
+
+    def visit_NameConstant(self, ast_node):
+      self._set_result(self.fctx.ic.emit_constant(ast_node.value))
+
+    def visit_Ellipsis(self, ast_node):
+      self._set_result(self.fctx.ic.emit_constant(...))
+  else:
+    # >= 3.8
+    def visit_Constant(self, ast_node):
+      self._set_result(self.fctx.ic.emit_constant(ast_node.value))
+
+
+def _get_function_ast(f) -> Tuple[str, ast.AST]:
+  filename = inspect.getsourcefile(f)
+  source_lines, start_lineno = inspect.getsourcelines(f)
+  source = "".join(source_lines)
+  source = textwrap.dedent(source)
+  ast_root = ast.parse(source, filename=filename)
+  ast.increment_lineno(ast_root, start_lineno - 1)
+  return filename, ast_root
+
+
+# Maps an AST type (from BinOp.op) to the dunder name in the Python data
+# model.
+_AST_BINOP_TYPE_TO_DUNDER = (
+    (ast.Add, "add"),
+    (ast.Sub, "sub"),
+    (ast.Mult, "mul"),
+    (ast.MatMult, "matmul"),
+    (ast.Div, "truediv"),
+    (ast.FloorDiv, "floordiv"),
+    (ast.Mod, "mod"),
+    (ast.Pow, "pow"),
+    (ast.LShift, "lshift"),
+    (ast.RShift, "rshift"),
+    (ast.BitAnd, "and"),
+    (ast.BitOr, "or"),
+    (ast.BitXor, "xor"),
+)
+
+# Maps AST Compare op type. Fields are:
+#   [0] = ast type
+#   [1] = op name (root of the dunder name for rich compare ops)
+#   [2] = reflective op name
+#   [3] = whether numeric promotion should take place
+_AST_COMPAREOP_TYPE_TO_INFO = (
+    (ast.Lt, "lt", "gte", True),
+    (ast.LtE, "le", "gt", True),
+    (ast.Eq, "eq", "eq", True),
+    (ast.NotEq, "ne", "ne", True),
+    (ast.Gt, "gt", "le", True),
+    (ast.GtE, "ge", "lt", True),
+    (ast.Is, "is", "is", False),
+    (ast.IsNot, "isnot", "isnot", False),
+    (ast.In, "in", "in", False),
+    (ast.NotIn, "notin", "notin", False),
+)
diff --git a/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/intrinsic_def.py b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/intrinsic_def.py
new file mode 100644
index 0000000..ee1a702
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/intrinsic_def.py
@@ -0,0 +1,147 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Any, Optional, Sequence
+
+import functools
+
+from .importer import Importer
+from .util import DefaultImportHooks, ImportStage, Intrinsic, FuncProvidingIntrinsic
+
+from ... import iree_pydm as d
+from .... import ir
+
+
+def def_pyfunc_intrinsic(f=None, *, symbol: Optional[str] = None):
+  """Defines an intrinsic function that will be included in the module."""
+  if f is None:
+    return functools.partial(def_pyfunc_intrinsic, symbol=symbol)
+
+  if symbol is None:
+    symbol = f.__name__
+
+  class PyIntrinsicFunc(FuncProvidingIntrinsic):
+    """The intrinsic which will compile the func and emit calls to it."""
+
+    def get_provided_func_symbol(self, stage: ImportStage) -> str:
+      ic = stage.ic
+      symbol_attr = ir.FlatSymbolRefAttr.get(symbol, context=ic.context)
+      existing = ic.lookup_symbol(symbol_attr)
+      if not existing:
+        _import_global_function(stage, f, symbol=symbol)
+      return symbol
+
+    def emit_call(self, stage: ImportStage, args: Sequence[ir.Value],
+                  keywords: Sequence[Any]) -> ir.Value:
+      ic = stage.ic
+      if keywords:
+        ic.abort(f"{self} only supports positional arguments")
+      resolved_symbol = self.get_provided_func_symbol(stage)
+      with ic.ip, ic.loc:
+        exc_result, call_result = d.DynamicCallOp(
+            d.ExceptionResultType.get(), d.ObjectType.get(),
+            ir.FlatSymbolRefAttr.get(resolved_symbol), args).results
+        d.RaiseOnFailureOp(exc_result)
+        return call_result
+
+    def __repr__(self):
+      return f"<py intrinsic {symbol}>"
+
+  return PyIntrinsicFunc()
+
+
+def def_ir_macro_intrinsic(f=None):
+  """Defines an IR macro intrinsic.
+
+  The decorated function must take as positional arguments the
+  ImportStage followed by *`ir.Value` instances corresponding with the
+  call and return a single `ir.Value`.
+
+  The function will be evaluated in an MLIR with context including
+  context, location and ip.
+  """
+  if f is None:
+    return functools.partial(def_ir_macro_intrinsic)
+
+  class IrIntrinsicMacro(Intrinsic):
+
+    def emit_call(self, stage: ImportStage, args: Sequence[ir.Value],
+                  keywords: Sequence[Any]) -> ir.Value:
+      ic = stage.ic
+      if keywords:
+        ic.abort(f"{self} only supports positional arguments")
+
+      # TODO: Apply pre-conditions on number of arguments, etc, for nicer
+      # error messages.
+      with ic.loc, ic.ip:
+        result = f(stage, *args)
+        if not isinstance(result, ir.Value):
+          ic.abort(f"compiler intrinsic macro must return an IR Value: {f}")
+        return result
+
+    def __repr__(self):
+      return f"<IR macro {self}>"
+
+  return IrIntrinsicMacro()
+
+
+def def_pattern_call_intrinsic(match_generic: Sequence[Any] = (),
+                               match_specific: Sequence[Any] = ()):
+  """Defines a multi-function call intrinsic."""
+
+  def _extract_symbol_intrinsics(
+      matches: Sequence[Any]) -> Sequence[FuncProvidingIntrinsic]:
+    names = []
+    for m in matches:
+      assert isinstance(m, FuncProvidingIntrinsic), (
+          f"Match functions for a def_multi_func_intrinsic must be "
+          f"a FuncProvidingIntrinsic. Got: {m}")
+      names.append(m)
+    return names
+
+  generic_intrinsics = _extract_symbol_intrinsics(match_generic)
+  specific_intrinsics = _extract_symbol_intrinsics(match_specific)
+
+  class IrPatternCallIntrinsic(Intrinsic):
+
+    def emit_call(self, stage: ImportStage, args: Sequence[ir.Value],
+                  keywords: Sequence[Any]) -> ir.Value:
+      ic = stage.ic
+      if keywords:
+        ic.abort(f"{self} only supports positional arguments")
+
+      generic_symbol_names = [
+          i.get_provided_func_symbol(stage) for i in generic_intrinsics
+      ]
+      specific_symbol_names = [
+          i.get_provided_func_symbol(stage) for i in specific_intrinsics
+      ]
+
+      with ic.ip, ic.loc:
+        generic_attrs = ir.ArrayAttr.get(
+            [ir.FlatSymbolRefAttr.get(s) for s in generic_symbol_names])
+        specific_attrs = ir.ArrayAttr.get(
+            [ir.FlatSymbolRefAttr.get(s) for s in specific_symbol_names])
+        exc_result, call_result = d.PatternMatchCallOp(
+            d.ExceptionResultType.get(), d.ObjectType.get(), generic_attrs,
+            specific_attrs, args).results
+        d.RaiseOnFailureOp(exc_result)
+        return call_result
+
+    def __repr__(self):
+      return (f"<pattern call generic={generic_intrinsics}, "
+              f"specific={specific_intrinsics}>")
+
+  return IrPatternCallIntrinsic()
+
+
+def _import_global_function(parent_stage: ImportStage, f, *,
+                            symbol: str) -> d.FuncOp:
+  """In a fresh import context, import a global function."""
+  # Note that we are bringing out own hooks, since intrinsics are compiled with
+  # defaults.
+  importer = Importer(parent_stage.ic, hooks=DefaultImportHooks())
+  return importer.import_global_function(f, symbol=symbol, visibility="private")
diff --git a/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/private_intrinsics.py b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/private_intrinsics.py
new file mode 100644
index 0000000..14c0fd4
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/private_intrinsics.py
@@ -0,0 +1,7 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .util import Intrinsic
diff --git a/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/test_util.py b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/test_util.py
new file mode 100644
index 0000000..2be914c
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/test_util.py
@@ -0,0 +1,17 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from . import *
+
+
+def test_import_global(f):
+  """Imports a global function and prints corresponding IR."""
+  print("// -----")
+  ic = ImportContext()
+  imp = Importer(ic)
+  imp.import_global_function(f)
+  print(ic.module.operation)
+  return f
diff --git a/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/util.py b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/util.py
new file mode 100644
index 0000000..5fcdc5f
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/python/mlir/dialects/iree_pydm/importer/util.py
@@ -0,0 +1,390 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from types import ModuleType
+from typing import Any, Mapping, Optional, Sequence, Union
+
+import contextlib
+import inspect
+
+from ... import iree_pydm as d
+from .... import ir
+# TODO: Upstream emit_error and use that instead.
+from ...._mlir_libs._ireeDialects import emit_error as _emit_error, lookup_nearest_symbol_from as _lookup_nearest_symbol_from
+
+
+class EmittedError(Exception):
+  """Exception subclass that indicates an error diagnostic has been emitted.
+  By throwing, this lets us abort and handle at a higher level so as not
+  to duplicate diagnostics.
+  """
+
+  def __init__(self, loc: ir.Location, message: str):
+    super().__init__(loc, message)
+
+  @property
+  def loc(self) -> ir.Location:
+    return self.args[0]
+
+  @property
+  def message(self) -> str:
+    return self.args[1]
+
+
+class UserReportableError(Exception):
+  """Used to raise an error with a message that should be reported to the user.
+  Raising this error indicates that the error message is well formed and
+  makes sense without a traceback.
+  """
+
+  def __init__(self, message: str):
+    super().__init__(message)
+
+  @property
+  def message(self) -> str:
+    return self.args[0]
+
+
+class ImportContext:
+  """Context for importing Python structures into IR."""
+
+  def __init__(self,
+               *,
+               context: Optional[ir.Context] = None,
+               module: Optional[ir.Module] = None):
+    self.context = context if context else create_context()
+    self.loc = ir.Location.unknown(context=self.context)
+    if module:
+      self.module = module
+    else:
+      self.module = ir.Module.create(self.loc)
+    self._ip_stack = []
+
+  def __str__(self):
+    return str(self.module)
+
+  def set_file_line_col(self, file: str, line: int, col: int):
+    self.loc = ir.Location.file(file, line, col, context=self.context)
+
+  @contextlib.contextmanager
+  def scoped_ip(self, scoped_ip: ir.InsertionPoint):
+    self.push_ip(scoped_ip)
+    try:
+      yield scoped_ip
+    finally:
+      self.pop_ip()
+
+  def push_ip(self, scoped_ip: ir.InsertionPoint):
+    self._ip_stack.append(scoped_ip)
+
+  def pop_ip(self):
+    assert self._ip_stack, "Mismatched push_ip/pop_ip: stack is empty on pop"
+    del self._ip_stack[-1]
+
+  @property
+  def ip(self):
+    assert self._ip_stack, "InsertionPoint requested but stack is empty"
+    return self._ip_stack[-1]
+
+  def reset_ip(self, ip: ir.InsertionPoint):
+    """Resets the TOS insertion point.
+
+    This is needed if splitting exection across blocks.
+    """
+    assert self._ip_stack, "InsertionPoint requested but stack is empty"
+    self._ip_stack[-1] = ip
+
+  def abort(self, message: str):
+    """Emits an error diagnostic and raises an exception to abort."""
+    loc = self.loc
+    _emit_error(loc, message)
+    raise EmittedError(loc, message)
+
+  def lookup_symbol(self, symbol_attr):
+    return _lookup_nearest_symbol_from(self.module.operation, symbol_attr)
+
+  def box(self, value: ir.Value, to_typed: Optional[bool] = True) -> ir.Value:
+    """Boxes a value if necessary."""
+    with self.ip, self.loc:
+      t = value.type
+      if d.ObjectType.isinstance(t):
+        # Already boxed.
+        return value
+      boxed_type = d.ObjectType.get_typed(t) if to_typed else d.ObjectType.get()
+      return d.BoxOp(boxed_type, value).result
+
+  def unbox(self, to_type: ir.Type, value: ir.Value) -> ir.Value:
+    with self.ip, self.loc:
+      exc_result, unboxed = d.UnboxOp(d.ExceptionResultType.get(), to_type,
+                                      value).results
+      d.RaiseOnFailureOp(exc_result)
+      return unboxed
+
+  def emit_constant(self, value: Any) -> ir.Value:
+    """Emits a constant for a supported Python value."""
+    # Handle the various primitives directly.
+    with self.loc, self.ip:
+      if value is None:
+        return d.NoneOp(d.NoneType.get()).result
+      elif value is True or value is False:
+        return d.ConstantOp(
+            d.BoolType.get(),
+            ir.IntegerAttr.get(ir.IntegerType.get_signless(1),
+                               1 if value else 0)).result
+      elif isinstance(value, int):
+        return d.ConstantOp(
+            d.IntegerType.get(),
+            ir.IntegerAttr.get(ir.IntegerType.get_signed(64), value)).result
+      elif isinstance(value, float):
+        return d.ConstantOp(d.RealType.get(),
+                            ir.FloatAttr.get(ir.F64Type.get(), value)).result
+      elif isinstance(value, str):
+        return d.ConstantOp(d.StrType.get(), ir.StringAttr.get(value)).result
+      elif isinstance(value, bytes):
+        return d.ConstantOp(d.BytesType.get(), ir.StringAttr.get(value)).result
+    self.abort(
+        f"unsupported Python constant value '{value}' (an {value.__class__}))")
+
+  # TODO: Map the SCF dialect properly upstream.
+  def scf_IfOp(self, results, condition: ir.Value, with_else_region: bool):
+    """Creates an SCF if op.
+    Returns:
+      (if_op, then_ip, else_ip) if with_else_region, otherwise (if_op, then_ip)
+    """
+    op = ir.Operation.create("scf.if",
+                             results=results,
+                             operands=[condition],
+                             regions=2 if with_else_region else 1,
+                             loc=self.loc,
+                             ip=self.ip)
+    then_region = op.regions[0]
+    then_block = then_region.blocks.append()
+    if with_else_region:
+      else_region = op.regions[1]
+      else_block = else_region.blocks.append()
+      return op, ir.InsertionPoint(then_block), ir.InsertionPoint(else_block)
+    else:
+      return op, ir.InsertionPoint(then_block)
+
+  def scf_YieldOp(self, operands):
+    return ir.Operation.create("scf.yield",
+                               operands=operands,
+                               loc=self.loc,
+                               ip=self.ip)
+
+
+class ImportStage:
+  """Base class for activities representing isolated import activity.
+
+  This is used, for example, to isolate activities, targeting the same
+  module but different functions.
+  """
+  __slots__ = [
+      "ic",
+      "hooks",
+  ]
+
+  def __init__(self, ic: ImportContext, hooks: "ImportHooks"):
+    self.ic = ic
+    self.hooks = hooks
+
+
+class _UnboundValue:
+
+  def __repr__(self):
+    return "<UnboundValue>"
+
+
+def _get_module_type():
+  import abc  # Not special - just standard.
+  return type(abc)
+
+
+_ModuleType = _get_module_type()
+
+
+class Intrinsic:
+  """An object that controls its own interaction with the AST and IR.
+
+  Intrinsics are typically returned as a result of evaluating globals in the
+  hosting Python process. They have methods on them for controlling how
+  evaluation and IR emission should proceed. They can also implenent
+  __call__, __getattr__, etc in order to support dual use, either in the host
+  process or the compiled process.
+  """
+  UNBOUND_VALUE = _UnboundValue()
+
+  def resolve_static_getattr(self, stage: ImportStage,
+                             attr_name: str) -> "ResolveOutcome":
+    return Intrinsic.UNBOUND_VALUE
+
+  def emit_call(self, stage: ImportStage, args: Sequence[ir.Value],
+                keywords: Sequence[Any]) -> ir.Value:
+    stage.ic.abort(f"the compiler intrinsic {self} does not support calls")
+
+  def emit_immediate(self, stage: ImportStage) -> ir.Value:
+    """Emits this object as an immediate value.
+
+    On failure, abort with error.
+    """
+    stage.ic.abort(
+        f"the compiler intrinsic {self} can not be serialized as a value")
+
+  @staticmethod
+  def make_singleton(cls) -> "Intrinsic":
+    """Class decorator to instantiate a singleton intrinsic class."""
+    assert issubclass(cls, Intrinsic)
+    return cls()
+
+
+ResolveOutcome = Union[_UnboundValue, Intrinsic, ir.Value]
+
+
+class FuncProvidingIntrinsic(Intrinsic):
+  """An intrinsic which provides an IR function in some way.
+
+  This provides an additional entry point for retrieving the provided
+  function symbol.
+  """
+
+  def get_provided_func_symbol(self, stage: ImportStage) -> str:
+    raise NotImplementedError()
+
+
+class ImportHooks:
+  """Hooks for customizing the import process."""
+
+  def resolve_annotation_to_type(self, stage: ImportStage, annot) -> ir.Type:
+    """Resolves a live, function annotation to a type.
+
+    TODO: This currently has some dependency on whether crossing a
+    ref-providence boundary. May need to untangle.
+    """
+    return d.ObjectType.get(context=stage.ic.context)
+
+  def resolve_global(
+      self,
+      stage: ImportStage,
+      name: str,
+      *,
+      host_closure_vars: Optional[inspect.ClosureVars] = None
+  ) -> ResolveOutcome:
+    """Resolves a global name.
+
+    By default, this returns NO_VALUE, indicating that the global cannot
+    be found.
+
+    Typical implementations will consult the provided 'globals' dict and
+    make a decision on a type of Intrinsic to return, bridging the host
+    runtime namespace to what the compiler should consider. There are many
+    strategies for doing this, each providing unique features and user
+    experiences.
+    """
+    return Intrinsic.UNBOUND_VALUE
+
+
+class DefaultImportHooks(ImportHooks):
+  """Hooks that provide some default behavior.
+
+  This has not been fully thought through yet with respect to layering
+  for real users. This may just become for testing.
+  """
+
+  def resolve_annotation_to_type(self, stage: ImportStage, annot) -> ir.Type:
+    """Resolves a live, function annotation to a type.
+
+    TODO: This currently has some dependency on whether crossing a
+    ref-providence boundary. May need to untangle.
+    """
+    ic = stage.ic
+    # Handle built-in primitive mappings.
+    with ic.context:
+      # Value types.
+      if annot is bool:
+        return d.BoolType.get()
+      if annot is int:
+        return d.IntegerType.get()
+      if annot is float:
+        return d.RealType.get()
+      if annot is None:
+        return d.NoneType.get()
+      if annot is inspect.Signature.empty:
+        # Special value for return annotations to signal no annotation.
+        return d.ObjectType.get()
+
+      # Reference types. We always box these across function boundaries
+      # to preserve providence.
+      # TODO: Better heuristic?
+      # TODO: Support typing annotations, not just raw types.
+      if annot is str:
+        return d.ObjectType.get_typed(d.StrType.get())
+      if annot is list:
+        return d.ObjectType.get_typed(d.ListType.get())
+      if annot is tuple:
+        return d.ObjectType.get_typed(d.TupleType.get())
+      if annot is type:
+        return d.ObjectType.get_typed(d.TypeType.get())
+
+    # Fall-back.
+    return super().resolve_annotation_to_type(stage, annot)
+
+  def resolve_global(
+      self,
+      stage: ImportStage,
+      name: str,
+      *,
+      host_closure_vars: Optional[inspect.ClosureVars] = None) -> Any:
+    ic = stage.ic
+    root = PassthroughModuleIntrinsic(
+        host_closure_vars.globals if host_closure_vars else dict())
+    found = root.resolve_static_getattr(stage, name)
+    if found is not Intrinsic.UNBOUND_VALUE:
+      return found
+    # Resolve against builtins.
+    if (not name in host_closure_vars.builtins):  # pytype: disable=attribute-error
+      return Intrinsic.UNBOUND_VALUE
+
+    from . import builtins_intrinsics
+    if hasattr(builtins_intrinsics, name):
+      return getattr(builtins_intrinsics, name)
+    ic.abort(f"builtin {name} is defined for the host Python but not yet"
+             f"implemented for this compiler")
+
+
+def create_context() -> ir.Context:
+  context = ir.Context()
+  d.register_dialect(context)
+  return context
+
+
+class PassthroughModuleIntrinsic(Intrinsic):
+  """Represents a host Python module, returning intrinsics for sub modules."""
+
+  def __init__(self, m: Union[Mapping[str, Any], ModuleType]):
+    self.m = m
+
+  def resolve_static_getattr(self, stage: ImportStage,
+                             attr_name: str) -> ResolveOutcome:
+    ic = stage.ic
+    m = self.m
+    try:
+      if isinstance(m, dict):
+        child = m[attr_name]
+      else:
+        child = getattr(m, attr_name)
+    except KeyError:
+      return Intrinsic.UNBOUND_VALUE
+    except AttributeError:
+      return Intrinsic.UNBOUND_VALUE
+
+    # We only return values that are modules or intrinsics.
+    if isinstance(child, _ModuleType):
+      return PassthroughModuleIntrinsic(child)
+    elif isinstance(child, Intrinsic):
+      return child
+    else:
+      ic.abort(f"when resolving '{attr_name}' against module {m}, "
+               f"encountered an unsupported type ({child.__class__})")
diff --git a/llvm-external-projects/iree-dialects/test/CMakeLists.txt b/llvm-external-projects/iree-dialects/test/CMakeLists.txt
index 45f1daa..cdff103 100644
--- a/llvm-external-projects/iree-dialects/test/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/test/CMakeLists.txt
@@ -11,6 +11,7 @@
 
 set(IREE_DIALECTS_TEST_DEPENDS
         FileCheck count not
+        iree-dialects-opt
         )
 
 if(MLIR_ENABLE_BINDINGS_PYTHON)
diff --git a/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/assignment.py b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/assignment.py
new file mode 100644
index 0000000..8b32f08
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/assignment.py
@@ -0,0 +1,14 @@
+# RUN: %PYTHON %s | iree-dialects-opt | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+from typing import List
+from mlir.dialects.iree_pydm.importer.test_util import *
+
+
+# CHECK-LABEL: @assign_free_var_not_arg
+# CHECK: %[[CST:.*]] = iree_pydm.constant 1
+# CHECK: %[[BOXED:.*]] = iree_pydm.box %[[CST]] : !iree_pydm.integer -> !iree_pydm.object<!iree_pydm.integer>
+# CHECK: iree_pydm.store_free_var "x", %[[BOXED]]
+@test_import_global
+def assign_free_var_not_arg():
+  x = 1
+  return x
diff --git a/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/binary.py b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/binary.py
new file mode 100644
index 0000000..888536b
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/binary.py
@@ -0,0 +1,98 @@
+# RUN: %PYTHON %s | iree-dialects-opt | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+from typing import List
+from mlir.dialects.iree_pydm.importer.test_util import *
+
+
+# CHECK-LABEL: @binary_add
+# CHECK: %[[L:.*]] = iree_pydm.load_free_var "a"
+# CHECK: %[[R:.*]] = iree_pydm.load_free_var "b"
+# CHECK: %[[LP:.*]], %[[RP:.*]] = iree_pydm.dynamic_binary_promote %[[L]], %[[R]]
+# CHECK: iree_pydm.apply_binary "add", %[[LP]], %[[RP]]
+@test_import_global
+def binary_add(a, b):
+  return a + b
+
+
+# CHECK-LABEL: @binary_sub
+# CHECK: iree_pydm.apply_binary "sub"
+@test_import_global
+def binary_sub(a, b):
+  return a - b
+
+
+# CHECK-LABEL: @binary_mul
+# CHECK: iree_pydm.apply_binary "mul"
+@test_import_global
+def binary_mul(a, b):
+  return a * b
+
+
+# CHECK-LABEL: @binary_matmul
+# CHECK: iree_pydm.apply_binary "matmul"
+@test_import_global
+def binary_matmul(a, b):
+  return a @ b
+
+
+# CHECK-LABEL: @binary_truediv
+# CHECK: iree_pydm.apply_binary "truediv"
+@test_import_global
+def binary_truediv(a, b):
+  return a / b
+
+
+# CHECK-LABEL: @binary_floordiv
+# CHECK: iree_pydm.apply_binary "floordiv"
+@test_import_global
+def binary_floordiv(a, b):
+  return a // b
+
+
+# CHECK-LABEL: @binary_mod
+# CHECK: iree_pydm.apply_binary "mod"
+@test_import_global
+def binary_mod(a, b):
+  return a % b
+
+
+# CHECK-LABEL: @binary_pow
+# CHECK: iree_pydm.apply_binary "pow"
+@test_import_global
+def binary_pow(a, b):
+  return a**b
+
+
+# CHECK-LABEL: @binary_lshift
+# CHECK: iree_pydm.apply_binary "lshift"
+@test_import_global
+def binary_lshift(a, b):
+  return a << b
+
+
+# CHECK-LABEL: @binary_rshift
+# CHECK: iree_pydm.apply_binary "rshift"
+@test_import_global
+def binary_rshift(a, b):
+  return a >> b
+
+
+# CHECK-LABEL: @binary_and
+# CHECK: iree_pydm.apply_binary "and"
+@test_import_global
+def binary_and(a, b):
+  return a & b
+
+
+# CHECK-LABEL: @binary_or
+# CHECK: iree_pydm.apply_binary "or"
+@test_import_global
+def binary_or(a, b):
+  return a | b
+
+
+# CHECK-LABEL: @binary_xor
+# CHECK: iree_pydm.apply_binary "xor"
+@test_import_global
+def binary_xor(a, b):
+  return a ^ b
diff --git a/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/booleans.py b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/booleans.py
new file mode 100644
index 0000000..63461b4
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/booleans.py
@@ -0,0 +1,87 @@
+# RUN: %PYTHON %s | iree-dialects-opt | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+from typing import List
+from mlir.dialects.iree_pydm.importer.test_util import *
+
+
+# CHECK-LABEL: @logical_and
+# CHECK: %[[XVAL:.*]] = iree_pydm.load_free_var "x"
+# CHECK: %[[XBOOL:.*]] = iree_pydm.as_bool %[[XVAL]]
+# CHECK: %[[XPRED:.*]] = iree_pydm.bool_to_pred %[[XBOOL]]
+# CHECK: %[[R1:.*]] = scf.if %[[XPRED]] {{.*}} {
+# CHECK:   %[[YVAL:.*]] = iree_pydm.load_free_var "y"
+# CHECK:   %[[YBOOL:.*]] = iree_pydm.as_bool %[[YVAL]]
+# CHECK:   %[[YPRED:.*]] = iree_pydm.bool_to_pred %[[YBOOL]]
+# CHECK:   %[[R2:.*]] = scf.if %[[YPRED]] {{.*}} {
+# CHECK:     %[[ZVAL:.*]] = iree_pydm.load_free_var "z"
+# CHECK:     scf.yield %[[ZVAL]]
+# CHECK:   } else {
+# CHECK:     scf.yield %[[YVAL]]
+# CHECK:   }
+# CHECK:   scf.yield %[[R2]]
+# CHECK: } else {
+# CHECK:   scf.yield %[[XVAL]]
+# CHECK: }
+@test_import_global
+def logical_and():
+  x = 1
+  y = 0
+  z = 2
+  return x and y and z
+
+
+# # CHECK-LABEL: @logical_or
+# CHECK: %[[XVAL:.*]] = iree_pydm.load_free_var "x"
+# CHECK: %[[XBOOL:.*]] = iree_pydm.as_bool %[[XVAL]]
+# CHECK: %[[XPRED:.*]] = iree_pydm.bool_to_pred %[[XBOOL]]
+# CHECK: %[[R1:.*]] = scf.if %[[XPRED]] {{.*}} {
+# CHECK:   scf.yield %[[XVAL]]
+# CHECK: } else {
+# CHECK:   %[[YVAL:.*]] = iree_pydm.load_free_var "y"
+# CHECK:   %[[YBOOL:.*]] = iree_pydm.as_bool %[[YVAL]]
+# CHECK:   %[[YPRED:.*]] = iree_pydm.bool_to_pred %[[YBOOL]]
+# CHECK:   %[[R2:.*]] = scf.if %[[YPRED]] {{.*}} {
+# CHECK:     scf.yield %[[YVAL]]
+# CHECK:   } else {
+# CHECK:     %[[ZVAL:.*]] = iree_pydm.load_free_var "z"
+# CHECK:     scf.yield %[[ZVAL]]
+# CHECK:   }
+# CHECK:   scf.yield %[[R2]]
+# CHECK: }
+@test_import_global
+def logical_or():
+  x = 0
+  y = 1
+  z = 2
+  return x or y or z
+
+
+# CHECK-LABEL: func @logical_not
+# CHECK: %[[XVAL:.*]] = iree_pydm.load_free_var "x"
+# CHECK: %[[XBOOL:.*]] = iree_pydm.as_bool %[[XVAL]]
+# CHECK: %[[T:.*]] = iree_pydm.constant true
+# CHECK: %[[F:.*]] = iree_pydm.constant false
+# CHECK: %[[R:.*]] = iree_pydm.select %[[XBOOL]], %[[F]], %[[T]]
+@test_import_global
+def logical_not():
+  x = 1
+  return not x
+
+
+# CHECK-LABEL: func @conditional
+# CHECK: %[[XVAL:.*]] = iree_pydm.load_free_var "x"
+# CHECK: %[[XBOOL:.*]] = iree_pydm.as_bool %[[XVAL]]
+# CHECK: %[[XPRED:.*]] = iree_pydm.bool_to_pred %[[XBOOL]]
+# CHECK: %[[R1:.*]] = scf.if %[[XPRED]] {{.*}} {
+# CHECK:   %[[TWOVAL:.*]] = iree_pydm.constant 2
+# CHECK:   %[[TWOBOXED:.*]] = iree_pydm.box %[[TWOVAL]] : !iree_pydm.integer -> !iree_pydm.object
+# CHECK:   scf.yield %[[TWOBOXED]]
+# CHECK: } else {
+# CHECK:   %[[THREEVAL:.*]] = iree_pydm.constant 3
+# CHECK:   %[[THREEBOXED:.*]] = iree_pydm.box %[[THREEVAL]] : !iree_pydm.integer -> !iree_pydm.object
+# CHECK:   scf.yield %[[THREEBOXED]]
+# CHECK: }
+@test_import_global
+def conditional():
+  x = 1
+  return 2 if x else 3
diff --git a/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/comparison.py b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/comparison.py
new file mode 100644
index 0000000..213c9f5
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/comparison.py
@@ -0,0 +1,157 @@
+# RUN: %PYTHON %s | iree-dialects-opt | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+# pytype: disable=invalid-directive
+# pytype: disable=unsupported-operands
+
+from typing import List
+from mlir.dialects.iree_pydm.importer.test_util import *
+
+
+# CHECK-LABEL: func @binary_lt_
+# CHECK-DAG: %[[L:.*]] = iree_pydm.load_free_var "x"
+# CHECK-DAG: %[[R:.*]] = iree_pydm.load_free_var "y"
+# CHECK: %[[LP:.*]], %[[RP:.*]] = iree_pydm.dynamic_binary_promote %[[L]], %[[R]]
+# CHECK: iree_pydm.apply_compare "lt", %[[LP]], %[[RP]]
+@test_import_global
+def binary_lt_():
+  x = 1
+  y = 2
+  return x < y
+
+
+# CHECK-LABEL: func @binary_gt_
+# CHECK: iree_pydm.dynamic_binary_promote
+# CHECK: iree_pydm.apply_compare "gt"
+@test_import_global
+def binary_gt_():
+  x = 1
+  y = 2
+  return x > y
+
+
+# CHECK-LABEL: func @binary_lte_
+# CHECK: iree_pydm.dynamic_binary_promote
+# CHECK: iree_pydm.apply_compare "le"
+@test_import_global
+def binary_lte_():
+  x = 1
+  y = 2
+  return x <= y
+
+
+# CHECK-LABEL: func @binary_gte_
+# CHECK: iree_pydm.dynamic_binary_promote
+# CHECK: iree_pydm.apply_compare "ge"
+@test_import_global
+def binary_gte_():
+  x = 1
+  y = 2
+  return x >= y
+
+
+# CHECK-LABEL: func @binary_eq_
+# CHECK: iree_pydm.dynamic_binary_promote
+# CHECK: iree_pydm.apply_compare "eq"
+@test_import_global
+def binary_eq_():
+  x = 1
+  y = 2
+  return x == y
+
+
+# CHECK-LABEL: func @binary_neq_
+# CHECK: iree_pydm.dynamic_binary_promote
+# CHECK: iree_pydm.apply_compare "ne"
+@test_import_global
+def binary_neq_():
+  x = 1
+  y = 2
+  return x != y
+
+
+# CHECK-LABEL: func @binary_is_
+# CHECK-NOT: iree_pydm.dynamic_binary_promote
+# CHECK: iree_pydm.apply_compare "is"
+@test_import_global
+def binary_is_():
+  x = 1
+  y = 2
+  return x is y
+
+
+# CHECK-LABEL: func @binary_is_not_
+# CHECK-NOT: iree_pydm.dynamic_binary_promote
+# CHECK: iree_pydm.apply_compare "isnot"
+@test_import_global
+def binary_is_not_():
+  x = 1
+  y = 2
+  return x is not y
+
+
+# CHECK-LABEL: func @binary_in_
+# CHECK-NOT: iree_pydm.dynamic_binary_promote
+# CHECK: iree_pydm.apply_compare "in"
+@test_import_global
+def binary_in_():
+  x = 1
+  y = 2
+  return x in y
+
+
+# CHECK-LABEL: func @binary_not_in_
+# CHECK-NOT: iree_pydm.dynamic_binary_promote
+# CHECK: iree_pydm.apply_compare "notin"
+@test_import_global
+def binary_not_in_():
+  x = 1
+  y = 2
+  return x not in y
+
+
+# CHECK-LABEL: @short_circuit
+# CHECK-DAG: %[[FALSE:.*]] = iree_pydm.constant false
+# CHECK-DAG: %[[X:.*]] = iree_pydm.load_free_var "x"
+# CHECK-DAG: %[[Y:.*]] = iree_pydm.load_free_var "y"
+# CHECK: %[[XP:.*]], %[[YP:.*]] = iree_pydm.dynamic_binary_promote %[[X]], %[[Y]]
+# CHECK: %[[R1:.*]] = iree_pydm.apply_compare "lt", %[[XP]], %[[YP]]
+# CHECK: %[[RP1:.*]] = iree_pydm.bool_to_pred %[[R1]]
+# CHECK: %[[RESULT:.*]] = scf.if %[[RP1]] {{.*}} {
+# CHECK:   %[[Z:.*]] = iree_pydm.load_free_var "z"
+# NOTE: Promotion happens on original loaded values, not already promoted
+# values.
+# CHECK:   %[[YP1:.*]], %[[ZP1:.*]] = iree_pydm.dynamic_binary_promote %[[Y]], %[[Z]]
+# CHECK:   %[[R2:.*]] = iree_pydm.apply_compare "eq", %[[YP1]], %[[ZP1]]
+# CHECK:   %[[RP2:.*]] = iree_pydm.bool_to_pred %[[R2]]
+# CHECK:   %[[RESULT1:.*]] = scf.if %[[RP2]] {{.*}} {
+# CHECK:     %[[OMEGA:.*]] = iree_pydm.load_free_var "omega"
+# CHECK:     %[[ZP2:.*]], %[[OMEGAP2:.*]] = iree_pydm.dynamic_binary_promote %[[Z]], %[[OMEGA]]
+# CHECK:     %[[R3:.*]] = iree_pydm.apply_compare "ge", %[[ZP2]], %[[OMEGAP2]]
+# CHECK:     scf.yield %[[R3]]
+# CHECK:   } else {
+# CHECK:     scf.yield %[[FALSE]]
+# CHECK:   }
+# CHECK:   scf.yield %[[RESULT1]]
+# CHECK: } else {
+# CHECK:   scf.yield %[[FALSE]]
+# CHECK: }
+@test_import_global
+def short_circuit():
+  x = 1
+  y = 2
+  z = 3
+  omega = 5
+  return x < y == z >= omega
+
+
+# CHECK-LABEL: nested_short_circuit_expression
+# Verify that the nested expression is evaluated in the context of the if.
+# CHECK: scf.if {{.*}} {
+# CHECK:   iree_pydm.apply_binary "add"
+# CHECK: } else {
+@test_import_global
+def nested_short_circuit_expression():
+  x = 1
+  y = 2
+  z = 3
+  return x < y == (z + 6)
diff --git a/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/constants.py b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/constants.py
new file mode 100644
index 0000000..1a2c268
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/constants.py
@@ -0,0 +1,53 @@
+# RUN: %PYTHON %s | iree-dialects-opt | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+from typing import List
+from mlir.dialects.iree_pydm.importer.test_util import *
+
+
+# CHECK-LABEL: @const_integer
+# CHECK: iree_pydm.constant 1 : si64 -> !iree_pydm.integer
+@test_import_global
+def const_integer():
+  return 1
+
+
+# CHECK-LABEL: @const_float
+# CHECK: iree_pydm.constant 2.200000e+00 : f64 -> !iree_pydm.real
+@test_import_global
+def const_float():
+  return 2.2
+
+
+# CHECK-LABEL: @const_str
+# CHECK: iree_pydm.constant "Hello" -> !iree_pydm.str
+@test_import_global
+def const_str():
+  return "Hello"
+
+
+# CHECK-LABEL: @const_bytes
+# CHECK: iree_pydm.constant "Bonjour" -> !iree_pydm.bytes
+@test_import_global
+def const_bytes():
+  return b"Bonjour"
+
+
+# CHECK-LABEL: @const_none
+# CHECK: iree_pydm.none -> !iree_pydm.none
+@test_import_global
+def const_none():
+  return None
+
+
+# CHECK-LABEL: @const_true
+# CHECK: iree_pydm.constant true -> !iree_pydm.bool
+@test_import_global
+def const_true():
+  return True
+
+
+# CHECK-LABEL: @const_false
+# CHECK: iree_pydm.constant false -> !iree_pydm.bool
+@test_import_global
+def const_false():
+  return False
diff --git a/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/flow_control.py b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/flow_control.py
new file mode 100644
index 0000000..7101dde
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/flow_control.py
@@ -0,0 +1,79 @@
+# RUN: %PYTHON %s | iree-dialects-opt | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+from typing import List
+from mlir.dialects.iree_pydm.importer.test_util import *
+
+
+# CHECK-LABEL: @simple_if
+# CHECK: %[[COND:.*]] = iree_pydm.load_free_var "cond"
+# CHECK: %[[COND_BOOL:.*]] = iree_pydm.as_bool %[[COND]]
+# CHECK: %[[COND_PRED:.*]] = iree_pydm.bool_to_pred %[[COND_BOOL]]
+# CHECK: cond_br %2, ^bb1, ^bb2
+# CHECK: ^bb1:
+# CHECK: %[[A:.*]] = iree_pydm.load_free_var "a"
+# CHECK: return %[[A]]
+# CHECK: ^bb2:
+# CHECK: %[[B:.*]] = iree_pydm.load_free_var "b"
+# CHECK: return %[[B]]
+@test_import_global
+def simple_if(cond, a, b):
+  if cond:
+    return a
+  else:
+    return b
+
+
+# CHECK-LABEL: @if_fallthrough
+# CHECK: cond_br {{.*}}, ^bb1, ^bb2
+# CHECK: ^bb1:
+# CHECK: br ^bb3
+# CHECK: ^bb2:
+# CHECK: br ^bb3
+# CHECK: ^bb3:
+# CHECK: iree_pydm.return
+@test_import_global
+def if_fallthrough(cond, a, b):
+  if cond:
+    c = a
+  else:
+    c = b
+  return c
+
+
+# CHECK-LABEL: @if_noelse
+# CHECK: cond_br {{.*}}, ^bb1, ^bb2
+# CHECK: ^bb1:
+# CHECK: br ^bb2
+# CHECK: ^bb2:
+# CHECK: iree_pydm.return
+@test_import_global
+def if_noelse(cond, a, b):
+  c = 1
+  if cond:
+    c = a
+  return c
+
+
+# CHECK-LABEL: @if_elif
+# CHECK: cond_br {{.*}}, ^bb1, ^bb2
+# CHECK: ^bb1:
+# CHECK: br ^bb6
+# CHECK: ^bb2:
+# CHECK: cond_br {{.*}}, ^bb3, ^bb4
+# CHECK: ^bb3:
+# CHECK: br ^bb5
+# CHECK: ^bb4:
+# CHECK: br ^bb5
+# CHECK: ^bb5:
+# CHECK: br ^bb6
+# CHECK: ^bb6:
+# CHECK: iree_pydm.return
+@test_import_global
+def if_elif(cond, a, b):
+  if cond:
+    c = a
+  elif b:
+    c = 2
+  else:
+    c = 3
+  return c
diff --git a/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/function_def.py b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/function_def.py
new file mode 100644
index 0000000..fc251347
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/function_def.py
@@ -0,0 +1,21 @@
+# RUN: %PYTHON %s | iree-dialects-opt | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+from typing import List
+from mlir.dialects.iree_pydm.importer.test_util import *
+
+
+# CHECK-LABEL: iree_pydm.func @fully_typed_with_return
+# CHECK-SAME: (%arg0: !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.integer)
+# CHECK-SAME: attributes {arg_names = ["a"], cell_vars = [], free_vars = ["a"]}
+# CHECK: iree_pydm.return {{.*}} : !iree_pydm.integer
+@test_import_global
+def fully_typed_with_return(a: int) -> int:
+  return a
+
+
+# CHECK-LABEL: iree_pydm.func @no_return
+# CHECK: %[[NONE:.*]] = iree_pydm.none
+# CHECK: iree_pydm.return %[[NONE]]
+@test_import_global
+def no_return():
+  pass
diff --git a/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/intrinsics.py b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/intrinsics.py
new file mode 100644
index 0000000..b561c45
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/intrinsics.py
@@ -0,0 +1,105 @@
+# RUN: %PYTHON %s | iree-dialects-opt -split-input-file | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+from typing import List
+from mlir.dialects.iree_pydm.importer import *
+from mlir.dialects.iree_pydm.importer.test_util import *
+
+from mlir.dialects import iree_pydm as d
+from mlir import ir
+
+################################################################################
+# Pyfunc intrinsics
+################################################################################
+
+
+@def_pyfunc_intrinsic(symbol="__return_one")
+def intrinsic_return_one() -> int:
+  return 1
+
+
+@def_pyfunc_intrinsic(symbol="__return_first_true")
+def intrinsic_return_first_true(a: int, b: int) -> int:
+  return a or b
+
+
+# CHECK-LABEL: @test_intrinsic_function_no_args
+# CHECK: iree_pydm.dynamic_call @__return_one() : () -> (!iree_pydm.exception_result, !iree_pydm.object)
+# CHECK: iree_pydm.func private @__return_one()
+@test_import_global
+def test_intrinsic_function_no_args():
+  value = intrinsic_return_one()
+  return value
+
+
+# CHECK-LABEL: @test_intrinsic_function_double_call
+# No need to check anything: verifier will fail if double emitted.
+@test_import_global
+def test_intrinsic_function_double_call():
+  value = intrinsic_return_one()
+  value2 = intrinsic_return_one()
+  return value
+
+
+# CHECK-LABEL: @test_intrinsic_function_args
+# CHECK: %[[ZERO:.*]] = iree_pydm.constant 0 : si64 -> !iree_pydm.integer
+# CHECK: %[[ONE:.*]] = iree_pydm.constant 1 : si64 -> !iree_pydm.integer
+# CHECK: iree_pydm.dynamic_call @__return_first_true(%[[ZERO]], %[[ONE]]) : (!iree_pydm.integer, !iree_pydm.integer) -> (!iree_pydm.exception_result, !iree_pydm.object)
+# CHECK: iree_pydm.func private @__return_first_true
+@test_import_global
+def test_intrinsic_function_args():
+  value = intrinsic_return_first_true(0, 1)
+  return value
+
+
+################################################################################
+# IR macro intrinsics
+################################################################################
+
+
+@def_ir_macro_intrinsic
+def macro_return_none(stage: ImportStage) -> ir.Value:
+  return d.NoneOp(d.NoneType.get()).result
+
+
+# Boxing isn't load bearing here: It is just something we can do/test.
+@def_ir_macro_intrinsic
+def macro_box_arg(stage: ImportStage, arg: ir.Value) -> ir.Value:
+  return stage.ic.box(arg)
+
+
+# CHECK-LABEL: @test_intrinsic_macro_no_args
+# CHECK: %[[ONE:.*]] = iree_pydm.constant 1
+# CHECK: iree_pydm.box %[[ONE]] : !iree_pydm.integer -> !iree_pydm.object<!iree_pydm.integer>
+@test_import_global
+def test_intrinsic_macro_no_args() -> int:
+  return macro_box_arg(1)
+
+
+################################################################################
+# Test multi func intrinsic.
+# There is nothing special about a logical not. It is just something we can
+# test.
+################################################################################
+@def_pyfunc_intrinsic(symbol="__logical_not_bool")
+def logical_not_bool(x: bool) -> bool:
+  return not x
+
+
+@def_pyfunc_intrinsic(symbol="__logical_not_generic")
+def logical_not_generic(x):
+  return not x
+
+
+logical_not = def_pattern_call_intrinsic(match_generic=[logical_not_generic],
+                                         match_specific=[logical_not_bool])
+
+
+# CHECK-LABEL: @test_pattern_call
+# CHECK: %[[TRUE:.*]] = iree_pydm.constant true
+# CHECK: iree_pydm.pattern_match_call(%[[TRUE]]) : (!iree_pydm.bool) -> (!iree_pydm.exception_result, !iree_pydm.object)
+# CHECK-SAME:   matching generic [@__logical_not_generic] specific [@__logical_not_bool]
+# CHECK-DAG: iree_pydm.func private @__logical_not_generic
+# CHECK-DAG: iree_pydm.func private @__logical_not_bool
+@test_import_global
+def test_pattern_call():
+  return logical_not(True)
diff --git a/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/structural.py b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/structural.py
new file mode 100644
index 0000000..35a6ac0
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/test/python/iree_pydm/importer/structural.py
@@ -0,0 +1,12 @@
+# RUN: %PYTHON %s | iree-dialects-opt | FileCheck --enable-var-scope --dump-input-filter=all %s
+
+from typing import List
+from mlir.dialects.iree_pydm.importer.test_util import *
+
+
+# CHECK-LABEL @expr_statement
+# CHECK: %[[XVAL:.*]] = iree_pydm.load_free_var "x"
+# CHECK: iree_pydm.expr_statement_discard %[[XVAL]]
+@test_import_global
+def expr_statement(x: int):
+  x
diff --git a/llvm-external-projects/iree-dialects/test/python/smoketest.py b/llvm-external-projects/iree-dialects/test/python/smoketest.py
index 82f0262..f92b132 100644
--- a/llvm-external-projects/iree-dialects/test/python/smoketest.py
+++ b/llvm-external-projects/iree-dialects/test/python/smoketest.py
@@ -2,6 +2,13 @@
 
 import mlir.ir
 from mlir.dialects import iree
+from mlir.dialects import iree_pydm
 
 with mlir.ir.Context() as ctx:
-  iree.register_iree_dialect(ctx)
+  iree.register_dialect()
+  iree_pydm.register_dialect()
+
+  # iree_pydm types.
+  bool_t = iree_pydm.BoolType.get()
+  typed_object_t = iree_pydm.ObjectType.get_typed(bool_t)
+  untyped_object_t = iree_pydm.ObjectType.get()
diff --git a/llvm-external-projects/iree-dialects/tools/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/CMakeLists.txt
new file mode 100644
index 0000000..3bed1e2
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/tools/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(iree-dialects-opt)
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
new file mode 100644
index 0000000..be4c269
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/CMakeLists.txt
@@ -0,0 +1,18 @@
+set(LIBS
+  MLIRDialect
+  MLIROptLib
+  MLIRSCF
+  MLIRSCFTransforms
+  MLIRStandard
+  MLIRTransforms
+  IREEDialectsIREEDialect
+  IREEDialectsIREEPyDMDialect
+)
+
+add_llvm_tool(iree-dialects-opt
+  iree-dialects-opt.cpp
+
+  DEPENDS
+  ${LIBS}
+)
+target_link_libraries(iree-dialects-opt PRIVATE ${LIBS})
diff --git a/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
new file mode 100644
index 0000000..afc7e09
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/tools/iree-dialects-opt/iree-dialects-opt.cpp
@@ -0,0 +1,36 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/IREE/IREEDialect.h"
+#include "iree-dialects/Dialect/IREEPyDM/IR/IREEPyDMDialect.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/InitAllPasses.h"
+#include "mlir/Support/MlirOptMain.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+int main(int argc, char **argv) {
+  registerAsmPrinterCLOptions();
+  registerMLIRContextCLOptions();
+
+  registerTransformsPasses();
+  registerSCFPasses();
+
+  DialectRegistry registry;
+  registry.insert<
+      // Local dialects
+      mlir::iree::IREEDialect, mlir::iree_pydm::IREEPyDMDialect,
+      // Upstream dialects
+      mlir::StandardOpsDialect, mlir::scf::SCFDialect>();
+
+  return mlir::asMainReturnCode(
+      mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
+                        /*preloadDialectsInContext=*/false));
+}