我试图生成16个子图,我的最终目标是有一个8 x 2大小的最终图,我的代码看起来像这样:
def visualize_t2t(token_dict, scores):
fig = plt.figure(figsize=(50, 50))
for idx, scores in enumerate(scores):
scores_np = np.array(scores)
ax = fig.add_subplot(12, 12, idx+1)
# append the attention weights
im = ax.imshow(scores, cmap='viridis')
fontdict = {'fontsize': 3}
ax.set_xticks(range(len(all_tokens)))
ax.set_yticks(range(len(all_tokens)))
ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
ax.set_yticklabels(all_tokens, fontdict=fontdict)
ax.set_xlabel('{} {}'.format('label_name', idx+1))
fig.colorbar(im, fraction=0.046, pad=0.04)
plt.tight_layout()
name_f = str(uuid.uuid4())
plt.savefig(f'{name_f}.pdf',
bbox_inches='tight',
dpi=350)
输入数据
all_tokens = ['[CLS]',
'what',
'type',
'of',
'heart',
'issue',
'does',
'the',
'Person',
'have',
'[CLS]']
dummy_input = np.random.uniform(-1, 1, [16, len(all_tokens), len(all_tokens)])
visualize_t2t(all_tokens, dummy_input)
但结果看起来是这样的:
如何设置行和列,使一行中有8个子图,而另一行中有8个子图?
1条答案
按热度按时间z6psavjg1#
只需将
ax = fig.add_subplot(12, 12, idx+1)
替换为ax = fig.add_subplot(2, 8, idx+1)
即可。