-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add composer model class for running with precomputed CLIP and T5 text latents #171
Open
coryMosaicML
wants to merge
24
commits into
mosaicml:main
Choose a base branch
from
coryMosaicML:t5-model
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
3459613
Initial model class
corystephenson-db 4b5b227
Support truncating embeddings
corystephenson-db 39ed5f7
Truncate before embedding, and add position embeds
corystephenson-db 8ef76e6
Don't need an arg for the loss
corystephenson-db 402b1a5
Prep for inference
corystephenson-db 19cc8fb
Prep for string inputs
corystephenson-db 101c353
Don't need to check for negative prompt existing
corystephenson-db 3402223
Timesteps shall be ints
corystephenson-db 72476c2
Fix off by one
corystephenson-db 363edd1
Add layernorms before sequence concat
corystephenson-db 260eb6c
Changes for running with bf16
corystephenson-db e30fa2b
Update docstrings and fix types
corystephenson-db ca94d5b
Drop nans
corystephenson-db 3d7b65e
Clean up names
corystephenson-db 7bc1796
Fix depreciation
corystephenson-db 4c36a06
More name changes
corystephenson-db 040afae
Fixes for running inference
corystephenson-db 9abc15b
Update docstrings
corystephenson-db ebacb59
Configurable schedulers
corystephenson-db 5ceee3f
Add schedule shifting
corystephenson-db d950a50
Add option for LoRA
corystephenson-db 40ecb59
Proper tensor timestep
corystephenson-db 479fe54
Add option for pre-bucketed aspect ratio buckets
corystephenson-db e7fcb59
Fix some missing keys
corystephenson-db File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,8 @@ | |
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
|
||
from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransform, RandomCropSquare | ||
from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform, | ||
RandomCropBucketedAspectRatioTransform, RandomCropSquare) | ||
from diffusion.datasets.utils import make_streams | ||
|
||
log = logging.getLogger(__name__) | ||
|
@@ -32,6 +33,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset): | |
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. | ||
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``. | ||
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``. | ||
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``. | ||
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset. | ||
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``. | ||
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset. | ||
|
@@ -40,6 +42,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset): | |
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset. | ||
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``. | ||
latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``. | ||
drop_nans (bool): Whether to treat samples with NaN latents as dropped captions. Default: ``True``. | ||
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader | ||
""" | ||
|
||
|
@@ -53,10 +56,12 @@ def __init__( | |
image_key: str = 'image', | ||
caption_keys: Tuple[str, ...] = ('caption',), | ||
caption_selection_probs: Tuple[float, ...] = (1.0,), | ||
aspect_ratio_bucket_key: Optional[str] = None, | ||
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), | ||
text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)), | ||
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), | ||
latent_dtype: torch.dtype = torch.bfloat16, | ||
drop_nans: bool = True, | ||
**streaming_kwargs, | ||
): | ||
|
||
|
@@ -72,10 +77,14 @@ def __init__( | |
self.image_key = image_key | ||
self.caption_keys = caption_keys | ||
self.caption_selection_probs = caption_selection_probs | ||
self.aspect_ratio_bucket_key = aspect_ratio_bucket_key | ||
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform): | ||
assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform' | ||
self.text_latent_keys = text_latent_keys | ||
self.text_latent_shapes = text_latent_shapes | ||
self.attention_mask_keys = attention_mask_keys | ||
self.latent_dtype = latent_dtype | ||
self.drop_nans = drop_nans | ||
|
||
def __getitem__(self, index): | ||
sample = super().__getitem__(index) | ||
|
@@ -90,15 +99,16 @@ def __getitem__(self, index): | |
out['cond_original_size'] = torch.tensor(img.size) | ||
|
||
# Image transforms | ||
if self.crop is not None: | ||
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform): | ||
img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key]) | ||
elif self.crop is not None: | ||
img, crop_top, crop_left = self.crop(img) | ||
else: | ||
crop_top, crop_left = 0, 0 | ||
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
out['image'] = img | ||
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) | ||
|
||
# Get the new height and width | ||
if isinstance(img, torch.Tensor): | ||
|
@@ -140,6 +150,15 @@ def __getitem__(self, index): | |
if 'CLIP_LATENTS' in latent_key: | ||
clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy() | ||
out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1]) | ||
if self.drop_nans: | ||
if out['CLIP_POOLED'].isnan().any(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will work for us, but this section will fail if the text latent keys are named differently - probably safer to loop thru
|
||
out['CLIP_POOLED'] = torch.zeros_like(out['CLIP_POOLED']) | ||
if out['T5_LATENTS'].isnan().any(): | ||
out['T5_LATENTS'] = torch.zeros_like(out['T5_LATENTS']) | ||
out['T5_ATTENTION_MASK'] = torch.zeros_like(out['T5_ATTENTION_MASK']) | ||
if out['CLIP_LATENTS'].isnan().any(): | ||
out['CLIP_LATENTS'] = torch.zeros_like(out['CLIP_LATENTS']) | ||
out['CLIP_ATTENTION_MASK'] = torch.zeros_like(out['CLIP_ATTENTION_MASK']) | ||
return out | ||
|
||
|
||
|
@@ -160,6 +179,7 @@ def build_streaming_image_caption_latents_dataloader( | |
text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)), | ||
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), | ||
latent_dtype: str = 'torch.bfloat16', | ||
aspect_ratio_bucket_key: Optional[str] = None, | ||
streaming_kwargs: Optional[Dict] = None, | ||
dataloader_kwargs: Optional[Dict] = None, | ||
): | ||
|
@@ -178,11 +198,12 @@ def build_streaming_image_caption_latents_dataloader( | |
``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected. | ||
Default: ``None``. | ||
transform (Callable, optional): The transforms to apply to the image. Default: ``None``. | ||
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio']. | ||
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']. | ||
Default: ``'square'``. | ||
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. | ||
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``. | ||
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``. | ||
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``. | ||
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset. | ||
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``. | ||
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset. | ||
|
@@ -192,18 +213,22 @@ def build_streaming_image_caption_latents_dataloader( | |
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``. | ||
latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32', | ||
or 'torch.bfloat16'. Default: ``'torch.bfloat16'``. | ||
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. | ||
Needed if using ``crop_type='bucketed_aspect_ratio'``. Default: ``None``. | ||
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. | ||
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. | ||
""" | ||
# Check crop type | ||
if crop_type is not None: | ||
crop_type = crop_type.lower() | ||
if crop_type not in ['square', 'random', 'aspect_ratio']: | ||
raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]') | ||
if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)): | ||
if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']: | ||
raise ValueError( | ||
'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.') | ||
|
||
f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]' | ||
) | ||
if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or | ||
isinstance(resize_size[0], int)): | ||
raise ValueError( | ||
'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.') | ||
# Check latent dtype | ||
dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16} | ||
assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}' | ||
|
@@ -225,6 +250,9 @@ def build_streaming_image_caption_latents_dataloader( | |
crop = RandomCropSquare(resize_size) | ||
elif crop_type == 'aspect_ratio': | ||
crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore | ||
elif crop_type == 'bucketed_aspect_ratio': | ||
assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type' | ||
crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore | ||
else: | ||
crop = None | ||
|
||
|
@@ -242,6 +270,7 @@ def build_streaming_image_caption_latents_dataloader( | |
image_key=image_key, | ||
caption_keys=caption_keys, | ||
caption_selection_probs=caption_selection_probs, | ||
aspect_ratio_bucket_key=aspect_ratio_bucket_key, | ||
text_latent_keys=text_latent_keys, | ||
text_latent_shapes=text_latent_shapes, | ||
attention_mask_keys=attention_mask_keys, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ideally these key names wouldn't be hardcoded and would be grabbed from the dataset class but i think this OK for now