8000 [train_text_to_image_lora] Better image interpolation in training scr… · huggingface/diffusers@3da98e7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3da98e7

Browse files
authored
[train_text_to_image_lora] Better image interpolation in training scripts follow up (#11427)
* Update train_text_to_image_lora.py * update_train_text_to_image_lora
1 parent b3b04fe commit 3da98e7

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,15 @@ def parse_args():
418418
default=4,
419419
help=("The dimension of the LoRA update matrices."),
420420
)
421+
parser.add_argument(
422+
"--image_interpolation_mode",
423+
type=str,
424+
default="lanczos",
425+
choices=[
426+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
427+
],
428+
help="The image interpolation method to use for resizing images.",
429+
)
421430

422431
args = parser.parse_args()
423432
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -649,10 +658,17 @@ def tokenize_captions(examples, is_train=True):
649658
)
650659
return inputs.input_ids
651660

652-
# Preprocessing the datasets.
661+
# Get the specified interpolation method from the args
662+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
663+
664+
# Raise an error if the interpolation method is invalid
665+
if interpolation is None:
666+
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
667+
668+
# Data preprocessing transformations
653669
train_transforms = transforms.Compose(
654670
[
655-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
671+
transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
656672
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
657673
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
658674
transforms.ToTensor(),

0 commit comments

Comments
 (0)
0