How to use the byteps.torch.ops.push_pull_async_inplace function in byteps

To help you get started, we’ve selected a few byteps examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github bytedance / byteps / byteps / torch / cross_barrier.py View on Github external
def _push_pull_grad_async(self, p):
        """Call byteps API to push-pull gradient asynchronously
        Arguments:
            tensor: The tensor to push-pull.
            name: The name of the tensor.
        Returns:
            an push-pull handle and context
        """
        name = self._get_parameter_name(p)
        tensor = p.grad
        tensor_compressed, ctx = self._compression.compress(tensor)

        self._locks[p].acquire()
        handle = byteps_push_pull(tensor_compressed, average=True, name="Gradient."+name)
        self._logger.debug("{} calls byteps_push_pull for {}".format(self._desc, self._get_parameter_name(p)))
        # Add to queue to poll completion
        self._event_queue.put((p, handle, ctx))
        return handle, ctx
github bytedance / byteps / byteps / torch / __init__.py View on Github external
if isinstance(params, dict):
        params = sorted(params.items())
    elif isinstance(params, list):
        # support both named_parameters() and regular parameters()
        params = [p if isinstance(p, tuple) else (None, p) for p in params]
    else:
        raise ValueError('invalid params of type: %s' % type(params))

    # Run synchronous broadcasts.
    for name, p in params:
        # Broadcast is implemented as push + pull in BytePS
        # To make it a real broadcast, we set the non-root tensors all 0.
        if rank() != root_rank:
            p.fill_(0)
        # Remember to disable averaging because we are doing broadcast
        handle = byteps_push_pull(p, average=False, name="Parameter."+name)
        synchronize(handle)
github bytedance / byteps / example / pytorch / microbenchmark-byteps.py View on Github external
def benchmark(tensor, average, name):
    if not args.no_wait and bps.rank() == 0:
        time.sleep(0.01)
    start = time.time()
    handle = push_pull_async_inplace(tensor, average, name)
    while True:
        if poll(handle):
            synchronize(handle)
            break
    end = time.time()
    return (end - start) * 1000
github bytedance / byteps / byteps / torch / __init__.py View on Github external
def _push_pull_grad_async(self, p):
        if self._is_tensor_instance:
            name = self._parameter_names.get(p.__hash__())
        else:
            name = self._parameter_names.get(p)
        if self._enable_async:
            # the real handle will be created in step()
            handle, ctx = None, None
        else:
            tensor = p.grad
            tensor_compressed, ctx = self._compression.compress(tensor)
            handle = byteps_push_pull(tensor_compressed, average=True, name="Gradient."+name)
        return handle, ctx
github bytedance / byteps / byteps / torch / __init__.py View on Github external
# store the weights before update
            for p, _ in self._handles.items():
                old_weight_map[p] = p.data.clone().detach()
            # update
            loss = super(self.__class__, self).step(closure)
            
            for p, (h, _) in self._handles.items():
                # get the diff for each weight (in-place)
                p.data.sub_(old_weight_map.get(p))
                if h is None:
                    # create the handler now
                    if self._is_tensor_instance:
                        name = self._parameter_names.get(p.__hash__())
                    else:
                        name = self._parameter_names.get(p)
                    handle = byteps_push_pull(p, average=False, name="AsyncParam."+name)
                    _, ctx = self._compression.compress(p)
                    self._handles[p] = (handle, ctx)

            self.synchronize()
            return loss
        else:
            self.synchronize()
            return super(self.__class__, self).step(closure)