Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def main(_):
# BytePS: initialize BytePS.
bps.init()
# Keras automatically creates a cache directory in ~/.keras/datasets for
# storing the downloaded MNIST data. This creates a race
# condition among the workers that share the same filesystem. If the
# directory already exists by the time this worker gets around to creating
# it, ignore the resulting exception and continue.
cache_dir = os.path.join(os.path.expanduser('~'), '.keras', 'datasets')
if not os.path.exists(cache_dir):
try:
os.mkdir(cache_dir)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(cache_dir):
pass
else:
raise
def __init__(self, average=True):
"""
Args:
average (bool): whether to average or sum the gradients across processes.
"""
import byteps.tensorflow as bps
self.hvd = bps # BytePS has the same interface as Horovod
self.hvd.allreduce = bps.push_pull # https://github.com/bytedance/byteps/issues/8
assert os.environ.get("DMLC_ROLE", None) == "worker"
assert "DMLC_WORKER_ID" in os.environ and "DMLC_NUM_WORKER" in os.environ
bps.init()
self.is_chief = bps.rank() == 0
self._local_rank = bps.local_rank()
self._rank = bps.rank()
self._average = average
self._compression = None
self._has_compression = False
logger.info("[BytePSTrainer] local rank={}".format(self._local_rank))
SingleCostTrainer.__init__(self)
def on_train_begin(self, logs=None):
if bps.size() <= 1:
return
with tf.device(self.device):
bcast_op = bps.broadcast_global_variables(self.root_rank)
self.backend.get_session().run(bcast_op)