Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"n_jobs": n_jobs,
"alpha": 0.0001,
"fit_intercept": True,
"random_state": 0,
}
_LOGGER.info("Fit SGD Classifier")
clf = SGDClassifier(**parameters).fit(matrix, train_target)
obj = {
"scikit-learn_version": pkg_resources.get_distribution("scikit-learn").version,
"tfidf_model": tfidf_model,
"clf": clf,
"target_names": target_names,
}
if store: # pragma: no cover
path = language.topdir(lang).joinpath("clf.joblib")
_LOGGER.info("Store classifier at {}".format(path))
with path.open("wb") as file:
joblib.dump(obj, file)
return obj
def training_set(lang="en_US"):
training_set_ = []
path = language.topdir(lang).joinpath("train")
for file in path.iterdir():
if file.suffix == ".json":
with file.open("r", encoding="utf-8") as train_file:
training_set_ += json.load(train_file)
return training_set_
def __init__(self, obj=None, lang="en_US"):
"""
Load the intent classifier
"""
self.tfidf_model = None
self.classifier = None
self.target_names = None
if not USE_CLF:
return
if not obj:
path = language.topdir(lang).joinpath("clf.joblib")
with path.open("rb") as file:
obj = joblib.load(file)
cur_scipy_version = pkg_resources.get_distribution("scikit-learn").version
if cur_scipy_version != obj.get("scikit-learn_version"): # pragma: no cover
_LOGGER.warning(
"The classifier was built using a different scikit-learn "
"version (={}, !={}). The disambiguation tool could behave "
"unexpectedly. Consider running classifier.train_classfier()".format(
obj.get("scikit-learn_version"), cur_scipy_version
)
)
self.tfidf_model = obj["tfidf_model"]
self.classifier = obj["clf"]
self.target_names = obj["target_names"]
if isinstance(unit, classes.Unit):
surfaces.update(unit.surfaces)
surfaces.update(unit.symbols)
for surface in surfaces:
neighbours = v.most_similar(
v.query(surface), topn=topn, min_similarity=min_similarity
)
training_set.append(
{
"unit": name,
"text": " ".join(neighbour[0] for neighbour in neighbours),
}
)
print("Done")
with language.topdir(lang).joinpath("train/similars.json").open(
"w", encoding="utf-8"
) as file:
json.dump(training_set, file, sort_keys=True, indent=4)
assert len(set(names)) == len(general_entities)
except AssertionError: # pragma: no cover
raise Exception(
"Entities with same name: %s" % [i for i in names if names.count(i) > 1]
)
self.names = dict(
(
k["name"],
c.Entity(name=k["name"], dimensions=k["dimensions"], uri=k["URI"]),
)
for k in general_entities
)
# Update with language specific URI
with TOPDIR.joinpath(language.topdir(lang), "entities.json").open(
"r", encoding="utf-8"
) as file:
lang_entities = json.load(file)
for ent in lang_entities:
general_entities[ent["name"]].uri = ent["URI"]
# Generate derived units
derived_ent = defaultdict(set)
for entity in self.names.values():
if not entity.dimensions:
continue
perms = self.get_dimension_permutations(entity.dimensions)
for perm in perms:
key = get_key_from_dimensions(perm)
derived_ent[key].add(entity)
objs = []
for num, page in enumerate(pages):
obj = {
"_id": page[1],
"url": "https://{}.wikipedia.org/wiki/{}".format(lang[:2], page[1]),
"clean": page[1].replace("_", " "),
}
print("---> Downloading %s (%d of %d)" % (obj["clean"], num + 1, len(pages)))
obj["text"] = wikipedia.page(obj["clean"], auto_suggest=False).content
obj["unit"] = page[0]
objs.append(obj)
path = language.topdir(lang).joinpath("train/wiki.json")
if store:
with path.open("w") as wiki_file:
json.dump(objs, wiki_file, indent=4, sort_keys=True)
print("\n---> All done.\n")
return objs