作者 by Tianzhi Jia / 2022-03-19 / 暂无评论 / 200 个足迹
import random
import importlib
import torch.optim as optim
if not args.manual_seed:
seed = random.randint(1, 10000)
else:
seed = int(args.manual_seed)
logging.info('Random Seed: %d' % seed)
random.seed(seed)
torch.manual_seed(seed)
model_module = importlib.import_module('.%s' % args.model_name, 'models')
net = torch.nn.DataParallel(model_module.Model(args))
net.cuda()
if hasattr(model_module, 'weights_init'): # 用于判断对象是否包含对应的属性
net.module.apply(model_module.weights_init)
net.module.encoder.conv1.weight.shape
cascade_gan = (args.model_name == 'cascade')
net_d = None
if cascade_gan:
net_d = torch.nn.DataParallel(model_module.Discriminator(args))
net_d.cuda()
net_d.module.apply(model_module.weights_init)
lr = args.lr
if cascade_gan:
lr_d = lr / 2
if args.lr_decay:
if args.lr_decay_interval and args.lr_step_decay_epochs:
raise ValueError('lr_decay_interval and lr_step_decay_epochs are mutually exclusive!')
if args.lr_step_decay_epochs:
decay_epoch_list = [int(ep.strip()) for ep in args.lr_step_decay_epochs.split(',')]
decay_rate_list = [float(rt.strip()) for rt in args.lr_step_decay_rates.split(',')]
optimizer = getattr(optim, args.optimizer) # 用于返回一个对象属性值
if args.optimizer == 'Adagrad':
optimizer = optimizer(net.module.parameters(), lr=lr, initial_accumulator_value=args.initial_accum_val)
else:
betas = args.betas.split(',')
betas = (float(betas[0].strip()), float(betas[1].strip()))
optimizer = optimizer(net.module.parameters(), lr=lr, weight_decay=args.weight_decay, betas=betas)
if cascade_gan:
optimizer_d = optim.Adam(net_d.parameters(), lr=lr_d, weight_decay=0.00001, betas=(0.5, 0.999))
alpha = None
if args.varying_constant:
varying_constant_epochs = [int(ep.strip()) for ep in args.varying_constant_epochs.split(',')]
varying_constant = [float(c.strip()) for c in args.varying_constant.split(',')]
assert len(varying_constant) == len(varying_constant_epochs) + 1
if args.load_model:
ckpt = torch.load(args.load_model)
net.module.load_state_dict(ckpt['net_state_dict'])
if cascade_gan:
net_d.module.load_state_dict(ckpt['D_state_dict'])
logging.info("%s's previous weights loaded." % args.model_name)
独特见解