Keras阻止意外行为

u3r8eeie  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(136)

我正在使用DNN解决语音去噪问题。我通过下面的函数计算信噪比。

def calculate_snr(clean_signal, recovered_signal):

    clean_power = tf.reduce_sum(tf.square(clean_signal))

    noise_power = tf.reduce_sum(tf.square(clean_signal - recovered_signal))

    snr_db = 10 * tf.math.log(clean_power / noise_power) / tf.math.log(10.0)

    return snr_db

我正在使用keras API来创建这样的模型

model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam(learning_rate=learning_rate),metrics=[calculate_snr])

sound_denoising_history = model.fit(x = X_abs.T, y = S_abs.T,epochs=200,batch_size = 100,validation_data=(X_test_01_abs.T,S_test_01_abs.T))

calculate_snr (X_test_01_abs.T,model.predict(X_test_01_abs.T) : 10.9
While model fit: -4.4 to -3

当我训练它时,我看到用于验证的SNR指标为-7,并在该范围内振荡。而如果我预测xval输入,然后用上面的函数,它给我8.2。这是相同的功能,我已经检查了尺寸多次。我不知道发生了什么事?
编辑:我知道我错过了一个处理信号计算的步骤,但即使指标是独立使用的,它也应该在列车端产生几乎相同的近似值,然后进行推理,然后进行计算

q1qsirdb

q1qsirdb1#

当您在model.compile中使用calculate_snr作为度量时,它会在训练期间分批应用,然后对这些分批值进行平均以计算最终度量。这可能会导致计算的SNR与在代码结束时进行预测后在整个数据集上手动计算时的SNR存在差异。
您可以通过将snr_metric定义为类来克服此限制。

class SNRMetric(keras.metrics.Metric):
    def __init__(self, **kwargs):
        super(SNRMetric, self).__init__(**kwargs)
        self.clean_power = self.add_weight(name="clean_power", initializer="zeros")
        self.noise_power = self.add_weight(name="noise_power", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        clean_power = tf.reduce_sum(tf.square(y_true))
        noise_power = tf.reduce_sum(tf.square(y_true - y_pred))

        self.clean_power.assign_add(clean_power)
        self.noise_power.assign_add(noise_power)
        self.count.assign_add(1)

    def result(self):
        snr_db = 10 * tf.math.log(self.clean_power / self.noise_power) / tf.math.log(10.0)
        return snr_db

然后你可以修改你的代码来进行训练和测试,如下所示:

# TODO define your model

model.compile(
   loss='mean_squared_error', 
   optimizer=keras.optimizers.Adam(learning_rate=learning_rate), 
   metrics=[SNRMetric()] # here the crucial point
)

# Train 
sound_denoising_history = model.fit(x=X_abs.T, y=S_abs.T, epochs=200, batch_size=100, validation_data=(X_test_01_abs.T, S_test_01_abs.T))

# Calculate SNR using the custom metric after training
snr_metric = SNRMetric()
snr_metric.update_state(S_test_01_abs.T, model.predict(X_test_01_abs.T))
snr_value = snr_metric.result()
print(f"SNR after training: {snr_value.numpy()}")

相关问题