Was ist die Ausgabe eines tf.nn.dynamic_rnn ()?

8

Ich bin mir nicht sicher, was ich aus der offiziellen Dokumentation verstehe, in der es heißt:

Rückgabe: Ein Paar (Ausgänge, Status) wobei:

outputs: Der RNN-Ausgangstensor.

Wenn time_major == False(Standard), ist dies eine Tensorform : [batch_size, max_time, cell.output_size].

Wenn time_major == Truedies ein Tensor ist : [max_time, batch_size, cell.output_size].

Wenn cell.output_sizees sich um ein (möglicherweise verschachteltes) Tupel von Ganzzahlen oder TensorShape-Objekten handelt, sind die Ausgaben ein Tupel mit derselben Struktur wie cell.output_size, das Tensoren mit Formen enthält, die den Formdaten in entsprechen cell.output_size.

state: Der Endzustand. Wenn cell.state_size ein int ist, wird dies geformt [batch_size, cell.state_size]. Wenn es sich um eine TensorShape handelt, wird diese geformt [batch_size] + cell.state_size. Wenn es sich um ein (möglicherweise verschachteltes) Tupel von Ints oder TensorShape handelt, handelt es sich um ein Tupel mit den entsprechenden Formen. Wenn Zellen LSTMCells sind, ist der Status ein Tupel, das ein LSTMStateTuple für jede Zelle enthält.

Ist output[-1] immer (in allen drei Zelltypen, dh RNN, GRU, LSTM) gleich state (zweites Element des Rückgabetupels)? Ich denke, die Literatur ist überall zu liberal in der Verwendung des Begriffs versteckter Zustand. Ist der versteckte Zustand in allen drei Zellen die Punktzahl, die herauskommt (warum er als versteckt bezeichnet wird, ist mir ein Rätsel, es scheint, dass der Zellzustand in LSTM als versteckter Zustand bezeichnet werden sollte, da er nicht offengelegt wird)?

MiloMinderbinder
quelle

Antworten:

10

Ja, die Zellenausgabe entspricht dem verborgenen Zustand. Im Fall von LSTM ist es der kurzfristige Teil des Tupels (zweites Element von LSTMStateTuple), wie in diesem Bild zu sehen ist:

LSTM

Aber für tf.nn.dynamic_rnnden zurück kann Zustand unterschiedlich sein , wenn die Sequenz kürzer ist ( sequence_lengthArgument). Schauen Sie sich dieses Beispiel an:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])
seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print(outputs_val)
  print()
  print(states_val)

Hier enthält der Eingabestapel 4 Sequenzen, von denen eine kurz und mit Nullen aufgefüllt ist. Beim Laufen solltest du so etwas machen:

[[[ 0.2315362  -0.37939444 -0.625332   -0.80235624  0.2288385 ]
  [ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]]

 [[ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.9994331   0.9929737  -0.8311569  -0.99928087  0.9990415 ]
  [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]]

 [[ 0.9962312   0.99659646  0.98880637  0.99548346  0.9997809 ]
  [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]]

[[ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]
 [ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
 [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]
 [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]

... was in der Tat zeigt, dass state == output[1]für vollständige Sequenzen und state == output[0]für die kurze. Ist auch output[1]ein Nullvektor für diese Sequenz. Gleiches gilt für LSTM- und GRU-Zellen.

Das stateist also ein praktischer Tensor, der den letzten tatsächlichen RNN-Zustand enthält und die Nullen ignoriert. Der outputTensor enthält die Ausgänge aller Zellen, sodass die Nullen nicht ignoriert werden. Das ist der Grund für die Rückgabe beider.

Maxime
quelle
2

Mögliche Kopie von /programming/36817596/get-last-output-of-dynamic-rnn-in-tensorflow/49705930#49705930

Wie auch immer, lassen Sie uns mit der Antwort fortfahren.

Dieser Code-Snip kann helfen, zu verstehen, was wirklich von der dynamic_rnnEbene zurückgegeben wird

=> Tupel von (Ausgaben, final_output_state) .

Also für einen Eingang mit max Sequenzlänge von T Zeitschritten Ausgänge ist von der Form [Batch_size, T, num_inputs](bei time_major= False; Standardwert) , und es enthält den Ausgangszustand bei jedem Zeitschritt h1, h2.....hT.

Und final_output_state hat die Form [Batch_size,num_inputs]und hat den endgültigen Zellen- cTund Ausgabezustand hTjeder Stapelsequenz.

Aber da das verwendet dynamic_rnnwird, schätze ich, dass Ihre Sequenzlängen für jede Charge variieren.

    import tensorflow as tf
    import numpy as np
    from tensorflow.contrib import rnn
    tf.reset_default_graph()

    # Create input data
    X = np.random.randn(2, 10, 8)

    # The second example is of length 6 
    X[1,6:] = 0
    X_lengths = [10, 6]

    cell = tf.nn.rnn_cell.LSTMCell(num_units=64, state_is_tuple=True)

    outputs, states  = tf.nn.dynamic_rnn(cell=cell,
                                         dtype=tf.float64,
                                         sequence_length=X_lengths,
                                         inputs=X)

    result = tf.contrib.learn.run_n({"outputs": outputs, "states":states},
                                    n=1,
                                    feed_dict=None)
    assert result[0]["outputs"].shape == (2, 10, 64)
    print result[0]["outputs"].shape
    print result[0]["states"].h.shape
    # the final outputs state and states returned must be equal for each      
    # sequence
    assert(result[0]["outputs"][0][-1]==result[0]["states"].h[0]).all()
    assert(result[0]["outputs"][-1][5]==result[0]["states"].h[-1]).all()
    assert(result[0]["outputs"][-1][-1]==result[0]["states"].h[-1]).all()

Die endgültige Behauptung schlägt fehl, da der endgültige Zustand für die 2. Sequenz im 6. Zeitschritt liegt, d. H. Der Index 5 und die restlichen Ausgaben von [6: 9] sind im 2. Zeitschritt alle Nullen

Bhaskar Arun
quelle