-
Notifications
You must be signed in to change notification settings - Fork 6
8000
Show file tree
Hide file tree
Jan 25, 2024
Jan 25, 2024
Jan 26, 2024
Jan 26, 2024
Jan 26, 2024
Jan 26, 2024
Jan 26, 2024
Jan 26, 2024
Jan 26, 2024
Jan 29, 2024
Jan 29, 2024
Jan 29, 2024
Jan 30, 2024
Loading
feat: add MySQLEngine and Loader load functionality #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
9ff3839
feat: add CloudSQLMySQLEngine class (#7)
jackwotherspoon eb2cba1
feat: support load document by query (#8)
loeng2023 53bfa4f
chore: update file and folder naming
jackwotherspoon bc10773
chore: lint
jackwotherspoon 3bf7506
chore: run mypy on tests
jackwotherspoon bd9d471
fix: load document schema.
loeng2023 a7a09e5
chore: make connector a class attribute
jackwotherspoon b579491
chore: merge in main
jackwotherspoon 5115207
chore: remove close and raise errors
jackwotherspoon 8f1926c
chore: add whitespace
jackwotherspoon 31c49e4
Add int tests to cover combination of customized content columns and …
loeng2023 e6b46e5
feat: support default metadata langchain_metadata in doc loader.
loeng2023 6f3bd2b
chore: re-add service_account_email check
jackwotherspoon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# TODO: Remove below import when minimum supported Python version is 3.10 | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Dict, Optional | ||
|
||
import google.auth | ||
import google.auth.transport.requests | ||
import requests | ||
import sqlalchemy | ||
from google.cloud.sql.connector import Connector | ||
|
||
if TYPE_CHECKING: | ||
import google.auth.credentials | ||
import pymysql | ||
|
||
|
||
def _get_iam_principal_email( | ||
credentials: google.auth.credentials.Credentials, | ||
) -> str: | ||
"""Get email address associated with current authenticated IAM principal. | ||
|
||
Email will be used for automatic IAM database authentication to Cloud SQL. | ||
|
||
Args: | ||
credentials (google.auth.credentials.Credentials): | ||
The credentials object to use in finding the associated IAM | ||
principal email address. | ||
|
||
Returns: | ||
email (str): | ||
The email address associated with the current authenticated IAM | ||
principal. | ||
""" | ||
# refresh credentials if they are not valid | ||
if not credentials.valid: | ||
request = google.auth.transport.requests.Request() | ||
credentials.refresh(request) | ||
jackwotherspoon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# if credentials are associated with a service account email, return early | ||
if hasattr(credentials, "_service_account_email"): | ||
return credentials._service_account_email | ||
# call OAuth2 api to get IAM principal email associated with OAuth2 token | ||
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" | ||
response = requests.get(url) | ||
jackwotherspoon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
response.raise_for_status() | ||
response_json: Dict = response.json() | ||
email = response_json.get("email") | ||
if email is None: | ||
raise ValueError( | ||
"Failed to automatically obtain authenticated IAM princpal's " | ||
"email address using environment's ADC credentials!" | ||
kurtisvg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
return email | ||
|
||
|
||
class MySQLEngine: | ||
"""A class for managing connections to a Cloud SQL for MySQL database.""" | ||
|
||
_connector: Optional[Connector] = None | ||
|
||
def __init__( | ||
self, | ||
engine: sqlalchemy.engine.Engine, | ||
) -> None: | ||
self.engine = engine | ||
|
||
@classmethod | ||
def from_instance( | ||
cls, | ||
project_id: str, | ||
region: str, | ||
instance: str, | ||
database: str, | ||
) -> MySQLEngine: | ||
"""Create an instance of MySQLEngine from Cloud SQL instance | ||
details. | ||
|
||
This method uses the Cloud SQL Python Connector to connect to Cloud SQL | ||
using automatic IAM database authentication with the Google ADC | ||
credentials sourced from the environment. | ||
|
||
More details can be found at https://github.com/GoogleCloudPlatform/cloud-sql-python-connector#credentials | ||
|
||
Args: | ||
project_id (str): Project ID of the Google Cloud Project where | ||
the Cloud SQL instance is located. | ||
region (str): Region where the Cloud SQL instance is located. | ||
instance (str): The name of the Cloud SQL instance. | ||
database (str): The name of the database to connect to on the | ||
Cloud SQL instance. | ||
|
||
Returns: | ||
(MySQLEngine): The engine configured to connect to a | ||
Cloud SQL instance database. | ||
""" | ||
engine = cls._create_connector_engine( | ||
instance_connection_name=f"{project_id}:{region}:{instance}", | ||
database=database, | ||
) | ||
return cls(engine=engine) | ||
|
||
@classmethod | ||
def _create_connector_engine( | ||
cls, instance_connection_name: str, database: str | ||
) -> sqlalchemy.engine.Engine: | ||
"""Create a SQLAlchemy engine using the Cloud SQL Python Connector. | ||
|
||
Defaults to use "pymysql" driver and to connect using automatic IAM | ||
database authentication with the IAM principal associated with the | ||
environment's Google Application Default Credentials. | ||
|
||
Args: | ||
instance_connection_name (str): The instance connection | ||
name of the Cloud SQL instance to establish a connection to. | ||
(ex. "project-id:instance-region:instance-name") | ||
database (str): The name of the database to connect to on the | ||
Cloud SQL instance. | ||
Returns: | ||
(sqlalchemy.engine.Engine): Engine configured using the Cloud SQL | ||
Python Connector. | ||
""" | ||
# get application default credentials | ||
credentials, _ = google.auth.default( | ||
scopes=["https://www.googleapis.com/auth/userinfo.email"] | ||
) | ||
iam_database_user = _get_iam_principal_email(credentials) | ||
if cls._connector is None: | ||
cls._connector = Connector() | ||
|
||
# anonymous function to be used for SQLAlchemy 'creator' argument | ||
def getconn() -> pymysql.Connection: | ||
conn = cls._connector.connect( # type: ignore | ||
instance_connection_name, | ||
"pymysql", | ||
user=iam_database_user, | ||
db=database, | ||
enable_iam_auth=True, | ||
) | ||
return conn | ||
|
||
return sqlalchemy.create_engine( | ||
"mysql+pymysql://", | ||
creator=getconn, | ||
) | ||
|
||
def connect(self) -> sqlalchemy.engine.Connection: | ||
"""Create a connection from SQLAlchemy connection pool. | ||
|
||
Returns: | ||
(sqlalchemy.engine.Connection): a single DBAPI connection checked | ||
out from the connection pool. | ||
""" | ||
return self.engine.connect() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import json | ||
from collections.abc import Iterable | ||
from typing import Any, Dict, List, Optional, Sequence, cast | ||
|
||
import sqlalchemy | ||
from langchain_community.document_loaders.base import BaseLoader | ||
from langchain_core.documents import Document | ||
|
||
from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine | ||
|
||
DEFAULT_METADATA_COL = "langchain_metadata" | ||
|
||
|
||
def _parse_doc_from_table( | ||
content_columns: Iterable[str], | ||
metadata_columns: Iterable[str], | ||
column_names: Iterable[str], | ||
rows: Sequence[Any], | ||
) -> List[Document]: | ||
docs = [] | ||
for row in rows: | ||
page_content = " ".join( | ||
str(getattr(row, column)) | ||
for column in content_columns | ||
if column in column_names | ||
) | ||
metadata = { | ||
column: getattr(row, column) | ||
for column in metadata_columns | ||
if column in column_names | ||
} | ||
if DEFAULT_METADATA_COL in metadata: | ||
extra_metadata = json.loads(metadata[DEFAULT_METADATA_COL]) | ||
del metadata[DEFAULT_METADATA_COL] | ||
metadata |= extra_metadata | ||
doc = Document(page_content=page_content, metadata=metadata) | ||
docs.append(doc) | ||
return docs | ||
|
||
|
||
class MySQLLoader(BaseLoader): | ||
"""A class for loading langchain documents from a Cloud SQL MySQL database.""" | ||
|
||
def __init__( | ||
self, | ||
engine: MySQLEngine, | ||
query: str, | ||
content_columns: Optional[List[str]] = None, | ||
metadata_columns: Optional[List[str]] = None, | ||
): | ||
""" | ||
Args: | ||
engine (MySQLEngine): MySQLEngine object to connect to the MySQL database. | ||
query (str): The query to execute in MySQL format. | ||
content_columns (List[str]): The columns to write into the `page_content` | ||
of the document. Optional. | ||
metadata_columns (List[str]): The columns to write into the `metadata` of the document. | ||
Optional. | ||
""" | ||
self.engine = engine | ||
self.query = query | ||
8C71 td> | self.content_columns = content_columns | |
self.metadata_columns = metadata_columns | ||
|
||
def load(self) -> List[Document]: | ||
""" | ||
Load langchain documents from a Cloud SQL MySQL database. | ||
|
||
Document page content defaults to the first columns present in the query or table and | ||
metadata defaults to all other columns. Use with content_columns to overwrite the column | ||
used for page content. Use metadata_columns to select specific metadata columns rather | ||
than using all remaining columns. | ||
|
||
If multiple content columns are specified, page_content’s string format will default to | ||
space-separated string concatenation. | ||
|
||
Returns: | ||
(List[langchain_core.documents.Document]): a list of Documents with metadata from | ||
specific columns. | ||
""" | ||
with self.engine.connect() as connection: | ||
result_proxy = connection.execute(sqlalchemy.text(self.query)) | ||
column_names = list(result_proxy.keys()) | ||
results = result_proxy.fetchall() | ||
content_columns = self.content_columns or [column_names[0]] | ||
metadata_columns = self.metadata_columns or [ | ||
col for col in column_names if col not in content_columns | ||
] | ||
return _parse_doc_from_table( | ||
content_columns, | ||
metadata_columns, | ||
column_names, | ||
results, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.