1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
| import os import numpy as np
def predict(direction = 'zh_fr', n_head = 4, num_encoder_layers=1, num_decoder_layers=1, emb_dim=512, dim_feedforward=256): L1,L2 = direction.split('_') with open('datasets/'+L1+'.vocab') as f: L1_vocab_size = len(f.readlines()) with open('datasets/'+L2+'.vocab') as f: L2_vocab_size = len(f.readlines())
model = MyNet(emb_dim = emb_dim, n_head = n_head, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, L1_vocab_size = L1_vocab_size, L2_vocab_size = L2_vocab_size) if 1: param_dict = paddle.load(direction+'.pdparams') model.load_dict(param_dict) model.eval()
with open(direction+'.rst','w') as output_file: with open('datasets/'+direction+'.test') as input_file: for sentence in input_file.readlines(): ''' if L1 == 'zh': sentence = ' '.join(jieba.cut(sentence)) ''' with open('tmp.txt','w') as f: f.write(sentence) os.system('subword-nmt apply-bpe -c datasets/'+L1+'.bpe < tmp.txt > tmp2.txt')
class IDmanager: def __init__(self, vocab_dir = 'datasets/zh.vocab', thre = 1): self.vocab = ['<unk>', '<s>', '<e>', '<pad>'] with open(vocab_dir,'r') as f: for item in f.readlines(): word, num = item.split(' ') if int(num[:-1])>=thre: self.vocab.append(word) def word2id(self, sentence): s = sentence.split(' ') r = [] for item in s: try: i = self.vocab.index(item) except: i = 0 r.append(str(i)) return r def id2word(self, bos_id_list): r = [] for item in bos_id_list: r.append(self.vocab[item]) return r
manager = IDmanager(vocab_dir = 'datasets/'+L1+'.vocab') with open('tmp2.txt','r') as f: line = f.readlines()[0] bos_id_list = manager.word2id(line)
os.system('rm -rf tmp.txt') os.system('rm -rf tmp2.txt')
def PADtoMax(bos_id, max_len = 52, with_s = 0): if with_s: bos_id = [1] + bos_id if len(bos_id)+1<max_len: bos_id = bos_id + [2] +(max_len - len(bos_id) - 1) * [3] else: bos_id = bos_id + [2] bos_id = bos_id[:max_len] return bos_id
max_len = 52 src_id = paddle.to_tensor(PADtoMax(bos_id_list),dtype = 'int64') src_id = src_id.reshape([1]+src_id.shape)
raw_tgt_id = [] for i in range(max_len): tgt_id = PADtoMax(raw_tgt_id, max_len = 52, with_s = 1) tgt_id = paddle.to_tensor(tgt_id,dtype = 'int64') tgt_id = tgt_id.reshape([1]+tgt_id.shape) pre = model(src_id, tgt_id) new_id = np.argmax(pre[0][i]) raw_tgt_id.append(new_id)
manager = IDmanager(vocab_dir = 'datasets/'+L2+'.vocab')
s1 = ' '.join(manager.id2word(raw_tgt_id)) s2 = s1.replace('<s> ','').replace('<e> ','').replace('<pad> ','').replace('<unk> ','').replace('@@ ','').replace('<s>','').replace('<e>','').replace('<pad>','').replace('<unk>','').replace('@@','') output_file.write(s2+'\n')
|