Anwendung der stochastischen Variationsinferenz auf die Bayes'sche Mischung von Gauß'sch

9

Ich versuche , Gaussian Mixture Modell mit stochastischen Variations Inferenz zu implementieren, nach diesem Papier .

Geben Sie hier die Bildbeschreibung ein

Dies ist die pgm der Gaußschen Mischung.

Dem Artikel zufolge ist der vollständige Algorithmus der stochastischen Variationsinferenz: Geben Sie hier die Bildbeschreibung ein

Und ich bin immer noch sehr verwirrt über die Methode, sie auf GMM zu skalieren.

Zuerst dachte ich, der lokale Variationsparameter sei nur und andere sind alle globale Parameter. Bitte korrigieren Sie mich, wenn ich falsch lag. Was bedeutet Schritt 6 ? Was soll ich tun, um dies zu erreichen?qzas though Xi is replicated by N times

Könnten Sie mir bitte dabei helfen? Danke im Voraus!

user5779223
quelle
Anstatt den gesamten Datensatz zu verwenden, sollten Sie einen Datenpunkt abtasten und so tun, als hätten Sie Datenpunkte gleicher Größe. In vielen Fällen wird dies gleichbedeutend mit einer Erwartung eines Datenpunkt durch Multiplizieren N . NN
Daeyoung Lim
@DaeyoungLim Danke für deine Antwort! Ich habe verstanden, was Sie jetzt meinen, aber ich bin immer noch verwirrt darüber, welche Statistiken lokal und welche global aktualisiert werden sollten. Hier ist zum Beispiel eine Implementierung der Mischung aus Gauß. Können Sie mir sagen, wie man sie auf svi skaliert? Ich bin ein bisschen verloren. Vielen Dank!
user5779223
Ich habe nicht den gesamten Code gelesen, aber wenn Sie es mit einem Gaußschen Mischungsmodell zu tun haben, sollten die Indikatorvariablen der Mischungskomponenten die lokalen Variablen sein, da jede von ihnen nur einer Beobachtung zugeordnet ist. Latente Variablen der Mischungskomponente, die der Multinoulli-Verteilung folgen (auch als kategoriale Verteilung in ML bekannt), sind also in Ihrer obigen Beschreibung. zi,i=1,,N
Daeyoung Lim
@DaeyoungLim Ja, ich verstehe, was du bisher gesagt hast. Für die Variationsverteilung q (Z) q (\ pi, \ mu, \ lambda) sollte q (Z) eine lokale Variable sein. Mit q (Z) sind jedoch viele Parameter verbunden. Andererseits sind auch viele Parameter mit q verbunden (\ pi, \ mu, \ lambda). Und ich weiß nicht, wie ich sie angemessen aktualisieren soll.
user5779223
Sie sollten die Mittelfeldannahme verwenden, um die optimalen Variationsverteilungen für Variationsparameter zu erhalten. Hier ist eine Referenz: maths.usyd.edu.au/u/jormerod/JTOpapers/Ormerod10.pdf
Daeyoung Lim

Antworten:

1

Zunächst einige Anmerkungen, die mir helfen, das SVI-Papier zu verstehen:

  • NN
  • ηgβ

kμk,τkηg

μ,τN(μ|γ,τ(2α1)Ga(τ|α,β)

η0=2α1η1=γ(2α1)η2=2β+γ2(2α1)a,b,mα,β,μ

μk,τkη˙+Nzn,kNzn,kxNN.zn,kxn2η˙zn,kexpln(p))N.p(xn|zn,α,β,γ)=N.K.(p(xn|αk,βk,γk))zn,k

Damit können wir Schritt (5) des SVI-Pseudocodes abschließen mit:

ϕn,kexp(ln(π)+E.qln(p(xn|αk,βk,γk))=exp(ln(π)+E.q[μkτk,- -τ2x,x2- -μ2τ- -lnτ2)]]

Das Aktualisieren der globalen Parameter ist einfacher, da jeder Parameter einer Anzahl der Daten oder einer seiner ausreichenden Statistiken entspricht:

λ^=η˙+N.ϕn1,x,x2

0ein,b,mα,β,μ

Geben Sie hier die Bildbeschreibung ein

Geben Sie hier die Bildbeschreibung ein

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Aug 12 12:49:15 2018

@author: SeanEaster
"""

import numpy as np
from matplotlib import pylab as plt
from scipy.stats import t
from scipy.special import digamma 

# These are priors for mu, alpha and beta

def calc_rho(t, delay=16,forgetting=1.):
    return np.power(t + delay, -forgetting)

m_prior, alpha_prior, beta_prior = 0., 1., 1.
eta_0 = 2 * alpha_prior - 1
eta_1 = m_prior * (2 * alpha_prior - 1)
eta_2 = 2 *  beta_prior + np.power(m_prior, 2.) * (2 * alpha_prior - 1)

k = 3

eta_shape = (k,3)
eta_prior = np.ones(eta_shape)
eta_prior[:,0] = eta_0
eta_prior[:,1] = eta_1
eta_prior[:,2] = eta_2

np.random.seed(123) 
size = 1000
dummy_data = np.concatenate((
        np.random.normal(-1., scale=.25, size=size),
        np.random.normal(0.,  scale=.25,size=size),
        np.random.normal(1., scale=.25, size=size)
        ))
N = len(dummy_data)
S = 1

# randomly init global params
alpha = np.random.gamma(3., scale=1./3., size=k)
m = np.random.normal(scale=1, size=k)
beta = np.random.gamma(3., scale=1./3., size=k)

eta = np.zeros(eta_shape)
eta[:,0] = 2 * alpha - 1
eta[:,1] = m * eta[:,0]
eta[:,2] = 2. * beta + np.power(m, 2.) * eta[:,0]


phi = np.random.dirichlet(np.ones(k) / k, size = dummy_data.shape[0])

nrows, ncols = 4, 5
total_plots = nrows * ncols
total_iters = np.power(2, total_plots - 1)
iter_idx = 0

x = np.linspace(dummy_data.min(), dummy_data.max(), num=200)

while iter_idx < total_iters:

    if np.log2(iter_idx + 1) % 1 == 0:

        alpha = 0.5 * (eta[:,0] + 1)
        beta = 0.5 * (eta[:,2] - np.power(eta[:,1], 2.) / eta[:,0])
        m = eta[:,1] / eta[:,0]
        idx = int(np.log2(iter_idx + 1)) + 1

        f = plt.subplot(nrows, ncols, idx)
        s = np.zeros(x.shape)
        for _ in range(k):
            y = t.pdf(x, alpha[_], m[_], 2 * beta[_] / (2 * alpha[_] - 1))
            s += y
            plt.plot(x, y)
        plt.plot(x, s)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)

    # randomly sample data point, update parameters
    interm_eta = np.zeros(eta_shape)
    for _ in range(S):
        datum = np.random.choice(dummy_data, 1)

        # mean params for ease of calculating expectations
        alpha = 0.5 * ( eta[:,0] + 1)
        beta = 0.5 * (eta[:,2] - np.power(eta[:,1], 2) / eta[:,0])
        m = eta[:,1] / eta[:,0]

        exp_mu = m
        exp_tau = alpha / beta 
        exp_tau_m_sq = 1. / (2 * alpha - 1) + np.power(m, 2.) * alpha / beta
        exp_log_tau = digamma(alpha) - np.log(beta)


        like_term = datum * (exp_mu * exp_tau) - np.power(datum, 2.) * exp_tau / 2 \
            - (0.5 * exp_tau_m_sq - 0.5 * exp_log_tau)
        log_phi = np.log(1. / k) + like_term
        phi = np.exp(log_phi)
        phi = phi / phi.sum()

        interm_eta[:, 0] += phi
        interm_eta[:, 1] += phi * datum
        interm_eta[:, 2] += phi * np.power(datum, 2.)

    interm_eta = interm_eta * N / S
    interm_eta += eta_prior

    rho = calc_rho(iter_idx + 1)

    eta = (1 - rho) * eta + rho * interm_eta

    iter_idx += 1
Sean Easter
quelle