更新时间:2021-12-11 02:18:23
我认为 tf.contrib.estimator.replicate_model_fn 文档,
I think tf.contrib.estimator.replicate_model_fn is a cleaner solution. The following is from tf.contrib.estimator.replicate_model_fn documentation,
...
def model_fn(...): # See `model_fn` in `Estimator`.
loss = ...
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
if mode == tf.estimator.ModeKeys.TRAIN:
# See the section below on `EstimatorSpec.train_op`.
return EstimatorSpec(mode=mode, loss=loss,
train_op=optimizer.minimize(loss))
# No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`.
return EstimatorSpec(...)
...
classifier = tf.estimator.Estimator(
model_fn=tf.contrib.estimator.replicate_model_fn(model_fn))
您需要做的是用tf.contrib.estimator.TowerOptimize
和model_fn()
用tf.contrib.estimator.replicate_model_fn()
包装优化器.
我按照说明进行操作,并使TPU squeezenet模型在具有4个GPU的计算机上工作.我的修改此处.
What you need to do is to wrap optimizer with tf.contrib.estimator.TowerOptimize
and model_fn()
with tf.contrib.estimator.replicate_model_fn()
.
I followed the description and make an TPU squeezenet model work on a machine with 4 GPUs. My modifications here.