Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add latent logger for T5-XXL text encoder #154

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 101 additions & 12 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

"""Logger for generated images."""

import gc
from math import ceil
from typing import List, Optional, Tuple, Union

import torch
from composer import Callback, Logger, State
from composer.core import TimeUnit, get_precision_context
from torch.nn.parallel import DistributedDataParallel
from transformers import AutoModel, AutoTokenizer, CLIPTextModel


class LogDiffusionImages(Callback):
Expand All @@ -35,6 +37,9 @@ class LogDiffusionImages(Callback):
seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation.
Default: ``1138``.
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
t5_encoder (str, optional): path to the T5 encoder to as a second text encoder.
clip_encoder (str, optional): path to the CLIP encoder as the first text encoder.
cache_dir: (str, optional): path for HF to cache files while downloading model
"""

def __init__(self,
Expand All @@ -45,14 +50,18 @@ def __init__(self,
guidance_scale: float = 0.0,
rescaled_guidance: Optional[float] = None,
seed: Optional[int] = 1138,
use_table: bool = False):
use_table: bool = False,
t5_encoder: Optional[str] = None,
clip_encoder: Optional[str] = None,
cache_dir: Optional[str] = '/tmp/hf_files'):
self.prompts = prompts
self.size = (size, size) if isinstance(size, int) else size
self.num_inference_steps = num_inference_steps
self.guidance_scale = guidance_scale
self.rescaled_guidance = rescaled_guidance
self.seed = seed
self.use_table = use_table
self.cache_dir = cache_dir

# Batch prompts
batch_size = len(prompts) if batch_size is None else batch_size
Expand All @@ -62,6 +71,66 @@ def __init__(self,
start, end = i * batch_size, (i + 1) * batch_size
self.batched_prompts.append(prompts[start:end])

if t5_encoder is not None and clip_encoder is None or t5_encoder is None and clip_encoder is not None:
raise ValueError('Cannot specify only one of text encoder and CLIP encoder.')

self.precomputed_latents = False
self.batched_latents = []
if t5_encoder:
self.precomputed_latents = True
t5_tokenizer = AutoTokenizer.from_pretrained(t5_encoder, cache_dir=self.cache_dir, local_files_only=True)
clip_tokenizer = AutoTokenizer.from_pretrained(clip_encoder,
subfolder='tokenizer',
cache_dir=self.cache_dir,
local_files_only=True)

t5_model = AutoModel.from_pretrained(t5_encoder,
torch_dtype=torch.float16,
cache_dir=self.cache_dir,
local_files_only=True).encoder.cuda().eval()
clip_model = CLIPTextModel.from_pretrained(clip_encoder,
subfolder='text_encoder',
torch_dtype=torch.float16,
cache_dir=self.cache_dir,
local_files_only=True).cuda().eval()

for batch in self.batched_prompts:
latent_batch = {}
tokenized_t5 = t5_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model.max_length,
rishab-partha marked this conversation as resolved.
Show resolved Hide resolved
truncation=True,
return_tensors='pt')
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
t5_ids = tokenized_t5['input_ids'].cuda()
t5_latents = t5_model(input_ids=t5_ids, attention_mask=t5_attention_mask)[0].cpu()
t5_attention_mask = t5_attention_mask.cpu().to(torch.long)

tokenized_clip = clip_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model.max_length,
rishab-partha marked this conversation as resolved.
Show resolved Hide resolved
truncation=True,
return_tensors='pt')
clip_attention_mask = tokenized_clip['attention_mask'].cuda()
clip_ids = tokenized_clip['input_ids'].cuda()
clip_outputs = clip_model(input_ids=clip_ids,
attention_mask=clip_attention_mask,
output_hidden_states=True)
clip_latents = clip_outputs.hidden_states[-2].cpu()
clip_pooled = clip_outputs[1].cpu()
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)

latent_batch['T5_LATENTS'] = t5_latents
latent_batch['CLIP_LATENTS'] = clip_latents
latent_batch['ATTENTION_MASK'] = torch.cat([t5_attention_mask, clip_attention_mask], dim=1)
latent_batch['CLIP_POOLED'] = clip_pooled
self.batched_latents.append(latent_batch)

del t5_model
del clip_model
gc.collect()
torch.cuda.empty_cache()

def eval_start(self, state: State, logger: Logger):
# Get the model object if it has been wrapped by DDP to access the image generation function.
if isinstance(state.model, DistributedDataParallel):
Expand All @@ -72,17 +141,37 @@ def eval_start(self, state: State, logger: Logger):
# Generate images
with get_precision_context(state.precision):
all_gen_images = []
for batch in self.batched_prompts:
gen_images = model.generate(
prompt=batch, # type: ignore
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
if self.precomputed_latents:
for batch in self.batched_latents:
pooled_prompt = batch['CLIP_POOLED'].cuda()
prompt_mask = batch['ATTENTION_MASK'].cuda()
t5_embeds = model.t5_proj(batch['T5_LATENTS'].cuda())
clip_embeds = model.clip_proj(batch['CLIP_LATENTS'].cuda())
prompt_embeds = torch.cat([t5_embeds, clip_embeds], dim=1)

gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
prompt_mask=prompt_mask,
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
else:
for batch in self.batched_prompts:
gen_images = model.generate(
prompt=batch, # type: ignore
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
gen_images = torch.cat(all_gen_images)

# Log images to wandb
Expand Down
Loading