matplotlib find top 10 countries according to one column, get grouped barchart according to another (python)

i1icjdpr  于 2023-02-23  发布在  Python
关注(0)|答案(1)|浏览(157)

So, I have a df which comes from a Kaggle survey. The columns I am interested in are :
| country | gender | id |
| ------------ | ------------ | ------------ |
| USA | Woman | 5612 |
| Germany | Man | 5613 |
| USA | non-binary | 5614 |
| India | Man | 5615 |
What I want to do now is plot a grouped bar chart of the top 10 countries (meaning top number of participants in the survey) showing the gender distribution of each top 10 country.
I managed to get a result close to what I want:

data_gender = df.groupby(['country', 'gender']).count().sort_values('id', ascending=False).reset_index()
data_gender.head()

fig = px.histogram(data_gender, x='county_residence', y="id",
             color='gender', barmode='group', height=400)
fig.show()

It results in exactly what I want: several bars representing the different gender categories per country.
plot I get with code above
BUT I just can't figure out how to only show the plot for the top 10 countries of participant count.
It did find out which countries are in the top 10 with:

dftop10 = df.groupby(['country']).size().to_frame('count').sort_values('count', ascending = False).reset_index()

and also like so:

df_top10 = df.value_counts("county_residence")
top10 = df_top10.head(10).index.tolist()
genders = df[df['county_residence'].isin(top10)].groupby('gender').count()['id']

But I am running down dead ends. When I get to find the top 10 countries, I loose the info about each gender category. How can I

  1. find the top 10 countries by overall participant count
  2. get a subset of the data with all data on only those countries with counts for each gender category per country
  3. and then with that run that plot code again?
    I have been trying to figure this out for hours now. Please help me solve this :)
lxkprmvk

lxkprmvk1#

While identifying the countries, you can pick just the first 10 rows using [:10] . While plotting, filter the data_gender column to show only those countries. That should do it...

data_gender = df.groupby(['country', 'gender']).count().sort_values('id', ascending=False).reset_index()

## Pick the first 10 rows only using [:10]
dfTop10=df.groupby(['country']).size().to_frame('count').sort_values('count', ascending = False).reset_index()[:10]

## Filter data_gender to show only countries in dfTop10 using .isin
fig = px.histogram(data_gender[data_gender['country'].isin(dfTop10['country'])], x='country', y="id",
             color='gender', barmode='group', height=400)
fig.show()

Output with random data

相关问题