From 3459613338372c212e431019229af03b45e6e777 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 28 Aug 2024 23:19:16 +0000 Subject: [PATCH 01/24] Initial model class --- diffusion/models/t5_diffusion.py | 612 +++++++++++++++++++++++++++++++ 1 file changed, 612 insertions(+) create mode 100644 diffusion/models/t5_diffusion.py diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py new file mode 100644 index 00000000..bf5390ab --- /dev/null +++ b/diffusion/models/t5_diffusion.py @@ -0,0 +1,612 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Diffusion models.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from composer.devices import DeviceGPU +from composer.models import ComposerModel +from composer.utils import dist +from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel +from scipy.stats import qmc +from torchmetrics import MeanSquaredError +from tqdm.auto import tqdm +from transformers import PretrainedConfig + +from diffusion.models.autoencoder import AutoEncoder, load_autoencoder +from diffusion.models.layers import zero_module +from diffusion.models.models import _parse_latent_statistics + +try: + import xformers # type: ignore + del xformers + is_xformers_installed = True +except: + is_xformers_installed = False + + +class DiffusionV1(ComposerModel): + """Stable Diffusion ComposerModel. + + This is a Latent Diffusion model conditioned on text prompts that are run through + a pre-trained CLIP or LLM model. The CLIP outputs are then passed to as an + additional input to our Unet during training and can later be used to guide + the image generation process. + + Args: + unet (torch.nn.Module): HuggingFace conditional unet, must accept a + (B, C, H, W) input, (B,) timestep array of noise timesteps, + and (B, 77, 768) text conditioning vectors. + vae (torch.nn.Module): HuggingFace or compatible vae. + must support `.encode()` and `decode()` functions. + noise_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers + noise scheduler. Used during the forward diffusion process (training). + inference_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers + noise scheduler. Used during the backward diffusion process (inference). + loss_fn (torch.nn.Module): torch loss function. Default: `F.mse_loss`. + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. + latent_mean (Optional[tuple[float]]): The means of the latent space. If not specified, defaults to + . Default: ``(0.0,) * 4``. + latent_std (Optional[tuple[float]]): The standard deviations of the latent space. Default: ``(1/0.13025,)*4``. + downsample_factor (int): The factor by which the image is downsampled by the autoencoder. Default `8`. + train_metrics (list): List of torchmetrics to calculate during training. + Default: `None`. + val_metrics (list): List of torchmetrics to calculate during validation. + Default: `None`. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. + val_seed (int): Seed to use for generating eval images. Default: `1138`. + fsdp (bool): whether to use FSDP, Default: `False`. + """ + + def __init__( + self, + unet, + vae, + noise_scheduler, + inference_noise_scheduler, + loss_fn=F.mse_loss, + prediction_type: str = 'epsilon', + latent_mean: Tuple[float] = (0.0,) * 4, + latent_std: Tuple[float] = (1 / 0.13025,) * 4, + downsample_factor: int = 8, + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, + val_seed: int = 1138, + text_embed_dim: int = 4096, + fsdp: bool = False, + ): + super().__init__() + self.unet = unet + self.vae = vae + self.noise_scheduler = noise_scheduler + self.loss_fn = loss_fn + self.prediction_type = prediction_type.lower() + if self.prediction_type not in ['sample', 'epsilon', 'v_prediction']: + raise ValueError(f'prediction type must be one of sample, epsilon, or v_prediction. Got {prediction_type}') + self.downsample_factor = downsample_factor + self.quasirandomness = quasirandomness + self.train_seed = train_seed + self.val_seed = val_seed + self.latent_mean = latent_mean + self.latent_std = latent_std + self.latent_mean = torch.tensor(latent_mean).view(1, -1, 1, 1) + self.latent_std = torch.tensor(latent_std).view(1, -1, 1, 1) + self.train_metrics = train_metrics if train_metrics is not None else [MeanSquaredError()] + self.val_metrics = val_metrics if val_metrics is not None else [MeanSquaredError()] + self.inference_scheduler = inference_noise_scheduler + # freeze VAE during diffusion training + self.vae.requires_grad_(False) + self.vae = self.vae.half() + if fsdp: + # only wrap models we are training + self.vae._fsdp_wrap = False + self.unet._fsdp_wrap = True + + # Optional rng generator + self.rng_generator: Optional[torch.Generator] = None + if self.quasirandomness: + self.sobol = qmc.Sobol(d=1, scramble=True, seed=self.train_seed) + + self.clip_proj = nn.Linear(768, text_embed_dim) + self.t5_proj = nn.Linear(4096, text_embed_dim) + + def _apply(self, fn): + super(DiffusionV1, self)._apply(fn) + self.latent_mean = fn(self.latent_mean) + self.latent_std = fn(self.latent_std) + return self + + def _generate_timesteps(self, latents: torch.Tensor): + if self.quasirandomness: + # Generate a quasirandom sequence of timesteps equal to the global batch size + global_batch_size = latents.shape[0] * dist.get_world_size() + sampled_fractions = torch.tensor(self.sobol.random(global_batch_size), device=latents.device) + timesteps = (len(self.noise_scheduler) * sampled_fractions).squeeze() + timesteps = torch.floor(timesteps).long() + # Get this device's subset of all the timesteps + idx_offset = dist.get_global_rank() * latents.shape[0] + timesteps = timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) + else: + timesteps = torch.randint(0, + len(self.noise_scheduler), (latents.shape[0],), + device=latents.device, + generator=self.rng_generator) + return timesteps + + def set_rng_generator(self, rng_generator: torch.Generator): + """Sets the rng generator for the model.""" + self.rng_generator = rng_generator + + def forward(self, batch): + latents, text_embeds, text_pooled_embeds, encoder_attention_mask = None, None, None, None + + inputs = batch['image'] + with torch.cuda.amp.autocast(enabled=False): + latents = self.vae.encode(inputs.half())['latent_dist'].sample().data + latents = (latents - self.latent_mean) / self.latent_std # scale latents + + t5_embed = self.t5_proj(batch['T5_LATENTS']) + clip_embed = self.clip_proj(batch['CLIP_LATENTS']) + text_embeds = torch.cat([t5_embed, clip_embed], dim=1) + text_pooled_embeds = batch['CLIP_POOLED'] + + encoder_attention_mask = torch.cat([batch['T5_ATTENTION_MASK'], batch['CLIP_ATTENTION_MASK']], dim=1) + + # Sample the diffusion timesteps + timesteps = self._generate_timesteps(latents) + # Add noise to the inputs (forward diffusion) + noise = torch.randn(*latents.shape, device=latents.device, generator=self.rng_generator) + noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + # Generate the targets + if self.prediction_type == 'epsilon': + targets = noise + elif self.prediction_type == 'sample': + targets = latents + elif self.prediction_type == 'v_prediction': + targets = self.noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') + + # Prepare added time ids & embeddings + add_time_ids = torch.cat( + [batch['cond_original_size'], batch['cond_crops_coords_top_left'], batch['cond_target_size']], dim=1) + added_cond_kwargs = {'text_embeds': text_pooled_embeds, 'time_ids': add_time_ids} + + # Forward through the model + return self.unet(noised_latents, + timesteps, + text_embeds, + encoder_attention_mask=encoder_attention_mask, + added_cond_kwargs=added_cond_kwargs)['sample'], targets, timesteps + + def loss(self, outputs, batch): + """Loss between unet output and added noise, typically mse.""" + return self.loss_fn(outputs[0], outputs[1]) + + def eval_forward(self, batch, outputs=None): + """For stable diffusion, eval forward computes unet outputs as well as some samples.""" + # Skip this if outputs have already been computed, e.g. during training + if outputs is not None: + return outputs + return self.forward(batch) + + def get_metrics(self, is_train: bool = False): + if is_train: + metrics = self.train_metrics + else: + metrics = self.val_metrics + metrics_dict = {metric.__class__.__name__: metric for metric in metrics} + return metrics_dict + + def update_metric(self, batch, outputs, metric): + metric.update(outputs[0], outputs[1]) + + @torch.no_grad() + def generate( + self, + prompt_embeds: torch.FloatTensor, + pooled_prompt: torch.FloatTensor, + prompt_mask: torch.LongTensor, + neg_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_neg_prompt: Optional[torch.FloatTensor] = None, + neg_prompt_mask: Optional[torch.LongTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 3.0, + rescaled_guidance: Optional[float] = None, + num_images_per_prompt: Optional[int] = 1, + seed: Optional[int] = None, + progress_bar: Optional[bool] = True, + crop_params: Optional[torch.Tensor] = None, + input_size_params: Optional[torch.Tensor] = None, + ): + """Generates image from noise. + + Performs the backward diffusion process, each inference step takes + one forward pass through the unet. + + Args: + prompt (str or List[str]): The prompt or prompts to guide the image generation. + negative_prompt (str or List[str]): The prompt or prompts to guide the + image generation away from. Ignored when not using guidance + (i.e., ignored if guidance_scale is less than 1). + Must be the same length as list of prompts. Default: `None`. + tokenized_prompts (torch.LongTensor): Optionally pass pre-tokenized prompts instead + of string prompts. If SDXL, this will be a tensor of size [B, 2, max_length], + otherwise will be of shape [B, max_length]. Default: `None`. + tokenized_negative_prompts (torch.LongTensor): Optionally pass pre-tokenized negative + prompts instead of string prompts. Default: `None`. + tokenized_prompts_pad_mask (torch.LongTensor): Optionally pass padding mask for + pre-tokenized prompts. Default `None`. + tokenized_negative_prompts_pad_mask (torch.LongTensor): Optionall pass padding mask for + pre-tokenized negative prompts. Default `None`. + prompt_embeds (torch.FloatTensor): Optionally pass pre-tokenized prompts instead + of string prompts. If both prompt and prompt_embeds + are passed, prompt_embeds will be used. Default: `None`. + neg_prompt_embeds (torch.FloatTensor): Optionally pass pre-embedded negative + prompts instead of string negative prompts. If both negative_prompt and + negative_prompt_embeds are passed, prompt_embeds will be used. Default: `None`. + height (int, optional): The height in pixels of the generated image. + Default: `self.unet.config.sample_size * 8)`. + width (int, optional): The width in pixels of the generated image. + Default: `self.unet.config.sample_size * 8)`. + num_inference_steps (int): The number of denoising steps. + More denoising steps usually lead to a higher quality image at the expense + of slower inference. Default: `50`. + guidance_scale (float): Guidance scale as defined in + Classifier-Free Diffusion Guidance. guidance_scale is defined as w of equation + 2. of Imagen Paper. Guidance scale is enabled by setting guidance_scale > 1. + Higher guidance scale encourages to generate images that are closely linked + to the text prompt, usually at the expense of lower image quality. + Default: `3.0`. + rescaled_guidance (float, optional): Rescaled guidance scale. If not specified, rescaled guidance will + not be used. Default: `None`. + num_images_per_prompt (int): The number of images to generate per prompt. + Default: `1`. + progress_bar (bool): Whether to use the tqdm progress bar during generation. + Default: `True`. + seed (int): Random seed to use for generation. Set a seed for reproducible generation. + Default: `None`. + crop_params (torch.FloatTensor of size [Bx2], optional): Crop parameters to use + when generating images with SDXL. Default: `None`. + input_size_params (torch.FloatTensor of size [Bx2], optional): Size parameters + (representing original size of input image) to use when generating images with SDXL. + Default: `None`. + """ + # TODO: do checks + # if prompt_embeds.shape[:2] == prompt_mask.shape[:2]: + # raise ValueError(' ') + + # Check all parts of negative prompts exist and are equal length + # if neg_prompt_embeds is not None or neg_prompt_mask is not None or pooled_neg_prompt is not None: + + # if negative_negative_embedlen(prompt_embeds) != len(negative_prompt_embeds): + # raise ValueError('len(prompts) and len(negative_prompts) must be the same. \ + # A negative prompt must be provided for each given prompt.') + + # Create rng for the generation + device = self.vae.device + rng_generator = torch.Generator(device=device) + if seed: + rng_generator = rng_generator.manual_seed(seed) # type: ignore + + if height is None: + height = self.unet.config.sample_size * self.downsample_factor + if width is None: + width = self.unet.config.sample_size * self.downsample_factor + + do_classifier_free_guidance = guidance_scale > 1.0 # type: ignore + + text_embeddings = _duplicate_tensor(prompt_embeds, num_images_per_prompt) + pooled_embeddings = _duplicate_tensor(pooled_prompt, num_images_per_prompt) + encoder_attn_mask = _duplicate_tensor(prompt_mask, num_images_per_prompt) + + batch_size = len(prompt_embeds) # len prompts * num_images_per_prompt + # classifier free guidance + negative prompts + # negative prompt is given in place of the unconditional input in classifier free guidance + if do_classifier_free_guidance: + if not neg_prompt_embeds: + # Negative prompt is empty and we want to zero it out + neg_prompt_embeds = torch.zeros_like(text_embeddings) + pooled_neg_prompt = torch.zeros_like(pooled_embeddings) + neg_prompt_mask = torch.zeros_like(encoder_attn_mask) + else: + neg_prompt_embeds = _duplicate_tensor(neg_prompt_embeds, num_images_per_prompt) + pooled_neg_prompt = _duplicate_tensor(pooled_neg_prompt, num_images_per_prompt) + neg_prompt_mask = _duplicate_tensor(neg_prompt_mask, num_images_per_prompt) + + # concat uncond + prompt + text_embeddings = torch.cat([neg_prompt_embeds, text_embeddings]) + pooled_embeddings = torch.cat([pooled_neg_prompt, pooled_embeddings]) + encoder_attn_mask = torch.cat([neg_prompt_mask, encoder_attn_mask]) + + # prepare for diffusion generation process + latents = torch.randn( + (batch_size, self.unet.config.in_channels, height // self.downsample_factor, + width // self.downsample_factor), + device=device, + dtype=self.unet.dtype, + generator=rng_generator, + ) + + self.inference_scheduler.set_timesteps(num_inference_steps) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.inference_scheduler.init_noise_sigma + + added_cond_kwargs = {} + # if using SDXL, prepare added time ids & embeddings + + if crop_params is None: + crop_params = torch.zeros((batch_size, 2), dtype=text_embeddings.dtype) + if input_size_params is None: + input_size_params = torch.tensor([[width, height]] * batch_size, dtype=text_embeddings.dtype) + output_size_params = torch.tensor([[width, height]] * batch_size, dtype=text_embeddings.dtype) + + if do_classifier_free_guidance: + crop_params = torch.cat([crop_params, crop_params]) + input_size_params = torch.cat([input_size_params, input_size_params]) + output_size_params = torch.cat([output_size_params, output_size_params]) + + add_time_ids = torch.cat([input_size_params, crop_params, output_size_params], dim=1).to(device) + added_cond_kwargs = {'text_embeds': pooled_embeddings, 'time_ids': add_time_ids} + + # backward diffusion process + for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): + if do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents + + latent_model_input = self.inference_scheduler.scale_model_input(latent_model_input, t) + # Model prediction + pred = self.unet(latent_model_input, + t, + encoder_hidden_states=text_embeddings, + encoder_attention_mask=encoder_attn_mask, + added_cond_kwargs=added_cond_kwargs).sample + + if do_classifier_free_guidance: + # perform guidance. Note this is only techincally correct for prediction_type 'epsilon' + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) + # Optionally rescale the classifer free guidance + if rescaled_guidance is not None: + std_pos = torch.std(pred_text, dim=(1, 2, 3), keepdim=True) + std_cfg = torch.std(pred, dim=(1, 2, 3), keepdim=True) + pred_rescaled = pred * (std_pos / std_cfg) + pred = pred_rescaled * rescaled_guidance + pred * (1 - rescaled_guidance) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inference_scheduler.step(pred, t, latents, generator=rng_generator).prev_sample + + # We now use the vae to decode the generated latents back into the image. + # scale and decode the image latents with vae + latents = latents * self.latent_std + self.latent_mean + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + return image.detach() # (batch*num_images_per_prompt, channel, h, w) + + +def _duplicate_tensor(tensor, num_images_per_prompt): + """Duplicate tensor for multiple generations from a single prompt.""" + batch_size, seq_len = tensor.shape[:2] + tensor = tensor.repeat(1, num_images_per_prompt, *[ + 1, + ] * len(tensor.shape[2:])) + return tensor.view(batch_size * num_images_per_prompt, seq_len, *[ + -1, + ] * len(tensor.shape[2:])) + + +def build_diffusion_v1( + unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', + vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', + autoencoder_path: Optional[str] = None, + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + prediction_type: str = 'epsilon', + latent_mean: Union[float, Tuple, str] = 0.0, + latent_std: Union[float, Tuple, str] = 7.67754318618, + text_embed_dim: int = 4096, + beta_schedule: str = 'scaled_linear', + zero_terminal_snr: bool = False, + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, + val_seed: int = 1138, + fsdp: bool = True, + use_xformers: bool = True, +): + """Stable diffusion 2 training setup + SDXL UNet and VAE. + + Requires batches of matched images and text prompts to train. Generates images from text + prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2. + + Args: + unet_model_name (str): Name of the UNet model to load. Defaults to + 'stabilityai/stable-diffusion-xl-base-1.0'. + vae_model_name (str): Name of the VAE model to load. Defaults to + 'madebyollin/sdxl-vae-fp16-fix' as the official VAE checkpoint (from + 'stabilityai/stable-diffusion-xl-base-1.0') is not compatible with fp16. + autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, + will use the vae from `model_name`. Default `None`. + autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. + latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value, + a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `0.0`. + latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, + a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `1/0.13025`. + beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. + Default: `scaled_linear`. + zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. + train_metrics (list, optional): List of metrics to compute during training. If None, defaults to + [MeanSquaredError()]. + val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to + [MeanSquaredError()]. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. + val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. + fsdp (bool): Whether to use FSDP. Defaults to True. + use_xformers (bool): Whether to use xformers for attention. Defaults to True. + """ + latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) + + if train_metrics is None: + train_metrics = [MeanSquaredError()] + if val_metrics is None: + val_metrics = [MeanSquaredError()] + + # Make the autoencoder + if autoencoder_path is None: + if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': + raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') + downsample_factor = 8 + # Use the pretrained vae + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch.float16) + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16) + else: + # Use a custom autoencoder + vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=torch.float16) + if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): + raise ValueError( + 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') + if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_mean = tuple(latent_statistics['latent_channel_means']) + if isinstance(latent_std, str) and latent_std == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_std = tuple(latent_statistics['latent_channel_stds']) + downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) + + # Make the unet + unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0] + + if isinstance(vae, AutoEncoder): + # Adapt the unet config to account for differing number of latent channels if necessary + unet_config['in_channels'] = vae.config['latent_channels'] + unet_config['out_channels'] = vae.config['latent_channels'] + unet_config['cross_attention_dim'] = text_embed_dim + # This config variable is the sum of the text encoder projection dimension (768 for CLIP) and + # the number of additional time embeddings (6) * addition_time_embed_dim (256) + unet_config['projection_class_embeddings_input_dim'] = 2304 + # Init the unet from the config + unet = UNet2DConditionModel(**unet_config) + + # Zero initialization trick + for name, layer in unet.named_modules(): + # Final conv in ResNet blocks + if name.endswith('conv2'): + layer = zero_module(layer) + # proj_out in attention blocks + if name.endswith('to_out.0'): + layer = zero_module(layer) + # Last conv block out projection + unet.conv_out = zero_module(unet.conv_out) + + if isinstance(latent_mean, float): + latent_mean = (latent_mean,) * unet_config['in_channels'] + if isinstance(latent_std, float): + latent_std = (latent_std,) * unet_config['in_channels'] + assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) + + # FSDP Wrapping Scheme + if hasattr(unet, 'mid_block') and unet.mid_block is not None: + for attention in unet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in unet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in unet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in unet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + + # Make the noise schedulers + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + variance_type='fixed_small', + clip_sample=False, + prediction_type=prediction_type, + sample_max_value=1.0, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + if beta_schedule == 'squaredcos_cap_v2': + inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + rescale_betas_zero_snr=zero_terminal_snr) + else: + inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + prediction_type=prediction_type, + interpolation_type='linear', + use_karras_sigmas=False, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + + # Make the composer model + model = DiffusionV1( + unet=unet, + vae=vae, + noise_scheduler=noise_scheduler, + inference_noise_scheduler=inference_noise_scheduler, + prediction_type=prediction_type, + latent_mean=latent_mean, + latent_std=latent_std, + downsample_factor=downsample_factor, + train_metrics=train_metrics, + val_metrics=val_metrics, + quasirandomness=quasirandomness, + train_seed=train_seed, + val_seed=val_seed, + text_embed_dim=text_embed_dim, + fsdp=fsdp, + ) + if torch.cuda.is_available(): + model = DeviceGPU().module_to_device(model) + if is_xformers_installed and use_xformers: + model.unet.enable_xformers_memory_efficient_attention() + if hasattr(model.vae, 'enable_xformers_memory_efficient_attention'): + model.vae.enable_xformers_memory_efficient_attention() + + return model From 4b5b22750ab700b89519f97c92480a8a602df836 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 29 Aug 2024 05:42:54 +0000 Subject: [PATCH 02/24] Support truncating embeddings --- diffusion/callbacks/log_diffusion_images.py | 12 ++++---- diffusion/models/t5_diffusion.py | 31 +++++++++++++++++---- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index c36b75d9..a9a70b85 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -121,8 +121,9 @@ def __init__(self, clip_attention_mask = clip_attention_mask.cpu().to(torch.long) latent_batch['T5_LATENTS'] = t5_latents + latent_batch['T5_ATTENTION_MASK'] = t5_attention_mask latent_batch['CLIP_LATENTS'] = clip_latents - latent_batch['ATTENTION_MASK'] = torch.cat([t5_attention_mask, clip_attention_mask], dim=1) + latent_batch['CLIP_ATTENTION_MASK'] = clip_attention_mask latent_batch['CLIP_POOLED'] = clip_pooled self.batched_latents.append(latent_batch) @@ -144,11 +145,10 @@ def eval_start(self, state: State, logger: Logger): if self.precomputed_latents: for batch in self.batched_latents: pooled_prompt = batch['CLIP_POOLED'].cuda() - prompt_mask = batch['ATTENTION_MASK'].cuda() - t5_embeds = model.t5_proj(batch['T5_LATENTS'].cuda()) - clip_embeds = model.clip_proj(batch['CLIP_LATENTS'].cuda()) - prompt_embeds = torch.cat([t5_embeds, clip_embeds], dim=1) - + prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch['T5_LATENTS'].cuda(), + batch['CLIP_LATENTS'].cuda(), + batch['T5_ATTENTION_MASK'].cuda(), + batch['CLIP_ATTENTION_MASK'].cuda()) gen_images = model.generate(prompt_embeds=prompt_embeds, pooled_prompt=pooled_prompt, prompt_mask=prompt_mask, diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py index bf5390ab..f3eb81f8 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/t5_diffusion.py @@ -54,6 +54,7 @@ class DiffusionV1(ComposerModel): . Default: ``(0.0,) * 4``. latent_std (Optional[tuple[float]]): The standard deviations of the latent space. Default: ``(1/0.13025,)*4``. downsample_factor (int): The factor by which the image is downsampled by the autoencoder. Default `8`. + max_seq_len (int): The maximum sequence length for the text encoder. Default: `77`. train_metrics (list): List of torchmetrics to calculate during training. Default: `None`. val_metrics (list): List of torchmetrics to calculate during validation. @@ -77,6 +78,7 @@ def __init__( latent_mean: Tuple[float] = (0.0,) * 4, latent_std: Tuple[float] = (1 / 0.13025,) * 4, downsample_factor: int = 8, + max_seq_len: int = 77, train_metrics: Optional[List] = None, val_metrics: Optional[List] = None, quasirandomness: bool = False, @@ -94,6 +96,7 @@ def __init__( if self.prediction_type not in ['sample', 'epsilon', 'v_prediction']: raise ValueError(f'prediction type must be one of sample, epsilon, or v_prediction. Got {prediction_type}') self.downsample_factor = downsample_factor + self.max_seq_len = max_seq_len self.quasirandomness = quasirandomness self.train_seed = train_seed self.val_seed = val_seed @@ -147,21 +150,37 @@ def set_rng_generator(self, rng_generator: torch.Generator): """Sets the rng generator for the model.""" self.rng_generator = rng_generator + def prepare_text_embeddings(self, t5_embed, clip_embed, t5_mask, clip_mask): + t5_embed = self.t5_proj(t5_embed) + clip_embed = self.clip_proj(clip_embed) + if t5_embed.shape[1] > self.max_seq_len: + t5_embed = t5_embed[:, :self.max_seq_len] + t5_mask = t5_mask[:, :self.max_seq_len] + if clip_embed.shape[1] > self.max_seq_len: + clip_embed = clip_embed[:, :self.max_seq_len] + clip_mask = clip_mask[:, :self.max_seq_len] + # Concatenate the text embeddings + text_embeds = torch.cat([t5_embed, clip_embed], dim=1) + encoder_attention_mask = torch.cat([t5_mask, clip_mask], dim=1) + return text_embeds, encoder_attention_mask + def forward(self, batch): latents, text_embeds, text_pooled_embeds, encoder_attention_mask = None, None, None, None + # Encode the images with the autoencoder encoder inputs = batch['image'] with torch.cuda.amp.autocast(enabled=False): latents = self.vae.encode(inputs.half())['latent_dist'].sample().data latents = (latents - self.latent_mean) / self.latent_std # scale latents - t5_embed = self.t5_proj(batch['T5_LATENTS']) - clip_embed = self.clip_proj(batch['CLIP_LATENTS']) - text_embeds = torch.cat([t5_embed, clip_embed], dim=1) + # Text embeddings are shape (B, seq_len, emb_dim), optionally truncate to a max length + t5_embed = batch['T5_LATENTS'] + t5_mask = batch['T5_ATTENTION_MASK'] + clip_embed = batch['CLIP_LATENTS'] + clip_mask = batch['CLIP_ATTENTION_MASK'] text_pooled_embeds = batch['CLIP_POOLED'] - - encoder_attention_mask = torch.cat([batch['T5_ATTENTION_MASK'], batch['CLIP_ATTENTION_MASK']], dim=1) - + text_embeds, encoder_attention_mask = self.prepare_text_embeddings(t5_embed, clip_embed, t5_mask, clip_mask) + # Sample the diffusion timesteps timesteps = self._generate_timesteps(latents) # Add noise to the inputs (forward diffusion) From 39ed5f72ff43e48e9189edecc7f5247d0fe9c4a9 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 29 Aug 2024 06:03:02 +0000 Subject: [PATCH 03/24] Truncate before embedding, and add position embeds --- diffusion/models/t5_diffusion.py | 58 +++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py index f3eb81f8..e3f440fa 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/t5_diffusion.py @@ -3,6 +3,7 @@ """Diffusion models.""" +import math from typing import List, Optional, Tuple, Union import torch @@ -120,8 +121,16 @@ def __init__( if self.quasirandomness: self.sobol = qmc.Sobol(d=1, scramble=True, seed=self.train_seed) + # Projection layers for the text embeddings self.clip_proj = nn.Linear(768, text_embed_dim) self.t5_proj = nn.Linear(4096, text_embed_dim) + # Learnable position embeddings for the conitioning sequences + t5_position_embeddings = torch.randn(self.max_seq_len, text_embed_dim) + t5_position_embeddings /= math.sqrt(text_embed_dim) + self.t5_position_embedding = torch.nn.Parameter(t5_position_embeddings, requires_grad=True) + clip_position_embeddings = torch.randn(self.max_seq_len, text_embed_dim) + clip_position_embeddings /= math.sqrt(text_embed_dim) + self.clip_position_embedding = torch.nn.Parameter(clip_position_embeddings, requires_grad=True) def _apply(self, fn): super(DiffusionV1, self)._apply(fn) @@ -150,28 +159,51 @@ def set_rng_generator(self, rng_generator: torch.Generator): """Sets the rng generator for the model.""" self.rng_generator = rng_generator + def encode_images(self, inputs): + with torch.cuda.amp.autocast(enabled=False): + latents = self.vae.encode(inputs.half())['latent_dist'].sample().data + latents = (latents - self.latent_mean) / self.latent_std # scale latents + return latents + def prepare_text_embeddings(self, t5_embed, clip_embed, t5_mask, clip_mask): - t5_embed = self.t5_proj(t5_embed) - clip_embed = self.clip_proj(clip_embed) if t5_embed.shape[1] > self.max_seq_len: t5_embed = t5_embed[:, :self.max_seq_len] t5_mask = t5_mask[:, :self.max_seq_len] if clip_embed.shape[1] > self.max_seq_len: clip_embed = clip_embed[:, :self.max_seq_len] clip_mask = clip_mask[:, :self.max_seq_len] + t5_embed = self.t5_proj(t5_embed) + clip_embed = self.clip_proj(clip_embed) + # Add position embeddings + t5_embed = 0.707 * t5_embed + 0.707 * self.t5_position_embedding[:t5_embed.shape[1]].unsqueeze(0) + clip_embed = 0.707 * clip_embed + 0.707 * self.clip_position_embedding[:clip_embed.shape[1]].unsqueeze(0) # Concatenate the text embeddings text_embeds = torch.cat([t5_embed, clip_embed], dim=1) encoder_attention_mask = torch.cat([t5_mask, clip_mask], dim=1) return text_embeds, encoder_attention_mask + def diffusion_forward(self, latents, timesteps): + # Add noise to the inputs (forward diffusion) + noise = torch.randn(*latents.shape, device=latents.device, generator=self.rng_generator) + noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + # Generate the targets + if self.prediction_type == 'epsilon': + targets = noise + elif self.prediction_type == 'sample': + targets = latents + elif self.prediction_type == 'v_prediction': + targets = self.noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') + return noised_latents, targets + def forward(self, batch): latents, text_embeds, text_pooled_embeds, encoder_attention_mask = None, None, None, None # Encode the images with the autoencoder encoder inputs = batch['image'] - with torch.cuda.amp.autocast(enabled=False): - latents = self.vae.encode(inputs.half())['latent_dist'].sample().data - latents = (latents - self.latent_mean) / self.latent_std # scale latents + latents = self.encode_images(inputs) # Text embeddings are shape (B, seq_len, emb_dim), optionally truncate to a max length t5_embed = batch['T5_LATENTS'] @@ -180,22 +212,10 @@ def forward(self, batch): clip_mask = batch['CLIP_ATTENTION_MASK'] text_pooled_embeds = batch['CLIP_POOLED'] text_embeds, encoder_attention_mask = self.prepare_text_embeddings(t5_embed, clip_embed, t5_mask, clip_mask) - + # Sample the diffusion timesteps timesteps = self._generate_timesteps(latents) - # Add noise to the inputs (forward diffusion) - noise = torch.randn(*latents.shape, device=latents.device, generator=self.rng_generator) - noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) - # Generate the targets - if self.prediction_type == 'epsilon': - targets = noise - elif self.prediction_type == 'sample': - targets = latents - elif self.prediction_type == 'v_prediction': - targets = self.noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError( - f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') + noised_latents, targets = self.diffusion_forward(latents, timesteps) # Prepare added time ids & embeddings add_time_ids = torch.cat( From 8ef76e62aef4cf16a023d5352cc94904689ec1cd Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 29 Aug 2024 16:40:36 +0000 Subject: [PATCH 04/24] Don't need an arg for the loss --- diffusion/models/t5_diffusion.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py index e3f440fa..23b4768e 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/t5_diffusion.py @@ -48,7 +48,6 @@ class DiffusionV1(ComposerModel): noise scheduler. Used during the forward diffusion process (training). inference_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers noise scheduler. Used during the backward diffusion process (inference). - loss_fn (torch.nn.Module): torch loss function. Default: `F.mse_loss`. prediction_type (str): The type of prediction to use. Must be one of 'sample', 'epsilon', or 'v_prediction'. Default: `epsilon`. latent_mean (Optional[tuple[float]]): The means of the latent space. If not specified, defaults to @@ -74,7 +73,6 @@ def __init__( vae, noise_scheduler, inference_noise_scheduler, - loss_fn=F.mse_loss, prediction_type: str = 'epsilon', latent_mean: Tuple[float] = (0.0,) * 4, latent_std: Tuple[float] = (1 / 0.13025,) * 4, @@ -92,7 +90,6 @@ def __init__( self.unet = unet self.vae = vae self.noise_scheduler = noise_scheduler - self.loss_fn = loss_fn self.prediction_type = prediction_type.lower() if self.prediction_type not in ['sample', 'epsilon', 'v_prediction']: raise ValueError(f'prediction type must be one of sample, epsilon, or v_prediction. Got {prediction_type}') @@ -165,6 +162,12 @@ def encode_images(self, inputs): latents = (latents - self.latent_mean) / self.latent_std # scale latents return latents + def decode_latents (self, latents): + latents = latents * self.latent_std + self.latent_mean + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + return self.vae.decode(latents) + def prepare_text_embeddings(self, t5_embed, clip_embed, t5_mask, clip_mask): if t5_embed.shape[1] > self.max_seq_len: t5_embed = t5_embed[:, :self.max_seq_len] @@ -231,7 +234,8 @@ def forward(self, batch): def loss(self, outputs, batch): """Loss between unet output and added noise, typically mse.""" - return self.loss_fn(outputs[0], outputs[1]) + loss = F.mse_loss(outputs[0], outputs[1]) + return loss def eval_forward(self, batch, outputs=None): """For stable diffusion, eval forward computes unet outputs as well as some samples.""" @@ -431,9 +435,7 @@ def generate( # We now use the vae to decode the generated latents back into the image. # scale and decode the image latents with vae - latents = latents * self.latent_std + self.latent_mean - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) + image = self.decode_latents(latents) return image.detach() # (batch*num_images_per_prompt, channel, h, w) From 402b1a5909a9c09facc2157adb2b7f6bd14fa48d Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 29 Aug 2024 17:10:20 +0000 Subject: [PATCH 05/24] Prep for inference --- diffusion/models/t5_diffusion.py | 46 +++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py index 23b4768e..faa15e4c 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/t5_diffusion.py @@ -48,6 +48,10 @@ class DiffusionV1(ComposerModel): noise scheduler. Used during the forward diffusion process (training). inference_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers noise scheduler. Used during the backward diffusion process (inference). + t5_tokenizer (Optional): Tokenizer for T5. Should only be specified during inference. Default: `None`. + t5_encoder (Optional): T5 text encoder. Should only be specified during inference. Default: `None`. + clip_tokenizer (Optional): Tokenizer for CLIP. Should only be specified during inference. Default: `None`. + clip_encoder (Optional): CLIP text encoder. Should only be specified during inference. Default: `None`. prediction_type (str): The type of prediction to use. Must be one of 'sample', 'epsilon', or 'v_prediction'. Default: `epsilon`. latent_mean (Optional[tuple[float]]): The means of the latent space. If not specified, defaults to @@ -73,6 +77,10 @@ def __init__( vae, noise_scheduler, inference_noise_scheduler, + t5_tokenizer: Optional = None, + t5_encoder: Optional = None, + clip_tokenizer: Optional = None, + clip_encoder: Optional = None, prediction_type: str = 'epsilon', latent_mean: Tuple[float] = (0.0,) * 4, latent_std: Tuple[float] = (1 / 0.13025,) * 4, @@ -89,6 +97,10 @@ def __init__( super().__init__() self.unet = unet self.vae = vae + self.t5_tokenizer = t5_tokenizer + self.t5_encoder = t5_encoder + self.clip_tokenizer = clip_tokenizer + self.clip_encoder = clip_encoder self.noise_scheduler = noise_scheduler self.prediction_type = prediction_type.lower() if self.prediction_type not in ['sample', 'epsilon', 'v_prediction']: @@ -136,20 +148,28 @@ def _apply(self, fn): return self def _generate_timesteps(self, latents: torch.Tensor): - if self.quasirandomness: - # Generate a quasirandom sequence of timesteps equal to the global batch size + if not self.model.training: + # Sample equally spaced timesteps across all devices global_batch_size = latents.shape[0] * dist.get_world_size() - sampled_fractions = torch.tensor(self.sobol.random(global_batch_size), device=latents.device) - timesteps = (len(self.noise_scheduler) * sampled_fractions).squeeze() - timesteps = torch.floor(timesteps).long() + global_timesteps = torch.linspace(0, len(self.noise_scheduler), global_batch_size) # Get this device's subset of all the timesteps idx_offset = dist.get_global_rank() * latents.shape[0] - timesteps = timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) + timesteps = global_timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) else: - timesteps = torch.randint(0, - len(self.noise_scheduler), (latents.shape[0],), - device=latents.device, - generator=self.rng_generator) + if self.quasirandomness: + # Generate a quasirandom sequence of timesteps equal to the global batch size + global_batch_size = latents.shape[0] * dist.get_world_size() + sampled_fractions = torch.tensor(self.sobol.random(global_batch_size), device=latents.device) + timesteps = (len(self.noise_scheduler) * sampled_fractions).squeeze() + timesteps = torch.floor(timesteps).long() + # Get this device's subset of all the timesteps + idx_offset = dist.get_global_rank() * latents.shape[0] + timesteps = timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) + else: + timesteps = torch.randint(0, + len(self.noise_scheduler), (latents.shape[0],), + device=latents.device, + generator=self.rng_generator) return timesteps def set_rng_generator(self, rng_generator: torch.Generator): @@ -162,12 +182,12 @@ def encode_images(self, inputs): latents = (latents - self.latent_mean) / self.latent_std # scale latents return latents - def decode_latents (self, latents): + def decode_latents(self, latents): latents = latents * self.latent_std + self.latent_mean image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) - return self.vae.decode(latents) - + return image + def prepare_text_embeddings(self, t5_embed, clip_embed, t5_mask, clip_mask): if t5_embed.shape[1] > self.max_seq_len: t5_embed = t5_embed[:, :self.max_seq_len] From 19cc8fbbf3f7b54c9b486bd0236751ab64997e00 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Fri, 30 Aug 2024 05:25:34 +0000 Subject: [PATCH 06/24] Prep for string inputs --- diffusion/models/models.py | 232 +++++++++++++++++++++- diffusion/models/t5_diffusion.py | 318 +++++++------------------------ 2 files changed, 303 insertions(+), 247 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index a6d782d8..671791bb 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -12,7 +12,7 @@ from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel from peft import LoraConfig from torchmetrics import MeanSquaredError -from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer, PretrainedConfig from diffusion.models.autoencoder import (AutoEncoder, AutoEncoderLoss, ComposerAutoEncoder, ComposerDiffusersAutoEncoder, load_autoencoder) @@ -20,6 +20,7 @@ from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT +from diffusion.models.t5_diffusion import DiffusionV1 from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer from diffusion.models.transformer import DiffusionTransformer from diffusion.schedulers.schedulers import ContinuousTimeScheduler @@ -580,6 +581,235 @@ def stable_diffusion_xl( return model +def build_diffusion_v1( + unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', + vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', + autoencoder_path: Optional[str] = None, + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + include_text_encoders: bool = False, + cache_dir: str = '/tmp/hf_files', + prediction_type: str = 'epsilon', + latent_mean: Union[float, Tuple, str] = 0.0, + latent_std: Union[float, Tuple, str] = 7.67754318618, + text_embed_dim: int = 4096, + beta_schedule: str = 'scaled_linear', + zero_terminal_snr: bool = False, + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, + val_seed: int = 1138, + fsdp: bool = True, + use_xformers: bool = True, +): + """Stable diffusion 2 training setup + SDXL UNet and VAE. + + Requires batches of matched images and text prompts to train. Generates images from text + prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2. + + Args: + unet_model_name (str): Name of the UNet model to load. Defaults to + 'stabilityai/stable-diffusion-xl-base-1.0'. + vae_model_name (str): Name of the VAE model to load. Defaults to + 'madebyollin/sdxl-vae-fp16-fix' as the official VAE checkpoint (from + 'stabilityai/stable-diffusion-xl-base-1.0') is not compatible with fp16. + autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, + will use the vae from `model_name`. Default `None`. + autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. + include_text_encoders (bool): Whether to include text encoders in the model. Should only do this for running + inference. Default: `False`. + cache_dir (str): Directory to cache the model in if using `include_text_encoders`. Default: `'/tmp/hf_files'`. + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. + latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value, + a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `0.0`. + latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, + a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `1/0.13025`. + beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. + Default: `scaled_linear`. + zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. + train_metrics (list, optional): List of metrics to compute during training. If None, defaults to + [MeanSquaredError()]. + val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to + [MeanSquaredError()]. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. + val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. + fsdp (bool): Whether to use FSDP. Defaults to True. + use_xformers (bool): Whether to use xformers for attention. Defaults to True. + """ + latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) + + if train_metrics is None: + train_metrics = [MeanSquaredError()] + if val_metrics is None: + val_metrics = [MeanSquaredError()] + + # Make the autoencoder + if autoencoder_path is None: + if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': + raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') + downsample_factor = 8 + # Use the pretrained vae + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch.float16) + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16) + else: + # Use a custom autoencoder + vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=torch.float16) + if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): + raise ValueError( + 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') + if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_mean = tuple(latent_statistics['latent_channel_means']) + if isinstance(latent_std, str) and latent_std == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_std = tuple(latent_statistics['latent_channel_stds']) + downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) + + # Make the unet + unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0] + + if isinstance(vae, AutoEncoder): + # Adapt the unet config to account for differing number of latent channels if necessary + unet_config['in_channels'] = vae.config['latent_channels'] + unet_config['out_channels'] = vae.config['latent_channels'] + unet_config['cross_attention_dim'] = text_embed_dim + # This config variable is the sum of the text encoder projection dimension (768 for CLIP) and + # the number of additional time embeddings (6) * addition_time_embed_dim (256) + unet_config['projection_class_embeddings_input_dim'] = 2304 + # Init the unet from the config + unet = UNet2DConditionModel(**unet_config) + + # Zero initialization trick + for name, layer in unet.named_modules(): + # Final conv in ResNet blocks + if name.endswith('conv2'): + layer = zero_module(layer) + # proj_out in attention blocks + if name.endswith('to_out.0'): + layer = zero_module(layer) + # Last conv block out projection + unet.conv_out = zero_module(unet.conv_out) + + if isinstance(latent_mean, float): + latent_mean = (latent_mean,) * unet_config['in_channels'] + if isinstance(latent_std, float): + latent_std = (latent_std,) * unet_config['in_channels'] + assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) + + # FSDP Wrapping Scheme + if hasattr(unet, 'mid_block') and unet.mid_block is not None: + for attention in unet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in unet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in unet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in unet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + + # Make the noise schedulers + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + variance_type='fixed_small', + clip_sample=False, + prediction_type=prediction_type, + sample_max_value=1.0, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + if beta_schedule == 'squaredcos_cap_v2': + inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + rescale_betas_zero_snr=zero_terminal_snr) + else: + inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + prediction_type=prediction_type, + interpolation_type='linear', + use_karras_sigmas=False, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + + # Optionally load the tokenizers and text encoders + t5_tokenizer, t5_encoder, clip_tokenizer, clip_encoder = None, None, None, None + if include_text_encoders: + t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True) + clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', + subfolder='tokenizer', + cache_dir=cache_dir, + local_files_only=True) + t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl', + torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + local_files_only=True).encoder.eval() + clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', + subfolder='text_encoder', + torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + local_files_only=True).cuda().eval() + # Make the composer model + model = DiffusionV1( + unet=unet, + vae=vae, + t5_tokenizer=t5_tokenizer, + t5_encoder=t5_encoder, + clip_tokenizer=clip_tokenizer, + clip_encoder=clip_encoder, + noise_scheduler=noise_scheduler, + inference_noise_scheduler=inference_noise_scheduler, + prediction_type=prediction_type, + latent_mean=latent_mean, + latent_std=latent_std, + downsample_factor=downsample_factor, + train_metrics=train_metrics, + val_metrics=val_metrics, + quasirandomness=quasirandomness, + train_seed=train_seed, + val_seed=val_seed, + text_embed_dim=text_embed_dim, + fsdp=fsdp, + ) + if torch.cuda.is_available(): + model = DeviceGPU().module_to_device(model) + if is_xformers_installed and use_xformers: + model.unet.enable_xformers_memory_efficient_attention() + if hasattr(model.vae, 'enable_xformers_memory_efficient_attention'): + model.vae.enable_xformers_memory_efficient_attention() + + return model + + def text_to_image_transformer( tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer'), text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder'), diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py index faa15e4c..9a2e38df 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/t5_diffusion.py @@ -4,23 +4,17 @@ """Diffusion models.""" import math -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F -from composer.devices import DeviceGPU from composer.models import ComposerModel from composer.utils import dist -from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel from scipy.stats import qmc from torchmetrics import MeanSquaredError from tqdm.auto import tqdm -from transformers import PretrainedConfig - -from diffusion.models.autoencoder import AutoEncoder, load_autoencoder -from diffusion.models.layers import zero_module -from diffusion.models.models import _parse_latent_statistics +from transformers import PreTrainedTokenizer try: import xformers # type: ignore @@ -77,10 +71,10 @@ def __init__( vae, noise_scheduler, inference_noise_scheduler, - t5_tokenizer: Optional = None, - t5_encoder: Optional = None, - clip_tokenizer: Optional = None, - clip_encoder: Optional = None, + t5_tokenizer: Optional[PreTrainedTokenizer] = None, + t5_encoder: Optional[torch.nn.Module] = None, + clip_tokenizer: Optional[PreTrainedTokenizer] = None, + clip_encoder: Optional[torch.nn.Module] = None, prediction_type: str = 'epsilon', latent_mean: Tuple[float] = (0.0,) * 4, latent_std: Tuple[float] = (1 / 0.13025,) * 4, @@ -110,8 +104,6 @@ def __init__( self.quasirandomness = quasirandomness self.train_seed = train_seed self.val_seed = val_seed - self.latent_mean = latent_mean - self.latent_std = latent_std self.latent_mean = torch.tensor(latent_mean).view(1, -1, 1, 1) self.latent_std = torch.tensor(latent_std).view(1, -1, 1, 1) self.train_metrics = train_metrics if train_metrics is not None else [MeanSquaredError()] @@ -148,7 +140,7 @@ def _apply(self, fn): return self def _generate_timesteps(self, latents: torch.Tensor): - if not self.model.training: + if not self.unet.training: # Sample equally spaced timesteps across all devices global_batch_size = latents.shape[0] * dist.get_world_size() global_timesteps = torch.linspace(0, len(self.noise_scheduler), global_batch_size) @@ -167,9 +159,9 @@ def _generate_timesteps(self, latents: torch.Tensor): timesteps = timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) else: timesteps = torch.randint(0, - len(self.noise_scheduler), (latents.shape[0],), - device=latents.device, - generator=self.rng_generator) + len(self.noise_scheduler), (latents.shape[0],), + device=latents.device, + generator=self.rng_generator) return timesteps def set_rng_generator(self, rng_generator: torch.Generator): @@ -187,7 +179,34 @@ def decode_latents(self, latents): image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) return image - + + def encode_text(self, text, device): + assert self.t5_tokenizer is not None and self.t5_encoder is not None + assert self.clip_tokenizer is not None and self.clip_encoder is not None + # Encode with T5 + t5_tokenizer_out = self.t5_tokenizer(text, + padding='max_length', + max_length=self.t5_tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + tokenized_captions = t5_tokenizer_out['input_ids'].to(device) + t5_attn_mask = t5_tokenizer_out['attention_mask'].to(torch.bool).to(device) + t5_embed = self.t5_encoder(input_ids=tokenized_captions, attention_mask=t5_attn_mask) + # Encode with CLIP + clip_tokenizer_out = self.clip_tokenizer(text, + padding='max_length', + max_length=self.clip_tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + tokenized_captions = clip_tokenizer_out['input_ids'].to(device) + clip_attn_mask = clip_tokenizer_out['attention_mask'].to(torch.bool).to(device) + clip_out = self.clip_encoder(input_ids=tokenized_captions, + attention_mask=clip_attn_mask, + output_hidden_states=True) + clip_embed = clip_out.hidden_states[-2] + pooled_embeddings = clip_out[1] + return t5_embed, clip_embed, t5_attn_mask, clip_attn_mask, pooled_embeddings + def prepare_text_embeddings(self, t5_embed, clip_embed, t5_mask, clip_mask): if t5_embed.shape[1] > self.max_seq_len: t5_embed = t5_embed[:, :self.max_seq_len] @@ -278,9 +297,11 @@ def update_metric(self, batch, outputs, metric): @torch.no_grad() def generate( self, - prompt_embeds: torch.FloatTensor, - pooled_prompt: torch.FloatTensor, - prompt_mask: torch.LongTensor, + prompt: Optional[list] = None, + negative_prompt: Optional[list] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt: Optional[torch.FloatTensor] = None, + prompt_mask: Optional[torch.LongTensor] = None, neg_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_neg_prompt: Optional[torch.FloatTensor] = None, neg_prompt_mask: Optional[torch.LongTensor] = None, @@ -301,20 +322,11 @@ def generate( one forward pass through the unet. Args: - prompt (str or List[str]): The prompt or prompts to guide the image generation. + prompt (List[str]): The prompts to guide the image generation. negative_prompt (str or List[str]): The prompt or prompts to guide the image generation away from. Ignored when not using guidance (i.e., ignored if guidance_scale is less than 1). Must be the same length as list of prompts. Default: `None`. - tokenized_prompts (torch.LongTensor): Optionally pass pre-tokenized prompts instead - of string prompts. If SDXL, this will be a tensor of size [B, 2, max_length], - otherwise will be of shape [B, max_length]. Default: `None`. - tokenized_negative_prompts (torch.LongTensor): Optionally pass pre-tokenized negative - prompts instead of string prompts. Default: `None`. - tokenized_prompts_pad_mask (torch.LongTensor): Optionally pass padding mask for - pre-tokenized prompts. Default `None`. - tokenized_negative_prompts_pad_mask (torch.LongTensor): Optionall pass padding mask for - pre-tokenized negative prompts. Default `None`. prompt_embeds (torch.FloatTensor): Optionally pass pre-tokenized prompts instead of string prompts. If both prompt and prompt_embeds are passed, prompt_embeds will be used. Default: `None`. @@ -348,17 +360,6 @@ def generate( (representing original size of input image) to use when generating images with SDXL. Default: `None`. """ - # TODO: do checks - # if prompt_embeds.shape[:2] == prompt_mask.shape[:2]: - # raise ValueError(' ') - - # Check all parts of negative prompts exist and are equal length - # if neg_prompt_embeds is not None or neg_prompt_mask is not None or pooled_neg_prompt is not None: - - # if negative_negative_embedlen(prompt_embeds) != len(negative_prompt_embeds): - # raise ValueError('len(prompts) and len(negative_prompts) must be the same. \ - # A negative prompt must be provided for each given prompt.') - # Create rng for the generation device = self.vae.device rng_generator = torch.Generator(device=device) @@ -372,6 +373,34 @@ def generate( do_classifier_free_guidance = guidance_scale > 1.0 # type: ignore + # Check that inputs are consistent with all embeddings or text inputs. All embeddings should be provided if using + # embeddings, and none if using text. + if (prompt_embeds is None) == (prompt is None): + raise ValueError('One and only one of prompt or prompt_embeds should be provided.') + if (pooled_prompt is None) != (prompt_embeds is None): + raise ValueError('pooled_prompt should be provided if and only if using embeddings') + if (prompt_mask is None) != (prompt_embeds is None): + raise ValueError('prompt_mask should be provided if and only if using embeddings') + if (neg_prompt_embeds is None) == (negative_prompt is None): + raise ValueError('One and only one of negative_prompt or neg_prompt_embeds should be provided.') + if (neg_prompt_mask is None) != (neg_prompt_embeds is None): + raise ValueError('neg_prompt_mask should be provided if and only if using embeddings') + if (pooled_neg_prompt is None) != (neg_prompt_embeds is None): + raise ValueError('pooled_neg_prompt should be provided if and only if using embeddings') + + # If the prompt is specified as text, encode it. + if prompt is not None: + t5_embed, clip_embed, t5_attn_mask, clip_attn_mask, pooled_prompt = self.encode_text( + prompt, self.vae.device) + prompt_embeds, prompt_mask = self.prepare_text_embeddings(t5_embed, clip_embed, t5_attn_mask, + clip_attn_mask) + # If negative prompt is specified as text, encode it. + if negative_prompt is not None: + t5_embed, clip_embed, t5_attn_mask, clip_attn_mask, pooled_neg_prompt = self.encode_text( + negative_prompt, self.vae.device) + neg_prompt_embeds, neg_prompt_mask = self.prepare_text_embeddings(t5_embed, clip_embed, t5_attn_mask, + clip_attn_mask) + text_embeddings = _duplicate_tensor(prompt_embeds, num_images_per_prompt) pooled_embeddings = _duplicate_tensor(pooled_prompt, num_images_per_prompt) encoder_attn_mask = _duplicate_tensor(prompt_mask, num_images_per_prompt) @@ -468,206 +497,3 @@ def _duplicate_tensor(tensor, num_images_per_prompt): return tensor.view(batch_size * num_images_per_prompt, seq_len, *[ -1, ] * len(tensor.shape[2:])) - - -def build_diffusion_v1( - unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', - vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', - autoencoder_path: Optional[str] = None, - autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', - prediction_type: str = 'epsilon', - latent_mean: Union[float, Tuple, str] = 0.0, - latent_std: Union[float, Tuple, str] = 7.67754318618, - text_embed_dim: int = 4096, - beta_schedule: str = 'scaled_linear', - zero_terminal_snr: bool = False, - train_metrics: Optional[List] = None, - val_metrics: Optional[List] = None, - quasirandomness: bool = False, - train_seed: int = 42, - val_seed: int = 1138, - fsdp: bool = True, - use_xformers: bool = True, -): - """Stable diffusion 2 training setup + SDXL UNet and VAE. - - Requires batches of matched images and text prompts to train. Generates images from text - prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2. - - Args: - unet_model_name (str): Name of the UNet model to load. Defaults to - 'stabilityai/stable-diffusion-xl-base-1.0'. - vae_model_name (str): Name of the VAE model to load. Defaults to - 'madebyollin/sdxl-vae-fp16-fix' as the official VAE checkpoint (from - 'stabilityai/stable-diffusion-xl-base-1.0') is not compatible with fp16. - autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, - will use the vae from `model_name`. Default `None`. - autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. - prediction_type (str): The type of prediction to use. Must be one of 'sample', - 'epsilon', or 'v_prediction'. Default: `epsilon`. - latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value, - a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder - checkpoint. Defaults to `0.0`. - latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, - a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder - checkpoint. Defaults to `1/0.13025`. - beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. - Default: `scaled_linear`. - zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. - train_metrics (list, optional): List of metrics to compute during training. If None, defaults to - [MeanSquaredError()]. - val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to - [MeanSquaredError()]. - quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. - Default: `False`. - train_seed (int): Seed to use for generating diffusion process noise during training if using - quasirandomness. Default: `42`. - val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. - fsdp (bool): Whether to use FSDP. Defaults to True. - use_xformers (bool): Whether to use xformers for attention. Defaults to True. - """ - latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) - - if train_metrics is None: - train_metrics = [MeanSquaredError()] - if val_metrics is None: - val_metrics = [MeanSquaredError()] - - # Make the autoencoder - if autoencoder_path is None: - if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': - raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') - downsample_factor = 8 - # Use the pretrained vae - try: - vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch.float16) - except: # for handling SDXL vae fp16 fixed checkpoint - vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16) - else: - # Use a custom autoencoder - vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=torch.float16) - if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): - raise ValueError( - 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') - if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': - assert isinstance(latent_statistics, dict) - latent_mean = tuple(latent_statistics['latent_channel_means']) - if isinstance(latent_std, str) and latent_std == 'latent_statistics': - assert isinstance(latent_statistics, dict) - latent_std = tuple(latent_statistics['latent_channel_stds']) - downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) - - # Make the unet - unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0] - - if isinstance(vae, AutoEncoder): - # Adapt the unet config to account for differing number of latent channels if necessary - unet_config['in_channels'] = vae.config['latent_channels'] - unet_config['out_channels'] = vae.config['latent_channels'] - unet_config['cross_attention_dim'] = text_embed_dim - # This config variable is the sum of the text encoder projection dimension (768 for CLIP) and - # the number of additional time embeddings (6) * addition_time_embed_dim (256) - unet_config['projection_class_embeddings_input_dim'] = 2304 - # Init the unet from the config - unet = UNet2DConditionModel(**unet_config) - - # Zero initialization trick - for name, layer in unet.named_modules(): - # Final conv in ResNet blocks - if name.endswith('conv2'): - layer = zero_module(layer) - # proj_out in attention blocks - if name.endswith('to_out.0'): - layer = zero_module(layer) - # Last conv block out projection - unet.conv_out = zero_module(unet.conv_out) - - if isinstance(latent_mean, float): - latent_mean = (latent_mean,) * unet_config['in_channels'] - if isinstance(latent_std, float): - latent_std = (latent_std,) * unet_config['in_channels'] - assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) - - # FSDP Wrapping Scheme - if hasattr(unet, 'mid_block') and unet.mid_block is not None: - for attention in unet.mid_block.attentions: - attention._fsdp_wrap = True - for resnet in unet.mid_block.resnets: - resnet._fsdp_wrap = True - for block in unet.up_blocks: - if hasattr(block, 'attentions'): - for attention in block.attentions: - attention._fsdp_wrap = True - if hasattr(block, 'resnets'): - for resnet in block.resnets: - resnet._fsdp_wrap = True - for block in unet.down_blocks: - if hasattr(block, 'attentions'): - for attention in block.attentions: - attention._fsdp_wrap = True - if hasattr(block, 'resnets'): - for resnet in block.resnets: - resnet._fsdp_wrap = True - - # Make the noise schedulers - noise_scheduler = DDPMScheduler(num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, - beta_schedule=beta_schedule, - trained_betas=None, - variance_type='fixed_small', - clip_sample=False, - prediction_type=prediction_type, - sample_max_value=1.0, - timestep_spacing='leading', - steps_offset=1, - rescale_betas_zero_snr=zero_terminal_snr) - if beta_schedule == 'squaredcos_cap_v2': - inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, - beta_schedule=beta_schedule, - trained_betas=None, - clip_sample=False, - set_alpha_to_one=False, - prediction_type=prediction_type, - rescale_betas_zero_snr=zero_terminal_snr) - else: - inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, - beta_schedule=beta_schedule, - trained_betas=None, - prediction_type=prediction_type, - interpolation_type='linear', - use_karras_sigmas=False, - timestep_spacing='leading', - steps_offset=1, - rescale_betas_zero_snr=zero_terminal_snr) - - # Make the composer model - model = DiffusionV1( - unet=unet, - vae=vae, - noise_scheduler=noise_scheduler, - inference_noise_scheduler=inference_noise_scheduler, - prediction_type=prediction_type, - latent_mean=latent_mean, - latent_std=latent_std, - downsample_factor=downsample_factor, - train_metrics=train_metrics, - val_metrics=val_metrics, - quasirandomness=quasirandomness, - train_seed=train_seed, - val_seed=val_seed, - text_embed_dim=text_embed_dim, - fsdp=fsdp, - ) - if torch.cuda.is_available(): - model = DeviceGPU().module_to_device(model) - if is_xformers_installed and use_xformers: - model.unet.enable_xformers_memory_efficient_attention() - if hasattr(model.vae, 'enable_xformers_memory_efficient_attention'): - model.vae.enable_xformers_memory_efficient_attention() - - return model From 101c353a5b252ba663321145f97276a97570267f Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Fri, 30 Aug 2024 05:52:11 +0000 Subject: [PATCH 07/24] Don't need to check for negative prompt existing --- diffusion/models/t5_diffusion.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py index 9a2e38df..6ee5b592 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/t5_diffusion.py @@ -381,8 +381,6 @@ def generate( raise ValueError('pooled_prompt should be provided if and only if using embeddings') if (prompt_mask is None) != (prompt_embeds is None): raise ValueError('prompt_mask should be provided if and only if using embeddings') - if (neg_prompt_embeds is None) == (negative_prompt is None): - raise ValueError('One and only one of negative_prompt or neg_prompt_embeds should be provided.') if (neg_prompt_mask is None) != (neg_prompt_embeds is None): raise ValueError('neg_prompt_mask should be provided if and only if using embeddings') if (pooled_neg_prompt is None) != (neg_prompt_embeds is None): From 34022236aba1eea7d71ee14a2d3dc107addd416a Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Fri, 30 Aug 2024 06:04:28 +0000 Subject: [PATCH 08/24] Timesteps shall be ints --- diffusion/models/t5_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py index 6ee5b592..4c25d9e6 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/t5_diffusion.py @@ -143,7 +143,7 @@ def _generate_timesteps(self, latents: torch.Tensor): if not self.unet.training: # Sample equally spaced timesteps across all devices global_batch_size = latents.shape[0] * dist.get_world_size() - global_timesteps = torch.linspace(0, len(self.noise_scheduler), global_batch_size) + global_timesteps = torch.linspace(0, len(self.noise_scheduler), global_batch_size).to(torch.int64) # Get this device's subset of all the timesteps idx_offset = dist.get_global_rank() * latents.shape[0] timesteps = global_timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) From 72476c2e0106b1c3e480100e5e84e6df5cb5f5a9 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Fri, 30 Aug 2024 06:24:03 +0000 Subject: [PATCH 09/24] Fix off by one --- diffusion/models/t5_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py index 4c25d9e6..1472b048 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/t5_diffusion.py @@ -143,7 +143,7 @@ def _generate_timesteps(self, latents: torch.Tensor): if not self.unet.training: # Sample equally spaced timesteps across all devices global_batch_size = latents.shape[0] * dist.get_world_size() - global_timesteps = torch.linspace(0, len(self.noise_scheduler), global_batch_size).to(torch.int64) + global_timesteps = torch.linspace(0, len(self.noise_scheduler) - 1, global_batch_size).to(torch.int64) # Get this device's subset of all the timesteps idx_offset = dist.get_global_rank() * latents.shape[0] timesteps = global_timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) From 363edd151c93cbd4a23860ba4074b464a609afe1 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Sun, 1 Sep 2024 20:00:30 +0000 Subject: [PATCH 10/24] Add layernorms before sequence concat --- diffusion/models/t5_diffusion.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py index 1472b048..7c69983d 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/t5_diffusion.py @@ -125,6 +125,9 @@ def __init__( # Projection layers for the text embeddings self.clip_proj = nn.Linear(768, text_embed_dim) self.t5_proj = nn.Linear(4096, text_embed_dim) + # Layernorms for the text embeddings + self.clip_ln = nn.LayerNorm(text_embed_dim) + self.t5_ln = nn.LayerNorm(text_embed_dim) # Learnable position embeddings for the conitioning sequences t5_position_embeddings = torch.randn(self.max_seq_len, text_embed_dim) t5_position_embeddings /= math.sqrt(text_embed_dim) @@ -219,6 +222,9 @@ def prepare_text_embeddings(self, t5_embed, clip_embed, t5_mask, clip_mask): # Add position embeddings t5_embed = 0.707 * t5_embed + 0.707 * self.t5_position_embedding[:t5_embed.shape[1]].unsqueeze(0) clip_embed = 0.707 * clip_embed + 0.707 * self.clip_position_embedding[:clip_embed.shape[1]].unsqueeze(0) + # Apply layernorms + t5_embed = self.t5_ln(t5_embed) + clip_embed = self.clip_ln(clip_embed) # Concatenate the text embeddings text_embeds = torch.cat([t5_embed, clip_embed], dim=1) encoder_attention_mask = torch.cat([t5_mask, clip_mask], dim=1) From 260eb6c722e6a9b1fc78ce8c659b548eedddd23b Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Mon, 2 Sep 2024 06:08:46 +0000 Subject: [PATCH 11/24] Changes for running with bf16 --- diffusion/models/t5_diffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py index 7c69983d..c965befe 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/t5_diffusion.py @@ -111,7 +111,7 @@ def __init__( self.inference_scheduler = inference_noise_scheduler # freeze VAE during diffusion training self.vae.requires_grad_(False) - self.vae = self.vae.half() + self.vae = self.vae.bfloat16() if fsdp: # only wrap models we are training self.vae._fsdp_wrap = False @@ -171,9 +171,9 @@ def set_rng_generator(self, rng_generator: torch.Generator): """Sets the rng generator for the model.""" self.rng_generator = rng_generator - def encode_images(self, inputs): + def encode_images(self, inputs, dtype=torch.bfloat16): with torch.cuda.amp.autocast(enabled=False): - latents = self.vae.encode(inputs.half())['latent_dist'].sample().data + latents = self.vae.encode(inputs.to(dtype))['latent_dist'].sample().data latents = (latents - self.latent_mean) / self.latent_std # scale latents return latents @@ -489,7 +489,7 @@ def generate( # We now use the vae to decode the generated latents back into the image. # scale and decode the image latents with vae image = self.decode_latents(latents) - return image.detach() # (batch*num_images_per_prompt, channel, h, w) + return image.detach().float() # (batch*num_images_per_prompt, channel, h, w) def _duplicate_tensor(tensor, num_images_per_prompt): From e30fa2b390ba5bd9671fae2521d0d9bc7dea7a32 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Mon, 9 Sep 2024 20:30:54 +0000 Subject: [PATCH 12/24] Update docstrings and fix types --- diffusion/models/t5_diffusion.py | 145 +++++++++++++++---------------- 1 file changed, 69 insertions(+), 76 deletions(-) diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/t5_diffusion.py index c965befe..f57b8d65 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/t5_diffusion.py @@ -25,17 +25,15 @@ class DiffusionV1(ComposerModel): - """Stable Diffusion ComposerModel. + """Diffusion ComposerModel for running with precomputed T5 and CLIP embeddings. This is a Latent Diffusion model conditioned on text prompts that are run through - a pre-trained CLIP or LLM model. The CLIP outputs are then passed to as an - additional input to our Unet during training and can later be used to guide - the image generation process. + a pre-trained CLIP and T5 text encoder. Args: unet (torch.nn.Module): HuggingFace conditional unet, must accept a (B, C, H, W) input, (B,) timestep array of noise timesteps, - and (B, 77, 768) text conditioning vectors. + and (B, 77, text_embed_dim) text conditioning vectors. vae (torch.nn.Module): HuggingFace or compatible vae. must support `.encode()` and `decode()` functions. noise_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers @@ -46,13 +44,14 @@ class DiffusionV1(ComposerModel): t5_encoder (Optional): T5 text encoder. Should only be specified during inference. Default: `None`. clip_tokenizer (Optional): Tokenizer for CLIP. Should only be specified during inference. Default: `None`. clip_encoder (Optional): CLIP text encoder. Should only be specified during inference. Default: `None`. + text_embed_dim (int): The common dimension to project the text embeddings to. Default: `4096`. prediction_type (str): The type of prediction to use. Must be one of 'sample', 'epsilon', or 'v_prediction'. Default: `epsilon`. latent_mean (Optional[tuple[float]]): The means of the latent space. If not specified, defaults to . Default: ``(0.0,) * 4``. latent_std (Optional[tuple[float]]): The standard deviations of the latent space. Default: ``(1/0.13025,)*4``. downsample_factor (int): The factor by which the image is downsampled by the autoencoder. Default `8`. - max_seq_len (int): The maximum sequence length for the text encoder. Default: `77`. + max_seq_len (int): The maximum sequence length for the text encoder. Default: `256`. train_metrics (list): List of torchmetrics to calculate during training. Default: `None`. val_metrics (list): List of torchmetrics to calculate during validation. @@ -75,17 +74,17 @@ def __init__( t5_encoder: Optional[torch.nn.Module] = None, clip_tokenizer: Optional[PreTrainedTokenizer] = None, clip_encoder: Optional[torch.nn.Module] = None, + text_embed_dim: int = 4096, prediction_type: str = 'epsilon', latent_mean: Tuple[float] = (0.0,) * 4, latent_std: Tuple[float] = (1 / 0.13025,) * 4, downsample_factor: int = 8, - max_seq_len: int = 77, + max_seq_len: int = 256, train_metrics: Optional[List] = None, val_metrics: Optional[List] = None, quasirandomness: bool = False, train_seed: int = 42, val_seed: int = 1138, - text_embed_dim: int = 4096, fsdp: bool = False, ): super().__init__() @@ -210,7 +209,8 @@ def encode_text(self, text, device): pooled_embeddings = clip_out[1] return t5_embed, clip_embed, t5_attn_mask, clip_attn_mask, pooled_embeddings - def prepare_text_embeddings(self, t5_embed, clip_embed, t5_mask, clip_mask): + def prepare_text_embeddings(self, t5_embed: torch.Tensor, clip_embed: torch.Tensor, t5_mask: torch.Tensor, + clip_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if t5_embed.shape[1] > self.max_seq_len: t5_embed = t5_embed[:, :self.max_seq_len] t5_mask = t5_mask[:, :self.max_seq_len] @@ -305,20 +305,20 @@ def generate( self, prompt: Optional[list] = None, negative_prompt: Optional[list] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt: Optional[torch.FloatTensor] = None, - prompt_mask: Optional[torch.LongTensor] = None, - neg_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_neg_prompt: Optional[torch.FloatTensor] = None, - neg_prompt_mask: Optional[torch.LongTensor] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 3.0, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt: Optional[torch.Tensor] = None, + prompt_mask: Optional[torch.Tensor] = None, + neg_prompt_embeds: Optional[torch.Tensor] = None, + pooled_neg_prompt: Optional[torch.Tensor] = None, + neg_prompt_mask: Optional[torch.Tensor] = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 3.0, rescaled_guidance: Optional[float] = None, - num_images_per_prompt: Optional[int] = 1, + num_images_per_prompt: int = 1, seed: Optional[int] = None, - progress_bar: Optional[bool] = True, + progress_bar: bool = True, crop_params: Optional[torch.Tensor] = None, input_size_params: Optional[torch.Tensor] = None, ): @@ -328,21 +328,28 @@ def generate( one forward pass through the unet. Args: - prompt (List[str]): The prompts to guide the image generation. + prompt (List[str]): The prompts to guide the image generation. Only use if not + using embeddings. Default: `None`. negative_prompt (str or List[str]): The prompt or prompts to guide the image generation away from. Ignored when not using guidance - (i.e., ignored if guidance_scale is less than 1). - Must be the same length as list of prompts. Default: `None`. - prompt_embeds (torch.FloatTensor): Optionally pass pre-tokenized prompts instead - of string prompts. If both prompt and prompt_embeds - are passed, prompt_embeds will be used. Default: `None`. - neg_prompt_embeds (torch.FloatTensor): Optionally pass pre-embedded negative - prompts instead of string negative prompts. If both negative_prompt and - negative_prompt_embeds are passed, prompt_embeds will be used. Default: `None`. + (i.e., ignored if guidance_scale is less than 1). Must be the same length + as list of prompts. Only use if not using negative embeddings. Default: `None`. + prompt_embeds (torch.Tensor): Optionally pass pre-tokenized prompts instead + of string prompts. Default: `None`. + pooled_prompt (torch.Tensor): Optionally pass a precomputed pooled prompt embedding + if using embeddings. Default: `None`. + prompt_mask (torch.Tensor): Optionally pass a precomputed attention mask for the + prompt embeddings. Default: `None`. + neg_prompt_embeds (torch.Tensor): Optionally pass pre-embedded negative + prompts instead of string negative prompts. Default: `None`. + pooled_neg_prompt (torch.Tensor): Optionally pass a precomputed pooled negative + prompt embedding if using embeddings. Default: `None`. + neg_prompt_mask (torch.Tensor): Optionally pass a precomputed attention mask for the + negative prompt embeddings. Default: `None`. height (int, optional): The height in pixels of the generated image. - Default: `self.unet.config.sample_size * 8)`. + Default: `1024`. width (int, optional): The width in pixels of the generated image. - Default: `self.unet.config.sample_size * 8)`. + Default: `1024`. num_inference_steps (int): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. Default: `50`. @@ -360,9 +367,9 @@ def generate( Default: `True`. seed (int): Random seed to use for generation. Set a seed for reproducible generation. Default: `None`. - crop_params (torch.FloatTensor of size [Bx2], optional): Crop parameters to use + crop_params (torch.Tensor of size [Bx2], optional): Crop parameters to use when generating images with SDXL. Default: `None`. - input_size_params (torch.FloatTensor of size [Bx2], optional): Size parameters + input_size_params (torch.Tensor of size [Bx2], optional): Size parameters (representing original size of input image) to use when generating images with SDXL. Default: `None`. """ @@ -370,14 +377,7 @@ def generate( device = self.vae.device rng_generator = torch.Generator(device=device) if seed: - rng_generator = rng_generator.manual_seed(seed) # type: ignore - - if height is None: - height = self.unet.config.sample_size * self.downsample_factor - if width is None: - width = self.unet.config.sample_size * self.downsample_factor - - do_classifier_free_guidance = guidance_scale > 1.0 # type: ignore + rng_generator = rng_generator.manual_seed(seed) # Check that inputs are consistent with all embeddings or text inputs. All embeddings should be provided if using # embeddings, and none if using text. @@ -409,24 +409,23 @@ def generate( pooled_embeddings = _duplicate_tensor(pooled_prompt, num_images_per_prompt) encoder_attn_mask = _duplicate_tensor(prompt_mask, num_images_per_prompt) - batch_size = len(prompt_embeds) # len prompts * num_images_per_prompt + batch_size = len(text_embeddings) # len prompts * num_images_per_prompt # classifier free guidance + negative prompts # negative prompt is given in place of the unconditional input in classifier free guidance - if do_classifier_free_guidance: - if not neg_prompt_embeds: - # Negative prompt is empty and we want to zero it out - neg_prompt_embeds = torch.zeros_like(text_embeddings) - pooled_neg_prompt = torch.zeros_like(pooled_embeddings) - neg_prompt_mask = torch.zeros_like(encoder_attn_mask) - else: - neg_prompt_embeds = _duplicate_tensor(neg_prompt_embeds, num_images_per_prompt) - pooled_neg_prompt = _duplicate_tensor(pooled_neg_prompt, num_images_per_prompt) - neg_prompt_mask = _duplicate_tensor(neg_prompt_mask, num_images_per_prompt) + if not neg_prompt_embeds: + # Negative prompt is empty and we want to zero it out + neg_prompt_embeds = torch.zeros_like(text_embeddings) + pooled_neg_prompt = torch.zeros_like(pooled_embeddings) + neg_prompt_mask = torch.zeros_like(encoder_attn_mask) + else: + neg_prompt_embeds = _duplicate_tensor(neg_prompt_embeds, num_images_per_prompt) + pooled_neg_prompt = _duplicate_tensor(pooled_neg_prompt, num_images_per_prompt) + neg_prompt_mask = _duplicate_tensor(neg_prompt_mask, num_images_per_prompt) - # concat uncond + prompt - text_embeddings = torch.cat([neg_prompt_embeds, text_embeddings]) - pooled_embeddings = torch.cat([pooled_neg_prompt, pooled_embeddings]) - encoder_attn_mask = torch.cat([neg_prompt_mask, encoder_attn_mask]) + # concat uncond + prompt + text_embeddings = torch.cat([neg_prompt_embeds, text_embeddings]) + pooled_embeddings = torch.cat([pooled_neg_prompt, pooled_embeddings]) + encoder_attn_mask = torch.cat([neg_prompt_mask, encoder_attn_mask]) # prepare for diffusion generation process latents = torch.randn( @@ -450,21 +449,16 @@ def generate( input_size_params = torch.tensor([[width, height]] * batch_size, dtype=text_embeddings.dtype) output_size_params = torch.tensor([[width, height]] * batch_size, dtype=text_embeddings.dtype) - if do_classifier_free_guidance: - crop_params = torch.cat([crop_params, crop_params]) - input_size_params = torch.cat([input_size_params, input_size_params]) - output_size_params = torch.cat([output_size_params, output_size_params]) + crop_params = torch.cat([crop_params, crop_params]) + input_size_params = torch.cat([input_size_params, input_size_params]) + output_size_params = torch.cat([output_size_params, output_size_params]) add_time_ids = torch.cat([input_size_params, crop_params, output_size_params], dim=1).to(device) added_cond_kwargs = {'text_embeds': pooled_embeddings, 'time_ids': add_time_ids} # backward diffusion process for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): - if do_classifier_free_guidance: - latent_model_input = torch.cat([latents] * 2) - else: - latent_model_input = latents - + latent_model_input = torch.cat([latents] * 2) latent_model_input = self.inference_scheduler.scale_model_input(latent_model_input, t) # Model prediction pred = self.unet(latent_model_input, @@ -473,16 +467,15 @@ def generate( encoder_attention_mask=encoder_attn_mask, added_cond_kwargs=added_cond_kwargs).sample - if do_classifier_free_guidance: - # perform guidance. Note this is only techincally correct for prediction_type 'epsilon' - pred_uncond, pred_text = pred.chunk(2) - pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) - # Optionally rescale the classifer free guidance - if rescaled_guidance is not None: - std_pos = torch.std(pred_text, dim=(1, 2, 3), keepdim=True) - std_cfg = torch.std(pred, dim=(1, 2, 3), keepdim=True) - pred_rescaled = pred * (std_pos / std_cfg) - pred = pred_rescaled * rescaled_guidance + pred * (1 - rescaled_guidance) + # perform guidance. Note this is only techincally correct for prediction_type 'epsilon' + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) + # Optionally rescale the classifer free guidance + if rescaled_guidance is not None: + std_pos = torch.std(pred_text, dim=(1, 2, 3), keepdim=True) + std_cfg = torch.std(pred, dim=(1, 2, 3), keepdim=True) + pred_rescaled = pred * (std_pos / std_cfg) + pred = pred_rescaled * rescaled_guidance + pred * (1 - rescaled_guidance) # compute the previous noisy sample x_t -> x_t-1 latents = self.inference_scheduler.step(pred, t, latents, generator=rng_generator).prev_sample From ca94d5ba882ef043223809e2aed13c495cfa0e02 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Mon, 9 Sep 2024 23:43:15 +0000 Subject: [PATCH 13/24] Drop nans --- diffusion/datasets/image_caption_latents.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/diffusion/datasets/image_caption_latents.py b/diffusion/datasets/image_caption_latents.py index fa1158f5..6107edc5 100644 --- a/diffusion/datasets/image_caption_latents.py +++ b/diffusion/datasets/image_caption_latents.py @@ -140,6 +140,14 @@ def __getitem__(self, index): if 'CLIP_LATENTS' in latent_key: clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy() out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1]) + if out['CLIP_POOLED'].isnan().any(): + out['CLIP_POOLED'] = torch.zeros_like(out['CLIP_POOLED']) + if out['T5_LATENTS'].isnan().any(): + out['T5_LATENTS'] = torch.zeros_like(out['T5_LATENTS']) + out['T5_ATTENTION_MASK'] = torch.zeros_like(out['T5_ATTENTION_MASK']) + if out['CLIP_LATENTS'].isnan().any(): + out['CLIP_LATENTS'] = torch.zeros_like(out['CLIP_LATENTS']) + out['CLIP_ATTENTION_MASK'] = torch.zeros_like(out['CLIP_ATTENTION_MASK']) return out From 3d7b65ef5212297d8a398133caaab184e374f404 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Tue, 10 Sep 2024 20:14:07 +0000 Subject: [PATCH 14/24] Clean up names --- diffusion/datasets/image_caption_latents.py | 20 +++++++++++-------- diffusion/models/__init__.py | 1 + diffusion/models/models.py | 6 +++--- ...y => precomputed_text_latent_diffusion.py} | 4 ++-- 4 files changed, 18 insertions(+), 13 deletions(-) rename diffusion/models/{t5_diffusion.py => precomputed_text_latent_diffusion.py} (99%) diff --git a/diffusion/datasets/image_caption_latents.py b/diffusion/datasets/image_caption_latents.py index 6107edc5..d019b284 100644 --- a/diffusion/datasets/image_caption_latents.py +++ b/diffusion/datasets/image_caption_latents.py @@ -40,6 +40,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset): attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset. Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``. latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``. + drop_nans (bool): Whether to treat samples with NaN latents as dropped captions. Default: ``True``. **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader """ @@ -57,6 +58,7 @@ def __init__( text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)), attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), latent_dtype: torch.dtype = torch.bfloat16, + drop_nans: bool = True, **streaming_kwargs, ): @@ -76,6 +78,7 @@ def __init__( self.text_latent_shapes = text_latent_shapes self.attention_mask_keys = attention_mask_keys self.latent_dtype = latent_dtype + self.drop_nans = drop_nans def __getitem__(self, index): sample = super().__getitem__(index) @@ -140,14 +143,15 @@ def __getitem__(self, index): if 'CLIP_LATENTS' in latent_key: clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy() out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1]) - if out['CLIP_POOLED'].isnan().any(): - out['CLIP_POOLED'] = torch.zeros_like(out['CLIP_POOLED']) - if out['T5_LATENTS'].isnan().any(): - out['T5_LATENTS'] = torch.zeros_like(out['T5_LATENTS']) - out['T5_ATTENTION_MASK'] = torch.zeros_like(out['T5_ATTENTION_MASK']) - if out['CLIP_LATENTS'].isnan().any(): - out['CLIP_LATENTS'] = torch.zeros_like(out['CLIP_LATENTS']) - out['CLIP_ATTENTION_MASK'] = torch.zeros_like(out['CLIP_ATTENTION_MASK']) + if self.drop_nans: + if out['CLIP_POOLED'].isnan().any(): + out['CLIP_POOLED'] = torch.zeros_like(out['CLIP_POOLED']) + if out['T5_LATENTS'].isnan().any(): + out['T5_LATENTS'] = torch.zeros_like(out['T5_LATENTS']) + out['T5_ATTENTION_MASK'] = torch.zeros_like(out['T5_ATTENTION_MASK']) + if out['CLIP_LATENTS'].isnan().any(): + out['CLIP_LATENTS'] = torch.zeros_like(out['CLIP_LATENTS']) + out['CLIP_ATTENTION_MASK'] = torch.zeros_like(out['CLIP_ATTENTION_MASK']) return out diff --git a/diffusion/models/__init__.py b/diffusion/models/__init__.py index 1a4bb0a2..83111005 100644 --- a/diffusion/models/__init__.py +++ b/diffusion/models/__init__.py @@ -20,5 +20,6 @@ 'stable_diffusion_2', 'stable_diffusion_xl', 'StableDiffusion', + 'TextLatentDiffusion', 'text_to_image_transformer', ] diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 671791bb..4496faa2 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -20,8 +20,8 @@ from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT -from diffusion.models.t5_diffusion import DiffusionV1 from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer +from diffusion.models.text_latent_diffusion import PrecomputedTextLatentDiffusion from diffusion.models.transformer import DiffusionTransformer from diffusion.schedulers.schedulers import ContinuousTimeScheduler from diffusion.schedulers.utils import shift_noise_schedule @@ -581,7 +581,7 @@ def stable_diffusion_xl( return model -def build_diffusion_v1( +def build_text_latent_diffusion( unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', autoencoder_path: Optional[str] = None, @@ -779,7 +779,7 @@ def build_diffusion_v1( cache_dir=cache_dir, local_files_only=True).cuda().eval() # Make the composer model - model = DiffusionV1( + model = PrecomputedTextLatentDiffusion( unet=unet, vae=vae, t5_tokenizer=t5_tokenizer, diff --git a/diffusion/models/t5_diffusion.py b/diffusion/models/precomputed_text_latent_diffusion.py similarity index 99% rename from diffusion/models/t5_diffusion.py rename to diffusion/models/precomputed_text_latent_diffusion.py index f57b8d65..26a3e8be 100644 --- a/diffusion/models/t5_diffusion.py +++ b/diffusion/models/precomputed_text_latent_diffusion.py @@ -24,7 +24,7 @@ is_xformers_installed = False -class DiffusionV1(ComposerModel): +class PrecomputedTextLatentDiffusion(ComposerModel): """Diffusion ComposerModel for running with precomputed T5 and CLIP embeddings. This is a Latent Diffusion model conditioned on text prompts that are run through @@ -136,7 +136,7 @@ def __init__( self.clip_position_embedding = torch.nn.Parameter(clip_position_embeddings, requires_grad=True) def _apply(self, fn): - super(DiffusionV1, self)._apply(fn) + super(PrecomputedTextLatentDiffusion, self)._apply(fn) self.latent_mean = fn(self.latent_mean) self.latent_std = fn(self.latent_std) return self From 7bc1796a6e3623d8300b06a52c6141749f2ff7da Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Tue, 10 Sep 2024 22:12:56 +0000 Subject: [PATCH 15/24] Fix depreciation --- diffusion/models/__init__.py | 3 ++- diffusion/models/models.py | 5 +++-- diffusion/models/precomputed_text_latent_diffusion.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/diffusion/models/__init__.py b/diffusion/models/__init__.py index 83111005..ef0c1e6f 100644 --- a/diffusion/models/__init__.py +++ b/diffusion/models/__init__.py @@ -8,6 +8,7 @@ text_to_image_transformer) from diffusion.models.noop import NoOpModel from diffusion.models.pixel_diffusion import PixelDiffusion +from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion from diffusion.models.stable_diffusion import StableDiffusion __all__ = [ @@ -20,6 +21,6 @@ 'stable_diffusion_2', 'stable_diffusion_xl', 'StableDiffusion', - 'TextLatentDiffusion', + 'PrecomputedTextLatentDiffusion', 'text_to_image_transformer', ] diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 4496faa2..7d5ed150 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -18,10 +18,10 @@ ComposerDiffusersAutoEncoder, load_autoencoder) from diffusion.models.layers import ClippedAttnProcessor2_0, ClippedXFormersAttnProcessor, zero_module from diffusion.models.pixel_diffusion import PixelDiffusion +from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion from diffusion.models.stable_diffusion import StableDiffusion from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer -from diffusion.models.text_latent_diffusion import PrecomputedTextLatentDiffusion from diffusion.models.transformer import DiffusionTransformer from diffusion.schedulers.schedulers import ContinuousTimeScheduler from diffusion.schedulers.utils import shift_noise_schedule @@ -627,6 +627,7 @@ def build_text_latent_diffusion( latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder checkpoint. Defaults to `1/0.13025`. + text_embed_dim (int): The dimension to project the text embeddings to. Default: `4096`. beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. Default: `scaled_linear`. zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. @@ -681,7 +682,7 @@ def build_text_latent_diffusion( unet_config['in_channels'] = vae.config['latent_channels'] unet_config['out_channels'] = vae.config['latent_channels'] unet_config['cross_attention_dim'] = text_embed_dim - # This config variable is the sum of the text encoder projection dimension (768 for CLIP) and + # This config variable is the sum of the text encoder projection dimension and # the number of additional time embeddings (6) * addition_time_embed_dim (256) unet_config['projection_class_embeddings_input_dim'] = 2304 # Init the unet from the config diff --git a/diffusion/models/precomputed_text_latent_diffusion.py b/diffusion/models/precomputed_text_latent_diffusion.py index 26a3e8be..a3166390 100644 --- a/diffusion/models/precomputed_text_latent_diffusion.py +++ b/diffusion/models/precomputed_text_latent_diffusion.py @@ -171,7 +171,7 @@ def set_rng_generator(self, rng_generator: torch.Generator): self.rng_generator = rng_generator def encode_images(self, inputs, dtype=torch.bfloat16): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast('cuda', enabled=False): latents = self.vae.encode(inputs.to(dtype))['latent_dist'].sample().data latents = (latents - self.latent_mean) / self.latent_std # scale latents return latents From 4c36a069daf70a5040084f50dfe001f1b5123302 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 11 Sep 2024 16:28:47 +0000 Subject: [PATCH 16/24] More name changes --- diffusion/models/__init__.py | 5 +++-- diffusion/models/models.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/diffusion/models/__init__.py b/diffusion/models/__init__.py index ef0c1e6f..69cf02bd 100644 --- a/diffusion/models/__init__.py +++ b/diffusion/models/__init__.py @@ -4,8 +4,8 @@ """Diffusion models.""" from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion, - discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl, - text_to_image_transformer) + discrete_pixel_diffusion, precomputed_text_latent_diffusion, stable_diffusion_2, + stable_diffusion_xl, text_to_image_transformer) from diffusion.models.noop import NoOpModel from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion @@ -18,6 +18,7 @@ 'discrete_pixel_diffusion', 'NoOpModel', 'PixelDiffusion', + 'precomputed_text_latent_diffusion', 'stable_diffusion_2', 'stable_diffusion_xl', 'StableDiffusion', diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 7d5ed150..a94f21b8 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -581,7 +581,7 @@ def stable_diffusion_xl( return model -def build_text_latent_diffusion( +def precomputed_text_latent_diffusion( unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', autoencoder_path: Optional[str] = None, From 040afaecc87850678977032d8ce37acb7acb42cb Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 11 Sep 2024 18:22:37 +0000 Subject: [PATCH 17/24] Fixes for running inference --- diffusion/models/models.py | 10 +++++----- diffusion/models/precomputed_text_latent_diffusion.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index a94f21b8..8c6ffffb 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -769,16 +769,16 @@ def precomputed_text_latent_diffusion( clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer', cache_dir=cache_dir, - local_files_only=True) + local_files_only=False) t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl', - torch_dtype=torch.bfloat16, + torch_dtype=torch.float16, cache_dir=cache_dir, - local_files_only=True).encoder.eval() + local_files_only=False).encoder.eval() clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='text_encoder', - torch_dtype=torch.bfloat16, + torch_dtype=torch.float16, cache_dir=cache_dir, - local_files_only=True).cuda().eval() + local_files_only=False).cuda().eval() # Make the composer model model = PrecomputedTextLatentDiffusion( unet=unet, diff --git a/diffusion/models/precomputed_text_latent_diffusion.py b/diffusion/models/precomputed_text_latent_diffusion.py index a3166390..c74a7468 100644 --- a/diffusion/models/precomputed_text_latent_diffusion.py +++ b/diffusion/models/precomputed_text_latent_diffusion.py @@ -191,18 +191,18 @@ def encode_text(self, text, device): max_length=self.t5_tokenizer.model_max_length, truncation=True, return_tensors='pt') - tokenized_captions = t5_tokenizer_out['input_ids'].to(device) + t5_tokenized_captions = t5_tokenizer_out['input_ids'].to(device) t5_attn_mask = t5_tokenizer_out['attention_mask'].to(torch.bool).to(device) - t5_embed = self.t5_encoder(input_ids=tokenized_captions, attention_mask=t5_attn_mask) + t5_embed = self.t5_encoder(input_ids=t5_tokenized_captions, attention_mask=t5_attn_mask)[0] # Encode with CLIP clip_tokenizer_out = self.clip_tokenizer(text, padding='max_length', max_length=self.clip_tokenizer.model_max_length, truncation=True, return_tensors='pt') - tokenized_captions = clip_tokenizer_out['input_ids'].to(device) + clip_tokenized_captions = clip_tokenizer_out['input_ids'].to(device) clip_attn_mask = clip_tokenizer_out['attention_mask'].to(torch.bool).to(device) - clip_out = self.clip_encoder(input_ids=tokenized_captions, + clip_out = self.clip_encoder(input_ids=clip_tokenized_captions, attention_mask=clip_attn_mask, output_hidden_states=True) clip_embed = clip_out.hidden_states[-2] @@ -412,7 +412,7 @@ def generate( batch_size = len(text_embeddings) # len prompts * num_images_per_prompt # classifier free guidance + negative prompts # negative prompt is given in place of the unconditional input in classifier free guidance - if not neg_prompt_embeds: + if neg_prompt_embeds is None: # Negative prompt is empty and we want to zero it out neg_prompt_embeds = torch.zeros_like(text_embeddings) pooled_neg_prompt = torch.zeros_like(pooled_embeddings) From 9abc15b3d34e6473ee538abfa1d103122d6a11bf Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 11 Sep 2024 18:57:28 +0000 Subject: [PATCH 18/24] Update docstrings --- diffusion/models/models.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 8c6ffffb..2965e13c 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -602,10 +602,7 @@ def precomputed_text_latent_diffusion( fsdp: bool = True, use_xformers: bool = True, ): - """Stable diffusion 2 training setup + SDXL UNet and VAE. - - Requires batches of matched images and text prompts to train. Generates images from text - prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2. + """Latent diffusion model training using precomputed text latents from T5-XXL and CLIP. Args: unet_model_name (str): Name of the UNet model to load. Defaults to From ebacb59dc4e5b83bd79cd76b10793f7f9d131e32 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 11 Sep 2024 20:56:13 +0000 Subject: [PATCH 19/24] Configurable schedulers --- diffusion/models/models.py | 81 +++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 40 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 2965e13c..57b7a3ad 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -5,7 +5,7 @@ import logging import math -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from composer.devices import DeviceGPU @@ -592,8 +592,8 @@ def precomputed_text_latent_diffusion( latent_mean: Union[float, Tuple, str] = 0.0, latent_std: Union[float, Tuple, str] = 7.67754318618, text_embed_dim: int = 4096, - beta_schedule: str = 'scaled_linear', - zero_terminal_snr: bool = False, + train_noise_scheduler_params: Optional[Dict[str, Any]] = None, + inference_noise_scheduler_params: Optional[Dict[str, Any]] = None, train_metrics: Optional[List] = None, val_metrics: Optional[List] = None, quasirandomness: bool = False, @@ -625,9 +625,10 @@ def precomputed_text_latent_diffusion( a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder checkpoint. Defaults to `1/0.13025`. text_embed_dim (int): The dimension to project the text embeddings to. Default: `4096`. - beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. - Default: `scaled_linear`. - zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. + train_noise_scheduler_params (Dict): Parameters to overried in the training noise scheduler. Anything not + specified will default to SDXL values. Default: `None`. + inference_noise_scheduler_params (Dict): Parameters to overried in the inference noise scheduler. Anything + not specified will default to SDXL values. Default: `None`. train_metrics (list, optional): List of metrics to compute during training. If None, defaults to [MeanSquaredError()]. val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to @@ -724,40 +725,40 @@ def precomputed_text_latent_diffusion( resnet._fsdp_wrap = True # Make the noise schedulers - noise_scheduler = DDPMScheduler(num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, - beta_schedule=beta_schedule, - trained_betas=None, - variance_type='fixed_small', - clip_sample=False, - prediction_type=prediction_type, - sample_max_value=1.0, - timestep_spacing='leading', - steps_offset=1, - rescale_betas_zero_snr=zero_terminal_snr) - if beta_schedule == 'squaredcos_cap_v2': - inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, - beta_schedule=beta_schedule, - trained_betas=None, - clip_sample=False, - set_alpha_to_one=False, - prediction_type=prediction_type, - rescale_betas_zero_snr=zero_terminal_snr) - else: - inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, - beta_schedule=beta_schedule, - trained_betas=None, - prediction_type=prediction_type, - interpolation_type='linear', - use_karras_sigmas=False, - timestep_spacing='leading', - steps_offset=1, - rescale_betas_zero_snr=zero_terminal_snr) + train_scheduler_params: Dict[str, Any] = { + 'num_train_timesteps': 1000, + 'beta_start': 0.00085, + 'beta_end': 0.012, + 'beta_schedule': 'scaled_linear', + 'variance_type': 'fixed_small', + 'clip_sample': False, + 'prediction_type': prediction_type, + 'sample_max_value': 1.0, + 'timestep_spacing': 'leading', + 'steps_offset': 1, + 'rescale_betas_zero_snr': False, + } + if train_noise_scheduler_params is not None: + train_scheduler_params.update(train_noise_scheduler_params) + noise_scheduler = DDPMScheduler(**train_scheduler_params) + + inference_scheduler_params: Dict[str, Any] = { + 'num_train_timesteps': 1000, + 'beta_start': 0.00085, + 'beta_end': 0.012, + 'beta_schedule': 'scaled_linear', + 'trained_betas': None, + 'prediction_type': prediction_type, + 'interpolation_type': 'linear', + 'use_karras_sigmas': False, + 'timestep_spacing': 'leading', + 'steps_offset': 1, + 'rescale_betas_zero_snr': False, + } + + if inference_noise_scheduler_params is not None: + inference_scheduler_params.update(inference_noise_scheduler_params) + inference_noise_scheduler = EulerDiscreteScheduler(**inference_scheduler_params) # Optionally load the tokenizers and text encoders t5_tokenizer, t5_encoder, clip_tokenizer, clip_encoder = None, None, None, None From 5ceee3fd5647ab6657d734b00be79a1d2b55dc74 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 11 Sep 2024 21:13:26 +0000 Subject: [PATCH 20/24] Add schedule shifting --- diffusion/models/models.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 57b7a3ad..8bf71d14 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -594,6 +594,7 @@ def precomputed_text_latent_diffusion( text_embed_dim: int = 4096, train_noise_scheduler_params: Optional[Dict[str, Any]] = None, inference_noise_scheduler_params: Optional[Dict[str, Any]] = None, + scheduler_shift_resolution: int = 256, train_metrics: Optional[List] = None, val_metrics: Optional[List] = None, quasirandomness: bool = False, @@ -629,6 +630,7 @@ def precomputed_text_latent_diffusion( specified will default to SDXL values. Default: `None`. inference_noise_scheduler_params (Dict): Parameters to overried in the inference noise scheduler. Anything not specified will default to SDXL values. Default: `None`. + scheduler_shift_resolution (int): The resolution to shift the noise scheduler to. Default: `256`. train_metrics (list, optional): List of metrics to compute during training. If None, defaults to [MeanSquaredError()]. val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to @@ -760,6 +762,14 @@ def precomputed_text_latent_diffusion( inference_scheduler_params.update(inference_noise_scheduler_params) inference_noise_scheduler = EulerDiscreteScheduler(**inference_scheduler_params) + # Shift noise scheduler to correct for resolution changes + noise_scheduler = shift_noise_schedule(noise_scheduler, + base_dim=32, + shift_dim=scheduler_shift_resolution // downsample_factor) + inference_noise_scheduler = shift_noise_schedule(inference_noise_scheduler, + base_dim=32, + shift_dim=scheduler_shift_resolution // downsample_factor) + # Optionally load the tokenizers and text encoders t5_tokenizer, t5_encoder, clip_tokenizer, clip_encoder = None, None, None, None if include_text_encoders: From d950a50376aafa7b4855bbad523d0faab333fbe5 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 11 Sep 2024 21:27:19 +0000 Subject: [PATCH 21/24] Add option for LoRA --- diffusion/models/models.py | 39 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 8bf71d14..ff510cbf 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -602,6 +602,8 @@ def precomputed_text_latent_diffusion( val_seed: int = 1138, fsdp: bool = True, use_xformers: bool = True, + lora_rank: Optional[int] = None, + lora_alpha: Optional[int] = None, ): """Latent diffusion model training using precomputed text latents from T5-XXL and CLIP. @@ -642,6 +644,8 @@ def precomputed_text_latent_diffusion( val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. fsdp (bool): Whether to use FSDP. Defaults to True. use_xformers (bool): Whether to use xformers for attention. Defaults to True. + lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None. + lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None. """ latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) @@ -809,6 +813,41 @@ def precomputed_text_latent_diffusion( text_embed_dim=text_embed_dim, fsdp=fsdp, ) + + if lora_rank is not None: + assert lora_alpha is not None + model.unet.requires_grad_(False) + for param in model.unet.parameters(): + param.requires_grad_(False) + + unet_lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + init_lora_weights='gaussian', + target_modules=['to_k', 'to_q', 'to_v', 'to_out.0'], + ) + model.unet.add_adapter(unet_lora_config) + model.unet._fsdp_wrap = True + if hasattr(model.unet, 'mid_block') and model.unet.mid_block is not None: + for attention in model.unet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in model.unet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + if torch.cuda.is_available(): model = DeviceGPU().module_to_device(model) if is_xformers_installed and use_xformers: From 40ecb592f4ecea29d4c8c05f7e9d33450dfb3e8a Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Sat, 14 Sep 2024 05:21:03 +0000 Subject: [PATCH 22/24] Proper tensor timestep --- diffusion/models/precomputed_text_latent_diffusion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/diffusion/models/precomputed_text_latent_diffusion.py b/diffusion/models/precomputed_text_latent_diffusion.py index c74a7468..6a747222 100644 --- a/diffusion/models/precomputed_text_latent_diffusion.py +++ b/diffusion/models/precomputed_text_latent_diffusion.py @@ -460,9 +460,11 @@ def generate( for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): latent_model_input = torch.cat([latents] * 2) latent_model_input = self.inference_scheduler.scale_model_input(latent_model_input, t) + # Make timestep + timestep = t.unsqueeze(0).repeat(latent_model_input.shape[0]).to(device) # Model prediction pred = self.unet(latent_model_input, - t, + timestep, encoder_hidden_states=text_embeddings, encoder_attention_mask=encoder_attn_mask, added_cond_kwargs=added_cond_kwargs).sample From 479fe544fd961c419086c793bed0de4bfa327cb5 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Fri, 27 Sep 2024 22:28:24 +0000 Subject: [PATCH 23/24] Add option for pre-bucketed aspect ratio buckets --- diffusion/datasets/image_caption_latents.py | 24 +++++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/diffusion/datasets/image_caption_latents.py b/diffusion/datasets/image_caption_latents.py index d019b284..dd2cbcc8 100644 --- a/diffusion/datasets/image_caption_latents.py +++ b/diffusion/datasets/image_caption_latents.py @@ -14,7 +14,8 @@ from torch.utils.data import DataLoader from torchvision import transforms -from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransform, RandomCropSquare +from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform, + RandomCropBucketedAspectRatioTransform, RandomCropSquare) from diffusion.datasets.utils import make_streams log = logging.getLogger(__name__) @@ -172,6 +173,7 @@ def build_streaming_image_caption_latents_dataloader( text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)), attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), latent_dtype: str = 'torch.bfloat16', + aspect_ratio_bucket_key: Optional[str] = None, streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, ): @@ -190,11 +192,12 @@ def build_streaming_image_caption_latents_dataloader( ``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected. Default: ``None``. transform (Callable, optional): The transforms to apply to the image. Default: ``None``. - crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio']. + crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']. Default: ``'square'``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``. caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``. + aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``. text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset. Default: ``('T5_LATENTS', 'CLIP_LATENTS')``. text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset. @@ -204,18 +207,22 @@ def build_streaming_image_caption_latents_dataloader( Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``. latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32', or 'torch.bfloat16'. Default: ``'torch.bfloat16'``. + aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. + Needed if using ``crop_type='bucketed_aspect_ratio'``. Default: ``None``. streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. """ # Check crop type if crop_type is not None: crop_type = crop_type.lower() - if crop_type not in ['square', 'random', 'aspect_ratio']: - raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]') - if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)): + if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']: raise ValueError( - 'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.') - + f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]' + ) + if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or + isinstance(resize_size[0], int)): + raise ValueError( + 'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.') # Check latent dtype dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16} assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}' @@ -237,6 +244,9 @@ def build_streaming_image_caption_latents_dataloader( crop = RandomCropSquare(resize_size) elif crop_type == 'aspect_ratio': crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore + elif crop_type == 'bucketed_aspect_ratio': + assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type' + crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore else: crop = None From e7fcb59e7c6676b8c100554ea8d52177ed0d632f Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Sat, 28 Sep 2024 00:27:58 +0000 Subject: [PATCH 24/24] Fix some missing keys --- diffusion/datasets/image_caption_latents.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/diffusion/datasets/image_caption_latents.py b/diffusion/datasets/image_caption_latents.py index dd2cbcc8..18fb2ffa 100644 --- a/diffusion/datasets/image_caption_latents.py +++ b/diffusion/datasets/image_caption_latents.py @@ -33,6 +33,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset): image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``. caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``. + aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``. text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset. Default: ``('T5_LATENTS', 'CLIP_LATENTS')``. text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset. @@ -55,6 +56,7 @@ def __init__( image_key: str = 'image', caption_keys: Tuple[str, ...] = ('caption',), caption_selection_probs: Tuple[float, ...] = (1.0,), + aspect_ratio_bucket_key: Optional[str] = None, text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)), attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), @@ -75,6 +77,9 @@ def __init__( self.image_key = image_key self.caption_keys = caption_keys self.caption_selection_probs = caption_selection_probs + self.aspect_ratio_bucket_key = aspect_ratio_bucket_key + if isinstance(self.crop, RandomCropBucketedAspectRatioTransform): + assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform' self.text_latent_keys = text_latent_keys self.text_latent_shapes = text_latent_shapes self.attention_mask_keys = attention_mask_keys @@ -94,15 +99,16 @@ def __getitem__(self, index): out['cond_original_size'] = torch.tensor(img.size) # Image transforms - if self.crop is not None: + if isinstance(self.crop, RandomCropBucketedAspectRatioTransform): + img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key]) + elif self.crop is not None: img, crop_top, crop_left = self.crop(img) else: crop_top, crop_left = 0, 0 - out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) - if self.transform is not None: img = self.transform(img) out['image'] = img + out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) # Get the new height and width if isinstance(img, torch.Tensor): @@ -264,6 +270,7 @@ def build_streaming_image_caption_latents_dataloader( image_key=image_key, caption_keys=caption_keys, caption_selection_probs=caption_selection_probs, + aspect_ratio_bucket_key=aspect_ratio_bucket_key, text_latent_keys=text_latent_keys, text_latent_shapes=text_latent_shapes, attention_mask_keys=attention_mask_keys,