🔥码云GVP开源项目 12k star Uniapp+ElementUI 功能强大 支持多语言、二开方便! 广告
# 加载和准备 PTB 数据集 首先导入模块并加载数据如下:: ```py from datasetslib.ptb import PTBSimple ptb = PTBSimple() # downloads data, converts words to ids, converts files to a list of ids ptb.load_data() print('Train :',ptb.part['train'][0:5]) print('Test: ',ptb.part['test'][0:5]) print('Valid: ',ptb.part['valid'][0:5]) print('Vocabulary Length = ',ptb.vocab_len) ``` 每个数据集的前五个元素以及词汇长度打印如下: ```py Train : [9970, 9971, 9972, 9974, 9975] Test: [102, 14, 24, 32, 752] Valid: [1132, 93, 358, 5, 329] Vocabulary Length = 10000 ``` 我们将上下文窗口设置为两个单词并获得 CBOW 对: ```py ptb.skip_window=2 ptb.reset_index_in_epoch() # in CBOW input is the context word and output is the target word y_batch, x_batch = ptb.next_batch_cbow() print('The CBOW pairs : context,target') for i in range(5 * ptb.skip_window): print('(', [ptb.id2word[x_i] for x_i in x_batch[i]], ',', y_batch[i], ptb.id2word[y_batch[i]], ')') ``` 输出是: ```py The CBOW pairs : context,target ( ['aer', 'banknote', 'calloway', 'centrust'] , 9972 berlitz ) ( ['banknote', 'berlitz', 'centrust', 'cluett'] , 9974 calloway ) ( ['berlitz', 'calloway', 'cluett', 'fromstein'] , 9975 centrust ) ( ['calloway', 'centrust', 'fromstein', 'gitano'] , 9976 cluett ) ( ['centrust', 'cluett', 'gitano', 'guterman'] , 9980 fromstein ) ( ['cluett', 'fromstein', 'guterman', 'hydro-quebec'] , 9981 gitano ) ( ['fromstein', 'gitano', 'hydro-quebec', 'ipo'] , 9982 guterman ) ( ['gitano', 'guterman', 'ipo', 'kia'] , 9983 hydro-quebec ) ( ['guterman', 'hydro-quebec', 'kia', 'memotec'] , 9984 ipo ) ( ['hydro-quebec', 'ipo', 'memotec', 'mlx'] , 9986 kia ) ``` 现在让我们看看 skip-gram 对: ```py ptb.skip_window=2 ptb.reset_index_in_epoch() # in skip-gram input is the target word and output is the context word x_batch, y_batch = ptb.next_batch() print('The skip-gram pairs : target,context') for i in range(5 * ptb.skip_window): print('(',x_batch[i], ptb.id2word[x_batch[i]], ',', y_batch[i], ptb.id2word[y_batch[i]],')') ``` 输出为: ```py The skip-gram pairs : target,context ( 9972 berlitz , 9970 aer ) ( 9972 berlitz , 9971 banknote ) ( 9972 berlitz , 9974 calloway ) ( 9972 berlitz , 9975 centrust ) ( 9974 calloway , 9971 banknote ) ( 9974 calloway , 9972 berlitz ) ( 9974 calloway , 9975 centrust ) ( 9974 calloway , 9976 cluett ) ( 9975 centrust , 9972 berlitz ) ( 9975 centrust , 9974 calloway ) ```