博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
LSTMCell和BasicLSTMCell
阅读量:4283 次
发布时间:2019-05-27

本文共 6737 字,大约阅读时间需要 22 分钟。

tf.contrib.rnn.BasicLSTMCell

Defined in tensorflow/python/ops/rnn_cell_impl.py.

__init__(    num_units,    forget_bias=1.0,    state_is_tuple=True,    activation=None,    reuse=None) 
1
2
3
4
5
6
7
Initialize the basic LSTM cell.Args:*num_units*: int, The number of units in the LSTM cell.*forget_bias*: float, The bias added to forget gates (see above). Must set to 0.0 manually when restoring from CudnnLSTM-trained checkpoints.*state_is_tuple*: If True, accepted and returned states are 2-tuples of the c_state and m_state. If False, they are concatenated along the column axis. The latter behavior will soon be deprecated.*activation*: Activation function of the inner states. Default: tanh.*reuse*: (optional) Python boolean describing whether to reuse variables in an existing scope. If not True, and the existing scope already has the given variables, an error is raised.When restoring from CudnnLSTM-trained checkpoints, must use CudnnCompatibleLSTMCell instead. 
1
2
3
4
5
6
7
8
9
10

tf.contrib.rnn.LSTMCell

Defined in tensorflow/python/ops/rnn_cell_impl.py.

__init__(    num_units,    use_peepholes=False,    cell_clip=None,    initializer=None,    num_proj=None,    proj_clip=None,    num_unit_shards=None,    num_proj_shards=None,    forget_bias=1.0,    state_is_tuple=True,    activation=None,    reuse=None) 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
Initialize the parameters for an LSTM cell.Args:*num_units*: int, The number of units in the LSTM cell.*use_peepholes*: bool, set True to enable diagonal/peephole connections.*cell_clip*: (optional) A float value, if provided the cell state is clipped by this value prior to the cell output activation.*initializer*: (optional) The initializer to use for the weight and projection matrices.*num_proj*: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed.*proj_clip*: (optional) A float value. If num_proj > 0 and proj_clip is provided, then the projected values are clipped elementwise to within [-proj_clip, proj_clip].*num_unit_shards*: Deprecated, will be removed by Jan. 2017. Use a variable_scope partitioner instead.*num_proj_shards*: Deprecated, will be removed by Jan. 2017. Use a variable_scope partitioner instead.*forget_bias*: Biases of the forget gate are initialized by default to 1 in order to reduce the scale of forgetting at the beginning of the training. Must set it manually to 0.0 when restoring from CudnnLSTM trained checkpoints.*state_is_tuple*: If True, accepted and returned states are 2-tuples of the c_state and m_state. If False, they are concatenated along the column axis. This latter behavior will soon be deprecated.*activation*: Activation function of the inner states. Default: tanh.*reuse*: (optional) Python boolean describing whether to reuse variables in an existing scope. If not True, and the existing scope already has the given variables, an error is raised.When restoring from CudnnLSTM-trained checkpoints, must use CudnnCompatibleLSTMCell instead. 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

LSTMCell和BasicLSTMCell区别

BasicLSTMCell:

if self._linear is None:      self._linear = _Linear([inputs, h], 4 * self._num_units, True)    # i = input_gate, j = new_input, f = forget_gate, o = output_gate    i, j, f, o = array_ops.split(        value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)    new_c = (        c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))    new_h = self._activation(new_c) * sigmoid(o)    if self._state_is_tuple:      new_state = LSTMStateTuple(new_c, new_h)    else:      new_state = array_ops.concat([new_c, new_h], 1)    return new_h, new_state 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

LSTMCell:

# i = input_gate, j = new_input, f = forget_gate, o = output_gate    lstm_matrix = self._linear1([inputs, m_prev])    i, j, f, o = array_ops.split(        value=lstm_matrix, num_or_size_splits=4, axis=1)    # Diagonal connections    if self._use_peepholes and not self._w_f_diag:      scope = vs.get_variable_scope()      with vs.variable_scope(          scope, initializer=self._initializer) as unit_scope:        with vs.variable_scope(unit_scope):          self._w_f_diag = vs.get_variable(              "w_f_diag", shape=[self._num_units], dtype=dtype)          self._w_i_diag = vs.get_variable(              "w_i_diag", shape=[self._num_units], dtype=dtype)          self._w_o_diag = vs.get_variable(              "w_o_diag", shape=[self._num_units], dtype=dtype)    if self._use_peepholes:      c = (sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +           sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))    else:      c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *           self._activation(j))    if self._cell_clip is not None:      # pylint: disable=invalid-unary-operand-type      c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)      # pylint: enable=invalid-unary-operand-type    if self._use_peepholes:      m = sigmoid(o + self._w_o_diag * c) * self._activation(c)    else:      m = sigmoid(o) * self._activation(c)    if self._num_proj is not None:      if self._linear2 is None:        scope = vs.get_variable_scope()        with vs.variable_scope(scope, initializer=self._initializer):          with vs.variable_scope("projection") as proj_scope:            if self._num_proj_shards is not None:              proj_scope.set_partitioner(                  partitioned_variables.fixed_size_partitioner(                      self._num_proj_shards))            self._linear2 = _Linear(m, self._num_proj, False)      m = self._linear2(m)      if self._proj_clip is not None:        # pylint: disable=invalid-unary-operand-type        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)        # pylint: enable=invalid-unary-operand-type    new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else                 array_ops.concat([c, m], 1))    return m, new_state 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

STMCell和BasicLSTMCell的区别:

1. 增加了use_peepholes, bool值,为True时增加窥视孔。
这里写图片描述
2. 增加了cell_clip, 浮点值,把cell的值限制在 ±cell_clip内

c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 
1
  1. 增加了num_proj(int)和proj_clip(float), 相对于BasicLSTMCell,在输出m计算完之后增加了一层线性变换,并限制了输出的值
m = _linear(m, self._num_proj, bias=False, scope=scope)m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 
1
2
你可能感兴趣的文章
Hadoop远程调试删除文件报错:org.apache.hadoop.security.AccessControlException: Permission denied: user=
查看>>
(转)@Override is not allowed when implementing interface method
查看>>
kafka消费者报错:Consider using the new consumer by passing [bootstrap-server] instead of [zookeeper].
查看>>
(转)Spark与Hadoop的shuffle的异同
查看>>
(转)Redis 持久化之RDB和AOF
查看>>
Redis创建集群报错:[ERR] Sorry, can't connect to node 192.168.0.9:6380
查看>>
Redis(4):集群搭建和java连接
查看>>
pom报错:Element' dependency' cannot have character [ children], because the type's content type is e
查看>>
(转)面试必备:HashMap、Hashtable、ConcurrentHashMap的原理与区别
查看>>
CDH5.15.2替换JDK1.7到1.8
查看>>
JAVA多线程(9):多线程依次打印ABC
查看>>
(转)github在git push之后不记录Contributions
查看>>
(转)IDEA导入Git项目后无Git选项
查看>>
Tomcat的GC优化实践
查看>>
idea多模块项目间通过配置pom.xml相互引用
查看>>
(转)MYSQL如何设置大小写敏感
查看>>
SpringBoot单元测试,无法导入@RunWith
查看>>
(转)hbase balance命令走过的坑
查看>>
Linux环境cpu过高,定位问题步骤(附实例)
查看>>
(转)java final关键字使用及面试题重点
查看>>