diff --git a/tableaudocumentapi/__init__.py b/tableaudocumentapi/__init__.py index c4e98c6..5e79216 100644 --- a/tableaudocumentapi/__init__.py +++ b/tableaudocumentapi/__init__.py @@ -1,5 +1,5 @@ __version__ = '0.0.1' __VERSION__ = __version__ from .connection import Connection -from .datasource import Datasource +from .datasource import Datasource, ConnectionParser from .workbook import Workbook diff --git a/tableaudocumentapi/datasource.py b/tableaudocumentapi/datasource.py index 3a03e1e..ee7428d 100644 --- a/tableaudocumentapi/datasource.py +++ b/tableaudocumentapi/datasource.py @@ -6,6 +6,25 @@ import xml.etree.ElementTree as ET from tableaudocumentapi import Connection +class ConnectionParser(object): + + def __init__(self, datasource_xml, version): + self._dsxml = datasource_xml + self._dsversion = version + + def _extract_federated_connections(self): + return list(map(Connection,self._dsxml.findall('.//named-connections/named-connection/*'))) + + def _extract_legacy_connection(self): + return Connection(self._dsxml.find('connection')) + + def get_connections(self): + if float(self._dsversion) < 10: + connections = self._extract_legacy_connection() + else: + connections = self._extract_federated_connections() + return connections + class Datasource(object): """ @@ -28,7 +47,8 @@ def __init__(self, dsxml, filename=None): self._datasourceTree = ET.ElementTree(self._datasourceXML) self._name = self._datasourceXML.get('name') or self._datasourceXML.get('formatted-name') # TDS files don't have a name attribute self._version = self._datasourceXML.get('version') - self._connection = Connection(self._datasourceXML.find('connection')) + self._connection_parser = ConnectionParser(self._datasourceXML, version=self._version) + self._connection = self._connection_parser.get_connections() @classmethod def from_file(cls, filename): diff --git a/test.py b/test.py index 7766c3b..eb20751 100644 --- a/test.py +++ b/test.py @@ -3,28 +3,20 @@ import os import xml.etree.ElementTree as ET -from tableaudocumentapi import Workbook, Datasource, Connection - -TABLEAU_93_WORKBOOK = ''' - - - - - - - -''' - -TABLEAU_93_TDS = ''' - - - -''' +from tableaudocumentapi import Workbook, Datasource, Connection, ConnectionParser + + +TABLEAU_93_WORKBOOK = '''''' + +TABLEAU_93_TDS = '''''' + +TABLEAU_10_TDS = '''''' + +TABLEAU_10_WORKBOOK = '''''' TABLEAU_CONNECTION_XML = ET.fromstring( '''''') - class HelperMethodTests(unittest.TestCase): def test_is_valid_file_with_valid_inputs(self): @@ -38,6 +30,23 @@ def test_is_valid_file_with_invalid_inputs(self): self.assertFalse(Workbook._is_valid_file('file2.twb3')) +class ConnectionParserTests(unittest.TestCase): + + def test_can_extract_legacy_connection(self): + parser = ConnectionParser(ET.fromstring(TABLEAU_93_TDS), '9.2') + connection = parser.get_connections() + self.assertIsInstance(connection, Connection) + self.assertEqual(connection.dbname, 'TestV1') + + + def test_can_extract_federated_connections(self): + parser = ConnectionParser(ET.fromstring(TABLEAU_10_TDS), '10.0') + connections = parser.get_connections() + self.assertIsInstance(connections, list) + self.assertIsInstance(connections[0], Connection) + self.assertEqual(connections[0].dbname, 'testv1') + + class ConnectionModelTests(unittest.TestCase): def setUp(self): @@ -114,5 +123,34 @@ def test_can_update_datasource_connection_and_save(self): self.assertEqual(new_wb.datasources[0].connection.dbname, 'newdb.test.tsi.lan') +class WorkbookModelV10Tests(unittest.TestCase): + + def setUp(self): + self.workbook_file = io.FileIO('testv10.twb', 'w') + self.workbook_file.write(TABLEAU_10_WORKBOOK.encode('utf8')) + self.workbook_file.seek(0) + + def tearDown(self): + self.workbook_file.close() + os.unlink(self.workbook_file.name) + + def test_can_extract_datasourceV10(self): + wb = Workbook(self.workbook_file.name) + self.assertEqual(len(wb.datasources), 1) + self.assertEqual(len(wb.datasources[0].connection), 2) + self.assertIsInstance(wb.datasources[0].connection, list) + self.assertIsInstance(wb.datasources[0], Datasource) + self.assertEqual(wb.datasources[0].name, + 'federated.1s4nxn20cywkdv13ql0yk0g1mpdx') + + def test_can_update_datasource_connection_and_saveV10(self): + original_wb = Workbook(self.workbook_file.name) + original_wb.datasources[0].connection[0].dbname = 'newdb.test.tsi.lan' + + original_wb.save() + + new_wb = Workbook(self.workbook_file.name) + self.assertEqual(new_wb.datasources[0].connection[0].dbname, 'newdb.test.tsi.lan') + if __name__ == '__main__': unittest.main()