blob: 0cd3df1d8b51a768490162456eec3149f200d897 [file]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_COMPILER_UTILS_PERMUTATION_H_
#define IREE_COMPILER_UTILS_PERMUTATION_H_
#include <iterator>
#include <type_traits>
#include <utility>
#include "llvm/ADT/ADL.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Support/LLVM.h"
namespace mlir::iree_compiler {
// Example: values = (1, 2, 3), permutation = (2, 1, 0)
// output = (3, 2, 1).
// TODO: make applyPermutation at mlir/Dialect/Utils/IndexingUtils.h in MLIR
// generic and use it instead.
template <typename ValuesIt, typename PermutationRange, typename OutIt>
void permute(ValuesIt valuesBegin, ValuesIt valuesEnd,
PermutationRange &&permutation, OutIt outBegin) {
assert(std::distance(valuesBegin, valuesEnd) >= llvm::adl_size(permutation));
llvm::transform(permutation, outBegin,
[valuesBegin](auto i) { return valuesBegin[i]; });
}
template <typename ValuesRange, typename PermutationRange, typename OutIt>
void permute(ValuesRange &&values, PermutationRange &&permutation,
OutIt outBegin) {
permute(llvm::adl_begin(std::forward<ValuesRange>(values)),
llvm::adl_end(std::forward<ValuesRange>(values)), permutation,
outBegin);
}
template <typename T, typename Index>
SmallVector<T> permute(ArrayRef<T> values, ArrayRef<Index> permutation) {
SmallVector<T> res;
permute(values, permutation, std::back_inserter(res));
return res;
}
// Check if the range is a sequence of numbers starting from 0.
// Example: (0, 1, 2, 3).
// TODO: Make the isIdentityPermutation in MLIR more generic to not only
// accept int64_t and delete this.
template <typename Range>
bool isIdentityPermutation(Range &&range) {
using ValueType = std::decay_t<decltype(*std::begin(range))>;
ValueType i = static_cast<ValueType>(0);
return llvm::all_of(std::forward<Range>(range), [&i](ValueType v) {
bool res = (v == i);
++i;
return res;
});
}
// Make a permutation that moves src to dst.
// Example with size = 5, src = 1, dst = 3.
// output = (0, 2, 3, 1, 4).
// Example with size = 2, src = 0, dst = 1.
// output = (1, 0).
template <typename T, typename OutIt>
void makeMovePermutation(T size, T src, T dst, OutIt outBegin) {
assert(src < size && dst < size && size > static_cast<T>(0));
T outSize = 0;
for (T i = 0; i < size; ++i) {
if (outSize == dst) {
*outBegin = src;
++outBegin;
++outSize;
}
if (i == src) {
++i;
if (i >= size) {
break;
}
}
*outBegin = i;
++outBegin;
++outSize;
}
if (size != outSize) {
*outBegin = src;
++outBegin;
}
}
template <typename T>
SmallVector<T> makeMovePermutation(T size, T src, T dst) {
SmallVector<T> res;
res.reserve(size);
makeMovePermutation(size, src, dst, std::back_inserter(res));
return res;
}
} // namespace mlir::iree_compiler
#endif // IREE_COMPILER_UTILS_PERMUTATION_H_