8000 fix: enable self signed jwt for grpc (#217) · googleapis/python-automl@20a72aa · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Dec 31, 2023. It is now read-only.

Commit 20a72aa

Browse files
fix: enable self signed jwt for grpc (#217)
PiperOrigin-RevId: 386504689 Source-Link: googleapis/googleapis@762094a Source-Link: https://github.com/googleapis/googleapis-gen/commit/6bfc480e1a161d5de121c2bcc3745885d33b265a
1 parent 2b09d79 commit 20a72aa

File tree

8 files changed

+88
-52
lines changed

8 files changed

+88
-52
lines changed

google/cloud/automl_v1/services/auto_ml/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,10 @@ def __init__(
435435
client_cert_source_for_mtls=client_cert_source_func,
436436
quota_project_id=client_options.quota_project_id,
437437
client_info=client_info,
438+
always_use_jwt_access=(
439+
Transport == type(self).get_transport_class("grpc")
440+
or Transport == type(self).get_transport_class("grpc_asyncio")
441+
),
438442
)
439443

440444
def create_dataset(

google/cloud/automl_v1/services/prediction_service/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,10 @@ def __init__(
355355
client_cert_source_for_mtls=client_cert_source_func,
356356
quota_project_id=client_options.quota_project_id,
357357
client_info=client_info,
358+
always_use_jwt_access=(
359+
Transport == type(self).get_transport_class("grpc")
360+
or Transport == type(self).get_transport_class("grpc_asyncio")
361+
),
358362
)
359363

360364
def predict(

google/cloud/automl_v1beta1/services/auto_ml/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,10 @@ def __init__(
483483
client_cert_source_for_mtls=client_cert_source_func,
484484
quota_project_id=client_options.quota_project_id,
485485
client_info=client_info,
486+
always_use_jwt_access=(
487+
Transport == type(self).get_transport_class("grpc")
488+
or Transport == type(self).get_transport_class("grpc_asyncio")
489+
),
486490
)
487491

488492
def create_dataset(

google/cloud/automl_v1beta1/services/prediction_service/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,10 @@ def __init__(
355355
client_cert_source_for_mtls=client_cert_source_func,
356356
quota_project_id=client_options.quota_project_id,
357357
client_info=client_info,
358+
always_use_jwt_access=(
359+
Transport == type(self).get_transport_class("grpc")
360+
or Transport == type(self).get_transport_class("grpc_asyncio")
361+
),
358362
)
359363

360364
def predict(

tests/unit/gapic/automl_v1/test_auto_ml.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -129,33 +129,28 @@ def test_auto_ml_client_from_service_account_info(client_class):
129129
assert client.transport._host == "automl.googleapis.com:443"
130130

131131

132-
@pytest.mark.parametrize("client_class", [AutoMlClient, AutoMlAsyncClient,])
133-
def test_auto_ml_client_service_account_always_use_jwt(client_class):
134-
with mock.patch.object(
135-
service_account.Credentials, "with_always_use_jwt_access", create=True
136-
) as use_jwt:
137-
creds = service_account.Credentials(None, None, None)
138-
client = client_class(credentials=creds)
139-
use_jwt.assert_not_called()
140-
141-
142132
@pytest.mark.parametrize(
143133
"transport_class,transport_name",
144134
[
145135
(transports.AutoMlGrpcTransport, "grpc"),
146136
(transports.AutoMlGrpcAsyncIOTransport, "grpc_asyncio"),
147137
],
148138
)
149-
def test_auto_ml_client_service_account_always_use_jwt_true(
150-
transport_class, transport_name
151-
):
139+
def test_auto_ml_client_service_account_always_use_jwt(transport_class, transport_name):
152140
with mock.patch.object(
153141
service_account.Credentials, "with_always_use_jwt_access", create=True
154142
) as use_jwt:
155143
creds = service_account.Credentials(None, None, None)
156144
transport = transport_class(credentials=creds, always_use_jwt_access=True)
157145
use_jwt.assert_called_once_with(True)
158146

147+
with mock.patch.object(
148+
service_account.Credentials, "with_always_use_jwt_access", create=True
149+
) as use_jwt:
150+
creds = service_account.Credentials(None, None, None)
151+
transport = transport_class(credentials=creds, always_use_jwt_access=False)
152+
use_jwt.assert_not_called()
153+
159154

160155
@pytest.mark.parametrize("client_class", [AutoMlClient, AutoMlAsyncClient,])
161156
def test_auto_ml_client_from_service_account_file(client_class):
@@ -224,6 +219,7 @@ def test_auto_ml_client_client_options(client_class, transport_class, transport_
224219
client_cert_source_for_mtls=None,
225220
quota_project_id=None,
226221
client_info=transports.base.DEFAULT_CLIENT_INFO,
222+
always_use_jwt_access=True,
227223
)
228224

229225
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -240,6 +236,7 @@ def test_auto_ml_client_client_options(client_class, transport_class, transport_
240236
client_cert_source_for_mtls=None,
241237
quota_project_id=None,
242238
client_info=transports.base.DEFAULT_CLIENT_INFO,
239+
always_use_jwt_access=True,
243240
)
244241

245242
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -256,6 +253,7 @@ def test_auto_ml_client_client_options(client_class, transport_class, transport_
256253
client_cert_source_for_mtls=None,
257254
quota_project_id=None,
258255
client_info=transports.base.DEFAULT_CLIENT_INFO,
256+
always_use_jwt_access=True,
259257
)
260258

261259
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
@@ -284,6 +282,7 @@ def test_auto_ml_client_client_options(client_class, transport_class, transport_
284282
client_cert_source_for_mtls=None,
285283
quota_project_id="octopus",
286284
client_info=transports.base.DEFAULT_CLIENT_INFO,
285+
always_use_jwt_access=True,
287286
)
288287

289288

@@ -346,6 +345,7 @@ def test_auto_ml_client_mtls_env_auto(
346345
client_cert_source_for_mtls=expected_client_cert_source,
347346
quota_project_id=None,
348347
client_info=transports.base.DEFAULT_CLIENT_INFO,
348+
always_use_jwt_access=True,
349349
)
350350

351351
# Check the case ADC client cert is provided. Whether client cert is used depends on
@@ -379,6 +379,7 @@ def test_auto_ml_client_mtls_env_auto(
379379
client_cert_source_for_mtls=expected_client_cert_source,
380380
quota_project_id=None,
381381
client_info=transports.base.DEFAULT_CLIENT_INFO,
382+
always_use_jwt_access=True,
382383
)
383384

384385
# Check the case client_cert_source and ADC client cert are not provided.
@@ -400,6 +401,7 @@ def test_auto_ml_client_mtls_env_auto(
400401
client_cert_source_for_mtls=None,
401402
quota_project_id=None,
402403
client_info=transports.base.DEFAULT_CLIENT_INFO,
404+
always_use_jwt_access=True,
403405
)
404406

405407

@@ -426,6 +428,7 @@ def test_auto_ml_client_client_options_scopes(
426428
client_cert_source_for_mtls=None,
427429
quota_project_id=None,
428430
client_info=transports.base.DEFAULT_CLIENT_INFO,
431+
always_use_jwt_access=True,
429432
)
430433

431434

@@ -452,6 +455,7 @@ def test_auto_ml_client_client_options_credentials_file(
452455
client_cert_source_for_mtls=None,
453456
quota_project_id=None,
454457
client_info=transports.base.DEFAULT_CLIENT_INFO,
458+
always_use_jwt_access=True,
455459
)
456460

457461

@@ -469,6 +473,7 @@ def test_auto_ml_client_client_options_from_dict():
469473
client_cert_source_for_mtls=None,
470474
quota_project_id=None,
471475
client_info=transports.base.DEFAULT_CLIENT_INFO,
476+
always_use_jwt_access=True,
472477
)
473478

474479

tests/unit/gapic/automl_v1/test_prediction_service.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,26 +130,14 @@ def test_prediction_service_client_from_service_account_info(client_class):
130130
assert client.transport._host == "automl.googleapis.com:443"
131131

132132

133-
@pytest.mark.parametrize(
134-
"client_class", [PredictionServiceClient, PredictionServiceAsyncClient,]
135-
)
136-
def test_prediction_service_client_service_account_always_use_jwt(client_class):
137-
with mock.patch.object(
138-
service_account.Credentials, "with_always_use_jwt_access", create=True
139-
) as use_jwt:
140-
creds = service_account.Credentials(None, None, None)
141-
client = client_class(credentials=creds)
142-
use_jwt.assert_not_called()
143-
144-
145133
@pytest.mark.parametrize(
146134
"transport_class,transport_name",
147135
[
148136
(transports.PredictionServiceGrpcTransport, "grpc"),
149137
(transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio"),
150138
],
151139
)
152-
def test_prediction_service_client_service_account_always_use_jwt_true(
140+
def test_prediction_service_client_service_account_always_use_jwt(
153141
transport_class, transport_name
154142
):
155143
with mock.patch.object(
@@ -159,6 +147,13 @@ def test_prediction_service_client_service_account_always_use_jwt_true(
159147
transport = transport_class(credentials=creds, always_use_jwt_access=True)
160148
use_jwt.assert_called_once_with(True)
161149

150+
with mock.patch.object(
151+
service_account.Credentials, "with_always_use_jwt_access", create=True
152+
) as use_jwt:
153+
creds = service_account.Credentials(None, None, None)
154+
transport = transport_class(credentials=creds, always_use_jwt_access=False)
155+
use_jwt.assert_not_called()
156+
162157

163158
@pytest.mark.parametrize(
164159
"client_class", [PredictionServiceClient, PredictionServiceAsyncClient,]
@@ -239,6 +234,7 @@ def test_prediction_service_client_client_options(
239234
client_cert_source_for_mtls=None,
240235
quota_project_id=None,
241236
client_info=transports.base.DEFAULT_CLIENT_INFO,
237+
always_use_jwt_access=True,
242238
)
243239

244240
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -255,6 +251,7 @@ def test_prediction_service_client_client_options(
255251
client_cert_source_for_mtls=None,
256252
quota_project_id=None,
257253
client_info=transports.base.DEFAULT_CLIENT_INFO,
254+
always_use_jwt_access=True,
258255
)
259256

260257
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
@@ -271,6 +268,7 @@ def test_prediction_service_client_client_options(
271268
client_cert_source_for_mtls=None,
272269
quota_project_id=None,
273270
client_info=transports.base.DEFAULT_CLIENT_INFO,
271+
always_use_jwt_access=True,
274272
)
275273

276274
# Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
@@ -299,6 +297,7 @@ def test_prediction_service_client_client_options(
299297
client_cert_source_for_mtls=None,
300298
quota_project_id="octopus",
301299
client_info=transports.base.DEFAULT_CLIENT_INFO,
300+
always_use_jwt_access=True,
302301
)
303302

304303

@@ -375,6 +374,7 @@ def test_prediction_service_client_mtls_env_auto(
375374
client_cert_source_for_mtls=expected_client_cert_source,
376375
quota_project_id=None,
377376
client_info=transports.base.DEFAULT_CLIENT_INFO,
377+
always_use_jwt_access=True,
378378
)
379379

380380
# Check the case ADC client cert is provided. Whether client cert is used depends on
@@ -408,6 +408,7 @@ def test_prediction_service_client_mtls_env_auto(
408408
client_cert_source_for_mtls=expected_client_cert_source,
409409
quota_project_id=None,
410410
client_info=transports.base.DEFAULT_CLIENT_INFO,
411+
always_use_jwt_access=True,
411412
)
412413

413414
# Check the case client_cert_source and ADC client cert are not provided.
@@ -429,6 +430,7 @@ def test_prediction_service_client_mtls_env_auto(
429430
client_cert_source_for_mtls=None,
430431
quota_project_id=None,
431432
client_info=transports.base.DEFAULT_CLIENT_INFO,
433+
always_use_jwt_access=True,
432434
)
433435

434436

@@ -459,6 +461,7 @@ def test_prediction_service_client_client_options_scopes(
459461
client_cert_source_for_mtls=None,
460462
quota_project_id=None,
461463
client_info=transports.base.DEFAULT_CLIENT_INFO,
464+
always_use_jwt_access=True,
462465
)
463466

464467

@@ -489,6 +492,7 @@ def test_prediction_service_client_client_options_credentials_file(
489492
client_cert_source_for_mtls=None,
490493
quota_project_id=None,
491494
client_info=transports.base.DEFAULT_CLIENT_INFO,
495+
always_use_jwt_access=True,
492496
)
493497

494498

@@ -508,6 +512,7 @@ def test_prediction_service_client_client_options_from_dict():
508512
client_cert_source_for_mtls=None,
509513
quota_project_id=None,
510514
client_info=transports.base.DEFAULT_CLIENT_INFO,
515+
always_use_jwt_access=True,
511516
)
512517

513518

0 commit comments

Comments
 (0)
0