blob: 758d66fe85324e4654b7076b45fd6ccdbe9f0b74 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef IREE_BASE_SHAPE_H_
#define IREE_BASE_SHAPE_H_
#include <array>
#include <cstring>
#include <initializer_list>
#include <iterator>
#include <string>
#include <type_traits>
#include <vector>
#include "absl/meta/type_traits.h"
#include "absl/types/span.h"
#include "iree/base/logging.h"
#include "iree/base/status.h"
namespace iree {
// For simplicity we limit our shapes to a max of rank-N (shape.size() == N) as
// this prevents dynamic allocations and rarely are there greater ranks.
constexpr int kMaxRank = 5;
// Represent indices and lengths of tensors.
using Index = std::array<int, kMaxRank>;
using Length = std::array<int, kMaxRank>;
// Represents the number of elements in multiple dimensions.
// Can be rank-0 (scalar) to rank-kMaxRank. Tries to match the API of
// std::vector and can be converted to a Span via subspan().
//
// https://www.tensorflow.org/guide/tensors#shape
class Shape {
public:
using size_type = int;
static constexpr size_type npos = ~(size_type(0)); // NOLINT
using iterator = int*;
using const_iterator = const int*;
Shape() = default;
Shape(const int* values, int size);
Shape(std::initializer_list<int> values)
: Shape(values.begin(), values.size()) {}
explicit Shape(absl::Span<const int> values)
: Shape(values.data(), values.size()) {}
template <typename Iterator>
using EnableIfForwardIterator = absl::enable_if_t<std::is_convertible<
typename std::iterator_traits<Iterator>::iterator_category,
std::forward_iterator_tag>::value>;
template <typename Iterator, EnableIfForwardIterator<Iterator>* = nullptr>
Shape(Iterator first, Iterator last) {
rank_ = std::distance(first, last);
QCHECK_LE(rank_, kMaxRank);
for (int i = 0; first != last; ++i, static_cast<void>(++first)) {
value_[i] = *first;
}
}
// Returns a string representation of the given shape.
std::string DebugString() const;
// Size (aka 'rank') of the shape, counting the number of dimensions.
constexpr size_type size() const noexcept { return rank_; }
// Whether the shape is rank-0 (scalar).
constexpr bool empty() const noexcept { return rank_ == 0; }
// Returns the total elements in the tensor shape.
// Returns 0 if the tensor shape is not complete and 1 if the shape is a
// scalar value.
int element_count() const;
// Resolves an axis in [-R,R) to the real axis value and verifies the range.
StatusOr<int> ResolveAxis(int axis) const;
// Compares two shapes for equality.
inline static bool Equal(const Shape& a, const Shape& b) {
return a.rank_ == b.rank_ &&
std::memcmp(a.value_, b.value_, a.rank_ * sizeof(value_[0])) == 0;
}
int& operator[](size_type i) noexcept {
DCHECK_GE(i, 0);
DCHECK_LT(i, rank_);
return value_[i];
}
const int& operator[](size_type i) const noexcept {
DCHECK_GE(i, 0);
DCHECK_LT(i, rank_);
return value_[i];
}
int front() const noexcept {
DCHECK_GE(rank_, 1);
return value_[0];
}
int back() const noexcept {
DCHECK_GE(rank_, 1);
return value_[rank_ - 1];
}
constexpr iterator begin() const noexcept {
return const_cast<iterator>(&value_[0]);
}
constexpr iterator end() const noexcept {
return const_cast<iterator>(&value_[rank_]);
}
constexpr const_iterator cbegin() const noexcept { return &value_[0]; }
constexpr const_iterator cend() const noexcept { return &value_[rank_]; }
absl::Span<const int> subspan(size_type pos = 0, size_type len = npos) const;
absl::Span<const int> data() const { return subspan(); }
void push_back(int dim);
void insert(iterator pos, int dim);
void erase(iterator pos);
void clear() { rank_ = 0; }
private:
size_type rank_ = 0;
int value_[kMaxRank];
};
inline bool operator==(const Shape& a, const Shape& b) {
return Shape::Equal(a, b);
}
inline bool operator!=(const Shape& a, const Shape& b) { return !(a == b); }
inline std::ostream& operator<<(std::ostream& stream, const Shape& shape) {
stream << shape.DebugString();
return stream;
}
} // namespace iree
#endif // IREE_BASE_SHAPE_H_