如何在python中为matplotlib - scipy的树状图着色?

r8uurelv  于 2023-02-12  发布在  Python
关注(0)|答案(2)|浏览(185)

下面的代码对数据执行层次聚类:

Z = linkage(data,method='weighted')
  plt.subplot(2,1,1)
  dendro = dendrogram(Z)
  leaves = dendro['leaves']
  print leaves
  plt.show()

然而在树状图中所有的聚类都有相同的颜色(蓝色)。有没有办法根据聚类之间的相似性使用不同的颜色?

bf1o4zei

bf1o4zei1#

看看documentation,看起来你可以通过link_color_func关键字或color_threshold关键字来获得不同的颜色。
编辑:
树状图着色方案的默认行为是,给定color_threshold = 0.7*max(Z[:,2]),如果k是低于切割阈值的第一个节点,则将集群节点k之下的所有后代链接着色为相同的颜色;否则,连接距离大于或等于阈值的节点的所有链接都将被着色为蓝色[来自文档]。
这到底是什么意思?好吧,如果你看一个树状图,不同的聚类链接在一起。两个聚类之间的"距离"是它们之间链接的高度。color_threshold是低于这个高度的新聚类将是不同的颜色。如果你所有的聚类都是蓝色的,那么你需要提高你的color_threshold。例如,

In [48]: mat = np.random.rand(10, 10)
In [49]: z = linkage(mat, method="weighted")
In [52]: d = dendrogram(z)
In [53]: d['color_list']
Out[53]: ['g', 'g', 'b', 'r', 'c', 'c', 'c', 'b', 'b']
In [54]: plt.show()

我可以通过以下方式检查默认的color_threshold

In [56]: 0.7*np.max(z[:,2])
Out[56]: 1.0278719020096947

如果我降低color_threshold的值,我会得到更多的蓝色,因为更多的链接的距离大于新的color_threshold,你可以直观地看到这一点,因为所有大于0.9的链接现在都是蓝色的:

In [64]: d = dendrogram(z, color_threshold=.9)
In [65]: d['color_list']
Out[65]: ['g', 'b', 'b', 'r', 'b', 'b', 'b', 'b', 'b']
In [66]: plt.show()

如果我将color_threshold增加到1.21.2下面的链接将不再是蓝色。此外,青色和红色链接将合并为一种颜色,因为它们的父链接在1.2下面:

yrwegjxp

yrwegjxp2#

下面的代码将为每片叶子生成一个不同颜色的树状图,如果在合并聚类的过程中遇到两个不同颜色的聚类,那么它将选择默认的dflt_col = tab:blue
注意:link_matrix函数是scikit-learn中的AgglomerativeClustering示例中的一个函数的简单拷贝。
要解释它所做的一切,这真的很耗时。因此,print直接解释每一个不清楚的步骤。

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import squareform, pdist

from matplotlib.pyplot import cm

from sklearn.cluster import AgglomerativeClustering
import matplotlib.colors as clrs

def link_matrix(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram as in the standard sci-kit learn documentation
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count
    
    Z = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)

    return Z

def assign_link_colors(model):
    n_clusters = len(model.Z)
    scl_map_to_hex = mpl.cm.ScalarMappable(cmap = "jet").to_rgba(np.unique(model.labels_), norm = True) #colors.to_hex()
    col = [clrs.to_hex(rgb) for rgb in scl_map_to_hex]

    dic_labels = {s:[c, idx] for s, c, idx in zip(np.arange(len(model.feature_names_in_), dtype = int), model.feature_names_in_, model.labels_, )}
    model.dict_idx_name_cl = {k: v for k, v in sorted(dic_labels.items(), key=lambda item: item[1][1])}

    

    dflt_col = "tab:blue"   # Unclustered blue
    model.dict_colors = {x:col[model.dict_idx_name_cl[x][1]] for x in model.dict_idx_name_cl}
        
    link_cols = {}
    for i, i_cl in enumerate(model.Z[:,:2].astype(int)): # select only 1st two rows
        c1, c2 = (link_cols[x] if x > n_clusters else model.dict_colors[x] for x in i_cl)

        # Choice of coloring assignment: if same color --> ok; if no leaf, dft ("undefined") color 
        if c1 == c2:
            tmp_cl = c1 
        elif min(i_cl) <= n_clusters: # select the leaf color
            tmp_cl = model.dict_colors[min(i_cl)]
        else: 
            tmp_cl = dflt_col
        link_cols[i+1+n_clusters] = tmp_cl
        #print(f'-link_cols: {link_cols}',)
    
    return link_cols

def mod_2_dendrogram(model, **kwargs):

    plt.style.use('seaborn-whitegrid')
    plt.figure(figsize=(int(.5 * len(model.feature_names_in_)), 7))

    print(f'-0.7*max(Z[:,2]): {0.7*max(model.Z[:,2])}',)

    # Plot the corresponding dendrogram
    ddata = dendrogram(model.Z, #count_sort = "descending", 
                        **kwargs)

    # Plot distances on the dendrogram
    # plot cluster points & distance labels
    y_lim = dist_thr
    for i, d, c in zip(ddata['icoord'], ddata['dcoord'], ddata['color_list']):
        x = sum(i[1:3])/2
        y = d[1]
        if y > y_lim:
            plt.plot(x, y, 'o', c=c, markeredgewidth=0)
            plt.annotate(np.round(y,2), (x, y), xytext=(0, -5),
                        textcoords='offset points',
                        va='top', ha='center', fontsize=9)

    plt.axhline(y=dist_thr, color='orange', alpha = 0.7, linestyle='--', label = f"threshold: {int(model.dist_thr)}")
    plt.title(f'Agglomerative Dendrogram with n_clust: {model.n_clusters_}')
    plt.xlabel('Clusters')
    plt.ylabel('Distance')
    plt.legend()

    return ddata

现在,运行的例子:

import string
import pandas as pd
np.random.seed(0)
dist = np.random.randint(1e4, size = (10,10))
np.fill_diagonal(dist, 0)
dist = pd.DataFrame(dist, columns = list(string.ascii_lowercase)[:dist.shape[0]])

dist_thr = 1.5e3
model = AgglomerativeClustering(distance_threshold = dist_thr, n_clusters=None, linkage = "single", metric = "precomputed",)
model.dist_thr = dist_thr

model = model.fit(dist)
model.Z = link_matrix(model)

link_cols = assign_link_colors(model)

_ = mod_2_dendrogram(model, labels = dist.columns, 
                     link_color_func = lambda x: link_cols[x])

相关问题