Je parcourais le papier BERT qui utilise GELU (Gaussian Error Linear Unit) qui énonce l'équation comme
qui à son tour est approximativement à
Pourriez-vous simplifier l'équation et expliquer comment elle a été approximée.
Je parcourais le papier BERT qui utilise GELU (Gaussian Error Linear Unit) qui énonce l'équation comme
qui à son tour est approximativement à
Pourriez-vous simplifier l'équation et expliquer comment elle a été approximée.
Réponses:
Nous pouvons étendre la distribution cumulative de , c'est-à-dire , comme suit:
Notez que c'est une définition , pas une équation (ou une relation). Les auteurs ont fourni quelques justifications à cette proposition, par exemple une analogie stochastique , mais mathématiquement, ce n'est qu'une définition.
Voici l'intrigue de GELU:
Pour ce type d'approximations numériques, l'idée clé est de trouver une fonction similaire (principalement basée sur l'expérience), de la paramétrer, puis de l'adapter à un ensemble de points de la fonction d'origine.
Sachant que est très proche de
et la première dérivée de coïncide avec celle de à , qui est , nous procédons à l'ajustement de
(ou avec plus de termes) à un ensemble de points .
J'ai ajusté cette fonction à 20 échantillons entre (en utilisant ce site ), et voici les coefficients:
En définissant , été estimé à . Avec plus d'échantillons d'une gamme plus large (ce site ne permettait que 20), le coefficient sera plus proche de du papier . Enfin, nous obtenons
avec une erreur quadratique moyenne pour .
Notez que si nous n'avions pas utilisé la relation entre les premières dérivées, le terme aurait été inclus dans les paramètres comme suit
ce qui est moins beau (moins analytique, plus numérique)!
Comme suggéré par @BookYourLuck , nous pouvons utiliser la parité des fonctions pour restreindre l'espace des polynômes dans lesquels nous recherchons. C'est-à-dire que est une fonction impaire, c'est-à-dire , et est également une fonction impaire, la fonction polynomiale intérieur doit également être impair (ne doit avoir que des pouvoirs impairs de ) pour avoir
Auparavant, nous avons eu la chance de nous retrouver avec des coefficients (presque) nuls pour les puissances paires et , mais en général, cela pourrait conduire à des approximations de faible qualité qui, par exemple, ont un terme comme qui est annulé par des termes supplémentaires (pairs ou impairs) au lieu d'opter simplement pour .
Une relation similaire existe entre et (sigmoid), qui est proposé dans l'article comme une autre approximation, avec erreur quadratique moyenne pour .
Voici un code Python pour générer des points de données, ajuster les fonctions et calculer les erreurs quadratiques moyennes:
import math
import numpy as np
import scipy.optimize as optimize
def tahn(xs, a):
return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]
def sigmoid(xs, a):
return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]
print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
# .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])
# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))
sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))
# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()
# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()
print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)
if print_points == 1:
print(len(xs))
for x, erf in zip(xs, erfs):
print(x, erf)
Production:
Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05