/* mbed Microcontroller Library
 * Copyright (c) 2019 ARM Limited
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#ifndef MSTD_MEMORY_
#define MSTD_MEMORY_

/* <mstd_memory>
 *
 * - includes toolchain's <memory>
 * - For ARM C 5, C++11/14 features:
 *   - std::align
 *   - std::addressof
 *   - std::uninitialized_copy_n
 *   - std::unique_ptr, std::make_unique, std::default_delete
 * - For all toolchains, C++17 backports:
 *   - mstd::uninitialized_default_construct, mstd::uninitialized_value_construct
 *   - mstd::uninitialized_move, mstd::uninitialized_move_n
 *   - mstd::destroy_at, mstd::destroy, mstd::destroy_n
 */

#include <memory>

#include <mstd_type_traits>
#include <mstd_utility> // std::pair
#include <mstd_iterator> // std::iterator_traits

#ifdef __CC_ARM

#include <cstddef> // size_t, ptrdiff_t
#include <_move.h> // exchange

namespace std
{
// [ptr.align]
inline void *align(size_t alignment, size_t size, void *&ptr, size_t &space) noexcept
{
    /* Behavior is undefined if alignment is not a power of 2 */
    uintptr_t addr = reinterpret_cast<uintptr_t>(ptr);
    uintptr_t new_addr = (addr + (alignment - 1)) & ~(alignment - 1);
    uintptr_t pad = new_addr - addr;
    if (pad + size <= space) {
        space -= pad;
        ptr = reinterpret_cast<void *>(new_addr);
        return ptr;
    } else {
        return nullptr;
    }
}

// [specialized.addressof]
template <typename T>
T *addressof(T &arg) noexcept
{
    return reinterpret_cast<T *>(const_cast<char *>(&reinterpret_cast<const volatile char &>(arg)));
}

// [uninitialized.copy] - ARMCC has pre-C++11 uninitialized_copy
template <class InputIterator, class Size, class ForwardIterator>
ForwardIterator uninitialized_copy_n(InputIterator first, Size n, ForwardIterator result) {
    for ( ; n > 0; ++result, (void) ++first, --n) {
        ::new (static_cast<void*>(addressof(*result)))
              typename std::iterator_traits<ForwardIterator>::value_type(*first);
    }

    return result;
}

// [uninitialized.fill] - ARMCC has pre-C++11 uninitialized_fill and uninitialized_fill_n

// [unique.ptr]
namespace impl
{
    /* Base version - use T * */
    template <typename T, typename D, typename = void>
    struct unique_ptr_type_helper {
        typedef T *type;
    };

    /* if "remove_reference_t<D>::pointer" is a type, specialise to use it */
    template <typename T, typename D>
    struct unique_ptr_type_helper<T, D, mstd::void_t<typename remove_reference_t<D>::pointer>> {
        typedef typename remove_reference_t<D>::pointer type;
    };

    template <class T, class D>
    using unique_ptr_type_helper_t = typename unique_ptr_type_helper<T, D>::type;

    // Want to eliminate storage for the deleter - could just use it as a base
    // class, for empty base optimisation, if we knew it was a class. But it could be
    // a pointer or reference. Here's a version that uses deleter as base,
    template<class D, typename = void>
    class deleter_store : private D {
    public:
        constexpr deleter_store() noexcept = default;
        template <typename _D>
        constexpr deleter_store(_D &&d) noexcept : D(std::forward<_D>(d)) { }

        D &get_deleter() noexcept { return static_cast<D &>(*this); }
        const D &get_deleter() const noexcept { return static_cast<const D &>(*this); }
    };

    //Here's a version that stores (for pointer/reference)
    template<class D>
    class deleter_store<D, enable_if_t<!is_class<D>::value>> {
        D d;
    public:
        constexpr deleter_store() noexcept : d() { }
        template <typename _D>
        constexpr deleter_store(_D &&d) noexcept : d(std::forward<_D>(d)) { }

        D &get_deleter() noexcept { return d; }
        const D &get_deleter() const noexcept { return d; }
    };
}

// [unique.ptr.dltr.dflt]
template<class T>
struct default_delete {
    constexpr default_delete() noexcept = default;

    template <class U, class = enable_if_t<is_convertible<U *, T *>::value>>
    default_delete(const default_delete<U> &d) noexcept { }

    void operator()(T *ptr) const
    {
        // Program is ill-formed if T is incomplete - generate diagnostic by breaking compilation
        // (Behaviour of raw delete of incomplete class is undefined if complete class is non-trivial, else permitted)
        static_assert(sizeof(T) == sizeof(T), "Cannot delete incomplete type");
        delete ptr;
    }
};

// [unique.ptr.dltr.dflt1]
template<class T>
struct default_delete<T[]> {
    constexpr default_delete() noexcept = default;

    template <class U, class = enable_if_t<is_convertible<U (*)[], T (*)[]>::value>>
    default_delete(const default_delete<U> &d) noexcept { }

    template <class U, class = enable_if_t<is_convertible<U (*)[], T (*)[]>::value>>
    void operator()(U *ptr) const
    {
        delete[] ptr;
    }
};

// [unique.ptr.single]
template<
    class T,
    class D = default_delete<T>
> class unique_ptr : public impl::deleter_store<D>
{
    template <class U, class E>
    static constexpr bool is_compatible_unique_ptr()
    {
        return is_convertible<typename unique_ptr<U,E>::pointer, pointer>::value &&
                !is_array<U>::value;
    }
public:
    typedef impl::unique_ptr_type_helper_t<T, D> pointer;
    typedef T element_type;
    typedef D deleter_type;
    // [unique.ptr.single.ctor]
    template <class _D = D, typename = enable_if_t<!is_pointer<_D>::value && is_default_constructible<_D>::value>>
    constexpr unique_ptr() noexcept : impl::deleter_store<D>(), ptr_() { }
    template <class _D = D, typename = enable_if_t<!is_pointer<_D>::value && is_default_constructible<_D>::value>>
    constexpr unique_ptr(nullptr_t) noexcept : unique_ptr() { }
    template <class _D = D, typename = enable_if_t<!is_pointer<_D>::value && is_default_constructible<_D>::value>>
    explicit unique_ptr(pointer ptr) noexcept : impl::deleter_store<D>(), ptr_(ptr) { }
    template <class _D = D, typename = enable_if_t<is_copy_constructible<_D>::value>>
    unique_ptr(pointer ptr, const D &d) noexcept : impl::deleter_store<D>(d), ptr_(ptr) { }
    template <class _D = D, typename = enable_if_t<is_move_constructible<_D>::value>>
    unique_ptr(pointer ptr, enable_if_t<!is_lvalue_reference<_D>::value, _D &&> d) noexcept : impl::deleter_store<D>(move(d)), ptr_(ptr) { }
    template <class _D = D, typename _A = remove_reference_t<_D>>
    unique_ptr(pointer ptr, enable_if_t<is_lvalue_reference<_D>::value, _A &&> d) = delete;
    unique_ptr(const unique_ptr &) = delete;
    unique_ptr(unique_ptr &&u) noexcept : impl::deleter_store<D>(forward<D>(u.get_deleter())), ptr_(u.ptr_) { u.ptr_ = nullptr; }
    template <class U, class E, class = enable_if_t<
                          is_compatible_unique_ptr<U, E>() &&
                          (is_reference<D>::value ? is_same<E, D>::value : is_convertible<E,D>::value)>>
    unique_ptr(unique_ptr<U, E>&& u) noexcept : impl::deleter_store<D>(std::forward<E>(u.get_deleter())), ptr_(u.release()) { }

    // [unique.ptr.single.dtor]
    ~unique_ptr()
    {
        if (ptr_) {
            this->get_deleter()(ptr_);
        }
    }

    // [unique.ptr.single.modifiers]
    pointer release() noexcept
    {
        return std::exchange(ptr_, nullptr);
    }

    void reset(pointer ptr = pointer()) noexcept
    {
        pointer old = std::exchange(ptr_, ptr);
        if (old) {
            this->get_deleter()(old);
        }
    }

    void swap(unique_ptr &other) noexcept
    {
        using std::swap;
        swap(this->get_deleter(), other.get_deleter());
        swap(ptr_, other.ptr_);
    }

    // [unique.ptr.single.asgn]
    unique_ptr &operator=(const unique_ptr &r) = delete;
    unique_ptr &operator=(unique_ptr &&r) noexcept
    {
        reset(r.release());
        this->get_deleter() = std::forward<D>(r.get_deleter());
        return *this;
    }
    unique_ptr &operator=(nullptr_t) noexcept
    {
        reset();
        return *this;
    }

    template <class U, class E>
    enable_if_t<is_compatible_unique_ptr<U, E>() &&
        is_assignable<D &, E &&>::value,
    unique_ptr> &operator=(unique_ptr<U, E> &&u) noexcept
    {
        reset(u.release());
        this->get_deleter() = std::forward<E>(u.get_deleter());
        return *this;
    }

    // [unique.ptr.single.observers]
    pointer get() const noexcept { return ptr_; }
    pointer operator->() const noexcept { return ptr_; }
    add_lvalue_reference_t<T> operator*() const noexcept { return *ptr_; }
    explicit operator bool() const noexcept { return ptr_; }
private:
    pointer ptr_;
};

// [unique.ptr.runtime]
template<class T, class D>
class unique_ptr<T[], D> : public impl::deleter_store<D>
{
    template <class U>
    static constexpr bool is_compatible_pointer()
    {
        return is_same<U, pointer>::value ||
               is_same<U, nullptr_t>::value ||
               (is_same<pointer, element_type *>::value && is_pointer<U>::value &&
                is_convertible<remove_pointer_t<U> (*)[], element_type (*)[]>::value);
    }

    template <class U, class E, class UP = unique_ptr<U,E>>
    static constexpr bool is_compatible_unique_ptr()
    {
        return is_array<U>::value &&
               is_same<pointer, element_type *>::value &&
               is_same<typename UP::pointer, typename UP::element_type *>::value &&
               is_convertible<typename UP::element_type(*)[], element_type(*)[]>::value;
    }
public:
    typedef impl::unique_ptr_type_helper_t<T, D> pointer;
    typedef T element_type;
    typedef D deleter_type;

    // [unique.ptr.runtime.ctor] / [unique.ptr.single.ctor]
    template <class _D = D, typename = enable_if_t<!is_pointer<_D>::value && is_default_constructible<_D>::value>>
    constexpr unique_ptr() noexcept : impl::deleter_store<D>(), ptr_() { }

    template <class _D = D, typename = enable_if_t<!is_pointer<_D>::value && is_default_constructible<_D>::value>>
    constexpr unique_ptr(nullptr_t) noexcept : unique_ptr() { }

    template <class _D = D, typename = enable_if_t<!is_pointer<_D>::value && is_default_constructible<_D>::value>,
              class U, typename = enable_if_t<is_compatible_pointer<U>()>>
    explicit unique_ptr(U ptr) noexcept : impl::deleter_store<D>(), ptr_(ptr) { }

    template <class _D = D, typename = enable_if_t<is_copy_constructible<_D>::value>,
              class U, typename = enable_if_t<is_compatible_pointer<U>()>>
    unique_ptr(U ptr, const D &d) noexcept : impl::deleter_store<D>(d), ptr_(ptr) { }

    template <class _D = D, typename = enable_if_t<is_move_constructible<_D>::value>,
              class U, typename = enable_if_t<is_compatible_pointer<U>()>>
    unique_ptr(U ptr, enable_if_t<!is_lvalue_reference<_D>::value, _D &&> d) noexcept : impl::deleter_store<D>(std::move(d)), ptr_(ptr) { }

    template <class _D = D, typename _A = remove_reference_t<_D>,
              class U, typename = enable_if_t<is_compatible_pointer<U>()>>
    unique_ptr(U ptr, enable_if_t<is_lvalue_reference<_D>::value, _A &&> d) = delete;

    unique_ptr(const unique_ptr &) = delete;
    unique_ptr(unique_ptr &&u) noexcept : impl::deleter_store<D>(std::forward<D>(u.get_deleter())), ptr_(u.ptr_) { u.ptr_ = nullptr; }

    template <class U, class E,
    typename = enable_if_t<is_compatible_unique_ptr<U, E>() &&
                 (is_reference<D>::value ? is_same<E,D>::value : is_convertible<E,D>::value)>>
    unique_ptr(unique_ptr<U, E>&& u) noexcept : impl::deleter_store<D>(std::forward<E>(u.get_deleter())), ptr_(u.release()) { }

    // [unique.ptr.single.dtor]
    ~unique_ptr()
    {
        if (ptr_) {
            this->get_deleter()(ptr_);
        }
    }

    // [unique.ptr.runtime.modifiers] / [unique.ptr.single.modifiers]
    pointer release() noexcept
    {
        return std::exchange(ptr_, nullptr);
    }

    void reset(pointer ptr = pointer()) noexcept
    {
        pointer old = std::exchange(ptr_, ptr);
        if (old) {
            this->get_deleter()(old);
        }
    }

    template <class U>
    void reset(U) = delete;

    void swap(unique_ptr &other) noexcept
    {
        using std::swap;
        swap(this->get_deleter(), other.get_deleter());
        swap(ptr_, other.ptr_);
    }

    // [unique.ptr.runtime.asgn] / [unique.ptr.single.asgn]
    unique_ptr &operator=(const unique_ptr &r) = delete;
    unique_ptr &operator=(unique_ptr &&r) noexcept
    {
        reset(r.release());
        this->get_deleter() = std::forward<D>(r.get_deleter());
        return *this;
    }
    unique_ptr &operator=(nullptr_t) noexcept
    {
        reset();
        return *this;
    }
    template <class U, class E>
    enable_if_t<is_compatible_unique_ptr<U, E>() &&
                     is_assignable<D &, E &&>::value,
    unique_ptr> &operator=(unique_ptr<U, E> &&u) noexcept
    {
        reset(u.release());
        this->get_deleter() = std::forward<E>(u.get_deleter());
        return *this;
    }

    // [unique.ptr.runtime.observers] / [unique.ptr.single.observers]
    pointer get() const noexcept { return ptr_; }
    T &operator[](size_t index) const { return ptr_[index]; }
    explicit operator bool() const noexcept { return ptr_; }
private:
    pointer ptr_;
};

// [unique.ptr.create]
template <typename T, typename... Args>
enable_if_t<!is_array<T>::value,
unique_ptr<T>> make_unique(Args &&... args)
{
    return unique_ptr<T>(new T(std::forward<Args>(args)...));
}

template <typename T>
enable_if_t<is_array<T>::value && extent<T>::value == 0,
unique_ptr<T>> make_unique(size_t size)
{
    return unique_ptr<T>(new remove_extent_t<T>[size]());
}

template <typename T, typename... Args>
enable_if_t<extent<T>::value != 0,
void> make_unique(Args &&... args) = delete;

// [unique.ptr.special]
template< class T, class D>
void swap(unique_ptr<T,D> &lhs, unique_ptr<T,D> &rhs) noexcept
{
    lhs.swap(rhs);
}

template<class T1, class D1, class T2, class D2>
bool operator==(const unique_ptr<T1, D1> &x, const unique_ptr<T2, D2> &y)
{
    return x.get() == y.get();
}

template<class T1, class D1, class T2, class D2>
bool operator!=(const unique_ptr<T1, D1> &x, const unique_ptr<T2, D2> &y)
{
    return x.get() != y.get();
}

template<class T1, class D1, class T2, class D2>
bool operator<(const unique_ptr<T1, D1> &x, const unique_ptr<T2, D2> &y)
{
    using CT = common_type_t<typename unique_ptr<T1, D1>::pointer, typename unique_ptr<T2, D2>::pointer>;
    return less<CT>()(x.get(), y.get());
}

template<class T1, class D1, class T2, class D2>
bool operator<=(const unique_ptr<T1, D1> &x, const unique_ptr<T2, D2> &y)
{
    return !(y < x);
}

template<class T1, class D1, class T2, class D2>
bool operator>(const unique_ptr<T1, D1> &x, const unique_ptr<T2, D2> &y)
{
    return y < x;
}

template<class T1, class D1, class T2, class D2>
bool operator>=(const unique_ptr<T1, D1> &x, const unique_ptr<T2, D2> &y)
{
    return !(x < y);
}

template <class T, class D>
bool operator==(const unique_ptr<T, D> &x, nullptr_t) noexcept
{
    return !x;
}

template <class T, class D>
bool operator==(nullptr_t, const unique_ptr<T, D> &x) noexcept
{
    return !x;
}

template <class T, class D>
bool operator!=(const unique_ptr<T, D> &x, nullptr_t) noexcept
{
    return bool(x);
}

template <class T, class D>
bool operator!=(nullptr_t, const unique_ptr<T, D> &x) noexcept
{
    return bool(x);
}

template <class T, class D>
bool operator<(const unique_ptr<T, D> &x, nullptr_t) noexcept
{
    return less<typename unique_ptr<T, D>::pointer>()(x.get(), nullptr);
}

template <class T, class D>
bool operator<(nullptr_t, const unique_ptr<T, D> &x) noexcept
{
    return less<typename unique_ptr<T, D>::pointer>()(nullptr, x.get());
}

template <class T, class D>
bool operator>(const unique_ptr<T, D> &x, nullptr_t) noexcept
{
    return nullptr < x;
}

template <class T, class D>
bool operator>(nullptr_t, const unique_ptr<T, D> &x) noexcept
{
    return x < nullptr;
}

template <class T, class D>
bool operator<=(const unique_ptr<T, D> &x, nullptr_t) noexcept
{
    return !(nullptr < x);
}

template <class T, class D>
bool operator<=(nullptr_t, const unique_ptr<T, D> &x) noexcept
{
    return !(x < nullptr);
}

template <class T, class D>
bool operator>=(const unique_ptr<T, D> &x, nullptr_t) noexcept
{
    return !(x < nullptr);
}

template <class T, class D>
bool operator>=(nullptr_t, const unique_ptr<T, D> &x) noexcept
{
    return !(nullptr < x);
}

} // namespace std

#endif // __CC_ARM

namespace mstd {
    using std::align;
    using std::allocator;
    using std::addressof;

    // [uninitialized.construct.default] (C++17)
    template <class ForwardIterator, class Size>
    void uninitialized_default_construct(ForwardIterator first, ForwardIterator last) {
        for (; first != last; ++first) {
            ::new (static_cast<void*>(addressof(*first)))
                    typename std::iterator_traits<ForwardIterator>::value_type;
        }
    }

    template <class ForwardIterator, class Size>
    ForwardIterator uninitialized_default_construct_n(ForwardIterator first, Size n) {
        for (; n; ++first, --n) {
            ::new (static_cast<void*>(addressof(*first)))
                    typename std::iterator_traits<ForwardIterator>::value_type;
        }

        return first;
    }

    // [uninitialized.construct.value] (C++17)
    template <class ForwardIterator, class Size>
    void uninitialized_value_construct(ForwardIterator first, ForwardIterator last) {
        for (; first != last; ++first) {
            ::new (static_cast<void*>(addressof(*first)))
                    typename std::iterator_traits<ForwardIterator>::value_type();
        }
    }

    template <class ForwardIterator, class Size>
    ForwardIterator uninitialized_value_construct_n(ForwardIterator first, Size n) {
        for (; n; ++first, --n) {
            ::new (static_cast<void*>(addressof(*first)))
                    typename std::iterator_traits<ForwardIterator>::value_type();
        }

        return first;
    }

    // [uninitialized.move] (C++17)
    template <class InputIterator, class ForwardIterator>
    ForwardIterator uninitialized_move(InputIterator first, InputIterator last, ForwardIterator result) {
        for (; first != last; ++result, (void) ++first) {
            ::new (static_cast<void*>(addressof(*result)))
                  typename std::iterator_traits<ForwardIterator>::value_type(move(*first));
        }

        return result;
    }

    template <class InputIterator, class Size, class ForwardIterator>
    std::pair<InputIterator, ForwardIterator> uninitialized_move_n(InputIterator first, Size n, ForwardIterator result) {
        for ( ; n > 0; ++result, (void) ++first, --n) {
            ::new (static_cast<void*>(addressof(*result)))
                  typename std::iterator_traits<ForwardIterator>::value_type(std::move(*first));
        }

        return { first, result };
    }

    using std::uninitialized_copy;
    using std::uninitialized_copy_n;
    using std::uninitialized_fill;
    using std::uninitialized_fill_n;

    // [specialized.destroy] (C++17)
    template <class T>
    void destroy_at(T *location)
    {
        location->~T();
    }

    template <class ForwardIterator>
    void destroy(ForwardIterator first, ForwardIterator last)
    {
        for (; first != last; ++first) {
            destroy_at(addressof(*first));
        }
    }

    template <class ForwardIterator, class Size>
    ForwardIterator destroy_n(ForwardIterator first, Size n)
    {
        for (; n > 0; (void)++first, --n) {
            destroy_at(addressof(*first));
        }
        return first;
    }

    using std::default_delete;
    using std::unique_ptr;
    using std::make_unique;
}

#endif // MSTD_MEMORY_
