使用scipy / python计算多元(4d)正态分布的维数顺序?

ehxuflar  于 2022-12-10  发布在  Python
关注(0)|答案(1)|浏览(164)

我想在一个四维网格上计算一个四维高斯/正态分布。(x1,y1,x2,y2)。那么如果我有均值=(x1=1,y1=0,x2=2,y2=0),我希望当我在x1,x2方向上绘制2D等值线图时,在y1=y2=0处,可以看到以(x1=1,x2=2)。然而,我看到的是平均值/中心在(x1=2,x2=0)。
我错过了什么?是我如何定义网格开始?
对于2d正态分布,它的工作方式与预期一致。

import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import multivariate_normal

xy_min = -5
xy_max = 5
npoints = 50
x = np.linspace(xy_min, xy_max, npoints)
dim = 4
xx1,yy1,xx2,yy2 = np.meshgrid(x, x,x,x)
points = np.concatenate([xx1[:, :,:, :,None], yy1[:, :, :,:,None],xx2[:, :, :,:,None],yy2[:, :, :,:,None]], axis=-1)

cov = np.diag(np.ones(4))
mean=np.array([1,0,2,0])
rv = multivariate_normal.pdf(points , mean=mean, cov=cov)

plt.figure()
plt.contourf(x, x, rv[:,0,:,0])

我试着先手动重塑评估点,但它给出了相同的结果。所以我想我在这里在概念上遗漏了一些东西?

points_resh = np.reshape(points,[npoints**4,dim],order='C')
rv_resh = multivariate_normal.pdf(points_resh , mean=mean, cov=cov)
rv2 = np.reshape(rv_resh,[npoints,npoints,npoints,npoints],order='C')

plt.figure()
plt.contourf(x, x, rv2[:,0,:,0])

EDIT:SOLVED对meshgrid使用ij索引,一切都按预期工作。只需记住,绘制等值线时需要转置矩阵。请参见以下示例:

#%% Instead use ij indexing

x = np.linspace(-5, 5, 50)
y = np.linspace(-3, 3, 30)
z= np.linspace(-2, 2, 20)
w= np.linspace(-1, 1, 10)

x4d,y4d,z4d,w4d= np.meshgrid(x, y,z,w,indexing='ij')
points4d= np.concatenate([x4d[:, :,:,:,None], y4d[:,  :,:,:,None], z4d[:, :,:,:,None],w4d[:, :,:,:,None]], axis=-1)


rv4d = multivariate_normal.pdf(points4d  , mean=[1,0.0,2,0.0],  cov=[0.1,0.1,0.1,0.1])

fig,ax=plt.subplots()
ax.contourf(x,z,rv4d[:,0,:,0].T)
ax.set(xlabel='x',ylabel='y')
print(x_mean)
0sgqnhkj

0sgqnhkj1#

对meshgrid使用ij索引,一切都按预期工作。只需要记住,矩阵需要转置以绘制等值线。参见以下示例:

#%% Instead use ij indexing

x = np.linspace(-5, 5, 50)
y = np.linspace(-3, 3, 30)
z= np.linspace(-2, 2, 20)
w= np.linspace(-1, 1, 10)

x4d,y4d,z4d,w4d= np.meshgrid(x, y,z,w,indexing='ij')
points4d= np.concatenate([x4d[:, :,:,:,None], y4d[:,  :,:,:,None], z4d[:, :,:,:,None],w4d[:, :,:,:,None]], axis=-1)


rv4d = multivariate_normal.pdf(points4d  , mean=[1,0.0,2,0.0],  cov=[0.1,0.1,0.1,0.1])

fig,ax=plt.subplots()
ax.contourf(x,z,rv4d[:,0,:,0].T)
ax.set(xlabel='x',ylabel='y')
print(x_mean)

相关问题