From 6fa08b0a522917471e339c04d606057d97b655e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Mon, 23 Jun 2025 11:21:00 +0200 Subject: [PATCH 1/2] add a helper to validate a final heap type --- Modules/hashlib.h | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/Modules/hashlib.h b/Modules/hashlib.h index 9a7e72f34a7f9d..c2f132a0da45ac 100644 --- a/Modules/hashlib.h +++ b/Modules/hashlib.h @@ -1,6 +1,32 @@ /* Common code for use by all hashlib related modules. */ -#include "pycore_lock.h" // PyMutex +#include "pycore_lock.h" // PyMutex +#include "pycore_moduleobject.h" // _PyModule_GetDef() + +#ifndef NDEBUG +/* + * Assert that a type cannot be subclassed and that + * its associated module definition matches 'moddef'. + * + * Use this helper to ensure that _PyType_GetModuleState() can be safely used. + */ +static inline void +_Py_hashlib_check_exported_type(PyTypeObject *type, PyModuleDef *moddef) +{ + assert(type != NULL); + assert(moddef != NULL); + /* ensure that the type is a final heap type */ + assert(PyType_Check(type)); + assert(type->tp_flags & Py_TPFLAGS_HEAPTYPE); + assert(!(type->tp_flags & Py_TPFLAGS_BASETYPE)); + /* ensure that the associated module definition matches 'moddef' */ + PyHeapTypeObject *ht = (PyHeapTypeObject *)type; + assert(ht->ht_module != NULL); + assert(moddef == _PyModule_GetDef(ht->ht_module)); +} +#else +#define _Py_hashlib_check_exported_type(_TYPE, _MODDEF) +#endif /* * Given a PyObject* obj, fill in the Py_buffer* viewp with the result From 0abef75371ed4d01ccd5cf8001b6e3110d8445ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Mon, 23 Jun 2025 12:41:40 +0200 Subject: [PATCH 2/2] HACL*: fortify DEBUG checks for fetching the module state --- Modules/blake2module.c | 42 ++++++++-------- Modules/hashlib.h | 4 +- Modules/hmacmodule.c | 19 ++++---- Modules/md5module.c | 45 ++++++++++------- Modules/sha1module.c | 47 +++++++++++------- Modules/sha2module.c | 106 ++++++++++++++++++++++++----------------- Modules/sha3module.c | 46 +++++++++++------- 7 files changed, 186 insertions(+), 123 deletions(-) diff --git a/Modules/blake2module.c b/Modules/blake2module.c index 6c4349ac06bb8a..645b4da469cd22 100644 --- a/Modules/blake2module.c +++ b/Modules/blake2module.c @@ -15,10 +15,11 @@ #endif #include "Python.h" +#include "pycore_moduleobject.h" // _PyModule_GetState() +#include "pycore_strhex.h" // _Py_strhex() +#include "pycore_typeobject.h" // _PyType_GetModuleState() + #include "hashlib.h" -#include "pycore_strhex.h" // _Py_strhex() -#include "pycore_typeobject.h" -#include "pycore_moduleobject.h" // QUICK CPU AUTODETECTION // @@ -67,6 +68,7 @@ // MODULE TYPE SLOTS +static struct PyModuleDef blake2module_def; static PyType_Spec blake2b_type_spec; static PyType_Spec blake2s_type_spec; @@ -78,23 +80,24 @@ typedef struct { PyTypeObject *blake2s_type; bool can_run_simd128; bool can_run_simd256; -} Blake2State; +} blake2module_state; -static inline Blake2State * -blake2_get_state(PyObject *module) +static inline blake2module_state * +get_blake2module_state(PyObject *module) { void *state = _PyModule_GetState(module); assert(state != NULL); - return (Blake2State *)state; + return (blake2module_state *)state; } #if defined(HACL_CAN_COMPILE_SIMD128) || defined(HACL_CAN_COMPILE_SIMD256) -static inline Blake2State * -blake2_get_state_from_type(PyTypeObject *module) +static inline blake2module_state * +get_blake2module_state_by_cls(PyTypeObject *cls) { - void *state = _PyType_GetModuleState(module); + _Py_hashlib_check_exported_type(cls, &blake2module_def); + void *state = _PyType_GetModuleState(cls); assert(state != NULL); - return (Blake2State *)state; + return (blake2module_state *)state; } #endif @@ -105,7 +108,7 @@ static struct PyMethodDef blake2mod_functions[] = { static int _blake2_traverse(PyObject *module, visitproc visit, void *arg) { - Blake2State *state = blake2_get_state(module); + blake2module_state *state = get_blake2module_state(module); Py_VISIT(state->blake2b_type); Py_VISIT(state->blake2s_type); return 0; @@ -114,7 +117,7 @@ _blake2_traverse(PyObject *module, visitproc visit, void *arg) static int _blake2_clear(PyObject *module) { - Blake2State *state = blake2_get_state(module); + blake2module_state *state = get_blake2module_state(module); Py_CLEAR(state->blake2b_type); Py_CLEAR(state->blake2s_type); return 0; @@ -127,7 +130,7 @@ _blake2_free(void *module) } static void -blake2module_init_cpu_features(Blake2State *state) +blake2module_init_cpu_features(blake2module_state *state) { /* This must be kept in sync with hmacmodule_init_cpu_features() * in hmacmodule.c */ @@ -205,7 +208,7 @@ blake2module_init_cpu_features(Blake2State *state) static int blake2_exec(PyObject *m) { - Blake2State *st = blake2_get_state(m); + blake2module_state *st = get_blake2module_state(m); blake2module_init_cpu_features(st); #define ADD_INT(DICT, NAME, VALUE) \ @@ -285,11 +288,11 @@ static PyModuleDef_Slot _blake2_slots[] = { {0, NULL} }; -static struct PyModuleDef blake2_module = { +static struct PyModuleDef blake2module_def = { .m_base = PyModuleDef_HEAD_INIT, .m_name = "_blake2", .m_doc = blake2mod__doc__, - .m_size = sizeof(Blake2State), + .m_size = sizeof(blake2module_state), .m_methods = blake2mod_functions, .m_slots = _blake2_slots, .m_traverse = _blake2_traverse, @@ -300,7 +303,7 @@ static struct PyModuleDef blake2_module = { PyMODINIT_FUNC PyInit__blake2(void) { - return PyModuleDef_Init(&blake2_module); + return PyModuleDef_Init(&blake2module_def); } // IMPLEMENTATION OF METHODS @@ -333,7 +336,7 @@ static inline blake2_impl type_to_impl(PyTypeObject *type) { #if defined(HACL_CAN_COMPILE_SIMD128) || defined(HACL_CAN_COMPILE_SIMD256) - Blake2State *st = blake2_get_state_from_type(type); + blake2module_state *st = get_blake2module_state_by_cls(type); #endif if (!strcmp(type->tp_name, blake2b_type_spec.name)) { #if HACL_CAN_COMPILE_SIMD256 @@ -385,6 +388,7 @@ class _blake2.blake2s "Blake2Object *" "&PyType_Type" static Blake2Object * new_Blake2Object(PyTypeObject *type) { + _Py_hashlib_check_exported_type(type, &blake2module_def); Blake2Object *self = PyObject_GC_New(Blake2Object, type); if (self == NULL) { return NULL; diff --git a/Modules/hashlib.h b/Modules/hashlib.h index c2f132a0da45ac..2d0c236450d37d 100644 --- a/Modules/hashlib.h +++ b/Modules/hashlib.h @@ -22,7 +22,9 @@ _Py_hashlib_check_exported_type(PyTypeObject *type, PyModuleDef *moddef) /* ensure that the associated module definition matches 'moddef' */ PyHeapTypeObject *ht = (PyHeapTypeObject *)type; assert(ht->ht_module != NULL); - assert(moddef == _PyModule_GetDef(ht->ht_module)); + PyModuleDef *ht_moddef = _PyModule_GetDef(ht->ht_module); + assert(ht_moddef != NULL); + assert(ht_moddef == moddef); } #else #define _Py_hashlib_check_exported_type(_TYPE, _MODDEF) diff --git a/Modules/hmacmodule.c b/Modules/hmacmodule.c index e7a5ccbb19b45c..f22191a8fb523c 100644 --- a/Modules/hmacmodule.c +++ b/Modules/hmacmodule.c @@ -18,7 +18,9 @@ #include "Python.h" #include "pycore_hashtable.h" +#include "pycore_moduleobject.h" // _PyModule_GetState() #include "pycore_strhex.h" // _Py_strhex() +#include "pycore_typeobject.h" // _PyType_GetModuleState() /* * Taken from blake2module.c. In the future, detection of SIMD support @@ -250,6 +252,8 @@ typedef struct py_hmac_hinfo { // --- HMAC module state ------------------------------------------------------ +static struct PyModuleDef hmacmodule_def; + typedef struct hmacmodule_state { _Py_hashtable_t *hinfo_table; PyObject *unknown_hash_error; @@ -265,7 +269,7 @@ typedef struct hmacmodule_state { static inline hmacmodule_state * get_hmacmodule_state(PyObject *module) { - void *state = PyModule_GetState(module); + void *state = _PyModule_GetState(module); assert(state != NULL); return (hmacmodule_state *)state; } @@ -273,7 +277,8 @@ get_hmacmodule_state(PyObject *module) static inline hmacmodule_state * get_hmacmodule_state_by_cls(PyTypeObject *cls) { - void *state = PyType_GetModuleState(cls); + _Py_hashlib_check_exported_type(cls, &hmacmodule_def); + void *state = _PyType_GetModuleState(cls); assert(state != NULL); return (hmacmodule_state *)state; } @@ -301,13 +306,11 @@ typedef struct HMACObject { /*[clinic input] module _hmac -class _hmac.HMAC "HMACObject *" "clinic_state()->hmac_type" +class _hmac.HMAC "HMACObject *" "&PyType_Type" [clinic start generated code]*/ -/*[clinic end generated code: output=da39a3ee5e6b4b0d input=c8bab73fde49ba8a]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=72bc06d6dc634770]*/ -#define clinic_state() (get_hmacmodule_state_by_cls(Py_TYPE(self))) #include "clinic/hmacmodule.c.h" -#undef clinic_state // --- Helpers ---------------------------------------------------------------- // @@ -1683,7 +1686,7 @@ static struct PyModuleDef_Slot hmacmodule_slots[] = { {0, NULL} /* sentinel */ }; -static struct PyModuleDef _hmacmodule = { +static struct PyModuleDef hmacmodule_def = { PyModuleDef_HEAD_INIT, .m_name = "_hmac", .m_size = sizeof(hmacmodule_state), @@ -1697,5 +1700,5 @@ static struct PyModuleDef _hmacmodule = { PyMODINIT_FUNC PyInit__hmac(void) { - return PyModuleDef_Init(&_hmacmodule); + return PyModuleDef_Init(&hmacmodule_def); } diff --git a/Modules/md5module.c b/Modules/md5module.c index 8b6dd4a8195dfb..5a761831efb695 100644 --- a/Modules/md5module.c +++ b/Modules/md5module.c @@ -22,7 +22,9 @@ #endif #include "Python.h" -#include "pycore_strhex.h" // _Py_strhex() +#include "pycore_moduleobject.h" // _PyModule_GetState() +#include "pycore_strhex.h" // _Py_strhex() +#include "pycore_typeobject.h" // _PyType_GetModuleState() #include "hashlib.h" @@ -44,16 +46,27 @@ typedef struct { // --- Module state ----------------------------------------------------------- +static struct PyModuleDef md5module_def; + typedef struct { - PyTypeObject* md5_type; -} MD5State; + PyTypeObject *md5_type; +} md5module_state; + +static inline md5module_state * +get_md5module_state(PyObject *module) +{ + void *state = _PyModule_GetState(module); + assert(state != NULL); + return (md5module_state *)state; +} -static inline MD5State* -md5_get_state(PyObject *module) +static inline md5module_state * +get_md5module_state_by_cls(PyTypeObject *cls) { - void *state = PyModule_GetState(module); + _Py_hashlib_check_exported_type(cls, &md5module_def); + void *state = _PyType_GetModuleState(cls); assert(state != NULL); - return (MD5State *)state; + return (md5module_state *)state; } // --- Module clinic configuration -------------------------------------------- @@ -69,7 +82,7 @@ class MD5Type "MD5object *" "&PyType_Type" // --- MD5 object interface --------------------------------------------------- static MD5object * -newMD5object(MD5State * st) +newMD5object(md5module_state *st) { MD5object *md5 = PyObject_GC_New(MD5object, st->md5_type); if (!md5) { @@ -115,7 +128,7 @@ static PyObject * MD5Type_copy_impl(MD5object *self, PyTypeObject *cls) /*[clinic end generated code: output=bf055e08244bf5ee input=d89087dcfb2a8620]*/ { - MD5State *st = PyType_GetModuleState(cls); + md5module_state *st = get_md5module_state_by_cls(cls); MD5object *newobj; if ((newobj = newMD5object(st)) == NULL) { @@ -288,7 +301,7 @@ _md5_md5_impl(PyObject *module, PyObject *data, int usedforsecurity, GET_BUFFER_VIEW_OR_ERROUT(string, &buf); } - MD5State *st = md5_get_state(module); + md5module_state *st = get_md5module_state(module); if ((new = newMD5object(st)) == NULL) { if (string) { PyBuffer_Release(&buf); @@ -329,7 +342,7 @@ static struct PyMethodDef MD5_functions[] = { static int _md5_traverse(PyObject *module, visitproc visit, void *arg) { - MD5State *state = md5_get_state(module); + md5module_state *state = get_md5module_state(module); Py_VISIT(state->md5_type); return 0; } @@ -337,7 +350,7 @@ _md5_traverse(PyObject *module, visitproc visit, void *arg) static int _md5_clear(PyObject *module) { - MD5State *state = md5_get_state(module); + md5module_state *state = get_md5module_state(module); Py_CLEAR(state->md5_type); return 0; } @@ -352,7 +365,7 @@ _md5_free(void *module) static int md5_exec(PyObject *m) { - MD5State *st = md5_get_state(m); + md5module_state *st = get_md5module_state(m); st->md5_type = (PyTypeObject *)PyType_FromModuleAndSpec( m, &md5_type_spec, NULL); @@ -375,10 +388,10 @@ static PyModuleDef_Slot _md5_slots[] = { }; -static struct PyModuleDef _md5module = { +static struct PyModuleDef md5module_def = { PyModuleDef_HEAD_INIT, .m_name = "_md5", - .m_size = sizeof(MD5State), + .m_size = sizeof(md5module_state), .m_methods = MD5_functions, .m_slots = _md5_slots, .m_traverse = _md5_traverse, @@ -389,5 +402,5 @@ static struct PyModuleDef _md5module = { PyMODINIT_FUNC PyInit__md5(void) { - return PyModuleDef_Init(&_md5module); + return PyModuleDef_Init(&md5module_def); } diff --git a/Modules/sha1module.c b/Modules/sha1module.c index faa9dcccc5755b..e5f6749a028785 100644 --- a/Modules/sha1module.c +++ b/Modules/sha1module.c @@ -20,9 +20,11 @@ #endif #include "Python.h" +#include "pycore_moduleobject.h" // _PyModule_GetState() +#include "pycore_strhex.h" // _Py_strhex() +#include "pycore_typeobject.h" // _PyType_GetModuleState() + #include "hashlib.h" -#include "pycore_strhex.h" // _Py_strhex() -#include "pycore_typeobject.h" // _PyType_GetModuleState() #include "_hacl/Hacl_Hash_SHA1.h" @@ -42,16 +44,27 @@ typedef struct { // --- Module state ----------------------------------------------------------- +static struct PyModuleDef sha1module_def; + typedef struct { - PyTypeObject* sha1_type; -} SHA1State; + PyTypeObject *sha1_type; +} sha1module_state; + +static inline sha1module_state * +get_sha1module_state(PyObject *module) +{ + void *state = _PyModule_GetState(module); + assert(state != NULL); + return (sha1module_state *)state; +} -static inline SHA1State* -sha1_get_state(PyObject *module) +static inline sha1module_state * +get_sha1module_state_by_cls(PyTypeObject *cls) { - void *state = PyModule_GetState(module); + _Py_hashlib_check_exported_type(cls, &sha1module_def); + void *state = _PyType_GetModuleState(cls); assert(state != NULL); - return (SHA1State *)state; + return (sha1module_state *)state; } // --- Module clinic configuration -------------------------------------------- @@ -67,7 +80,7 @@ class SHA1Type "SHA1object *" "&PyType_Type" // --- SHA-1 object interface configuration ----------------------------------- static SHA1object * -newSHA1object(SHA1State *st) +newSHA1object(sha1module_state *st) { SHA1object *sha = PyObject_GC_New(SHA1object, st->sha1_type); if (sha == NULL) { @@ -117,7 +130,7 @@ static PyObject * SHA1Type_copy_impl(SHA1object *self, PyTypeObject *cls) /*[clinic end generated code: output=b32d4461ce8bc7a7 input=6c22e66fcc34c58e]*/ { - SHA1State *st = _PyType_GetModuleState(cls); + sha1module_state *st = get_sha1module_state_by_cls(cls); SHA1object *newobj; if ((newobj = newSHA1object(st)) == NULL) { @@ -288,7 +301,7 @@ _sha1_sha1_impl(PyObject *module, PyObject *data, int usedforsecurity, GET_BUFFER_VIEW_OR_ERROUT(string, &buf); } - SHA1State *st = sha1_get_state(module); + sha1module_state *st = get_sha1module_state(module); if ((new = newSHA1object(st)) == NULL) { if (string) { PyBuffer_Release(&buf); @@ -329,7 +342,7 @@ static struct PyMethodDef SHA1_functions[] = { static int _sha1_traverse(PyObject *module, visitproc visit, void *arg) { - SHA1State *state = sha1_get_state(module); + sha1module_state *state = get_sha1module_state(module); Py_VISIT(state->sha1_type); return 0; } @@ -337,7 +350,7 @@ _sha1_traverse(PyObject *module, visitproc visit, void *arg) static int _sha1_clear(PyObject *module) { - SHA1State *state = sha1_get_state(module); + sha1module_state *state = get_sha1module_state(module); Py_CLEAR(state->sha1_type); return 0; } @@ -351,7 +364,7 @@ _sha1_free(void *module) static int _sha1_exec(PyObject *module) { - SHA1State* st = sha1_get_state(module); + sha1module_state *st = get_sha1module_state(module); st->sha1_type = (PyTypeObject *)PyType_FromModuleAndSpec( module, &sha1_type_spec, NULL); @@ -381,10 +394,10 @@ static PyModuleDef_Slot _sha1_slots[] = { {0, NULL} }; -static struct PyModuleDef _sha1module = { +static struct PyModuleDef sha1module_def = { PyModuleDef_HEAD_INIT, .m_name = "_sha1", - .m_size = sizeof(SHA1State), + .m_size = sizeof(sha1module_state), .m_methods = SHA1_functions, .m_slots = _sha1_slots, .m_traverse = _sha1_traverse, @@ -395,5 +408,5 @@ static struct PyModuleDef _sha1module = { PyMODINIT_FUNC PyInit__sha1(void) { - return PyModuleDef_Init(&_sha1module); + return PyModuleDef_Init(&sha1module_def); } diff --git a/Modules/sha2module.c b/Modules/sha2module.c index 36300ba899fd44..c2f3faa1eea7e4 100644 --- a/Modules/sha2module.c +++ b/Modules/sha2module.c @@ -21,9 +21,9 @@ #endif #include "Python.h" -#include "pycore_moduleobject.h" // _PyModule_GetState() -#include "pycore_typeobject.h" // _PyType_GetModuleState() -#include "pycore_strhex.h" // _Py_strhex() +#include "pycore_moduleobject.h" // _PyModule_GetState() +#include "pycore_strhex.h" // _Py_strhex() +#include "pycore_typeobject.h" // _PyType_GetModuleState() #include "hashlib.h" @@ -57,21 +57,32 @@ typedef struct { // --- Module state ----------------------------------------------------------- +static struct PyModuleDef sha2module_def; + /* We shall use run-time type information in the remainder of this module to * tell apart SHA2-224 and SHA2-256 */ typedef struct { - PyTypeObject* sha224_type; - PyTypeObject* sha256_type; - PyTypeObject* sha384_type; - PyTypeObject* sha512_type; -} sha2_state; + PyTypeObject *sha224_type; + PyTypeObject *sha256_type; + PyTypeObject *sha384_type; + PyTypeObject *sha512_type; +} sha2module_state; -static inline sha2_state* -sha2_get_state(PyObject *module) +static inline sha2module_state * +get_sha2module_state(PyObject *module) { void *state = _PyModule_GetState(module); assert(state != NULL); - return (sha2_state *)state; + return (sha2module_state *)state; +} + +static inline sha2module_state * +get_sha2module_state_by_cls(PyTypeObject *cls) +{ + _Py_hashlib_check_exported_type(cls, &sha2module_def); + void *state = _PyType_GetModuleState(cls); + assert(state != NULL); + return (sha2module_state *)state; } // --- Module clinic configuration -------------------------------------------- @@ -112,7 +123,7 @@ SHA512copy(SHA512object *src, SHA512object *dest) } static SHA256object * -newSHA224object(sha2_state *state) +newSHA224object(sha2module_state *state) { SHA256object *sha = PyObject_GC_New(SHA256object, state->sha224_type); if (!sha) { @@ -125,7 +136,7 @@ newSHA224object(sha2_state *state) } static SHA256object * -newSHA256object(sha2_state *state) +newSHA256object(sha2module_state *state) { SHA256object *sha = PyObject_GC_New(SHA256object, state->sha256_type); if (!sha) { @@ -138,7 +149,7 @@ newSHA256object(sha2_state *state) } static SHA512object * -newSHA384object(sha2_state *state) +newSHA384object(sha2module_state *state) { SHA512object *sha = PyObject_GC_New(SHA512object, state->sha384_type); if (!sha) { @@ -151,7 +162,7 @@ newSHA384object(sha2_state *state) } static SHA512object * -newSHA512object(sha2_state *state) +newSHA512object(sha2module_state *state) { SHA512object *sha = PyObject_GC_New(SHA512object, state->sha512_type); if (!sha) { @@ -247,27 +258,31 @@ update_512(Hacl_Hash_SHA2_state_t_512 *state, uint8_t *buf, Py_ssize_t len) /*[clinic input] SHA256Type.copy - cls:defining_class + cls: defining_class Return a copy of the hash object. [clinic start generated code]*/ static PyObject * SHA256Type_copy_impl(SHA256object *self, PyTypeObject *cls) -/*[clinic end generated code: output=fabd515577805cd3 input=3137146fcb88e212]*/ +/*[clinic end generated code: output=fabd515577805cd3 input=333d446f6c240abe]*/ { int rc; SHA256object *newobj; - sha2_state *state = _PyType_GetModuleState(cls); - if (Py_IS_TYPE(self, state->sha256_type)) { - if ((newobj = newSHA256object(state)) == NULL) { - return NULL; - } + sha2module_state *state = get_sha2module_state_by_cls(cls); + + if (cls == state->sha256_type) { + /* for now, we want to be sure that we are using final types */ + assert(Py_IS_TYPE(self, state->sha256_type)); + newobj = newSHA256object(state); } else { - if ((newobj = newSHA224object(state)) == NULL) { - return NULL; - } + /* for now, we want to be sure that we are using final types */ + assert(Py_IS_TYPE(self, state->sha224_type)); + newobj = newSHA224object(state); + } + if (newobj == NULL) { + return NULL; } HASHLIB_ACQUIRE_LOCK(self); @@ -294,17 +309,20 @@ SHA512Type_copy_impl(SHA512object *self, PyTypeObject *cls) { int rc; SHA512object *newobj; - sha2_state *state = _PyType_GetModuleState(cls); + sha2module_state *state = get_sha2module_state_by_cls(cls); - if (Py_IS_TYPE((PyObject*)self, state->sha512_type)) { - if ((newobj = newSHA512object(state)) == NULL) { - return NULL; - } + if (cls == state->sha512_type) { + /* for now, we want to be sure that we are using final types */ + assert(Py_IS_TYPE(self, state->sha512_type)); + newobj = newSHA512object(state); } else { - if ((newobj = newSHA384object(state)) == NULL) { - return NULL; - } + /* for now, we want to be sure that we are using final types */ + assert(Py_IS_TYPE(self, state->sha384_type)); + newobj = newSHA384object(state); + } + if (newobj == NULL) { + return NULL; } HASHLIB_ACQUIRE_LOCK(self); @@ -531,8 +549,6 @@ static PyType_Slot sha512_type_slots[] = { {0,0} }; -// Using _PyType_GetModuleState() on these types is safe since they -// cannot be subclassed: they don't have the Py_TPFLAGS_BASETYPE flag. static PyType_Spec sha224_type_spec = { .name = "_sha2.SHA224Type", .basicsize = sizeof(SHA256object), @@ -593,7 +609,7 @@ _sha2_sha256_impl(PyObject *module, PyObject *data, int usedforsecurity, GET_BUFFER_VIEW_OR_ERROUT(string, &buf); } - sha2_state *state = sha2_get_state(module); + sha2module_state *state = get_sha2module_state(module); SHA256object *new; if ((new = newSHA256object(state)) == NULL) { @@ -652,7 +668,7 @@ _sha2_sha224_impl(PyObject *module, PyObject *data, int usedforsecurity, GET_BUFFER_VIEW_OR_ERROUT(string, &buf); } - sha2_state *state = sha2_get_state(module); + sha2module_state *state = get_sha2module_state(module); SHA256object *new; if ((new = newSHA224object(state)) == NULL) { if (string) { @@ -707,7 +723,7 @@ _sha2_sha512_impl(PyObject *module, PyObject *data, int usedforsecurity, return NULL; } - sha2_state *state = sha2_get_state(module); + sha2module_state *state = get_sha2module_state(module); if (string) { GET_BUFFER_VIEW_OR_ERROUT(string, &buf); @@ -766,7 +782,7 @@ _sha2_sha384_impl(PyObject *module, PyObject *data, int usedforsecurity, return NULL; } - sha2_state *state = sha2_get_state(module); + sha2module_state *state = get_sha2module_state(module); if (string) { GET_BUFFER_VIEW_OR_ERROUT(string, &buf); @@ -815,7 +831,7 @@ static struct PyMethodDef SHA2_functions[] = { static int _sha2_traverse(PyObject *module, visitproc visit, void *arg) { - sha2_state *state = sha2_get_state(module); + sha2module_state *state = get_sha2module_state(module); Py_VISIT(state->sha224_type); Py_VISIT(state->sha256_type); Py_VISIT(state->sha384_type); @@ -826,7 +842,7 @@ _sha2_traverse(PyObject *module, visitproc visit, void *arg) static int _sha2_clear(PyObject *module) { - sha2_state *state = sha2_get_state(module); + sha2module_state *state = get_sha2module_state(module); Py_CLEAR(state->sha224_type); Py_CLEAR(state->sha256_type); Py_CLEAR(state->sha384_type); @@ -843,7 +859,7 @@ _sha2_free(void *module) /* Initialize this module. */ static int sha2_exec(PyObject *module) { - sha2_state *state = sha2_get_state(module); + sha2module_state *state = get_sha2module_state(module); state->sha224_type = (PyTypeObject *)PyType_FromModuleAndSpec( module, &sha224_type_spec, NULL); @@ -896,10 +912,10 @@ static PyModuleDef_Slot _sha2_slots[] = { {0, NULL} }; -static struct PyModuleDef _sha2module = { +static struct PyModuleDef sha2module_def = { PyModuleDef_HEAD_INIT, .m_name = "_sha2", - .m_size = sizeof(sha2_state), + .m_size = sizeof(sha2module_state), .m_methods = SHA2_functions, .m_slots = _sha2_slots, .m_traverse = _sha2_traverse, @@ -910,5 +926,5 @@ static struct PyModuleDef _sha2module = { PyMODINIT_FUNC PyInit__sha2(void) { - return PyModuleDef_Init(&_sha2module); + return PyModuleDef_Init(&sha2module_def); } diff --git a/Modules/sha3module.c b/Modules/sha3module.c index face90a6094beb..9e8b759bbcf18a 100644 --- a/Modules/sha3module.c +++ b/Modules/sha3module.c @@ -21,8 +21,10 @@ #endif #include "Python.h" -#include "pycore_strhex.h" // _Py_strhex() -#include "pycore_typeobject.h" // _PyType_GetModuleState() +#include "pycore_moduleobject.h" // _PyModule_GetState() +#include "pycore_strhex.h" // _Py_strhex() +#include "pycore_typeobject.h" // _PyType_GetModuleState() + #include "hashlib.h" #include "_hacl/Hacl_Hash_SHA3.h" @@ -42,6 +44,8 @@ // --- Module state ----------------------------------------------------------- +static struct PyModuleDef sha3module_def; + typedef struct { PyTypeObject *sha3_224_type; PyTypeObject *sha3_256_type; @@ -49,14 +53,23 @@ typedef struct { PyTypeObject *sha3_512_type; PyTypeObject *shake_128_type; PyTypeObject *shake_256_type; -} SHA3State; +} sha3module_state; + +static inline sha3module_state * +get_sha3module_state(PyObject *module) +{ + void *state = _PyModule_GetState(module); + assert(state != NULL); + return (sha3module_state *)state; +} -static inline SHA3State* -sha3_get_state(PyObject *module) +static inline sha3module_state * +get_sha3module_state_by_cls(PyTypeObject *cls) { - void *state = PyModule_GetState(module); + _Py_hashlib_check_exported_type(cls, &sha3module_def); + void *state = _PyType_GetModuleState(cls); assert(state != NULL); - return (SHA3State *)state; + return (sha3module_state *)state; } // --- Module objects --------------------------------------------------------- @@ -90,6 +103,7 @@ class _sha3.shake_256 "SHA3object *" "&PyType_Type" static SHA3object * newSHA3object(PyTypeObject *type) { + _Py_hashlib_check_exported_type(type, &sha3module_def); SHA3object *newobj = PyObject_GC_New(SHA3object, type); if (newobj == NULL) { return NULL; @@ -142,14 +156,12 @@ py_sha3_new_impl(PyTypeObject *type, PyObject *data_obj, int usedforsecurity, } Py_buffer buf = {NULL, NULL}; - SHA3State *state = _PyType_GetModuleState(type); SHA3object *self = newSHA3object(type); if (self == NULL) { goto error; } - assert(state != NULL); - + sha3module_state *state = get_sha3module_state_by_cls(type); if (type == state->sha3_224_type) { self->hash_state = Hacl_Hash_SHA3_malloc(Spec_Hash_Definitions_SHA3_224); } @@ -349,7 +361,7 @@ SHA3_get_name(PyObject *self, void *Py_UNUSED(closure)) { PyTypeObject *type = Py_TYPE(self); - SHA3State *state = _PyType_GetModuleState(type); + sha3module_state *state = get_sha3module_state_by_cls(type); assert(state != NULL); if (type == state->sha3_224_type) { @@ -617,7 +629,7 @@ SHA3_TYPE_SPEC(SHAKE256_spec, "shake_256", SHAKE256slots); static int _sha3_traverse(PyObject *module, visitproc visit, void *arg) { - SHA3State *state = sha3_get_state(module); + sha3module_state *state = get_sha3module_state(module); Py_VISIT(state->sha3_224_type); Py_VISIT(state->sha3_256_type); Py_VISIT(state->sha3_384_type); @@ -630,7 +642,7 @@ _sha3_traverse(PyObject *module, visitproc visit, void *arg) static int _sha3_clear(PyObject *module) { - SHA3State *state = sha3_get_state(module); + sha3module_state *state = get_sha3module_state(module); Py_CLEAR(state->sha3_224_type); Py_CLEAR(state->sha3_256_type); Py_CLEAR(state->sha3_384_type); @@ -649,7 +661,7 @@ _sha3_free(void *module) static int _sha3_exec(PyObject *m) { - SHA3State *st = sha3_get_state(m); + sha3module_state *st = get_sha3module_state(m); #define init_sha3type(type, typespec) \ do { \ @@ -689,10 +701,10 @@ static PyModuleDef_Slot _sha3_slots[] = { }; /* Initialize this module. */ -static struct PyModuleDef _sha3module = { +static struct PyModuleDef sha3module_def = { PyModuleDef_HEAD_INIT, .m_name = "_sha3", - .m_size = sizeof(SHA3State), + .m_size = sizeof(sha3module_state), .m_slots = _sha3_slots, .m_traverse = _sha3_traverse, .m_clear = _sha3_clear, @@ -703,5 +715,5 @@ static struct PyModuleDef _sha3module = { PyMODINIT_FUNC PyInit__sha3(void) { - return PyModuleDef_Init(&_sha3module); + return PyModuleDef_Init(&sha3module_def); }