Application de l'inférence variationnelle stochastique au mélange bayésien de gaussien


9

J'essaie d'implémenter le modèle de mélange gaussien avec l'inférence variationnelle stochastique, à la suite de cet article .

entrez la description de l'image ici

C'est le pgm du mélange gaussien.

Selon l'article, l'algorithme complet d'inférence variationnelle stochastique est: entrez la description de l'image ici

Et je suis encore très confus de la méthode pour l'adapter à GMM.

Tout d'abord, je pensais que le paramètre variationnel local est juste et que d'autres sont tous des paramètres globaux. Veuillez me corriger si je me trompais. Que signifie l'étape 6 ? Que dois-je faire pour y parvenir?qzas though Xi is replicated by N times

Pourriez-vous s'il vous plaît m'aider avec cela? Merci d'avance!


Cela signifie qu'au lieu d'utiliser l'ensemble de données, échantillonnez un point de données et prétendez que vous avez points de données de la même taille. Dans de nombreux cas, cela équivaudrait à multiplier une attente avec un point de données par . NNN
Daeyoung Lim

@DaeyoungLim Merci pour votre réponse! J'ai compris ce que vous voulez dire maintenant, mais je ne comprends toujours pas quelles statistiques doivent être mises à jour localement et lesquelles doivent être mises à jour globalement. Par exemple, voici une implémentation du mélange de gaussien, pourriez-vous me dire comment l'adapter à svi? Je suis un peu perdu. Merci beaucoup!
user5779223

Je n'ai pas lu tout le code mais si vous avez affaire à un modèle de mélange gaussien, les variables indicatrices des composants du mélange devraient être les variables locales car chacune d'elles est associée à une seule observation. Les variables latentes des composants du mélange qui suivent la distribution Multinoulli (également connue sous le nom de distribution catégorique en ML) sont dans votre description ci-dessus. zi,i=1,,N
Daeyoung Lim

@DaeyoungLim Oui, je comprends ce que vous avez dit jusqu'à présent. Ainsi, pour la distribution variationnelle q (Z) q (\ pi, \ mu, \ lambda), q (Z) doit être une variable locale. Mais il y a beaucoup de paramètres associés à q (Z). D'autre part, il existe également de nombreux paramètres associés à q (\ pi, \ mu, \ lambda). Et je ne sais pas comment les mettre à jour correctement.
user5779223

Vous devez utiliser l'hypothèse du champ moyen pour obtenir les distributions variationnelles optimales pour les paramètres variationnels. Voici une référence: maths.usyd.edu.au/u/jormerod/JTOpapers/Ormerod10.pdf
Daeyoung Lim

Réponses:



1

Tout d'abord, quelques notes qui m'aident à donner un sens au papier SVI:

  • En calculant la valeur intermédiaire pour le paramètre variationnel des paramètres globaux, nous échantillonnons un point de données et prétendons que l'ensemble de nos données de taille était ce point unique, fois.NNN
  • βηg est le paramètre naturel pour le conditionnel complet de la variable globale . La notation est utilisée pour souligner qu'elle est fonction des variables conditionnées, y compris les données observées. β

Dans un mélange de Gaussiens, nos paramètres globaux sont les paramètres de moyenne et de précision (variance inverse) params pour chacun. Autrement dit, est le paramètre naturel pour cette distribution, un Normal-Gamma de la formeμ k , τ k η gkμk,τkηg

μ,τN(μ|γ,τ(2α1)Ga(τ|α,β)

avec , et . (Bernardo et Smith, théorie bayésienne ; notez que cela varie un peu par rapport au gamma normal à quatre paramètres que vous verrez généralement .) Nous utiliserons pour faire référence aux paramètres variationnels pourη 1 = γ ( 2 α - 1 ) η 2 = 2 β + γ 2 ( 2 α - 1η0=2α1η1=γ(2α1)η2=2β+γ2(2α1)a,b,mα,β,μ

Le conditionnel complet de est un Normal-Gamma avec des paramètres , , , où est le prieur. (Le là-dedans peut aussi être déroutant; il est logique de commencer par une astuce appliquée à , et se terminant avec une bonne quantité d'algèbre laissée au lecteur.)˙ ημk,τkη˙+Nzn,kNzn,kxNNzn,kxn2η˙zn,kexpln(p))Np(xn|zn,α,β,γ)=NK(p(xn|αk,βk,γk))zn,k

Avec cela, nous pouvons terminer l'étape (5) du pseudocode SVI avec:

ϕn,kexp(ln(π)+Eqln(p(xn|αk,βk,γk))=exp(ln(π)+Eq[μkτk,τ2x,x2μ2τlnτ2)]

La mise à jour des paramètres globaux est plus facile, car chaque paramètre correspond à un décompte des données ou à l'une de ses statistiques suffisantes:

λ^=η˙+Nϕn1,x,x2

Voici à quoi ressemble la probabilité marginale de données sur de nombreuses itérations, lorsqu'elles sont formées sur des données très artificielles et facilement séparables (code ci-dessous). Le premier graphique montre la probabilité avec des paramètres variationnels aléatoires initiaux et itérations; chaque suivant est après la puissance suivante de deux itérations. Dans le code, font référence aux paramètres variationnels pour .a , b , m α , β , μ0a,b,mα,β,μ

entrez la description de l'image ici

entrez la description de l'image ici

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Aug 12 12:49:15 2018

@author: SeanEaster
"""

import numpy as np
from matplotlib import pylab as plt
from scipy.stats import t
from scipy.special import digamma 

# These are priors for mu, alpha and beta

def calc_rho(t, delay=16,forgetting=1.):
    return np.power(t + delay, -forgetting)

m_prior, alpha_prior, beta_prior = 0., 1., 1.
eta_0 = 2 * alpha_prior - 1
eta_1 = m_prior * (2 * alpha_prior - 1)
eta_2 = 2 *  beta_prior + np.power(m_prior, 2.) * (2 * alpha_prior - 1)

k = 3

eta_shape = (k,3)
eta_prior = np.ones(eta_shape)
eta_prior[:,0] = eta_0
eta_prior[:,1] = eta_1
eta_prior[:,2] = eta_2

np.random.seed(123) 
size = 1000
dummy_data = np.concatenate((
        np.random.normal(-1., scale=.25, size=size),
        np.random.normal(0.,  scale=.25,size=size),
        np.random.normal(1., scale=.25, size=size)
        ))
N = len(dummy_data)
S = 1

# randomly init global params
alpha = np.random.gamma(3., scale=1./3., size=k)
m = np.random.normal(scale=1, size=k)
beta = np.random.gamma(3., scale=1./3., size=k)

eta = np.zeros(eta_shape)
eta[:,0] = 2 * alpha - 1
eta[:,1] = m * eta[:,0]
eta[:,2] = 2. * beta + np.power(m, 2.) * eta[:,0]


phi = np.random.dirichlet(np.ones(k) / k, size = dummy_data.shape[0])

nrows, ncols = 4, 5
total_plots = nrows * ncols
total_iters = np.power(2, total_plots - 1)
iter_idx = 0

x = np.linspace(dummy_data.min(), dummy_data.max(), num=200)

while iter_idx < total_iters:

    if np.log2(iter_idx + 1) % 1 == 0:

        alpha = 0.5 * (eta[:,0] + 1)
        beta = 0.5 * (eta[:,2] - np.power(eta[:,1], 2.) / eta[:,0])
        m = eta[:,1] / eta[:,0]
        idx = int(np.log2(iter_idx + 1)) + 1

        f = plt.subplot(nrows, ncols, idx)
        s = np.zeros(x.shape)
        for _ in range(k):
            y = t.pdf(x, alpha[_], m[_], 2 * beta[_] / (2 * alpha[_] - 1))
            s += y
            plt.plot(x, y)
        plt.plot(x, s)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)

    # randomly sample data point, update parameters
    interm_eta = np.zeros(eta_shape)
    for _ in range(S):
        datum = np.random.choice(dummy_data, 1)

        # mean params for ease of calculating expectations
        alpha = 0.5 * ( eta[:,0] + 1)
        beta = 0.5 * (eta[:,2] - np.power(eta[:,1], 2) / eta[:,0])
        m = eta[:,1] / eta[:,0]

        exp_mu = m
        exp_tau = alpha / beta 
        exp_tau_m_sq = 1. / (2 * alpha - 1) + np.power(m, 2.) * alpha / beta
        exp_log_tau = digamma(alpha) - np.log(beta)


        like_term = datum * (exp_mu * exp_tau) - np.power(datum, 2.) * exp_tau / 2 \
            - (0.5 * exp_tau_m_sq - 0.5 * exp_log_tau)
        log_phi = np.log(1. / k) + like_term
        phi = np.exp(log_phi)
        phi = phi / phi.sum()

        interm_eta[:, 0] += phi
        interm_eta[:, 1] += phi * datum
        interm_eta[:, 2] += phi * np.power(datum, 2.)

    interm_eta = interm_eta * N / S
    interm_eta += eta_prior

    rho = calc_rho(iter_idx + 1)

    eta = (1 - rho) * eta + rho * interm_eta

    iter_idx += 1
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.