Comment utiliser la propagation d'étiquettes Scikit-Learn sur des données structurées graphiquement?


11

Dans le cadre de mes recherches, je m'intéresse à la propagation d'étiquettes sur un graphe. Je suis particulièrement intéressé par ces deux méthodes:

J'ai vu que scikit-learn propose un modèle pour le faire. Cependant, ce modèle est censé être appliqué à des données vectorielles structurées ( c'est-à-dire des points de données).

Le modèle construit une matrice d'affinité à partir des points de données à l'aide d'un noyau, puis exécute l'algorithme sur la matrice construite. Je voudrais pouvoir entrer directement la matrice d'adjacence de mon graphique à la place de la matrice de similarité.

Une idée sur la façon d'y parvenir? Ou connaissez-vous une bibliothèque Python qui permettra d'exécuter la propagation d'étiquettes directement sur des données structurées par graphe pour les deux méthodes susmentionnées?

Merci d'avance pour votre aide!


Avez-vous vérifié le code source de Scikit-learn pour voir ce qu'il fait après avoir calculé la matrice d'affinité? Peut-être pourriez-vous "copier" le code après cette partie pour l'appliquer directement à votre matrice d'adjacence.
Tasos

Merci pour votre commentaire! Donc, en fait, c'est ce que je fais actuellement, mais certaines parties du code que je dois modifier pour répondre à mes besoins sont quelque peu cryptiques. Je crains que la réécriture de ces parties ne provoque des erreurs. J'espérais qu'il existait une méthode plus simple.
Thibaud Martinez

1
Le code source sur github.com/scikit-learn/scikit-learn/blob/7389dba/sklearn/… - dit que les implémentations devraient remplacer la méthode _build_graph. Donc, nativement, vous devriez essayer de créer une classe dérivée qui accepte une matrice précalculée.
mikalai

Réponses:


2

Répondre à ma propre question ici, car j'espère qu'elle sera utile à certains lecteurs.

Scikit-learn est principalement conçu pour traiter des données structurées vectorielles. Par conséquent, si vous souhaitez effectuer la propagation / la propagation d'étiquettes sur des données structurées sous forme de graphique, il est probablement préférable de réimplémenter la méthode vous-même plutôt que d'utiliser l'interface Scikit.

Voici une implémentation de la propagation d'étiquettes et de l'étalement d'étiquettes dans PyTorch.

Dans l'ensemble, les deux méthodes suivent les mêmes étapes algorithmiques, avec des variations sur la façon dont la matrice d'adjacence est normalisée et la façon dont les étiquettes sont propagées à chaque étape. Créons donc une classe de base pour nos deux modèles.

from abc import abstractmethod
import torch

class BaseLabelPropagation:
    """Base class for label propagation models.

    Parameters
    ----------
    adj_matrix: torch.FloatTensor
        Adjacency matrix of the graph.
    """
    def __init__(self, adj_matrix):
        self.norm_adj_matrix = self._normalize(adj_matrix)
        self.n_nodes = adj_matrix.size(0)
        self.one_hot_labels = None 
        self.n_classes = None
        self.labeled_mask = None
        self.predictions = None

    @staticmethod
    @abstractmethod
    def _normalize(adj_matrix):
        raise NotImplementedError("_normalize must be implemented")

    @abstractmethod
    def _propagate(self):
        raise NotImplementedError("_propagate must be implemented")

    def _one_hot_encode(self, labels):
        # Get the number of classes
        classes = torch.unique(labels)
        classes = classes[classes != -1]
        self.n_classes = classes.size(0)

        # One-hot encode labeled data instances and zero rows corresponding to unlabeled instances
        unlabeled_mask = (labels == -1)
        labels = labels.clone()  # defensive copying
        labels[unlabeled_mask] = 0
        self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
        self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
        self.one_hot_labels[unlabeled_mask, 0] = 0

        self.labeled_mask = ~unlabeled_mask

    def fit(self, labels, max_iter, tol):
        """Fits a semi-supervised learning label propagation model.

        labels: torch.LongTensor
            Tensor of size n_nodes indicating the class number of each node.
            Unlabeled nodes are denoted with -1.
        max_iter: int
            Maximum number of iterations allowed.
        tol: float
            Convergence tolerance: threshold to consider the system at steady state.
        """
        self._one_hot_encode(labels)

        self.predictions = self.one_hot_labels.clone()
        prev_predictions = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)

        for i in range(max_iter):
            # Stop iterations if the system is considered at a steady state
            variation = torch.abs(self.predictions - prev_predictions).sum().item()

            if variation < tol:
                print(f"The method stopped after {i} iterations, variation={variation:.4f}.")
                break

            prev_predictions = self.predictions
            self._propagate()

    def predict(self):
        return self.predictions

    def predict_classes(self):
        return self.predictions.max(dim=1).indices

Le modèle prend en entrée la matrice d'adjacence du graphe ainsi que les étiquettes des nœuds. Les étiquettes sont sous la forme d'un vecteur d'un entier indiquant le numéro de classe de chaque nœud avec un -1 à la position des nœuds non étiquetés.

L'algorithme de propagation des étiquettes est présenté ci-dessous.

W: adjacency matrix of the graph Compute the diagonal degree matrix D by DiijWij Initialize Y^(0)(y1,,yl,0,0,,0) Iterate  1. Y^(t+1)D1WY^(t) 2. Y^l(t+1)Yl until convergence to Y^() Label point xi by the sign of y^i()

De Xiaojin Zhu et Zoubin Ghahramani. Apprendre à partir de données étiquetées et non étiquetées avec propagation d'étiquettes. Rapport technique CMU-CALD-02-107, Université Carnegie Mellon, 2002

Nous obtenons l'implémentation suivante.

class LabelPropagation(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1 * W"""
        degs = adj_matrix.sum(dim=1)
        degs[degs == 0] = 1  # avoid division by 0 error
        return adj_matrix / degs[:, None]

    def _propagate(self):
        self.predictions = torch.matmul(self.norm_adj_matrix, self.predictions)

        # Put back already known labels
        self.predictions[self.labeled_mask] = self.one_hot_labels[self.labeled_mask]

    def fit(self, labels, max_iter=1000, tol=1e-3):
        super().fit(labels, max_iter, tol)

L'algorithme d'étalement des étiquettes est:

W: adjacency matrix of the graph Compute the diagonal degree matrix D by DiijWij Compute the normalized graph Laplacian LD1/2WD1/2 Initialize Y^(0)(y1,,yl,0,0,,0) Choose a parameter α[0,1) Iterate Y^(t+1)αLY^(t)+(1α)Y^(0) until convergence to Y^() Label point xi by the sign of y^i()

De Dengyong Zhou, Olivier Bousquet, Thomas Navin Lal, Jason Weston, Bernhard Schoelkopf. Apprendre avec cohérence locale et mondiale (2004)

La mise en œuvre est donc la suivante.

class LabelSpreading(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)
        self.alpha = None

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1/2 * W * D^-1/2"""
        degs = adj_matrix.sum(dim=1)
        norm = torch.pow(degs, -0.5)
        norm[torch.isinf(norm)] = 1
        return adj_matrix * norm[:, None] * norm[None, :]

    def _propagate(self):
        self.predictions = (
            self.alpha * torch.matmul(self.norm_adj_matrix, self.predictions)
            + (1 - self.alpha) * self.one_hot_labels
        )

    def fit(self, labels, max_iter=1000, tol=1e-3, alpha=0.5):
        """
        Parameters
        ----------
        alpha: float
            Clamping factor.
        """
        self.alpha = alpha
        super().fit(labels, max_iter, tol)

Essayons maintenant nos modèles de propagation sur des données synthétiques. Pour ce faire, nous choisissons d'utiliser un graphe d'homme des cavernes .

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

# Create caveman graph
n_cliques = 4
size_cliques = 10
caveman_graph = nx.connected_caveman_graph(n_cliques, size_cliques)
adj_matrix = nx.adjacency_matrix(caveman_graph).toarray()

# Create labels
labels = np.full(n_cliques * size_cliques, -1.)

# Only one node per clique is labeled. Each clique belongs to a different class.
labels[0] = 0
labels[size_cliques] = 1
labels[size_cliques * 2] = 2
labels[size_cliques * 3] = 3

# Create input tensors
adj_matrix_t = torch.FloatTensor(adj_matrix)
labels_t = torch.LongTensor(labels)

# Learn with Label Propagation
label_propagation = LabelPropagation(adj_matrix_t)
label_propagation.fit(labels_t)
label_propagation_output_labels = label_propagation.predict_classes()

# Learn with Label Spreading
label_spreading = LabelSpreading(adj_matrix_t)
label_spreading.fit(labels_t, alpha=0.8)
label_spreading_output_labels = label_spreading.predict_classes()

# Plot graphs
color_map = {-1: "orange", 0: "blue", 1: "green", 2: "red", 3: "cyan"}
input_labels_colors = [color_map[l] for l in labels]
lprop_labels_colors = [color_map[l] for l in label_propagation_output_labels.numpy()]
lspread_labels_colors = [color_map[l] for l in label_spreading_output_labels.numpy()]

plt.figure(figsize=(14, 6))
ax1 = plt.subplot(1, 4, 1)
ax2 = plt.subplot(1, 4, 2)
ax3 = plt.subplot(1, 4, 3)

ax1.title.set_text("Raw data (4 classes)")
ax2.title.set_text("Label Propagation")
ax3.title.set_text("Label Spreading")

pos = nx.spring_layout(caveman_graph)
nx.draw(caveman_graph, ax=ax1, pos=pos, node_color=input_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax2, pos=pos, node_color=lprop_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax3, pos=pos, node_color=lspread_labels_colors, node_size=50)

# Legend
ax4 = plt.subplot(1, 4, 4)
ax4.axis("off")
legend_colors = ["orange", "blue", "green", "red", "cyan"]
legend_labels = ["unlabeled", "class 0", "class 1", "class 2", "class 3"]
dummy_legend = [ax4.plot([], [], ls='-', c=c)[0] for c in legend_colors]
plt.legend(dummy_legend, legend_labels)

plt.show()

Les modèles mis en œuvre fonctionnent correctement et permettent de détecter les communautés dans le graphique.

Implémentations de propagation et d'étalement d'étiquettes testées sur un graphe d'homme des cavernes

Remarque: Les méthodes de propagation présentées sont destinées à être utilisées sur des graphiques non orientés.

Le code est disponible en bloc - notes interactif Jupyter ici .

Classeur

En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.