Les forêts aléatoires sont à peine une boîte noire. Ils sont basés sur des arbres de décision, très faciles à interpréter:
#Setup a binary classification problem
require(randomForest)
data(iris)
set.seed(1)
dat <- iris
dat$Species <- factor(ifelse(dat$Species=='virginica','virginica','other'))
trainrows <- runif(nrow(dat)) > 0.3
train <- dat[trainrows,]
test <- dat[!trainrows,]
#Build a decision tree
require(rpart)
model.rpart <- rpart(Species~., train)
Cela donne un arbre de décision simple:
> model.rpart
n= 111
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 111 35 other (0.68468468 0.31531532)
2) Petal.Length< 4.95 77 3 other (0.96103896 0.03896104) *
3) Petal.Length>=4.95 34 2 virginica (0.05882353 0.94117647) *
Si Petal.Length <4,95, cette arborescence classe l’observation dans la catégorie "autre". S'il est supérieur à 4,95, l'observation est classée comme "virginica". Une forêt aléatoire est une simple collection de nombreux arbres de ce type, où chacun est formé sur un sous-ensemble aléatoire de données. Chaque arbre "vote" ensuite sur le classement final de chaque observation.
model.rf <- randomForest(Species~., train, ntree=25, proximity=TRUE, importance=TRUE, nodesize=5)
> getTree(model.rf, k=1, labelVar=TRUE)
left daughter right daughter split var split point status prediction
1 2 3 Petal.Width 1.70 1 <NA>
2 4 5 Petal.Length 4.95 1 <NA>
3 6 7 Petal.Length 4.95 1 <NA>
4 0 0 <NA> 0.00 -1 other
5 0 0 <NA> 0.00 -1 virginica
6 0 0 <NA> 0.00 -1 other
7 0 0 <NA> 0.00 -1 virginica
Vous pouvez même extraire des arbres individuels de la radio et regarder leur structure. Le format est légèrement différent de celui des rpart
modèles, mais vous pouvez inspecter chaque arbre si vous le souhaitez et voir comment il modélise les données.
De plus, aucun modèle n'est vraiment une boîte noire, car vous pouvez examiner les réponses prévues par rapport aux réponses réelles pour chaque variable de l'ensemble de données. C'est une bonne idée quel que soit le type de modèle que vous construisez:
library(ggplot2)
pSpecies <- predict(model.rf,test,'vote')[,2]
plotData <- lapply(names(test[,1:4]), function(x){
out <- data.frame(
var = x,
type = c(rep('Actual',nrow(test)),rep('Predicted',nrow(test))),
value = c(test[,x],test[,x]),
species = c(as.numeric(test$Species)-1,pSpecies)
)
out$value <- out$value-min(out$value) #Normalize to [0,1]
out$value <- out$value/max(out$value)
out
})
plotData <- do.call(rbind,plotData)
qplot(value, species, data=plotData, facets = type ~ var, geom='smooth', span = 0.5)
J'ai normalisé les variables (longueur et largeur des sépales et des pétales) sur une plage de 0 à 1. La réponse est également 0-1, 0 étant autre et 1 virginica. Comme vous pouvez le constater, la forêt aléatoire est un bon modèle, même sur l’ensemble de tests.
De plus, une forêt aléatoire calculera diverses mesures d'importance variable, ce qui peut être très informatif:
> importance(model.rf, type=1)
MeanDecreaseAccuracy
Sepal.Length 0.28567162
Sepal.Width -0.08584199
Petal.Length 0.64705819
Petal.Width 0.58176828
Ce tableau représente combien le fait de supprimer chaque variable réduit la précision du modèle. Enfin, il existe de nombreux autres tracés que vous pouvez créer à partir d'un modèle de forêt aléatoire pour voir ce qui se passe dans la boîte noire:
plot(model.rf)
plot(margin(model.rf))
MDSplot(model.rf, iris$Species, k=5)
plot(outlier(model.rf), type="h", col=c("red", "green", "blue")[as.numeric(dat$Species)])
Vous pouvez afficher les fichiers d'aide de chacune de ces fonctions pour avoir une meilleure idée de ce qu'ils affichent.