@@ -418,6 +418,15 @@ def parse_args():
418
418
default = 4 ,
419
419
help = ("The dimension of the LoRA update matrices." ),
420
420
)
421
+ parser .add_argument (
422
+ "--image_interpolation_mode" ,
423
+ type = str ,
424
+ default = "lanczos" ,
425
+ choices = [
426
+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
427
+ ],
428
+ help = "The image interpolation method to use for resizing images." ,
429
+ )
421
430
422
431
args = parser .parse_args ()
423
432
env_local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
@@ -649,10 +658,17 @@ def tokenize_captions(examples, is_train=True):
649
658
)
650
659
return inputs .input_ids
651
660
652
- # Preprocessing the datasets.
661
+ # Get the specified interpolation method from the args
662
+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
663
+
664
+ # Raise an error if the interpolation method is invalid
665
+ if interpolation is None :
666
+ raise ValueError (f"Unsupported interpolation mode { args .image_interpolation_mode } ." )
667
+
668
+ # Data preprocessing transformations
653
669
train_transforms = transforms .Compose (
654
670
[
655
- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
671
+ transforms .Resize (args .resolution , interpolation = interpolation ), # Use dynamic interpolation method
656
672
transforms .CenterCrop (args .resolution ) if args .center_crop else transforms .RandomCrop (args .resolution ),
657
673
transforms .RandomHorizontalFlip () if args .random_flip else transforms .Lambda (lambda x : x ),
658
674
transforms .ToTensor (),
0 commit comments