pytorch 为什么我在pyro中所有的HMM的发射μ都收敛到同一个数?

iecba09b  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(126)

我试图用pyro语言建立一个高斯HMM模型来推断一个非常简单的马尔可夫序列的参数。但是,我的模型无法推断出参数,并且在训练过程中发生了一些异常。使用相同的序列,hmmlearn成功地推断出了真实的参数。
完整的代码可以在这里访问:
https://colab.research.google.com/drive/1u_4J-dg9Y1CDLwByJ6FL4oMWMFUVnVNd#scrollTo=ZJ4PzdTUBgJi
我的模型是从下面的示例中修改而来的:
https://github.com/pyro-ppl/pyro/blob/dev/examples/hmm.py
我手动创建了一个一阶马尔可夫序列,其中有3个状态,真实均值是[-10,0,10],σ是[1,2,1]。
这是我的模型

def model(observations, num_state):

    assert not torch._C._get_tracing_state()

    with poutine.mask(mask = True):

        p_transition = pyro.sample("p_transition",
                                   dist.Dirichlet((1 / num_state) * torch.ones(num_state, num_state)).to_event(1))

        p_init = pyro.sample("p_init",
                             dist.Dirichlet((1 / num_state) * torch.ones(num_state)))

    p_mu = pyro.param(name = "p_mu",
                      init_tensor = torch.randn(num_state),
                      constraint = constraints.real)

    p_tau = pyro.param(name = "p_tau",
                       init_tensor = torch.ones(num_state),
                       constraint = constraints.positive)

    current_state = pyro.sample("x_0",
                                dist.Categorical(p_init),
                                infer = {"enumerate" : "parallel"})

    for t in pyro.markov(range(1, len(observations))):

        current_state = pyro.sample("x_{}".format(t),
                                    dist.Categorical(Vindex(p_transition)[current_state, :]),
                                    infer = {"enumerate" : "parallel"})

        pyro.sample("y_{}".format(t),
                    dist.Normal(Vindex(p_mu)[current_state], Vindex(p_tau)[current_state]),
                    obs = observations[t])

我的模型被编译为

device = torch.device("cuda:0")
obs = torch.tensor(obs)
obs = obs.to(device)

torch.set_default_tensor_type("torch.cuda.FloatTensor")

guide = AutoDelta(poutine.block(model, expose_fn = lambda msg : msg["name"].startswith("p_")))

Elbo = Trace_ELBO
elbo = Elbo(max_plate_nesting = 1)

optim = Adam({"lr": 0.001})
svi = SVI(model, guide, optim, elbo)

随着训练的进行,ELBO稳定地减小,如图所示。然而,状态的三个平均值收敛。
我试过将我的模型的for循环放入pyro.plate中,并将pyro.param切换到pyro.sample,反之亦然,但对我的模型没有任何效果。

iyfamqjs

iyfamqjs1#

我没有尝试过这个模型,但是我认为应该可以通过以下方式修改模型来解决问题:def模型(观察值,num_state):

assert not torch._C._get_tracing_state()

with poutine.mask(mask = True):

    p_transition = pyro.sample("p_transition",
                            dist.Dirichlet((1 / num_state) * torch.ones(num_state, num_state)).to_event(1))

    p_init = pyro.sample("p_init",
                        dist.Dirichlet((1 / num_state) * torch.ones(num_state)))

p_mu = pyro.sample("p_mu",
                dist.Normal(torch.zeros(num_state), torch.ones(num_state)).to_event(1))

p_tau = pyro.sample("p_tau",
                dist.HalfCauchy(torch.zeros(num_state)).to_event(1))

current_state = pyro.sample("x_0",
                            dist.Categorical(p_init),
                            infer = {"enumerate" : "parallel"})

for t in pyro.markov(range(1, len(observations))):

    current_state = pyro.sample("x_{}".format(t),
                                dist.Categorical(Vindex(p_transition)[current_state, :]),
                                infer = {"enumerate" : "parallel"})

    pyro.sample("y_{}".format(t),
                dist.Normal(Vindex(p_mu)[current_state], Vindex(p_tau)[current_state]),
                obs = observations[t])

然后使用MCMC训练模型:


# MCMC

hmc_kernel = NUTS(model, target_accept_prob = 0.9, max_tree_depth = 7)
mcmc = MCMC(hmc_kernel, num_samples = 1000, warmup_steps = 100, num_chains = 1)
mcmc.run(obs)

然后可以使用以下方法分析结果:mcmc.get_samples()

相关问题