更新时间:2023-12-02 15:22:55
我已解决问题:)
batch_size = 32
# y = np.ones(batch_size)
aug.fit(X['anc_input'])
def gen_flow_multi_inputs(X):
gen_X_ = {}
for k, X_ in X.items():
gen_X_[k] = aug.flow(X_, batch_size=batch_size, seed=7)
while True:
XX = {}
for k, X_ in X.items():
XX[k] = gen_X_[k].next()
N = len(XX['anc_input'])
yield XX, np.ones(N)
self.model.fit_generator(gen_flow_multi_inputs(X),
validation_data=[X_te, np.ones(len(anc_ins_te))],
steps_per_epoch=len(anc_ins) // batch_size,
epochs=50,
callbacks=self.callbacks)