From 2900967ffa2685cc867ac0e694e0e53e97231d34 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Fri, 12 Jul 2024 06:06:52 +0000 Subject: [PATCH 1/6] Add new aspect ratio bucketing transform --- diffusion/datasets/image_caption.py | 32 +++++++++++++----- diffusion/datasets/laion/transforms.py | 47 ++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 8 deletions(-) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 81939c61..f45f7203 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -15,7 +15,8 @@ from torch.utils.data import DataLoader from torchvision import transforms -from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransorm, RandomCropSquare +from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransorm, + RandomCropBucketedAspectRatioTransorm, RandomCropSquare) from diffusion.datasets.utils import make_streams from diffusion.models.text_encoder import MultiTokenizer @@ -45,6 +46,7 @@ class StreamingImageCaptionDataset(StreamingDataset): transform (Callable, optional): The transforms to apply to the image. Default: ``None``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. + aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``. sdxl_conditioning (bool): Whether or not to include SDXL microconditioning in a sample. Default: `False`. zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``False``. **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader @@ -63,6 +65,7 @@ def __init__( transform: Optional[Callable] = None, image_key: str = 'image', caption_key: str = 'caption', + aspect_ratio_bucket_key: Optional[str] = None, sdxl_conditioning: bool = False, zero_dropped_captions: bool = False, **streaming_kwargs, @@ -90,6 +93,9 @@ def __init__( self.caption_selection = caption_selection self.image_key = image_key self.caption_key = caption_key + self.aspect_ratio_bucket_key = aspect_ratio_bucket_key + if isinstance(self.crop, RandomCropBucketedAspectRatioTransorm): + assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransorm' self.zero_dropped_captions = zero_dropped_captions self.tokenizer = tokenizer @@ -107,7 +113,9 @@ def __getitem__(self, index): orig_w, orig_h = img.size # Image transforms - if self.crop is not None: + if isinstance(self.crop, RandomCropBucketedAspectRatioTransorm): + img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key]) + elif self.crop is not None: img, crop_top, crop_left = self.crop(img) else: crop_top, crop_left = 0, 0 @@ -179,6 +187,7 @@ def build_streaming_image_caption_dataloader( transform: Optional[List[Callable]] = None, image_key: str = 'image', caption_key: str = 'caption', + aspect_ratio_bucket_key: Optional[str] = None, crop_type: Optional[str] = 'square', zero_dropped_captions: bool = True, sdxl_conditioning: bool = False, @@ -212,7 +221,8 @@ def build_streaming_image_caption_dataloader( transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. - crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio']. + aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``. + crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']. Default: ``'square'``. zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``. sdxl_conditioning (bool): Whether or not to include SDXL microconditioning in a sample. Default: `False`. @@ -225,12 +235,14 @@ def build_streaming_image_caption_dataloader( # Check crop type if crop_type is not None: crop_type = crop_type.lower() - if crop_type not in ['square', 'random', 'aspect_ratio']: - raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]') - if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)): + if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']: raise ValueError( - 'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.') - + f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]' + ) + if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or + isinstance(resize_size[0], int)): + raise ValueError( + 'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.') # Handle ``None`` kwargs if streaming_kwargs is None: streaming_kwargs = {} @@ -247,6 +259,9 @@ def build_streaming_image_caption_dataloader( crop = RandomCropSquare(resize_size) elif crop_type == 'aspect_ratio': crop = RandomCropAspectRatioTransorm(resize_size, ar_bucket_boundaries) # type: ignore + elif crop_type == 'bucketed_aspect_ratio': + assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type' + crop = RandomCropBucketedAspectRatioTransorm(resize_size) # type: ignore else: crop = None @@ -265,6 +280,7 @@ def build_streaming_image_caption_dataloader( transform=transform, image_key=image_key, caption_key=caption_key, + aspect_ratio_bucket_key=aspect_ratio_bucket_key, batch_size=batch_size, sdxl_conditioning=sdxl_conditioning, zero_dropped_captions=zero_dropped_captions, diff --git a/diffusion/datasets/laion/transforms.py b/diffusion/datasets/laion/transforms.py index e9700b9b..1c73f648 100644 --- a/diffusion/datasets/laion/transforms.py +++ b/diffusion/datasets/laion/transforms.py @@ -3,6 +3,7 @@ """Transforms for the training and eval dataset.""" +import math from typing import Optional, Tuple import torch @@ -111,3 +112,49 @@ def __call__(self, img): c_top, c_left, height, width = transforms.RandomCrop.get_params(img, output_size=(target_height, target_width)) img = crop(img, c_top, c_left, height, width) return img, c_top, c_left + + +class RandomCropBucketedAspectRatioTransorm: + """Assigns an image to a arbitrary set of aspect ratio buckets, then resizes and crops to fit into the bucket. + + This transform requires the desired aspect ratio bucket to be specified manually in the call to the transform. + + Args: + resize_size (Tuple[Tuple[int, int], ...): A tuple of 2-tuple integers representing the aspect ratio buckets. + The format is ((height_bucket1, width_bucket1), (height_bucket2, width_bucket2), ...). + """ + + def __init__( + self, + resize_size: Tuple[Tuple[int, int], ...], + ): + self.height_buckets = torch.tensor([size[0] for size in resize_size]) + self.width_buckets = torch.tensor([size[1] for size in resize_size]) + self.aspect_ratio_buckets = self.height_buckets / self.width_buckets + self.log_aspect_ratio_buckets = torch.log(self.aspect_ratio_buckets) + + def __call__(self, img, aspect_ratio): + orig_h, orig_w = img.shape[1:] + orig_aspect_ratio = orig_h / orig_w + # Figure out target H/W given the input aspect ratio + bucket_ind = torch.abs(self.log_aspect_ratio_buckets - math.log(aspect_ratio)).argmin() + target_width, target_height = self.width_buckets[bucket_ind].item(), self.height_buckets[bucket_ind].item() + target_aspect_ratio = target_height / target_width + + # Determine resize size + if orig_aspect_ratio > target_aspect_ratio: + # Resize width and crop height + w_scale = target_width / orig_w + resize_size = (round(w_scale * orig_h), target_width) + elif orig_aspect_ratio < target_aspect_ratio: + # Resize height and crop width + h_scale = target_height / orig_h + resize_size = (target_height, round(h_scale * orig_w)) + else: + resize_size = (target_height, target_width) + img = transforms.functional.resize(img, resize_size, antialias=True) + + # Crop based on aspect ratio + c_top, c_left, height, width = transforms.RandomCrop.get_params(img, output_size=(target_height, target_width)) + img = crop(img, c_top, c_left, height, width) + return img, c_top, c_left From 1150340f05e2395f1b1932f3f9b92636fe5537cf Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Fri, 12 Jul 2024 06:22:10 +0000 Subject: [PATCH 2/6] PIL images don't have shapes --- diffusion/datasets/laion/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/datasets/laion/transforms.py b/diffusion/datasets/laion/transforms.py index 1c73f648..7c315417 100644 --- a/diffusion/datasets/laion/transforms.py +++ b/diffusion/datasets/laion/transforms.py @@ -134,7 +134,7 @@ def __init__( self.log_aspect_ratio_buckets = torch.log(self.aspect_ratio_buckets) def __call__(self, img, aspect_ratio): - orig_h, orig_w = img.shape[1:] + orig_h, orig_w = img.size orig_aspect_ratio = orig_h / orig_w # Figure out target H/W given the input aspect ratio bucket_ind = torch.abs(self.log_aspect_ratio_buckets - math.log(aspect_ratio)).argmin() From 6fabbf51770548b1ef27db1a61b26a5f44d865ff Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Sat, 20 Jul 2024 04:35:06 +0000 Subject: [PATCH 3/6] Fix bug getting h,w --- diffusion/datasets/laion/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/datasets/laion/transforms.py b/diffusion/datasets/laion/transforms.py index 7c315417..7f6c8f86 100644 --- a/diffusion/datasets/laion/transforms.py +++ b/diffusion/datasets/laion/transforms.py @@ -134,7 +134,7 @@ def __init__( self.log_aspect_ratio_buckets = torch.log(self.aspect_ratio_buckets) def __call__(self, img, aspect_ratio): - orig_h, orig_w = img.size + orig_w, orig_h = img.size orig_aspect_ratio = orig_h / orig_w # Figure out target H/W given the input aspect ratio bucket_ind = torch.abs(self.log_aspect_ratio_buckets - math.log(aspect_ratio)).argmin() From 24c59165fae85e63dda0822ecd46ffefb621bfb2 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Mon, 22 Jul 2024 04:32:16 +0000 Subject: [PATCH 4/6] Adjustable betas --- diffusion/models/models.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index eaaee630..e06f977e 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -244,6 +244,8 @@ def stable_diffusion_xl( latent_mean: Union[float, Tuple, str] = 0.0, latent_std: Union[float, Tuple, str] = 7.67754318618, beta_schedule: str = 'scaled_linear', + beta_start: float = 0.00085, + beta_end: float = 0.012, zero_terminal_snr: bool = False, use_karras_sigmas: bool = False, offset_noise: Optional[float] = None, @@ -291,6 +293,8 @@ def stable_diffusion_xl( checkpoint. Defaults to `1/0.13025`. beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. Default: `scaled_linear`. + beta_start (float): The starting beta value. Default: `0.00085`. + beta_end (float): The ending beta value. Default: `0.012`. zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. use_karras_sigmas (bool): Whether to use the Karras sigmas for the diffusion process noise. Default: `False`. offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not @@ -412,8 +416,8 @@ def stable_diffusion_xl( # Make the noise schedulers noise_scheduler = DDPMScheduler(num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, + beta_start=beta_start, + beta_end=beta_end, beta_schedule=beta_schedule, trained_betas=None, variance_type='fixed_small', @@ -425,8 +429,8 @@ def stable_diffusion_xl( rescale_betas_zero_snr=zero_terminal_snr) if beta_schedule == 'squaredcos_cap_v2': inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, + beta_start=beta_start, + beta_end=beta_end, beta_schedule=beta_schedule, trained_betas=None, clip_sample=False, @@ -435,8 +439,8 @@ def stable_diffusion_xl( rescale_betas_zero_snr=zero_terminal_snr) else: inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, + beta_start=beta_start, + beta_end=beta_end, beta_schedule=beta_schedule, trained_betas=None, prediction_type=prediction_type, From 516b098830a9c2433db8ffa43e089f2f81b2e521 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 21 Aug 2024 05:32:06 +0000 Subject: [PATCH 5/6] Need an arg --- diffusion/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/generate.py b/diffusion/generate.py index b6b766ad..f03b3190 100644 --- a/diffusion/generate.py +++ b/diffusion/generate.py @@ -26,7 +26,7 @@ def generate(config: DictConfig) -> None: config (DictConfig): Configuration composed by Hydra """ reproducibility.seed_all(config.seed) - device = get_device() # type: ignore + device = get_device(None) # type: ignore dist.initialize_dist(device, config.dist_timeout) # The model to evaluate From b82ba31ae63a2864f75975a84711f84af6432a34 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Mon, 26 Aug 2024 17:21:55 +0000 Subject: [PATCH 6/6] Load state dict on cpu --- diffusion/evaluation/generate_images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/evaluation/generate_images.py b/diffusion/evaluation/generate_images.py index 3ac40c9c..ebb8af28 100644 --- a/diffusion/evaluation/generate_images.py +++ b/diffusion/evaluation/generate_images.py @@ -104,7 +104,7 @@ def __init__(self, get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True) with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path): # Load the model - state_dict = torch.load(self.local_checkpoint_path) + state_dict = torch.load(self.local_checkpoint_path, map_location='cpu') for key in list(state_dict['state']['model'].keys()): if 'val_metrics.' in key: del state_dict['state']['model'][key]