【VRCNet】参数解析、日志、AverageValueMeter

作者 by Tianzhi Jia / 2022-03-18 / 暂无评论 / 156 个足迹

import argparse
import munch
import yaml
import datetime
import os
import logging
import sys
import math
from dataset import ShapeNetH5
import torch
parser = argparse.ArgumentParser(description='Train config file')
parser.add_argument('-c', '--config', help='path to config file', required=True)
arg = parser.parse_args(['-c', 'cfgs/pcn.yaml'])

config_path = arg.config
args = munch.munchify(yaml.safe_load(open(config_path)))
time = datetime.datetime.now().isoformat()[:19]

if args.load_model:
    exp_name = os.path.basename(os.path.dirname(args.load_model))
    log_dir = os.path.dirname(args.load_model)
else:
    exp_name = args.model_name + '_' + args.loss + '_' + args.flag + '_' + time
    log_dir = os.path.join(args.work_dir, exp_name)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(log_dir, 'train.log')), logging.StreamHandler(sys.stdout)])

logging.info(str(args))
class AverageValueMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0.0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
metrics = ['cd_p', 'cd_t', 'emd', 'f1']
best_epoch_losses = {m: (0, 0) if m == 'f1' else (0, math.inf) for m in metrics}
train_loss_meter = AverageValueMeter()
val_loss_meters = {m: AverageValueMeter() for m in metrics}
dataset = ShapeNetH5(train=True, npoints=args.num_points)
dataset_test = ShapeNetH5(train=False, npoints=args.num_points)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers))
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers))

logging.info('Length of train dataset:%d', len(dataset))
logging.info('Length of test dataset:%d', len(dataset_test))

独特见解