Matrice de corrélation de tracé à l'aide de pandas


212

J'ai un ensemble de données avec un grand nombre de fonctionnalités, donc l'analyse de la matrice de corrélation est devenue très difficile. Je veux tracer une matrice de corrélation que nous obtenons en utilisant la dataframe.corr()fonction de la bibliothèque pandas. Existe-t-il une fonction intégrée fournie par la bibliothèque pandas pour tracer cette matrice?


Des réponses connexes peuvent être trouvées ici Création d'une carte
thermique à

Réponses:


293

Vous pouvez utiliser à pyplot.matshow() partir de matplotlib:

import matplotlib.pyplot as plt

plt.matshow(dataframe.corr())
plt.show()

Éditer:

Dans les commentaires, il y avait une demande sur la façon de changer les étiquettes des graduations des axes. Voici une version de luxe qui est dessinée sur une plus grande taille de figure, a des étiquettes d'axe pour correspondre au cadre de données et une légende de barre de couleur pour interpréter l'échelle de couleurs.

J'inclus comment ajuster la taille et la rotation des étiquettes, et j'utilise un ratio de chiffres qui fait sortir la barre de couleur et la figure principale de la même hauteur.

f = plt.figure(figsize=(19, 15))
plt.matshow(df.corr(), fignum=f.number)
plt.xticks(range(df.shape[1]), df.columns, fontsize=14, rotation=45)
plt.yticks(range(df.shape[1]), df.columns, fontsize=14)
cb = plt.colorbar()
cb.ax.tick_params(labelsize=14)
plt.title('Correlation Matrix', fontsize=16);

exemple de tracé de corrélation


1
Je dois manquer quelque chose:AttributeError: 'module' object has no attribute 'matshow'
Tom Russell

1
@TomRussell L'avez-vous fait import matplotlib.pyplot as plt?
joelostblom

1
J'aimerais penser que je l'ai fait! :-)
Tom Russell

7
savez-vous comment afficher les noms de colonne réels sur le tracé?
WebQube

2
@Cecilia J'avais résolu ce problème en changeant le paramètre de rotation à 90
ikbel benabdessamad

182

Si votre objectif principal est de visualiser la matrice de corrélation, plutôt que de créer un tracé en soi, les pandas options de style pratiques sont une solution intégrée viable:

import pandas as pd
import numpy as np

rs = np.random.RandomState(0)
df = pd.DataFrame(rs.rand(10, 10))
corr = df.corr()
corr.style.background_gradient(cmap='coolwarm')
# 'RdBu_r' & 'BrBG' are other good diverging colormaps

entrez la description de l'image ici

Notez que cela doit être dans un backend qui prend en charge le rendu HTML, tel que le bloc-notes JupyterLab. (Le texte clair automatique sur fond sombre provient d'un PR existant et non de la dernière version publiée, pandas0,23).


Coiffant

Vous pouvez facilement limiter la précision des chiffres:

corr.style.background_gradient(cmap='coolwarm').set_precision(2)

entrez la description de l'image ici

Ou supprimez complètement les chiffres si vous préférez la matrice sans annotations:

corr.style.background_gradient(cmap='coolwarm').set_properties(**{'font-size': '0pt'})

entrez la description de l'image ici

La documentation de style comprend également des instructions de styles plus avancés, telles que la façon de modifier l'affichage de la cellule sur laquelle le pointeur de la souris survole. Pour enregistrer la sortie, vous pouvez retourner le code HTML en ajoutant la render()méthode, puis l'écrire dans un fichier (ou simplement prendre une capture d'écran à des fins moins formelles).


Comparaison de temps

Dans mes tests, style.background_gradient()était 4x plus rapide que plt.matshow()et 120x plus rapide sns.heatmap()qu'avec une matrice 10x10. Malheureusement, il n'est pas aussi évolutif que cela plt.matshow(): les deux prennent environ le même temps pour une matrice 100x100 et plt.matshow()sont 10 fois plus rapides pour une matrice 1000x1000.


Économie

Il existe plusieurs façons d’enregistrer la trame de données stylisée:

  • Renvoyez le code HTML en ajoutant le render() méthode, puis écrivez la sortie dans un fichier.
  • Enregistrez en tant que .xslxfichier avec une mise en forme conditionnelle en ajoutant la to_excel()méthode.
  • Combinez avec imgkit pour enregistrer un bitmap
  • Faites une capture d'écran (à des fins moins formelles).

Mise à jour pour les pandas> = 0,24

En définissant axis=None, il est désormais possible de calculer les couleurs en fonction de la matrice entière plutôt que par colonne ou par ligne:

corr.style.background_gradient(cmap='coolwarm', axis=None)

entrez la description de l'image ici


2
S'il y avait un moyen d'exporter en tant qu'image, ça aurait été génial!
Kristada673

1
Merci! Vous avez certainement besoin d'une palette divergenteimport seaborn as sns corr = df.corr() cm = sns.light_palette("green", as_cmap=True) cm = sns.diverging_palette(220, 20, sep=20, as_cmap=True) corr.style.background_gradient(cmap=cm).set_precision(2)
stallingOne

1
@stallingOne Bon point, je n'aurais pas dû inclure de valeurs négatives dans l'exemple, je pourrais changer cela plus tard. Juste pour référence pour les personnes qui lisent ceci, vous n'avez pas besoin de créer une cmap divergente personnalisée avec seaborn (bien que celle dans le commentaire ci-dessus semble assez lisse), vous pouvez également utiliser les cmaps divergentes intégrées de matplotlib, par exemple corr.style.background_gradient(cmap='coolwarm'). Il n'existe actuellement aucun moyen de centrer la cmap sur une valeur spécifique, ce qui peut être une bonne idée avec des cmaps divergentes.
joelostblom

1
@rovyko Êtes-vous sous pandas> = 0,24,0?
joelostblom

2
Ces tracés sont visuellement superbes, mais la question @ Kristada673 est tout à fait pertinente, comment les exporteriez-vous?
Erfan

89

Essayez cette fonction, qui affiche également des noms de variables pour la matrice de corrélation:

def plot_corr(df,size=10):
    '''Function plots a graphical correlation matrix for each pair of columns in the dataframe.

    Input:
        df: pandas DataFrame
        size: vertical and horizontal size of the plot'''

    corr = df.corr()
    fig, ax = plt.subplots(figsize=(size, size))
    ax.matshow(corr)
    plt.xticks(range(len(corr.columns)), corr.columns);
    plt.yticks(range(len(corr.columns)), corr.columns);

6
plt.xticks(range(len(corr.columns)), corr.columns, rotation='vertical')si vous voulez une orientation verticale des noms de colonnes sur l'axe des x
nishant

Autre élément graphique, mais l'ajout d'un plt.tight_layout()peut également être utile pour les noms de colonne longs.
user3017048

86

Version de la carte thermique de Seaborn:

import seaborn as sns
corr = dataframe.corr()
sns.heatmap(corr, 
            xticklabels=corr.columns.values,
            yticklabels=corr.columns.values)

9
La carte thermique Seaborn est fantaisiste mais elle fonctionne mal sur les grandes matrices. La méthode matshow de matplotlib est beaucoup plus rapide.
anilbey

3
Seaborn peut automatiquement déduire les ticklabels des noms de colonne.
Tulio Casagrande

80

Vous pouvez observer la relation entre les entités en dessinant une carte thermique à partir de la mer ou une matrice de dispersion des pandas.

Matrice de dispersion:

pd.scatter_matrix(dataframe, alpha = 0.3, figsize = (14,8), diagonal = 'kde');

Si vous souhaitez également visualiser l'asymétrie de chaque fonctionnalité, utilisez des diagrammes de paires nés en mer.

sns.pairplot(dataframe)

Sns Heatmap:

import seaborn as sns

f, ax = pl.subplots(figsize=(10, 8))
corr = dataframe.corr()
sns.heatmap(corr, mask=np.zeros_like(corr, dtype=np.bool), cmap=sns.diverging_palette(220, 10, as_cmap=True),
            square=True, ax=ax)

Le résultat sera une carte de corrélation des entités. c'est à dire voir l'exemple ci-dessous.

entrez la description de l'image ici

La corrélation entre l'épicerie et les détergents est élevée. De même:

Pdoducts à corrélation élevée:
  1. Épicerie et détergents.
Produits à corrélation moyenne:
  1. Lait et épicerie
  2. Lait et détergents_Paper
Produits à faible corrélation:
  1. Lait et charcuterie
  2. Congelé et frais.
  3. Frozen and Deli.

De Pairplots: Vous pouvez observer le même ensemble de relations à partir de pairplots ou de matrice de dispersion. Mais à partir de ceux-ci, nous pouvons dire que les données sont normalement distribuées ou non.

entrez la description de l'image ici

Remarque: Ce qui précède est le même graphique tiré des données, qui est utilisé pour dessiner une carte thermique.


3
Je pense que ce devrait être .plt pas .pl (si cela fait référence à matplotlib)
ghukill

2
@ghukill Pas nécessairement. Il aurait pu le désigner commefrom matplotlib import pyplot as pl
Jeru Luke

comment définir la limite de la corrélation entre -1 et +1 toujours, dans le tracé de corrélation
debaonline4u

7

Vous pouvez utiliser la méthode imshow () de matplotlib

import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')

plt.imshow(X.corr(), cmap=plt.cm.Reds, interpolation='nearest')
plt.colorbar()
tick_marks = [i for i in range(len(X.columns))]
plt.xticks(tick_marks, X.columns, rotation='vertical')
plt.yticks(tick_marks, X.columns)
plt.show()

5

Si votre dataframe est, dfvous pouvez simplement utiliser:

import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(15, 10))
sns.heatmap(df.corr(), annot=True)

3

les graphiques statmodels donnent également une belle vue de la matrice de corrélation

import statsmodels.api as sm
import matplotlib.pyplot as plt

corr = dataframe.corr()
sm.graphics.plot_corr(corr, xnames=list(corr.columns))
plt.show()

3

Pour être complet, la solution la plus simple que je connaisse avec Seaborn à la fin de 2019, si l'on utilise Jupyter :

import seaborn as sns
sns.heatmap(dataframe.corr())

1

Avec d'autres méthodes, il est également bon d'avoir un pairplot qui donnera un nuage de points pour tous les cas -

import pandas as pd
import numpy as np
import seaborn as sns
rs = np.random.RandomState(0)
df = pd.DataFrame(rs.rand(10, 10))
sns.pairplot(df)

0

Matrice de corrélation de forme, dans mon cas, zdf est la trame de données dont j'ai besoin pour effectuer la matrice de corrélation.

corrMatrix =zdf.corr()
corrMatrix.to_csv('sm_zscaled_correlation_matrix.csv');
html = corrMatrix.style.background_gradient(cmap='RdBu').set_precision(2).render()

# Writing the output to a html file.
with open('test.html', 'w') as f:
   print('<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-widthinitial-scale=1.0"><title>Document</title></head><style>table{word-break: break-all;}</style><body>' + html+'</body></html>', file=f)

Ensuite, nous pouvons prendre une capture d'écran. ou convertissez le HTML en fichier image.

En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.