Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, average=True):
"""
Args:
average (bool): whether to average or sum the gradients across processes.
"""
import byteps.tensorflow as bps
self.hvd = bps # BytePS has the same interface as Horovod
self.hvd.allreduce = bps.push_pull # https://github.com/bytedance/byteps/issues/8
assert os.environ.get("DMLC_ROLE", None) == "worker"
assert "DMLC_WORKER_ID" in os.environ and "DMLC_NUM_WORKER" in os.environ
bps.init()
self.is_chief = bps.rank() == 0
self._local_rank = bps.local_rank()
self._rank = bps.rank()
self._average = average
self._compression = None
self._has_compression = False
logger.info("[BytePSTrainer] local rank={}".format(self._local_rank))
SingleCostTrainer.__init__(self)
def _make_variable(self, metric, value):
with tf.name_scope('MetricAverageCallback') as scope:
var = tf.Variable(value, name=metric)
self.backend.get_session().run(var.initializer)
push_pull_op = bps.push_pull(var, scope, device_dense=self.device)
return var, push_pull_op
"""
Compute gradients of all trainable variables.
See Optimizer.get_gradients() for more info.
In DistributedOptimizer, get_gradients() is overriden to also
push_pull the gradients before returning them.
"""
gradients = super(self.__class__, self).get_gradients(loss, params)
if bps.size() > 1:
averaged_gradients = []
with tf.name_scope(self._name + "_Push_Pull") as scope:
for grad in gradients:
if grad is not None:
if self._sparse_as_dense and \
isinstance(grad, tf.IndexedSlices):
grad = tf.convert_to_tensor(grad)
avg_grad = bps.push_pull(grad, scope,
device_dense=self._device_dense,
device_sparse=self._device_sparse,
compression=self._compression)
averaged_gradients.append(avg_grad)
else:
averaged_gradients.append(None)
return averaged_gradients
else:
return gradients
def push_pull(backend, value, name, average):
push_pull_op = bps.push_pull(tf.constant(value, name=name), average=average)
return backend.get_session().run(push_pull_op)