8000 wait for connection success when start a mqtt client · emqx/mcp-python-sdk@17cc30e · GitHub
[go: up one dir, main page]

Skip to content

Commit 17cc30e

Browse files
committed
wait for connection success when start a mqtt client
1 parent cc3a0ea commit 17cc30e

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

src/mcp/client/mqtt.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,31 @@ def __init__(self,
108108
def get_presence_topic(self) -> str:
109109
return mqtt_topic.get_client_presence_topic(self.mcp_client_id)
110110

111-
def start(self):
111+
async def start(self, timeout: timedelta | None = None) -> bool | str:
112+
connect_result = self.connect()
112113
def do_start():
113-
self.connect()
114114
self.client.loop_forever()
115115
try:
116116
asyncio.create_task(anyio_to_thread.run_sync(do_start))
117+
if connect_result and connect_result != mqtt.MQTT_ERR_SUCCESS:
118+
logger.error(f"Failed to connect to MQTT broker, error code: {connect_result}")
119+
return mqtt.error_string(connect_result)
120+
# test if the client is connected and wait until it is connected
121+
if timeout:
122+
while not self.is_connected():
123+
await asyncio.sleep(0.1)
124+
if timeout.total_seconds() <= 0:
125+
logger.error(f"Timeout while waiting for MQTT client to connect, reason: {self.get_last_connect_fail_reason()}")
126+
return self.get_last_connect_fail_reason() or "timeout"
127+
timeout -= timedelta(seconds=0.1)
128+
return True
117129
except asyncio.CancelledError:
118130
logger.debug("MQTT transport (MCP client) got cancelled")
131+
return "cancelled"
119132
except Exception as exc:
120133
logger.error(f"MQTT transport (MCP client) failed: {exc}")
121134
traceback.print_exc()
135+
return "error"
122136

123137
def get_session(self, server_name: ServerName) -> MqttClientSession | None:
124138
return self.client_sessions.get(server_name, None)
@@ -277,8 +291,8 @@ def _create_session(
277291
)
278292

279293
def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None):
294+
super()._on_connect(client, userdata, connect_flags, reason_code, properties)
280295
if reason_code == 0:
281-
super()._on_connect(client, userdata, connect_flags, reason_code, properties)
282296
## Subscribe to the MCP server's presence topic
283297
client.subscribe(mqtt_topic.get_server_presence_topic('+', self.server_name_filter), qos=QOS)
284298

src/mcp/server/mqtt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def get_presence_topic(self) -> str:
5050
return mqtt_topic.get_server_presence_topic(self.server_id, self.server_name)
5151

5252
def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None):
53+
super()._on_connect(client, userdata, connect_flags, reason_code, properties)
5354
if reason_code == 0:
54-
super()._on_connect(client, userdata, connect_flags, reason_code, properties)
5555
if properties and hasattr(properties, "UserProperty"):
5656
user_properties: dict[str, Any] = dict(properties.UserProperty) # type: ignore
5757
if MCP_SERVER_NAME in user_properties:

src/mcp/shared/mqtt.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(self,
6565
disconnected_msg: types.JSONRPCMessage | None = None,
6666
disconnected_msg_retain: bool = True):
6767
self._read_stream_writers = {}
68+
self._last_connect_fail_reason = None
6869
self.mqtt_clientid = mqtt_clientid
6970
self.mcp_component_type = mcp_component_type
7071
self.mqtt_options = mqtt_options
@@ -74,7 +75,11 @@ def __init__(self,
7475
callback_api_version=CallbackAPIVersion.VERSION2,
7576
client_id=mqtt_clientid, protocol=mqtt.MQTTv5,
7677
userdata={},
77-
transport=mqtt_options.transport, reconnect_on_failure=True
78+
transport=mqtt_options.transport,
79+
reconnect_on_failure=True
80+
)
81+
client.reconnect_delay_set(
82+
min_delay=1, max_delay=120
7883
)
7984
client.username_pw_set(mqtt_options.username, mqtt_options.password.get_secret_value() if mqtt_options.password else None)
8085
if mqtt_options.tls_enabled:
@@ -123,12 +128,13 @@ async def __aexit__(
123128
self._task_group.cancel_scope.cancel()
124129
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
125130

126-
def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None):
131+
def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code: ReasonCode, properties: Properties | None):
127132
if reason_code == 0:
128133
logger.debug(f"Connected to MQTT broker_host at {self.mqtt_options.host}:{self.mqtt_options.port}")
129134
self.assert_property(properties, "RetainAvailable", 1)
130135
self.assert_property(properties, "WildcardSubscriptionAvailable", 1)
131136
else:
137+
self._last_connect_fail_reason = reason_code
132138
logger.error(f"Failed to connect, return code {reason_code}")
133139

134140
def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage):
@@ -138,6 +144,12 @@ def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int,
138144
reason_code_list: list[ReasonCode], properties: Properties | None):
139145
pass
140146

147+
def is_connected(self) -> bool:
148+
return self.client.is_connected()
149+
150+
def get_last_connect_fail_reason(self) -> ReasonCode | None:
151+
return self._last_connect_fail_reason
152+
141153
def publish_json_rpc_message(self, topic: str, message: types.JSONRPCMessage | None,
142154
retain: bool = False):
143155
props = self.get_publish_properties()

0 commit comments

Comments
 (0)
0