From de7f08d76ffb1771b0ff3871cc57a8eb48978b15 Mon Sep 17 00:00:00 2001 From: Jac Fitzgerald Date: Sat, 18 Feb 2023 12:24:23 -0800 Subject: [PATCH 1/4] add query-tagging attribute to connection --- tableauserverclient/models/connection_item.py | 2 ++ tableauserverclient/server/request_factory.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tableauserverclient/models/connection_item.py b/tableauserverclient/models/connection_item.py index 3b2255a3b..3d0885206 100644 --- a/tableauserverclient/models/connection_item.py +++ b/tableauserverclient/models/connection_item.py @@ -17,6 +17,7 @@ def __init__(self): self.server_port: Optional[str] = None self.username: Optional[str] = None self.connection_credentials: Optional[ConnectionCredentials] = None + self.query_tagging: bool = None @property def datasource_id(self) -> Optional[str]: @@ -52,6 +53,7 @@ def from_response(cls, resp, ns) -> List["ConnectionItem"]: connection_item.server_address = connection_xml.get("serverAddress", None) connection_item.server_port = connection_xml.get("serverPort", None) connection_item.username = connection_xml.get("userName", None) + connection_item.query_tagging = connection_xml.get("queryTaggingEnabled", None) datasource_elem = connection_xml.find(".//t:datasource", namespaces=ns) if datasource_elem is not None: connection_item._datasource_id = datasource_elem.get("id", None) diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index 14dc7606e..b19c3cc56 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -1002,6 +1002,8 @@ def update_req(self, xml_request: ET.Element, connection_item: "ConnectionItem") connection_element.attrib["password"] = connection_item.password if connection_item.embed_password is not None: connection_element.attrib["embedPassword"] = str(connection_item.embed_password).lower() + if connection_item.query_tagging is not None: + connection_element.attrib["queryTaggingEnabled"] = str(connection_item.query_tagging).lower() class TaskRequest(object): From e782c29d761c8f5e87ebae4593b5b1893ae63ca3 Mon Sep 17 00:00:00 2001 From: Jac Fitzgerald Date: Wed, 8 Mar 2023 15:06:52 -0800 Subject: [PATCH 2/4] add explanation for why it doesn't work on hyper --- tableauserverclient/models/connection_item.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tableauserverclient/models/connection_item.py b/tableauserverclient/models/connection_item.py index 3d0885206..0789dbb2b 100644 --- a/tableauserverclient/models/connection_item.py +++ b/tableauserverclient/models/connection_item.py @@ -1,3 +1,4 @@ +import logging from typing import List, Optional from defusedxml.ElementTree import fromstring @@ -17,7 +18,7 @@ def __init__(self): self.server_port: Optional[str] = None self.username: Optional[str] = None self.connection_credentials: Optional[ConnectionCredentials] = None - self.query_tagging: bool = None + self._query_tagging: bool = None @property def datasource_id(self) -> Optional[str]: @@ -35,6 +36,19 @@ def id(self) -> Optional[str]: def connection_type(self) -> Optional[str]: return self._connection_type + @property + def query_tagging(self) -> Optional[bool]: + return self._query_tagging + + @query_tagging.setter + def query_tagging(self, value: Optional[bool]): + # if connection type = hyper, Snowflake, or Teradata, we can't change this value: it is always true + if self._connection_type in ["hyper", "snowflake", "teradata"]: + logger = logging.getLogger("tableauserverclient.models.connection_item") + logger.debug("Cannot update value: Query tagging is always enabled for {} connections".format(self._connection_type)) + return + self._query_tagging = value + def __repr__(self): return "".format( **self.__dict__ @@ -53,7 +67,7 @@ def from_response(cls, resp, ns) -> List["ConnectionItem"]: connection_item.server_address = connection_xml.get("serverAddress", None) connection_item.server_port = connection_xml.get("serverPort", None) connection_item.username = connection_xml.get("userName", None) - connection_item.query_tagging = connection_xml.get("queryTaggingEnabled", None) + connection_item._query_tagging = connection_xml.get("queryTaggingEnabled", None) datasource_elem = connection_xml.find(".//t:datasource", namespaces=ns) if datasource_elem is not None: connection_item._datasource_id = datasource_elem.get("id", None) From 11e0a2e6ac9c7cf1a30fff9a2934a8b4cafe8e2a Mon Sep 17 00:00:00 2001 From: Jac Fitzgerald Date: Wed, 8 Mar 2023 15:17:19 -0800 Subject: [PATCH 3/4] format/type --- tableauserverclient/models/connection_item.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tableauserverclient/models/connection_item.py b/tableauserverclient/models/connection_item.py index 0789dbb2b..420a0c199 100644 --- a/tableauserverclient/models/connection_item.py +++ b/tableauserverclient/models/connection_item.py @@ -18,7 +18,7 @@ def __init__(self): self.server_port: Optional[str] = None self.username: Optional[str] = None self.connection_credentials: Optional[ConnectionCredentials] = None - self._query_tagging: bool = None + self._query_tagging: Optional[bool] = None @property def datasource_id(self) -> Optional[str]: @@ -45,7 +45,9 @@ def query_tagging(self, value: Optional[bool]): # if connection type = hyper, Snowflake, or Teradata, we can't change this value: it is always true if self._connection_type in ["hyper", "snowflake", "teradata"]: logger = logging.getLogger("tableauserverclient.models.connection_item") - logger.debug("Cannot update value: Query tagging is always enabled for {} connections".format(self._connection_type)) + logger.debug( + "Cannot update value: Query tagging is always enabled for {} connections".format(self._connection_type) + ) return self._query_tagging = value From 24fb52fd7ad87250ef9b2c14d2b26bbfd9312845 Mon Sep 17 00:00:00 2001 From: Jac Fitzgerald Date: Wed, 8 Mar 2023 20:44:55 -0800 Subject: [PATCH 4/4] string_to_bool fix --- tableauserverclient/models/connection_item.py | 6 ++-- test/test_connection_.py | 34 +++++++++++++++++++ test/test_datasource_model.py | 11 +++++- 3 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 test/test_connection_.py diff --git a/tableauserverclient/models/connection_item.py b/tableauserverclient/models/connection_item.py index 420a0c199..c17421c07 100644 --- a/tableauserverclient/models/connection_item.py +++ b/tableauserverclient/models/connection_item.py @@ -4,6 +4,7 @@ from defusedxml.ElementTree import fromstring from .connection_credentials import ConnectionCredentials +from .property_decorators import property_is_boolean class ConnectionItem(object): @@ -41,6 +42,7 @@ def query_tagging(self) -> Optional[bool]: return self._query_tagging @query_tagging.setter + @property_is_boolean def query_tagging(self, value: Optional[bool]): # if connection type = hyper, Snowflake, or Teradata, we can't change this value: it is always true if self._connection_type in ["hyper", "snowflake", "teradata"]: @@ -69,7 +71,7 @@ def from_response(cls, resp, ns) -> List["ConnectionItem"]: connection_item.server_address = connection_xml.get("serverAddress", None) connection_item.server_port = connection_xml.get("serverPort", None) connection_item.username = connection_xml.get("userName", None) - connection_item._query_tagging = connection_xml.get("queryTaggingEnabled", None) + connection_item._query_tagging = string_to_bool(connection_xml.get("queryTaggingEnabled", None)) datasource_elem = connection_xml.find(".//t:datasource", namespaces=ns) if datasource_elem is not None: connection_item._datasource_id = datasource_elem.get("id", None) @@ -111,4 +113,4 @@ def from_xml_element(cls, parsed_response, ns) -> List["ConnectionItem"]: # Used to convert string represented boolean to a boolean type def string_to_bool(s: str) -> bool: - return s.lower() == "true" + return s is not None and s.lower() == "true" diff --git a/test/test_connection_.py b/test/test_connection_.py new file mode 100644 index 000000000..47b796ebe --- /dev/null +++ b/test/test_connection_.py @@ -0,0 +1,34 @@ +import unittest +import tableauserverclient as TSC + + +class DatasourceModelTests(unittest.TestCase): + def test_require_boolean_query_tag_fails(self): + conn = TSC.ConnectionItem() + conn._connection_type = "postgres" + with self.assertRaises(ValueError): + conn.query_tagging = "no" + + def test_set_query_tag_normal_conn(self): + conn = TSC.ConnectionItem() + conn._connection_type = "postgres" + conn.query_tagging = True + self.assertEqual(conn.query_tagging, True) + + def test_ignore_query_tag_for_hyper(self): + conn = TSC.ConnectionItem() + conn._connection_type = "hyper" + conn.query_tagging = True + self.assertEqual(conn.query_tagging, None) + + def test_ignore_query_tag_for_teradata(self): + conn = TSC.ConnectionItem() + conn._connection_type = "teradata" + conn.query_tagging = True + self.assertEqual(conn.query_tagging, None) + + def test_ignore_query_tag_for_snowflake(self): + conn = TSC.ConnectionItem() + conn._connection_type = "snowflake" + conn.query_tagging = True + self.assertEqual(conn.query_tagging, None) diff --git a/test/test_datasource_model.py b/test/test_datasource_model.py index 81a26b068..2360574ec 100644 --- a/test/test_datasource_model.py +++ b/test/test_datasource_model.py @@ -1,5 +1,4 @@ import unittest - import tableauserverclient as TSC @@ -9,3 +8,13 @@ def test_invalid_project_id(self): datasource = TSC.DatasourceItem("10") with self.assertRaises(ValueError): datasource.project_id = None + + def test_require_boolean_flag_bridge_fail(self): + datasource = TSC.DatasourceItem("10") + with self.assertRaises(ValueError): + datasource.use_remote_query_agent = "yes" + + def test_require_boolean_flag_bridge_ok(self): + datasource = TSC.DatasourceItem("10") + datasource.use_remote_query_agent = True + self.assertEqual(datasource.use_remote_query_agent, True)