OpenWalnut  1.4.0
WThreadedFunction.h
00001 //---------------------------------------------------------------------------
00002 //
00003 // Project: OpenWalnut ( http://www.openwalnut.org )
00004 //
00005 // Copyright 2009 OpenWalnut Community, BSV@Uni-Leipzig and CNCF@MPI-CBS
00006 // For more information see http://www.openwalnut.org/copying
00007 //
00008 // This file is part of OpenWalnut.
00009 //
00010 // OpenWalnut is free software: you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as published by
00012 // the Free Software Foundation, either version 3 of the License, or
00013 // (at your option) any later version.
00014 //
00015 // OpenWalnut is distributed in the hope that it will be useful,
00016 // but WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00018 // GNU Lesser General Public License for more details.
00019 //
00020 // You should have received a copy of the GNU Lesser General Public License
00021 // along with OpenWalnut. If not, see <http://www.gnu.org/licenses/>.
00022 //
00023 //---------------------------------------------------------------------------
00024 
00025 #ifndef WTHREADEDFUNCTION_H
00026 #define WTHREADEDFUNCTION_H
00027 
00028 #include <memory.h>
00029 #include <iostream>
00030 
00031 #include <string>
00032 #include <vector>
00033 #include <boost/thread.hpp>
00034 
00035 #include "WAssert.h"
00036 #include "WWorkerThread.h"
00037 #include "WSharedObject.h"
00038 
00039 
00040 /**
00041  * An enum indicating the status of a multithreaded computation
00042  */
00043 enum WThreadedFunctionStatus
00044 {
00045     W_THREADS_INITIALIZED,      //! the status after constructing the function
00046     W_THREADS_RUNNING,          //! the threads were started
00047     W_THREADS_STOP_REQUESTED,   //! a stop was requested and not all threads have stopped yet
00048     W_THREADS_ABORTED,          //! at least one thread was aborted due to a stop request or an exception
00049     W_THREADS_FINISHED          //! all threads completed their work successfully
00050 };
00051 
00052 /**
00053  * An enum indicating the number of threads used
00054  */
00055 enum WThreadedFunctionNbThreads
00056 {
00057     W_AUTOMATIC_NB_THREADS = 0      //!< Use half the available cores as number of threads
00058 };
00059 
00060 /**
00061  * \class WThreadedFunctionBase
00062  *
00063  * A virtual base class for threaded functions (see below).
00064  */
00065 class WThreadedFunctionBase // NOLINT
00066 {
00067     //! a type for exception signals
00068     typedef boost::signals2::signal< void ( WException const& ) > ExceptionSignal;
00069 
00070 public:
00071     //! a type for exception callbacks
00072     typedef boost::function< void ( WException const& ) > ExceptionFunction;
00073 
00074     /**
00075      * Standard constructor.
00076      */
00077     WThreadedFunctionBase();
00078 
00079     /**
00080      * Destroys the thread pool and stops all threads, if any one of them is still running.
00081      *
00082      * \note Of course, the client has to make sure the threads do not work endlessly on a single job.
00083      */
00084     virtual ~WThreadedFunctionBase();
00085 
00086     /**
00087      * Starts the threads.
00088      */
00089     virtual void run() = 0;
00090 
00091     /**
00092      * Request all threads to stop. Returns immediately, so you might
00093      * have to wait() for the threads to actually finish.
00094      */
00095     virtual void stop() = 0;
00096 
00097     /**
00098      * Wait for all threads to stop.
00099      */
00100     virtual void wait() = 0;
00101 
00102     /**
00103      * Get the status of the threads.
00104      *
00105      * \return The current status.
00106      */
00107     WThreadedFunctionStatus status();
00108 
00109     /**
00110      * Returns a condition that gets fired when all threads have finished.
00111      *
00112      * \return The condition indicating all threads are done.
00113      */
00114     boost::shared_ptr< WCondition > getThreadsDoneCondition();
00115 
00116     /**
00117      * Subscribe a function to an exception signal.
00118      *
00119      * \param func The function to subscribe.
00120      */
00121     void subscribeExceptionSignal( ExceptionFunction func );
00122 
00123 protected:
00124     /**
00125      * WThreadedFunctionBase is non-copyable, so the copy constructor is not implemented.
00126      */
00127     WThreadedFunctionBase( WThreadedFunctionBase const& ); // NOLINT
00128 
00129     /**
00130      * WThreadedFunctionBase is non-copyable, so the copy operator is not implemented.
00131      *
00132      * \return this function
00133      */
00134     WThreadedFunctionBase& operator = ( WThreadedFunctionBase const& );
00135 
00136     //! a condition that gets notified when the work is complete
00137     boost::shared_ptr< WCondition > m_doneCondition;
00138 
00139     //! a signal for exceptions
00140     ExceptionSignal m_exceptionSignal;
00141 
00142     //! the current status
00143     WSharedObject< WThreadedFunctionStatus > m_status;
00144 };
00145 
00146 /**
00147  * \class WThreadedFunction
00148  *
00149  * Creates threads that computes a function in a multithreaded fashion. The template parameter
00150  * is an object that provides a function to execute. The following function needs to be implemented:
00151  *
00152  * void operator ( std::size_t id, std::size_t mx, WBoolFlag const& s );
00153  *
00154  * Here, 'id' is the number of the thread currently executing the function, ranging from
00155  * 0 to mx - 1, where 'mx' is the number of threads running. 's' is a flag that indicates
00156  * if the execution should be stopped. Make sure to check the flag often, so that the threads
00157  * can be stopped when needed.
00158  *
00159  * This class itself is NOT thread-safe, do not access it from different threads simultaneously.
00160  * Also, make sure any resources used by your function are accessed in a threadsafe manner,
00161  * as all threads share the same function object.
00162  *
00163  * Any exception thrown by your function will be caught and forwarded via the exception
00164  * signal. Beware that the signal function will be called in the executing threads, as opposed
00165  * to in your module thread. This means that the exception handler bound to the exception
00166  * signal must be threadsafe too.
00167  *
00168  * The status of the execution can be checked via the status() function. Also, when all threads
00169  * finish (due to throwing exceptions or actually successfully finishing computation ), a condition
00170  * will be notified.
00171  *
00172  * \ingroup common
00173  */
00174 template< class Function_T >
00175 class WThreadedFunction : public WThreadedFunctionBase
00176 {
00177     //! a type for exception signals
00178     typedef boost::signals2::signal< void ( WException const& ) > ExceptionSignal;
00179 
00180 public:
00181     //! a type for exception callbacks
00182     typedef boost::function< void ( WException const& ) > ExceptionFunction;
00183 
00184     /**
00185      * Creates the thread pool with a given number of threads.
00186      *
00187      * \param numThreads The number of threads to create.
00188      * \param function The function object.
00189      *
00190      * \note If the number of threads equals 0, a good number of threads will be determined by the threadpool.
00191      */
00192     WThreadedFunction( std::size_t numThreads, boost::shared_ptr< Function_T > function );
00193 
00194     /**
00195      * Destroys the thread pool and stops all threads, if any one of them is still running.
00196      *
00197      * \note Of course, the client has to make sure the threads do not work endlessly on a single job.
00198      */
00199     virtual ~WThreadedFunction();
00200 
00201     /**
00202      * Starts the threads.
00203      */
00204     virtual void run();
00205 
00206     /**
00207      * Request all threads to stop. Returns immediately, so you might
00208      * have to wait() for the threads to actually finish.
00209      */
00210     virtual void stop();
00211 
00212     /**
00213      * Wait for all threads to stop.
00214      */
00215     virtual void wait();
00216 
00217 private:
00218     /**
00219      * WThreadedFunction is non-copyable, so the copy constructor is not implemented.
00220      */
00221     WThreadedFunction( WThreadedFunction const& ); // NOLINT
00222 
00223     /**
00224      * WThreadedFunction is non-copyable, so the copy operator is not implemented.
00225      *
00226      * \return this function
00227      */
00228     WThreadedFunction& operator = ( WThreadedFunction const& );
00229 
00230     /**
00231      * This function gets subscribed to the threads' stop signals.
00232      */
00233     void handleThreadDone();
00234 
00235     /**
00236      * This function handles exceptions thrown in the worker threads.
00237      *
00238      * \param e The exception that was thrown.
00239      */
00240     void handleThreadException( WException const& e );
00241 
00242     //! the number of threads to manage
00243     std::size_t m_numThreads;
00244 
00245     //! the threads
00246     // use shared_ptr here, because WWorkerThread is non-copyable
00247     std::vector< boost::shared_ptr< WWorkerThread< Function_T > > > m_threads;
00248 
00249     //! the function object
00250     boost::shared_ptr< Function_T > m_func;
00251 
00252     //! a counter that keeps track of how many threads have finished
00253     WSharedObject< std::size_t > m_threadsDone;
00254 };
00255 
00256 template< class Function_T >
00257 WThreadedFunction< Function_T >::WThreadedFunction( std::size_t numThreads, boost::shared_ptr< Function_T > function )
00258     : WThreadedFunctionBase(),
00259       m_numThreads( numThreads ),
00260       m_threads(),
00261       m_func( function ),
00262       m_threadsDone()
00263 {
00264     if( !m_func )
00265     {
00266         throw WException( std::string( "No valid thread function pointer." ) );
00267     }
00268 
00269     // find a suitable number of threads
00270     if( m_numThreads == W_AUTOMATIC_NB_THREADS )
00271     {
00272         m_numThreads = 1;
00273         while( m_numThreads < boost::thread::hardware_concurrency() / 2 && m_numThreads < 1024 )
00274         {
00275             m_numThreads *= 2;
00276         }
00277     }
00278 
00279     // set number of finished threads to 0
00280     m_threadsDone.getWriteTicket()->get() = 0;
00281 
00282     // create threads
00283     for( std::size_t k = 0; k < m_numThreads; ++k )
00284     {
00285         boost::shared_ptr< WWorkerThread< Function_T > > t( new WWorkerThread< Function_T >( m_func, k, m_numThreads ) );
00286         t->subscribeStopSignal( boost::bind( &WThreadedFunction::handleThreadDone, this ) );
00287         t->subscribeExceptionSignal( boost::bind( &WThreadedFunction::handleThreadException, this, _1 ) );
00288         m_threads.push_back( t );
00289     }
00290 }
00291 
00292 template< class Function_T >
00293 WThreadedFunction< Function_T >::~WThreadedFunction()
00294 {
00295     stop();
00296 }
00297 
00298 template< class Function_T >
00299 void WThreadedFunction< Function_T >::run()
00300 {
00301     // set the number of finished threads to 0
00302     m_threadsDone.getWriteTicket()->get() = 0;
00303     // change status
00304     m_status.getWriteTicket()->get() = W_THREADS_RUNNING;
00305     // start threads
00306     for( std::size_t k = 0; k < m_numThreads; ++k )
00307     {
00308         m_threads[ k ]->run();
00309     }
00310 }
00311 
00312 template< class Function_T >
00313 void WThreadedFunction< Function_T >::stop()
00314 {
00315     // change status
00316     m_status.getWriteTicket()->get() = W_THREADS_STOP_REQUESTED;
00317 
00318     typename std::vector< boost::shared_ptr< WWorkerThread< Function_T > > >::iterator it;
00319     // tell the threads to stop
00320     for( it = m_threads.begin(); it != m_threads.end(); ++it )
00321     {
00322         ( *it )->requestStop();
00323     }
00324 }
00325 
00326 template< class Function_T >
00327 void WThreadedFunction< Function_T >::wait()
00328 {
00329     typename std::vector< boost::shared_ptr< WWorkerThread< Function_T > > >::iterator it;
00330     // wait for the threads to stop
00331     for( it = m_threads.begin(); it != m_threads.end(); ++it )
00332     {
00333         ( *it )->wait();
00334     }
00335 }
00336 
00337 template< class Function_T >
00338 void WThreadedFunction< Function_T >::handleThreadDone()
00339 {
00340     typedef typename WSharedObject< std::size_t >::WriteTicket WT;
00341 
00342     WT t = m_threadsDone.getWriteTicket();
00343     WAssert( t->get() < m_numThreads, "" );
00344     ++t->get();
00345     std::size_t k = t->get();
00346     t = WT();
00347 
00348     if( m_numThreads == k )
00349     {
00350         typedef typename WSharedObject< WThreadedFunctionStatus >::WriteTicket ST;
00351         ST s = m_status.getWriteTicket();
00352         if( s->get() == W_THREADS_RUNNING )
00353         {
00354             s->get() = W_THREADS_FINISHED;
00355         }
00356         else if( s->get() == W_THREADS_STOP_REQUESTED )
00357         {
00358             s->get() = W_THREADS_ABORTED;
00359         }
00360         else
00361         {
00362             throw WException( std::string( "Invalid status change." ) );
00363         }
00364         m_doneCondition->notify();
00365     }
00366 }
00367 
00368 template< class Function_T >
00369 void WThreadedFunction< Function_T >::handleThreadException( WException const& e )
00370 {
00371     // change status
00372     typedef typename WSharedObject< WThreadedFunctionStatus >::WriteTicket WT;
00373     WT w = m_status.getWriteTicket();
00374     WAssert( w->get() != W_THREADS_FINISHED &&
00375              w->get() != W_THREADS_ABORTED, "" );
00376     if( w->get() == W_THREADS_RUNNING )
00377     {
00378         w->get() = W_THREADS_STOP_REQUESTED;
00379     }
00380     // force destruction of the write ticket
00381     w = WT();
00382     // update the number of finished threads
00383     handleThreadDone();
00384 
00385     m_exceptionSignal( e );
00386 }
00387 
00388 #endif  // WTHREADEDFUNCTION_H