Skip to content

Commit

Permalink
add option merged_dtype to clip_concrete_adamerging method
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Sep 24, 2024
1 parent 23f7712 commit 7d4f0ee
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions config/method/clip_concrete_layer_wise_adamerging.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ name: clip_concrete_layer_wise_adamerging
batch_size: 16
num_workers: 8

merge_dtype: null
optimizer: adam
lr: 1e-3
base_lr: 1
Expand Down
1 change: 1 addition & 0 deletions config/method/clip_concrete_task_wise_adamerging.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ name: clip_concrete_task_wise_adamerging
batch_size: 16
num_workers: 8

merge_dtype: null
optimizer: adam
lr: 1e-3
base_lr: 1
Expand Down
11 changes: 11 additions & 0 deletions fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
get_task_wise_weights,
)
from fusion_bench.tasks.clip_classification.clip_mixin import CLIPClassificationMixin
from fusion_bench.utils.dtype import parse_dtype
from fusion_bench.utils.parameters import print_parameters
from fusion_bench.utils.type import StateDictType

Expand All @@ -55,6 +56,8 @@ class ConcreteTaskWiseAdaMergingForCLIP(
):
@torch.no_grad()
def setup_models(self):
config = self.config
self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
modelpool = self.modelpool

# Load the pretrained model
Expand All @@ -66,6 +69,8 @@ def setup_models(self):
ignore_untrained_params=True,
parameter_type="logits",
)
if self.merge_dtype is not None:
mask_model.to(self.merge_dtype)
mask_model.fill_(self.config.initial_logits)
# TODO: ablation study for the initialization of mask model
# for param in mask_model.parameters():
Expand All @@ -92,6 +97,7 @@ def setup_models(self):
clamp_weights=self.config.clamp_weights,
tie_weights=self.config.tie_weights,
strict=self.config.strict,
task_vector_dtype=self.merge_dtype,
)
return module, mask_model

Expand Down Expand Up @@ -321,6 +327,8 @@ class ConcreteLayerWiseAdaMergingForCLIP(
):
@torch.no_grad()
def setup_models(self):
config = self.config
self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
modelpool = self.modelpool

# Load the pretrained model
Expand All @@ -332,6 +340,8 @@ def setup_models(self):
ignore_untrained_params=True,
parameter_type="logits",
)
if self.merge_dtype is not None:
mask_model.to(self.merge_dtype)
mask_model.fill_(self.config.initial_logits)
# TODO: ablation study for the initialization of mask model
# for param in mask_model.parameters():
Expand Down Expand Up @@ -361,6 +371,7 @@ def setup_models(self):
clamp_weights=self.config.clamp_weights,
tie_weights=self.config.tie_weights,
strict=self.config.strict,
layer_vector_dtype=self.merge_dtype,
)
return module, mask_model

Expand Down

0 comments on commit 7d4f0ee

Please sign in to comment.