作者 by Tianzhi Jia / 2022-03-16 / 暂无评论 / 182 个足迹
import h5py
import numpy as np
import torch
train = True
npoints = 4096
if train:
input_path = './data/mvp_train_input.h5'
gt_path = './data/mvp_train_gt_%dpts.h5' % npoints
else:
input_path = './data/mvp_test_input.h5'
gt_path = './data/mvp_test_gt_%dpts.h5' % npoints
input_file = h5py.File(input_path, 'r')
input_data = np.array((input_file['incomplete_pcds'][()]))
labels = np.array((input_file['labels'][()]))
novel_input_data = np.array((input_file['novel_incomplete_pcds'][()]))
novel_labels = np.array((input_file['novel_labels'][()]))
input_file.close()
gt_file = h5py.File(gt_path, 'r')
gt_data = np.array((gt_file['complete_pcds'][()]))
novel_gt_data = np.array((gt_file['novel_complete_pcds'][()]))
gt_file.close()
# gt_file = h5py.File(gt_path, 'r')
# gt_labels = np.array((gt_file['labels'][()]))
# gt_normal = np.array((gt_file['normal'][()]))
# gt_novel_labels = np.array((gt_file['novel_labels'][()]))
# gt_novel_normal = np.array((gt_file['novel_normal'][()]))
# gt_file.close()
novel_input=True
novel_input_only=False
if novel_input_only:
input_data = novel_input_data
gt_data = novel_gt_data
labels = novel_labels
elif novel_input:
input_data = np.concatenate((input_data, novel_input_data), axis=0)
gt_data = np.concatenate((gt_data, novel_gt_data), axis=0)
labels = np.concatenate((labels, novel_labels), axis=0)
print(input_data.shape)
print(gt_data.shape)
print(labels.shape)
len = input_data.shape[0]
def get_item(index):
partial = torch.from_numpy((input_data[index]))
complete = torch.from_numpy((gt_data[index // 26]))
label = (labels[index])
return label, partial, complete
label, partial, complete = get_item(0)
from dataset import ShapeNetH5
import torch
dataset = ShapeNetH5(train=True, npoints=args.num_points)
dataset_test = ShapeNetH5(train=False, npoints=args.num_points)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers))
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers))
logging.info('Length of train dataset:%d', len(dataset))
logging.info('Length of test dataset:%d', len(dataset_test))
for i, data in enumerate(dataloader, 0):
print(i)
print(data[0])
print(data[1].shape)
print(data[2].shape)
break
独特见解