💎一站式轻松地调用各大LLM模型接口,支持GPT4、智谱、星火、月之暗面及文生图 广告
这里以官网的一个入门案例来进行回顾: ### 1. 导入包 ``` import tensorflow as tf ``` ### 2. 准备和处理数据集 ``` // [MNIST 数据集](http://yann.lecun.com/exdb/mnist/) mnist = tf.keras.datasets.mnist // 加载训练集和测试集 (x_train, y_train), (x_test, y_test) = mnist.load_data() ``` 由于图片的构成为0-255的值,且含有三个维度,所以这里为了方便进行灰度处理,也就是: ``` x_train, x_test = x_train / 255.0, x_test / 255.0 ``` 将其数值整到0-1的范围,其维度不变,这里不妨看下维度信息: ``` x_train.shape y_train.shape ``` 结果为: ![](https://img.kancloud.cn/1d/5b/1d5bbad1462caf2b9604241a33df4260_296x164.png) ### 3. 定义模型 示例中使用的是序列接模型: ``` model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ]) ``` 使用Flatten来将三维数据扁平化,变为二维数据,也就是60000x768。Dense为一个全连接层,定义为: ~~~ `output = activation(dot(input, kernel) + bias)` ~~~ 传入的128,也就是units,表示输出的维度。activation为激活函数,当然它还有其余的参数: ~~~ tf.keras.layers.Dense(units, activation=None, use_bias=True, kernel_initializer='glorot_uniform',    bias_initializer='zeros', kernel_regularizer=None,    bias_regularizer=None, activity_regularizer=None, kernel_constraint=None,    bias_constraint=None, **kwargs) ~~~ ![](https://img.kancloud.cn/06/e0/06e0f006b1aa5906fa9ba1231486c52b_1086x552.png) 注意到这里只指定了输出的维度为128,而没有想第一层Flatten中指定input_shape,因为默认程序自己可以知道,所以不必设置。 Dropout定义为: ~~~ tf.keras.layers.Dropout(rate, noise_shape=None, seed=None, **kwargs) ~~~ 主要用来防止过拟合,也就是:Dropout层在训练期间的每一步中将输入单位随机设置为0,未设置为0的输入将按1 /(1\-rate)放大,以使**所有输入的总和不变**。比如在官网中给的例子: ![](https://img.kancloud.cn/f1/c7/f1c7f5dee1c55dcff5d0dca5e18a4d91_680x432.png) ### 4. 模型训练 先指定一下优化器,损失函数和度量值: ``` model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) ``` 不妨看下模型参数量: ``` model.summary() ``` ![](https://img.kancloud.cn/15/7c/157c035fb6dabf3eb142af2f3084f115_905x333.png) 然后进行拟合数据,训练: ``` model.fit(x_train, y_train, epochs=5) ``` 可以看见训练过程中的结果: ![](https://img.kancloud.cn/de/51/de513c8c1480b70a509abb625ccdd9f0_956x273.png) ### 5. 模型评估 ``` model.evaluate(x_test, y_test, verbose=2) ``` ![](https://img.kancloud.cn/5d/f2/5df2f100f8580861db5ca60827d66892_456x120.png)