120
120
}
121
121
TFLITE_FORMAT_2 = ml .TFLiteFormat .from_dict (TFLITE_FORMAT_JSON_2 )
122
122
123
+ AUTOML_MODEL_NAME = 'projects/111111111111/locations/us-central1/models/ICN7683346839371803263'
124
+ AUTOML_MODEL_SOURCE = ml .TFLiteAutoMlSource (AUTOML_MODEL_NAME )
125
+ TFLITE_FORMAT_JSON_3 = {
126
+ 'automlModel' : AUTOML_MODEL_NAME ,
127
+ 'sizeBytes' : '3456789'
128
+ }
129
+ TFLITE_FORMAT_3 = ml .TFLiteFormat .from_dict (TFLITE_FORMAT_JSON_3 )
130
+
131
+ AUTOML_MODEL_NAME_2 = 'projects/2222222222/locations/us-central1/models/ICN2222222222222222222'
132
+ AUTOML_MODEL_NAME_JSON_2 = {'automlModel' : AUTOML_MODEL_NAME_2 }
133
+ AUTOML_MODEL_SOURCE_2 = ml .TFLiteAutoMlSource (AUTOML_MODEL_NAME_2 )
134
+
123
135
CREATED_UPDATED_MODEL_JSON_1 = {
124
136
'name' : MODEL_NAME_1 ,
125
137
'displayName' : DISPLAY_NAME_1 ,
@@ -403,7 +415,15 @@ def test_model_keyword_based_creation_and_setters(self):
403
415
'tfliteModel' : TFLITE_FORMAT_JSON_2
404
416
}
405
417
406
- def test_model_format_source_creation (self ):
418
+ model .model_format = TFLITE_FORMAT_3
419
+ assert model .as_dict () == {
420
+ 'displayName' : DISPLAY_NAME_2 ,
421
+ 'tags' : TAGS_2 ,
422
+ 'tfliteModel' : TFLITE_FORMAT_JSON_3
423
+ }
424
+
425
+
426
+ def test_gcs_tflite_model_format_source_creation (self ):
407
427
model_source = ml .TFLiteGCSModelSource (gcs_tflite_uri = GCS_TFLITE_URI )
408
428
model_format = ml .TFLiteFormat (model_source = model_source )
409
429
model = ml .Model (display_name = DISPLAY_NAME_1 , model_format = model_format )
@@ -414,19 +434,37 @@ def test_model_format_source_creation(self):
414
434
}
415
435
}
416
436
437
+ def test_auto_ml_tflite_model_format_source_creation (self ):
438
+ model_source = ml .TFLiteAutoMlSource (auto_ml_model = AUTOML_MODEL_NAME )
439
+ model_format = ml .TFLiteFormat (model_source = model_source )
440
+ model = ml .Model (display_name = DISPLAY_NAME_1 , model_format = model_format )
441
+ assert model .as_dict () == {
442
+ 'displayName' : DISPLAY_NAME_1 ,
443
+ 'tfliteModel' : {
444
+ 'automlModel' : AUTOML_MODEL_NAME
445
+ }
446
+ }
447
+
417
448
def test_source_creation_from_tflite_file (self ):
418
449
model_source = ml .TFLiteGCSModelSource .from_tflite_model_file (
419
450
"my_model.tflite" , "my_bucket" )
420
451
assert model_source .as_dict () == {
421
452
'gcsTfliteUri' : 'gs://my_bucket/Firebase/ML/Models/my_model.tflite'
422
453
}
423
454
424
- def test_model_source_setters (self ):
10000
td>
455
+ def test_gcs_tflite_model_source_setters (self ):
425
456
model_source = ml .TFLiteGCSModelSource (GCS_TFLITE_URI )
426
457
model_source .gcs_tflite_uri = GCS_TFLITE_URI_2
427
458
assert model_source .gcs_tflite_uri == GCS_TFLITE_URI_2
428
459
assert model_source .as_dict () == GCS_TFLITE_URI_JSON_2
429
460
461
+ def test_auto_ml_tflite_model_source_setters (self ):
462
+ model_source = ml .TFLiteAutoMlSource (AUTOML_MODEL_NAME )
463
+ model_source .auto_ml_model = AUTOML_MODEL_NAME_2
464
+ assert model_source .auto_ml_model == AUTOML_MODEL_NAME_2
465
+ assert model_source .as_dict () == AUTOML_MODEL_NAME_JSON_2
466
+
467
+
430
468
def test_model_format_setters (self ):
431
469
model_format = ml .TFLiteFormat (model_source = GCS_TFLITE_MODEL_SOURCE )
432
470
model_format .model_source = GCS_TFLITE_MODEL_SOURCE_2
@@ -437,6 +475,14 @@ def test_model_format_setters(self):
437
475
}
438
476
}
439
477
478
+ model_format .model_source = AUTOML_MODEL_SOURCE
479
+ assert model_format .model_source == AUTOML_MODEL_SOURCE
480
+ assert model_format .as_dict () == {
481
+ 'tfliteModel' : {
482
+ 'automlModel' : AUTOML_MODEL_NAME
483
+ }
484
+ }
485
+
440
486
def test_model_as_dict_for_upload (self ):
441
487
model_source = ml .TFLiteGCSModelSource (gcs_tflite_uri = GCS_TFLITE_URI )
442
488
model_format = ml .TFLiteFormat (model_source = model_source )
@@ -522,6 +568,23 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type):
522
568
ml .TFLiteGCSModelSource (gcs_tflite_uri = uri )
523
569
check_error (excinfo , exc_type )
524
570
571
+ @pytest .mark .parametrize ('auto_ml_model, exc_type' , [
572
+ (123 , TypeError ),
573
+ ('abc' , ValueError ),
574
+ ('/projects/123456/locations/us-central1/models/noLeadingSlash' , ValueError ),
575
+ ('projects/123546/models/ICN123456' , ValueError ),
576
+ ('projects//locations/us-central1/models/ICN123456' , ValueError ),
577
+ ('projects/123456/locations//models/ICN123456' , ValueError ),
578
+ ('projects/123456/locations/us-central1/models/' , ValueError ),
579
+ ('projects/ABC/locations/us-central1/models/ICN123456' , ValueError ),
580
+ ('projects/123456/locations/us-central1/models/@#$%^&' , ValueError ),
581
+ ('projects/123456/locations/us-cent/ral1/models/ICN123456' , ValueError ),
582
+ ])
583
+ def test_auto_ml_tflite_source_validation_errors (self , auto_ml_model , exc_type ):
584
+ with pytest .raises (exc_type ) as excinfo :
585
+ ml .TFLiteAutoMlSource (auto_ml_model = auto_ml_model )
586
+ check_error (excinfo , exc_type )
587
+
525
588
def test_wait_for_unlocked_not_locked (self ):
526
589
model = ml .Model (display_name = "not_locked" )
527
590
model .wait_for_unlocked ()
0 commit comments