且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

如何加载张量流模型并继续训练

更新时间:2023-10-18 12:38:22

我想我找到了答案.关键是它不需要调用 tf.train.import_meta_graph() 如果它已经使用了 saver.restore(sess, tf.train.latest_checkpoint('./')).这是我的代码.

I think I found the answer. The key is that it doesn't need to call tf.train.import_meta_graph() if it has already uses saver.restore(sess, tf.train.latest_checkpoint('./')). Here is my code.

# tf Graph input
X = tf.placeholder("float", [None, n_input])
Y = tf.placeholder("float", [None, n_classes])
mlp_layer_name = ['h1', 'b1', 'h2', 'b2', 'h3', 'b3', 'w_o', 'b_o']
logits = multilayer_perceptron(X, n_input, n_classes, mlp_layer_name)
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y), name='loss_op')
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op, name='train_op')

with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint('./')) # search for checkpoint file

    graph = tf.get_default_graph()

    for epoch in range(training_epochs):
        avg_cost = 0.

        # Loop over all batches
        for i in range(total_batch):
            batch_x, batch_y = next(train_generator)

            # Run optimization op (backprop) and cost op (to get loss value)
            _, c = sess.run([train_op, loss_op], feed_dict={X: batch_x,
                                                            Y: batch_y})
            # Compute average loss
            avg_cost += c / total_batch

        print("Epoch: {:3d}, cost = {:.6f}".format(epoch+1, avg_cost))