diff --git a/tools/train.py b/tools/train.py index 4e935ac..00301e4 100644 --- a/tools/train.py +++ b/tools/train.py @@ -32,7 +32,7 @@ def train(gpu, args, cfg): device = torch.device(f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu') map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} - model = build_model(cfg, gpu, map_location=map_location).to(device) + model = build_model(cfg, gpu, map_location=map_location) criterion = build_criterion(cfg) optimizer = build_optimizer(cfg, model) lr_scheduler = build_lr_scheduler(cfg, optimizer) diff --git a/tsn/model/build.py b/tsn/model/build.py index 01cd41e..ff120c5 100644 --- a/tsn/model/build.py +++ b/tsn/model/build.py @@ -22,7 +22,7 @@ def build_model(cfg, gpu, map_location=None, logger=None): - model = registry.RECOGNIZER[cfg.MODEL.RECOGNIZER.NAME](cfg, map_location=map_location) + model = registry.RECOGNIZER[cfg.MODEL.RECOGNIZER.NAME](cfg, map_location=map_location).cuda(gpu) world_size = du.get_world_size() rank = du.get_rank()