8000 feat: support saving with customized content column and saving/loading with non-default metadata JSON column. by loeng2023 · Pull Request #12 · googleapis/langchain-google-cloud-sql-mssql-python · GitHub
[go: up one dir, main page]

Skip to content

feat: support saving with customized content column and saving/loading with non-default metadata JSON column. #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/langchain_google_cloud_sql_mssql/mssql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def init_document_table(
self,
table_name: str,
metadata_columns: List[sqlalchemy.Column] = [],
store_metadata: bool = True,
content_column: str = "page_content",
metadata_json_column: Optional[str] = "langchain_metadata",
overwrite_existing: bool = False,
) -> None:
"""
Create a table for saving of langchain documents.
Expand All @@ -162,22 +164,29 @@ def init_document_table(
table_name (str): The MSSQL database table name.
metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns
to create for custom metadata. Optional.
store_metadata (bool): Whether to store extra metadata in a metadata column
if not described in 'metadata' field list (Default: True).
content_column (str): The column to store document content.
Deafult: `page_content`.
metadata_json_column (Optional[str]): The column to store extra metadata in JSON format.
Default: `langchain_metadata`. Optional.
overwrite_existing (bool): Whether to drop existing table. Default: False.
"""
if overwrite_existing:
with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(f'DROP TABLE IF EXISTS "{table_name}";'))

columns = [
sqlalchemy.Column(
"page_content",
content_column,
sqlalchemy.UnicodeText,
primary_key=False,
nullable=False,
)
]
columns += metadata_columns
if store_metadata:
if metadata_json_column:
columns.append(
sqlalchemy.Column(
"langchain_metadata",
metadata_json_column,
sqlalchemy.JSON,
primary_key=False,
nullable=True,
Expand Down
113 changes: 88 additions & 25 deletions src/langchain_google_cloud_sql_mssql/mssql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,41 @@


def _parse_doc_from_row(
content_columns: Iterable[str], metadata_columns: Iterable[str], row: Dict
content_columns: Iterable[str],
metadata_columns: Iterable[str],
row: Dict,
metadata_json_column: str = DEFAULT_METADATA_COL,
) -> Document:
page_content = " ".join(
str(row[column]) for column in content_columns if column in row
)
metadata: Dict[str, Any] = {}
# unnest metadata from langchain_metadata column
if DEFAULT_METADATA_COL in metadata_columns and row.get(DEFAULT_METADATA_COL):
for k, v in row[DEFAULT_METADATA_COL].items():
if row.get(metadata_json_column):
for k, v in row[metadata_json_column].items():
metadata[k] = v
# load metadata from other columns
for column in metadata_columns:
if column in row and column != DEFAULT_METADATA_COL:
if column in row and column != metadata_json_column:
metadata[column] = row[column]
return Document(page_content=page_content, metadata=metadata)


def _parse_row_from_doc(column_names: Iterable[str], doc: Document) -> Dict:
def _parse_row_from_doc(
column_names: Iterable[str],
doc: Document,
content_column: str = DEFAULT_CONTENT_COL,
metadata_json_column: str = DEFAULT_METADATA_COL,
) -> Dict:
doc_metadata = doc.metadata.copy()
row: Dict[str, Any] = {DEFAULT_CONTENT_COL: doc.page_content}
row: Dict[str, Any] = {content_column: doc.page_content}
for entry in doc.metadata:
if entry in column_names:
row[entry] = doc_metadata[entry]
del doc_metadata[entry]
# store extra metadata in langchain_metadata column in json format
if DEFAULT_METADATA_COL in column_names and len(doc_metadata) > 0:
row[DEFAULT_METADATA_COL] = doc_metadata
if metadata_json_column in column_names and len(doc_metadata) > 0:
row[metadata_json_column] = doc_metadata
return row


Expand All @@ -66,6 +74,7 @@ def __init__(
query: str = "",
content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
metadata_json_column: Optional[str] = None,
):
"""
Document page content defaults to the first column present in the query or table and
Expand All @@ -77,19 +86,22 @@ def __init__(
space-separated string concatenation.

Args:
engine (MSSQLEngine): MSSQLEngine object to connect to the MSSQL database.
table_name (str): The MSSQL database table name. (OneOf: table_name, query).
query (str): The query to execute in MSSQL format. (OneOf: table_name, query).
content_columns (List[str]): The columns to write into the `page_content`
of the document. Optional.
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
Optional.
engine (MSSQLEngine): MSSQLEngine object to connect to the MSSQL database.
table_name (str): The MSSQL database table name. (OneOf: table_name, query).
query (str): The query to execute in MSSQL format. (OneOf: table_name, query).
content_columns (List[str]): The columns to write into the `page_content`
of the document. Optional.
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
Optional.
metadata_json_column (str): The name of the JSON column to use as the metadata’s base
dictionary. Default: `langchain_metadata`. Optional.
"""
self.engine = engine
self.table_name = table_name
self.query = query
self.content_columns = content_columns
self.metadata_columns = metadata_columns
self.metadata_json_column = metadata_json_column
if not self.table_name and not self.query:
raise ValueError("One of 'table_name' or 'query' must be specified.")
if self.table_name and self.query:
10000 Expand Down Expand Up @@ -128,6 +140,25 @@ def lazy_load(self) -> Iterator[Document]:
metadata_columns = self.metadata_columns or [
col for col in column_names if col not in content_columns
]
# check validity of metadata json column
if (
self.metadata_json_column
and self.metadata_json_column not in column_names
):
raise ValueError(
f"Column {self.metadata_json_column} not found in query result {column_names}."
)
# check validity of other column
all_names = content_columns + metadata_columns
for name in all_names:
if name not in column_names:
raise ValueError(
f"Column {name} not found in query result {column_names}."
)
# use default metadata json column if not specified
metadata_json_column = self.metadata_json_column or DEFAULT_METADATA_COL

# load document one by one
while True:
row = result_proxy.fetchone()
if not row:
Expand All @@ -136,11 +167,13 @@ def lazy_load(self) -> Iterator[Document]:
row_data = {}
for column in column_names:
value = getattr(row, column)
if column == DEFAULT_METADATA_COL:
if column == metadata_json_column:
row_data[column] = json.loads(value)
else:
row_data[column] = value
yield _parse_doc_from_row(content_columns, metadata_columns, row_data)
yield _parse_doc_from_row(
content_columns, metadata_columns, row_data, metadata_json_column
)


class MSSQLDocumentSaver:
Expand All @@ -150,6 +183,8 @@ def __init__(
self,
engine: MSSQLEngine,
table_name: str,
content_column: Optional[str] = None,
metadata_json_column: Optional[str] = None,
):
"""
MSSQLDocumentSaver allows for saving of langchain documents in a database. If the table
Expand All @@ -160,14 +195,28 @@ def __init__(
Args:
engine: MSSQLEngine object to connect to the MSSQL database.
table_name: The name of table for saving documents.
content_column (str): The column to store document content.
Deafult: `page_content`. Optional.
metadata_json_column (str): The name of the JSON column to use as the metadata’s base
dictionary. Default: `langchain_metadata`. Optional.
"""
self.engine = engine
self.table_name = table_name
self._table = self.engine._load_document_table(table_name)
if DEFAULT_CONTENT_COL not in self._table.columns.keys():
self.content_column = content_column or DEFAULT_CONTENT_COL
if self.content_column not in self._table.columns.keys():
raise ValueError(
f"Missing '{self.content_column}' field in table {table_name}."
)
# check metadata_json_column existence if it's provided.
if (
metadata_json_column
and metadata_json_column not in self._table.columns.keys()
):
raise ValueError(
f"Missing '{DEFAULT_CONTENT_COL}' field in table {table_name}."
f"Cannot find '{metadata_json_column}' column in table {table_name}."
)
self.metadata_json_column = metadata_json_column or DEFAULT_METADATA_COL

def add_documents(self, docs: List[Document]) -> None:
"""
Expand All @@ -179,9 +228,16 @@ def add_documents(self, docs: List[Document]) -> None:
"""
with self.engine.connect() as conn:
for doc in docs:
row = _parse_row_from_doc(self._table.columns.keys(), doc)
if DEFAULT_METADATA_COL in row:
row[DEFAULT_METADATA_COL] = json.dumps(row[DEFAULT_METADATA_COL])
row = _parse_row_from_doc(
self._table.columns.keys(),
doc,
self.content_column,
self.metadata_json_column,
)
if self.metadata_json_column in row:
row[self.metadata_json_column] = json.dumps(
row[self.metadata_json_column]
)
conn.execute(sqlalchemy.insert(self._table).values(row))
conn.commit()

Expand All @@ -195,9 +251,16 @@ def delete(self, docs: List[Document]) -> None:
"""
with self.engine.connect() as conn:
for doc in docs:
row = _parse_row_from_doc(self._table.columns.keys(), doc)
if DEFAULT_METADATA_COL in row:
row[DEFAULT_METADATA_COL] = json.dumps(row[DEFAULT_METADATA_COL])
row = _parse_row_from_doc(
self._table.columns.keys(),
doc,
self.content_column,
self.metadata_json_column,
)
if self.metadata_json_column in row:
row[self.metadata_json_column] = json.dumps(
row[self.metadata_json_column]
)
# delete by matching all fields of document
where_conditions = []
for col in self._table.columns:
Expand Down
54 changes: 37 additions & 17 deletions tests/integration/test_mssql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def test_load_from_query_with_langchain_metadata(engine):
query=query,
metadata_columns=[
"fruit_name",
"langchain_metadata",
],
)

Expand Down Expand Up @@ -311,8 +310,9 @@ def test_save_doc_with_default_metadata(engine):
]


@pytest.mark.parametrize("store_metadata", [True, False])
def test_save_doc_with_customized_metadata(engine, store_metadata):
@pytest.mark.parametrize("metadata_json_column", [None, "metadata_col_test"])
def test_save_doc_with_customized_metadata(engine, metadata_json_column):
content_column = "content_col_test"
engine.init_document_table(
table_name,
metadata_columns=[
Expand All @@ -329,35 +329,43 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
nullable=True,
),
],
store_metadata=store_metadata,
content_column=content_column,
metadata_json_column=metadata_json_column,
overwrite_existing=True,
)
test_docs = [
Document(
page_content="Granny Smith 150 0.99",
metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1},
),
]
saver = MSSQLDocumentSaver(engine=engine, table_name=table_name)
saver = MSSQLDocumentSaver(
engine=engine,
table_name=table_name,
content_column=content_column,
metadata_json_column=metadata_json_column,
)
loader = MSSQLLoader(
engine=engine,
table_name=table_name,
content_columns=[content_column],
metadata_columns=[
"fruit_id",
"fruit_name",
"organic",
],
metadata_json_column=metadata_json_column,
)

saver.add_documents(test_docs)
docs = loader.load()

if store_metadata:
if metadata_json_column:
docs == test_docs
assert engine._load_document_table(table_name).columns.keys() == [
"page_content",
content_column,
"fruit_name",
"organic",
"langchain_metadata",
metadata_json_column,
]
else:
assert docs == [
Expand All @@ -367,7 +375,7 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
),
]
assert engine._load_document_table(table_name).columns.keys() == [
"page_content",
content_column,
"fruit_name",
"organic",
]
Expand All @@ -376,7 +384,7 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
def test_save_doc_without_metadata(engine):
engine.init_document_table(
table_name,
store_metadata=False,
metadata_json_column=None,
)
test_docs = [
Document(
Expand Down Expand Up @@ -430,8 +438,9 @@ def test_delete_doc_with_default_metadata(engine):
assert len(loader.load()) == 0


@pytest.mark.parametrize("store_metadata", [True, False])
def test_delete_doc_with_customized_metadata(engine, store_metadata):
@pytest.mark.parametrize("metadata_json_column", [None, "metadata_col_test"])
def test_delete_doc_with_customized_metadata(engine, metadata_json_column):
content_column = "content_col_test"
engine.init_document_table(
table_name,
metadata_columns=[
Expand All @@ -448,7 +457,9 @@ def test_delete_doc_with_customized_metadata(engine, store_metadata):
nullable=True,
),
],
store_metadata=store_metadata,
content_column=content_column,
metadata_json_column=metadata_json_column,
overwrite_existing=True,
)
test_docs = [
Document(
Expand All @@ -460,8 +471,18 @@ def test_delete_doc_with_customized_metadata(engine, store_metadata):
metadata={"fruit_id": 2, "fruit_name": "Banana", "organic": 1},
),
]
saver = MSSQLDocumentSaver(engine=engine, table_name=table_name)
loader = MSSQLLoader(engine=engine, table_name=table_name)
saver = MSSQLDocumentSaver(
engine=engine,
table_name=table_name,
content_column=content_column,
metadata_json_column=metadata_json_column,
)
loader = MSSQLLoader(
engine=engine,
table_name=table_name,
content_columns=[content_column],
metadata_json_column=metadata_json_column,
)

saver.add_documents(test_docs)
docs = loader.load()
Expand Down Expand Up @@ -491,7 +512,6 @@ def test_delete_doc_with_query(engine):
nullable=True,
),
],
store_metadata=True,
)
test_docs = [
Document(
Expand Down
Loading
0