本文记录一下Pytorch中的核心操作之一——Distributed Data Parallel (分布式数据并行)
训练时:
1 2 3 4 5 6 7 8 9 10 11 12
import torch.nn as nn import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3, 5" multi_gpu = True model = Model(args) if multi_gpu: print("training on multi_gpu: ") torch.cuda.empty_cache() model = nn.DataParallel(model) model.train(True) model.cuda()
测试时:
1 2 3 4 5 6 7
model = Model(args) if multi_gpu: print("testing on multi_gpu...") model = nn.DataParallel(model) model.load_state_dict(torch.load(path)) model.train(False) model.cuda()