且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

使用Keras ImageDataGenerator时多输入模型中的内存错误

更新时间:2023-12-02 15:22:55

我已解决问题:)

batch_size = 32
# y = np.ones(batch_size)
aug.fit(X['anc_input'])

def gen_flow_multi_inputs(X):
    gen_X_ = {}
    for k, X_ in X.items():
        gen_X_[k] = aug.flow(X_, batch_size=batch_size, seed=7)
    while True:
        XX = {}
        for k, X_ in X.items():
            XX[k] = gen_X_[k].next()
        N = len(XX['anc_input'])
        yield XX, np.ones(N)

self.model.fit_generator(gen_flow_multi_inputs(X),
                         validation_data=[X_te, np.ones(len(anc_ins_te))],
                         steps_per_epoch=len(anc_ins) // batch_size,
                         epochs=50,
                         callbacks=self.callbacks)