且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

如何为我的 tensorflow 模型提高此数据管道的性能

更新时间:2023-12-02 18:38:04

来自 hampi 的建议来描述您的培训工作是一个很好的方法,可能需要了解管道中的实际瓶颈.输入管道性能指南中的其他建议也应该很有用.

The suggestion from hampi to profile your training job is a good one, and may be necessary to understand the actual bottlenecks in your pipeline. The other suggestions in the Input Pipeline performance guide should be useful as well.

然而,还有另一种可能有用的快速修复".在某些情况下,Dataset.map() 转换中的工作量可能非常小,主要是为每个元素调用函数的成本.在这些情况下,我们经常尝试向量化 map 函数,并在 Dataset.batch() 转换之后移动它,以减少调用该函数的次数 (1/512 次,在这种情况下),并对每个批次执行更大的(并且可能更容易并行化)操作.幸运的是,您的管道可以按如下方式矢量化:

However, there is another possible "quick fix" that might be useful. In some cases, the amount of work in a Dataset.map() transformation can be very small, and dominated by the cost of invoking the function for each element. In those cases, we often try to vectorize the map function, and move it after the Dataset.batch() transformation, in order to invoke the function fewer times (1/512 as many times, in this case), and perform larger—and potentially easier-to-parallelize—operations on each batch. Fortunately, your pipeline can be vectorized as follows:

def _batch_parser(record_batch):
  # NOTE: Use `tf.parse_example()` to operate on batches of records.
  parsed = tf.parse_example(record_batch, _keys_to_map)
  return parsed['d'], parsed['s']

def init_tfrecord_dataset():
  files_train = glob.glob(DIR_TFRECORDS + '*.tfrecord')
  random.shuffle(files_train)

  with tf.name_scope('tfr_iterator'):
    ds = tf.data.TFRecordDataset(files_train)      # define data from randomly ordered files
    ds = ds.shuffle(buffer_size=10000)             # select elements randomly from the buffer
    # NOTE: Change begins here.
    ds = ds.batch(BATCH_SIZE, drop_remainder=True) # group elements in batch (remove batch of less than BATCH_SIZE)
    ds = ds.map(_batch_parser)                     # map batches based on tfrecord format
    # NOTE: Change ends here.
    ds = ds.repeat()                               # iterate infinitely 

    return ds.make_initializable_iterator()        # initialize the iterator

目前,矢量化是您必须手动进行的更改,但 tf.data 团队正在致力于 提供自动矢量化的优化通道.

Currently, vectorization is a change that you have to make manually, but the tf.data team are working on an optimization pass that provides automatic vectorization.