How to use the pycorrector.seq2seq.config function in pycorrector

To help you get started, we’ve selected a few pycorrector examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github shibing624 / pycorrector / pycorrector / seq2seq / train.py View on Github external
def train(train_path=config.train_path,
          output_dir=config.output_dir,
          save_model_dir=config.save_model_dir,
          vocab_path=config.vocab_path,
          val_path=config.val_path,
          vocab_max_size=config.vocab_max_size,
          vocab_min_count=config.vocab_min_count,
          batch_size=config.batch_size,
          epochs=config.epochs,
          learning_rate=0.0001,
          src_emb_dim=128,
          trg_emb_dim=128,
          src_hidden_dim=256,
          trg_hidden_dim=256,
          src_num_layers=1,
          batch_first=True,
          src_bidirection=True,
github shibing624 / pycorrector / pycorrector / seq2seq / infer.py View on Github external
'由我起开始做',
        '不能人类实现更美好的将来。',
        '这几年前时间,',
        '歌曲使人的感到快乐,',
    ]
    inference = Inference(vocab_path=config.vocab_path,
                          model_path=config.model_path)
    for i in inputs:
        gen = inference.infer(i)
        print('input:', i, 'output:', gen)

    if not os.path.exists(config.predict_out_path):
        # infer test file
        infer_by_file(model_path=config.model_path,
                      output_dir=config.output_dir,
                      test_path=config.test_path,
                      predict_out_path=config.predict_out_path,
                      vocab_path=config.vocab_path)
github shibing624 / pycorrector / pycorrector / seq2seq / infer.py View on Github external
'由我起开始做。',
        '由我起开始做',
        '不能人类实现更美好的将来。',
        '这几年前时间,',
        '歌曲使人的感到快乐,',
    ]
    inference = Inference(vocab_path=config.vocab_path,
                          model_path=config.model_path)
    for i in inputs:
        gen = inference.infer(i)
        print('input:', i, 'output:', gen)

    if not os.path.exists(config.predict_out_path):
        # infer test file
        infer_by_file(model_path=config.model_path,
                      output_dir=config.output_dir,
                      test_path=config.test_path,
                      predict_out_path=config.predict_out_path,
                      vocab_path=config.vocab_path)
github shibing624 / pycorrector / pycorrector / seq2seq / infer.py View on Github external
def infer_by_file(model_path,
                  output_dir,
                  test_path,
                  predict_out_path,
                  vocab_path,
                  src_seq_lens=128,
                  trg_seq_lens=128,
                  beam_size=5,
                  batch_size=1,
                  gpu_id=0):
    if gpu_id > -1:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id)
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
    else:
        device = torch.device('cpu')
    print('device:', device)
    test_batch = create_batch_file(output_dir, 'test', test_path, batch_size=batch_size)
    print('The number of batches (test): {}'.format(test_batch))

    vocab2id = load_word_dict(vocab_path)
    id2vocab = {v: k for k, v in vocab2id.items()}
    print('The vocabulary file:%s, size: %s' % (vocab_path, len(vocab2id)))

    model = Seq2Seq(
        src_vocab_size=len(vocab2id),
github shibing624 / pycorrector / pycorrector / seq2seq / infer.py View on Github external
'这几年前时间,',
        '歌曲使人的感到快乐,',
    ]
    inference = Inference(vocab_path=config.vocab_path,
                          model_path=config.model_path)
    for i in inputs:
        gen = inference.infer(i)
        print('input:', i, 'output:', gen)

    if not os.path.exists(config.predict_out_path):
        # infer test file
        infer_by_file(model_path=config.model_path,
                      output_dir=config.output_dir,
                      test_path=config.test_path,
                      predict_out_path=config.predict_out_path,
                      vocab_path=config.vocab_path)
github shibing624 / pycorrector / pycorrector / seq2seq / train_generator.py View on Github external
decoder_model.save(decoder_model_path)
    logger.info("Model save to " + save_model_path)
    logger.info("Training has finished.")

    evaluate(encoder_model, decoder_model, num_encoder_tokens,
             num_decoder_tokens, rnn_hidden_dim, target_token_index,
             max_target_texts_len, encoder_input_data_val, input_texts)


if __name__ == "__main__":
    train(train_path=config.train_path,
          save_model_path=config.save_model_path,
          encoder_model_path=config.encoder_model_path,
          decoder_model_path=config.decoder_model_path,
          save_input_token_path=config.input_vocab_path,
          save_target_token_path=config.target_vocab_path,
          batch_size=config.batch_size,
          epochs=config.epochs,
          rnn_hidden_dim=config.rnn_hidden_dim)
github shibing624 / pycorrector / pycorrector / seq2seq / infer.py View on Github external
'没有解决这个问题,',
        '由我起开始做。',
        '由我起开始做',
        '不能人类实现更美好的将来。',
        '这几年前时间,',
        '歌曲使人的感到快乐,',
    ]
    inference = Inference(vocab_path=config.vocab_path,
                          model_path=config.model_path)
    for i in inputs:
        gen = inference.infer(i)
        print('input:', i, 'output:', gen)

    if not os.path.exists(config.predict_out_path):
        # infer test file
        infer_by_file(model_path=config.model_path,
                      output_dir=config.output_dir,
                      test_path=config.test_path,
                      predict_out_path=config.predict_out_path,
                      vocab_path=config.vocab_path)
github shibing624 / pycorrector / pycorrector / seq2seq / infer.py View on Github external
'不能人类实现更美好的将来。',
        '这几年前时间,',
        '歌曲使人的感到快乐,',
    ]
    inference = Inference(vocab_path=config.vocab_path,
                          model_path=config.model_path)
    for i in inputs:
        gen = inference.infer(i)
        print('input:', i, 'output:', gen)

    if not os.path.exists(config.predict_out_path):
        # infer test file
        infer_by_file(model_path=config.model_path,
                      output_dir=config.output_dir,
                      test_path=config.test_path,
                      predict_out_path=config.predict_out_path,
                      vocab_path=config.vocab_path)
github shibing624 / pycorrector / pycorrector / seq2seq / train.py View on Github external
attn_method='luong_concat',
          repetition='vanilla',
          network='lstm',
          pointer_net=True,
          attn_decoder=True,
          shared_embedding=True,
          share_emb_weight=True,
          src_seq_lens=128,
          trg_seq_lens=128,
          grad_clip=2.0,
          save_model_batch_num=config.save_model_batch_num,
          gpu_id=config.gpu_id):
    print('Training model...')

    if gpu_id > -1:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id)
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
    else:
        device = torch.device('cpu')
    print('device:', device)
    source_texts, target_texts = build_dataset(train_path)
    print('source_texts:', source_texts[0])
    print('target_texts:', target_texts[0])

    vocab2id = read_vocab(source_texts, max_size=vocab_max_size, min_count=vocab_min_count)
    num_encoder_tokens = len(vocab2id)
    max_input_texts_len = max([len(text) for text in source_texts])
    print('num of samples:', len(source_texts))
    print('num of unique input tokens:', num_encoder_tokens)
github shibing624 / pycorrector / pycorrector / seq2seq / train_generator.py View on Github external
epochs=epochs,
        verbose=1,
        validation_data=([encoder_input_data_val, decoder_input_data_val], decoder_target_data_val),
        callbacks=callbacks_list)
    encoder_model.save(encoder_model_path)
    decoder_model.save(decoder_model_path)
    logger.info("Model save to " + save_model_path)
    logger.info("Training has finished.")

    evaluate(encoder_model, decoder_model, num_encoder_tokens,
             num_decoder_tokens, rnn_hidden_dim, target_token_index,
             max_target_texts_len, encoder_input_data_val, input_texts)


if __name__ == "__main__":
    train(train_path=config.train_path,
          save_model_path=config.save_model_path,
          encoder_model_path=config.encoder_model_path,
          decoder_model_path=config.decoder_model_path,
          save_input_token_path=config.input_vocab_path,
          save_target_token_path=config.target_vocab_path,
          batch_size=config.batch_size,
          epochs=config.epochs,
          rnn_hidden_dim=config.rnn_hidden_dim)