我试图用pyro语言建立一个高斯HMM模型来推断一个非常简单的马尔可夫序列的参数。但是,我的模型无法推断出参数,并且在训练过程中发生了一些异常。使用相同的序列,hmmlearn成功地推断出了真实的参数。
完整的代码可以在这里访问:
https://colab.research.google.com/drive/1u_4J-dg9Y1CDLwByJ6FL4oMWMFUVnVNd#scrollTo=ZJ4PzdTUBgJi
我的模型是从下面的示例中修改而来的:
https://github.com/pyro-ppl/pyro/blob/dev/examples/hmm.py
我手动创建了一个一阶马尔可夫序列,其中有3个状态,真实均值是[-10,0,10],σ是[1,2,1]。
这是我的模型
def model(observations, num_state):
assert not torch._C._get_tracing_state()
with poutine.mask(mask = True):
p_transition = pyro.sample("p_transition",
dist.Dirichlet((1 / num_state) * torch.ones(num_state, num_state)).to_event(1))
p_init = pyro.sample("p_init",
dist.Dirichlet((1 / num_state) * torch.ones(num_state)))
p_mu = pyro.param(name = "p_mu",
init_tensor = torch.randn(num_state),
constraint = constraints.real)
p_tau = pyro.param(name = "p_tau",
init_tensor = torch.ones(num_state),
constraint = constraints.positive)
current_state = pyro.sample("x_0",
dist.Categorical(p_init),
infer = {"enumerate" : "parallel"})
for t in pyro.markov(range(1, len(observations))):
current_state = pyro.sample("x_{}".format(t),
dist.Categorical(Vindex(p_transition)[current_state, :]),
infer = {"enumerate" : "parallel"})
pyro.sample("y_{}".format(t),
dist.Normal(Vindex(p_mu)[current_state], Vindex(p_tau)[current_state]),
obs = observations[t])
我的模型被编译为
device = torch.device("cuda:0")
obs = torch.tensor(obs)
obs = obs.to(device)
torch.set_default_tensor_type("torch.cuda.FloatTensor")
guide = AutoDelta(poutine.block(model, expose_fn = lambda msg : msg["name"].startswith("p_")))
Elbo = Trace_ELBO
elbo = Elbo(max_plate_nesting = 1)
optim = Adam({"lr": 0.001})
svi = SVI(model, guide, optim, elbo)
随着训练的进行,ELBO稳定地减小,如图所示。然而,状态的三个平均值收敛。
我试过将我的模型的for循环放入pyro.plate中,并将pyro.param切换到pyro.sample,反之亦然,但对我的模型没有任何效果。
1条答案
按热度按时间iyfamqjs1#
我没有尝试过这个模型,但是我认为应该可以通过以下方式修改模型来解决问题:def模型(观察值,num_state):
然后使用MCMC训练模型:
然后可以使用以下方法分析结果:
mcmc.get_samples()