/* mbed Microcontroller Library
 * Copyright (c) 2019 ARM Limited
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#ifndef MSTD_ALGORITHM_
#define MSTD_ALGORITHM_

/* <mstd_algorithm>
 *
 * - provides <algorithm>
 * - For ARM C 5, standard C++11/14 features:
 *   - std::min, std::max for std::initializer_list
 *   - std::all_of, std::any_of, std::none_of
 *   - std::find_if_not
 *   - std::equal (2-range forms)
 *   - std::copy_n, std::move, std::move_backward
 *   - mstd::min, mstd::max constexpr replacements
 */

#include <algorithm>

#ifdef __CC_ARM
#include <mstd_utility>
#endif

namespace mstd {
#ifdef __CC_ARM
// Really want basic min/max to be constexpr as per C++14
template <typename T>
constexpr const T &max(const T &a, const T &b)
{
    return a < b ? b : a;
}

template <typename T, class Compare>
constexpr const T &max(const T &a, const T &b, Compare comp)
{
    return comp(a, b) ? b : a;
}

template <typename T>
constexpr const T &min(const T &a, const T &b)
{
    return b < a ? b : a;
}

template <typename T>
constexpr const T &min(const T &a, const T &b, Compare comp)
{
    return comp(b, a) ? b : a;
}

// Maybe sort out C++14 constexpr for these later - these are C++11 at least
template <typename T>
T max(initializer_list<T> il)
{
    return *std::max_element(begin(il), end(il));
}

template <typename T, class Compare>
T max(initializer_list<T> il, Compare comp)
{
    return *std::max_element(begin(il), end(il), comp);
}

template <typename T>
T min(initializer_list<T> il)
{
    return *std::min_element(begin(il), end(il));
}

template <typename T, class Compare>
T min(initializer_list<T> il, Compare comp)
{
    return *std::min_element(begin(il), end(il), comp);
}

template <typename T1, typename T2>
class pair;

template <typename T>
constexpr pair<const T &, const T &> minmax(const T &a, const T &b)
{
    if (b < a) {
        return pair<const T &, const T &>(b, a);
    } else {
        return pair<const T &, const T &>(a, b);
    }
}

template <typename T, class Compare>
constexpr pair<const T &, const T &> minmax(const T &a, const T &b, Compare comp)
{
    if (comp(b, a)) {
        return pair<const T &, const T &>(b, a);
    } else {
        return pair<const T &, const T &>(a, b);
    }
}
#else
using std::min;
using std::max;
using std::minmax;
#endif
}

#ifdef __CC_ARM

namespace std {
// [alg.all_of]
template <class InputIterator, class Predicate>
bool all_of(InputIterator first, InputIterator last, Predicate pred)
{
    for (; first != last; ++first) {
        if (!pred(*first)) {
            return false;
        }
    }
    return true;
}

// [alg.any_of]
template <class InputIterator, class Predicate>
bool any_of(InputIterator first, InputIterator last, Predicate pred)
{
    for (; first != last; ++first) {
        if (pred(*first)) {
            return true;
        }
    }
    return false;
}

// [alg.none_of]
template <class InputIterator, class Predicate>
bool none_of(InputIterator first, InputIterator last, Predicate pred)
{
    for (; first != last; ++first) {
        if (pred(*first)) {
            return false;
        }
    }
    return true;
}

// [alg.find]
template<class InputIterator, class Predicate>
InputIterator find_if_not(InputIterator first, InputIterator last, Predicate pred)
{
    for (; first != last; ++first) {
        if (!pred(*first)) {
            return first;
        }
    }
    return first;
}

// [alg.equal]
namespace impl {
template<class RandomAccessIterator1, class RandomAccessIterator2>
bool equal(RandomAccessIterator1 first1, RandomAccessIterator1 last1,
           RandomAccessIterator2 first2, RandomAccessIterator2 last2,
           random_access_iterator_tag,
           random_access_iterator_tag)
{
    if (last1 - first1 != last2 - first2) {
        return false;
    }
    return equal(first1, last1, first2);
}

template<class RandomAccessIterator1, class RandomAccessIterator2, class BinaryPredicate>
bool equal(RandomAccessIterator1 first1, RandomAccessIterator1 last1,
           RandomAccessIterator2 first2, RandomAccessIterator2 last2,
           BinaryPredicate pred,
           random_access_iterator_tag,
           random_access_iterator_tag)
{
    if (last1 - first1 != last2 - first2) {
        return false;
    }
    return equal(first1, last1, first2, pred);
}

template<class InputIterator1, class InputIterator2>
bool equal(InputIterator1 first1, InputIterator1 last1,
           InputIterator2 first2, InputIterator2 last2,
           input_iterator_tag,
           input_iterator_tag)
{
    for (; first1 != last1 && first2 != last2; ++first1, ++first2) {
        if (!(*first1 == *first2)) {
            return false;
        }
    }
    return first1 == last1 && first2 == last2;
}

template<class InputIterator1, class InputIterator2, class BinaryPredicate>
bool equal(InputIterator1 first1, InputIterator1 last1,
           InputIterator2 first2, InputIterator2 last2,
           BinaryPredicate pred,
           input_iterator_tag,
           input_iterator_tag)
{
    for (; first1 != last1 && first2 != last2; ++first1, ++first2) {
        if (!pred(*first1, *first2)) {
            return false;
        }
    }
    return first1 == last1 && first2 == last2;
}
}

template<class InputIterator1, class InputIterator2>
bool equal(InputIterator1 first1, InputIterator1 last1,
           InputIterator2 first2, InputIterator2 last2)
{
    return impl::equal(first1, last1, first2, last2,
                       typename iterator_traits<InputIterator1>::iterator_category(),
                       typename iterator_traits<InputIterator2>::iterator_category());
}

template<class InputIterator1, class InputIterator2, class BinaryPredicate>
bool equal(InputIterator1 first1, InputIterator1 last1,
           InputIterator2 first2, InputIterator2 last2,
           BinaryPredicate pred)
{
    return impl::equal(first1, last1, first2, last2, pred,
                       typename iterator_traits<InputIterator1>::iterator_category(),
                       typename iterator_traits<InputIterator2>::iterator_category());

}

// [alg.copy]

namespace impl
{
template<class RandomAccessIterator, class Size, class OutputIterator>
OutputIterator copy_n(RandomAccessIterator first, Size n, OutputIterator result, random_access_iterator_tag)
{
    // presumably this should have memcpy etc optimisations
    return std::copy(first, first + n, result);
}

template<class InputIterator, class Size, class OutputIterator>
OutputIterator copy_n(InputIterator first, Size n, OutputIterator result, input_iterator_tag)
{
    for (Size i = 0; i < n; ++i) {
        *result++ = *first++;
    }
    return result;
}
}

template<class InputIterator, class Size, class OutputIterator>
OutputIterator copy_n(InputIterator first, Size n, OutputIterator result)
{
    return impl::copy_n(first, n, result,
                        typename iterator_traits<InputIterator>::iterator_category());
}

template<class InputIterator, class OutputIterator, class Predicate>
OutputIterator copy_if(InputIterator first, InputIterator last, OutputIterator result, Predicate pred)
{
    for (; first != last; ++first) {
        if (pred(*first)) {
            *result++ = *first;
        }
    }
    return result;
}

// [alg.move]

template<class InputIterator, class OutputIterator>
OutputIterator move(InputIterator first, InputIterator last, OutputIterator result)
{
    while (first != last) {
        *result++ = std::move(*first++);
    }
    return result;
}

template<class BidirectionalIterator1, class BidirectionalIterator2>
BidirectionalIterator2 move_backward(BidirectionalIterator1 first, BidirectionalIterator1 last, BidirectionalIterator2 result)
{
    while (last != first) {
        *--result = std::move(*--last);
    }
    return result;
}

}

#endif // __CC_ARM

namespace mstd
{
using std::initializer_list;
using std::all_of;
using std::any_of;
using std::none_of;
using std::for_each;
using std::find;
using std::find_if;
using std::find_if_not;
using std::find_end;
using std::find_first_of;
using std::adjacent_find;
using std::count;
using std::count_if;
using std::mismatch;
using std::equal;
using std::search;
using std::search_n;
using std::copy;
using std::copy_n;
using std::copy_if;
using std::move;
using std::move_backward;
using std::swap_ranges;
using std::iter_swap;
using std::transform;
using std::replace;
using std::replace_if;
using std::replace_copy;
using std::replace_copy_if;
using std::fill;
using std::fill_n;
using std::generate;
using std::generate_n;
using std::remove;
using std::remove_if;
using std::remove_copy;
using std::remove_copy_if;
using std::unique;
using std::unique_copy;
using std::reverse;
using std::reverse_copy;
using std::rotate;
using std::rotate_copy;
using std::partition;
using std::stable_partition;
using std::sort;
using std::stable_sort;
using std::partial_sort;
using std::partial_sort_copy;
using std::nth_element;
using std::lower_bound;
using std::upper_bound;
using std::equal_range;
using std::binary_search;
using std::merge;
using std::inplace_merge;
using std::includes;
using std::set_union;
using std::set_intersection;
using std::set_difference;
using std::set_symmetric_difference;
using std::push_heap;
using std::pop_heap;
using std::make_heap;
using std::sort_heap;
using std::min_element;
using std::max_element;
using std::lexicographical_compare;
using std::next_permutation;
using std::prev_permutation;

}

#endif // MSTD_ALGORITHM_
