@@ -552,7 +552,7 @@ def _collect_written_timers_and_add_to_deferred_inputs(
552
552
stage ,
553
553
get_buffer_callable ,
554
554
deferred_inputs # type: DefaultDict[str, _ListBuffer]
555
- ):
555
+ ):
556
556
557
557
for transform_id , timer_writes in stage .timer_pcollections :
558
558
@@ -588,7 +588,7 @@ def _add_residuals_and_channel_splits_to_deferred_inputs(
588
588
input_for_callable ,
589
589
last_sent ,
590
590
deferred_inputs # type: DefaultDict[str, _ListBuffer]
591
- ):
591
+ ):
592
592
593
593
prev_stops = {} # type: Dict[str, int]
594
594
for split in splits :
@@ -1079,7 +1079,7 @@ def __init__(self,
1079
1079
control_handler ,
1080
1080
data_plane_handler ,
1081
1081
state , # type: FnApiRunner.StateServicer
1082
- provision_info # type: fn_api_runner.ExtendedProvisionInfo
1082
+ provision_info # type: Optional[ fn_api_runner.ExtendedProvisionInfo]
1083
1083
):
1084
1084
"""Initialize a WorkerHandler.
1085
1085
@@ -1120,14 +1120,23 @@ def logging_api_service_descriptor(self):
1120
1120
raise NotImplementedError
1121
1121
1122
1122
@classmethod
1123
- def register_environment (cls , urn , payload_type ):
1123
+ def register_environment (cls ,
1124
+ urn , # type: str
1125
+ payload_type # type: Optional[Type[T]]
1126
+ ):
1127
+ # type: (...) -> Callable[[Callable[[T, FnApiRunner.StateServicer, Optional[fn_api_runner.ExtendedProvisionInfo], GrpcServer], WorkerHandler]], Callable[[T, FnApiRunner.StateServicer, Optional[fn_api_runner.ExtendedProvisionInfo], GrpcServer], WorkerHandler]]
1124
1128
def wrapper (constructor ):
1125
1129
cls ._registered_environments [urn ] = constructor , payload_type
1126
1130
return constructor
1127
1131
return wrapper
1128
1132
1129
1133
@classmethod
1130
- def create (cls , environment , state , provision_info , grpc_server ):
1134
+ def create (cls ,
1135
+ environment , # type: beam_runner_api_pb2.Environment
1136
+ state , # type: FnApiRunner.StateServicer
1137
+ provision_info , # type: Optional[fn_api_runner.ExtendedProvisionInfo]
1138
+ grpc_server # type: GrpcServer
1139
+ ):
1131
1140
# type: (...) -> WorkerHandler
1132
1141
constructor , payload_type = cls ._registered_environments [environment .urn ]
1133
1142
return constructor (
@@ -1141,8 +1150,12 @@ def create(cls, environment, state, provision_info, grpc_server):
1141
1150
class EmbeddedWorkerHandler (WorkerHandler ):
1142
1151
"""An in-memory worker_handler for fn API control, state and data planes."""
1143
1152
1144
- def __init__ (self , unused_payload , state , provision_info ,
1145
- unused_grpc_server = None ):
1153
+ def __init__ (self ,
1154
+ unused_payload , # type: None
1155
+ state , # type: sdk_worker.StateHandler
1156
+ provision_info , # type: Optional[fn_api_runner.ExtendedProvisionInfo]
1157
+ unused_grpc_server = None
1158
+ ):
1146
1159
super (EmbeddedWorkerHandler , self ).__init__ (
1147
1160
self , data_plane .InMemoryDataChannel (), state , provision_info )
1148
1161
self .control_conn = self # type: ignore # need Protocol to describe this
@@ -1228,7 +1241,11 @@ class GrpcServer(object):
1228
1241
1229
1242
_DEFAULT_SHUTDOWN_TIMEOUT_SECS = 5
1230
1243
1231
- def __init__ (self , state , provision_info , max_workers ):
1244
+ def __init__ (self ,
1245
+ state ,
1246
+ provision_info , # type: Optional[fn_api_runner.ExtendedProvisionInfo]
1247
+ max_workers # type: int
1248
+ ):
1232
1249
self .state = state
1233
1250
self .provision_info = provision_info
1234
1251
self .max_workers = max_workers
@@ -1269,7 +1286,7 @@ def __init__(self, state, provision_info, max_workers):
1269
1286
1270
1287
if self .provision_info .artifact_staging_dir :
1271
1288
service = artifact_service .BeamFilesystemArtifactService (
1272
- self .provision_info .artifact_staging_dir )
1289
+ self .provision_info .artifact_staging_dir ) # type: beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceServicer
1273
1290
else :
1274
1291
service = EmptyArtifactRetrievalService ()
1275
1292
beam_artifact_api_pb2_grpc .add_ArtifactRetrievalServiceServicer_to_server (
@@ -1316,8 +1333,8 @@ class GrpcWorkerHandler(WorkerHandler):
1316
1333
"""An grpc based worker_handler for fn API control, state and data planes."""
1317
1334
1318
1335
def __init__ (self ,
1319
- state ,
1320
- provision_info ,
1336
+ state , # type: FnApiRunner.StateServicer
1337
+ provision_info , # type: Optional[fn_api_runner.ExtendedProvisionInfo]
1321
1338
grpc_server # type: GrpcServer
1322
1339
):
1323
1340
self ._grpc_server = grpc_server
@@ -1363,7 +1380,12 @@ def localhost_from_worker(self):
1363
1380
@WorkerHandler .register_environment (
1364
1381
common_urns .environments .EXTERNAL .urn , beam_runner_api_pb2 .ExternalPayload )
1365
1382
class ExternalWorkerHandler (GrpcWorkerHandler ):
1366
- def __init__ (self , external_payload , state , provision_info , grpc_server ):
1383
+ def __init__ (self ,
1384
+ external_payload , # type: beam_runner_api_pb2.ExternalPayload
1385
+ state , # type: FnApiRunner.StateServicer
1386
+ provision_info , # type: Optional[fn_api_runner.ExtendedProvisionInfo]
1387
+ grpc_server # type: GrpcServer
1388
+ ):
1367
1389
super (ExternalWorkerHandler , self ).__init__ (state , provision_info ,
1368
1390
grpc_server )
1369
1391
self ._external_payload = external_payload
@@ -1388,7 +1410,12 @@ def stop_worker(self):
1388
1410
1389
1411
@WorkerHandler .register_environment (python_urns .EMBEDDED_PYTHON_GRPC , bytes )
1390
1412
class EmbeddedGrpcWorkerHandler (GrpcWorkerHandler ):
1391
- def __init__ (self , num_workers_payload , state , provision_info , grpc_server ):
1413
+ def __init__ (self ,
1414
+ num_workers_payload ,
1415
+ state , # type: FnApiRunner.StateServicer
1416
+ provision_info , # type: Optional[fn_api_runner.ExtendedProvisionInfo]
1417
+ grpc_server # type: GrpcServer
1418
+ ):
1392
1419
super (EmbeddedGrpcWorkerHandler , self ).__init__ (state , provision_info ,
1393
1420
grpc_server )
1394
1421
self ._num_threads = int (num_workers_payload ) if num_workers_payload else 1
@@ -1413,7 +1440,12 @@ def stop_worker(self):
1413
1440
1414
1441
@WorkerHandler .register_environment (python_urns .SUBPROCESS_SDK , bytes )
1415
1442
class SubprocessSdkWorkerHandler (GrpcWorkerHandler ):
1416
- def __init__ (self , worker_command_line , state , provision_info , grpc_server ):
1443
+ def __init__ (self ,
1444
+ worker_command_line , # type: bytes
1445
+ state , # type: FnApiRunner.StateServicer
1446
+ provision_info , # type: Optional[fn_api_runner.ExtendedProvisionInfo]
1447
+ grpc_server # type: GrpcServer
1448
+ ):
1417
1449
super (SubprocessSdkWorkerHandler , self ).__init__ (state , provision_info ,
1418
1450
grpc_server )
1419
1451
self ._worker_command_line = worker_command_line
@@ -1433,7 +1465,12 @@ def stop_worker(self):
1433
1465
@WorkerHandler .register_environment (common_urns .environments .DOCKER .urn ,
1434
1466
beam_runner_api_pb2 .DockerPayload )
1435
1467
class DockerSdkWorkerHandler (GrpcWorkerHandler ):
1436
- def __init__ (self , payload , state , provision_info , grpc_server ):
1468
+ def __init__ (self ,
1469
+ payload , # type: beam_runner_api_pb2.DockerPayload
1470
+ state , # type: FnApiRunner.StateServicer
1471
+ provision_info , # type: Optional[fn_api_runner.ExtendedProvisionInfo]
1472
+ grpc_server # type: GrpcServer
1473
+ ):
1437
1474
super (DockerSdkWorkerHandler , self ).__init__ (state , provision_info ,
1438
1475
grpc_server )
1439
1476
self ._container_image = payload .container_image
@@ -1557,7 +1594,10 @@ def close_all(self):
1557
1594
1558
1595
1559
1596
class ExtendedProvisionInfo (object ):
1560
- def __init__ (self , provision_info = None , artifact_staging_dir = None ):
1597
+ def __init__ (self ,
1598
+ provision_info = None , # type: Optional[beam_provision_api_pb2.ProvisionInfo]
1599
+ artifact_staging_dir = None
1600
+ ):
1561
1601
self .provision_info = (
1562
1602
provision_info or beam_provision_api_pb2 .ProvisionInfo ())
1563
1603
self .artifact_staging_dir = artifact_staging_dir
@@ -1813,7 +1853,7 @@ def __init__(
1813
1853
def process_bundle (self ,
1814
1854
inputs , # type: Mapping[str, _ListBuffer]
1815
1855
expected_outputs
1816
- ):
1856
+ ):
1817
1857
# type: (...) -> Tuple[beam_fn_api_pb2.InstructionResponse, List[beam_fn_api_pb2.ProcessBundleSplitResponse]]
1818
1858
part_inputs = [{} for _ in range (self ._num_workers )] # type: List[Dict[str, List]]
1819
1859
for name , input in inputs .items ():
0 commit comments