0%

对RNN中Output与State的区分

对RNN中Output与State的区分

之所以写这篇文章,主要是在之前实现Seq2Seq模型时发现了一些问题,在这方面有着一些混淆,主要是对LSTM的理解不够

先看基础RNN的结构

BasicRNN
在这里,每一个RNN单元实现的功能是

那么$O_t$是干嘛的呢?
实际上$O_t$的输出并不是由RNN单元完成的,在每次生成$S_t$后,还需要进行一次全连接才能产生$O_t$,即

在Tensorflow中,如果不自加全连接层而使用BasicRNN

1
2
3
4
5
def get_basicRNN_cell(rnn_size):
return tf.nn.rnn_cell.BasicRNNCell(rnn_size)

stack_rnn = tf.nn.rnn_cell.MultiRNNCell([get_basicRNN_cell(rnn_size) for _ in range(num_layers)])
output, state = tf.nn.dynamic_rnn(stack_rnn, input_data, dtype = tf.float32)

那么output[-1]和state是一样的,即output是所有state的集合,而最后一个output则是finalState


LSTM,其整体结构和RNN一样,但是由于内部结构有了变化,以上的情况就不适用了
再看LSTM的输入输出结构
BasicRNN
每个LSTM实际上进行的运算是

在LSTM中,$S_t$由$C_t$和$h_t$两部分组成,通常是两向量拼接而成,但是

即在Tensorflow中,如果不自加全连接层,那么output[-1]应和finalState的一部分相同