keras tf.math.pow(x,0.5)的激活函数导致NaN损失

dw1jzc5e  于 2023-01-05  发布在  其他
关注(0)|答案(1)|浏览(213)

我尝试为Keras序列模型(特别是MNIST数据集)使用自定义平方根激活函数。当我使用tf.math.sqrt(x)时,训练进行得很顺利,模型也相当准确。但是,当我尝试使用tf.math.pow(x, 0.5)时,模型训练失败,损失变为NaN。
我真的不知道为什么会发生这种情况,因为我会认为这两种选择是相同的。
平方根函数

def tfsqrt(x):
    cond = tf.greater_equal(x, 0)
    return tf.where(cond, tf.math.sqrt(x), -tf.math.sqrt(-x))

幂函数

def pwsqrt(x):
  cond = tf.greater_equal(x, 0)
  return tf.where(cond, tf.math.pow(x, 0.5), -tf.math.pow(-x, 0.5))

如果有人能解释这种意外的行为,那将是非常感激的。谢谢!

ufj5ltwl

ufj5ltwl1#

功能正确:x=tf.变量([-2.0,-3.0,0.0,1.0,2.0])

y=tfsqrt(x)
y
y=pwsqrt(x)
y

这些函数在google colab中工作得很好,也许数据中有一些nan值。
可能是模型丢失或度量有问题。

相关问题