Skip to content

Commit

Permalink
add merge_dtype option to concrete task arithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Aug 27, 2024
1 parent bb5bcbf commit eaa4f38
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

import logging
import os
from copy import deepcopy

import torch
from tqdm.autonotebook import tqdm

from fusion_bench import separate_io
from fusion_bench.method import ModelFusionAlgorithm
from fusion_bench.method.adamerging.entropy_loss import entropy_loss
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
Expand All @@ -27,6 +29,7 @@
get_task_wise_weights,
)
from fusion_bench.tasks.clip_classification.clip_mixin import CLIPClassificationMixin
from fusion_bench.utils.dtype import parse_dtype
from fusion_bench.utils.parameters import print_parameters
from fusion_bench.utils.type import _StateDict

Expand All @@ -40,6 +43,8 @@ class ConcreteTaskArithmeticAlgorithmForCLIP(
):
@torch.no_grad()
def setup_models(self):
config = self.config
self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
modelpool = self.modelpool

# Load the pretrained model
Expand All @@ -51,6 +56,8 @@ def setup_models(self):
ignore_untrained_params=True,
parameter_type="logits",
)
if self.merge_dtype is not None:
mask_model.to(self.merge_dtype)
mask_model.fill_(self.config.initial_logits)
# TODO: ablation study for the initialization of mask model
# for param in mask_model.parameters():
Expand All @@ -76,16 +83,22 @@ def setup_models(self):
clamp_weights=self.config.clamp_weights,
tie_weights=self.config.tie_weights,
strict=self.config.strict,
task_vector_dtype=self.merge_dtype,
)

return module, mask_model

def train_mask(self, module, mask_model: MaskModel):
def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
config = self.config
# mask_model: MaskModel = self.fabric.to_device(mask_model)

# configure optimizer
lr_scheduler = None
if self.config.optimizer == "adam":
optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, mask_model.parameters()),
lr=self.config.lr,
)
print(f"{optimizer=}")
# TODO: ablation study for the learning rate scheduler. It should yield similar results.
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
Expand All @@ -102,6 +115,9 @@ def train_mask(self, module, mask_model: MaskModel):
else:
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

module.to(mask_model.device)
module.requires_grad_(False)

mask_model.train()
optimizer.zero_grad()
for step_idx in (
Expand Down Expand Up @@ -137,7 +153,7 @@ def train_mask(self, module, mask_model: MaskModel):
with self.profile("data loading"):
batch = next(self.get_shuffled_test_loader_iter(task))
# NOTE: The labels are not allowed to be used during test-time adaptation
images = batch[0]
images = batch[0].to(dtype=self.merge_dtype)
with self.profile("forward pass"):
logits = self.compute_logits(module, images, task)
loss = entropy_loss(logits)
Expand Down Expand Up @@ -183,12 +199,12 @@ def run(self, modelpool: HuggingFaceClipVisionPool):

with self.profile("setup models"):
module, mask_model = self.setup_models()
mask_model: MaskModel = self.fabric.to_device(mask_model)
module: TaskWiseMergedModel = self.fabric.to_device(module)
self.setup_zero_shot_classification_head()

if config.mask_checkpoint is None:
self.train_mask(module=module, mask_model=mask_model)
if not config.skip_training:
torch.cuda.empty_cache()
self.train_mask(module=module, mask_model=mask_model)
else:
if self.fabric.is_global_zero:
print("loading mask from checkpoint", config.mask_checkpoint)
Expand All @@ -205,4 +221,4 @@ def run(self, modelpool: HuggingFaceClipVisionPool):
for name, m in mask.items():
mask[name] = m / torch.mean(m)
model = module.merge_and_unload(mask)
return model
return model.to(dtype=torch.float32)
7 changes: 7 additions & 0 deletions fusion_bench/models/separate_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
__all__ = ["separate_save", "separate_load"]


def dir_is_empty(path: str) -> bool:
return not os.path.exists(path) or len(os.listdir(path)) == 0


def separate_save(
model: nn.Module,
save_dir: str,
Expand All @@ -30,6 +34,9 @@ def separate_save(
model_file (str, optional): The name of the file to save the model's architecture. Default is "functional.bin".
state_dict_file (str, optional): The name of the file to save the model's state dictionary. Default is "state_dict.bin".
"""
if os.path.exists(save_dir) and not dir_is_empty(save_dir):
raise FileExistsError(f"Directory exists and is not empty. {save_dir}")

if not in_place:
model = deepcopy(model)
state_dict = {}
Expand Down
4 changes: 4 additions & 0 deletions fusion_bench/models/wrappers/task_wise_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,13 @@ def __init__(
clamp_weights: bool = True,
tie_weights: bool = False,
strict: bool = True,
task_vector_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.clamp_weights = clamp_weights
self.tie_weights = tie_weights
self.strict = strict
self.task_vector_dtype = task_vector_dtype

self.merge_weight = nn.Parameter(task_wise_weight, requires_grad=True)

Expand All @@ -192,6 +194,8 @@ def __init__(
for m in finetuned_models:
m.requires_grad_(False)
self.task_vectors = nn.ModuleList(finetuned_models)
if self.task_vector_dtype is not None:
self.task_vectors = self.task_vectors.to(self.task_vector_dtype)

@property
def forward_model(self):
Expand Down
10 changes: 7 additions & 3 deletions fusion_bench/tasks/clip_classification/clip_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
This mixin provides methods to classify images using the CLIP model.
"""

config: DictConfig
# the modelpool is set by inheriting class
modelpool: HuggingFaceClipVisionPool = None
_clip_processor: CLIPProcessor = None
Expand Down Expand Up @@ -84,6 +85,7 @@ def get_shuffled_test_loader_iter(self, task: str):
loader = self.fabric.setup_dataloaders(loader)
return iter(InfiniteDataLoader(loader))

@torch.no_grad()
def setup_zero_shot_classification_head(
self,
clip_processor: Optional[CLIPProcessor] = None,
Expand Down Expand Up @@ -111,7 +113,7 @@ def setup_zero_shot_classification_head(
self.visual_projection = self.to_device(self.visual_projection)
self.logit_scale = self.to_device(self.logit_scale)

cache_dir = cache_file = os.path.join(
cache_dir = os.path.join(
self.config.get("cache_dir", "outputs"),
os.path.normpath(f"{os.path.basename(clip_model_config.path)}"),
)
Expand All @@ -133,7 +135,9 @@ def setup_zero_shot_classification_head(
)
if os.path.exists(cache_file):
log.info(f"Loading cached zeroshot weights for task: {task}")
zeroshot_weights = torch.load(cache_file, map_location="cpu")
zeroshot_weights = torch.load(
cache_file, map_location="cpu"
).detach()
else:
log.info(
f"Construct zero shot classification head for task: {task}"
Expand All @@ -142,7 +146,7 @@ def setup_zero_shot_classification_head(
self.modelpool.get_train_dataset_config(task)["dataset"].name
)
clip_classifier.set_classification_task(classnames, templates)
zeroshot_weights = clip_classifier.zeroshot_weights
zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
log.info(f"save zeroshot weights to {cache_file}")
torch.save(zeroshot_weights, cache_file)

Expand Down

0 comments on commit eaa4f38

Please sign in to comment.