Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def initialize_corrector(self):
t1 = time.time()
# chinese common char dict
self.cn_char_set = self.load_char_set(self.common_char_path)
# same pinyin
self.same_pinyin = self.load_same_pinyin(self.same_pinyin_text_path)
t2 = time.time()
logger.debug("Loaded same pinyin file: %s, spend: %.3f s." % (
self.same_pinyin_text_path, t2 - t1))
# same stroke
self.same_stroke = self.load_same_stroke(self.same_stroke_text_path)
logger.debug("Loaded same stroke file: %s, spend: %.3f s." % (
self.same_stroke_text_path, time.time() - t2))
self.initialized_corrector = True
def detect(self, sentence):
"""
句子改错
:param sentence: 句子文本
:param threshold: 阈值
:return: list[list], [error_word, begin_pos, end_pos, error_type]
"""
maybe_errors = []
for prob, f in self.predict_token_prob(sentence):
logger.debug('prob:%s, token:%s, idx:%s' % (prob, f.token, f.id))
if prob < self.threshold:
maybe_errors.append([f.token, f.id, f.id + 1, ErrorType.char])
return maybe_errors
error_end_idx=error_end_idx
)
for f in eval_features:
input_ids = torch.tensor([f.input_ids])
segment_ids = torch.tensor([f.segment_ids])
outputs = self.model(input_ids, segment_ids)
predictions = outputs[0]
# confirm we were able to predict 'henson'
masked_ids = f.mask_ids
if masked_ids:
for idx, i in enumerate(masked_ids):
predicted_index = torch.argmax(predictions[0, i]).item()
predicted_token = self.bert_tokenizer.convert_ids_to_tokens([predicted_index])[0]
logger.debug('original text is: %s' % f.input_tokens)
logger.debug('Mask predict is: %s' % predicted_token)
corrected_item = predicted_token
return corrected_item
def initialize_bert_detector(self):
t1 = time.time()
self.bert_tokenizer = BertTokenizer(vocab_file=self.bert_model_vocab)
self.MASK_TOKEN = "[MASK]"
self.MASK_ID = self.bert_tokenizer.convert_tokens_to_ids([self.MASK_TOKEN])[0]
# Prepare model
self.model = BertForMaskedLM.from_pretrained(self.bert_model_dir)
logger.debug("Loaded model ok, path: %s, spend: %.3f s." % (self.bert_model_dir, time.time() - t1))
self.initialized_bert_detector = True
def initialize_corrector(self):
t1 = time.time()
# chinese common char dict
self.cn_char_set = self.load_char_set(self.common_char_path)
# same pinyin
self.same_pinyin = self.load_same_pinyin(self.same_pinyin_text_path)
t2 = time.time()
logger.debug("Loaded same pinyin file: %s, spend: %.3f s." % (
self.same_pinyin_text_path, t2 - t1))
# same stroke
self.same_stroke = self.load_same_stroke(self.same_stroke_text_path)
logger.debug("Loaded same stroke file: %s, spend: %.3f s." % (
self.same_stroke_text_path, time.time() - t2))
self.initialized_corrector = True
raise ImportError('pycorrector dependencies are not fully installed, '
'they are required for statistical language model.'
'Please use "pip install kenlm" to install it.'
'if you are Win, Please install kenlm in cgwin.')
if not os.path.exists(self.language_model_path):
filename = self.pre_trained_language_models.get(self.language_model_path,
'zh_giga.no_cna_cmn.prune01244.klm')
url = self.pre_trained_language_models.get(filename)
get_file(
filename, url, extract=True,
cache_dir=config.USER_DIR,
cache_subdir=config.USER_DATA_DIR,
verbose=1
)
self.lm = kenlm.Model(self.language_model_path)
logger.debug('Loaded language model: %s, spend: %s s' % (self.language_model_path, str(time.time() - t1)))
# 词、频数dict
t2 = time.time()
self.word_freq = self.load_word_freq_dict(self.word_freq_path)
t3 = time.time()
logger.debug('Loaded word freq file: %s, size: %d, spend: %s s' %
(self.word_freq_path, len(self.word_freq), str(t3 - t2)))
# 自定义混淆集
self.custom_confusion = self._get_custom_confusion_dict(self.custom_confusion_path)
t4 = time.time()
logger.debug('Loaded confusion file: %s, size: %d, spend: %s s' %
(self.custom_confusion_path, len(self.custom_confusion), str(t4 - t3)))
# 自定义切词词典
self.custom_word_freq = self.load_word_freq_dict(self.custom_word_freq_path)
self.person_names = self.load_word_freq_dict(self.person_name_path)
self.place_names = self.load_word_freq_dict(self.place_name_path)