Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _learn(self, corpus, epochs):
scores, true = self._corpus_to_vectors(corpus)
# fit the model
self._model.fit(scores, true, batch_size=32, verbose=True,
epochs=epochs)
annif.util.atomic_save(
self._model,
self.datadir,
self.MODEL_FILE)
def subjects(self):
with open(self.path, encoding='utf-8') as subjfile:
for line in subjfile:
uri, label = line.strip().split(None, 1)
clean_uri = annif.util.cleanup_uri(uri)
yield Subject(uri=clean_uri, label=label, text=None)
def _parse_tsv_line(self, line):
if '\t' in line:
text, uris = line.split('\t', maxsplit=1)
subjects = [annif.util.cleanup_uri(uri)
for uri in uris.split()]
yield self._create_document(text=text,
uris=subjects,
labels=[])
else:
logger.warning('Skipping invalid line (missing tab): "%s"',
line.rstrip())
def _create_index(self, veccorpus):
self.info('creating similarity index')
gscorpus = Sparse2Corpus(veccorpus, documents_columns=False)
self._index = gensim.similarities.SparseMatrixSimilarity(
gscorpus, num_features=len(self._vectorizer.vocabulary_))
annif.util.atomic_save(
self._index,
self.datadir,
self.INDEX_FILE)
def _create_subject_index(self, subject_corpus):
self._subjects = annif.corpus.SubjectIndex(subject_corpus)
annif.util.atomic_save(self._subjects, self.datadir, 'subjects')
def _create_model(self):
train_path = os.path.join(self.datadir, self.TRAIN_FILE)
model_path = os.path.join(self.datadir, self.MODEL_FILE)
hyper_param = omikuji.Model.default_hyper_param()
hyper_param.cluster_balanced = annif.util.boolean(
self.params['cluster_balanced'])
hyper_param.cluster_k = int(self.params['cluster_k'])
hyper_param.max_depth = int(self.params['max_depth'])
self._model = omikuji.Model.train_on_data(train_path, hyper_param)
if os.path.exists(model_path):
shutil.rmtree(model_path)
self._model.save(os.path.join(self.datadir, self.MODEL_FILE))
def train(self, corpus):
sources = annif.util.parse_sources(self.params['sources'])
self._create_model(sources)
self._learn(corpus, epochs=int(self.params['epochs']))
def _merge_hits_from_sources(self, hits_from_sources, params):
"""Hook for merging hits from sources. Can be overridden by
subclasses."""
return annif.util.merge_hits(hits_from_sources, self.project.subjects)