Nuages ​​de points dans Pandas / Pyplot: Comment tracer par catégorie


89

J'essaie de faire un simple nuage de points dans pyplot en utilisant un objet Pandas DataFrame, mais je veux un moyen efficace de tracer deux variables mais que les symboles soient dictés par une troisième colonne (clé). J'ai essayé différentes manières d'utiliser df.groupby, mais sans succès. Un exemple de script df est ci-dessous. Cela colore les marqueurs en fonction de «key1», mais j'aimerais voir une légende avec les catégories «key1». Suis-je proche? Merci.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()

Réponses:


118

Vous pouvez utiliser scatterpour cela, mais cela nécessite d'avoir des valeurs numériques pour votre key1, et vous n'aurez pas de légende, comme vous l'avez remarqué.

Il est préférable de l'utiliser uniquement plotpour des catégories discrètes comme celle-ci. Par exemple:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()

plt.show()

entrez la description de l'image ici

Si vous souhaitez que les choses ressemblent au pandasstyle par défaut , mettez simplement à jour la rcParamsfeuille de style pandas et utilisez son générateur de couleurs. (Je peaufine également légèrement la légende):

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')

fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')

plt.show()

entrez la description de l'image ici


Pourquoi, dans l'exemple RVB ci-dessus, le symbole est-il affiché deux fois dans la légende? Comment ne montrer qu'une seule fois?
Steve Schulist

1
@SteveSchulist - Permet ax.legend(numpoints=1)d'afficher un seul marqueur. Il y en a deux, comme avec a Line2D, il y a souvent une ligne reliant les deux marqueurs.
Joe Kington du

Ce code n'a fonctionné pour moi qu'après l'ajout plt.hold(True)après la ax.plot()commande. Une idée pourquoi?
Yuval Atzmon

set_color_cycle() était obsolète dans matplotlib 1.5. Il y a set_prop_cycle(), maintenant.
ale

52

C'est simple à faire avec Seaborn ( pip install seaborn) comme oneliner

sns.scatterplot(x_vars="one", y_vars="two", data=df, hue="key1") :

import seaborn as sns
import pandas as pd
import numpy as np
np.random.seed(1974)

df = pd.DataFrame(
    np.random.normal(10, 1, 30).reshape(10, 3),
    index=pd.date_range('2010-01-01', freq='M', periods=10),
    columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)

sns.scatterplot(x="one", y="two", data=df, hue="key1")

entrez la description de l'image ici

Voici le dataframe pour référence:

entrez la description de l'image ici

Puisque vous avez trois colonnes variables dans vos données, vous souhaiterez peut-être tracer toutes les dimensions par paires avec:

sns.pairplot(vars=["one","two","three"], data=df, hue="key1")

entrez la description de l'image ici

https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/ est une autre option.


19

Avec plt.scatter, je ne peux penser qu'à un seul: utiliser un artiste proxy:

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)

ccm=x.get_cmap()
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)]
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)

Et le résultat est:

entrez la description de l'image ici


10

Vous pouvez utiliser df.plot.scatter, et passer un tableau à c = argument définissant la couleur de chaque point:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
colors = np.where(df["key1"]==4,'r','-')
colors[df["key1"]==6] = 'g'
colors[df["key1"]==8] = 'b'
print(colors)
df.plot.scatter(x="one",y="two",c=colors)
plt.show()

entrez la description de l'image ici


4

Vous pouvez également essayer Altair ou ggpot qui sont axés sur les visualisations déclaratives.

import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

Code Altair

from altair import Chart
c = Chart(df)
c.mark_circle().encode(x='x', y='y', color='label')

entrez la description de l'image ici

code ggplot

from ggplot import *
ggplot(aes(x='x', y='y', color='label'), data=df) +\
geom_point(size=50) +\
theme_bw()

entrez la description de l'image ici


3

À partir de matplotlib 3.1, vous pouvez utiliser .legend_elements(). Un exemple est présenté dans la création de légende automatisée . L'avantage est qu'un seul appel scatter peut être utilisé.

Dans ce cas:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)


fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ax.legend(*sc.legend_elements())
plt.show()

entrez la description de l'image ici

Dans le cas où les clés ne seraient pas directement données sous forme de nombres, cela ressemblerait à

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = list("AAABBBCCCC")

labels, index = np.unique(df["key1"], return_inverse=True)

fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
ax.legend(sc.legend_elements()[0], labels)
plt.show()

entrez la description de l'image ici


J'ai eu une erreur disant que l'objet 'PathCollection' n'a pas d'attribut 'legends_elements'. Mon code est le suivant. fig, ax = plt.subplots(1, 1, figsize = (4,4)) scat = ax.scatter(rand_jitter(important_dataframe["workout_type_int"], jitter = 0.04), important_dataframe["distance"], c = color_list, marker = 'o', alpha = 0.9) print(scat.legends_elements()) #ax.legend(*scat.legend_elements())
Nandish Patel

1
@NandishPatel Vérifiez la toute première phrase de cette réponse. Assurez-vous également de ne pas confondre legends_elementset legend_elements.
ImportanceOfBeingErnest

Oui merci. C'était une faute de frappe (légendes / légende). Je travaillais sur quelque chose depuis 6 heures, donc la version Matplotlib ne m'est pas venue à l'esprit. Je pensais que j'utilisais le dernier. J'étais confus que la documentation indique qu'il existe une telle méthode, mais le code donnait une erreur. Merci encore. Je peux dormir maintenant.
Nandish Patel

2

Il est assez hacky, mais vous pouvez utiliser one1comme Float64Indexpour tout faire en une seule fois:

df.set_index('one').sort_index().groupby('key1')['two'].plot(style='--o', legend=True)

entrez la description de l'image ici

Notez qu'à partir de 0.20.3, le tri de l'index est nécessaire et la légende est un peu bancale .


1

seaborn a une fonction d'emballage scatterplotqui le fait plus efficacement.

sns.scatterplot(data = df, x = 'one', y = 'two', data =  'key1'])
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.