From f926bfd0bfbdf2a8420b9cfbcb6c9ea59854eaaf Mon Sep 17 00:00:00 2001 From: zjkjzj Date: Sat, 10 Oct 2020 23:41:51 +0800 Subject: [PATCH] =?UTF-8?q?fix(model):=20=E5=8C=85=E8=A3=85DDP=E4=B9=8B?= =?UTF-8?q?=E5=89=8D=E5=85=88=E5=B0=86=E6=A8=A1=E5=9E=8B=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E5=A4=8D=E5=88=B6=E5=88=B0=E6=8C=87=E5=AE=9AGPU=20device?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/train.py | 2 +- tsn/model/build.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/train.py b/tools/train.py index 4e935ac..00301e4 100644 --- a/tools/train.py +++ b/tools/train.py @@ -32,7 +32,7 @@ def train(gpu, args, cfg): device = torch.device(f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu') map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} - model = build_model(cfg, gpu, map_location=map_location).to(device) + model = build_model(cfg, gpu, map_location=map_location) criterion = build_criterion(cfg) optimizer = build_optimizer(cfg, model) lr_scheduler = build_lr_scheduler(cfg, optimizer) diff --git a/tsn/model/build.py b/tsn/model/build.py index 01cd41e..ff120c5 100644 --- a/tsn/model/build.py +++ b/tsn/model/build.py @@ -22,7 +22,7 @@ def build_model(cfg, gpu, map_location=None, logger=None): - model = registry.RECOGNIZER[cfg.MODEL.RECOGNIZER.NAME](cfg, map_location=map_location) + model = registry.RECOGNIZER[cfg.MODEL.RECOGNIZER.NAME](cfg, map_location=map_location).cuda(gpu) world_size = du.get_world_size() rank = du.get_rank()