🔥码云GVP开源项目 12k star Uniapp+ElementUI 功能强大 支持多语言、二开方便! 广告
# R 中的 TF 核心 API 我们在第 1 章中了解了 TensorFlow 核心 API。在 R 中,该 API 使用 `tensorflow` R 包实现。 作为一个例子,我们提供了 MLP 模型的演练,用于在以下链接中对来自 MNIST 数据集的手写数字进行分类: [https://tensorflow.rstudio.com/tensorflow/articles/examples/mnist_softmax.html](https://tensorflow.rstudio.com/tensorflow/articles/examples/mnist_softmax.html) 。 您可以按照 Jupyter R 笔记本中的代码`ch-17a_TFCore_in_R`。 1. 首先,加载库: ```r library(tensorflow) ``` 1. 定义超参数: ```r batch_size <- 128 num_classes <- 10 steps <- 1000 ``` 1. 准备数据: ```r datasets <- tf$contrib$learn$datasets mnist <- datasets$mnist$read_data_sets("MNIST-data", one_hot = TRUE) ``` 数据从 TensorFlow 数据集库加载,并已标准化为[0,1]范围。 1. 定义模型: ```r # Create the model x <- tf$placeholder(tf$float32, shape(NULL, 784L)) W <- tf$Variable(tf$zeros(shape(784L, num_classes))) b <- tf$Variable(tf$zeros(shape(num_classes))) y <- tf$nn$softmax(tf$matmul(x, W) + b) # Define loss and optimizer y_ <- tf$placeholder(tf$float32, shape(NULL, num_classes)) cross_entropy <- tf$reduce_mean(-tf$reduce_sum(y_ * log(y), reduction_indices=1L)) train_step <- tf$train$GradientDescentOptimizer(0.5)$minimize(cross_entropy) ``` 1. 训练模型: ```r # Create session and initialize variables sess <- tf$Session() sess$run(tf$global_variables_initializer()) # Train for (i in 1:steps) { batches <- mnist$train$next_batch(batch_size) batch_xs <- batches[[1]] batch_ys <- batches[[2]] sess$run(train_step, feed_dict = dict(x = batch_xs, y_ = batch_ys)) } ``` 1. 评估模型: ```r correct_prediction <- tf$equal(tf$argmax(y, 1L), tf$argmax(y_, 1L)) accuracy <- tf$reduce_mean(tf$cast(correct_prediction, tf$float32)) score <-sess$run(accuracy, feed_dict = dict(x = mnist$test$images, y_ = mnist$test$labels)) cat('Test accuracy:', score, '\n') ``` 输出如下: ```r Test accuracy: 0.9185 ``` 太酷了! 通过以下链接查找 R 中 TF Core 的更多示例:[https://tensorflow.rstudio.com/tensorflow/articles/examples/](https://tensorflow.rstudio.com/tensorflow/articles/examples/) 有关`tensorflow` R 包的更多文档可以在以下链接中找到:[https://tensorflow.rstudio.com/tensorflow/reference/](https://tensorflow.rstudio.com/tensorflow/reference/).