Skip to content

Commit

Permalink
add fsdp wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
rishab-partha committed Aug 21, 2024
1 parent 500d523 commit c84836c
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,26 @@ def stable_diffusion_2(
target_modules=['to_k', 'to_q', 'to_v', 'to_out.0'],
)
model.unet.add_adapter(unet_lora_config)
model.unet._fsdp_wrap = True
if hasattr(model.unet, 'mid_block') and model.unet.mid_block is not None:
for attention in model.unet.mid_block.attentions:
attention._fsdp_wrap = True
for resnet in model.unet.mid_block.resnets:
resnet._fsdp_wrap = True
for block in model.unet.up_blocks:
if hasattr(block, 'attentions'):
for attention in block.attentions:
attention._fsdp_wrap = True
if hasattr(block, 'resnets'):
for resnet in block.resnets:
resnet._fsdp_wrap = True
for block in model.unet.down_blocks:
if hasattr(block, 'attentions'):
for attention in block.attentions:
attention._fsdp_wrap = True
if hasattr(block, 'resnets'):
for resnet in block.resnets:
resnet._fsdp_wrap = True

if torch.cuda.is_available():
model = DeviceGPU().module_to_device(model)
Expand Down Expand Up @@ -518,6 +538,26 @@ def stable_diffusion_xl(
target_modules=['to_k', 'to_q', 'to_v', 'to_out.0'],
)
model.unet.add_adapter(unet_lora_config)
model.unet._fsdp_wrap = True
if hasattr(model.unet, 'mid_block') and model.unet.mid_block is not None:
for attention in model.unet.mid_block.attentions:
attention._fsdp_wrap = True
for resnet in model.unet.mid_block.resnets:
resnet._fsdp_wrap = True
for block in model.unet.up_blocks:
if hasattr(block, 'attentions'):
for attention in block.attentions:
attention._fsdp_wrap = True
if hasattr(block, 'resnets'):
for resnet in block.resnets:
resnet._fsdp_wrap = True
for block in model.unet.down_blocks:
if hasattr(block, 'attentions'):
for attention in block.attentions:
attention._fsdp_wrap = True
if hasattr(block, 'resnets'):
for resnet in block.resnets:
resnet._fsdp_wrap = True
if torch.cuda.is_available():
model = DeviceGPU().module_to_device(model)
if is_xformers_installed and use_xformers:
Expand Down

0 comments on commit c84836c

Please sign in to comment.