且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

创建一个输出dict的Tensorflow数据集

更新时间:2023-11-30 14:34:34

所以实际上可以按照您的意愿去做,只需要具体说明dict的内容即可:

So actually it is possible to do what you intend, you just have to be specific about the contents of the dict:

import tensorflow as tf
import numpy as np

N = 100
# dictionary of arrays:
metadata = {'m1': np.zeros(shape=(N,2)), 'm2': np.ones(shape=(N,3,5))}
num_samples = N

def meta_dict_gen():
    for i in range(num_samples):
        ls = {}
        for key, val in metadata.items():
            ls[key] = val[i]
        yield ls

dataset = tf.data.Dataset.from_generator(
    meta_dict_gen,
    output_types={k: tf.float32 for k in metadata},
    output_shapes={'m1': (2,), 'm2': (3, 5)})
iter = dataset.make_one_shot_iterator()
next_elem = iter.get_next()
print(next_elem)

输出:

{'m1': <tf.Tensor 'IteratorGetNext:0' shape=(2,) dtype=float32>,
 'm2': <tf.Tensor 'IteratorGetNext:1' shape=(3, 5) dtype=float32>}