Delete LoadStoreVectorizationPass which is unused in IREE (#5514)
We have a proper flow to vectorize element wise ops after tiling, and it
is not used in old pipeline. We can delete the pass and the option.
The registration of `createVectorizeLinalgConvPass` was missing. The
PR also adds it to init_conversions.h.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.cpp
index 21d2618..1e2be3b 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.cpp
@@ -43,11 +43,6 @@
llvm::cl::desc("Use workgroup memory in SPIR-V code generation"),
llvm::cl::init(false));
- static llvm::cl::opt<bool> clVectorizeMemref(
- "iree-spirv-enable-memref-vectorization",
- llvm::cl::desc("Vectorize memref if possible in SPIR-V code generation"),
- llvm::cl::init(false));
-
static llvm::cl::list<unsigned> clWorkgroupSizes(
"iree-spirv-workgroup-size",
llvm::cl::desc("Set workgroup size to use for SPIR-V code generation"),
@@ -68,7 +63,6 @@
options.enableVectorization =
clEnableLinalgOnTensorsSPIRV || clEnableVectorization;
options.useWorkgroupMemory = clUseWorkgroupMemory;
- options.vectorizeMemref = clVectorizeMemref;
options.usingLinalgOnTensors = clEnableLinalgOnTensorsSPIRV;
return options;
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index 5469c74..d9bc606 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -94,10 +94,6 @@
}
pm.addPass(createTileAndVectorizeInOneWorkgroupPass(options));
- if (options.vectorizeMemref) {
- pm.nest<ModuleOp>().addNestedPass<FuncOp>(
- createLoadStoreVectorizationPass());
- }
pm.nest<ModuleOp>().addPass(createCanonicalizerPass());
//===--------------------------------------------------------------------===//
diff --git a/iree/compiler/Conversion/LinalgToVector/BUILD b/iree/compiler/Conversion/LinalgToVector/BUILD
index fe5a537..f796fb5 100644
--- a/iree/compiler/Conversion/LinalgToVector/BUILD
+++ b/iree/compiler/Conversion/LinalgToVector/BUILD
@@ -21,7 +21,6 @@
cc_library(
name = "LinalgToVector",
srcs = [
- "LoadStoreVectorization.cpp",
"VectorizeConv.cpp",
],
hdrs = [
diff --git a/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
index 9b766ab..83b2848 100644
--- a/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToVector/CMakeLists.txt
@@ -16,7 +16,6 @@
HDRS
"Passes.h"
SRCS
- "LoadStoreVectorization.cpp"
"VectorizeConv.cpp"
DEPS
LLVMSupport
diff --git a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp b/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
deleted file mode 100644
index 221d971..0000000
--- a/iree/compiler/Conversion/LinalgToVector/LoadStoreVectorization.cpp
+++ /dev/null
@@ -1,319 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
-#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
-#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-constexpr int kVectorizationSizeInBits = 128;
-constexpr int kVecSize = kVectorizationSizeInBits / (sizeof(float) * 8);
-
-/// Returns a VectorType in `kVectorizationSizeInBits` bits if `t` is a scalar.
-static VectorType getVecType(OpBuilder &builder, Type t) {
- if (!t.isa<IntegerType, FloatType>()) return {};
- if (t.getIntOrFloatBitWidth() != 32) return {};
- Type newElemType = t.isa<IntegerType>() ? builder.getI32Type().cast<Type>()
- : builder.getF32Type().cast<Type>();
- return VectorType::get(kVecSize, newElemType);
-}
-
-/// Returns the memref of vector converted from `type`.
-static MemRefType getVectorizedMemRefType(OpBuilder &builder, MemRefType type) {
- Type elemType = type.getElementType();
- VectorType vecType = getVecType(builder, elemType);
- if (!vecType) return {};
- unsigned elemSize = elemType.getIntOrFloatBitWidth();
- unsigned vecSize = kVectorizationSizeInBits / elemSize;
- SmallVector<int64_t, 2> newShape(type.getShape().begin(),
- type.getShape().end());
- if (newShape.empty()) return {};
- if (newShape.back() % vecSize != 0) return {};
- newShape.back() = newShape.back() / vecSize;
- return MemRefType::get(newShape, vecType, {}, type.getMemorySpaceAsInt());
-}
-
-/// Returns a vectorized `val`, ie, the result type is a VectorType.
-static Value legalizeToVectorType(OpBuilder &builder, Value val) {
- Type type = val.getType();
- if (type.isa<VectorType>()) {
- return val;
- } else if (type.isIntOrFloat()) {
- auto vecType = getVecType(builder, type);
- if (!vecType) return nullptr;
- return builder.createOrFold<vector::BroadcastOp>(val.getLoc(), vecType,
- val);
- }
- return nullptr;
-}
-
-/// Base class to vectorize std ops. If a generic op is vectorized, all the std
-/// ops in the region should be vectorized as well.
-///
-/// This base class handles the check on operands and vectorization for all the
-/// operands.
-///
-/// All derived classes implement a static apply method with the following
-/// signature:
-///
-/// ```c++
-/// LogicalResult apply(SrcOpTy op, ArrayRef<Value> args,
-/// ConversionPatternRewriter& rewriter) const;
-/// ```
-template <typename DerivedTy, typename SrcOpTy>
-struct VectorizeOpBase : public OpConversionPattern<SrcOpTy> {
- using OpConversionPattern<SrcOpTy>::OpConversionPattern;
- LogicalResult matchAndRewrite(
- SrcOpTy op, ArrayRef<Value> args,
- ConversionPatternRewriter &rewriter) const override {
- if (llvm::all_of(args, [](Value arg) {
- return arg.getType().isIntOrIndexOrFloat();
- })) {
- return failure();
- }
- SmallVector<Value, 4> vecArgs;
- for (Value arg : args) {
- Value val = legalizeToVectorType(rewriter, arg);
- if (!val) return failure();
- vecArgs.push_back(val);
- }
- return static_cast<DerivedTy const *>(this)->apply(op, vecArgs, rewriter);
- }
-};
-
-template <typename OpTy>
-struct VectorizeElementwiseOp
- : public VectorizeOpBase<VectorizeElementwiseOp<OpTy>, OpTy> {
- using VectorizeOpBase<VectorizeElementwiseOp<OpTy>, OpTy>::VectorizeOpBase;
- LogicalResult apply(OpTy op, ArrayRef<Value> args,
- ConversionPatternRewriter &rewriter) const {
- auto vecType = getVecType(rewriter, op.getResult().getType());
- if (!vecType) return failure();
- auto newOp = rewriter.create<OpTy>(op.getLoc(), vecType, args);
- rewriter.replaceOp(op, newOp.getOperation()->getResults());
- return success();
- }
-};
-
-template <typename OpTy>
-struct VectorizeCmpOp : public VectorizeOpBase<VectorizeCmpOp<OpTy>, OpTy> {
- using VectorizeOpBase<VectorizeCmpOp<OpTy>, OpTy>::VectorizeOpBase;
- LogicalResult apply(OpTy op, ArrayRef<Value> args,
- ConversionPatternRewriter &rewriter) const {
- auto newOp =
- rewriter.create<OpTy>(op.getLoc(), op.predicate(), args[0], args[1]);
- rewriter.replaceOp(op, newOp.getResult());
- return success();
- }
-};
-
-struct VectorizeSelectOp
- : public VectorizeOpBase<VectorizeSelectOp, mlir::SelectOp> {
- using VectorizeOpBase<VectorizeSelectOp, mlir::SelectOp>::VectorizeOpBase;
- LogicalResult apply(mlir::SelectOp op, ArrayRef<Value> args,
- ConversionPatternRewriter &rewriter) const {
- auto newOp =
- rewriter.create<SelectOp>(op.getLoc(), args[0], args[1], args[2]);
- rewriter.replaceOp(op, newOp.getResult());
- return success();
- }
-};
-
-struct VectorizeGenericOp : public OpConversionPattern<linalg::GenericOp> {
- using OpConversionPattern<linalg::GenericOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(
- linalg::GenericOp genericOp, ArrayRef<Value> args,
- ConversionPatternRewriter &rewriter) const override {
- // If a generic op does not take any input, it means it's working on
- // constants and those operations do not have canonicalization patterns to
- // fold it. For now just ignore to vectorize it.
- if (genericOp.getNumInputs() == 0) {
- return failure();
- }
-
- if (llvm::any_of(genericOp.iterator_types(), [](Attribute attr) {
- return attr.cast<StringAttr>().getValue() !=
- getParallelIteratorTypeName();
- })) {
- return failure();
- }
-
- // Do not vectorize if one of the operand is 0-D or one of the operand is
- // not iterated on contiguous memory.
- for (auto map : genericOp.getIndexingMaps()) {
- if (map.getNumResults() == 0) return failure();
- AffineDimExpr innerMostExpr =
- map.getResults().back().dyn_cast<AffineDimExpr>();
- if (!innerMostExpr ||
- innerMostExpr.getPosition() != map.getNumDims() - 1) {
- return failure();
- }
- }
-
- SmallVector<IREE::PlaceholderOp, 4> operands;
- SmallVector<MemRefType, 4> vecMemRefs;
- for (auto operand : args) {
- auto op = operand.getDefiningOp<IREE::PlaceholderOp>();
- if (!op) return failure();
- if (!op.getOperation()->hasOneUse()) return failure();
- auto memrefType = op.getResult().getType().dyn_cast<MemRefType>();
- if (!memrefType) return failure();
- auto vecMemRef = getVectorizedMemRefType(rewriter, memrefType);
- if (!vecMemRef) return failure();
- operands.push_back(op);
- vecMemRefs.push_back(vecMemRef);
- }
-
- SmallVector<Value, 4> newArgs;
- for (auto it : llvm::zip(operands, vecMemRefs)) {
- IREE::PlaceholderOp placeholder = std::get<0>(it);
- MemRefType vecMemRef = std::get<1>(it);
- auto arg = rewriter.create<IREE::PlaceholderOp>(placeholder.getLoc(),
- vecMemRef, ValueRange{},
- placeholder->getAttrs());
- rewriter.replaceOp(placeholder, arg.getResult());
- newArgs.push_back(arg.getResult());
- }
- ArrayRef<Value> newArgsRef(newArgs.begin(), newArgs.end());
- auto newOp = rewriter.create<linalg::GenericOp>(
- genericOp.getLoc(), genericOp.getResultTypes(),
- /*inputs=*/newArgsRef.take_front(genericOp.getNumInputs()),
- /*outputBuffers*/ newArgsRef.take_back(genericOp.getNumOutputs()),
- genericOp.indexing_mapsAttr(), genericOp.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr, genericOp.sparseAttr());
-
- Region &newRegion = newOp.region();
- rewriter.inlineRegionBefore(genericOp.getRegion(), newRegion,
- newRegion.end());
- Block &newBlock = newOp.region().front();
- TypeConverter::SignatureConversion signatureConverter(
- newBlock.getNumArguments());
- for (auto arg : llvm::enumerate(vecMemRefs)) {
- signatureConverter.addInputs(arg.index(), arg.value().getElementType());
- }
- rewriter.applySignatureConversion(&newOp.region(), signatureConverter);
- rewriter.replaceOp(genericOp, newOp.getResults());
- return success();
- }
-};
-
-struct LoadStoreVectorizationPass
- : public PassWrapper<LoadStoreVectorizationPass, OperationPass<FuncOp>> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect, vector::VectorDialect>();
- }
-
- void runOnOperation() override {
- MLIRContext *context = &getContext();
- OwningRewritePatternList patterns(&getContext());
- // clang-format off
- patterns.insert<
- VectorizeGenericOp,
- VectorizeCmpOp<CmpFOp>,
- VectorizeCmpOp<CmpIOp>,
- VectorizeSelectOp,
- VectorizeElementwiseOp<AbsFOp>,
- VectorizeElementwiseOp<AndOp>,
- VectorizeElementwiseOp<OrOp>,
- VectorizeElementwiseOp<XOrOp>,
- VectorizeElementwiseOp<AddFOp>,
- VectorizeElementwiseOp<AddIOp>,
- VectorizeElementwiseOp<CeilFOp>,
- VectorizeElementwiseOp<math::CosOp>,
- VectorizeElementwiseOp<DivFOp>,
- VectorizeElementwiseOp<math::ExpOp>,
- VectorizeElementwiseOp<FPExtOp>,
- VectorizeElementwiseOp<FPToSIOp>,
- VectorizeElementwiseOp<FPTruncOp>,
- VectorizeElementwiseOp<FloorFOp>,
- VectorizeElementwiseOp<math::LogOp>,
- VectorizeElementwiseOp<MulFOp>,
- VectorizeElementwiseOp<MulIOp>,
- VectorizeElementwiseOp<NegFOp>,
- VectorizeElementwiseOp<RemFOp>,
- VectorizeElementwiseOp<math::RsqrtOp>,
- VectorizeElementwiseOp<SIToFPOp>,
- VectorizeElementwiseOp<ShiftLeftOp>,
- VectorizeElementwiseOp<SignExtendIOp>,
- VectorizeElementwiseOp<SignedDivIOp>,
- VectorizeElementwiseOp<SignedShiftRightOp>,
- VectorizeElementwiseOp<math::SinOp>,
- VectorizeElementwiseOp<math::SqrtOp>,
- VectorizeElementwiseOp<SubFOp>,
- VectorizeElementwiseOp<SubIOp>,
- VectorizeElementwiseOp<math::TanhOp>,
- VectorizeElementwiseOp<TruncateIOp>,
- VectorizeElementwiseOp<UnsignedDivIOp>,
- VectorizeElementwiseOp<UnsignedRemIOp>,
- VectorizeElementwiseOp<UnsignedShiftRightOp>,
- VectorizeElementwiseOp<ZeroExtendIOp>>(context);
- // clang-format on
-
- ConversionTarget target(*context);
- // Mark vector dialect and plancholder op legal.
- target.addLegalDialect<vector::VectorDialect>();
- target.addLegalOp<IREE::PlaceholderOp>();
-
- // If a generic op is vectorized, it is legal.
- target.addDynamicallyLegalOp<linalg::GenericOp>([](linalg::GenericOp op) {
- if (!op.hasBufferSemantics()) return false;
- for (auto arg : op.getOperands()) {
- if (arg.getType()
- .cast<MemRefType>()
- .getElementType()
- .isSignlessIntOrFloat())
- return false;
- }
- return true;
- });
-
- // Mark all standard ops legal if they are operating on vector types.
- target.addDynamicallyLegalDialect<mlir::StandardOpsDialect,
- mlir::math::MathDialect>(
- Optional<ConversionTarget::DynamicLegalityCallbackFn>(
- [](Operation *op) {
- auto isVectorType = [](Type t) { return t.isa<VectorType>(); };
- return llvm::any_of(op->getOperandTypes(), isVectorType) ||
- llvm::any_of(op->getResultTypes(), isVectorType);
- }));
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- return signalPassFailure();
- }
-};
-} // namespace
-
-std::unique_ptr<Pass> createLoadStoreVectorizationPass() {
- return std::make_unique<LoadStoreVectorizationPass>();
-}
-
-static PassRegistration<LoadStoreVectorizationPass> pass(
- "iree-codegen-vectorize-linalg-ops", "Vectorize Linalg operations");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToVector/Passes.h b/iree/compiler/Conversion/LinalgToVector/Passes.h
index 4f7340b..cfe4f03 100644
--- a/iree/compiler/Conversion/LinalgToVector/Passes.h
+++ b/iree/compiler/Conversion/LinalgToVector/Passes.h
@@ -20,9 +20,6 @@
namespace mlir {
namespace iree_compiler {
-/// Creates a pass to vectorize Linalg operations.
-std::unique_ptr<Pass> createLoadStoreVectorizationPass();
-
/// Creates a pass to vectorize a very specific form of linalg.conv ops.
std::unique_ptr<Pass> createVectorizeLinalgConvPass();
diff --git a/iree/compiler/Conversion/LinalgToVector/test/BUILD b/iree/compiler/Conversion/LinalgToVector/test/BUILD
index 2997186..d3a6fcf 100644
--- a/iree/compiler/Conversion/LinalgToVector/test/BUILD
+++ b/iree/compiler/Conversion/LinalgToVector/test/BUILD
@@ -28,7 +28,6 @@
srcs = enforce_glob(
[
"vectorize_linalg_conv.mlir",
- "vectorize_linalg_ops.mlir",
],
include = ["*.mlir"],
),
diff --git a/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt
index b958f27..51d9623 100644
--- a/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToVector/test/CMakeLists.txt
@@ -15,7 +15,6 @@
lit
SRCS
"vectorize_linalg_conv.mlir"
- "vectorize_linalg_ops.mlir"
DATA
iree::tools::IreeFileCheck
iree::tools::iree-opt
diff --git a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir b/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir
deleted file mode 100644
index e91a4fd..0000000
--- a/iree/compiler/Conversion/LinalgToVector/test/vectorize_linalg_ops.mlir
+++ /dev/null
@@ -1,174 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-codegen-vectorize-linalg-ops -canonicalize -cse %s | IreeFileCheck %s
-
-func @broadcast_add() {
- %0 = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<4xf32>
- %1 = iree.placeholder for "interface buffer" {binding = @io::@arg1} : memref<3x4xf32>
- %2 = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<3x4xf32>
- linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%0, %1 : memref<4xf32>, memref<3x4xf32>)
- outs(%2 : memref<3x4xf32>) {
- ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors
- %3 = addf %arg0, %arg1 : f32
- linalg.yield %3 : f32
- }
- return
-}
-// CHECK-LABEL: func @broadcast_add
-// CHECK-DAG: %[[BUF0:.+]] = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<1xvector<4xf32>>
-// CHECK-DAG: %[[BUF1:.+]] = iree.placeholder for "interface buffer" {binding = @io::@arg1} : memref<3x1xvector<4xf32>>
-// CHECK-DAG: %[[BUF2:.+]] = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<3x1xvector<4xf32>>
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[BUF0]], %[[BUF1]] :
-// CHECK-SAME: outs(%[[BUF2]] :
-// CHECK: ^bb0(%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>, %[[ARG2:.+]]: vector<4xf32>)
-// CHECK: %[[RES:.+]] = addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
-// CHECK: linalg.yield %[[RES]] : vector<4xf32>
-
-// -----
-
-func @log_plus_one() {
- %0 = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<4xf32>
- %c0 = constant 0 : index
- %cst = constant 1.000000e+00 : f32
- %1 = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<4xf32>
- linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]}
- ins(%1 : memref<4xf32>)
- outs(%0 : memref<4xf32>) {
- ^bb0(%arg0: f32, %arg1: f32): // no predecessors
- %2 = addf %arg0, %cst : f32
- %3 = math.log %2 : f32
- linalg.yield %3 : f32
- }
- return
-}
-// CHECK-LABEL: func @log_plus_one
-// CHECK-DAG: %[[BUF0:.+]] = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<1xvector<4xf32>>
-// CHECK-DAG: %[[BUF1:.+]] = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<1xvector<4xf32>>
-// CHECK-DAG: %[[CST:.+]] = constant dense<1.000000e+00> : vector<4xf32>
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[BUF0]] :
-// CHECK-SAME: outs(%[[BUF1]] :
-// CHECK: ^bb0(%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xf32>)
-// CHECK: %[[T1:.+]] = addf %[[ARG0]], %[[CST]] : vector<4xf32>
-// CHECK: %[[T2:.+]] = math.log %[[T1]] : vector<4xf32>
-// CHECK: linalg.yield %[[T2]] : vector<4xf32>
-
-// -----
-
-func @cmp_and_select() {
- %0 = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<4xi32>
- %1 = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<4xi32>
- %2 = iree.placeholder for "interface buffer" {binding = @io::@arg1} : memref<4xi32>
- linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]}
- ins(%1, %2 : memref<4xi32>, memref<4xi32>)
- outs(%0 : memref<4xi32>) {
- ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors
- %3 = cmpi sgt, %arg0, %arg1 : i32
- %4 = select %3, %arg0, %arg1 : i32
- linalg.yield %4 : i32
- }
- return
-}
-// CHECK-LABEL: func @cmp_and_select
-// CHECK-DAG: %[[BUF0:.+]] = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<1xvector<4xi32>>
-// CHECK-DAG: %[[BUF1:.+]] = iree.placeholder for "interface buffer" {binding = @io::@arg1} : memref<1xvector<4xi32>>
-// CHECK-DAG: %[[BUF2:.+]] = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<1xvector<4xi32>>
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[BUF0]], %[[BUF1]] :
-// CHECK-SAME: outs(%[[BUF2]] :
-// CHECK: ^bb0(%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>)
-// CHECK: %[[T1:.+]] = cmpi sgt, %[[ARG0]], %[[ARG1]] : vector<4xi32>
-// CHECK: %[[T2:.+]] = select %[[T1]], %[[ARG0]], %[[ARG1]] : vector<4xi1>, vector<4xi32>
-// CHECK: linalg.yield %[[T2]] : vector<4xi32>
-
-// -----
-
-func @cmp_convert_mul() {
- %1 = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<4xf32>
- %2 = iree.placeholder for "interface buffer" {binding = @io::@arg1} : memref<4xi32>
- %0 = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<4xi32>
- %cst = constant 0.000000e+00 : f32
- linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
- affine_map<(d0) -> (d0)>,
- affine_map<(d0) -> (d0)>],
- iterator_types = ["parallel"]}
- ins(%1, %2 : memref<4xf32>, memref<4xi32>)
- outs(%0 : memref<4xi32>) {
- ^bb0(%arg0: f32, %arg1: i32, %arg2: i32): // no predecessors
- %3 = cmpf oeq, %arg0, %cst : f32
- %4 = zexti %3 : i1 to i32
- %5 = muli %4, %arg1 : i32
- linalg.yield %5 : i32
- }
- return
-}
-// CHECK-LABEL: func @cmp_convert_mul
-// CHECK-DAG: %[[BUF0:.+]] = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<1xvector<4xf32>>
-// CHECK-DAG: %[[BUF1:.+]] = iree.placeholder for "interface buffer" {binding = @io::@arg1} : memref<1xvector<4xi32>>
-// CHECK-DAG: %[[BUF2:.+]] = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<1xvector<4xi32>>
-// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[BUF0]], %[[BUF1]] :
-// CHECK-SAME: outs(%[[BUF2]] :
-// CHECK: ^bb0(%[[ARG0:.+]]: vector<4xf32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>)
-// CHECK: %[[T1:.+]] = cmpf oeq, %[[ARG0]], %{{.+}} : vector<4xf32>
-// CHECK: %[[T2:.+]] = zexti %[[T1]] : vector<4xi1> to vector<4xi32>
-// CHECK: %[[T3:.+]] = muli %[[T2]], %[[ARG1]] : vector<4xi32>
-// CHECK: linalg.yield %[[T3]] : vector<4xi32>
-
-// -----
-
-func @not_contiguous() {
- %0 = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<4x4xf32>
- %c0 = constant 0 : index
- %cst = constant 1.000000e+00 : f32
- %1 = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<4x4xf32>
- linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]}
- ins(%1 : memref<4x4xf32>)
- outs(%0 : memref<4x4xf32>) {
- ^bb0(%arg0: f32, %arg1: f32): // no predecessors
- %2 = addf %arg0, %cst : f32
- linalg.yield %2 : f32
- }
- return
-}
-// CHECK-LABEL: func @not_contiguous
-// CHECK-DAG: iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<4x4xf32>
-// CHECK-DAG: iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<4x4xf32>
-
-// -----
-
-func @not_4s() {
- %0 = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<4x3xf32>
- %c0 = constant 0 : index
- %cst = constant 1.000000e+00 : f32
- %1 = iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<4x3xf32>
- linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]}
- ins(%1 : memref<4x3xf32>)
- outs(%0 : memref<4x3xf32>) {
- ^bb0(%arg0: f32, %arg1: f32): // no predecessors
- %2 = addf %arg0, %cst : f32
- linalg.yield %2 : f32
- }
- return
-}
-// CHECK-LABEL: func @not_4s
-// CHECK-DAG: iree.placeholder for "interface buffer" {binding = @io::@arg0} : memref<4x3xf32>
-// CHECK-DAG: iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<4x3xf32>
-
-// -----
-
-// CHECK-LABEL: func @cst
-// CHECK: iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<4xf32>
-func @cst() {
- %cst = constant 1.001000e+00 : f32
- %0 = iree.placeholder for "interface buffer" {binding = @io::@ret0} : memref<4xf32>
- linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%0 : memref<4xf32>) {
- ^bb0(%arg0: f32): // no predecessors
- %1 = math.rsqrt %cst : f32
- linalg.yield %1 : f32
- }
- return
-}
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index d701559..d948fae 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -56,7 +56,7 @@
inline void registerLinalgToVectorPasses() {
static bool init_once = []() {
- createLoadStoreVectorizationPass();
+ createVectorizeLinalgConvPass();
return true;
}();
(void)init_once;
diff --git a/iree/test/e2e/vulkan_specific/BUILD b/iree/test/e2e/vulkan_specific/BUILD
index 876cc74..0f4c27f 100644
--- a/iree/test/e2e/vulkan_specific/BUILD
+++ b/iree/test/e2e/vulkan_specific/BUILD
@@ -57,18 +57,6 @@
)
iree_check_single_backend_test_suite(
- name = "check_vulkan-spirv_vulkan_vector",
- srcs = [
- "compare.mlir",
- "log_plus_one.mlir",
- "pw_add_multiwg.mlir",
- ],
- compiler_flags = ["-iree-spirv-enable-memref-vectorization"],
- driver = "vulkan",
- target_backend = "vulkan-spirv",
-)
-
-iree_check_single_backend_test_suite(
name = "check_vulkan-spirv_vulkan_vectorized_conv",
srcs = [
"vectorized_conv.mlir",
diff --git a/iree/test/e2e/vulkan_specific/CMakeLists.txt b/iree/test/e2e/vulkan_specific/CMakeLists.txt
index c5ec28c..9e94d29 100644
--- a/iree/test/e2e/vulkan_specific/CMakeLists.txt
+++ b/iree/test/e2e/vulkan_specific/CMakeLists.txt
@@ -43,21 +43,6 @@
iree_check_single_backend_test_suite(
NAME
- check_vulkan-spirv_vulkan_vector
- SRCS
- "compare.mlir"
- "log_plus_one.mlir"
- "pw_add_multiwg.mlir"
- TARGET_BACKEND
- "vulkan-spirv"
- DRIVER
- "vulkan"
- COMPILER_FLAGS
- "-iree-spirv-enable-memref-vectorization"
-)
-
-iree_check_single_backend_test_suite(
- NAME
check_vulkan-spirv_vulkan_vectorized_conv
SRCS
"vectorized_conv.mlir"