8000 Django-style filters by scuml · Pull Request #615 · tableau/server-client-python · GitHub
[go: up one dir, main page]

Skip to content

Django-style filters #615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ target/
# pyenv
.python-version

# poetry
poetry.lock
pyproject.toml

# celery beat schedule file
celerybeat-schedule

Expand Down
5 changes: 3 additions & 2 deletions tableauserverclient/server/endpoint/datasources_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .endpoint import Endpoint, api, parameter_added_in
from .endpoint import QuerysetEndpoint, api, parameter_added_in
from .exceptions import InternalServerError, MissingRequiredFieldError
from .permissions_endpoint import _PermissionsEndpoint
from .fileuploads_endpoint import Fileuploads
from .resource_tagger import _ResourceTagger
from .. import RequestFactory, DatasourceItem, PaginationItem, ConnectionItem
from ..query import QuerySet
from ...filesys_helpers import to_filename, make_download_path
from ...models.job_item import JobItem

Expand All @@ -21,7 +22,7 @@
logger = logging.getLogger('tableau.endpoint.datasources')


class Datasources(Endpoint):
class Datasources(QuerysetEndpoint):
def __init__(self, parent_srv):
super(Datasources, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
24 changes: 23 additions & 1 deletion tableauserverclient/server/endpoint/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .exceptions import ServerResponseError, InternalServerError, NonXMLResponseError
from functools import wraps
from xml.etree.ElementTree import ParseError

from ..query import QuerySet
import logging

try:
Expand Down Expand Up @@ -165,3 +165,25 @@ def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
return wrapper
return _decorator


class QuerysetEndpoint(Endpoint):
@api(version="2.0")
def all(self, *args, **kwargs):
queryset = QuerySet(self)
return queryset

@api(version="2.0")
def filter(self, *args, **kwargs):
queryset = QuerySet(self).filter(**kwargs)
return queryset

@api(version="2.0")
def order_by(self, *args, **kwargs):
queryset = QuerySet(self).order_by(*args)
return queryset

@api(version="2.0")
def paginate(self, **kwargs):
queryset = QuerySet(self).paginate(**kwargs)
return queryset
4 changes: 2 additions & 2 deletions tableauserverclient/server/endpoint/users_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .endpoint import Endpoint, api
from .endpoint import QuerysetEndpoint, api
from .exceptions import MissingRequiredFieldError
from .. import RequestFactory, UserItem, WorkbookItem, PaginationItem
from ..pager import Pager
Expand All @@ -9,7 +9,7 @@
logger = logging.getLogger('tableau.endpoint.users')


class Users(Endpoint):
class Users(QuerysetEndpoint):
@property
def baseurl(self):
return "{0}/sites/{1}/users".format(self.parent_srv.baseurl, self.parent_srv.site_id)
Expand Down
4 changes: 2 additions & 2 deletions tableauserverclient/server/endpoint/views_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .endpoint import Endpoint, api
from .endpoint import QuerysetEndpoint, api
from .exceptions import MissingRequiredFieldError
from .resource_tagger import _ResourceTagger
from .permissions_endpoint import _PermissionsEndpoint
Expand All @@ -10,7 +10,7 @@
logger = logging.getLogger('tableau.endpoint.views')


class Views(Endpoint):
class Views(QuerysetEndpoint):
def __init__(self, parent_srv):
super(Views, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
10 changes: 6 additions & 4 deletions tableauserverclient/server/endpoint/workbooks_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .endpoint import Endpoint, api, parameter_added_in
from .endpoint import QuerysetEndpoint, api, parameter_added_in
from .exceptions import InternalServerError, MissingRequiredFieldError
from .permissions_endpoint import _PermissionsEndpoint
from .fileuploads_endpoint import Fileuploads
Expand All @@ -21,7 +21,7 @@
logger = logging.getLogger('tableau.endpoint.workbooks')


class Workbooks(Endpoint):
class Workbooks(QuerysetEndpoint):
def __init__(self, parent_srv):
super(Workbooks, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand All @@ -37,8 +37,10 @@ def get(self, req_options=None):
logger.info('Querying all workbooks on site')
url = self.baseurl
server_response = self.get_request(url, req_options)
pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace)
all_workbook_items = WorkbookItem.from_response(server_response.content, self.parent_srv.namespace)
pagination_item = PaginationItem.from_response(
server_response.content, self.parent_srv.namespace)
all_workbook_items = WorkbookItem.from_response(
server_response.content, self.parent_srv.namespace)
return all_workbook_items, pagination_item

# Get 1 workbook
Expand Down
89 changes: 89 additions & 0 deletions tableauserverclient/server/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from .request_options import RequestOptions
from .filter import Filter
from .sort import Sort


def to_camel_case(word):
return word.split('_')[0] + ''.join(x.capitalize() or '_' for x in word.split('_')[1:])


class QuerySet:

def __init__(self, model):
self.model = model
self.request_options = RequestOptions()
self._result_cache = None
self._pagination_item = None

def __iter__(self):
self._fetch_all()
return iter(self._result_cache)

def __getitem__(self, k):
return list(self)[k]

def _fetch_all(self):
"""
Retrieve the data and store result and pagination item in cache
"""
if self._result_cache is None:
self._result_cache, self._pagination_item = self.model.get(self.request_options)

@property
def total_available(self):
self._fetch_all()
return self._pagination_item.total_available

@property
def page_number(self):
self._fetch_all()
return self._pagination_item.page_number

@property
def page_size(self):
self._fetch_all()
return self._pagination_item.page_size

def filter(self, **kwargs):
for kwarg_key, value in kwargs.items():
field_name, operator = self._parse_shorthand_filter(kwarg_key)
self.request_options.filter.add(Filter(field_name, operator, value))
return self

def order_by(self, *args):
for arg in args:
field_name, direction = self._parse_shorthand_sort(arg)
self.request_options.sort.add(Sort(field_name, direction))
return self

def paginate(self, **kwargs):
if "page_number" in kwargs:
self.request_options.pagenumber = kwargs["page_number"]
if "page_size" in kwargs:
self.request_options.pagesize = kwargs["page_size"]
return self

def _parse_shorthand_filter(self, key):
tokens = key.split("__", 1)
if len(tokens) == 1:
operator = RequestOptions.Operator.Equals
else:
operator = tokens[1]
if operator not in RequestOptions.Operator.__dict__.values():
raise ValueError("Operator `{}` is not valid.".format(operator))

field = to_camel_case(tokens[0])
if field not in RequestOptions.Field.__dict__.values():
raise ValueError("Field name `{}` is not valid.".format(field))
return (field, operator)

def _parse_shorthand_sort(self, key):
direction = RequestOptions.Direction.Asc
if key.startswith("-"):
direction = RequestOptions.Direction.Desc
key = key[1:]

key = to_camel_case(key)
if key not in RequestOptions.Field.__dict__.values():
raise ValueError("Sort key name %s is not valid.", key)
return (key, direction)
43 changes: 43 additions & 0 deletions test/test_request_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ def test_filter_equals(self):
self.assertEqual('RESTAPISample', matching_workbooks[0].name)
self.assertEqual('RESTAPISample', matching_workbooks[1].name)

def test_filter_equals_shorthand(self):
with open(FILTER_EQUALS, 'rb') as f:
response_xml = f.read().decode('utf-8')
with requests_mock.mock() as m:
m.get(self.baseurl + '/workbooks?filter=name:eq:RESTAPISample', text=response_xml)
matching_workbooks = self.server.workbooks.filter(name='RESTAPISample').order_by("name")

self.assertEqual(2, matching_workbooks.total_available)
self.assertEqual('RESTAPISample', matching_workbooks[0].name)
self.assertEqual('RESTAPISample', matching_workbooks[1].name)

def test_filter_tags_in(self):
with open(FILTER_TAGS_IN, 'rb') as f:
response_xml = f.read().decode('utf-8')
Expand All @@ -91,6 +102,22 @@ def test_filter_tags_in(self):
self.assertEqual(set(['safari']), matching_workbooks[1].tags)
self.assertEqual(set(['sample']), matching_workbooks[2].tags)

def test_filter_tags_in_shorthand(self):
with open(FILTER_TAGS_IN, 'rb') as f:
response_xml = f.read().decode('utf-8')
with requests_mock.mock() as m:
m.get(self.baseurl + '/workbooks?filter=tags:in:[sample,safari,weather]', text=response_xml)
matching_workbooks = self.server.workbooks.filter(tags__in=['sample', 'safari', 'weather'])

self.assertEqual(3, matching_workbooks.total_available)
self.assertEqual(set(['weather']), matching_workbooks[0].tags)
self.assertEqual(set(['safari']), matching_workbooks[1].tags)
self.assertEqual(set(['sample']), matching_workbooks[2].tags)

def test_invalid_shorthand_option(self):
with self.assertRaises(ValueError):
self.server.workbooks.filter(nonexistant__in=['sample', 'safari'])

def test_multiple_filter_options(self):
with open(FILTER_MULTIPLE, 'rb') as f:
response_xml = f.read().decode('utf-8')
Expand All @@ -107,3 +134,19 @@ def test_multiple_filter_options(self):
for _ in range(100):
matching_workbooks, pagination_item = self.server.workbooks.get(req_option)
self.assertEqual(3, pagination_item.total_available)

def test_multiple_filter_options_shorthand(self):
with open(FILTER_MULTIPLE, 'rb') as f:
response_xml = f.read().decode('utf-8')
# To ensure that this is deterministic, run this a few times
with requests_mock.mock() as m:
# Sometimes pep8 requires you to do things you might not otherwise do
url = ''.join((self.baseurl, '/workbooks?pageNumber=1&pageSize=100&',
'filter=name:eq:foo,tags:in:[sample,safari,weather]'))
m.get(url, text=response_xml)

for _ in range(100):
matching_workbooks = self.server.workbooks.filter(
tags__in=['sample', 'safari', 'weather'], name='foo'
)
self.assertEqual(3, matching_workbooks.total_available)
0