多应用+插件架构,代码干净,二开方便,首家独创一键云编译技术,文档视频完善,免费商用码云13.8K 广告
# 强化学习的朴素神经网络策略 我们按照以下策略进行: 1. 让我们实现一个朴素的基于神经网络的策略。为定义一个新策略使用基于神经网络的预测来返回动作: ```py def policy_naive_nn(nn,obs): return np.argmax(nn.predict(np.array([obs]))) ``` 1. 将`nn`定义为一个简单的单层 MLP 网络,它将具有四个维度的观测值作为输入,并产生两个动作的概率: ```py from keras.models import Sequential from keras.layers import Dense model = Sequential() model.add(Dense(8,input_dim=4, activation='relu')) model.add(Dense(2, activation='softmax')) model.compile(loss='categorical_crossentropy',optimizer='adam') model.summary() ``` 这就是模型的样子: ```py Layer (type) Output Shape Param # ================================================================= dense_16 (Dense) (None, 8) 40 _________________________________________________________________ dense_17 (Dense) (None, 2) 18 ================================================================= Total params: 58 Trainable params: 58 Non-trainable params: 0 ``` 1. 这个模型需要训练。运行 100 集的模拟并仅收集分数大于 100 的那些剧集的训练数据。如果分数小于 100,那么这些状态和动作不值得记录,因为它们不是好戏的例子: ```py # create training data env = gym.make('CartPole-v0') n_obs = 4 n_actions = 2 theta = np.random.rand(4) * 2 - 1 n_episodes = 100 r_max = 0 t_max = 0 x_train, y_train = experiment(env, policy_random, n_episodes, theta,r_max,t_max, return_hist_reward=100 ) y_train = np.eye(n_actions)[y_train] print(x_train.shape,y_train.shape) ``` 我们能够收集 5732 个样本进行训练: ```py (5732, 4) (5732, 2) ``` 1. 接下来,训练模型: ```py model.fit(x_train, y_train, epochs=50, batch_size=10) ``` 1. 训练的模型可用于玩游戏。但是,在我们合并更新训练数据的循环之前,模型不会从游戏的进一步游戏中学习: ```py n_episodes = 200 r_max = 0 t_max = 0 _ = experiment(env, policy_naive_nn, n_episodes, theta=model, r_max=r_max, t_max=t_max, return_hist_reward=0 ) _ = experiment(env, policy_random, n_episodes, theta,r_max,t_max, return_hist_reward=0 ) ``` 我们可以看到,这种朴素的策略几乎以同样的方式执行,虽然比随机策略好一点: ```py Policy:policy_naive_nn, Min reward:37.0, Max reward:200.0, Average reward:71.05 Policy:policy_random, Min reward:36.0, Max reward:200.0, Average reward:68.755 ``` 我们可以通过网络调整和超参数调整,或通过学习更多游戏玩法来进一步改进结果。 但是,有更好的算法,例如 Q-Learning。 在本章的其余部分,我们将重点关注 Q-Learning 算法,因为大多数现实生活中的问题涉及无模型学习。