且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

在 Keras 中具有多个输入/输出的 tf.data

更新时间:2023-12-01 22:28:46

我没有使用 Keras,但我会使用 tf.data.Dataset.from_generator() - 比如:

I'm not using Keras but I would go with an tf.data.Dataset.from_generator() - like:

def _input_fn():
  sent1 = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int64)
  sent2 = np.array([20, 25, 35, 40, 600, 30, 20, 30], dtype=np.int64)
  sent1 = np.reshape(sent1, (8, 1, 1))
  sent2 = np.reshape(sent2, (8, 1, 1))

  labels = np.array([40, 30, 20, 10, 80, 70, 50, 60], dtype=np.int64)
  labels = np.reshape(labels, (8, 1))

  def generator():
    for s1, s2, l in zip(sent1, sent2, labels):
      yield {"input_1": s1, "input_2": s2}, l

  dataset = tf.data.Dataset.from_generator(generator, output_types=({"input_1": tf.int64, "input_2": tf.int64}, tf.int64))
  dataset = dataset.batch(2)
  return dataset

...

model.fit(_input_fn(), epochs=10, steps_per_epoch=4)

这个生成器可以迭代你的例如文本文件/numpy 数组,并在每次调用时产生一个示例.在这个例子中,我假设句子的单词已经转换为词汇表中的索引.

This generator can iterate over your e.g text-files / numpy arrays and yield on every call a example. In this example, I assume that the word of the sentences are already converted to the indices in the vocabulary.

由于 OP 要求,使用 Dataset.from_tensor_slices() 也应该是可能的:

Since OP asked, it should be also possible with Dataset.from_tensor_slices():

def _input_fn():
  sent1 = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int64)
  sent2 = np.array([20, 25, 35, 40, 600, 30, 20, 30], dtype=np.int64)
  sent1 = np.reshape(sent1, (8, 1))
  sent2 = np.reshape(sent2, (8, 1))

  labels = np.array([40, 30, 20, 10, 80, 70, 50, 60], dtype=np.int64)
  labels = np.reshape(labels, (8))

  dataset = tf.data.Dataset.from_tensor_slices(({"input_1": sent1, "input_2": sent2}, labels))
  dataset = dataset.batch(2, drop_remainder=True)
  return dataset