thread 5.74 KB
#pragma once

#if __cplusplus < 201103L
#error "C++ version lower than C++11"
#endif

//#if defined(RT_USING_PTHREADS)

#include <unistd.h>
#include <pthread.h>
#include <sched.h>
#include <rtthread.h>

#include <cstddef>
#include <cerrno>
#include <ostream>
#include <functional>
#include <utility>
#include <chrono>
#include <memory>

#define rt_cpp_thread_t pthread_t
#ifndef PTHREAD_NUM_MAX
#define PTHREAD_NUM_MAX 32
#endif
#define CPP_UNJOINABLE_THREAD PTHREAD_NUM_MAX

namespace std
{
    #define __STDCPP_THREADS__ __cplusplus

    

    class thread
    {
    public:
        typedef rt_cpp_thread_t   native_handle_type;
        
        struct invoker_base;
        typedef shared_ptr<invoker_base> invoker_base_ptr;

        class id
        {
            // basically a wrapper around native_handle_type
            native_handle_type __cpp_thread_t;

        public:
            id() noexcept : __cpp_thread_t(CPP_UNJOINABLE_THREAD) {} 

            explicit id(native_handle_type hid) 
                : __cpp_thread_t(hid) {}
        private:
            friend class thread;
            friend class hash<thread::id>;

            friend bool operator==(thread::id x, thread::id y) noexcept;

            friend bool operator<(thread::id x, thread::id y) noexcept;

            template <class charT, class traits>
            friend basic_ostream<charT, traits>& 
            operator<<(basic_ostream<charT, traits>& out, thread::id id);
        };

        thread() noexcept = default;
        thread(const thread&) = delete;
        thread& operator=(const thread&) = delete;
        ~thread();
        
        template <class F, class ...Args>
        explicit thread(F&& f, Args&&... args)
        {
            start_thread(make_invoker_ptr(std::bind(
                std::forward<F>(f),
                std::forward<Args>(args)...
            )));
        }

        thread(thread&& t) noexcept
        {
            swap(t);
        }

        thread& operator=(thread&& t) noexcept
        {
            if (joinable())
                terminate();
            swap(t);
            return *this;
        }

        // member functions
        void swap(thread& t) noexcept
        {
            std::swap(_m_thr, t._m_thr);
        }

        bool joinable() const noexcept
        {
            return (_m_thr.__cpp_thread_t < PTHREAD_NUM_MAX);
        }

        void join();

        void detach();

        id get_id() const noexcept { return _m_thr; }

        native_handle_type native_handle() { return _m_thr.__cpp_thread_t; }

        // static members
        static unsigned hardware_concurrency() noexcept;

    private:
        id _m_thr;

        void start_thread(invoker_base_ptr b);
    public:
        struct invoker_base
        {
            invoker_base_ptr this_ptr;
            
            virtual ~invoker_base() = default;

            virtual void invoke() = 0;
        };


        template<typename Callable>
        struct invoker : public invoker_base
        {
            Callable func;

            invoker(Callable&& F) : func(std::forward<Callable>(F)) { }

            void invoke() { func(); }
        };

        template <typename Callable>
        shared_ptr<invoker<Callable>> make_invoker_ptr(Callable&& F)
        {
            return std::make_shared<invoker<Callable>>(std::forward<Callable>(F));
        }

        
    };

    inline void swap(thread& x, thread& y) noexcept
    {
        x.swap(y);
    }


    inline bool operator==(thread::id x, thread::id y) noexcept
    {
        // From POSIX for pthread_equal: 
        //"If either t1 or t2 are not valid thread IDs,  the behavior is undefined."
        return x.__cpp_thread_t == y.__cpp_thread_t;
    }

    inline bool operator!=(thread::id x, thread::id y) noexcept
    {
        return !(x == y);
    }

    inline bool operator<(thread::id x, thread::id y) noexcept
    {
        return x.__cpp_thread_t < y.__cpp_thread_t;
    }

    inline bool operator<=(thread::id x, thread::id y) noexcept
    {
        return !(y < x);
    }

    inline bool operator>(thread::id x, thread::id y) noexcept
    {
        return !(x <= y);
    }

    inline bool operator>=(thread::id x, thread::id y) noexcept
    {
        return !(x < y);
    }

    template <class charT, class traits>
    inline basic_ostream<charT, traits>& 
    operator<<(basic_ostream<charT, traits>& out, thread::id id)
    {
        if (id == thread::id()) // id is invalid, representing no pthread
            out << "thread::id of a non-executing thread";
        else
            out << id.__cpp_thread_t;
        return out;
    }

    template <> 
    struct hash<thread::id> 
    {
        typedef size_t result_type;
        typedef thread::id argument_type;
        size_t operator()(const thread::id& id) const noexcept
        {
            return hash<rt_cpp_thread_t>()(id.__cpp_thread_t);
        }
    };

    namespace this_thread 
    {
        inline thread::id get_id() noexcept
        {
            return thread::id(pthread_self());
        }

        inline void yield() noexcept
        {
            sched_yield();
        } 

        template <class Rep, class Period>
        inline void sleep_for(const chrono::duration<Rep, Period>& rel_time)
        {
            if (rel_time <= rel_time.zero()) // less than zero, no need to sleep
                return;
            auto milli_secs = chrono::duration_cast<chrono::milliseconds>(rel_time);
            // the precision is limited by rt-thread thread API
            rt_thread_mdelay(milli_secs.count());
        }

        template <class Clock, class Duration>
        inline void sleep_until(const chrono::time_point<Clock, Duration>& abs_time)
        {
            auto now = Clock::now();
            if (abs_time > now)
                sleep_for(abs_time - now);
        }

    }
}