Nach dem, was ich bisher gesammelt habe, gibt es verschiedene Möglichkeiten, ein TensorFlow-Diagramm in eine Datei zu kopieren und dann in ein anderes Programm zu laden, aber ich konnte keine eindeutigen Beispiele / Informationen zu deren Funktionsweise finden. Was ich bereits weiß, ist Folgendes:
- Speichern Sie die Variablen des Modells mit a in einer Prüfpunktdatei (.ckpt)
tf.train.Saver()
und stellen Sie sie später wieder her ( Quelle ). - Speichern Sie ein Modell in einer .pb-Datei und laden Sie es mit
tf.train.write_graph()
undtf.import_graph_def()
( Quelle ) zurück. - Laden Sie ein Modell aus einer .pb-Datei, trainieren Sie es erneut und speichern Sie es mit Bazel ( Quelle ) in einer neuen .pb-Datei.
- Frieren Sie das Diagramm ein, um das Diagramm und die Gewichte zusammen zu speichern ( Quelle )
- Verwenden Sie
as_graph_def()
diese Option , um das Modell zu speichern und für Gewichte / Variablen Konstanten zuzuordnen ( Quelle ).
Ich konnte jedoch einige Fragen zu diesen verschiedenen Methoden nicht klären:
- Speichern Checkpoint-Dateien nur die trainierten Gewichte eines Modells? Könnten Checkpoint-Dateien in ein neues Programm geladen und zum Ausführen des Modells verwendet werden, oder dienen sie einfach dazu, die Gewichte in einem Modell zu einem bestimmten Zeitpunkt / in einem bestimmten Stadium zu speichern?
- Werden
tf.train.write_graph()
auch die Gewichte / Variablen gespeichert? - Kann Bazel nur zur Umschulung in .pb-Dateien gespeichert oder daraus geladen werden? Gibt es einen einfachen Bazel-Befehl, um ein Diagramm in eine .pb-Datei zu kopieren?
- Kann beim Einfrieren ein eingefrorenes Diagramm mit verwendet werden
tf.import_graph_def()
? - Die Android-Demo für TensorFlow wird im Google Inception-Modell aus einer .pb-Datei geladen. Wenn ich meine eigene .pb-Datei ersetzen wollte, wie würde ich das tun? Muss ich nativen Code / native Methoden ändern?
- Was genau ist im Allgemeinen der Unterschied zwischen all diesen Methoden? Oder allgemeiner, was ist der Unterschied zwischen
as_graph_def()
/.ckpt/.pb?
Kurz gesagt, ich suche nach einer Methode, um sowohl ein Diagramm (wie in den verschiedenen Operationen und dergleichen) als auch seine Gewichte / Variablen in einer Datei zu speichern, die dann zum Laden des Diagramms und der Gewichte in ein anderes Programm verwendet werden kann , zur Verwendung (nicht unbedingt Fortsetzung / Umschulung).
Die Dokumentation zu diesem Thema ist nicht sehr einfach, daher wären Antworten / Informationen sehr willkommen.
quelle
Antworten:
Es gibt viele Möglichkeiten, das Problem des Speicherns eines Modells in TensorFlow anzugehen, was es etwas verwirrend machen kann. Nehmen Sie jede Ihrer Unterfragen der Reihe nach:
Die Prüfpunktdateien (die beispielsweise durch Aufrufen
saver.save()
einestf.train.Saver
Objekts erstellt wurden) enthalten nur die Gewichte und alle anderen Variablen, die im selben Programm definiert sind. Um sie in einem anderen Programm zu verwenden, müssen Sie die zugehörige Diagrammstruktur neu erstellen (z. B. indem Sie Code ausführen, um sie erneut zu erstellen, oder aufrufentf.import_graph_def()
), um TensorFlow mitzuteilen, was mit diesen Gewichten zu tun ist. Beachten Sie, dass beim Aufrufensaver.save()
auch eine Datei mit a erstellt wirdMetaGraphDef
, die ein Diagramm und Details zum Zuordnen der Gewichte von einem Prüfpunkt zu diesem Diagramm enthält. Weitere Informationen finden Sie im Tutorial .tf.train.write_graph()
schreibt nur die Graphstruktur; nicht die Gewichte.Bazel hat nichts mit dem Lesen oder Schreiben von TensorFlow-Diagrammen zu tun. (Vielleicht verstehe ich Ihre Frage falsch: Sie können sie gerne in einem Kommentar klarstellen.)
Ein eingefrorenes Diagramm kann mit geladen werden
tf.import_graph_def()
. In diesem Fall sind die Gewichte (normalerweise) in das Diagramm eingebettet, sodass Sie keinen separaten Prüfpunkt laden müssen.Die Hauptänderung besteht darin, die Namen der Tensoren, die in das Modell eingespeist werden, und die Namen der Tensoren, die aus dem Modell abgerufen werden, zu aktualisieren. In der TensorFlow-Android-Demo entspricht dies den
inputName
undoutputName
Zeichenfolgen, an die übergeben wirdTensorFlowClassifier.initializeTensorFlow()
.Dies
GraphDef
ist die Programmstruktur, die sich normalerweise während des Trainingsprozesses nicht ändert. Der Checkpoint ist eine Momentaufnahme des Status eines Trainingsprozesses, der sich normalerweise bei jedem Schritt des Trainingsprozesses ändert. Infolgedessen verwendet TensorFlow unterschiedliche Speicherformate für diese Datentypen, und die Low-Level-API bietet verschiedene Möglichkeiten zum Speichern und Laden. Übergeordnete Bibliotheken wieMetaGraphDef
Bibliotheken, Keras und Skflow bauen auf diesen Mechanismen auf und bieten bequemere Möglichkeiten zum Speichern und Wiederherstellen eines gesamten Modells.quelle
tf.train.write_graph()
und dann ausführen können?GraphDef
gespeicherten vontf.train.write_graph()
auch die Namen der Tensoren merken müssen, die Sie beim Ausführen des Diagramms füttern und abrufen möchten (Punkt 5 oben).Sie können den folgenden Code ausprobieren:
quelle