1 参数定义
需要定义以下参数以从断点继续训练:
# 训练的总轮次
parser.add_argument(
"--epochs",type=int,metavar="N",help="number of total epochs to run"
)
# 适用于训练中断,重新从某个检查点中断的位置开始训练,训练轮次为epochs - start-epoch
parser.add_argument(
"--start-epoch",
default=0,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)"
)
# 加载检查点模型文件
parser.add_argument(
"--resume",
default="",
type=str,
metavar="PATH",
help="path to latest checkpoint (default: none)"
)
2 模型和优化器构建
完成模型和优化器构建(以MoCo代码为例):
# 模型构建
model = moco.builder.MoCo(
models.__dict__[args.arch],
args.moco_dim,
args.moco_k,
args.moco_m,
args.moco_t,
args.mlp,
)
# 优化器构建
optimizer = torch.optim.SGD(
model.parameters(),
args.lr,
momentum = args.momentum,
weight_decay = args.weight_decay
)
3 断点重启
在main函数中或者其他设置的函数中添加以下代码:
# 从检测点中选择一个重启的点
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
# 默认加载方式(通常会自动加载到当前设备)
checkpoint = torch.load(args.resume)
else:
# 将待加载的模型映射到指定的单个GPU。
loc = "cuda:{}".format(args.gpu)
checkpoint = torch.load(args.resume, map_location = loc)
args.start_epoch = checkpoint['epoch']
# state_dict 是PyTorch中保存模型所有可想学习参数(权重和偏置)的字典
model.load_state_dict(checkpoint['state_dict'])
# 恢复优化器的内部状态
optimizer.load_state_dict(checkpoint['optimizer'])
print(
"=> loaded checkpoint '{}' (epoch {})".format(
args.resume, checkpoint['epoch']
)
)
else:
print("=> no checkpoint found at '{}'".format(args.resume))
4 训练周期定义
训练周期定义从args.start_epoch到args.epochs:
for epoch in range(args.start_epoch, args.epochs):