Nachdem Sie ein Modell in Tensorflow trainiert haben:
- Wie speichern Sie das trainierte Modell?
- Wie können Sie dieses gespeicherte Modell später wiederherstellen?
python
tensorflow
machine-learning
model
Matheten
quelle
quelle
Antworten:
Docs
ausführliches und nützliches Tutorial -> https://www.tensorflow.org/guide/saved_model
Keras ausführliche Anleitung zum Speichern von Modellen -> https://www.tensorflow.org/guide/keras/save_and_serialize
Aus den Dokumenten:
speichern
Wiederherstellen
Tensorflow 2
Dies ist noch Beta, daher würde ich vorerst davon abraten. Wenn Sie diesen Weg noch gehen möchten, finden Sie hier die
tf.saved_model
GebrauchsanweisungTensorflow <2
simple_save
Viele gute Antworten, der Vollständigkeit halber füge ich meine 2 Cent hinzu: simple_save . Auch ein eigenständiges Codebeispiel mit der
tf.data.Dataset
API.Python 3; Tensorflow 1.14
Wiederherstellen:
Eigenständiges Beispiel
Ursprünglicher Blog-Beitrag
Der folgende Code generiert zur Demonstration zufällige Daten.
Dataset
und dann seinIterator
. Wir erhalten den vom Iterator erzeugten Tensor, genanntinput_tensor
der als Eingabe für unser Modell dient.input_tensor
: einem GRU-basierten bidirektionalen RNN, gefolgt von einem dichten Klassifikator. Weil warum nicht.softmax_cross_entropy_with_logits
, optimiert mitAdam
. Nach 2 Epochen (mit jeweils 2 Chargen) speichern wir das "trainierte" Modell mittf.saved_model.simple_save
. Wenn Sie den Code unverändert ausführen, wird das Modell in einem Ordner gespeichert, dersimple/
in Ihrem aktuellen Arbeitsverzeichnis aufgerufen wird .tf.saved_model.loader.load
. Wir greifen die Platzhalter und Logs mitgraph.get_tensor_by_name
und denIterator
Initialisierungsvorgang mitgraph.get_operation_by_name
.Code:
Dies wird gedruckt:
quelle
tf.contrib.layers
?[n.name for n in graph2.as_graph_def().node]
. Wie in der Dokumentation angegeben, zielt das einfache Speichern darauf ab, die Interaktion mit dem Tensorflow-Serving zu vereinfachen. Dies ist der Punkt der Argumente. andere Variablen werden jedoch noch wiederhergestellt, da sonst keine Inferenz auftreten würde. Nehmen Sie einfach Ihre interessierenden Variablen wie im Beispiel. Schauen Sie sich die Dokumentation anglobal_step
ein Argument, wenn Sie aufhören und dann versuchen, das Training wieder aufzunehmen, wird es denken, dass Sie ein Schritt eins sind. Es wird zumindest Ihre Tensorboard-Visualisierungen vermasselnIch verbessere meine Antwort, um weitere Details zum Speichern und Wiederherstellen von Modellen hinzuzufügen.
In (und nach) Tensorflow Version 0.11 :
Speichern Sie das Modell:
Stellen Sie das Modell wieder her:
Dieser und einige fortgeschrittenere Anwendungsfälle wurden hier sehr gut erklärt.
Ein kurzes, vollständiges Tutorial zum Speichern und Wiederherstellen von Tensorflow-Modellen
quelle
:0
sie den Namen hinzu?In (und nach) TensorFlow Version 0.11.0RC1 können Sie Ihr Modell direkt durch den Aufruf speichern und wiederherzustellen
tf.train.export_meta_graph
undtf.train.import_meta_graph
nach https://www.tensorflow.org/programmers_guide/meta_graph .Speichern Sie das Modell
Stellen Sie das Modell wieder her
quelle
<built-in function TF_Run> returned a result with an error set
tf.get_variable_scope().reuse_variables()
gefolgt von erhaltenvar = tf.get_variable("varname")
. Dies gibt mir den Fehler: "ValueError: Variable varname existiert nicht oder wurde nicht mit tf.get_variable () erstellt." Warum? Sollte das nicht möglich sein?Für TensorFlow-Version <0.11.0RC1:
Die gespeicherten Prüfpunkte enthalten Werte für die
Variable
s in Ihrem Modell, nicht für das Modell / Diagramm selbst. Dies bedeutet, dass das Diagramm beim Wiederherstellen des Prüfpunkts identisch sein sollte.Hier ist ein Beispiel für eine lineare Regression, bei der es eine Trainingsschleife gibt, in der Variablenprüfpunkte gespeichert werden, und einen Bewertungsabschnitt, in dem in einem vorherigen Lauf gespeicherte Variablen wiederhergestellt und Vorhersagen berechnet werden. Natürlich können Sie auch Variablen wiederherstellen und das Training fortsetzen, wenn Sie möchten.
Hier sind die Dokumente für
Variable
s, die das Speichern und Wiederherstellen abdecken. Und hier sind die Dokumente für dieSaver
.quelle
batch_x
sein? Binär? Numpy Array?undefined
. Können Sie mir sagen, welche FLAGS für diesen Code def ist? @ RyanSepassiMeine Umgebung: Python 3.6, Tensorflow 1.3.0
Obwohl es viele Lösungen gab, basieren die meisten auf diesen
tf.train.Saver
. Wenn wir eine Last.ckpt
von gespeichertSaver
, haben wir das tensorflow Netzwerk entweder neu oder einige seltsame und schwer erinnerten Namen verwenden, zB'placehold_0:0'
,'dense/Adam/Weight:0'
. Hier empfehle ich die Verwendungtf.saved_model
eines einfachen Beispiels, mit dem Sie mehr über das Servieren eines TensorFlow-Modells lernen können :Speichern Sie das Modell:
Laden Sie das Modell:
quelle
Das Modell besteht aus zwei Teilen: der Modelldefinition, die
Supervisor
wiegraph.pbtxt
im Modellverzeichnis gespeichert ist, und den numerischen Werten der Tensoren, die in Prüfpunktdateien wie gespeichert werdenmodel.ckpt-1003418
.Die Modelldefinition kann mithilfe von wiederhergestellt werden
tf.import_graph_def
, und die Gewichte werden mithilfe von wiederhergestelltSaver
.Verwendet jedoch
Saver
eine spezielle Sammlungsliste mit Variablen, die an das Modelldiagramm angehängt sind, und diese Sammlung wird nicht mit import_graph_def initialisiert, sodass Sie die beiden derzeit nicht zusammen verwenden können (dies ist in unserer Roadmap zu beheben). Im Moment müssen Sie den Ansatz von Ryan Sepassi verwenden - manuell ein Diagramm mit identischen Knotennamen erstellen undSaver
die Gewichte darin laden.(Alternativ können Sie es hacken, indem Sie verwenden
import_graph_def
, Variablen manuell erstellen undtf.add_to_collection(tf.GraphKeys.VARIABLES, variable)
für jede Variable verwenden und dann verwenden.Saver
)quelle
Sie können diesen Weg auch einfacher nehmen.
Schritt 1: Initialisieren Sie alle Ihre Variablen
Schritt 2: Speichern Sie die Sitzung im Modell
Saver
und speichern Sie sieSchritt 3: Stellen Sie das Modell wieder her
Schritt 4: Überprüfen Sie Ihre Variable
Verwenden Sie, während Sie in einer anderen Python-Instanz ausgeführt werden
quelle
In den meisten Fällen ist das Speichern und Wiederherstellen von der Festplatte mit a
tf.train.Saver
die beste Option:Sie können auch die Diagrammstruktur selbst speichern / wiederherstellen ( Einzelheiten finden Sie in der MetaGraph-Dokumentation ). Standardmäßig
Saver
speichert das die Diagrammstruktur in einer.meta
Datei. Sie können anrufenimport_meta_graph()
, um es wiederherzustellen. Es stellt die Diagrammstruktur wieder her und gibt eine zurückSaver
, mit der Sie den Status des Modells wiederherstellen können:Es gibt jedoch Fälle, in denen Sie etwas viel schneller benötigen. Wenn Sie beispielsweise ein vorzeitiges Anhalten implementieren, möchten Sie jedes Mal, wenn sich das Modell während des Trainings verbessert (gemessen am Validierungssatz), Prüfpunkte speichern. Wenn für einige Zeit keine Fortschritte erzielt werden, möchten Sie zum besten Modell zurückkehren. Wenn Sie das Modell bei jeder Verbesserung auf der Festplatte speichern, wird das Training erheblich verlangsamt. Der Trick besteht darin, die variablen Zustände im Speicher zu speichern und sie später wiederherzustellen:
Eine kurze Erklärung: Wenn Sie eine Variable erstellen
X
, erstellt TensorFlow automatisch eine ZuweisungsoperationX/Assign
, um den Anfangswert der Variablen festzulegen. Anstatt Platzhalter und zusätzliche Zuweisungsoperationen zu erstellen (was das Diagramm nur unübersichtlich machen würde), verwenden wir nur diese vorhandenen Zuweisungsoperationen. Die erste Eingabe jeder Zuweisung op ist eine Referenz auf die Variable, die initialisiert werden soll, und die zweite Eingabe (assign_op.inputs[1]
) ist der Anfangswert. Um also einen beliebigen Wert festzulegen (anstelle des Anfangswertes), müssen wir a verwendenfeed_dict
und den Anfangswert ersetzen. Ja, mit TensorFlow können Sie einen Wert für jede Operation eingeben, nicht nur für Platzhalter. Dies funktioniert also einwandfrei.quelle
Wie Jaroslaw sagte, können Sie die Wiederherstellung von einem graph_def und einem Checkpoint hacken, indem Sie das Diagramm importieren, manuell Variablen erstellen und dann einen Sparer verwenden.
Ich habe dies für meinen persönlichen Gebrauch implementiert, daher würde ich den Code hier teilen.
Link: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(Dies ist natürlich ein Hack, und es gibt keine Garantie dafür, dass auf diese Weise gespeicherte Modelle in zukünftigen Versionen von TensorFlow lesbar bleiben.)
quelle
Wenn es sich um ein intern gespeichertes Modell handelt, geben Sie einfach einen Wiederhersteller für alle Variablen als an
und verwenden Sie es, um Variablen in einer aktuellen Sitzung wiederherzustellen:
Für das externe Modell müssen Sie die Zuordnung von den Variablennamen zu Ihren Variablennamen angeben. Sie können die Namen der Modellvariablen mit dem Befehl anzeigen
Das Skript inspect_checkpoint.py befindet sich im Ordner './tensorflow/python/tools' der Tensorflow-Quelle.
Um die Zuordnung festzulegen, können Sie mein Tensorflow-Worklab verwenden , das eine Reihe von Klassen und Skripten enthält, um verschiedene Modelle zu trainieren und neu zu trainieren. Es enthält ein Beispiel für die Umschulung von ResNet-Modellen, das sich hier befindet
quelle
all_variables()
ist jetzt veraltetHier ist meine einfache Lösung für die beiden grundlegenden Fälle, die sich darin unterscheiden, ob Sie das Diagramm aus einer Datei laden oder zur Laufzeit erstellen möchten.
Diese Antwort gilt für Tensorflow 0.12+ (einschließlich 1.0).
Neuaufbau des Diagramms im Code
Sparen
Wird geladen
Laden Sie auch das Diagramm aus einer Datei
Stellen Sie bei Verwendung dieser Technik sicher, dass alle Ihre Ebenen / Variablen explizit eindeutige Namen festgelegt haben.Andernfalls macht Tensorflow die Namen selbst eindeutig und unterscheidet sich somit von den in der Datei gespeicherten Namen. Bei der vorherigen Technik ist dies kein Problem, da die Namen beim Laden und Speichern auf die gleiche Weise "entstellt" werden.
Sparen
Wird geladen
quelle
global_step
Variablen und die gleitenden Durchschnitte der Chargennormalisierung nicht trainierbare Variablen, aber beide sind definitiv eine Speicherung wert. Außerdem sollten Sie die Konstruktion des Diagramms klarer von der Ausführung der Sitzung unterscheiden. BeispielsweiseSaver(...).save()
werden bei jeder Ausführung neue Knoten erstellt. Wahrscheinlich nicht das, was du willst. Und es gibt noch mehr ...: /Sie können sich auch Beispiele in TensorFlow / skflow ansehen , das Angebote
save
undrestore
Methoden bietet, mit denen Sie Ihre Modelle einfach verwalten können. Es verfügt über Parameter, mit denen Sie auch steuern können, wie oft Sie Ihr Modell sichern möchten.quelle
Wenn Sie tf.train.MonitoredTrainingSession als Standardsitzung verwenden, müssen Sie keinen zusätzlichen Code hinzufügen, um Dinge zu speichern / wiederherzustellen. Übergeben Sie einfach einen Checkpoint-Verzeichnisnamen an den Konstruktor von MonitoredTrainingSession. Er verwendet Sitzungs-Hooks, um diese zu verarbeiten.
quelle
Alle Antworten hier sind großartig, aber ich möchte zwei Dinge hinzufügen.
Um die Antwort von @ user7505159 näher zu erläutern, kann es wichtig sein, das "./" am Anfang des Dateinamens hinzuzufügen, den Sie wiederherstellen.
Sie können beispielsweise ein Diagramm ohne "./" im Dateinamen wie folgt speichern:
Um das Diagramm wiederherzustellen, müssen Sie dem Dateinamen möglicherweise ein "./" voranstellen:
Sie benötigen nicht immer das "./", es kann jedoch je nach Umgebung und Version von TensorFlow Probleme verursachen.
Es möchte auch erwähnt werden, dass
sess.run(tf.global_variables_initializer())
dies wichtig sein kann, bevor die Sitzung wiederhergestellt wird.Wenn beim Versuch, eine gespeicherte Sitzung wiederherzustellen, eine Fehlermeldung bezüglich nicht initialisierter Variablen angezeigt wird, stellen Sie sicher, dass Sie diese
sess.run(tf.global_variables_initializer())
vor dersaver.restore(sess, save_file)
Zeile einfügen. Es kann Ihnen Kopfschmerzen ersparen.quelle
Wie in Ausgabe 6255 beschrieben :
anstatt
quelle
Laut der neuen Tensorflow-Version
tf.train.Checkpoint
ist dies die bevorzugte Methode zum Speichern und Wiederherstellen eines Modells:Hier ist ein Beispiel:
Weitere Informationen und Beispiele hier.
quelle
Für Tensorflow 2.0 ist es so einfach wie
Etwas wiederherstellen:
quelle
tf.keras Modell speichern mit
TF2.0
Ich sehe gute Antworten zum Speichern von Modellen mit TF1.x. Ich möchte ein paar weitere Hinweise zum Speichern von
tensorflow.keras
Modellen geben, was etwas kompliziert ist, da es viele Möglichkeiten gibt, ein Modell zu speichern.Hier gebe ich ein Beispiel für das Speichern eines
tensorflow.keras
Modells in einemmodel_path
Ordner im aktuellen Verzeichnis. Dies funktioniert gut mit dem neuesten Tensorflow (TF2.0). Ich werde diese Beschreibung aktualisieren, wenn sich in naher Zukunft etwas ändert.Speichern und Laden des gesamten Modells
Modell speichern und laden Nur Gewichte
Wenn Sie nur Modellgewichte speichern und dann Gewichte laden möchten, um das Modell wiederherzustellen, dann
Speichern und Wiederherstellen mit Keras Checkpoint Callback
Modell mit benutzerdefinierten Metriken speichern
Speichern des Keras-Modells mit benutzerdefinierten Operationen
Wenn wir benutzerdefinierte Operationen wie im folgenden Fall (
tf.tile
) haben, müssen wir eine Funktion erstellen und mit einer Lambda-Ebene umbrechen. Andernfalls kann das Modell nicht gespeichert werden.Ich glaube, ich habe einige der vielen Möglichkeiten zum Speichern des tf.keras-Modells behandelt. Es gibt jedoch viele andere Möglichkeiten. Bitte kommentieren Sie unten, wenn Sie sehen, dass Ihr Anwendungsfall oben nicht behandelt wird. Vielen Dank!
quelle
Verwenden Sie tf.train.Saver, um ein Modell zu speichern. Remerber, Sie müssen die var_list angeben, wenn Sie die Modellgröße reduzieren möchten. Die val_list kann tf.trainable_variables oder tf.global_variables sein.
quelle
Sie können die Variablen im Netzwerk mit speichern
Um Wiederherstellung des Netzes für die Wiederverwendung später oder in einem anderen Skript verwenden:
Wichtige Punkte:
sess
muss zwischen erstem und späterem Lauf gleich sein (kohärente Struktur).saver.restore
benötigt den Pfad des Ordners der gespeicherten Dateien, keinen einzelnen Dateipfad.quelle
Wo immer Sie das Modell speichern möchten,
Stellen Sie sicher, dass alle Ihre
tf.Variable
Namen haben, da Sie sie möglicherweise später unter Verwendung ihrer Namen wiederherstellen möchten. Und wo Sie vorhersagen wollen,Stellen Sie sicher, dass der Sparer in der entsprechenden Sitzung ausgeführt wird. Denken Sie daran, dass bei Verwendung von
tf.train.latest_checkpoint('./')
nur der neueste Kontrollpunkt verwendet wird.quelle
Ich bin auf Version:
Einfacher Weg ist
Speichern:
Wiederherstellen:
quelle
Für Tensorflow-2.0
es ist sehr einfach.
SPEICHERN
WIEDERHERSTELLEN
quelle
Nach der Antwort von @Vishnuvardhan Janapati gibt es eine weitere Möglichkeit, ein Modell mit benutzerdefinierten Ebenen / Metriken / Verlusten unter TensorFlow 2.0.0 zu speichern und neu zu laden
Auf diese Weise können Sie, sobald Sie solche Codes ausgeführt und Ihr Modell mit
tf.keras.models.save_model
odermodel.save
oderModelCheckpoint
Rückruf gespeichert haben, Ihr Modell erneut laden, ohne präzise benutzerdefinierte Objekte zu benötigenquelle
In der neuen Version von Tensorflow 2.0 ist das Speichern / Laden eines Modells viel einfacher. Aufgrund der Implementierung der Keras-API, einer übergeordneten API für TensorFlow.
So speichern Sie ein Modell: Überprüfen Sie die Dokumentation als Referenz: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model
So laden Sie ein Modell:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model
quelle
Hier ist ein einfaches Beispiel für die Verwendung des Tensorflow 2.0 SavedModel- Formats (das laut Dokumentation das empfohlene Format ist ) für einen einfachen MNIST-Dataset-Klassifikator, bei dem die Keras-Funktions-API verwendet wird, ohne dass zu viel Lust darauf besteht:
Was ist
serving_default
?Dies ist der Name der Signatur def des von Ihnen ausgewählten Tags (in diesem Fall wurde das Standard-
serve
Tag ausgewählt). Außerdem wird hier erläutert, wie Sie die Tags und Signaturen eines Modells mithilfe von findensaved_model_cli
.Haftungsausschluss
Dies ist nur ein einfaches Beispiel, wenn Sie es nur zum Laufen bringen möchten, aber es ist keineswegs eine vollständige Antwort - vielleicht kann ich es in Zukunft aktualisieren. Ich wollte nur ein einfaches Beispiel mit dem geben
SavedModel
in TF 2.0 geben, weil ich nirgendwo eines gesehen habe, auch nicht so einfach.@ Toms Antwort ist ein SavedModel-Beispiel, aber es funktioniert nicht mit Tensorflow 2.0, da es leider einige wichtige Änderungen gibt.
@ Vishnuvardhan Janapatis Antwort lautet TF 2.0, aber nicht für das SavedModel-Format.
quelle