Skip to content

Commit

Permalink
Make latent dtype configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Aug 26, 2024
1 parent 054e853 commit aea93ea
Showing 1 changed file with 31 additions and 18 deletions.
49 changes: 31 additions & 18 deletions diffusion/datasets/image_caption_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,25 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
Default: ``((512, 4096), (77, 768))``.
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset.
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``.
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
"""

def __init__(
self,
streams: Sequence[Stream],
caption_drop_prob: float = 0.0,
microcond_drop_prob: float = 0.0,
crop: Optional[Callable] = None,
transform: Optional[Callable] = None,
image_key: str = 'image',
caption_keys: Tuple[str, ...] = ('caption',),
caption_selection_probs: Tuple[float, ...] = (1.0,),
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
**streaming_kwargs,
self,
streams: Sequence[Stream],
caption_drop_prob: float = 0.0,
microcond_drop_prob: float = 0.0,
crop: Optional[Callable] = None,
transform: Optional[Callable] = None,
image_key: str = 'image',
caption_keys: Tuple[str, ...] = ('caption',),
caption_selection_probs: Tuple[float, ...] = (1.0,),
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
latent_dtype: torch.dtype = torch.bfloat16,
**streaming_kwargs,
):

# Set defaults for vision-friendly streaming args.
Expand All @@ -73,6 +75,7 @@ def __init__(
self.text_latent_keys = text_latent_keys
self.text_latent_shapes = text_latent_shapes
self.attention_mask_keys = attention_mask_keys
self.latent_dtype = latent_dtype

def __getitem__(self, index):
sample = super().__getitem__(index)
Expand Down Expand Up @@ -124,18 +127,19 @@ def __getitem__(self, index):
attention_key = f'{caption_key}_{self.attention_mask_keys[i]}'

if torch.rand(1) < self.caption_drop_prob:
out[self.text_latent_keys[i]] = torch.zeros(latent_shape, dtype=torch.float16)
out[self.text_latent_keys[i]] = torch.zeros(latent_shape, dtype=self.latent_dtype)
out[self.attention_mask_keys[i]] = torch.zeros(latent_shape[0])
if 'CLIP_LATENTS' in latent_key:
out['CLIP_POOLED'] = torch.zeros(latent_shape[1])
else:
text_latent = np.frombuffer(sample[latent_key], dtype=np.float16).copy()
out[self.text_latent_keys[i]] = torch.from_numpy(text_latent).reshape(latent_shape)
text_latent = np.frombuffer(sample[latent_key], dtype=np.float32).copy()
out[self.text_latent_keys[i]] = torch.from_numpy(text_latent).to(
self.latent_dtype).reshape(latent_shape)
attention_mask = np.frombuffer(sample[attention_key], dtype=np.bool_).copy()
out[self.attention_mask_keys[i]] = torch.from_numpy(attention_mask).to(dtype=torch.float).reshape(-1) #.reshape(latent_shape[0])
if 'CLIP_LATENTS' in latent_key:
clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float16).copy()
out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).reshape(latent_shape[1])
clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy()
out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1])
return out


Expand All @@ -155,6 +159,7 @@ def build_streaming_image_caption_latents_dataloader(
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
latent_dtype: str = 'torch.bfloat16',
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand Down Expand Up @@ -185,6 +190,8 @@ def build_streaming_image_caption_latents_dataloader(
Default: ``((512, 4096), (77, 768))``.
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset.
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32',
or 'torch.bfloat16'. Default: ``'torch.bfloat16'``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
Expand All @@ -197,6 +204,11 @@ def build_streaming_image_caption_latents_dataloader(
raise ValueError(
'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.')

# Check latent dtype
dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16}
assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}'
dtype = dtypes[latent_dtype]

# Handle ``None`` kwargs
if streaming_kwargs is None:
streaming_kwargs = {}
Expand Down Expand Up @@ -233,6 +245,7 @@ def build_streaming_image_caption_latents_dataloader(
text_latent_keys=text_latent_keys,
text_latent_shapes=text_latent_shapes,
attention_mask_keys=attention_mask_keys,
latent_dtype=dtype,
**streaming_kwargs,
)

Expand Down

0 comments on commit aea93ea

Please sign in to comment.