【VRCNet】随机种子、动态导入(模型)、学习率、lr_decay、优化方法、varying_constant、加载模型

作者 by Tianzhi Jia / 2022-03-19 / 暂无评论 / 200 个足迹

import random
import importlib
import torch.optim as optim
if not args.manual_seed:
    seed = random.randint(1, 10000)
else:
    seed = int(args.manual_seed)
    
logging.info('Random Seed: %d' % seed)

random.seed(seed)
torch.manual_seed(seed)
model_module = importlib.import_module('.%s' % args.model_name, 'models')
net = torch.nn.DataParallel(model_module.Model(args))
net.cuda()
if hasattr(model_module, 'weights_init'): # 用于判断对象是否包含对应的属性
    net.module.apply(model_module.weights_init)

net.module.encoder.conv1.weight.shape
cascade_gan = (args.model_name == 'cascade')
net_d = None
if cascade_gan:
    net_d = torch.nn.DataParallel(model_module.Discriminator(args))
    net_d.cuda()
    net_d.module.apply(model_module.weights_init)
lr = args.lr
if cascade_gan:
    lr_d = lr / 2
if args.lr_decay:
    if args.lr_decay_interval and args.lr_step_decay_epochs:
        raise ValueError('lr_decay_interval and lr_step_decay_epochs are mutually exclusive!')
    if args.lr_step_decay_epochs:
        decay_epoch_list = [int(ep.strip()) for ep in args.lr_step_decay_epochs.split(',')]
        decay_rate_list = [float(rt.strip()) for rt in args.lr_step_decay_rates.split(',')]
optimizer = getattr(optim, args.optimizer) # 用于返回一个对象属性值
if args.optimizer == 'Adagrad':
    optimizer = optimizer(net.module.parameters(), lr=lr, initial_accumulator_value=args.initial_accum_val)
else:
    betas = args.betas.split(',')
    betas = (float(betas[0].strip()), float(betas[1].strip()))
    optimizer = optimizer(net.module.parameters(), lr=lr, weight_decay=args.weight_decay, betas=betas)

if cascade_gan:
    optimizer_d = optim.Adam(net_d.parameters(), lr=lr_d, weight_decay=0.00001, betas=(0.5, 0.999))
alpha = None
if args.varying_constant:
    varying_constant_epochs = [int(ep.strip()) for ep in args.varying_constant_epochs.split(',')]
    varying_constant = [float(c.strip()) for c in args.varying_constant.split(',')]
    assert len(varying_constant) == len(varying_constant_epochs) + 1
if args.load_model:
    ckpt = torch.load(args.load_model)
    net.module.load_state_dict(ckpt['net_state_dict'])
    if cascade_gan:
        net_d.module.load_state_dict(ckpt['D_state_dict'])
    logging.info("%s's previous weights loaded." % args.model_name)

独特见解