diff --git a/fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py b/fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py index 90116776..c0a2749f 100644 --- a/fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py +++ b/fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py @@ -12,10 +12,12 @@ import logging import os +from copy import deepcopy import torch from tqdm.autonotebook import tqdm +from fusion_bench import separate_io from fusion_bench.method import ModelFusionAlgorithm from fusion_bench.method.adamerging.entropy_loss import entropy_loss from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin @@ -27,6 +29,7 @@ get_task_wise_weights, ) from fusion_bench.tasks.clip_classification.clip_mixin import CLIPClassificationMixin +from fusion_bench.utils.dtype import parse_dtype from fusion_bench.utils.parameters import print_parameters from fusion_bench.utils.type import _StateDict @@ -40,6 +43,8 @@ class ConcreteTaskArithmeticAlgorithmForCLIP( ): @torch.no_grad() def setup_models(self): + config = self.config + self.merge_dtype = parse_dtype(config.get("merge_dtype", None)) modelpool = self.modelpool # Load the pretrained model @@ -51,6 +56,8 @@ def setup_models(self): ignore_untrained_params=True, parameter_type="logits", ) + if self.merge_dtype is not None: + mask_model.to(self.merge_dtype) mask_model.fill_(self.config.initial_logits) # TODO: ablation study for the initialization of mask model # for param in mask_model.parameters(): @@ -76,16 +83,22 @@ def setup_models(self): clamp_weights=self.config.clamp_weights, tie_weights=self.config.tie_weights, strict=self.config.strict, + task_vector_dtype=self.merge_dtype, ) + return module, mask_model - def train_mask(self, module, mask_model: MaskModel): + def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel): config = self.config + # mask_model: MaskModel = self.fabric.to_device(mask_model) # configure optimizer lr_scheduler = None if self.config.optimizer == "adam": - optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr) + optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, mask_model.parameters()), + lr=self.config.lr, + ) print(f"{optimizer=}") # TODO: ablation study for the learning rate scheduler. It should yield similar results. # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( @@ -102,6 +115,9 @@ def train_mask(self, module, mask_model: MaskModel): else: raise ValueError(f"Unsupported optimizer: {self.config.optimizer}") + module.to(mask_model.device) + module.requires_grad_(False) + mask_model.train() optimizer.zero_grad() for step_idx in ( @@ -137,7 +153,7 @@ def train_mask(self, module, mask_model: MaskModel): with self.profile("data loading"): batch = next(self.get_shuffled_test_loader_iter(task)) # NOTE: The labels are not allowed to be used during test-time adaptation - images = batch[0] + images = batch[0].to(dtype=self.merge_dtype) with self.profile("forward pass"): logits = self.compute_logits(module, images, task) loss = entropy_loss(logits) @@ -183,12 +199,12 @@ def run(self, modelpool: HuggingFaceClipVisionPool): with self.profile("setup models"): module, mask_model = self.setup_models() - mask_model: MaskModel = self.fabric.to_device(mask_model) - module: TaskWiseMergedModel = self.fabric.to_device(module) self.setup_zero_shot_classification_head() if config.mask_checkpoint is None: - self.train_mask(module=module, mask_model=mask_model) + if not config.skip_training: + torch.cuda.empty_cache() + self.train_mask(module=module, mask_model=mask_model) else: if self.fabric.is_global_zero: print("loading mask from checkpoint", config.mask_checkpoint) @@ -205,4 +221,4 @@ def run(self, modelpool: HuggingFaceClipVisionPool): for name, m in mask.items(): mask[name] = m / torch.mean(m) model = module.merge_and_unload(mask) - return model + return model.to(dtype=torch.float32) diff --git a/fusion_bench/models/separate_io.py b/fusion_bench/models/separate_io.py index 12e31e99..3ca72cd4 100644 --- a/fusion_bench/models/separate_io.py +++ b/fusion_bench/models/separate_io.py @@ -11,6 +11,10 @@ __all__ = ["separate_save", "separate_load"] +def dir_is_empty(path: str) -> bool: + return not os.path.exists(path) or len(os.listdir(path)) == 0 + + def separate_save( model: nn.Module, save_dir: str, @@ -30,6 +34,9 @@ def separate_save( model_file (str, optional): The name of the file to save the model's architecture. Default is "functional.bin". state_dict_file (str, optional): The name of the file to save the model's state dictionary. Default is "state_dict.bin". """ + if os.path.exists(save_dir) and not dir_is_empty(save_dir): + raise FileExistsError(f"Directory exists and is not empty. {save_dir}") + if not in_place: model = deepcopy(model) state_dict = {} diff --git a/fusion_bench/models/wrappers/task_wise_fusion.py b/fusion_bench/models/wrappers/task_wise_fusion.py index 0eac8517..e56de8b1 100644 --- a/fusion_bench/models/wrappers/task_wise_fusion.py +++ b/fusion_bench/models/wrappers/task_wise_fusion.py @@ -171,11 +171,13 @@ def __init__( clamp_weights: bool = True, tie_weights: bool = False, strict: bool = True, + task_vector_dtype: Optional[torch.dtype] = None, ): super().__init__() self.clamp_weights = clamp_weights self.tie_weights = tie_weights self.strict = strict + self.task_vector_dtype = task_vector_dtype self.merge_weight = nn.Parameter(task_wise_weight, requires_grad=True) @@ -192,6 +194,8 @@ def __init__( for m in finetuned_models: m.requires_grad_(False) self.task_vectors = nn.ModuleList(finetuned_models) + if self.task_vector_dtype is not None: + self.task_vectors = self.task_vectors.to(self.task_vector_dtype) @property def forward_model(self): diff --git a/fusion_bench/tasks/clip_classification/clip_mixin.py b/fusion_bench/tasks/clip_classification/clip_mixin.py index dd2f9ce3..09c78e9f 100644 --- a/fusion_bench/tasks/clip_classification/clip_mixin.py +++ b/fusion_bench/tasks/clip_classification/clip_mixin.py @@ -31,6 +31,7 @@ class CLIPClassificationMixin(LightningFabricMixin): This mixin provides methods to classify images using the CLIP model. """ + config: DictConfig # the modelpool is set by inheriting class modelpool: HuggingFaceClipVisionPool = None _clip_processor: CLIPProcessor = None @@ -84,6 +85,7 @@ def get_shuffled_test_loader_iter(self, task: str): loader = self.fabric.setup_dataloaders(loader) return iter(InfiniteDataLoader(loader)) + @torch.no_grad() def setup_zero_shot_classification_head( self, clip_processor: Optional[CLIPProcessor] = None, @@ -111,7 +113,7 @@ def setup_zero_shot_classification_head( self.visual_projection = self.to_device(self.visual_projection) self.logit_scale = self.to_device(self.logit_scale) - cache_dir = cache_file = os.path.join( + cache_dir = os.path.join( self.config.get("cache_dir", "outputs"), os.path.normpath(f"{os.path.basename(clip_model_config.path)}"), ) @@ -133,7 +135,9 @@ def setup_zero_shot_classification_head( ) if os.path.exists(cache_file): log.info(f"Loading cached zeroshot weights for task: {task}") - zeroshot_weights = torch.load(cache_file, map_location="cpu") + zeroshot_weights = torch.load( + cache_file, map_location="cpu" + ).detach() else: log.info( f"Construct zero shot classification head for task: {task}" @@ -142,7 +146,7 @@ def setup_zero_shot_classification_head( self.modelpool.get_train_dataset_config(task)["dataset"].name ) clip_classifier.set_classification_task(classnames, templates) - zeroshot_weights = clip_classifier.zeroshot_weights + zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone() log.info(f"save zeroshot weights to {cache_file}") torch.save(zeroshot_weights, cache_file)