blob: 0bd88d326ab27bd4678ab737bd56e89bf6d6c9b8 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/base/shape.h"
#include "iree/base/status.h"
#include "iree/base/status_matchers.h"
#include "iree/testing/gtest.h"
namespace iree {
namespace {
using ::testing::ElementsAre;
// Tests shapes that represent 0-D scalar values.
TEST(ShapeTest, Scalar) {
Shape shape;
EXPECT_EQ(0, shape.size());
EXPECT_TRUE(shape.empty());
EXPECT_EQ(1, shape.element_count());
EXPECT_EQ(shape, shape);
EXPECT_EQ(0, shape.subspan().size());
for (const int dim : shape) {
FAIL() << "Should have no dimensions, have: " << dim;
}
EXPECT_EQ(shape.begin(), shape.end());
EXPECT_EQ(shape.cbegin(), shape.cend());
shape.clear();
EXPECT_EQ(0, shape.size());
}
// Tests the various ways of constructing a 1+D shape.
TEST(ShapeTest, NonScalarConstruction) {
EXPECT_EQ(0, Shape().size());
EXPECT_EQ(0, Shape({}).size());
EXPECT_EQ(1, Shape({10}).size());
EXPECT_EQ(4, Shape({10, 20, 30, 40}).size());
std::vector<int> empty_data = {};
EXPECT_EQ(0, Shape(empty_data.data(), empty_data.size()).size());
EXPECT_EQ(0, Shape(empty_data.begin(), empty_data.end()).size());
EXPECT_EQ(0, Shape(absl::MakeConstSpan(empty_data)).size());
EXPECT_THAT(Shape({}).subspan(), ElementsAre());
EXPECT_THAT(Shape({10}).subspan(), ElementsAre(10));
EXPECT_THAT(Shape({10, 20, 30, 40}).subspan(), ElementsAre(10, 20, 30, 40));
std::vector<int> valid_data = {10, 20, 30, 40};
EXPECT_THAT(Shape(valid_data.begin(), valid_data.end()).subspan(),
ElementsAre(10, 20, 30, 40));
EXPECT_THAT(Shape(absl::MakeConstSpan(valid_data)).subspan(),
ElementsAre(10, 20, 30, 40));
}
// Tests shapes that represent 1+D multidimensional values.
TEST(ShapeTest, NonScalarAccess) {
Shape shape = {1, 2, 3, 4};
EXPECT_EQ(4, shape.size());
EXPECT_FALSE(shape.empty());
EXPECT_EQ(1 * 2 * 3 * 4, shape.element_count());
EXPECT_EQ(shape, shape);
EXPECT_NE(shape, Shape({4, 3, 2, 1}));
EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4));
std::vector<int> readout;
for (const int dim : shape) {
readout.push_back(dim);
}
EXPECT_THAT(readout, ElementsAre(1, 2, 3, 4));
EXPECT_EQ(1, shape[0]);
EXPECT_EQ(2, shape[1]);
EXPECT_EQ(3, shape[2]);
EXPECT_EQ(4, shape[3]);
EXPECT_EQ(1, shape.front());
EXPECT_EQ(4, shape.back());
}
TEST(ShapeTest, PushBack) {
Shape shape;
EXPECT_EQ(0, shape.size());
shape.push_back(10);
EXPECT_EQ(1, shape.size());
EXPECT_EQ(10, shape.front());
EXPECT_EQ(10, shape.back());
EXPECT_EQ(10, shape[0]);
EXPECT_THAT(shape.subspan(), ElementsAre(10));
shape.push_back(20);
EXPECT_EQ(2, shape.size());
EXPECT_EQ(10, shape.front());
EXPECT_EQ(20, shape.back());
EXPECT_EQ(10, shape[0]);
EXPECT_EQ(20, shape[1]);
EXPECT_THAT(shape.subspan(), ElementsAre(10, 20));
}
TEST(ShapeTest, Insert) {
Shape shape;
EXPECT_EQ(0, shape.size());
shape.insert(shape.begin(), 20);
EXPECT_THAT(shape.subspan(), ElementsAre(20));
shape.insert(shape.begin(), 10);
EXPECT_THAT(shape.subspan(), ElementsAre(10, 20));
shape.insert(shape.end(), 40);
EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 40));
shape.insert(shape.begin() + 2, 30);
EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 30, 40));
Shape ex_shape{72, 4};
ex_shape.insert(ex_shape.begin(), 144);
EXPECT_THAT(ex_shape.subspan(), ElementsAre(144, 72, 4));
}
TEST(ShapeTest, Erase) {
Shape shape = {1, 2, 3, 4};
EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4));
shape.erase(shape.begin());
EXPECT_THAT(shape.subspan(), ElementsAre(2, 3, 4));
shape.erase(shape.end());
EXPECT_THAT(shape.subspan(), ElementsAre(2, 3));
shape.erase(shape.begin() + 1);
EXPECT_THAT(shape.subspan(), ElementsAre(2));
shape.erase(shape.end());
EXPECT_THAT(shape.subspan(), ElementsAre());
}
TEST(ShapeTest, Clear) {
Shape shape;
EXPECT_EQ(0, shape.size());
shape.clear();
EXPECT_EQ(0, shape.size());
shape = Shape({1});
shape.clear();
EXPECT_EQ(0, shape.size());
shape = Shape({1, 2, 3, 4});
shape.clear();
EXPECT_EQ(0, shape.size());
}
TEST(ShapeTest, DebugString) {
EXPECT_EQ("[]", Shape({}).DebugString());
EXPECT_EQ("[1]", Shape({1}).DebugString());
EXPECT_EQ("[1,2]", Shape({1, 2}).DebugString());
}
TEST(ShapeTest, ElementCount) {
EXPECT_EQ(1, Shape({}).element_count());
EXPECT_EQ(0, Shape({0}).element_count());
EXPECT_EQ(1, Shape({1}).element_count());
EXPECT_EQ(2, Shape({2, 1}).element_count());
EXPECT_EQ(10, Shape({2, 5}).element_count());
EXPECT_EQ(9216, Shape({72, 1, 128}).element_count());
EXPECT_EQ(9216, Shape({1, 72, 128}).element_count());
// Partial shaping should yield no elements.
EXPECT_EQ(0, Shape({1, -1, 2, 3}).element_count());
}
TEST(ShapeTest, ResolveAxis) {
int axis;
ASSERT_OK_AND_ASSIGN(axis, Shape({0}).ResolveAxis(0));
EXPECT_EQ(0, axis);
ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(1));
EXPECT_EQ(1, axis);
ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(2));
EXPECT_EQ(2, axis);
EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(3).status()));
}
TEST(ShapeTest, ResolveAxisNegative) {
int axis;
ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-3));
EXPECT_EQ(0, axis);
ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-2));
EXPECT_EQ(1, axis);
ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-1));
EXPECT_EQ(2, axis);
EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(-4).status()));
}
TEST(ShapeTest, ResolveAxisScalar) {
int axis;
ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(0));
EXPECT_EQ(0, axis);
ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(-1));
EXPECT_EQ(0, axis);
EXPECT_TRUE(IsInvalidArgument(Shape({}).ResolveAxis(1).status()));
}
TEST(ShapeTest, Equality) {
EXPECT_EQ(Shape({}), Shape({}));
EXPECT_EQ(Shape({0}), Shape({0}));
EXPECT_EQ(Shape({1}), Shape({1}));
EXPECT_EQ(Shape({1, 2}), Shape({1, 2}));
EXPECT_NE(Shape({}), Shape({1}));
EXPECT_NE(Shape({-1}), Shape({1}));
EXPECT_NE(Shape({1}), Shape({}));
EXPECT_NE(Shape({1}), Shape({2}));
EXPECT_NE(Shape({1, 2}), Shape({3, 4}));
}
} // namespace
} // namespace iree