diff --git a/.pylintrc b/.pylintrc index 122f99a9f1ce..41ad4bb15882 100644 --- a/.pylintrc +++ b/.pylintrc @@ -92,7 +92,7 @@ persistent=no # Minimum Python version to use for version dependent checks. Will default to # the version used to run pylint. -py-version=3.9 +py-version=3.10 # Discover python modules and packages in the file system subtree. recursive=no diff --git a/changelog/68030.changed.md b/changelog/68030.changed.md new file mode 100644 index 000000000000..d7e91505f6b6 --- /dev/null +++ b/changelog/68030.changed.md @@ -0,0 +1,3 @@ +PillarCache: reimplement using salt.cache +fix minion data cache organization/move pillar and grains to dedicated cache banks +salt.cache: allow cache.store() to set expires per key diff --git a/changelog/68030.fixed.md b/changelog/68030.fixed.md new file mode 100644 index 000000000000..8b7e184ce460 --- /dev/null +++ b/changelog/68030.fixed.md @@ -0,0 +1 @@ +salt.key: check_minion_cache performance optimization diff --git a/changelog/68039.changed.md b/changelog/68039.changed.md new file mode 100644 index 000000000000..336ccdba8422 --- /dev/null +++ b/changelog/68039.changed.md @@ -0,0 +1 @@ +Provide token storage using the salt.cache interface diff --git a/changelog/68068.added.md b/changelog/68068.added.md new file mode 100644 index 000000000000..9f635a6e939a --- /dev/null +++ b/changelog/68068.added.md @@ -0,0 +1,3 @@ +new: sqlalchemy base implementation for shared database access +new: sqlalchemy returner implementation +new: sqlalchemy cache implementation diff --git a/doc/ref/cache/all/index.rst b/doc/ref/cache/all/index.rst index 4081bf54e1c8..bf401c757a0b 100644 --- a/doc/ref/cache/all/index.rst +++ b/doc/ref/cache/all/index.rst @@ -18,3 +18,4 @@ For understanding and usage of the cache modules see the :ref:`cache` topic. localfs_key mysql_cache redis_cache + sqlalchemy diff --git a/doc/ref/cache/all/salt.cache.sqlalchemy.rst b/doc/ref/cache/all/salt.cache.sqlalchemy.rst new file mode 100644 index 000000000000..beed59bf16ff --- /dev/null +++ b/doc/ref/cache/all/salt.cache.sqlalchemy.rst @@ -0,0 +1,5 @@ +salt.cache.sqlalchemy +====================== + +.. automodule:: salt.cache.sqlalchemy + :members: diff --git a/doc/ref/returners/all/index.rst b/doc/ref/returners/all/index.rst index e7eb08dc8e1d..5cce1f07142d 100644 --- a/doc/ref/returners/all/index.rst +++ b/doc/ref/returners/all/index.rst @@ -18,4 +18,5 @@ returner modules postgres postgres_local_cache rawfile_json + sqlalchemy syslog_return diff --git a/doc/ref/returners/all/salt.returners.sqlalchemy.rst b/doc/ref/returners/all/salt.returners.sqlalchemy.rst new file mode 100644 index 000000000000..69a88fd48081 --- /dev/null +++ b/doc/ref/returners/all/salt.returners.sqlalchemy.rst @@ -0,0 +1,5 @@ +salt.returners.sqlalchemy +========================== + +.. automodule:: salt.returners.sqlalchemy + :members: diff --git a/doc/ref/runners/all/index.rst b/doc/ref/runners/all/index.rst index 3bf5b192d675..209bce221912 100644 --- a/doc/ref/runners/all/index.rst +++ b/doc/ref/runners/all/index.rst @@ -31,6 +31,7 @@ runner modules salt saltutil sdb + sqlalchemy ssh state survey diff --git a/doc/ref/runners/all/salt.runners.sqlalchemy.rst b/doc/ref/runners/all/salt.runners.sqlalchemy.rst new file mode 100644 index 000000000000..150505926b86 --- /dev/null +++ b/doc/ref/runners/all/salt.runners.sqlalchemy.rst @@ -0,0 +1,5 @@ +salt.runners.sqlalchemy +======================= + +.. automodule:: salt.runners.sqlalchemy + :members: diff --git a/requirements/static/ci/common.in b/requirements/static/ci/common.in index 9aad00959b83..0dc995198a69 100644 --- a/requirements/static/ci/common.in +++ b/requirements/static/ci/common.in @@ -48,3 +48,4 @@ genshi>=0.7.3 cheetah3>=3.2.2 mako wempy +sqlalchemy diff --git a/requirements/static/ci/py3.10/darwin.txt b/requirements/static/ci/py3.10/darwin.txt index 390d4555c230..7ad847fdb4a0 100644 --- a/requirements/static/ci/py3.10/darwin.txt +++ b/requirements/static/ci/py3.10/darwin.txt @@ -149,6 +149,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy hglib==2.6.2 # via -r requirements/static/ci/darwin.in idna==3.7 @@ -482,6 +484,8 @@ six==1.16.0 # websocket-client smmap==5.0.1 # via gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -524,6 +528,7 @@ typing-extensions==4.8.0 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.10/darwin.txt diff --git a/requirements/static/ci/py3.10/freebsd.txt b/requirements/static/ci/py3.10/freebsd.txt index c0bafadf7e44..ab404348c4b6 100644 --- a/requirements/static/ci/py3.10/freebsd.txt +++ b/requirements/static/ci/py3.10/freebsd.txt @@ -148,6 +148,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy hglib==2.6.2 # via -r requirements/static/ci/freebsd.in idna==3.7 @@ -487,6 +489,8 @@ six==1.16.0 # websocket-client smmap==5.0.1 # via gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -529,6 +533,7 @@ typing-extensions==4.8.0 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.10/freebsd.txt diff --git a/requirements/static/ci/py3.10/linux.txt b/requirements/static/ci/py3.10/linux.txt index 25813d5958a8..d5d2fc6ca752 100644 --- a/requirements/static/ci/py3.10/linux.txt +++ b/requirements/static/ci/py3.10/linux.txt @@ -164,6 +164,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy h11==0.14.0 # via httpcore hglib==2.6.2 @@ -549,6 +551,8 @@ sniffio==1.3.0 # anyio # httpcore # httpx +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -593,6 +597,7 @@ typing-extensions==4.8.0 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.10/linux.txt diff --git a/requirements/static/ci/py3.10/windows.txt b/requirements/static/ci/py3.10/windows.txt index ffc507c041e7..fa9807661573 100644 --- a/requirements/static/ci/py3.10/windows.txt +++ b/requirements/static/ci/py3.10/windows.txt @@ -148,6 +148,8 @@ gitpython==3.1.43 ; sys_platform == "win32" # -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy idna==3.7 # via # -c requirements/static/ci/../pkg/py3.10/windows.txt @@ -447,6 +449,8 @@ smmap==5.0.1 # via # -c requirements/static/ci/../pkg/py3.10/windows.txt # gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -474,7 +478,9 @@ trustme==1.1.0 types-pyyaml==6.0.1 # via responses typing-extensions==4.8.0 - # via pytest-system-statistics + # via + # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.10/windows.txt diff --git a/requirements/static/ci/py3.11/darwin.txt b/requirements/static/ci/py3.11/darwin.txt index d74102d117cb..101f97a981f8 100644 --- a/requirements/static/ci/py3.11/darwin.txt +++ b/requirements/static/ci/py3.11/darwin.txt @@ -142,6 +142,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy hglib==2.6.2 # via -r requirements/static/ci/darwin.in idna==3.7 @@ -475,6 +477,8 @@ six==1.16.0 # websocket-client smmap==5.0.1 # via gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -515,6 +519,7 @@ typing-extensions==4.8.0 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.11/darwin.txt diff --git a/requirements/static/ci/py3.11/freebsd.txt b/requirements/static/ci/py3.11/freebsd.txt index e137073ac32c..45e64b597cb3 100644 --- a/requirements/static/ci/py3.11/freebsd.txt +++ b/requirements/static/ci/py3.11/freebsd.txt @@ -141,6 +141,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy hglib==2.6.2 # via -r requirements/static/ci/freebsd.in idna==3.7 @@ -481,6 +483,8 @@ six==1.16.0 # websocket-client smmap==5.0.1 # via gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -521,6 +525,7 @@ typing-extensions==4.8.0 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.11/freebsd.txt diff --git a/requirements/static/ci/py3.11/linux.txt b/requirements/static/ci/py3.11/linux.txt index 8dd22b7480f9..b804daa6056b 100644 --- a/requirements/static/ci/py3.11/linux.txt +++ b/requirements/static/ci/py3.11/linux.txt @@ -155,6 +155,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy h11==0.14.0 # via httpcore hglib==2.6.2 @@ -541,6 +543,8 @@ sniffio==1.3.0 # anyio # httpcore # httpx +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -583,6 +587,7 @@ typing-extensions==4.8.0 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.11/linux.txt diff --git a/requirements/static/ci/py3.11/windows.txt b/requirements/static/ci/py3.11/windows.txt index 8fda3a66854b..2c3fd1a2f847 100644 --- a/requirements/static/ci/py3.11/windows.txt +++ b/requirements/static/ci/py3.11/windows.txt @@ -141,6 +141,8 @@ gitpython==3.1.43 ; sys_platform == "win32" # -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy idna==3.7 # via # -c requirements/static/ci/../pkg/py3.11/windows.txt @@ -440,6 +442,8 @@ smmap==5.0.1 # via # -c requirements/static/ci/../pkg/py3.11/windows.txt # gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -465,7 +469,9 @@ trustme==1.1.0 types-pyyaml==6.0.12.12 # via responses typing-extensions==4.8.0 - # via pytest-system-statistics + # via + # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.11/windows.txt diff --git a/requirements/static/ci/py3.12/cloud.txt b/requirements/static/ci/py3.12/cloud.txt index 7c1c26e22295..e70990424c62 100644 --- a/requirements/static/ci/py3.12/cloud.txt +++ b/requirements/static/ci/py3.12/cloud.txt @@ -201,6 +201,10 @@ google-auth==2.27.0 # via # -c requirements/static/ci/py3.12/linux.txt # kubernetes +greenlet==3.2.3 + # via + # -c requirements/static/ci/py3.12/linux.txt + # sqlalchemy idna==3.7 # via # -c requirements/static/ci/../pkg/py3.12/linux.txt @@ -684,6 +688,10 @@ smmap==5.0.1 # via # -c requirements/static/ci/py3.12/linux.txt # gitdb +sqlalchemy==2.0.41 + # via + # -c requirements/static/ci/py3.12/linux.txt + # -r requirements/static/ci/common.in sqlparse==0.5.0 # via # -c requirements/static/ci/py3.12/linux.txt @@ -744,6 +752,7 @@ typing-extensions==4.8.0 # -c requirements/static/ci/py3.12/linux.txt # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.12/linux.txt diff --git a/requirements/static/ci/py3.12/darwin.txt b/requirements/static/ci/py3.12/darwin.txt index c134d508921f..0df48b7a5968 100644 --- a/requirements/static/ci/py3.12/darwin.txt +++ b/requirements/static/ci/py3.12/darwin.txt @@ -142,6 +142,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy hglib==2.6.2 # via -r requirements/static/ci/darwin.in idna==3.7 @@ -475,6 +477,8 @@ six==1.16.0 # websocket-client smmap==5.0.1 # via gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -515,6 +519,7 @@ typing-extensions==4.8.0 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.12/darwin.txt diff --git a/requirements/static/ci/py3.12/freebsd.txt b/requirements/static/ci/py3.12/freebsd.txt index 4f1d6814c223..9f42bd72266e 100644 --- a/requirements/static/ci/py3.12/freebsd.txt +++ b/requirements/static/ci/py3.12/freebsd.txt @@ -141,6 +141,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy hglib==2.6.2 # via -r requirements/static/ci/freebsd.in idna==3.7 @@ -481,6 +483,8 @@ six==1.16.0 # websocket-client smmap==5.0.1 # via gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -521,6 +525,7 @@ typing-extensions==4.8.0 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.12/freebsd.txt diff --git a/requirements/static/ci/py3.12/lint.txt b/requirements/static/ci/py3.12/lint.txt index 03dcf9dfbc20..eb5f8900caf4 100644 --- a/requirements/static/ci/py3.12/lint.txt +++ b/requirements/static/ci/py3.12/lint.txt @@ -213,6 +213,10 @@ google-auth==2.27.0 # via # -c requirements/static/ci/py3.12/linux.txt # kubernetes +greenlet==3.2.3 + # via + # -c requirements/static/ci/py3.12/linux.txt + # sqlalchemy h11==0.14.0 # via # -c requirements/static/ci/py3.12/linux.txt @@ -689,6 +693,10 @@ sniffio==1.3.0 # anyio # httpcore # httpx +sqlalchemy==2.0.41 + # via + # -c requirements/static/ci/py3.12/linux.txt + # -r requirements/static/ci/common.in sqlparse==0.5.0 # via # -c requirements/static/ci/py3.12/linux.txt @@ -751,6 +759,7 @@ typing-extensions==4.8.0 # via # -c requirements/static/ci/py3.12/linux.txt # napalm + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.12/linux.txt diff --git a/requirements/static/ci/py3.12/linux.txt b/requirements/static/ci/py3.12/linux.txt index 92e23afb17d2..71178bea1d8e 100644 --- a/requirements/static/ci/py3.12/linux.txt +++ b/requirements/static/ci/py3.12/linux.txt @@ -155,6 +155,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy h11==0.14.0 # via httpcore hglib==2.6.2 @@ -541,6 +543,8 @@ sniffio==1.3.0 # anyio # httpcore # httpx +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -583,6 +587,7 @@ typing-extensions==4.8.0 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.12/linux.txt diff --git a/requirements/static/ci/py3.12/windows.txt b/requirements/static/ci/py3.12/windows.txt index 9e8e6f53aaf7..66d6aa122e59 100644 --- a/requirements/static/ci/py3.12/windows.txt +++ b/requirements/static/ci/py3.12/windows.txt @@ -141,6 +141,8 @@ gitpython==3.1.43 ; sys_platform == "win32" # -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy idna==3.7 # via # -c requirements/static/ci/../pkg/py3.12/windows.txt @@ -440,6 +442,8 @@ smmap==5.0.1 # via # -c requirements/static/ci/../pkg/py3.12/windows.txt # gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -465,7 +469,9 @@ trustme==1.1.0 types-pyyaml==6.0.12.12 # via responses typing-extensions==4.8.0 - # via pytest-system-statistics + # via + # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.12/windows.txt diff --git a/requirements/static/ci/py3.13/cloud.txt b/requirements/static/ci/py3.13/cloud.txt index 71e6b177dea3..8feb5626c70c 100644 --- a/requirements/static/ci/py3.13/cloud.txt +++ b/requirements/static/ci/py3.13/cloud.txt @@ -198,6 +198,10 @@ google-auth==2.35.0 # via # -c requirements/static/ci/py3.13/linux.txt # kubernetes +greenlet==3.2.3 + # via + # -c requirements/static/ci/py3.13/linux.txt + # sqlalchemy idna==3.10 # via # -c requirements/static/ci/../pkg/py3.13/linux.txt @@ -682,6 +686,10 @@ smmap==5.0.1 # via # -c requirements/static/ci/py3.13/linux.txt # gitdb +sqlalchemy==2.0.41 + # via + # -c requirements/static/ci/py3.13/linux.txt + # -r requirements/static/ci/common.in sqlparse==0.5.1 # via # -c requirements/static/ci/py3.13/linux.txt @@ -738,6 +746,7 @@ typing-extensions==4.12.2 # -c requirements/static/ci/py3.13/linux.txt # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.13/linux.txt diff --git a/requirements/static/ci/py3.13/darwin.txt b/requirements/static/ci/py3.13/darwin.txt index e54561c2eed9..ef768be9aa60 100644 --- a/requirements/static/ci/py3.13/darwin.txt +++ b/requirements/static/ci/py3.13/darwin.txt @@ -140,6 +140,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.35.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy hglib==2.6.2 # via -r requirements/static/ci/darwin.in idna==3.10 @@ -476,6 +478,8 @@ six==1.16.0 # websocket-client smmap==5.0.1 # via gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.1 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -514,6 +518,7 @@ typing-extensions==4.12.2 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.13/darwin.txt diff --git a/requirements/static/ci/py3.13/freebsd.txt b/requirements/static/ci/py3.13/freebsd.txt index 6774d058bc29..90c89f64cd92 100644 --- a/requirements/static/ci/py3.13/freebsd.txt +++ b/requirements/static/ci/py3.13/freebsd.txt @@ -139,6 +139,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.35.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy hglib==2.6.2 # via -r requirements/static/ci/freebsd.in idna==3.10 @@ -480,6 +482,8 @@ six==1.16.0 # websocket-client smmap==5.0.1 # via gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.1 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -518,6 +522,7 @@ typing-extensions==4.12.2 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.13/freebsd.txt diff --git a/requirements/static/ci/py3.13/lint.txt b/requirements/static/ci/py3.13/lint.txt index d6a48bdf77a7..80bec9c4b130 100644 --- a/requirements/static/ci/py3.13/lint.txt +++ b/requirements/static/ci/py3.13/lint.txt @@ -209,6 +209,10 @@ google-auth==2.35.0 # via # -c requirements/static/ci/py3.13/linux.txt # kubernetes +greenlet==3.2.3 + # via + # -c requirements/static/ci/py3.13/linux.txt + # sqlalchemy h11==0.14.0 # via # -c requirements/static/ci/py3.13/linux.txt @@ -685,6 +689,10 @@ sniffio==1.3.1 # -c requirements/static/ci/py3.13/linux.txt # anyio # httpx +sqlalchemy==2.0.41 + # via + # -c requirements/static/ci/py3.13/linux.txt + # -r requirements/static/ci/common.in sqlparse==0.5.1 # via # -c requirements/static/ci/py3.13/linux.txt @@ -743,6 +751,7 @@ typing-extensions==4.12.2 # via # -c requirements/static/ci/py3.13/linux.txt # napalm + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.13/linux.txt diff --git a/requirements/static/ci/py3.13/linux.txt b/requirements/static/ci/py3.13/linux.txt index 976075585d05..c8b74bff5abc 100644 --- a/requirements/static/ci/py3.13/linux.txt +++ b/requirements/static/ci/py3.13/linux.txt @@ -153,6 +153,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.35.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy h11==0.14.0 # via httpcore hglib==2.6.2 @@ -538,6 +540,8 @@ sniffio==1.3.1 # via # anyio # httpx +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.1 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -578,6 +582,7 @@ typing-extensions==4.12.2 # via # napalm # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.13/linux.txt diff --git a/requirements/static/ci/py3.13/windows.txt b/requirements/static/ci/py3.13/windows.txt index 1cc3e72fe154..cf50a9147d09 100644 --- a/requirements/static/ci/py3.13/windows.txt +++ b/requirements/static/ci/py3.13/windows.txt @@ -142,6 +142,8 @@ gitpython==3.1.43 ; sys_platform == "win32" # -r requirements/static/ci/common.in google-auth==2.35.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy idna==3.10 # via # -c requirements/static/ci/../pkg/py3.13/windows.txt @@ -443,6 +445,8 @@ smmap==5.0.1 # via # -c requirements/static/ci/../pkg/py3.13/windows.txt # gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.1 # via -r requirements/static/ci/common.in sspilib==0.2.0 @@ -468,7 +472,9 @@ tornado==6.4.1 trustme==1.2.0 # via -r requirements/pytest.txt typing-extensions==4.12.2 - # via pytest-system-statistics + # via + # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.13/windows.txt diff --git a/requirements/static/ci/py3.9/darwin.txt b/requirements/static/ci/py3.9/darwin.txt index 63234110fa02..1f61a623e96c 100644 --- a/requirements/static/ci/py3.9/darwin.txt +++ b/requirements/static/ci/py3.9/darwin.txt @@ -149,6 +149,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy hglib==2.6.2 # via -r requirements/static/ci/darwin.in idna==3.7 @@ -482,6 +484,8 @@ six==1.16.0 # websocket-client smmap==5.0.1 # via gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -525,6 +529,7 @@ typing-extensions==4.8.0 # napalm # pytest-shell-utilities # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.9/darwin.txt diff --git a/requirements/static/ci/py3.9/freebsd.txt b/requirements/static/ci/py3.9/freebsd.txt index 5f58710a7f15..dfa3435217e1 100644 --- a/requirements/static/ci/py3.9/freebsd.txt +++ b/requirements/static/ci/py3.9/freebsd.txt @@ -148,6 +148,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy hglib==2.6.2 # via -r requirements/static/ci/freebsd.in idna==3.7 @@ -487,6 +489,8 @@ six==1.16.0 # websocket-client smmap==5.0.1 # via gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -530,6 +534,7 @@ typing-extensions==4.8.0 # napalm # pytest-shell-utilities # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.9/freebsd.txt diff --git a/requirements/static/ci/py3.9/linux.txt b/requirements/static/ci/py3.9/linux.txt index 3d9f9151ae72..c363f289288b 100644 --- a/requirements/static/ci/py3.9/linux.txt +++ b/requirements/static/ci/py3.9/linux.txt @@ -159,6 +159,8 @@ gitpython==3.1.43 # via -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy h11==0.14.0 # via httpcore hglib==2.6.2 @@ -539,6 +541,8 @@ sniffio==1.3.0 # anyio # httpcore # httpx +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -582,6 +586,7 @@ typing-extensions==4.8.0 # napalm # pytest-shell-utilities # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.9/linux.txt diff --git a/requirements/static/ci/py3.9/windows.txt b/requirements/static/ci/py3.9/windows.txt index 0f8bf20ef18e..0e7cbb70056a 100644 --- a/requirements/static/ci/py3.9/windows.txt +++ b/requirements/static/ci/py3.9/windows.txt @@ -148,6 +148,8 @@ gitpython==3.1.43 ; sys_platform == "win32" # -r requirements/static/ci/common.in google-auth==2.27.0 # via kubernetes +greenlet==3.2.3 + # via sqlalchemy idna==3.7 # via # -c requirements/static/ci/../pkg/py3.9/windows.txt @@ -448,6 +450,8 @@ smmap==5.0.1 # via # -c requirements/static/ci/../pkg/py3.9/windows.txt # gitdb +sqlalchemy==2.0.41 + # via -r requirements/static/ci/common.in sqlparse==0.5.0 # via -r requirements/static/ci/common.in strict-rfc3339==0.7 @@ -478,6 +482,7 @@ typing-extensions==4.8.0 # via # pytest-shell-utilities # pytest-system-statistics + # sqlalchemy urllib3==1.26.20 # via # -c requirements/static/ci/../pkg/py3.9/windows.txt diff --git a/salt/auth/__init__.py b/salt/auth/__init__.py index f00e84ba33a0..98101a3d8039 100644 --- a/salt/auth/__init__.py +++ b/salt/auth/__init__.py @@ -13,18 +13,18 @@ # 6. Interface to verify tokens import getpass +import hashlib import logging +import os import random import time from collections.abc import Iterable, Mapping +import salt.cache import salt.channel.client -import salt.config import salt.exceptions import salt.loader -import salt.payload import salt.utils.args -import salt.utils.dictupdate import salt.utils.files import salt.utils.minions import salt.utils.network @@ -62,6 +62,10 @@ def __init__(self, opts, ckminions=None): self.auth = salt.loader.auth(opts) self.tokens = salt.loader.eauth_tokens(opts) self._ckminions = ckminions + tokens_cluster_id = opts["eauth_tokens.cluster_id"] or opts["cluster_id"] + self.cache = salt.cache.factory( + opts, driver=opts["eauth_tokens.cache_driver"], cluster_id=tokens_cluster_id + ) @cached_property def ckminions(self): @@ -235,54 +239,168 @@ def mk_token(self, load): if groups: tdata["groups"] = groups - return self.tokens["{}.mk_token".format(self.opts["eauth_tokens"])]( - self.opts, tdata - ) + if self.opts["eauth_tokens.cache_driver"] == "rediscluster": + salt.utils.versions.warn_until( + 3010, + "The 'rediscluster' token backend has been deprecated, and will be removed " + "in the Calcium release. Please use the 'redis_cache' cache backend instead.", + ) + return self.tokens["{}.mk_token".format(self.opts["eauth_tokens"])]( + self.opts, tdata + ) + else: + hash_type = getattr(hashlib, self.opts.get("hash_type", "md5")) + new_token = str(hash_type(os.urandom(512)).hexdigest()) + tdata["token"] = new_token + try: + self.cache.store("tokens", new_token, tdata, expires=tdata["expire"]) + except salt.exceptions.SaltCacheError as err: + log.error( + "Cannot mk_token from tokens cache using %s: %s", + self.opts["eauth_tokens.cache_driver"], + err, + ) + return {} + + return tdata def get_tok(self, tok): """ Return the name associated with the token, or False if the token is not valid """ - tdata = {} - try: - tdata = self.tokens["{}.get_token".format(self.opts["eauth_tokens"])]( - self.opts, tok + if self.opts["eauth_tokens.cache_driver"] == "rediscluster": + salt.utils.versions.warn_until( + 3010, + "The 'rediscluster' token backend has been deprecated, and will be removed " + "in the Calcium release. Please use the 'redis_cache' cache backend instead.", ) - except salt.exceptions.SaltDeserializationError: - log.warning("Failed to load token %r - removing broken/empty file.", tok) - rm_tok = True - else: - if not tdata: + + tdata = {} + try: + tdata = self.tokens["{}.get_token".format(self.opts["eauth_tokens"])]( + self.opts, tok + ) + except salt.exceptions.SaltDeserializationError: + log.warning( + "Failed to load token %r - removing broken/empty file.", tok + ) + rm_tok = True + else: + if not tdata: + return {} + rm_tok = False + + if tdata.get("expire", 0) < time.time(): + # If expire isn't present in the token it's invalid and needs + # to be removed. Also, if it's present and has expired - in + # other words, the expiration is before right now, it should + # be removed. + rm_tok = True + + if rm_tok: + self.rm_token(tok) return {} - rm_tok = False - if tdata.get("expire", 0) < time.time(): - # If expire isn't present in the token it's invalid and needs - # to be removed. Also, if it's present and has expired - in - # other words, the expiration is before right now, it should - # be removed. - rm_tok = True + return tdata + else: + try: + tdata = self.cache.fetch("tokens", tok) - if rm_tok: - self.rm_token(tok) - return {} + if tdata.get("expire", 0) < time.time(): + raise salt.exceptions.TokenExpiredError - return tdata + return tdata + except ( + salt.exceptions.SaltDeserializationError, + salt.exceptions.TokenExpiredError, + ): + log.warning( + "Failed to load token %r - removing broken/empty file.", tok + ) + self.rm_token(tok) + except salt.exceptions.SaltCacheError as err: + log.error( + "Cannot get token %s from tokens cache using %s: %s", + tok, + self.opts["eauth_tokens.cache_driver"], + err, + ) + return {} def list_tokens(self): """ List all tokens in eauth_tokens storage. """ - return self.tokens["{}.list_tokens".format(self.opts["eauth_tokens"])]( - self.opts - ) + if self.opts["eauth_tokens.cache_driver"] == "rediscluster": + salt.utils.versions.warn_until( + 3010, + "The 'rediscluster' token backend has been deprecated, and will be removed " + "in the Calcium release. Please use the 'redis_cache' cache backend instead.", + ) + + return self.tokens["{}.list_tokens".format(self.opts["eauth_tokens"])]( + self.opts + ) + else: + try: + return self.cache.list("tokens") + except salt.exceptions.SaltCacheError as err: + log.error( + "Cannot list tokens from tokens cache using %s: %s", + self.opts["eauth_tokens.cache_driver"], + err, + ) + return [] def rm_token(self, tok): """ Remove the given token from token storage. """ - self.tokens["{}.rm_token".format(self.opts["eauth_tokens"])](self.opts, tok) + if self.opts["eauth_tokens.cache_driver"] == "rediscluster": + salt.utils.versions.warn_until( + 3010, + "The 'rediscluster' token backend has been deprecated, and will be removed " + "in the Calcium release. Please use the 'redis_cache' cache backend instead.", + ) + + self.tokens["{}.rm_token".format(self.opts["eauth_tokens"])](self.opts, tok) + else: + try: + return self.cache.flush("tokens", tok) + except salt.exceptions.SaltCacheError as err: + log.error( + "Cannot rm token %s from tokens cache using %s: %s", + tok, + self.opts["eauth_tokens.cache_driver"], + err, + ) + return {} + + def clean_expired_tokens(self): + """ + Clean expired tokens + """ + if self.opts["eauth_tokens.cache_driver"] == "rediscluster": + salt.utils.versions.warn_until( + 3010, + "The 'rediscluster' token backend has been deprecated, and will be removed " + "in the Calcium release. Please use the 'redis_cache' cache backend instead.", + ) + log.debug( + "cleaning expired tokens using token driver: {}".format( + self.opts["eauth_tokens"] + ) + ) + for token in self.list_tokens(): + token_data = self.get_tok(token) + if ( + "expire" not in token_data + or token_data.get("expire", 0) < time.time() + ): + self.rm_token(token) + else: + self.cache.clean_expired("tokens") def authenticate_token(self, load): """ @@ -609,6 +727,15 @@ def get_token(self, token): tdata = self._send_token_request(load) return tdata + def rm_token(self, token): + """ + Delete a token from the master + """ + load = {} + load["token"] = token + load["cmd"] = "rm_token" + self._send_token_request(load) + class AuthUser: """ diff --git a/salt/cache/__init__.py b/salt/cache/__init__.py index 8b0389d313d7..030ab9a1b2ba 100644 --- a/salt/cache/__init__.py +++ b/salt/cache/__init__.py @@ -4,12 +4,14 @@ .. versionadded:: 2016.11.0 """ +import datetime import logging import time import salt.config import salt.loader import salt.syspaths +from salt.utils.decorators import cached_property from salt.utils.odict import OrderedDict log = logging.getLogger(__name__) @@ -58,30 +60,28 @@ class Cache: def __init__(self, opts, cachedir=None, **kwargs): self.opts = opts - if cachedir is None: - self.cachedir = opts.get("cachedir", salt.syspaths.CACHE_DIR) + + if kwargs.get("driver"): + self.driver = kwargs["driver"] else: - self.cachedir = cachedir - self.driver = kwargs.get( - "driver", opts.get("cache", salt.config.DEFAULT_MASTER_OPTS["cache"]) + self.driver = opts.get("cache", salt.config.DEFAULT_MASTER_OPTS["cache"]) + + self.cachedir = kwargs["cachedir"] = cachedir or opts.get( + "cachedir", salt.syspaths.CACHE_DIR ) self._modules = None self._kwargs = kwargs - self._kwargs["cachedir"] = self.cachedir - def __lazy_init(self): - self._modules = salt.loader.cache(self.opts) - fun = f"{self.driver}.init_kwargs" - if fun in self.modules: - self._kwargs = self.modules[fun](self._kwargs) - else: - self._kwargs = {} - - @property + @cached_property def modules(self): - if self._modules is None: - self.__lazy_init() - return self._modules + return salt.loader.cache(self.opts) + + @cached_property + def kwargs(self): + try: + return self.modules[f"{self.driver}.init_kwargs"](self._kwargs) + except KeyError: + return {} def cache(self, bank, key, fun, loop_fun=None, **kwargs): """ @@ -121,7 +121,7 @@ def cache(self, bank, key, fun, loop_fun=None, **kwargs): return data - def store(self, bank, key, data): + def store(self, bank, key, data, expires=None): """ Store data using the specified module @@ -138,12 +138,28 @@ def store(self, bank, key, data): The data which will be stored in the cache. This data should be in a format which can be serialized by msgpack. - :raises SaltCacheError: + :param expires: + how many seconds from now the data should be considered stale. + + :raises SaltCacheError: Raises an exception if cache driver detected an error accessing data in the cache backend (auth, permissions, etc). """ fun = f"{self.driver}.store" - return self.modules[fun](bank, key, data, **self._kwargs) + try: + return self.modules[fun](bank, key, data, expires=expires, **self.kwargs) + except TypeError: + # if the backing store doesnt natively support expiry, we handle it as a fallback + if expires: + expires_at = datetime.datetime.now().astimezone() + datetime.timedelta( + seconds=expires + ) + expires_at = int(expires_at.timestamp()) + return self.modules[fun]( + bank, key, {"data": data, "_expires": expires_at}, **self.kwargs + ) + else: + return self.modules[fun](bank, key, data, **self.kwargs) def fetch(self, bank, key): """ @@ -167,7 +183,17 @@ def fetch(self, bank, key): in the cache backend (auth, permissions, etc). """ fun = f"{self.driver}.fetch" - return self.modules[fun](bank, key, **self._kwargs) + ret = self.modules[fun](bank, key, **self.kwargs) + + # handle fallback if necessary + if isinstance(ret, dict) and set(ret.keys()) == {"data", "_expires"}: + now = datetime.datetime.now().astimezone().timestamp() + if ret["_expires"] > now: + return ret["data"] + else: + return {} + else: + return ret def updated(self, bank, key): """ @@ -191,7 +217,7 @@ def updated(self, bank, key): in the cache backend (auth, permissions, etc). """ fun = f"{self.driver}.updated" - return self.modules[fun](bank, key, **self._kwargs) + return self.modules[fun](bank, key, **self.kwargs) def flush(self, bank, key=None): """ @@ -212,7 +238,7 @@ def flush(self, bank, key=None): in the cache backend (auth, permissions, etc). """ fun = f"{self.driver}.flush" - return self.modules[fun](bank, key=key, **self._kwargs) + return self.modules[fun](bank, key=key, **self.kwargs) def list(self, bank): """ @@ -231,7 +257,7 @@ def list(self, bank): in the cache backend (auth, permissions, etc). """ fun = f"{self.driver}.list" - return self.modules[fun](bank, **self._kwargs) + return self.modules[fun](bank, **self.kwargs) def contains(self, bank, key=None): """ @@ -256,7 +282,33 @@ def contains(self, bank, key=None): in the cache backend (auth, permissions, etc). """ fun = f"{self.driver}.contains" - return self.modules[fun](bank, key, **self._kwargs) + return self.modules[fun](bank, key, **self.kwargs) + + def clean_expired(self, bank, *args, **kwargs): + """ + Clean expired keys + + :param bank: + The name of the location inside the cache which will hold the key + and its associated data. + + :raises SaltCacheError: + Raises an exception if cache driver detected an error accessing data + in the cache backend (auth, permissions, etc). + """ + # If the cache driver has a clean_expired() func, call it to clean up + # expired keys. + clean_expired = f"{self.driver}.clean_expired" + if clean_expired in self.modules: + self.modules[clean_expired](bank, *args, **{**self.kwargs, **kwargs}) + else: + list_ = f"{self.driver}.list" + updated = f"{self.driver}.updated" + flush = f"{self.driver}.flush" + for key in self.modules[list_](bank, **self.kwargs): + ts = self.modules[updated](bank, key, **self.kwargs) + if ts is not None and ts <= time.time(): + self.modules[flush](bank, key, **self.kwargs) class MemCache(Cache): @@ -309,21 +361,28 @@ def fetch(self, bank, key): if self.debug: self.call += 1 now = time.time() + expires = None record = self.storage.pop((bank, key), None) # Have a cached value for the key - if record is not None and record[0] + self.expire >= now: - if self.debug: - self.hit += 1 - log.debug( - "MemCache stats (call/hit/rate): %s/%s/%s", - self.call, - self.hit, - float(self.hit) / self.call, - ) - # update atime and return - record[0] = now - self.storage[(bank, key)] = record - return record[1] + if record is not None: + if len(record) == 2: + (created_at, data) = record + elif len(record) == 3: + (created_at, expires, data) = record + + if (created_at + (expires or self.expire)) >= now: + if self.debug: + self.hit += 1 + log.debug( + "MemCache stats (call/hit/rate): %s/%s/%s", + self.call, + self.hit, + float(self.hit) / self.call, + ) + # update atime and return + record[0] = now + self.storage[(bank, key)] = record + return data # Have no value for the key or value is expired data = super().fetch(bank, key) @@ -332,18 +391,18 @@ def fetch(self, bank, key): MemCache.__cleanup(self.expire) if len(self.storage) >= self.max: self.storage.popitem(last=False) - self.storage[(bank, key)] = [now, data] + self.storage[(bank, key)] = [now, self.expire, data] return data - def store(self, bank, key, data): + def store(self, bank, key, data, expires=None): self.storage.pop((bank, key), None) - super().store(bank, key, data) + super().store(bank, key, data, expires=expires) if len(self.storage) >= self.max: if self.cleanup: MemCache.__cleanup(self.expire) if len(self.storage) >= self.max: self.storage.popitem(last=False) - self.storage[(bank, key)] = [time.time(), data] + self.storage[(bank, key)] = [time.time(), expires, data] def flush(self, bank, key=None): if key is None: diff --git a/salt/cache/localfs.py b/salt/cache/localfs.py index 96a9a13aeb41..af701b727349 100644 --- a/salt/cache/localfs.py +++ b/salt/cache/localfs.py @@ -76,6 +76,7 @@ def fetch(bank, key, cachedir): inkey = False key_file = salt.utils.path.join(cachedir, os.path.normpath(bank), f"{key}.p") if not os.path.isfile(key_file): + log.debug('Cache file "%s" does not exist', key_file) # The bank includes the full filename, and the key is inside the file key_file = salt.utils.path.join(cachedir, os.path.normpath(bank) + ".p") inkey = True diff --git a/salt/cache/sqlalchemy.py b/salt/cache/sqlalchemy.py new file mode 100644 index 000000000000..5dec5944afaa --- /dev/null +++ b/salt/cache/sqlalchemy.py @@ -0,0 +1,284 @@ +""" +Cache plugin for SQLAlchemy +""" + +import datetime +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING + +import salt.sqlalchemy +from salt.sqlalchemy import model_for + +try: + import sqlalchemy.exc + from sqlalchemy import delete, insert, or_, select, tuple_, update + from sqlalchemy.dialects.postgresql import insert as pg_insert + from sqlalchemy.sql.functions import count, now +except ImportError: + pass + + +if TYPE_CHECKING: + __opts__ = {} + __salt__: dict[str, Callable] + + +log = logging.getLogger(__name__) + +__virtualname__ = "sqlalchemy" + + +def __virtual__(): + """ + Confirm that SQLAlchemy is setup + """ + if not salt.sqlalchemy.orm_configured(): + salt.sqlalchemy.configure_orm(__opts__) + return __virtualname__ + + +def init_kwargs(kwargs): + """ + init kwargs + """ + cluster_id = kwargs.get("cluster_id", __opts__["cluster_id"]) + + # we use cluster_id as a pk, None/null wont work + if not cluster_id: + cluster_id = "null" + + return { + "cluster_id": cluster_id, + "expires": kwargs.get("expires"), + "engine_name": kwargs.get("engine_name"), + } + + +def fetch(bank, key, cluster_id=None, expires=None, engine_name=None): + """ + Fetch a key value. + """ + Cache = model_for("Cache", engine_name=engine_name) + with salt.sqlalchemy.ROSession(engine_name) as session: + stmt = select(Cache).where( + Cache.key == key, + Cache.bank == bank, + Cache.cluster == cluster_id, + or_( + Cache.expires_at.is_(None), + Cache.expires_at >= now(), + ), + ) + + result = session.execute(stmt).scalars().first() + + data = {} + + if result: + data = result.data + + session.commit() + + return data + + +def store(bank, key, data, expires=None, cluster_id=None, engine_name=None): + """ + Store a key value. + """ + if expires: + expires_at = datetime.datetime.now().astimezone() + datetime.timedelta( + seconds=expires + ) + elif isinstance(data, dict) and "expire" in data: + if isinstance(data["expire"], float): + # only convert if unix timestamp + expires_at = datetime.datetime.fromtimestamp(data["expire"]).isoformat() + else: + expires_at = None + + log.trace( + "storing %s:%s:%s:%s:%s", + bank, + key, + data, + expires_at, + cluster_id, + ) + + Cache = model_for("Cache", engine_name=engine_name) + with salt.sqlalchemy.Session(engine_name) as session: + if session.bind.dialect.name == "postgresql": + stmt = pg_insert(Cache).values( + key=key, + bank=bank, + data=data, + expires_at=expires_at, + cluster=cluster_id, + ) + stmt = stmt.on_conflict_do_update( + index_elements=[Cache.cluster, Cache.bank, Cache.key], + set_=dict( + data=stmt.excluded.data, + expires_at=stmt.excluded.expires_at, + created_at=now(), + ), + ) + + session.execute(stmt) + session.commit() + else: + # the default path is racy, so any implementation specific upsert is preferred + try: + stmt = insert(Cache).values( + key=key, bank=bank, cluster=cluster_id, data=data + ) + session.execute(stmt) + session.commit() + except sqlalchemy.exc.IntegrityError: + session.rollback() + + stmt = ( + update(Cache) + .where( + Cache.key == key, + Cache.bank == bank, + Cache.cluster == cluster_id, + ) + .values(data=data) + ) + session.execute(stmt) + session.commit() + + +def flush(bank, key=None, cluster_id=None, engine_name=None, **_): + """ + Remove the key from the cache bank with all the key content. + """ + log.trace("flushing %s:%s", bank, key) + + Cache = model_for("Cache", engine_name=engine_name) + with salt.sqlalchemy.Session(engine_name) as session: + stmt = delete(Cache).where(Cache.cluster == cluster_id, Cache.bank == bank) + + if key: + stmt = stmt.where(Cache.key == key) + + session.execute(stmt) + session.commit() + + +def list(bank, cluster_id=None, engine_name=None, **_): + """ + Return an iterable object containing all entries stored in the specified + bank. + """ + log.trace("listing %s, cluster: %s", bank, cluster_id) + + Cache = model_for("Cache", engine_name=engine_name) + with salt.sqlalchemy.ROSession(engine_name) as session: + stmt = ( + select(Cache.key) + .where( + Cache.cluster == cluster_id, + Cache.bank == bank, + or_( + Cache.expires_at.is_(None), + Cache.expires_at >= now(), + ), + ) + .order_by(Cache.key) + ) + keys = session.execute(stmt).scalars().all() + session.commit() + return keys + + +def contains(bank, key, cluster_id=None, engine_name=None, **_): + """ + Checks if the specified bank contains the specified key. + """ + log.trace("check if %s in %s, cluster: %s", key, bank, cluster_id) + + Cache = model_for("Cache", engine_name=engine_name) + with salt.sqlalchemy.ROSession(engine_name) as session: + if key is None: + stmt = select(count()).where( + Cache.cluster == cluster_id, + Cache.bank == bank, + or_( + Cache.expires_at.is_(None), + Cache.expires_at >= now(), + ), + ) + key = session.execute(stmt).scalars().first() + session.commit() + return key > 0 + else: + stmt = select(Cache.key).where( + Cache.cluster == cluster_id, + Cache.bank == bank, + Cache.key == key, + or_( + Cache.expires_at.is_(None), + Cache.expires_at >= now(), + ), + ) + key = session.execute(stmt).scalars().first() + session.commit() + return key is not None + + +def updated(bank, key, cluster_id=None, engine_name=None, **_): + """ + Given a bank and key, return the epoch of the created_at. + """ + log.trace("returning epoch key %s at %s, cluster: %s", key, bank, cluster_id) + + Cache = model_for("Cache", engine_name=engine_name) + with salt.sqlalchemy.ROSession(engine_name) as session: + stmt = select(Cache.created_at).where( + Cache.cluster == cluster_id, + Cache.bank == bank, + Cache.key == key, + ) + created_at = session.execute(stmt).scalars().first() + session.commit() + + if created_at: + return created_at.timestamp() + + +def clean_expired(bank, cluster_id=None, limit=None, engine_name=None, **_): + """ + Delete keys from a bank that has expired keys if the + 'expires_at' column is not null. + """ + log.trace( + "sqlalchemy.clean_expired: removing expired keys at bank %s, cluster: %s", + bank, + cluster_id, + ) + + Cache = model_for("Cache", engine_name=engine_name) + with salt.sqlalchemy.Session(engine_name) as session: + subq = select(Cache.bank, Cache.key, Cache.cluster).where( + Cache.cluster == cluster_id, + Cache.bank == bank, + (Cache.expires_at.isnot(None)) & (Cache.expires_at <= now()), + ) + + if limit: + subq = subq.limit(limit) + + stmt = ( + delete(Cache) + .where(tuple_(Cache.bank, Cache.key, Cache.cluster).in_(subq)) + .returning(Cache.key) + ) + + result = session.execute(stmt) + expired = result.scalars().all() + session.commit() + return expired diff --git a/salt/client/__init__.py b/salt/client/__init__.py index 474c305df0e3..ecbfaa867b8c 100644 --- a/salt/client/__init__.py +++ b/salt/client/__init__.py @@ -1682,13 +1682,10 @@ def get_cli_event_returns( ).connected_ids() if ( self.opts["minion_data_cache"] - and salt.cache.factory(self.opts).contains( - f"minions/{id_}", "data" - ) + and salt.cache.factory(self.opts).contains("grains", id_) and connected_minions and id_ not in connected_minions ): - yield { id_: { "out": "no_return", diff --git a/salt/config/__init__.py b/salt/config/__init__.py index f8aca19fa360..cd6c7a772e0e 100644 --- a/salt/config/__init__.py +++ b/salt/config/__init__.py @@ -1020,6 +1020,40 @@ def _gather_buffer_space(): "keys.cache_driver": (type(None), str), "request_server_ttl": int, "request_server_aes_session": int, + # optional cache driver for pillar cache + "pillar.cache_driver": (type(None), str), + # optional cache driver for eauth_tokens cache + "eauth_tokens.cache_driver": (type(None), str), + # eauth tokens cluster id override + "eauth_tokens.cluster_id": (type(None), str), + # sqlalchemy settings + "sqlalchemy.dsn": (type(None), str), + "sqlalchemy.driver": (type(None), str), + "sqlalchemy.host": (type(None), str), + "sqlalchemy.port": (type(None), str), + "sqlalchemy.user": (type(None), str), + "sqlalchemy.password": (type(None), str), + "sqlalchemy.db": (type(None), str), + "sqlalchemy.engine_opts": (type(None), str), + "sqlalchemy.disable_connection_pool": bool, + "sqlalchemy.ro_dsn": (type(None), str), + "sqlalchemy.ro_host": (type(None), str), + "sqlalchemy.ro_port": (type(None), str), + "sqlalchemy.ro_user": (type(None), str), + "sqlalchemy.ro_password": (type(None), str), + "sqlalchemy.ro_db": (type(None), str), + "sqlalchemy.echo": (type(None), bool), + "sqlalchemy.slow_query_threshold": (type(None), int, float), + "sqlalchemy.slow_connect_threshold": (type(None), int, float), + "sqlalchemy.ro_engine_opts": (type(None), str), + "sqlalchemy.ro_disable_connection_pool": bool, + "sqlalchemy.partman.enabled": bool, + "sqlalchemy.partman.schema": (type(None), str), + "sqlalchemy.partman.interval": str, + "sqlalchemy.partman.jobmon": bool, + "sqlalchemy.partman.retention": (type(None), str), + "returner.sqlalchemy.max_retries": (type(None), int), + "returner.sqlalchemy.retry_delay": (type(None), int), } ) @@ -1330,6 +1364,7 @@ def _gather_buffer_space(): "encryption_algorithm": "OAEP-SHA1", "signing_algorithm": "PKCS1v15-SHA1", "keys.cache_driver": "localfs_key", + "pillar.cache_driver": None, } ) @@ -1687,6 +1722,35 @@ def _gather_buffer_space(): "keys.cache_driver": "localfs_key", "request_server_aes_session": 0, "request_server_ttl": 0, + "pillar.cache_driver": None, + "eauth_tokens.cache_driver": None, + "eauth_tokens.cluster_id": None, + "sqlalchemy.driver": None, + "sqlalchemy.dsn": None, + "sqlalchemy.host": None, + "sqlalchemy.port": None, + "sqlalchemy.user": None, + "sqlalchemy.password": None, + "sqlalchemy.database": None, + "sqlalchemy.engine_opts": None, + "sqlalchemy.disable_connection_pool": False, + "sqlalchemy.slow_connect_threshold": 1, + "sqlalchemy.slow_query_threshold": 1, + "sqlalchemy.ro_dsn": None, + "sqlalchemy.ro_host": None, + "sqlalchemy.ro_port": None, + "sqlalchemy.ro_user": None, + "sqlalchemy.ro_password": None, + "sqlalchemy.ro_database": None, + "sqlalchemy.ro_engine_opts": None, + "sqlalchemy.ro_disable_connection_pool": False, + "sqlalchemy.partman.enabled": False, + "sqlalchemy.partman.retention": None, + "sqlalchemy.partman.schema": "pgpartman", + "sqlalchemy.partman.interval": "weekly", + "sqlalchemy.partman.jobmon": True, + "returner.sqlalchemy.max_retries": 15, + "returner.sqlalchemy.retry_delay": 5, } ) diff --git a/salt/daemons/masterapi.py b/salt/daemons/masterapi.py index 1ab16b36e776..643f4078f2cc 100644 --- a/salt/daemons/masterapi.py +++ b/salt/daemons/masterapi.py @@ -141,10 +141,7 @@ def clean_expired_tokens(opts): Clean expired tokens from the master """ loadauth = salt.auth.LoadAuth(opts) - for tok in loadauth.list_tokens(): - token_data = loadauth.get_tok(tok) - if "expire" not in token_data or token_data.get("expire", 0) < time.time(): - loadauth.rm_token(tok) + loadauth.clean_expired_tokens() def clean_pub_auth(opts): @@ -624,7 +621,7 @@ def _mine_get(self, load, skip_verify=False): minions = _res["minions"] minion_side_acl = {} # Cache minion-side ACL for minion in minions: - mine_data = self.cache.fetch(f"minions/{minion}", "mine") + mine_data = self.cache.fetch("mine", minion) if not isinstance(mine_data, dict): continue for function in functions_allowed: @@ -675,8 +672,8 @@ def _mine(self, load, skip_verify=False): if self.opts.get("minion_data_cache", False) or self.opts.get( "enforce_mine_cache", False ): - cbank = "minions/{}".format(load["id"]) - ckey = "mine" + ckey = load["id"] + cbank = "mine" new_data = load["data"] if not load.get("clear", False): data = self.cache.fetch(cbank, ckey) @@ -694,8 +691,8 @@ def _mine_delete(self, load): if self.opts.get("minion_data_cache", False) or self.opts.get( "enforce_mine_cache", False ): - cbank = "minions/{}".format(load["id"]) - ckey = "mine" + cbank = "mine" + ckey = load["id"] try: data = self.cache.fetch(cbank, ckey) if not isinstance(data, dict): @@ -716,7 +713,7 @@ def _mine_flush(self, load, skip_verify=False): if self.opts.get("minion_data_cache", False) or self.opts.get( "enforce_mine_cache", False ): - return self.cache.flush("minions/{}".format(load["id"]), "mine") + return self.cache.flush("mine", load["id"]) return True def _file_recv(self, load): @@ -789,11 +786,7 @@ def _pillar(self, load): ) data = pillar.compile_pillar() if self.opts.get("minion_data_cache", False): - self.cache.store( - "minions/{}".format(load["id"]), - "data", - {"grains": load["grains"], "pillar": data}, - ) + self.cache.store("grains", load["id"], load["grains"]) if self.opts.get("minion_data_cache_events") is True: self.event.fire_event( {"comment": "Minion data cache refresh"}, diff --git a/salt/exceptions.py b/salt/exceptions.py index 2b04bcf453c9..2a628af0fd59 100644 --- a/salt/exceptions.py +++ b/salt/exceptions.py @@ -350,6 +350,12 @@ class TokenAuthenticationError(SaltException): """ +class TokenExpiredError(SaltException): + """ + Thrown when token is expired + """ + + class SaltDeserializationError(SaltException): """ Thrown when salt cannot deserialize data. diff --git a/salt/key.py b/salt/key.py index aa2b3588eba7..a2fda179e8d4 100644 --- a/salt/key.py +++ b/salt/key.py @@ -474,20 +474,28 @@ def check_minion_cache(self, preserve_minions=None): Optionally, pass in a list of minions which should have their caches preserved. To preserve all caches, set __opts__['preserve_minion_cache'] """ + if self.opts.get("preserve_minion_cache", False): + return + if preserve_minions is None: preserve_minions = [] + preserve_minions = set(preserve_minions) + keys = self.list_keys() - minions = [] - for key, val in keys.items(): - minions.extend(val) - if not self.opts.get("preserve_minion_cache", False): - # we use a new cache instance here as we dont want the key cache - cache = salt.cache.factory(self.opts) - clist = cache.list(self.ACC) - if clist: - for minion in clist: - if minion not in minions and minion not in preserve_minions: - cache.flush(f"{self.ACC}/{minion}") + + for val in keys.values(): + preserve_minions.update(val) + + # we use a new cache instance here as we dont want the key cache + cache = salt.cache.factory(self.opts) + + for bank in ["grains", "pillar"]: + clist = set(cache.list(bank)) + for minion in clist - preserve_minions: + # pillar optionally encodes pillarenv in the key as minion:$pillarenv + if ":" in minion and minion.split(":")[0] in preserve_minions: + continue + cache.flush(bank, minion) def check_master(self): """ diff --git a/salt/master.py b/salt/master.py index 3fb369f41bc2..61aa3979afcd 100644 --- a/salt/master.py +++ b/salt/master.py @@ -730,16 +730,6 @@ def _pre_flight(self): if not self.opts["fileserver_backend"]: errors.append("No fileserver backends are configured") - # Check to see if we need to create a pillar cache dir - if self.opts["pillar_cache"] and not os.path.isdir( - os.path.join(self.opts["cachedir"], "pillar_cache") - ): - try: - with salt.utils.files.set_umask(0o077): - os.mkdir(os.path.join(self.opts["cachedir"], "pillar_cache")) - except OSError: - pass - if self.opts.get("git_pillar_verify_config", True): try: git_pillars = [ @@ -1801,11 +1791,8 @@ def _pillar(self, load): data = pillar.compile_pillar() self.fs_.update_opts() if self.opts.get("minion_data_cache", False): - self.masterapi.cache.store( - "minions/{}".format(load["id"]), - "data", - {"grains": load["grains"], "pillar": data}, - ) + self.masterapi.cache.store("grains", load["id"], load["grains"]) + if self.opts.get("minion_data_cache_events") is True: self.event.fire_event( {"Minion data cache refresh": load["id"]}, diff --git a/salt/pillar/__init__.py b/salt/pillar/__init__.py index 3812de861668..1cb102c7c3cd 100644 --- a/salt/pillar/__init__.py +++ b/salt/pillar/__init__.py @@ -4,24 +4,25 @@ import collections import copy +import datetime import fnmatch import logging -import os import sys import time import traceback import tornado.gen +import salt.cache import salt.channel.client import salt.fileclient import salt.loader import salt.minion import salt.utils.args -import salt.utils.cache import salt.utils.crypt import salt.utils.data import salt.utils.dictupdate +import salt.utils.master import salt.utils.url from salt.exceptions import SaltClientError from salt.template import compile_template @@ -69,8 +70,7 @@ def get_pillar( ptype = {"remote": RemotePillar, "local": Pillar}.get(file_client, Pillar) # If local pillar and we're caching, run through the cache system first - log.debug("Determining pillar cache") - if opts["pillar_cache"]: + if opts.get("pillar_cache") or opts.get("minion_data_cache"): log.debug("get_pillar using pillar cache with ext: %s", ext) return PillarCache( opts, @@ -393,144 +393,6 @@ def __del__(self): # pylint: enable=W1701 -class PillarCache: - """ - Return a cached pillar if it exists, otherwise cache it. - - Pillar caches are structed in two diminensions: minion_id with a dict of - saltenvs. Each saltenv contains a pillar dict - - Example data structure: - - ``` - {'minion_1': - {'base': {'pilar_key_1' 'pillar_val_1'} - } - """ - - # TODO ABC? - def __init__( - self, - opts, - grains, - minion_id, - saltenv, - ext=None, - functions=None, - pillar_override=None, - pillarenv=None, - extra_minion_data=None, - clean_cache=False, - ): - # Yes, we need all of these because we need to route to the Pillar object - # if we have no cache. This is another refactor target. - - # Go ahead and assign these because they may be needed later - self.opts = opts - self.grains = grains - self.minion_id = minion_id - self.ext = ext - self.functions = functions - self.pillar_override = pillar_override - self.pillarenv = pillarenv - self.clean_cache = clean_cache - self.extra_minion_data = extra_minion_data - - if saltenv is None: - self.saltenv = "base" - else: - self.saltenv = saltenv - - # Determine caching backend - self.cache = salt.utils.cache.CacheFactory.factory( - self.opts["pillar_cache_backend"], - self.opts["pillar_cache_ttl"], - minion_cache_path=self._minion_cache_path(minion_id), - ) - - def _minion_cache_path(self, minion_id): - """ - Return the path to the cache file for the minion. - - Used only for disk-based backends - """ - return os.path.join(self.opts["cachedir"], "pillar_cache", minion_id) - - def fetch_pillar(self): - """ - In the event of a cache miss, we need to incur the overhead of caching - a new pillar. - """ - log.debug("Pillar cache getting external pillar with ext: %s", self.ext) - fresh_pillar = Pillar( - self.opts, - self.grains, - self.minion_id, - self.saltenv, - ext=self.ext, - functions=self.functions, - pillar_override=None, - pillarenv=self.pillarenv, - extra_minion_data=self.extra_minion_data, - ) - return fresh_pillar.compile_pillar() - - def clear_pillar(self): - """ - Clear the cache - """ - self.cache.clear() - - return True - - def compile_pillar(self, *args, **kwargs): # Will likely just be pillar_dirs - if self.clean_cache: - self.clear_pillar() - log.debug( - "Scanning pillar cache for information about minion %s and pillarenv %s", - self.minion_id, - self.pillarenv, - ) - if self.opts["pillar_cache_backend"] == "memory": - cache_dict = self.cache - else: - cache_dict = self.cache._dict - - log.debug("Scanning cache: %s", cache_dict) - # Check the cache! - if self.minion_id in self.cache: # Keyed by minion_id - # TODO Compare grains, etc? - if self.pillarenv in self.cache[self.minion_id]: - # We have a cache hit! Send it back. - log.debug( - "Pillar cache hit for minion %s and pillarenv %s", - self.minion_id, - self.pillarenv, - ) - return self.cache[self.minion_id][self.pillarenv] - else: - # We found the minion but not the env. Store it. - fresh_pillar = self.fetch_pillar() - - minion_cache = self.cache[self.minion_id] - minion_cache[self.pillarenv] = fresh_pillar - self.cache[self.minion_id] = minion_cache - - log.debug( - "Pillar cache miss for pillarenv %s for minion %s", - self.pillarenv, - self.minion_id, - ) - return fresh_pillar - else: - # We haven't seen this minion yet in the cache. Store it. - fresh_pillar = self.fetch_pillar() - self.cache[self.minion_id] = {self.pillarenv: fresh_pillar} - log.debug("Pillar cache miss for minion %s", self.minion_id) - log.debug("Current pillar cache: %s", cache_dict) # FIXME hack! - return fresh_pillar - - class Pillar: """ Read over the pillar top files and render the pillar data @@ -1398,3 +1260,100 @@ class AsyncPillar(Pillar): def compile_pillar(self, ext=True): ret = super().compile_pillar(ext=ext) raise tornado.gen.Return(ret) + + +class PillarCache(Pillar): + """ + Return a cached pillar if it exists, otherwise cache it. + + Pillar caches are structed in two diminensions: minion_id with a dict of + saltenvs. Each saltenv contains a pillar dict + + Example data structure: + + ``` + {'minion_1': + {'base': {'pilar_key_1' 'pillar_val_1'} + } + """ + + def __init__( + self, + *args, + clean_cache=False, + pillar_override=None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.clean_cache = clean_cache + self.pillar_override = pillar_override or {} + self.cache = salt.cache.factory( + self.opts, + driver=self.opts["pillar.cache_driver"], + expires=self.opts["pillar_cache_ttl"], + ) + + @property + def pillar_key(self): + if not self.opts["pillarenv"]: + return self.minion_id + else: + return f"{self.minion_id}:{self.opts['pillarenv']}" + + def cached_pillar(self): + """ + Return the cached pillar if it exists, or None + """ + return self.cache.fetch("pillar", self.pillar_key) + + def clear_pillar(self): + """ + Clear the pillar cache, if it exists + """ + return self.cache.flush("pillar", self.pillar_key) + + def compile_pillar(self, *args, **kwargs): # Will likely just be pillar_dirs + # matching to consume the same dataset + if self.clean_cache: + self.clear_pillar() + log.debug( + "Scanning pillar cache for information about minion %s and pillarenv %s", + self.minion_id, + self.opts["pillarenv"], + ) + + # if MDC is on, but not pillar cache, we never read the cache, only write to it + if self.opts["minion_data_cache"] and not self.opts["pillar_cache"]: + pillar_data = {} + else: + # Check the cache! + pillar_data = self.cached_pillar() + + if pillar_data: + log.debug( + "Pillar cache hit for minion %s and pillarenv %s", + self.minion_id, + self.opts["pillarenv"], + ) + else: + # We found the minion but not the env. Store it. + log.debug( + "Pillar cache miss for pillarenv %s for minion %s", + self.opts["pillarenv"], + self.minion_id, + ) + pillar_data = super().compile_pillar(*args, **kwargs) + + self.cache.store("pillar", self.pillar_key, pillar_data) + + # we dont want the pillar_override baked into the cached compile_pillar from above + if self.pillar_override: + pillar_data = merge( + pillar_data, + self.pillar_override, + self.merge_strategy, + self.opts.get("renderer", "yaml"), + self.opts.get("pillar_merge_lists", False), + ) + + return pillar_data diff --git a/salt/returners/sqlalchemy.py b/salt/returners/sqlalchemy.py new file mode 100644 index 000000000000..f11d3854ef78 --- /dev/null +++ b/salt/returners/sqlalchemy.py @@ -0,0 +1,314 @@ +""" +Returner plugin for SQLAlchemy. +""" + +import functools +import logging +import time +from collections.abc import Callable +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +import salt.exceptions +import salt.sqlalchemy +import salt.utils.jid +import salt.utils.job +from salt.sqlalchemy import model_for + +try: + import sqlalchemy.exc + from sqlalchemy import BigInteger, cast, delete, func, insert, literal, select +except ImportError: + pass + +if TYPE_CHECKING: + __opts__ = {} + __context__ = {} + __salt__: dict[str, Callable] + + +log = logging.getLogger(__name__) + +__virtualname__ = "sqlalchemy" + + +def __virtual__(): + """ + Ensure that SQLAlchemy ORM is configured and ready. + """ + if not salt.sqlalchemy.orm_configured(): + salt.sqlalchemy.configure_orm(__opts__) + return __virtualname__ + + +def retry_on_failure(f): + """ + Simple decorator to retry on OperationalError/InterfaceError + """ + + @functools.wraps(f) + def wrapper(*args, **kwargs): + """ + Wrapper function that implements retry logic for database operations. + """ + tries = __opts__["returner.sqlalchemy.max_retries"] + for _ in range(0, tries): + try: + return f(*args, **kwargs) + except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.InterfaceError): + time.sleep(__opts__["returner.sqlalchemy.retry_delay"]) + + return wrapper + + +@retry_on_failure +def returner(ret): + """ + Return data to returns in database + """ + Returns = model_for("Returns", engine_name=__context__.get("engine_name")) + with salt.sqlalchemy.Session(__context__.get("engine_name")) as session: + record = { + "cluster": __opts__["cluster_id"], + "fun": ret["fun"], + "jid": ret["jid"], + "id": ret["id"], + "success": ret.get("success", False), + "ret": ret, + } + + session.execute(insert(Returns), [record]) + session.commit() + + +@retry_on_failure +def event_return(evts, tries=None): + """ + Return event to database server + + Requires that configuration be enabled via 'event_return' + option in master config. + """ + Events = model_for("Events", engine_name=__context__.get("engine_name")) + with salt.sqlalchemy.Session(__context__.get("engine_name")) as session: + records = [] + for evt in evts: + record = { + "tag": evt.get("tag", ""), + "data": evt.get("data", ""), + "cluster": __opts__["cluster_id"], + "master_id": __opts__["id"], + } + + try: + record["created_at"] = evt["data"]["_stamp"] + except (KeyError, TypeError): + pass + + records.append(record) + + session.execute(insert(Events), records) + session.commit() + + +@retry_on_failure +def save_load(jid, load, minions=None): + """ + Save the load to the specified jid id + """ + if not minions: + minions = [] + + Jids = model_for("Jids", engine_name=__context__.get("engine_name")) + with salt.sqlalchemy.Session(__context__.get("engine_name")) as session: + record = { + "jid": jid, + "load": load, + "minions": minions, + "cluster": __opts__["cluster_id"], + } + + session.execute(insert(Jids), [record]) + session.commit() + + +def save_minions(jid, minions, syndic_id=None): # pylint: disable=unused-argument + """ + Included for API consistency + """ + + +def get_load(jid): + """ + Return the load data that marks a specified jid + """ + Jids = model_for("Jids", engine_name=__context__.get("engine_name")) + with salt.sqlalchemy.ROSession(__context__.get("engine_name")) as session: + # Use to_jsonb for jsonb conversion in Postgres + stmt = select(Jids).where(Jids.jid == str(jid)) + result = session.execute(stmt).first() + load = {} + if result: + jid = result[0] + load = jid.load + load["Minions"] = jid.minions or [] + session.commit() + return load + + +def get_jid(jid): + """ + Return the information returned when the specified job id was executed + """ + Returns = model_for("Returns", engine_name=__context__.get("engine_name")) + with salt.sqlalchemy.ROSession(__context__.get("engine_name")) as session: + stmt = select(Returns.id, Returns.ret).where(Returns.jid == str(jid)) + results = session.execute(stmt).all() + + ret = {} + for result in results: + ret[result.id] = result.ret + + session.commit() + + return ret + + +def get_fun(fun): + """ + Return a dict of the last function called for all minions + """ + # this could be done with a separate table, but why? + raise salt.exceptions.SaltException( + "This is too costly to run via database at the moment, left unimplemented" + ) + + +def get_jids(last=None): + """ + Return a list of all job ids + """ + # this could be done, but why would you ever? + raise salt.exceptions.SaltException("This is too costly to run, left unimplemented") + + +def get_minions(): + """ + Return a list of minions + """ + raise salt.exceptions.SaltException("Use salt.util.minions._all_minions instead") + + +def prep_jid( + nocache=False, passed_jid=None, retry_count=0 +): # pylint: disable=unused-argument + """ + Do any work necessary to prepare a JID, including sending a custom id + Using a recursive retry approach with advisory locks to avoid table contention + Locking guaruntees a global unique jid on postgresql and mysql. + """ + # this will return false for "req" salt-call jid + if salt.utils.jid.is_jid(passed_jid): + return passed_jid + + # generate a candidate JID + jid = salt.utils.jid.gen_jid(__opts__) + + try: + with salt.sqlalchemy.Session(__context__.get("engine_name")) as session: + if session.bind.dialect.name == "postgresql": + jid_expr = func.to_char(func.clock_timestamp(), "YYYYMMDDHH24MISSUS") + + lock_expr = func.pg_try_advisory_xact_lock( + cast(func.abs(func.hashtext(jid_expr)), BigInteger) + ) + elif session.bind.dialect.name == "mysql": + # Build the DATE_FORMAT(NOW(3), '%Y%m%d%H%i%s%f') expression + jid_expr = func.DATE_FORMAT(func.NOW(3), literal("%Y%m%d%H%i%s%f")) + + # Apply GET_LOCK(jid_expr, 0) for non-blocking lock attempt + lock_expr = func.GET_LOCK(jid_expr, literal(0)) + elif session.bind.dialect.name == "sqlite": + # sqlite doesn't require locking + return jid + + else: + raise salt.exceptions.SaltException("Unrecognized dialect") + + stmt = select(lock_expr.label("locked"), jid_expr.label("jid")) + + result = session.execute(stmt).one() + locked, jid = result.locked, result.jid + + if locked: + # lock acquired, return the generated jid + if session.bind.dialect.name == "mysql": + # mysql needs a manual lock release + stmt = select(func.RELEASE_LOCK(literal(jid)).label("released")) + session.execute(stmt).one() + + session.commit() + return jid + else: + # lock contention, retry with a new jid + if retry_count < 5: + return prep_jid( + nocache=nocache, passed_jid=None, retry_count=retry_count + 1 + ) + else: + log.warning( + "Maximum retry attempts reached for prep_jid lock acquisition" + ) + except Exception: # pylint: disable=broad-except + log.exception( + "Something went wrong trying to prep jid (unable to acquire lock?), falling back to salt.utils.jid.gen_jid()" + ) + + # we failed to get a unique jid, just return one + # without asserting global uniqueness + return salt.utils.jid.gen_jid(__opts__) + + +def clean_old_jobs(): + """ + Called in the master's event loop every loop_interval. Removes data older + than the configured keep_jobs_seconds setting from the database tables. + When configured, uses partitioning for efficient data lifecycle management. + + Returns: + bool: True if cleaning was performed, None if no action was taken + """ + keep_jobs_seconds = int(salt.utils.job.get_keep_jobs_seconds(__opts__)) + if keep_jobs_seconds > 0: + if __opts__.get("archive_jobs", False): + raise salt.exceptions.SaltException( + "This is unimplemented. Use pg_partman or other native partition handling" + ) + else: + Jids, Returns, Events = model_for( + "Jids", "Returns", "Events", engine_name=__context__.get("engine_name") + ) + ttl = datetime.now(timezone.utc) - timedelta(seconds=keep_jobs_seconds) + + with salt.sqlalchemy.Session(__context__.get("engine_name")) as session: + stmt = delete(Jids).where( + Jids.created_at < ttl, Jids.cluster == __opts__["cluster_id"] + ) + session.execute(stmt) + session.commit() + + with salt.sqlalchemy.Session(__context__.get("engine_name")) as session: + stmt = delete(Returns).where( + Returns.created_at < ttl, Returns.cluster == __opts__["cluster_id"] + ) + session.execute(stmt) + session.commit() + + with salt.sqlalchemy.Session(__context__.get("engine_name")) as session: + stmt = delete(Events).where( + Events.created_at < ttl, Events.cluster == __opts__["cluster_id"] + ) + session.execute(stmt) + session.commit() + + return True diff --git a/salt/roster/cache.py b/salt/roster/cache.py index 8545418cc23b..d135757cb231 100644 --- a/salt/roster/cache.py +++ b/salt/roster/cache.py @@ -177,7 +177,7 @@ def _load_minion(minion_id, cache): 6: sorted(ipaddress.IPv6Address(addr) for addr in grains.get("ipv6", [])), } - mine = cache.fetch(f"minions/{minion_id}", "mine") + mine = cache.fetch("mine", minion_id) return grains, pillar, addrs, mine diff --git a/salt/runners/cache.py b/salt/runners/cache.py index 89309af690f1..217f1a3b5b49 100644 --- a/salt/runners/cache.py +++ b/salt/runners/cache.py @@ -9,7 +9,6 @@ import salt.cache import salt.config import salt.fileserver.gitfs -import salt.payload import salt.pillar.git_pillar import salt.runners.winrepo import salt.utils.args @@ -466,3 +465,58 @@ def flush(bank, key=None, cachedir=None): except TypeError: cache = salt.cache.Cache(__opts__) return cache.flush(bank, key) + + +def migrate(target=None, bank=None): + """ + Migrate cache contents from the current configured/running cache to another. + + .. note:: This will NOT migrate ttl values, if set in the source cache. + + target + The configured, but not enabled (via cache/master_job_cache/cache_driver config) cache backend + + bank + If you only want to migrate a specific bank (instead of all), the name of the bank(s), csv delimited. + + CLI Examples: + + .. code-block:: bash + + salt-run cache.migrate target=sqlalchemy + salt-run cache.migrate target=redis bank=keys,denied_keys,master_keys + """ + # if specific drivers are not set for these, Cache will just fall back to base cache + key_cache = salt.cache.Cache(__opts__, driver=__opts__["keys.cache_driver"]) + token_cache = salt.cache.Cache( + __opts__, driver=__opts__["eauth_tokens.cache_driver"] + ) + mdc_cache = salt.cache.Cache(__opts__, driver=__opts__["pillar.cache_driver"]) + base_cache = salt.cache.Cache(__opts__) + dst_cache = salt.cache.Cache(__opts__, driver=target) + + # unfortunately there is no 'list all cache banks' in the cache api + banks = { + "keys": key_cache, + "master_keys": key_cache, + "denied_keys": key_cache, + "tokens": token_cache, + "pillar": mdc_cache, + "grains": base_cache, + "mine": base_cache, + } + + if bank: + bank = bank.split(",") + else: + bank = banks.keys() + + for _bank in bank: + cache = banks[_bank] + keys = cache.list(_bank) + log.info("bank %s: migrating %s keys", _bank, len(keys)) + for key in keys: + value = cache.fetch(_bank, key) + dst_cache.store(_bank, key, value) + + return True diff --git a/salt/runners/pillar.py b/salt/runners/pillar.py index 1968d6be610b..27b8132a2931 100644 --- a/salt/runners/pillar.py +++ b/salt/runners/pillar.py @@ -96,13 +96,14 @@ def show_pillar(minion="*", **kwargs): pillar = salt.pillar.Pillar(__opts__, grains, id_, saltenv, pillarenv=pillarenv) - compiled_pillar = pillar.compile_pillar() - return compiled_pillar + return pillar.compile_pillar() def clear_pillar_cache(minion="*", **kwargs): """ - Clears the cached values when using pillar_cache + Clears the cached values when using pillar_cache. + Returns True on success. + Returns False if pillar_cache or minion_data_cache are not enabled. .. versionadded:: 3003 @@ -116,7 +117,7 @@ def clear_pillar_cache(minion="*", **kwargs): """ - if not __opts__.get("pillar_cache"): + if not (__opts__.get("pillar_cache") or __opts__.get("minion_data_cache")): log.info("The pillar_cache is set to False or not enabled.") return False @@ -126,7 +127,6 @@ def clear_pillar_cache(minion="*", **kwargs): pillarenv = kwargs.pop("pillarenv", None) saltenv = kwargs.pop("saltenv", "base") - pillar_cache = {} for tgt in ret.get("minions", []): id_, grains, _ = salt.utils.minions.get_minion_data(tgt, __opts__) @@ -141,15 +141,7 @@ def clear_pillar_cache(minion="*", **kwargs): ) pillar.clear_pillar() - if __opts__.get("pillar_cache_backend") == "memory": - _pillar_cache = pillar.cache - else: - _pillar_cache = pillar.cache._dict - - if tgt in _pillar_cache and _pillar_cache[tgt]: - pillar_cache[tgt] = _pillar_cache.get(tgt).get(pillarenv) - - return pillar_cache + return True def show_pillar_cache(minion="*", **kwargs): @@ -168,7 +160,7 @@ def show_pillar_cache(minion="*", **kwargs): """ - if not __opts__.get("pillar_cache"): + if not (__opts__.get("pillar_cache") or __opts__.get("minion_data_cache")): log.info("The pillar_cache is set to False or not enabled.") return False @@ -179,25 +171,14 @@ def show_pillar_cache(minion="*", **kwargs): saltenv = kwargs.pop("saltenv", "base") pillar_cache = {} + for tgt in ret.get("minions", []): id_, grains, _ = salt.utils.minions.get_minion_data(tgt, __opts__) - - for key in kwargs: - grains[key] = kwargs[key] - - if grains is None: - grains = {"fqdn": minion} - + # we could use the pillar from above, but its not pillarenv aware pillar = salt.pillar.PillarCache( __opts__, grains, id_, saltenv, pillarenv=pillarenv - ) - - if __opts__.get("pillar_cache_backend") == "memory": - _pillar_cache = pillar.cache - else: - _pillar_cache = pillar.cache._dict - - if tgt in _pillar_cache and _pillar_cache[tgt]: - pillar_cache[tgt] = _pillar_cache[tgt].get(pillarenv) + ).cached_pillar() + if pillar: + pillar_cache[tgt] = pillar return pillar_cache diff --git a/salt/runners/sqlalchemy.py b/salt/runners/sqlalchemy.py new file mode 100644 index 000000000000..3109284beb88 --- /dev/null +++ b/salt/runners/sqlalchemy.py @@ -0,0 +1,51 @@ +""" +Salt runner for managing SQLAlchemy database schema. + +Provides runners to create or drop all tables using the current SQLAlchemy configuration. +""" + +import logging + +import salt.sqlalchemy + +log = logging.getLogger(__name__) + + +def __virtual__(): + """ + Only load if SQLAlchemy ORM can be configured. + Returns True if successful, False otherwise. + """ + try: + salt.sqlalchemy.configure_orm(__opts__) + return True + except Exception: # pylint: disable=broad-exception-caught + return False + + +def drop_all(target_engine=None): + """ + Drop all tables in the configured SQLAlchemy database. + + Args: + target_engine (str, optional): Name of the engine to use. If None, default target is used. + """ + with salt.sqlalchemy.Session(target_engine) as session: + salt.sqlalchemy.drop_all() + session.commit() + + return True + + +def create_all(target_engine=None): + """ + Create all tables in the configured SQLAlchemy database. + + Args: + target_engine (str, optional): Name of the engine to use. If None, default target is used. + """ + with salt.sqlalchemy.Session(target_engine) as session: + salt.sqlalchemy.create_all() + session.commit() + + return True diff --git a/salt/sqlalchemy/__init__.py b/salt/sqlalchemy/__init__.py new file mode 100644 index 000000000000..cad765d42d86 --- /dev/null +++ b/salt/sqlalchemy/__init__.py @@ -0,0 +1,448 @@ +import base64 +import json +import logging +import os +import time + +try: + import sqlalchemy + import sqlalchemy.engine.url + from sqlalchemy import event, exc + from sqlalchemy.orm import scoped_session, sessionmaker + from sqlalchemy.pool import NullPool + + HAS_SQLA = True +except ImportError: + HAS_SQLA = False + +import salt.config +import salt.exceptions + +log = logging.getLogger(__name__) + +SQLA_DEFAULT_OPTS = { + k[len("sqlalchemy.") :]: v + for (k, v) in salt.config.DEFAULT_MASTER_OPTS.items() + if k.startswith("sqlalchemy") +} + +ENGINE_REGISTRY = {} + + +def orm_configured(): + """ + Check if the ORM is configured. + + Returns: + bool: True if the ORM has been configured, False otherwise + """ + return bool(ENGINE_REGISTRY) + + +def _make_engine(opts, prefix=None): + """ + Create and configure a SQLAlchemy engine instance. + + Creates a SQLAlchemy engine with appropriate connection settings, + serialization functions, and event listeners for monitoring performance. + Supports both direct DSN strings and individual connection parameters. + + Args: + opts (dict): Configuration options dictionary with connection parameters + prefix (str, optional): Prefix for connection parameters, used for read-only connections + + Returns: + Engine: Configured SQLAlchemy engine instance + + Raises: + SaltConfigurationError: When required configuration parameters are missing + """ + if not prefix: + prefix = "" + + url = None + + if opts.get(f"{prefix}dsn"): + url = opts[f"{prefix}dsn"] + + _opts = {} + + for kw in ["drivername", "host", "username", "password", "database", "port"]: + if opts.get(f"{prefix}{kw}"): + _opts[kw] = opts[f"{prefix}{kw}"] + elif prefix and kw == "drivername" and "drivername" in opts: + # if we are ro, just take the non _ro drivername if unset. it should be the same + _opts[kw] = opts["drivername"] + elif not url: + raise salt.exceptions.SaltConfigurationError( + f"Missing required config opts parameter 'sqlalchemy.{prefix}{kw}'" + ) + + for kw in ["sslmode", "sslcert", "sslkey", "sslrootcert", "sslcrl"]: + if f"sqlalchemy.{prefix}{kw}" in opts: + _opts.setdefault("query", {})[kw] = opts[f"{prefix}{kw}"] + + if url and _opts: + raise salt.exceptions.SaltConfigurationError( + "Can define dsn, or individual attributes, but not both" + ) + elif not url: + _opts.setdefault("query", {}) + url = sqlalchemy.engine.url.URL(**_opts) + + # TODO: other pool types could be useful if we were to use threading + engine_opts = {} + if opts.get(f"{prefix}engine_opts"): + try: + engine_opts = json.loads(opts[f"{prefix}engine_opts"]) + except json.JSONDecodeError: + log.error( + f"Failed to deserialize {prefix}engine_opts value (%s). Did you make a typo?", + opts[f"{prefix}engine_opts"], + ) + + if opts.get(f"{prefix}disable_connection_pool"): + engine_opts["poolclass"] = NullPool + + _engine = sqlalchemy.create_engine( + url, + json_serializer=_serialize, + json_deserializer=_deserialize, + **engine_opts, + ) + + if opts.get(f"{prefix}schema"): + _engine = _engine.execution_options( + schema_translate_map={None: opts[f"{prefix}schema"]} + ) + + if opts.get("echo") or os.environ.get("SQLALCHEMY_ECHO"): + logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + + @event.listens_for(_engine, "do_connect") + def do_connect(dialect, connection_record, cargs, cparams): + connection_record.info["pid"] = os.getpid() + connection_record.info["connect_start_time"] = time.time() + + @event.listens_for(_engine, "connect") + def connect(dbapi_connection, connection_record): + total = time.time() - connection_record.info["connect_start_time"] + if total >= opts["slow_connect_threshold"]: + log.error( + "%s Slow database connect exceeded threshold (%s s); total time: %f s", + _engine, + opts["slow_query_threshold"], + total, + ) + + if _engine.dialect.name == "sqlite": + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA journal_mode=WAL;") + cursor.close() + + @event.listens_for(_engine, "checkout") + def checkout(dbapi_connection, connection_record, connection_proxy): + pid = os.getpid() + if connection_record.info["pid"] != pid: + connection_record.dbapi_connection = connection_proxy.dbapi_connection = ( + None + ) + raise exc.DisconnectionError( + "Connection record belongs to pid %s, " + "attempting to check out in pid %s" + % (connection_record.info["pid"], pid) + ) + + @event.listens_for(_engine, "before_cursor_execute") + def before_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + conn.info.setdefault("query_start_time", []).append(time.time()) + + @event.listens_for(_engine, "after_cursor_execute") + def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): + total = time.time() - conn.info["query_start_time"].pop(-1) + if ( + total >= opts["slow_query_threshold"] + and "REFRESH MATERIALIZED" not in statement.upper() + ): + log.error( + "%s Slow query exceeded threshold (%s s); total time: %f s,\nStatement: %s", + _engine, + opts["slow_query_threshold"], + total, + statement, + ) + + return _engine + + +def configure_orm(opts): + """ + Configure the SQLAlchemy ORM with the provided options. + + Initializes SQLAlchemy engine(s) using the provided Salt configuration. + Creates both read-write and read-only connections when configured. + Registers models for the created engines. + + Args: + opts (dict): Salt configuration options dictionary + + Raises: + SaltException: When SQLAlchemy dependency is missing + SaltConfigurationError: When required configuration is missing + """ + if not HAS_SQLA: + raise salt.exceptions.SaltException("Missing sqlalchemy dependency") + + # this only needs to run once + if ENGINE_REGISTRY: + return + + # many engines can be defined, in addition to the default + engine_configs = {} + for key in opts.keys(): + if not key.startswith("sqlalchemy."): + continue + + _, _key = key.split(".", maxsplit=1) + + if _key.count(".") == 0 or _key.startswith("ddl") or _key.startswith("partman"): + name = "default" + engine_configs.setdefault(name, {})[_key] = opts[key] + else: + name, _key = _key.split(".", maxsplit=1) + engine_configs.setdefault(name, {})[_key] = opts[key] + + # we let certain configs in the 'default' namespace apply to all engines + for global_default in ["echo", "slow_query_threshold"]: + if global_default in engine_configs.get("default", {}): + SQLA_DEFAULT_OPTS[global_default] = engine_configs["default"][ + global_default + ] + + if not engine_configs: + raise salt.exceptions.SaltConfigurationError( + "Expected sqlalchemy configuration but got none." + ) + + for name, defined_config in engine_configs.items(): + engine_config = {**SQLA_DEFAULT_OPTS, **defined_config} + + _engine = _make_engine(engine_config, prefix=None) + _Session = scoped_session( + sessionmaker( + autocommit=False, + bind=_engine, + ) + ) + + # if configured for a readonly pool separate from the readwrite, configure it + # else just alias it to the main pool + if engine_config["ro_dsn"] or engine_config["ro_host"]: + _ro_engine = _make_engine(engine_config, prefix="ro_") + _ROSession = scoped_session( + sessionmaker( + autocommit=False, + bind=_ro_engine, + ) + ) + else: + _ro_engine = _engine + _ROSession = _Session + + ENGINE_REGISTRY[name] = {} + ENGINE_REGISTRY[name]["engine"] = _engine + ENGINE_REGISTRY[name]["session"] = _Session + ENGINE_REGISTRY[name]["ro_engine"] = _ro_engine + ENGINE_REGISTRY[name]["ro_session"] = _ROSession + + # the sqla models are behavior dependent on opts config + # note this must be a late import ; salt.sqlalchemy must be importable + # even if sqlalchemy isn't installed + from salt.sqlalchemy.models import populate_model_registry + + populate_model_registry(engine_config, name, _engine) + + +def dispose_orm(): + """ + Clean up and dispose of all SQLAlchemy engine and session resources. + + Properly disposes all engine connection pools and removes session instances + from the registry to prevent resource leaks. + + Raises: + SaltException: When SQLAlchemy dependency is missing + """ + if not HAS_SQLA: + raise salt.exceptions.SaltException("Missing sqlalchemy dependency") + + if not ENGINE_REGISTRY: + return + + for engine in list(ENGINE_REGISTRY): + log.debug("Disposing DB connection pool for %s", engine) + + ENGINE_REGISTRY[engine]["engine"].dispose() + ENGINE_REGISTRY[engine]["session"].remove() + + if "ro_engine" in ENGINE_REGISTRY[engine]: + ENGINE_REGISTRY[engine]["ro_engine"].dispose() + if "ro_session" in ENGINE_REGISTRY[engine]: + ENGINE_REGISTRY[engine]["ro_session"].remove() + + ENGINE_REGISTRY.pop(engine) + + +def reconfigure_orm(opts): + """ + Reconfigure the SQLAlchemy ORM with new options. + + Disposes of existing engine resources and then reconfigures the ORM + with the provided options. This is useful for refreshing connections + or updating configuration. + + Args: + opts (dict): Salt configuration options dictionary + """ + dispose_orm() + configure_orm(opts) + + +def Session(name=None): + """ + Get a SQLAlchemy session for database operations. + + Creates and returns a new session from the appropriate session factory + in the engine registry. Used for read-write operations on the database. + + Args: + name (str, optional): Name of the engine to use. Defaults to "default". + + Returns: + Session: A configured SQLAlchemy session object + + Raises: + SaltInvocationError: When the requested engine name is not configured + """ + if not name: + name = "default" + try: + return ENGINE_REGISTRY[name]["session"]() + except KeyError: + raise salt.exceptions.SaltInvocationError( + f"ORM not configured for '{name}' yet. Did you forget to call salt.sqlalchemy.configure_orm?" + ) + + +def ROSession(name=None): + """ + Get a read-only SQLAlchemy session for database operations. + + Creates and returns a new session from the read-only session factory + in the engine registry. Falls back to standard session if read-only + session is not available. + + Args: + name (str, optional): Name of the engine to use. Defaults to "default". + + Returns: + Session: A configured read-only SQLAlchemy session object + + Raises: + SaltInvocationError: When the requested engine name is not configured + """ + if not name: + name = "default" + try: + try: + return ENGINE_REGISTRY[name]["ro_session"]() + except KeyError: + return ENGINE_REGISTRY[name]["session"]() + except KeyError: + raise salt.exceptions.SaltInvocationError( + f"ORM not configured for '{name}' yet. Did you forget to call salt.sqlalchemy.configure_orm?" + ) + + +def _serialize(data): + """ + Serialize and base64 encode the data + Also remove NUL bytes because they make postgres jsonb unhappy + """ + # handle bytes input + if isinstance(data, bytes): + data = base64.b64encode(data).decode("ascii") + data = {"_base64": data} + + encoded = json.dumps(data).replace("\\u0000", "") + + return encoded + + +def _deserialize(data): + """ + Deserialize and base64 decode the data + """ + inflated = json.loads(data) + + if isinstance(inflated, dict) and "_base64" in inflated: + inflated = base64.b64decode(inflated["_base64"]) + + return inflated + + +def model_for(*args, **kwargs): + """ + Get SQLAlchemy models by name from the registry. + + Acts as a pass-through to the model getter in the models module. + + Args: + *args: Model names to retrieve + **kwargs: Additional arguments to pass to the model getter + + Returns: + Model or tuple of models: The requested SQLAlchemy model(s) + + Raises: + SaltInvocationError: When SQLAlchemy is not installed + """ + # pass through to the model getter + # it is important this stay in this file as importing + # salt.sqlalchemy.models requires sqlalchemy be installed + if HAS_SQLA: + from salt.sqlalchemy.models import model_for + + return model_for(*args, **kwargs) + else: + raise salt.exceptions.SaltInvocationError("SQLAlchemy must be installed") + + +def drop_all(engine_name=None): + """ + Drop all tables in the database. + + Removes all tables defined in the SQLAlchemy metadata for the specified engine. + + Args: + engine_name (str, optional): Name of the engine to use. Defaults to "default". + """ + with Session(engine_name) as session: + Base = model_for("Base", engine_name=engine_name) + Base.metadata.drop_all(session.get_bind()) + + +def create_all(engine_name=None): + """ + Create all tables in the database. + + Creates all tables defined in the SQLAlchemy metadata for the specified engine. + + Args: + engine_name (str, optional): Name of the engine to use. Defaults to "default". + """ + with Session(engine_name) as session: + Base = model_for("Base", engine_name=engine_name) + Base.metadata.create_all(session.get_bind()) diff --git a/salt/sqlalchemy/models.py b/salt/sqlalchemy/models.py new file mode 100644 index 000000000000..b903c645bfbc --- /dev/null +++ b/salt/sqlalchemy/models.py @@ -0,0 +1,406 @@ +import logging +from datetime import datetime, timezone + +try: + from sqlalchemy import DDL, JSON, DateTime, Index, String, Text, event + from sqlalchemy.dialects.mysql import JSON as MySQL_JSON + from sqlalchemy.dialects.postgresql import JSONB + from sqlalchemy.engine import Dialect + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.orm import DeclarativeBase, Mapped, configure_mappers, mapped_column + from sqlalchemy.sql.expression import FunctionElement + from sqlalchemy.types import DateTime as t_DateTime + from sqlalchemy.types import TypeDecorator +except ImportError: + # stubs so below passively compiles even if sqlalchemy isn't installed + # all consuming code is gated by salt.sqlalchemy.configure_orm + TypeDecorator = object + FunctionElement = object + Dialect = None + + def DateTime(timezone=None): + pass + + t_DateTime = DateTime + + def compiles(cls, dialect=None): + def decorator(fn): + return fn + + return decorator + + +import salt.exceptions + +log = logging.getLogger(__name__) + +REGISTRY = {} + + +class DateTimeUTC(TypeDecorator): # type: ignore + """Timezone Aware DateTimeUTC. + + Ensure UTC is stored in the database and that TZ aware dates are returned for all dialects. + Accepts datetime or string (ISO8601 or dateutil formats). + """ + + impl = DateTime(timezone=True) + cache_ok = True + + @property + def python_type(self) -> type[datetime]: + return datetime + + def process_bind_param( + self, value: datetime | str | None, dialect: Dialect + ) -> datetime | None: + if value is None: + return value + if isinstance(value, str): + value = datetime.fromisoformat(value) + if not isinstance(value, datetime): + raise TypeError("created_at must be a datetime or ISO 8601 string") + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + def process_literal_param(self, value: datetime | None, dialect: Dialect) -> str: + return super().process_literal_param(value, dialect) + + def process_result_value( + self, value: datetime | None, dialect: Dialect + ) -> datetime | None: + if value is None: + return value + if isinstance(value, datetime) and value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value + + +def model_for(*names, engine_name=None): + if not engine_name: + engine_name = "default" + + models = [] + for name in names: + try: + models.append(REGISTRY[engine_name][name]) + except KeyError: + raise salt.exceptions.SaltInvocationError( + f"Unrecognized model name {name}. Did you forget to call salt.sqlalchemy.configure_orm?" + ) + + if len(names) == 1: + return models[0] + else: + return tuple(models) + + +def get_json_type(engine): + if engine.dialect.name == "postgresql": + return JSONB + elif engine.dialect.name == "mysql": + return MySQL_JSON + else: + return JSON + + +def get_text_type(engine): + if engine.dialect.name == "postgresql": + return Text + else: + return String(255) + + +class utcnow(FunctionElement): + type = t_DateTime() + + +@compiles(utcnow, "postgresql") +def postgresql_utcnow(element, compiler, **kw): + """ + SQLAlchemy compiler function for PostgreSQL dialect. + + Generates the PostgreSQL-specific SQL expression for UTC timestamp. + """ + return "TIMEZONE('utc', CURRENT_TIMESTAMP)" + + +@compiles(utcnow, "mssql") +def mssql_utcnow(element, compiler, **kw): + """ + SQLAlchemy compiler function for Microsoft SQL Server dialect. + + Generates the MSSQL-specific SQL expression for UTC timestamp. + + Args: + element: The SQLAlchemy FunctionElement instance + compiler: The SQLAlchemy statement compiler + **kw: Additional keyword arguments passed to the compiler + + Returns: + str: MSSQL-specific UTC timestamp function call + """ + return "GETUTCDATE()" + + +@compiles(utcnow, "mysql") +def myqsql_utcnow(element, compiler, **kw): + """ + SQLAlchemy compiler function for MySQL dialect. + + Generates the MySQL-specific SQL expression for UTC timestamp. + + Args: + element: The SQLAlchemy FunctionElement instance + compiler: The SQLAlchemy statement compiler + **kw: Additional keyword arguments passed to the compiler + + Returns: + str: MySQL-specific UTC timestamp function call + """ + return "UTC_TIMESTAMP()" + + +@compiles(utcnow, "sqlite") +def sqlite_utcnow(element, compiler, **kw): + """ + SQLAlchemy compiler function for SQLite dialect. + + Generates the SQLite-specific SQL expression for UTC timestamp. + + Args: + element: The SQLAlchemy FunctionElement instance + compiler: The SQLAlchemy statement compiler + **kw: Additional keyword arguments passed to the compiler + + Returns: + str: SQLite-specific UTC datetime function call + """ + return "datetime('now')" + + +def populate_model_registry(opts, name, engine): + """ + Creates and registers SQLAlchemy models in the global registry. + + Defines ORM models for various Salt data structures and registers them + in the REGISTRY dictionary for later retrieval. Configures PostgreSQL + table partitioning via pg_partman when applicable. + + Args: + opts (dict): Salt configuration options dictionary + name (str): Registry name, defaults to "default" if not provided + engine: SQLAlchemy engine instance for database connections + + Returns: + None: Models are registered in the global REGISTRY dictionary + """ + if not name: + name = "default" + + is_postgres = engine.dialect.name == "postgresql" + + class Base(DeclarativeBase): + pass + + class PartmanBase(Base): + __abstract__ = True + + @classmethod + def __declare_last__(cls): + after_create_ddl = DDL( + f""" + SELECT {opts["partman.schema"]}.create_parent( + p_parent_table := '%(schema)s.{cls.__tablename__}', + p_control := 'created_at', + p_interval := '{opts["partman.interval"]}', + p_type := 'native', + p_constraint_cols := '{{jid}}', + p_jobmon := {str(opts["partman.jobmon"]).lower()} + ) + """ + ) + event.listen(cls.__table__, "after_create", after_create_ddl) + + if opts["partman.retention"]: + after_create_retention_ddl = DDL( + f""" + UPDATE {opts["partman.schema"]}.part_config + SET retention = INTERVAL '{opts["partman.retention"]}', retention_keep_table=false + WHERE parent_table = '%(schema)s.{cls.__tablename__}' + """ + ) + event.listen(cls.__table__, "after_create", after_create_retention_ddl) + + after_drop_part_config_ddl = DDL( + f""" + DELETE FROM {opts["partman.schema"]}.part_config where parent_table = '%(schema)s.{cls.__tablename__}' + """ + ) + event.listen(cls.__table__, "after_drop", after_drop_part_config_ddl) + + after_drop_template_ddl = DDL( + f""" + DROP TABLE IF EXISTS {opts["partman.schema"]}.template_{cls.__tablename__} + """ + ) + event.listen(cls.__table__, "after_drop", after_drop_template_ddl) + + # shortcircuit partman behavior if not turned on + if not opts["partman.enabled"] or not is_postgres: + PartmanBase = Base + + class Cache(Base): + __tablename__ = "cache" + table_args = [] + if is_postgres: + table_args.append( + Index( + None, + "data", + postgresql_using="gin", + postgresql_with={"FASTUPDATE": "OFF"}, + ) + ) + table_args.append({"postgresql_partition_by": "LIST (bank)"}) + __table_args__ = tuple(table_args) + + bank: Mapped[str] = mapped_column( + get_text_type(engine), primary_key=True, index=True + ) + key: Mapped[str] = mapped_column( + get_text_type(engine), primary_key=True, index=True + ) + data: Mapped[dict] = mapped_column(get_json_type(engine)) + cluster: Mapped[str] = mapped_column(get_text_type(engine), primary_key=True) + created_at: Mapped[datetime] = mapped_column( + DateTimeUTC(6), + server_default=utcnow(), + nullable=False, + ) + expires_at: Mapped[datetime | None] = mapped_column(DateTimeUTC) + + if is_postgres: + + @classmethod + def __declare_last__(cls): + ddl = DDL( + """ + CREATE TABLE IF NOT EXISTS %(table)s_default PARTITION OF %(table)s DEFAULT; + CREATE TABLE IF NOT EXISTS %(table)s_keys PARTITION OF %(table)s FOR VALUES IN ('keys', 'denied_keys', 'master_keys'); + CREATE TABLE IF NOT EXISTS %(table)s_grains PARTITION OF %(table)s FOR VALUES IN ('grains'); + CREATE TABLE IF NOT EXISTS %(table)s_pillar PARTITION OF %(table)s FOR VALUES IN ('pillar'); + CREATE TABLE IF NOT EXISTS %(table)s_tokens PARTITION OF %(table)s FOR VALUES IN ('tokens'); + CREATE TABLE IF NOT EXISTS %(table)s_mine PARTITION OF %(table)s FOR VALUES IN ('mine'); + """ + ) + event.listen(cls.__table__, "after_create", ddl) + + class Events(PartmanBase): + __tablename__ = "events" + __mapper_args__ = {"primary_key": ["created_at"]} + table_args = [] + if is_postgres: + table_args.append( + Index( + None, + "data", + postgresql_using="gin", + postgresql_with={"FASTUPDATE": "OFF"}, + ) + ) + + if opts["partman.enabled"]: + table_args.append({"postgresql_partition_by": "RANGE (created_at)"}) + __table_args__ = tuple(table_args) if table_args else () + + tag: Mapped[str] = mapped_column(get_text_type(engine), index=True) + data: Mapped[dict] = mapped_column(get_json_type(engine)) + master_id: Mapped[str] = mapped_column(get_text_type(engine)) + cluster: Mapped[str] = mapped_column(get_text_type(engine), nullable=True) + created_at: Mapped[datetime | None] = mapped_column( + "created_at", + DateTimeUTC, + server_default=utcnow(), + index=True, + nullable=False, + ) + + class Jids(PartmanBase): + __tablename__ = "jids" + __mapper_args__ = {"primary_key": ["created_at"]} + table_args = [] + if is_postgres: + table_args.append( + Index( + None, + "load", + postgresql_using="gin", + postgresql_with={"FASTUPDATE": "OFF"}, + ) + ) + + if opts["partman.enabled"]: + table_args.append( + { + "postgresql_partition_by": "RANGE (created_at)", + }, + ) + __table_args__ = tuple(table_args) if table_args else () + + jid: Mapped[str] = mapped_column(get_text_type(engine), index=True) + load: Mapped[dict] = mapped_column(get_json_type(engine)) + minions: Mapped[list | None] = mapped_column( + get_json_type(engine), default=list + ) + cluster: Mapped[str] = mapped_column(get_text_type(engine), nullable=True) + created_at: Mapped[datetime] = mapped_column( + "created_at", + DateTimeUTC, + server_default=utcnow(), + nullable=False, + ) + + class Returns(PartmanBase): + __tablename__ = "returns" + __mapper_args__ = {"primary_key": ["created_at"]} + table_args = [] + if is_postgres: + table_args.append( + Index( + None, + "ret", + postgresql_using="gin", + postgresql_with={"FASTUPDATE": "OFF"}, + ) + ) + if opts["partman.enabled"]: + table_args.append( + { + "postgresql_partition_by": "RANGE (created_at)", + } + ) + __table_args__ = tuple(table_args) if table_args else () + + fun: Mapped[str] = mapped_column(get_text_type(engine), index=True) + jid: Mapped[str] = mapped_column(get_text_type(engine), index=True) + id: Mapped[str] = mapped_column(get_text_type(engine), index=True) + success: Mapped[str] = mapped_column(get_text_type(engine)) + ret: Mapped[dict] = mapped_column("ret", get_json_type(engine)) + cluster: Mapped[str] = mapped_column(get_text_type(engine), nullable=True) + created_at: Mapped[datetime | None] = mapped_column( + "createD_at", + DateTimeUTC, + server_default=utcnow(), + nullable=False, + ) + + configure_mappers() + + REGISTRY.setdefault(name, {}) + REGISTRY[name]["Base"] = Base + REGISTRY[name]["Returns"] = Returns + REGISTRY[name]["Events"] = Events + REGISTRY[name]["Cache"] = Cache + REGISTRY[name]["Jids"] = Jids diff --git a/salt/utils/master.py b/salt/utils/master.py index ac2785cb2b7b..499a5e58573d 100644 --- a/salt/utils/master.py +++ b/salt/utils/master.py @@ -239,11 +239,11 @@ def _get_cached_mine_data(self, *minion_ids): ) return mine_data if not minion_ids: - minion_ids = self.cache.list("minions") + minion_ids = self.cache.list("grains") for minion_id in minion_ids: if not salt.utils.verify.valid_id(self.opts, minion_id): continue - mdata = self.cache.fetch(f"minions/{minion_id}", "mine") + mdata = self.cache.fetch("mine", minion_id) if isinstance(mdata, dict): mine_data[minion_id] = mdata return mine_data @@ -257,23 +257,24 @@ def _get_cached_minion_data(self, *minion_ids): log.debug("Skipping cached data because minion_data_cache is not enabled.") return grains, pillars if not minion_ids: - minion_ids = self.cache.list("minions") + minion_ids = self.cache.list("grains") for minion_id in minion_ids: if not salt.utils.verify.valid_id(self.opts, minion_id): continue - mdata = self.cache.fetch(f"minions/{minion_id}", "data") - if not isinstance(mdata, dict): - log.warning( - "cache.fetch should always return a dict. ReturnedType: %s," - " MinionId: %s", - type(mdata).__name__, - minion_id, - ) - continue - if "grains" in mdata: - grains[minion_id] = mdata["grains"] - if "pillar" in mdata: - pillars[minion_id] = mdata["pillar"] + for bank in ["grains", "pillar"]: + mdata = self.cache.fetch(bank, minion_id) + if not isinstance(mdata, dict): + log.warning( + "cache.fetch should always return a dict. ReturnedType: %s," + " MinionId: %s", + type(mdata).__name__, + minion_id, + ) + continue + if bank == "grains": + grains[minion_id] = mdata + if bank == "pillar": + pillars[minion_id] = mdata return grains, pillars def _get_live_minion_grains(self, minion_ids): @@ -545,10 +546,12 @@ def clear_cached_minion_data( else: # Unless both clear_pillar and clear_grains are True, we need # to read in the pillar/grains data since they are both stored - # in the same file, 'data.p' + # in the the minion data cache grains, pillars = self._get_cached_minion_data(*minion_ids) try: - c_minions = self.cache.list("minions") + # we operate under the assumption that grains should be sufficient + # for minion list + c_minions = self.cache.list("grains") for minion_id in minion_ids: if not salt.utils.verify.valid_id(self.opts, minion_id): continue @@ -556,29 +559,25 @@ def clear_cached_minion_data( if minion_id not in c_minions: # Cache bank for this minion does not exist. Nothing to do. continue - bank = f"minions/{minion_id}" + minion_pillar = pillars.pop(minion_id, False) minion_grains = grains.pop(minion_id, False) - if ( - (clear_pillar and clear_grains) - or (clear_pillar and not minion_grains) - or (clear_grains and not minion_pillar) - ): - # Not saving pillar or grains, so just delete the cache file - self.cache.flush(bank, "data") - elif clear_pillar and minion_grains: - self.cache.store(bank, "data", {"grains": minion_grains}) - elif clear_grains and minion_pillar: - self.cache.store(bank, "data", {"pillar": minion_pillar}) + + if clear_pillar: + self.cache.flush("pillar", minion_id) + + if clear_grains: + self.cache.flush("grains", minion_id) + if clear_mine: # Delete the whole mine file - self.cache.flush(bank, "mine") + self.cache.flush("mine", minion_id) elif clear_mine_func is not None: # Delete a specific function from the mine file - mine_data = self.cache.fetch(bank, "mine") + mine_data = self.cache.fetch("mine", minion_id) if isinstance(mine_data, dict): if mine_data.pop(clear_mine_func, False): - self.cache.store(bank, "mine", mine_data) + self.cache.store("mine", minion_id, mine_data) except OSError: return True return True diff --git a/salt/utils/minions.py b/salt/utils/minions.py index ba0eff6b688f..0d493928445c 100644 --- a/salt/utils/minions.py +++ b/salt/utils/minions.py @@ -106,15 +106,14 @@ def get_minion_data(minion, opts): if opts.get("minion_data_cache", False): cache = salt.cache.factory(opts) if minion is None: - for id_ in cache.list("minions"): - data = cache.fetch(f"minions/{id_}", "data") - if data is None: - continue + for id_ in cache.list("grains"): + grains = cache.fetch("grains", id_) + pillar = cache.fetch("pillar", id_) + if grains: + break else: - data = cache.fetch(f"minions/{minion}", "data") - if data is not None: - grains = data.get("grains", None) - pillar = data.get("pillar", None) + grains = cache.fetch("grains", minion) + pillar = cache.fetch("pillar", minion) return minion if minion else None, grains, pillar @@ -320,7 +319,8 @@ def _check_cache_minions( cache_enabled = self.opts.get("minion_data_cache", False) def list_cached_minions(): - return self.cache.list("minions") + # we use grains as a equivalent for minion list + return self.cache.list("grains") if greedy: if not minions: @@ -341,14 +341,13 @@ def list_cached_minions(): for id_ in cminions: if greedy and id_ not in minions: continue - mdata = self.cache.fetch(f"minions/{id_}", "data") + mdata = self.cache.fetch(search_type, id_) if mdata is None: if not greedy: minions.remove(id_) continue - search_results = mdata.get(search_type) if not salt.utils.data.subdict_match( - search_results, + mdata, expr, delimiter=delimiter, regex_match=regex_match, @@ -408,13 +407,13 @@ def _check_ipcidr_minions(self, expr, greedy, minions=None): if not minions: minions = self._pki_minions() elif cache_enabled: - minions = self.cache.list("minions") + minions = self.cache.list("grains") else: return {"minions": [], "missing": []} if cache_enabled: if greedy: - cminions = self.cache.list("minions") + cminions = self.cache.list("grains") else: cminions = minions if cminions is None: @@ -435,12 +434,11 @@ def _check_ipcidr_minions(self, expr, greedy, minions=None): minions = set(minions) for id_ in cminions: - mdata = self.cache.fetch(f"minions/{id_}", "data") - if mdata is None: + grains = self.cache.fetch("grains", id_) + if grains is None: if not greedy: minions.remove(id_) continue - grains = mdata.get("grains") if grains is None or proto not in grains: match = False elif isinstance(tgt, (ipaddress.IPv4Address, ipaddress.IPv6Address)): @@ -477,7 +475,7 @@ def _check_range_minions(self, expr, greedy): mlist.append(fn_) return {"minions": mlist, "missing": []} elif cache_enabled: - return {"minions": self.cache.list("minions"), "missing": []} + return {"minions": self.cache.list("grains"), "missing": []} else: return {"minions": [], "missing": []} @@ -670,7 +668,7 @@ def connected_ids(self, subset=None, show_ip=False): """ minions = set() if self.opts.get("minion_data_cache", False): - search = self.cache.list("minions") + search = self.cache.list("grains") if search is None: return minions addrs = salt.utils.network.local_port_tcp(int(self.opts["publish_port"])) @@ -690,16 +688,15 @@ def connected_ids(self, subset=None, show_ip=False): search = subset for id_ in search: try: - mdata = self.cache.fetch(f"minions/{id_}", "data") + grains = self.cache.fetch("grains", id_) except SaltCacheError: # If a SaltCacheError is explicitly raised during the fetch operation, # permission was denied to open the cached data.p file. Continue on as # in the releases <= 2016.3. (An explicit error raise was added in PR # #35388. See issue #36867 for more information. continue - if mdata is None: + if grains is None: continue - grains = mdata.get("grains", {}) for ipv4 in grains.get("ipv4", []): if ipv4 in addrs: if show_ip: diff --git a/tests/pytests/functional/cache/helpers.py b/tests/pytests/functional/cache/helpers.py index 49ddf32e406a..aa1ae74f9b01 100644 --- a/tests/pytests/functional/cache/helpers.py +++ b/tests/pytests/functional/cache/helpers.py @@ -57,7 +57,13 @@ def run_common_cache_tests(subtests, cache): assert actual_thing is new_thing else: assert actual_thing is not new_thing - assert actual_thing == new_thing + + try: + assert actual_thing == new_thing + except AssertionError: + # json storage disallows int object keys, which some storages use + new_thing["42"] = new_thing.pop(42) + assert actual_thing == new_thing with subtests.test("contains returns true if key in bank"): assert cache.contains(bank=bank, key=good_key) @@ -125,13 +131,14 @@ def run_common_cache_tests(subtests, cache): assert timestamp is None with subtests.test("Updated for key should return a reasonable time"): - before_storage = int(time.time()) + before_storage = time.time() cache.store(bank="fnord", key="updated test part 2", data="fnord") - after_storage = int(time.time()) + after_storage = time.time() timestamp = cache.updated(bank="fnord", key="updated test part 2") - assert before_storage <= timestamp <= after_storage + # the -1/+1 because mysql timestamps are janky + assert before_storage - 1 <= timestamp <= after_storage + 1 with subtests.test( "If the module raises SaltCacheError then it should make it out of updated" diff --git a/tests/pytests/functional/cache/test_mysql.py b/tests/pytests/functional/cache/test_mysql.py index c283872c08c9..8c8e2b0a771b 100644 --- a/tests/pytests/functional/cache/test_mysql.py +++ b/tests/pytests/functional/cache/test_mysql.py @@ -3,9 +3,7 @@ import pytest import salt.cache -import salt.loader from tests.pytests.functional.cache.helpers import run_common_cache_tests -from tests.support.pytest.mysql import * # pylint: disable=wildcard-import,unused-wildcard-import docker = pytest.importorskip("docker") @@ -14,24 +12,35 @@ pytestmark = [ pytest.mark.slow_test, pytest.mark.skip_if_binaries_missing("dockerd"), + pytest.mark.parametrize( + "database_backend", + [ + ("mysql-server", "5.5"), + ("mysql-server", "5.6"), + ("mysql-server", "5.7"), + ("mysql-server", "8.0"), + ("mariadb", "10.3"), + ("mariadb", "10.4"), + ("mariadb", "10.5"), + ("percona", "5.6"), + ("percona", "5.7"), + ("percona", "8.0"), + ], + ids=lambda val: f"{val[0]}-{val[1] or 'default'}", + indirect=True, + ), ] -@pytest.fixture(scope="module") -def mysql_combo(create_mysql_combo): # pylint: disable=function-redefined - create_mysql_combo.mysql_database = "salt_cache" - return create_mysql_combo - - @pytest.fixture -def cache(minion_opts, mysql_container): +def cache(minion_opts, database_backend): opts = minion_opts.copy() opts["cache"] = "mysql" opts["mysql.host"] = "127.0.0.1" - opts["mysql.port"] = mysql_container.mysql_port - opts["mysql.user"] = mysql_container.mysql_user - opts["mysql.password"] = mysql_container.mysql_passwd - opts["mysql.database"] = mysql_container.mysql_database + opts["mysql.port"] = database_backend.port + opts["mysql.user"] = database_backend.user + opts["mysql.password"] = database_backend.passwd + opts["mysql.database"] = database_backend.database opts["mysql.table_name"] = "cache" cache = salt.cache.factory(opts) return cache diff --git a/tests/pytests/functional/cache/test_sqlalchemy.py b/tests/pytests/functional/cache/test_sqlalchemy.py new file mode 100644 index 000000000000..45cc91549d60 --- /dev/null +++ b/tests/pytests/functional/cache/test_sqlalchemy.py @@ -0,0 +1,81 @@ +import os + +import pytest + +import salt.cache +import salt.sqlalchemy +from tests.pytests.functional.cache.helpers import run_common_cache_tests +from tests.support.pytest.database import available_databases + +sqlalchemy = pytest.importorskip("sqlalchemy") + +pytestmark = [ + pytest.mark.slow_test, + pytest.mark.parametrize( + "database_backend", + available_databases( + [ + ("mysql-server", "8.0"), + ("mariadb", "10.4"), + ("mariadb", "10.5"), + ("percona", "8.0"), + ("postgresql", "13"), + ("postgresql", "17"), + ("sqlite", None), + ] + ), + indirect=True, + ), +] + + +@pytest.fixture +def cache(master_opts, database_backend, tmp_path_factory): + opts = master_opts.copy() + opts["cache"] = "sqlalchemy" + opts["sqlalchemy.echo"] = True + + if database_backend.dialect in {"mysql", "postgresql"}: + if database_backend.dialect == "mysql": + driver = "mysql+pymysql" + elif database_backend.dialect == "postgresql": + driver = "postgresql+psycopg" + + opts["sqlalchemy.drivername"] = driver + opts["sqlalchemy.username"] = database_backend.user + opts["sqlalchemy.password"] = database_backend.passwd + opts["sqlalchemy.port"] = database_backend.port + opts["sqlalchemy.database"] = database_backend.database + opts["sqlalchemy.host"] = "0.0.0.0" + opts["sqlalchemy.disable_connection_pool"] = True + elif database_backend.dialect == "sqlite": + opts["sqlalchemy.dsn"] = "sqlite:///" + os.path.join( + tmp_path_factory.mktemp("sqlite"), "salt.db" + ) + else: + raise ValueError(f"Unsupported returner param: {database_backend}") + + salt.sqlalchemy.reconfigure_orm(opts) + salt.sqlalchemy.drop_all() + salt.sqlalchemy.create_all() + + return salt.cache.factory(opts) + + +@pytest.fixture(scope="module") +def master_opts( + salt_factories, + master_id, + master_config_defaults, + master_config_overrides, +): + factory = salt_factories.salt_master_daemon( + master_id, + defaults=master_config_defaults or None, + overrides=master_config_overrides, + ) + return factory.config.copy() + + +def test_caching(subtests, cache): + run_common_cache_tests(subtests, cache) diff --git a/tests/pytests/functional/conftest.py b/tests/pytests/functional/conftest.py index 0a8219b8f717..8c721fa1c112 100644 --- a/tests/pytests/functional/conftest.py +++ b/tests/pytests/functional/conftest.py @@ -5,6 +5,8 @@ import pytest from saltfactories.utils.functional import Loaders +from tests.support.pytest.database import * # pylint: disable=wildcard-import,unused-wildcard-import + log = logging.getLogger(__name__) diff --git a/tests/pytests/functional/modules/test_mysql.py b/tests/pytests/functional/modules/test_mysql.py index c82bba301932..633297afd6ed 100644 --- a/tests/pytests/functional/modules/test_mysql.py +++ b/tests/pytests/functional/modules/test_mysql.py @@ -10,7 +10,7 @@ import salt.modules.mysql as mysqlmod from salt.utils.versions import version_cmp -from tests.support.pytest.mysql import * # pylint: disable=wildcard-import,unused-wildcard-import +from tests.support.pytest.database import available_databases log = logging.getLogger(__name__) @@ -20,6 +20,24 @@ pytest.mark.skipif( mysqlmod.MySQLdb is None, reason="No python mysql client installed." ), + pytest.mark.parametrize( + "database_backend", + available_databases( + [ + ("mysql-server", "5.5", "MySQLdb"), + ("mysql-server", "5.6", "MySQLdb"), + ("mysql-server", "5.7", "MySQLdb"), + ("mysql-server", "8.0", "MySQLdb"), + ("mariadb", "10.3", "MySQLdb"), + ("mariadb", "10.4", "MySQLdb"), + ("mariadb", "10.5", "MySQLdb"), + ("percona", "5.6", "MySQLdb"), + ("percona", "5.7", "MySQLdb"), + ("percona", "8.0", "MySQLdb"), + ] + ), + indirect=True, + ), pytest.mark.skip_on_fips_enabled_platform, ] @@ -61,13 +79,13 @@ def __call__(self, *args, **kwargs): @pytest.fixture(scope="module") -def mysql(modules, mysql_container, loaders): +def mysql(modules, database_backend, loaders): for name in list(modules): if not name.startswith("mysql."): continue modules._dict[name] = CallWrapper( modules._dict[name], - mysql_container, + database_backend, loaders.context, ) return modules.mysql @@ -79,10 +97,10 @@ def test_query(mysql): assert ret["results"] == (("1",),) -def test_version(mysql, mysql_container): +def test_version(mysql, database_backend): ret = mysql.version() assert ret - assert mysql_container.mysql_version in ret + assert database_backend.version in ret def test_status(mysql): @@ -111,20 +129,20 @@ def test_db_create_alter_remove(mysql): assert ret -def test_user_list(mysql, mysql_combo): +def test_user_list(mysql, database_backend): ret = mysql.user_list() assert ret assert { - "User": mysql_combo.mysql_root_user, - "Host": mysql_combo.mysql_host, + "User": database_backend.root_user, + "Host": database_backend.host, } in ret -def test_user_exists(mysql, mysql_combo): +def test_user_exists(mysql, database_backend): ret = mysql.user_exists( - mysql_combo.mysql_root_user, - host=mysql_combo.mysql_host, - password=mysql_combo.mysql_passwd, + database_backend.root_user, + host=database_backend.host, + password=database_backend.passwd, ) assert ret @@ -136,15 +154,15 @@ def test_user_exists(mysql, mysql_combo): assert not ret -def test_user_info(mysql, mysql_combo): - ret = mysql.user_info(mysql_combo.mysql_root_user, host=mysql_combo.mysql_host) +def test_user_info(mysql, database_backend): + ret = mysql.user_info(database_backend.root_user, host=database_backend.host) assert ret # Check that a subset of the information # is available in the returned user information. expected = { - "Host": mysql_combo.mysql_host, - "User": mysql_combo.mysql_root_user, + "Host": database_backend.host, + "User": database_backend.root_user, "Select_priv": "Y", "Insert_priv": "Y", "Update_priv": "Y", @@ -201,8 +219,8 @@ def test_user_create_chpass_delete(mysql): assert ret -def test_user_grants(mysql, mysql_combo): - ret = mysql.user_grants(mysql_combo.mysql_root_user, host=mysql_combo.mysql_host) +def test_user_grants(mysql, database_backend): + ret = mysql.user_grants(database_backend.root_user, host=database_backend.host) assert ret @@ -300,22 +318,22 @@ def test_grant_add_revoke(mysql): assert ret -def test_grant_replication_replica_add_revoke(mysql, mysql_container): +def test_grant_replication_replica_add_revoke(mysql, database_backend): # The REPLICATION REPLICA grant is only available for mariadb - if "mariadb" not in mysql_container.mysql_name: + if "mariadb" not in database_backend.name: pytest.skip( "The REPLICATION REPLICA grant is unavailable " "for the {}:{} docker image.".format( - mysql_container.mysql_name, mysql_container.mysql_version + database_backend.name, database_backend.version ) ) # The REPLICATION REPLICA grant was added in mariadb 10.5.1 - if version_cmp(mysql_container.mysql_version, "10.5.1") < 0: + if version_cmp(database_backend.version, "10.5.1") < 0: pytest.skip( "The REPLICATION REPLICA grant is unavailable " "for the {}:{} docker image.".format( - mysql_container.mysql_name, mysql_container.mysql_version + database_backend.name, database_backend.version ) ) @@ -376,7 +394,7 @@ def test_grant_replication_replica_add_revoke(mysql, mysql_container): assert ret -def test_grant_replication_slave_add_revoke(mysql, mysql_container): +def test_grant_replication_slave_add_revoke(mysql, database_backend): # Create the database ret = mysql.db_create("salt") assert ret @@ -434,7 +452,7 @@ def test_grant_replication_slave_add_revoke(mysql, mysql_container): assert ret -def test_grant_replication_client_add_revoke(mysql, mysql_container): +def test_grant_replication_client_add_revoke(mysql, database_backend): # Create the database ret = mysql.db_create("salt") assert ret @@ -492,22 +510,22 @@ def test_grant_replication_client_add_revoke(mysql, mysql_container): assert ret -def test_grant_binlog_monitor_add_revoke(mysql, mysql_container): +def test_grant_binlog_monitor_add_revoke(mysql, database_backend): # The BINLOG MONITOR grant is only available for mariadb - if "mariadb" not in mysql_container.mysql_name: + if "mariadb" not in database_backend.name: pytest.skip( "The BINLOG MONITOR grant is unavailable " "for the {}:{} docker image.".format( - mysql_container.mysql_name, mysql_container.mysql_version + database_backend.name, database_backend.version ) ) # The BINLOG MONITOR grant was added in mariadb 10.5.2 - if version_cmp(mysql_container.mysql_version, "10.5.2") < 0: + if version_cmp(database_backend.version, "10.5.2") < 0: pytest.skip( "The BINLOG_MONITOR grant is unavailable " "for the {}:{} docker image.".format( - mysql_container.mysql_name, mysql_container.mysql_version + database_backend.name, database_backend.version ) ) @@ -568,22 +586,22 @@ def test_grant_binlog_monitor_add_revoke(mysql, mysql_container): assert ret -def test_grant_replica_monitor_add_revoke(mysql, mysql_container): +def test_grant_replica_monitor_add_revoke(mysql, database_backend): # The REPLICA MONITOR grant is only available for mariadb - if "mariadb" not in mysql_container.mysql_name: + if "mariadb" not in database_backend.name: pytest.skip( "The REPLICA MONITOR grant is unavailable " "for the {}:{} docker image.".format( - mysql_container.mysql_name, mysql_container.mysql_version + database_backend.name, database_backend.version ) ) # The REPLICA MONITOR grant was added in mariadb 10.5.9 - if version_cmp(mysql_container.mysql_version, "10.5.9") < 0: + if version_cmp(database_backend.version, "10.5.9") < 0: pytest.skip( "The REPLICA MONITOR grant is unavailable " "for the {}:{} docker image.".format( - mysql_container.mysql_name, mysql_container.mysql_version + database_backend.name, database_backend.version ) ) @@ -644,22 +662,22 @@ def test_grant_replica_monitor_add_revoke(mysql, mysql_container): assert ret -def test_grant_slave_monitor_add_revoke(mysql, mysql_container): +def test_grant_slave_monitor_add_revoke(mysql, database_backend): # The SLAVE MONITOR grant is only available for mariadb - if "mariadb" not in mysql_container.mysql_name: + if "mariadb" not in database_backend.name: pytest.skip( "The SLAVE MONITOR grant is unavailable " "for the {}:{} docker image.".format( - mysql_container.mysql_name, mysql_container.mysql_version + database_backend.name, database_backend.version ) ) # The SLAVE MONITOR grant was added in mariadb 10.5.9 - if version_cmp(mysql_container.mysql_version, "10.5.9") < 0: + if version_cmp(database_backend.version, "10.5.9") < 0: pytest.skip( "The SLAVE MONITOR grant is unavailable " "for the {}:{} docker image.".format( - mysql_container.mysql_name, mysql_container.mysql_version + database_backend.name, database_backend.version ) ) @@ -720,32 +738,32 @@ def test_grant_slave_monitor_add_revoke(mysql, mysql_container): assert ret -def test_plugin_add_status_remove(mysql, mysql_combo): +def test_plugin_add_status_remove(mysql, database_backend): - if "mariadb" in mysql_combo.mysql_name: + if "mariadb" in database_backend.name: plugin = "simple_password_check" else: plugin = "auth_socket" - ret = mysql.plugin_status(plugin, host=mysql_combo.mysql_host) + ret = mysql.plugin_status(plugin, host=database_backend.host) assert not ret ret = mysql.plugin_add(plugin) assert ret - ret = mysql.plugin_status(plugin, host=mysql_combo.mysql_host) + ret = mysql.plugin_status(plugin, host=database_backend.host) assert ret assert ret == "ACTIVE" ret = mysql.plugin_remove(plugin) assert ret - ret = mysql.plugin_status(plugin, host=mysql_combo.mysql_host) + ret = mysql.plugin_status(plugin, host=database_backend.host) assert not ret -def test_plugin_list(mysql, mysql_container): - if "mariadb" in mysql_container.mysql_name: +def test_plugin_list(mysql, database_backend): + if "mariadb" in database_backend.name: plugin = "simple_password_check" else: plugin = "auth_socket" diff --git a/tests/pytests/functional/returners/test_sqlalchemy.py b/tests/pytests/functional/returners/test_sqlalchemy.py new file mode 100644 index 000000000000..91db42b3be4b --- /dev/null +++ b/tests/pytests/functional/returners/test_sqlalchemy.py @@ -0,0 +1,232 @@ +import logging +import os +from datetime import datetime, timezone + +import pytest + +import salt.exceptions +import salt.loader +import salt.sqlalchemy +from salt.sqlalchemy import Session +from salt.sqlalchemy.models import model_for +from salt.utils.jid import gen_jid +from tests.support.mock import patch +from tests.support.pytest.database import available_databases + +sqlalchemy = pytest.importorskip("sqlalchemy") + +from sqlalchemy import ( # pylint: disable=3rd-party-module-not-gated + delete, + func, + select, +) + +log = logging.getLogger(__name__) + +pytestmark = [ + pytest.mark.slow_test, + pytest.mark.parametrize( + "database_backend", + available_databases( + [ + ("mysql-server", "8.0"), + ("mariadb", "10.4"), + ("mariadb", "10.5"), + ("percona", "8.0"), + ("postgresql", "13"), + ("postgresql", "17"), + ("sqlite", None), + ] + ), + indirect=True, + ), +] + + +@pytest.fixture(scope="module") +def returner(master_opts, database_backend, tmp_path_factory): + opts = master_opts.copy() + opts["cache"] = "sqlalchemy" + opts["sqlalchemy.echo"] = True + + if database_backend.dialect in {"mysql", "postgresql"}: + if database_backend.dialect == "mysql": + driver = "mysql+pymysql" + elif database_backend.dialect == "postgresql": + driver = "postgresql+psycopg" + + opts["sqlalchemy.drivername"] = driver + opts["sqlalchemy.username"] = database_backend.user + opts["sqlalchemy.password"] = database_backend.passwd + opts["sqlalchemy.port"] = database_backend.port + opts["sqlalchemy.database"] = database_backend.database + opts["sqlalchemy.host"] = "0.0.0.0" + opts["sqlalchemy.disable_connection_pool"] = True + elif database_backend.dialect == "sqlite": + opts["sqlalchemy.dsn"] = "sqlite:///" + os.path.join( + tmp_path_factory.mktemp("sqlite"), "salt.db" + ) + else: + raise ValueError(f"Unsupported returner param: {database_backend}") + + salt.sqlalchemy.reconfigure_orm(opts) + salt.sqlalchemy.drop_all() + salt.sqlalchemy.create_all() + + functions = salt.loader.minion_mods(opts) + return salt.loader.returners(opts, functions) + + +def test_returner_inserts(returner): + Returns = model_for("Returns") + ret = {"fun": "test.ping", "jid": gen_jid({}), "id": "minion", "success": True} + returner["sqlalchemy.returner"](ret) + + with Session() as session: + stmt = select(func.count()).where( # pylint: disable=not-callable + Returns.jid == ret["jid"] + ) + inserted = session.execute(stmt).scalar() + assert inserted == 1 + + +def test_event_return_inserts(returner): + Events = model_for("Events") + evts = [{"tag": "test", "data": {"_stamp": str(datetime.now(timezone.utc))}}] + returner["sqlalchemy.event_return"](evts) + with Session() as session: + stmt = select(func.count()).where( # pylint: disable=not-callable + Events.tag == "test" + ) + inserted = session.execute(stmt).scalar() + assert inserted == 1 + + +def test_save_load_inserts(returner): + Jids = model_for("Jids") + jid = gen_jid({}) + load = {"foo": "bar"} + minions = ["minion1", "minion2"] + returner["sqlalchemy.save_load"](jid, load, minions) + with Session() as session: + stmt = select(func.count()).where( # pylint: disable=not-callable + Jids.jid == jid + ) + inserted = session.execute(stmt).scalar() + assert inserted == 1 + + +def test_get_load_returns(returner): + Jids = model_for("Jids") + jid = gen_jid({}) + load = {"foo": "bar"} + minions = ["minion1", "minion2"] + returner["sqlalchemy.save_load"](jid, load, minions) + result = returner["sqlalchemy.get_load"](jid) + assert isinstance(result, dict) + assert "foo" in result + + +def test_get_jid_returns(returner): + Returns = model_for("Returns") + jid = gen_jid({}) + ret = {"fun": "test.ping", "jid": jid, "id": "minion", "success": True} + returner["sqlalchemy.returner"](ret) + result = returner["sqlalchemy.get_jid"](jid) + assert isinstance(result, dict) + assert "minion" in result + + +def test_prep_jid_returns_unique(returner): + jid1 = returner["sqlalchemy.prep_jid"]() + jid2 = returner["sqlalchemy.prep_jid"]() + assert jid1 != jid2 + + +def test_save_minions_noop(returner): + # Should not raise or do anything + assert returner["sqlalchemy.save_minions"]("jid", ["minion"]) is None + + +def test_get_fun_raises(returner): + with pytest.raises(Exception): + returner["sqlalchemy.get_fun"]("test.ping") + + +def test_get_jids_raises(returner): + with pytest.raises(Exception): + returner["sqlalchemy.get_jids"]() + + +def test_get_minions_raises(returner): + with pytest.raises(Exception): + returner["sqlalchemy.get_minions"]() + + +def test_clean_old_jobs(master_opts, returner): + # there might be a better way to do this + opts = returner["sqlalchemy.clean_old_jobs"].__globals__["__opts__"] + + with patch.dict(opts, {"keep_jobs_seconds": 3600, "archive_jobs": True}): + with pytest.raises(salt.exceptions.SaltException): + returner["sqlalchemy.clean_old_jobs"]() + + with patch.dict( + opts, + {"keep_jobs_seconds": 3600, "cluster_id": "testcluster", "id": "testmaster"}, + ): + # Insert fake data into Jids, Returns, Events + Jids = model_for("Jids") + Returns = model_for("Returns") + Events = model_for("Events") + # + # delete all state so counts are correct + with Session() as session: + session.execute(delete(Jids)) + session.execute(delete(Returns)) + session.execute(delete(Events)) + session.commit() + + now = datetime.now(timezone.utc) + old_time = now.replace(year=now.year - 1) # definitely old enough to be deleted + + with Session() as session: + session.add( + Jids( + jid="jid1", + load={"foo": "bar"}, + minions=["minion1"], + cluster="testcluster", + created_at=old_time, + ) + ) + session.add( + Returns( + fun="test.ping", + jid="jid1", + id="minion1", + success=True, + ret={"foo": "bar"}, + cluster="testcluster", + created_at=old_time, + ) + ) + session.add( + Events( + tag="test", + data={"_stamp": str(old_time)}, + master_id="testmaster", + cluster="testcluster", + created_at=old_time, + ) + ) + session.commit() + + # Run clean_old_jobs + returner["sqlalchemy.clean_old_jobs"]() + + # Assert all old data is deleted + with Session() as session: + assert session.query(Jids).count() == 0 + assert session.query(Returns).count() == 0 + assert session.query(Events).count() == 0 diff --git a/tests/pytests/integration/modules/test_pillar.py b/tests/pytests/integration/modules/test_pillar.py index 8ebe4e1a03c9..d2231849f4e8 100644 --- a/tests/pytests/integration/modules/test_pillar.py +++ b/tests/pytests/integration/modules/test_pillar.py @@ -351,6 +351,8 @@ def test_pillar_refresh_pillar_items(salt_cli, salt_minion, key_pillar): with key_pillar(key) as key_pillar_instance: # A pillar.items call sees the pillar right away because a # refresh_pillar event is fired. + # Calling refresh_pillar to update in-memory pillars + key_pillar_instance.refresh_pillar() ret = salt_cli.run("pillar.items", minion_tgt=salt_minion.id) assert ret.returncode == 0 val = ret.data diff --git a/tests/pytests/unit/auth/test_auth.py b/tests/pytests/unit/auth/test_auth.py deleted file mode 100644 index 4fc3d836426f..000000000000 --- a/tests/pytests/unit/auth/test_auth.py +++ /dev/null @@ -1,33 +0,0 @@ -import time - -import salt.auth -import salt.config - - -def test_cve_2021_3244(tmp_path): - token_dir = tmp_path / "tokens" - token_dir.mkdir() - opts = { - "extension_modules": "", - "optimization_order": [0, 1, 2], - "token_expire": 1, - "keep_acl_in_token": False, - "eauth_tokens": "localfs", - "token_dir": str(token_dir), - "token_expire_user_override": True, - "external_auth": {"auto": {"foo": []}}, - } - auth = salt.auth.LoadAuth(opts) - load = { - "eauth": "auto", - "username": "foo", - "password": "foo", - "token_expire": -1, - } - t_data = auth.mk_token(load) - assert t_data["expire"] < time.time() - token_file = token_dir / t_data["token"] - assert token_file.exists() - t_data = auth.get_tok(t_data["token"]) - assert not t_data - assert not token_file.exists() diff --git a/tests/pytests/unit/cache/test_memcache.py b/tests/pytests/unit/cache/test_memcache.py index 7a1d93743350..e52dd447a2d4 100644 --- a/tests/pytests/unit/cache/test_memcache.py +++ b/tests/pytests/unit/cache/test_memcache.py @@ -34,7 +34,13 @@ def test_fetch(cache): ret = cache.fetch("bank", "key") assert ret == "fake_data" assert salt.cache.MemCache.data == { - "fake_driver": {("bank", "key"): [0, "fake_data"]} + "fake_driver": { + ("bank", "key"): [ + 0, + cache.opts["memcache_expire_seconds"], + "fake_data", + ] + } } cache_fetch_mock.assert_called_once_with("bank", "key") cache_fetch_mock.reset_mock() @@ -44,7 +50,13 @@ def test_fetch(cache): ret = cache.fetch("bank", "key") assert ret == "fake_data" assert salt.cache.MemCache.data == { - "fake_driver": {("bank", "key"): [1, "fake_data"]} + "fake_driver": { + ("bank", "key"): [ + 1, + cache.opts["memcache_expire_seconds"], + "fake_data", + ] + } } cache_fetch_mock.assert_not_called() @@ -53,7 +65,13 @@ def test_fetch(cache): ret = cache.fetch("bank", "key") assert ret == "fake_data" assert salt.cache.MemCache.data == { - "fake_driver": {("bank", "key"): [12, "fake_data"]} + "fake_driver": { + ("bank", "key"): [ + 12, + cache.opts["memcache_expire_seconds"], + "fake_data", + ] + } } cache_fetch_mock.assert_called_once_with("bank", "key") cache_fetch_mock.reset_mock() @@ -66,9 +84,11 @@ def test_store(cache): with patch("time.time", return_value=0): cache.store("bank", "key", "fake_data") assert salt.cache.MemCache.data == { - "fake_driver": {("bank", "key"): [0, "fake_data"]} + "fake_driver": {("bank", "key"): [0, None, "fake_data"]} } - cache_store_mock.assert_called_once_with("bank", "key", "fake_data") + cache_store_mock.assert_called_once_with( + "bank", "key", "fake_data", expires=None + ) cache_store_mock.reset_mock() # Store another value. @@ -76,11 +96,13 @@ def test_store(cache): cache.store("bank", "key2", "fake_data2") assert salt.cache.MemCache.data == { "fake_driver": { - ("bank", "key"): [0, "fake_data"], - ("bank", "key2"): [1, "fake_data2"], + ("bank", "key"): [0, None, "fake_data"], + ("bank", "key2"): [1, None, "fake_data2"], } } - cache_store_mock.assert_called_once_with("bank", "key2", "fake_data2") + cache_store_mock.assert_called_once_with( + "bank", "key2", "fake_data2", expires=None + ) def test_flush(cache): @@ -102,10 +124,11 @@ def test_flush(cache): cache.store("bank", "key", "fake_data") assert salt.cache.MemCache.data["fake_driver"][("bank", "key")] == [ 0, + None, "fake_data", ] assert salt.cache.MemCache.data == { - "fake_driver": {("bank", "key"): [0, "fake_data"]} + "fake_driver": {("bank", "key"): [0, None, "fake_data"]} } cache.flush("bank", "key") assert salt.cache.MemCache.data == {"fake_driver": {}} @@ -124,17 +147,17 @@ def test_max_items(cache): with patch("time.time", return_value=2): cache.store("bank2", "key1", "fake_data21") assert salt.cache.MemCache.data["fake_driver"] == { - ("bank1", "key1"): [0, "fake_data11"], - ("bank1", "key2"): [1, "fake_data12"], - ("bank2", "key1"): [2, "fake_data21"], + ("bank1", "key1"): [0, None, "fake_data11"], + ("bank1", "key2"): [1, None, "fake_data12"], + ("bank2", "key1"): [2, None, "fake_data21"], } # Put one more and check the oldest was removed with patch("time.time", return_value=3): cache.store("bank2", "key2", "fake_data22") assert salt.cache.MemCache.data["fake_driver"] == { - ("bank1", "key2"): [1, "fake_data12"], - ("bank2", "key1"): [2, "fake_data21"], - ("bank2", "key2"): [3, "fake_data22"], + ("bank1", "key2"): [1, None, "fake_data12"], + ("bank2", "key1"): [2, None, "fake_data21"], + ("bank2", "key2"): [3, None, "fake_data22"], } @@ -151,16 +174,16 @@ def test_full_cleanup(cache): with patch("time.time", return_value=2): cache.store("bank2", "key1", "fake_data21") assert salt.cache.MemCache.data["fake_driver"] == { - ("bank1", "key1"): [0, "fake_data11"], - ("bank1", "key2"): [1, "fake_data12"], - ("bank2", "key1"): [2, "fake_data21"], + ("bank1", "key1"): [0, None, "fake_data11"], + ("bank1", "key2"): [1, None, "fake_data12"], + ("bank2", "key1"): [2, None, "fake_data21"], } # Put one more and check all expired was removed with patch("time.time", return_value=12): cache.store("bank2", "key2", "fake_data22") assert salt.cache.MemCache.data["fake_driver"] == { - ("bank2", "key1"): [2, "fake_data21"], - ("bank2", "key2"): [12, "fake_data22"], + ("bank2", "key1"): [2, None, "fake_data21"], + ("bank2", "key2"): [12, None, "fake_data22"], } diff --git a/tests/pytests/unit/daemons/masterapi/test_remote_funcs.py b/tests/pytests/unit/daemons/masterapi/test_remote_funcs.py index 99821a8f54a9..bcd41e9edbbd 100644 --- a/tests/pytests/unit/daemons/masterapi/test_remote_funcs.py +++ b/tests/pytests/unit/daemons/masterapi/test_remote_funcs.py @@ -1,3 +1,5 @@ +import logging + import pytest import salt.config @@ -5,6 +7,7 @@ import salt.utils.platform from tests.support.mock import MagicMock, patch +log = logging.getLogger(__name__) pytestmark = [ pytest.mark.slow_test, ] @@ -18,7 +21,7 @@ def store(self, bank, key, value): self.data[bank, key] = value def fetch(self, bank, key): - return self.data[bank, key] + return self.data.get((bank, key), None) @pytest.fixture @@ -39,7 +42,7 @@ def test_mine_get(funcs, tgt_type_key="tgt_type"): - the correct check minions method is called - the correct cache key is subsequently used """ - funcs.cache.store("minions/webserver", "mine", dict(ip_addr="2001:db8::1:3")) + funcs.cache.store("mine", "webserver", dict(ip_addr="2001:db8::1:3")) with patch( "salt.utils.minions.CkMinions._check_compound_minions", MagicMock(return_value=dict(minions=["webserver"], missing=[])), @@ -75,8 +78,8 @@ def test_mine_get_dict_str(funcs, tgt_type_key="tgt_type"): - the correct cache key is subsequently used """ funcs.cache.store( - "minions/webserver", "mine", + "webserver", dict(ip_addr="2001:db8::1:3", ip4_addr="127.0.0.1"), ) with patch( @@ -108,8 +111,8 @@ def test_mine_get_dict_list(funcs, tgt_type_key="tgt_type"): - the correct cache key is subsequently used """ funcs.cache.store( - "minions/webserver", "mine", + "webserver", dict(ip_addr="2001:db8::1:3", ip4_addr="127.0.0.1"), ) with patch( @@ -136,8 +139,8 @@ def test_mine_get_acl_allowed(funcs): in the client-side ACL that was stored in the mine data. """ funcs.cache.store( - "minions/webserver", "mine", + "webserver", { "ip_addr": { salt.utils.mine.MINE_ITEM_ACL_DATA: "2001:db8::1:4", @@ -174,8 +177,8 @@ def test_mine_get_acl_rejected(funcs): no data being sent back (just as if the entry wouldn't exist). """ funcs.cache.store( - "minions/webserver", "mine", + "webserver", { "ip_addr": { salt.utils.mine.MINE_ITEM_ACL_DATA: "2001:db8::1:4", diff --git a/tests/pytests/unit/pillar/test_pillar.py b/tests/pytests/unit/pillar/test_pillar.py index 763739f9193c..2a6d4fb77f97 100644 --- a/tests/pytests/unit/pillar/test_pillar.py +++ b/tests/pytests/unit/pillar/test_pillar.py @@ -137,16 +137,15 @@ def test_pillar_envs_path_substitution(env, temp_salt_minion, tmp_path): def test_pillar_get_cache_disk(temp_salt_minion, caplog): # create faked path for cache with pytest.helpers.temp_directory() as temp_path: - tmp_cachedir = Path(str(temp_path) + "/pillar_cache/") - tmp_cachedir.mkdir(parents=True) - assert tmp_cachedir.exists() - tmp_cachefile = Path(str(temp_path) + "/pillar_cache/" + temp_salt_minion.id) + opts = temp_salt_minion.config.copy() + tmp_cachefile = Path(str(temp_path)) / "pillar" / f"{temp_salt_minion.id}.p" + tmp_cachefile.parent.mkdir(parents=True) tmp_cachefile.touch() assert tmp_cachefile.exists() - opts = temp_salt_minion.config.copy() opts["pillarenv"] = None opts["pillar_cache"] = True + opts["minion_data_cache"] = False opts["cachedir"] = str(temp_path) caplog.at_level(logging.DEBUG) @@ -156,25 +155,23 @@ def test_pillar_get_cache_disk(temp_salt_minion, caplog): minion_id=temp_salt_minion.id, saltenv="base", ) - fresh_pillar = pillar.fetch_pillar() - assert not ( - f"Error reading cache file at '{tmp_cachefile}': Unpack failed: incomplete input" - in caplog.messages - ) + fresh_pillar = pillar.compile_pillar() + assert "Unpack failed: incomplete input" not in caplog.messages assert fresh_pillar == {} def test_pillar_fetch_pillar_override_skipped(temp_salt_minion, caplog): with pytest.helpers.temp_directory() as temp_path: - tmp_cachedir = Path(str(temp_path) + "/pillar_cache/") - tmp_cachedir.mkdir(parents=True) - assert tmp_cachedir.exists() - tmp_cachefile = Path(str(temp_path) + "/pillar_cache/" + temp_salt_minion.id) - assert tmp_cachefile.exists() is False + opts = temp_salt_minion.config.copy() + tmp_cachefile = Path(str(temp_path)) / "pillar" / f"{temp_salt_minion.id}.p" + tmp_cachefile.parent.mkdir(parents=True) + tmp_cachefile.touch() + assert tmp_cachefile.exists() opts = temp_salt_minion.config.copy() opts["pillarenv"] = None opts["pillar_cache"] = True + opts["minion_data_cache"] = False opts["cachedir"] = str(temp_path) pillar_override = {"inline_pillar": True} @@ -188,8 +185,11 @@ def test_pillar_fetch_pillar_override_skipped(temp_salt_minion, caplog): pillar_override=pillar_override, ) - fresh_pillar = pillar.fetch_pillar() - assert fresh_pillar == {} + fresh_pillar = pillar.compile_pillar() + assert "inline_pillar" in fresh_pillar + + pillar_cache = pillar.cache.fetch("pillar", f"{temp_salt_minion.id}:base") + assert "inline_pillar" not in pillar_cache def test_remote_pillar_timeout(temp_salt_minion, tmp_path): diff --git a/tests/pytests/unit/runners/test_cache.py b/tests/pytests/unit/runners/test_cache.py index 8f416c069d8b..15846cd4c1e4 100644 --- a/tests/pytests/unit/runners/test_cache.py +++ b/tests/pytests/unit/runners/test_cache.py @@ -7,22 +7,28 @@ import salt.config import salt.runners.cache as cache import salt.utils.master -from tests.support.mock import patch +from tests.support.mock import MagicMock, call, patch @pytest.fixture -def configure_loader_modules(tmp_path): - master_config = salt.config.master_config(None) - master_config.update( +def master_opts(master_opts, tmp_path): + master_opts.update( { "cache": "localfs", "pki_dir": str(tmp_path), "key_cache": True, "keys.cache_driver": "localfs_key", "__role": "master", + "eauth_tokens.cache_driver": "localfs", + "pillar.cache_driver": "localfs", } ) - return {cache: {"__opts__": master_config}} + return master_opts + + +@pytest.fixture +def configure_loader_modules(master_opts): + return {cache: {"__opts__": master_opts}} def test_grains(): @@ -44,3 +50,192 @@ def get_minion_grains(self): with patch.object(salt.utils.master, "MasterPillarUtil", MockMaster): assert cache.grains(tgt="*") == mock_data + + +def test_migrate_all_banks(master_opts): + """ + Test migrate function when migrating all banks + """ + mock_key_cache = MagicMock() + mock_token_cache = MagicMock() + mock_mdc_cache = MagicMock() + mock_base_cache = MagicMock() + mock_dst_cache = MagicMock() + + mock_key_cache.list.side_effect = [["key1", "key2"], ["key3"], ["key4"]] + mock_token_cache.list.return_value = ["token1"] + mock_mdc_cache.list.return_value = ["pillar1"] + mock_base_cache.list.side_effect = [["grain1"], ["mine1"]] + + mock_key_cache.fetch.side_effect = ["value1", "value2", "value3", "value4"] + mock_token_cache.fetch.return_value = "token_value" + mock_mdc_cache.fetch.return_value = "pillar_value" + mock_base_cache.fetch.side_effect = ["grain_value", "mine_value"] + + mock_caches = [ + mock_key_cache, + mock_token_cache, + mock_mdc_cache, + mock_base_cache, + mock_dst_cache, + ] + + with patch("salt.cache.Cache") as mock_cache_factory: + mock_cache_factory.side_effect = mock_caches + + result = cache.migrate(target="redis") + + # Assert the result is True + assert result is True + + # Assert Cache initialized with correct drivers + assert mock_cache_factory.call_count == 5 + mock_cache_factory.assert_any_call( + master_opts, driver=master_opts["keys.cache_driver"] + ) + mock_cache_factory.assert_any_call( + master_opts, driver=master_opts["eauth_tokens.cache_driver"] + ) + mock_cache_factory.assert_any_call( + master_opts, driver=master_opts["pillar.cache_driver"] + ) + mock_cache_factory.assert_any_call(master_opts) + mock_cache_factory.assert_any_call(master_opts, driver="redis") + + # Assert all banks were listed + mock_key_cache.list.assert_any_call("keys") + mock_key_cache.list.assert_any_call("master_keys") + mock_key_cache.list.assert_any_call("denied_keys") + mock_token_cache.list.assert_called_once_with("tokens") + mock_mdc_cache.list.assert_called_once_with("pillar") + mock_base_cache.list.assert_any_call("grains") + mock_base_cache.list.assert_any_call("mine") + + # Assert data was fetched and stored + expected_fetch_calls = [ + call("keys", "key1"), + call("keys", "key2"), + call("master_keys", "key3"), + call("denied_keys", "key4"), + ] + mock_key_cache.fetch.assert_has_calls(expected_fetch_calls, any_order=True) + mock_token_cache.fetch.assert_called_once_with("tokens", "token1") + mock_mdc_cache.fetch.assert_called_once_with("pillar", "pillar1") + + # Assert data was stored in destination cache + expected_store_calls = [ + call("keys", "key1", "value1"), + call("keys", "key2", "value2"), + call("master_keys", "key3", "value3"), + call("denied_keys", "key4", "value4"), + call("tokens", "token1", "token_value"), + call("pillar", "pillar1", "pillar_value"), + call("grains", "grain1", "grain_value"), + call("mine", "mine1", "mine_value"), + ] + mock_dst_cache.store.assert_has_calls(expected_store_calls, any_order=True) + + +def test_migrate_specific_banks(): + """ + Test migrate function when specifying specific banks + """ + mock_key_cache = MagicMock() + mock_token_cache = MagicMock() + mock_mdc_cache = MagicMock() + mock_base_cache = MagicMock() + mock_dst_cache = MagicMock() + + mock_key_cache.list.side_effect = [["key1", "key2"], ["key3"]] + + mock_key_cache.fetch.side_effect = ["value1", "value2", "value3"] + + mock_caches = [ + mock_key_cache, + mock_token_cache, + mock_mdc_cache, + mock_base_cache, + mock_dst_cache, + ] + + with patch("salt.cache.Cache") as mock_cache_factory: + mock_cache_factory.side_effect = mock_caches + + result = cache.migrate(target="redis", bank="keys,master_keys") + + # Assert the result is True + assert result is True + + # Assert Cache initialized with correct drivers + assert mock_cache_factory.call_count == 5 + + # Assert only specified banks were listed + mock_key_cache.list.assert_any_call("keys") + mock_key_cache.list.assert_any_call("master_keys") + + # Assert specified banks were NOT listed + assert call("denied_keys") not in mock_key_cache.list.call_args_list + assert not mock_token_cache.list.called + assert not mock_mdc_cache.list.called + assert not mock_base_cache.list.called + + # Assert data was fetched and stored only for specified banks + expected_fetch_calls = [ + call("keys", "key1"), + call("keys", "key2"), + call("master_keys", "key3"), + ] + mock_key_cache.fetch.assert_has_calls(expected_fetch_calls, any_order=True) + + # Assert data was stored in destination cache + expected_store_calls = [ + call("keys", "key1", "value1"), + call("keys", "key2", "value2"), + call("master_keys", "key3", "value3"), + ] + mock_dst_cache.store.assert_has_calls(expected_store_calls, any_order=True) + + +def test_migrate_empty_bank(caplog): + """ + Test migrate function with a bank that has no keys + """ + mock_key_cache = MagicMock() + mock_token_cache = MagicMock() + mock_mdc_cache = MagicMock() + mock_base_cache = MagicMock() + mock_dst_cache = MagicMock() + + # Empty list of keys + mock_key_cache.list.return_value = [] + + mock_caches = [ + mock_key_cache, + mock_token_cache, + mock_mdc_cache, + mock_base_cache, + mock_dst_cache, + ] + + with patch("salt.cache.Cache") as mock_cache_factory: + mock_cache_factory.side_effect = mock_caches + + # Set caplog to capture INFO level messages + caplog.set_level("INFO") + + result = cache.migrate(target="redis", bank="keys") + + # Assert the result is True + assert result is True + + # Assert bank was listed but found empty + mock_key_cache.list.assert_called_once_with("keys") + + # Assert no data was fetched since bank was empty + assert not mock_key_cache.fetch.called + + # Assert no data was stored since bank was empty + assert not mock_dst_cache.store.called + + # Check that the empty migration was logged + assert "bank keys: migrating 0 keys" in caplog.text diff --git a/tests/pytests/unit/runners/test_pillar.py b/tests/pytests/unit/runners/test_pillar.py index 8dc7fa767daf..c6157704f4b1 100644 --- a/tests/pytests/unit/runners/test_pillar.py +++ b/tests/pytests/unit/runners/test_pillar.py @@ -10,7 +10,6 @@ import salt.runners.pillar as pillar_runner import salt.utils.files -import salt.utils.gitfs import salt.utils.msgpack from tests.support.mock import MagicMock, mock_open, patch @@ -18,18 +17,20 @@ @pytest.fixture -def configure_loader_modules(): - return { - pillar_runner: { - "__opts__": { - "pillar_cache": True, - "pillar_cache_backend": "disk", - "pillar_cache_ttl": 30, - "keys.cache_driver": "localfs_key", - "__role": "master", - } +def configure_loader_modules(master_opts): + master_opts.update( + { + "pillar_cache": True, + "pillar_cache_backend": "disk", + "pillar_cache_ttl": 30, + "state_top": "top.sls", + "pillar_roots": [], + "fileserver_backend": [], + "keys.cache_driver": "localfs_key", + "__role": "master", } - } + ) + return {pillar_runner: {"__opts__": master_opts}} @pytest.fixture(scope="module") @@ -43,7 +44,7 @@ def cachedir_tree(tmp_path_factory): @pytest.fixture(scope="module") def pillar_cache_dir(cachedir_tree): - pillar_cache_dir = cachedir_tree / "pillar_cache" + pillar_cache_dir = cachedir_tree / "pillar" pillar_cache_dir.mkdir() return pillar_cache_dir @@ -51,46 +52,34 @@ def pillar_cache_dir(cachedir_tree): @pytest.fixture(scope="function") def pillar_cache_files(pillar_cache_dir): MINION_ID = "test-host" - cache = { - "CacheDisk_data": { - MINION_ID: { - None: { - "this": "one", - "that": "two", - "those": ["three", "four", "five"], - } - } - }, - "CacheDisk_cachetime": {MINION_ID: 1612302460.146923}, + pillar = { + "this": "one", + "that": "two", + "those": ["three", "four", "five"], } packer = salt.utils.msgpack.Packer() - cache_contents = packer.pack(cache) + cache_contents = packer.pack(pillar) + cache_file = os.path.join(str(pillar_cache_dir), f"{MINION_ID}.p") - with salt.utils.files.fopen( - os.path.join(str(pillar_cache_dir), MINION_ID), "wb+" - ) as fp: + with salt.utils.files.fopen(cache_file, "wb+") as fp: fp.write(cache_contents) + mtime = 1612302460.146923 + os.utime(cache_file, (mtime, mtime)) MINION_ID = "another-host" cache = { - "CacheDisk_data": { - MINION_ID: { - None: { - "this": "six", - "that": "seven", - "those": ["eight", "nine", "ten"], - } - } - }, - "CacheDisk_cachetime": {MINION_ID: 1612302460.146923}, + "this": "six", + "that": "seven", + "those": ["eight", "nine", "ten"], } packer = salt.utils.msgpack.Packer() cache_contents = packer.pack(cache) - with salt.utils.files.fopen( - os.path.join(str(pillar_cache_dir), MINION_ID), "wb+" - ) as fp: + cache_file = os.path.join(str(pillar_cache_dir), f"{MINION_ID}.p") + with salt.utils.files.fopen(cache_file, "wb+") as fp: fp.write(cache_contents) + mtime = 1612302460.146923 + os.utime(cache_file, (mtime, mtime)) def test_clear_pillar_cache(cachedir_tree, pillar_cache_dir, pillar_cache_files): @@ -113,36 +102,37 @@ def test_clear_pillar_cache(cachedir_tree, pillar_cache_dir, pillar_cache_files) "salt.utils.minions.CkMinions.check_minions", MagicMock(side_effect=_CHECK_MINIONS_RETURN), ): - expected = { - "test-host": { - "those": ["three", "four", "five"], - "that": "two", - "this": "one", - }, - "another-host": { - "those": ["eight", "nine", "ten"], - "that": "seven", - "this": "six", - }, - } - ret = pillar_runner.show_pillar_cache() - assert ret == expected - - ret = pillar_runner.clear_pillar_cache("test-host") - assert ret == {} - - expected = { - "another-host": { - "those": ["eight", "nine", "ten"], - "that": "seven", - "this": "six", + with patch("salt.pillar.Pillar._Pillar__gather_avail", return_value={}): + expected = { + "test-host": { + "those": ["three", "four", "five"], + "that": "two", + "this": "one", + }, + "another-host": { + "those": ["eight", "nine", "ten"], + "that": "seven", + "this": "six", + }, } - } - ret = pillar_runner.show_pillar_cache() - assert ret == expected + ret = pillar_runner.show_pillar_cache() + assert ret == expected + + ret = pillar_runner.clear_pillar_cache("test-host") + assert ret is True + + expected = { + "another-host": { + "those": ["eight", "nine", "ten"], + "that": "seven", + "this": "six", + } + } + ret = pillar_runner.show_pillar_cache() + assert ret == expected - ret = pillar_runner.clear_pillar_cache() - assert ret == {} + ret = pillar_runner.clear_pillar_cache() + assert ret is True def test_show_pillar_cache(cachedir_tree, pillar_cache_dir, pillar_cache_files): @@ -161,35 +151,42 @@ def test_show_pillar_cache(cachedir_tree, pillar_cache_dir, pillar_cache_files): "salt.utils.minions.CkMinions.check_minions", MagicMock(side_effect=_CHECK_MINIONS_RETURN), ): - expected = { - "test-host": { - "those": ["three", "four", "five"], - "that": "two", - "this": "one", - }, - "another-host": { - "those": ["eight", "nine", "ten"], - "that": "seven", - "this": "six", - }, - } - ret = pillar_runner.show_pillar_cache() - assert ret == expected - - expected = { - "test-host": { - "this": "one", - "that": "two", - "those": ["three", "four", "five"], - } - } - ret = pillar_runner.show_pillar_cache("test-host") - assert ret == expected - - _EMPTY_CHECK_MINIONS_RETURN = {"minions": [], "missing": []} - with patch( - "salt.utils.minions.CkMinions.check_minions", - MagicMock(return_value=_EMPTY_CHECK_MINIONS_RETURN), - ), patch("salt.utils.atomicfile.atomic_open", mock_open()) as atomic_open_mock: - ret = pillar_runner.show_pillar_cache("fake-host") - assert ret == {} + with patch("salt.pillar.Pillar._Pillar__gather_avail", return_value={}): + with patch( + "salt.utils.minions.CkMinions.check_minions", + MagicMock(side_effect=_CHECK_MINIONS_RETURN), + ): + expected = { + "test-host": { + "those": ["three", "four", "five"], + "that": "two", + "this": "one", + }, + "another-host": { + "those": ["eight", "nine", "ten"], + "that": "seven", + "this": "six", + }, + } + ret = pillar_runner.show_pillar_cache() + assert ret == expected + + expected = { + "test-host": { + "this": "one", + "that": "two", + "those": ["three", "four", "five"], + } + } + ret = pillar_runner.show_pillar_cache("test-host") + assert ret == expected + + _EMPTY_CHECK_MINIONS_RETURN = {"minions": [], "missing": []} + with patch( + "salt.utils.minions.CkMinions.check_minions", + MagicMock(return_value=_EMPTY_CHECK_MINIONS_RETURN), + ), patch( + "salt.utils.atomicfile.atomic_open", mock_open() + ) as atomic_open_mock: + ret = pillar_runner.show_pillar_cache("fake-host") + assert ret == {} diff --git a/tests/pytests/unit/test_auth.py b/tests/pytests/unit/test_auth.py index e477ee9ae8e3..bf1999ceeadc 100644 --- a/tests/pytests/unit/test_auth.py +++ b/tests/pytests/unit/test_auth.py @@ -37,7 +37,13 @@ def load_auth(): patcher = patch(mod, mock) patcher.start() patchers.append(patcher) - lauth = salt.auth.LoadAuth({}) # Load with empty opts + lauth = salt.auth.LoadAuth( + { + "eauth_tokens.cache_driver": None, + "eauth_tokens.cluster_id": None, + "cluster_id": None, + } + ) # Load with empty opts try: yield lauth finally: @@ -213,14 +219,10 @@ def test_get_tok_with_broken_file_will_remove_bad_token(load_auth): fake_get_token = MagicMock( side_effect=salt.exceptions.SaltDeserializationError("hi") ) - patch_opts = patch.dict(load_auth.opts, {"eauth_tokens": "testfs"}) - patch_get_token = patch.dict( - load_auth.tokens, - {"testfs.get_token": fake_get_token}, - ) + patch_get_token = patch.object(load_auth.cache, "fetch", fake_get_token) mock_rm_token = MagicMock() patch_rm_token = patch.object(load_auth, "rm_token", mock_rm_token) - with patch_opts, patch_get_token, patch_rm_token: + with patch_get_token, patch_rm_token: expected_token = "fnord" load_auth.get_tok(expected_token) mock_rm_token.assert_called_with(expected_token) @@ -228,14 +230,10 @@ def test_get_tok_with_broken_file_will_remove_bad_token(load_auth): def test_get_tok_with_no_expiration_should_remove_bad_token(load_auth): fake_get_token = MagicMock(return_value={"no_expire_here": "Nope"}) - patch_opts = patch.dict(load_auth.opts, {"eauth_tokens": "testfs"}) - patch_get_token = patch.dict( - load_auth.tokens, - {"testfs.get_token": fake_get_token}, - ) + patch_get_token = patch.object(load_auth.cache, "fetch", fake_get_token) mock_rm_token = MagicMock() patch_rm_token = patch.object(load_auth, "rm_token", mock_rm_token) - with patch_opts, patch_get_token, patch_rm_token: + with patch_get_token, patch_rm_token: expected_token = "fnord" load_auth.get_tok(expected_token) mock_rm_token.assert_called_with(expected_token) @@ -243,14 +241,10 @@ def test_get_tok_with_no_expiration_should_remove_bad_token(load_auth): def test_get_tok_with_expire_before_current_time_should_remove_token(load_auth): fake_get_token = MagicMock(return_value={"expire": time.time() - 1}) - patch_opts = patch.dict(load_auth.opts, {"eauth_tokens": "testfs"}) - patch_get_token = patch.dict( - load_auth.tokens, - {"testfs.get_token": fake_get_token}, - ) + patch_get_token = patch.object(load_auth.cache, "fetch", fake_get_token) mock_rm_token = MagicMock() patch_rm_token = patch.object(load_auth, "rm_token", mock_rm_token) - with patch_opts, patch_get_token, patch_rm_token: + with patch_get_token, patch_rm_token: expected_token = "fnord" load_auth.get_tok(expected_token) mock_rm_token.assert_called_with(expected_token) @@ -259,14 +253,10 @@ def test_get_tok_with_expire_before_current_time_should_remove_token(load_auth): def test_get_tok_with_valid_expiration_should_return_token(load_auth): expected_token = {"expire": time.time() + 1} fake_get_token = MagicMock(return_value=expected_token) - patch_opts = patch.dict(load_auth.opts, {"eauth_tokens": "testfs"}) - patch_get_token = patch.dict( - load_auth.tokens, - {"testfs.get_token": fake_get_token}, - ) + patch_get_token = patch.object(load_auth.cache, "fetch", fake_get_token) mock_rm_token = MagicMock() patch_rm_token = patch.object(load_auth, "rm_token", mock_rm_token) - with patch_opts, patch_get_token, patch_rm_token: + with patch_get_token, patch_rm_token: token_name = "fnord" actual_token = load_auth.get_tok(token_name) mock_rm_token.assert_not_called() @@ -893,3 +883,38 @@ async def test_acl_simple_deny(auth_acl_clear_funcs, auth_acl_valid_load): assert auth_acl_clear_funcs.ckminions.auth_check.call_args[0][0] == [ {"beta_minion": ["test.ping"]} ] + + +def test_cve_2021_3244(tmp_path): + token_dir = tmp_path + opts = { + "extension_modules": "", + "optimization_order": [0, 1, 2], + "token_expire": 1, + "keep_acl_in_token": False, + "eauth_tokens": "localfs", + "cachedir": str(token_dir), + "token_expire_user_override": True, + "external_auth": {"auto": {"foo": []}}, + "eauth_tokens.cache_driver": None, + "eauth_tokens.cluster_id": None, + "cluster_id": None, + } + auth = salt.auth.LoadAuth(opts) + load = { + "eauth": "auto", + "username": "foo", + "password": "foo", + "token_expire": -1, + } + fake_get_token = MagicMock( + side_effect=salt.exceptions.SaltDeserializationError("hi") + ) + with patch.object(auth.cache, "fetch", fake_get_token): + t_data = auth.mk_token(load) + assert t_data["expire"] < time.time() + token_file = token_dir / "tokens" / f'{t_data["token"]}.p' + assert token_file.exists() + t_data = auth.get_tok(t_data["token"]) + assert not t_data + assert not token_file.exists() diff --git a/tests/pytests/unit/test_pillar.py b/tests/pytests/unit/test_pillar.py index bcd172e25947..04ecf65c1819 100644 --- a/tests/pytests/unit/test_pillar.py +++ b/tests/pytests/unit/test_pillar.py @@ -19,7 +19,7 @@ import salt.fileclient import salt.utils.stringutils from salt.utils.files import fopen -from tests.support.mock import MagicMock, patch +from tests.support.mock import ANY, MagicMock, call, patch from tests.support.runtests import RUNTIME_VARS log = logging.getLogger(__name__) @@ -1180,8 +1180,14 @@ def test_include(tmp_path): assert compiled_pillar["sub_init_dot"] == "sub_with_init_dot_worked" -def test_compile_pillar_memory_cache(master_opts): - master_opts.update({"pillar_cache_backend": "memory", "pillar_cache_ttl": 3600}) +def test_compile_pillar_cache(master_opts): + master_opts.update( + { + "memcache_expire_seconds": 3600, + "pillar_cache_ttl": 3600, + "pillar_cache": True, + } + ) pillar = salt.pillar.PillarCache( master_opts, @@ -1192,39 +1198,36 @@ def test_compile_pillar_memory_cache(master_opts): ) with patch( - "salt.pillar.PillarCache.fetch_pillar", + "salt.pillar.Pillar.compile_pillar", side_effect=[{"foo": "bar"}, {"foo": "baz"}], ): # Run once for pillarenv base - ret = pillar.compile_pillar() - expected_cache = {"base": {"foo": "bar"}} - assert "mocked_minion" in pillar.cache - assert pillar.cache["mocked_minion"] == expected_cache + pillar.compile_pillar() + expected_cache = {("pillar", "mocked_minion:base"): [ANY, None, {"foo": "bar"}]} + assert pillar.cache.storage == expected_cache # Run a second time for pillarenv base - ret = pillar.compile_pillar() - expected_cache = {"base": {"foo": "bar"}} - assert "mocked_minion" in pillar.cache - assert pillar.cache["mocked_minion"] == expected_cache + pillar.compile_pillar() + assert pillar.cache.storage == expected_cache # Change the pillarenv - pillar.pillarenv = "dev" + pillar.opts["pillarenv"] = "dev" # Run once for pillarenv dev - ret = pillar.compile_pillar() - expected_cache = {"base": {"foo": "bar"}, "dev": {"foo": "baz"}} - assert "mocked_minion" in pillar.cache - assert pillar.cache["mocked_minion"] == expected_cache + pillar.compile_pillar() + expected_cache = { + ("pillar", "mocked_minion:base"): [ANY, None, {"foo": "bar"}], + ("pillar", "mocked_minion:dev"): [ANY, None, {"foo": "baz"}], + } + assert pillar.cache.storage == expected_cache # Run a second time for pillarenv dev - ret = pillar.compile_pillar() - expected_cache = {"base": {"foo": "bar"}, "dev": {"foo": "baz"}} - assert "mocked_minion" in pillar.cache - assert pillar.cache["mocked_minion"] == expected_cache + pillar.compile_pillar() + assert pillar.cache.storage == expected_cache def test_compile_pillar_disk_cache(master_opts, grains): - master_opts.update({"pillar_cache_backend": "disk", "pillar_cache_ttl": 3600}) + master_opts.update({"pillar_cache_ttl": 3600, "pillar_cache": True}) pillar = salt.pillar.PillarCache( master_opts, @@ -1233,38 +1236,46 @@ def test_compile_pillar_disk_cache(master_opts, grains): "fake_env", pillarenv="base", ) + with patch( + "salt.pillar.Pillar.compile_pillar", + side_effect=[{"foo": "bar"}, {"foo": "baz"}], + ), patch.object( + pillar.cache, "fetch", side_effect=[None, {"foo": "bar"}, None, {"foo": "baz"}] + ) as fetch_mock, patch.object( + pillar.cache, "store" + ) as store_mock: + # Run once for pillarenv base + pillar.compile_pillar() - with patch("salt.utils.cache.CacheDisk._write", MagicMock()): - with patch( - "salt.pillar.PillarCache.fetch_pillar", - side_effect=[{"foo": "bar"}, {"foo": "baz"}], - ): - # Run once for pillarenv base - ret = pillar.compile_pillar() - expected_cache = {"mocked_minion": {"base": {"foo": "bar"}}} - assert pillar.cache._dict == expected_cache - - # Run a second time for pillarenv base - ret = pillar.compile_pillar() - expected_cache = {"mocked_minion": {"base": {"foo": "bar"}}} - assert pillar.cache._dict == expected_cache - - # Change the pillarenv - pillar.pillarenv = "dev" - - # Run once for pillarenv dev - ret = pillar.compile_pillar() - expected_cache = { - "mocked_minion": {"base": {"foo": "bar"}, "dev": {"foo": "baz"}} - } - assert pillar.cache._dict == expected_cache + # Run a second time for pillarenv base + pillar.compile_pillar() - # Run a second time for pillarenv dev - ret = pillar.compile_pillar() - expected_cache = { - "mocked_minion": {"base": {"foo": "bar"}, "dev": {"foo": "baz"}} - } - assert pillar.cache._dict == expected_cache + # Change the pillarenv + pillar.opts["pillarenv"] = "dev" + + # Run once for pillarenv dev + pillar.compile_pillar() + + # Run a second time for pillarenv dev + pillar.compile_pillar() + + expected_fetches = [ + call("pillar", "mocked_minion:base"), + call("pillar", "mocked_minion:base"), + call("pillar", "mocked_minion:dev"), + call("pillar", "mocked_minion:dev"), + ] + + # Assert all calls match the pattern + fetch_mock.assert_has_calls(expected_fetches, any_order=False) + + expected_stores = [ + call("pillar", "mocked_minion:base", {"foo": "bar"}), + call("pillar", "mocked_minion:dev", {"foo": "baz"}), + ] + + # Assert all calls match the pattern + store_mock.assert_has_calls(expected_stores, any_order=False) def test_remote_pillar_bad_return(grains, tmp_pki): diff --git a/tests/pytests/unit/test_sqlalchemy.py b/tests/pytests/unit/test_sqlalchemy.py new file mode 100644 index 000000000000..e1ad2c0e1ebd --- /dev/null +++ b/tests/pytests/unit/test_sqlalchemy.py @@ -0,0 +1,602 @@ +""" + Test cases for salt.sqlalchemy +""" + +import json + +import pytest + +import salt.exceptions +import salt.sqlalchemy +from tests.support.mock import ANY, MagicMock, call, patch + + +@pytest.fixture +def mock_engine(): + """Mock SQLAlchemy engine object""" + mock_engine = MagicMock() + mock_engine.dialect.name = "postgresql" + + # Mock connection and cursor for event listeners + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_connection.cursor.return_value = mock_cursor + mock_engine.connect.return_value.__enter__.return_value = mock_connection + + return mock_engine + + +@pytest.fixture +def dsn_opts(): + """Return opts dict with DSN configuration""" + return { + "sqlalchemy.dsn": "postgresql://user:pass@localhost:5432/test", + "sqlalchemy.slow_query_threshold": 0.5, + "sqlalchemy.slow_connect_threshold": 0.5, + } + + +@pytest.fixture +def connection_opts(): + """Return opts dict with connection parameters""" + return { + "sqlalchemy.drivername": "postgresql", + "sqlalchemy.host": "localhost", + "sqlalchemy.username": "salt", + "sqlalchemy.password": "salt", + "sqlalchemy.database": "salt", + "sqlalchemy.port": 5432, + "sqlalchemy.slow_query_threshold": 0.5, + "sqlalchemy.slow_connect_threshold": 0.5, + "sqlalchemy.disable_connection_pool": True, + } + + +@pytest.fixture +def ro_connection_opts(): + """Return opts dict with connection parameters including read-only configuration""" + return { + "sqlalchemy.drivername": "postgresql", + "sqlalchemy.host": "localhost", + "sqlalchemy.username": "salt", + "sqlalchemy.password": "salt", + "sqlalchemy.database": "salt", + "sqlalchemy.port": 5432, + "sqlalchemy.slow_query_threshold": 0.5, + "sqlalchemy.slow_connect_threshold": 0.5, + "sqlalchemy.ro_host": "readonly.localhost", + "sqlalchemy.ro_username": "salt_ro", + "sqlalchemy.ro_password": "salt_ro", + "sqlalchemy.ro_database": "salt", + "sqlalchemy.ro_port": 5432, + } + + +@pytest.fixture +def ro_dsn_opts(): + """Return opts dict with DSN configuration including read-only DSN""" + return { + "sqlalchemy.dsn": "postgresql://user:pass@localhost:5432/test", + "sqlalchemy.ro_dsn": "postgresql://reader:pass@readonly.localhost:5432/test", + "sqlalchemy.slow_query_threshold": 0.5, + "sqlalchemy.slow_connect_threshold": 0.5, + } + + +def strip_prefix(opts): + """ + direct calls to _make_engine do not expect prefixed opts + """ + return {k.replace("sqlalchemy.", ""): v for k, v in opts.items()} + + +def test_orm_configured(): + """ + Test the orm_configured function + """ + salt.sqlalchemy.ENGINE_REGISTRY = {} + assert salt.sqlalchemy.orm_configured() is False + salt.sqlalchemy.ENGINE_REGISTRY = {"default": {}} + assert salt.sqlalchemy.orm_configured() is True + + +def test_make_engine_with_dsn(dsn_opts): + """ + Test _make_engine function with DSN configuration + """ + stripped_dsn_opts = strip_prefix(dsn_opts) + with patch("sqlalchemy.create_engine") as mock_create_engine, patch( + "sqlalchemy.event.listens_for" + ): + mock_create_engine.return_value = MagicMock() + engine = salt.sqlalchemy._make_engine(stripped_dsn_opts) + assert engine is mock_create_engine.return_value + mock_create_engine.assert_called_once_with( + stripped_dsn_opts["dsn"], + json_serializer=ANY, + json_deserializer=ANY, + ) + + +def test_make_engine_with_connection_params(connection_opts): + """ + Test _make_engine function with connection parameters + """ + with patch("sqlalchemy.create_engine") as mock_create_engine, patch( + "sqlalchemy.engine.url.URL" + ) as mock_url, patch("sqlalchemy.event.listens_for"): + mock_url.return_value = "postgresql://salt:salt@localhost:5432/salt" + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + engine = salt.sqlalchemy._make_engine(strip_prefix(connection_opts)) + assert engine is mock_engine + mock_url.assert_called_once() + mock_create_engine.assert_called_once() + + +def test_make_engine_with_both_configurations(dsn_opts, connection_opts): + """ + Test _make_engine with both DSN and connection parameters should raise an exception + """ + invalid_opts = {**dsn_opts} + invalid_opts.update( + { + k: v + for k, v in connection_opts.items() + if k != "sqlalchemy.slow_query_threshold" + } + ) + + with pytest.raises(salt.exceptions.SaltConfigurationError): + salt.sqlalchemy._make_engine(strip_prefix(invalid_opts)) + + +def test_make_engine_with_missing_config(): + """ + Test _make_engine with missing required configuration should raise an exception + """ + opts = {"host": "localhost"} # Missing other required params + with pytest.raises(salt.exceptions.SaltConfigurationError): + salt.sqlalchemy._make_engine(opts) + + +def test_make_engine_with_schema(dsn_opts): + """ + Test _make_engine with schema translation + """ + with patch("sqlalchemy.create_engine") as mock_create_engine, patch( + "sqlalchemy.event.listens_for" + ): + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + salt.sqlalchemy._make_engine(strip_prefix(dsn_opts)) + + # Verify schema translation was not set + mock_engine.execution_options.assert_not_called() + + +def test_configure_orm(dsn_opts, mock_engine): + """ + Test configure_orm function + """ + with patch("salt.sqlalchemy._make_engine", return_value=mock_engine), patch( + "sqlalchemy.orm.sessionmaker" + ), patch("sqlalchemy.orm.scoped_session"), patch( + "salt.sqlalchemy.models.populate_model_registry" + ) as mock_populate: + + # Clear any existing registry + salt.sqlalchemy.ENGINE_REGISTRY = {} + + salt.sqlalchemy.configure_orm(dsn_opts) + + # Check that the engine was registered + assert "default" in salt.sqlalchemy.ENGINE_REGISTRY + assert "engine" in salt.sqlalchemy.ENGINE_REGISTRY["default"] + assert "session" in salt.sqlalchemy.ENGINE_REGISTRY["default"] + assert "ro_engine" in salt.sqlalchemy.ENGINE_REGISTRY["default"] + assert "ro_session" in salt.sqlalchemy.ENGINE_REGISTRY["default"] + + # Since dsn_opts doesn't have ro_dsn or ro_host, ro_engine should equal engine + assert ( + salt.sqlalchemy.ENGINE_REGISTRY["default"]["engine"] + is salt.sqlalchemy.ENGINE_REGISTRY["default"]["ro_engine"] + ) + assert ( + salt.sqlalchemy.ENGINE_REGISTRY["default"]["session"] + is salt.sqlalchemy.ENGINE_REGISTRY["default"]["ro_session"] + ) + + # Check that model registry was populated + mock_populate.assert_called_once() + + +def test_configure_orm_with_ro_settings(ro_connection_opts, mock_engine): + """ + Test configure_orm function with read-only configuration + """ + # Create a separate mock for the read-only engine + mock_ro_engine = MagicMock() + mock_ro_engine.dialect.name = "postgresql" + + with patch("salt.sqlalchemy._make_engine") as mock_make_engine, patch( + "sqlalchemy.orm.sessionmaker", return_value=MagicMock() + ), patch("sqlalchemy.orm.scoped_session", return_value=MagicMock()), patch( + "salt.sqlalchemy.models.populate_model_registry" + ): + + # Make _make_engine return different engines based on prefix + def side_effect_make_engine(opts, prefix=None): + return mock_ro_engine if prefix == "ro_" else mock_engine + + mock_make_engine.side_effect = side_effect_make_engine + + # Clear any existing registry + salt.sqlalchemy.ENGINE_REGISTRY = {} + + salt.sqlalchemy.configure_orm(ro_connection_opts) + + # Verify _make_engine was called twice - once for main engine, once for ro_engine + assert mock_make_engine.call_count == 2 + + # Check that the engines and sessions are separate + assert "default" in salt.sqlalchemy.ENGINE_REGISTRY + assert salt.sqlalchemy.ENGINE_REGISTRY["default"]["engine"] == mock_engine + assert salt.sqlalchemy.ENGINE_REGISTRY["default"]["ro_engine"] == mock_ro_engine + assert ( + salt.sqlalchemy.ENGINE_REGISTRY["default"]["engine"] + is not salt.sqlalchemy.ENGINE_REGISTRY["default"]["ro_engine"] + ) + assert ( + salt.sqlalchemy.ENGINE_REGISTRY["default"]["session"] + is not salt.sqlalchemy.ENGINE_REGISTRY["default"]["ro_session"] + ) + + +def test_configure_orm_with_ro_dsn(ro_dsn_opts, mock_engine): + """ + Test configure_orm function with read-only DSN configuration + """ + # Create a separate mock for the read-only engine + mock_ro_engine = MagicMock() + mock_ro_engine.dialect.name = "postgresql" + + with patch("salt.sqlalchemy._make_engine") as mock_make_engine, patch( + "sqlalchemy.orm.sessionmaker", return_value=MagicMock() + ), patch("sqlalchemy.orm.scoped_session", return_value=MagicMock()), patch( + "salt.sqlalchemy.models.populate_model_registry" + ): + + # Make _make_engine return different engines based on prefix + def side_effect_make_engine(opts, prefix=None): + return mock_ro_engine if prefix == "ro_" else mock_engine + + mock_make_engine.side_effect = side_effect_make_engine + + # Clear any existing registry + salt.sqlalchemy.ENGINE_REGISTRY = {} + + salt.sqlalchemy.configure_orm(ro_dsn_opts) + + # Verify _make_engine was called twice - once for main engine, once for ro_engine + assert mock_make_engine.call_count == 2 + + # Check that the engines and sessions are separate + assert "default" in salt.sqlalchemy.ENGINE_REGISTRY + assert salt.sqlalchemy.ENGINE_REGISTRY["default"]["engine"] == mock_engine + assert salt.sqlalchemy.ENGINE_REGISTRY["default"]["ro_engine"] == mock_ro_engine + assert ( + salt.sqlalchemy.ENGINE_REGISTRY["default"]["engine"] + is not salt.sqlalchemy.ENGINE_REGISTRY["default"]["ro_engine"] + ) + assert ( + salt.sqlalchemy.ENGINE_REGISTRY["default"]["session"] + is not salt.sqlalchemy.ENGINE_REGISTRY["default"]["ro_session"] + ) + + # Verify that _make_engine was called with correct prefixes + config_with_defaults = { + **salt.sqlalchemy.SQLA_DEFAULT_OPTS, + **strip_prefix(ro_dsn_opts), + } + + calls = [ + call(config_with_defaults, prefix=None), + call(config_with_defaults, prefix="ro_"), + ] + mock_make_engine.assert_has_calls(calls, any_order=True) + + +def test_configure_orm_without_ro_settings(connection_opts, mock_engine): + """ + Test configure_orm function without read-only configuration + """ + with patch("salt.sqlalchemy._make_engine", return_value=mock_engine), patch( + "sqlalchemy.orm.sessionmaker", return_value=MagicMock() + ), patch("sqlalchemy.orm.scoped_session", return_value=MagicMock()), patch( + "salt.sqlalchemy.models.populate_model_registry" + ): + + # Clear any existing registry + salt.sqlalchemy.ENGINE_REGISTRY = {} + + salt.sqlalchemy.configure_orm(connection_opts) + + # Verify _make_engine was called only once + salt.sqlalchemy._make_engine.assert_called_once() + + # Check that ro_engine and engine are the same instance + assert "default" in salt.sqlalchemy.ENGINE_REGISTRY + assert ( + salt.sqlalchemy.ENGINE_REGISTRY["default"]["engine"] + == salt.sqlalchemy.ENGINE_REGISTRY["default"]["ro_engine"] + ) + assert ( + salt.sqlalchemy.ENGINE_REGISTRY["default"]["session"] + == salt.sqlalchemy.ENGINE_REGISTRY["default"]["ro_session"] + ) + + +def test_configure_orm_already_configured(dsn_opts, mock_engine): + """ + Test configure_orm when already configured + """ + # Simulate already configured ORM + salt.sqlalchemy.ENGINE_REGISTRY = {"default": {"engine": mock_engine}} + + with patch("salt.sqlalchemy._make_engine") as mock_make_engine: + salt.sqlalchemy.configure_orm(dsn_opts) + mock_make_engine.assert_not_called() + + +def test_dispose_orm(mock_engine): + """ + Test dispose_orm function + """ + mock_session = MagicMock() + salt.sqlalchemy.ENGINE_REGISTRY = { + "default": { + "engine": mock_engine, + "session": mock_session, + "ro_engine": mock_engine, + "ro_session": mock_session, + } + } + + salt.sqlalchemy.dispose_orm() + + # Check that engine and session were disposed + mock_engine.dispose.assert_called() + mock_session.remove.assert_called() + + # Check that registry was cleared + assert len(salt.sqlalchemy.ENGINE_REGISTRY) == 0 + + +def test_reconfigure_orm(dsn_opts): + """ + Test reconfigure_orm function + """ + with patch("salt.sqlalchemy.dispose_orm") as mock_dispose, patch( + "salt.sqlalchemy.configure_orm" + ) as mock_configure: + salt.sqlalchemy.reconfigure_orm(dsn_opts) + mock_dispose.assert_called_once() + mock_configure.assert_called_once_with(dsn_opts) + + +def test_session(mock_engine): + """ + Test Session function + """ + mock_session = MagicMock() + mock_session_instance = MagicMock() + mock_session.return_value = mock_session_instance + + salt.sqlalchemy.ENGINE_REGISTRY = { + "default": { + "engine": mock_engine, + "session": mock_session, + } + } + + session = salt.sqlalchemy.Session() + assert session is mock_session_instance + mock_session.assert_called_once() + + # Test with specific engine name + salt.sqlalchemy.ENGINE_REGISTRY["test"] = {"session": mock_session} + session = salt.sqlalchemy.Session("test") + assert session is mock_session_instance + + +def test_session_not_configured(): + """ + Test Session function with unconfigured ORM + """ + salt.sqlalchemy.ENGINE_REGISTRY = {} + + with pytest.raises(salt.exceptions.SaltInvocationError): + salt.sqlalchemy.Session() + + +def test_ro_session(mock_engine): + """ + Test ROSession function + """ + mock_session = MagicMock() + mock_ro_session = MagicMock() + mock_session_instance = MagicMock() + mock_ro_session_instance = MagicMock() + mock_session.return_value = mock_session_instance + mock_ro_session.return_value = mock_ro_session_instance + + # Test with ro_session available + salt.sqlalchemy.ENGINE_REGISTRY = { + "default": { + "engine": mock_engine, + "session": mock_session, + "ro_session": mock_ro_session, + } + } + + session = salt.sqlalchemy.ROSession() + assert session is mock_ro_session_instance + mock_ro_session.assert_called_once() + + # Test with only session available (no ro_session) + salt.sqlalchemy.ENGINE_REGISTRY = { + "test": { + "engine": mock_engine, + "session": mock_session, + } + } + + session = salt.sqlalchemy.ROSession("test") + assert session is mock_session_instance + + +def test_ro_session_not_configured(): + """ + Test ROSession function with unconfigured ORM + """ + salt.sqlalchemy.ENGINE_REGISTRY = {} + + with pytest.raises(salt.exceptions.SaltInvocationError): + salt.sqlalchemy.ROSession() + + +def test_serialize(): + """ + Test _serialize function + """ + # Test with normal dict + data = {"key": "value"} + serialized = salt.sqlalchemy._serialize(data) + assert serialized == json.dumps(data) + + # Test with bytes + data = b"binary data" + serialized = salt.sqlalchemy._serialize(data) + expected = json.dumps({"_base64": "YmluYXJ5IGRhdGE="}) + assert serialized == expected + + # Test with NUL bytes + data = {"key": "value\u0000with\u0000nulls"} + serialized = salt.sqlalchemy._serialize(data) + expected = json.dumps(data).replace("\\u0000", "") + assert serialized == expected + + +def test_deserialize(): + """ + Test _deserialize function + """ + # Test with normal JSON + data = json.dumps({"key": "value"}) + deserialized = salt.sqlalchemy._deserialize(data) + assert deserialized == {"key": "value"} + + # Test with base64 encoded data + data = json.dumps({"_base64": "YmluYXJ5IGRhdGE="}) + deserialized = salt.sqlalchemy._deserialize(data) + assert deserialized == b"binary data" + + +def test_create_all(mock_engine): + """ + Test create_all function + """ + mock_session = MagicMock() + mock_session_instance = MagicMock() + mock_session.return_value = mock_session_instance + + salt.sqlalchemy.ENGINE_REGISTRY = { + "default": { + "engine": mock_engine, + "session": mock_session, + } + } + + with patch("salt.sqlalchemy.Session") as mock_session_constructor, patch( + "salt.sqlalchemy.models.model_for" + ) as mock_model_for: + mock_session_constructor.return_value.__enter__.return_value = ( + mock_session_instance + ) + mock_base = MagicMock() + mock_model_for.return_value = mock_base + + salt.sqlalchemy.create_all() + + mock_model_for.assert_called_once_with("Base", engine_name=None) + mock_base.metadata.create_all.assert_called_once_with( + mock_session_instance.get_bind() + ) + + +def test_drop_all(mock_engine): + """ + Test drop_all function + """ + mock_session = MagicMock() + mock_session_instance = MagicMock() + mock_session.return_value = mock_session_instance + + salt.sqlalchemy.ENGINE_REGISTRY = { + "default": { + "engine": mock_engine, + "session": mock_session, + } + } + + with patch("salt.sqlalchemy.Session") as mock_session_constructor, patch( + "salt.sqlalchemy.models.model_for" + ) as mock_model_for: + mock_session_constructor.return_value.__enter__.return_value = ( + mock_session_instance + ) + mock_base = MagicMock() + mock_model_for.return_value = mock_base + + salt.sqlalchemy.drop_all() + + mock_model_for.assert_called_once_with("Base", engine_name=None) + mock_base.metadata.drop_all.assert_called_once_with( + mock_session_instance.get_bind() + ) + + +def test_event_listeners_registered(connection_opts, mock_engine): + """ + Test that event listeners are registered for engine + """ + with patch("sqlalchemy.create_engine", return_value=mock_engine), patch( + "sqlalchemy.engine.url.URL" + ) as mock_url, patch("sqlalchemy.event.listens_for") as mock_listens_for: + + mock_url.return_value = "postgresql://salt:salt@localhost:5432/salt" + + salt.sqlalchemy._make_engine(strip_prefix(connection_opts)) + + # Check that event listeners were registered + assert mock_listens_for.call_count >= 5 + + # Verify specific event listeners + expected_events = [ + "do_connect", + "connect", + "checkout", + "before_cursor_execute", + "after_cursor_execute", + ] + + for event_name in expected_events: + event_registered = False + for call_args in mock_listens_for.call_args_list: + if call_args[0][1] == event_name: + event_registered = True + break + assert ( + event_registered + ), f"Event listener for {event_name} was not registered" diff --git a/tests/pytests/unit/utils/test_minions.py b/tests/pytests/unit/utils/test_minions.py index e6fb63250201..f99718da2dc7 100644 --- a/tests/pytests/unit/utils/test_minions.py +++ b/tests/pytests/unit/utils/test_minions.py @@ -22,7 +22,7 @@ def test_connected_ids(): ) minion = "minion" ips = {"203.0.113.1", "203.0.113.2", "127.0.0.1"} - mdata = {"grains": {"ipv4": ips, "ipv6": []}} + mdata = {"ipv4": ips, "ipv6": []} patch_net = patch("salt.utils.network.local_port_tcp", return_value=ips) patch_list = patch("salt.cache.Cache.list", return_value=[minion]) patch_fetch = patch("salt.cache.Cache.fetch", return_value=mdata) @@ -51,8 +51,8 @@ def test_connected_ids_remote_minions(): minion2 = "minion2" minion2_ip = "192.168.2.10" minion_ips = {"203.0.113.1", "203.0.113.2", "127.0.0.1"} - mdata = {"grains": {"ipv4": minion_ips, "ipv6": []}} - mdata2 = {"grains": {"ipv4": [minion2_ip], "ipv6": []}} + mdata = {"ipv4": minion_ips, "ipv6": []} + mdata2 = {"ipv4": [minion2_ip], "ipv6": []} patch_net = patch("salt.utils.network.local_port_tcp", return_value=minion_ips) patch_remote_net = patch( "salt.utils.network.remote_port_tcp", return_value={minion2_ip} diff --git a/tests/support/pytest/database.py b/tests/support/pytest/database.py new file mode 100644 index 000000000000..31c2a0a5bb7b --- /dev/null +++ b/tests/support/pytest/database.py @@ -0,0 +1,309 @@ +import logging +import time +from contextlib import contextmanager + +import attr +import pytest +from pytestskipmarkers.utils import platform +from saltfactories.utils import random_string + +try: + from docker.errors import APIError +except ImportError: + APIError = OSError + +log = logging.getLogger(__name__) + + +def _has_driver(driver_name): + try: + __import__(driver_name) + return True + except ImportError: + return False + + +def db_param(db_name, version, driver_name=None): + """ + Return a pytest.param for a database, optionally skipping if the driver is missing. + + :param db_name: Logical name of the database (e.g., 'postgres', 'mysql') + :param driver_name: Python module to import for the driver (e.g., 'psycopg2') + :return: pytest.param with conditional skip + """ + if driver_name is None: + return pytest.param((db_name, version), id=f"{db_name}-{version or 'default'}") + else: + has_driver = _has_driver(driver_name) + return pytest.param( + (db_name, version), + marks=pytest.mark.skipif( + not has_driver, reason=f"{driver_name} not installed" + ), + id=f"{db_name}-{version or 'default'}", + ) + + +def available_databases(subset=None): + """ + Return a list of pytest.param objects for known databases, + skipping those without drivers installed. + if passed subset input, will skip dbs that have no driver. + """ + driver_map = { + "sqlite": None, + "postgresql": "psycopg", + "mysql-server": "pymysql", + "percona": "pymysql", + "mariadb": "pymysql", + } + + all_configurations = [ + ("sqlite"), + ("postgresql", "13"), + ("postgresql", "17"), + ("mysql-server", "5.5"), + ("mysql-server", "5.6"), + ("mysql-server", "5.7"), + ("mysql-server", "8.0"), + ("mariadb", "10.3"), + ("mariadb", "10.4"), + ("mariadb", "10.5"), + ("percona", "5.6"), + ("percona", "5.7"), + ("percona", "8.0"), + ] + + if not subset: + subset = all_configurations + + marks = [] + for tup in subset: + if len(tup) == 3: + db_name, version, driver = tup + marks.append(db_param(db_name, version, driver_name=driver)) + else: + db_name, version = tup + marks.append(db_param(db_name, version, driver_name=driver_map[db_name])) + + return marks + + +@attr.s(kw_only=True, slots=True) +class DockerImage: + name = attr.ib() + tag = attr.ib() + container_id = attr.ib() + + def __str__(self): + return f"{self.name}:{self.tag}" + + +@attr.s(kw_only=True, slots=True) +class DatabaseCombo: + name = attr.ib() + dialect = attr.ib() + version = attr.ib() + port = attr.ib(default=None) + host = attr.ib(default="%") + user = attr.ib() + passwd = attr.ib() + database = attr.ib(default=None) + root_user = attr.ib(default="root") + root_passwd = attr.ib() + container = attr.ib(default=None) + container_id = attr.ib() + + @container_id.default + def _default_container_id(self): + return random_string( + "{}-{}-".format( + self.name.replace("/", "-"), + self.version, + ) + ) + + @root_passwd.default + def _default_root_user_passwd(self): + return self.passwd + + def get_credentials(self, **kwargs): + return { + "connection_user": kwargs.get("connection_user") or self.root_user, + "connection_pass": kwargs.get("connection_pass") or self.root_passwd, + "connection_db": kwargs.get("connection_db") or self.database, + "connection_port": kwargs.get("connection_port") or self.port, + } + + +def set_container_name_before_start(container): + """ + This is useful if the container has to be restared and the old + container, under the same name was left running, but in a bad shape. + """ + container.name = random_string("{}-".format(container.name.rsplit("-", 1)[0])) + container.display_name = None + return container + + +def check_container_started(timeout_at, container, container_test): + sleeptime = 0.5 + while time.time() <= timeout_at: + try: + if not container.is_running(): + log.warning("%s is no longer running", container) + return False + ret = container_test() + if ret.returncode == 0: + break + except APIError: + log.exception("Failed to run start check") + time.sleep(sleeptime) + sleeptime *= 2 + else: + return False + time.sleep(0.5) + return True + + +@pytest.fixture(scope="module") +def database_backend(request, salt_factories): + backend_type, version = request.param + + docker_image = DockerImage( + name=backend_type.replace("postgresql", "postgres"), + tag=version, + container_id=random_string(f"{backend_type}-{version}-"), + ) + + if platform.is_fips_enabled(): + if ( + docker_image.name in ("mysql-server", "percona") + and docker_image.tag == "8.0" + ): + pytest.skip(f"These tests fail on {docker_image.name}:{docker_image.tag}") + + if backend_type == "postgresql": + with make_postgresql_backend(salt_factories, docker_image) as container: + yield container + elif backend_type in ("mysql-server", "percona", "mariadb"): + with make_mysql_backend(salt_factories, docker_image, request) as container: + yield container + elif backend_type == "sqlite": + # just a stub to make sqlite act the same as ms/pg + yield DatabaseCombo( + name="sqlite", dialect="sqlite", version=version, user=None, passwd=None + ) + else: + raise ValueError(f"Unknown backend type: {backend_type}") + + +@contextmanager +def make_postgresql_backend(salt_factories, postgresql_image): + postgresql_combo = DatabaseCombo( + name=postgresql_image.name, + dialect="postgresql", + version=postgresql_image.tag, + user="salt-postgres-user", + passwd="Pa55w0rd!", + database="salt", + container_id=postgresql_image.container_id, + ) + + container_environment = { + "POSTGRES_USER": postgresql_combo.user, + "POSTGRES_PASSWORD": postgresql_combo.passwd, + } + if postgresql_combo.database: + container_environment["POSTGRES_DB"] = postgresql_combo.database + + container = salt_factories.get_container( + postgresql_combo.container_id, + f"{postgresql_combo.name}:{postgresql_combo.version}", + pull_before_start=True, + skip_on_pull_failure=True, + skip_if_docker_client_not_connectable=True, + container_run_kwargs={ + "ports": {"5432/tcp": None}, + "environment": container_environment, + }, + ) + + def _test(): + return container.run( + "psql", + f"--user={postgresql_combo.user}", + postgresql_combo.database, + "-e", + "SELECT 1", + environment={"PG_PASSWORD": postgresql_combo.passwd}, + ) + + container.before_start(set_container_name_before_start, container) + container.container_start_check(check_container_started, container, _test) + with container.started(): + postgresql_combo.container = container + postgresql_combo.port = container.get_host_port_binding( + 5432, protocol="tcp", ipv6=False + ) + yield postgresql_combo + + +@contextmanager +def make_mysql_backend(salt_factories, mysql_image, request): + # modules.test_mysql explicitly expects no database pre-created + mysql_combo = DatabaseCombo( + name=mysql_image.name, + dialect="mysql", + version=mysql_image.tag, + user="salt-mysql-user", + passwd="Pa55w0rd!", + database=( + None + # the mysql module test expects no database + if request.module.__name__ == "tests.pytests.functional.modules.test_mysql" + else "salt" + ), + container_id=mysql_image.container_id, + ) + + container_environment = { + "MYSQL_ROOT_PASSWORD": mysql_combo.passwd, + "MYSQL_ROOT_HOST": mysql_combo.host, + "MYSQL_USER": mysql_combo.user, + "MYSQL_PASSWORD": mysql_combo.passwd, + } + if mysql_combo.database: + container_environment["MYSQL_DATABASE"] = mysql_combo.database + + container = salt_factories.get_container( + mysql_combo.container_id, + "ghcr.io/saltstack/salt-ci-containers/{}:{}".format( + mysql_combo.name, mysql_combo.version + ), + pull_before_start=True, + skip_on_pull_failure=True, + skip_if_docker_client_not_connectable=True, + container_run_kwargs={ + "ports": {"3306/tcp": None}, + "environment": container_environment, + }, + ) + + def _test(): + return container.run( + "mysql", + f"--user={mysql_combo.user}", + f"--password={mysql_combo.passwd}", + "-e", + "SELECT 1", + ) + + container.before_start(set_container_name_before_start, container) + container.container_start_check(check_container_started, container, _test) + with container.started(): + mysql_combo.container = container + mysql_combo.port = container.get_host_port_binding( + 3306, protocol="tcp", ipv6=False + ) + yield mysql_combo diff --git a/tests/support/pytest/mysql.py b/tests/support/pytest/mysql.py deleted file mode 100644 index 20377e3453f4..000000000000 --- a/tests/support/pytest/mysql.py +++ /dev/null @@ -1,192 +0,0 @@ -import logging -import time - -import attr -import pytest -from pytestskipmarkers.utils import platform -from saltfactories.utils import random_string - -# This `pytest.importorskip` here actually works because this module -# is imported into test modules, otherwise, the skipping would just fail -pytest.importorskip("docker") -import docker.errors # isort:skip pylint: disable=3rd-party-module-not-gated - -log = logging.getLogger(__name__) - - -@attr.s(kw_only=True, slots=True) -class MySQLImage: - name = attr.ib() - tag = attr.ib() - container_id = attr.ib() - - def __str__(self): - return f"{self.name}:{self.tag}" - - -@attr.s(kw_only=True, slots=True) -class MySQLCombo: - mysql_name = attr.ib() - mysql_version = attr.ib() - mysql_port = attr.ib(default=None) - mysql_host = attr.ib(default="%") - mysql_user = attr.ib() - mysql_passwd = attr.ib() - mysql_database = attr.ib(default=None) - mysql_root_user = attr.ib(default="root") - mysql_root_passwd = attr.ib() - container = attr.ib(default=None) - container_id = attr.ib() - - @container_id.default - def _default_container_id(self): - return random_string( - "{}-{}-".format( - self.mysql_name.replace("/", "-"), - self.mysql_version, - ) - ) - - @mysql_root_passwd.default - def _default_mysql_root_user_passwd(self): - return self.mysql_passwd - - def get_credentials(self, **kwargs): - return { - "connection_user": kwargs.get("connection_user") or self.mysql_root_user, - "connection_pass": kwargs.get("connection_pass") or self.mysql_root_passwd, - "connection_db": kwargs.get("connection_db") or "mysql", - "connection_port": kwargs.get("connection_port") or self.mysql_port, - } - - -def get_test_versions(): - test_versions = [] - name = "mysql-server" - for version in ("5.5", "5.6", "5.7", "8.0"): - test_versions.append( - MySQLImage( - name=name, - tag=version, - container_id=random_string(f"mysql-{version}-"), - ) - ) - name = "mariadb" - for version in ("10.3", "10.4", "10.5"): - test_versions.append( - MySQLImage( - name=name, - tag=version, - container_id=random_string(f"mariadb-{version}-"), - ) - ) - name = "percona" - for version in ("5.6", "5.7", "8.0"): - test_versions.append( - MySQLImage( - name=name, - tag=version, - container_id=random_string(f"percona-{version}-"), - ) - ) - return test_versions - - -def get_test_version_id(value): - return f"container={value}" - - -@pytest.fixture(scope="module", params=get_test_versions(), ids=get_test_version_id) -def mysql_image(request): - return request.param - - -@pytest.fixture(scope="module") -def create_mysql_combo(mysql_image): - if platform.is_fips_enabled(): - if mysql_image.name in ("mysql-server", "percona") and mysql_image.tag == "8.0": - pytest.skip(f"These tests fail on {mysql_image.name}:{mysql_image.tag}") - - return MySQLCombo( - mysql_name=mysql_image.name, - mysql_version=mysql_image.tag, - mysql_user="salt-mysql-user", - mysql_passwd="Pa55w0rd!", - container_id=mysql_image.container_id, - ) - - -@pytest.fixture(scope="module") -def mysql_combo(create_mysql_combo): - return create_mysql_combo - - -def check_container_started(timeout_at, container, combo): - sleeptime = 0.5 - while time.time() <= timeout_at: - try: - if not container.is_running(): - log.warning("%s is no longer running", container) - return False - ret = container.run( - "mysql", - f"--user={combo.mysql_user}", - f"--password={combo.mysql_passwd}", - "-e", - "SELECT 1", - ) - if ret.returncode == 0: - break - except docker.errors.APIError: - log.exception("Failed to run start check") - time.sleep(sleeptime) - sleeptime *= 2 - else: - return False - time.sleep(0.5) - return True - - -def set_container_name_before_start(container): - """ - This is useful if the container has to be restared and the old - container, under the same name was left running, but in a bad shape. - """ - container.name = random_string("{}-".format(container.name.rsplit("-", 1)[0])) - container.display_name = None - return container - - -@pytest.fixture(scope="module") -def mysql_container(salt_factories, mysql_combo): - - container_environment = { - "MYSQL_ROOT_PASSWORD": mysql_combo.mysql_passwd, - "MYSQL_ROOT_HOST": mysql_combo.mysql_host, - "MYSQL_USER": mysql_combo.mysql_user, - "MYSQL_PASSWORD": mysql_combo.mysql_passwd, - } - if mysql_combo.mysql_database: - container_environment["MYSQL_DATABASE"] = mysql_combo.mysql_database - - container = salt_factories.get_container( - mysql_combo.container_id, - "ghcr.io/saltstack/salt-ci-containers/{}:{}".format( - mysql_combo.mysql_name, mysql_combo.mysql_version - ), - pull_before_start=True, - skip_on_pull_failure=True, - skip_if_docker_client_not_connectable=True, - container_run_kwargs={ - "ports": {"3306/tcp": None}, - "environment": container_environment, - }, - ) - container.before_start(set_container_name_before_start, container) - container.container_start_check(check_container_started, container, mysql_combo) - with container.started(): - mysql_combo.container = container - mysql_combo.mysql_port = container.get_host_port_binding( - 3306, protocol="tcp", ipv6=False - ) - yield mysql_combo