Wie liste ich alle verwendeten Operationen in Tensorflow SavedModel auf?

10

Wenn ich mein Modell mit der tensorflow.saved_model.saveFunktion im SavedModel-Format speichere, wie kann ich anschließend abrufen, welche Tensorflow-Operationen in diesem Modell verwendet werden? Da das Modell wiederhergestellt werden kann, werden diese Vorgänge im Diagramm gespeichert. Meine Vermutung liegt in der saved_model.pbDatei. Wenn ich diesen Protobuf lade (also nicht das gesamte Modell), listet der Bibliotheksteil des Protobuf diese auf, aber dies ist vorerst nicht dokumentiert und als experimentelle Funktion gekennzeichnet. In Tensorflow 1.x erstellte Modelle haben diesen Teil nicht.

Was ist also eine schnelle und zuverlässige Methode, um eine Liste der verwendeten Operationen (Like MatchingFilesoder WriteFile) von einem Modell im SavedModel-Format abzurufen?

Im Moment kann ich das Ganze einfrieren, so wie es tensorflowjs-convertertut. Da sie auch nach unterstützten Operationen suchen. Dies funktioniert derzeit nicht, wenn sich ein LSTM im Modell befindet (siehe hier) . Gibt es einen besseren Weg, dies zu tun, da die Ops definitiv da sind?

Ein Beispielmodell:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

Erwartet in der Ausgabe alle Operationen, die in diesem Fall mindestens Folgendes enthalten:

  • ReadFilewie hier beschrieben
  • ...
sampers
quelle
1
Es ist schwer genau zu sagen, was Sie wollen, was ist saved_model.pb, ist es eine tf.GraphDefoder eine SavedModelProtobuf-Nachricht? Wenn Sie tf.GraphDefangerufen haben gd, können Sie die Liste der verwendeten Operationen mit abrufen sorted(set(n.op for n in gd.node)). Wenn Sie ein geladenes Modell haben, können Sie dies tun sorted(set(op.type for op in tf.get_default_graph().get_operations())). Wenn es ein ist SavedModel, können Sie das tf.GraphDefdavon bekommen (zB saved_model.meta_graphs[0].graph_def).
Jdehesa
Ich möchte die Operationen von einem gespeicherten SavedModel abrufen. Also in der Tat die letzte Option, die Sie beschreiben. Was ist die saved_modelVariable in Ihrem letzten Beispiel? Das Ergebnis tf.saved_model.load('/path/to/model')oder Laden des Protobufs der Datei saved_model.pb.
Sampers

Antworten:

1

Wenn saved_model.pbes sich um eine SavedModelProtobuf-Nachricht handelt, erhalten Sie die Operationen direkt von dort. Angenommen, wir erstellen ein Modell wie folgt:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

Wir können nun die von diesem Modell verwendeten Operationen wie folgt finden:

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin
jdehesa
quelle
Ich habe so etwas versucht, aber leider entspricht dies nicht input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content')meinen Erwartungen : Angenommen, ich habe ein Modell, das dies tut: Dann ist das hier aufgeführte ReadFile Op dort, wird aber nicht gedruckt.
Sampers
1
@sampers Ich habe die Antwort mit einem Beispiel bearbeitet, wie Sie es vorschlagen. Ich bekomme die ReadFileOperation in der Ausgabe. Ist es möglich, dass in Ihrem tatsächlichen Fall diese Operation nicht zwischen der Eingabe und der Ausgabe des gespeicherten Modells liegt? In diesem Fall denke ich, dass es beschnitten werden könnte.
Jdehesa
In der Tat funktioniert es mit dem gegebenen Modell. Leider für ein Modul in tf2 nicht. Wenn ich ein tf.Module mit 1 Funktion mit einer file_nameArgumentanmerkung erstelle @tf.function, die die Aufrufe enthält, die ich in meinem vorherigen Kommentar aufgeführt habe, wird die folgende Liste Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
angezeigt
hat meiner Frage ein Modell hinzugefügt
sampers
@sampers Ich habe meine Antwort aktualisiert. Ich habe zuvor TF 1.x verwendet. Ich war mit den Änderungen an Grafikdefinitionsobjekten in TF 2.x nicht vertraut. Ich denke, die Antwort deckt jetzt alles im gespeicherten Modell ab. Ich denke, die Operationen, die der Python-Funktion entsprechen, die Sie geschrieben haben, befinden sich in saved_model.meta_graphs[0].graph_def.library.function[0](der node_defSammlung innerhalb dieses Funktionsobjekts).
Jdehesa