From c84836c0df673af9395d65b96a349ef28419be4d Mon Sep 17 00:00:00 2001 From: rishab-partha Date: Wed, 21 Aug 2024 06:05:16 +0000 Subject: [PATCH] add fsdp wrapping --- diffusion/models/models.py | 40 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 8fd5a33d..737b125b 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -233,6 +233,26 @@ def stable_diffusion_2( target_modules=['to_k', 'to_q', 'to_v', 'to_out.0'], ) model.unet.add_adapter(unet_lora_config) + model.unet._fsdp_wrap = True + if hasattr(model.unet, 'mid_block') and model.unet.mid_block is not None: + for attention in model.unet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in model.unet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True if torch.cuda.is_available(): model = DeviceGPU().module_to_device(model) @@ -518,6 +538,26 @@ def stable_diffusion_xl( target_modules=['to_k', 'to_q', 'to_v', 'to_out.0'], ) model.unet.add_adapter(unet_lora_config) + model.unet._fsdp_wrap = True + if hasattr(model.unet, 'mid_block') and model.unet.mid_block is not None: + for attention in model.unet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in model.unet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True if torch.cuda.is_available(): model = DeviceGPU().module_to_device(model) if is_xformers_installed and use_xformers: