Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_loss_function_call():
margin = 2
loss_dict = get_loss_functions(margin=margin)
for loss_name in loss_dict.keys():
loss_function = triplet_loss(distance=loss_name, margin=margin)
assert loss_function.__name__ == loss_name
build_annoy_index(X, self.annoy_index_path,
ntrees=self.ntrees,
build_index_on_disk=self.build_index_on_disk,
verbose=self.verbose)
datagen = generator_from_index(X, Y,
index_path=self.annoy_index_path,
k=self.k,
batch_size=self.batch_size,
search_k=self.search_k,
precompute=self.precompute,
verbose=self.verbose)
loss_monitor = 'loss'
try:
triplet_loss_func = triplet_loss(distance=self.distance,
margin=self.margin)
except KeyError:
raise ValueError('Loss function `{}` not implemented.'.format(self.distance))
if self.model_ is None:
if type(self.model_def) is str:
input_size = (X.shape[-1],)
self.model_, anchor_embedding, _, _ = \
triplet_network(base_network(self.model_def, input_size),
embedding_dims=self.embedding_dims)
else:
self.model_, anchor_embedding, _, _ = \
triplet_network(self.model_def,
embedding_dims=self.embedding_dims)
if Y is None:
Parameters
----------
folder_path : string
Path to serialised model files and metadata
Returns
-------
returns an ivis instance
"""
ivis_config = json.load(open(os.path.join(folder_path,
'ivis_params.json'), 'r'))
self.__dict__ = ivis_config
loss_function = triplet_loss(self.distance, self.margin)
self.model_ = load_model(os.path.join(folder_path, 'ivis_model.h5'),
custom_objects={'tf': tf,
loss_function.__name__: loss_function })
self.encoder = self.model_.layers[3]
self.encoder._make_predict_function()
# If a supervised model exists, load it
supervised_path = os.path.join(folder_path, 'supervised_model.h5')
if os.path.exists(supervised_path):
self.supervised_model_ = load_model(supervised_path)
return self