@@ -936,35 +936,35 @@ async def gather_tasks(self, n, coro):
936
936
937
937
async def test_barrier (self ):
938
938
barrier = asyncio .Barrier (self .N )
939
- self .assertTrue (barrier .filling )
939
+ self .assertTrue (barrier ._is_filling () )
940
940
with self .assertRaisesRegex (
941
941
TypeError ,
942
942
"object Barrier can't be used in 'await' expression" ,
943
943
):
944
944
await barrier
945
945
946
- self .assertTrue (barrier .filling )
946
+ self .assertTrue (barrier ._is_filling () )
947
947
948
948
async def test_repr (self ):
949
949
barrier = asyncio .Barrier (self .N )
950
950
951
951
self .assertTrue (RGX_REPR .match (repr (barrier )))
952
952
953
953
incr = 2
954
- barrier ._count_wait += incr
954
+ barrier ._count += incr
955
955
self .assertTrue (RGX_REPR .match (repr (barrier )))
956
956
self .assertTrue (f"wait:{ incr } /{ self .N } " in repr (barrier ))
957
- self .assertTrue (f"block:0/ { self . N } " in repr (barrier ))
957
+ self .assertTrue (f"block" not in repr (barrier ))
958
958
959
959
barrier ._set_filling ()
960
960
self .assertTrue (RGX_REPR .match (repr (barrier )))
961
961
self .assertTrue (repr (barrier ).endswith ('state:0]>' ))
962
962
963
963
barrier ._count_block += incr
964
- barrier ._count_wait -= incr
964
+ barrier ._count -= incr
965
965
self .assertTrue (RGX_REPR .match (repr (barrier )))
966
- self .assertTrue (f"wait:0/{ self .N } " in repr (barrier ))
967
966
self .assertTrue (f"block:{ incr } /{ self .N } " in repr (barrier ))
967
+ self .assertTrue (f"wait" not in repr (barrier ))
968
968
969
969
barrier ._set_draining ()
970
970
self .assertTrue (RGX_REPR .match (repr (barrier )))
@@ -992,7 +992,7 @@ async def test_barrier_parties(self):
992
992
self .assertRaises (ValueError , lambda : asyncio .Barrier (- 4 ))
993
993
994
994
async def test_context_manager (self ):
995
- self .N = 2
995
+ self .N = 3
996
996
barrier = asyncio .Barrier (self .N )
997
997
results = []
998
998
@@ -1004,50 +1004,58 @@ async def coro():
1004
1004
1005
1005
self .assertListEqual (sorted (results ), list (range (self .N )))
1006
1006
self .assertEqual (barrier .n_waiting , 0 )
1007
- self .assertFalse (barrier .broken )
1007
+ self .assertFalse (barrier ._is_broken () )
1008
1008
1009
1009
async def test_filling_one_task (self ):
1010
1010
barrier = asyncio .Barrier (1 )
1011
- results = []
1012
1011
1013
1012
async def f ():
1014
1013
async with barrier as i :
1015
- results . append ( i )
1014
+ return True
1016
1015
1017
- await f ()
1016
+ ret = await f ()
1018
1017
1019
- self .assertEqual (len (results ), 1 )
1020
- self .assertEqual (results , [0 ])
1018
+ self .assertTrue (ret )
1021
1019
self .assertEqual (barrier .n_waiting , 0 )
1022
- self .assertFalse (barrier .broken )
1020
+ self .assertFalse (barrier ._is_broken () )
1023
1021
1024
1022
async def test_filling_one_task_twice (self ):
1025
1023
barrier = asyncio .Barrier (1 )
1026
1024
1027
- r1 = asyncio .create_task (barrier .wait ())
1025
+ t1 = asyncio .create_task (barrier .wait ())
1028
1026
await asyncio .sleep (0 )
1029
1027
self .assertEqual (barrier .n_waiting , 0 )
1030
1028
1031
- r2 = asyncio .create_task (barrier .wait ())
1029
+ t2 = asyncio .create_task (barrier .wait ())
1032
1030
await asyncio .sleep (0 )
1033
1031
1034
- self .assertEqual (r1 .result (), r2 .result ())
1035
- self .assertEqual (r1 .done (), r2 .done ())
1032
+ self .assertEqual (t1 .result (), t2 .result ())
1033
+ self .assertEqual (t1 .done (), t2 .done ())
1036
1034
1037
1035
self .assertEqual (barrier .n_waiting , 0 )
1038
- self .assertFalse (barrier .broken )
1036
+ self .assertFalse (barrier ._is_broken () )
1039
1037
1040
1038
async def test_filling_task_by_task (self ):
1041
1039
self .N = 3
1042
1040
barrier = asyncio .Barrier (self .N )
1043
1041
1044
- async def coro ():
1045
- await barrier .wait ()
1042
+ t1 = asyncio .create_task (barrier .wait ())
1043
+ await asyncio .sleep (0 )
1044
+ self .assertEqual (barrier .n_waiting , 1 )
1045
+ self .assertTrue (barrier ._is_filling ())
1046
1046
1047
- await self .gather_tasks (self .N , coro )
1047
+ t2 = asyncio .create_task (barrier .wait ())
1048
+ await asyncio .sleep (0 )
1049
+ self .assertEqual (barrier .n_waiting , 2 )
1050
+ self .assertTrue (barrier ._is_filling ())
1051
+
1052
+ t3 = asyncio .create_task (barrier .wait ())
1053
+ await asyncio .sleep (0 )
1054
+
1055
+ await asyncio .wait ([t1 , t2 , t3 ])
1048
1056
1049
1057
self .assertEqual (barrier .n_waiting , 0 )
1050
- self .assertFalse (barrier .broken )
1058
+ self .assertFalse (barrier ._is_broken () )
1051
1059
1052
1060
async def test_filling_tasks_wait_twice (self ):
1053
1061
barrier = asyncio .Barrier (self .N )
@@ -1067,7 +1075,7 @@ async def coro():
1067
1075
self .assertEqual (results .count (False ), self .N )
1068
1076
1069
1077
self .assertEqual (barrier .n_waiting , 0 )
1070
- self .assertFalse (barrier .broken )
1078
+ self .assertFalse (barrier ._is_broken () )
1071
1079
1072
1080
async def test_filling_tasks_check_return_value (self ):
1073
1081
barrier = asyncio .Barrier (self .N )
@@ -1091,7 +1099,7 @@ async def coro():
1091
1099
self .assertListEqual (sorted (res ), list (range (self .N )))
1092
1100
1093
1101
self .assertEqual (barrier .n_waiting , 0 )
1094
- self .assertFalse (barrier .broken )
1102
+ self .assertFalse (barrier ._is_broken () )
1095
1103
1096
1104
async def test_draining_state (self ):
1097
1105
barrier = asyncio .Barrier (self .N )
@@ -1100,7 +1108,7 @@ async def test_draining_state(self):
1100
1108
async def coro ():
1101
1109
async with barrier :
1102
1110
# barrier state change to filling for the last task release
1103
- results .append (barrier .draining )
1111
+ results .append (barrier ._is_draining () )
1104
1112
1105
1113
await self .gather_tasks (self .N , coro )
1106
1114
@@ -1109,7 +1117,7 @@ async def coro():
1109
1117
self .assertTrue (all (results [:self .N - 1 ]))
1110
1118
1111
1119
self .assertEqual (barrier .n_waiting , 0 )
1112
- self .assertFalse (barrier .broken )
1120
+ self .assertFalse (barrier ._is_broken () )
1113
1121
1114
1122
async def test_blocking_tasks_while_draining (self ):
1115
1123
rewait = 2
@@ -1166,6 +1174,8 @@ async def coro():
1166
1174
t1 .cancel ()
1167
1175
await asyncio .sleep (0 )
1168
1176
self .assertEqual (barrier .n_waiting , 1 )
1177
+ with self .assertRaises (asyncio .CancelledError ):
1178
+ await t1
1169
1179
self .assertTrue (t1 .cancelled ())
1170
1180
1171
1181
t3 = asyncio .create_task (coro ())
@@ -1179,7 +1189,7 @@ async def coro():
1179
1189
self .assertTrue (all (results ))
1180
1190
1181
1191
self .assertEqual (barrier .n_waiting , 0 )
1182
- self .assertFalse (barrier .broken )
1192
+ self .assertFalse (barrier ._is_broken () )
1183
1193
1184
1194
async def test_draining_check_action (self ):
1185
1195
async def action_task ():
@@ -1200,7 +1210,37 @@ async def coro():
1200
1210
self .assertTrue (all (results ))
1201
1211
1202
1212
self .assertEqual (barrier .n_waiting , 0 )
1203
- self .assertFalse (barrier .broken )
1213
+ self .assertFalse (barrier ._is_broken ())
1214
+
1215
+ async def test_draining_check_cancelled_action (self ):
1216
+ async def action_coro ():
1217
+ asyncio .current_task ().cancel ()
1218
+ await asyncio .sleep (0 )
1219
+
1220
+ barrier = asyncio .Barrier (self .N , action = action_coro )
1221
+ results = []
1222
+ results1 = []
1223
+ results2 = []
1224
+
1225
+ async def coro ():
1226
+ try :
1227
+ await barrier .wait ()
1228
+ results .append (True )
1229
+ except asyncio .CancelledError :
1230
+ results1 .append (True )
1231
+ except asyncio .BrokenBarrierError :
1232
+ results2 .append (True )
1233
+
1234
+ #with self.assertRaises(asyncio.CancelledError):
1235
+ res , t = await self .gather_tasks (self .N , coro )
1236
+
1237
+ self .assertEqual (len (results ), 0 )
1238
+ self .assertEqual (results1 , [True ])
1239
+ self .assertEqual (len (results2 ), self .N - 1 )
1240
+ self .assertTrue (all (results2 ))
1241
+
1242
+ self .assertEqual (barrier .n_waiting , 0 )
1243
+ self .assertTrue (barrier ._is_broken ())
1204
1244
1205
1245
async def test_draining_check_error_on_action (self ):
1206
1246
ERROR = ZeroDivisionError
@@ -1214,21 +1254,20 @@ async def raise_except():
1214
1254
1215
1255
async def coro ():
1216
1256
try :
1217
- ret = await barrier .wait ()
1257
+ await barrier .wait ()
1218
1258
except ERROR :
1219
1259
results1 .append (False )
1220
1260
except asyncio .BrokenBarrierError :
1221
1261
results2 .append (True )
1222
1262
1223
1263
await self .gather_tasks (self .N , coro )
1224
1264
1225
- self .assertEqual (len (results1 ), 1 )
1226
- self .assertFalse (results1 [0 ])
1265
+ self .assertEqual (results1 , [False ])
1227
1266
self .assertEqual (len (results2 ), self .N - 1 )
1228
1267
self .assertTrue (all (results2 ))
1229
1268
1230
1269
self .assertEqual (barrier .n_waiting , 0 )
1231
- self .assertTrue (barrier .broken )
1270
+ self .assertTrue (barrier ._is_broken () )
1232
1271
1233
1272
async def test_reset_barrier (self ):
1234
1273
barrier = asyncio .Barrier (1 )
@@ -1237,7 +1276,7 @@ async def test_reset_barrier(self):
1237
1276
await asyncio .sleep (0 )
1238
1277
1239
1278
self .assertEqual (barrier .n_waiting , 0 )
1240
- self .assertFalse (barrier .broken )
1279
+ self .assertFalse (barrier ._is_broken () )
1241
1280
1242
1281
async def test_reset_barrier_while_tasks_waiting (self ):
1243
1282
barrier = asyncio .Barrier (self .N )
@@ -1263,8 +1302,8 @@ async def coro_reset():
1263
1302
self .assertEqual (len (results ), self .N - 1 )
1264
1303
self .assertTrue (all (results ))
1265
1304
self .assertEqual (barrier .n_waiting , 0 )
1266
- self .assertFalse (barrier .resetting )
1267
- self .assertFalse (barrier .broken )
1305
+ self .assertFalse (barrier ._is_resetting () )
1306
+ self .assertFalse (barrier ._is_broken () )
1268
1307
1269
1308
async def test_reset_barrier_when_tasks_half_draining (self ):
1270
1309
barrier = asyncio .Barrier (self .N )
@@ -1287,8 +1326,8 @@ async def coro():
1287
1326
1288
1327
self .assertEqual (results1 , [True ]* rest_of_tasks )
1289
1328
self .assertEqual (barrier .n_waiting , 0 )
1290
- self .assertFalse (barrier .resetting )
1291
- self .assertFalse (barrier .broken )
1329
+ self .assertFalse (barrier ._is_resetting () )
1330
+ self .assertFalse (barrier ._is_broken () )
1292
1331
1293
1332
async def test_reset_barrier_when_tasks_half_draining_half_blocking (self ):
1294
1333
barrier = asyncio .Barrier (self .N )
@@ -1324,8 +1363,8 @@ async def coro():
1324
1363
self .assertEqual (results1 , [True ]* blocking_tasks )
1325
1364
self .assertEqual (results2 , [])
1326
1365
self .assertEqual (barrier .n_waiting , 0 )
1327
- self .assertFalse (barrier .resetting )
1328
- self .assertFalse (barrier .broken )
1366
+ self .assertFalse (barrier ._is_resetting () )
1367
+ self .assertFalse (barrier ._is_broken () )
1329
1368
1330
1369
async def test_reset_barrier_while_tasks_waiting_and_waiting_again (self ):
1331
1370
barrier = asyncio .Barrier (self .N )
@@ -1356,7 +1395,7 @@ async def coro2():
1356
1395
1357
1396
await asyncio .gather (* tasks )
1358
1397
1359
- self .assertFalse (barrier .broken )
1398
+ self .assertFalse (barrier ._is_broken () )
1360
1399
self .assertEqual (len (results1 ), self .N - 1 )
1361
1400
self .assertTrue (all (results1 ))
1362
1401
self .assertEqual (len (results2 ), self .N )
@@ -1401,7 +1440,7 @@ async def coro():
1401
1440
1402
1441
await self .gather_tasks (self .N , coro )
1403
1442
1404
- self .assertFalse (barrier .broken )
1443
+ self .assertFalse (barrier ._is_broken () )
1405
1444
self .assertTrue (all (results1 ))
1406
1445
self .assertEqual (len (results1 ), self .N - 1 )
1407
1446
self .assertEqual (len (results2 ), 0 )
@@ -1417,7 +1456,7 @@ async def test_abort_barrier(self):
1417
1456
await asyncio .sleep (0 )
1418
1457
1419
1458
self .assertEqual (barrier .n_waiting , 0 )
1420
- self .assertTrue (barrier .broken )
1459
+ self .assertTrue (barrier ._is_broken () )
1421
1460
1422
1461
async def test_abort_barrier_when_tasks_half_draining_half_blocking (self ):
1423
1462
barrier = asyncio .Barrier (self .N )
@@ -1444,11 +1483,11 @@ async def coro():
1444
1483
1445
1484
await self .gather_tasks (self .N , coro )
1446
1485
1447
- self .assertTrue (barrier .broken )
1486
+ self .assertTrue (barrier ._is_broken () )
1448
1487
self .assertEqual (results1 , [True ]* blocking_tasks )
1449
1488
self .assertEqual (results2 , [True ]* (self .N - blocking_tasks - 1 ))
1450
1489
self .assertEqual (barrier .n_waiting , 0 )
1451
- self .assertFalse (barrier .resetting )
1490
+ self .assertFalse (barrier ._is_resetting () )
1452
1491
1453
1492
async def test_abort_barrier_when_exception (self ):
1454
1493
# test from threading.Barrier: see `lock_tests.test_reset`
@@ -1470,7 +1509,7 @@ async def coro():
1470
1509
1471
1510
await self .gather_tasks (self .N , coro )
1472
1511
1473
- self .assertTrue (barrier .broken )
1512
+ self .assertTrue (barrier ._is_broken () )
1474
1513
self .assertEqual (len (results1 ), 0 )
1475
1514
self .assertEqual (len (results2 ), self .N - 1 )
1476
1515
self .assertTrue (all (results2 ))
@@ -1508,7 +1547,7 @@ async def coro():
1508
1547
1509
1548
await self .gather_tasks (self .N , coro )
1510
1549
1511
- self .assertFalse (barrier1 .broken )
1550
+ self .assertFalse (barrier1 ._is_broken () )
1512
1551
self .assertEqual (len (results1 ), 0 )
1513
1552
self .assertEqual (len (results2 ), self .N - 1 )
1514
1553
self .assertTrue (all (results2 ))
0 commit comments