关于PyTorch的Dataset、DataLoader和Sampler的设计看法

PyTorch的Dataset和DataLoader设计上还算优雅。在PyTorch的官方文档,以及大多数项目中,Dataset的作用是加载数据集提供样本的乱序访问__getitem__,通常在构造函数中传入数据集路径。Sampler设计得很漂亮,用于遍历数据集时的顺序控制__iter__和batching,可以做一些很fancy的事情,比如根据sequence长度进行kmeans batching。而DataLoader的设计目的则是将Dataset和Sampler组合起来,提供迭代方法__iter__
设计决定应用,正因为Dataset被设计为DataLoader的成员,并且需要提供乱序访问,所以它常驻内存。PyTorch初级用户于是习惯了将数据集全部加载到内存,给系统带来了很大的负担。前几天,我们实验室512GB内存的服务器就因为OOM宕机了。

而TensorFlow的Dataset.from_generator就特别优雅,处理大数据集、特别是生产环境中的数据特别合适。在HanLP2.0alpha中,所有数据集都是通过Dataset.from_generator创建的,天生适合生产环境。同时,如果需要乱序,可以用map方法shuffle;如果需要缓存,同样一个参数就行,还支持缓存到磁盘。

当然,这些都是人写的,PyTorch当然也可以写出来。不知道大家平时对PyTorch的data API有什么看法,欢迎留言。

1 Like

PyTorch 1.2提供了IterableDataset,类比于TF的from_generator。然而IterableDataset并不能像TF那样进行各种并行map、shuffle,而是需要在worker_init_fn中让每个worker处理dataset的一个子集。例如这篇文章:

它要求dataset在启动worker前全部加载到内存,这是完全无法用于生产环境的。

不全部加载到内存这就麻烦了,比如生产环境中读取一个1TB的训练集文件,将文件路径传入2个worker,第一个worker读取前一半sample,第二个worker读取后一半sample。如果每个sample不定长,第二个worker根本不知道后一半sample在文件的什么位置,最后第二个worker还是得一个一个数。也可以约定第一个读奇数,第二个处理偶数,然而最终结果同一个文件还是被读取了2遍,不可接受。

TF是怎么做到的?读文件只有一个线程,这个线程读出来的sample放入一个缓冲区,在把这个缓冲区分区给不同的worker处理。文件只读一遍,但是多个worker在处理,同时用户也没多写一行代码。

1 Like

根据 https://discuss.pytorch.org/t/how-to-speed-up-the-data-loader ,要在目前的PyTorch里实现TF同等质量的data pipeline,可选的方案有:

  1. 将数据集转换为h5格式,在__getitem__读取
  2. 自己写一个生产者消费者,完全忽略PyTorch的multiprocess机制。
1 Like

https://github.com/pytorch/text/issues/130#issuecomment-531901039 这个Issue中给出了一个使用
f.seek(), f.tell()的方案,代价是预先计算每行在文件中的byte offset
可惜的是,这个方案是不支持多线程的。

1 Like

这个按system programing,应该用生产者消费者模式。

我自己的实现是起一个python原生进程来读数据,生产者消费者