diff --git a/arango/request.py b/arango/request.py index fa7a7483..d80e324f 100644 --- a/arango/request.py +++ b/arango/request.py @@ -2,13 +2,22 @@ from typing import Any, MutableMapping, Optional -from arango.typings import Fields, Headers, Params - - -def normalize_headers(headers: Optional[Headers]) -> Headers: +from arango.typings import DriverFlags, Fields, Headers, Params +from arango.version import __version__ + + +def normalize_headers( + headers: Optional[Headers], driver_flags: Optional[DriverFlags] = None +) -> Headers: + flags = "" + if driver_flags is not None: + for flag in driver_flags: + flags = flags + flag + ";" + driver_header = "python-arango/" + __version__ + " (" + flags + ")" normalized_headers: Headers = { "charset": "utf-8", "content-type": "application/json", + "x-arango-driver": driver_header, } if headers is not None: for key, value in headers.items(): @@ -53,6 +62,8 @@ class Request: :type exclusive: str | [str] | None :param deserialize: Whether the response body can be deserialized. :type deserialize: bool + :param driver_flags: List of flags for the driver + :type driver_flags: list :ivar method: HTTP method in lowercase (e.g. "post"). :vartype method: str @@ -74,6 +85,8 @@ class Request: :vartype exclusive: str | [str] | None :ivar deserialize: Whether the response body can be deserialized. :vartype deserialize: bool + :ivar driver_flags: List of flags for the driver + :vartype driver_flags: list """ __slots__ = ( @@ -86,6 +99,7 @@ class Request: "write", "exclusive", "deserialize", + "driver_flags", ) def __init__( @@ -99,13 +113,15 @@ def __init__( write: Optional[Fields] = None, exclusive: Optional[Fields] = None, deserialize: bool = True, + driver_flags: Optional[DriverFlags] = None, ) -> None: self.method = method self.endpoint = endpoint - self.headers: Headers = normalize_headers(headers) + self.headers: Headers = normalize_headers(headers, driver_flags) self.params: MutableMapping[str, str] = normalize_params(params) self.data = data self.read = read self.write = write self.exclusive = exclusive self.deserialize = deserialize + self.driver_flags = driver_flags diff --git a/arango/typings.py b/arango/typings.py index ed685d99..8d49e3fd 100644 --- a/arango/typings.py +++ b/arango/typings.py @@ -7,3 +7,4 @@ Params = MutableMapping[str, Union[bool, int, str]] Headers = MutableMapping[str, str] Fields = Union[str, Sequence[str]] +DriverFlags = List[str] diff --git a/tests/test_request.py b/tests/test_request.py index 616a388c..256c9a68 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -1,7 +1,7 @@ from arango.request import Request -def test_request_no_data(): +def test_request_no_data() -> None: request = Request( method="post", endpoint="/_api/test", @@ -11,15 +11,13 @@ def test_request_no_data(): assert request.method == "post" assert request.endpoint == "/_api/test" assert request.params == {"bool": "1"} - assert request.headers == { - "charset": "utf-8", - "content-type": "application/json", - "foo": "bar", - } + assert request.headers["charset"] == "utf-8" + assert request.headers["content-type"] == "application/json" + assert request.headers["foo"] == "bar" assert request.data is None -def test_request_string_data(): +def test_request_string_data() -> None: request = Request( method="post", endpoint="/_api/test", @@ -30,15 +28,13 @@ def test_request_string_data(): assert request.method == "post" assert request.endpoint == "/_api/test" assert request.params == {"bool": "1"} - assert request.headers == { - "charset": "utf-8", - "content-type": "application/json", - "foo": "bar", - } + assert request.headers["charset"] == "utf-8" + assert request.headers["content-type"] == "application/json" + assert request.headers["foo"] == "bar" assert request.data == "test" -def test_request_json_data(): +def test_request_json_data() -> None: request = Request( method="post", endpoint="/_api/test", @@ -49,15 +45,13 @@ def test_request_json_data(): assert request.method == "post" assert request.endpoint == "/_api/test" assert request.params == {"bool": "1"} - assert request.headers == { - "charset": "utf-8", - "content-type": "application/json", - "foo": "bar", - } + assert request.headers["charset"] == "utf-8" + assert request.headers["content-type"] == "application/json" + assert request.headers["foo"] == "bar" assert request.data == {"baz": "qux"} -def test_request_transaction_data(): +def test_request_transaction_data() -> None: request = Request( method="post", endpoint="/_api/test", @@ -68,9 +62,7 @@ def test_request_transaction_data(): assert request.method == "post" assert request.endpoint == "/_api/test" assert request.params == {"bool": "1"} - assert request.headers == { - "charset": "utf-8", - "content-type": "application/json", - "foo": "bar", - } + assert request.headers["charset"] == "utf-8" + assert request.headers["content-type"] == "application/json" + assert request.headers["foo"] == "bar" assert request.data == {"baz": "qux"}