Dropping the VMLA compiler. (#5903)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b1ac598..7e09869 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -162,7 +162,6 @@
WASM-LLVM-AOT
Metal-SPIRV
Vulkan-SPIRV
- VMLA
VMVX
)
diff --git a/iree/compiler/Conversion/HLOToHLO/BUILD b/iree/compiler/Conversion/HLOToHLO/BUILD
index fb1661c..e874421 100644
--- a/iree/compiler/Conversion/HLOToHLO/BUILD
+++ b/iree/compiler/Conversion/HLOToHLO/BUILD
@@ -21,7 +21,6 @@
cc_library(
name = "HLOToHLO",
srcs = [
- "Convert1x1ConvToDot.cpp",
"DecomposeHLOClamp.cpp",
"DemoteF32ToF16.cpp",
],
diff --git a/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
index ec9a9ad..acd6747 100644
--- a/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
+++ b/iree/compiler/Conversion/HLOToHLO/CMakeLists.txt
@@ -16,7 +16,6 @@
HDRS
"Passes.h"
SRCS
- "Convert1x1ConvToDot.cpp"
"DecomposeHLOClamp.cpp"
"DemoteF32ToF16.cpp"
DEPS
diff --git a/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp b/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp
deleted file mode 100644
index 2a142f0..0000000
--- a/iree/compiler/Conversion/HLOToHLO/Convert1x1ConvToDot.cpp
+++ /dev/null
@@ -1,149 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Rewrites an n-d (n, d1, d2, d3, ..., ci) * (1, 1, 1, ..., ci, co)
-// as (n * d1 * d2 * d3, ..., ci) . (ci, co)
-// TODO(#4876): this pattern should be replaced by a pattern that converts
-// linalg.conv to linalg.matmul.
-class Convert1x1ConvolutionToDotOp : public OpRewritePattern<mhlo::ConvOp> {
- public:
- using OpRewritePattern<mhlo::ConvOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(mhlo::ConvOp op,
- PatternRewriter &rewriter) const override {
- // Only 1x1 convolution no groups will match.
- if (op.feature_group_count() != 1) return failure();
-
- Value input = op.lhs();
- Value filter = op.rhs();
- Value output = op.getResult();
- auto inputShapeType = input.getType().dyn_cast_or_null<RankedTensorType>();
- auto filterShapeType =
- filter.getType().dyn_cast_or_null<RankedTensorType>();
- auto outputShapeType =
- output.getType().dyn_cast_or_null<RankedTensorType>();
-
- if (!inputShapeType || !filterShapeType || !outputShapeType) {
- return failure();
- }
-
- auto inputShape = inputShapeType.getShape();
- auto filterShape = filterShapeType.getShape();
-
- auto inputBatchDim =
- op.dimension_numbers().input_batch_dimension().getInt();
- auto inputFeatureDim =
- op.dimension_numbers().input_feature_dimension().getInt();
- auto kernelInputFeatureDim =
- op.dimension_numbers().kernel_input_feature_dimension().getInt();
- auto kernelOutputFeatureDim =
- op.dimension_numbers().kernel_output_feature_dimension().getInt();
-
- // Match input (n, d1, d2, ..., ci) format
- if (inputFeatureDim != (inputShape.size() - 1) || inputBatchDim != 0) {
- return failure();
- }
-
- // Match filter (k1, k2, ..., ci, co) format
- if (kernelInputFeatureDim != (filterShape.size() - 2) ||
- kernelOutputFeatureDim != (filterShape.size() - 1)) {
- return failure();
- }
-
- // Check 1x1x... kernel spatial size.
- for (auto dim : op.dimension_numbers().kernel_spatial_dimensions()) {
- if (filterShape[dim.getZExtValue()] != 1) return failure();
- }
-
- // Check dilation & strides are ones.
- if (op.window_strides()) {
- for (auto stride : op.window_strides()->getValues<int64_t>()) {
- if (stride != 1) return failure();
- }
- }
- if (op.rhs_dilation()) {
- for (auto dilation : op.rhs_dilation()->getValues<int64_t>()) {
- if (dilation != 1) return failure();
- }
- }
-
- int64_t spatialSize = inputShape[0];
- for (auto dim : op.dimension_numbers().input_spatial_dimensions()) {
- spatialSize *= inputShape[dim.getZExtValue()];
- }
-
- Type reshapedInputType =
- RankedTensorType::get({spatialSize, inputShape[inputFeatureDim]},
- inputShapeType.getElementType());
- Type reshapedFilterTYpe =
- RankedTensorType::get({filterShape[kernelInputFeatureDim],
- filterShape[kernelOutputFeatureDim]},
- filterShapeType.getElementType());
- Type dotResultType = RankedTensorType::get(
- {spatialSize, filterShape[kernelOutputFeatureDim]},
- outputShapeType.getElementType());
-
- Value reshapedInput =
- rewriter.create<mhlo::ReshapeOp>(op.getLoc(), reshapedInputType, input);
- Value reshapedFilter = rewriter.create<mhlo::ReshapeOp>(
- op.getLoc(), reshapedFilterTYpe, filter);
-
- Value dotResult = rewriter.create<mhlo::DotOp>(
- op.getLoc(), dotResultType, reshapedInput, reshapedFilter,
- rewriter.getStrArrayAttr({"HIGHEST", "HIGHEST"}));
-
- Value reshapedResult = rewriter.create<mhlo::ReshapeOp>(
- op.getLoc(), outputShapeType, dotResult);
-
- rewriter.replaceOp(op, reshapedResult);
-
- return success();
- }
-};
-
-struct Convert1x1ConvToDotPass
- : public PassWrapper<Convert1x1ConvToDotPass, FunctionPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<mhlo::MhloDialect>();
- }
-
- void runOnFunction() override {
- MLIRContext *context = &getContext();
- OwningRewritePatternList patterns(&getContext());
- patterns.insert<Convert1x1ConvolutionToDotOp>(context);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
- }
-};
-} // namespace
-
-std::unique_ptr<OperationPass<FuncOp>> createConvert1x1ConvToDotPass() {
- return std::make_unique<Convert1x1ConvToDotPass>();
-}
-
-static PassRegistration<Convert1x1ConvToDotPass> pass(
- "iree-codegen-convert-1x1-conv-to-dot",
- "Convert mhlo.convolution ops with 1x1 kernels into mhlo.dot ops");
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Conversion/HLOToHLO/Passes.h b/iree/compiler/Conversion/HLOToHLO/Passes.h
index 0822ec7..5edbcb2 100644
--- a/iree/compiler/Conversion/HLOToHLO/Passes.h
+++ b/iree/compiler/Conversion/HLOToHLO/Passes.h
@@ -30,10 +30,6 @@
namespace mlir {
namespace iree_compiler {
-/// Creates a pass to convert mhlo.convolution ops with 1x1 kernels into
-/// mhlo.dot ops.
-std::unique_ptr<OperationPass<FuncOp>> createConvert1x1ConvToDotPass();
-
/// Creates a pass to decompose XLA-HLO clamp ops into primitive ops.
std::unique_ptr<OperationPass<FuncOp>> createDecomposeHLOClampPass();
diff --git a/iree/compiler/Conversion/HLOToHLO/test/BUILD b/iree/compiler/Conversion/HLOToHLO/test/BUILD
index ab0ba6b..979ba9b 100644
--- a/iree/compiler/Conversion/HLOToHLO/test/BUILD
+++ b/iree/compiler/Conversion/HLOToHLO/test/BUILD
@@ -27,7 +27,6 @@
name = "lit",
srcs = enforce_glob(
[
- "conv1x12dot.mlir",
"f32Tof16.mlir",
],
include = ["*.mlir"],
diff --git a/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt b/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt
index 50035da..53f22fc 100644
--- a/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt
+++ b/iree/compiler/Conversion/HLOToHLO/test/CMakeLists.txt
@@ -14,7 +14,6 @@
NAME
lit
SRCS
- "conv1x12dot.mlir"
"f32Tof16.mlir"
DATA
iree::tools::IreeFileCheck
diff --git a/iree/compiler/Conversion/HLOToHLO/test/conv1x12dot.mlir b/iree/compiler/Conversion/HLOToHLO/test/conv1x12dot.mlir
deleted file mode 100644
index 002fe3e..0000000
--- a/iree/compiler/Conversion/HLOToHLO/test/conv1x12dot.mlir
+++ /dev/null
@@ -1,28 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-codegen-convert-1x1-conv-to-dot %s | IreeFileCheck %s
-
-// CHECK: @conv_1x1(%[[INPUT:.+]]: tensor<2x4x5x2xf32>, %[[FILTER:.+]]: tensor<1x1x2x7xf32>) -> tensor<2x4x5x7xf32>
-func @conv_1x1(%arg0: tensor<2x4x5x2xf32>, %arg1: tensor<1x1x2x7xf32>) -> tensor<2x4x5x7xf32> {
- // CHECK: %[[RESHAPED_INPUT:.+]] = "mhlo.reshape"(%[[INPUT]]) : (tensor<2x4x5x2xf32>) -> tensor<40x2xf32>
- // CHECK: %[[RESHAPED_FILTER:.+]] = "mhlo.reshape"(%[[FILTER]]) : (tensor<1x1x2x7xf32>) -> tensor<2x7xf32>
- // CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot"(%[[RESHAPED_INPUT]], %[[RESHAPED_FILTER]]) {precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<40x2xf32>, tensor<2x7xf32>) -> tensor<40x7xf32>
- // CEHCK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<40x7xf32>) -> tensor<2x4x5x7xf32>
- %0 = "mhlo.convolution"(%arg0, %arg1) {
- batch_group_count = 1 : i64,
- dimension_numbers = {
- input_batch_dimension = 0 : i64,
- input_feature_dimension = 3 : i64,
- input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
- kernel_input_feature_dimension = 2 : i64,
- kernel_output_feature_dimension = 3 : i64,
- kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
- output_batch_dimension = 0 : i64,
- output_feature_dimension = 3 : i64,
- output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
- feature_group_count = 1 : i64,
- padding = dense<0> : tensor<2x2xi64>,
- rhs_dilation = dense<1> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<1x1x2x7xf32>) -> tensor<2x4x5x7xf32>
- return %0 : tensor<2x4x5x7xf32>
-}
-
-
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/BUILD b/iree/compiler/Dialect/HAL/Target/VMLA/BUILD
deleted file mode 100644
index 5081805..0000000
--- a/iree/compiler/Dialect/HAL/Target/VMLA/BUILD
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright 2019 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_cmake_extra_content(
- content = """
-if(NOT "${IREE_TARGET_BACKEND_VMLA}")
- return()
-endif()
-""",
-)
-
-cc_library(
- name = "VMLA",
- srcs = [
- "VMLATarget.cpp",
- ],
- hdrs = [
- "VMLATarget.h",
- ],
- deps = [
- "//iree/base/internal:flatcc",
- "//iree/compiler/Dialect/Flow/IR",
- "//iree/compiler/Dialect/HAL/Target",
- "//iree/compiler/Dialect/VM/Conversion",
- "//iree/compiler/Dialect/VM/IR",
- "//iree/compiler/Dialect/VM/Target/Bytecode",
- "//iree/compiler/Dialect/VM/Transforms",
- "//iree/compiler/Dialect/VMLA/IR:VMLADialect",
- "//iree/compiler/Dialect/VMLA/Transforms",
- "//iree/compiler/Utils",
- "//iree/schemas:vmla_executable_def_c_fbs",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:Support",
- ],
-)
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt
deleted file mode 100644
index d24651c..0000000
--- a/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt
+++ /dev/null
@@ -1,43 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/HAL/Target/VMLA/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-if(NOT "${IREE_TARGET_BACKEND_VMLA}")
- return()
-endif()
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- VMLA
- HDRS
- "VMLATarget.h"
- SRCS
- "VMLATarget.cpp"
- DEPS
- LLVMSupport
- MLIRIR
- MLIRPass
- MLIRSupport
- iree::base::internal::flatcc
- iree::compiler::Dialect::Flow::IR
- iree::compiler::Dialect::HAL::Target
- iree::compiler::Dialect::VM::Conversion
- iree::compiler::Dialect::VM::IR
- iree::compiler::Dialect::VM::Target::Bytecode
- iree::compiler::Dialect::VM::Transforms
- iree::compiler::Dialect::VMLA::IR::VMLADialect
- iree::compiler::Dialect::VMLA::Transforms
- iree::compiler::Utils
- iree::schemas::vmla_executable_def_c_fbs
- PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
deleted file mode 100644
index 080fd6f..0000000
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
+++ /dev/null
@@ -1,159 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.h"
-
-#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
-#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
-#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/VM/IR/VMDialect.h"
-#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
-#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
-#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
-#include "iree/compiler/Utils/FlatbufferUtils.h"
-#include "iree/schemas/vmla_executable_def_builder.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/LogicalResult.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-
-VMLATargetOptions getVMLATargetOptionsFromFlags() {
- VMLATargetOptions targetOptions;
- // TODO(benvanik): flags.
- return targetOptions;
-}
-
-class VMLATargetBackend final : public TargetBackend {
- public:
- VMLATargetBackend(VMLATargetOptions options) : options_(std::move(options)) {}
-
- std::string name() const override { return "vmla"; }
- std::string filter_pattern() const override { return "vmla"; }
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<VM::VMDialect, VMLA::VMLADialect>();
- }
-
- void buildTranslationPassPipeline(OpPassManager &passManager) override {
- OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
- IREE::VMLA::buildVMLATransformPassPipeline(nestedModulePM);
-
- // TODO(#614): remove this when the std->vm conversion isn't looking for
- // iree.module.export.
- nestedModulePM.addPass(IREE::VM::createMarkPublicSymbolsExportedPass());
-
- IREE::VM::buildVMTransformPassPipeline(
- nestedModulePM, IREE::VM::getTargetOptionsFromFlags());
- }
-
- LogicalResult linkExecutables(mlir::ModuleOp moduleOp) override {
- OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
-
- auto sourceExecutableOps =
- llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
- if (sourceExecutableOps.size() <= 1) return success();
-
- // Create our new "linked" hal.executable.
- std::string linkedExecutableName = llvm::formatv("vmla_linked_{1}", name());
- auto linkedExecutableOp = builder.create<IREE::HAL::ExecutableOp>(
- moduleOp.getLoc(), linkedExecutableName);
- linkedExecutableOp.setVisibility(
- sourceExecutableOps.front().getVisibility());
-
- // Add our VMLA hal.executable.target with an empty module.
- builder.setInsertionPointToStart(linkedExecutableOp.getBody());
- auto linkedTargetOp = builder.create<IREE::HAL::ExecutableTargetOp>(
- moduleOp.getLoc(), name(), filter_pattern());
- builder.setInsertionPoint(&linkedTargetOp.getBlock().back());
- auto linkedModuleOp = builder.create<ModuleOp>(moduleOp.getLoc());
-
- // Add an empty vm.module to that module (as our vm.funcs must live in it).
- builder.setInsertionPointToStart(linkedModuleOp.getBody());
- builder.create<IREE::VM::ModuleOp>(moduleOp.getLoc(), "linked_module");
-
- // Try linking together all executables in moduleOp.
- return linkExecutablesInto(
- moduleOp, sourceExecutableOps, linkedExecutableOp, linkedTargetOp,
- [](mlir::ModuleOp moduleOp) {
- return *moduleOp.getOps<IREE::VM::ModuleOp>().begin();
- },
- builder);
- }
-
- LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
- OpBuilder &executableBuilder) override {
- FlatbufferBuilder builder;
- iree_VMLAExecutableDef_start_as_root(builder);
-
- // Serialize the VM module to bytes directly into a flatbuffer.
- IREE::VM::BytecodeTargetOptions bytecodeOptions;
- auto dataRef = builder.streamUint8Vec([&](raw_ostream &stream) {
- return succeeded(translateModuleToBytecode(targetOp.getInnerModule(),
- bytecodeOptions, stream));
- });
- if (!dataRef) {
- return targetOp.emitError() << "failed to serialize converted VM module";
- }
-
- // Pack the executable definition and get the bytes with the proper header.
- // The header is used to verify the contents at runtime.
- iree_VMLAExecutableDef_bytecode_module_add(builder, dataRef);
- iree_VMLAExecutableDef_end_as_root(builder);
-
- // Add the binary data to the target executable.
- // NOTE: this snapshots the flatbuffer builder data at the time it is called
- // and future changes will not be observed.
- auto binaryOp = executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
- targetOp.getLoc(), targetOp.sym_name(),
- executableBuilder.getStringAttr("VMLA"),
- builder.getBufferAttr(executableBuilder.getContext()));
- binaryOp.mime_typeAttr(
- executableBuilder.getStringAttr("application/x-flatbuffers"));
- return success();
- }
-
- std::array<Value, 3> calculateDispatchWorkgroupCount(
- Location loc, IREE::HAL::ExecutableOp executableOp,
- IREE::HAL::ExecutableEntryPointOp entryPointOp, ValueRange workload,
- OpBuilder &builder) override {
- // For now we are not tiling and just dispatch everything as 1,1,1.
- auto constantOne = builder.createOrFold<mlir::ConstantIndexOp>(loc, 1);
- return {constantOne, constantOne, constantOne};
- }
-
- private:
- VMLATargetOptions options_;
-};
-
-void registerVMLATargetBackends(
- std::function<VMLATargetOptions()> queryOptions) {
- getVMLATargetOptionsFromFlags();
- static TargetBackendRegistration registration("vmla", [=]() {
- return std::make_unique<VMLATargetBackend>(queryOptions());
- });
-}
-
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.h b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.h
deleted file mode 100644
index 945a6e0..0000000
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.h
+++ /dev/null
@@ -1,43 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_HAL_TARGET_VMLA_VMLATARGET_H_
-#define IREE_COMPILER_DIALECT_HAL_TARGET_VMLA_VMLATARGET_H_
-
-#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace HAL {
-
-// Options controlling the VM/LA translation.
-struct VMLATargetOptions {
- // TODO(benvanik): target configuration.
-};
-
-// Returns a VMLATargetOptions struct initialized with the
-// --iree-hal-vm-la-* flags.
-VMLATargetOptions getVMLATargetOptionsFromFlags();
-
-// Registers the VMLA backends.
-void registerVMLATargetBackends(
- std::function<VMLATargetOptions()> queryOptions);
-
-} // namespace HAL
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_HAL_TARGET_VMLA_VMLATARGET_H_
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD b/iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD
deleted file mode 100644
index fa884d6..0000000
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:lit_test.bzl", "iree_lit_test_suite")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_lit_test_suite(
- name = "lit",
- srcs = enforce_glob(
- [
- "i1_types.mlir",
- "linking.mlir",
- "smoketest.mlir",
- ],
- include = ["*.mlir"],
- ),
- data = [
- "//iree/tools:IreeFileCheck",
- "//iree/tools:iree-opt",
- ],
-)
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VMLA/test/CMakeLists.txt
deleted file mode 100644
index 4a238ee..0000000
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/CMakeLists.txt
+++ /dev/null
@@ -1,25 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/HAL/Target/VMLA/test/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_lit_test_suite(
- NAME
- lit
- SRCS
- "i1_types.mlir"
- "linking.mlir"
- "smoketest.mlir"
- DATA
- iree::tools::IreeFileCheck
- iree::tools::iree-opt
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
deleted file mode 100644
index b129310..0000000
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='iree-hal-transformation-pipeline{serialize-executables=false link-executables=false},canonicalize' -iree-hal-target-backends=vmla %s | IreeFileCheck %s
-
-// CHECK-LABEL: @i1_op_usage(%arg0: !hal.buffer) -> !hal.buffer
-func @i1_op_usage(%arg0: tensor<4xi1>) -> tensor<4xi1> {
- %c4 = constant 4 : index
- // CHECK: %0 = iree.byte_buffer.constant : !iree.byte_buffer = dense<[1, 0, 1, 0]> : tensor<4xi8>
- %cst = constant dense<[true, false, true, false]> : tensor<4xi1>
- %0 = flow.ex.stream.fragment(%c4, %arg0, %cst) : (index, tensor<4xi1>, tensor<4xi1>) -> (tensor<4xi1>) =
- (%arg1: index, %arg2: tensor<4xi1>, %arg3: tensor<4xi1>) -> (tensor<4xi1>) {
- %1 = flow.dispatch @i1_op_usage_ex_dispatch_0::@i1_op_usage_ex_dispatch_0[%arg1](%arg2, %arg3) : (tensor<4xi1>, tensor<4xi1>) -> (tensor<4xi1>)
- flow.return %1 : tensor<4xi1>
- }
- return %0 : tensor<4xi1>
-}
-
-// CHECK: hal.executable @i1_op_usage_ex_dispatch_0
-// CHECK: hal.executable.target @vmla
-// CHECK: hal.executable.entry_point @i1_op_usage_ex_dispatch_0 attributes {
-// CHECK-SAME: interface = @legacy_io
-// CHECK-SAME: ordinal = 0 : index
-// CHECK-SAME: signature = (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
-flow.executable @i1_op_usage_ex_dispatch_0 attributes {sym_visibility = "private"} {
- flow.dispatch.entry @i1_op_usage_ex_dispatch_0
- // CHECK: vm.module @module
- module {
- // CHECK: vm.rodata {{.+}} dense<[0, 0, 1, 0]> : tensor<4xi8>
- // CHECK: vm.func @i1_op_usage_ex_dispatch_0
- func @i1_op_usage_ex_dispatch_0(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
- %0 = mhlo.and %arg0, %arg1 : tensor<4xi1>
- %cst = mhlo.constant dense<[false, false, true, false]> : tensor<4xi1>
- %1 = mhlo.and %0, %cst : tensor<4xi1>
- return %1 : tensor<4xi1>
- }
- }
-}
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir
deleted file mode 100644
index 5b43c72..0000000
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir
+++ /dev/null
@@ -1,294 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-hal-link-executables -iree-hal-target-backends=vmla %s | IreeFileCheck %s
-
-module {
- hal.executable @dispatch_0 attributes {sym_visibility = "private"} {
- hal.interface @io {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.target @vmla, filter="vmla" {
- hal.executable.entry_point @dispatch_0 attributes {interface = @io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
- module {
- vm.module @module {
- vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
- vm.return
- }
- vm.export @dispatch_0
- }
- }
- }
- }
- hal.executable @dispatch_1 attributes {sym_visibility = "private"} {
- hal.interface @io {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.target @vmla, filter="vmla" {
- hal.executable.entry_point @dispatch_1 attributes {interface = @io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
- module {
- vm.module @module {
- vm.func @dispatch_1(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
- vm.return
- }
- vm.export @dispatch_1
- }
- }
- }
- }
- hal.executable @dispatch_2 attributes {sym_visibility = "private"} {
- hal.interface @io {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.target @vmla, filter="vmla" {
- hal.executable.entry_point @dispatch_2 attributes {interface = @io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
- module {
- vm.module @module {
- vm.func @dispatch_2(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32) {
- vm.return
- }
- vm.export @dispatch_2
- }
- }
- }
- }
- func @main() -> () {
- %device = hal.ex.shared_device : !hal.device
- %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
- %c1 = constant 1 : index
- hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_0::@vmla::@dispatch_0) workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_1::@vmla::@dispatch_1) workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_2::@vmla::@dispatch_2) workgroups([%c1, %c1, %c1])
- return
- }
-}
-
-// All executables (including their interfaces and entry points) should be linked together into @linked_vmla
-// CHECK-NOT: hal.executable @dispatch_0
-// CHECK-NOT: hal.executable @dispatch_1
-// CHECK-NOT: hal.executable @dispatch_2
-// CHECK: hal.executable @vmla_linked_1 attributes {sym_visibility = "private"} {
-// CHECK-NEXT: hal.interface @io_0 {
-// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-// CHECK-NEXT: }
-// CHECK-NEXT: hal.interface @io_1 {
-// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-// CHECK-NEXT: }
-// CHECK-NEXT: hal.executable.target @vmla, filter="vmla" {
-// CHECK-NEXT: hal.executable.entry_point @dispatch_0 attributes {interface = @io_0, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
-// CHECK-NEXT: hal.executable.entry_point @dispatch_1 attributes {interface = @io_0, ordinal = 1 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
-// CHECK-NEXT: hal.executable.entry_point @dispatch_2 attributes {interface = @io_1, ordinal = 2 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
-// CHECK-NEXT: module {
-// CHECK-NEXT: vm.module @linked_module {
-// CHECK-NEXT: vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
-// CHECK-NEXT: vm.return
-// CHECK-NEXT: }
-// CHECK-NEXT: vm.export @dispatch_0
-// CHECK-NEXT: vm.func @dispatch_1(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
-// CHECK-NEXT: vm.return
-// CHECK-NEXT: }
-// CHECK-NEXT: vm.export @dispatch_1
-// CHECK-NEXT: vm.func @dispatch_2(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32) {
-// CHECK-NEXT: vm.return
-// CHECK-NEXT: }
-// CHECK-NEXT: vm.export @dispatch_2
-// CHECK-NEXT: }
-// CHECK-NEXT: }
-// CHECK-NEXT: }
-// CHECK-NEXT: }
-//
-// CHECK: func @main() {
-// CHECK: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@vmla_linked_1::@vmla::@dispatch_0) workgroups([%c1, %c1, %c1])
-// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@vmla_linked_1::@vmla::@dispatch_1) workgroups([%c1, %c1, %c1])
-// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@vmla_linked_1::@vmla::@dispatch_2) workgroups([%c1, %c1, %c1])
-// CHECK-NEXT: return
-// CHECK-NEXT: }
-
-// -----
-
-module {
- hal.executable @dispatch_0 attributes {sym_visibility = "private"} {
- hal.interface @io {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.target @vmla, filter="vmla" {
- hal.executable.entry_point @dispatch_0 attributes {interface = @io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
- module {
- vm.module @module {
- vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
- vm.return
- }
- vm.export @dispatch_0
- }
- }
- }
- hal.executable.target @othertarget, filter="othertarget" {
- module {
- }
- }
- }
- hal.executable @dispatch_1 attributes {sym_visibility = "private"} {
- hal.interface @io {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.target @vmla, filter="vmla" {
- hal.executable.entry_point @dispatch_1 attributes {interface = @io, ordinal = 0 : index, signature = (tensor<1x1xf32>) -> tensor<1x1xf32>}
- module {
- vm.module @module {
- vm.func @dispatch_1(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32) {
- vm.return
- }
- vm.export @dispatch_1
- }
- }
- }
- hal.executable.target @othertarget, filter="othertarget" {
- module {
- }
- }
- }
- func @main() -> () {
- %device = hal.ex.shared_device : !hal.device
- %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
- hal.device.switch<%device : !hal.device>
- #hal.device.match.id<"vmla">(%arg1 = %cmd : !hal.command_buffer) {
- %c1 = constant 1 : index
- hal.command_buffer.dispatch.symbol<%arg1 : !hal.command_buffer> target(@dispatch_0::@vmla::@dispatch_0) workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch.symbol<%arg1 : !hal.command_buffer> target(@dispatch_1::@vmla::@dispatch_1) workgroups([%c1, %c1, %c1])
- hal.return
- },
- #hal.device.match.id<"othertarget">(%arg1 = %cmd : !hal.command_buffer) {
- %c1 = constant 1 : index
- hal.command_buffer.dispatch.symbol<%arg1 : !hal.command_buffer> target(@dispatch_0::@otherdispatch::@dispatch_0) workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch.symbol<%arg1 : !hal.command_buffer> target(@dispatch_1::@otherdispatch::@dispatch_1) workgroups([%c1, %c1, %c1])
- hal.return
- }
- return
- }
-}
-
-// VMLA target should be pulled out from both executables
-// CHECK: hal.executable @vmla_linked_1 attributes {sym_visibility = "private"} {
-// CHECK-NEXT: hal.interface @io_0 {
-// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-// CHECK-NEXT: }
-// CHECK-NEXT: hal.interface @io_1 {
-// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-// CHECK-NEXT: }
-// CHECK-NEXT: hal.executable.target @vmla, filter="vmla" {
-// CHECK-NEXT: hal.executable.entry_point @dispatch_0 attributes {interface = @io_0, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
-// CHECK-NEXT: hal.executable.entry_point @dispatch_1 attributes {interface = @io_1, ordinal = 1 : index, signature = (tensor<1x1xf32>) -> tensor<1x1xf32>}
-// CHECK-NEXT: module {
-// CHECK-NEXT: vm.module @linked_module {
-// CHECK-NEXT: vm.func @dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
-// CHECK-NEXT: vm.return
-// CHECK-NEXT: }
-// CHECK-NEXT: vm.export @dispatch_0
-// CHECK-NEXT: vm.func @dispatch_1(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32) {
-// CHECK-NEXT: vm.return
-// CHECK-NEXT: }
-// CHECK-NEXT: vm.export @dispatch_1
-// CHECK-NEXT: }
-// CHECK-NEXT: }
-// CHECK-NEXT: }
-// CHECK-NEXT: }
-//
-// @dispatch_0/1 should remain, with just @othertarget
-// CHECK: hal.executable @dispatch_0 attributes {sym_visibility = "private"} {
-// CHECK: hal.interface @io
-// CHECK: hal.executable.target @othertarget, filter="othertarget"
-// CHECK: hal.executable @dispatch_1 attributes {sym_visibility = "private"} {
-// CHECK: hal.interface @io
-// CHECK: hal.executable.target @othertarget, filter="othertarget"
-//
-// CHECK: func @main() {
-// CHECK: hal.device.switch<%device : !hal.device>
-// CHECK-NEXT: #hal.device.match.id<"vmla">(%arg0 = %cmd : !hal.command_buffer) {
-// CHECK-NEXT: %c1 = constant 1 : index
-// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%arg0 : !hal.command_buffer> target(@vmla_linked_1::@vmla::@dispatch_0) workgroups([%c1, %c1, %c1])
-// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%arg0 : !hal.command_buffer> target(@vmla_linked_1::@vmla::@dispatch_1) workgroups([%c1, %c1, %c1])
-// CHECK-NEXT: hal.return
-// CHECK-NEXT: },
-// CHECK-NEXT: #hal.device.match.id<"othertarget">(%arg0 = %cmd : !hal.command_buffer) {
-// CHECK-NEXT: %c1 = constant 1 : index
-// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%arg0 : !hal.command_buffer> target(@dispatch_0::@otherdispatch::@dispatch_0) workgroups([%c1, %c1, %c1])
-// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%arg0 : !hal.command_buffer> target(@dispatch_1::@otherdispatch::@dispatch_1) workgroups([%c1, %c1, %c1])
-// CHECK-NEXT: hal.return
-// CHECK-NEXT: }
-// CHECK-NEXT: return
-// CHECK-NEXT: }
-
-// -----
-
-module {
- hal.executable @dispatch_0 attributes {sym_visibility = "private"} {
- hal.interface @io {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.target @vmla, filter="vmla" {
- hal.executable.entry_point @dispatch_0 attributes {interface = @io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
- module {
- vm.module @module {}
- }
- }
- }
- hal.executable @dispatch_1 attributes {sym_visibility = "private"} {
- hal.interface @io attributes {push_constants = 2 : index} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.target @vmla, filter="vmla" {
- hal.executable.entry_point @dispatch_1 attributes {interface = @io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
- module {
- vm.module @module {}
- }
- }
- }
- hal.executable @dispatch_2 attributes {sym_visibility = "private"} {
- hal.interface @io attributes {push_constants = 2 : index} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.target @vmla, filter="vmla" {
- hal.executable.entry_point @dispatch_2 attributes {interface = @io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>}
- module {
- vm.module @module {}
- }
- }
- }
-}
-
-// Interfaces with different numbers of push constants should remain separate.
-// CHECK-NOT: hal.executable @dispatch_0
-// CHECK-NOT: hal.executable @dispatch_1
-// CHECK-NOT: hal.executable @dispatch_2
-// CHECK: hal.executable @vmla_linked_1 attributes {sym_visibility = "private"} {
-// CHECK-NEXT: hal.interface @io_0 {
-// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-// CHECK-NEXT: }
-// CHECK-NEXT: hal.interface @io_1 attributes {push_constants = 2 : index} {
-// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
-// CHECK-NEXT: }
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
deleted file mode 100644
index 92b6caf..0000000
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
+++ /dev/null
@@ -1,136 +0,0 @@
-// RUN: iree-opt -split-input-file -pass-pipeline='iree-hal-transformation-pipeline{serialize-executables=false},canonicalize' -iree-hal-target-backends=vmla %s | IreeFileCheck %s
-
-flow.executable @simpleMath_ex_dispatch_0 {
- flow.dispatch.entry @simpleMath_rgn_dispatch_0 attributes {
- workload = 4 : index
- }
- module {
- func @simpleMath_rgn_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> {
- %0 = mhlo.add %arg0, %arg0 : tensor<4xf32>
- return %0 : tensor<4xf32>
- }
- }
-}
-
-// CHECK-LABEL: hal.executable @simpleMath_ex_dispatch_0
-// CHECK-NEXT: hal.interface @legacy_io {
-// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
-// CHECK-NEXT: }
-// CHECK-NEXT: hal.executable.target @vmla, filter="vmla" {
-// CHECK-NEXT: hal.executable.entry_point @simpleMath_rgn_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<4xf32>) -> tensor<4xf32>}
-// CHECK-NEXT: module {
-// CHECK-NEXT: vm.module @module {
-// CHECK-DAG: vm.import @vmla.interface.binding(%interface : !vm.ref<!vmla.interface>, %set : i32, %binding : i32) -> !vm.ref<!vmla.buffer>
-// CHECK-DAG: vm.import @vmla.buffer.alloc(%byte_length : i32) -> !vm.ref<!vmla.buffer>
-// CHECK-DAG: vm.import @vmla.buffer.view(%src : !vm.ref<!vmla.buffer>, %byte_offset : i32, %byte_length : i32) -> !vm.ref<!vmla.buffer>
-// CHECK-DAG: vm.import @vmla.buffer.copy(%src : !vm.ref<!vmla.buffer>, %src_byte_offset : i32, %dst : !vm.ref<!vmla.buffer>, %dst_byte_offset : i32, %byte_length : i32)
-// CHECK-DAG: vm.import @vmla.add.f32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-// CHECK: vm.func @simpleMath_rgn_dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
-// CHECK-DAG: %zero = vm.const.i32.zero : i32
-// CHECK-DAG: %c16 = vm.const.i32 16 : i32
-// CHECK-DAG: %c1 = vm.const.i32 1 : i32
-// CHECK-NEXT: %ref = vm.call @vmla.interface.binding(%arg0, %zero, %zero) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: %ref_0 = vm.call @vmla.buffer.view(%ref, %zero, %c16) : (!vm.ref<!vmla.buffer>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: %ref_1 = vm.call @vmla.buffer.alloc(%c16) : (i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: vm.call @vmla.add.f32(%ref_0, %ref_0, %ref_1) : (!vm.ref<!vmla.buffer>, !vm.ref<!vmla.buffer>, !vm.ref<!vmla.buffer>) -> ()
-// CHECK-NEXT: %ref_2 = vm.call @vmla.interface.binding(%arg0, %zero, %c1) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: vm.call @vmla.buffer.copy(%ref_1, %zero, %ref_2, %zero, %c16) : (!vm.ref<!vmla.buffer>, i32, !vm.ref<!vmla.buffer>, i32, i32) -> ()
-// CHECK-NEXT: vm.return
-// CHECK-NEXT: }
-// CHECK-NEXT: vm.export @simpleMath_rgn_dispatch_0
-
-// -----
-
-flow.executable @shaped_dispatch {
- flow.dispatch.entry @entry
- module {
- func @entry(%arg0: tensor<4x?xf32>, %arg1 : index) -> tensor<4x?xf32> {
- %0 = shapex.make_ranked_shape %arg1 : (index) -> !shapex.ranked_shape<[4,?]>
- %1 = shapex.tie_shape %arg0, %0 : tensor<4x?xf32>, !shapex.ranked_shape<[4,?]>
- %2 = mhlo.add %1, %1 : tensor<4x?xf32>
- %3 = shapex.tie_shape %2, %0 : tensor<4x?xf32>, !shapex.ranked_shape<[4,?]>
- return %3 : tensor<4x?xf32>
- }
- }
-}
-
-// CHECK-LABEL: hal.executable @shaped_dispatch
-// CHECK-NEXT: hal.interface @legacy_io attributes {push_constants = 1 : index} {
-// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
-// CHECK-NEXT: }
-// CHECK-NEXT: hal.executable.target @vmla, filter="vmla" {
-// CHECK-NEXT: hal.executable.entry_point @entry attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<4x?xf32>, index) -> tensor<4x?xf32>}
-// CHECK-NEXT: module {
-// CHECK-NEXT: vm.module @module {
-// CHECK: vm.func @entry(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
-// CHECK-DAG: %zero = vm.const.i32.zero : i32
-// CHECK-DAG: %c16 = vm.const.i32 16 : i32
-// CHECK-DAG: %c1 = vm.const.i32 1 : i32
-// CHECK-NEXT: %0 = vm.call @vmla.interface.const(%arg0, %zero) : (!vm.ref<!vmla.interface>, i32) -> i32
-// CHECK-NEXT: %ref = vm.call @vmla.interface.binding(%arg0, %zero, %zero) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: %1 = vm.mul.i32 %0, %c16 : i32
-// CHECK-NEXT: %ref_0 = vm.call @vmla.buffer.view(%ref, %zero, %1) : (!vm.ref<!vmla.buffer>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: %ref_1 = vm.call @vmla.buffer.alloc(%1) : (i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: vm.call @vmla.add.f32(%ref_0, %ref_0, %ref_1) : (!vm.ref<!vmla.buffer>, !vm.ref<!vmla.buffer>, !vm.ref<!vmla.buffer>) -> ()
-// CHECK-NEXT: %ref_2 = vm.call @vmla.interface.binding(%arg0, %zero, %c1) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: vm.call @vmla.buffer.copy(%ref_1, %zero, %ref_2, %zero, %1) : (!vm.ref<!vmla.buffer>, i32, !vm.ref<!vmla.buffer>, i32, i32) -> ()
-// CHECK-NEXT: vm.return
-// CHECK-NEXT: }
-// CHECK-NEXT: vm.export @entry
-
-// -----
-
-flow.executable @reduction_ex_dispatch_0 {
- flow.dispatch.entry @reduction_ex_dispatch_0 attributes {workload = 4 : index}
- module {
- func @reduction_ex_dispatch_0(%arg0: tensor<4x8xf32>) -> tensor<4xf32> {
- %cst = constant dense<0.000000e+00> : tensor<f32>
- %0 = "mhlo.reduce"(%arg0, %cst) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
- %1 = mhlo.add %arg1, %arg2 : tensor<f32>
- "mhlo.return"(%1) : (tensor<f32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
- return %0 : tensor<4xf32>
- }
- }
-}
-
-// CHECK-LABEL: hal.executable @reduction_ex_dispatch_0
-// CHECK-NEXT: hal.interface @legacy_io {
-// CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
-// CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
-// CHECK-NEXT: }
-// CHECK-NEXT: hal.executable.target @vmla, filter="vmla" {
-// CHECK-NEXT: hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<4x8xf32>) -> tensor<4xf32>}
-// CHECK-NEXT: module {
-// CHECK-NEXT: vm.module @module {
-// CHECK-DAG: vm.import @vmla.interface.binding(%interface : !vm.ref<!vmla.interface>, %set : i32, %binding : i32) -> !vm.ref<!vmla.buffer>
-// CHECK-DAG: vm.import @vmla.buffer.const(%value : !vm.buffer) -> !vm.ref<!vmla.buffer>
-// CHECK-DAG: vm.import @vmla.buffer.alloc(%byte_length : i32) -> !vm.ref<!vmla.buffer>
-// CHECK-DAG: vm.import @vmla.buffer.view(%src : !vm.ref<!vmla.buffer>, %byte_offset : i32, %byte_length : i32) -> !vm.ref<!vmla.buffer>
-// CHECK-DAG: vm.import @vmla.buffer.copy(%src : !vm.ref<!vmla.buffer>, %src_byte_offset : i32, %dst : !vm.ref<!vmla.buffer>, %dst_byte_offset : i32, %byte_length : i32)
-// CHECK-DAG: vm.import @vmla.buffer.fill(%value : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-// CHECK-DAG: vm.import @vmla.reduce.sum.f32(%src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ..., %dimension : i32, %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...)
-// CHECK-DAG: vm.rodata @reduction_ex_dispatch_0_const dense<0.000000e+00> : tensor<1xf32>
-// CHECK: vm.func @reduction_ex_dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
-// CHECK-DAG: %zero = vm.const.i32.zero : i32
-// CHECK-DAG: %c128 = vm.const.i32 128 : i32
-// CHECK-DAG: %c16 = vm.const.i32 16 : i32
-// CHECK-DAG: %c4 = vm.const.i32 4 : i32
-// CHECK-DAG: %c8 = vm.const.i32 8 : i32
-// CHECK-DAG: %c1 = vm.const.i32 1 : i32
-// CHECK-NEXT: %reduction_ex_dispatch_0_const = vm.const.ref.rodata @reduction_ex_dispatch_0_const : !vm.buffer
-// CHECK-NEXT: %ref = vm.call @vmla.buffer.const(%reduction_ex_dispatch_0_const) : (!vm.buffer) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: %ref_0 = vm.call @vmla.buffer.alloc(%c4) : (i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: vm.call @vmla.buffer.fill(%ref, %ref_0) : (!vm.ref<!vmla.buffer>, !vm.ref<!vmla.buffer>) -> ()
-// CHECK-NEXT: %ref_1 = vm.call @vmla.interface.binding(%arg0, %zero, %zero) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: %ref_2 = vm.call @vmla.buffer.view(%ref_1, %zero, %c128) : (!vm.ref<!vmla.buffer>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: %ref_3 = vm.call @vmla.buffer.alloc(%c16) : (i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: vm.call.variadic @vmla.reduce.sum.f32(%ref_2, [%c4, %c8], %ref_0, [], %c1, %ref_3, [%c4]) : (!vm.ref<!vmla.buffer>, i32 ..., !vm.ref<!vmla.buffer>, i32 ..., i32, !vm.ref<!vmla.buffer>, i32 ...)
-// CHECK-NEXT: %ref_4 = vm.call @vmla.interface.binding(%arg0, %zero, %c1) : (!vm.ref<!vmla.interface>, i32, i32) -> !vm.ref<!vmla.buffer>
-// CHECK-NEXT: vm.call @vmla.buffer.copy(%ref_3, %zero, %ref_4, %zero, %c16) : (!vm.ref<!vmla.buffer>, i32, !vm.ref<!vmla.buffer>, i32, i32) -> ()
-// CHECK-NEXT: vm.return
-// CHECK-NEXT: }
-// CHECK-NEXT: vm.export @reduction_ex_dispatch_0
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/BUILD b/iree/compiler/Dialect/Shape/Plugins/VMLA/BUILD
deleted file mode 100644
index 17d9247..0000000
--- a/iree/compiler/Dialect/Shape/Plugins/VMLA/BUILD
+++ /dev/null
@@ -1,35 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "VMLAShapeBuilder",
- srcs = [
- "VMLAShapeBuilder.cpp",
- ],
- hdrs = [
- "VMLAShapeBuilder.h",
- ],
- deps = [
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/VMLA/IR",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:IR",
- ],
-)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/CMakeLists.txt b/iree/compiler/Dialect/Shape/Plugins/VMLA/CMakeLists.txt
deleted file mode 100644
index a4f88f8..0000000
--- a/iree/compiler/Dialect/Shape/Plugins/VMLA/CMakeLists.txt
+++ /dev/null
@@ -1,28 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/Shape/Plugins/VMLA/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- VMLAShapeBuilder
- HDRS
- "VMLAShapeBuilder.h"
- SRCS
- "VMLAShapeBuilder.cpp"
- DEPS
- LLVMSupport
- MLIRIR
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::VMLA::IR
- PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.cpp b/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.cpp
deleted file mode 100644
index 10fa04d..0000000
--- a/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.cpp
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h"
-
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "llvm/ADT/BitVector.h"
-#include "llvm/ADT/Optional.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Value.h"
-
-using namespace mlir::iree_compiler::Shape;
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace VMLA {
-namespace {
-
-Value rewriteBatchMatMulPseudoOp(RankedShapeType resultShape,
- BatchMatMulPseudoOp op, OpBuilder &builder) {
- auto lhsShape = builder.create<GetRankedShapeOp>(op.getLoc(), op.lhs());
- auto rhsShape = builder.create<GetRankedShapeOp>(op.getLoc(), op.rhs());
- SmallVector<Value, 6> extents;
- // Batch dimension (already been established to match between both operands,
- // so arbitrarily use the LHS).
- extents.push_back(builder.create<RankedDimOp>(op.getLoc(), lhsShape, 0));
- // RHS free dimension.
- extents.push_back(builder.create<RankedDimOp>(op.getLoc(), rhsShape, 1));
- // LHS free dimension.
- extents.push_back(builder.create<RankedDimOp>(op.getLoc(), lhsShape, 1));
- // Due to a quirk of MakeRankedShapeOp, we only pass in the dynamic dims.
- // So prune them down here.
- SmallVector<Value, 6> onlyDynamicExtents;
- for (int i = 0; i < 3; i++) {
- if (resultShape.isDimDynamic(i)) {
- onlyDynamicExtents.push_back(extents[i]);
- }
- }
- return builder.create<MakeRankedShapeOp>(op.getLoc(), resultShape,
- onlyDynamicExtents);
-}
-
-} // namespace
-
-void populateVMLACustomOpShapeBuilder(CustomOpShapeBuilderList &builders) {
- auto &b = builders.make<CallbackCustomOpShapeBuilder>();
- b.insertOpRankedShapeBuilder<BatchMatMulPseudoOp>(rewriteBatchMatMulPseudoOp);
-}
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h b/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h
deleted file mode 100644
index a80de01..0000000
--- a/iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_SHAPE_IR_VMLASHAPEBUILDER_H_
-#define IREE_COMPILER_DIALECT_SHAPE_IR_VMLASHAPEBUILDER_H_
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace VMLA {
-// Creates a custom op shape builder for VMLA ops that are not otherwise
-// supported through traits or other declarative means.
-void populateVMLACustomOpShapeBuilder(
- iree_compiler::Shape::CustomOpShapeBuilderList &builders);
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_SHAPE_IR_VMLASHAPEBUILDER_H_
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD
deleted file mode 100644
index 24fd287..0000000
--- a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:lit_test.bzl", "iree_lit_test_suite")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_lit_test_suite(
- name = "lit",
- srcs = enforce_glob(
- ["custom_ops.mlir"],
- include = ["*.mlir"],
- ),
- data = [
- "//iree/tools:IreeFileCheck",
- "//iree/tools:iree-opt",
- ],
-)
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt
deleted file mode 100644
index 3c6404e..0000000
--- a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/CMakeLists.txt
+++ /dev/null
@@ -1,23 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/Shape/Plugins/VMLA/test/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_lit_test_suite(
- NAME
- lit
- SRCS
- "custom_ops.mlir"
- DATA
- iree::tools::IreeFileCheck
- iree::tools::iree-opt
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/custom_ops.mlir b/iree/compiler/Dialect/Shape/Plugins/VMLA/test/custom_ops.mlir
deleted file mode 100644
index 70ba1cb..0000000
--- a/iree/compiler/Dialect/Shape/Plugins/VMLA/test/custom_ops.mlir
+++ /dev/null
@@ -1,19 +0,0 @@
-// RUN: iree-opt -split-input-file -verify-diagnostics -iree-shape-materialize-calculations %s | IreeFileCheck %s
-
-// -----
-// CHECK-LABEL: func @batch.matmul.pseudo
-func @batch.matmul.pseudo(
- %lhs: tensor<?x?x?xf32>, %rhs: tensor<?x?x?xf32>,
- %lhsShape: !shapex.ranked_shape<[?,?,?]>, %rhsShape: !shapex.ranked_shape<[?,?,?]>
-) -> !shapex.ranked_shape<[?,?,?]> {
- %lhsTied = shapex.tie_shape %lhs, %lhsShape : tensor<?x?x?xf32>, !shapex.ranked_shape<[?,?,?]>
- %rhsTied = shapex.tie_shape %rhs, %rhsShape : tensor<?x?x?xf32>, !shapex.ranked_shape<[?,?,?]>
- %0 = "vmla.batch.matmul.pseudo"(%lhsTied, %rhsTied) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
- // CHECK-DAG: %[[BATCH:.+]] = shapex.ranked_dim %arg2[0]
- // CHECK-DAG: %[[FLHS:.+]] = shapex.ranked_dim %arg2[1]
- // CHECK-DAG: %[[FRHS:.+]] = shapex.ranked_dim %arg3[1]
- // CHECK-DAG: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[BATCH]], %[[FRHS]], %[[FLHS]]
- // CHECK-DAG: return %[[SHAPE]]
- %1 = shapex.get_ranked_shape %0 : tensor<?x?x?xf32> -> !shapex.ranked_shape<[?,?,?]>
- return %1 : !shapex.ranked_shape<[?,?,?]>
-}
diff --git a/iree/compiler/Dialect/Shape/Transforms/BUILD b/iree/compiler/Dialect/Shape/Transforms/BUILD
index 4c3b105..5f11f4f 100644
--- a/iree/compiler/Dialect/Shape/Transforms/BUILD
+++ b/iree/compiler/Dialect/Shape/Transforms/BUILD
@@ -36,7 +36,6 @@
],
deps = [
"//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/Shape/Plugins/VMLA:VMLAShapeBuilder",
"//iree/compiler/Dialect/Shape/Plugins/XLA:XlaHloShapeBuilder",
"//iree/compiler/Dialect/Shape/Utils:TypeConversion",
"//iree/compiler/Utils",
diff --git a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
index 61989fb..1ad3803 100644
--- a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
@@ -36,7 +36,6 @@
MLIRTensor
MLIRTransforms
iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::Shape::Plugins::VMLA::VMLAShapeBuilder
iree::compiler::Dialect::Shape::Plugins::XLA::XlaHloShapeBuilder
iree::compiler::Dialect::Shape::Utils::TypeConversion
iree::compiler::Utils
diff --git a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
index 6b0927a..07f59df 100644
--- a/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
+++ b/iree/compiler/Dialect/Shape/Transforms/MaterializeShapeCalculations.cpp
@@ -17,7 +17,6 @@
#include "iree/compiler/Dialect/Shape/IR/ShapeInterface.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-#include "iree/compiler/Dialect/Shape/Plugins/VMLA/VMLAShapeBuilder.h"
#include "iree/compiler/Dialect/Shape/Plugins/XLA/XlaHloShapeBuilder.h"
#include "iree/compiler/Dialect/Shape/Transforms/Patterns.h"
#include "iree/compiler/Utils/PatternUtils.h"
@@ -48,7 +47,6 @@
static CustomOpShapeBuilderList globalBuilders = ([]() {
CustomOpShapeBuilderList builders;
mhlo::populateXlaHloCustomOpShapeBuilder(builders);
- IREE::VMLA::populateVMLACustomOpShapeBuilder(builders);
return builders;
})();
return &globalBuilders;
diff --git a/iree/compiler/Dialect/VMLA/BUILD b/iree/compiler/Dialect/VMLA/BUILD
deleted file mode 100644
index 34b653c..0000000
--- a/iree/compiler/Dialect/VMLA/BUILD
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//build_tools/embed_data:build_defs.bzl", "c_embed_data")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-c_embed_data(
- name = "vmla_imports",
- srcs = ["vmla.imports.mlir"],
- c_file_output = "vmla.imports.c",
- flatten = True,
- h_file_output = "vmla.imports.h",
- identifier = "iree_vmla_imports",
-)
diff --git a/iree/compiler/Dialect/VMLA/CMakeLists.txt b/iree/compiler/Dialect/VMLA/CMakeLists.txt
deleted file mode 100644
index 81371fc..0000000
--- a/iree/compiler/Dialect/VMLA/CMakeLists.txt
+++ /dev/null
@@ -1,28 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_c_embed_data(
- NAME
- vmla_imports
- SRCS
- "vmla.imports.mlir"
- C_FILE_OUTPUT
- "vmla.imports.c"
- H_FILE_OUTPUT
- "vmla.imports.h"
- IDENTIFIER
- "iree_vmla_imports"
- FLATTEN
- PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Conversion/BUILD b/iree/compiler/Dialect/VMLA/Conversion/BUILD
deleted file mode 100644
index d1e955d..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/BUILD
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "Conversion",
- srcs = [
- "ConversionTarget.cpp",
- "TypeConverter.cpp",
- ],
- hdrs = [
- "ConversionTarget.h",
- "TypeConverter.h",
- ],
- deps = [
- "//iree/compiler/Dialect/IREE/IR",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/VMLA/IR",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:TensorDialect",
- "@llvm-project//mlir:Transforms",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/CMakeLists.txt
deleted file mode 100644
index b44f51c..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/CMakeLists.txt
+++ /dev/null
@@ -1,33 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/Conversion/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- Conversion
- HDRS
- "ConversionTarget.h"
- "TypeConverter.h"
- SRCS
- "ConversionTarget.cpp"
- "TypeConverter.cpp"
- DEPS
- MLIRIR
- MLIRStandard
- MLIRTensor
- MLIRTransforms
- iree::compiler::Dialect::IREE::IR
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::VMLA::IR
- PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
deleted file mode 100644
index b12e8aa..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
+++ /dev/null
@@ -1,302 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
-
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATraits.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-using Shape::buildOrFindRankedShapeForValue;
-
-VMLAConversionTarget::VMLAConversionTarget(MLIRContext *context,
- TypeConverter &typeConverter)
- : ConversionTarget(*context), typeConverter(typeConverter) {
- // The VMLA dialect expects both standard ops and the VMLA ops (in case some
- // conversion has already happened).
- addLegalOp<ModuleOp>();
- addLegalDialect("vmla");
- // Pseudo-ops are illegal.
- // If we end up with a lot of these, consider using an "is pseudo" trait.
- addIllegalOp<IREE::VMLA::BatchMatMulPseudoOp>();
- addIllegalOp<IREE::VMLA::SortPseudoOp>();
- addIllegalOp<IREE::VMLA::FftPseudoOp>();
- addIllegalOp<IREE::VMLA::IfftPseudoOp>();
- addIllegalOp<IREE::VMLA::RfftPseudoOp>();
- addIllegalOp<IREE::VMLA::IrfftPseudoOp>();
-
- // Allow other ops to pass through so long as their type is valid (not a
- // tensor, basically).
- markUnknownOpDynamicallyLegal();
- addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
- return typeConverter.isSignatureLegal(op.getType()) &&
- typeConverter.isLegal(&op.getBody());
- });
- addDynamicallyLegalOp<ConstantOp>(
- [&](ConstantOp op) { return typeConverter.isLegal(op.getType()); });
- addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
- return llvm::all_of(op.getOperandTypes(), [&typeConverter](Type t) {
- return typeConverter.isLegal(t);
- });
- });
-}
-
-bool VMLAConversionTarget::isDynamicallyLegal(Operation *op) const {
- // Short-circuit test that bails on the first illegal type.
- const auto isTypeIllegal = [&](Type type) {
- return !typeConverter.isLegal(type);
- };
- return !(llvm::any_of(op->getOperandTypes(), isTypeIllegal) ||
- llvm::any_of(op->getResultTypes(), isTypeIllegal));
-}
-
-static Attribute convertAttribute(Attribute srcAttribute) {
- auto *context = srcAttribute.getContext();
- Type attrType = srcAttribute.getType();
- auto elementsAttr = srcAttribute.dyn_cast<ElementsAttr>();
- auto tensorType = attrType.dyn_cast<RankedTensorType>();
- auto indexType = IndexType::get(context);
- auto i64Type = IntegerType::get(context, 64);
- // Detect and convert index and i64 tensor attributes to i32 since these
- // invariably must be imported as some kind of VM constant, and the VM is
- // 32bit only.
- // TODO(laurenzo): Remove the i64 match once the HLO ops are defined in terms
- // of index for shape components (vs i64).
- if (elementsAttr && tensorType &&
- (tensorType.getElementType() == i64Type ||
- tensorType.getElementType() == indexType)) {
- auto i32Type = IntegerType::get(context, 32);
- using func_type = APInt(const APInt &);
- return elementsAttr.mapValues(
- i32Type, llvm::function_ref<func_type>([](const APInt &in) -> APInt {
- int64_t inValue = in.getSExtValue();
- return APInt(32, inValue, true);
- }));
- }
-
- return srcAttribute;
-}
-
-// static
-LogicalResult VMLAConversionTarget::applyDefaultBufferRewrite(
- Operation *srcOp, ArrayRef<Value> operands, VMLAOpSemantics semantics,
- StringRef dstOpName, TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter) {
- OperationState state{srcOp->getLoc(), dstOpName};
- for (auto srcAttrPair : srcOp->getAttrs()) {
- state.addAttribute(srcAttrPair.first, convertAttribute(srcAttrPair.second));
- }
-
- auto *dstOperation = state.name.getAbstractOperation();
- auto *opInterface = dstOperation->getInterface<IREE::VMLA::VMLAOp>();
-
- // Allow the op to get at any of the type information it requires. For
- // example, if the op may later need to know the type of the elements in a
- // type-erased buffer it can stash the original tensor type as an attribute.
- if (opInterface) {
- auto convertTensorType = [](Type type) -> Type {
- if (auto tensorType = type.dyn_cast<TensorType>()) {
- return VMLATypeConverter::convertTensorTypeToVMLAType(tensorType);
- }
- return type;
- };
- auto operandTypes = llvm::to_vector<4>(
- llvm::map_range(srcOp->getOperandTypes(),
- [&](Type type) { return convertTensorType(type); }));
- auto resultTypes = llvm::to_vector<4>(
- llvm::map_range(srcOp->getResultTypes(),
- [&](Type type) { return convertTensorType(type); }));
- opInterface->extractTypeAttributes(state, operandTypes, resultTypes);
- }
-
- // Until MLIR supports unsigned types we need to sidechannel this to the
- // VMLA->VM conversion that really needs to know.
- switch (semantics) {
- default:
- break;
- case VMLAOpSemantics::kForceUnsigned:
- state.addAttribute("force_unsigned", UnitAttr::get(srcOp->getContext()));
- break;
- }
-
- // Add all input operands.
- for (auto srcDstOperand : llvm::zip(srcOp->getOperands(), operands)) {
- auto srcOperand = std::get<0>(srcDstOperand);
- auto dstOperand = std::get<1>(srcDstOperand);
- if (auto tensorType =
- srcOperand.getType().template dyn_cast<TensorType>()) {
- // Some ops also require shape information.
- state.addOperands({dstOperand});
- if (dstOperation->hasTrait<OpTrait::IREE::VMLA::IncludeShapes>()) {
- Value operandShape = getTensorShape(srcOp->getLoc(), srcOperand,
- typeConverter, rewriter);
- if (!operandShape) {
- return srcOp->emitError() << "failed to get operand tensor shape";
- }
- state.addOperands({operandShape});
- }
- } else {
- // Normal pass-through operand.
- state.addOperands({dstOperand});
- }
- }
-
- // Allocate output buffers for tensors returned by the op. We'll append these
- // to the operands in order (as is convention here).
- SmallVector<Value, 4> allocatedBuffers;
- for (auto srcResult : srcOp->getResults()) {
- if (auto tensorType = srcResult.getType().template dyn_cast<TensorType>()) {
- auto dstBuffer = allocateOutputBuffer(srcOp->getLoc(), srcResult,
- typeConverter, rewriter);
- if (!dstBuffer) {
- return srcOp->emitError()
- << "failed to allocate output buffer for tensor result";
- }
- state.addOperands({dstBuffer});
- allocatedBuffers.push_back(dstBuffer);
- if (dstOperation->hasTrait<OpTrait::IREE::VMLA::IncludeShapes>()) {
- Value resultShape =
- getTensorShape(srcOp->getLoc(), srcResult, typeConverter, rewriter);
- if (!resultShape) {
- return srcOp->emitError() << "failed to get operand tensor shape";
- }
- state.addOperands({resultShape});
- }
- } else {
- // Normal pass-through result.
- state.addTypes({srcResult.getType()});
- }
- }
-
- // Rebuild the result list and replace the op ensuring that all original op
- // results are represented in order even if we changed them to out params.
- auto *dstOp = rewriter.createOperation(state);
- auto dstResults = llvm::to_vector<4>(dstOp->getResults());
- SmallVector<Value, 4> resultValues;
- for (auto resultType : srcOp->getResultTypes()) {
- if (resultType.template isa<TensorType>()) {
- resultValues.push_back(allocatedBuffers.front());
- allocatedBuffers.erase(allocatedBuffers.begin());
- } else {
- resultValues.push_back(dstResults.front());
- dstResults.erase(dstResults.begin());
- }
- }
- rewriter.replaceOp(srcOp, resultValues);
- return success();
-}
-
-// static
-Value VMLAConversionTarget::getTensorShape(
- Location loc, Value originalValue, TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter) {
- return buildOrFindRankedShapeForValue(loc, originalValue,
- rewriter.getIndexType(), rewriter);
-}
-
-// static
-Value VMLAConversionTarget::getBufferOffset(
- Location loc, Value tensorValue, Value indicesValue,
- TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
- auto indicesType = indicesValue.getType().cast<ShapedType>();
- SmallVector<Value, 4> indices(indicesType.getNumElements());
- for (int i = 0; i < indicesType.getNumElements(); ++i) {
- auto extractIndex = rewriter.createOrFold<mlir::ConstantIndexOp>(loc, i);
- indices[i] = rewriter.createOrFold<mlir::tensor::ExtractOp>(
- loc, indicesValue, ValueRange{extractIndex});
- }
- return getBufferOffset(loc, tensorValue, indices, typeConverter, rewriter);
-}
-
-// static
-Value VMLAConversionTarget::getBufferOffset(
- Location loc, Value tensorValue, ValueRange indices,
- TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
- // Element type byte length as the base.
- auto tensorType = tensorValue.getType().cast<ShapedType>();
- auto elementType = tensorType.getElementType();
- auto elementSize = rewriter.createOrFold<mlir::ConstantIndexOp>(
- loc, VMLATypeConverter::getRoundedElementByteWidth(elementType));
-
- auto shape = getTensorShape(loc, tensorValue, typeConverter, rewriter);
- if (!shape) {
- return nullptr;
- }
- Value offset = rewriter.createOrFold<mlir::ConstantIndexOp>(loc, 0);
- for (int i = 0; i < tensorType.getRank(); ++i) {
- auto axisOffset = indices[i];
- for (int j = i + 1; j < tensorType.getRank(); ++j) {
- auto dim = rewriter.createOrFold<Shape::RankedDimOp>(
- loc, rewriter.getIntegerType(32), shape, j);
- axisOffset = rewriter.createOrFold<mlir::MulIOp>(loc, axisOffset, dim);
- }
- offset = rewriter.createOrFold<mlir::AddIOp>(loc, offset, axisOffset);
- }
- return rewriter.createOrFold<mlir::MulIOp>(loc, offset, elementSize);
-}
-
-// static
-Value VMLAConversionTarget::getBufferLength(
- Location loc, Value tensorValue, TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter) {
- // Element type byte length as the base.
- auto tensorType = tensorValue.getType().cast<ShapedType>();
- auto elementType = tensorType.getElementType();
- auto elementSize = rewriter.createOrFold<mlir::ConstantIndexOp>(
- loc, VMLATypeConverter::getRoundedElementByteWidth(elementType));
-
- auto shape = getTensorShape(loc, tensorValue, typeConverter, rewriter);
- if (!shape) return nullptr;
- auto dims =
- rewriter.create<Shape::RankedDimsOp>(loc, rewriter.getIndexType(), shape);
- Value length = elementSize;
- for (auto dim : dims.getResults()) {
- length = rewriter.createOrFold<mlir::MulIOp>(loc, length, dim);
- }
- return length;
-}
-
-// static
-Value VMLAConversionTarget::allocateOutputBuffer(
- Location loc, Value originalValue, TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter) {
- // Compute the required buffer size. Since we are always dense (right now)
- // this is just normal x*y*z*...
- Value byteLength =
- getBufferLength(loc, originalValue, typeConverter, rewriter);
- if (!byteLength) {
- return nullptr;
- }
-
- // Allocate the buffer of the required size.
- // The caller can then use the buffer instead of the original SSA value.
- return rewriter.createOrFold<IREE::VMLA::BufferAllocOp>(
- loc, IREE::VMLA::BufferType::get(rewriter.getContext()), byteLength);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h
deleted file mode 100644
index 5d28d67..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h
+++ /dev/null
@@ -1,104 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_VMLA_CONVERSION_CONVERSIONTARGET_H_
-#define IREE_COMPILER_DIALECT_VMLA_CONVERSION_CONVERSIONTARGET_H_
-
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-enum class VMLAOpSemantics {
- kDefault = 0,
- // Forces integers to be treated as unsigned integers.
- kForceUnsigned,
-};
-
-// A conversion target for the VMLA dialect that ensures that tensor types are
-// fully removed. Conversions targeting the VMLA dialect should always use this.
-class VMLAConversionTarget : public ConversionTarget {
- public:
- VMLAConversionTarget(MLIRContext *context, TypeConverter &typeConverter);
-
- // Attempts to rewrite an op that may use tensor values into an op using VMLA
- // buffers. See VMLAOpConversion for more information.
- static LogicalResult applyDefaultBufferRewrite(
- Operation *srcOp, ArrayRef<Value> operands, VMLAOpSemantics semantics,
- StringRef dstOpName, TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter);
-
- // Returns the shape of the |originalValue| tensor as an SSA ranked shape.
- static Value getTensorShape(Location loc, Value originalValue,
- TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter);
-
- // Returns the offset, in bytes, of an index within a linearized dense buffer.
- static Value getBufferOffset(Location loc, Value tensorValue,
- Value indicesValue, TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter);
- static Value getBufferOffset(Location loc, Value tensorValue,
- ValueRange indices, TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter);
-
- // Returns the length, in bytes, of a linearized dense buffer.
- static Value getBufferLength(Location loc, Value tensorValue,
- TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter);
-
- // Allocates a VMLA buffer for an output operand of an op.
- // Returns a buffer allocated with the appropriate size for storing the value.
- // Callers must replace uses of |originalValue| with the returned value.
- static Value allocateOutputBuffer(Location loc, Value originalValue,
- TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter);
-
- private:
- bool isDynamicallyLegal(Operation *op) const override;
-
- TypeConverter &typeConverter;
-};
-
-// VMLA tensor-to-buffer conversion utility.
-// This can be used by dialects to model custom op conversion from a dialect
-// that uses the MLIR tensor type to the IREE VMLA buffer type. At this point
-// during conversion the source values will be TensorType and the target values
-// will be IREE::VMLA::BufferTypes. Any static information available about the
-// tensor (such as static dimensions, element type, layout, etc) are extracted
-// here and lowered as expanded values.
-template <typename SRC, typename DST,
- VMLAOpSemantics semantics = VMLAOpSemantics::kDefault>
-class VMLAOpConversion : public OpConversionPattern<SRC> {
- public:
- using OpConversionPattern<SRC>::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- SRC srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- return VMLAConversionTarget::applyDefaultBufferRewrite(
- srcOp, operands, semantics, DST::getOperationName(),
- *this->getTypeConverter(), rewriter);
- }
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_VMLA_CONVERSION_CONVERSIONTARGET_H_
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/BUILD
deleted file mode 100644
index 0f4e3fb..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/BUILD
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "HALToVMLA",
- srcs = [
- "ConvertHALToVMLA.cpp",
- ],
- hdrs = [
- "ConvertHALToVMLA.h",
- ],
- deps = [
- "//iree/compiler/Dialect/HAL/IR",
- "//iree/compiler/Dialect/IREE/IR",
- "//iree/compiler/Dialect/VMLA/Conversion",
- "//iree/compiler/Dialect/VMLA/IR",
- "//iree/compiler/Dialect/VMLA/IR:VMLADialect",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:Transforms",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/CMakeLists.txt
deleted file mode 100644
index 1a7f476..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/CMakeLists.txt
+++ /dev/null
@@ -1,33 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- HALToVMLA
- HDRS
- "ConvertHALToVMLA.h"
- SRCS
- "ConvertHALToVMLA.cpp"
- DEPS
- MLIRIR
- MLIRPass
- MLIRStandard
- MLIRTransforms
- iree::compiler::Dialect::HAL::IR
- iree::compiler::Dialect::IREE::IR
- iree::compiler::Dialect::VMLA::Conversion
- iree::compiler::Dialect::VMLA::IR
- iree::compiler::Dialect::VMLA::IR::VMLADialect
- PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.cpp
deleted file mode 100644
index 6a1d27f..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.cpp
+++ /dev/null
@@ -1,156 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.h"
-
-#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-struct InterfaceOpEraser : public OpConversionPattern<IREE::HAL::InterfaceOp> {
- using OpConversionPattern<IREE::HAL::InterfaceOp>::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- IREE::HAL::InterfaceOp interfaceOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.eraseOp(interfaceOp);
- return success();
- }
-};
-
-struct InterfaceLoadConstantOpConversion
- : public OpConversionPattern<IREE::HAL::InterfaceLoadConstantOp> {
- InterfaceLoadConstantOpConversion(MLIRContext *context,
- TypeConverter &typeConverter)
- : OpConversionPattern(context), typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- IREE::HAL::InterfaceLoadConstantOp loadOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- // Find the vmla.interface argument to the function.
- auto interfaceArg = loadOp->getParentOfType<FuncOp>().getArgument(0);
- assert(interfaceArg &&
- interfaceArg.getType().isa<IREE::VMLA::InterfaceType>() &&
- "exported VMLA functions require vmla.interface ops as their only "
- "argument");
-
- IREE::HAL::InterfaceLoadConstantOp::Adaptor newOperands(operands);
- rewriter.replaceOpWithNewOp<IREE::VMLA::InterfaceConstOp>(
- loadOp, typeConverter.convertType(loadOp.getResult().getType()),
- interfaceArg, loadOp.offsetAttr());
- return success();
- }
-
- TypeConverter &typeConverter;
-};
-
-struct InterfaceLoadTensorOpConversion
- : public OpConversionPattern<IREE::HAL::InterfaceLoadTensorOp> {
- InterfaceLoadTensorOpConversion(MLIRContext *context,
- TypeConverter &typeConverter)
- : OpConversionPattern(context), typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- IREE::HAL::InterfaceLoadTensorOp loadOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- // Find the vmla.interface argument to the function.
- auto interfaceArg = loadOp->getParentOfType<FuncOp>().getArgument(0);
- assert(interfaceArg &&
- interfaceArg.getType().isa<IREE::VMLA::InterfaceType>() &&
- "exported VMLA functions require vmla.interface ops as their only "
- "argument");
- auto bindingOp = loadOp.queryBindingOp();
-
- IREE::HAL::InterfaceLoadTensorOp::Adaptor newOperands(operands);
- auto bufferOp = rewriter.create<IREE::VMLA::InterfaceBindingOp>(
- loadOp.getLoc(), IREE::VMLA::BufferType::get(loadOp.getContext()),
- interfaceArg, bindingOp.set().getZExtValue(),
- bindingOp.binding().getZExtValue());
- auto byteLengthValue = VMLAConversionTarget::getBufferLength(
- loadOp.getLoc(), loadOp.result(), typeConverter, rewriter);
- if (!byteLengthValue) return failure();
- rewriter.replaceOpWithNewOp<IREE::VMLA::BufferViewOp>(
- loadOp, IREE::VMLA::BufferType::get(loadOp.getContext()),
- bufferOp.result(), newOperands.offset(), byteLengthValue);
- return success();
- }
-
- TypeConverter &typeConverter;
-};
-
-struct InterfaceStoreTensorOpConversion
- : public OpConversionPattern<IREE::HAL::InterfaceStoreTensorOp> {
- InterfaceStoreTensorOpConversion(MLIRContext *context,
- TypeConverter &typeConverter)
- : OpConversionPattern(context), typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- IREE::HAL::InterfaceStoreTensorOp storeOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- // Find the vmla.interface argument to the function.
- auto interfaceArg = storeOp->getParentOfType<FuncOp>().getArgument(0);
- assert(interfaceArg.getType().isa<IREE::VMLA::InterfaceType>() &&
- "exported VMLA functions require vmla.interface ops as their only "
- "argument");
- auto bindingOp = storeOp.queryBindingOp();
-
- IREE::HAL::InterfaceStoreTensorOp::Adaptor newOperands(operands);
- auto bufferOp = rewriter.create<IREE::VMLA::InterfaceBindingOp>(
- storeOp.getLoc(), IREE::VMLA::BufferType::get(storeOp.getContext()),
- interfaceArg, bindingOp.set().getZExtValue(),
- bindingOp.binding().getZExtValue());
-
- auto zeroValue =
- rewriter.createOrFold<mlir::ConstantIndexOp>(storeOp.getLoc(), 0);
- auto byteLengthValue = VMLAConversionTarget::getBufferLength(
- storeOp.getLoc(), storeOp.operand(), typeConverter, rewriter);
- rewriter.create<IREE::VMLA::BufferCopyOp>(
- storeOp.getLoc(), newOperands.operand(), zeroValue, bufferOp,
- newOperands.offset(), byteLengthValue);
- rewriter.replaceOp(storeOp, {});
- return success();
- }
-
- TypeConverter &typeConverter;
-};
-
-} // namespace
-
-void populateHALToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter) {
- patterns.insert<InterfaceOpEraser>(context);
- patterns.insert<InterfaceLoadConstantOpConversion>(context, typeConverter);
- patterns.insert<InterfaceLoadTensorOpConversion>(context, typeConverter);
- patterns.insert<InterfaceStoreTensorOpConversion>(context, typeConverter);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.h b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.h
deleted file mode 100644
index ef9c766..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.h
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_VMLA_CONVERSION_HALTOVMLA_CONVERTHALTOVMLA_H_
-#define IREE_COMPILER_DIALECT_VMLA_CONVERSION_HALTOVMLA_CONVERTHALTOVMLA_H_
-
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Populates conversion patterns from the IREE HAL dialect interface to the
-// VMLA dialect interface.
-void populateHALToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_VMLA_CONVERSION_HALTOVMLA_CONVERTHALTOVMLA_H_
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD
deleted file mode 100644
index 40dc09d..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:lit_test.bzl", "iree_lit_test_suite")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_lit_test_suite(
- name = "lit",
- srcs = enforce_glob(
- ["interface_ops.mlir"],
- include = ["*.mlir"],
- ),
- data = [
- "//iree/tools:IreeFileCheck",
- "//iree/tools:iree-opt",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/CMakeLists.txt
deleted file mode 100644
index 42cc9b2..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/CMakeLists.txt
+++ /dev/null
@@ -1,23 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_lit_test_suite(
- NAME
- lit
- SRCS
- "interface_ops.mlir"
- DATA
- iree::tools::IreeFileCheck
- iree::tools::iree-opt
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/interface_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/interface_ops.mlir
deleted file mode 100644
index b8aa7b2..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/test/interface_ops.mlir
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: func @inc_rgn_dispatch_0
-// CHECK-SAME: (%[[INTERFACE:.+]]: !vmla.interface
-func @inc_rgn_dispatch_0() attributes {iree.module.export} {
- // CHECK-DAG: %[[C0:.+]] = constant 0
- // CHECK-DAG: %[[C4:.+]] = constant 4
- %c0 = constant 0 : index
- // CHECK-DAG: %[[CST1:.+]] = vmla.constant dense<1.000000e+00> : tensor<f32> -> !vmla.buffer
- %cst = constant dense<1.000000e+00> : tensor<f32>
- // CHECK-NEXT: %[[SET0BINDING0:.+]] = vmla.interface.binding %[[INTERFACE]] {binding = 0 : i32, set = 0 : i32} : !vmla.buffer
- // CHECK-NEXT: %[[ARG0:.+]] = vmla.buffer.view %[[SET0BINDING0]][%[[C0]]], byte_length = %[[C4]] : !vmla.buffer
- %0 = hal.interface.load.tensor @io::@arg0, offset = %c0 : tensor<f32>
- // CHECK-NEXT: %[[TEMP:.+]] = vmla.buffer.alloc byte_length = %[[C4]] : !vmla.buffer
- // CHECK-NEXT: vmla.add %[[ARG0]], %[[CST1]], out %[[TEMP]] : f32
- %1 = mhlo.add %0, %cst : tensor<f32>
- // CHECK-NEXT: %[[SET0BINDING1:.+]] = vmla.interface.binding %[[INTERFACE]] {binding = 1 : i32, set = 0 : i32} : !vmla.buffer
- // CHECK-NEXT: vmla.buffer.copy %[[TEMP]][%[[C0]]], out %[[SET0BINDING1]][%[[C0]]], byte_length = %[[C4]]
- hal.interface.store.tensor %1, @io::@ret0, offset = %c0 : tensor<f32>
- return
-}
-func private @inc_rgn_dispatch_0_impl(%arg0: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
- %cst = constant dense<1.000000e+00> : tensor<f32>
- %0 = mhlo.add %arg0, %cst : tensor<f32>
- return %0 : tensor<f32>
-}
-hal.interface @io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
deleted file mode 100644
index 3a45201..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "HLOToVMLA",
- srcs = [
- "ConvertConvOps.cpp",
- "ConvertHLOToVMLA.cpp",
- "ConvertReductionOps.cpp",
- ],
- hdrs = [
- "ConvertHLOToVMLA.h",
- ],
- deps = [
- "//iree/compiler/Dialect/IREE/IR",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/VMLA/Conversion",
- "//iree/compiler/Dialect/VMLA/IR",
- "//iree/compiler/Dialect/VMLA/IR:VMLADialect",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:Transforms",
- "@mlir-hlo//:hlo",
- "@mlir-hlo//:legalize_to_linalg",
- "@mlir-hlo//:legalize_to_standard",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
deleted file mode 100644
index 0df1146..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/CMakeLists.txt
+++ /dev/null
@@ -1,37 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- HLOToVMLA
- HDRS
- "ConvertHLOToVMLA.h"
- SRCS
- "ConvertConvOps.cpp"
- "ConvertHLOToVMLA.cpp"
- "ConvertReductionOps.cpp"
- DEPS
- LLVMSupport
- MLIRIR
- MLIRPass
- MLIRStandard
- MLIRTransforms
- iree::compiler::Dialect::IREE::IR
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::VMLA::Conversion
- iree::compiler::Dialect::VMLA::IR
- iree::compiler::Dialect::VMLA::IR::VMLADialect
- tensorflow::mlir_hlo
- PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertConvOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertConvOps.cpp
deleted file mode 100644
index 0eeda13..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertConvOps.cpp
+++ /dev/null
@@ -1,169 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/BitVector.h"
-#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-struct VMLAConvOpConverter : public OpConversionPattern<mhlo::ConvOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mhlo::ConvOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- if (op.dimension_numbers()) {
- const auto dimensionNumbers = op.dimension_numbers();
- const int inputSpatialRank =
- std::distance(dimensionNumbers.input_spatial_dimensions().begin(),
- dimensionNumbers.input_spatial_dimensions().end());
-
- if (inputSpatialRank != 2) {
- op.emitWarning() << "Only lowering 2D conv is supported";
- return failure();
- }
- // Input storage order is N,spatial_dims...,Ci.
- if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
- dimensionNumbers.input_feature_dimension().getInt() !=
- (inputSpatialRank + 1)) {
- op.emitWarning()
- << "Could not lower conv op due to inconsistant storage type";
- return failure();
- }
-
- const int kernelSpatialRank =
- std::distance(dimensionNumbers.kernel_spatial_dimensions().begin(),
- dimensionNumbers.kernel_spatial_dimensions().end());
- // Filter storage order is spatial_dims...,C, Co.
- if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
- kernelSpatialRank ||
- dimensionNumbers.kernel_output_feature_dimension().getInt() !=
- (kernelSpatialRank + 1))
- return failure();
-
- const int outputSpatialRank =
- std::distance(dimensionNumbers.output_spatial_dimensions().begin(),
- dimensionNumbers.output_spatial_dimensions().end());
- // Output storage order is N,spatial_dims..,Co.
- if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
- dimensionNumbers.output_feature_dimension().getInt() !=
- (outputSpatialRank + 1))
- return failure();
-
- if (inputSpatialRank != outputSpatialRank ||
- inputSpatialRank != kernelSpatialRank)
- return failure();
-
- auto inputSpatialDim =
- dimensionNumbers.input_spatial_dimensions().begin();
- auto kernelSpatialDim =
- dimensionNumbers.kernel_spatial_dimensions().begin();
- auto outputSpatialDim =
- dimensionNumbers.output_spatial_dimensions().begin();
- // Check spatial dims are ordred correctly.
- for (int i = 0; i < inputSpatialRank; ++i) {
- const int dim = i + 1;
- if ((*inputSpatialDim++).getZExtValue() != dim ||
- (*outputSpatialDim++).getZExtValue() != dim ||
- (*kernelSpatialDim++).getZExtValue() != i)
- return failure();
- }
- }
-
- auto inputShape = VMLAConversionTarget::getTensorShape(
- op.getLoc(), op.lhs(), *getTypeConverter(), rewriter);
- auto filterShape = VMLAConversionTarget::getTensorShape(
- op.getLoc(), op.rhs(), *getTypeConverter(), rewriter);
- auto dstShape = VMLAConversionTarget::getTensorShape(
- op.getLoc(), op.getResult(), *getTypeConverter(), rewriter);
-
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- op.getLoc(), op.getResult(), *getTypeConverter(), rewriter);
-
- auto lhsType =
- TypeAttr::get(op.lhs().getType().cast<ShapedType>().getElementType());
- auto rhsType =
- TypeAttr::get(op.lhs().getType().cast<ShapedType>().getElementType());
-
- SmallVector<int32_t, 4> windowStrides{1, 1};
- SmallVector<int32_t, 4> padding{0, 0, 0, 0};
- SmallVector<int32_t, 4> lhsDilation{1, 1};
- SmallVector<int32_t, 4> rhsDilation{1, 1};
- int32_t featureGroupCount = op.feature_group_count();
- int32_t batchGroupCount = op.batch_group_count();
-
- auto fill_optional = [](auto filed, SmallVector<int32_t, 4> *vec) {
- if (filed.hasValue()) {
- int index = 0;
- for (auto attribute : filed.getValue()) {
- (*vec)[index++] = attribute.getZExtValue();
- }
- }
- };
-
- fill_optional(op.window_strides(), &windowStrides);
- fill_optional(op.padding(), &padding);
- fill_optional(op.lhs_dilation(), &lhsDilation);
- fill_optional(op.rhs_dilation(), &rhsDilation);
-
- if (batchGroupCount != 1) {
- op.emitWarning() << "Batch group convoution isn't supported";
- return failure();
- }
-
- rewriter.create<IREE::VMLA::ConvOp>(
- op.getLoc(), op.lhs(), inputShape, op.rhs(), filterShape, dst, dstShape,
- rewriter.getI32VectorAttr(windowStrides),
- rewriter.getI32VectorAttr(padding),
- rewriter.getI32VectorAttr(lhsDilation),
- rewriter.getI32VectorAttr(rhsDilation),
- rewriter.getI32IntegerAttr(featureGroupCount),
- rewriter.getI32IntegerAttr(batchGroupCount), lhsType, rhsType, rhsType);
-
- rewriter.replaceOp(op, dst);
-
- return success();
- }
-};
-
-} // namespace
-
-void populateHLOConvToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter) {
- patterns.insert<VMLAConvOpConverter>(typeConverter, context);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
deleted file mode 100644
index 52b3543..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ /dev/null
@@ -1,1003 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
-
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "llvm/ADT/STLExtras.h"
-#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-void populateHLOConvToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter);
-void populateHLODotToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter);
-void populateHLOReductionToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter);
-
-namespace {
-
-// Clones operand[0] and returns the result.
-// This models the value semantics of XLA. We expect previous passes to elide
-// identity ops when possible and only check for trivial single use ops here.
-template <typename SRC>
-struct IdentityOpConversion : public OpConversionPattern<SRC> {
- using OpConversionPattern<SRC>::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- SRC srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- // mhlo::DynamicReshape has multiple operands, so we cannot just say
- // `getOperand()`. But `getOperand(0)` doesn't work for the other
- // single-operand ops. So use the raw Operation to get the operand.
- if (srcOp.getOperation()->getOperand(0).hasOneUse()) {
- // Can directly pass through the input buffer as we don't need to clone
- // for other users.
- rewriter.replaceOp(srcOp, operands[0]);
- return success();
- } else {
- // More than one user of the operand exist and we need to ensure they
- // keep a valid snapshot of the buffer.
- rewriter.replaceOpWithNewOp<IREE::VMLA::BufferCloneOp>(
- srcOp, IREE::VMLA::BufferType::get(rewriter.getContext()),
- operands[0]);
- return success();
- }
- }
-};
-
-// Converts a shapex.ranked_broadcast_in_dim op to either a broadcast or a tile
-// depending on the input shape.
-//
-// We assume that mhlo.broadcast_in_dim and mhlo.dynamic_broadcast_in_dim
-// have been legalized into that op.
-//
-// Note that shapex.ranked_broadcast_in_dim is not strictly speaking an HLO op,
-// but we would like HLO to eventually have something like it, and the shapex
-// dialect is currently where we have it stuffed.
-struct BroadcastInDimOpConversion
- : public OpConversionPattern<Shape::RankedBroadcastInDimOp> {
- using OpConversionPattern<Shape::RankedBroadcastInDimOp>::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- Shape::RankedBroadcastInDimOp srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto srcShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.operand(), *getTypeConverter(), rewriter);
- auto dstShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
-
- auto tensorType = srcOp.operand().getType().cast<TensorType>();
- if (tensorType.getRank() == 0) {
- // Broadcast of a scalar value.
- rewriter.create<IREE::VMLA::BroadcastOp>(
- srcOp.getLoc(), operands[0], srcShape, dst, dstShape,
- TypeAttr::get(tensorType.getElementType()));
- } else {
- // Tiling a non-scalar value by first broadcasting the shape to
- // include degenerate dimensions that tile will duplicate.
- auto dstRsType = dstShape.getType().dyn_cast<Shape::RankedShapeType>();
- if (!dstRsType) {
- srcOp.emitWarning() << "currently only operates on ranked tensors";
- return failure();
- }
- SmallVector<int64_t, 4> broadcastDims;
- if (srcOp.broadcast_dimensions()) {
- auto srcBroadcastDims = srcOp.broadcast_dimensions();
- for (const auto &broadcastDim : srcBroadcastDims) {
- broadcastDims.push_back(broadcastDim.getSExtValue());
- }
- }
-
- auto broadcastedShape = Shape::buildDegenerateBroadcastRankedShape(
- srcShape, dstRsType.getRank(), broadcastDims, rewriter);
- if (!broadcastedShape) {
- srcOp.emitWarning("unsupported shape type for degenerate broadcast");
- return failure();
- }
- rewriter.create<IREE::VMLA::TileOp>(
- srcOp.getLoc(), operands[0], broadcastedShape, dst, dstShape,
- TypeAttr::get(tensorType.getElementType()));
- }
-
- rewriter.replaceOp(srcOp, {dst});
- return success();
- }
-};
-
-struct IotaOpConversion : public OpConversionPattern<Shape::IotaOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(
- Shape::IotaOp op, ArrayRef<Value> operandValues,
- ConversionPatternRewriter &rewriter) const override {
- auto resultTy = op.getResult().getType().cast<ShapedType>();
-
- int32_t elementSize = VMLATypeConverter::getRoundedElementByteWidth(
- resultTy.getElementType());
- auto elementSizeValue =
- rewriter.createOrFold<mlir::ConstantIndexOp>(op.getLoc(), elementSize);
-
- auto shapeDim0 = rewriter.createOrFold<Shape::RankedDimOp>(
- op.getLoc(), rewriter.getIndexType(), op.getOperand(),
- rewriter.getI64IntegerAttr(0));
-
- auto bufferSize = rewriter.createOrFold<mlir::MulIOp>(
- op.getLoc(), elementSizeValue, shapeDim0);
-
- auto dst = rewriter.createOrFold<IREE::VMLA::BufferAllocOp>(
- op.getLoc(), IREE::VMLA::BufferType::get(rewriter.getContext()),
- bufferSize);
-
- rewriter.createOrFold<IREE::VMLA::IotaOp>(
- op.getLoc(), dst, TypeAttr::get(resultTy.getElementType()));
- rewriter.replaceOp(op, {dst});
-
- return success();
- }
-};
-
-struct CanonicalizeBroadcastOp : public OpRewritePattern<mhlo::BroadcastOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(mhlo::BroadcastOp op,
- PatternRewriter &rewriter) const override {
- SmallVector<int64_t, 6> broadcastDimensions;
- RankedTensorType inputType =
- op.getOperand().getType().cast<RankedTensorType>();
- RankedTensorType outputType =
- op.getResult().getType().cast<RankedTensorType>();
- for (int outputDim = outputType.getRank() - inputType.getRank(),
- outputRank = outputType.getRank();
- outputDim < outputRank; outputDim++) {
- broadcastDimensions.push_back(outputDim);
- }
- // TODO(silvasean): move this helper to DenseIntElementsAttr.
- auto make1DElementsAttr = [&rewriter](ArrayRef<int64_t> integers) {
- auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
- rewriter.getIntegerType(64));
- return DenseIntElementsAttr::get(type, integers);
- };
- rewriter.replaceOpWithNewOp<mhlo::BroadcastInDimOp>(
- op, op.getType(), op.getOperand(),
- make1DElementsAttr(broadcastDimensions));
- return success();
- }
-};
-
-// Converts a concat into a set of copies into the destination buffer.
-struct ConcatenateOpConversion
- : public OpConversionPattern<mhlo::ConcatenateOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mhlo::ConcatenateOp srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto zero = rewriter.createOrFold<mlir::ConstantIndexOp>(srcOp.getLoc(), 0);
-
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
- auto dstShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
-
- auto finalType = srcOp.getResult().getType().cast<TensorType>();
- int rank = finalType.getRank();
- llvm::SmallVector<Value, 4> srcIndices(rank, zero);
- llvm::SmallVector<Value, 4> dstIndices(rank, zero);
- auto concatDimension = srcOp.dimension();
- for (auto srcDstOperand : llvm::zip(srcOp.val(), operands)) {
- Value tensorOperand, bufferOperand;
- std::tie(tensorOperand, bufferOperand) = srcDstOperand;
-
- auto srcShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), tensorOperand, *getTypeConverter(), rewriter);
- SmallVector<Value, 4> lengths(rank);
- for (int i = 0; i < rank; ++i) {
- lengths[i] = rewriter.createOrFold<Shape::RankedDimOp>(
- srcOp.getLoc(), rewriter.getIndexType(), srcShape, i);
- }
-
- rewriter.create<IREE::VMLA::CopyOp>(
- srcOp.getLoc(), bufferOperand, srcShape, srcIndices, dst, dstShape,
- dstIndices, lengths,
- TypeAttr::get(srcOp.getType().cast<ShapedType>().getElementType()));
-
- dstIndices[concatDimension] = rewriter.createOrFold<mlir::AddIOp>(
- srcOp.getLoc(), dstIndices[concatDimension],
- lengths[concatDimension]);
- }
-
- rewriter.replaceOp(srcOp, {dst});
- return success();
- }
-};
-
-// Lowers a subset of gathers along axis 0 that are really just a slice and
-// reshape.
-// TODO(ataei): Move this to vmla.gather lowering.
-struct GatherOpConversion : public OpConversionPattern<mhlo::GatherOp> {
- using OpConversionPattern::OpConversionPattern;
-
- // TODO(gcmn): This only handles a minimal number of cases. When XLA
- // redefines gather to be simpler, lower it properly.
- LogicalResult matchAndRewrite(
- mhlo::GatherOp gatherOp, ArrayRef<Value> operandValues,
- ConversionPatternRewriter &rewriter) const override {
- mhlo::GatherOp::Adaptor operands(operandValues);
- auto dimension_numbers = gatherOp.dimension_numbers();
- if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) {
- gatherOp.emitRemark()
- << "couldn't lower gather with index_vector_dim != 0";
- return failure();
- }
- if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
- dimension_numbers.start_index_map()
- .getValue(0)
- .cast<IntegerAttr>()
- .getValue() != 0) {
- gatherOp.emitRemark()
- << "couldn't lower gather with start_index_map != [0]";
- return failure();
- }
- if (dimension_numbers.collapsed_slice_dims().getType().getRank() != 1 ||
- dimension_numbers.collapsed_slice_dims()
- .getValue(0)
- .cast<IntegerAttr>()
- .getValue() != 0) {
- gatherOp.emitRemark()
- << "couldn't lower gather with collapsed_dims != [0]";
- return failure();
- }
-
- auto resultType = gatherOp.getResult().getType().cast<RankedTensorType>();
- if (dimension_numbers.offset_dims().getType().getNumElements() !=
- resultType.getRank()) {
- gatherOp.emitRemark() << "couldn't lower gather with offset_dims != "
- "[0,...,rank of output]";
- return failure();
- }
- for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
- if (it.index() != it.value()) {
- gatherOp.emitRemark() << "couldn't lower gather with offset_dims != "
- "[0,...,rank of output]";
- return failure();
- }
- }
-
- for (auto it : llvm::enumerate(resultType.getShape())) {
- if (gatherOp.slice_sizes()
- .getValue(it.index() + 1)
- .cast<IntegerAttr>()
- .getValue() != it.value()) {
- gatherOp.emitRemark()
- << "couldn't lower gather with slice_sizes not [1] + final shape";
- return failure();
- }
- }
-
- auto srcShape = VMLAConversionTarget::getTensorShape(
- gatherOp.getLoc(), gatherOp.operand(), *getTypeConverter(), rewriter);
- auto dstShape = VMLAConversionTarget::getTensorShape(
- gatherOp.getLoc(), gatherOp.getResult(), *getTypeConverter(), rewriter);
-
- auto srcRsType = srcShape.getType().dyn_cast<Shape::RankedShapeType>();
- if (!srcRsType) {
- gatherOp.emitWarning() << "currently only operates on ranked tensors";
- return failure();
- }
-
- // Broadcast the dst shape to the src rank by prepending degenerate
- // dimensions.
- SmallVector<int64_t, 1> emptyBroadcastDims;
- dstShape = Shape::buildDegenerateBroadcastRankedShape(
- dstShape, srcRsType.getRank(), emptyBroadcastDims, rewriter);
- if (!dstShape) {
- gatherOp.emitWarning("unsupported shape type for degenerate broadcast");
- return failure();
- }
-
- auto inputType = gatherOp.operand().getType().cast<RankedTensorType>();
- auto startIndicesType =
- gatherOp.start_indices().getType().cast<ShapedType>();
- int rank = inputType.getRank();
- SmallVector<Value, 4> srcIndices(rank);
- SmallVector<Value, 4> dstIndices(rank);
- SmallVector<Value, 4> lengths(rank);
- Value zero =
- rewriter.createOrFold<mlir::ConstantIndexOp>(gatherOp.getLoc(), 0);
- for (int i = 0; i < rank; ++i) {
- if (i < startIndicesType.getNumElements()) {
- auto srcIndexByteOffset = rewriter.createOrFold<mlir::ConstantIndexOp>(
- gatherOp.getLoc(), i * sizeof(int32_t));
- srcIndices[i] = rewriter.createOrFold<IndexCastOp>(
- gatherOp.getLoc(), rewriter.getIndexType(),
- rewriter.createOrFold<IREE::VMLA::BufferLoadI32Op>(
- gatherOp.getLoc(), rewriter.getIntegerType(32),
- operands.start_indices(), srcIndexByteOffset));
- } else {
- // Pad missing dimensions to zero offsets.
- srcIndices[i] = zero;
- }
- dstIndices[i] = zero;
- lengths[i] = rewriter.createOrFold<mlir::ConstantIndexOp>(
- gatherOp.getLoc(),
- gatherOp.slice_sizes().getValue<int64_t>({static_cast<uint64_t>(i)}));
- }
-
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- gatherOp.getLoc(), gatherOp.getResult(), *getTypeConverter(), rewriter);
- rewriter.create<IREE::VMLA::CopyOp>(
- gatherOp.getLoc(), operands.operand(), srcShape, srcIndices, dst,
- dstShape, dstIndices, lengths,
- TypeAttr::get(inputType.getElementType()));
- rewriter.replaceOp(gatherOp, {dst});
- return success();
- }
-};
-
-// Converts a static slice op to a copy (if the source must be preserved).
-struct SliceOpConversion : public OpConversionPattern<mhlo::SliceOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mhlo::SliceOp srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto isNotOne = [](APInt stride) { return stride != 1; };
- if (llvm::any_of(srcOp.strides(), isNotOne)) {
- srcOp.emitWarning()
- << "Could not lower slice op with non-singular strides";
- return failure();
- }
-
- // TODO(benvanik): if the source is only used by this op then replace with
- // a vmla.buffer.view op.
-
- auto srcShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.operand(), *getTypeConverter(), rewriter);
- auto dstShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
-
- int rank = srcOp.operand().getType().cast<ShapedType>().getRank();
- SmallVector<Value, 4> srcIndices(rank);
- SmallVector<Value, 4> dstIndices(rank);
- SmallVector<Value, 4> lengths(rank);
- Value zero =
- rewriter.createOrFold<mlir::ConstantIndexOp>(srcOp.getLoc(), 0);
- for (int i = 0; i < rank; ++i) {
- uint64_t ui = static_cast<uint64_t>(i);
- srcIndices[i] = rewriter.createOrFold<mlir::ConstantIndexOp>(
- srcOp.getLoc(), srcOp.start_indices().getValue<int64_t>({ui}));
- dstIndices[i] = zero;
- lengths[i] = rewriter.createOrFold<mlir::ConstantIndexOp>(
- srcOp.getLoc(), srcOp.limit_indices().getValue<int64_t>({ui}) -
- srcOp.start_indices().getValue<int64_t>({ui}));
- }
-
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
- rewriter.create<IREE::VMLA::CopyOp>(
- srcOp.getLoc(), operands[0], srcShape, srcIndices, dst, dstShape,
- dstIndices, lengths,
- TypeAttr::get(srcOp.getType().cast<ShapedType>().getElementType()));
- rewriter.replaceOp(srcOp, {dst});
- return success();
- }
-};
-
-// This lowering converts a subset of the XLA ScatterOp to a VMLA equivalent.
-// If attempts to handle the simplest case with updates along a single preceding
-// batch dimension and support both scattered updates per-element and per-slice.
-// It does not support swizzling / reordering slices.
-struct ScatterOpConversion : public OpConversionPattern<mhlo::ScatterOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mhlo::ScatterOp scatterOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto dimension_numbers = scatterOp.scatter_dimension_numbers();
- int rank = scatterOp.getType().cast<ShapedType>().getRank();
-
- // We should only update along the upper most batch dimension.
- if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 1) {
- scatterOp.emitRemark()
- << "couldn't lower scatter with index_vector_dim != 1";
- return failure();
- }
-
- if (dimension_numbers.scatter_dims_to_operand_dims().getType().getRank() !=
- 1) {
- return rewriter.notifyMatchFailure(
- scatterOp,
- "couldn't lower scatter with scatter_dims_to_operand_dims with non "
- "rank-1");
- }
-
- // We assume the scatter to operand dims occurs in a normal order. If
- // support for other orders is needed a transpose should be processed on
- // the update.
- for (auto pair : llvm::enumerate(
- dimension_numbers.scatter_dims_to_operand_dims().getIntValues())) {
- if (pair.index() != pair.value()) {
- return rewriter.notifyMatchFailure(
- scatterOp,
- "couldn't lower scatter with scatter_dims_to_operand_dims "
- "!= [0, 1, ..., n]");
- }
- }
-
- if (dimension_numbers.inserted_window_dims().getType().getRank() != 1) {
- return rewriter.notifyMatchFailure(
- scatterOp,
- "couldn't lower scatter with inserted_window_dims with non rank-1");
- }
-
- // Inserted window dims only occurs in normal order and all sources should
- // only support these values.
- for (auto pair : llvm::enumerate(
- dimension_numbers.inserted_window_dims().getIntValues())) {
- if (pair.index() != pair.value()) {
- return rewriter.notifyMatchFailure(
- scatterOp,
- "couldn't lower scatter with inserted_window_dims "
- "!= [0, 1, ..., n]");
- }
- }
-
- if (dimension_numbers.update_window_dims().getType().getRank() != 1) {
- return rewriter.notifyMatchFailure(
- scatterOp,
- "couldn't lower scatter with update_window_dims with non rank-1");
- }
-
- for (auto pair : llvm::enumerate(
- dimension_numbers.update_window_dims().getIntValues())) {
- if ((pair.index() + 1) != pair.value()) {
- return rewriter.notifyMatchFailure(
- scatterOp,
- "couldn't lower scatter with update_window_dims != [1, 2, ..., n]");
- }
- }
-
- auto src = scatterOp.operand();
- auto indices = scatterOp.scatter_indices();
- auto update = scatterOp.updates();
- auto result = scatterOp.getResult();
-
- auto srcShape = VMLAConversionTarget::getTensorShape(
- scatterOp.getLoc(), src, *getTypeConverter(), rewriter);
- auto indicesShape = VMLAConversionTarget::getTensorShape(
- scatterOp.getLoc(), indices, *getTypeConverter(), rewriter);
- auto updateShape = VMLAConversionTarget::getTensorShape(
- scatterOp.getLoc(), update, *getTypeConverter(), rewriter);
- auto resultShape = VMLAConversionTarget::getTensorShape(
- scatterOp.getLoc(), result, *getTypeConverter(), rewriter);
-
- auto dstShape = VMLAConversionTarget::getTensorShape(
- scatterOp.getLoc(), scatterOp.getResult(), *getTypeConverter(),
- rewriter);
-
- SmallVector<Value, 4> lengths(rank);
- for (int i = 0; i < rank; ++i) {
- lengths[i] = rewriter.createOrFold<Shape::RankedDimOp>(
- scatterOp.getLoc(), rewriter.getIndexType(), srcShape, i);
- }
-
- // Verify the update computation is only a scatter write and does not
- // perform an other update.
- // TODO(suderman): Handle other numeric updates.
- auto &firstBlock = scatterOp.getRegion().front();
- if (!isa<mhlo::ReturnOp>(firstBlock.front()) ||
- firstBlock.front().getOperand(0) != firstBlock.getArgument(1)) {
- return rewriter.notifyMatchFailure(
- scatterOp, "scatter update is not solely a write.");
- }
-
- // Copy the source contents. The copy can be optimized in the future.
- Value zero =
- rewriter.createOrFold<mlir::ConstantIndexOp>(scatterOp.getLoc(), 0);
- llvm::SmallVector<Value, 4> srcOffset(rank, zero);
- llvm::SmallVector<Value, 4> dstOffset(rank, zero);
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- scatterOp.getLoc(), scatterOp.getResult(), *getTypeConverter(),
- rewriter);
- rewriter.create<IREE::VMLA::CopyOp>(
- scatterOp.getLoc(), src, resultShape, srcOffset, dst, dstShape,
- dstOffset, lengths,
- TypeAttr::get(scatterOp.getType().cast<ShapedType>().getElementType()));
-
- rewriter.create<IREE::VMLA::ScatterOp>(
- scatterOp.getLoc(), update, updateShape, indices, indicesShape, dst,
- dstShape,
- TypeAttr::get(scatterOp.getType().cast<ShapedType>().getElementType()));
-
- rewriter.replaceOp(scatterOp, {dst});
-
- return success();
- }
-};
-
-// Converts a dynamic slice op to a copy (if the source must be preserved).
-struct DynamicSliceOpConversion
- : public OpConversionPattern<mhlo::DynamicSliceOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mhlo::DynamicSliceOp srcOp, ArrayRef<Value> rawOperands,
- ConversionPatternRewriter &rewriter) const override {
- mhlo::DynamicSliceOp::Adaptor operands(rawOperands);
- // TODO(benvanik): if the source is only used by this op then replace with
- // a vmla.buffer.view op.
-
- auto srcShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.operand(), *getTypeConverter(), rewriter);
- auto dstShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.result(), *getTypeConverter(), rewriter);
-
- int rank = srcOp.operand().getType().cast<ShapedType>().getRank();
- SmallVector<Value, 4> srcIndices(rank);
- SmallVector<Value, 4> dstIndices(rank);
- SmallVector<Value, 4> lengths(rank);
- Value zero =
- rewriter.createOrFold<mlir::ConstantIndexOp>(srcOp.getLoc(), 0);
- for (int i = 0; i < rank; ++i) {
- srcIndices[i] = rewriter.createOrFold<IndexCastOp>(
- srcOp.getLoc(), rewriter.getIndexType(),
- rewriter.createOrFold<IREE::VMLA::BufferLoadI32Op>(
- srcOp.getLoc(), rewriter.getIntegerType(32),
- operands.start_indices()[i],
- rewriter.createOrFold<mlir::ConstantIndexOp>(srcOp.getLoc(), 0)));
- dstIndices[i] = zero;
- lengths[i] = rewriter.createOrFold<mlir::ConstantIndexOp>(
- srcOp.getLoc(),
- srcOp.slice_sizes().getValue<int64_t>({static_cast<uint64_t>(i)}));
- }
-
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
- rewriter.create<IREE::VMLA::CopyOp>(
- srcOp.getLoc(), operands.operand(), srcShape, srcIndices, dst, dstShape,
- dstIndices, lengths,
- TypeAttr::get(srcOp.getType().cast<ShapedType>().getElementType()));
- rewriter.replaceOp(srcOp, {dst});
- return success();
- }
-};
-
-struct CompareOpConversion : public OpConversionPattern<mhlo::CompareOp> {
- CompareOpConversion(TypeConverter &typeConverter, MLIRContext *context)
- : OpConversionPattern(typeConverter, context,
- /*benefit=*/9999) {}
-
- LogicalResult matchAndRewrite(
- mhlo::CompareOp srcOp, ArrayRef<Value> rawOperands,
- ConversionPatternRewriter &rewriter) const override {
- auto linputType = srcOp.lhs().getType().dyn_cast<ShapedType>();
- auto rinputType = srcOp.rhs().getType().dyn_cast<ShapedType>();
- if (!linputType || !rinputType) return failure();
-
- IREE::VMLA::CmpPredicate predicate = IREE::VMLA::CmpPredicate::EQ;
- auto comparisonDirection = srcOp.comparison_direction();
- auto comparePredicate =
- llvm::StringSwitch<Optional<CmpIPredicate>>(comparisonDirection)
- .Case("EQ", CmpIPredicate::eq)
- .Case("NE", CmpIPredicate::ne)
- .Case("LT", CmpIPredicate::slt)
- .Case("LE", CmpIPredicate::sle)
- .Case("GT", CmpIPredicate::sgt)
- .Case("GE", CmpIPredicate::sge)
- .Default(llvm::None);
- if (!comparePredicate.hasValue()) return failure();
-
- auto predicateValue = comparePredicate.getValue();
- switch (predicateValue) {
- case CmpIPredicate::eq:
- predicate = IREE::VMLA::CmpPredicate::EQ;
- break;
- case CmpIPredicate::ne:
- predicate = IREE::VMLA::CmpPredicate::NE;
- break;
- case CmpIPredicate::slt:
- predicate = IREE::VMLA::CmpPredicate::LT;
- break;
- case CmpIPredicate::sle:
- predicate = IREE::VMLA::CmpPredicate::LE;
- break;
- case CmpIPredicate::sgt:
- predicate = IREE::VMLA::CmpPredicate::GT;
- break;
- case CmpIPredicate::sge:
- predicate = IREE::VMLA::CmpPredicate::GE;
- break;
- default:
- llvm_unreachable("unhandled comparison predicate");
- return failure();
- }
-
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
- auto newOp = rewriter.create<IREE::VMLA::CmpOp>(
- srcOp.getLoc(), predicate, rawOperands[0], rawOperands[1], dst,
- TypeAttr::get(linputType.getElementType()));
- rewriter.replaceOp(srcOp, newOp.dst());
- return success();
- }
-};
-
-struct FiniteOpConversion : public OpConversionPattern<mhlo::IsFiniteOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mhlo::IsFiniteOp srcOp, ArrayRef<Value> rawOperands,
- ConversionPatternRewriter &rewriter) const override {
- auto inputType =
- srcOp.getOperand().getType().cast<ShapedType>().getElementType();
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
- rewriter.createOrFold<IREE::VMLA::FiniteOp>(srcOp.getLoc(), rawOperands[0],
- dst, TypeAttr::get(inputType));
- rewriter.replaceOp(srcOp, {dst});
- return success();
- }
-};
-
-struct SortOpConversion : public OpConversionPattern<IREE::VMLA::SortPseudoOp> {
- SortOpConversion(MLIRContext *context, TypeConverter &typeConverter)
- : OpConversionPattern(context), typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- IREE::VMLA::SortPseudoOp srcOp, ArrayRef<Value> rawOperands,
- ConversionPatternRewriter &rewriter) const override {
- auto inputType =
- srcOp.getOperand().getType().cast<ShapedType>().getElementType();
- auto src = rawOperands[0];
- auto src_shape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.value(), typeConverter, rewriter);
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
- rewriter.createOrFold<IREE::VMLA::SortOp>(srcOp.getLoc(), src, src_shape,
- dst, TypeAttr::get(inputType));
- rewriter.replaceOp(srcOp, {dst});
- return success();
- }
-
- TypeConverter &typeConverter;
-};
-
-struct FftOpConversion : public OpConversionPattern<IREE::VMLA::FftPseudoOp> {
- FftOpConversion(MLIRContext *context, TypeConverter &typeConverter)
- : OpConversionPattern(context), typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- IREE::VMLA::FftPseudoOp srcOp, ArrayRef<Value> rawOperands,
- ConversionPatternRewriter &rewriter) const override {
- auto input_shape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.real_in(), typeConverter, rewriter);
-
- auto real_input_type = srcOp.getOperand(0).getType().cast<ShapedType>();
- auto imag_input_type = srcOp.getOperand(1).getType().cast<ShapedType>();
-
- // The input type/shape should match for the real and imag components.
- if (real_input_type != imag_input_type) {
- srcOp.emitWarning() << "real and imag should have matching types";
- return failure();
- }
-
- auto real_out = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(0), typeConverter, rewriter);
- auto imag_out = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(1), typeConverter, rewriter);
-
- rewriter.createOrFold<IREE::VMLA::FftOp>(
- srcOp.getLoc(), rawOperands[0], input_shape, rawOperands[1],
- input_shape, real_out, imag_out,
- TypeAttr::get(real_input_type.getElementType()));
-
- rewriter.replaceOp(srcOp, {real_out, imag_out});
- return success();
- }
-
- TypeConverter &typeConverter;
-};
-
-struct IfftOpConversion : public OpConversionPattern<IREE::VMLA::IfftPseudoOp> {
- IfftOpConversion(MLIRContext *context, TypeConverter &typeConverter)
- : OpConversionPattern(context), typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- IREE::VMLA::IfftPseudoOp srcOp, ArrayRef<Value> rawOperands,
- ConversionPatternRewriter &rewriter) const override {
- auto input_shape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.real_in(), typeConverter, rewriter);
-
- auto real_input_type = srcOp.getOperand(0).getType().cast<ShapedType>();
- auto imag_input_type = srcOp.getOperand(1).getType().cast<ShapedType>();
-
- // The input type/shape should match for the real and imag components.
- if (real_input_type != imag_input_type) {
- srcOp.emitWarning() << "real and imag should have matching types";
- return failure();
- }
-
- auto real_out = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(0), typeConverter, rewriter);
- auto imag_out = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(1), typeConverter, rewriter);
-
- rewriter.createOrFold<IREE::VMLA::IfftOp>(
- srcOp.getLoc(), rawOperands[0], input_shape, rawOperands[1],
- input_shape, real_out, imag_out,
- TypeAttr::get(real_input_type.getElementType()));
-
- rewriter.replaceOp(srcOp, {real_out, imag_out});
- return success();
- }
-
- TypeConverter &typeConverter;
-};
-
-struct RfftOpConversion : public OpConversionPattern<IREE::VMLA::RfftPseudoOp> {
- RfftOpConversion(MLIRContext *context, TypeConverter &typeConverter)
- : OpConversionPattern(context), typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- IREE::VMLA::RfftPseudoOp srcOp, ArrayRef<Value> rawOperands,
- ConversionPatternRewriter &rewriter) const override {
- auto input_type = srcOp.getOperand().getType().cast<ShapedType>();
- auto input_shape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.real_in(), typeConverter, rewriter);
-
- auto real_out = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(0), typeConverter, rewriter);
- auto imag_out = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(1), typeConverter, rewriter);
-
- rewriter.createOrFold<IREE::VMLA::RfftOp>(
- srcOp.getLoc(), rawOperands[0], input_shape, real_out, imag_out,
- TypeAttr::get(input_type.getElementType()));
-
- rewriter.replaceOp(srcOp, {real_out, imag_out});
- return success();
- }
-
- TypeConverter &typeConverter;
-};
-
-struct IrfftOpConversion
- : public OpConversionPattern<IREE::VMLA::IrfftPseudoOp> {
- IrfftOpConversion(MLIRContext *context, TypeConverter &typeConverter)
- : OpConversionPattern(context), typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- IREE::VMLA::IrfftPseudoOp srcOp, ArrayRef<Value> rawOperands,
- ConversionPatternRewriter &rewriter) const override {
- auto real_input_type = srcOp.getOperand(0).getType().cast<ShapedType>();
- auto imag_input_type = srcOp.getOperand(1).getType().cast<ShapedType>();
-
- // The input type/shape should match for the real and imag components.
- if (real_input_type != imag_input_type) {
- srcOp.emitWarning() << "real and imag should have matching types";
- return failure();
- }
-
- auto input_shape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.real_in(), typeConverter, rewriter);
- auto real_out = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
-
- rewriter.createOrFold<IREE::VMLA::IrfftOp>(
- srcOp.getLoc(), rawOperands[0], input_shape, rawOperands[1],
- input_shape, real_out, TypeAttr::get(real_input_type.getElementType()));
-
- rewriter.replaceOp(srcOp, {real_out});
- return success();
- }
-
- TypeConverter &typeConverter;
-};
-
-struct ConvertOpConversion : public OpConversionPattern<mhlo::ConvertOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- mhlo::ConvertOp srcOp, ArrayRef<Value> rawOperands,
- ConversionPatternRewriter &rewriter) const override {
- auto srcType = srcOp.operand().getType().cast<ShapedType>();
- auto dstType = srcOp.getResult().getType().cast<ShapedType>();
-
- // The mhlo.convert op can have the same src and dst element types, in
- // which case it just represents a static structural annotation of a shape
- // change, so it is just an identity op at runtime.
- if (srcType.getElementType() == dstType.getElementType()) {
- return IdentityOpConversion<mhlo::ConvertOp>{rewriter.getContext()}
- .matchAndRewrite(srcOp, rawOperands, rewriter);
- }
-
- // VMLA does not support tensors of i1. tensor<*xi1> will be converted to
- // tensor<*xi8>.
- if ((srcType.getElementTypeBitWidth() == 1 &&
- dstType.getElementTypeBitWidth() == 8) ||
- (srcType.getElementTypeBitWidth() == 8 &&
- dstType.getElementTypeBitWidth() == 1)) {
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
- auto bitMask = rewriter.createOrFold<mlir::ConstantIntOp>(
- srcOp.getLoc(), 1, rewriter.getI32Type());
- rewriter.createOrFold<IREE::VMLA::AndBroadcastOp>(
- srcOp.getLoc(), rawOperands[0], bitMask, dst,
- TypeAttr::get(rewriter.getIntegerType(8)), false);
- rewriter.replaceOp(srcOp, {dst});
- } else {
- return VMLAConversionTarget::applyDefaultBufferRewrite(
- srcOp, rawOperands, VMLAOpSemantics::kDefault,
- IREE::VMLA::ConvertOp::getOperationName(), *getTypeConverter(),
- rewriter);
- }
-
- return success();
- }
-};
-
-} // namespace
-
-void populateHLOToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter) {
- // mhlo.convolution.
- populateHLOConvToVMLAPatterns(context, patterns, typeConverter);
-
- // mhlo.reduce and mhlo.reduce_window.
- populateHLOReductionToVMLAPatterns(context, patterns, typeConverter);
-
- // vmla.batch.matmul.pseudo
- patterns.insert<VMLAOpConversion<IREE::VMLA::BatchMatMulPseudoOp,
- IREE::VMLA::BatchMatMulOp>>(typeConverter,
- context);
-
- // vmla.sort.pseudo
- patterns.insert<SortOpConversion>(context, typeConverter);
-
- // vmla.fft.pseudo, vmla.ifft.pseudo, vmla.rfft.pseudo, vmla.irfft.pseudo
- patterns.insert<FftOpConversion, IfftOpConversion, RfftOpConversion,
- IrfftOpConversion>(context, typeConverter);
- // Simple 1:1 conversion patterns using the automated trait-based converter.
- // Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
- patterns.insert<VMLAOpConversion<mhlo::AddOp, IREE::VMLA::AddOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::SubOp, IREE::VMLA::SubOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::DivOp, IREE::VMLA::DivOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::MulOp, IREE::VMLA::MulOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::PowOp, IREE::VMLA::PowOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::RemOp, IREE::VMLA::RemOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::ShiftLeftOp, IREE::VMLA::ShlOp>>(
- typeConverter, context);
- patterns.insert<
- VMLAOpConversion<mhlo::ShiftRightArithmeticOp, IREE::VMLA::ShrOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::ShiftRightLogicalOp, IREE::VMLA::ShrOp,
- VMLAOpSemantics::kForceUnsigned>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::AndOp, IREE::VMLA::AndOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::OrOp, IREE::VMLA::OrOp>>(typeConverter,
- context);
- patterns.insert<VMLAOpConversion<mhlo::XorOp, IREE::VMLA::XorOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::NotOp, IREE::VMLA::NotOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::ExpOp, IREE::VMLA::ExpOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::LogOp, IREE::VMLA::LogOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::CeilOp, IREE::VMLA::CeilOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::FloorOp, IREE::VMLA::FloorOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::RoundOp, IREE::VMLA::RoundOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::RsqrtOp, IREE::VMLA::RsqrtOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::SqrtOp, IREE::VMLA::SqrtOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::CosOp, IREE::VMLA::CosOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::SinOp, IREE::VMLA::SinOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::TanhOp, IREE::VMLA::TanhOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::Atan2Op, IREE::VMLA::Atan2Op>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::SelectOp, IREE::VMLA::SelectOp>>(
- typeConverter, context);
- patterns.insert<ConvertOpConversion>(typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::ReverseOp, IREE::VMLA::ReverseOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::TransposeOp, IREE::VMLA::TransposeOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::PadOp, IREE::VMLA::PadOp>>(
- typeConverter, context);
- patterns
- .insert<VMLAOpConversion<mhlo::TorchIndexSelectOp, IREE::VMLA::GatherOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::AbsOp, IREE::VMLA::AbsOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::NegOp, IREE::VMLA::NegOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::MaxOp, IREE::VMLA::MaxOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::MinOp, IREE::VMLA::MinOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mhlo::ClampOp, IREE::VMLA::ClampOp>>(
- typeConverter, context);
-
- patterns.insert<CompareOpConversion>(typeConverter, context);
- patterns.insert<FiniteOpConversion>(typeConverter, context);
-
- // Ops that are only used for type information that we erase. We can elide
- // these entirely by just passing on their input values.
- patterns.insert<IdentityOpConversion<mhlo::BitcastConvertOp>>(context);
- patterns.insert<IdentityOpConversion<mhlo::CopyOp>>(context);
- patterns.insert<IdentityOpConversion<mhlo::ReshapeOp>>(context);
- patterns.insert<IdentityOpConversion<mhlo::DynamicReshapeOp>>(context);
-
- // Conversions that don't have a 1:1 mapping, mostly involving buffer views
- // or transfers.
- patterns.insert<BroadcastInDimOpConversion>(typeConverter, context);
- patterns.insert<ConcatenateOpConversion>(typeConverter, context);
- patterns.insert<GatherOpConversion>(typeConverter, context);
- patterns.insert<ScatterOpConversion>(typeConverter, context);
- patterns.insert<SliceOpConversion>(typeConverter, context);
- patterns.insert<DynamicSliceOpConversion>(typeConverter, context);
- patterns.insert<IotaOpConversion>(context);
-
- // Tensor-level canonicalizations to reduce the op surface area of the
- // runtime.
- patterns.insert<CanonicalizeBroadcastOp>(context);
-
- // We rely on some additional HLO->std patterns and assume they
- // have been run already. In case they haven't we provide them here (useful
- // for standalone conversion testing). We run them last so that other patterns
- // have a chance to handle the HLO conversions first.
- mhlo::PopulateMhloToStdPatterns(&patterns, context);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h
deleted file mode 100644
index 2d51a25..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_VMLA_CONVERSION_HLOTOVMLA_CONVERTHLOTOVMLA_H_
-#define IREE_COMPILER_DIALECT_VMLA_CONVERSION_HLOTOVMLA_CONVERTHLOTOVMLA_H_
-
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Populates conversion patterns from the XLA HLO dialect to the VMLA dialect.
-void populateHLOToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_VMLA_CONVERSION_HLOTOVMLA_CONVERTHLOTOVMLA_H_
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
deleted file mode 100644
index a9877d2..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertReductionOps.cpp
+++ /dev/null
@@ -1,340 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-// Converts a simple mhlo.reduce op that performs independent individual
-// computations into a set of mhlo.reduce ops. This is an intermediate
-// conversion that may make it possible to use the much faster builtin VMLA
-// reduction ops.
-//
-// Only supports single dimensional reductions and assumes that unrolling has
-// been performed prior to conversion.
-struct SplitIndependentReductionOpConversion
- : public OpConversionPattern<mhlo::ReduceOp> {
- SplitIndependentReductionOpConversion(MLIRContext *context,
- TypeConverter &typeConverter)
- : OpConversionPattern(context), typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- mhlo::ReduceOp srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- if (srcOp.dimensions().getNumElements() > 1) {
- srcOp.emitOpError() << "multi-dimensional reductions must be unrolled";
- return failure();
- } else if (srcOp.body().getBlocks().size() > 1) {
- // Control flow within the computation is not supported; bail to fallback.
- return failure();
- }
- auto &block = srcOp.body().getBlocks().front();
- mhlo::ReduceOp::Adaptor newOperands(operands);
- SmallVector<Value, 4> setResults;
- for (auto &op : block) {
- if (op.hasTrait<OpTrait::IsTerminator>()) {
- continue;
- } else if (op.getOperands().size() != 2) {
- // Only binary ops are supported for builtins.
- return failure();
- }
-
- // Determine which argument set this op is acting on. For the builtins we
- // only support ops that act within a single set.
- // Our arguments are expanded tuples like <lhs0, lhs1>, <rhs0, rhs1>, so
- // this index gets the set offset.
- int opSetIndex =
- std::distance(block.args_begin(),
- llvm::find(block.getArguments(), op.getOperand(0)));
-
- for (auto operand : op.getOperands()) {
- if (operand.getDefiningOp() != nullptr) {
- // Operand comes from another op within the block; unsupported.
- return failure();
- }
- int operandSetIndex =
- std::distance(block.args_begin(),
- llvm::find(block.getArguments(), operand)) %
- newOperands.inputs().size();
- if (operandSetIndex != opSetIndex) {
- // Operand is not coming from the same set as the other operands of
- // this op; unsupported.
- return failure();
- }
- }
- for (auto result : op.getResults()) {
- for (auto *user : result.getUsers()) {
- if (!user->hasTrait<OpTrait::IsTerminator>()) {
- // Result is not directly returned from the block; unsupported.
- return failure();
- }
- }
- }
-
- // Create the new op for this set.
- Value operandArg = srcOp.inputs()[opSetIndex];
- Value initArg = srcOp.init_values()[opSetIndex];
- auto splitOp = rewriter.create<mhlo::ReduceOp>(
- op.getLoc(), ValueRange{operandArg}, ValueRange{initArg},
- srcOp.dimensionsAttr());
- auto *splitBlock = new Block();
- splitOp.body().getBlocks().push_back(splitBlock);
- OpBuilder splitBuilder = OpBuilder::atBlockEnd(splitBlock);
- BlockAndValueMapping mapping;
- for (auto operand : op.getOperands()) {
- mapping.map(operand, splitBlock->addArgument(operand.getType()));
- }
- Operation *splitComputeOp = splitBuilder.clone(op, mapping);
- splitBuilder.create<mhlo::ReturnOp>(
- srcOp.getLoc(), ValueRange{*splitComputeOp->getResults().begin()});
- setResults.push_back(*splitOp.getResults().begin());
- }
-
- rewriter.replaceOp(srcOp, setResults);
- return success();
- }
-
- TypeConverter &typeConverter;
-};
-
-// Converts an mhlo.reduce with a single op to a builtin reduce op.
-// This is meant to pair with the SplitIndependentReductionOpConversion that
-// tries to unfuse/divide combined reductions. If this cannot match then the
-// fallback path will be used and a VM loop will be emitted (slower, but can
-// perform any reduction).
-//
-// Only supports single dimensional reductions and assumes that unrolling has
-// been performed prior to conversion.
-struct BuiltinReduceOpConversion : public OpConversionPattern<mhlo::ReduceOp> {
- BuiltinReduceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
- : OpConversionPattern(context, /*benefit=*/1000),
- typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- mhlo::ReduceOp srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- if (srcOp.dimensions().getNumElements() > 1) {
- srcOp.emitOpError() << "multi-dimensional reductions must be unrolled";
- return failure();
- } else if (srcOp.body().getBlocks().size() > 1) {
- // Control flow within the computation is not supported; bail to fallback.
- return failure();
- } else if (srcOp.body().front().getOperations().size() > 2) {
- // Require splitting first.
- return failure();
- }
-
- auto operand = operands[0];
- auto operandShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.inputs()[0], typeConverter, rewriter);
- auto initValue = operands[1];
- auto initValueShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.init_values()[0], typeConverter, rewriter);
- int dimension = srcOp.dimensions().getValue<IntegerAttr>({0}).getInt();
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResults()[0], typeConverter, rewriter);
- auto dstShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.getResults()[0], typeConverter, rewriter);
- auto elementType =
- srcOp.inputs()[0].getType().cast<ShapedType>().getElementType();
-
- auto &computeOp = *srcOp.body().front().begin();
- if (isa<mlir::AddIOp>(computeOp) || isa<mlir::AddFOp>(computeOp) ||
- isa<mhlo::AddOp>(computeOp)) {
- rewriter.create<IREE::VMLA::ReduceSumOp>(
- srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
- rewriter.getI32IntegerAttr(dimension), dst, dstShape,
- TypeAttr::get(elementType));
- } else if (isa<mhlo::MinOp>(computeOp)) {
- rewriter.create<IREE::VMLA::ReduceMinOp>(
- srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
- rewriter.getI32IntegerAttr(dimension), dst, dstShape,
- TypeAttr::get(elementType));
- } else if (isa<mhlo::MaxOp>(computeOp)) {
- rewriter.create<IREE::VMLA::ReduceMaxOp>(
- srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
- rewriter.getI32IntegerAttr(dimension), dst, dstShape,
- TypeAttr::get(elementType));
- } else if (isa<mhlo::AndOp>(computeOp)) {
- rewriter.create<IREE::VMLA::ReduceAndOp>(
- srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
- rewriter.getI32IntegerAttr(dimension), dst, dstShape,
- TypeAttr::get(elementType));
- } else if (isa<mhlo::OrOp>(computeOp)) {
- rewriter.create<IREE::VMLA::ReduceOrOp>(
- srcOp.getLoc(), operand, operandShape, initValue, initValueShape,
- rewriter.getI32IntegerAttr(dimension), dst, dstShape,
- TypeAttr::get(elementType));
- } else {
- computeOp.emitRemark() << "unsupported builtin reduction operation";
- return failure();
- }
-
- rewriter.replaceOp(srcOp, {dst});
- return success();
- }
-
- TypeConverter &typeConverter;
-};
-
-// Converts a generic mhlo.reduce to a VM loop.
-//
-// Only supports single dimensional reductions and assumes that unrolling has
-// been performed prior to conversion.
-struct GenericReduceOpConversion : public OpConversionPattern<mhlo::ReduceOp> {
- GenericReduceOpConversion(MLIRContext *context, TypeConverter &typeConverter)
- : OpConversionPattern(context), typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- mhlo::ReduceOp srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- if (srcOp.dimensions().getNumElements() > 1) {
- srcOp.emitOpError() << "multi-dimensional reductions must be unrolled";
- return failure();
- }
-
- // TODO(benvanik): emit VM loop around computation.
- srcOp.emitOpError() << "generic reduction lowering not yet implemented";
- return failure();
- }
-
- TypeConverter &typeConverter;
-};
-
-struct BuiltinPoolingOpConversion
- : public OpConversionPattern<mhlo::ReduceWindowOp> {
- BuiltinPoolingOpConversion(MLIRContext *context, TypeConverter &typeConverter)
- : OpConversionPattern(context, /*benefit=*/1000),
- typeConverter(typeConverter) {}
-
- LogicalResult matchAndRewrite(
- mhlo::ReduceWindowOp srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- // ReduceWindow body is always a single block region.For simple pattern
- // matching, the reduce-window region needs to have 1 operation for each
- // output + 1 operation for the return (exactly).
- Block &body = srcOp.body().front();
- if (!llvm::hasNItems(body, srcOp.getNumResults() + 1)) {
- // Require splitting first.
- return failure();
- }
-
- SmallVector<int32_t, 4> windowDimensions;
- for (const auto &value : srcOp.window_dimensions().getIntValues())
- windowDimensions.push_back(value.getSExtValue());
- int rank = windowDimensions.size();
- SmallVector<int32_t, 4> windowStrides(rank, 1);
- SmallVector<int32_t, 4> padding(rank, 0);
- for (unsigned i = 0; i < rank; ++i) {
- if (srcOp.window_strides())
- windowStrides[i] = srcOp.window_stridesAttr().getValue<int64_t>(i);
- if (srcOp.padding())
- padding[i] = srcOp.paddingAttr().getValue<int64_t>({i, 0});
- }
-
- int numInputs = srcOp.inputs().size();
- ArrayRef<Value> inputs = operands.take_front(numInputs);
- ArrayRef<Value> initValues = operands.drop_front(numInputs);
-
- SmallVector<Value> newValues;
- for (unsigned i = 0; i < numInputs; ++i) {
- auto inputShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.inputs()[i], typeConverter, rewriter);
- auto initValueShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.init_values()[i], typeConverter, rewriter);
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(i), typeConverter, rewriter);
- auto dstShape = VMLAConversionTarget::getTensorShape(
- srcOp.getLoc(), srcOp.getResult(i), typeConverter, rewriter);
- auto elementType =
- srcOp.inputs()[i].getType().cast<ShapedType>().getElementType();
-
- Operation *computeOp = srcOp.getReductionOp(i);
- if (!computeOp) {
- srcOp.emitRemark() << "unsupported builtin reduction operation";
- return failure();
- }
-
- Value input = inputs[i];
- Value initValue = initValues[i];
-
- if (isa<mlir::AddIOp>(computeOp) || isa<mlir::AddFOp>(computeOp) ||
- isa<mhlo::AddOp>(computeOp)) {
- rewriter.create<IREE::VMLA::PoolingSumOp>(
- srcOp.getLoc(), input, inputShape, initValue, initValueShape, dst,
- dstShape, TypeAttr::get(elementType),
- rewriter.getI32VectorAttr(windowDimensions),
- rewriter.getI32VectorAttr(windowStrides),
- rewriter.getI32VectorAttr(padding));
- } else if (isa<mhlo::MinOp>(computeOp)) {
- rewriter.create<IREE::VMLA::PoolingMinOp>(
- srcOp.getLoc(), input, inputShape, initValue, initValueShape, dst,
- dstShape, TypeAttr::get(elementType),
- rewriter.getI32VectorAttr(windowDimensions),
- rewriter.getI32VectorAttr(windowStrides),
- rewriter.getI32VectorAttr(padding));
- } else if (isa<mhlo::MaxOp>(computeOp)) {
- rewriter.create<IREE::VMLA::PoolingMaxOp>(
- srcOp.getLoc(), input, inputShape, initValue, initValueShape, dst,
- dstShape, TypeAttr::get(elementType),
- rewriter.getI32VectorAttr(windowDimensions),
- rewriter.getI32VectorAttr(windowStrides),
- rewriter.getI32VectorAttr(padding));
- } else {
- computeOp->emitRemark() << "unsupported builtin reduction operation";
- return failure();
- }
- newValues.push_back(dst);
- }
-
- rewriter.replaceOp(srcOp, newValues);
- return success();
- }
-
- TypeConverter &typeConverter;
-};
-
-} // namespace
-
-void populateHLOReductionToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter) {
- patterns.insert<SplitIndependentReductionOpConversion>(context,
- typeConverter);
- patterns.insert<BuiltinReduceOpConversion>(context, typeConverter);
- patterns.insert<BuiltinPoolingOpConversion>(context, typeConverter);
- patterns.insert<GenericReduceOpConversion>(context, typeConverter);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD
deleted file mode 100644
index 3cbab8b..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD
+++ /dev/null
@@ -1,49 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:lit_test.bzl", "iree_lit_test_suite")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_lit_test_suite(
- name = "lit",
- srcs = enforce_glob(
- [
- "broadcast_in_dim.mlir",
- "concatenate.mlir",
- "conv.mlir",
- "convert.mlir",
- "dynamic_slice.mlir",
- "fft.mlir",
- "math_ops.mlir",
- "reduce.mlir",
- "reduce_window.mlir",
- "reshape.mlir",
- "scatter.mlir",
- "slice.mlir",
- "sort.mlir",
- "transpose.mlir",
- ],
- include = ["*.mlir"],
- ),
- data = [
- "//iree/tools:IreeFileCheck",
- "//iree/tools:iree-opt",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/CMakeLists.txt
deleted file mode 100644
index fbd6949..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/CMakeLists.txt
+++ /dev/null
@@ -1,36 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_lit_test_suite(
- NAME
- lit
- SRCS
- "broadcast_in_dim.mlir"
- "concatenate.mlir"
- "conv.mlir"
- "convert.mlir"
- "dynamic_slice.mlir"
- "fft.mlir"
- "math_ops.mlir"
- "reduce.mlir"
- "reduce_window.mlir"
- "reshape.mlir"
- "scatter.mlir"
- "slice.mlir"
- "sort.mlir"
- "transpose.mlir"
- DATA
- iree::tools::IreeFileCheck
- iree::tools::iree-opt
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
deleted file mode 100644
index ea5578a..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/broadcast_in_dim.mlir
+++ /dev/null
@@ -1,33 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @broadcast_in_dim_2D_3D
-func private @broadcast_in_dim_2D_3D() -> tensor<3x2x4xi32> {
- %rs3_2_4 = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4]>
- %input = constant dense<[[1, 2, 3, 4], [5, 6, 7, 8]]> : tensor<2x4xi32>
- // CHECK-DAG: %[[SRC:.+]] = vmla.constant
- // CHECK-DAG: %[[SRC_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[1,2,4]>
- // CHECK-DAG: %[[DST_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4]>
- // CHECK-DAG: %[[DST_SIZE:.+]] = constant 96 : index
- // CHECK-DAG: %[[DST:.+]] = vmla.buffer.alloc byte_length = %[[DST_SIZE]] : !vmla.buffer
- // CHECK-DAG: vmla.tile %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[1,2,4]>), out %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[3,2,4]>) : i32
- %0 = "shapex.ranked_broadcast_in_dim"(%input, %rs3_2_4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>, !shapex.ranked_shape<[3,2,4]>) -> tensor<3x2x4xi32>
- // CHECK-NEXT: return %[[DST]] : !vmla.buffer
- return %0 : tensor<3x2x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @broadcast_in_dim_3D_scalar
-func private @broadcast_in_dim_3D_scalar() -> tensor<3x2x4xi32> {
- %rs3_2_4 = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4]>
- // CHECK-DAG: %[[SRC:.+]] = vmla.constant
- // CHECK-DAG: %[[SRC_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[]>
- // CHECK-DAG: %[[DST_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[3,2,4]>
- // CHECK-DAG: %[[DST_SIZE:.+]] = constant 96 : index
- %input = constant dense<42> : tensor<i32>
- // CHECK-NEXT: %[[DST:.+]] = vmla.buffer.alloc byte_length = %[[DST_SIZE]] : !vmla.buffer
- // CHECK-NEXT: vmla.broadcast %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[]>), out %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[3,2,4]>) : i32
- %0 = "shapex.ranked_broadcast_in_dim"(%input, %rs3_2_4) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, !shapex.ranked_shape<[3,2,4]>) -> tensor<3x2x4xi32>
- // CHECK-NEXT: return %[[DST]] : !vmla.buffer
- return %0 : tensor<3x2x4xi32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/concatenate.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/concatenate.mlir
deleted file mode 100644
index f72e8bc..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/concatenate.mlir
+++ /dev/null
@@ -1,98 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @concatenate_0
-func private @concatenate_0(%arg0 : tensor<2x2xi32>) -> (tensor<2x5xi32>) {
- // CHECK-SAME: %[[ARG0:.+]]:
- // CHECK-DAG: %[[ARG1:.+]] = vmla.constant {{.+}} tensor<2x3xi32>
- %c0 = constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32>
- // CHECK: %[[DST:.+]] = vmla.buffer.alloc byte_length = %c40 : !vmla.buffer
- // CHECK: vmla.copy
- // CHECK-SAME: %[[ARG0]](%rs2_2 : !shapex.ranked_shape<[2,2]>),
- // CHECK-SAME: src_indices = [%c0, %c0],
- // CHECK-SAME: out %[[DST]](%rs2_5 : !shapex.ranked_shape<[2,5]>),
- // CHECK-SAME: dst_indices = [%c0, %c0], lengths = [%c2, %c2] : i32
- // CHECK: vmla.copy
- // CHECK-SAME: %[[ARG1]](%rs2_3 : !shapex.ranked_shape<[2,3]>),
- // CHECK-SAME: src_indices = [%c0, %c0],
- // CHECK-SAME: out %[[DST]](%rs2_5 : !shapex.ranked_shape<[2,5]>),
- // CHECK-SAME: dst_indices = [%c0, %c2], lengths = [%c2, %c3] : i32
- %0 = "mhlo.concatenate"(%arg0, %c0) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>) -> tensor<2x5xi32>
- // CHECK-NEXT: return %[[DST]]
- return %0: tensor<2x5xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @concatenate_1
-func private @concatenate_1(%arg0: tensor<2x3xi32>) -> (tensor<2x5xi32>) {
- // CHECK-SAME: %[[ARG0:.+]]:
- // CHECK-DAG: %[[ARG1:.+]] = vmla.constant {{.+}} tensor<2x2xi32>
- %c0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- // CHECK: %[[DST:.+]] = vmla.buffer.alloc byte_length = %c40 : !vmla.buffer
- // CHECK: vmla.copy
- // CHECK-SAME: %[[ARG0]](%rs2_3 : !shapex.ranked_shape<[2,3]>),
- // CHECK-SAME: src_indices = [%c0, %c0],
- // CHECK-SAME: out %[[DST]](%rs2_5 : !shapex.ranked_shape<[2,5]>),
- // CHECK-SAME: dst_indices = [%c0, %c0], lengths = [%c2, %c3] : i32
- // CHECK: vmla.copy
- // CHECK-SAME: %[[ARG1]](%rs2_2 : !shapex.ranked_shape<[2,2]>),
- // CHECK-SAME: src_indices = [%c0, %c0],
- // CHECK-SAME: out %[[DST]](%rs2_5 : !shapex.ranked_shape<[2,5]>),
- // CHECK-SAME: dst_indices = [%c0, %c3], lengths = [%c2, %c2] : i32
- %0 = "mhlo.concatenate"(%arg0, %c0) {dimension = 1} : (tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x5xi32>
- // CHECK-NEXT: return %[[DST]]
- return %0: tensor<2x5xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @concatenate_2
-func private @concatenate_2(%arg0: tensor<2x2xi32>) -> (tensor<2x7xi32>) {
- // CHECK-SAME: %[[ARG0:.+]]:
- // CHECK-DAG: %[[ARG1:.+]] = vmla.constant {{.+}} tensor<2x3xi32>
- %c0 = constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32>
- // CHECK-DAG: %[[ARG2:.+]] = vmla.constant {{.+}} tensor<2x2xi32>
- %c1 = constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32>
- // CHECK: %[[DST:.+]] = vmla.buffer.alloc byte_length = %c56 : !vmla.buffer
- // CHECK: vmla.copy
- // CHECK-SAME: %[[ARG0]](%rs2_2 : !shapex.ranked_shape<[2,2]>),
- // CHECK-SAME: src_indices = [%c0, %c0],
- // CHECK-SAME: out %[[DST]](%rs2_7 : !shapex.ranked_shape<[2,7]>),
- // CHECK-SAME: dst_indices = [%c0, %c0], lengths = [%c2, %c2] : i32
- // CHECK: vmla.copy
- // CHECK-SAME: %[[ARG1]](%rs2_3 : !shapex.ranked_shape<[2,3]>),
- // CEHCK-SAME: src_indices = [%c0, %c0],
- // CHECK-SAME: out %[[DST]](%rs2_7 : !shapex.ranked_shape<[2,7]>),
- // CHECK-SAME: dst_indices = [%c0, %c2], lengths = [%c2, %c3] : i32
- // CHECK: vmla.copy
- // CHECK-SAME: %[[ARG2]](%rs2_2 : !shapex.ranked_shape<[2,2]>),
- // CHECK-SAME: src_indices = [%c0, %c0],
- // CHECK-SAME: out %[[DST]](%rs2_7 : !shapex.ranked_shape<[2,7]>),
- // CHECK-SAME: dst_indices = [%c0, %c5], lengths = [%c2, %c2] : i32
- %0 = "mhlo.concatenate"(%arg0, %c0, %c1) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x7xi32>
- // CHECK-NEXT: return %[[DST]]
- return %0: tensor<2x7xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @concatenate_3
-func private @concatenate_3(%arg0: tensor<2x2xi32>) -> (tensor<4x2xi32>) {
- // CHECK-SAME: %[[ARG0:.+]]:
- // CHECK-DAG: %[[ARG1:.+]] = vmla.constant {{.+}} tensor<2x2xi32>
- %c0 = constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32>
- // CHECK: %[[DST:.+]] = vmla.buffer.alloc byte_length = %c32 : !vmla.buffer
- // CHECK: vmla.copy
- // CHECK-SAME: %[[ARG0]](%rs2_2 : !shapex.ranked_shape<[2,2]>),
- // CHECK-SAME: src_indices = [%c0, %c0],
- // CHECK-SAME: out %[[DST]](%rs4_2 : !shapex.ranked_shape<[4,2]>),
- // CHECK-SAME: dst_indices = [%c0, %c0], lengths = [%c2, %c2] : i32
- // CHECK: vmla.copy
- // CHECK-SAME: %[[ARG1]](%rs2_2 : !shapex.ranked_shape<[2,2]>),
- // CHECK-SAME: src_indices = [%c0, %c0],
- // CHECK-SAME: out %[[DST]](%rs4_2 : !shapex.ranked_shape<[4,2]>),
- // CHECK-SAME: dst_indices = [%c2, %c0], lengths = [%c2, %c2] : i32
- %0 = "mhlo.concatenate"(%arg0, %c0) {dimension = 0} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<4x2xi32>
- // CHECK-NEXT: return %[[DST]]
- return %0: tensor<4x2xi32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/conv.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/conv.mlir
deleted file mode 100644
index 82543f0..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/conv.mlir
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @conv
-func private @conv(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<3x2x2x1xf32>) -> tensor<1x2x3x1xf32> {
- // CHECK: vmla.conv
- // CHECK-SAME: {batch_group_count = 1 : i32,
- // CHECK-SAME: feature_group_count = 1 : i32,
- // CHECK-SAME: lhs_dilation = dense<1> : vector<2xi32>,
- // CHECK-SAME: padding = dense<[1, 2, 2, 2]> : vector<4xi32>,
- // CHECK-SAME: rhs_dilation = dense<1> : vector<2xi32>,
- // CHECK-SAME: window_strides = dense<1> : vector<2xi32>}
- %2 = "mhlo.convolution"(%arg0, %arg1) {
- batch_group_count = 1 : i64,
- dimension_numbers = {
- input_batch_dimension = 0 : i64,
- input_feature_dimension = 3 : i64,
- input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
- kernel_input_feature_dimension = 2 : i64,
- kernel_output_feature_dimension = 3 : i64,
- kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
- output_batch_dimension = 0 : i64,
- output_feature_dimension = 3 : i64,
- output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
- feature_group_count = 1 : i64,
- rhs_dilation = dense<1> : tensor<2xi64>,
- lhs_dilation = dense<1> : tensor<2xi64>,
- padding = dense<[[1, 2],[2, 2]]> : tensor<2x2xi64>,
- window_strides = dense<1> : tensor<2xi64>} : (tensor<1x4x5x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x2x3x1xf32>
- return %2: tensor<1x2x3x1xf32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/convert.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/convert.mlir
deleted file mode 100644
index 8abb0f5..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/convert.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion %s | IreeFileCheck %s
-
-// CHECK-LABEL: func private @basic
-func private @basic(%arg0 : tensor<5xf32>) -> (tensor<5xi32>) {
- // CHECK: vmla.convert
- %0 = "mhlo.convert"(%arg0) : (tensor<5xf32>) -> tensor<5xi32>
- return %0 : tensor<5xi32>
-}
-
-// CHECK-LABEL: func private @noop
-func private @noop(%arg0 : tensor<?xf32>) -> (tensor<5xf32>) {
- // CHECK: return %arg0
- %0 = "mhlo.convert"(%arg0) : (tensor<?xf32>) -> tensor<5xf32>
- return %0 : tensor<5xf32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dynamic_slice.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dynamic_slice.mlir
deleted file mode 100644
index 5b070c2..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/dynamic_slice.mlir
+++ /dev/null
@@ -1,111 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @slice_whole_buffer
-// CHECK-SAME: %[[SRC_IDX_1:.+]]: !vmla.buffer, %[[SRC_IDX_2:.+]]: !vmla.buffer
-func private @slice_whole_buffer(%src_idx_1 : tensor<i64>, %src_idx_2 : tensor<i64>) -> tensor<3x4xi32> {
- // CHECK: %[[SRC:.+]] = vmla.constant
- %input = constant dense<[
- [01, 02, 03, 04],
- [05, 06, 07, 08],
- [09, 10, 11, 12]
- ]> : tensor<3x4xi32>
- // CHECK-DAG: %[[SRC_INDEX_0_I32:.+]] = vmla.buffer.load.i32 %[[SRC_IDX_1]][%c0] : i32
- // CHECK-DAG: %[[SRC_INDEX_0:.+]] = index_cast %[[SRC_INDEX_0_I32]]
- // CHECK-DAG: %[[SRC_INDEX_1_I32:.+]] = vmla.buffer.load.i32 %[[SRC_IDX_2]][%c0] : i32
- // CHECK-DAG: %[[SRC_INDEX_1:.+]] = index_cast %[[SRC_INDEX_1_I32]]
- // CHECK-DAG: %[[DST:.+]] = vmla.buffer.alloc byte_length = %c48 : !vmla.buffer
- // CHECK: vmla.copy
- // CHECK-SAME: %[[SRC]](%rs3_4 : !shapex.ranked_shape<[3,4]>),
- // CHECK-SAME: src_indices = [%[[SRC_INDEX_0]], %[[SRC_INDEX_1]]],
- // CHECK-SAME: out %[[DST]](%rs3_4 : !shapex.ranked_shape<[3,4]>),
- // CHECK-SAME: dst_indices = [%c0, %c0], lengths = [%c3, %c4] : i32
- %result = "mhlo.dynamic-slice"(%input, %src_idx_1, %src_idx_2) {
- slice_sizes = dense<[3, 4]> : tensor<2xi64>
- } : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<3x4xi32>
- // CHECK-NEXT: return %[[DST]]
- return %result : tensor<3x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @slice_whole_stride
-// CHECK-SAME: %[[SRC_IDX_1:.+]]: !vmla.buffer, %[[SRC_IDX_2:.+]]: !vmla.buffer
-func private @slice_whole_stride(%src_idx_1 : tensor<i64>, %src_idx_2 : tensor<i64>) -> tensor<1x4xi32> {
- // CHECK: %[[SRC:.+]] = vmla.constant
- %input = constant dense<[
- [01, 02, 03, 04],
- [05, 06, 07, 08],
- [09, 10, 11, 12]
- ]> : tensor<3x4xi32>
- // CHECK-DAG: %[[SRC_INDEX_0_I32:.+]] = vmla.buffer.load.i32 %[[SRC_IDX_1]][%c0] : i32
- // CHECK-DAG: %[[SRC_INDEX_0:.+]] = index_cast %[[SRC_INDEX_0_I32]]
- // CHECK-DAG: %[[SRC_INDEX_1_I32:.+]] = vmla.buffer.load.i32 %[[SRC_IDX_2]][%c0] : i32
- // CHECK-DAG: %[[SRC_INDEX_1:.+]] = index_cast %[[SRC_INDEX_1_I32]]
- // CHECK-DAG: %[[DST:.+]] = vmla.buffer.alloc byte_length = %c16 : !vmla.buffer
- // CHECK: vmla.copy
- // CHECK-SAME: %[[SRC]](%rs3_4 : !shapex.ranked_shape<[3,4]>),
- // CHECK-SAME: src_indices = [%[[SRC_INDEX_0]], %[[SRC_INDEX_1]]],
- // CHECK-SAME: out %[[DST]](%rs1_4 : !shapex.ranked_shape<[1,4]>),
- // CHECK-SAME: dst_indices = [%c0, %c0], lengths = [%c1, %c4] : i32
- %result = "mhlo.dynamic-slice"(%input, %src_idx_1, %src_idx_2) {
- slice_sizes = dense<[1, 4]> : tensor<2xi64>
- } : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
- // CHECK-NEXT: return %[[DST]]
- return %result : tensor<1x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @slice_stride_part
-// CHECK-SAME: %[[SRC_IDX_1:.+]]: !vmla.buffer, %[[SRC_IDX_2:.+]]: !vmla.buffer
-func private @slice_stride_part(%src_idx_1 : tensor<i64>, %src_idx_2 : tensor<i64>) -> tensor<1x2xi32> {
- // CHECK: %[[SRC:.+]] = vmla.constant
- %input = constant dense<[
- [01, 02, 03, 04],
- [05, 06, 07, 08],
- [09, 10, 11, 12]
- ]> : tensor<3x4xi32>
- // CHECK-DAG: %[[SRC_INDEX_0_I32:.+]] = vmla.buffer.load.i32 %[[SRC_IDX_1]][%c0] : i32
- // CHECK-DAG: %[[SRC_INDEX_0:.+]] = index_cast %[[SRC_INDEX_0_I32]]
- // CHECK-DAG: %[[SRC_INDEX_1_I32:.+]] = vmla.buffer.load.i32 %[[SRC_IDX_2]][%c0] : i32
- // CHECK-DAG: %[[SRC_INDEX_1:.+]] = index_cast %[[SRC_INDEX_1_I32]]
- // CHECK: %[[DST:.+]] = vmla.buffer.alloc byte_length = %c8 : !vmla.buffer
- // CHECK: vmla.copy
- // CHECK-SAME: %[[SRC]](%rs3_4 : !shapex.ranked_shape<[3,4]>),
- // CHECK-SAME: src_indices = [%[[SRC_INDEX_0]], %[[SRC_INDEX_1]]],
- // CHECK-SAME: out %[[DST]](%rs1_2 : !shapex.ranked_shape<[1,2]>),
- // CHECK-SAME: dst_indices = [%c0, %c0], lengths = [%c1, %c2] : i32
- %result = "mhlo.dynamic-slice"(%input, %src_idx_1, %src_idx_2) {
- slice_sizes = dense<[1, 2]> : tensor<2xi64>
- } : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x2xi32>
- // CHECK-NEXT: return %[[DST]]
- return %result : tensor<1x2xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @slice_multi_stride
-// CHECK-SAME: %[[SRC_IDX_1:.+]]: !vmla.buffer, %[[SRC_IDX_2:.+]]: !vmla.buffer
-func private @slice_multi_stride(%src_idx_1 : tensor<i64>, %src_idx_2 : tensor<i64>) -> tensor<2x4xi32> {
- // CHECK: %[[SRC:.+]] = vmla.constant
- %input = constant dense<[
- [01, 02, 03, 04],
- [05, 06, 07, 08],
- [09, 10, 11, 12]
- ]> : tensor<3x4xi32>
- // CHECK-DAG: %[[SRC_INDEX_0_I32:.+]] = vmla.buffer.load.i32 %[[SRC_IDX_1]][%c0] : i32
- // CHECK-DAG: %[[SRC_INDEX_0:.+]] = index_cast %[[SRC_INDEX_0_I32]]
- // CHECK-DAG: %[[SRC_INDEX_1_I32:.+]] = vmla.buffer.load.i32 %[[SRC_IDX_2]][%c0] : i32
- // CHECK-DAG: %[[SRC_INDEX_1:.+]] = index_cast %[[SRC_INDEX_1_I32]]
- // CHECK: %[[DST:.+]] = vmla.buffer.alloc byte_length = %c32 : !vmla.buffer
- // CHECK: vmla.copy
- // CHECK-SAME: %[[SRC]](%rs3_4 : !shapex.ranked_shape<[3,4]>),
- // CHECK-SAME: src_indices = [%[[SRC_INDEX_0]], %[[SRC_INDEX_1]]],
- // CHECK-SAME: out %[[DST]](%rs2_4 : !shapex.ranked_shape<[2,4]>),
- // CHECK-SAME: dst_indices = [%c0, %c0], lengths = [%c2, %c4] : i32
- %result = "mhlo.dynamic-slice"(%input, %src_idx_1, %src_idx_2) {
- slice_sizes = dense<[2, 4]> : tensor<2xi64>
- } : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x4xi32>
- // CHECK-NEXT: return %[[DST]]
- return %result : tensor<2x4xi32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir
deleted file mode 100644
index 0991c79..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir
+++ /dev/null
@@ -1,41 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-func private @fft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
- // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]>
- // CHECK-DAG: [[C32:%.+]] = constant 32 : index
- // CHECK: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
- // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
- // CHECK-NEXT: vmla.fft %arg0([[RS]] : !shapex.ranked_shape<[8]>), %arg1([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32
- %real, %imag = "vmla.fft.pseudo"(%arg0, %arg1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>)
- return %real, %imag : tensor<8xf32>, tensor<8xf32>
-}
-
-func private @ifft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
- // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]>
- // CHECK-DAG: [[C32:%.+]] = constant 32 : index
- // CHECK: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
- // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
- // CHECK-NEXT: vmla.ifft %arg0([[RS]] : !shapex.ranked_shape<[8]>), %arg1([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32
- %real, %imag = "vmla.ifft.pseudo"(%arg0, %arg1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>)
- return %real, %imag : tensor<8xf32>, tensor<8xf32>
-}
-
-func private @rfft(%arg0: tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
- // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]>
- // CHECK-DAG: [[C20:%.+]] = constant 20 : index
- // CHECK: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C20]] : !vmla.buffer
- // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C20]] : !vmla.buffer
- // CHECK-NEXT: vmla.rfft %arg0([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32
- %real, %imag = "vmla.rfft.pseudo"(%arg0) : (tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>)
- return %real, %imag : tensor<5xf32>, tensor<5xf32>
-}
-
-func private @irfft(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<8xf32> {
- // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[5]>
- // CHECK-DAG: [[C32:%.+]] = constant 32 : index
- // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
- // CHECK-NEXT: vmla.irfft %arg0([[RS]] : !shapex.ranked_shape<[5]>), %arg1([[RS]] : !shapex.ranked_shape<[5]>), out [[OUTBUF1]] : f32
- %real = "vmla.irfft.pseudo"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> (tensor<8xf32>)
- return %real : tensor<8xf32>
-}
-
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir
deleted file mode 100644
index bdbd687..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/math_ops.mlir
+++ /dev/null
@@ -1,60 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @abs_scalar
-func private @abs_scalar(%arg0 : tensor<f32>) -> tensor<f32> {
- // CHECK-NEXT: %[[BUF_SZ:.+]] = constant 4
- // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
- // CHECK-NEXT: vmla.abs %arg0, out %[[BUF]] : f32
- %0 = "mhlo.abs"(%arg0) : (tensor<f32>) -> tensor<f32>
- // CHECK-NEXT: return %[[BUF]]
- return %0 : tensor<f32>
-}
-
-// -----
-
-// CHECK-LABEL: @abs_tensor
-func private @abs_tensor(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
- // CHECK-NEXT: %[[BUF_SZ:.+]] = constant 16
- // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
- // CHECK-NEXT: vmla.abs %arg0, out %[[BUF]] : f32
- %0 = "mhlo.abs"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
- // CHECK-NEXT: return %[[BUF]]
- return %0 : tensor<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @clamp
-func private @clamp(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>) -> tensor<4xf32> {
- // CHECK-NEXT: %[[BUF_SZ:.+]] = constant 16
- // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
- // CHECK-NEXT: vmla.clamp %arg0, %arg1, %arg2, out %[[BUF]] : f32
- %0 = "mhlo.clamp"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
- // CHECK-NEXT: return %[[BUF]]
- return %0 : tensor<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @finite
-func private @finite(%arg0 : tensor<4xf32>) -> tensor<4xi1> {
- // CHECK-NEXT: %[[BUF_SZ:.+]] = constant 4
- // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
- // CHECK-NEXT: vmla.finite %arg0, out %[[BUF]] : f32
- %0 = "mhlo.is_finite"(%arg0) : (tensor<4xf32>) -> tensor<4xi1>
- // CHECK-NEXT: return %[[BUF]]
- return %0 : tensor<4xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @not
-func private @not(%arg0 : tensor<4xi8>) -> tensor<4xi8> {
- // CHECK-NEXT: %[[BUF_SZ:.+]] = constant 4
- // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
- // CHECK-NEXT: vmla.not %arg0, out %[[BUF]] : i8
- %0 = "mhlo.not"(%arg0) : (tensor<4xi8>) -> tensor<4xi8>
- // CHECK-NEXT: return %[[BUF]]
- return %0 : tensor<4xi8>
-}
-
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce.mlir
deleted file mode 100644
index 86f0fff..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce.mlir
+++ /dev/null
@@ -1,57 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -cse %s | IreeFileCheck %s
-
-// CHECK-LABEL: @single_reduction
-func private @single_reduction(%arg0: tensor<4x8xf32>) -> tensor<4xf32> {
- // CHECK-DAG: %[[INIT:.+]] = vmla.constant dense<0.000000e+00> : tensor<f32> -> !vmla.buffer
- %cst = constant dense<0.000000e+00> : tensor<f32>
- // CHECK-DAG: %[[SRC_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4,8]>
- // CHECK-DAG: %[[INIT_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[]>
- // CHECK-DAG: %[[DST:.+]] = vmla.buffer.alloc
- // CHECK-DAG: %[[DST_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4]>
- // CHECK-NEXT: vmla.reduce.sum
- // CEHCK-SAME: %arg0(%[[SRC_SHAPE]] : !shapex.ranked_shape<[4,8]>),
- // CHECK-SAME: %[[INIT]](%[[INIT_SHAPE]] : !shapex.ranked_shape<[]>),
- // CHECK-SAME: out %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[4]>)
- // CHECK-SaME: {dimension = 1 : i32} : f32
- %0 = "mhlo.reduce"(%arg0, %cst) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
- %1 = mhlo.add %arg1, %arg2 : tensor<f32>
- "mhlo.return"(%1) : (tensor<f32>) -> ()
- }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
- // CHECK-NEXT: return %[[DST]] : !vmla.buffer
- return %0 : tensor<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @multi_reduction
-func private @multi_reduction(%arg0 : tensor<4x8xf32>, %arg1 : tensor<4x8xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
- // CHECK-DAG: %[[CST0:.+]] = vmla.constant dense<0.000000e+00> : tensor<f32> -> !vmla.buffer
- %0 = constant dense<0.000000e+00> : tensor<f32>
- // CHECK-DAG: %[[CST1:.+]] = vmla.constant dense<1.000000e+00> : tensor<f32> -> !vmla.buffer
- %1 = constant dense<1.000000e+00> : tensor<f32>
- // CHECK-DAG: %[[INPUT_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4,8]>
- // CHECK-DAG: %[[SCALAR_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[]>
- // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4]>
- // CHECK-DAG: %[[RET_SIZE:.+]] = muli
- // CHECK-DAG: %[[RET0:.+]] = vmla.buffer.alloc byte_length = %[[RET_SIZE]] : !vmla.buffer
- // CHECK-NEXT: vmla.reduce.sum
- // CEHCK-SAME: %arg0(%[[INPUT_SHAPE]] : !shapex.ranked_shape<[4,8]>),
- // CHECK-SAME: %[[CST0]](%[[SCALAR_SHAPE]] : !shapex.ranked_shape<[]>),
- // CHECK-SAME: out %[[RET0]](%[[RESULT_SHAPE]] : !shapex.ranked_shape<[4]>)
- // CHECK-SaME: {dimension = 1 : i32} : f32
- // CHECK-NEXT: %[[RET1:.+]] = vmla.buffer.alloc byte_length = %[[RET_SIZE]] : !vmla.buffer
- // CHECK-NEXT: vmla.reduce.sum
- // CEHCK-SAME: %arg1(%[[INPUT_SHAPE]] : !shapex.ranked_shape<[4,8]>),
- // CHECK-SAME: %[[CST1]](%[[SCALAR_SHAPE]] : !shapex.ranked_shape<[]>),
- // CHECK-SAME: out %[[RET1]](%[[RESULT_SHAPE]] : !shapex.ranked_shape<[4]>)
- // CHECK-SaME: {dimension = 1 : i32} : f32
- %2, %3 = "mhlo.reduce"(%arg0, %arg1, %0, %1) ( {
- ^bb0(%arg0_lhs : tensor<f32>, %arg1_lhs : tensor<f32>, %arg0_rhs : tensor<f32>, %arg1_rhs : tensor<f32>):
- %4 = mhlo.add %arg0_lhs, %arg0_rhs : tensor<f32>
- %5 = mhlo.add %arg1_lhs, %arg1_rhs : tensor<f32>
- "mhlo.return"(%4, %5) : (tensor<f32>, tensor<f32>) -> ()
- }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<4x8xf32>, tensor<f32>, tensor<f32>) -> (tensor<4xf32>, tensor<4xf32>)
- // CHECK-NEXT: return %[[RET0]], %[[RET1]] : !vmla.buffer, !vmla.buffer
- return %2, %3 : tensor<4xf32>, tensor<4xf32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce_window.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce_window.mlir
deleted file mode 100644
index 55e8751..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reduce_window.mlir
+++ /dev/null
@@ -1,90 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -cse %s -verify-diagnostics | IreeFileCheck %s
-
-// CHECK-LABEL: @pooling_max
-func private @pooling_max(%arg0: tensor<1x4x6x1xf32>) -> tensor<1x2x2x1xf32> {
- // CHECK: vmla.pooling.max
- %cst = constant dense<0.000000e+00> : tensor<f32>
- %0 = "mhlo.reduce_window"(%arg0, %cst) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
- %1 = mhlo.maximum %arg1, %arg2 : tensor<f32>
- "mhlo.return"(%1) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>,
- window_strides = dense<1> : tensor<4xi64>
- } : (tensor<1x4x6x1xf32>, tensor<f32>) -> tensor<1x2x2x1xf32>
- return %0 : tensor<1x2x2x1xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @pooling_min
-func private @pooling_min(%arg0: tensor<1x4x6x1xi32>) -> tensor<1x2x2x1xi32> {
- // CHECK: vmla.pooling.min
- %cst = constant dense<0> : tensor<i32>
- %0 = "mhlo.reduce_window"(%arg0, %cst) ( {
- ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
- %1 = mhlo.minimum %arg1, %arg2 : tensor<i32>
- "mhlo.return"(%1) : (tensor<i32>) -> ()
- }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>,
- window_strides = dense<1> : tensor<4xi64>
- } : (tensor<1x4x6x1xi32>, tensor<i32>) -> tensor<1x2x2x1xi32>
- return %0 : tensor<1x2x2x1xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @pooling_sum
-func private @pooling_sum(%arg0: tensor<4x6xf32>) -> tensor<3x4xf32> {
- // CHECK: vmla.pooling.sum
- %cst = constant dense<0.000000e+00> : tensor<f32>
- %0 = "mhlo.reduce_window"(%arg0, %cst) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
- %1 = mhlo.add %arg1, %arg2 : tensor<f32>
- "mhlo.return"(%1) : (tensor<f32>) -> ()
- }) {window_dimensions = dense<[2, 3]> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>,
- padding = dense<[[1, 0], [2, 0]]> : tensor<2x2xi64>
- } : (tensor<4x6xf32>, tensor<f32>) -> tensor<3x4xf32>
- return %0 : tensor<3x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @pooling_sum_min
-func private @pooling_sum_min(%arg0: tensor<4x6xf32>) -> (tensor<3x4xf32>, tensor<3x4xf32>) {
- // CHECK: vmla.pooling.sum
- // CHECK: vmla.pooling.min
- %cst = constant dense<0.000000e+00> : tensor<f32>
- %0:2 = "mhlo.reduce_window"(%arg0, %arg0, %cst, %cst) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
- %1 = mhlo.add %arg1, %arg3 : tensor<f32>
- %2 = mhlo.minimum %arg2, %arg4 : tensor<f32>
- "mhlo.return"(%1, %2) : (tensor<f32>, tensor<f32>) -> ()
- }) {window_dimensions = dense<[2, 3]> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>,
- padding = dense<[[1, 0], [2, 0]]> : tensor<2x2xi64>
- } : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<f32>, tensor<f32>) -> (tensor<3x4xf32>, tensor<3x4xf32>)
- return %0#0, %0#1 : tensor<3x4xf32>, tensor<3x4xf32>
-}
-
-// -----
-
-// Specify the module explicitly to anchor the conversion failure message.
-// expected-error@+1 {{conversion to the VMLA dialect failed}}
-module {
-
- func private @pooling_sum_min_fail(%arg0: tensor<4x6xf32>) -> (tensor<3x4xf32>, tensor<3x4xf32>) {
- %cst = constant dense<0.000000e+00> : tensor<f32>
- // expected-remark @+2 {{unsupported builtin reduction operation}}
- // expected-error @+1 {{failed to legalize operation 'mhlo.reduce_window' that was explicitly marked illegal}}
- %0:2 = "mhlo.reduce_window"(%arg0, %arg0, %cst, %cst) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
- %1 = mhlo.add %arg1, %arg2 : tensor<f32>
- %2 = mhlo.minimum %arg3, %arg4 : tensor<f32>
- "mhlo.return"(%1, %2) : (tensor<f32>, tensor<f32>) -> ()
- }) {window_dimensions = dense<[2, 3]> : tensor<2xi64>,
- window_strides = dense<1> : tensor<2xi64>,
- padding = dense<[[1, 0], [2, 0]]> : tensor<2x2xi64>
- } : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<f32>, tensor<f32>) -> (tensor<3x4xf32>, tensor<3x4xf32>)
- return %0#0, %0#1 : tensor<3x4xf32>, tensor<3x4xf32>
- }
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reshape.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reshape.mlir
deleted file mode 100644
index 1583074..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/reshape.mlir
+++ /dev/null
@@ -1,18 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @reshape_bypass
-func private @reshape_bypass(%arg0 : tensor<3x2xi32>) -> tensor<6xi32> {
- // CHECK-NEXT: return %arg0
- %0 = "mhlo.reshape"(%arg0) : (tensor<3x2xi32>) -> tensor<6xi32>
- return %0 : tensor<6xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @reshape_copy
-func private @reshape_copy(%arg0 : tensor<3x2xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) {
- // CHECK-NEXT: %0 = vmla.buffer.clone %arg0 : !vmla.buffer
- %0 = "mhlo.reshape"(%arg0) : (tensor<3x2xi32>) -> tensor<6xi32>
- // CHECK-NEXT: return %arg0, %0
- return %arg0, %0 : tensor<3x2xi32>, tensor<6xi32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/scatter.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/scatter.mlir
deleted file mode 100644
index 384054c..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/scatter.mlir
+++ /dev/null
@@ -1,78 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s -verify-diagnostics | IreeFileCheck %s
-
-// CHECK-LABEL: @scatter_update_1D(
-// CHECK-SAME: [[SRC:arg0]]: !vmla.buffer,
-// CHECK-SAME: [[INDICES:arg1]]: !vmla.buffer,
-// CHECK-SAME: [[UPDATES:arg2]]: !vmla.buffer
-func private @scatter_update_1D(%arg0: tensor<8xi32>, %arg1: tensor<3x1xi32>, %arg2: tensor<3xi32>) -> tensor<8xi32> {
- // CHECK: [[BUFFER:%.+]] = vmla.buffer.alloc
- // CHECK: vmla.copy
- // CHECK-SAME: [[BUFFER]]
- // CHECK: vmla.scatter
- // CHECK-SAME: [[BUFFER]]
- %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
- ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
- "mhlo.return"(%arg4) : (tensor<i32>) -> ()
- }) {
- indices_are_sorted = false,
- scatter_dimension_numbers = {
- index_vector_dim = 1 : i64,
- inserted_window_dims = dense<0> : tensor<1xi64>,
- scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
- update_window_dims = dense<[]> : tensor<0xi64>
- },
- unique_indices = false
- } : (tensor<8xi32>, tensor<3x1xi32>, tensor<3xi32>) -> tensor<8xi32>
- // CHECK: return [[BUFFER]]
- return %0 : tensor<8xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @scatter_update_2D
-func private @scatter_update_2D(%arg0: tensor<4x3xi32>, %arg1: tensor<3x2xi32>, %arg2: tensor<3xi32>) -> tensor<4x3xi32> {
- // CHECK: [[BUFFER:%.+]] = vmla.buffer.alloc
- // CHECK: vmla.copy
- // CHECK-SAME: [[BUFFER]]
- // CHECK: vmla.scatter
- // CHECK-SAME: [[BUFFER]]
- %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
- ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
- "mhlo.return"(%arg4) : (tensor<i32>) -> ()
- }) {
- indices_are_sorted = false,
- scatter_dimension_numbers = {
- index_vector_dim = 1 : i64,
- inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
- scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
- update_window_dims = dense<[]> : tensor<0xi64>},
- unique_indices = false
- } : (tensor<4x3xi32>, tensor<3x2xi32>, tensor<3xi32>) -> tensor<4x3xi32>
- // CHECK: return [[BUFFER]]
- return %0 : tensor<4x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @scatter_update_2D_slice
-func private @scatter_update_2D_slice(%arg0: tensor<4x3xi32>, %arg1: tensor<3x1xi32>, %arg2: tensor<3x3xi32>) -> tensor<4x3xi32> {
- // CHECK: [[BUFFER:%.+]] = vmla.buffer.alloc
- // CHECK: vmla.copy
- // CHECK-SAME: [[BUFFER]]
- // CHECK: vmla.scatter
- // CHECK-SAME: [[BUFFER]]
- %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
- ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
- "mhlo.return"(%arg4) : (tensor<i32>) -> ()
- }) {
- indices_are_sorted = false,
- scatter_dimension_numbers = {
- index_vector_dim = 1 : i64,
- inserted_window_dims = dense<0> : tensor<1xi64>,
- scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
- update_window_dims = dense<1> : tensor<1xi64>},
- unique_indices = false
- } : (tensor<4x3xi32>, tensor<3x1xi32>, tensor<3x3xi32>) -> tensor<4x3xi32>
- // CHECK: return [[BUFFER]]
- return %0 : tensor<4x3xi32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/slice.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/slice.mlir
deleted file mode 100644
index 0c37693..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/slice.mlir
+++ /dev/null
@@ -1,59 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @slice_whole_stride
-func private @slice_whole_stride(%arg0 : tensor<3x4xi32>) -> tensor<1x4xi32> {
- // CHECK-SAME: %[[SRC:.+]]:
- // CHECK: %[[DST:.+]] = vmla.buffer.alloc byte_length = %c16
- // CHECK: vmla.copy
- // CHECK-SAME: %[[SRC]](%rs3_4 : !shapex.ranked_shape<[3,4]>),
- // CHECK-SAME: src_indices = [%c1, %c0],
- // CHECK-SAME: out %[[DST]](%rs1_4 : !shapex.ranked_shape<[1,4]>),
- // CHECK-SAME: dst_indices = [%c0, %c0], lengths = [%c1, %c4] : i32
- %result = "mhlo.slice"(%arg0) {
- start_indices = dense<[1, 0]> : tensor<2xi64>,
- limit_indices = dense<[2, 4]> : tensor<2xi64>,
- strides = dense<1> : tensor<2xi64>
- } : (tensor<3x4xi32>) -> tensor<1x4xi32>
- // CHECK-NEXT: return %[[DST]]
- return %result : tensor<1x4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @slice_stride_part
-func private @slice_stride_part(%arg0 : tensor<3x4xi32>) -> tensor<1x2xi32> {
- // CHECK-SAME: %[[SRC:.+]]:
- // CHECK: %[[DST:.+]] = vmla.buffer.alloc byte_length = %c8
- // CHECK: vmla.copy
- // CHECK-SAME: %[[SRC]](%rs3_4 : !shapex.ranked_shape<[3,4]>),
- // CHECK-SAME: src_indices = [%c1, %c1],
- // CHECK-SAME: out %[[DST]](%rs1_2 : !shapex.ranked_shape<[1,2]>),
- // CHECK-SAME: dst_indices = [%c0, %c0], lengths = [%c1, %c2] : i32
- %result = "mhlo.slice"(%arg0) {
- start_indices = dense<[1, 1]> : tensor<2xi64>,
- limit_indices = dense<[2, 3]> : tensor<2xi64>,
- strides = dense<1> : tensor<2xi64>
- } : (tensor<3x4xi32>) -> tensor<1x2xi32>
- // CHECK-NEXT: return %[[DST]]
- return %result : tensor<1x2xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @slice_multi_stride
-func private @slice_multi_stride(%arg0: tensor<3x4xi32>) -> tensor<2x4xi32> {
- // CHECK-SAME: %[[SRC:.+]]:
- // CHECK: %[[DST:.+]] = vmla.buffer.alloc byte_length = %c32
- // CHECK: vmla.copy
- // CHECK-SAME: %[[SRC]](%rs3_4 : !shapex.ranked_shape<[3,4]>),
- // CHECK-SAME: src_indices = [%c1, %c0],
- // CHECK-SAME: out %[[DST]](%rs2_4 : !shapex.ranked_shape<[2,4]>),
- // CHECK-SAME: dst_indices = [%c0, %c0], lengths = [%c2, %c4] : i32
- %result = "mhlo.slice"(%arg0) {
- start_indices = dense<[1, 0]> : tensor<2xi64>,
- limit_indices = dense<[3, 4]> : tensor<2xi64>,
- strides = dense<1> : tensor<2xi64>
- } : (tensor<3x4xi32>) -> tensor<2x4xi32>
- // CHECK-NEXT: return %[[DST]]
- return %result : tensor<2x4xi32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir
deleted file mode 100644
index 826cfc8..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir
+++ /dev/null
@@ -1,37 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-func private @sort1D(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
- // CHECK-DAG: [[C16:%.+]] = constant 16 : index
- // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4]>
- // CHECK-DAG: [[BL:%.+]] = vmla.buffer.alloc byte_length = [[C16]] : !vmla.buffer
- // CHECK-DAG: vmla.sort %arg0([[RS]] : !shapex.ranked_shape<[4]>), out [[BL]] : f32
- // CHECK-DAG: [[BUF:%.+]] = vmla.buffer.alloc byte_length = [[C16]] : !vmla.buffer
- // CHECK-DAG: vmla.gather %arg0([[RS]] : !shapex.ranked_shape<[4]>), [[BL]]([[RS]] : !shapex.ranked_shape<[4]>), out [[BUF]]([[RS]] : !shapex.ranked_shape<[4]>) {batch_dims = 0 : i64, dim = 0 : i64} : f32
- %sort = "mhlo.sort"(%arg0) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
- %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
- "mhlo.return"(%compare) : (tensor<i1>) -> ()
- }) {dimension = 0 : i64, is_stable = false} : (tensor<4xf32>) -> tensor<4xf32>
-
- // CHECK: return [[BUF]] : !vmla.buffer
- return %sort : tensor<4xf32>
-}
-
-
-// CHECK-LABEL: func private @sort2D
-func private @sort2D(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
- // CHECK-DAG: [[C64:%.+]] = constant 64 : index
- // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4,4]>
- // CHECK-DAG: [[BL:%.+]] = vmla.buffer.alloc byte_length = [[C64]] : !vmla.buffer
- // CHECK-DAG: vmla.sort %arg0([[RS]] : !shapex.ranked_shape<[4,4]>), out [[BL]] : f32
- // CHECK-DAG: [[BUF:%.+]] = vmla.buffer.alloc byte_length = [[C64]] : !vmla.buffer
- // CHECK-DAG: vmla.gather %arg0([[RS]] : !shapex.ranked_shape<[4,4]>), [[BL]]([[RS]] : !shapex.ranked_shape<[4,4]>), out [[BUF]]([[RS]] : !shapex.ranked_shape<[4,4]>) {batch_dims = 1 : i64, dim = 1 : i64} : f32
- %sort = "mhlo.sort"(%arg0) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
- %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
- "mhlo.return"(%compare) : (tensor<i1>) -> ()
- }) {dimension = 1 : i64, is_stable = false} : (tensor<4x4xf32>) -> tensor<4x4xf32>
-
- // CHECK: return [[BUF]] : !vmla.buffer
- return %sort : tensor<4x4xf32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/transpose.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/transpose.mlir
deleted file mode 100644
index 19784d7..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/transpose.mlir
+++ /dev/null
@@ -1,18 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @transpose
-func private @transpose() -> tensor<24x7x10xf32> {
- // CHECK-DAG: %[[SRC:.+]] = vmla.constant
- %input = constant dense<1.0> : tensor<7x24x10xf32>
- // CHECK-DAG: %[[SRC_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[7,24,10]>
- // CHECK-DAG: %[[DST_SHAPE:.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[24,7,10]>
- // CHECK-DAG: %[[DST_SIZE:.+]] = constant 6720 : index
- // CHECK-DAG: %[[DST:.+]] = vmla.buffer.alloc byte_length = %[[DST_SIZE]] : !vmla.buffer
- // CHECK-NEXT: vmla.transpose
- // CHECK-SAME: %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[7,24,10]>),
- // CHECK-SAME: out %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[24,7,10]>)
- // CHECK-SAME: {permutation = dense<[1, 0, 2]> : tensor<3xi32>} : f32
- %0 = "mhlo.transpose"(%input) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<7x24x10xf32>) -> tensor<24x7x10xf32>
- // CHECK-NEXT: return %[[DST]]
- return %0 : tensor<24x7x10xf32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD
deleted file mode 100644
index 58933a4..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD
+++ /dev/null
@@ -1,41 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "StandardToVMLA",
- srcs = [
- "ConvertStandardToVMLA.cpp",
- ],
- hdrs = [
- "ConvertStandardToVMLA.h",
- ],
- deps = [
- "//iree/compiler/Dialect/IREE/IR",
- "//iree/compiler/Dialect/VMLA/Conversion",
- "//iree/compiler/Dialect/VMLA/IR",
- "//iree/compiler/Dialect/VMLA/IR:VMLADialect",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:MathDialect",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:Transforms",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/CMakeLists.txt
deleted file mode 100644
index dd2d3bf..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/CMakeLists.txt
+++ /dev/null
@@ -1,34 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- StandardToVMLA
- HDRS
- "ConvertStandardToVMLA.h"
- SRCS
- "ConvertStandardToVMLA.cpp"
- DEPS
- MLIRIR
- MLIRMath
- MLIRPass
- MLIRStandard
- MLIRSupport
- MLIRTransforms
- iree::compiler::Dialect::IREE::IR
- iree::compiler::Dialect::VMLA::Conversion
- iree::compiler::Dialect::VMLA::IR
- iree::compiler::Dialect::VMLA::IR::VMLADialect
- PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.cpp
deleted file mode 100644
index 1a76817..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.cpp
+++ /dev/null
@@ -1,283 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.h"
-
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-struct ConstantOpConversion
- : public VMLAOpConversion<mlir::ConstantOp, IREE::VMLA::BufferConstOp> {
- using VMLAOpConversion::VMLAOpConversion;
-
- LogicalResult matchAndRewrite(
- mlir::ConstantOp srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto value = srcOp.value().dyn_cast<ElementsAttr>();
- if (!value) return failure();
-
- if (value.getType().getElementType().isInteger(1)) {
- value = value.mapValues(rewriter.getIntegerType(8),
- llvm::function_ref<APInt(const APInt &val)>(
- [](const APInt &val) -> APInt {
- return APInt(8, val.getBoolValue());
- }));
- }
-
- rewriter.replaceOpWithNewOp<IREE::VMLA::ConstantOp>(srcOp, value);
- return success();
- }
-};
-
-struct CmpIOpConversion
- : public VMLAOpConversion<mlir::CmpIOp, IREE::VMLA::CmpOp> {
- using VMLAOpConversion::VMLAOpConversion;
-
- LogicalResult matchAndRewrite(
- mlir::CmpIOp srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto inputType = srcOp.lhs().getType().dyn_cast<ShapedType>();
- if (!inputType) return failure();
-
- IREE::VMLA::CmpPredicate predicate = IREE::VMLA::CmpPredicate::EQ;
- bool forceUnsigned = false;
- switch (srcOp.predicate()) {
- case CmpIPredicate::eq:
- predicate = IREE::VMLA::CmpPredicate::EQ;
- break;
- case CmpIPredicate::ne:
- predicate = IREE::VMLA::CmpPredicate::NE;
- break;
- case CmpIPredicate::slt:
- predicate = IREE::VMLA::CmpPredicate::LT;
- break;
- case CmpIPredicate::sle:
- predicate = IREE::VMLA::CmpPredicate::LE;
- break;
- case CmpIPredicate::sgt:
- predicate = IREE::VMLA::CmpPredicate::GT;
- break;
- case CmpIPredicate::sge:
- predicate = IREE::VMLA::CmpPredicate::GE;
- break;
- case CmpIPredicate::ult:
- predicate = IREE::VMLA::CmpPredicate::LT;
- forceUnsigned = true;
- break;
- case CmpIPredicate::ule:
- predicate = IREE::VMLA::CmpPredicate::LE;
- forceUnsigned = true;
- break;
- case CmpIPredicate::ugt:
- predicate = IREE::VMLA::CmpPredicate::GT;
- forceUnsigned = true;
- break;
- case CmpIPredicate::uge:
- predicate = IREE::VMLA::CmpPredicate::GE;
- forceUnsigned = true;
- break;
- default:
- llvm_unreachable("unhandled comparison predicate");
- return failure();
- }
-
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
- auto newOp = rewriter.create<IREE::VMLA::CmpOp>(
- srcOp.getLoc(), predicate, operands[0], operands[1], dst,
- TypeAttr::get(inputType.getElementType()));
- if (forceUnsigned) {
- newOp->setAttr("force_unsigned", UnitAttr::get(rewriter.getContext()));
- }
- rewriter.replaceOp(srcOp, newOp.dst());
- return success();
- }
-};
-
-class CmpFOpConversion
- : public VMLAOpConversion<mlir::CmpFOp, IREE::VMLA::CmpOp> {
- public:
- using VMLAOpConversion::VMLAOpConversion;
-
- LogicalResult matchAndRewrite(
- mlir::CmpFOp srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto inputType = srcOp.lhs().getType().dyn_cast<ShapedType>();
- if (!inputType) return failure();
-
- // NOTE: the std.cmpf semantics are practically undefined. We explicitly
- // match the HLO semantics (that get lowered to the expected case values
- // here). In the future as new ML-focused intermediate dialects are built we
- // can reevaluate what we support here.
- //
- // Rules:
- // https://stackoverflow.com/questions/8627331/what-does-ordered-unordered-comparison-mean
- IREE::VMLA::CmpPredicate predicate = IREE::VMLA::CmpPredicate::EQ;
- switch (srcOp.getPredicate()) {
- case CmpFPredicate::OEQ:
- predicate = IREE::VMLA::CmpPredicate::EQ;
- break;
- case CmpFPredicate::UNE:
- predicate = IREE::VMLA::CmpPredicate::NE;
- break;
- case CmpFPredicate::OLT:
- predicate = IREE::VMLA::CmpPredicate::LT;
- break;
- case CmpFPredicate::OLE:
- predicate = IREE::VMLA::CmpPredicate::LE;
- break;
- case CmpFPredicate::OGT:
- predicate = IREE::VMLA::CmpPredicate::GT;
- break;
- case CmpFPredicate::OGE:
- predicate = IREE::VMLA::CmpPredicate::GE;
- break;
- default:
- llvm_unreachable("unhandled comparison predicate");
- return failure();
- }
-
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
- auto newOp = rewriter.create<IREE::VMLA::CmpOp>(
- srcOp.getLoc(), predicate, operands[0], operands[1], dst,
- TypeAttr::get(inputType.getElementType()));
- rewriter.replaceOp(srcOp, newOp.dst());
- return success();
- }
-};
-
-class ZeroExtendIOpConversion
- : public VMLAOpConversion<mlir::ZeroExtendIOp, IREE::VMLA::CmpOp> {
- public:
- using VMLAOpConversion::VMLAOpConversion;
-
- LogicalResult matchAndRewrite(
- mlir::ZeroExtendIOp srcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto srcType = srcOp.getOperand().getType().dyn_cast<ShapedType>();
- auto dstType = srcOp.getResult().getType().dyn_cast<ShapedType>();
- if (!srcType || !dstType) return failure();
- if ((srcType.getElementTypeBitWidth() == 1 &&
- dstType.getElementTypeBitWidth() == 8) ||
- (srcType.getElementTypeBitWidth() == 8 &&
- dstType.getElementTypeBitWidth() == 1)) {
- auto dst = VMLAConversionTarget::allocateOutputBuffer(
- srcOp.getLoc(), srcOp.getResult(), *getTypeConverter(), rewriter);
- auto bitMask = rewriter.createOrFold<mlir::ConstantIntOp>(
- srcOp.getLoc(), 1, rewriter.getI32Type());
- rewriter.createOrFold<IREE::VMLA::AndBroadcastOp>(
- srcOp.getLoc(), operands[0], bitMask, dst,
- TypeAttr::get(rewriter.getIntegerType(8)), false);
- rewriter.replaceOp(srcOp, {dst});
- return success();
- } else {
- // Unhandled.
- return failure();
- }
- }
-};
-
-} // namespace
-
-void populateStandardToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter) {
- patterns.insert<ConstantOpConversion>(typeConverter, context);
- patterns.insert<CmpIOpConversion>(typeConverter, context);
- patterns.insert<CmpFOpConversion>(typeConverter, context);
- patterns.insert<ZeroExtendIOpConversion>(typeConverter, context);
-
- patterns.insert<VMLAOpConversion<mlir::ReturnOp, mlir::ReturnOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::AddIOp, IREE::VMLA::AddOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::AddFOp, IREE::VMLA::AddOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::SubIOp, IREE::VMLA::SubOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::SubFOp, IREE::VMLA::SubOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::MulIOp, IREE::VMLA::MulOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::MulFOp, IREE::VMLA::MulOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::SignedDivIOp, IREE::VMLA::DivOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::UnsignedDivIOp, IREE::VMLA::DivOp,
- VMLAOpSemantics::kForceUnsigned>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::DivFOp, IREE::VMLA::DivOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::AbsFOp, IREE::VMLA::AbsOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::SignedRemIOp, IREE::VMLA::RemOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::UnsignedRemIOp, IREE::VMLA::RemOp,
- VMLAOpSemantics::kForceUnsigned>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::RemFOp, IREE::VMLA::RemOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::math::LogOp, IREE::VMLA::LogOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::math::ExpOp, IREE::VMLA::ExpOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::math::SqrtOp, IREE::VMLA::SqrtOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::math::CosOp, IREE::VMLA::CosOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::math::TanhOp, IREE::VMLA::TanhOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::NegFOp, IREE::VMLA::NegOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::AndOp, IREE::VMLA::AndOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::OrOp, IREE::VMLA::OrOp>>(typeConverter,
- context);
- patterns.insert<VMLAOpConversion<mlir::XOrOp, IREE::VMLA::XorOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::ShiftLeftOp, IREE::VMLA::ShlOp>>(
- typeConverter, context);
- patterns
- .insert<VMLAOpConversion<mlir::SignedShiftRightOp, IREE::VMLA::ShrOp>>(
- typeConverter, context);
- patterns
- .insert<VMLAOpConversion<mlir::UnsignedShiftRightOp, IREE::VMLA::ShrOp,
- VMLAOpSemantics::kForceUnsigned>>(typeConverter,
- context);
- patterns.insert<VMLAOpConversion<mlir::CeilFOp, IREE::VMLA::CeilOp>>(
- typeConverter, context);
- patterns.insert<VMLAOpConversion<mlir::SelectOp, IREE::VMLA::SelectOp>>(
- typeConverter, context);
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.h b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.h
deleted file mode 100644
index a65d1ce..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.h
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_VMLA_CONVERSION_STANDARDTOVMLA_CONVERTSTANDARDTOVMLA_H_
-#define IREE_COMPILER_DIALECT_VMLA_CONVERSION_STANDARDTOVMLA_CONVERTSTANDARDTOVMLA_H_
-
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Populates conversion patterns from the std dialect to the VMLA dialect.
-void populateStandardToVMLAPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns,
- TypeConverter &typeConverter);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_VMLA_CONVERSION_STANDARDTOVMLA_CONVERTSTANDARDTOVMLA_H_
- // // NOLINT
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD
deleted file mode 100644
index 80a7bd4..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:lit_test.bzl", "iree_lit_test_suite")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_lit_test_suite(
- name = "lit",
- srcs = enforce_glob(
- [
- "comparison_ops.mlir",
- "constant_ops.mlir",
- "math_ops.mlir",
- ],
- include = ["*.mlir"],
- ),
- data = [
- "//iree/tools:IreeFileCheck",
- "//iree/tools:iree-opt",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/CMakeLists.txt
deleted file mode 100644
index ed4a341..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/CMakeLists.txt
+++ /dev/null
@@ -1,25 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_lit_test_suite(
- NAME
- lit
- SRCS
- "comparison_ops.mlir"
- "constant_ops.mlir"
- "math_ops.mlir"
- DATA
- iree::tools::IreeFileCheck
- iree::tools::iree-opt
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/comparison_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/comparison_ops.mlir
deleted file mode 100644
index bd516c8..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/comparison_ops.mlir
+++ /dev/null
@@ -1,23 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @cmp_i
-func private @cmp_i(%arg0 : tensor<4xi32>, %arg1 : tensor<4xi32>) -> tensor<4xi1> {
- // CHECK: %[[BUF_SZ:.+]] = constant 4
- // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
- // CHECK-NEXT: vmla.cmp GE, %arg0, %arg1, out %[[BUF]] : i32
- %0 = cmpi sge, %arg0, %arg1 : tensor<4xi32>
- // CHECK-NEXT: return %[[BUF]]
- return %0 : tensor<4xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @cmp_f
-func private @cmp_f(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor<4xi1> {
- // CHECK: %[[BUF_SZ:.+]] = constant 4
- // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
- // CHECK-NEXT: vmla.cmp GE, %arg0, %arg1, out %[[BUF]] : f32
- %0 = cmpf oge, %arg0, %arg1 : tensor<4xf32>
- // CHECK-NEXT: return %[[BUF]]
- return %0 : tensor<4xi1>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/constant_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/constant_ops.mlir
deleted file mode 100644
index d4b204c..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/constant_ops.mlir
+++ /dev/null
@@ -1,26 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @constant_scalar
-func private @constant_scalar() -> tensor<i16> {
- // CHECK: = vmla.constant dense<12345> : tensor<i16> -> !vmla.buffer
- %0 = constant dense<12345> : tensor<i16>
- return %0 : tensor<i16>
-}
-
-// -----
-
-// CHECK-LABEL: @constant_tensor
-func private @constant_tensor() -> tensor<4xf32> {
- // CHECK: = vmla.constant dense<[-1.000000e+00, -2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> -> !vmla.buffer
- %0 = constant dense<[-1.0, -2.0, 3.0, 4.0]> : tensor<4xf32>
- return %0 : tensor<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @constant_tensor_bool
-func private @constant_tensor_bool() -> tensor<4xi1> {
- // CHECK: = vmla.constant dense<[0, 1, 1, 0]> : tensor<4xi8> -> !vmla.buffer
- %0 = constant dense<[false, true, true, false]> : tensor<4xi1>
- return %0 : tensor<4xi1>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/math_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/math_ops.mlir
deleted file mode 100644
index dde8e6a..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/test/math_ops.mlir
+++ /dev/null
@@ -1,35 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
-
-// CHECK-LABEL: @absf
-func private @absf(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
- // CHECK-NEXT: %[[BUF_SZ:.+]] = constant 16
- // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
- // CHECK-NEXT: vmla.abs %arg0, out %[[BUF]] : f32
- %0 = absf %arg0 : tensor<4xf32>
- // CHECK-NEXT: return %[[BUF]]
- return %0 : tensor<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @shr_signed
-func private @shr_signed(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
- // CHECK-NEXT: %[[BUF_SZ:.+]] = constant 16
- // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
- // CHECK-NEXT: vmla.shr %arg0, %arg0, out %[[BUF]] : i32
- %0 = shift_right_signed %arg0, %arg0 : tensor<4xi32>
- // CHECK-NEXT: return %[[BUF]]
- return %0 : tensor<4xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @shr_unsigned
-func private @shr_unsigned(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
- // CHECK-NEXT: %[[BUF_SZ:.+]] = constant 16
- // CHECK-NEXT: %[[BUF:.+]] = vmla.buffer.alloc byte_length = %[[BUF_SZ]] : !vmla.buffer
- // CHECK-NEXT: vmla.shr %arg0, %arg0, out %[[BUF]] {force_unsigned} : i32
- %0 = shift_right_unsigned %arg0, %arg0 : tensor<4xi32>
- // CHECK-NEXT: return %[[BUF]]
- return %0 : tensor<4xi32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.cpp
deleted file mode 100644
index ce35638..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.cpp
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
-
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "mlir/IR/BuiltinTypes.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-VMLATypeConverter::VMLATypeConverter() {
- addConversion([](Type type) -> Type {
- if (type.isInteger(1)) {
- // Widen i1 to i8.
- return IntegerType::get(type.getContext(), 8);
- }
- return type;
- });
-
- addConversion([](TensorType type) {
- // TODO(benvanik): composite-type conversion (buffer + dynamic dims).
- return IREE::VMLA::BufferType::get(type.getContext());
- });
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h b/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h
deleted file mode 100644
index 0567194..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h
+++ /dev/null
@@ -1,54 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_VMLA_CONVERSION_TYPECONVERTER_H_
-#define IREE_COMPILER_DIALECT_VMLA_CONVERSION_TYPECONVERTER_H_
-
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-class VMLATypeConverter : public TypeConverter {
- public:
- VMLATypeConverter();
-
- // Returns the number of bytes an element of the given type occupies
- // post-conversion. For example, the size of i1 would be '1 byte'.
- static int32_t getRoundedElementByteWidth(Type type) {
- return (type.getIntOrFloatBitWidth() + 8 - 1) / 8;
- }
-
- // Converts a `tensor` type with an arbitrary element size to one supported by
- // VMLA. For example, `tensor<4x8xi1>` is converted to `tensor<4x8xi8>`.
- static TensorType convertTensorTypeToVMLAType(TensorType sourceType) {
- assert(sourceType.hasRank() && "only ranked tensors are supported");
- Type sourceElementType = sourceType.getElementType();
- Type targetElementType = sourceElementType;
- if (auto sourceIntType = sourceElementType.dyn_cast<IntegerType>()) {
- int32_t targetByteWidth = getRoundedElementByteWidth(sourceElementType);
- targetElementType =
- IntegerType::get(sourceElementType.getContext(), targetByteWidth * 8);
- }
- return RankedTensorType::get(sourceType.getShape(), targetElementType);
- }
-
- // TODO(benvanik): signature conversion for output buffers.
-};
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_VMLA_CONVERSION_TYPECONVERTER_H_
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/BUILD b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/BUILD
deleted file mode 100644
index 9f61f7c..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/BUILD
+++ /dev/null
@@ -1,43 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "VMLAToVM",
- srcs = [
- "ConvertVMLAToVM.cpp",
- ],
- hdrs = [
- "ConvertVMLAToVM.h",
- ],
- deps = [
- "//iree/compiler/Dialect/IREE/IR",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/VM/Conversion",
- "//iree/compiler/Dialect/VM/Conversion/StandardToVM",
- "//iree/compiler/Dialect/VM/IR",
- "//iree/compiler/Dialect/VMLA:vmla_imports",
- "//iree/compiler/Dialect/VMLA/Conversion",
- "//iree/compiler/Dialect/VMLA/IR",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:Transforms",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/CMakeLists.txt
deleted file mode 100644
index fd85f6e..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/CMakeLists.txt
+++ /dev/null
@@ -1,36 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- VMLAToVM
- HDRS
- "ConvertVMLAToVM.h"
- SRCS
- "ConvertVMLAToVM.cpp"
- DEPS
- MLIRIR
- MLIRPass
- MLIRStandard
- MLIRTransforms
- iree::compiler::Dialect::IREE::IR
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::VM::Conversion
- iree::compiler::Dialect::VM::Conversion::StandardToVM
- iree::compiler::Dialect::VM::IR
- iree::compiler::Dialect::VMLA::Conversion
- iree::compiler::Dialect::VMLA::IR
- iree::compiler::Dialect::VMLA::vmla_imports
- PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
deleted file mode 100644
index 45712d3..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ /dev/null
@@ -1,442 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.h"
-
-#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/VM/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
-#include "iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.h"
-#include "iree/compiler/Dialect/VM/Conversion/TypeConverter.h"
-#include "iree/compiler/Dialect/VM/IR/VMOps.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "iree/compiler/Dialect/VMLA/vmla.imports.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace {
-
-// Erases an op. This should only be used for ops that are legalized away
-// as part of lowering (i.e. tagging or metadata ops that are unrepresentable
-// in the VM dialect).
-class EraseNonVMOp : public ConversionPattern {
- public:
- EraseNonVMOp(StringRef rootName, MLIRContext *ctx)
- : ConversionPattern(rootName, 0, ctx) {}
-
- LogicalResult matchAndRewrite(
- Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.eraseOp(op);
- return success();
- }
-};
-
-// When converting to the VM, it is safe to remove any identity tie_shape
-// ops that remain.
-class ElideTieShapeOp : public OpConversionPattern<Shape::TieShapeOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(
- Shape::TieShapeOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOp(op, operands[0]);
- return success();
- }
-};
-
-// VMLA -> VM import conversion base for generic ops.
-// Handles signatures with integers, VM types, or simple buffers.
-template <typename T, typename Adaptor = typename T::Adaptor>
-class VMLAImportOpConversion : public OpConversionPattern<T> {
- public:
- VMLAImportOpConversion(MLIRContext *context, SymbolTable &importSymbols,
- TypeConverter &typeConverter, StringRef importName)
- : OpConversionPattern<T>(context),
- importSymbols(importSymbols),
- typeConverter(typeConverter),
- importName(importName) {}
-
- LogicalResult matchAndRewrite(
- T op, llvm::ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- std::string importFqName = importName + getImportSuffix(op);
- auto importOp =
- importSymbols.template lookup<IREE::VM::ImportOp>(importFqName);
- if (!importOp) {
- op.emitError() << "failed to resolve VM function import for "
- << importFqName;
- return failure();
- }
- assert(importOp);
- return rewriteToCall(op, Adaptor{operands}, importOp, typeConverter,
- rewriter);
- }
-
- protected:
- virtual std::string getImportSuffix(T op) const { return ""; }
-
- std::string getSizedTypeStr(Type elementType) const {
- int bitWidth = elementType.getIntOrFloatBitWidth();
- // Widen i1 -> i8 to match the VM type conversion.
- if (bitWidth == 1) {
- bitWidth = 8;
- }
- return "x" + std::to_string(bitWidth);
- }
-
- std::string getTypedTypeStr(Type type, bool forceUnsigned = false) const {
- Type elementType = type;
- auto shapedType = type.dyn_cast<ShapedType>();
- if (shapedType) {
- elementType = shapedType.getElementType();
- }
-
- std::string typePrefix = "x";
- if (elementType.isa<FloatType>()) {
- typePrefix = "f";
- } else if (elementType.isSignlessInteger()) {
- typePrefix = forceUnsigned ? "u" : "i";
- }
-
- int bitWidth = elementType.getIntOrFloatBitWidth();
- // Widen i1 -> i8 to match the VM type conversion.
- if (bitWidth == 1) {
- bitWidth = 8;
- }
- return typePrefix + std::to_string(bitWidth);
- }
-
- private:
- SymbolTable &importSymbols;
- TypeConverter &typeConverter;
- std::string importName;
-};
-#define VMLA_IMPORT_OP(op_type, op_mnemonic) \
- patterns.insert<VMLAImportOpConversion<op_type>>( \
- context, importSymbols, typeConverter, op_mnemonic);
-
-// VMLA -> VM import conversion for ops using sized operands (foo.xNN).
-// This will use only the bit-width of the element type to add a .xNN suffix to
-// the op name. Assumes the element type is valid.
-template <typename T>
-class VMLASizedImportOpConversion : public VMLAImportOpConversion<T> {
- public:
- using VMLAImportOpConversion<T>::VMLAImportOpConversion;
-
- std::string getImportSuffix(T op) const override {
- return std::string(".") + this->getSizedTypeStr(op.element_type());
- }
-};
-#define VMLA_SIZED_IMPORT_OP(op_type, op_mnemonic) \
- patterns.insert<VMLASizedImportOpConversion<op_type>>( \
- context, importSymbols, typeConverter, op_mnemonic);
-
-// VMLA -> VM import conversion for ops using typed operands (foo.fNN, etc).
-// This will use the element type to add a type-specific suffix to the op name.
-// Assumes the element type is valid.
-template <typename T>
-class VMLATypedImportOpConversion : public VMLAImportOpConversion<T> {
- public:
- using VMLAImportOpConversion<T>::VMLAImportOpConversion;
-
- std::string getImportSuffix(T op) const override {
- bool forceUnsigned =
- !!static_cast<Operation *>(op)->getAttr("forceUnsigned");
- return "." + this->getTypedTypeStr(op.element_type(), forceUnsigned);
- }
-};
-#define VMLA_TYPED_IMPORT_OP(op_type, op_mnemonic) \
- patterns.insert<VMLATypedImportOpConversion<op_type>>( \
- context, importSymbols, typeConverter, op_mnemonic);
-
-class VMLAConstantOpConversion
- : public OpConversionPattern<IREE::VMLA::ConstantOp> {
- public:
- VMLAConstantOpConversion(MLIRContext *context,
- TypeConverter & /*typeConverter*/)
- : OpConversionPattern(context) {}
-
- LogicalResult matchAndRewrite(
- IREE::VMLA::ConstantOp op, llvm::ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- if (auto splatAttr = op.value().dyn_cast<SplatElementsAttr>()) {
- // Encode just a single splat element and use a buffer fill.
- auto rodataValue = rewriter.createOrFold<IREE::VM::RodataInlineOp>(
- op.getLoc(),
- IREE::VM::RefType::get(IREE::VM::BufferType::get(op.getContext())),
- DenseElementsAttr::get(
- RankedTensorType::get({1}, splatAttr.getSplatValue().getType()),
- splatAttr.getSplatValue()));
- auto fillValue = rewriter.createOrFold<IREE::VMLA::BufferConstOp>(
- op.getLoc(), IREE::VMLA::BufferType::get(op.getContext()),
- rodataValue);
- auto bufferLengthValue = rewriter.createOrFold<mlir::ConstantIndexOp>(
- op.getLoc(), splatAttr.getType().cast<ShapedType>().getNumElements() *
- VMLATypeConverter::getRoundedElementByteWidth(
- splatAttr.getSplatValue().getType()));
- auto bufferValue = rewriter.createOrFold<IREE::VMLA::BufferAllocOp>(
- op.getLoc(), IREE::VMLA::BufferType::get(op.getContext()),
- bufferLengthValue);
- rewriter.create<IREE::VMLA::BufferFillOp>(op.getLoc(), fillValue,
- bufferValue);
- rewriter.replaceOp(op, bufferValue);
- } else {
- // Encode constant data into a rodata segment. These will eventually get
- // deduped and combined.
- auto rodataValue = rewriter.createOrFold<IREE::VM::RodataInlineOp>(
- op.getLoc(),
- IREE::VM::RefType::get(IREE::VM::BufferType::get(op.getContext())),
- op.value());
- rewriter.replaceOpWithNewOp<IREE::VMLA::BufferConstOp>(
- op, IREE::VMLA::BufferType::get(op.getContext()), rodataValue);
- }
- return success();
- }
-
- private:
- // TODO(b/145839814): find a name that's unique or make the rewriter support
- // assigning unique names.
- int allocateUniqueId(Operation *context) const {
- if (uniqueContext != context) {
- uniqueContext = context;
- uniqueCounter = 0;
- }
- return uniqueCounter++;
- }
- mutable Operation *uniqueContext = nullptr;
- mutable int uniqueCounter = 0;
-};
-
-class VMLAConvertImportOpConversion
- : public VMLAImportOpConversion<IREE::VMLA::ConvertOp> {
- public:
- using VMLAImportOpConversion<IREE::VMLA::ConvertOp>::VMLAImportOpConversion;
-
- std::string getImportSuffix(IREE::VMLA::ConvertOp op) const override {
- return std::string(".") + getTypedTypeStr(op.src_type()) +
- std::string(".") + getTypedTypeStr(op.dst_type());
- }
-};
-
-class VMLABatchMatMulImportOpConversion
- : public VMLAImportOpConversion<IREE::VMLA::BatchMatMulOp> {
- public:
- using VMLAImportOpConversion<
- IREE::VMLA::BatchMatMulOp>::VMLAImportOpConversion;
-
- std::string getImportSuffix(IREE::VMLA::BatchMatMulOp op) const override {
- return std::string(".") + getTypedTypeStr(op.lhs_type()) +
- getTypedTypeStr(op.rhs_type()) + std::string(".") +
- getTypedTypeStr(op.dst_type());
- }
-};
-
-class VMLAConvImportOpConversion
- : public VMLAImportOpConversion<IREE::VMLA::ConvOp> {
- public:
- using VMLAImportOpConversion<IREE::VMLA::ConvOp>::VMLAImportOpConversion;
-
- std::string getImportSuffix(IREE::VMLA::ConvOp op) const override {
- return std::string(".") + getTypedTypeStr(op.input_type()) +
- getTypedTypeStr(op.filter_type()) + std::string(".") +
- getTypedTypeStr(op.dst_type());
- }
-};
-
-template <typename FftOp>
-class VMLAFftImportOpConversion : public VMLAImportOpConversion<FftOp> {
- public:
- using VMLAImportOpConversion<FftOp>::VMLAImportOpConversion;
-
- std::string getImportSuffix(FftOp op) const override {
- return std::string(".") + this->getTypedTypeStr(op.element_type());
- }
-};
-} // namespace
-
-void populateVMLAToVMPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- SymbolTable &importSymbols,
- OwningRewritePatternList &patterns) {
- patterns.insert<VMLAConstantOpConversion>(context, typeConverter);
- patterns.insert<EraseNonVMOp>(Shape::ConstRankedShapeOp::getOperationName(),
- context);
- patterns.insert<EraseNonVMOp>(Shape::MakeRankedShapeOp::getOperationName(),
- context);
- patterns.insert<ElideTieShapeOp>(context);
-
- VMLA_IMPORT_OP(IREE::VMLA::BufferConstOp, "vmla.buffer.const");
- VMLA_IMPORT_OP(IREE::VMLA::BufferAllocOp, "vmla.buffer.alloc");
- VMLA_IMPORT_OP(IREE::VMLA::BufferCloneOp, "vmla.buffer.clone");
- VMLA_IMPORT_OP(IREE::VMLA::BufferByteLengthOp, "vmla.buffer.byte_length");
- VMLA_IMPORT_OP(IREE::VMLA::BufferViewOp, "vmla.buffer.view");
- VMLA_IMPORT_OP(IREE::VMLA::BufferCopyOp, "vmla.buffer.copy");
- VMLA_IMPORT_OP(IREE::VMLA::BufferFillOp, "vmla.buffer.fill");
- VMLA_IMPORT_OP(IREE::VMLA::BufferLoadI32Op, "vmla.buffer.load.i32");
-
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::CmpOp, "vmla.cmp");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::SelectOp, "vmla.select");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::FiniteOp, "vmla.finite");
-
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::CopyOp, "vmla.copy");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::TransposeOp, "vmla.transpose");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::ReverseOp, "vmla.reverse");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::PadOp, "vmla.pad");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::GatherOp, "vmla.gather");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::ScatterOp, "vmla.scatter");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::BroadcastOp, "vmla.broadcast");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::IotaOp, "vmla.iota");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::TileOp, "vmla.tile");
-
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::NotOp, "vmla.not");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::AndOp, "vmla.and");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::AndBroadcastOp, "vmla.and.broadcast");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::OrOp, "vmla.or");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::XorOp, "vmla.xor");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::XorBroadcastOp, "vmla.xor.broadcast");
- VMLA_SIZED_IMPORT_OP(IREE::VMLA::ShlOp, "vmla.shl");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::ShrOp, "vmla.shr");
-
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::AddOp, "vmla.add");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::SubOp, "vmla.sub");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::AbsOp, "vmla.abs");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::NegOp, "vmla.neg");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::MulOp, "vmla.mul");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::DivOp, "vmla.div");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::RemOp, "vmla.rem");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::PowOp, "vmla.pow");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::ExpOp, "vmla.exp");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::LogOp, "vmla.log");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::RsqrtOp, "vmla.rsqrt");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::SqrtOp, "vmla.sqrt");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::CosOp, "vmla.cos");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::SinOp, "vmla.sin");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::TanhOp, "vmla.tanh");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::Atan2Op, "vmla.atan2");
-
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::MinOp, "vmla.min");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::MaxOp, "vmla.max");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::ClampOp, "vmla.clamp");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::FloorOp, "vmla.floor");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::CeilOp, "vmla.ceil");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::RoundOp, "vmla.round");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::SortOp, "vmla.sort");
-
- patterns.insert<VMLAConvertImportOpConversion>(context, importSymbols,
- typeConverter, "vmla.convert");
- patterns.insert<VMLABatchMatMulImportOpConversion>(
- context, importSymbols, typeConverter, "vmla.batch.matmul");
- patterns.insert<VMLAConvImportOpConversion>(context, importSymbols,
- typeConverter, "vmla.conv");
- patterns.insert<VMLAFftImportOpConversion<IREE::VMLA::FftOp>>(
- context, importSymbols, typeConverter, "vmla.fft");
- patterns.insert<VMLAFftImportOpConversion<IREE::VMLA::IfftOp>>(
- context, importSymbols, typeConverter, "vmla.ifft");
- patterns.insert<VMLAFftImportOpConversion<IREE::VMLA::RfftOp>>(
- context, importSymbols, typeConverter, "vmla.rfft");
- patterns.insert<VMLAFftImportOpConversion<IREE::VMLA::IrfftOp>>(
- context, importSymbols, typeConverter, "vmla.irfft");
-
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceSumOp, "vmla.reduce.sum");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMinOp, "vmla.reduce.min");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceMaxOp, "vmla.reduce.max");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceAndOp, "vmla.reduce.and");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::ReduceOrOp, "vmla.reduce.or");
-
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingSumOp, "vmla.pooling.sum");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingMinOp, "vmla.pooling.min");
- VMLA_TYPED_IMPORT_OP(IREE::VMLA::PoolingMaxOp, "vmla.pooling.max");
-
- VMLA_IMPORT_OP(IREE::VMLA::InterfaceConstOp, "vmla.interface.const");
- VMLA_IMPORT_OP(IREE::VMLA::InterfaceBindingOp, "vmla.interface.binding");
-}
-
-namespace {
-
-// A pass converting the IREE flow dialect into the IREE VMLA dialect.
-class ConvertVMLAToVMPass
- : public PassWrapper<ConvertVMLAToVMPass, OperationPass<ModuleOp>> {
- public:
- explicit ConvertVMLAToVMPass(IREE::VM::TargetOptions targetOptions)
- : targetOptions_(targetOptions) {}
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<IREEDialect, IREE::VM::VMDialect>();
- }
-
- void runOnOperation() override {
- auto *context = &getContext();
-
- VMConversionTarget conversionTarget(context);
- IREE::VM::TypeConverter typeConverter(targetOptions_);
-
- mlir::ModuleOp outerModuleOp, innerModuleOp;
- std::tie(outerModuleOp, innerModuleOp) =
- VMConversionTarget::nestModuleForConversion(getOperation());
-
- (void)appendImportModule(StringRef(iree_vmla_imports_create()->data,
- iree_vmla_imports_create()->size),
- innerModuleOp);
-
- OwningRewritePatternList conversionPatterns(&getContext());
- populateStandardToVMPatterns(context, typeConverter, conversionPatterns);
-
- SymbolTable importSymbols(innerModuleOp);
- populateVMLAToVMPatterns(context, typeConverter, importSymbols,
- conversionPatterns);
-
- // Usually shape conversion patterns come in at a higher level, but for
- // this standalone pass, they must be provided directly.
- Shape::populateFoldConversionPatterns(&getContext(), conversionPatterns);
-
- if (failed(applyPartialConversion(outerModuleOp, conversionTarget,
- std::move(conversionPatterns)))) {
- outerModuleOp.emitError() << "conversion to vm.module failed";
- return signalPassFailure();
- }
- }
-
- private:
- IREE::VM::TargetOptions targetOptions_;
-};
-
-} // namespace
-
-std::unique_ptr<OperationPass<ModuleOp>> createConvertVMLAToVMPass(
- IREE::VM::TargetOptions targetOptions) {
- return std::make_unique<ConvertVMLAToVMPass>(targetOptions);
-}
-
-static PassRegistration<ConvertVMLAToVMPass> pass(
- "iree-convert-vmla-to-vm",
- "Convert the IREE VMLA dialect to the IREE VM dialect", [] {
- auto options = IREE::VM::getTargetOptionsFromFlags();
- return std::make_unique<ConvertVMLAToVMPass>(options);
- });
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.h b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.h
deleted file mode 100644
index db39320..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.h
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_VMLA_CONVERSION_VMLATOVM_CONVERTVMLATOVM_H_
-#define IREE_COMPILER_DIALECT_VMLA_CONVERSION_VMLATOVM_CONVERTVMLATOVM_H_
-
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-// Populates conversion patterns from the VMLA dialect to the VM dialect.
-void populateVMLAToVMPatterns(MLIRContext *context,
- TypeConverter &typeConverter,
- SymbolTable &importSymbols,
- OwningRewritePatternList &patterns);
-
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_VMLA_CONVERSION_VMLATOVM_CONVERTVMLATOVM_H_
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD
deleted file mode 100644
index 368d102..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:lit_test.bzl", "iree_lit_test_suite")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_lit_test_suite(
- name = "lit",
- srcs = enforce_glob(
- [
- "constant_ops.mlir",
- "conversion.mlir",
- ],
- include = ["*.mlir"],
- ),
- data = [
- "//iree/tools:IreeFileCheck",
- "//iree/tools:iree-opt",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/CMakeLists.txt
deleted file mode 100644
index e60fbfe..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/CMakeLists.txt
+++ /dev/null
@@ -1,24 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_lit_test_suite(
- NAME
- lit
- SRCS
- "constant_ops.mlir"
- "conversion.mlir"
- DATA
- iree::tools::IreeFileCheck
- iree::tools::iree-opt
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/constant_ops.mlir b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/constant_ops.mlir
deleted file mode 100644
index 8635072..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/constant_ops.mlir
+++ /dev/null
@@ -1,23 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-convert-vmla-to-vm -cse %s | IreeFileCheck %s
-
-// CHECK-LABEL: vm.func @denseConstant
-func @denseConstant() -> !vmla.buffer {
- // CHECK-NEXT: [[RODATA:%.+]] = vm.rodata.inline : !vm.buffer = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32
- // CHECK-NEXT: = vm.call @vmla.buffer.const([[RODATA]]) : (!vm.buffer) -> !vm.ref<!vmla.buffer>
- %0 = vmla.constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32> -> !vmla.buffer
- return %0 : !vmla.buffer
-}
-
-// -----
-
-// CHECK-LABEL: @splatConstant
-func @splatConstant() -> !vmla.buffer {
- // CHECK-NEXT: [[RODATA:%.+]] = vm.rodata.inline : !vm.buffer = dense<0.176776692> : tensor<1xf32>
- // CHECK-NEXT: [[SPLATTED:%.+]] = vm.call @vmla.buffer.const([[RODATA]]) : (!vm.buffer) -> !vm.ref<!vmla.buffer>
- // CHECK-NEXT: [[LENGTH:%.+]] = vm.const.i32 2359296 : i32
- // CHECK-NEXT: [[RESULT:%.+]] = vm.call @vmla.buffer.alloc([[LENGTH]]) : (i32) -> !vm.ref<!vmla.buffer>
- // CHECK-NEXT: vm.call @vmla.buffer.fill([[SPLATTED]], [[RESULT]]) : (!vm.ref<!vmla.buffer>, !vm.ref<!vmla.buffer>) -> ()
- %0 = vmla.constant dense<0.176776692> : tensor<1x4x384x384xf32> -> !vmla.buffer
- // CHECK-NEXT: vm.return [[RESULT]] : !vm.ref<!vmla.buffer>
- return %0 : !vmla.buffer
-}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/conversion.mlir b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/conversion.mlir
deleted file mode 100644
index fa3aae2..0000000
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/test/conversion.mlir
+++ /dev/null
@@ -1,73 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-convert-vmla-to-vm -cse %s | IreeFileCheck %s
-
-// CHECK-LABEL: vm.func @bufferImport
-func @bufferImport() -> !vmla.buffer {
- %c0 = std.constant 1 : index
- // CHECK: = vm.call @vmla.buffer.alloc(%c1) : (i32) -> !vm.ref<!vmla.buffer>
- %0 = vmla.buffer.alloc byte_length = %c0 : !vmla.buffer
- return %0 : !vmla.buffer
-}
-
-// -----
-
-// CHECK-LABEL: vm.func @typedImport
-func @typedImport(%arg0 : !vmla.buffer, %arg1 : !vmla.buffer) {
- // CHECK-NEXT: %c1 = vm.const.i32 1 : i32
- // CHECK-NEXT: vm.call @vmla.cmp.f32(%c1, %arg0, %arg0, %arg1) : (i32, !vm.ref<!vmla.buffer>, !vm.ref<!vmla.buffer>, !vm.ref<!vmla.buffer>) -> ()
- vmla.cmp "NE", %arg0, %arg0, out %arg1 : f32
- return
-}
-
-// -----
-
-// CHECK-LABEL: vm.func @sizedImport
-func @sizedImport(%arg0 : !vmla.buffer, %arg1 : !vmla.buffer) {
- // CHECK-NEXT: vm.call @vmla.select.x32(%arg0, %arg0, %arg0, %arg1)
- vmla.select %arg0, %arg0, %arg0, out %arg1 : f32
- return
-}
-
-// -----
-
-// CHECK-LABEL: vm.func @shapeExpansion
-// CHECK-SAME: %arg0: !vm.ref<!vmla.buffer>, %arg1: i32, %arg2: !vm.ref<!vmla.buffer>, %arg3: i32
-func @shapeExpansion(%arg0 : !vmla.buffer, %arg1 : index, %arg2 : !vmla.buffer, %arg3 : index) {
- %rs0 = shapex.make_ranked_shape %arg1 : (index) -> !shapex.ranked_shape<[4,?,8]>
- %rs1 = shapex.make_ranked_shape %arg3 : (index) -> !shapex.ranked_shape<[?,4,8]>
- // CHECK-DAG: %c1 = vm.const.i32 1 : i32
- // CHECK-DAG: %c4 = vm.const.i32 4 : i32
- // CHECK-DAG: %c8 = vm.const.i32 8 : i32
- // CHECK-NEXT: vm.call.variadic @vmla.transpose.x16(%arg0, [%c4, %arg1, %c8], [%c1], %arg2, [%arg3, %c4, %c8]) : (!vm.ref<!vmla.buffer>, i32 ..., i32 ..., !vm.ref<!vmla.buffer>, i32 ...)
- vmla.transpose %arg0(%rs0 : !shapex.ranked_shape<[4,?,8]>),
- out %arg2(%rs1 : !shapex.ranked_shape<[?,4,8]>)
- {permutation = dense<[1]> : tensor<1xi32>} : i16
- return
-}
-
-// -----
-
-// CHECK-LABEL: vm.func @convert
-func @convert(%arg0 : !vmla.buffer, %arg1 : !vmla.buffer) {
- // CHECK-NEXT: vm.call @vmla.convert.f32.i8(%arg0, %arg1)
- vmla.convert %arg0, out %arg1 : f32 -> i8
- return
-}
-
-// -----
-
-// CHECK-LABEL: vm.func @batch_matmul
-func @batch_matmul(
- %lhs : !vmla.buffer,
- %lhs_dim : index,
- %rhs : !vmla.buffer,
- %rhs_dim : index,
- %dst : !vmla.buffer) {
- %lhs_shape = shapex.make_ranked_shape %lhs_dim : (index) -> !shapex.ranked_shape<[3,4,?]>
- %rhs_shape = shapex.make_ranked_shape %rhs_dim : (index) -> !shapex.ranked_shape<[3,?,4]>
- %dst_shape = shapex.const_ranked_shape : !shapex.ranked_shape<[3,4,4]>
- // CHECK: vm.call.variadic @vmla.batch.matmul.f32f32.f32(%arg0, [%c3, %c4, %arg1], %arg2, [%c3, %arg3, %c4], %arg4, [%c3, %c4, %c4])
- vmla.batch.matmul %lhs(%lhs_shape : !shapex.ranked_shape<[3,4,?]>) : f32,
- %rhs(%rhs_shape : !shapex.ranked_shape<[3,?,4]>) : f32,
- out %dst(%dst_shape : !shapex.ranked_shape<[3,4,4]>) : f32
- return
-}
diff --git a/iree/compiler/Dialect/VMLA/IR/BUILD b/iree/compiler/Dialect/VMLA/IR/BUILD
deleted file mode 100644
index d648708..0000000
--- a/iree/compiler/Dialect/VMLA/IR/BUILD
+++ /dev/null
@@ -1,177 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//build_tools/bazel:iree_tablegen_doc.bzl", "iree_tablegen_doc")
-load("//build_tools/bazel:tblgen.bzl", "gentbl_cc_library")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-exports_files(["VMLAOps.td"])
-
-filegroup(
- name = "td_files",
- srcs = enforce_glob(
- [
- "VMLABase.td",
- "VMLAOps.td",
- ],
- include = ["*.td"],
- ),
-)
-
-cc_library(
- name = "IR",
- srcs = [
- "VMLAEnums.cpp.inc",
- "VMLAOpInterface.cpp.inc",
- "VMLAOps.cpp",
- "VMLATypes.cpp",
- ],
- hdrs = [
- "VMLAEnums.h.inc",
- "VMLAOpInterface.h.inc",
- "VMLAOps.h",
- "VMLAOps.h.inc",
- "VMLATraits.h",
- "VMLATypes.h",
- ],
- textual_hdrs = [
- "VMLAOps.cpp.inc",
- ],
- deps = [
- ":VMLAEnumsGen",
- ":VMLAOpInterfaceGen",
- ":VMLAOpsGen",
- "//iree/compiler/Dialect/IREE/IR",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/VM/IR",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:SideEffects",
- "@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:TransformUtils",
- "@llvm-project//mlir:Translation",
- ],
-)
-
-cc_library(
- name = "VMLADialect",
- srcs = ["VMLADialect.cpp"],
- hdrs = ["VMLADialect.h"],
- deps = [
- ":IR",
- "//iree/compiler/Dialect/VM/Conversion",
- "//iree/compiler/Dialect/VMLA:vmla_imports",
- "//iree/compiler/Dialect/VMLA/Conversion/VMLAToVM",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Parser",
- "@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:TransformUtils",
- ],
-)
-
-gentbl_cc_library(
- name = "VMLAEnumsGen",
- tbl_outs = [
- (
- ["-gen-enum-decls"],
- "VMLAEnums.h.inc",
- ),
- (
- ["-gen-enum-defs"],
- "VMLAEnums.cpp.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "VMLABase.td",
- td_srcs = [
- ":td_files",
- "//iree/compiler/Dialect/IREE/IR:td_files",
- "//iree/compiler/Dialect/Shape/IR:td_files",
- "@llvm-project//mlir:OpBaseTdFiles",
- "@llvm-project//mlir:StdOpsTdFiles",
- ],
-)
-
-gentbl_cc_library(
- name = "VMLAOpsGen",
- tbl_outs = [
- (
- ["-gen-op-decls"],
- "VMLAOps.h.inc",
- ),
- (
- ["-gen-op-defs"],
- "VMLAOps.cpp.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "VMLAOps.td",
- td_srcs = [
- ":td_files",
- "//iree/compiler/Dialect/IREE/IR:td_files",
- "//iree/compiler/Dialect/Shape/IR:td_files",
- "@llvm-project//mlir:OpBaseTdFiles",
- "@llvm-project//mlir:StdOpsTdFiles",
- ],
-)
-
-gentbl_cc_library(
- name = "VMLAOpInterfaceGen",
- tbl_outs = [
- (
- ["-gen-op-interface-decls"],
- "VMLAOpInterface.h.inc",
- ),
- (
- ["-gen-op-interface-defs"],
- "VMLAOpInterface.cpp.inc",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "VMLABase.td",
- td_srcs = [
- ":td_files",
- "//iree/compiler/Dialect/IREE/IR:td_files",
- "//iree/compiler/Dialect/Shape/IR:td_files",
- "@llvm-project//mlir:OpBaseTdFiles",
- ],
-)
-
-iree_tablegen_doc(
- name = "VMLADialecDocGen",
- tbl_outs = [
- (
- ["-gen-dialect-doc"],
- "VMLADialect.md",
- ),
- ],
- tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "VMLAOps.td",
- td_srcs = [
- ":td_files",
- "//iree/compiler/Dialect/IREE/IR:td_files",
- "//iree/compiler/Dialect/Shape/IR:td_files",
- "@llvm-project//mlir:OpBaseTdFiles",
- "@llvm-project//mlir:StdOpsTdFiles",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/IR/CMakeLists.txt b/iree/compiler/Dialect/VMLA/IR/CMakeLists.txt
deleted file mode 100644
index c33ab9f..0000000
--- a/iree/compiler/Dialect/VMLA/IR/CMakeLists.txt
+++ /dev/null
@@ -1,104 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/IR/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- IR
- HDRS
- "VMLAEnums.h.inc"
- "VMLAOpInterface.h.inc"
- "VMLAOps.h"
- "VMLAOps.h.inc"
- "VMLATraits.h"
- "VMLATypes.h"
- TEXTUAL_HDRS
- "VMLAOps.cpp.inc"
- SRCS
- "VMLAEnums.cpp.inc"
- "VMLAOpInterface.cpp.inc"
- "VMLAOps.cpp"
- "VMLATypes.cpp"
- DEPS
- LLVMSupport
- MLIRIR
- MLIRSideEffectInterfaces
- MLIRStandard
- MLIRSupport
- MLIRTransformUtils
- MLIRTranslation
- iree::compiler::Dialect::IREE::IR
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::VM::IR
- PUBLIC
-)
-
-iree_cc_library(
- NAME
- VMLADialect
- HDRS
- "VMLADialect.h"
- SRCS
- "VMLADialect.cpp"
- DEPS
- ::IR
- LLVMSupport
- MLIRIR
- MLIRParser
- MLIRStandard
- MLIRSupport
- MLIRTransformUtils
- iree::compiler::Dialect::VM::Conversion
- iree::compiler::Dialect::VMLA::Conversion::VMLAToVM
- iree::compiler::Dialect::VMLA::vmla_imports
- PUBLIC
-)
-
-iree_tablegen_library(
- NAME
- VMLAEnumsGen
- TD_FILE
- "VMLABase.td"
- OUTS
- -gen-enum-decls VMLAEnums.h.inc
- -gen-enum-defs VMLAEnums.cpp.inc
-)
-
-iree_tablegen_library(
- NAME
- VMLAOpsGen
- TD_FILE
- "VMLAOps.td"
- OUTS
- -gen-op-decls VMLAOps.h.inc
- -gen-op-defs VMLAOps.cpp.inc
-)
-
-iree_tablegen_library(
- NAME
- VMLAOpInterfaceGen
- TD_FILE
- "VMLABase.td"
- OUTS
- -gen-op-interface-decls VMLAOpInterface.h.inc
- -gen-op-interface-defs VMLAOpInterface.cpp.inc
-)
-
-iree_tablegen_doc(
- NAME
- VMLADialecDocGen
- TD_FILE
- "VMLAOps.td"
- OUTS
- -gen-dialect-doc VMLADialect.md
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLABase.td b/iree/compiler/Dialect/VMLA/IR/VMLABase.td
deleted file mode 100644
index 538192e..0000000
--- a/iree/compiler/Dialect/VMLA/IR/VMLABase.td
+++ /dev/null
@@ -1,220 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_DIALECT_VMLA_BASE
-#define IREE_DIALECT_VMLA_BASE
-
-include "iree/compiler/Dialect/IREE/IR/IREEBase.td"
-include "iree/compiler/Dialect/Shape/IR/ShapeBase.td"
-
-//===----------------------------------------------------------------------===//
-// IREE VMLA (Virtual Machine-based Linear Algebra) dialect
-//===----------------------------------------------------------------------===//
-
-def VMLA_Dialect : Dialect {
- let name = "vmla";
- let cppNamespace = "::mlir::iree_compiler::IREE::VMLA";
-
- let summary = [{
- A dialect representing operations against the IREE VM-based backend.
- }];
- let description = [{
- This is a reference dialect representing a simple IREE VM-based linear
- algebra module that is used as a library at runtime. The ops in this dialect
- map (roughly) 1:1 with the exported functions in the runtime module.
-
- See `vmla.imports.mlir` for the full list of exported functions.
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// VMLA enums
-//===----------------------------------------------------------------------===//
-
-def VMLA_CmpPredicate_EQ : I32EnumAttrCase<"EQ", 0>;
-def VMLA_CmpPredicate_NE : I32EnumAttrCase<"NE", 1>;
-def VMLA_CmpPredicate_LT : I32EnumAttrCase<"LT", 2>;
-def VMLA_CmpPredicate_LE : I32EnumAttrCase<"LE", 3>;
-def VMLA_CmpPredicate_GT : I32EnumAttrCase<"GT", 4>;
-def VMLA_CmpPredicate_GE : I32EnumAttrCase<"GE", 5>;
-def VMLA_CmpPredicateAttr :
- I32EnumAttr<"CmpPredicate", "IREE VMLA comparison op predicate", [
- VMLA_CmpPredicate_EQ,
- VMLA_CmpPredicate_NE,
- VMLA_CmpPredicate_LT,
- VMLA_CmpPredicate_LE,
- VMLA_CmpPredicate_GT,
- VMLA_CmpPredicate_GE,
- ]> {
- let cppNamespace = "::mlir::iree_compiler::IREE::VMLA";
-}
-
-//===----------------------------------------------------------------------===//
-// VMLA types
-//===----------------------------------------------------------------------===//
-
-def VMLA_DeviceSize : TypeAlias<Index>;
-def VMLA_DeviceSizeAttr : IREE_IndexAttrBase<"size_t">;
-
-def VMLA_HostSize : TypeAlias<Index>;
-def VMLA_HostSizeAttr : IREE_IndexAttrBase<"size_t">;
-
-def VMLA_Index : TypeAlias<Index>;
-
-def VMLA_Shape : TypeAlias<Shape_RankedShape>;
-
-def VMLA_HostBuffer : AnyTypeOf<[
- ByteBufferType,
- MutableByteBufferType,
-]>;
-
-def VMLA_Buffer : DialectType<
- VMLA_Dialect,
- CPred<"$_self.isa<IREE::VMLA::BufferType>()">,
- "buffer"> {
- let description = [{
- A lightweight unshaped byte buffer.
- }];
- let builderCall = "$_builder.getType<IREE::VMLA::BufferType>()";
-}
-
-def VMLA_Interface : DialectType<
- VMLA_Dialect,
- CPred<"$_self.isa<IREE::VMLA::InterfaceType>()">,
- "interface"> {
- let description = [{
- Binding and parameter interface (derived from `hal.interface`).
- }];
-
- let builderCall = "$_builder.getType<IREE::VMLA::InterfaceType>()";
-}
-
-// TODO(benvanik): figure out how to get constraints to work.
-// def VMLA_AnyTypeAttr : Confined<TypeAttr, [I8Attr, I16Attr, I32Attr, F32Attr]> {
-// let constBuilderCall = ?;
-// let defaultValue = ?;
-// }
-// def VMLA_FloatTypeAttr : Confined<TypeAttr, [F32Attr]> {
-// let constBuilderCall = ?;
-// let defaultValue = ?;
-// }
-def VMLA_AnyTypeAttr : TypeAttrBase<"Type", "any type attribute">;
-def VMLA_FloatTypeAttr : TypeAttrBase<"Type", "any type attribute">;
-
-//===----------------------------------------------------------------------===//
-// VMLA op traits
-//===----------------------------------------------------------------------===//
-
-// Operations with this trait require shapes be provided for all buffers.
-// For example, if the original HLO op had an `%arg : tensor<4x?xf32>`, adding
-// this trait will have the converted op contain both a `%arg : !vmla.buffer`
-// and an `%arg_shape : shapex.ranked_shape<[4,?]>`.
-def VMLA_IncludeShapes : NativeOpTrait<"IREE::VMLA::IncludeShapes">;
-
-//===----------------------------------------------------------------------===//
-// Base VMLA op classes
-//===----------------------------------------------------------------------===//
-
-def VMLA_OpInterface : OpInterface<"VMLAOp"> {
- let description = [{
- Interface for VMLA ops that can be used to customize the lowering.
- This is required as there is not a way to get reflection information about
- ops.
- }];
-
- let methods = [
- StaticInterfaceMethod<
- "Extracts type information attributes that may be required by the op.",
- "void", "extractTypeAttributes", (ins "OperationState &":$state, "ArrayRef<Type>":$operandTypes, "ArrayRef<Type>":$resultTypes), [{
- ConcreteOp::extractTypeAttributes(state, operandTypes, resultTypes);
- }], [{
- return; // default to no extraction
- }]
- >,
- ];
-}
-
-class VMLA_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<VMLA_Dialect, mnemonic, !listconcat(traits, [])> {
- // TODO(benvanik): use new tablegen printer/parser.
- // let parser = [{ return parse$cppClass(parser, &result); }];
- // let printer = [{ return print$cppClass(p, *this); }];
-}
-
-class VMLA_ElementTypeOp<string mnemonic, list<OpTrait> traits = []> :
- VMLA_Op<mnemonic, !listconcat(traits, [VMLA_OpInterface])> {
- let extraClassDeclaration = [{
- static void extractTypeAttributes(OperationState &state, ArrayRef<Type> operandTypes, ArrayRef<Type> resultTypes) {
- state.addAttribute("element_type", TypeAttr::get(resultTypes[0].cast<ShapedType>().getElementType()));
- }
- }];
-}
-
-class VMLA_UnaryOp<string mnemonic, Attr typeAttr, list<OpTrait> traits = []> :
- VMLA_ElementTypeOp<mnemonic, traits> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Buffer:$dst,
- typeAttr:$element_type,
- // TODO(benvanik): remove once unsigned types are in MLIR.
- UnitAttr:$forceUnsigned
- );
-
- let assemblyFormat = "$src`,` `out` $dst attr-dict `:` $element_type";
-}
-
-class VMLA_BinaryOp<string mnemonic, Attr typeAttr, list<OpTrait> traits = []>
- : VMLA_ElementTypeOp<mnemonic, traits> {
- let arguments = (ins
- VMLA_Buffer:$lhs,
- VMLA_Buffer:$rhs,
- VMLA_Buffer:$dst,
- typeAttr:$element_type,
- // TODO(benvanik): remove once unsigned types are in MLIR.
- UnitAttr:$forceUnsigned
- );
-
- let assemblyFormat = "$lhs`,` $rhs`,` `out` $dst attr-dict `:` $element_type";
-}
-
-class VMLA_BinaryBroadcastOp<string mnemonic, Attr typeAttr, list<OpTrait> traits = []>
- : VMLA_ElementTypeOp<mnemonic, traits> {
- let arguments = (ins
- VMLA_Buffer:$lhs,
- I32:$rhs,
- VMLA_Buffer:$dst,
- typeAttr:$element_type,
- // TODO(benvanik): remove once unsigned types are in MLIR.
- UnitAttr:$forceUnsigned
- );
-
- let assemblyFormat = "$lhs`,` $rhs`,` `out` $dst attr-dict `:` $element_type";
-}
-
-class VMLA_TernaryOp<string mnemonic, Attr typeAttr, list<OpTrait> traits = []>
- : VMLA_ElementTypeOp<mnemonic, traits> {
- let arguments = (ins
- VMLA_Buffer:$a,
- VMLA_Buffer:$b,
- VMLA_Buffer:$c,
- VMLA_Buffer:$dst,
- typeAttr:$element_type,
- // TODO(benvanik): remove once unsigned types are in MLIR.
- UnitAttr:$forceUnsigned
- );
-
- let assemblyFormat = "$a`,` $b`,` $c`,` `out` $dst attr-dict `:` $element_type";
-}
-
-#endif // IREE_DIALECT_VMLA_BASE
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLADialect.cpp b/iree/compiler/Dialect/VMLA/IR/VMLADialect.cpp
deleted file mode 100644
index 2fbc4a0..0000000
--- a/iree/compiler/Dialect/VMLA/IR/VMLADialect.cpp
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
-
-#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "iree/compiler/Dialect/VMLA/vmla.imports.h"
-#include "llvm/Support/SourceMgr.h"
-#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/Parser.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace VMLA {
-
-namespace {
-
-class VMLAToVMConversionInterface : public VMConversionDialectInterface {
- public:
- using VMConversionDialectInterface::VMConversionDialectInterface;
-
- OwningModuleRef parseVMImportModule() const override {
- return mlir::parseSourceString(StringRef(iree_vmla_imports_create()->data,
- iree_vmla_imports_create()->size),
- getDialect()->getContext());
- }
-
- void populateVMConversionPatterns(
- SymbolTable &importSymbols, OwningRewritePatternList &patterns,
- TypeConverter &typeConverter) const override {
- populateVMLAToVMPatterns(getDialect()->getContext(), typeConverter,
- importSymbols, patterns);
- }
-};
-
-} // namespace
-
-VMLADialect::VMLADialect(MLIRContext *context)
- : Dialect(getDialectNamespace(), context, TypeID::get<VMLADialect>()) {
- addInterfaces<VMLAToVMConversionInterface>();
-
- addTypes<BufferType, InterfaceType>();
-
-#define GET_OP_LIST
- addOperations<
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.cpp.inc"
- >();
-}
-
-//===----------------------------------------------------------------------===//
-// Type printing and parsing
-//===----------------------------------------------------------------------===//
-
-Type VMLADialect::parseType(DialectAsmParser &parser) const {
- StringRef typeName;
- if (parser.parseKeyword(&typeName)) return Type();
- auto type = llvm::StringSwitch<Type>(typeName)
- .Case("buffer", BufferType::get(getContext()))
- .Case("interface", InterfaceType::get(getContext()))
- .Default(nullptr);
- if (!type) {
- parser.emitError(parser.getCurrentLocation())
- << "unknown VMLA type: " << typeName;
- }
- return type;
-}
-
-void VMLADialect::printType(Type type, DialectAsmPrinter &p) const {
- if (type.isa<BufferType>()) {
- p << "buffer";
- } else if (type.isa<InterfaceType>()) {
- p << "interface";
- } else {
- llvm_unreachable("unknown VMLA type");
- }
-}
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLADialect.h b/iree/compiler/Dialect/VMLA/IR/VMLADialect.h
deleted file mode 100644
index c567b66..0000000
--- a/iree/compiler/Dialect/VMLA/IR/VMLADialect.h
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_VMLA_IR_VMLADIALECT_H_
-#define IREE_COMPILER_DIALECT_VMLA_IR_VMLADIALECT_H_
-
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace VMLA {
-
-class VMLADialect : public Dialect {
- public:
- explicit VMLADialect(MLIRContext *context);
- static StringRef getDialectNamespace() { return "vmla"; }
-
- Type parseType(DialectAsmParser &parser) const override;
- void printType(Type type, DialectAsmPrinter &p) const override;
-};
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_VMLA_IR_VMLADIALECT_H_
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.cpp b/iree/compiler/Dialect/VMLA/IR/VMLAOps.cpp
deleted file mode 100644
index eefc949..0000000
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.cpp
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/SMLoc.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/IR/TypeUtilities.h"
-
-//===----------------------------------------------------------------------===//
-// TableGen definitions (intentionally last)
-//===----------------------------------------------------------------------===//
-
-#define GET_OP_CLASSES
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.cpp.inc"
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.h b/iree/compiler/Dialect/VMLA/IR/VMLAOps.h
deleted file mode 100644
index cf14f85..0000000
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.h
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_VMLA_IR_VMLAOPS_H_
-#define IREE_COMPILER_DIALECT_VMLA_IR_VMLAOPS_H_
-
-#include <cstdint>
-
-#include "iree/compiler/Dialect/IREE/IR/IREETraits.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATraits.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-
-#define GET_OP_CLASSES
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h.inc"
-
-#endif // IREE_COMPILER_DIALECT_VMLA_IR_VMLAOPS_H_
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
deleted file mode 100644
index 9642297..0000000
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ /dev/null
@@ -1,845 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_DIALECT_VMLA_OPS
-#define IREE_DIALECT_VMLA_OPS
-
-include "iree/compiler/Dialect/VMLA/IR/VMLABase.td"
-include "mlir/IR/OpAsmInterface.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-
-class VMLA_PureOp<string mnemonic, list<OpTrait> traits = []> :
- VMLA_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: pseudo ops
-//===----------------------------------------------------------------------===//
-
-def VMLA_ConstantOp : VMLA_PureOp<"constant"> {
- let summary = [{constant buffer declaration}];
- let description = [{
- A pseudo-op used to represent a buffer with constant contents. This is later
- expanded into VM ops and the vmla.buffer.const op.
- }];
-
- let arguments = (ins
- ElementsAttr:$value
- );
- let results = (outs
- VMLA_Buffer:$result
- );
-
- let builders = [
- OpBuilder<(ins "ElementsAttr":$value),
- [{
- build($_builder, $_state, IREE::VMLA::BufferType::get($_builder.getContext()),
- value);
- }]>,
- ];
-
- let assemblyFormat = "attr-dict $value `->` type($result)";
-}
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: buffer manipulation
-//===----------------------------------------------------------------------===//
-
-def VMLA_BufferConstOp : VMLA_PureOp<"buffer.const"> {
- let arguments = (ins
- VMLA_HostBuffer:$value
- );
- let results = (outs
- VMLA_Buffer:$result
- );
-
- let assemblyFormat = "$value attr-dict `:` type($value) `->` type($result)";
-}
-
-def VMLA_BufferAllocOp : VMLA_Op<"buffer.alloc"> {
- let arguments = (ins
- VMLA_DeviceSize:$byte_length
- );
- let results = (outs
- VMLA_Buffer:$result
- );
-
- let assemblyFormat = [{
- `byte_length` `=` $byte_length attr-dict `:` type($result)
- }];
-}
-
-def VMLA_BufferCloneOp : VMLA_Op<"buffer.clone"> {
- let arguments = (ins
- VMLA_Buffer:$src
- );
- let results = (outs
- VMLA_Buffer:$result
- );
-
- let assemblyFormat = "$src attr-dict `:` type($result)";
-}
-
-def VMLA_BufferByteLengthOp : VMLA_PureOp<"buffer.byte_length"> {
- let arguments = (ins
- VMLA_Buffer:$value
- );
- let results = (outs
- VMLA_DeviceSize:$result
- );
-
- let assemblyFormat = "$value attr-dict `:` type($result)";
-}
-
-def VMLA_BufferViewOp : VMLA_PureOp<"buffer.view"> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_DeviceSize:$byte_offset,
- VMLA_DeviceSize:$byte_length
- );
- let results = (outs
- VMLA_Buffer:$result
- );
-
- let assemblyFormat = [{
- $src`[`$byte_offset`]``,` `byte_length` `=` $byte_length
- attr-dict `:` type($result)
- }];
-}
-
-def VMLA_BufferCopyOp : VMLA_Op<"buffer.copy"> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_DeviceSize:$src_byte_offset,
- VMLA_Buffer:$dst,
- VMLA_DeviceSize:$dst_byte_offset,
- VMLA_DeviceSize:$byte_length
- );
-
- let assemblyFormat = [{
- $src`[`$src_byte_offset`]``,`
- `out` $dst`[`$dst_byte_offset`]``,` `byte_length` `=` $byte_length
- attr-dict
- }];
-}
-
-def VMLA_BufferFillOp : VMLA_Op<"buffer.fill"> {
- let arguments = (ins
- VMLA_Buffer:$value,
- VMLA_Buffer:$dst
- );
-
- let assemblyFormat = "$value`,` `out` $dst attr-dict";
-}
-
-def VMLA_BufferLoadI32Op : VMLA_PureOp<"buffer.load.i32"> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_DeviceSize:$byte_offset
- );
- let results = (outs
- I32:$result
- );
-
- let assemblyFormat = "$src`[`$byte_offset`]` attr-dict `:` type($result)";
-}
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: comparison
-//===----------------------------------------------------------------------===//
-
-def VMLA_CmpOp : VMLA_ElementTypeOp<"cmp"> {
- let arguments = (ins
- VMLA_CmpPredicateAttr:$predicate,
- VMLA_Buffer:$lhs,
- VMLA_Buffer:$rhs,
- VMLA_Buffer:$dst,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $predicate`,` $lhs`,` $rhs`,` `out` $dst attr-dict `:` $element_type
- }];
-}
-
-def VMLA_SelectOp : VMLA_ElementTypeOp<"select"> {
- let arguments = (ins
- VMLA_Buffer:$cond,
- VMLA_Buffer:$lhs,
- VMLA_Buffer:$rhs,
- VMLA_Buffer:$dst,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $cond`,` $lhs`,` $rhs`,` `out` $dst attr-dict `:` $element_type
- }];
-}
-
-def VMLA_FiniteOp : VMLA_Op<"finite", [VMLA_OpInterface]> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Buffer:$dst,
- VMLA_FloatTypeAttr:$element_type
- );
-
- let assemblyFormat = "$src`,` `out` $dst attr-dict `:` $element_type";
-}
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: shape/structure
-//===----------------------------------------------------------------------===//
-
-def VMLA_CopyOp : VMLA_ElementTypeOp<"copy", [
- VMLA_IncludeShapes,
- SameVariadicOperandSize,
- ]> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Shape:$src_shape,
- Variadic<VMLA_Index>:$src_indices,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- Variadic<VMLA_Index>:$dst_indices,
- Variadic<VMLA_Index>:$lengths,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $src`(`$src_shape `:` type($src_shape)`)``,`
- (`src_indices` `=` `[` $src_indices^ `]``,`)?
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)`
- (`,` `dst_indices` `=` `[` $dst_indices^ `]`)?
- (`,` `lengths` `=` `[` $lengths^ `]`)? attr-dict `:` $element_type
- }];
-}
-
-def VMLA_TransposeOp : VMLA_ElementTypeOp<"transpose", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Shape:$src_shape,
- ElementsAttr:$permutation,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $src`(`$src_shape `:` type($src_shape)`)``,`
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)` attr-dict `:` $element_type
- }];
-}
-
-def VMLA_ReverseOp : VMLA_ElementTypeOp<"reverse", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Shape:$src_shape,
- ElementsAttr:$dimensions,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $src`(`$src_shape `:` type($src_shape)`)``,`
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)` attr-dict `:` $element_type
- }];
-}
-
-def VMLA_PadOp : VMLA_ElementTypeOp<"pad", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Shape:$src_shape,
- VMLA_Buffer:$value,
- VMLA_Shape:$value_shape,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- ElementsAttr:$edge_padding_low,
- ElementsAttr:$edge_padding_high,
- ElementsAttr:$interior_padding,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $src`(`$src_shape `:` type($src_shape)`)``,`
- $value`(`$value_shape `:` type($value_shape)`)``,`
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)` attr-dict `:` $element_type
- }];
-}
-
-def VMLA_BroadcastOp : VMLA_ElementTypeOp<"broadcast", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Shape:$src_shape,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $src`(`$src_shape `:` type($src_shape)`)``,`
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)` attr-dict `:` $element_type
- }];
-}
-
-def VMLA_IotaOp : VMLA_ElementTypeOp<"iota"> {
- let arguments = (ins
- VMLA_Buffer:$dst,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- `out` $dst attr-dict `:` $element_type
- }];
-}
-
-def VMLA_TileOp : VMLA_ElementTypeOp<"tile", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Shape:$src_shape,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $src`(`$src_shape `:` type($src_shape)`)``,`
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)` attr-dict `:` $element_type
- }];
-}
-
-def VMLA_GatherOp : VMLA_ElementTypeOp<"gather", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Shape:$src_shape,
- VMLA_Buffer:$indices,
- VMLA_Shape:$indices_shape,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- I64Attr:$dim,
- I64Attr:$batch_dims,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $src`(`$src_shape `:` type($src_shape)`)``,`
- $indices`(`$indices_shape `:` type($indices_shape)`)``,`
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)` attr-dict `:` $element_type
- }];
-}
-
-def VMLA_ScatterOp : VMLA_ElementTypeOp<"scatter", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Shape:$src_shape,
- VMLA_Buffer:$indices,
- VMLA_Shape:$indices_shape,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $src`(`$src_shape `:` type($src_shape)`)``,`
- $indices`(`$indices_shape `:` type($indices_shape)`)``,`
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)` attr-dict `:` $element_type
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: bit manipulation
-//===----------------------------------------------------------------------===//
-
-def VMLA_NotOp : VMLA_UnaryOp<"not", VMLA_AnyTypeAttr>;
-def VMLA_AndOp : VMLA_BinaryOp<"and", VMLA_AnyTypeAttr>;
-def VMLA_AndBroadcastOp : VMLA_BinaryBroadcastOp<"and.broadcast", VMLA_AnyTypeAttr>;
-def VMLA_OrOp : VMLA_BinaryOp<"or", VMLA_AnyTypeAttr>;
-def VMLA_XorOp : VMLA_BinaryOp<"xor", VMLA_AnyTypeAttr>;
-def VMLA_XorBroadcastOp : VMLA_BinaryBroadcastOp<"xor.broadcast", VMLA_AnyTypeAttr>;
-def VMLA_ShlOp : VMLA_BinaryOp<"shl", VMLA_AnyTypeAttr>;
-def VMLA_ShrOp : VMLA_BinaryOp<"shr", VMLA_AnyTypeAttr>;
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: arithmetic
-//===----------------------------------------------------------------------===//
-
-def VMLA_AddOp : VMLA_BinaryOp<"add", VMLA_AnyTypeAttr>;
-def VMLA_SubOp : VMLA_BinaryOp<"sub", VMLA_AnyTypeAttr>;
-def VMLA_AbsOp : VMLA_UnaryOp<"abs", VMLA_AnyTypeAttr>;
-def VMLA_NegOp : VMLA_UnaryOp<"neg", VMLA_AnyTypeAttr>;
-def VMLA_MulOp : VMLA_BinaryOp<"mul", VMLA_AnyTypeAttr>;
-def VMLA_DivOp : VMLA_BinaryOp<"div", VMLA_AnyTypeAttr>;
-def VMLA_RemOp : VMLA_BinaryOp<"rem", VMLA_AnyTypeAttr>;
-def VMLA_PowOp : VMLA_BinaryOp<"pow", VMLA_FloatTypeAttr>;
-def VMLA_ExpOp : VMLA_UnaryOp<"exp", VMLA_FloatTypeAttr>;
-def VMLA_LogOp : VMLA_UnaryOp<"log", VMLA_FloatTypeAttr>;
-def VMLA_RsqrtOp : VMLA_UnaryOp<"rsqrt", VMLA_FloatTypeAttr>;
-def VMLA_SqrtOp : VMLA_UnaryOp<"sqrt", VMLA_FloatTypeAttr>;
-def VMLA_CosOp : VMLA_UnaryOp<"cos", VMLA_FloatTypeAttr>;
-def VMLA_SinOp : VMLA_UnaryOp<"sin", VMLA_FloatTypeAttr>;
-def VMLA_TanhOp : VMLA_UnaryOp<"tanh", VMLA_FloatTypeAttr>;
-def VMLA_Atan2Op : VMLA_BinaryOp<"atan2", VMLA_FloatTypeAttr>;
-
-def VMLA_MinOp : VMLA_BinaryOp<"min", VMLA_AnyTypeAttr>;
-def VMLA_MaxOp : VMLA_BinaryOp<"max", VMLA_AnyTypeAttr>;
-def VMLA_ClampOp : VMLA_TernaryOp<"clamp", VMLA_AnyTypeAttr>;
-def VMLA_FloorOp : VMLA_UnaryOp<"floor", VMLA_FloatTypeAttr>;
-def VMLA_CeilOp : VMLA_UnaryOp<"ceil", VMLA_FloatTypeAttr>;
-def VMLA_RoundOp : VMLA_UnaryOp<"round", VMLA_FloatTypeAttr>;
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: conversion
-//===----------------------------------------------------------------------===//
-
-def VMLA_ConvertOp : VMLA_Op<"convert", [VMLA_OpInterface]> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Buffer:$dst,
- VMLA_AnyTypeAttr:$src_type,
- VMLA_AnyTypeAttr:$dst_type
- );
-
- let extraClassDeclaration = [{
- static void extractTypeAttributes(OperationState &state, ArrayRef<Type> operandTypes, ArrayRef<Type> resultTypes) {
- state.addAttribute("src_type", TypeAttr::get(operandTypes[0]));
- state.addAttribute("dst_type", TypeAttr::get(resultTypes[0]));
- }
- }];
-
- let assemblyFormat = [{
- $src`,` `out` $dst attr-dict `:` $src_type `->` $dst_type
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: Convolution
-//===----------------------------------------------------------------------===//
-
-// Handles both Convolutions and Transpose convolutions.
-def VLMA_ConvOp : VMLA_Op<"conv", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$input,
- VMLA_Shape:$input_shape,
- VMLA_Buffer:$filter,
- VMLA_Shape:$filter_shape,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- I32ElementsAttr:$window_strides,
- I32ElementsAttr:$padding,
- I32ElementsAttr:$lhs_dilation,
- I32ElementsAttr:$rhs_dilation,
- I32Attr:$feature_group_count,
- I32Attr:$batch_group_count,
- VMLA_FloatTypeAttr:$input_type,
- VMLA_FloatTypeAttr:$filter_type,
- VMLA_FloatTypeAttr:$dst_type
- );
-
- let extraClassDeclaration = [{
- static void extractTypeAttributes(OperationState &state, ArrayRef<Type> operandTypes, ArrayRef<Type> resultTypes) {
- state.addAttribute("input_type", TypeAttr::get(operandTypes[0]));
- state.addAttribute("filter_type", TypeAttr::get(operandTypes[1]));
- state.addAttribute("dst_type", TypeAttr::get(resultTypes[0]));
- }
- }];
-
- let assemblyFormat = [{
- $input`(`$input_shape `:` type($input_shape)`)` `:` $input_type`,`
- $filter`(`$filter_shape `:` type($filter_shape)`)` `:` $filter_type`,`
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)` `:` $dst_type attr-dict
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: Sorting
-//===----------------------------------------------------------------------===//
-
-def VMLA_SortPseudoOp : VMLA_Op<"sort.pseudo"> {
- let summary = "Tensor-level pseudo-op of VMLA::SortOp.";
- let description = [{
- This is a tensor-level version of VMLA::SortOp, to facilitate
- the lowering process.
-
- This operation generates a sorted index list along the last dimension,
- performing batch-wise along all other dimensions.
- }];
- let arguments = (ins
- AnyTensor:$value
- );
- let results = (outs
- I32Tensor:$dst
- );
-
- let assemblyFormat = [{
- $value attr-dict `:` `(`type($value)`)` `->` type($dst)
- }];
-}
-
-def VMLA_SortOp : VMLA_ElementTypeOp<"sort", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Shape:$src_shape,
- VMLA_Buffer:$dst,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $src`(`$src_shape `:` type($src_shape)`)``,`
- `out` $dst attr-dict `:` $element_type
- }];
-}
-
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: GEMM/GEMV
-//===----------------------------------------------------------------------===//
-
-def VMLA_BatchMatMulPseudoOp : VMLA_Op<"batch.matmul.pseudo"> {
- let summary = "Tensor-level pseudo-op of VMLA::BatchMatMulOp.";
- let description = [{
- This is a tensor-level version of VMLA::BatchMatMulOp, to facilitate
- the lowering process.
-
- All operands are rank-3 with the following dimension structure:
- - lhs = [B, FLHS, C]
- - rhs = [B, FRHS, C]
- - dst = [B, FRHS, FLHS]
- Where:
- - B = batch dimension
- - C = contracting dimension
- - FLHS and FRHS are the free dimensions of each operand
-
- To put this in terms closer to the mathematics of matrix multiplication,
- if we ignore the leading B dimension and focus on what is mathematically an
- MxKxN matmul, then this corresponds to:
- - lhs = [M, K] = [LHSROWS, K]
- - rhs = [N, K] = [RHSCOLS, K]
- - dst = [N, M] = [RHSCOLS, LHSROWS]
- Note that dst is transposed from what one would expect.
- This is due to an implementation detail of this op in the runtime.
- This op is backed by an invocation of the Ruy matrix multiplication library,
- which prefers its matrices in this layout (in matrix terminology:
- lhs = row-major, rhs = column-major, dst = column-major).
- We insert the relevant transposes as needed in the compiler.
- }];
- let arguments = (ins
- AnyTensor:$lhs,
- AnyTensor:$rhs
- );
- let results = (outs
- AnyTensor:$dst
- );
-
- let assemblyFormat = [{
- $lhs`,` $rhs attr-dict `:`
- `(`type($lhs)`,` type($rhs)`)` `->` type($dst)
- }];
-}
-
-def VMLA_BatchMatMulOp : VMLA_Op<"batch.matmul", [VMLA_OpInterface, VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$lhs,
- VMLA_Shape:$lhs_shape,
- VMLA_Buffer:$rhs,
- VMLA_Shape:$rhs_shape,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- VMLA_FloatTypeAttr:$lhs_type,
- VMLA_FloatTypeAttr:$rhs_type,
- VMLA_FloatTypeAttr:$dst_type
- );
-
- let extraClassDeclaration = [{
- static void extractTypeAttributes(OperationState &state, ArrayRef<Type> operandTypes, ArrayRef<Type> resultTypes) {
- state.addAttribute("lhs_type", TypeAttr::get(operandTypes[0]));
- state.addAttribute("rhs_type", TypeAttr::get(operandTypes[1]));
- state.addAttribute("dst_type", TypeAttr::get(resultTypes[0]));
- }
- }];
-
- let assemblyFormat = [{
- $lhs`(`$lhs_shape `:` type($lhs_shape)`)` `:` $lhs_type`,`
- $rhs`(`$rhs_shape `:` type($rhs_shape)`)` `:` $rhs_type`,`
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)` `:` $dst_type attr-dict
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: reduction
-//===----------------------------------------------------------------------===//
-
-class VMLA_ReduceOp<string mnemonic, list<OpTrait> traits = []> :
- VMLA_ElementTypeOp<mnemonic, !listconcat(traits, [VMLA_IncludeShapes])> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Shape:$src_shape,
- VMLA_Buffer:$init,
- VMLA_Shape:$init_shape,
- I32Attr:$dimension,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $src`(`$src_shape `:` type($src_shape)`)``,`
- $init`(`$init_shape `:` type($init_shape)`)``,`
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)` attr-dict `:` $element_type
- }];
-}
-
-def VMLA_ReduceSumOp : VMLA_ReduceOp<"reduce.sum">;
-def VMLA_ReduceMinOp : VMLA_ReduceOp<"reduce.min">;
-def VMLA_ReduceMaxOp : VMLA_ReduceOp<"reduce.max">;
-def VMLA_ReduceAndOp : VMLA_ReduceOp<"reduce.and">;
-def VMLA_ReduceOrOp : VMLA_ReduceOp<"reduce.or">;
-
-class VMLA_PoolingOp<string mnemonic, list<OpTrait> traits = []> :
- VMLA_ElementTypeOp<mnemonic, !listconcat(traits, [VMLA_IncludeShapes])> {
- let arguments = (ins
- VMLA_Buffer:$src,
- VMLA_Shape:$src_shape,
- VMLA_Buffer:$init,
- VMLA_Shape:$init_shape,
- VMLA_Buffer:$dst,
- VMLA_Shape:$dst_shape,
- VMLA_AnyTypeAttr:$element_type,
- I32ElementsAttr:$window_dimensions,
- I32ElementsAttr:$window_strides,
- I32ElementsAttr:$padding
- );
-
- let assemblyFormat = [{
- $src`(`$src_shape `:` type($src_shape)`)``,`
- $init`(`$init_shape `:` type($init_shape)`)``,`
- `out` $dst`(`$dst_shape `:` type($dst_shape)`)` attr-dict `:` $element_type
- }];
-}
-
-def VMLA_PoolingSumOp : VMLA_PoolingOp<"pooling.sum">;
-def VMLA_PoolingMinOp : VMLA_PoolingOp<"pooling.min">;
-def VMLA_PoolingMaxOp : VMLA_PoolingOp<"pooling.max">;
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: ABI
-//===----------------------------------------------------------------------===//
-
-def VMLA_InterfaceConstOp :
- VMLA_PureOp<"interface.const", [VMLA_OpInterface]> {
- let arguments = (ins
- VMLA_Interface:$interface,
- IREE_IndexAttr:$offset
- );
- let results = (outs
- AnyTypeOf<[I32, VMLA_Index]>:$result
- );
-
- let assemblyFormat = [{
- $interface attr-dict `:` type($result)
- }];
-}
-
-def VMLA_InterfaceBindingOp :
- VMLA_PureOp<"interface.binding", [VMLA_OpInterface]> {
- let arguments = (ins
- VMLA_Interface:$interface,
- I32Attr:$set,
- I32Attr:$binding
- );
- let results = (outs
- VMLA_Buffer:$result
- );
-
- let assemblyFormat = [{
- $interface attr-dict `:` type($result)
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: FFT
-//===----------------------------------------------------------------------===//
-
-def VMLA_FftPseudoOp : VMLA_Op<"fft.pseudo"> {
- let summary = "pseudo-op of VMLA::FftOp.";
- let description = [{
- This is a tensor-level version of VMLA::FftOp, to facilitate
- the lowering process.
-
- The op takes two tensors, representing the real and imaginary components of
- a complex number and performs a Fast Fourier transform
- (https://en.wikipedia.org/wiki/Fast_Fourier_transform) on them. It returns
- two tensors as output, since the output of an FFT is also a complex number.
-
- }];
- let arguments = (ins
- AnyTensor:$real_in,
- AnyTensor:$imag_in
- );
- let results = (outs
- AnyTensor:$real_out,
- AnyTensor:$imag_out
- );
-
- let assemblyFormat = [{
- $real_in`,` $imag_in attr-dict `:` `(`type($real_in)`,` type($imag_in)`)`
- `->` `(`type($real_out)`,` type($imag_out)`)`
- }];
-
-}
-
-def VMLA_IfftPseudoOp : VMLA_Op<"ifft.pseudo"> {
- let summary = "pseudo-op of VMLA::IfftOp.";
- let description = [{
- This is a tensor-level version of VMLA::IfftOp, to facilitate
- the lowering process.
-
- The op takes two tensors, representing the real and imaginary components of
- a complex number and performs an Inverse Fast Fourier transform
- (https://en.wikipedia.org/wiki/Fast_Fourier_transform) on them. It returns
- two tensors as output, since the output of an IFFT is also a complex number.
- }];
- let arguments = (ins
- AnyTensor:$real_in,
- AnyTensor:$imag_in
- );
- let results = (outs
- AnyTensor:$real_out,
- AnyTensor:$imag_out
- );
-
- let assemblyFormat = [{
- $real_in`,` $imag_in attr-dict `:` `(`type($real_in)`,` type($imag_in)`)`
- `->` `(`type($real_out)`,` type($imag_out)`)`
- }];
-}
-
-def VMLA_RfftPseudoOp : VMLA_Op<"rfft.pseudo"> {
- let summary = "pseudo-op of VMLA::RfftOp.";
- let description = [{
- This is a tensor-level version of VMLA::RfftOp, to facilitate
- the lowering process.
-
- The op takes a tensor and performs a Real Fast Fourier
- (https://en.wikipedia.org/wiki/Fast_Fourier_transform) on them. It returns
- two tensors as output, since the output of an RFFT is also a complex number.
- }];
- let arguments = (ins
- AnyTensor:$real_in
- );
- let results = (outs
- AnyTensor:$real_out,
- AnyTensor:$imag_out
- );
-
- let assemblyFormat = [{
- $real_in attr-dict `:` `(`type($real_in)`)`
- `->` `(`type($real_out)`,` type($imag_out)`)`
- }];
-}
-
-def VMLA_IrfftPseudoOp : VMLA_Op<"irfft.pseudo"> {
- let summary = "pseudo-op of VMLA::IrfftOp.";
- let description = [{
- This is a tensor-level version of VMLA::IrfftOp, to facilitate
- the lowering process.
-
- The op takes two tensors, representing the real and imaginary components of
- a complex number and performs an Inverse Real Fast Fourier transform
- (https://en.wikipedia.org/wiki/Fast_Fourier_transform) on them. It returns a
- single tensor as output, since the output of an IRFFT is a set of real
- numbers.
- }];
- let arguments = (ins
- AnyTensor:$real_in,
- AnyTensor:$imag_in
- );
- let results = (outs
- AnyTensor:$real_out
- );
-
- let assemblyFormat = [{
- $real_in`,` $imag_in attr-dict `:` `(`type($real_in)`,` type($imag_in)`)`
- `->` `(`type($real_out)`)`
- }];
-}
-
-def VMLA_FftOp : VMLA_ElementTypeOp<"fft", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$real_in,
- VMLA_Shape:$real_in_shape,
- VMLA_Buffer:$imag_in,
- VMLA_Shape:$imag_in_shape,
- VMLA_Buffer:$real_out,
- VMLA_Buffer:$imag_out,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $real_in`(`$real_in_shape `:` type($real_in_shape)`)` `,`
- $imag_in`(`$imag_in_shape `:` type($imag_in_shape)`)` `,`
- `out` $real_out `,` $imag_out attr-dict `:` $element_type
- }];
-}
-
-def VMLA_IfftOp : VMLA_ElementTypeOp<"ifft", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$real_in,
- VMLA_Shape:$real_in_shape,
- VMLA_Buffer:$imag_in,
- VMLA_Shape:$imag_in_shape,
- VMLA_Buffer:$real_out,
- VMLA_Buffer:$imag_out,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $real_in`(`$real_in_shape `:` type($real_in_shape)`)` `,`
- $imag_in`(`$imag_in_shape `:` type($imag_in_shape)`)` `,`
- `out` $real_out `,` $imag_out attr-dict `:` $element_type
- }];
-}
-
-def VMLA_RfftOp : VMLA_ElementTypeOp<"rfft", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$real_in,
- VMLA_Shape:$real_in_shape,
- VMLA_Buffer:$real_out,
- VMLA_Buffer:$imag_out,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $real_in`(`$real_in_shape `:` type($real_in_shape)`)` `,`
- `out` $real_out `,` $imag_out attr-dict `:` $element_type
- }];
-}
-
-def VMLA_IrfftOp : VMLA_ElementTypeOp<"irfft", [VMLA_IncludeShapes]> {
- let arguments = (ins
- VMLA_Buffer:$real_in,
- VMLA_Shape:$real_in_shape,
- VMLA_Buffer:$imag_in,
- VMLA_Shape:$imag_in_shape,
- VMLA_Buffer:$real_out,
- VMLA_AnyTypeAttr:$element_type
- );
-
- let assemblyFormat = [{
- $real_in`(`$real_in_shape `:` type($real_in_shape)`)` `,`
- $imag_in`(`$imag_in_shape `:` type($imag_in_shape)`)` `,`
- `out` $real_out attr-dict `:` $element_type
- }];
-}
-
-
-#endif // IREE_DIALECT_VMLA_OPS
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLATraits.h b/iree/compiler/Dialect/VMLA/IR/VMLATraits.h
deleted file mode 100644
index 410779b..0000000
--- a/iree/compiler/Dialect/VMLA/IR/VMLATraits.h
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_VMLA_IR_VMLATRAITS_H_
-#define IREE_COMPILER_DIALECT_VMLA_IR_VMLATRAITS_H_
-
-#include "mlir/IR/OpDefinition.h"
-
-namespace mlir {
-namespace OpTrait {
-namespace IREE {
-namespace VMLA {
-
-template <typename ConcreteType>
-class IncludeShapes : public OpTrait::TraitBase<ConcreteType, IncludeShapes> {
- public:
- static LogicalResult verifyTrait(Operation *op) { return success(); }
-};
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace OpTrait
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_VMLA_IR_VMLATRAITS_H_
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLATypes.cpp b/iree/compiler/Dialect/VMLA/IR/VMLATypes.cpp
deleted file mode 100644
index 786c5f3..0000000
--- a/iree/compiler/Dialect/VMLA/IR/VMLATypes.cpp
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-
-#include "llvm/ADT/StringExtras.h"
-
-// Order matters:
-#include "iree/compiler/Dialect/VMLA/IR/VMLAEnums.cpp.inc"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace VMLA {
-
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOpInterface.cpp.inc"
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLATypes.h b/iree/compiler/Dialect/VMLA/IR/VMLATypes.h
deleted file mode 100644
index 0f410c3..0000000
--- a/iree/compiler/Dialect/VMLA/IR/VMLATypes.h
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_VMLA_IR_VMLATYPES_H_
-#define IREE_COMPILER_DIALECT_VMLA_IR_VMLATYPES_H_
-
-#include <cstdint>
-
-#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseMapInfo.h"
-#include "llvm/ADT/Optional.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringSwitch.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/TypeSupport.h"
-#include "mlir/IR/Types.h"
-#include "mlir/Support/LLVM.h"
-
-// Order matters.
-#include "iree/compiler/Dialect/VMLA/IR/VMLAEnums.h.inc"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace VMLA {
-
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOpInterface.h.inc"
-
-//===----------------------------------------------------------------------===//
-// RefObject types
-//===----------------------------------------------------------------------===//
-
-class BufferType : public Type::TypeBase<BufferType, Type, TypeStorage> {
- public:
- using Base::Base;
-};
-
-class InterfaceType : public Type::TypeBase<InterfaceType, Type, TypeStorage> {
- public:
- using Base::Base;
-};
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_VMLA_IR_VMLATYPES_H_
diff --git a/iree/compiler/Dialect/VMLA/IR/test/BUILD b/iree/compiler/Dialect/VMLA/IR/test/BUILD
deleted file mode 100644
index 228bc60..0000000
--- a/iree/compiler/Dialect/VMLA/IR/test/BUILD
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:lit_test.bzl", "iree_lit_test_suite")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_lit_test_suite(
- name = "lit",
- srcs = enforce_glob(
- [
- "buffer_ops.mlir",
- "conv_reduction_ops.mlir",
- "general_ops.mlir",
- "shape_structure_ops.mlir",
- ],
- include = ["*.mlir"],
- ),
- data = [
- "//iree/tools:IreeFileCheck",
- "//iree/tools:iree-opt",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/IR/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/IR/test/CMakeLists.txt
deleted file mode 100644
index a392ee0..0000000
--- a/iree/compiler/Dialect/VMLA/IR/test/CMakeLists.txt
+++ /dev/null
@@ -1,26 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/IR/test/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_lit_test_suite(
- NAME
- lit
- SRCS
- "buffer_ops.mlir"
- "conv_reduction_ops.mlir"
- "general_ops.mlir"
- "shape_structure_ops.mlir"
- DATA
- iree::tools::IreeFileCheck
- iree::tools::iree-opt
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/IR/test/buffer_ops.mlir b/iree/compiler/Dialect/VMLA/IR/test/buffer_ops.mlir
deleted file mode 100644
index b83011e..0000000
--- a/iree/compiler/Dialect/VMLA/IR/test/buffer_ops.mlir
+++ /dev/null
@@ -1,99 +0,0 @@
-// Tests the printing/parsing of the VMLA dialect buffer ops.
-
-// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
-
-// CHECK-LABEL: vmla_buffer_const
-// CHECK-SAME: %[[VALUE:[a-zA-Z0-9$._-]+]]
-func @vmla_buffer_const(%value : !iree.byte_buffer) {
- // CHECK: vmla.buffer.const %[[VALUE]] : !iree.byte_buffer -> !vmla.buffer
- %result = vmla.buffer.const %value : !iree.byte_buffer -> !vmla.buffer
- return
-}
-
-// -----
-
-// CHECK-LABEL: vmla_buffer_alloc
-// CHECK-SAME: %[[LENGTH:[a-zA-Z0-9$._-]+]]
-func @vmla_buffer_alloc(%byte_length : index) {
- // CHECK: vmla.buffer.alloc byte_length = %[[LENGTH]] : !vmla.buffer
- %result = vmla.buffer.alloc byte_length = %byte_length : !vmla.buffer
- return
-}
-
-// -----
-
-// CHECK-LABEL: vmla_buffer_clone
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-func @vmla_buffer_clone(%src : !vmla.buffer) {
- // CHECK: vmla.buffer.clone %[[SRC]] : !vmla.buffer
- %result = vmla.buffer.clone %src : !vmla.buffer
- return
-}
-
-// -----
-
-// CHECK-LABEL: vmla_buffer_byte_length
-// CHECK-SAME: %[[VALUE:[a-zA-Z0-9$._-]+]]
-func @vmla_buffer_byte_length(%value : !vmla.buffer) {
- // CHECK: vmla.buffer.byte_length %[[VALUE]] : index
- %result = vmla.buffer.byte_length %value : index
- return
-}
-
-// -----
-
-// CHECK-LABEL: vmla_buffer_view
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[OFFSET:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[LENGTH:[a-zA-Z0-9$._-]+]]
-func @vmla_buffer_view(%src : !vmla.buffer,
- %byte_offset : index,
- %byte_length : index) {
- // CHECK: vmla.buffer.view %[[SRC]][%[[OFFSET]]],
- // CHECK-SAME: byte_length = %[[LENGTH]] : !vmla.buffer
- %result = vmla.buffer.view %src[%byte_offset],
- byte_length = %byte_length : !vmla.buffer
- return
-}
-
-// -----
-
-// CHECK-LABEL: vmla_buffer_copy
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_OFFSET:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_OFFSET:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[LENGTH:[a-zA-Z0-9$._-]+]]
-func @vmla_buffer_copy(%src : !vmla.buffer,
- %src_byte_offset : index,
- %dst : !vmla.buffer,
- %dst_byte_offset : index,
- %byte_length : index) {
- // CHECK: vmla.buffer.copy %[[SRC]][%[[SRC_OFFSET]]],
- // CHECK-SAME: out %[[DST]][%[[DST_OFFSET]]], byte_length = %[[LENGTH]]
- vmla.buffer.copy %src[%src_byte_offset],
- out %dst[%dst_byte_offset], byte_length = %byte_length
- return
-}
-
-// -----
-
-// CHECK-LABEL: vmla_buffer_fill
-// CHECK-SAME: %[[VALUE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-func @vmla_buffer_fill(%src : !vmla.buffer, %dst : !vmla.buffer) {
- // CHECK: vmla.buffer.fill %[[VALUE]], out %[[DST]]
- vmla.buffer.fill %src, out %dst
- return
-}
-
-// -----
-
-// CHECK-LABEL: vmla_buffer_load_i32
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[OFFSET:[a-zA-Z0-9$._-]+]]
-func @vmla_buffer_load_i32(%src : !vmla.buffer, %byte_offset : index) {
- // CHECK: vmla.buffer.load.i32 %[[SRC]][%[[OFFSET]]] : i32
- %result = vmla.buffer.load.i32 %src[%byte_offset] : i32
- return
-}
diff --git a/iree/compiler/Dialect/VMLA/IR/test/conv_reduction_ops.mlir b/iree/compiler/Dialect/VMLA/IR/test/conv_reduction_ops.mlir
deleted file mode 100644
index 2e2b9fc..0000000
--- a/iree/compiler/Dialect/VMLA/IR/test/conv_reduction_ops.mlir
+++ /dev/null
@@ -1,97 +0,0 @@
-// Tests the printing/parsing of the VMLA dialect ops.
-
-// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
-
-// CHECK-LABEL: @vmla_conv
-// CHECK-SAME: %[[INPUT:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[INPUT_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[FILTER:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[FILTER_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-func @vmla_conv(%input : !vmla.buffer,
- %input_shape : !shapex.ranked_shape<[1,4,5,2]>,
- %filter : !vmla.buffer,
- %filter_shape : !shapex.ranked_shape<[3,2,2,1]>,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[1,2,3,1]>) {
- // CHECK: vmla.conv
- // CHECK-SAME: %[[INPUT]](%[[INPUT_SHAPE]] :
- // CHECK-SAME: !shapex.ranked_shape<[1,4,5,2]>) : f16,
- // CHECK-SAME: %[[FILTER]](%[[FILTER_SHAPE]] :
- // CHECK-SAME: !shapex.ranked_shape<[3,2,2,1]>) : f16,
- // CHECK-SAME: out %[[DST]](%[[DST_SHAPE]] :
- // CHECK-SAME: !shapex.ranked_shape<[1,2,3,1]>) : f16
- // CHECK-SAME: {batch_group_count = 1 : i32,
- // CHECK-SAME: feature_group_count = 1 : i32,
- // CHECK-SAME: lhs_dilation = dense<1> : vector<2xi32>,
- // CHECK-SAME: padding = dense<[1, 2, 2, 2]> : vector<4xi32>,
- // CHECK-SAME: rhs_dilation = dense<1> : vector<2xi32>,
- // CHECK-SAME: window_strides = dense<1> : vector<2xi32>}
- vmla.conv %input(%input_shape : !shapex.ranked_shape<[1,4,5,2]>) : f16,
- %filter(%filter_shape : !shapex.ranked_shape<[3,2,2,1]>) : f16,
- out %dst(%dst_shape : !shapex.ranked_shape<[1,2,3,1]>) : f16
- {batch_group_count = 1 : i32,
- feature_group_count = 1 : i32,
- lhs_dilation = dense<1> : vector<2xi32>,
- padding = dense<[1, 2, 2, 2]> : vector<4xi32>,
- rhs_dilation = dense<1> : vector<2xi32>,
- window_strides = dense<1> : vector<2xi32>}
- return
-}
-
-// CHECK-LABEL: @vmla_reduce
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[INIT:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[INIT_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-func @vmla_reduce(%src : !vmla.buffer,
- %src_shape : !shapex.ranked_shape<[4,8]>,
- %init : !vmla.buffer,
- %init_shape : !shapex.ranked_shape<[]>,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[4]>) {
- // CHECK-NEXT: vmla.reduce.sum
- // CEHCK-SAME: %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[4,8]>),
- // CHECK-SAME: %[[INIT]](%[[INIT_SHAPE]] : !shapex.ranked_shape<[]>),
- // CHECK-SAME: out %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[4]>)
- // CHECK-SaME: {dimension = 1 : i32} : f16
- vmla.reduce.sum %src(%src_shape : !shapex.ranked_shape<[4,8]>),
- %init(%init_shape : !shapex.ranked_shape<[]>),
- out %dst(%dst_shape : !shapex.ranked_shape<[4]>)
- {dimension = 1 : i32} : f16
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_pooling
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[INIT:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[INIT_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-func @vmla_pooling(%src : !vmla.buffer,
- %src_shape : !shapex.ranked_shape<[32,32]>,
- %init : !vmla.buffer,
- %init_shape : !shapex.ranked_shape<[]>,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[16,16]>) {
- // CHECK-NEXT: vmla.pooling.min
- // CEHCK-SAME: %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[32,32]>),
- // CHECK-SAME: %[[INIT]](%[[INIT_SHAPE]] : !shapex.ranked_shape<[]>),
- // CHECK-SAME: out %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[16,16]>)
- // CHECK-SAME: {padding = dense<0> : tensor<i32>,
- // CHECK-SaME: window_dimensions = dense<[2,2]> : tensor<2xi32>,
- // CHECK-SaME: window_strides = dense<[2,2]> : tensor<2xi32>} : f16
- vmla.pooling.min %src(%src_shape : !shapex.ranked_shape<[32,32]>),
- %init(%init_shape : !shapex.ranked_shape<[]>),
- out %dst(%dst_shape : !shapex.ranked_shape<[16,16]>)
- {padding = dense<0> : tensor<i32>,
- window_dimensions = dense<[2,2]> : tensor<2xi32>,
- window_strides = dense<[2,2]> : tensor<2xi32>} : f16
- return
-}
diff --git a/iree/compiler/Dialect/VMLA/IR/test/general_ops.mlir b/iree/compiler/Dialect/VMLA/IR/test/general_ops.mlir
deleted file mode 100644
index c249c6d..0000000
--- a/iree/compiler/Dialect/VMLA/IR/test/general_ops.mlir
+++ /dev/null
@@ -1,142 +0,0 @@
-// Tests the printing/parsing of the VMLA dialect ops.
-
-// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
-
-// CHECK-LABEL: @unaryOp
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-func @unaryOp(%src : !vmla.buffer, %dst : !vmla.buffer) {
- // CHECK: vmla.log %[[SRC]], out %[[DST]] : f32
- vmla.log %src, out %dst : f32
- return
-}
-
-// -----
-
-// CHECK-LABEL: @binaryOp
-// CHECK-SAME: %[[LHS:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[RHS:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-func @binaryOp(%lhs : !vmla.buffer, %rhs : !vmla.buffer, %dst : !vmla.buffer) {
- // CHECK: vmla.atan2 %[[LHS]], %[[RHS]], out %[[DST]] : f32
- vmla.atan2 %lhs, %rhs, out %dst : f32
- return
-}
-
-// -----
-
-// CHECK-LABEL: @ternaryOp
-// CHECK-SAME: %[[A:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[B:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[C:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-func @ternaryOp(%a : !vmla.buffer, %b : !vmla.buffer, %c : !vmla.buffer,
- %dst : !vmla.buffer) {
- // CHECK: vmla.clamp %[[A]], %[[B]], %[[C]], out %[[DST]] : f32
- vmla.clamp %a, %b, %c, out %dst : f32
- return
-}
-
-
-// -----
-
-// CHECK-LABEL: @vmla_convert
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-func @vmla_convert(%src : !vmla.buffer, %dst : !vmla.buffer) {
- // CHECK: vmla.convert %[[SRC]], out %[[DST]] : f32 -> i8
- vmla.convert %src, out %dst : f32 -> i8
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_batch_matmul_pseudo
-// CHECK-SAME: %[[LHS:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[RHS:[a-zA-Z0-9$._-]+]]
-func @vmla_batch_matmul_pseudo(%lhs : tensor<32x256x128xf32>,
- %rhs : tensor<32x1x128xf32>) {
- // CHECK: vmla.batch.matmul.pseudo %[[LHS]], %[[RHS]] :
- // CHECK-SAME: (tensor<32x256x128xf32>, tensor<32x1x128xf32>) ->
- // CHECK-SAME: tensor<32x1x256xf32>
- %dst = vmla.batch.matmul.pseudo %lhs, %rhs :
- (tensor<32x256x128xf32>, tensor<32x1x128xf32>) -> tensor<32x1x256xf32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_batch_matmul
-// CHECK-SAME: %[[LHS:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[LHS_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[RHS:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[RHS_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-func @vmla_batch_matmul(%lhs : !vmla.buffer,
- %lhs_shape : !shapex.ranked_shape<[8,4,4]>,
- %rhs : !vmla.buffer,
- %rhs_shape : !shapex.ranked_shape<[8,1,4]>,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[8,1,4]>) {
- // CHECK: vmla.batch.matmul
- // CHECK-SAME: %[[LHS]](%[[LHS_SHAPE]] : !shapex.ranked_shape<[8,4,4]>) : f32,
- // CHECK-SAME: %[[RHS]](%[[RHS_SHAPE]] : !shapex.ranked_shape<[8,1,4]>) : f32,
- // CHECK-SAME: out
- // CHECK-SAME: %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[8,1,4]>) : f32
- vmla.batch.matmul %lhs(%lhs_shape : !shapex.ranked_shape<[8,4,4]>) : f32,
- %rhs(%rhs_shape : !shapex.ranked_shape<[8,1,4]>) : f32,
- out %dst(%dst_shape : !shapex.ranked_shape<[8,1,4]>) : f32
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_cmp
-// CHECK-SAME: %[[LHS:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[RHS:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-func @vmla_cmp(%lhs : !vmla.buffer, %rhs : !vmla.buffer, %dst : !vmla.buffer) {
- // CHECK: vmla.cmp NE, %[[LHS]], %[[RHS]], out %[[DST]] : f16
- vmla.cmp NE, %lhs, %rhs, out %dst : f16
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_select
-// CHECK-SAME: %[[COND:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[LHS:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[RHS:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-func @vmla_select(%cond : !vmla.buffer,
- %lhs : !vmla.buffer,
- %rhs : !vmla.buffer,
- %dst : !vmla.buffer) {
- // CHECK: vmla.select %[[COND]], %[[LHS]], %[[RHS]], out %[[DST]] : f16
- vmla.select %cond, %lhs, %rhs, out %dst : f16
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_interface_const
-// CHECK-SAME: %[[INTERFACE:[a-zA-Z0-9$._-]+]]
-func @vmla_interface_const(%interface : !vmla.interface) {
- // CHECK: vmla.interface.const %[[INTERFACE]]
- // CHECK-SAME: {offset = 3 : index} : i32
- vmla.interface.const %interface {offset = 3 : index} : i32
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_interface_binding
-// CHECK-SAME: %[[INTERFACE:[a-zA-Z0-9$._-]+]]
-func @vmla_interface_binding(%interface : !vmla.interface) {
- // CHECK: vmla.interface.binding %[[INTERFACE]]
- // CHECK-SAME: {binding = 0 : i32, set = 0 : i32} : !vmla.buffer
- vmla.interface.binding %interface
- {binding = 0 : i32, set = 0 : i32} : !vmla.buffer
- return
-}
diff --git a/iree/compiler/Dialect/VMLA/IR/test/shape_structure_ops.mlir b/iree/compiler/Dialect/VMLA/IR/test/shape_structure_ops.mlir
deleted file mode 100644
index 45daa48..0000000
--- a/iree/compiler/Dialect/VMLA/IR/test/shape_structure_ops.mlir
+++ /dev/null
@@ -1,202 +0,0 @@
-// Tests the printing/parsing of the VMLA dialect ops.
-
-// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
-
-// CHECK-LABEL: @vmla_copy
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_INDEX_0:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_INDEX_1:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_INDEX_0:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_INDEX_1:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[LENGTH_0:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[LENGTH_1:[a-zA-Z0-9$._-]+]]
-func @vmla_copy(%src : !vmla.buffer,
- %src_shape : !shapex.ranked_shape<[64]>,
- %src_index_0 : index,
- %src_index_1 : index,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[32]>,
- %dst_index_0 : index,
- %dst_index_1 : index,
- %length_0 : index,
- %length_1 : index) {
- // CHECK: vmla.copy
- // CHECK-SAME: %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[64]>),
- // CHECK-SAME; src_indices = [%[[SRC_INDEX_0]], %[[SRC_INDEX_1]]],
- // CHECK-SAME: out %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[32]>),
- // CHECK-SAME: dst_indices = [%[[DST_INDEX_0]], %[[DST_INDEX_1]]],
- // CHECK-SAME: lengths = [%[[LENGTH_0]], %[[LENGTH_1]]] : i32
- vmla.copy %src(%src_shape : !shapex.ranked_shape<[64]>),
- src_indices = [%src_index_0, %src_index_1],
- out %dst(%dst_shape : !shapex.ranked_shape<[32]>),
- dst_indices = [%dst_index_0, %dst_index_1],
- lengths = [%length_0, %length_1] : i32
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_copy_no_variadic
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-func @vmla_copy_no_variadic(%src : !vmla.buffer,
- %src_shape : !shapex.ranked_shape<[64]>,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[32]>) {
- // CHECK: vmla.copy
- // CHECK-SAME: %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[64]>),
- // CHECK-SAME: out %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[32]>)
- // CHECK-SAME: : i32
- vmla.copy %src(%src_shape : !shapex.ranked_shape<[64]>),
- out %dst(%dst_shape : !shapex.ranked_shape<[32]>) : i32
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_transpose
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-func @vmla_transpose(%src : !vmla.buffer,
- %src_shape : !shapex.ranked_shape<[64,32,32,10]>,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[64,10,32,32]>) {
- // CHECK: vmla.transpose
- // CHECK-SAME: %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[64,32,32,10]>),
- // CHECK-SAME: out
- // CHECK-SAME: %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[64,10,32,32]>)
- // CHECK-SAME: {permutation = dense<[0, 3, 2, 1]> : tensor<4xi32>} : f32
- vmla.transpose %src(%src_shape : !shapex.ranked_shape<[64,32,32,10]>),
- out %dst(%dst_shape : !shapex.ranked_shape<[64,10,32,32]>)
- {permutation = dense<[0, 3, 2, 1]> : tensor<4xi32>} : f32
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_reverse
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-func @vmla_reverse(%src : !vmla.buffer,
- %src_shape : !shapex.ranked_shape<[4,8]>,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[4,8]>) {
- // CHECK: vmla.reverse
- // CHECK-SAME: %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[4,8]>),
- // CHECK-SAME: out
- // CHECK-SAME: %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[4,8]>)
- // CHECK-SAME: {dimensions = dense<1> : tensor<1xi32>} : f32
- vmla.reverse %src(%src_shape : !shapex.ranked_shape<[4,8]>),
- out %dst(%dst_shape : !shapex.ranked_shape<[4,8]>)
- {dimensions = dense<1> : tensor<1xi32>} : f32
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_pad
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[VALUE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[VALUE_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-func @vmla_pad(%src : !vmla.buffer,
- %src_shape : !shapex.ranked_shape<[4,8]>,
- %value : !vmla.buffer,
- %value_shape : !shapex.ranked_shape<[4,8]>,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[4,8]>) {
- // CHECK: vmla.pad
- // CHECK-SAME: %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[4,8]>),
- // CHECK-SAME: %[[VALUE]](%[[VALUE_SHAPE]] : !shapex.ranked_shape<[4,8]>),
- // CHECK-SAME: out
- // CHECK-SAME: %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[4,8]>)
- // CHECK-SAME: {edge_padding_high = dense<2> : tensor<i32>,
- // CHECK-SAME: edge_padding_low = dense<2> : tensor<i32>,
- // CHECK-SAME: interior_padding = dense<0> : tensor<i32>} : f32
- vmla.pad %src(%src_shape : !shapex.ranked_shape<[4,8]>),
- %value(%value_shape : !shapex.ranked_shape<[4,8]>),
- out %dst(%dst_shape : !shapex.ranked_shape<[4,8]>)
- {edge_padding_high = dense<2> : tensor<i32>,
- edge_padding_low = dense<2> : tensor<i32>,
- interior_padding = dense<0> : tensor<i32>} : f32
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_broadcast
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-func @vmla_broadcast(%src : !vmla.buffer,
- %src_shape : !shapex.ranked_shape<[]>,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[4,8]>) {
- // CHECK: vmla.broadcast
- // CHECK-SAME: %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[]>),
- // CHECK-SAME: out
- // CHECK-SAME: %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[4,8]>) : f32
- vmla.broadcast %src(%src_shape : !shapex.ranked_shape<[]>),
- out %dst(%dst_shape : !shapex.ranked_shape<[4,8]>) : f32
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_tile
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-func @vmla_tile(%src : !vmla.buffer,
- %src_shape : !shapex.ranked_shape<[4]>,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[4,8]>) {
- // CHECK: vmla.tile
- // CHECK-SAME: %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[4]>),
- // CHECK-SAME: out
- // CHECK-SAME: %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[4,8]>) : f32
- vmla.tile %src(%src_shape : !shapex.ranked_shape<[4]>),
- out %dst(%dst_shape : !shapex.ranked_shape<[4,8]>) : f32
- return
-}
-
-// -----
-
-// CHECK-LABEL: @vmla_gather
-// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[SRC_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[INDICES:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[INDICES_SHAPE:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST:[a-zA-Z0-9$._-]+]]
-// CHECK-SAME: %[[DST_SHAPE:[a-zA-Z0-9$._-]+]]
-func @vmla_gather(%src : !vmla.buffer,
- %src_shape : !shapex.ranked_shape<[4,8]>,
- %indices : !vmla.buffer,
- %indices_shape : !shapex.ranked_shape<[4,8]>,
- %dst : !vmla.buffer,
- %dst_shape : !shapex.ranked_shape<[4,8]>) {
- // CHECK: vmla.gather
- // CHECK-SAME: %[[SRC]](%[[SRC_SHAPE]] : !shapex.ranked_shape<[4,8]>),
- // CHECK-SAME: %[[INDICES]](%[[INDICES_SHAPE]] : !shapex.ranked_shape<[4,8]>),
- // CHECK-SAME: out
- // CHECK-SAME: %[[DST]](%[[DST_SHAPE]] : !shapex.ranked_shape<[4,8]>)
- // CHECK-SAME: {batch_dims = 2 : i64, dim = 1 : i64} : f32
- vmla.gather %src(%src_shape : !shapex.ranked_shape<[4,8]>),
- %indices(%indices_shape : !shapex.ranked_shape<[4,8]>),
- out %dst(%dst_shape : !shapex.ranked_shape<[4,8]>)
- {batch_dims = 2 : i64, dim = 1 : i64} : f32
- return
-}
diff --git a/iree/compiler/Dialect/VMLA/README.md b/iree/compiler/Dialect/VMLA/README.md
deleted file mode 100644
index cf59c54..0000000
--- a/iree/compiler/Dialect/VMLA/README.md
+++ /dev/null
@@ -1,167 +0,0 @@
-# VMLA (Virtual Machine-based Linear Algebra)
-
-This dialect is designed to closely model XLA HLO ops in a way that is easy to
-map to execution on the IREE VM. The changes involve using byte buffers instead
-of tensors, propagating shape information and converting shape math to simple
-integer arithmetic, and legalizing types to supported values (such as 1bit bools
-to 8bit integers of 0 or 1).
-
-## Adding an Op
-
-As with other VM modules, VMLA ops are declared in
-[vmla.imports.mlir](/iree/compiler/Dialect/VMLA/vmla.imports.mlir). These
-declarations are what enable the compiler and runtime side to talk to each
-other. It's helpful to start here to think about the information you need to
-communicate to the runtime prior to writing conversions. As a general rule, try
-to avoid communicating anything not strictly required for a correct
-implementation; instead, perform more work on the compiler side if it allows
-simpler ops to be implemented at runtime. For example, if there's an attribute
-on the op that selects between two different implementations, instead of
-plumbing that attribute through to runtime and switching there, one should
-implement the conversion to lower it into two ops. This makes it easier to
-reduce binary sizes, get accurate profiles at runtime, etc.
-
-TLDR:
-
-1. Add a `vm.import` to
- [vmla.imports.mlir](/iree/compiler/Dialect/VMLA/vmla.imports.mlir).
-2. Add an MLIR op def to
- [VMLAOps.td](/iree/compiler/Dialect/VMLA/IR/VMLAOps.td).
-3. Add a conversion from the source dialect like
- [HLOToVMLA](/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/).
-4. Add a conversion to the `vm.import` in
- [VMLAToVM](/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/).
-5. Add the runtime C++ kernel thunk to
- [vmla_module.cc](/iree/hal/vmla/vmla_module.cc).
-6. Declare the kernel in [op_kernels.h](/iree/modules/vmla/op_kernels.h) and add a
- reference implementation in
- [op_kernels_generic.h](/iree/modules/vmla/op_kernels_generic.h).
-
-### Declaring the Op
-
-See the file comments in
-[vmla.imports.mlir](/iree/compiler/Dialect/VMLA/vmla.imports.mlir) for style and
-naming conventions. Note that **the suffix naming convention is load-bearing and
-must be followed**.
-
-Add a new `vm.import` corresponding to the op you want to add. Try to group it
-with existing related ops in the file.
-
-If the op does not need to know type information then always prefer to use `xN`
-as the op suffix (like `copy.x8`, which copies 8-bit elements, as copies do not
-care that the bits they are copying are ints or floats).
-
-Almost all ops use output argument buffers (such as `%dst`). Only ops that
-return references to in-place contents of existing input buffers should return
-values (such as `buffer.view`).
-
-If shape information is required (and it's a good idea to make sure it
-absolutely is) then add shapes following the buffer they are related to, for
-example: `vm.import @transpose.x8(%src : !vm.ref<!vmla.buffer>, %src_shape : i32
-...)`
-
-### Adding the Op Tablegen Description
-
-Once the op is declared you can add the tablegen op def in
-[VMLAOps.td](/iree/compiler/Dialect/VMLA/IR/VMLAOps.td). Match the order and
-grouping in this file with the `vmla.imports.mlir` file to make moving between
-the two easier.
-
-The automated conversion helper uses names and order to match the op defs with
-the `vm.import` declarations. Make sure the names of all argument values and
-attributes match those in the declaration.
-
-Many ops can be expressed with `VMLA_UnaryOp`/`VMLA_BinaryOp`/etc classes such
-as `VMLA_AddOp`. These will automatically get their `lhs`/`rhs`/`dst` and fan
-out to the given type group. For example, use `VMLA_AnyTypeAttr` will allow both
-integers and floats of various bit depths, while `VMLA_FloatTypeAttr` will only
-allow floating-point values. These should match to which suffixes you defined in
-the import; for example if you only have `foo.f32` declared to indicate that the
-op only operates on floating-point values then use `VMLA_FloatTypeAttr`).
-
-For ops that don't fit the unary/binary/etc form you can use the
-`VMLA_ElementTypeOp` class to get at least get the automated type suffix
-conversion. These expect an argument of `VMLA_*TypeAttr:$element_type` to store
-the appropriate type from the result value and it will be populated
-automatically.
-
-For ops that require shapes you must add the `VMLA_IncludeShapes` trait to tell
-the automated conversion helper to insert shape information. Again, really try
-to avoid passing shape information if possible (use element counts/etc that you
-can derive from buffer sizes at runtime, if needed).
-
-Finally, some ops like `VMLA_MatMulOp` may use multiple types and may need to
-provide their own type extraction and suffix creation logic. For these add the
-`VMLA_OpInterface` trait and define an `extractTypeAttributes` function.
-
-### Converting to the VMLA Op
-
-There are two conversion required: one from the source dialect to your new VMLA
-op and one from the VMLA op to the VM import call.
-
-See [HLOToVMLA](/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/) for examples
-of the former. Most ops can use the `VMLAOpConversion` helper to automatically
-convert between ops so long as they match in values and attributes (for example,
-`mhlo.add` can be trivially converted to `vmla.add`). Examples of more complex
-ops that may require additional IR to be emitted or attributes to be mapped can
-be seen in there as well.
-
-You can add tests for your conversion as needed under `test/` in the appropriate
-dialect-specific conversion folder.
-
-### Converting to the VM Import
-
-If your new op is defined well then the conversion from VMLA to VM should be
-straightforward. Many ops can use `VMLA_*_IMPORT_OP` macros to perform the
-conversion automatically. See
-[VMLAToVM](/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/) for examples.
-
-* Ops that have no suffix (like `vmla.buffer.fill`) use `VMLA_IMPORT_OP`.
-* Ops that only require the bit-depth (like `vmla.copy.x8`) use
- `VMLA_SIZED_IMPORT_OP`.
-* Ops that require type information (like `vmla.cmp.f32`) use
- `VMLA_TYPED_IMPORT_OP`.
-
-Custom conversions can be performed as well but try to avoid that.
-
-You can add tests for your conversion under the `test/` path and are encouraged
-to do so particularly if not using the `VMLA_*_IMPORT_OP` macros.
-
-### Add the Runtime Kernel
-
-[vmla_module.cc](/iree/hal/vmla/vmla_module.cc) contains the runtime companion
-of the `vmla.imports.mlir` file mapping from the VM calls to C++. Again add your
-function in here in the same place as you did in the other files. Follow the
-example of other functions in the file for how to declare arguments, how to add
-the `IREE_TRACE_SCOPE` line, etc.
-
-There are some helpers such as `IREE_VMLA_BINARY_OP` that match the equivalents
-in the tablegen file such that if your op can usually be just a single line.
-
-The thunks in this file just call one of the kernels defined in the
-[op_kernels.h](/iree/modules/vmla/op_kernels.h) file. These kernels are designed to
-be standalone from the VM code and take effectively just pointers and lists of
-values. The job of the `vmla_module.cc` thunk is to unwrap the VM arguments and
-pass them to these functions.
-
-Declare your new kernel in the header without its implementation. If your kernel
-needs to keep state at runtime you can follow what `MatMul` does with the
-`RuntimeState` struct, however it is strongly discouraged and almost never
-required, so avoid if possible. One way to avoid it is to make your op take any
-scratch memory it may require as an argument and generate the IR during
-conversion. This ensures that we can optimize things on the compiler-side
-instead of forcing the runtime to deal with things.
-
-Finally, implement the kernel in
-[op_kernels_generic.h](/iree/modules/vmla/op_kernels_generic.h). Try to keep it
-simple and readable. These are reference kernels and don't need to be fast,
-however all of our tests use them and as such they shouldn't be so slow as to
-prevent tests from running in a reasonable time. Use your judgement or be
-willing to have someone file a bug telling you to make them faster if they are
-terribly slow :)
-
-Tests for the kernels can be added to
-[op_kernels_test.cc](/iree/modules/vmla/op_kernels_test.cc). The thunks in
-`vmla_module.cc` are best tested via end-to-end tests using `iree-run-mlir` as
-what you really want to ensure is that the compiler is emitting calls that match
-the runtime side and the only way to do this is to actually compile and run.
diff --git a/iree/compiler/Dialect/VMLA/Transforms/BUILD b/iree/compiler/Dialect/VMLA/Transforms/BUILD
deleted file mode 100644
index 7435c5e..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/BUILD
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-cc_library(
- name = "Transforms",
- srcs = [
- "Conversion.cpp",
- "Passes.cpp",
- "PreConversionLowering.cpp",
- "UnrollReductions.cpp",
- ],
- hdrs = [
- "Passes.h",
- ],
- deps = [
- "//iree/compiler/Conversion/HLOToHLO",
- "//iree/compiler/Dialect/HAL/IR:HALDialect",
- "//iree/compiler/Dialect/IREE/Transforms",
- "//iree/compiler/Dialect/Shape/IR",
- "//iree/compiler/Dialect/Shape/Transforms",
- "//iree/compiler/Dialect/VMLA/Conversion",
- "//iree/compiler/Dialect/VMLA/Conversion/HALToVMLA",
- "//iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA",
- "//iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA",
- "//iree/compiler/Dialect/VMLA/IR",
- "//iree/compiler/Dialect/VMLA/IR:VMLADialect",
- "@llvm-project//llvm:Support",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Pass",
- "@llvm-project//mlir:StandardOps",
- "@llvm-project//mlir:Support",
- "@llvm-project//mlir:Transforms",
- "@mlir-hlo//:hlo",
- "@mlir-hlo//:lhlo_fuse_linalg",
- "@mlir-hlo//:mhlo_to_mhlo_lowering_patterns",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
deleted file mode 100644
index fcd64e5..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/CMakeLists.txt
+++ /dev/null
@@ -1,45 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/Transforms/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_cc_library(
- NAME
- Transforms
- HDRS
- "Passes.h"
- SRCS
- "Conversion.cpp"
- "Passes.cpp"
- "PreConversionLowering.cpp"
- "UnrollReductions.cpp"
- DEPS
- LLVMSupport
- MLIRIR
- MLIRPass
- MLIRStandard
- MLIRSupport
- MLIRTransforms
- iree::compiler::Conversion::HLOToHLO
- iree::compiler::Dialect::HAL::IR::HALDialect
- iree::compiler::Dialect::IREE::Transforms
- iree::compiler::Dialect::Shape::IR
- iree::compiler::Dialect::Shape::Transforms
- iree::compiler::Dialect::VMLA::Conversion
- iree::compiler::Dialect::VMLA::Conversion::HALToVMLA
- iree::compiler::Dialect::VMLA::Conversion::HLOToVMLA
- iree::compiler::Dialect::VMLA::Conversion::StandardToVMLA
- iree::compiler::Dialect::VMLA::IR
- iree::compiler::Dialect::VMLA::IR::VMLADialect
- tensorflow::mlir_hlo
- PUBLIC
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp b/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
deleted file mode 100644
index cfdab50..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/Conversion.cpp
+++ /dev/null
@@ -1,146 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/Transforms/Patterns.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/StandardToVMLA/ConvertStandardToVMLA.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/TypeConverter.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
-#include "llvm/ADT/STLExtras.h"
-#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace VMLA {
-
-// Rewrites entry functions to have a vmla.interface and an XYZ workgroup ID.
-// The runtime will provide these values during invocation.
-static LogicalResult insertInterfacesToEntryPoints(mlir::ModuleOp moduleOp) {
- for (auto funcOp : moduleOp.getOps<FuncOp>()) {
- if (!funcOp.isPublic()) {
- continue;
- }
- auto originalType = funcOp.getType();
- if (originalType.getNumInputs() != 0 || originalType.getNumResults() != 0) {
- return funcOp.emitError() << "exported functions must have no I/O";
- }
- auto interfaceType = IREE::VMLA::InterfaceType::get(moduleOp.getContext());
- auto indexType = IndexType::get(moduleOp.getContext());
- auto newType =
- FunctionType::get(moduleOp.getContext(),
- {interfaceType, indexType, indexType, indexType}, {});
- funcOp.setType(newType);
- funcOp.front().addArguments(
- {interfaceType, indexType, indexType, indexType});
- }
- return success();
-}
-
-// Runs conversion with registered input dialects.
-class ConversionPass
- : public PassWrapper<ConversionPass, OperationPass<mlir::ModuleOp>> {
- public:
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<ShapeDialect, IREE::VMLA::VMLADialect>();
- }
-
- void runOnOperation() override {
- // First insert vmla.interface arguments to all exported functions.
- // The conversions require that the interface argument is present in order
- // to properly retrieve buffer bindings.
- if (failed(insertInterfacesToEntryPoints(getOperation()))) {
- return signalPassFailure();
- }
-
- auto *context = &getContext();
- VMLATypeConverter typeConverter;
- VMLAConversionTarget conversionTarget(context, typeConverter);
-
- // Ensure all input dialects go away.
- conversionTarget.addIllegalDialect<mhlo::MhloDialect>();
- conversionTarget.addIllegalDialect<IREE::HAL::HALDialect>();
-
- OwningRewritePatternList conversionPatterns(&getContext());
- populateStandardToVMLAPatterns(context, conversionPatterns, typeConverter);
- populateHLOToVMLAPatterns(context, conversionPatterns, typeConverter);
- populateHALToVMLAPatterns(context, conversionPatterns, typeConverter);
-
- // Ensure FuncOp signatures are updated.
- populateFuncOpTypeConversionPattern(conversionPatterns, typeConverter);
-
- // We allow the shape dialect to persist, making specific dim queries
- // illegal (which allows them to fold away). These patterns allow dimension
- // queries to convert properly, but they do not allow the introduction
- // of new shaped tensors.
- Shape::populateFoldConversionPatterns(&getContext(), conversionPatterns);
- conversionTarget.addLegalDialect<ShapeDialect>();
- // Since all inputs are converted to buffers, must trigger the TieShape
- // type conversion if the result type is illegal.
- conversionTarget.addDynamicallyLegalOp<Shape::TieShapeOp>(
- [](Shape::TieShapeOp op) {
- return op.result().getType().isa<BufferType>();
- });
- conversionTarget.addIllegalOp<Shape::RankedDimOp>();
- conversionTarget.addIllegalOp<Shape::RankedDimsOp>();
- // XLA ops use tensors of extents, so we tend to launder back to
- // !shapex.ranked_shape for most shape-related things. This is a problem
- // because we don't have a lowering for the ops going back and forth between
- // tensors of extents and !shapex.ranked_shape. So we mark this op as
- // illegal and rely on our fold of `from_extent_tensor(to_extent_tensor(x))
- // -> x` to eliminate these ops. Setting it illegal here triggers that fold.
- // This is skating on thin ice.
- // TODO(silvasean): Legalize ToExtentTensorOp and FromExtentTensorOp.
- conversionTarget.addIllegalOp<Shape::FromExtentTensorOp>();
- // IotaOp and RankedBroadcastInDimOp is an logically something that should
- // be an mhlo op (or in a dialect at a similar level of abstraction), but
- // since it isn't technically in that dialect, we need to special-case mark
- // it as illegal here.
- // TODO(silvasean): Reconcile the dialect layering here.
- conversionTarget.addIllegalOp<Shape::IotaOp>();
- conversionTarget.addIllegalOp<Shape::RankedBroadcastInDimOp>();
-
- if (failed(applyPartialConversion(getOperation(), conversionTarget,
- std::move(conversionPatterns)))) {
- getOperation().emitError() << "conversion to the VMLA dialect failed";
- return signalPassFailure();
- }
- }
-};
-
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createConversionPass() {
- return std::make_unique<ConversionPass>();
-}
-
-static PassRegistration<ConversionPass> pass(
- "iree-vmla-conversion",
- "Converts from various dialects to the VMLA dialect");
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp b/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
deleted file mode 100644
index 384d050..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/Passes.cpp
+++ /dev/null
@@ -1,127 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
-
-#include <memory>
-
-#include "iree/compiler/Conversion/HLOToHLO/Passes.h"
-#include "iree/compiler/Dialect/IREE/Transforms/Passes.h"
-#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
-#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Transforms/Passes.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace VMLA {
-
-void buildVMLATransformPassPipeline(OpPassManager &passManager) {
- passManager.addPass(createCanonicalizerPass());
-
- // ---------------------------------------------------------------------------
- // Inline and flatten structured control flow to our CFG.
- // ---------------------------------------------------------------------------
- passManager.addNestedPass<FuncOp>(mhlo::createLegalizeControlFlowPass());
-
- // Perform inlining and cleanup after CFG manipulation.
- passManager.addPass(createInlinerPass());
- passManager.addPass(createSymbolDCEPass());
- passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
- passManager.addNestedPass<FuncOp>(createCSEPass());
-
- // ---------------------------------------------------------------------------
- // Tensor-level rewrites.
- // At this point, the computation is in tensor-level CFG form.
- // There are no specific requirements on shape-related calculations at this
- // point yet, so general tensor->tensor transformations in preparation
- // for later conversion steps should go here.
- // ---------------------------------------------------------------------------
- // Legalize input types.
- // TODO(benvanik): legalize input.
- // passManager.addPass(IREE::VMLA::createLegalizeInputTypesPass());
-
- // TODO(benvanik): preserve these hints during conversion.
- passManager.addNestedPass<FuncOp>(createDropCompilerHintsPass());
-
- // Unroll multi-dimensional reductions to one reduction per dimension.
- passManager.addNestedPass<FuncOp>(createUnrollReductionsPass());
-
- // Converts mhlo.convolution ops with 1x1 kernels into mhlo.dot ops.
- passManager.addNestedPass<FuncOp>(createConvert1x1ConvToDotPass());
-
- // Tensor-level pattern-based lowerings. Thrown into one pass for simplicity.
- passManager.addNestedPass<FuncOp>(createPreConversionLoweringPass());
-
- // Clean up the IR before going into shape-materialized IR.
- passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
-
- // ---------------------------------------------------------------------------
- // Shape calculation.
- // Pre-conditions:
- // - All transformations altering the tensor-level shapes have been done.
- // - "Root" dynamic tensors all pass through a single shapex.tie_shape
- // use which associates them to their shape.
- // - Loose, non-associated shapex.get_ranked_shape ops can exist anywhere
- // and will be resolved.
- // Post-conditions:
- // - All dynamic tensors bridge through a shapex.tie_shape op with the
- // appropriate shape.
- // - No shapex.get_ranked_shape ops exist.
- // - Shape folding and canonicalization has been done.
- // ---------------------------------------------------------------------------
- passManager.addNestedPass<FuncOp>(Shape::createTieDynamicShapesPass());
- passManager.addNestedPass<FuncOp>(
- Shape::createMaterializeShapeCalculationsPass());
- passManager.addNestedPass<FuncOp>(Shape::createHoistShapeCalculationsPass());
-
- // ---------------------------------------------------------------------------
- // VMLA conversion.
- // Performs lowering from tensor-level to VMLA-level ops/types and on to the
- // VM dialect.
- // Pre-conditions:
- // - All tensors with dynamic dimensions must have a tie_shape use which
- // associates them with the SSA values providing the missing dims.
- // - Functions must be in CFG form.
- // - Any non-trivial tensor-level transformations have already been done.
- // - No shapex.get_ranked_shape ops can exist (or be introduced).
- // Post-conditions:
- // - All ops and types have been fully lowered to the VM dialect.
- // ---------------------------------------------------------------------------
- passManager.addNestedPass<FuncOp>(createCSEPass());
- passManager.addPass(createConversionPass());
-
- // ---------------------------------------------------------------------------
- // Cleanup identity ops that clutter up the IR and canonicalize.
- // ---------------------------------------------------------------------------
- passManager.addNestedPass<FuncOp>(createCSEPass());
- passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
-
- // TODO(benvanik): run symbol DCE pass.
-}
-
-void createVMLATransformPassPipeline() {
- PassPipelineRegistration<> transformPassPipeline(
- "iree-vmla-transformation-pipeline",
- "Runs the full IREE VMLA dialect transformation pipeline",
- [](OpPassManager &passManager) {
- buildVMLATransformPassPipeline(passManager);
- });
-}
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Transforms/Passes.h b/iree/compiler/Dialect/VMLA/Transforms/Passes.h
deleted file mode 100644
index df75022..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/Passes.h
+++ /dev/null
@@ -1,82 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef IREE_COMPILER_DIALECT_VMLA_TRANSFORMS_PASSES_H_
-#define IREE_COMPILER_DIALECT_VMLA_TRANSFORMS_PASSES_H_
-
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "llvm/ADT/StringMap.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/LLVM.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace VMLA {
-
-//===----------------------------------------------------------------------===//
-// Helpers
-//===----------------------------------------------------------------------===//
-
-// Adds a set of passes to the given pass manager that run the required VMLA
-// transforms in the canonical order.
-//
-// Most translation code should prefer to use this instead of manually adding
-// the passes themselves to ensure that expected pass ordering is observed.
-//
-// The expected usage is:
-// <run conversion from TF/HLO/etc to flow>
-// buildVMLATransformPassPipeline & run
-// <serialize VM module>
-void buildVMLATransformPassPipeline(OpPassManager &passManager);
-
-void createVMLATransformPassPipeline();
-
-//===----------------------------------------------------------------------===//
-// Input canonicalization and legalization
-//===----------------------------------------------------------------------===//
-
-// Unrolls multi-dimensional reduction operations into reductions along each
-// dimension, from innermost to outermost.
-std::unique_ptr<OperationPass<FuncOp>> createUnrollReductionsPass();
-
-// Tensor-level pattern-based lowerings. Thrown into one pass for simplicity.
-std::unique_ptr<OperationPass<FuncOp>> createPreConversionLoweringPass();
-
-//===----------------------------------------------------------------------===//
-// Dialect conversion
-//===----------------------------------------------------------------------===//
-
-// Converts from various dialects (standard, HLO, etc) to the VMLA dialect.
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createConversionPass();
-
-//===----------------------------------------------------------------------===//
-// Register all Passes
-//===----------------------------------------------------------------------===//
-
-inline void registerVMLAPasses() {
- createVMLATransformPassPipeline();
- createUnrollReductionsPass();
- createConversionPass();
- createPreConversionLoweringPass();
-}
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
-
-#endif // IREE_COMPILER_DIALECT_VMLA_TRANSFORMS_PASSES_H_
diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
deleted file mode 100644
index 1afdbd2..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
+++ /dev/null
@@ -1,529 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
-#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/BitVector.h"
-#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace VMLA {
-
-namespace {
-
-// Removes no-op transpose.
-//
-// TODO(silvasean): This is a temporary workaround after upstream MLIR
-// (https://reviews.llvm.org/D95991) changed canoncalization to bail out on
-// different types. Figure out a better way to handle type specialization style
-// canonicalization in general.
-struct CanonicalizeTranspose : public OpRewritePattern<mhlo::TransposeOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(mhlo::TransposeOp op,
- PatternRewriter &rewriter) const override {
- for (auto it : llvm::enumerate(op.permutation().getValues<APInt>())) {
- if (it.index() != it.value()) return failure();
- }
- rewriter.replaceOp(op, op.operand());
- return success();
- }
-};
-
-// Convert instances of `mhlo.dot` to `mhlo.dot_general`.
-//
-// TODO(silvasean): This logically is part of a future HLO client -> HLO server
-// type of pass in the mhlo dialect proper.
-struct LowerDotOp : public OpRewritePattern<mhlo::DotOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(mhlo::DotOp op,
- PatternRewriter &rewriter) const override {
- Value lhs = op.lhs();
- Value rhs = op.rhs();
- RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
- RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
- if (!lhsType || !rhsType) {
- return failure();
- }
- if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
- return failure();
- }
- // TODO(silvasean): Move this helper to MLIR core.
- auto make1DElementsAttr = [&rewriter](ArrayRef<int64_t> integers) {
- auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
- rewriter.getIntegerType(64));
- return DenseIntElementsAttr::get(type, integers);
- };
- auto dimensionNumbers = mhlo::DotDimensionNumbers::get(
- /*lhs_batching_dimensions=*/make1DElementsAttr({}),
- /*rhs_batching_dimensions=*/make1DElementsAttr({}),
- /*lhs_contracting_dimensions=*/make1DElementsAttr({1}),
- /*rhs_contracting_dimensions=*/make1DElementsAttr({0}),
- rewriter.getContext());
- rewriter.replaceOpWithNewOp<mhlo::DotGeneralOp>(
- op, op.getType(), lhs, rhs, dimensionNumbers,
- op.precision_config().hasValue() ? op.precision_config().getValue()
- : nullptr);
- return success();
- }
-};
-
-// Inserts transposes on the operands of DotGeneralOp's such that the resulting
-// batch dimensions are all the leading dimensions and all the contracting
-// dimensions are all the trailing dimensions.
-//
-// Furthermore, all batch, contracting, and free dimensions are flattened into
-// single dimensions, with an appropriate reshape back to the original
-// dimensions.
-//
-// This results in a very simple corresponding VMLA op in the runtime.
-// [1 batch dimension, 1 free dimension, 1 contracting dimension].
-//
-// The result doesn't have a DotGeneralOp, but rather a
-// VMLA::BatchMatMulPseudoOp which represents this transformation.
-//
-// TODO(silvasean): Move this to a "prepare" pass and test separately.
-struct LowerDotGeneralOp : public OpRewritePattern<mhlo::DotGeneralOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
- PatternRewriter &rewriter) const override {
- Value lhs = op.lhs();
- Value rhs = op.rhs();
- RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
- RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
- RankedTensorType dstType =
- op.getResult().getType().dyn_cast<RankedTensorType>();
- if (!lhsType || !rhsType || !dstType) {
- return rewriter.notifyMatchFailure(op, "requires ranked types");
- }
- mhlo::DotDimensionNumbers dimNumbers = op.dot_dimension_numbers();
- auto extract1DVector = [](DenseIntElementsAttr elements) {
- SmallVector<int64_t, 6> ret;
- for (const APInt &element : elements) {
- ret.push_back(element.getLimitedValue());
- }
- return ret;
- };
- auto lhsBatchingDims =
- extract1DVector(dimNumbers.lhs_batching_dimensions());
- auto rhsBatchingDims =
- extract1DVector(dimNumbers.rhs_batching_dimensions());
- auto lhsContractingDims =
- extract1DVector(dimNumbers.lhs_contracting_dimensions());
- auto rhsContractingDims =
- extract1DVector(dimNumbers.rhs_contracting_dimensions());
- // TODO(silvasean): Move this helper to MLIR core.
- auto make1DElementsAttr = [&rewriter](ArrayRef<int64_t> integers) {
- auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
- rewriter.getIntegerType(64));
- return DenseIntElementsAttr::get(type, integers);
- };
- auto totalElements = [&](ArrayRef<Value> extents) {
- Value numElements = rewriter.create<mlir::ConstantOp>(
- op.getLoc(), IntegerAttr::get(rewriter.getIndexType(), 1));
- for (Value extent : extents) {
- numElements =
- rewriter.create<mlir::MulIOp>(op.getLoc(), numElements, extent);
- }
- return numElements;
- };
- auto handleOneSide = [&](ArrayRef<int64_t> batchingDims,
- ArrayRef<int64_t> contractingDims, Value &value,
- RankedTensorType &type,
- SmallVectorImpl<int64_t> &outFreeDims,
- SmallVectorImpl<Value> &outFreeDimExtents,
- SmallVectorImpl<Value> &outBatchingDimExtents) {
- outBatchingDimExtents.clear();
- RankedTensorType untransposedType = type;
- Type elementType = type.getElementType();
- SmallVector<int64_t, 6> permutation;
- llvm::BitVector freeDims(untransposedType.getRank(), true);
- SmallVector<Value, 6> contractingDimExtents;
- Value valueShape =
- rewriter.create<Shape::GetRankedShapeOp>(op.getLoc(), value);
- auto getExtentValue = [&](int64_t dim) {
- return rewriter.create<Shape::RankedDimOp>(op.getLoc(), valueShape,
- dim);
- };
- for (auto dims : {batchingDims, contractingDims}) {
- for (int64_t dim : dims) {
- freeDims.reset(dim);
- }
- }
- for (int64_t dim : batchingDims) {
- permutation.push_back(dim);
- outBatchingDimExtents.push_back(getExtentValue(dim));
- }
- for (int64_t dim : freeDims.set_bits()) {
- permutation.push_back(dim);
- outFreeDims.push_back(dim);
- outFreeDimExtents.push_back(getExtentValue(dim));
- }
- for (int64_t dim : contractingDims) {
- permutation.push_back(dim);
- contractingDimExtents.push_back(getExtentValue(dim));
- }
- // Construct the type that the transpose will result in.
- SmallVector<int64_t, 6> transposeStaticShape;
- for (int64_t index : permutation) {
- (void)index;
- transposeStaticShape.push_back(-1);
- }
- auto transposeType =
- RankedTensorType::get(transposeStaticShape, elementType);
- auto transpose = rewriter.create<mhlo::TransposeOp>(
- op.getLoc(), transposeType, value, make1DElementsAttr(permutation));
-
- SmallVector<Value, 6> reshapeShape;
- reshapeShape.push_back(totalElements(outBatchingDimExtents));
- reshapeShape.push_back(totalElements(outFreeDimExtents));
- reshapeShape.push_back(totalElements(contractingDimExtents));
- auto reshapeType = RankedTensorType::get(
- {static_cast<int64_t>(-1), static_cast<int64_t>(-1),
- static_cast<int64_t>(-1)},
- elementType);
- auto reshapeRankedShape = rewriter.create<Shape::MakeRankedShapeOp>(
- op.getLoc(),
- Shape::RankedShapeType::get(reshapeType.getShape(),
- rewriter.getContext()),
- reshapeShape);
- auto reshapeShapeExtentTensor = rewriter.create<Shape::ToExtentTensorOp>(
- op.getLoc(), reshapeRankedShape);
- value = rewriter.create<mhlo::DynamicReshapeOp>(
- op.getLoc(), reshapeType, transpose, reshapeShapeExtentTensor);
- };
- SmallVector<Value, 6> batchingDimExtents;
- SmallVector<int64_t, 6> lhsFreeDims;
- SmallVector<Value, 6> lhsFreeDimExtents;
- handleOneSide(lhsBatchingDims, lhsContractingDims, lhs, lhsType,
- lhsFreeDims, lhsFreeDimExtents, batchingDimExtents);
- SmallVector<int64_t, 6> rhsFreeDims;
- SmallVector<Value, 6> rhsFreeDimExtents;
- handleOneSide(rhsBatchingDims, rhsContractingDims, rhs, rhsType,
- rhsFreeDims, rhsFreeDimExtents, batchingDimExtents);
-
- auto dstStaticShape = llvm::to_vector<6>(
- llvm::makeArrayRef({static_cast<int64_t>(-1), static_cast<int64_t>(-1),
- static_cast<int64_t>(-1)}));
- auto dstElementType = dstType.getElementType();
- Value dst = rewriter.create<IREE::VMLA::BatchMatMulPseudoOp>(
- op.getLoc(), RankedTensorType::get(dstStaticShape, dstElementType), lhs,
- rhs);
- RankedTensorType transposeType = RankedTensorType::get(
- {dstStaticShape[0], dstStaticShape[2], dstStaticShape[1]},
- dstElementType);
- auto transpose = rewriter.create<mhlo::TransposeOp>(
- op.getLoc(), transposeType, dst, make1DElementsAttr({0, 2, 1}));
- auto reshapeShape = batchingDimExtents;
- reshapeShape.append(lhsFreeDimExtents.begin(), lhsFreeDimExtents.end());
- reshapeShape.append(rhsFreeDimExtents.begin(), rhsFreeDimExtents.end());
- SmallVector<int64_t, 6> reshapeStaticShape;
- for (int i = 0, e = batchingDimExtents.size() + lhsFreeDimExtents.size() +
- rhsFreeDimExtents.size();
- i < e; i++) {
- reshapeStaticShape.push_back(-1);
- }
- auto reshapeRankedShape = rewriter.create<Shape::MakeRankedShapeOp>(
- op.getLoc(),
- Shape::RankedShapeType::get(reshapeStaticShape, rewriter.getContext()),
- reshapeShape);
- auto reshapeShapeExtentTensor = rewriter.create<Shape::ToExtentTensorOp>(
- op.getLoc(), reshapeRankedShape);
- rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
- op, op.getType(), transpose, reshapeShapeExtentTensor);
- return success();
- }
-};
-
-class LowerBroadcastInDimOp : public OpRewritePattern<mhlo::BroadcastInDimOp> {
- public:
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(mhlo::BroadcastInDimOp op,
- PatternRewriter &rewriter) const override {
- auto type = op.getType().cast<RankedTensorType>();
- auto shapeType =
- Shape::RankedShapeType::get(type.getShape(), rewriter.getContext());
- auto shape =
- rewriter.create<Shape::ConstRankedShapeOp>(op.getLoc(), shapeType);
- rewriter.replaceOpWithNewOp<Shape::RankedBroadcastInDimOp>(
- op, op.getType(), op.operand(), shape, op.broadcast_dimensions());
- return success();
- }
-};
-
-// Lower mhlo::BroadcastOp via mhlo::BroadcastInDimOp.
-class LowerBroadcastOp : public OpRewritePattern<mhlo::BroadcastOp> {
- public:
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(mhlo::BroadcastOp op,
- PatternRewriter &rewriter) const override {
- auto type = op.getOperand().getType().cast<RankedTensorType>();
- auto resultType = op.getType().cast<RankedTensorType>();
- auto broadcastDimensions = llvm::to_vector<6>(llvm::seq<int64_t>(
- resultType.getRank() - type.getRank(), resultType.getRank()));
- rewriter.replaceOpWithNewOp<mhlo::BroadcastInDimOp>(
- op, op.getType(), op.getOperand(),
- rewriter.getI64TensorAttr(broadcastDimensions));
- return success();
- }
-};
-
-// Lower mhlo::SortOp to an pseudo SortOp in the VMLA dialect. This
-// pseudo op generates a set of ordered indices for that array along the last
-// dimension. Then using a torch_index_select the values can be reordered to
-// support arbitrary inputs.
-//
-// TODO(suderman): This lowering only covers the case of ascending values, we
-// should support a separate descending value case by having separate
-// SortAscending and SortDescending operations.
-class LowerSortOp : public OpRewritePattern<mhlo::SortOp> {
- public:
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(mhlo::SortOp op,
- PatternRewriter &rewriter) const override {
- auto operandTy = op.getOperand(0).getType().cast<RankedTensorType>();
- bool lastDimension =
- (op.dimension() == -1) || (op.dimension() == (operandTy.getRank() - 1));
-
- // TODO(suderman): Add transpose to sort along the last dimension.
- if (!lastDimension) return failure();
-
- auto &comparator = op.comparator();
- auto &block = comparator.getBlocks().front();
- auto &operations = block.getOperations();
- auto comparison = dyn_cast_or_null<mhlo::CompareOp>(&operations.front());
-
- // First verify that the block is purely a return of a comparison. This
- // handles sorting a single tensor of values.
- if (!comparison) return failure();
-
- auto returnOp =
- dyn_cast_or_null<mhlo::ReturnOp>(&(*(++operations.begin())));
- if (!returnOp) return failure();
-
- if (returnOp.getOperand(0) != comparison.getResult()) return failure();
-
- // Determine which operands being compared.
- auto lhs = comparison.getOperand(0);
- auto rhs = comparison.getOperand(1);
- auto lhsIndex = -1;
- auto rhsIndex = -1;
- for (auto arg : llvm::enumerate(block.getArguments())) {
- if (arg.value() == lhs) lhsIndex = arg.index();
- if (arg.value() == rhs) rhsIndex = arg.index();
- }
-
- // This should never happen but best to check.
- if (lhsIndex == -1) return failure();
- if (rhsIndex == -1) return failure();
-
- // They should not be the same.
- if (lhsIndex == rhsIndex) return failure();
-
- // Comparisons need to pull from same Sort operand..
- auto lhsOperand = lhsIndex / 2;
- auto rhsOperand = rhsIndex / 2;
- if (lhsOperand != rhsOperand) return failure();
-
- // Must be GT, GE, LT, or LE.
- auto isGt = comparison.comparison_direction() == "GT" ||
- comparison.comparison_direction() == "GE";
- auto isLt = comparison.comparison_direction() == "LT" ||
- comparison.comparison_direction() == "LE";
- if (!isGt && !isLt) return failure();
-
- bool operandParity = lhsIndex > rhsIndex;
- auto isAscending = operandParity ^ isGt;
- // TODO(suderman): Add support for descended sorting.
- if (!isAscending) return failure();
-
- auto operand = op.getOperand(lhsOperand);
- auto sortedIndices = rewriter.create<VMLA::SortPseudoOp>(
- op.getLoc(),
- RankedTensorType::get(operandTy.getShape(), rewriter.getI32Type()),
- operand);
-
- llvm::SmallVector<Value, 6> sortedResults;
- for (auto operand : op.getOperands()) {
- auto tensorTy = operand.getType().cast<RankedTensorType>();
- auto gathered = rewriter.create<mhlo::TorchIndexSelectOp>(
- op.getLoc(), tensorTy, operand, sortedIndices,
- /**dim=*/operandTy.getRank() - 1,
- /**batch_dims=*/operandTy.getRank() - 1);
- sortedResults.push_back(gathered);
- }
-
- rewriter.replaceOp(op, sortedResults);
- return success();
- }
-};
-
-class LowerFftOp : public OpRewritePattern<mhlo::FftOp> {
- public:
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(mhlo::FftOp op,
- PatternRewriter &rewriter) const override {
- auto tensor_type = op.operand().getType().cast<RankedTensorType>();
- auto fft_type = op.fft_type();
-
- if (fft_type == "RFFT") {
- return ReplaceRfft(op, tensor_type, op.getOperand(), rewriter);
- }
-
- auto real = rewriter.create<mhlo::RealOp>(op.getLoc(), op.getOperand());
- auto imag = rewriter.create<mhlo::ImagOp>(op.getLoc(), op.getOperand());
-
- if (fft_type == "FFT") {
- return ReplaceFftOpComplextoComplex<VMLA::FftPseudoOp>(
- op, tensor_type, real, imag, rewriter);
-
- } else if (fft_type == "IFFT") {
- return ReplaceFftOpComplextoComplex<VMLA::IfftPseudoOp>(
- op, tensor_type, real, imag, rewriter);
-
- } else if (fft_type == "IRFFT") {
- return ReplaceIrfft(op, tensor_type, real, imag, rewriter);
- }
- return rewriter.notifyMatchFailure(op, "FFT type not recognized");
- }
-
- private:
- template <typename T>
- LogicalResult ReplaceFftOpComplextoComplex(mhlo::FftOp op,
- RankedTensorType tensor_type,
- mhlo::RealOp real,
- mhlo::ImagOp imag,
- PatternRewriter &rewriter) const {
- auto results = rewriter.create<T>(op.getLoc(), real.getType(),
- imag.getType(), real, imag);
- auto complex_result = rewriter.create<mhlo::ComplexOp>(
- op.getLoc(), op.getType(), results.real_out(), results.imag_out());
- rewriter.replaceOp(op, {complex_result});
- return success();
- }
-
- LogicalResult ReplaceRfft(mhlo::FftOp op, RankedTensorType input_tensor_type,
- mlir::Value real, PatternRewriter &rewriter) const {
- RankedTensorType new_type =
- RankedTensorType::get(op.getType().cast<ShapedType>().getShape(),
- input_tensor_type.getElementType());
-
- auto results = rewriter.create<VMLA::RfftPseudoOp>(op.getLoc(), new_type,
- new_type, real);
- auto complex_result = rewriter.create<mhlo::ComplexOp>(
- op.getLoc(), op.getType(), results.real_out(), results.imag_out());
- rewriter.replaceOp(op, {complex_result});
- return success();
- }
-
- LogicalResult ReplaceIrfft(mhlo::FftOp op, RankedTensorType input_tensor_type,
- mhlo::RealOp real, mhlo::ImagOp imag,
- PatternRewriter &rewriter) const {
- auto results = rewriter.create<VMLA::IrfftPseudoOp>(
- op.getLoc(), op.getType(), real, imag);
- rewriter.replaceOp(op, {results});
- return success();
- }
-};
-
-class PreConversionLoweringPass
- : public PassWrapper<PreConversionLoweringPass, OperationPass<FuncOp>> {
- public:
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<ShapeDialect, IREE::VMLA::VMLADialect>();
- }
-
- void runOnOperation() override {
- MLIRContext *context = &getContext();
-
- // These patterns should be run greedily as they are not dialect
- // conversions.
- OwningRewritePatternList greedyPatterns(&getContext());
- mhlo::PopulateComplexLoweringPatterns(context, &greedyPatterns);
- if (failed(applyPatternsAndFoldGreedily(getOperation(),
- std::move(greedyPatterns)))) {
- return signalPassFailure();
- }
-
- OwningRewritePatternList patterns(&getContext());
- ConversionTarget target(*context);
- target.addLegalDialect<StandardOpsDialect>();
- target.addLegalDialect<IREE::VMLA::VMLADialect>();
- target.addLegalDialect<mhlo::MhloDialect>();
- target.addLegalDialect<ShapeDialect>();
-
- target.addIllegalOp<mhlo::DotGeneralOp>();
- patterns.insert<LowerDotGeneralOp>(context);
- target.addIllegalOp<mhlo::DotOp>();
- patterns.insert<LowerDotOp>(context);
- target.addIllegalOp<mhlo::BroadcastInDimOp>();
- patterns.insert<LowerBroadcastInDimOp>(context);
- target.addIllegalOp<mhlo::BroadcastOp>();
- patterns.insert<LowerBroadcastOp>(context);
- target.addIllegalOp<mhlo::SortOp>();
- patterns.insert<LowerSortOp>(context);
- target.addIllegalOp<mhlo::FftOp>();
- patterns.insert<LowerFftOp>(context);
-
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns)))) {
- return signalPassFailure();
- }
-
- {
- OwningRewritePatternList greedyPatterns(&getContext());
- greedyPatterns.insert<CanonicalizeTranspose>(context);
- if (failed(applyPatternsAndFoldGreedily(getOperation(),
- std::move(greedyPatterns)))) {
- return signalPassFailure();
- }
- }
- }
-};
-
-static PassRegistration<PreConversionLoweringPass> pass(
- "iree-vmla-pre-conversion-lowering",
- "Tensor-level pattern-based lowerings.");
-
-} // namespace
-
-std::unique_ptr<OperationPass<FuncOp>> createPreConversionLoweringPass() {
- return std::make_unique<PreConversionLoweringPass>();
-}
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Transforms/UnrollReductions.cpp b/iree/compiler/Dialect/VMLA/Transforms/UnrollReductions.cpp
deleted file mode 100644
index 3208947..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/UnrollReductions.cpp
+++ /dev/null
@@ -1,90 +0,0 @@
-// Copyright 2020 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-namespace iree_compiler {
-namespace IREE {
-namespace VMLA {
-
-namespace {
-
-// Unrolls a multi-dimensional mhlo.reduce op into one mhlo.reduce op per
-// dimension. The XLA operation semantics state that this is a valid
-// transformation.
-void unrollReduceOp(mhlo::ReduceOp reduceOp) {
- // Create one op per dimension being reduced.
- // We'll do this by chaining the original input through with the temporary
- // reduction results. The results we end up with will be the originally
- // requested shape and we can just substitute them.
- SmallVector<int64_t, 4> sortedDimensions{
- reduceOp.dimensions().getValues<int64_t>()};
- llvm::sort(sortedDimensions,
- [](int64_t a, int64_t b) { return (a - b) > 0; });
-
- // Insert at the same place as the original op.
- OpBuilder builder(reduceOp);
- SmallVector<Value, 4> temps{reduceOp.inputs()};
- for (int64_t dimension : sortedDimensions) {
- // Create the new reduction using the results of the previous operation.
- auto singleAttrType =
- RankedTensorType::get({1}, builder.getIntegerType(64));
- auto singleReduceOp = builder.create<mhlo::ReduceOp>(
- reduceOp.getLoc(), temps, reduceOp.init_values(),
- DenseIntElementsAttr::get(singleAttrType, {dimension}));
- BlockAndValueMapping mapping;
- reduceOp.body().cloneInto(&singleReduceOp.body(), mapping);
- temps = singleReduceOp.getResults();
- }
-
- // Replace uses of the existing results with the new results.
- reduceOp.replaceAllUsesWith(temps);
-
- // Erase original op.
- reduceOp.erase();
-}
-
-} // namespace
-
-class UnrollReductionsPass
- : public PassWrapper<UnrollReductionsPass, FunctionPass> {
- public:
- void runOnFunction() override {
- for (auto &block : getFunction()) {
- auto reduceOps = llvm::to_vector<4>(block.getOps<mhlo::ReduceOp>());
- for (auto reduceOp : reduceOps) {
- if (reduceOp.dimensions().getNumElements() > 1) {
- unrollReduceOp(reduceOp);
- }
- }
- }
- }
-};
-
-std::unique_ptr<OperationPass<FuncOp>> createUnrollReductionsPass() {
- return std::make_unique<UnrollReductionsPass>();
-}
-
-static PassRegistration<UnrollReductionsPass> pass(
- "iree-vmla-unroll-reductions",
- "Unrolls multi-dimensional reductions to one reduction per dimension.");
-
-} // namespace VMLA
-} // namespace IREE
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/BUILD b/iree/compiler/Dialect/VMLA/Transforms/test/BUILD
deleted file mode 100644
index 8e124ba..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/test/BUILD
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-load("//iree:lit_test.bzl", "iree_lit_test_suite")
-load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-iree_lit_test_suite(
- name = "lit",
- srcs = enforce_glob(
- [
- "pre_conversion_lowering.mlir",
- "transformation.mlir",
- "unroll_reductions.mlir",
- ],
- include = ["*.mlir"],
- ),
- data = [
- "//iree/tools:IreeFileCheck",
- "//iree/tools:iree-opt",
- ],
-)
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/VMLA/Transforms/test/CMakeLists.txt
deleted file mode 100644
index 4a84f85..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/test/CMakeLists.txt
+++ /dev/null
@@ -1,25 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# iree/compiler/Dialect/VMLA/Transforms/test/BUILD #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_lit_test_suite(
- NAME
- lit
- SRCS
- "pre_conversion_lowering.mlir"
- "transformation.mlir"
- "unroll_reductions.mlir"
- DATA
- iree::tools::IreeFileCheck
- iree::tools::iree-opt
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
deleted file mode 100644
index 1e0e48b..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
+++ /dev/null
@@ -1,136 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering %s | IreeFileCheck %s
-
-// -----
-
-// CHECK-LABEL: func @dot_general_float
-func @dot_general_float(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> {
- // CHECK: vmla.batch.matmul
- %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {
- lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
- lhs_contracting_dimensions = dense<[1]> : tensor<1xi64>,
- rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
- rhs_contracting_dimensions = dense<[0]> : tensor<1xi64>
- }} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>
- return %0 : tensor<3x5xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @dot_mixed_element_types
-func @dot_mixed_element_types(%arg0: tensor<2x3xi8>, %arg1: tensor<3x2xi16>) -> tensor<2x2xi32> {
- // CHECK: vmla.batch.matmul.pseudo %{{[a-zA-Z0-9$._-]+}}, %{{[a-zA-Z0-9$._-]+}} : (tensor<?x?x?xi8>, tensor<?x?x?xi16>) -> tensor<?x?x?xi32>
- %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xi8>, tensor<3x2xi16>) -> tensor<2x2xi32>
- return %0 : tensor<2x2xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func private @sort
-func private @sort(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
- // CHECK-DAG: [[SORT:%.+]] = vmla.sort.pseudo %arg0
- // CHECK-DAG: [[GATHER:%.+]] = "mhlo.torch_index_select"(%arg0, [[SORT]]) {batch_dims = 0 : i64, dim = 0 : i64}
- %sort = "mhlo.sort"(%arg0) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
- %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
- "mhlo.return"(%compare) : (tensor<i1>) -> ()
- }) {dimension = 0 : i64, is_stable = false} : (tensor<4xf32>) -> tensor<4xf32>
-
- // CHECK: return [[GATHER]]
- return %sort : tensor<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func private @sort
-func private @sort_2d(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
- // CHECK-DAG: [[SORT:%.+]] = vmla.sort.pseudo %arg0
- // CHECK-DAG: [[GATHER:%.+]] = "mhlo.torch_index_select"(%arg0, [[SORT]]) {batch_dims = 1 : i64, dim = 1 : i64}
- %sort = "mhlo.sort"(%arg0) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
- %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
- "mhlo.return"(%compare) : (tensor<i1>) -> ()
- }) {dimension = 1 : i64, is_stable = false} : (tensor<4x4xf32>) -> tensor<4x4xf32>
-
- // CHECK return [[GATHER]]
- return %sort : tensor<4x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @broadcast_in_dim
-func @broadcast_in_dim(%arg0: tensor<3xf32>) -> tensor<4x3xf32> {
- // CHECK: "shapex.ranked_broadcast_in_dim"(%arg0, %rs4_3)
- %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
- return %0 : tensor<4x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @ranked_broadcast_in_dim
-func @ranked_broadcast_in_dim(%arg0: tensor<3xf32>) -> tensor<5x6x3xf32> {
- // CHECK: "shapex.ranked_broadcast_in_dim"(%arg0, %rs5_6_3)
- %0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[5, 6]> : tensor<2xi64>} : (tensor<3xf32>) -> tensor<5x6x3xf32>
- return %0 : tensor<5x6x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func private @fft
-func private @fft(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
- // CHECK-DAG: [[REAL:%.+]] = "mhlo.real"(%arg0)
- // CHECK-DAG: [[IMAG:%.+]] = "mhlo.imag"(%arg0)
- // CHECK-DAG: [[REAL_OUT:%.+]], [[IMAG_OUT:%.+]] = vmla.fft.pseudo [[REAL]], [[IMAG]]
- // CHECK: "mhlo.complex"([[REAL_OUT]], [[IMAG_OUT]])
- %0 = "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "FFT"} : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
- return %0 : tensor<8xcomplex<f32>>
-}
-
-// -----
-
-// CHECK-LABEL: func private @ifft
-func private @ifft(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
- // CHECK-DAG: [[REAL:%.+]] = "mhlo.real"(%arg0)
- // CHECK-DAG: [[IMAG:%.+]] = "mhlo.imag"(%arg0)
- // CHECK-DAG: [[REAL_OUT:%.+]], [[IMAG_OUT:%.+]] = vmla.ifft.pseudo [[REAL]], [[IMAG]]
- // CHECK: "mhlo.complex"([[REAL_OUT]], [[IMAG_OUT]])
- %0 = "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "IFFT"} : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
- return %0 : tensor<8xcomplex<f32>>
-}
-
-// -----
-
-// CHECK-LABEL: func private @rfft
-func private @rfft(%arg0: tensor<8xf32>) -> tensor<5xcomplex<f32>> {
- // CHECK-DAG: [[REAL_OUT:%.+]], [[IMAG_OUT:%.+]] = vmla.rfft.pseudo %arg0
- // CHECK: "mhlo.complex"([[REAL_OUT]], [[IMAG_OUT]])
- %0 = "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32>) -> tensor<5xcomplex<f32>>
- return %0 : tensor<5xcomplex<f32>>
-}
-
-// -----
-
-// CHECK-LABEL: func private @irfft
-func private @irfft(%arg0: tensor<5xcomplex<f32>>) -> tensor<8xf32> {
- // CHECK-DAG: [[REAL:%.+]] = "mhlo.real"(%arg0)
- // CHECK-DAG: [[IMAG:%.+]] = "mhlo.imag"(%arg0)
- // CHECK-DAG: [[REAL_OUT:%.+]] = vmla.irfft.pseudo [[REAL]], [[IMAG]]
- %0 = "mhlo.fft"(%arg0) {fft_length = dense<5> : tensor<1xi64>, fft_type = "IRFFT"} : (tensor<5xcomplex<f32>>) -> tensor<8xf32>
- return %0 : tensor<8xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @complex_multiply
-func @complex_multiply(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
- // CHECK-NOT: "mhlo.complex"
- %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
-
- // CHECK-DAG: [[V1:%.+]] = mhlo.multiply %arg0, %arg0
- // CHECK-DAG: [[V2:%.+]] = mhlo.multiply %arg1, %arg1
- // CHECK-DAG: [[V3:%.+]] = mhlo.subtract [[V1]], [[V2]]
- %1 = "mhlo.multiply"(%0, %0) : (tensor<3xcomplex<f32>>, tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>>
- %2 = "mhlo.real"(%1) : (tensor<3xcomplex<f32>>) -> tensor<3xf32>
-
- // CHECK: return [[V3]]
- return %2 : tensor<3xf32>
-}
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/transformation.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/transformation.mlir
deleted file mode 100644
index 5e7a54b..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/test/transformation.mlir
+++ /dev/null
@@ -1,29 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-transformation-pipeline %s | IreeFileCheck %s
-
-func @simpleMath_rgn_dispatch_0() {
- %c0 = constant 0 : index
- %0 = hal.interface.load.tensor @io::@arg0, offset = %c0 : tensor<4xf32>
- %1 = call @simpleMath_rgn_dispatch_0_impl(%0) : (tensor<4xf32>) -> tensor<4xf32>
- hal.interface.store.tensor %1, @io::@ret0, offset = %c0 : tensor<4xf32>
- return
-}
-func private @simpleMath_rgn_dispatch_0_impl(%arg0: tensor<4xf32>) -> tensor<4xf32> {
- %0 = mhlo.add %arg0, %arg0 : tensor<4xf32>
- return %0 : tensor<4xf32>
-}
-hal.interface @io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
-}
-
-// CHECK: func @simpleMath_rgn_dispatch_0(%arg0: !vmla.interface, %arg1: index, %arg2: index, %arg3: index) {
-// CHECK-DAG: %c0 = constant 0 : index
-// CHECK-DAG: %c16 = constant 16 : index
-// CHECK-NEXT: %0 = vmla.interface.binding %arg0 {binding = 0 : i32, set = 0 : i32} : !vmla.buffer
-// CHECK-NEXT: %1 = vmla.buffer.view %0[%c0], byte_length = %c16 : !vmla.buffer
-// CHECK-NEXT: %2 = vmla.buffer.alloc byte_length = %c16 : !vmla.buffer
-// CHECK-NEXT: vmla.add %1, %1, out %2 : f32
-// CHECK-NEXT: %3 = vmla.interface.binding %arg0 {binding = 1 : i32, set = 0 : i32} : !vmla.buffer
-// CHECK-NEXT: vmla.buffer.copy %2[%c0], out %3[%c0], byte_length = %c16
-// CHECK-NEXT: return
-// CHECK-NEXT: }
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/unroll_reductions.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/unroll_reductions.mlir
deleted file mode 100644
index 8b89122..0000000
--- a/iree/compiler/Dialect/VMLA/Transforms/test/unroll_reductions.mlir
+++ /dev/null
@@ -1,24 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-vmla-unroll-reductions -cse %s | IreeFileCheck %s
-
-// CHECK-LABEL: func @unrolled_reduction
-func @unrolled_reduction(%arg0: tensor<4x2x8xf32>) -> tensor<4xf32> {
- // CHECK-DAG: %[[INITIAL:.+]] = constant dense<0.000000e+00> : tensor<f32>
- %cst = constant dense<0.000000e+00> : tensor<f32>
- // CHECK-NEXT: %[[TEMP:.+]] = "mhlo.reduce"(%arg0, %[[INITIAL]]) ( {
- // CHECK-NEXT: ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
- // CHECK-NEXT: %2 = mhlo.add %arg1, %arg2 : tensor<f32>
- // CHECK-NEXT: "mhlo.return"(%2) : (tensor<f32>) -> ()
- // CHECK-NEXT: }) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x2x8xf32>, tensor<f32>) -> tensor<4x2xf32>
- // CHECK-NEXT: %[[RESULT:.+]] = "mhlo.reduce"(%[[TEMP]], %[[INITIAL]]) ( {
- // CHECK-NEXT: ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
- // CHECK-NEXT: %2 = mhlo.add %arg1, %arg2 : tensor<f32>
- // CHECK-NEXT: "mhlo.return"(%2) : (tensor<f32>) -> ()
- // CHECK-NEXT: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x2xf32>, tensor<f32>) -> tensor<4xf32>
- %0 = "mhlo.reduce"(%arg0, %cst) ( {
- ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
- %1 = mhlo.add %arg1, %arg2 : tensor<f32>
- "mhlo.return"(%1) : (tensor<f32>) -> ()
- }) {dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2x8xf32>, tensor<f32>) -> tensor<4xf32>
- // CHECK-NEXT: return %[[RESULT]]
- return %0 : tensor<4xf32>
-}
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
deleted file mode 100644
index 598afb3..0000000
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ /dev/null
@@ -1,625 +0,0 @@
-// IREE VMLA (Virtual Machine-based Linear Algebra) runtime module imports.
-//
-// This is embedded in the compiler binary and inserted into any module
-// containing VMLA dialect ops (vmla.*) that is lowered to the VM dialect.
-//
-// Element types are embedded in the function. The convention used:
-// * 'x': don't-care, bit-depth only.
-// * 'i': signed integer
-// * 'u': unsigned integer
-// * 'f': IREE float
-//
-// The native module does not need shapes in many cases and only ops that
-// actually use the shape information take it as arguments.
-//
-// When adding methods try to first reuse existing ones. For example, unrolling
-// a memcpy to a sequence of vmla.buffer.copy calls (or a loop of them) instead
-// of adding a my_batch_copy method.
-vm.module @vmla {
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: ABI
-//===----------------------------------------------------------------------===//
-
-vm.import @interface.const(
- %interface : !vm.ref<!vmla.interface>,
- %offset : i32
-) -> i32
-attributes {nosideeffects}
-
-vm.import @interface.binding(
- %interface : !vm.ref<!vmla.interface>,
- %set : i32,
- %binding : i32
-) -> !vm.ref<!vmla.buffer>
-attributes {nosideeffects}
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: buffer manipulation
-//===----------------------------------------------------------------------===//
-
-vm.import @buffer.const(
- %value : !vm.buffer
-) -> !vm.ref<!vmla.buffer>
-attributes {nosideeffects}
-
-vm.import @buffer.alloc(
- %byte_length : i32
-) -> !vm.ref<!vmla.buffer>
-attributes {nosideeffects}
-
-vm.import @buffer.clone(
- %src : !vm.ref<!vmla.buffer>
-) -> !vm.ref<!vmla.buffer>
-attributes {nosideeffects}
-
-vm.import @buffer.byte_length(
- %value : !vm.ref<!vmla.buffer>
-) -> i32
-attributes {nosideeffects}
-
-vm.import @buffer.view(
- %src : !vm.ref<!vmla.buffer>,
- %byte_offset : i32,
- %byte_length : i32
-) -> !vm.ref<!vmla.buffer>
-attributes {nosideeffects}
-
-vm.import @buffer.copy(
- %src : !vm.ref<!vmla.buffer>, %src_byte_offset : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_byte_offset : i32,
- %byte_length : i32
-)
-
-vm.import @buffer.fill(
- %value : !vm.ref<!vmla.buffer>,
- %dst : !vm.ref<!vmla.buffer>
-)
-
-vm.import @buffer.load.i32(
- %src : !vm.ref<!vmla.buffer>,
- %byte_offset : i32
-) -> i32
-attributes {nosideeffects}
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: comparison
-//===----------------------------------------------------------------------===//
-
-vm.import @cmp.i8(%predicate : i32, %lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @cmp.i16(%predicate : i32, %lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @cmp.i32(%predicate : i32, %lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @cmp.f32(%predicate : i32, %lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-
-vm.import @select.x8(%cond : !vm.ref<!vmla.buffer>, %lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @select.x16(%cond : !vm.ref<!vmla.buffer>, %lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @select.x32(%cond : !vm.ref<!vmla.buffer>, %lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-
-vm.import @finite.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: shape/structure
-//===----------------------------------------------------------------------===//
-
-// TODO(benvanik): do the copies with buffer.copy instead and leave the offset
-// calculations in the IR for the compiler to simplify.
-vm.import @copy.x8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., %src_indices : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., %dst_indices : i32 ...,
- %lengths : i32 ...
-)
-vm.import @copy.x16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., %src_indices : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., %dst_indices : i32 ...,
- %lengths : i32 ...
-)
-vm.import @copy.x32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ..., %src_indices : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ..., %dst_indices : i32 ...,
- %lengths : i32 ...
-)
-
-vm.import @transpose.x8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %permutation : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @transpose.x16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %permutation : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @transpose.x32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %permutation : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-vm.import @reverse.x8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dimensions : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @reverse.x16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dimensions : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @reverse.x32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dimensions : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-vm.import @pad.x8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %value : !vm.ref<!vmla.buffer>, %value_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %edge_padding_low : i32 ...,
- %edge_padding_high : i32 ...,
- %interior_padding : i32 ...
-)
-vm.import @pad.x16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %value : !vm.ref<!vmla.buffer>, %value_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %edge_padding_low : i32 ..., %edge_padding_high : i32 ...,
- %interior_padding : i32 ...
-)
-vm.import @pad.x32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %value : !vm.ref<!vmla.buffer>, %value_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %edge_padding_low : i32 ..., %edge_padding_high : i32 ...,
- %interior_padding : i32 ...
-)
-vm.import @gather.x8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %indices : !vm.ref<!vmla.buffer>, %indices_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %dim : i32, %batch_dims : i32
-)
-vm.import @gather.x16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %indices : !vm.ref<!vmla.buffer>, %indices_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %dim : i32, %batch_dims : i32
-)
-vm.import @gather.x32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %indices : !vm.ref<!vmla.buffer>, %indices_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %dim : i32, %batch_dims : i32
-)
- vm.import @scatter.x8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %indices : !vm.ref<!vmla.buffer>, %indices_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @scatter.x16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %indices : !vm.ref<!vmla.buffer>, %indices_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @scatter.x32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %indices : !vm.ref<!vmla.buffer>, %indices_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @broadcast.x8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @broadcast.x16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @broadcast.x32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-vm.import @iota.i8(%dst : !vm.ref<!vmla.buffer>)
-vm.import @iota.i16(%dst : !vm.ref<!vmla.buffer>)
-vm.import @iota.i32(%dst : !vm.ref<!vmla.buffer>)
-vm.import @iota.f32(%dst : !vm.ref<!vmla.buffer>)
-
-vm.import @tile.x8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @tile.x16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @tile.x32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: bit manipulation
-//===----------------------------------------------------------------------===//
-
-vm.import @not.x8(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @not.x16(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @not.x32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @and.x8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @and.x16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @and.x32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @and.broadcast.x8(%lhs : !vm.ref<!vmla.buffer>, %rhs : i32, %dst : !vm.ref<!vmla.buffer>)
-vm.import @and.broadcast.x16(%lhs : !vm.ref<!vmla.buffer>, %rhs : i32, %dst : !vm.ref<!vmla.buffer>)
-vm.import @and.broadcast.x32(%lhs : !vm.ref<!vmla.buffer>, %rhs : i32, %dst : !vm.ref<!vmla.buffer>)
-vm.import @or.x8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @or.x16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @or.x32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @xor.x8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @xor.x16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @xor.x32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @xor.broadcast.x8(%lhs : !vm.ref<!vmla.buffer>, %rhs : i32, %dst : !vm.ref<!vmla.buffer>)
-vm.import @xor.broadcast.x16(%lhs : !vm.ref<!vmla.buffer>, %rhs : i32, %dst : !vm.ref<!vmla.buffer>)
-vm.import @xor.broadcast.x32(%lhs : !vm.ref<!vmla.buffer>, %rhs : i32, %dst : !vm.ref<!vmla.buffer>)
-vm.import @shl.x8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @shl.x16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @shl.x32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @shr.u8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @shr.u16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @shr.u32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @shr.i8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @shr.i16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @shr.i32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: arithmetic
-//===----------------------------------------------------------------------===//
-
-vm.import @add.i8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @add.i16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @add.i32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @add.f32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @sub.i8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @sub.i16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @sub.i32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @sub.f32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @abs.i8(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @abs.i16(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @abs.i32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @abs.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @neg.i8(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @neg.i16(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @neg.i32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @neg.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @mul.i8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @mul.i16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @mul.i32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @mul.f32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @div.i8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @div.i16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @div.i32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @div.u8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @div.u16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @div.u32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @div.f32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @rem.i8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @rem.i16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @rem.i32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @rem.u8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @rem.u16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @rem.u32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @rem.f32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @pow.f32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @exp.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @log.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @rsqrt.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @sqrt.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @cos.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @sin.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @tanh.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @atan2.f32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-
-vm.import @min.i8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @min.i16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @min.i32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @min.f32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @max.i8(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @max.i16(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @max.i32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @max.f32(%lhs : !vm.ref<!vmla.buffer>, %rhs : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @clamp.i8(%min : !vm.ref<!vmla.buffer>, %value : !vm.ref<!vmla.buffer>, %max : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @clamp.i16(%min : !vm.ref<!vmla.buffer>, %value : !vm.ref<!vmla.buffer>, %max : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @clamp.i32(%min : !vm.ref<!vmla.buffer>, %value : !vm.ref<!vmla.buffer>, %max : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @clamp.f32(%min : !vm.ref<!vmla.buffer>, %value : !vm.ref<!vmla.buffer>, %max : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @floor.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @ceil.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @round.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-
-
-vm.import @sort.i8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>)
-vm.import @sort.i16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>)
-vm.import @sort.i32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>)
-vm.import @sort.f32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>)
-
-vm.import @fft.f32(
- %real_src : !vm.ref<!vmla.buffer>, %real_src_shape : i32 ...,
- %imag_src : !vm.ref<!vmla.buffer>, %imag_src_shape : i32 ...,
- %real_dst : !vm.ref<!vmla.buffer>,
- %imag_dst : !vm.ref<!vmla.buffer>)
-
-vm.import @ifft.f32(
- %real_src : !vm.ref<!vmla.buffer>, %real_src_shape : i32 ...,
- %imag_src : !vm.ref<!vmla.buffer>, %imag_src_shape : i32 ...,
- %real_dst : !vm.ref<!vmla.buffer>,
- %imag_dst : !vm.ref<!vmla.buffer>)
-
-vm.import @rfft.f32(
- %real_src : !vm.ref<!vmla.buffer>, %real_src_shape : i32 ...,
- %real_dst : !vm.ref<!vmla.buffer>,
- %imag_dst : !vm.ref<!vmla.buffer>)
-
-vm.import @irfft.f32(
- %real_src : !vm.ref<!vmla.buffer>, %real_src_shape : i32 ...,
- %imag_src : !vm.ref<!vmla.buffer>, %imag_src_shape : i32 ...,
- %real_dst : !vm.ref<!vmla.buffer>)
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: conversion
-//===----------------------------------------------------------------------===//
-
-vm.import @convert.i8.i16(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @convert.i8.i32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @convert.i8.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @convert.i16.i8(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @convert.i16.i32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @convert.i16.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @convert.i32.i8(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @convert.i32.i16(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @convert.i32.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @convert.f32.i8(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @convert.f32.i16(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-vm.import @convert.f32.i32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: Convolution
-//===----------------------------------------------------------------------===//
-
-vm.import @conv.f32f32.f32(
- %input: !vm.ref<!vmla.buffer>, %input_shape: i32 ...,
- %filter: !vm.ref<!vmla.buffer>, %filter_shape: i32 ...,
- %dst: !vm.ref<!vmla.buffer>, %dst_shape: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...,
- %lhs_dilation: i32 ...,
- %rhs_dilation: i32 ...,
- %feature_group_count: i32,
- %batch_group_count: i32
-)
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: GEMM/GEMV
-//===----------------------------------------------------------------------===//
-
-vm.import @batch.matmul.f32f32.f32(
- %lhs : !vm.ref<!vmla.buffer>, %lhs_shape : i32 ...,
- %rhs : !vm.ref<!vmla.buffer>, %rhs_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-vm.import @batch.matmul.i32i32.i32(
- %lhs : !vm.ref<!vmla.buffer>, %lhs_shape : i32 ...,
- %rhs : !vm.ref<!vmla.buffer>, %rhs_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-vm.import @batch.matmul.i8i8.i32(
- %lhs : !vm.ref<!vmla.buffer>, %lhs_shape : i8 ...,
- %rhs : !vm.ref<!vmla.buffer>, %rhs_shape : i8 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-vm.import @batch.matmul.i16i16.i32(
- %lhs : !vm.ref<!vmla.buffer>, %lhs_shape : i16 ...,
- %rhs : !vm.ref<!vmla.buffer>, %rhs_shape : i16 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-//===----------------------------------------------------------------------===//
-// VMLA Ops: reduction
-//===----------------------------------------------------------------------===//
-
-vm.import @reduce.sum.i8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @reduce.sum.i16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @reduce.sum.i32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @reduce.sum.f32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-vm.import @reduce.min.i8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @reduce.min.i16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @reduce.min.i32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @reduce.min.f32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-vm.import @reduce.max.i8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @reduce.max.i16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @reduce.max.i32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-vm.import @reduce.max.f32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-vm.import @reduce.and.i8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-vm.import @reduce.or.i8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dimension : i32,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...
-)
-
-vm.import @pooling.sum.i8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-vm.import @pooling.sum.i16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-vm.import @pooling.sum.i32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-vm.import @pooling.sum.f32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-
-vm.import @pooling.min.i8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-vm.import @pooling.min.i16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-vm.import @pooling.min.i32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-vm.import @pooling.min.f32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-
-vm.import @pooling.max.i8(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-vm.import @pooling.max.i16(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-vm.import @pooling.max.i32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-vm.import @pooling.max.f32(
- %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
- %init : !vm.ref<!vmla.buffer>, %init_shape : i32 ...,
- %dst : !vm.ref<!vmla.buffer>, %dst_shape : i32 ...,
- %window_dimensions: i32 ...,
- %window_strides: i32 ...,
- %padding: i32 ...
-)
-
-} // module
diff --git a/iree/schemas/BUILD b/iree/schemas/BUILD
index c2e57d0..4fc782c 100644
--- a/iree/schemas/BUILD
+++ b/iree/schemas/BUILD
@@ -58,12 +58,6 @@
flatcc_args = FLATCC_ARGS,
)
-iree_flatbuffer_c_library(
- name = "vmla_executable_def_c_fbs",
- srcs = ["vmla_executable_def.fbs"],
- flatcc_args = FLATCC_ARGS,
-)
-
iree_build_test(
name = "schema_build_test",
targets = [
@@ -71,6 +65,5 @@
":dylib_executable_def_c_fbs",
":metal_executable_def_c_fbs",
":spirv_executable_def_c_fbs",
- ":vmla_executable_def_c_fbs",
],
)
diff --git a/iree/schemas/CMakeLists.txt b/iree/schemas/CMakeLists.txt
index 8dd5049..bc31d95 100644
--- a/iree/schemas/CMakeLists.txt
+++ b/iree/schemas/CMakeLists.txt
@@ -75,17 +75,4 @@
PUBLIC
)
-flatbuffer_c_library(
- NAME
- vmla_executable_def_c_fbs
- SRCS
- "vmla_executable_def.fbs"
- FLATCC_ARGS
- "--reader"
- "--builder"
- "--verifier"
- "--json"
- PUBLIC
-)
-
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/schemas/vmla_executable_def.fbs b/iree/schemas/vmla_executable_def.fbs
deleted file mode 100644
index e85972c..0000000
--- a/iree/schemas/vmla_executable_def.fbs
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-namespace iree;
-
-// 'VMLA Executable'.
-file_identifier "VMLA";
-file_extension "vmla";
-
-// A VMLA (VM-based Linear Algebra) executable module.
-// This executable contains the VM bytecode module and additional metadata
-// required at runtime to associate the descriptor set bindings with the VM
-// exported functions.
-table VMLAExecutableDef {
- // Embedded BytecodeModuleDef flatbuffer.
- // We embed the entire flatbuffer contents opaquely to allow for easier
- // manipulation of the files (such that we can just slice out the bytes and
- // dump them to a file, etc).
- bytecode_module:[ubyte];
-}
-
-root_type VMLAExecutableDef;
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 736ec21..7ff2cda 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -98,8 +98,6 @@
"//iree/compiler/Dialect/VM/Analysis",
"//iree/compiler/Dialect/VM/IR",
"//iree/compiler/Dialect/VM/Transforms",
- "//iree/compiler/Dialect/VMLA/IR:VMLADialect",
- "//iree/compiler/Dialect/VMLA/Transforms",
"//iree/compiler/Dialect/Vulkan/IR",
"//iree/compiler/Translation:IREEVM",
"@llvm-project//mlir:IR",
@@ -176,7 +174,6 @@
"IREE_HAVE_CUDA_TARGET",
"IREE_HAVE_LLVMAOT_TARGET",
"IREE_HAVE_METALSPIRV_TARGET",
- "IREE_HAVE_VMLA_TARGET",
"IREE_HAVE_VMVX_TARGET",
"IREE_HAVE_VULKANSPIRV_TARGET",
],
@@ -184,7 +181,6 @@
"//iree/compiler/Dialect/HAL/Target/CUDA",
"//iree/compiler/Dialect/HAL/Target/LLVM",
"//iree/compiler/Dialect/HAL/Target/MetalSPIRV",
- "//iree/compiler/Dialect/HAL/Target/VMLA",
"//iree/compiler/Dialect/HAL/Target/VMVX",
"//iree/compiler/Dialect/HAL/Target/VulkanSPIRV",
],
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index 08de1af..b8bed15 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -29,10 +29,6 @@
list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::MetalSPIRV)
list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_METALSPIRV_TARGET")
endif()
-if("${IREE_TARGET_BACKEND_VMLA}")
- list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::VMLA)
- list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_VMLA_TARGET")
-endif()
if("${IREE_TARGET_BACKEND_VMVX}")
list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::VMVX)
list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_VMVX_TARGET")
@@ -172,8 +168,6 @@
iree::compiler::Dialect::VM::Analysis
iree::compiler::Dialect::VM::IR
iree::compiler::Dialect::VM::Transforms
- iree::compiler::Dialect::VMLA::IR::VMLADialect
- iree::compiler::Dialect::VMLA::Transforms
iree::compiler::Dialect::Vulkan::IR
iree::compiler::Translation::IREEVM
PUBLIC
diff --git a/iree/tools/init_iree_dialects.h b/iree/tools/init_iree_dialects.h
index 42a39bc..6475322 100644
--- a/iree/tools/init_iree_dialects.h
+++ b/iree/tools/init_iree_dialects.h
@@ -26,7 +26,6 @@
#include "iree/compiler/Dialect/Modules/VMVX/IR/VMVXDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/VM/IR/VMDialect.h"
-#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
#include "mlir/IR/Dialect.h"
@@ -41,7 +40,6 @@
ShapeDialect,
IREEDialect,
IREE::VM::VMDialect,
- IREE::VMLA::VMLADialect,
IREE::VMVX::VMVXDialect,
IREE::Vulkan::VulkanDialect>();
// clang-format on
diff --git a/iree/tools/init_iree_passes.h b/iree/tools/init_iree_passes.h
index cea5739..b52f61b 100644
--- a/iree/tools/init_iree_passes.h
+++ b/iree/tools/init_iree_passes.h
@@ -33,7 +33,6 @@
#include "iree/compiler/Dialect/Shape/Transforms/Passes.h"
#include "iree/compiler/Dialect/VM/Analysis/TestPasses.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
-#include "iree/compiler/Dialect/VMLA/Transforms/Passes.h"
#include "iree/compiler/Translation/IREEVM.h"
namespace mlir {
@@ -56,7 +55,6 @@
IREE::VM::registerVMPasses();
IREE::VM::registerVMAnalysisTestPasses();
IREE::VM::registerVMTestPasses();
- IREE::VMLA::registerVMLAPasses();
IREE::VMVX::registerVMVXPasses();
registerIREEVMTransformPassPipeline();
}
diff --git a/iree/tools/init_targets.cc b/iree/tools/init_targets.cc
index 903ce78..0143d06 100644
--- a/iree/tools/init_targets.cc
+++ b/iree/tools/init_targets.cc
@@ -23,9 +23,6 @@
#ifdef IREE_HAVE_METALSPIRV_TARGET
#include "iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.h"
#endif // IREE_HAVE_METALSPIRV_TARGET
-#ifdef IREE_HAVE_VMLA_TARGET
-#include "iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.h"
-#endif // IREE_HAVE_VMLA_TARGET
#ifdef IREE_HAVE_VMVX_TARGET
#include "iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.h"
#endif // IREE_HAVE_VMVX_TARGET
@@ -55,10 +52,6 @@
IREE::HAL::registerMetalSPIRVTargetBackends(
[]() { return IREE::HAL::getMetalSPIRVTargetOptionsFromFlags(); });
#endif // IREE_HAVE_METALSPIRV_TARGET
-#ifdef IREE_HAVE_VMLA_TARGET
- IREE::HAL::registerVMLATargetBackends(
- []() { return IREE::HAL::getVMLATargetOptionsFromFlags(); });
-#endif // IREE_HAVE_VMLA_TARGET
#ifdef IREE_HAVE_VMVX_TARGET
IREE::HAL::registerVMVXTargetBackends(
[]() { return IREE::HAL::getVMVXTargetOptionsFromFlags(); });