对于比较大的数据集,比如好几个T的数据,没有办法一次性全部加载进内存,因此需要构建一个可迭代的数据集IterableDataset。
迭代读取文本文件要借助pytorch的IterableDataset模块,官方文档是:IterableDataset。
按照官网的说法,需要继承这个IterableDataset类,然后覆写__iter__这个方法,返回一个可迭代的对象即可。
因为我们要处理的时标准 Libsvm 格式数据,所以还需要实现又给process_line函数处理每一行数据。
class LibsvmDataset(IterableDataset): def __init__(self, file_path, n_features): """ file_path: Libsvm格式数据文件地址 n_features: 特征数,从1开始 """ self.file_path = file_path self.n_features = n_features def process_line(self, line): line = line.split(' ') label, values = int(line[0]), line[1:] value = torch.zeros((self.n_features)) for item in values: idx, val = item.split(':') value[int(idx) - 1] = float(val) return label, value def __iter__(self): with open(self.file_path, 'r') as fp: for line in fp: yield self.process_line(line.strip("\n"))
然后我们就可以直接把LibsvmDataset通过DataLoader封装成一个加载器。
dataset = LibsvmDataset("./test.libsvm", 10) dataloader = DataLoader(dataset, batch_size=3) for data in dataloader: print(data)
如果说想实现shuffle操作的话,可以手动增加一个缓冲池,然后随机抽取。
class LibsvmDataset(IterableDataset): def __init__(self, file_path, n_features, buffer_size=256): """ file_path: Libsvm格式数据文件地址 n_features: 特征数,从1开始 """ self.file_path = file_path self.n_features = n_features self.buffer_size = buffer_size def process_line(self, line): line = line.split(' ') label, values = torch.tensor([int(line[0])], dtype=torch.float), line[1:] value = torch.zeros((self.n_features), dtype=torch.float) for item in values: idx, val = item.split(':') value[int(idx) - 1] = float(val) return value, label def __iter__(self): shuffle_buffer = [] with open(self.file_path, 'r') as fp: index = 0 for line in fp: shuffle_buffer.append(self.process_line(line.strip("\n"))) index += 1 if index > self.buffer_size: break with open(self.file_path, 'r') as fp: for line in fp: evict_idx = random.randint(0, self.buffer_size - 1) yield shuffle_buffer[evict_idx] shuffle_buffer[evict_idx] = self.process_line(line.strip("\n"))