Wie wird der Generator in einem GAN trainiert?

9

Das Papier über GANs besagt, dass der Diskriminator den folgenden Gradienten zum Trainieren verwendet:

θd1mi=1m[logD(x(i))+log(1D(G(z(i))))]

Die Werte werden abgetastet, durch den Generator geleitet, um Datenabtastwerte zu erzeugen, und dann wird der Diskriminator unter Verwendung der erzeugten Datenabtastwerte rückpropogiert. Sobald der Generator die Daten erzeugt, spielt er keine weitere Rolle beim Training des Diskriminators. Mit anderen Worten, der Generator kann vollständig aus der Metrik entfernt werden, indem Datenproben generiert werden und dann nur mit den Proben gearbeitet wird.z

Ich bin etwas verwirrter darüber, wie der Generator trainiert ist. Es wird der folgende Farbverlauf verwendet:

θg1mi=1m[log(1D(G(z(i))))]

In diesem Fall ist der Diskriminator Teil der Metrik. Es kann nicht wie im vorherigen Fall entfernt werden. Dinge wie kleinste Quadrate oder logarithmische Wahrscheinlichkeit in regulären Unterscheidungsmodellen können leicht unterschieden werden, da sie eine schöne, eng geformte Definition haben. Ich bin jedoch etwas verwirrt darüber, wie Sie zurückpropogieren, wenn die Metrik von einem anderen neuronalen Netzwerk abhängt. Verbinden Sie im Wesentlichen die Ausgänge des Generators mit den Eingängen des Diskriminators und behandeln Sie dann das Ganze wie ein riesiges Netzwerk, in dem die Gewichte im Diskriminatorteil konstant sind?

Phidias
quelle

Antworten:

10

Es ist hilfreich, sich diesen Prozess im Pseudocode vorzustellen. Sei generator(z)eine Funktion, die einen gleichmäßig abgetasteten Rauschvektor nimmt zund einen Vektor mit der gleichen Größe wie der Eingabevektor zurückgibt X; Nennen wir diese Länge d. Sei discriminator(x)eine Funktion, die einen dDimensionsvektor nimmt und eine skalare Wahrscheinlichkeit zurückgibt, xdie zur wahren Datenverteilung gehört. Für das Training:

G_sample = generator(Z)
D_real = discriminator(X)
D_fake = discriminator(G_sample)

D_loss = maximize mean of (log(D_real) + log(1 - D_fake))
G_loss = maximize mean of log(D_fake)

# Only update D(X)'s parameters
D_solver = Optimizer().minimize(D_loss, theta_D)
# Only update G(X)'s parameters
G_solver = Optimizer().minimize(G_loss, theta_G)

# theta_D and theta_G are the weights and biases of D and G respectively
Repeat the above for a number of epochs

Ja, Sie haben Recht, dass wir Generator und Diskriminator im Wesentlichen als ein riesiges Netzwerk für abwechselnde Minibatches betrachten, wenn wir gefälschte Daten verwenden. Die Verlustfunktion des Generators kümmert sich um die Gradienten für diese Hälfte. Wenn Sie dieses Netzwerktraining isoliert betrachten, wird es so trainiert, wie Sie es normalerweise bei einem MLP trainieren würden, wobei sein Eingang der Ausgang der letzten Schicht des Generatornetzwerks ist.

Eine ausführliche Erklärung mit Code in Tensorflow finden Sie hier (unter anderem): http://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/

Es sollte leicht zu befolgen sein, wenn Sie sich den Code ansehen.

Tejaskhot
quelle
1
Könnten Sie näher darauf D_losseingehen G_loss? Maximierung über welchen Raum? IIUC, D_realund D_fakesind jeweils eine Charge, also maximieren wir über die Charge?
P i
@Pi Ja, wir maximieren über eine Charge.
Tejaskhot
1

Verbinden Sie im Wesentlichen die Ausgänge des Generators mit den Eingängen des Diskriminators?> Und behandeln Sie dann das Ganze wie ein riesiges Netzwerk, in dem die Gewichte im> Diskriminatorteil konstant sind?

Kurz: Ja. (Ich habe einige Quellen der GAN ausgegraben, um dies zu überprüfen.)

Es gibt auch viel mehr im GAN-Training wie: Sollten wir D und G jedes Mal aktualisieren oder D bei ungeraden Iterationen und G bei geraden und vieles mehr. Es gibt auch ein sehr schönes Papier zu diesem Thema:

"Verbesserte Techniken zum Trainieren von GANs"

Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford und Xi Chen

https://arxiv.org/abs/1606.03498

Liberus
quelle
Könnten Sie bitte Links zu den Quellen angeben, die Sie untersucht haben? Es wäre hilfreich für mich, sie zu lesen.
Vivek Subramanian
0

Kürzlich habe ich eine Sammlung verschiedener GAN-Modelle auf Github Repo hochgeladen. Es basiert auf torch7 und ist sehr einfach zu bedienen. Der Code ist einfach genug, um mit experimentellen Ergebnissen verstanden zu werden. Hoffe das wird helfen

https://github.com/nashory/gans-collection.torch

Nashory
quelle