Wie breiten sich Gradienten in einem nicht abgerollten wiederkehrenden neuronalen Netzwerk aus?

8

Ich versuche zu verstehen, wie rnns verwendet werden können, um Sequenzen anhand eines einfachen Beispiels vorherzusagen. Hier ist mein einfaches Netzwerk, bestehend aus einem Eingang, einem versteckten Neuron und einem Ausgang:

Geben Sie hier die Bildbeschreibung ein

Das versteckte Neuron ist die Sigmoidfunktion, und die Ausgabe wird als einfache lineare Ausgabe angesehen. Ich denke, das Netzwerk funktioniert wie folgt: Wenn die verborgene Einheit im Status startet sund wir einen Datenpunkt verarbeiten, der eine Folge der Länge ist , dann:( x 1 , x 2 , x 3 )3(x1,x2,x3)

Zu einem Zeitpunkt 1, der vorhergesagte Wert, istp1

p1=u×σ(ws+vx1)

Zur Zeit 2haben wir

p2=u×σ(w×σ(ws+vx1)+vx2)

Zur Zeit 3haben wir

p3=u×σ(w×σ(w×σ(ws+vx1)+vx2)+vx3)

So weit, ist es gut?

Das "abgerollte" RNN sieht folgendermaßen aus:

Geben Sie hier die Bildbeschreibung ein

Wenn wir für die Zielfunktion eine Summe der quadratischen Fehlerterme verwenden, wie ist sie dann definiert? Auf die ganze Sequenz? In welchem ​​Fall hätten wir so etwas wie ?E=(p1x1)2+(p2x2)2+(p3x3)2

Werden Gewichte erst aktualisiert, wenn die gesamte Sequenz betrachtet wurde (in diesem Fall die 3-Punkt-Sequenz)?

Was den Gradienten in Bezug auf die Gewichte betrifft, müssen wir berechnen. Ich werde versuchen, dies einfach durch Untersuchen der 3 Gleichungen für oben zu tun , wenn alles andere korrekt aussieht. Abgesehen davon sieht dies für mich nicht nach Vanilla-Back-Propagation aus, da dieselben Parameter in verschiedenen Schichten des Netzwerks angezeigt werden. Wie stellen wir uns darauf ein?dE/dw,dE/dv,dE/dupi

Wenn mir jemand helfen kann, mich durch dieses Spielzeugbeispiel zu führen, wäre ich sehr dankbar.

Fequish
quelle
Ich denke, etwas stimmt nicht mit der Fehlerfunktion, Sie erhalten wahrscheinlich als zweiten Elementterm und Sie müssen es wahrscheinlich mit , im perfekten Fall müssen sie gleich sein. In Ihrer Fehlerfunktion vergleichen Sie einfach die Ein- und Ausgabe des Netzwerks. p1x2
itdxer
Ich dachte, das könnte der Fall sein. Aber wie ist dann der Fehler für das letzte vorhergesagte Element ? p3
Fequish

Antworten:

1

Ich denke, Sie brauchen Zielwerte. Für die Sequenz benötigen Sie also entsprechende Ziele . Da Sie den nächsten Term der ursprünglichen Eingabesequenz vorhersagen möchten, benötigen Sie: (x1,x2,x3)(t1,t2,t3)

t1=x2, t2=x3, t3=x4

Sie müssten definieren. Wenn Sie also eine Eingabesequenz der Länge zum Trainieren des RNN hätten, könnten Sie nur die ersten Terme als Eingabewerte und die letzten Terme als Ziel verwenden Werte.x4NN1N1

Wenn wir für die Zielfunktion eine Summe der quadratischen Fehlerterme verwenden, wie ist sie dann definiert?

Soweit mir bekannt ist, haben Sie Recht - der Fehler ist die Summe über die gesamte Sequenz. Dies liegt daran, dass die Gewichte , und über die entfaltete RNN gleich sind.uvw

Also ist

E=tEt=t(ttpt)2

Werden Gewichte erst aktualisiert, wenn die gesamte Sequenz betrachtet wurde (in diesem Fall die 3-Punkt-Sequenz)?

Ja, wenn ich die Rückausbreitung durch die Zeit verwende, dann glaube ich das.

Was die Differentiale betrifft, möchten Sie nicht den gesamten Ausdruck für und ihn differenzieren, wenn es um größere RNNs geht. Eine Notation kann es also ordentlicher machen:E

  • Es sei die Eingabe in das verborgene Neuron zum Zeitpunkt (dh ).zttz1=ws+vx1
  • Lassen bezeichnet die Ausgabe für den versteckten Neurons zum Zeitpunkt (dh ytty1=σ(ws+vx1))
  • Seiy0=s
  • Seiδt=Ezt

Dann sind die Derivate:

Eu=ytEv=tδtxtEw=tδtyt1

Wobei für eine Folge der Länge gilt und:t[1, T]T

δt=σ(zt)(u+δt+1w)

Diese wiederkehrende Beziehung ergibt sich aus der Erkenntnis, dass die verborgene Aktivität nicht nur den Fehler am -Ausgang , sondern auch den Rest des Fehlers weiter unten im RNN :tthtthEtEEt

Ezt=Etytytzt+(EEt)zt+1zt+1ytytztEzt=ytzt(Etyt+(EEt)zt+1zt+1yt)Ezt=σ(zt)(u+(EEt)zt+1w)δt=Ezt=σ(zt)(u+δt+1w)

Abgesehen davon sieht dies für mich nicht nach Vanilla-Back-Propagation aus, da dieselben Parameter in verschiedenen Schichten des Netzwerks angezeigt werden. Wie stellen wir uns darauf ein?

Diese Methode wird als Back Propagation Through Time (BPTT) bezeichnet und ähnelt der Back Propagation in dem Sinne, dass die Kettenregel wiederholt angewendet wird.

Ein detaillierteres, aber kompliziertes Beispiel für ein RNN finden Sie in Kapitel 3.2 von 'Überwachte Sequenzmarkierung mit wiederkehrenden neuronalen Netzen' von Alex Graves - wirklich interessante Lektüre!

dok
quelle
0

Der oben beschriebene Fehler (nach der Änderung, die ich im Kommentar unter der Frage geschrieben habe) können Sie nur wie einen Gesamtvorhersagefehler verwenden, aber Sie können ihn nicht im Lernprozess verwenden. Bei jeder Iteration geben Sie einen Eingabewert in das Netzwerk ein und erhalten eine Ausgabe. Wenn Sie eine Ausgabe erhalten, müssen Sie Ihr Netzwerkergebnis überprüfen und den Fehler auf alle Gewichte übertragen. Nach dem Update setzen Sie den nächsten Wert in die Reihenfolge und machen eine Vorhersage für diesen Wert, dann verbreiten Sie auch den Fehler und so weiter.

itdxer
quelle