如何在TensorFlow Probability中学习beta分布的参数?

hpxqektj  于 2023-05-01  发布在  其他
关注(0)|答案(1)|浏览(165)

我正在尝试使用TensorFlow Probability来学习beta分布的alpha和beta参数。由于某种原因,我无法让它工作-损失的都是NaN值。这是我的资料

from scipy.stats import beta
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

beta_sample_data = beta1 = beta.rvs(5,5,size=1000)
beta_train = tfd.Beta(concentration1=tf.Variable(1.,name='alpha'),concentration0=tf.Variable(1.,name='beta'),name='beta_train')

def nll(x_train,distribution):
    return -tf.reduce_mean(distribution.log_prob(x_train))

# Define a function to compute the loss and gradients
@tf.function
def get_loss_and_grads(x_train,distribution):
    with tf.GradientTape() as tape:
        tape.watch(distribution.trainable_variables)
        loss = nll(x_train, distribution)
        grads = tape.gradient(loss,distribution.trainable_variables)
        
    return loss,grads

def beta_dist_optimisation(data, distribution):

    # Keep results for plotting
    train_loss_results = []
    train_rate_results = []
    
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.005)

    num_steps = 10

    for i in range(num_steps):
        loss,grads = get_loss_and_grads(data,distribution)
        print(loss,grads)
        optimizer.apply_gradients(zip(grads,distribution.trainable_variables))
        alpha_value = distribution.concentration1.value()
        beta_value = distribution.concentration0.value()
        train_loss_results.append(loss)
        train_rate_results.append((alpha_value,beta_value))
        
        
        
        print("Step {:03d}: Loss: {:.3f}: Alpha: {:.3f} Beta: {:.3f}".format(i,loss,alpha_value,beta_value))
        
    return train_loss_results, train_rate_results

sample_data = tf.cast(beta_sample_data, tf.float32)

train_loss_results, train_rate_results = beta_dist_optimisation(sample_data,beta_train)

我试着用最大似然法来学习5,5的alpha和beta参数。

w6lpcovy

w6lpcovy1#

在使用变量之前,您需要将变量约束为正数。否则,梯度步长可能会使它们变为负值并给予nans。你可以在传入之前调用softplus。你也可以看看tfp。util.TransformedVariable.文档中应该有一些例子。

相关问题