22
22
SOFTWARE.
23
23
"""
24
24
import logging
25
- from typing import (TYPE_CHECKING , Callable , Dict , Iterator , Optional , Tuple ,
26
- Type , TypeVar )
25
+ from typing import (TYPE_CHECKING , Callable , Dict , Generic , Iterator , Optional ,
26
+ Tuple , Type , TypeVar , Union , overload )
27
27
28
28
from .errors import ClientError
29
29
from .node import Node
35
35
_log = logging .getLogger (__name__ )
36
36
37
37
PlayerT = TypeVar ('PlayerT' , bound = BasePlayer )
38
+ CustomPlayerT = TypeVar ('CustomPlayerT' , bound = BasePlayer )
38
39
39
40
40
- class PlayerManager :
41
+ class PlayerManager ( Generic [ PlayerT ]) :
41
42
"""
42
43
Represents the player manager that contains all the players.
43
44
@@ -61,22 +62,22 @@ def __init__(self, client, player: Type[PlayerT]):
61
62
62
63
self .client : 'Client' = client
63
64
self ._player_cls : Type [PlayerT ] = player
64
- self .players : Dict [int , BasePlayer ] = {}
65
+ self .players : Dict [int , PlayerT ] = {}
65
66
66
67
def __len__ (self ) -> int :
67
68
return len (self .players )
68
69
69
- def __iter__ (self ) -> Iterator [Tuple [int , BasePlayer ]]:
70
+ def __iter__ (self ) -> Iterator [Tuple [int , PlayerT ]]:
70
71
""" Returns an iterator that yields a tuple of (guild_id, player). """
71
72
for guild_id , player in self .players .items ():
72
73
yield guild_id , player
73
74
74
- def values (self ) -> Iterator [BasePlayer ]:
75
+ def values (self ) -> Iterator [PlayerT ]:
75
76
""" Returns an iterator that yields only values. """
76
77
for player in self .players .values ():
77
78
yield player
78
79
79
- def find_all (self , predicate : Optional [Callable [[BasePlayer ], bool ]] = None ):
80
+ def find_all (self , predicate : Optional [Callable [[PlayerT ], bool ]] = None ):
80
81
"""
81
82
Returns a list of players that match the given predicate.
82
83
@@ -96,7 +97,7 @@ def find_all(self, predicate: Optional[Callable[[BasePlayer], bool]] = None):
96
97
97
98
return [p for p in self .players .values () if bool (predicate (p ))]
98
99
99
- def get (self , guild_id : int ) -> Optional [BasePlayer ]:
100
+ def get (self , guild_id : int ) -> Optional [PlayerT ]:
100
101
"""
101
102
Gets a player from cache.
102
103
@@ -126,13 +127,32 @@ def remove(self, guild_id: int):
126
127
player = self .players .pop (guild_id )
127
128
player .cleanup ()
128
129
130
+ @overload
131
+ def create (self ,
132
+ guild_id : int ,
133
+ * ,
134
+ region : Optional [str ] = ...,
135
+ endpoint : Optional [str ] = ...,
136
+ node : Optional [Node ] = ...) -> PlayerT :
137
+ ...
138
+
139
+ @overload
140
+ def create (self ,
141
+ guild_id : int ,
142
+ * ,
143
+ region : Optional [str ] = ...,
144
+ endpoint : Optional [str ] = ...,
145
+ node : Optional [Node ] = ...,
146
+ cls : Type [CustomPlayerT ]) -> CustomPlayerT :
147
+ ...
148
+
129
149
def create (self ,
130
150
guild_id : int ,
131
151
* ,
132
152
region : Optional [str ] = None ,
133
153
endpoint : Optional [str ] = None ,
134
154
node : Optional [Node ] = None ,
135
- cls : Optional [Type [PlayerT ]] = None ) -> BasePlayer :
155
+ cls : Optional [Type [CustomPlayerT ]] = None ) -> Union [ CustomPlayerT , PlayerT ] :
136
156
"""
137
157
Creates a player if one doesn't exist with the given information.
138
158
0 commit comments