💎一站式轻松地调用各大LLM模型接口,支持GPT4、智谱、星火、月之暗面及文生图 广告
# TensorFlow 中的 LSTM 文本生成 您可以在 Jupyter 笔记本`ch-08b_RNN_Text_TensorFlow`中按照本节的代码进行操作。 我们使用以下步骤在 TensorFlow 中实现文本生成 LSTM: 1. 让我们为`x`和`y`定义参数和占位符: ```py batch_size = 128 n_x = 5 # number of input words n_y = 1 # number of output words n_x_vars = 1 # in case of our text, there is only 1 variable at each timestep n_y_vars = text8.vocab_len state_size = 128 learning_rate = 0.001 x_p = tf.placeholder(tf.float32, [None, n_x, n_x_vars], name='x_p') y_p = tf.placeholder(tf.float32, [None, n_y_vars], name='y_p') ``` 对于输入,我们使用单词的整数表示,因此`n_x_vars`是 1.对于输出,我们使用单热编码值,因此输出的数量等于词汇长度。 1. 接下来,创建一个长度为`n_x`的张量列表: ```py x_in = tf.unstack(x_p,axis=1,name='x_in') ``` 1. 接下来,从输入和单元创建 LSTM 单元和静态 RNN 网络: ```py cell = tf.nn.rnn_cell.LSTMCell(state_size) rnn_outputs, final_states = tf.nn.static_rnn(cell, x_in,dtype=tf.float32) ``` 1. 接下来,我们定义最终层的权重,偏差和公式。最后一层只需要为第六个单词选择输出,因此我们应用以下公式来仅获取最后一个输出: ```py # output node parameters w = tf.get_variable('w', [state_size, n_y_vars], initializer= tf.random_normal_initializer) b = tf.get_variable('b', [n_y_vars], initializer=tf.constant_initializer(0.0)) y_out = tf.matmul(rnn_outputs[-1], w) + b ``` 1. 接下来,创建一个损失函数和优化器: ```py loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( logits=y_out, labels=y_p)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) .minimize(loss) ``` 1. 创建我们可以在会话块中运行的准确率函数,以检查训练模式的准确性: ```py n_correct_pred = tf.equal(tf.argmax(y_out,1), tf.argmax(y_p,1)) accuracy = tf.reduce_mean(tf.cast(n_correct_pred, tf.float32)) ``` 1. 最后,我们训练模型 1000 个周期,并每 100 个周期打印结果。此外,每 100 个周期,我们从上面描述的种子字符串打印生成的文本。 LSTM 和 RNN 网络需要对大量数据集进行大量周期的训练,以获得更好的结果。 请尝试加载完整的数据集并在计算机上运行 50,000或80,000 个周期,并使用其他超参数来改善结果。 ```py n_epochs = 1000 learning_rate = 0.001 text8.reset_index_in_epoch() n_batches = text8.n_batches_seq(batch_size=batch_size,n_tx=n_x,n_ty=n_y) n_epochs_display = 100 with tf.Session() as tfs: tf.global_variables_initializer().run() for epoch in range(n_epochs): epoch_loss = 0 epoch_accuracy = 0 for step in range(n_batches): x_batch, y_batch = text8.next_batch_seq(batch_size=batch_size, n_tx=n_x,n_ty=n_y) y_batch = dsu.to2d(y_batch,unit_axis=1) y_onehot = np.zeros(shape=[batch_size,text8.vocab_len], dtype=np.float32) for i in range(batch_size): y_onehot[i,y_batch[i]]=1 feed_dict = {x_p: x_batch.reshape(-1, n_x, n_x_vars), y_p: y_onehot} _, batch_accuracy, batch_loss = tfs.run([optimizer,accuracy, loss],feed_dict=feed_dict) epoch_loss += batch_loss epoch_accuracy += batch_accuracy if (epoch+1) % (n_epochs_display) == 0: epoch_loss = epoch_loss / n_batches epoch_accuracy = epoch_accuracy / n_batches print('\nEpoch {0:}, Average loss:{1:}, Average accuracy:{2:}'. format(epoch,epoch_loss,epoch_accuracy )) y_pred_r5 = np.empty([10]) y_pred_f5 = np.empty([10]) x_test_r5 = random5.copy() x_test_f5 = first5.copy() # let us generate text of 10 words after feeding 5 words for i in range(10): for x,y in zip([x_test_r5,x_test_f5], [y_pred_r5,y_pred_f5]): x_input = x.copy() feed_dict = {x_p: x_input.reshape(-1, n_x, n_x_vars)} y_pred = tfs.run(y_out, feed_dict=feed_dict) y_pred_id = int(tf.argmax(y_pred, 1).eval()) y[i]=y_pred_id x[:-1] = x[1:] x[-1] = y_pred_id print(' Random 5 prediction:',id2string(y_pred_r5)) print(' First 5 prediction:',id2string(y_pred_f5)) ``` 结果如下: ```py Epoch 99, Average loss:1.3972469369570415, Average accuracy:0.8489583333333334 Random 5 prediction: labor warren together strongly profits strongly supported supported co without First 5 prediction: market own self free together strongly profits strongly supported supported Epoch 199, Average loss:0.7894854595263799, Average accuracy:0.9186197916666666 Random 5 prediction: syndicalists spanish class movements also also anarcho anarcho anarchist was First 5 prediction: five civil association class movements also anarcho anarcho anarcho anarcho Epoch 299, Average loss:1.360412875811259, Average accuracy:0.865234375 Random 5 prediction: anarchistic beginnings influenced true tolstoy tolstoy tolstoy tolstoy tolstoy tolstoy First 5 prediction: early civil movement be for was two most most most Epoch 399, Average loss:1.1692512730757396, Average accuracy:0.8645833333333334 Random 5 prediction: including war than than revolutionary than than war than than First 5 prediction: left including including including other other other other other other Epoch 499, Average loss:0.5921860883633295, Average accuracy:0.923828125 Random 5 prediction: ever edited interested interested variety variety variety variety variety variety First 5 prediction: english market herbert strongly price interested variety variety variety variety Epoch 599, Average loss:0.8356450994809469, Average accuracy:0.8958333333333334 Random 5 prediction: management allow trabajo trabajo national national mag mag ricardo ricardo First 5 prediction: spain prior am working n war war war self self Epoch 699, Average loss:0.7057955612738928, Average accuracy:0.8971354166666666 Random 5 prediction: teachings can directive tend resist obey christianity author christianity christianity First 5 prediction: early early called social called social social social social social Epoch 799, Average loss:0.772875706354777, Average accuracy:0.90234375 Random 5 prediction: associated war than revolutionary revolutionary revolutionary than than revolutionary revolutionary First 5 prediction: political been hierarchy war than see anti anti anti anti Epoch 899, Average loss:0.43675946692625683, Average accuracy:0.9375 Random 5 prediction: individualist which which individualist warren warren tucker benjamin how tucker First 5 prediction: four at warren individualist warren published considered considered considered considered Epoch 999, Average loss:0.23202441136042276, Average accuracy:0.9602864583333334 Random 5 prediction: allow allow trabajo you you you you you you you First 5 prediction: labour spanish they they they movement movement anarcho anarcho two ``` 生成的文本中的重复单词是常见的,并且应该更好地训练模型。虽然模型的准确性提高到 96%,但仍然不足以生成清晰的文本。尝试增加 LSTM 单元/隐藏层的数量,同时在较大的数据集上运行模型以获取大量周期。 现在让我们在 Keras 建立相同的模型: