PyTorch的Dataset和DataLoader设计上还算优雅。在PyTorch的官方文档,以及大多数项目中,Dataset的作用是加载数据集提供样本的乱序访问__getitem__
,通常在构造函数中传入数据集路径。Sampler设计得很漂亮,用于遍历数据集时的顺序控制__iter__
和batching,可以做一些很fancy的事情,比如根据sequence长度进行kmeans batching。而DataLoader的设计目的则是将Dataset和Sampler组合起来,提供迭代方法__iter__
。
设计决定应用,正因为Dataset被设计为DataLoader的成员,并且需要提供乱序访问,所以它常驻内存。PyTorch初级用户于是习惯了将数据集全部加载到内存,给系统带来了很大的负担。前几天,我们实验室512GB内存的服务器就因为OOM宕机了。
而TensorFlow的Dataset.from_generator
就特别优雅,处理大数据集、特别是生产环境中的数据特别合适。在HanLP2.0alpha中,所有数据集都是通过Dataset.from_generator
创建的,天生适合生产环境。同时,如果需要乱序,可以用map方法shuffle;如果需要缓存,同样一个参数就行,还支持缓存到磁盘。
当然,这些都是人写的,PyTorch当然也可以写出来。不知道大家平时对PyTorch的data API有什么看法,欢迎留言。