8000 Add util.assert_never for static checks of must-be-unreachable code · Simplexum/python-bitcointx@809d2fe · GitHub
[go: up one dir, main page]

10000
Skip to content

Commit 809d2fe

Browse files
committed
Add util.assert_never for static checks of must-be-unreachable code
1 parent f875d12 commit 809d2fe

File tree

3 files changed

+44
-29
lines changed

3 files changed

+44
-29
lines changed

bitcointx/core/psbt.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from ..wallet import CCoinExtPubKey
4545

4646
from ..util import (
47-
ensure_isinstance, no_bool_use_as_property,
47+
ensure_isinstance, no_bool_use_as_property, assert_never,
4848
ClassMappingDispatcher, activate_class_dispatcher
4949
)
5050

@@ -135,13 +135,8 @@ class PSBT_OutKeyType(Enum):
135135
('key_type', int), ('key_data', bytes), ('value', bytes)
136136
])
137137

138-
T_KeyTypeEnum_Type = Union[
139-
Type[PSBT_GlobalKeyType],
140-
Type[PSBT_OutKeyType],
141-
Type[PSBT_InKeyType],
142-
]
143-
144-
T_KeyTypeEnum = Union[PSBT_GlobalKeyType, PSBT_OutKeyType, PSBT_InKeyType]
138+
T_KeyTypeEnum = TypeVar(
139+
'T_KeyTypeEnum', PSBT_GlobalKeyType, PSBT_OutKeyType, PSBT_InKeyType)
145140

146141

147142
def proprietary_field_repr(
@@ -296,7 +291,7 @@ def merge_unknown_fields(
296291
def read_psbt_keymap(
297292
f: ByteStream_Type,
298293
keys_seen: Set[bytes],
299-
keys_enum_class: T_KeyTypeEnum_Type,
294+
keys_enum_class: Type[T_KeyTypeEnum],
300295
proprietary_fields: Dict[bytes, List[PSBT_ProprietaryTypeData]],
301296
unknown_fields: List[PSBT_UnknownTypeData]
302297
) -> Generator[Tuple[T_KeyTypeEnum, bytes, bytes], None, None]:
@@ -1295,10 +1290,7 @@ def check_witness_and_nonwitness_utxo_in_sync(
12951290
ensure_empty_key_data(key_type, key_data, descr(''))
12961291
proof_of_reserves_commitment = value
12971292
else:
1298-
raise AssertionError(
1299-
f'If key type {key_type} is recognized, '
1300-
f'it must be handled, and this statement '
1301-
f'should not be reached.')
1293+
assert_never(key_type)
13021294

13031295
# non_witness_utxo is preferred over witness_utxo for `utxo` kwarg
13041296
# because non_witness_utxo is a full transaction,
@@ -1646,13 +1638,13 @@ def descr(msg: str) -> str:
16461638
read_psbt_keymap(f, keys_seen, PSBT_OutKeyType,
16471639
proprietary_fields, unknown_fields):
16481640

1649-
if key_type == PSBT_OutKeyType.REDEEM_SCRIPT:
1641+
if key_type is PSBT_OutKeyType.REDEEM_SCRIPT:
16501642
ensure_empty_key_data(key_type, key_data, descr(''))
16511643
redeem_script = CScript(value)
1652-
elif key_type == PSBT_OutKeyType.WITNESS_SCRIPT:
1644+
elif key_type is PSBT_OutKeyType.WITNESS_SCRIPT:
16531645
ensure_empty_key_data(key_type, key_data, descr(''))
16541646
witness_script = CScript(value)
1655-
elif key_type == PSBT_OutKeyType.BIP32_DERIVATION:
1647+
elif key_type is PSBT_OutKeyType.BIP32_DERIVATION:
16561648
pub = CPubKey(key_data)
16571649
if not pub.is_fullyvalid():
16581650
raise SerializationError(
@@ -1662,6 +1654,8 @@ def descr(msg: str) -> str:
16621654
("duplicate keys should have been catched "
16631655
"inside read_psbt_keymap()")
16641656
derivation_map[pub] = PSBT_KeyDerivationInfo.deserialize(value)
1657+
else:
1658+
assert_never(key_type)
16651659

16661660
return cls(redeem_script=redeem_script, witness_script=witness_script,
16671661
derivation_map=derivation_map,
@@ -2107,10 +2101,10 @@ def stream_deserialize(cls: Type[T_PartiallySignedTransaction],
21072101
read_psbt_keymap(f, keys_seen, PSBT_GlobalKeyType,
21082102
proprietary_fields, unknown_fields):
21092103

2110-
if key_type == PSBT_GlobalKeyType.UNSIGNED_TX:
2104+
if key_type is PSBT_GlobalKeyType.UNSIGNED_TX:
21112105
ensure_empty_key_data(key_type, key_data)
21122106
unsigned_tx = CTransaction.deserialize(value)
2113-
elif key_type == PSBT_GlobalKeyType.XPUB:
2107+
elif key_type is PSBT_GlobalKeyType.XPUB:
21142108
if key_data[:4] != CCoinExtPubKey.base58_prefix:
21152109
raise ValueError(
21162110
f'One of global xpubs has unknown prefix: expected '
@@ -2121,17 +2115,14 @@ def stream_deserialize(cls: Type[T_PartiallySignedTransaction],
21212115
("duplicate keys should have been catched "
21222116
"inside read_psbt_keymap()")
21232117
xpubs[xpub] = PSBT_KeyDerivationInfo.deserialize(value)
2124-
elif key_type == PSBT_GlobalKeyType.VERSION:
2118+
elif key_type is PSBT_GlobalKeyType.VERSION:
21252119
ensure_empty_key_data(key_type, key_data)
21262120
if len(value) != 4:
21272121
raise SerializationError(
21282122
f'Incorrect data length for {key_type.name}')
21292123
version = struct.unpack(b'<I', value)[0]
21302124
else:
2131-
raise AssertionError(
2132-
f'If key type {key_type} is present in PSBT_GLOBAL_KEYS, '
2133-
f'it must be handled, and this statement '
2134-
f'should not be reached.')
2125+
assert_never(key_type)
21352126

21362127
if unsigned_tx is None:
21372128
raise ValueError(

bitcointx/tests/test_psbt.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -679,20 +679,20 @@ def get_input_key_types(psbt_bytes: bytes) -> Set[PSBT_InKeyType]:
679679
assert magic == CoreCoinParams.PSBT_MAGIC_HEADER_BYTES
680680
unsigned_tx = None
681681
keys_seen: Set[bytes] = set()
682-
for key_type, key_data, value in \
682+
for key_type_g, key_data, value in \
683683
read_psbt_keymap(f, keys_seen, PSBT_GlobalKeyType,
684684
OrderedDict(), list()):
685-
if key_type == PSBT_GlobalKeyType.UNSIGNED_TX:
685+
if key_type_g == PSBT_GlobalKeyType.UNSIGNED_TX:
686686
unsigned_tx = CTransaction.deserialize(value)
687687
assert unsigned_tx
688688
keys_seen = set()
689689
key_types_seen: Set[PSBT_InKeyType] = set()
690690
assert len(unsigned_tx.vin) == 1
691-
for key_type, key_data, value in \
691+
for key_type_in, key_data, value in \
692692
read_psbt_keymap(f, keys_seen, PSBT_InKeyType,
693693
OrderedDict(), list()):
694-
assert isinstance(key_type, PSBT_InKeyType)
695-
key_types_seen.add(key_type)
694+
assert isinstance(key_type_in, PSBT_InKeyType)
695+
key_types_seen.add(key_type_in)
696696

697697
return key_types_seen
698698

bitcointx/util.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
has_contextvars = False
2020

2121
import functools
22+
from enum import Enum
2223
from types import FunctionType
2324
from abc import ABCMeta, ABC, abstractmethod
2425
from typing import (
2526
Type, Set, Tuple, List, Dict, Union, Any, Callable, Iterable, Optional,
26-
TypeVar, Generic, cast
27+
TypeVar, Generic, cast, NoReturn
2728
)
2829

2930
_secp256k1_library_path: Optional[str] = None
@@ -471,6 +472,28 @@ def ensure_isinstance(var: object,
471472
raise TypeError(msg)
472473

473474

475+
def assert_never(x: NoReturn) -> NoReturn:
476+
"""For use with static checking. The checker such as mypy will raise
477+
error if the statement `assert_never(...)` is reached. At runtime,
478+
an `AssertionError` will be raised.
479+
Useful to ensure that all variants of Enum is handled.
480+
Might become useful in other ways, and because of this, the message
481+
for `AssertionError` at runtime can differ on actual type of the argument.
482+
For full control of the message, just pass a string as the argument.
483+
"""
484+
485+
if isinstance(x, Enum):
486+
msg = f'Enum {x} is not handled'
487+
elif isinstance(x, str):
488+
msg = x
489+
elif isinstance(x, type):
490+
msg = f'{x.__name__} is not handled'
491+
else:
492+
msg = f'{x.__class__.__name__} is not handled'
493+
494+
raise AssertionError(msg)
495+
496+
474497
class ReadOnlyFieldGuard(ABC):
475498
"""A unique class that is used as a guard type for ReadOnlyField.
476499
It cannot be instantiated at runtime, and the static check will also
@@ -604,6 +627,7 @@ def set_dispatcher_class(self, identity: str,
604627
'ClassMappingDispatcher',
605628
'classgetter',
606629
'ensure_isinstance',
630+
'assert_never',
607631
'ReadOnlyField',
608632
'WriteableField',
609633
'ContextVarsCompat',

0 commit comments

Comments
 (0)
0