您当前的位置: 首页 >  pytorch

PyTorch 加载超大 Libsvm 格式数据

发布时间:2022-01-10 22:01:53 ,浏览量:0

对于比较大的数据集,比如好几个T的数据,没有办法一次性全部加载进内存,因此需要构建一个可迭代的数据集IterableDataset。

迭代读取文本文件

要借助pytorch的IterableDataset模块,官方文档是:IterableDataset。

按照官网的说法,需要继承这个IterableDataset类,然后覆写__iter__这个方法,返回一个可迭代的对象即可。

因为我们要处理的时标准 Libsvm 格式数据,所以还需要实现又给process_line函数处理每一行数据。

class LibsvmDataset(IterableDataset): def __init__(self, file_path, n_features): """
        file_path: Libsvm格式数据文件地址
        n_features: 特征数,从1开始
        """ self.file_path = file_path
        self.n_features = n_features def process_line(self, line): line = line.split(' ') label, values = int(line[0]), line[1:] value = torch.zeros((self.n_features)) for item in values: idx, val = item.split(':') value[int(idx) - 1] = float(val) return label, value def __iter__(self): with open(self.file_path, 'r') as fp: for line in fp: yield self.process_line(line.strip("\n")) 

然后我们就可以直接把LibsvmDataset通过DataLoader封装成一个加载器。

dataset = LibsvmDataset("./test.libsvm", 10) dataloader = DataLoader(dataset, batch_size=3) for data in dataloader: print(data) 

在这里插入图片描述

Shuffle 操作

如果说想实现shuffle操作的话,可以手动增加一个缓冲池,然后随机抽取。

class LibsvmDataset(IterableDataset): def __init__(self, file_path, n_features, buffer_size=256): """
        file_path: Libsvm格式数据文件地址
        n_features: 特征数,从1开始
        """ self.file_path = file_path
        self.n_features = n_features
        self.buffer_size = buffer_size def process_line(self, line): line = line.split(' ') label, values = torch.tensor([int(line[0])], dtype=torch.float), line[1:] value = torch.zeros((self.n_features), dtype=torch.float) for item in values: idx, val = item.split(':') value[int(idx) - 1] = float(val) return value, label def __iter__(self): shuffle_buffer = [] with open(self.file_path, 'r') as fp: index = 0 for line in fp: shuffle_buffer.append(self.process_line(line.strip("\n"))) index += 1 if index > self.buffer_size: break with open(self.file_path, 'r') as fp: for line in fp: evict_idx = random.randint(0, self.buffer_size - 1) yield shuffle_buffer[evict_idx] shuffle_buffer[evict_idx] = self.process_line(line.strip("\n")) 
关注
打赏
1688896170
查看更多评论

暂无认证

  • 0浏览

    0关注

    108697博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文
立即登录/注册

微信扫码登录

0.3184s