diff --git a/.gitignore b/.gitignore index c6ef2445..4fa6f46d 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,6 @@ arango/version.py # test results *_results.txt + +# devcontainers +.devcontainer diff --git a/arango/collection.py b/arango/collection.py index 446200fb..e2dfcd2a 100644 --- a/arango/collection.py +++ b/arango/collection.py @@ -50,11 +50,13 @@ from arango.typings import Fields, Headers, Json, Jsons, Params from arango.utils import ( build_filter_conditions, + build_sort_expression, get_batches, get_doc_id, is_none_or_bool, is_none_or_int, is_none_or_str, + validate_sort_parameters, ) @@ -753,6 +755,7 @@ def find( skip: Optional[int] = None, limit: Optional[int] = None, allow_dirty_read: bool = False, + sort: Optional[Jsons] = None, ) -> Result[Cursor]: """Return all documents that match the given filters. @@ -764,13 +767,18 @@ def find( :type limit: int | None :param allow_dirty_read: Allow reads from followers in a cluster. :type allow_dirty_read: bool + :param sort: Document sort parameters + :type sort: Jsons | None :return: Document cursor. :rtype: arango.cursor.Cursor :raise arango.exceptions.DocumentGetError: If retrieval fails. + :raise arango.exceptions.SortValidationError: If sort parameters are invalid. """ assert isinstance(filters, dict), "filters must be a dict" assert is_none_or_int(skip), "skip must be a non-negative int" assert is_none_or_int(limit), "limit must be a non-negative int" + if sort: + validate_sort_parameters(sort) skip_val = skip if skip is not None else 0 limit_val = limit if limit is not None else "null" @@ -778,9 +786,9 @@ def find( FOR doc IN @@collection {build_filter_conditions(filters)} LIMIT {skip_val}, {limit_val} + {build_sort_expression(sort)} RETURN doc """ - bind_vars = {"@collection": self.name} request = Request( diff --git a/arango/exceptions.py b/arango/exceptions.py index 28295b2b..29bcdc17 100644 --- a/arango/exceptions.py +++ b/arango/exceptions.py @@ -1074,3 +1074,10 @@ class JWTRefreshError(ArangoClientError): class JWTExpiredError(ArangoClientError): """JWT token has expired.""" + + +################################### +# Parameter Validation Exceptions # +################################### +class SortValidationError(ArangoClientError): + """Invalid sort parameters.""" diff --git a/arango/utils.py b/arango/utils.py index 541f9d0c..0d128db3 100644 --- a/arango/utils.py +++ b/arango/utils.py @@ -11,8 +11,8 @@ from contextlib import contextmanager from typing import Any, Iterator, Sequence, Union -from arango.exceptions import DocumentParseError -from arango.typings import Json +from arango.exceptions import DocumentParseError, SortValidationError +from arango.typings import Json, Jsons @contextmanager @@ -126,3 +126,42 @@ def build_filter_conditions(filters: Json) -> str: conditions.append(f"doc.{field} == {json.dumps(v)}") return "FILTER " + " AND ".join(conditions) + + +def validate_sort_parameters(sort: Sequence[Json]) -> bool: + """Validate sort parameters for an AQL query. + + :param sort: Document sort parameters. + :type sort: Sequence[Json] + :return: Validation success. + :rtype: bool + :raise arango.exceptions.SortValidationError: If sort parameters are invalid. + """ + assert isinstance(sort, Sequence) + for param in sort: + if "sort_by" not in param or "sort_order" not in param: + raise SortValidationError( + "Each sort parameter must have 'sort_by' and 'sort_order'." + ) + if param["sort_order"].upper() not in ["ASC", "DESC"]: + raise SortValidationError("'sort_order' must be either 'ASC' or 'DESC'") + return True + + +def build_sort_expression(sort: Jsons | None) -> str: + """Build a sort condition for an AQL query. + + :param sort: Document sort parameters. + :type sort: Jsons | None + :return: The complete AQL sort condition. + :rtype: str + """ + if not sort: + return "" + + sort_chunks = [] + for sort_param in sort: + chunk = f"doc.{sort_param['sort_by']} {sort_param['sort_order']}" + sort_chunks.append(chunk) + + return "SORT " + ", ".join(sort_chunks) diff --git a/docs/document.rst b/docs/document.rst index 62ad0886..0f0d7d10 100644 --- a/docs/document.rst +++ b/docs/document.rst @@ -103,6 +103,12 @@ Standard documents are managed via collection API wrapper: assert student['GPA'] == 3.6 assert student['last'] == 'Kim' + # Retrieve one or more matching documents, sorted by a field. + for student in students.find({'first': 'John'}, sort=[{'sort_by': 'GPA', 'sort_order': 'DESC'}]): + assert student['_key'] == 'john' + assert student['GPA'] == 3.6 + assert student['last'] == 'Kim' + # Retrieve a document by key. students.get('john') diff --git a/tests/test_document.py b/tests/test_document.py index 37599507..7cb0a435 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -1162,6 +1162,26 @@ def test_document_find(col, bad_col, docs): # Set up test documents col.import_bulk(docs) + # Test find with sort expression (single field) + found = list(col.find({}, sort=[{"sort_by": "text", "sort_order": "ASC"}])) + assert len(found) == 6 + assert found[0]["text"] == "bar" + assert found[-1]["text"] == "foo" + + # Test find with sort expression (multiple fields) + found = list( + col.find( + {}, + sort=[ + {"sort_by": "text", "sort_order": "ASC"}, + {"sort_by": "val", "sort_order": "DESC"}, + ], + ) + ) + assert len(found) == 6 + assert found[0]["val"] == 6 + assert found[-1]["val"] == 1 + # Test find (single match) with default options found = list(col.find({"val": 2})) assert len(found) == 1