Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _push_pull_grad_async(self, p):
"""Call byteps API to push-pull gradient asynchronously
Arguments:
tensor: The tensor to push-pull.
name: The name of the tensor.
Returns:
an push-pull handle and context
"""
name = self._get_parameter_name(p)
tensor = p.grad
tensor_compressed, ctx = self._compression.compress(tensor)
self._locks[p].acquire()
handle = byteps_push_pull(tensor_compressed, average=True, name="Gradient."+name)
self._logger.debug("{} calls byteps_push_pull for {}".format(self._desc, self._get_parameter_name(p)))
# Add to queue to poll completion
self._event_queue.put((p, handle, ctx))
return handle, ctx
if isinstance(params, dict):
params = sorted(params.items())
elif isinstance(params, list):
# support both named_parameters() and regular parameters()
params = [p if isinstance(p, tuple) else (None, p) for p in params]
else:
raise ValueError('invalid params of type: %s' % type(params))
# Run synchronous broadcasts.
for name, p in params:
# Broadcast is implemented as push + pull in BytePS
# To make it a real broadcast, we set the non-root tensors all 0.
if rank() != root_rank:
p.fill_(0)
# Remember to disable averaging because we are doing broadcast
handle = byteps_push_pull(p, average=False, name="Parameter."+name)
synchronize(handle)
def benchmark(tensor, average, name):
if not args.no_wait and bps.rank() == 0:
time.sleep(0.01)
start = time.time()
handle = push_pull_async_inplace(tensor, average, name)
while True:
if poll(handle):
synchronize(handle)
break
end = time.time()
return (end - start) * 1000
def _push_pull_grad_async(self, p):
if self._is_tensor_instance:
name = self._parameter_names.get(p.__hash__())
else:
name = self._parameter_names.get(p)
if self._enable_async:
# the real handle will be created in step()
handle, ctx = None, None
else:
tensor = p.grad
tensor_compressed, ctx = self._compression.compress(tensor)
handle = byteps_push_pull(tensor_compressed, average=True, name="Gradient."+name)
return handle, ctx
# store the weights before update
for p, _ in self._handles.items():
old_weight_map[p] = p.data.clone().detach()
# update
loss = super(self.__class__, self).step(closure)
for p, (h, _) in self._handles.items():
# get the diff for each weight (in-place)
p.data.sub_(old_weight_map.get(p))
if h is None:
# create the handler now
if self._is_tensor_instance:
name = self._parameter_names.get(p.__hash__())
else:
name = self._parameter_names.get(p)
handle = byteps_push_pull(p, average=False, name="AsyncParam."+name)
_, ctx = self._compression.compress(p)
self._handles[p] = (handle, ctx)
self.synchronize()
return loss
else:
self.synchronize()
return super(self.__class__, self).step(closure)