1
1
"""Synchronization primitives."""
2
2
3
- __all__ = ('Lock' , 'Event' , 'Condition' , 'Semaphore' , 'BoundedSemaphore' )
3
+ __all__ = ('Lock' , 'Event' , 'Condition' , 'Semaphore' , 'BoundedSemaphore' , 'Barrier' , 'BrokenBarrierError' )
4
4
5
5
import collections
6
6
@@ -418,3 +418,174 @@ def release(self):
418
418
if self ._value >= self ._bound_value :
419
419
raise ValueError ('BoundedSemaphore released too many times' )
420
420
super ().release ()
421
+
422
+
423
+ # A barrier class. Inspired in part by the pthread_barrier_* api and
424
+ # the CyclicBarrier class from Java. See
425
+ # http://sourceware.org/pthreads-win32/manual/pthread_barrier_init.html and
426
+ # http://java.sun.com/j2se/1.5.0/docs/api/java/util/concurrent/
427
+ # CyclicBarrier.html
428
+ # for information.
429
+ # We maintain two main states, 'filling' and 'draining' enabling the barrier
430
+ # to be cyclic. Tasks are not allowed into it until it has fully drained
431
+ # since the previous cycle. In addition, a 'resetting' state exists which is
432
+ # similar to 'draining' except that tasks leave with a BrokenBarrierError,
433
+ # and a 'broken' state in which all tasks get the exception.
434
+
435
+ class Barrier (mixins ._LoopBoundMixin ):
436
+ """Asynchronous equivalent to threading.Barrier
437
+
438
+ Implements a Barrier.
439
+ Useful for synchronizing a fixed number of tasks at known synchronization
440
+ points. Tasks block on 'wait()' and are simultaneously awoken once they
441
+ have all made that call.
442
+ """
443
+
444
+ def __init__ (self , parties , action = None , * , loop = mixins ._marker ):
445
+ """Create a barrier, initialised to 'parties' tasks.
446
+ 'action' is a callable which, when supplied, will be called by one of
447
+ the tasks after they have all entered the barrier and just prior to
448
+ releasing them all.
449
+ """
450
+ super ().__init__ (loop = loop )
451
+ if parties < 1 :
452
+ raise ValueError ('parties must be > 0' )
453
+
454
+ self ._waiting = Event () # used notify all waiting tasks
455
+ self ._blocking = Event () # used block tasks while wainting tasks are draining or broken
456
+ self ._action = action
457
+ self ._parties = parties
458
+ self ._state = 0 # 0 filling, 1, draining, -1 resetting, -2 broken
459
+ self ._count = 0 # count waiting tasks
460
+
461
+ def __repr__ (self ):
462
+ res = super ().__repr__ ()
463
+ _wait = 'set' if self ._waiting .is_set () else 'unset'
464
+ _block = 'set' if self ._blocking .is_set () else 'unset'
465
+ extra = f'{ _wait } , count:{ self ._count } /{ self ._parties } , { _block } , state:{ self ._state } '
466
+ return f'<{ res [1 :- 1 ]} [{ extra } ]>'
467
+
468
+ async def wait (self ):
469
+ """Wait for the barrier.
470
+ When the specified number of tasks have started waiting, they are all
471
+ simultaneously awoken. If an 'action' was provided for the barrier, one
472
+ of the tasks will have executed that callback prior to returning.
473
+ Returns an individual index number from 0 to 'parties-1'.
474
+ """
475
+ await self ._block () # Block while the barrier drains or resets.
476
+ index = self ._count
477
+ self ._count += 1
478
+ try :
479
+ if index + 1 == self ._parties :
480
+ # We release the barrier
481
+ self ._release ()
482
+ else :
483
+ # We wait until someone releases us
484
+ await self ._wait ()
485
+ return index
486
+ finally :
487
+ self ._count -= 1
488
+ # Wake up any tasks waiting for barrier to drain.
489
+ self ._exit ()
490
+
491
+ # Block until the barrier is ready for us, or raise an exception
492
+ # if it is broken.
493
+ async def _block (self ):
494
+ if self ._state in (- 1 , 1 ):
495
+ # It is draining or resetting, wait until done
496
+ await self ._blocking .wait ()
497
+
498
+ #see if the barrier is in a broken state
499
+ if self ._state < 0 :
500
+ raise BrokenBarrierError
501
+ assert self ._state == 0 , repr (self )
502
+
503
+ # Optionally run the 'action' and release the tasks waiting
504
+ # in the barrier.
505
+ def _release (self ):
506
+ try :
507
+ if self ._action :
508
+ self ._action ()
509
+ # enter draining state
510
+ self ._state = 1
511
+ self ._blocking .clear ()
512
+ self ._waiting .set ()
513
+ except :
514
+ #an exception during the _action handler. Break and reraise
515
+ self ._state = - 2
516
+ self ._blocking .clear ()
517
+ self ._waiting .set ()
518
+ raise
519
+
520
+ # Wait in the barrier until we are released. Raise an exception
521
+ # if the barrier is reset or broken.
522
+ async def _wait (self ):
523
+ await self ._waiting .wait ()
524
+ # no timeout so
525
+ if self ._state < 0 :
526
+ raise BrokenBarrierError
527
+ assert self ._state == 1 , repr (self )
528
+
529
+ # If we are the last tasks to exit the barrier, signal any tasks
530
+ # waiting for the barrier to drain.
531
+ def _exit (self ):
532
+ if self ._count == 0 :
533
+ if self ._state == 1 :
534
+ self ._state = 0
535
+ elif self ._state == - 1 :
536
+ self ._state = 0
537
+ self ._waiting .clear ()
538
+ self ._blocking .set ()
539
+
540
+ # async def reset(self):
541
+ def reset (self ):
542
+ """Reset the barrier to the initial state.
543
+ Any tasks currently waiting will get the BrokenBarrier exception
544
+ raised.
545
+ """
546
+ if self ._count > 0 :
547
+ if self ._state in (0 , 1 ):
548
+ #reset the barrier, waking up tasks
549
+ self ._state = - 1
550
+ elif self ._state == - 2 :
551
+ #was broken, set it to reset state
552
+ #which clears when the last tasks exits
553
+ self ._state = - 1
554
+ self ._waiting .set ()
555
+ self ._blocking .clear ()
556
+ else :
557
+ self ._state = 0
558
+
559
+
560
+ # async def abort(self):
561
+ def abort (self ):
562
+ """Place the barrier into a 'broken' state.
563
+ Useful in case of error. Any currently waiting tasks and tasks
564
+ attempting to 'wait()' will have BrokenBarrierError raised.
565
+ """
566
+ self ._state = - 2
567
+ self ._waiting .set ()
568
+ self ._blocking .clear ()
569
+
570
+ @property
571
+ def parties (self ):
572
+ """Return the number of tasks required to trip the barrier."""
573
+ return self ._parties
574
+
575
+ @property
576
+ def n_waiting (self ):
577
+ """Return the number of tasks currently waiting at the barrier."""
578
+ # We don't need synchronization here since this is an ephemeral result
579
+ # anyway. It returns the correct value in the steady state.
580
+ if self ._state == 0 :
581
+ return self ._count
582
+ return 0
583
+
584
+ @property
585
+ def broken (self ):
586
+ """Return True if the barrier is in a broken state."""
587
+ return self ._state == - 2
588
+
589
+ # exception raised by the Barrier class
590
+ class BrokenBarrierError (RuntimeError ):
591
+ pass
0 commit comments