blob: f5bc7c1ad34ccb4df1788d8e9d9263453f626d56 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/IREE/IR/IREETypes.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLADialect.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLAOps.h"
#include "iree/compiler/Dialect/VMLA/IR/VMLATypes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitVector.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
namespace {
struct VMLAConvOpConverter : public OpConversionPattern<mhlo::ConvOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConvOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (op.dimension_numbers()) {
const auto dimensionNumbers = op.dimension_numbers();
const int inputSpatialRank =
std::distance(dimensionNumbers.input_spatial_dimensions().begin(),
dimensionNumbers.input_spatial_dimensions().end());
if (inputSpatialRank != 2) {
op.emitWarning() << "Only lowering 2D conv is supported";
return failure();
}
// Input storage order is N,spatial_dims...,Ci.
if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
dimensionNumbers.input_feature_dimension().getInt() !=
(inputSpatialRank + 1)) {
op.emitWarning()
<< "Could not lower conv op due to inconsistant storage type";
return failure();
}
const int kernelSpatialRank =
std::distance(dimensionNumbers.kernel_spatial_dimensions().begin(),
dimensionNumbers.kernel_spatial_dimensions().end());
// Filter storage order is spatial_dims...,C, Co.
if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
kernelSpatialRank ||
dimensionNumbers.kernel_output_feature_dimension().getInt() !=
(kernelSpatialRank + 1))
return failure();
const int outputSpatialRank =
std::distance(dimensionNumbers.output_spatial_dimensions().begin(),
dimensionNumbers.output_spatial_dimensions().end());
// Output storage order is N,spatial_dims..,Co.
if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
dimensionNumbers.output_feature_dimension().getInt() !=
(outputSpatialRank + 1))
return failure();
if (inputSpatialRank != outputSpatialRank ||
inputSpatialRank != kernelSpatialRank)
return failure();
auto inputSpatialDim =
dimensionNumbers.input_spatial_dimensions().begin();
auto kernelSpatialDim =
dimensionNumbers.kernel_spatial_dimensions().begin();
auto outputSpatialDim =
dimensionNumbers.output_spatial_dimensions().begin();
// Check spatial dims are ordred correctly.
for (int i = 0; i < inputSpatialRank; ++i) {
const int dim = i + 1;
if ((*inputSpatialDim++).getZExtValue() != dim ||
(*outputSpatialDim++).getZExtValue() != dim ||
(*kernelSpatialDim++).getZExtValue() != i)
return failure();
}
}
auto inputShape = VMLAConversionTarget::getTensorShape(
op.getLoc(), op.lhs(), *getTypeConverter(), rewriter);
auto filterShape = VMLAConversionTarget::getTensorShape(
op.getLoc(), op.rhs(), *getTypeConverter(), rewriter);
auto dstShape = VMLAConversionTarget::getTensorShape(
op.getLoc(), op.getResult(), *getTypeConverter(), rewriter);
auto dst = VMLAConversionTarget::allocateOutputBuffer(
op.getLoc(), op.getResult(), *getTypeConverter(), rewriter);
auto lhsType =
TypeAttr::get(op.lhs().getType().cast<ShapedType>().getElementType());
auto rhsType =
TypeAttr::get(op.lhs().getType().cast<ShapedType>().getElementType());
SmallVector<int32_t, 4> windowStrides{1, 1};
SmallVector<int32_t, 4> padding{0, 0, 0, 0};
SmallVector<int32_t, 4> lhsDilation{1, 1};
SmallVector<int32_t, 4> rhsDilation{1, 1};
int32_t featureGroupCount = op.feature_group_count();
int32_t batchGroupCount = op.batch_group_count();
auto fill_optional = [](auto filed, SmallVector<int32_t, 4> *vec) {
if (filed.hasValue()) {
int index = 0;
for (auto attribute : filed.getValue()) {
(*vec)[index++] = attribute.getZExtValue();
}
}
};
fill_optional(op.window_strides(), &windowStrides);
fill_optional(op.padding(), &padding);
fill_optional(op.lhs_dilation(), &lhsDilation);
fill_optional(op.rhs_dilation(), &rhsDilation);
if (batchGroupCount != 1) {
op.emitWarning() << "Batch group convoution isn't supported";
return failure();
}
rewriter.create<IREE::VMLA::ConvOp>(
op.getLoc(), op.lhs(), inputShape, op.rhs(), filterShape, dst, dstShape,
rewriter.getI32VectorAttr(windowStrides),
rewriter.getI32VectorAttr(padding),
rewriter.getI32VectorAttr(lhsDilation),
rewriter.getI32VectorAttr(rhsDilation),
rewriter.getI32IntegerAttr(featureGroupCount),
rewriter.getI32IntegerAttr(batchGroupCount), lhsType, rhsType, rhsType);
rewriter.replaceOp(op, dst);
return success();
}
};
} // namespace
void populateHLOConvToVMLAPatterns(MLIRContext *context,
OwningRewritePatternList &patterns,
TypeConverter &typeConverter) {
patterns.insert<VMLAConvOpConverter>(typeConverter, context);
}
} // namespace iree_compiler
} // namespace mlir