Skip to content

Commit

Permalink
format code using black
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Aug 25, 2024
1 parent 02ad4ea commit 50b4dbd
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 8 deletions.
14 changes: 13 additions & 1 deletion fusion_bench/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
from . import dataset, method, modelpool, models, taskpool, tasks, utils, constants
from . import (
constants,
dataset,
method,
metrics,
mixins,
modelpool,
models,
optim,
taskpool,
tasks,
utils,
)
2 changes: 1 addition & 1 deletion fusion_bench/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .stats import *
from .paths import *
from .stats import *
2 changes: 1 addition & 1 deletion fusion_bench/method/pwe_moe/phn/solvers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from abc import abstractmethod
from typing import Tuple

import cvxopt
import cvxpy as cp
import numpy as np
import torch
from torch import Tensor
from typing import Tuple

"""Implementation of Pareto HyperNetworks with:
1. Linear scalarization
Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/method/weighted_average/weighted_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from copy import deepcopy
from typing import List, Mapping, Union

import numpy as np
import torch
from torch import Tensor, nn
import numpy as np

from fusion_bench.method.base_algorithm import ModelFusionAlgorithm
from fusion_bench.modelpool import ModelPool, to_modelpool
Expand Down
5 changes: 3 additions & 2 deletions fusion_bench/models/separate_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from copy import deepcopy

import torch
from safetensors import safe_open
from safetensors.torch import save_file
from torch import nn

from fusion_bench.utils.dtype import parse_dtype
from safetensors import safe_open
from safetensors.torch import save_file

__all__ = ["separate_save", "separate_load"]


def separate_save(
model: nn.Module,
save_dir: str,
Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import logging
import os
from typing import Iterable, Union, Dict
from typing import Dict, Iterable, Union

import hydra
from omegaconf import DictConfig, OmegaConf
Expand Down
2 changes: 1 addition & 1 deletion fusion_bench/scripts/nyuv2_mtl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run(self):
LearningRateMonitor(logging_interval="step"),
RichModelSummary(max_depth=1),
ModelCheckpoint(save_last=True),
]
],
)

train_loader = DataLoader(
Expand Down

0 comments on commit 50b4dbd

Please sign in to comment.