Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from byteps.torch.ops import poll, synchronize
from byteps.torch.ops import init, shutdown
from byteps.torch.ops import size, local_size, rank, local_rank
import threading
import logging
try:
import queue
except ImportError:
import Queue as queue
import time
import math
import torch
import byteps.torch as bps
_DistributedOptimizer = bps._DistributedOptimizer
_bps_DistributedOptimizer = bps.DistributedOptimizer
broadcast_parameters = bps.broadcast_parameters
broadcast_optimizer_state = bps.broadcast_optimizer_state
class _CrossBarrier(_DistributedOptimizer):
"""An optimizer that wraps a _DistributedOptimizer, intercepting push-pull operations.
This class enables overlapping gradient push-pull with both backward and forward propagation while maintaining
correct dependencies. It can achieve even higher training performance than the default BytePS with proper system
parameters. To understand the principles behind barrier crossing, check the paper
https://dl.acm.org/citation.cfm?id=3359642
"""
def __init__(self, model, byteps_opt, num_steps=10**6):
"""Construct a new ScheduledOptimizer, which uses byteps optimizer under the hood for averaging gradients
across all workers.
Args: