Vous pouvez passer plt.scatter
un c
argument qui vous permettra de sélectionner les couleurs. Le code ci-dessous définit un colors
dictionnaire pour mapper vos couleurs de diamant aux couleurs de traçage.
import matplotlib.pyplot as plt
import pandas as pd
carat = [5, 10, 20, 30, 5, 10, 20, 30, 5, 10, 20, 30]
price = [100, 100, 200, 200, 300, 300, 400, 400, 500, 500, 600, 600]
color =['D', 'D', 'D', 'E', 'E', 'E', 'F', 'F', 'F', 'G', 'G', 'G',]
df = pd.DataFrame(dict(carat=carat, price=price, color=color))
fig, ax = plt.subplots()
colors = {'D':'red', 'E':'blue', 'F':'green', 'G':'black'}
ax.scatter(df['carat'], df['price'], c=df['color'].apply(lambda x: colors[x]))
plt.show()
df['color'].apply(lambda x: colors[x])
mappe efficacement les couleurs du «diamant» au «tracé».
(Pardonnez-moi de ne pas mettre une autre image d'exemple, je pense que 2 suffit: P)
Avec seaborn
Vous pouvez utiliser seaborn
un wrapper matplotlib
qui le rend plus joli par défaut (plutôt basé sur l'opinion, je sais: P) mais ajoute également des fonctions de traçage.
Pour cela, vous pouvez utiliser seaborn.lmplot
avec fit_reg=False
(ce qui l'empêche de faire automatiquement une régression).
Le code ci-dessous utilise un exemple de jeu de données. En sélectionnant, hue='color'
vous dites à seaborn de diviser votre trame de données en fonction de vos couleurs, puis de tracer chacune d'elles.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
carat = [5, 10, 20, 30, 5, 10, 20, 30, 5, 10, 20, 30]
price = [100, 100, 200, 200, 300, 300, 400, 400, 500, 500, 600, 600]
color =['D', 'D', 'D', 'E', 'E', 'E', 'F', 'F', 'F', 'G', 'G', 'G',]
df = pd.DataFrame(dict(carat=carat, price=price, color=color))
sns.lmplot('carat', 'price', data=df, hue='color', fit_reg=False)
plt.show()
Sans seaborn
utiliserpandas.groupby
Si vous ne souhaitez pas utiliser seaborn, vous pouvez utiliser pandas.groupby
pour obtenir les couleurs seules, puis les tracer en utilisant uniquement matplotlib, mais vous devrez attribuer manuellement les couleurs au fur et à mesure, j'ai ajouté un exemple ci-dessous:
fig, ax = plt.subplots()
colors = {'D':'red', 'E':'blue', 'F':'green', 'G':'black'}
grouped = df.groupby('color')
for key, group in grouped:
group.plot(ax=ax, kind='scatter', x='carat', y='price', label=key, color=colors[key])
plt.show()
Ce code suppose le même DataFrame que ci-dessus, puis le regroupe en fonction de color
. Il itère ensuite sur ces groupes, en traçant pour chacun d'eux. Pour sélectionner une couleur, j'ai créé un colors
dictionnaire qui peut mapper la couleur du diamant (par exemple D
) à une couleur réelle (par exemple red
).