Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add composer model class for running with precomputed CLIP and T5 text latents #171

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3459613
Initial model class
corystephenson-db Aug 28, 2024
4b5b227
Support truncating embeddings
corystephenson-db Aug 29, 2024
39ed5f7
Truncate before embedding, and add position embeds
corystephenson-db Aug 29, 2024
8ef76e6
Don't need an arg for the loss
corystephenson-db Aug 29, 2024
402b1a5
Prep for inference
corystephenson-db Aug 29, 2024
19cc8fb
Prep for string inputs
corystephenson-db Aug 30, 2024
101c353
Don't need to check for negative prompt existing
corystephenson-db Aug 30, 2024
3402223
Timesteps shall be ints
corystephenson-db Aug 30, 2024
72476c2
Fix off by one
corystephenson-db Aug 30, 2024
363edd1
Add layernorms before sequence concat
corystephenson-db Sep 1, 2024
260eb6c
Changes for running with bf16
corystephenson-db Sep 2, 2024
e30fa2b
Update docstrings and fix types
corystephenson-db Sep 9, 2024
ca94d5b
Drop nans
corystephenson-db Sep 9, 2024
3d7b65e
Clean up names
corystephenson-db Sep 10, 2024
7bc1796
Fix depreciation
corystephenson-db Sep 10, 2024
4c36a06
More name changes
corystephenson-db Sep 11, 2024
040afae
Fixes for running inference
corystephenson-db Sep 11, 2024
9abc15b
Update docstrings
corystephenson-db Sep 11, 2024
ebacb59
Configurable schedulers
corystephenson-db Sep 11, 2024
5ceee3f
Add schedule shifting
corystephenson-db Sep 11, 2024
d950a50
Add option for LoRA
corystephenson-db Sep 11, 2024
40ecb59
Proper tensor timestep
corystephenson-db Sep 14, 2024
479fe54
Add option for pre-bucketed aspect ratio buckets
corystephenson-db Sep 27, 2024
e7fcb59
Fix some missing keys
corystephenson-db Sep 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ def __init__(self,
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)

latent_batch['T5_LATENTS'] = t5_latents
latent_batch['T5_ATTENTION_MASK'] = t5_attention_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally these key names wouldn't be hardcoded and would be grabbed from the dataset class but i think this OK for now

latent_batch['CLIP_LATENTS'] = clip_latents
latent_batch['ATTENTION_MASK'] = torch.cat([t5_attention_mask, clip_attention_mask], dim=1)
latent_batch['CLIP_ATTENTION_MASK'] = clip_attention_mask
latent_batch['CLIP_POOLED'] = clip_pooled
self.batched_latents.append(latent_batch)

Expand All @@ -144,11 +145,10 @@ def eval_start(self, state: State, logger: Logger):
if self.precomputed_latents:
for batch in self.batched_latents:
pooled_prompt = batch['CLIP_POOLED'].cuda()
prompt_mask = batch['ATTENTION_MASK'].cuda()
t5_embeds = model.t5_proj(batch['T5_LATENTS'].cuda())
clip_embeds = model.clip_proj(batch['CLIP_LATENTS'].cuda())
prompt_embeds = torch.cat([t5_embeds, clip_embeds], dim=1)

prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch['T5_LATENTS'].cuda(),
batch['CLIP_LATENTS'].cuda(),
batch['T5_ATTENTION_MASK'].cuda(),
batch['CLIP_ATTENTION_MASK'].cuda())
gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
prompt_mask=prompt_mask,
Expand Down
49 changes: 39 additions & 10 deletions diffusion/datasets/image_caption_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from torch.utils.data import DataLoader
from torchvision import transforms

from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransform, RandomCropSquare
from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform,
RandomCropBucketedAspectRatioTransform, RandomCropSquare)
from diffusion.datasets.utils import make_streams

log = logging.getLogger(__name__)
Expand All @@ -32,6 +33,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
Expand All @@ -40,6 +42,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset.
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``.
drop_nans (bool): Whether to treat samples with NaN latents as dropped captions. Default: ``True``.
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
"""

Expand All @@ -53,10 +56,12 @@ def __init__(
image_key: str = 'image',
caption_keys: Tuple[str, ...] = ('caption',),
caption_selection_probs: Tuple[float, ...] = (1.0,),
aspect_ratio_bucket_key: Optional[str] = None,
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
latent_dtype: torch.dtype = torch.bfloat16,
drop_nans: bool = True,
**streaming_kwargs,
):

Expand All @@ -72,10 +77,14 @@ def __init__(
self.image_key = image_key
self.caption_keys = caption_keys
self.caption_selection_probs = caption_selection_probs
self.aspect_ratio_bucket_key = aspect_ratio_bucket_key
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform'
self.text_latent_keys = text_latent_keys
self.text_latent_shapes = text_latent_shapes
self.attention_mask_keys = attention_mask_keys
self.latent_dtype = latent_dtype
self.drop_nans = drop_nans

def __getitem__(self, index):
sample = super().__getitem__(index)
Expand All @@ -90,15 +99,16 @@ def __getitem__(self, index):
out['cond_original_size'] = torch.tensor(img.size)

# Image transforms
if self.crop is not None:
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key])
elif self.crop is not None:
img, crop_top, crop_left = self.crop(img)
else:
crop_top, crop_left = 0, 0
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])

if self.transform is not None:
img = self.transform(img)
out['image'] = img
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])

# Get the new height and width
if isinstance(img, torch.Tensor):
Expand Down Expand Up @@ -140,6 +150,15 @@ def __getitem__(self, index):
if 'CLIP_LATENTS' in latent_key:
clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy()
out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1])
if self.drop_nans:
if out['CLIP_POOLED'].isnan().any():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will work for us, but this section will fail if the text latent keys are named differently - probably safer to loop thru self.text_latent_keys and self.attention_mask_keys, e.g. something like:

if self.drop_nans:
    for latent_key, attn_key in zip(self.text_latent_keys, self.attention_mask_keys):
        if out[latent_key].isnan().any():
            out[latent_key] = torch.zeros_like(out[latent_key])
            out[attn_key] = torch.zeros_like(out[attn_key])
        if 'CLIP_LATENTS' in latent_key:
            out['CLIP_POOLED'] = torch.zeros_like(out['CLIP_POOLED'])

out['CLIP_POOLED'] = torch.zeros_like(out['CLIP_POOLED'])
if out['T5_LATENTS'].isnan().any():
out['T5_LATENTS'] = torch.zeros_like(out['T5_LATENTS'])
out['T5_ATTENTION_MASK'] = torch.zeros_like(out['T5_ATTENTION_MASK'])
if out['CLIP_LATENTS'].isnan().any():
out['CLIP_LATENTS'] = torch.zeros_like(out['CLIP_LATENTS'])
out['CLIP_ATTENTION_MASK'] = torch.zeros_like(out['CLIP_ATTENTION_MASK'])
return out


Expand All @@ -160,6 +179,7 @@ def build_streaming_image_caption_latents_dataloader(
text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
latent_dtype: str = 'torch.bfloat16',
aspect_ratio_bucket_key: Optional[str] = None,
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand All @@ -178,11 +198,12 @@ def build_streaming_image_caption_latents_dataloader(
``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected.
Default: ``None``.
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio'].
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio'].
Default: ``'square'``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
Expand All @@ -192,18 +213,22 @@ def build_streaming_image_caption_latents_dataloader(
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32',
or 'torch.bfloat16'. Default: ``'torch.bfloat16'``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset.
Needed if using ``crop_type='bucketed_aspect_ratio'``. Default: ``None``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
# Check crop type
if crop_type is not None:
crop_type = crop_type.lower()
if crop_type not in ['square', 'random', 'aspect_ratio']:
raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]')
if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)):
if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']:
raise ValueError(
'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.')

f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]'
)
if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or
isinstance(resize_size[0], int)):
raise ValueError(
'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.')
# Check latent dtype
dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16}
assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}'
Expand All @@ -225,6 +250,9 @@ def build_streaming_image_caption_latents_dataloader(
crop = RandomCropSquare(resize_size)
elif crop_type == 'aspect_ratio':
crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore
elif crop_type == 'bucketed_aspect_ratio':
assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type'
crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore
else:
crop = None

Expand All @@ -242,6 +270,7 @@ def build_streaming_image_caption_latents_dataloader(
image_key=image_key,
caption_keys=caption_keys,
caption_selection_probs=caption_selection_probs,
aspect_ratio_bucket_key=aspect_ratio_bucket_key,
text_latent_keys=text_latent_keys,
text_latent_shapes=text_latent_shapes,
attention_mask_keys=attention_mask_keys,
Expand Down
7 changes: 5 additions & 2 deletions diffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""Diffusion models."""

from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion,
discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl,
text_to_image_transformer)
discrete_pixel_diffusion, precomputed_text_latent_diffusion, stable_diffusion_2,
stable_diffusion_xl, text_to_image_transformer)
from diffusion.models.noop import NoOpModel
from diffusion.models.pixel_diffusion import PixelDiffusion
from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion
from diffusion.models.stable_diffusion import StableDiffusion

__all__ = [
Expand All @@ -17,8 +18,10 @@
'discrete_pixel_diffusion',
'NoOpModel',
'PixelDiffusion',
'precomputed_text_latent_diffusion',
'stable_diffusion_2',
'stable_diffusion_xl',
'StableDiffusion',
'PrecomputedTextLatentDiffusion',
'text_to_image_transformer',
]
Loading
Loading