|
16 | 16 |
|
17 | 17 | import textwrap
|
18 | 18 | from typing import Optional
|
| 19 | +from unittest import mock |
19 | 20 |
|
20 | 21 | from google.adk.tools import BaseTool
|
21 | 22 | from google.adk.tools.bigquery import BigQueryCredentialsConfig
|
22 | 23 | from google.adk.tools.bigquery import BigQueryToolset
|
23 | 24 | from google.adk.tools.bigquery.config import BigQueryToolConfig
|
24 | 25 | from google.adk.tools.bigquery.config import WriteMode
|
| 26 | +from google.adk.tools.bigquery.query_tool import execute_sql |
| 27 | +from google.cloud import bigquery |
| 28 | +from google.oauth2.credentials import Credentials |
25 | 29 | import pytest
|
26 | 30 |
|
27 | 31 |
|
@@ -218,3 +222,123 @@ async def test_execute_sql_declaration_write(tool_config):
|
218 | 222 | - Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE".
|
219 | 223 | - First run "DROP TABLE", followed by "CREATE TABLE".
|
220 | 224 | - To insert data into a table, use "INSERT INTO" statement.""")
|
| 225 | + |
| 226 | + |
| 227 | +@pytest.mark.parametrize( |
| 228 | + ("write_mode",), |
| 229 | + [ |
| 230 | + pytest.param( |
| 231 | + WriteMode.BLOCKED, |
| 232 | + id="blocked", |
| 233 | + ), |
| 234 | + pytest.param( |
| 235 | + WriteMode.ALLOWED, |
| 236 | + id="allowed", |
| 237 | + ), |
| 238 | + ], |
| 239 | +) |
| 240 | +def test_execute_sql_select_stmt(write_mode): |
| 241 | + """Test execute_sql tool for SELECT query when writes are blocked.""" |
| 242 | + project = "my_project" |
| 243 | + query = "SELECT 123 AS num" |
| 244 | + statement_type = "SELECT" |
| 245 | + query_result = [{"num": 123}] |
| 246 | + credentials = mock.create_autospec(Credentials, instance=True) |
| 247 | + tool_config = BigQueryToolConfig(write_mode=write_mode) |
| 248 | + |
| 249 | + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: |
| 250 | + # The mock instance |
| 251 | + bq_client = Client.return_value |
| 252 | + |
| 253 | + # Simulate the result of query API |
| 254 | + query_job = mock.create_autospec(bigquery.QueryJob) |
| 255 | + query_job.statement_type = statement_type |
| 256 | + bq_client.query.return_value = query_job |
| 257 | + |
| 258 | + # Simulate the result of query_and_wait API |
| 259 | + bq_client.query_and_wait.return_value = query_result |
| 260 | + |
| 261 | + # Test the tool |
| 262 | + result = execute_sql(project, query, credentials, tool_config) |
| 263 | + assert result == {"status": "SUCCESS", "rows": query_result} |
| 264 | + |
| 265 | + |
| 266 | +@pytest.mark.parametrize( |
| 267 | + ("query", "statement_type"), |
| 268 | + [ |
| 269 | + pytest.param( |
| 270 | + "CREATE TABLE my_dataset.my_table AS SELECT 123 AS num", |
| 271 | + "CREATE_AS_SELECT", |
| 272 | + id="create-as-select", |
| 273 | + ), |
| 274 | + pytest.param( |
| 275 | + "DROP TABLE my_dataset.my_table", |
| 276 | + "DROP_TABLE", |
| 277 | + id="drop-table", |
| 278 | + ), |
| 279 | + ], |
| 280 | +) |
| 281 | +def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): |
| 282 | + """Test execute_sql tool for SELECT query when writes are blocked.""" |
| 283 | + project = "my_project" |
| 284 | + query_result = [] |
| 285 | + credentials = mock.create_autospec(Credentials, instance=True) |
| 286 | + tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) |
| 287 | + |
| 288 | + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: |
| 289 | + # The mock instance |
| 290 | + bq_client = Client.return_value |
| 291 | + |
| 292 | + # Simulate the result of query API |
| 293 | + query_job = mock.create_autospec(bigquery.QueryJob) |
| 294 | + query_job.statement_type = statement_type |
| 295 | + bq_client.query.return_value = query_job |
| 296 | + |
| 297 | + # Simulate the result of query_and_wait API |
| 298 | + bq_client.query_and_wait.return_value = query_result |
| 299 | + |
| 300 | + # Test the tool |
| 301 | + result = execute_sql(project, query, credentials, tool_config) |
| 302 | + assert result == {"status": "SUCCESS", "rows": query_result} |
| 303 | + |
| 304 | + |
| 305 | +@pytest.mark.parametrize( |
| 306 | + ("query", "statement_type"), |
| 307 | + [ |
| 308 | + pytest.param( |
| 309 | + "CREATE TABLE my_dataset.my_table AS SELECT 123 AS num", |
| 310 | + "CREATE_AS_SELECT", |
| 311 | + id="create-as-select", |
| 312 | + ), |
| 313 | + pytest.param( |
| 314 | + "DROP TABLE my_dataset.my_table", |
| 315 | + "DROP_TABLE", |
| 316 | + id="drop-table", |
| 317 | + ), |
| 318 | + ], |
| 319 | +) |
| 320 | +def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): |
| 321 | + """Test execute_sql tool for SELECT query when writes are blocked.""" |
| 322 | + project = "my_project" |
| 323 | + query_result = [] |
| 324 | + credentials = mock.create_autospec(Credentials, instance=True) |
| 325 | + tool_config = BigQueryToolConfig(write_mode=WriteMode.BLOCKED) |
| 326 | + |
| 327 | + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: |
| 328 | + # The mock instance |
| 329 | + bq_client = Client.return_value |
| 330 | + |
| 331 | + # Simulate the result of query API |
| 332 | + query_job = mock.create_autospec(bigquery.QueryJob) |
| 333 | + query_job.statement_type = statement_type |
| 334 | + bq_client.query.return_value = query_job |
| 335 | + |
| 336 | + # Simulate the result of query_and_wait API |
| 337 | + bq_client.query_and_wait.return_value = query_result |
| 338 | + |
| 339 | + # Test the tool |
| 340 | + result = execute_sql(project, query, credentials, tool_config) |
| 341 | + assert result == { |
| 342 | + "status": "ERROR", |
| 343 | + "error_details": "Read-only mode only supports SELECT statements.", |
| 344 | + } |
0 commit comments