8000 feat: speed up the service registry (#1174) · python-zeroconf/python-zeroconf@360ceb2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 360ceb2

Browse files
authored
feat: speed up the service registry (#1174)
1 parent bb496a1 commit 360ceb2

File tree

5 files changed

+119
-30
lines changed

5 files changed

+119
-30
lines changed

src/zeroconf/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable:
745745
# goodbye packets for the address records
746746

747747
assert info.server is not None
748-
entries = self.registry.async_get_infos_server(info.server)
748+
entries = self.registry.async_get_infos_server(info.server.lower())
749749
broadcast_addresses = not bool(entries)
750750
return asyncio.ensure_future(
751751
self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses)

src/zeroconf/_handlers.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ def _add_service_type_enumeration_query_answers(
255255
answer_set[dns_pointer] = set()
256256

257257
def _add_pointer_answers(
258-
self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
258+
self, lower_name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float
259259
) -> None:
260260
"""Answer PTR/ANY question."""
261-
for service in self.registry.async_get_infos_type(name):
261+
for service in self.registry.async_get_infos_type(lower_name):
262262
# Add recommended additional answers according to
263263
# https://tools.ietf.org/html/rfc6763#section-12.1.
264264
dns_pointer = service.dns_pointer(created=now)
@@ -270,14 +270,14 @@ def _add_pointer_answers(
270270

271271
def _add_address_answers(
272272
self,
273-
name: str,
273+
lower_name: str,
274274
answer_set: _AnswerWithAdditionalsType,
275275
known_answers: DNSRRSet,
276276
now: float,
277277
type_: int,
278278
) -> None:
279279
"""Answer A/AAAA/ANY question."""
280-
for service in self.registry.async_get_infos_server(name):
280+
for service in self.registry.async_get_infos_server(lower_name):
281281
answers: List[DNSAddress] = []
282282
additionals: Set[DNSRecord] = set()
283283
seen_types: Set[int] = set()
@@ -305,21 +305,22 @@ def _answer_question(
305305
now: float,
306306
) -> _AnswerWithAdditionalsType:
307307
answer_set: _AnswerWithAdditionalsType = {}
308+
question_lower_name = question.name.lower()
308309

309-
if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME:
310+
if question.type == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME:
310311
self._add_service_type_enumeration_query_answers(answer_set, known_answers, now)
311312
return answer_set
312313

313314
type_ = question.type
314315

315316
if type_ in (_TYPE_PTR, _TYPE_ANY):
316-
self._add_pointer_answers(question.name, answer_set, known_answers, now)
317+
self._add_pointer_answers(question_lower_name, answer_set, known_answers, now)
317318

318319
if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY):
319-
self._add_address_answers(question.name, answer_set, known_answers, now, type_)
320+
self._add_address_answers(question_lower_name, answer_set, known_answers, now, type_)
320321

321322
if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY):
322-
service = self.registry.async_get_info_name(question.name)
323+
service = self.registry.async_get_info_name(question_lower_name)
323324
if service is not None:
324325
if type_ in (_TYPE_SRV, _TYPE_ANY):
325326
# Add recommended additional answers according to

src/zeroconf/_services/registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def async_get_service_infos(self) -> List[ServiceInfo]:
6060

6161
def async_get_info_name(self, name: str) -> Optional[ServiceInfo]:
6262
"""Return all ServiceInfo for the name."""
63-
return self._services.get(name.lower())
63+
return self._services.get(name)
6464

6565
def async_get_types(self) -> List[str]:
6666
"""Return all types."""
@@ -76,7 +76,7 @@ def async_get_infos_server(self, server: str) -> List[ServiceInfo]:
7676

7777
def _async_get_by_index(self, records: Dict[str, List], key: str) -> List[ServiceInfo]:
7878
"""Return all ServiceInfo matching the index."""
79-
return [self._services[name] for name in records.get(key.lower(), [])]
79+
return [self._services[name] for name in records.get(key, [])]
8080

8181
def _add(self, info: ServiceInfo) -> None:
8282
"""Add a new service under the lock."""

tests/services/test_registry.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -110,22 +110,3 @@ def test_lookups_upper_case_by_lower_case(self):
110110
assert registry.async_get_infos_type(type_.lower()) == [info]
111111
assert registry.async_get_infos_server("ash-2.local.") == [info]
112112
assert registry.async_get_types() == [type_.lower()]
113-
114-
def test_lookups_lower_case_by_upper_case(self):
115-
type_ = "_test-srvc-type._tcp.local."
116-
name = "xxxyyy"
117-
registration_name = f"{name}.{type_}"
118-
119-
desc = {'path': '/~paulsm/'}
120-
info = ServiceInfo(
121-
type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
122-
)
123-
124-
registry = r.ServiceRegistry()
125-
registry.async_add(info)
126-
127-
assert registry.async_get_service_infos() == [info]
128-
assert registry.async_get_info_name(registration_name.upper()) == info
129-
assert registry.async_get_infos_type(type_.upper()) == [info]
130-
assert registry.async_get_infos_server("ASH-2.local.") == [info]
131-
assert registry.async_get_types() == [type_]

tests/test_handlers.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,32 @@ def test_aaaa_query():
340340
zc.close()
341341

342342

343+
@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
344+
@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
345+
def test_aaaa_query_upper_case():
346+
"""Test that queries for AAAA records work and should respond right away with an upper case name."""
347+
zc = Zeroconf(interfaces=['127.0.0.1'])
348+
type_ = "_knownaaaservice._tcp.local."
349+
name = "knownname"
350+
registration_name = f"{name}.{type_}"
351+
desc = {'path': '/~paulsm/'}
352+
server_name = "ash-2.local."
353+
ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1")
354+
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address])
355+
zc.registry.async_add(info)
356+
357+
generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
358+
question = r.DNSQuestion(server_name.upper(), const._TYPE_AAAA, const._CLASS_IN)
359+
generated.add_question(question)
360+
packets = generated.packets()
361+
question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
362+
mcast_answers = list(question_answers.mcast_now)
363+
assert mcast_answers[0].address == ipv6_address # type: ignore[attr-defined]
364+
# unregister
365+
zc.registry.async_remove(info)
366+
zc.close()
367+
368+
343369
@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
344370
@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
345371
def test_a_and_aaaa_record_fate_sharing():
@@ -481,6 +507,48 @@ async def test_probe_answered_immediately():
481507
zc.close()
482508

483509

510+
@pytest.mark.asyncio
511+
async def test_probe_answered_immediately_with_uppercase_name():
512+
"""Verify probes are responded to immediately with an uppercase name."""
513+
# instantiate a zeroconf instance
514+
zc = Zeroconf(interfaces=['127.0.0.1'])
515+
516+
# service definition
517+
type_ = "_test-srvc-type._tcp.local."
518+
name = "xxxyyy"
519+
registration_name = f"{name}.{type_}"
520+
desc = {'path': '/~paulsm/'}
521+
info = ServiceInfo(
522+
type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
523+
)
524+
zc.registry.async_add(info)
525+
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
526+
question = r.DNSQuestion(info.type.upper(), const._TYPE_PTR, const._CLASS_IN)
527+
query.add_question(question)
528+
query.add_authorative_answer(info.dns_pointer())
529+
question_answers = zc.query_handler.async_response(
530+
[r.DNSIncoming(packet) for packet in query.packets()], False
531+
)
532+
assert not question_answers.ucast
533+
assert not question_answers.mcast_aggregate
534+
assert not question_answers.mcast_aggregate_last_second
535+
assert question_answers.mcast_now
536+
537+
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
538+
question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
539+
question.unicast = True
540+
query.add_question(question)
541+
query.add_authorative_answer(info.dns_pointer())
542+
question_answers = zc.query_handler.async_response(
543+
[r.DNSIncoming(packet) for packet in query.packets()], False
544+
)
545+
assert question_answers.ucast
546+
assert question_answers.mcast_now
547+
assert not question_answers.mcast_aggregate
548+
assert not question_answers.mcast_aggregate_last_second
549+
zc.close()
550+
551+
484552
def test_qu_response():
485553
"""Handle multicast incoming with the QU bit set."""
486554
# instantiate a zeroconf instance
@@ -842,6 +910,45 @@ def test_known_answer_supression_service_type_enumeration_query():
842910
zc.close()
843911

844912

913+
def test_upper_case_enumeration_query():
914+
zc = Zeroconf(interfaces=['127.0.0.1'])
915+
type_ = "_otherknown._tcp.local."
916+
name = "knownname"
917+
registration_name = f"{name}.{type_}"
918+
desc = {'path': '/~paulsm/'}
919+
server_name = "ash-2.local."
920+
info = ServiceInfo(
921+
type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
922+
)
923+
zc.registry.async_add(info)
924+
925+
type_2 = "_otherknown2._tcp.local."
926+
name = "knownname"
927+
registration_name2 = f"{name}.{type_2}"
928+
desc = {'path': '/~paulsm/'}
929+
server_name2 = "ash-3.local."
930+
info2 = ServiceInfo(
931+
type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
932+
)
D216 933+
zc.registry.async_add(info2)
934+
_clear_cache(zc)
935+
936+
# Test PTR supression
937+
generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
938+
question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME.upper(), const._TYPE_PTR, const._CLASS_IN)
939+
generated.add_question(question)
940+
packets = generated.packets()
941+
question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False)
942+
assert not question_answers.ucast
943+
assert not question_answers.mcast_now
944+
assert question_answers.mcast_aggregate
945+
assert not question_answers.mcast_aggregate_last_second
946+
# unregister
947+
zc.registry.async_remove(info)
948+
zc.registry.async_remove(info2)
949+
zc.close()
950+
951+
845952
# This test uses asyncio because it needs to access the cache directly
846953
# which is not threadsafe
847954
@pytest.mark.asyncio

0 commit comments

Comments
 (0)
0