Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
opt = hook.wrap_optimizer(opt)
model.compile(
optimizer=opt,
loss="sparse_categorical_crossentropy",
run_eagerly=False,
metrics=["accuracy"],
)
hooks = [hook]
hook.save_scalar("tf_keras_num_steps", steps, sm_metric=True)
hook.save_scalar("tf_keras_before_train", 1, sm_metric=False)
hook.set_mode(ModeKeys.TRAIN)
model.fit(x_train, y_train, epochs=1, steps_per_epoch=steps, callbacks=hooks, verbose=0)
hook.set_mode(ModeKeys.EVAL)
model.evaluate(x_test, y_test, steps=10, callbacks=hooks, verbose=0)
hook.save_scalar("tf_keras_after_train", 1, sm_metric=False)
)
opt = tf.train.RMSPropOptimizer(lr)
opt = hook.wrap_optimizer(opt)
model.compile(
optimizer=opt,
loss="sparse_categorical_crossentropy",
run_eagerly=False,
metrics=["accuracy"],
)
hooks = [hook]
hook.save_scalar("tf_keras_num_steps", steps, sm_metric=True)
hook.save_scalar("tf_keras_before_train", 1, sm_metric=False)
hook.set_mode(ModeKeys.TRAIN)
model.fit(x_train, y_train, epochs=1, steps_per_epoch=steps, callbacks=hooks, verbose=0)
hook.set_mode(ModeKeys.EVAL)
model.evaluate(x_test, y_test, steps=10, callbacks=hooks, verbose=0)
hook.save_scalar("tf_keras_after_train", 1, sm_metric=False)
def get_keras_mode(mode):
# Should never be called in TF 1.13 where this is not available
from tensorflow.python.keras.utils.mode_keys import ModeKeys as KerasModeKeys
if mode == ModeKeys.TRAIN:
return KerasModeKeys.TRAIN
elif mode == ModeKeys.EVAL:
return KerasModeKeys.TEST
elif mode == ModeKeys.PREDICT:
return KerasModeKeys.PREDICT
def _get_collections_to_save_for_step(self) -> Set["Collection"]:
if self._collections_to_save_for_step is None:
self._assert_prep()
self._collections_to_save_for_step = set()
for coll in self._get_all_collections_to_save():
if self.mode in [ModeKeys.EVAL, ModeKeys.PREDICT]:
if coll.name in [CollectionKeys.GRADIENTS, CollectionKeys.OPTIMIZER_VARIABLES]:
continue
if coll.save_config.should_save_step(self.mode, self.mode_steps[self.mode]):
self._collections_to_save_for_step.add(coll)
if self._collections_to_save_for_step:
if self.mode == ModeKeys.GLOBAL:
step_str = f"for step {self.step}"
else:
step_str = f"for step {self.mode_steps[self.mode]} of mode {self.mode.name}"
self.logger.debug(
f"Saving the collections "
f"{', '.join([x.name for x in self._collections_to_save_for_step])} {step_str}"
)
return self._collections_to_save_for_step
def _get_mode_modestep(self, step, plugin_data):
mode_step = step
mode = ModeKeys.GLOBAL
for metadata in plugin_data:
if metadata.plugin_name == MODE_STEP_PLUGIN_NAME:
mode_step = int(metadata.content)
if metadata.plugin_name == MODE_PLUGIN_NAME:
mode = ModeKeys(int(metadata.content))
return mode, mode_step
def _check_mode_step(mode, mode_step, global_step):
if mode_step is None:
mode_step = global_step
if mode is None:
mode = ModeKeys.GLOBAL
if not isinstance(mode, ModeKeys):
mode_keys = ["ModeKeys." + x.name for x in ModeKeys]
ex_str = "mode can be one of " + ", ".join(mode_keys)
raise ValueError(ex_str)
return mode, mode_step
def tensor_names(self, *, step=None, mode=ModeKeys.GLOBAL, regex=None, collection=None) -> list:
self.maybe_refresh()
ts = set()
if step is None and mode == ModeKeys.GLOBAL:
ts.update(self._tensors.keys())
if step is None and mode != ModeKeys.GLOBAL:
ts.update(self.mode_to_tensors_map[mode])
else:
ts.update(self._tensors_for_step(step, mode))
self.logger.debug(
f"getting tensor_names with params: step:{step} mode:{mode} regex:{regex} collection:{collection}"
)
if regex is None and collection is None:
return sorted(list(ts))
elif regex is not None and collection is not None:
raise ValueError("Only one of `regex` or `collection` can be passed to this method")
else:
if collection is not None:
xs = set(self._tensors.keys()).intersection(self._tensors_in_collection(collection))
def _get_exec_function(self, mode):
if self.distribution_strategy in [
TFDistributionStrategy.NONE,
TFDistributionStrategy.HOROVOD,
]:
if mode == ModeKeys.TRAIN:
x = self.model.train_function
elif mode == ModeKeys.EVAL:
x = self.model.test_function
elif mode == ModeKeys.PREDICT:
x = self.model.predict_function
else:
raise NotImplementedError
else:
x = self._get_distributed_model(mode)._distributed_function
return x