Add pass to convert memref to memref of vector. (#3028)

Add a new pass to convert memref to memref of vector in order to allow vector load/stores to be used.
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
index 9c8a1f6..5b6bbf2 100644
--- a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
+++ b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
@@ -82,6 +82,9 @@
   pm.addPass(mlir::createLegalizeStdOpsForSPIRVLoweringPass());
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addPass(mlir::createCSEPass());
+  pm.addPass(mlir::iree_compiler::createVectorizeMemref());
+  pm.addPass(mlir::createCanonicalizerPass());
+  pm.addPass(mlir::createCSEPass());
   pm.addPass(mlir::iree_compiler::createConvertToSPIRVPass());
 
   auto &spirvModulePM = pm.nest<mlir::spirv::ModuleOp>();
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index ebb0e6a..6a30713 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -31,6 +31,7 @@
         "SplitDispatchFunctionPass.cpp",
         "Utils.cpp",
         "VectorToGPUPass.cpp",
+        "VectorizeMemref.cpp",
     ],
     hdrs = [
         "Attributes.h",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index 77323e0..8e34e87 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -35,6 +35,7 @@
     "SplitDispatchFunctionPass.cpp"
     "Utils.cpp"
     "VectorToGPUPass.cpp"
+    "VectorizeMemref.cpp"
   DEPS
     LLVMSupport
     MLIRAffineOps
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index b08890e..72872f5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -65,6 +65,11 @@
 /// Pass to apply tiling and vectorization transformations on linagl::MatMulOp.
 std::unique_ptr<FunctionPass> createMatMulTileAndVectorizeGPUPass();
 
+/// Convert memref of scalar to memref of vector of efficent size. This will
+/// allow to convert memory accesses to vector load/store in SPIR-V without
+/// having pointer bitcast.
+std::unique_ptr<OperationPass<ModuleOp>> createVectorizeMemref();
+
 /// Populates passes needed to lower a XLA HLO op to SPIR-V dialect via the
 /// structured ops path. The pass manager `pm` in here operate on the module
 /// within the IREE::HAL::ExecutableOp. The `workGroupSize` can be used to
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
new file mode 100644
index 0000000..7ff1530
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
@@ -0,0 +1,375 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===----------------------------------------------------------------------===//
+//
+// Pass to convert memref into memref of vector.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/StandardTypes.h"
+
+constexpr int kVectorizationSizeInBits = 128;
+constexpr int kVecSize = kVectorizationSizeInBits / (sizeof(float) * 8);
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Returns true if all uses are transfer read/write operations. If it returns
+/// true also return the uses of memref.
+static bool getUsesIfAllTransferOp(Value v,
+                                   SmallVectorImpl<Operation *> &uses) {
+  assert(uses.empty() && "expected uses to be empty");
+  for (Operation *userOp : v.getUsers()) {
+    if (isa<DeallocOp>(userOp)) continue;
+    // Only vectorize memref used by vector transfer ops.
+    if (!isa<vector::TransferReadOp, vector::TransferWriteOp>(userOp)) {
+      uses.clear();
+      return false;
+    }
+    uses.push_back(userOp);
+  }
+  return true;
+}
+
+/// Returns true of the type is a memref that can be vectorized to
+/// vector<4xi32>. If it returns true also return the uses of memref.
+static bool isMemRefAndVectorizable(Value v,
+                                    SmallVectorImpl<Operation *> &uses) {
+  auto memrefType = v.getType().dyn_cast<MemRefType>();
+  // To be able to vectorize the memref it needs to be a scalar memref with a
+  // static most inner dimension aligned on the vectorization size.
+  return memrefType && !memrefType.getElementType().isa<VectorType>() &&
+         (kVectorizationSizeInBits % memrefType.getElementTypeBitWidth() ==
+          0) &&
+         !ShapedType::isDynamic(memrefType.getShape().back()) &&
+         ((memrefType.getElementTypeBitWidth() * memrefType.getShape().back()) %
+              kVectorizationSizeInBits ==
+          0) &&
+         getUsesIfAllTransferOp(v, uses);
+}
+
+/// Returns the bitwidth of a scalar or vector type.
+static Optional<unsigned> getBitWidth(Type type) {
+  if (type.isIntOrFloat()) {
+    return type.getIntOrFloatBitWidth();
+  } else if (type.isa<VectorType>()) {
+    auto vecType = type.cast<VectorType>();
+    auto elementType = vecType.getElementType();
+    return elementType.getIntOrFloatBitWidth() * vecType.getNumElements();
+  }
+  return {};
+}
+
+namespace {
+/// Analyze memref usages to decide if it should be vectorized. Right now the
+/// logic is to vectorize memref only if it is used by
+/// vectortransfer_read/vectortransfer_write operations.
+class MemRefUsageAnalysis {
+ public:
+  explicit MemRefUsageAnalysis(mlir::Operation *);
+
+  // Returns true if the memref should be converted to a vector of memref.
+  bool vectorizeMemRef(Value v) const { return vectorize.count(v); }
+  // Returns true if the transfer operation needs to be updated during memref
+  // vectorization.
+  bool transferConvert(Operation *op) const { return transferOps.count(op); }
+
+ private:
+  void analyzeFunc(FuncOp funcOp);
+  void analyzeAlloc(AllocOp allocOp);
+  void analyzePlaceholder(IREE::PlaceholderOp placeholderOp);
+  llvm::DenseSet<Value> vectorize;
+  llvm::DenseSet<Operation *> transferOps;
+};
+
+MemRefUsageAnalysis::MemRefUsageAnalysis(mlir::Operation *op) {
+  op->walk([&](Operation *op) {
+    if (auto func = dyn_cast<FuncOp>(op)) analyzeFunc(func);
+    if (auto alloc = dyn_cast<AllocOp>(op)) analyzeAlloc(alloc);
+    if (auto placeholder = dyn_cast<IREE::PlaceholderOp>(op))
+      analyzePlaceholder(placeholder);
+  });
+}
+
+void MemRefUsageAnalysis::analyzeFunc(FuncOp funcOp) {
+  for (Value arg : funcOp.getArguments()) {
+    SmallVector<Operation *, 4> vectorUses;
+    if (isMemRefAndVectorizable(arg, vectorUses)) {
+      vectorize.insert(arg);
+      transferOps.insert(vectorUses.begin(), vectorUses.end());
+    }
+  }
+}
+
+void MemRefUsageAnalysis::analyzePlaceholder(
+    IREE::PlaceholderOp placeholderOp) {
+  SmallVector<Operation *, 4> vectorUses;
+  if (isMemRefAndVectorizable(placeholderOp, vectorUses)) {
+    vectorize.insert(placeholderOp);
+    transferOps.insert(vectorUses.begin(), vectorUses.end());
+  }
+}
+
+void MemRefUsageAnalysis::analyzeAlloc(AllocOp allocOp) {
+  SmallVector<Operation *, 4> vectorUses;
+  if (isMemRefAndVectorizable(allocOp, vectorUses)) {
+    vectorize.insert(allocOp);
+    transferOps.insert(vectorUses.begin(), vectorUses.end());
+  }
+}
+
+template <typename OpTy>
+class MemRefConversionPattern : public OpConversionPattern<OpTy> {
+ public:
+  MemRefConversionPattern<OpTy>(MLIRContext *context,
+                                const MemRefUsageAnalysis &memrefUsageAnalysis)
+      : OpConversionPattern<OpTy>::OpConversionPattern(context),
+        memrefUsageAnalysis(memrefUsageAnalysis) {}
+
+ protected:
+  const MemRefUsageAnalysis &memrefUsageAnalysis;
+};
+
+class ProcessFuncArg final : public MemRefConversionPattern<FuncOp> {
+ public:
+  using MemRefConversionPattern<FuncOp>::MemRefConversionPattern;
+  LogicalResult matchAndRewrite(
+      FuncOp funcOp, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override;
+};
+
+class ProcessTransferRead final
+    : public MemRefConversionPattern<vector::TransferReadOp> {
+ public:
+  using MemRefConversionPattern<
+      vector::TransferReadOp>::MemRefConversionPattern;
+  LogicalResult matchAndRewrite(
+      vector::TransferReadOp read, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    if (!memrefUsageAnalysis.transferConvert(read)) return failure();
+    vector::TransferReadOp::Adaptor adaptor(operands);
+    Value memref = adaptor.memref();
+    Location loc = read.getLoc();
+    Optional<unsigned> vecMemrefElemSize =
+        getBitWidth(memref.getType().cast<MemRefType>().getElementType());
+    Optional<unsigned> readElemSize =
+        getBitWidth(read.getMemRefType().getElementType());
+    Optional<unsigned> readVecSize = getBitWidth(read.getVectorType());
+    if (!vecMemrefElemSize || !readElemSize || !readVecSize) return failure();
+    unsigned ratio = *vecMemrefElemSize / *readElemSize;
+    SmallVector<Value, 4> indices(adaptor.indices().begin(),
+                                  adaptor.indices().end());
+    indices.back() = rewriter.create<SignedDivIOp>(
+        loc, indices.back(), rewriter.create<ConstantIndexOp>(loc, ratio));
+    // If the transfer_read can be replaced by a load after vectorization use
+    // LoadOp and cast back to the original type.
+    if (*vecMemrefElemSize == *readVecSize) {
+      Type elemType = memref.getType().cast<MemRefType>().getElementType();
+      Value newLoad = rewriter.create<LoadOp>(loc, elemType, memref, indices);
+      Type serializedVecType =
+          VectorType::get(read.getVectorType().getNumElements(),
+                          read.getVectorType().getElementType());
+      newLoad =
+          rewriter.create<vector::BitCastOp>(loc, serializedVecType, newLoad);
+      newLoad = rewriter.create<vector::ShapeCastOp>(loc, read.getVectorType(),
+                                                     newLoad);
+      rewriter.replaceOp(read, newLoad);
+    } else {
+      Value newRead = rewriter.create<vector::TransferReadOp>(
+          loc, read.getVectorType(), memref, indices);
+      rewriter.replaceOp(read, newRead);
+    }
+    return success();
+  }
+};
+
+class ProcessTransferWrite final
+    : public MemRefConversionPattern<vector::TransferWriteOp> {
+ public:
+  using MemRefConversionPattern<
+      vector::TransferWriteOp>::MemRefConversionPattern;
+  LogicalResult matchAndRewrite(
+      vector::TransferWriteOp write, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    if (!memrefUsageAnalysis.transferConvert(write)) return failure();
+    vector::TransferWriteOp::Adaptor adaptor(operands);
+    Value memref = adaptor.memref();
+    Location loc = write.getLoc();
+    Optional<unsigned> vecMemrefElemSize =
+        getBitWidth(memref.getType().cast<MemRefType>().getElementType());
+    Optional<unsigned> writeElemSize =
+        getBitWidth(write.getMemRefType().getElementType());
+    Optional<unsigned> writeVecSize = getBitWidth(write.getVectorType());
+    if (!vecMemrefElemSize || !writeElemSize || !writeVecSize) return failure();
+    unsigned ratio = *vecMemrefElemSize / *writeElemSize;
+    SmallVector<Value, 4> indices(adaptor.indices());
+    indices.back() = rewriter.create<SignedDivIOp>(
+        loc, indices.back(), rewriter.create<ConstantIndexOp>(loc, ratio));
+    // If the transfer_write can be replaced by a store after vectorization cast
+    // the original value and use StoreOp.
+    if (*vecMemrefElemSize == *writeVecSize) {
+      Type serializedVecType =
+          VectorType::get(write.getVectorType().getNumElements(),
+                          write.getVectorType().getElementType());
+      Value data = rewriter.create<vector::ShapeCastOp>(loc, serializedVecType,
+                                                        adaptor.vector());
+      data = rewriter.create<vector::BitCastOp>(
+          loc, memref.getType().cast<MemRefType>().getElementType(), data);
+      rewriter.create<StoreOp>(loc, data, memref, indices);
+    } else {
+      rewriter.create<vector::TransferWriteOp>(loc, adaptor.vector(), memref,
+                                               indices);
+    }
+    rewriter.eraseOp(write);
+    return success();
+  }
+};
+
+static Optional<MemRefType> getVectorizedMemRefType(
+    ConversionPatternRewriter &rewriter, MemRefType type) {
+  unsigned elemSize = type.getElementTypeBitWidth();
+  unsigned vecSize = kVectorizationSizeInBits / elemSize;
+  // Pick a new type of element size 32bits.
+  Type newElemType = type.getElementType().isa<IntegerType>()
+                         ? rewriter.getI32Type().cast<Type>()
+                         : rewriter.getF32Type().cast<Type>();
+  Type vecType = VectorType::get(kVecSize, newElemType);
+  SmallVector<int64_t, 2> newShape(type.getShape().begin(),
+                                   type.getShape().end());
+  if (newShape.back() % vecSize != 0) return {};
+  newShape.back() = newShape.back() / vecSize;
+  return MemRefType::get(newShape, vecType, {}, type.getMemorySpace());
+}
+
+class ProcessAlloc final : public MemRefConversionPattern<AllocOp> {
+ public:
+  using MemRefConversionPattern<AllocOp>::MemRefConversionPattern;
+  LogicalResult matchAndRewrite(
+      AllocOp alloc, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto memrefType = getVectorizedMemRefType(rewriter, alloc.getType());
+    if (!memrefType) return failure();
+    Value newAlloc =
+        rewriter.create<AllocOp>(alloc.getLoc(), *memrefType, alloc.value());
+    rewriter.replaceOp(alloc, newAlloc);
+    return success();
+  }
+};
+
+class ProcessIreeBinding final
+    : public MemRefConversionPattern<IREE::PlaceholderOp> {
+ public:
+  using MemRefConversionPattern<IREE::PlaceholderOp>::MemRefConversionPattern;
+  LogicalResult matchAndRewrite(
+      IREE::PlaceholderOp placeholder, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto memrefType = placeholder.getType().dyn_cast<MemRefType>();
+    if (!memrefType) return failure();
+    auto vecMemRef = getVectorizedMemRefType(rewriter, memrefType);
+    if (!vecMemRef) return failure();
+    ValueRange dummyOperands;
+    Value newPlaceholder = rewriter.create<IREE::PlaceholderOp>(
+        placeholder.getLoc(), *vecMemRef, dummyOperands,
+        placeholder.getAttrs());
+    rewriter.replaceOp(placeholder, newPlaceholder);
+    return success();
+  }
+};
+
+class VectorizeMemRefPass final
+    : public PassWrapper<VectorizeMemRefPass, OperationPass<ModuleOp>> {
+  void runOnOperation() override;
+
+ private:
+  MemRefUsageAnalysis *memrefUsageAnalysis = nullptr;
+};
+}  // namespace
+
+LogicalResult ProcessFuncArg::matchAndRewrite(
+    FuncOp funcOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  TypeConverter::SignatureConversion signatureConverter(
+      funcOp.getType().getNumInputs());
+  TypeConverter typeConverter;
+  for (const auto &arg : llvm::enumerate(funcOp.getArguments())) {
+    if (memrefUsageAnalysis.vectorizeMemRef(arg.value())) {
+      if (auto memrefType = getVectorizedMemRefType(
+              rewriter, arg.value().getType().cast<MemRefType>())) {
+        signatureConverter.addInputs(arg.index(), *memrefType);
+        continue;
+      }
+    }
+    signatureConverter.addInputs(arg.index(), arg.value().getType());
+  }
+  // Creates a new function with the update signature.
+  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), typeConverter,
+                                         &signatureConverter)))
+    return failure();
+
+  // Creates a new function with the update signature.
+  rewriter.updateRootInPlace(funcOp, [&] {
+    funcOp.setType(rewriter.getFunctionType(
+        signatureConverter.getConvertedTypes(), llvm::None));
+  });
+  return success();
+}
+
+void VectorizeMemRefPass::runOnOperation() {
+  // Uses the signature conversion methodology of the dialect conversion
+  // framework to implement the conversion.
+  ModuleOp module = getOperation();
+  MLIRContext *context = &getContext();
+  memrefUsageAnalysis = &getAnalysis<MemRefUsageAnalysis>();
+
+  OwningRewritePatternList patterns;
+  patterns.insert<ProcessFuncArg, ProcessTransferRead, ProcessTransferWrite,
+                  ProcessAlloc, ProcessIreeBinding>(context,
+                                                    *memrefUsageAnalysis);
+
+  ConversionTarget target(*context);
+  target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+    return llvm::all_of(op.getArguments(), [&](Value arg) {
+      return !memrefUsageAnalysis->vectorizeMemRef(arg);
+    });
+  });
+  target.addDynamicallyLegalOp<AllocOp>([&](AllocOp alloc) {
+    return !memrefUsageAnalysis->vectorizeMemRef(alloc);
+  });
+  target.addDynamicallyLegalOp<IREE::PlaceholderOp>(
+      [&](IREE::PlaceholderOp placeholder) {
+        return !memrefUsageAnalysis->vectorizeMemRef(placeholder);
+      });
+  target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+    if (isa<vector::TransferWriteOp, vector::TransferReadOp>(op))
+      return !memrefUsageAnalysis->transferConvert(op);
+    return true;
+  });
+  if (failed(applyPartialConversion(module, target, patterns)))
+    return signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> createVectorizeMemref() {
+  return std::make_unique<VectorizeMemRefPass>();
+}
+
+static PassRegistration<VectorizeMemRefPass> pass(
+    "iree-spirv-vectorize-memref",
+    "Vectorize memref arguments and allocations");
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir
new file mode 100644
index 0000000..9a87e8c
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir
@@ -0,0 +1,61 @@
+// RUN: iree-opt -split-input-file -iree-spirv-vectorize-memref -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: func @copy
+//  CHECK-SAME: (%[[ARG0:.+]]: memref<4096x1024xvector<4xf32>>
+//       CHECK: %[[ALLOC:.+]] = alloc() : memref<128x8xvector<4xf32>, 3>
+//       CHECK: %[[V:.+]] = load %[[ARG0]][%{{.*}}, %{{.*}}] : memref<4096x1024xvector<4xf32>>
+//       CHECK: store %[[V]], %[[ALLOC]][%{{.*}}, %{{.*}}] : memref<128x8xvector<4xf32>, 3>
+//       CHECK: %[[MAT:.+]] = vector.transfer_read %[[ARG0]][%{{.*}}, %{{.*}}], %{{.*}} : memref<4096x1024xvector<4xf32>>, vector<32x8xf32>
+//       CHECK: vector.transfer_write %[[MAT]], %[[ALLOC]][%{{.*}}, %{{.*}}] : vector<32x8xf32>, memref<128x8xvector<4xf32>, 3>
+//       CHECK: dealloc %[[ALLOC]] : memref<128x8xvector<4xf32>, 3>
+func @copy(%arg0: memref<4096x4096xf32>, %x: index, %y: index) {
+  %cst = constant 0.000000e+00 : f32
+  %0 = alloc() : memref<128x32xf32, 3>
+  %v = vector.transfer_read %arg0[%x, %y], %cst : memref<4096x4096xf32>, vector<1x4xf32>
+  vector.transfer_write %v, %0[%x, %y] : vector<1x4xf32>, memref<128x32xf32, 3>
+  %mat = vector.transfer_read %arg0[%x, %y], %cst : memref<4096x4096xf32>, vector<32x8xf32>
+  vector.transfer_write %mat, %0[%x, %y] : vector<32x8xf32>, memref<128x32xf32, 3>
+  dealloc %0 : memref<128x32xf32, 3>
+  return
+}
+
+// -----
+
+// Test that the memref is not vectorized if used by scalar load or store.
+// CHECK-LABEL: func @copy
+//  CHECK-SAME: %[[ARG0:.+]]: memref<4096x4096xf32>
+func @copy(%arg0: memref<4096x4096xf32>, %x: index, %y: index) {
+  %cst = constant 0.000000e+00 : f32
+  %0 = alloc() : memref<128x32xf32, 3>
+  %s = load %arg0[%x, %y] : memref<4096x4096xf32>
+  store %s, %0[%x, %y] : memref<128x32xf32, 3>
+  dealloc %0 : memref<128x32xf32, 3>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @resource_copy
+//     CHECK: %[[A:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4096x1024xvector<4xf32>>
+//     CHECK: %[[B:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4096x1024xvector<4xf32>>
+//     CHECK: %[[V:.+]] = load %[[A]][%{{.*}}, %{{.*}}] : memref<4096x1024xvector<4xf32>>
+//     CHECK: store %[[V]], %[[B]][%{{.*}}, %{{.*}}] : memref<4096x1024xvector<4xf32>>
+//     CHECK: %[[MAT:.+]] = vector.transfer_read %[[A]][%{{.*}}, %{{.*}}], %{{.*}} : memref<4096x1024xvector<4xf32>>, vector<32x8xf32>
+//     CHECK: vector.transfer_write %[[MAT]], %[[B]][%{{.*}}, %{{.*}}] {{.*}} : vector<32x8xf32>, memref<4096x1024xvector<4xf32>>
+func @resource_copy() {
+  %cst = constant 0.000000e+00 : f32
+  %c0 = constant 0 : index
+  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4096x4096xf32>
+  %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4096x4096xf32>
+  %v = vector.transfer_read %0[%c0, %c0], %cst : memref<4096x4096xf32>, vector<1x4xf32>
+  vector.transfer_write %v, %1[%c0, %c0] : vector<1x4xf32>, memref<4096x4096xf32>
+  %mat = vector.transfer_read %0[%c0, %c0], %cst : memref<4096x4096xf32>, vector<32x8xf32>
+  vector.transfer_write %mat, %1[%c0, %c0] : vector<32x8xf32>, memref<4096x4096xf32>
+  return
+}
+
+hal.interface @legacy_io attributes {push_constants = 5 : i32, sym_visibility = "private"} {
+  hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read"
+  hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write"
+}
+
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 74eaf71..2978c69 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -41,6 +41,7 @@
     createSplitDispatchFunctionPass();
     createVectorToGPUPass();
     createMatMulTileAndVectorizeGPUPass();
+    createVectorizeMemref();
     return true;
   }();
   (void)init_once;