condition_variable 6.46 KB
#pragma once

#if __cplusplus < 201103L
#error "C++ version lower than C++11"
#endif

#include <pthread.h>

#include <system_error>
#include <chrono>
#include <utility>
#include <functional>
#include <memory>

#include "__utils.h"
#include "mutex"

#define rt_cpp_cond_var pthread_cond_t

namespace std
{

    enum class cv_status
    {
        no_timeout,
        timeout
    };

    class condition_variable
    {
    public:
        typedef rt_cpp_cond_var *native_handle_type;

        condition_variable(const condition_variable &) = delete;
        condition_variable &operator=(const condition_variable &) = delete;

        condition_variable() = default;

        ~condition_variable()
        {
            pthread_cond_destroy(&_m_cond);
        }

        void wait(unique_lock<mutex> &lock);

        void notify_one() noexcept
        {
            pthread_cond_signal(&_m_cond);
        }

        void notify_all() noexcept
        {
            pthread_cond_broadcast(&_m_cond);
        }

        template <class Predicate>
        void wait(unique_lock<mutex> &lock, Predicate pred)
        {
            while (!pred())
                wait(lock);
        }

        template <class Clock, class Duration>
        cv_status wait_until(unique_lock<mutex> &lock,
                             const chrono::time_point<Clock, Duration> &abs_time)
        {
            if (!lock.owns_lock())
                throw_system_error((int)errc::operation_not_permitted,
                                   "condition_variable::wailt_until: waiting on unlocked lock");
            auto secs = chrono::time_point_cast<chrono::seconds>(abs_time);
            auto nano_secs = chrono::duration_cast<chrono::nanoseconds>(abs_time - secs);

            struct timespec c_abs_time = {static_cast<time_t>(secs.time_since_epoch().count()),
                                          static_cast<long>(nano_secs.count())};

            pthread_cond_timedwait(&_m_cond, lock.mutex()->native_handle(), &c_abs_time);

            return (Clock::now() < abs_time) ? cv_status::no_timeout : cv_status::timeout;
        }

        template <class Clock, class Duration, class Predicate>
        bool wait_until(unique_lock<mutex> &lock,
                        const chrono::time_point<Clock, Duration> &abs_time,
                        Predicate pred)
        {
            while (!pred())
                if (wait_until(lock, abs_time) == cv_status::timeout)
                    return pred();
            return true;
        }

        template <class Rep, class Period>
        cv_status wait_for(unique_lock<mutex> &lock,
                           const chrono::duration<Rep, Period> &rel_time)
        {
            return wait_until(lock, real_time_clock::now() + rel_time);
        }

        template <class Rep, class Period, class Predicate>
        bool wait_for(unique_lock<mutex> &lock,
                      const chrono::duration<Rep, Period> &rel_time,
                      Predicate pred)
        {
            return wait_until(lock, real_time_clock::now() + rel_time, std::move(pred));
        }

        native_handle_type native_handle()
        {
            return &_m_cond;
        }

    private:
        rt_cpp_cond_var _m_cond = PTHREAD_COND_INITIALIZER;
    };

    // Lockable is only required to have `lock()` and `unlock()`
    class condition_variable_any
    {
    private:
        condition_variable _m_cond;
        shared_ptr<mutex> _m_mtx;

        // so that Lockable automatically unlocks when waiting and locks after waiting
        template <class Lockable>
        struct unlocker
        {
            Lockable &_m_lock;

            explicit unlocker(Lockable &lk)
                : _m_lock(lk)
            {
                _m_lock.unlock();
            }

            ~unlocker()
            {
                _m_lock.lock();
            }

            unlocker(const unlocker &) = delete;
            unlocker &operator=(const unlocker &) = delete;
        };

    public:
        condition_variable_any() : _m_mtx(std::make_shared<mutex>()) {}
        ~condition_variable_any() = default;

        condition_variable_any(const condition_variable_any &) = delete;
        condition_variable_any &operator=(const condition_variable_any &) = delete;

        void notify_one() noexcept
        {
            lock_guard<mutex> lk(*_m_mtx);
            _m_cond.notify_one();
        }

        void notify_all() noexcept
        {
            lock_guard<mutex> lk(*_m_mtx);
            _m_cond.notify_all();
        }

        template <class Lock>
        void wait(Lock &lock)
        {
            shared_ptr<mutex> mut = _m_mtx;
            unique_lock<mutex> lk(*mut);
            unlocker<Lock> auto_lk(lock); // unlock here

            unique_lock<mutex> lk2(std::move(lk));
            _m_cond.wait(lk2);
        } // mut.unlock(); lock.lock();

        template <class Lock, class Predicate>
        void wait(Lock &lock, Predicate pred)
        {
            while (!pred())
                wait(lock);
        }

        template <class Lock, class Clock, class Duration>
        cv_status wait_until(Lock &lock,
                             const chrono::time_point<Clock, Duration> &abs_time)
        {
            shared_ptr<mutex> mut = _m_mtx;
            unique_lock<mutex> lk(*mut);
            unlocker<Lock> auto_lk(lock); // unlock here

            unique_lock<mutex> lk2(std::move(lk));
            return _m_cond.wait_until(lk2, abs_time);
        }

        template <class Lock, class Clock, class Duration, class Predicate>
        bool wait_until(Lock &lock,
                        const chrono::time_point<Clock, Duration> &abs_time,
                        Predicate pred)
        {
            while (!pred())
                if (wait_until(lock, abs_time) == cv_status::timeout)
                    return pred();
            return true;
        }

        template <class Lock, class Rep, class Period>
        cv_status wait_for(Lock &lock,
                           const chrono::duration<Rep, Period> &rel_time)
        {
            return wait_until(lock, real_time_clock::now() + rel_time);
        }

        template <class Lock, class Rep, class Period, class Predicate>
        bool wait_for(Lock &lock,
                      const chrono::duration<Rep, Period> &rel_time,
                      Predicate pred)
        {
            return wait_until(lock, real_time_clock::now() + rel_time, std::move(pred));
        }
    };

    void notify_all_at_thread_exit(condition_variable &cond, unique_lock<mutex> lk);

} // namespace std