Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_checker_call(setup):
checker = Checker(setup['path'])
assert checker['no exists'] is None
checker.set_checkpoint(**setup['cp'])
assert (Path(checker.path) / 'checkpoints').exists()
def test_checker_path(setup):
checker = Checker()
assert checker.path == str(Path('.').resolve() / Path(os.getcwd()).name)
assert checker.model_name == Path(os.getcwd()).name
checker = Checker(setup['path'])
assert checker.path == str(Path(setup['path']))
assert checker.model_name == setup['name']
checker = Checker(setup['path'], increment=True)
assert checker.path == str(Path(setup['path'] + '@1'))
assert checker.model_name == setup['name'] + '@1'
def test_checker_path(setup):
checker = Checker()
assert checker.path == str(Path('.').resolve() / Path(os.getcwd()).name)
assert checker.model_name == Path(os.getcwd()).name
checker = Checker(setup['path'])
assert checker.path == str(Path(setup['path']))
assert checker.model_name == setup['name']
checker = Checker(setup['path'], increment=True)
assert checker.path == str(Path(setup['path'] + '@1'))
assert checker.model_name == setup['name'] + '@1'
def test_checker_path(setup):
checker = Checker()
assert checker.path == str(Path('.').resolve() / Path(os.getcwd()).name)
assert checker.model_name == Path(os.getcwd()).name
checker = Checker(setup['path'])
assert checker.path == str(Path(setup['path']))
assert checker.model_name == setup['name']
checker = Checker(setup['path'], increment=True)
assert checker.path == str(Path(setup['path'] + '@1'))
assert checker.model_name == setup['name'] + '@1'
def test_checker_model_1(setup):
checker = Checker(setup['path'])
assert checker.model is None
assert checker.describe is None
assert checker.model_structure is None
with pytest.raises(TypeError):
checker.model = None
checker.model = setup['model']
assert isinstance(checker.model, LinearLayer)
assert isinstance(checker.model_structure, str)
assert 'LinearLayer' in checker.model_structure
assert str(checker.model) == str(setup['model'])
assert checker.init_state is not None
with pytest.raises(TypeError):
checker.init_state = None
def test_checker_from_cp(setup):
checker = Checker(setup['path'])
path = checker.path
checker.set_checkpoint(test_cp=setup['cp'])
checker(a=1)
checker2 = Checker.load(path)
assert checker2['a'] == 1
cp = checker2.checkpoints['test_cp']
assert 'b' in cp
assert cp['model_state'] == setup['cp']['model_state']
If ``True``, dir name of path will be decorated with a auto increment number,
e.g. use ``model_dir@1`` for ``model_dir``.
sync_training_step
If ``True``, will save ``trainer.training_info`` at each iteration.
Default is ``False``, only save ``trainer.training_info`` at each epoch.
describe:
Any other information to describe this model.
These information will be saved under model dir by name ``describe.pkl.z``.
"""
self._model_class: Callable = model_class
self._model_params: Union[list, dict] = model_params
self.sync_training_step = sync_training_step
self._increment = increment
self._describe = describe
self._describe_ = None
self._checker: Union[Checker, None] = None
self._tmp_args: list = []
self._tmp_kwargs: dict = {}
self._epoch_count = 0
self.path = path
clip_grad
Clip grad before each optimize.
epochs
Number of iterations.
cuda
Set training device(s).
non_blocking
When non_blocking is ``True``,
it tries to convert/move asynchronously with respect to the host if possible.
Returns
-------
"""
if isinstance(checker, (str, Path)):
checker = Checker(checker)
else:
checker = checker
if len(checker.files) == 0:
raise RuntimeError(f'{checker.path} is not a model dir')
tmp = cls(model=checker.model,
cuda=cuda,
loss_func=loss_func,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
clip_grad=clip_grad,
epochs=epochs,
non_blocking=non_blocking)
tmp._training_info = checker.training_info.to_dict(orient='records')
if Path(checker.path + '/checkpoints').is_dir():
for k in checker.checkpoints.files:
def before_proc(self, trainer: Trainer) -> None:
self._checker = Checker(self._path, increment=self._increment)
if self._model_class is not None:
self._checker(model_class=self._model_class)
if self._model_params is not None:
self._checker(model_params=self._model_params)
self._checker.model = trainer.model
self._describe_ = dict(
python=py_ver,
system=sys_ver(),
numpy=np.__version__,
torch=torch.__version__,
xenonpy=__version__,
device=str(trainer.device),
start=datetime.now().strftime('%Y/%m/%d %H:%M:%S'),
finish='N/A',
time_elapsed='N/A',
**self._describe,