Ich habe an einem Regressionsproblem gearbeitet, bei dem die Eingabe ein Bild und die Bezeichnung ein kontinuierlicher Wert zwischen 80 und 350 ist. Bei den Bildern handelt es sich um einige Chemikalien, nachdem eine Reaktion stattgefunden hat. Die Farbe, die angezeigt wird, gibt die Konzentration einer anderen Chemikalie an, die übrig bleibt, und das ist, was das Modell ausgeben soll - die Konzentration dieser Chemikalie. Die Bilder können gedreht, gespiegelt und gespiegelt werden, und die erwartete Ausgabe sollte immer noch dieselbe sein. Diese Art der Analyse wird in realen Labors durchgeführt (sehr spezialisierte Maschinen geben die Konzentration der Chemikalien mithilfe der Farbanalyse aus, genau wie ich dieses Modell trainiere).
Bisher habe ich nur mit Modellen experimentiert, die grob auf VGG basieren (mehrere Sequenzen von Conv-Conv-Conv-Pool-Blöcken). Bevor ich mit neueren Architekturen (Inception, ResNets usw.) experimentierte, dachte ich, ich würde nachforschen, ob es andere Architekturen gibt, die häufiger für die Regression mithilfe von Bildern verwendet werden.
Der Datensatz sieht folgendermaßen aus:
Der Datensatz enthält ungefähr 5.000 250x250 Samples, die ich auf 64x64 skaliert habe, um das Training zu vereinfachen. Sobald ich eine vielversprechende Architektur gefunden habe, experimentiere ich mit Bildern mit größerer Auflösung.
Bisher haben meine besten Modelle einen mittleren Fehlerquadrat für Trainings- und Validierungssätze von etwa 0,3, was in meinem Anwendungsfall alles andere als akzeptabel ist.
Mein bisher bestes Modell sieht so aus:
// pseudo code
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])
x = dropout()->conv2d(x, filters=128, kernel=[1, 1])->batch_norm()->relu()
x = dropout()->conv2d(x, filters=32, kernel=[1, 1])->batch_norm()->relu()
y = dense(x, units=1)
// loss = mean_squared_error(y, labels)
Frage
Was ist eine geeignete Architektur für die Regressionsausgabe aus einer Bildeingabe?
Bearbeiten
Ich habe meine Erklärung umformuliert und Erwähnungen der Genauigkeit entfernt.
Bearbeiten 2
Ich habe meine Frage so umstrukturiert, dass hoffentlich klar ist, wonach ich suche
quelle
Antworten:
Zunächst ein allgemeiner Vorschlag: Führen Sie eine Literaturrecherche durch, bevor Sie Experimente zu einem Thema durchführen, mit dem Sie nicht vertraut sind. Sie sparen sich viel Zeit.
In diesem Fall haben Sie bei der Betrachtung vorhandener Papiere möglicherweise bemerkt, dass
Die Regression mit CNNs ist kein triviales Problem. Wenn Sie sich das erste Papier noch einmal ansehen, werden Sie feststellen, dass es ein Problem gibt, bei dem im Grunde unbegrenzte Daten generiert werden können. Ihr Ziel ist es, den Drehwinkel vorherzusagen, der zum Korrigieren von 2D-Bildern erforderlich ist. Das bedeutet, dass ich mein Trainingsset im Grunde genommen durch Drehen jedes Bildes um beliebige Winkel erweitern und ein gültiges, größeres Trainingsset erhalten kann. Daher scheint das Problem relativ einfach zu sein, was Deep-Learning-Probleme angeht. Übrigens, beachten Sie auch die anderen von ihnen verwendeten Tricks zur Datenerweiterung:
Bei einem viel einfacheren (gedrehten MNIST-) Problem können Sie etwas Besseres erzielen , aber Sie unterschreiten dennoch nicht einen RMSE-Fehler, der des maximal möglichen Fehlers beträgt.2.6%
Was können wir daraus lernen? Zuallererst sind diese 5000 Bilder ein kleiner Datensatz für Ihre Aufgabe. Die erste Arbeit verwendete ein Netzwerk, das auf Bildern aufgebaut war, die denen ähnelten, für die sie die Regressionsaufgabe lernen wollten: Sie müssen nicht nur eine andere Aufgabe lernen als die, für die die Architektur entworfen wurde (Klassifizierung), sondern Ihr Trainingsset funktioniert auch nicht Sieht überhaupt nicht so aus wie die Trainingssets, auf denen diese Netzwerke normalerweise trainiert werden (CIFAR-10/100 oder ImageNet). Sie werden also wahrscheinlich keinen Nutzen aus dem Transferlernen ziehen. Das MATLAB-Beispiel hatte 5000 Bilder, aber sie waren schwarz-weiß und semantisch alle sehr ähnlich (na ja, das könnte auch Ihr Fall sein).
Wie realistisch ist es dann besser als 0,3? Zunächst müssen wir verstehen, was Sie unter einem durchschnittlichen Verlust von 0,3 verstehen. Meinen Sie, dass der RMSE-Fehler 0,3 ist,
wobei die Größe Ihres Trainingssatzes ist (also ), die Ausgabe Ihres CNN für image und die entsprechende Konzentration der Chemikalie ist? Seit Sie weniger als Fehler , wenn Sie davon ausgehen, dass Sie die Prognosen Ihres CNN zwischen 80 und 350 (oder einfach ein Logit verwenden, um sie in dieses Intervall einzupassen) . Ernsthaft, was erwartest du? Es scheint mir überhaupt kein großer Fehler zu sein.N N<5000 h(xi) xi yi yi∈[80,350] 0.12%
Versuchen Sie auch einfach, die Anzahl der Parameter in Ihrem Netzwerk zu berechnen: Ich habe es eilig und mache möglicherweise dumme Fehler. Überprüfen Sie meine Berechnungen daher auf jeden Fall noch einmal mit einer
summary
Funktion aus dem von Ihnen verwendeten Framework. Aber ungefähr würde ich sagen, dass Sie haben(Anmerkung: Ich habe die Parameter der Batch-Norm-Layer übersprungen, aber sie sind nur vier Parameter für Layer, damit sie keinen Unterschied machen.) Sie haben eine halbe Million Parameter und 5000 Beispiele ... was würden Sie erwarten? Sicher, die Anzahl der Parameter ist kein guter Indikator für die Kapazität eines neuronalen Netzwerks (es ist ein nicht identifizierbares Modell), aber dennoch ... Ich denke, Sie können nicht viel besser als dies tun, aber Sie können es versuchen wenige Sachen:
quelle