matplotlib 创建网络图时节点颜色错误

ldxq2e6h  于 2023-08-06  发布在  其他
关注(0)|答案(1)|浏览(115)

创建网络图时,某些节点显示的类别颜色错误。
我编写了以下创建网络图的代码。它从两个CSV文件中读取数据-一个用于邻接矩阵,另一个用于类别。它假设根据节点所属的类别对节点进行着色,但它最终将一些节点着色错误。起初我认为这个问题是由于使用shell布局造成的,但当我使用不同的布局时也会发生这种情况。假设csv文件没有任何问题(类别定义正确,矩阵和类别之间没有不匹配等),问题可能是什么?

import pandas as pd 
import networkx as nx 
import matplotlib.pyplot as plt
 
def create_network_diagram_from_csv(adjacency_csv_path, categories_csv_path):
     # Read data from CSV into pandas DataFrames
     df_adjacency = pd.read_csv(adjacency_csv_path, index_col=0)
     df_categories = pd.read_csv(categories_csv_path, index_col=0)
 
     # Convert DataFrames to NumPy arrays for easier handling
     adjacency_matrix = df_adjacency.to_numpy()
 
     # Get the node names from the DataFrame's index
     nodes = df_adjacency.index.to_list()
 
     # Create a directed graph using networkx
     G = nx.DiGraph()
 
     # Iterate through the adjacency matrix and add edges to the graph
     for i, source_node in enumerate(nodes):
         for j, target_node in enumerate(nodes):
             edge_weight = adjacency_matrix[i, j]
             if not pd.isna(edge_weight) and edge_weight > 0:
                 G.add_edge(source_node, target_node)
 
     # Create the network diagram using shell layout
     pos = nx.shell_layout(G)
 
     # Get categories for each node from the categories DataFrame
     categories = df_categories.to_dict()['Category']
 
     # Define a mapping of categories to colors
     category_colors = {
         'category_a': 'green',
         'category_b': 'blue',
         'category_c': 'yellow',
         'category_d': 'brown',
         'category_e': 'red',
     }
 
     # Get the colors for each node based on its category
     node_colors = [category_colors.get(categories[node], 'gray') for node in nodes]

     # Customize node size
     node_size = 300
 
     # Draw the network diagram with nodes colored based on categories
     nx.draw(G, pos, with_labels=True, node_size=node_size, node_color=node_colors, font_size=10, arrows=True)
 
      #Create a custom legend for categories and colors
     legend_elements = [plt.Line2D([0], [0], marker='o', color='w', label=category, markerfacecolor=color, markersize=10) for category, color in category_colors.items()]
 
     # Add the legend to the plot
     plt.legend(handles=legend_elements, title='Categories', loc='upper right', frameon=True)
 
     edge_color = 'grey'  # Specify the color of the arrows
     # Draw the edges with the specified edge color
     nx.draw_networkx_edges(G, pos, arrows=True, edge_color=edge_color)
 
     plt.show()
 
adjacency_csv_path = r'C:\location of file'
categories_csv_path = r'C:\location of file'
create_network_diagram_from_csv(adjacency_csv_path, categories_csv_path)

字符串
分类数据:

Node, Category
node_1, category_a
node_2, category_b
node_3, category_d
node_4, category_a
node_5, category_a


邻接数据:

,node_1,node_2,node_3,node_4,node_5
node_1, 1, 0, 0, 0, 1
node_2, 0, 0, 1, 0, 0
node_3, 0, 1, 0, 1, 1
node_4, 0, 1, 0, 1, 0
node_5, 1, 1, 1, 1, 1

sirbozc5

sirbozc51#

创建图形的方式不正确,不要手动循环,而是使用from_pandas_adjacency

G = nx.from_pandas_adjacency(df_adjacency, create_using=nx.DiGraph)

字符串
输出量:


的数据

相关问题