🔥码云GVP开源项目 12k star Uniapp+ElementUI 功能强大 支持多语言、二开方便! 广告
# `tf.Assert()` 调试 TensorFlow 模型的另一种方法是插入条件断言。`tf.Assert()`函数需要一个条件,如果条件为假,则打印给定张量的列表并抛出 `tf.errors.InvalidArgumentError` 。 1. `tf.Assert()` 函数具有以下特征: ```py tf.Assert( condition, data, summarize=None, name=None ) ``` 1. 断言操作不会像`tf.Print()`函数那样落入图的路径中。为了确保`tf.Assert()`操作得到执行,我们需要将它添加到依赖项中。例如,让我们定义一个断言来检查所有输入是否为正: ```py assert_op = tf.Assert(tf.reduce_all(tf.greater_equal(x,0)),[x]) ``` 1. 在定义模型时将`assert_op`添加到依赖项,如下所示: ```py with tf.control_dependencies([assert_op]): # x is input layer layer = x # add hidden layers for i in range(num_layers): layer = tf.nn.relu(tf.matmul(layer, w[i]) + b[i]) # add output layer layer = tf.matmul(layer, w[num_layers]) + b[num_layers] ``` 1. 为了测试这段代码,我们在第 5 周期之后引入了一个杂质,如下: ```py if epoch > 5: X_batch = np.copy(X_batch) X_batch[0,0]=-2 ``` 1. 代码运行正常五个周期,然后抛出错误: ```py epoch: 0000 loss = 6.975991 epoch: 0001 loss = 2.246228 epoch: 0002 loss = 1.924571 epoch: 0003 loss = 1.745509 epoch: 0004 loss = 1.616791 epoch: 0005 loss = 1.520804 ----------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) ... InvalidArgumentError: assertion failed: [[-2 0 0]...] ... ``` 除了`tf.Assert()`函数,它可以采用任何有效的条件表达式,TensorFlow 提供以下断言操作,检查特定条件并具有简单的语法: * `assert_equal` * `assert_greater` * `assert_greater_equal` * `assert_integer` * `assert_less` * `assert_less_equal` * `assert_negative` * `assert_none_equal` * `assert_non_negative` * `assert_non_positive` * `assert_positive` * `assert_proper_iterable` * `assert_rank` * `assert_rank_at_least` * `assert_rank_in` * `assert_same_float_dtype` * `assert_scalar` * `assert_type` * `assert_variables_initialized` 作为示例,前面提到的示例断言操作也可以写成如下: ```py assert_op = tf.assert_greater_equal(x,0) ```