Comment un modèle de régression logistique simple permet-il d'obtenir une précision de classification de 92% sur le MNIST?


73

Même si toutes les images du jeu de données MNIST sont centrées, avec une échelle similaire et face visible sans rotations, elles présentent une variation importante de l'écriture manuscrite qui me laisse perplexe sur la précision avec laquelle un modèle linéaire atteint une précision de classification aussi élevée.

Dans la mesure où je suis en mesure de visualiser, compte tenu de la variation importante de l'écriture manuscrite, les chiffres doivent être linéairement indissociables dans un espace à 784 dimensions, c'est-à-dire qu'il doit exister une petite limite non linéaire complexe (mais pas très complexe) séparant les différents chiffres. , semblable à l'exemple bien cité de XOR où les classes positives et négatives ne peuvent être séparées par aucun classifieur linéaire. Il me semble déconcertant que la régression logistique multi-classes produise une telle précision avec des caractéristiques entièrement linéaires (aucune caractéristique polynomiale).

À titre d'exemple, étant donné n'importe quel pixel de l'image, différentes variations manuscrites des chiffres 2 et 3 peuvent éclairer ou non ce pixel. Par conséquent, avec un ensemble de poids appris, chaque pixel peut donner à un chiffre l’apparence d’un 2 ou d’un 3 . Seule une combinaison de valeurs de pixels doit permettre de dire si un chiffre est un 2 ou un 3 . Cela est vrai pour la plupart des paires de chiffres. Alors, comment la régression logistique, qui base aveuglément sa décision indépendamment sur toutes les valeurs de pixels (sans aucune dépendance entre pixels), est-elle capable d’atteindre une telle précision?

Je sais que je me trompe quelque part ou que je surestime quelque peu la variation des images. Cependant, ce serait formidable si quelqu'un pouvait m'aider avec une intuition sur la façon dont les chiffres sont «presque» séparables linéairement.


Consultez le manuel Apprentissage statistique avec parcimonie: le lasso et les généralisations 3.3.1 Exemple: chiffres manuscrits web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
Adrian

J'ai été curieux: à quel point quelque chose comme un modèle linéaire pénalisé (c'est-à-dire, glmnet) fait-il pour résoudre le problème? Si je me souviens bien, ce que vous signalez, c'est l'exactitude non dénaturée de l'échantillon.
Cliff AB

Réponses:


91

tl; dr Même s’il s’agit d’un jeu de données de classification d’images, cela reste une tâche très facile , pour laquelle on peut facilement trouver un mappage direct des entrées aux prévisions.


Répondre:

C'est une question très intéressante et, grâce à la simplicité de la régression logistique, vous pouvez réellement trouver la réponse.

78478428×28

Notez, encore une fois, que ce sont les poids .

Regardez maintenant l'image ci-dessus et concentrez-vous sur les deux premiers chiffres (c'est-à-dire zéro et un). Les poids bleus signifient que l'intensité de ce pixel contribue beaucoup à cette classe et les valeurs rouges signifient qu'il contribue négativement.

0

1

2378

Grâce à cela, vous pouvez voir que la régression logistique a de très bonnes chances d’obtenir beaucoup d’images, c’est pourquoi elle est si performante.


Le code pour reproduire le chiffre ci-dessus est un peu daté, mais ici vous allez:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)

13
2378

13
Bien entendu, il est utile que les échantillons MNIST soient centrés, mis à l'échelle et normalisés par contraste avant que le classificateur ne les voie jamais. Vous n'avez pas à répondre à des questions telles que "que se passe-t-il si le bord du zéro passe réellement au milieu de la boîte?" parce que le pré-processeur a déjà fait beaucoup pour que tous les zéros se ressemblent.
Hobbs

1
@EricDuminil J'ai ajouté un commentaire sur le script avec votre suggestion. Merci beaucoup pour votre contribution! : D
Djib2011

1
@NitishAgarwal, Si vous pensez que cette réponse est la réponse à votre question, pensez à la marquer comme telle.
Sintax le

16
Pour ceux qui s'intéressent à ce type de traitement mais ne sont pas très familiarisés avec ce type de traitement, cette réponse constitue un exemple fantastique et intuitif de la mécanique.
Chrylis -on grève- le
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.