Prédiction probabiliste de forêt aléatoire vs vote majoritaire


10

Scikit learn semble utiliser la prédiction probabiliste au lieu du vote majoritaire pour la technique d'agrégation du modèle sans expliquer pourquoi (1.9.2.1. Forêts aléatoires).

Y a-t-il une explication claire pourquoi? De plus, existe-t-il un bon article ou article de synthèse sur les différentes techniques d'agrégation de modèles pouvant être utilisées pour l'ensachage en forêt aléatoire?

Merci!

Réponses:


10

Il est toujours préférable de répondre à ces questions en consultant le code, si vous parlez couramment Python.

RandomForestClassifier.predict, au moins dans la version actuelle 0.16.1, prédit la classe avec l'estimation de probabilité la plus élevée, telle que donnée par predict_proba. ( cette ligne )

La documentation de predict_probadit:

Les probabilités de classe prédites d'un échantillon d'entrée sont calculées comme les probabilités moyennes de classe prédites des arbres dans la forêt. La probabilité de classe d'un seul arbre est la fraction d'échantillons de la même classe dans une feuille.

La différence par rapport à la méthode d'origine est probablement juste pour que les predictprédictions soient cohérentes avec predict_proba. Le résultat est parfois appelé "vote doux", plutôt que le vote majoritaire "dur" utilisé dans le document original de Breiman. Je n'ai pas pu en recherche rapide trouver une comparaison appropriée des performances des deux méthodes, mais elles semblent toutes les deux assez raisonnables dans cette situation.

La predictdocumentation est au mieux assez trompeuse; J'ai soumis une demande d'extraction pour le corriger.

Si vous voulez plutôt faire une prédiction de vote majoritaire, voici une fonction pour le faire. Appelez ça predict_majvote(clf, X)plutôt que clf.predict(X). (Basé sur predict_proba; seulement légèrement testé, mais je pense que cela devrait fonctionner.)

from scipy.stats import mode
from sklearn.ensemble.forest import _partition_estimators, _parallel_helper
from sklearn.tree._tree import DTYPE
from sklearn.externals.joblib import Parallel, delayed
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

def predict_majvote(forest, X):
    """Predict class for X.

    Uses majority voting, rather than the soft voting scheme
    used by RandomForestClassifier.predict.

    Parameters
    ----------
    X : array-like or sparse matrix of shape = [n_samples, n_features]
        The input samples. Internally, it will be converted to
        ``dtype=np.float32`` and if a sparse matrix is provided
        to a sparse ``csr_matrix``.
    Returns
    -------
    y : array of shape = [n_samples] or [n_samples, n_outputs]
        The predicted classes.
    """
    check_is_fitted(forest, 'n_outputs_')

    # Check data
    X = check_array(X, dtype=DTYPE, accept_sparse="csr")

    # Assign chunk of trees to jobs
    n_jobs, n_trees, starts = _partition_estimators(forest.n_estimators,
                                                    forest.n_jobs)

    # Parallel loop
    all_preds = Parallel(n_jobs=n_jobs, verbose=forest.verbose,
                         backend="threading")(
        delayed(_parallel_helper)(e, 'predict', X, check_input=False)
        for e in forest.estimators_)

    # Reduce
    modes, counts = mode(all_preds, axis=0)

    if forest.n_outputs_ == 1:
        return forest.classes_.take(modes[0], axis=0)
    else:
        n_samples = all_preds[0].shape[0]
        preds = np.zeros((n_samples, forest.n_outputs_),
                         dtype=forest.classes_.dtype)
        for k in range(forest.n_outputs_):
            preds[:, k] = forest.classes_[k].take(modes[:, k], axis=0)
        return preds

Sur le cas synthétique stupide que j'ai essayé, les prédictions étaient d'accord avec la predictméthode à chaque fois.


Excellente réponse, Dougal! Merci d'avoir pris le temps de l'expliquer soigneusement. S'il vous plaît également envisager d' aller vers un débordement de pile et de répondre à cette question là .
user1745038

1
Il y a aussi un article, ici , qui traite de la prédiction probabiliste.
user1745038
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.