8000 chore: address last comments of PR#87 by GarrettWu · Pull Request #102 · googleapis/python-bigquery-dataframes · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions bigframes/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
)
logger = logging.getLogger(__name__)

_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection"


class BqConnectionManager:
"""Manager to handle operations with BQ connections."""
Expand All @@ -46,6 +44,23 @@ def __init__(
self._bq_connection_client = bq_connection_client
self._cloud_resource_manager_client = cloud_resource_manager_client

@classmethod
def resolve_full_connection_name(
cls, connection_name: str, default_project: str, default_location: str
) -> str:
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
Use default project, location or connection_id when any of them are missing."""
if connection_name.count(".") == 2:
return connection_name

if connection_name.count(".") == 1:
return f"{default_project}.{connection_name}"

if connection_name.count(".") == 0:
return f"{default_project}.{default_location}.{connection_name}"

raise ValueError(f"Invalid connection name format: {connection_name}.")

def create_bq_connection(
self, project_id: str, location: str, connection_id: str, iam_role: str
):
Expand Down Expand Up @@ -164,25 +179,3 @@ def _get_service_account_if_connection_exists(
pass

return service_account


def get_connection_name_full(
connection_name: Optional[str], default_project: str, default_location: str
) -> str:
"""Retrieve the full connection name of the form <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
Use default project, location or connection_id when any of them are missing."""
if connection_name is None:
return (
f"{default_project}.{default_location}.{_BIGFRAMES_DEFAULT_CONNECTION_ID}"
)

if connection_name.count(".") == 2:
return connection_name

if connection_name.count(".") == 1:
return f"{default_project}.{connection_name}"

if connection_name.count(".") == 0:
return f"{default_project}.{default_location}.{connection_name}"

raise ValueError(f"Invalid connection name format: {connection_name}.")
16 changes: 8 additions & 8 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ def __init__(
connection_name: Optional[str] = None,
):
self.session = session or bpd.get_global_session()
self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)

connection_name = connection_name or self.session._bq_connection
self.connection_name = clients.get_connection_name_full(
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
)

self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bqml_model_factory = globals.bqml_model_factory()
self._bqml_model: core.BqmlModel = self._create_bqml_model()

Expand Down Expand Up @@ -188,17 +188,17 @@ def __init__(
connection_name: Optional[str] = None,
):
self.session = session or bpd.get_global_session()
self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)

connection_name = connection_name or self.session._bq_connection
self.connection_name = clients.get_connection_name_full(
self.connection_name = self._bq_connection_manager.resolve_full_connection_name(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea to keep it classmethod was to use BqConnectionManager.resolve_full_connection_name(...) directly. Using via instance feels a bit weird, but it works.

connection_name,
default_project=self.session._project,
default_location=self.session._location,
)

self._bq_connection_manager = clients.BqConnectionManager(
self.session.bqconnectionclient, self.session.resourcemanagerclient
)
self._bqml_model_factory = globals.bqml_model_factory()
self._bqml_model: core.BqmlModel = self._create_bqml_model()

Expand Down
2 changes: 1 addition & 1 deletion bigframes/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def remote_function(
if not bigquery_connection:
bigquery_connection = session._bq_connection # type: ignore

bigquery_connection = clients.get_connection_name_full(
bigquery_connection = clients.BqConnectionManager.resolve_full_connection_name(
bigquery_connection,
default_project=dataset_ref.project,
default_location=bq_location,
Expand Down
4 changes: 3 additions & 1 deletion bigframes/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@
_BIGQUERYCONNECTION_REGIONAL_ENDPOINT = "{location}-bigqueryconnection.googleapis.com"
_BIGQUERYSTORAGE_REGIONAL_ENDPOINT = "{location}-bigquerystorage.googleapis.com"

_BIGFRAMES_DEFAULT_CONNECTION_ID = "bigframes-default-connection"

_MAX_CLUSTER_COLUMNS = 4

# TODO(swast): Need to connect to regional endpoints when performing remote
Expand Down Expand Up @@ -321,7 +323,7 @@ def __init__(
),
)

self._bq_connection = context.bq_connection
self._bq_connection = context.bq_connection or _BIGFRAMES_DEFAULT_CONNECTION_ID

# Now that we're starting the session, don't allow the options to be
# changed.
Expand Down
16 changes: 4 additions & 12 deletions tests/unit/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,22 @@
from bigframes import clients


def test_get_connection_name_full_none():
connection_name = clients.get_connection_name_full(
None, default_project="default-project", default_location="us"
)
assert connection_name == "default-project.us.bigframes-default-connection"


def test_get_connection_name_full_connection_id():
connection_name = clients.get_connection_name_full(
connection_name = clients.BqConnectionManager.resolve_full_connection_name(
"connection-id", default_project="default-project", default_location="us"
)
assert connection_name == "default-project.us.connection-id"


def test_get_connection_name_full_location_connection_id():
connection_name = clients.get_connection_name_full(
connection_name = clients.BqConnectionManager.resolve_full_connection_name(
"eu.connection-id", default_project="default-project", default_location="us"
)
assert connection_name == "default-project.eu.connection-id"


def test_get_connection_name_full_all():
connection_name = clients.get_connection_name_full(
connection_name = clients.BqConnectionManager.resolve_full_connection_name(
"my-project.eu.connection-id",
default_project="default-project",
default_location="us",
Expand All @@ -48,9 +41,8 @@ def test_get_connection_name_full_all():


def test_get_connection_name_full_raise_value_error():

with pytest.raises(ValueError):
clients.get_connection_name_full(
clients.BqConnectionManager.resolve_full_connection_name(
"my-project.eu.connection-id.extra_field",
default_project="default-project",
default_location="us",
Expand Down
0