Generate arm intrinsics for MMT4D (#7751)

This passes all the e2e matmul tests (I added one specifically for the
MMT4D shape, which I think was missing test coverage). The option to
use intrinsics vs asm is added as a global flag to be used for
*development* as we explore these two different paths. It defaults to
using asm, since that version is actually producing the thing we want
right now. Registration of Arm Neon dialects and passes is added
unconditionally in a few places. I can look into making that
conditional if that seems worthwhile.

The generated asm is also not great. The LLVM shuffle + vdotq_s32
intrinsic is not getting switched into the version that uses lanes
directly. I wonder if the issue has to do with no way to express poison
at this level.

Here's a gist showing a lowering down to asm of the lit test here:
https://gist.github.com/GMNGeoffrey/02509944091560adf8150ceb2445cb27

In contrast to inline asm:
https://gist.github.com/GMNGeoffrey/06c4bb92708f1d3d2bc173a4ceafe5dd

Regardless, this isn't in a production path, so I think we can optimize
later.
diff --git a/iree/compiler/Codegen/LLVMCPU/BUILD b/iree/compiler/Codegen/LLVMCPU/BUILD
index 7fbbf3a..4d61f31 100644
--- a/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -44,6 +44,8 @@
         "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:ArithmeticToLLVM",
         "@llvm-project//mlir:ArithmeticTransforms",
+        "@llvm-project//mlir:ArmNeon",
+        "@llvm-project//mlir:ArmNeon2dToIntr",
         "@llvm-project//mlir:CFGTransforms",
         "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
diff --git a/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index 04e25a3..bc801a5 100644
--- a/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -33,6 +33,8 @@
     MLIRAnalysis
     MLIRArithmeticToLLVM
     MLIRArithmeticTransforms
+    MLIRArmNeon
+    MLIRArmNeon2dToIntr
     MLIRIR
     MLIRLLVMCommonConversion
     MLIRLLVMIR
diff --git a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index eab2e49..1e752d2 100644
--- a/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
+#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -30,6 +31,7 @@
 #include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/Math/IR/Math.h"
@@ -615,7 +617,7 @@
   ConvertToLLVMPass() = default;
   ConvertToLLVMPass(const ConvertToLLVMPass &pass) {}
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<LLVM::LLVMDialect>();
+    registry.insert<LLVM::LLVMDialect, arm_neon::ArmNeonDialect>();
   }
 
   void runOnOperation() override;
@@ -673,6 +675,7 @@
     vector::populateVectorMaskOpLoweringPatterns(patterns);
     vector::populateVectorShapeCastLoweringPatterns(patterns);
     vector::populateVectorTransposeLoweringPatterns(patterns);
+    populateConvertArmNeon2dToIntrPatterns(patterns);
     if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                             std::move(patterns)))) {
       return signalPassFailure();
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
index 05dddaa..1a5e806 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
@@ -28,6 +28,15 @@
 namespace mlir {
 namespace iree_compiler {
 
+// A flag to switch between inline asm and intrinsics while we develop these two
+//  parallel paths.
+static llvm::cl::opt<bool> clUseMmt4dUseIntrinsics(
+    "iree-codegen-mmt4d-use-intrinsics",
+    llvm::cl::desc("Whether to use instrinsics when lowering vector contracts "
+                   "generated from mmt4d matmuls (as opposed to inline asm). "
+                   "Not for production use."),
+    llvm::cl::init(false));
+
 namespace {
 // Could just be linalg::TilingPattern with a ContractionOpInterface filter, but
 // that is always templated on an op.
@@ -342,6 +351,7 @@
     // just before the generic vector ops lowerings.
     CustomKernelsTargetInfo info;
     if (succeeded(InferCustomKernelsTargetInfoFromParent(funcOp, info))) {
+      info.intrinsics = clUseMmt4dUseIntrinsics;
       RewritePatternSet patterns(context);
       populateVectorContractCustomKernelsPatterns(info, patterns);
       if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
diff --git a/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 9a067f2..0131ae1 100644
--- a/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -13,6 +13,7 @@
 #include "iree/compiler/Codegen/Sandbox/Passes.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
 #include "llvm/Support/CommandLine.h"
+#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
@@ -320,6 +321,7 @@
   passManager.addNestedPass<FuncOp>(arith::createArithmeticExpandOpsPass());
   passManager.addNestedPass<FuncOp>(memref::createExpandOpsPass());
   passManager.addPass(createConvertToLLVMPass());
+  passManager.addPass(createReconcileUnrealizedCastsPass());
 
   // We rely on MLIR symbol visibility being correct after this point and need
   // to mirror the LLVM linkage that was assigned during conversion.
diff --git a/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
index e094448..0881503 100644
--- a/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
@@ -11,6 +11,7 @@
 #include "llvm/ADT/Triple.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -150,6 +151,9 @@
 
 // Checks that the Value `extResult` is defined by an arith::ExtSIOp promoting
 // from `extSrcType` to `extDstType`, and returns the input of the ExtSIOp.
+// Note that this only looks at the immediately defining operation, so we likely
+// want to have earlier passes that sink widening operations as far down as
+// possible, which is probably just good regardless.
 static Value getExtSIInput(Type extSrcType, Type extDstType, Value extResult) {
   auto extSIOp = extResult.getDefiningOp<arith::ExtSIOp>();
   if (!extSIOp) {
@@ -248,9 +252,6 @@
           extract1DSlice(rewriter, loc, int32x4Type, flatAcc, position));
     }
 
-    // Start of the code that's specific to inline assembly. An intrinsics
-    // code path would diverge here.
-
     // Create the inline asm op's operands list.
     SmallVector<Value> asmOperands;
     // First the inputs operands.
@@ -299,9 +300,6 @@
           loc, int32x4Type, asmOp.getRes(), rewriter.getI64ArrayAttr({i})));
     }
 
-    // End of the code that's specific to inline assembly. An intrinsics code
-    // path would merge here.
-
     // Insert the result vectors of size 4 into the overall result vector of
     // size 64, still 1D.
     VectorType int32x64xType = VectorType::get({64}, I32Type);
@@ -321,11 +319,129 @@
   }
 };
 
+/// Converts matrix-times-matrix-transposed vector.contracts with
+/// lhs and rhs inputs defined by arith.extsi promoting from i8 to i32,
+///
+///     %lhs_i32 = arith.extsi %lhs_i8 : i8 to i32
+///     %rhs_i32 = arith.extsi %rhs_i8 : i8 to i32
+///     %result = vector.contract [...]
+///                 %lhs_i32 : vector<8x4xi32>,
+///                 %rhs_i32 : vector<8x4xi32>,
+///                 %acc_i32 : vector<8x8xi32>,
+///                 [...]
+///
+/// To vector ops reading directly from the %lhs_i8 and %rhs_i8 values
+/// (bypassing the existing arith.extsi) and passing that to a llvm.inline_asm
+/// block implementing the matrix multiplication arithmetic using Aarch64
+/// dot-product instructions (sdot).
+/// It matches the same patterns as MMT_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm
+struct MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics
+    : public OpRewritePattern<vector::ContractionOp> {
+ public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractionOp,
+                                PatternRewriter &rewriter) const override {
+    if (!isMatrixTimesMatrixTransposedOfGivenShape(contractionOp, 8, 4, 8)) {
+      return failure();
+    }
+
+    Type I8Type = rewriter.getIntegerType(8);
+    Type I32Type = rewriter.getIntegerType(32);
+
+    auto acc = contractionOp.acc();
+    auto lhs = contractionOp.lhs();
+    auto rhs = contractionOp.rhs();
+    if (acc.getType().cast<VectorType>().getElementType() != I32Type) {
+      return failure();
+    }
+
+    Value inLhs = getExtSIInput(I8Type, I32Type, lhs);
+    Value inRhs = getExtSIInput(I8Type, I32Type, rhs);
+
+    if (!inLhs || !inRhs) return failure();
+
+    auto loc = contractionOp.getLoc();
+
+    auto int32x4VType = VectorType::get({4}, I32Type);
+
+    std::array<Value, 16> accChunks;
+    {
+      int idx = 0;
+      for (int row = 0; row < 8; ++row) {
+        auto accRow = rewriter.create<vector::ExtractOp>(
+            loc, acc, ArrayRef<int64_t>{row});
+        for (int col = 0; col < 8; col += 4) {
+          auto accChunk = rewriter.create<vector::ExtractStridedSliceOp>(
+              loc, accRow, ArrayRef<int64_t>{col}, ArrayRef<int64_t>{4},
+              ArrayRef<int64_t>{1});
+          assert(accChunk.getType() == int32x4VType);
+          accChunks[idx++] = accChunk;
+        }
+      }
+    }
+
+    auto int8x4x4VType = VectorType::get({4, 4}, rewriter.getIntegerType(8));
+    auto extract4x4 = [&](Value in, int rowOffset, int colOffset) {
+      auto chunk = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, in, ArrayRef<int64_t>{rowOffset, colOffset},
+          ArrayRef<int64_t>{4, 4}, ArrayRef<int64_t>{1, 1});
+      assert(chunk.getType() == int8x4x4VType);
+      return chunk;
+    };
+
+    std::array<Value, 2> lhsHalves = {extract4x4(inLhs, 0, 0),
+                                      extract4x4(inLhs, 4, 0)};
+    std::array<Value, 2> rhsHalves = {extract4x4(inRhs, 0, 0),
+                                      extract4x4(inRhs, 4, 0)};
+
+    auto int8Zero4x4 = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(int8x4x4VType));
+    auto sdot = [&](Value acc, Value a, Value b, int64_t lane) -> Value {
+      auto bReplicatedLane = rewriter.create<vector::ShuffleOp>(
+          loc, b, int8Zero4x4, ArrayRef<int64_t>{lane, lane, lane, lane});
+
+      return rewriter.create<arm_neon::Sdot2dOp>(loc, int32x4VType, acc, a,
+                                                 bReplicatedLane);
+    };
+
+    std::array<Value, 16> dstChunks;
+    {
+      int idx = 0;
+      for (Value lhs : lhsHalves) {
+        for (int lane = 0; lane < 4; ++lane) {
+          for (Value rhs : rhsHalves) {
+            dstChunks[idx] = sdot(accChunks[idx], rhs, lhs, lane);
+            ++idx;
+          }
+        }
+      }
+    }
+
+    // Put the results back in the accumulator
+    {
+      int idx = 0;
+      for (int row = 0; row < 8; ++row) {
+        for (int col = 0; col < 8; col += 4) {
+          acc = rewriter.create<vector::InsertStridedSliceOp>(
+              loc, dstChunks[idx++], acc, ArrayRef<int64_t>{row, col},
+              ArrayRef<int64_t>{1});
+        }
+      }
+    }
+    rewriter.replaceOp(contractionOp, {acc});
+    return success();
+  }
+};
+
 class VectorContractCustomKernelsPass
     : public VectorContractCustomKernelsBase<VectorContractCustomKernelsPass> {
  public:
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<vector::VectorDialect, LLVM::LLVMDialect>();
+    if (target_info.intrinsics) {
+      registry.insert<arm_neon::ArmNeonDialect>();
+    }
   }
   LogicalResult initializeOptions(StringRef options) override {
     if (failed(Pass::initializeOptions(options))) {
@@ -333,6 +449,7 @@
     }
     target_info.aarch64 = aarch64;
     target_info.dotprod = dotprod;
+    target_info.intrinsics = intrinsics;
     return success();
   }
   void runOnOperation() override {
@@ -355,7 +472,11 @@
     const CustomKernelsTargetInfo &target_info, RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
   if (target_info.aarch64 && target_info.dotprod) {
-    patterns.insert<MMT_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm>(context);
+    if (target_info.intrinsics) {
+      patterns.insert<MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics>(context);
+    } else {
+      patterns.insert<MMT_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm>(context);
+    }
   }
 }
 
diff --git a/iree/compiler/Codegen/LLVMCPU/test/BUILD b/iree/compiler/Codegen/LLVMCPU/test/BUILD
index eb50dab..57e06d0 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/BUILD
+++ b/iree/compiler/Codegen/LLVMCPU/test/BUILD
@@ -18,6 +18,7 @@
 iree_lit_test_suite(
     name = "lit",
     srcs = enforce_glob(
+        # keep sorted
         [
             "check_ir_before_llvm_conversion.mlir",
             "hal_interface_bindings.mlir",
@@ -29,7 +30,8 @@
             "test_config_mmt4d.mlir",
             "tile_fuse_and_vectorize.mlir",
             "unfused_fma.mlir",
-            "vector_contract_custom_kernels.mlir",
+            "vector_contract_to_arm_asm.mlir",
+            "vector_contract_to_arm_intrinsics.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
index e20d855..1cb2a67 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
@@ -24,7 +24,8 @@
     "test_config_mmt4d.mlir"
     "tile_fuse_and_vectorize.mlir"
     "unfused_fma.mlir"
-    "vector_contract_custom_kernels.mlir"
+    "vector_contract_to_arm_asm.mlir"
+    "vector_contract_to_arm_intrinsics.mlir"
   TOOLS
     FileCheck
     iree::tools::iree-opt
diff --git a/iree/compiler/Codegen/LLVMCPU/test/vector_contract_custom_kernels.mlir b/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_asm.mlir
similarity index 100%
rename from iree/compiler/Codegen/LLVMCPU/test/vector_contract_custom_kernels.mlir
rename to iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_asm.mlir
diff --git a/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_intrinsics.mlir b/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_intrinsics.mlir
new file mode 100644
index 0000000..9193ca9
--- /dev/null
+++ b/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_arm_intrinsics.mlir
@@ -0,0 +1,136 @@
+// RUN: iree-opt -iree-llvmcpu-vector-contract-custom-kernels='aarch64 dotprod intrinsics' %s | FileCheck %s
+
+// CHECK-LABEL: @vector_i8i8i32matmul(
+// CHECK-SAME:          %[[LHS:[^:[:space:]]+]]
+// CHECK-SAME:          %[[RHS:[^:[:space:]]+]]
+// CHECK-SAME:          %[[ACC:[^:[:space:]]+]]
+// CHECK-DAG:       %[[ZERO:.*]]          = arith.constant dense<0> : vector<4x4xi8>
+// CHECK-DAG:       %[[ACC_ROW_0:.*]]     = vector.extract %[[ACC]][0] : vector<8x8xi32>
+// CHECK-DAG:       %[[ACC_ROW_1:.*]]     = vector.extract %[[ACC]][1] : vector<8x8xi32>
+// CHECK-DAG:       %[[ACC_ROW_2:.*]]     = vector.extract %[[ACC]][2] : vector<8x8xi32>
+// CHECK-DAG:       %[[ACC_ROW_3:.*]]     = vector.extract %[[ACC]][3] : vector<8x8xi32>
+// CHECK-DAG:       %[[ACC_ROW_4:.*]]     = vector.extract %[[ACC]][4] : vector<8x8xi32>
+// CHECK-DAG:       %[[ACC_ROW_5:.*]]     = vector.extract %[[ACC]][5] : vector<8x8xi32>
+// CHECK-DAG:       %[[ACC_ROW_6:.*]]     = vector.extract %[[ACC]][6] : vector<8x8xi32>
+// CHECK-DAG:       %[[ACC_ROW_7:.*]]     = vector.extract %[[ACC]][7] : vector<8x8xi32>
+// CHECK-DAG:       %[[ACC_CHUNK_00:.*]]  = vector.extract_strided_slice %[[ACC_ROW_0]] {offsets = [0]
+// CHECK-DAG:       %[[ACC_CHUNK_01:.*]]  = vector.extract_strided_slice %[[ACC_ROW_0]] {offsets = [4]
+// CHECK-DAG:       %[[ACC_CHUNK_02:.*]]  = vector.extract_strided_slice %[[ACC_ROW_1]] {offsets = [0]
+// CHECK-DAG:       %[[ACC_CHUNK_03:.*]]  = vector.extract_strided_slice %[[ACC_ROW_1]] {offsets = [4]
+// CHECK-DAG:       %[[ACC_CHUNK_04:.*]]  = vector.extract_strided_slice %[[ACC_ROW_2]] {offsets = [0]
+// CHECK-DAG:       %[[ACC_CHUNK_05:.*]]  = vector.extract_strided_slice %[[ACC_ROW_2]] {offsets = [4]
+// CHECK-DAG:       %[[ACC_CHUNK_06:.*]]  = vector.extract_strided_slice %[[ACC_ROW_3]] {offsets = [0]
+// CHECK-DAG:       %[[ACC_CHUNK_07:.*]]  = vector.extract_strided_slice %[[ACC_ROW_3]] {offsets = [4]
+// CHECK-DAG:       %[[ACC_CHUNK_08:.*]]  = vector.extract_strided_slice %[[ACC_ROW_4]] {offsets = [0]
+// CHECK-DAG:       %[[ACC_CHUNK_09:.*]]  = vector.extract_strided_slice %[[ACC_ROW_4]] {offsets = [4]
+// CHECK-DAG:       %[[ACC_CHUNK_10:.*]]  = vector.extract_strided_slice %[[ACC_ROW_5]] {offsets = [0]
+// CHECK-DAG:       %[[ACC_CHUNK_11:.*]]  = vector.extract_strided_slice %[[ACC_ROW_5]] {offsets = [4]
+// CHECK-DAG:       %[[ACC_CHUNK_12:.*]]  = vector.extract_strided_slice %[[ACC_ROW_6]] {offsets = [0]
+// CHECK-DAG:       %[[ACC_CHUNK_13:.*]]  = vector.extract_strided_slice %[[ACC_ROW_6]] {offsets = [4]
+// CHECK-DAG:       %[[ACC_CHUNK_14:.*]]  = vector.extract_strided_slice %[[ACC_ROW_7]] {offsets = [0]
+// CHECK-DAG:       %[[ACC_CHUNK_15:.*]]  = vector.extract_strided_slice %[[ACC_ROW_7]] {offsets = [4]
+// CHECK-DAG:       %[[LHS_HALF_0:.*]]    = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0]
+// CHECK-DAG:       %[[LHS_HALF_1:.*]]    = vector.extract_strided_slice %[[LHS]] {offsets = [4, 0]
+// CHECK-DAG:       %[[RHS_HALF_0:.*]]    = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0]
+// CHECK-DAG:       %[[RHS_HALF_1:.*]]    = vector.extract_strided_slice %[[RHS]] {offsets = [4, 0]
+// CHECK-DAG:       %[[LHS_CHUNK_00:.*]]  = vector.shuffle %[[LHS_HALF_0]], %[[ZERO]] [0, 0, 0, 0]
+// CHECK-DAG:       %[[LHS_CHUNK_01:.*]]  = vector.shuffle %[[LHS_HALF_0]], %[[ZERO]] [0, 0, 0, 0]
+// CHECK-DAG:       %[[LHS_CHUNK_02:.*]]  = vector.shuffle %[[LHS_HALF_0]], %[[ZERO]] [1, 1, 1, 1]
+// CHECK-DAG:       %[[LHS_CHUNK_03:.*]]  = vector.shuffle %[[LHS_HALF_0]], %[[ZERO]] [1, 1, 1, 1]
+// CHECK-DAG:       %[[LHS_CHUNK_04:.*]]  = vector.shuffle %[[LHS_HALF_0]], %[[ZERO]] [2, 2, 2, 2]
+// CHECK-DAG:       %[[LHS_CHUNK_05:.*]]  = vector.shuffle %[[LHS_HALF_0]], %[[ZERO]] [2, 2, 2, 2]
+// CHECK-DAG:       %[[LHS_CHUNK_06:.*]]  = vector.shuffle %[[LHS_HALF_0]], %[[ZERO]] [3, 3, 3, 3]
+// CHECK-DAG:       %[[LHS_CHUNK_07:.*]]  = vector.shuffle %[[LHS_HALF_0]], %[[ZERO]] [3, 3, 3, 3]
+// CHECK-DAG:       %[[LHS_CHUNK_08:.*]]  = vector.shuffle %[[LHS_HALF_1]], %[[ZERO]] [0, 0, 0, 0]
+// CHECK-DAG:       %[[LHS_CHUNK_09:.*]]  = vector.shuffle %[[LHS_HALF_1]], %[[ZERO]] [0, 0, 0, 0]
+// CHECK-DAG:       %[[LHS_CHUNK_10:.*]]  = vector.shuffle %[[LHS_HALF_1]], %[[ZERO]] [1, 1, 1, 1]
+// CHECK-DAG:       %[[LHS_CHUNK_11:.*]]  = vector.shuffle %[[LHS_HALF_1]], %[[ZERO]] [1, 1, 1, 1]
+// CHECK-DAG:       %[[LHS_CHUNK_12:.*]]  = vector.shuffle %[[LHS_HALF_1]], %[[ZERO]] [2, 2, 2, 2]
+// CHECK-DAG:       %[[LHS_CHUNK_13:.*]]  = vector.shuffle %[[LHS_HALF_1]], %[[ZERO]] [2, 2, 2, 2]
+// CHECK-DAG:       %[[LHS_CHUNK_14:.*]]  = vector.shuffle %[[LHS_HALF_1]], %[[ZERO]] [3, 3, 3, 3]
+// CHECK-DAG:       %[[LHS_CHUNK_15:.*]]  = vector.shuffle %[[LHS_HALF_1]], %[[ZERO]] [3, 3, 3, 3]
+// CHECK-DAG:       %[[SDOT_CHUNK_00:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_00]], %[[RHS_HALF_0]], %[[LHS_CHUNK_00]]
+// CHECK-DAG:       %[[SDOT_CHUNK_01:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_01]], %[[RHS_HALF_1]], %[[LHS_CHUNK_01]]
+// CHECK-DAG:       %[[SDOT_CHUNK_02:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_02]], %[[RHS_HALF_0]], %[[LHS_CHUNK_02]]
+// CHECK-DAG:       %[[SDOT_CHUNK_03:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_03]], %[[RHS_HALF_1]], %[[LHS_CHUNK_03]]
+// CHECK-DAG:       %[[SDOT_CHUNK_04:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_04]], %[[RHS_HALF_0]], %[[LHS_CHUNK_04]]
+// CHECK-DAG:       %[[SDOT_CHUNK_05:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_05]], %[[RHS_HALF_1]], %[[LHS_CHUNK_05]]
+// CHECK-DAG:       %[[SDOT_CHUNK_06:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_06]], %[[RHS_HALF_0]], %[[LHS_CHUNK_06]]
+// CHECK-DAG:       %[[SDOT_CHUNK_07:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_07]], %[[RHS_HALF_1]], %[[LHS_CHUNK_07]]
+// CHECK-DAG:       %[[SDOT_CHUNK_08:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_08]], %[[RHS_HALF_0]], %[[LHS_CHUNK_08]]
+// CHECK-DAG:       %[[SDOT_CHUNK_09:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_09]], %[[RHS_HALF_1]], %[[LHS_CHUNK_09]]
+// CHECK-DAG:       %[[SDOT_CHUNK_10:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_10]], %[[RHS_HALF_0]], %[[LHS_CHUNK_10]]
+// CHECK-DAG:       %[[SDOT_CHUNK_11:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_11]], %[[RHS_HALF_1]], %[[LHS_CHUNK_11]]
+// CHECK-DAG:       %[[SDOT_CHUNK_12:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_12]], %[[RHS_HALF_0]], %[[LHS_CHUNK_12]]
+// CHECK-DAG:       %[[SDOT_CHUNK_13:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_13]], %[[RHS_HALF_1]], %[[LHS_CHUNK_13]]
+// CHECK-DAG:       %[[SDOT_CHUNK_14:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_14]], %[[RHS_HALF_0]], %[[LHS_CHUNK_14]]
+// CHECK-DAG:       %[[SDOT_CHUNK_15:.*]] = arm_neon.2d.sdot %[[ACC_CHUNK_15]], %[[RHS_HALF_1]], %[[LHS_CHUNK_15]]
+// CHECK-DAG:       %[[RES_00:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_00]], %[[ACC]]    {offsets = [0, 0]
+// CHECK-DAG:       %[[RES_01:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_01]], %[[RES_00]] {offsets = [0, 4]
+// CHECK-DAG:       %[[RES_02:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_02]], %[[RES_01]] {offsets = [1, 0]
+// CHECK-DAG:       %[[RES_03:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_03]], %[[RES_02]] {offsets = [1, 4]
+// CHECK-DAG:       %[[RES_04:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_04]], %[[RES_03]] {offsets = [2, 0]
+// CHECK-DAG:       %[[RES_05:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_05]], %[[RES_04]] {offsets = [2, 4]
+// CHECK-DAG:       %[[RES_06:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_06]], %[[RES_05]] {offsets = [3, 0]
+// CHECK-DAG:       %[[RES_07:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_07]], %[[RES_06]] {offsets = [3, 4]
+// CHECK-DAG:       %[[RES_08:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_08]], %[[RES_07]] {offsets = [4, 0]
+// CHECK-DAG:       %[[RES_09:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_09]], %[[RES_08]] {offsets = [4, 4]
+// CHECK-DAG:       %[[RES_10:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_10]], %[[RES_09]] {offsets = [5, 0]
+// CHECK-DAG:       %[[RES_11:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_11]], %[[RES_10]] {offsets = [5, 4]
+// CHECK-DAG:       %[[RES_12:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_12]], %[[RES_11]] {offsets = [6, 0]
+// CHECK-DAG:       %[[RES_13:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_13]], %[[RES_12]] {offsets = [6, 4]
+// CHECK-DAG:       %[[RES_14:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_14]], %[[RES_13]] {offsets = [7, 0]
+// CHECK-DAG:       %[[RES_15:.*]]        = vector.insert_strided_slice %[[SDOT_CHUNK_15]], %[[RES_14]] {offsets = [7, 4]
+// CHECK:           return %[[RES_15]]
+func @vector_i8i8i32matmul(
+    %lhs: vector<8x4xi8>,
+    %rhs: vector<8x4xi8>,
+    %acc: vector<8x8xi32>) -> vector<8x8xi32> {
+  %lhs_wide = arith.extsi %lhs : vector<8x4xi8> to vector<8x4xi32>
+  %rhs_wide = arith.extsi %rhs : vector<8x4xi8> to vector<8x4xi32>
+  %res = vector.contract {
+      indexing_maps = [
+          affine_map<(d0, d1, d2) -> (d0, d2)>,
+          affine_map<(d0, d1, d2) -> (d1, d2)>,
+          affine_map<(d0, d1, d2) -> (d0, d1)>
+      ], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
+  } %lhs_wide, %rhs_wide, %acc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
+  return %res : vector<8x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_f32f32f32matmul(
+func @vector_f32f32f32matmul(
+    %lhs: vector<8x4xf32>,
+    %rhs: vector<8x4xf32>,
+    %acc: vector<8x8xf32>) -> vector<8x8xf32> {
+  // CHECK: vector.contract
+  %res = vector.contract {
+      indexing_maps = [
+          affine_map<(d0, d1, d2) -> (d0, d2)>,
+          affine_map<(d0, d1, d2) -> (d1, d2)>,
+          affine_map<(d0, d1, d2) -> (d0, d1)>
+      ], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
+  } %lhs, %rhs, %acc : vector<8x4xf32>, vector<8x4xf32> into vector<8x8xf32>
+  return %res : vector<8x8xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @vector_i32i32i32matmul(
+func @vector_i32i32i32matmul(
+    %lhs: vector<8x4xi32>,
+    %rhs: vector<8x4xi32>,
+    %acc: vector<8x8xi32>) -> vector<8x8xi32> {
+  // CHECK: vector.contract
+  %res = vector.contract {
+      indexing_maps = [
+          affine_map<(d0, d1, d2) -> (d0, d2)>,
+          affine_map<(d0, d1, d2) -> (d1, d2)>,
+          affine_map<(d0, d1, d2) -> (d0, d1)>
+      ], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
+  } %lhs, %rhs, %acc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
+  return %res : vector<8x8xi32>
+}
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index d7a4e3b..9da0bba 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -200,6 +200,8 @@
   bool aarch64 = false;
   // Under aarch64: indicates dot-product extension (SDOT, UDOT)
   bool dotprod = false;
+  // Indicates that intrinsics should be used rather than inline asm
+  bool intrinsics = false;
 };
 
 // Populate target_info fields from the parent HAL::ExecutableVariantOp.
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index f025df1..7b90474 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -177,6 +177,9 @@
     Option<"dotprod", "dotprod", "bool",
             /*default=*/"false",
            "Under aarch64, enable kernels that use dotprod instructions">,
+    Option<"intrinsics", "intrinsics", "bool",
+            /*default=*/"false",
+           "Under aarch64, enable kernels that use dotprod instructions">,
   ];
 }
 
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
index 045cea0..7755e97 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD
@@ -57,6 +57,7 @@
         "@llvm-project//llvm:WebAssemblyCodeGen",
         "@llvm-project//llvm:X86AsmParser",
         "@llvm-project//llvm:X86CodeGen",
+        "@llvm-project//mlir:ArmNeon",
         "@llvm-project//mlir:LLVMDialect",
         "@llvm-project//mlir:LLVMToLLVMIRTranslation",
         "@llvm-project//mlir:ToLLVMIRTranslation",
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
index 6455ef5..d73ca81 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt
@@ -43,6 +43,7 @@
     LLVMWebAssemblyCodeGen
     LLVMX86AsmParser
     LLVMX86CodeGen
+    MLIRArmNeon
     MLIRLLVMIR
     MLIRLLVMToLLVMIRTranslation
     MLIRTargetLLVMIRExport
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
index e4ef1f2..0d927e2 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -24,6 +24,7 @@
 #include "llvm/Linker/Linker.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/TargetSelect.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Export.h"
@@ -143,7 +144,9 @@
 
   void getDependentDialects(DialectRegistry &registry) const override {
     mlir::registerLLVMDialectTranslation(registry);
-    registry.insert<IREE::Codegen::IREECodegenDialect>();
+    // TODO: make inclusion of ArmNeon conditional?
+    registry
+        .insert<IREE::Codegen::IREECodegenDialect, arm_neon::ArmNeonDialect>();
   }
 
   IREE::HAL::DeviceTargetAttr getDefaultDeviceTarget(
diff --git a/iree/test/e2e/regression/BUILD b/iree/test/e2e/regression/BUILD
index c1ea41e..c57be78 100644
--- a/iree/test/e2e/regression/BUILD
+++ b/iree/test/e2e/regression/BUILD
@@ -130,6 +130,7 @@
     "f32",
 ]]
 
+# Test asm
 [iree_generated_trace_runner_test(
     name = "e2e_matmul_mmt4d_%s_small" % lhs_rhs_type,
     generator = ":generate_e2e_matmul_tests",
@@ -177,3 +178,32 @@
     "i8",
     "f32",
 ]]
+
+# Test intrinsics. No need to run vmvx again, since it isn't affected by this
+# codegen flag.
+[iree_generated_trace_runner_test(
+    name = "e2e_matmul_mmt4d_%s_intrinsics_%s" % (lhs_rhs_type, size),
+    compiler_flags = ["--iree-codegen-mmt4d-use-intrinsics"],
+    generator = ":generate_e2e_matmul_tests",
+    generator_args = [
+        "--lhs_rhs_type=%s" % lhs_rhs_type,
+        "--shapes=%s" % size,
+    ],
+    opt_flags = [
+        "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=%d N0=8" % (4 if lhs_rhs_type == "i8" else 1),
+    ],
+    target_backends_and_drivers = [
+        ("dylib-llvm-aot", "dylib"),
+    ],
+    target_cpu_features_variants = [
+        "default",
+        "aarch64:+dotprod",
+    ],
+    trace_runner = "//iree/tools:iree-e2e-matmul-test",
+) for lhs_rhs_type in [
+    "i8",
+    "f32",
+] for size in [
+    "small",
+    "large",
+]]
diff --git a/iree/test/e2e/regression/CMakeLists.txt b/iree/test/e2e/regression/CMakeLists.txt
index 1f47f4a..d3aedfd 100644
--- a/iree/test/e2e/regression/CMakeLists.txt
+++ b/iree/test/e2e/regression/CMakeLists.txt
@@ -220,4 +220,96 @@
     "aarch64:+dotprod"
 )
 
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_mmt4d_i8_intrinsics_small
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=i8"
+    "--shapes=small"
+  TRACE_RUNNER
+    iree_tools_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "dylib-llvm-aot"
+  DRIVERS
+    "dylib"
+  COMPILER_FLAGS
+    "--iree-codegen-mmt4d-use-intrinsics"
+  OPT_FLAGS
+    "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "aarch64:+dotprod"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_mmt4d_i8_intrinsics_large
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=i8"
+    "--shapes=large"
+  TRACE_RUNNER
+    iree_tools_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "dylib-llvm-aot"
+  DRIVERS
+    "dylib"
+  COMPILER_FLAGS
+    "--iree-codegen-mmt4d-use-intrinsics"
+  OPT_FLAGS
+    "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "aarch64:+dotprod"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_mmt4d_f32_intrinsics_small
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f32"
+    "--shapes=small"
+  TRACE_RUNNER
+    iree_tools_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "dylib-llvm-aot"
+  DRIVERS
+    "dylib"
+  COMPILER_FLAGS
+    "--iree-codegen-mmt4d-use-intrinsics"
+  OPT_FLAGS
+    "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "aarch64:+dotprod"
+)
+
+iree_generated_trace_runner_test(
+  NAME
+    e2e_matmul_mmt4d_f32_intrinsics_large
+  GENERATOR
+    "generate_e2e_matmul_tests.py"
+  GENERATOR_ARGS
+    "--lhs_rhs_type=f32"
+    "--shapes=large"
+  TRACE_RUNNER
+    iree_tools_iree-e2e-matmul-test
+  TARGET_BACKENDS
+    "dylib-llvm-aot"
+  DRIVERS
+    "dylib"
+  COMPILER_FLAGS
+    "--iree-codegen-mmt4d-use-intrinsics"
+  OPT_FLAGS
+    "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
+  TARGET_CPU_FEATURES_VARIANTS
+    "default"
+    "aarch64:+dotprod"
+)
+
 ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/test/e2e/regression/generate_e2e_matmul_tests.py b/iree/test/e2e/regression/generate_e2e_matmul_tests.py
index caffcdf..605c588 100644
--- a/iree/test/e2e/regression/generate_e2e_matmul_tests.py
+++ b/iree/test/e2e/regression/generate_e2e_matmul_tests.py
@@ -91,6 +91,8 @@
         TestShape(m=2, k=3, n=4),
         #TestShape(m=8, k=7, n=6),
         #TestShape(m=15, k=16, n=17),
+        # Exactly the mmt4d kernel size
+        TestShape(m=8, k=4, n=8),
         TestShape(m=14, k=19, n=23),
         #TestShape(m=31, k=33, n=32),
         TestShape(m=25, k=41, n=35),
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index df1e5ec..bba42ff 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -135,6 +135,8 @@
     deps = [
         "@llvm-project//mlir:Affine",
         "@llvm-project//mlir:AffineTransforms",
+        "@llvm-project//mlir:ArmNeon",
+        "@llvm-project//mlir:ArmNeon2dToIntr",
         "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:ConversionPasses",
         "@llvm-project//mlir:GPUDialect",
@@ -389,6 +391,7 @@
         "//iree/compiler/Translation:HALExecutable",
         "//iree/compiler/Translation:IREEVM",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ArmNeonToLLVMIRTranslation",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LLVMToLLVMIRTranslation",
         "@llvm-project//mlir:Pass",
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 58cfada..3c6caff 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -275,6 +275,8 @@
       MLIRTosa
       MLIRTosaTransforms
       MLIRTransforms
+      MLIRArmNeon
+      MLIRArmNeon2dToIntr
       MLIRVector
     PUBLIC
   )
@@ -373,6 +375,7 @@
       ::init_xla_dialects
       LLVMSupport
       MLIRIR
+      MLIRArmNeonToLLVMIRTranslation
       MLIRLLVMToLLVMIRTranslation
       MLIRSCFTransforms
       MLIRPass
diff --git a/iree/tools/init_mlir_dialects.h b/iree/tools/init_mlir_dialects.h
index af7a9cb..4d9ec16 100644
--- a/iree/tools/init_mlir_dialects.h
+++ b/iree/tools/init_mlir_dialects.h
@@ -13,6 +13,7 @@
 #define IREE_TOOLS_INIT_MLIR_DIALECTS_H_
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -49,6 +50,7 @@
                   scf::SCFDialect,
                   quant::QuantizationDialect,
                   spirv::SPIRVDialect,
+                  arm_neon::ArmNeonDialect,
                   StandardOpsDialect,
                   mlir::arith::ArithmeticDialect,
                   vector::VectorDialect,
diff --git a/iree/tools/init_mlir_passes.h b/iree/tools/init_mlir_passes.h
index 4d6c5f3..5635714 100644
--- a/iree/tools/init_mlir_passes.h
+++ b/iree/tools/init_mlir_passes.h
@@ -62,6 +62,9 @@
   // Linalg
   registerLinalgPasses();
 
+  // LLVM
+  registerConvertArmNeon2dToIntrPass();
+
   // MemRef
   memref::registerMemRefPasses();
 
diff --git a/iree/tools/iree_translate_lib.cc b/iree/tools/iree_translate_lib.cc
index f12725f..0fe1b8e 100644
--- a/iree/tools/iree_translate_lib.cc
+++ b/iree/tools/iree_translate_lib.cc
@@ -36,6 +36,7 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/Timing.h"
 #include "mlir/Support/ToolUtilities.h"
+#include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
 #include "mlir/Translation.h"
 
@@ -44,6 +45,8 @@
   mlir::DialectRegistry registry;
   mlir::registerMlirDialects(registry);
   mlir::registerLLVMDialectTranslation(registry);
+  // TODO: Make this conditional?
+  mlir::registerArmNeonDialectTranslation(registry);
   mlir::registerXLADialects(registry);
   mlir::iree_compiler::registerAllPasses();
   mlir::iree_compiler::registerIreeDialects(registry);