且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

tf读取数据pai

更新时间:2021-10-09 03:54:04



  1. import argparse
  2. import tensorflow as tf
  3. import os
  4. FLAGS=None
  5. def main(_):
  6. dirname = os.path.join(FLAGS.buckets, "csvtest.csv")
  7. reader=tf.TextLineReader()
  8. filename_queue=tf.train.string_input_producer([dirname])
  9. key,value=reader.read(filename_queue)
  10. record_defaults=[[''],[''],[''],[''],['']]
  11. d1, d2, d3, d4, d5= tf.decode_csv(value, record_defaults, ',')
  12. init=tf.initialize_all_variables()
  13. with tf.Session() as sess:
  14. sess.run(init)
  15. coord = tf.train.Coordinator()
  16. threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  17. for i in range(4):
  18. print(sess.run(d2))
  19. coord.request_stop()
  20. coord.join(threads)
  21. if __name__ == '__main__':
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument('--buckets', type=str, default='',
  24. help='input data path')
  25. parser.add_argument('--checkpointDir', type=str, default='',
  26. help='output model path')
  27. FLAGS, _ = parser.parse_known_args()
  28. tf.app.run(main=main)


  • dirname:OSS文件路径,可以是数组,方便下一阶段shuffle
  • reader:TF内置各种reader API,可以根据需求选用
  • tf.train.string_input_producer:将文件生成队列
  • tf.decode_csv:是一个splite功能的OP,可以拿到每一行的特定参数
  • 通过OP获取数据,在session中需要tf.train.Coordinator()和tf.train.start_queue_runners(sess=sess,coord=coord)

在代码中,我们的输入是3行5个字段:


  1. 1,1,1,1,1
  2. 2,2,2,2,2
  3. 3,3,3,3,3

我们循环输出4次,打印出第2个字段。结果如图:

tf读取数据pai

输出结果也证明了数据结构是成队列。