Add pass to convert memref to memref of vector. (#3028)
Add a new pass to convert memref to memref of vector in order to allow vector load/stores to be used.
diff --git a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
index 9c8a1f6..5b6bbf2 100644
--- a/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
+++ b/experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
@@ -82,6 +82,9 @@
pm.addPass(mlir::createLegalizeStdOpsForSPIRVLoweringPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
+ pm.addPass(mlir::iree_compiler::createVectorizeMemref());
+ pm.addPass(mlir::createCanonicalizerPass());
+ pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::iree_compiler::createConvertToSPIRVPass());
auto &spirvModulePM = pm.nest<mlir::spirv::ModuleOp>();
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index ebb0e6a..6a30713 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -31,6 +31,7 @@
"SplitDispatchFunctionPass.cpp",
"Utils.cpp",
"VectorToGPUPass.cpp",
+ "VectorizeMemref.cpp",
],
hdrs = [
"Attributes.h",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index 77323e0..8e34e87 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -35,6 +35,7 @@
"SplitDispatchFunctionPass.cpp"
"Utils.cpp"
"VectorToGPUPass.cpp"
+ "VectorizeMemref.cpp"
DEPS
LLVMSupport
MLIRAffineOps
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index b08890e..72872f5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -65,6 +65,11 @@
/// Pass to apply tiling and vectorization transformations on linagl::MatMulOp.
std::unique_ptr<FunctionPass> createMatMulTileAndVectorizeGPUPass();
+/// Convert memref of scalar to memref of vector of efficent size. This will
+/// allow to convert memory accesses to vector load/store in SPIR-V without
+/// having pointer bitcast.
+std::unique_ptr<OperationPass<ModuleOp>> createVectorizeMemref();
+
/// Populates passes needed to lower a XLA HLO op to SPIR-V dialect via the
/// structured ops path. The pass manager `pm` in here operate on the module
/// within the IREE::HAL::ExecutableOp. The `workGroupSize` can be used to
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
new file mode 100644
index 0000000..7ff1530
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
@@ -0,0 +1,375 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===----------------------------------------------------------------------===//
+//
+// Pass to convert memref into memref of vector.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/StandardTypes.h"
+
+constexpr int kVectorizationSizeInBits = 128;
+constexpr int kVecSize = kVectorizationSizeInBits / (sizeof(float) * 8);
+
+namespace mlir {
+namespace iree_compiler {
+
+/// Returns true if all uses are transfer read/write operations. If it returns
+/// true also return the uses of memref.
+static bool getUsesIfAllTransferOp(Value v,
+ SmallVectorImpl<Operation *> &uses) {
+ assert(uses.empty() && "expected uses to be empty");
+ for (Operation *userOp : v.getUsers()) {
+ if (isa<DeallocOp>(userOp)) continue;
+ // Only vectorize memref used by vector transfer ops.
+ if (!isa<vector::TransferReadOp, vector::TransferWriteOp>(userOp)) {
+ uses.clear();
+ return false;
+ }
+ uses.push_back(userOp);
+ }
+ return true;
+}
+
+/// Returns true of the type is a memref that can be vectorized to
+/// vector<4xi32>. If it returns true also return the uses of memref.
+static bool isMemRefAndVectorizable(Value v,
+ SmallVectorImpl<Operation *> &uses) {
+ auto memrefType = v.getType().dyn_cast<MemRefType>();
+ // To be able to vectorize the memref it needs to be a scalar memref with a
+ // static most inner dimension aligned on the vectorization size.
+ return memrefType && !memrefType.getElementType().isa<VectorType>() &&
+ (kVectorizationSizeInBits % memrefType.getElementTypeBitWidth() ==
+ 0) &&
+ !ShapedType::isDynamic(memrefType.getShape().back()) &&
+ ((memrefType.getElementTypeBitWidth() * memrefType.getShape().back()) %
+ kVectorizationSizeInBits ==
+ 0) &&
+ getUsesIfAllTransferOp(v, uses);
+}
+
+/// Returns the bitwidth of a scalar or vector type.
+static Optional<unsigned> getBitWidth(Type type) {
+ if (type.isIntOrFloat()) {
+ return type.getIntOrFloatBitWidth();
+ } else if (type.isa<VectorType>()) {
+ auto vecType = type.cast<VectorType>();
+ auto elementType = vecType.getElementType();
+ return elementType.getIntOrFloatBitWidth() * vecType.getNumElements();
+ }
+ return {};
+}
+
+namespace {
+/// Analyze memref usages to decide if it should be vectorized. Right now the
+/// logic is to vectorize memref only if it is used by
+/// vectortransfer_read/vectortransfer_write operations.
+class MemRefUsageAnalysis {
+ public:
+ explicit MemRefUsageAnalysis(mlir::Operation *);
+
+ // Returns true if the memref should be converted to a vector of memref.
+ bool vectorizeMemRef(Value v) const { return vectorize.count(v); }
+ // Returns true if the transfer operation needs to be updated during memref
+ // vectorization.
+ bool transferConvert(Operation *op) const { return transferOps.count(op); }
+
+ private:
+ void analyzeFunc(FuncOp funcOp);
+ void analyzeAlloc(AllocOp allocOp);
+ void analyzePlaceholder(IREE::PlaceholderOp placeholderOp);
+ llvm::DenseSet<Value> vectorize;
+ llvm::DenseSet<Operation *> transferOps;
+};
+
+MemRefUsageAnalysis::MemRefUsageAnalysis(mlir::Operation *op) {
+ op->walk([&](Operation *op) {
+ if (auto func = dyn_cast<FuncOp>(op)) analyzeFunc(func);
+ if (auto alloc = dyn_cast<AllocOp>(op)) analyzeAlloc(alloc);
+ if (auto placeholder = dyn_cast<IREE::PlaceholderOp>(op))
+ analyzePlaceholder(placeholder);
+ });
+}
+
+void MemRefUsageAnalysis::analyzeFunc(FuncOp funcOp) {
+ for (Value arg : funcOp.getArguments()) {
+ SmallVector<Operation *, 4> vectorUses;
+ if (isMemRefAndVectorizable(arg, vectorUses)) {
+ vectorize.insert(arg);
+ transferOps.insert(vectorUses.begin(), vectorUses.end());
+ }
+ }
+}
+
+void MemRefUsageAnalysis::analyzePlaceholder(
+ IREE::PlaceholderOp placeholderOp) {
+ SmallVector<Operation *, 4> vectorUses;
+ if (isMemRefAndVectorizable(placeholderOp, vectorUses)) {
+ vectorize.insert(placeholderOp);
+ transferOps.insert(vectorUses.begin(), vectorUses.end());
+ }
+}
+
+void MemRefUsageAnalysis::analyzeAlloc(AllocOp allocOp) {
+ SmallVector<Operation *, 4> vectorUses;
+ if (isMemRefAndVectorizable(allocOp, vectorUses)) {
+ vectorize.insert(allocOp);
+ transferOps.insert(vectorUses.begin(), vectorUses.end());
+ }
+}
+
+template <typename OpTy>
+class MemRefConversionPattern : public OpConversionPattern<OpTy> {
+ public:
+ MemRefConversionPattern<OpTy>(MLIRContext *context,
+ const MemRefUsageAnalysis &memrefUsageAnalysis)
+ : OpConversionPattern<OpTy>::OpConversionPattern(context),
+ memrefUsageAnalysis(memrefUsageAnalysis) {}
+
+ protected:
+ const MemRefUsageAnalysis &memrefUsageAnalysis;
+};
+
+class ProcessFuncArg final : public MemRefConversionPattern<FuncOp> {
+ public:
+ using MemRefConversionPattern<FuncOp>::MemRefConversionPattern;
+ LogicalResult matchAndRewrite(
+ FuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+class ProcessTransferRead final
+ : public MemRefConversionPattern<vector::TransferReadOp> {
+ public:
+ using MemRefConversionPattern<
+ vector::TransferReadOp>::MemRefConversionPattern;
+ LogicalResult matchAndRewrite(
+ vector::TransferReadOp read, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!memrefUsageAnalysis.transferConvert(read)) return failure();
+ vector::TransferReadOp::Adaptor adaptor(operands);
+ Value memref = adaptor.memref();
+ Location loc = read.getLoc();
+ Optional<unsigned> vecMemrefElemSize =
+ getBitWidth(memref.getType().cast<MemRefType>().getElementType());
+ Optional<unsigned> readElemSize =
+ getBitWidth(read.getMemRefType().getElementType());
+ Optional<unsigned> readVecSize = getBitWidth(read.getVectorType());
+ if (!vecMemrefElemSize || !readElemSize || !readVecSize) return failure();
+ unsigned ratio = *vecMemrefElemSize / *readElemSize;
+ SmallVector<Value, 4> indices(adaptor.indices().begin(),
+ adaptor.indices().end());
+ indices.back() = rewriter.create<SignedDivIOp>(
+ loc, indices.back(), rewriter.create<ConstantIndexOp>(loc, ratio));
+ // If the transfer_read can be replaced by a load after vectorization use
+ // LoadOp and cast back to the original type.
+ if (*vecMemrefElemSize == *readVecSize) {
+ Type elemType = memref.getType().cast<MemRefType>().getElementType();
+ Value newLoad = rewriter.create<LoadOp>(loc, elemType, memref, indices);
+ Type serializedVecType =
+ VectorType::get(read.getVectorType().getNumElements(),
+ read.getVectorType().getElementType());
+ newLoad =
+ rewriter.create<vector::BitCastOp>(loc, serializedVecType, newLoad);
+ newLoad = rewriter.create<vector::ShapeCastOp>(loc, read.getVectorType(),
+ newLoad);
+ rewriter.replaceOp(read, newLoad);
+ } else {
+ Value newRead = rewriter.create<vector::TransferReadOp>(
+ loc, read.getVectorType(), memref, indices);
+ rewriter.replaceOp(read, newRead);
+ }
+ return success();
+ }
+};
+
+class ProcessTransferWrite final
+ : public MemRefConversionPattern<vector::TransferWriteOp> {
+ public:
+ using MemRefConversionPattern<
+ vector::TransferWriteOp>::MemRefConversionPattern;
+ LogicalResult matchAndRewrite(
+ vector::TransferWriteOp write, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!memrefUsageAnalysis.transferConvert(write)) return failure();
+ vector::TransferWriteOp::Adaptor adaptor(operands);
+ Value memref = adaptor.memref();
+ Location loc = write.getLoc();
+ Optional<unsigned> vecMemrefElemSize =
+ getBitWidth(memref.getType().cast<MemRefType>().getElementType());
+ Optional<unsigned> writeElemSize =
+ getBitWidth(write.getMemRefType().getElementType());
+ Optional<unsigned> writeVecSize = getBitWidth(write.getVectorType());
+ if (!vecMemrefElemSize || !writeElemSize || !writeVecSize) return failure();
+ unsigned ratio = *vecMemrefElemSize / *writeElemSize;
+ SmallVector<Value, 4> indices(adaptor.indices());
+ indices.back() = rewriter.create<SignedDivIOp>(
+ loc, indices.back(), rewriter.create<ConstantIndexOp>(loc, ratio));
+ // If the transfer_write can be replaced by a store after vectorization cast
+ // the original value and use StoreOp.
+ if (*vecMemrefElemSize == *writeVecSize) {
+ Type serializedVecType =
+ VectorType::get(write.getVectorType().getNumElements(),
+ write.getVectorType().getElementType());
+ Value data = rewriter.create<vector::ShapeCastOp>(loc, serializedVecType,
+ adaptor.vector());
+ data = rewriter.create<vector::BitCastOp>(
+ loc, memref.getType().cast<MemRefType>().getElementType(), data);
+ rewriter.create<StoreOp>(loc, data, memref, indices);
+ } else {
+ rewriter.create<vector::TransferWriteOp>(loc, adaptor.vector(), memref,
+ indices);
+ }
+ rewriter.eraseOp(write);
+ return success();
+ }
+};
+
+static Optional<MemRefType> getVectorizedMemRefType(
+ ConversionPatternRewriter &rewriter, MemRefType type) {
+ unsigned elemSize = type.getElementTypeBitWidth();
+ unsigned vecSize = kVectorizationSizeInBits / elemSize;
+ // Pick a new type of element size 32bits.
+ Type newElemType = type.getElementType().isa<IntegerType>()
+ ? rewriter.getI32Type().cast<Type>()
+ : rewriter.getF32Type().cast<Type>();
+ Type vecType = VectorType::get(kVecSize, newElemType);
+ SmallVector<int64_t, 2> newShape(type.getShape().begin(),
+ type.getShape().end());
+ if (newShape.back() % vecSize != 0) return {};
+ newShape.back() = newShape.back() / vecSize;
+ return MemRefType::get(newShape, vecType, {}, type.getMemorySpace());
+}
+
+class ProcessAlloc final : public MemRefConversionPattern<AllocOp> {
+ public:
+ using MemRefConversionPattern<AllocOp>::MemRefConversionPattern;
+ LogicalResult matchAndRewrite(
+ AllocOp alloc, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefType = getVectorizedMemRefType(rewriter, alloc.getType());
+ if (!memrefType) return failure();
+ Value newAlloc =
+ rewriter.create<AllocOp>(alloc.getLoc(), *memrefType, alloc.value());
+ rewriter.replaceOp(alloc, newAlloc);
+ return success();
+ }
+};
+
+class ProcessIreeBinding final
+ : public MemRefConversionPattern<IREE::PlaceholderOp> {
+ public:
+ using MemRefConversionPattern<IREE::PlaceholderOp>::MemRefConversionPattern;
+ LogicalResult matchAndRewrite(
+ IREE::PlaceholderOp placeholder, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefType = placeholder.getType().dyn_cast<MemRefType>();
+ if (!memrefType) return failure();
+ auto vecMemRef = getVectorizedMemRefType(rewriter, memrefType);
+ if (!vecMemRef) return failure();
+ ValueRange dummyOperands;
+ Value newPlaceholder = rewriter.create<IREE::PlaceholderOp>(
+ placeholder.getLoc(), *vecMemRef, dummyOperands,
+ placeholder.getAttrs());
+ rewriter.replaceOp(placeholder, newPlaceholder);
+ return success();
+ }
+};
+
+class VectorizeMemRefPass final
+ : public PassWrapper<VectorizeMemRefPass, OperationPass<ModuleOp>> {
+ void runOnOperation() override;
+
+ private:
+ MemRefUsageAnalysis *memrefUsageAnalysis = nullptr;
+};
+} // namespace
+
+LogicalResult ProcessFuncArg::matchAndRewrite(
+ FuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ TypeConverter::SignatureConversion signatureConverter(
+ funcOp.getType().getNumInputs());
+ TypeConverter typeConverter;
+ for (const auto &arg : llvm::enumerate(funcOp.getArguments())) {
+ if (memrefUsageAnalysis.vectorizeMemRef(arg.value())) {
+ if (auto memrefType = getVectorizedMemRefType(
+ rewriter, arg.value().getType().cast<MemRefType>())) {
+ signatureConverter.addInputs(arg.index(), *memrefType);
+ continue;
+ }
+ }
+ signatureConverter.addInputs(arg.index(), arg.value().getType());
+ }
+ // Creates a new function with the update signature.
+ if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), typeConverter,
+ &signatureConverter)))
+ return failure();
+
+ // Creates a new function with the update signature.
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), llvm::None));
+ });
+ return success();
+}
+
+void VectorizeMemRefPass::runOnOperation() {
+ // Uses the signature conversion methodology of the dialect conversion
+ // framework to implement the conversion.
+ ModuleOp module = getOperation();
+ MLIRContext *context = &getContext();
+ memrefUsageAnalysis = &getAnalysis<MemRefUsageAnalysis>();
+
+ OwningRewritePatternList patterns;
+ patterns.insert<ProcessFuncArg, ProcessTransferRead, ProcessTransferWrite,
+ ProcessAlloc, ProcessIreeBinding>(context,
+ *memrefUsageAnalysis);
+
+ ConversionTarget target(*context);
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+ return llvm::all_of(op.getArguments(), [&](Value arg) {
+ return !memrefUsageAnalysis->vectorizeMemRef(arg);
+ });
+ });
+ target.addDynamicallyLegalOp<AllocOp>([&](AllocOp alloc) {
+ return !memrefUsageAnalysis->vectorizeMemRef(alloc);
+ });
+ target.addDynamicallyLegalOp<IREE::PlaceholderOp>(
+ [&](IREE::PlaceholderOp placeholder) {
+ return !memrefUsageAnalysis->vectorizeMemRef(placeholder);
+ });
+ target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+ if (isa<vector::TransferWriteOp, vector::TransferReadOp>(op))
+ return !memrefUsageAnalysis->transferConvert(op);
+ return true;
+ });
+ if (failed(applyPartialConversion(module, target, patterns)))
+ return signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> createVectorizeMemref() {
+ return std::make_unique<VectorizeMemRefPass>();
+}
+
+static PassRegistration<VectorizeMemRefPass> pass(
+ "iree-spirv-vectorize-memref",
+ "Vectorize memref arguments and allocations");
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir
new file mode 100644
index 0000000..9a87e8c
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir
@@ -0,0 +1,61 @@
+// RUN: iree-opt -split-input-file -iree-spirv-vectorize-memref -canonicalize %s | IreeFileCheck %s
+
+// CHECK-LABEL: func @copy
+// CHECK-SAME: (%[[ARG0:.+]]: memref<4096x1024xvector<4xf32>>
+// CHECK: %[[ALLOC:.+]] = alloc() : memref<128x8xvector<4xf32>, 3>
+// CHECK: %[[V:.+]] = load %[[ARG0]][%{{.*}}, %{{.*}}] : memref<4096x1024xvector<4xf32>>
+// CHECK: store %[[V]], %[[ALLOC]][%{{.*}}, %{{.*}}] : memref<128x8xvector<4xf32>, 3>
+// CHECK: %[[MAT:.+]] = vector.transfer_read %[[ARG0]][%{{.*}}, %{{.*}}], %{{.*}} : memref<4096x1024xvector<4xf32>>, vector<32x8xf32>
+// CHECK: vector.transfer_write %[[MAT]], %[[ALLOC]][%{{.*}}, %{{.*}}] : vector<32x8xf32>, memref<128x8xvector<4xf32>, 3>
+// CHECK: dealloc %[[ALLOC]] : memref<128x8xvector<4xf32>, 3>
+func @copy(%arg0: memref<4096x4096xf32>, %x: index, %y: index) {
+ %cst = constant 0.000000e+00 : f32
+ %0 = alloc() : memref<128x32xf32, 3>
+ %v = vector.transfer_read %arg0[%x, %y], %cst : memref<4096x4096xf32>, vector<1x4xf32>
+ vector.transfer_write %v, %0[%x, %y] : vector<1x4xf32>, memref<128x32xf32, 3>
+ %mat = vector.transfer_read %arg0[%x, %y], %cst : memref<4096x4096xf32>, vector<32x8xf32>
+ vector.transfer_write %mat, %0[%x, %y] : vector<32x8xf32>, memref<128x32xf32, 3>
+ dealloc %0 : memref<128x32xf32, 3>
+ return
+}
+
+// -----
+
+// Test that the memref is not vectorized if used by scalar load or store.
+// CHECK-LABEL: func @copy
+// CHECK-SAME: %[[ARG0:.+]]: memref<4096x4096xf32>
+func @copy(%arg0: memref<4096x4096xf32>, %x: index, %y: index) {
+ %cst = constant 0.000000e+00 : f32
+ %0 = alloc() : memref<128x32xf32, 3>
+ %s = load %arg0[%x, %y] : memref<4096x4096xf32>
+ store %s, %0[%x, %y] : memref<128x32xf32, 3>
+ dealloc %0 : memref<128x32xf32, 3>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @resource_copy
+// CHECK: %[[A:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4096x1024xvector<4xf32>>
+// CHECK: %[[B:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4096x1024xvector<4xf32>>
+// CHECK: %[[V:.+]] = load %[[A]][%{{.*}}, %{{.*}}] : memref<4096x1024xvector<4xf32>>
+// CHECK: store %[[V]], %[[B]][%{{.*}}, %{{.*}}] : memref<4096x1024xvector<4xf32>>
+// CHECK: %[[MAT:.+]] = vector.transfer_read %[[A]][%{{.*}}, %{{.*}}], %{{.*}} : memref<4096x1024xvector<4xf32>>, vector<32x8xf32>
+// CHECK: vector.transfer_write %[[MAT]], %[[B]][%{{.*}}, %{{.*}}] {{.*}} : vector<32x8xf32>, memref<4096x1024xvector<4xf32>>
+func @resource_copy() {
+ %cst = constant 0.000000e+00 : f32
+ %c0 = constant 0 : index
+ %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<4096x4096xf32>
+ %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<4096x4096xf32>
+ %v = vector.transfer_read %0[%c0, %c0], %cst : memref<4096x4096xf32>, vector<1x4xf32>
+ vector.transfer_write %v, %1[%c0, %c0] : vector<1x4xf32>, memref<4096x4096xf32>
+ %mat = vector.transfer_read %0[%c0, %c0], %cst : memref<4096x4096xf32>, vector<32x8xf32>
+ vector.transfer_write %mat, %1[%c0, %c0] : vector<32x8xf32>, memref<4096x4096xf32>
+ return
+}
+
+hal.interface @legacy_io attributes {push_constants = 5 : i32, sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write"
+}
+
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index 74eaf71..2978c69 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -41,6 +41,7 @@
createSplitDispatchFunctionPass();
createVectorToGPUPass();
createMatMulTileAndVectorizeGPUPass();
+ createVectorizeMemref();
return true;
}();
(void)init_once;