Computes the LSTM cell forward propagation for 1 time step.
This implementation uses 1 weight matrix and 1 bias vector, and there's an optional peephole connection.
This kernel op implements the following mathematical equations:
xh = [x, h_prev]
[i, f, ci, o] = xh * w + b
f = f + forget_bias
if not use_peephole:
wci = wcf = wco = 0
i = sigmoid(cs_prev * wci + i)
f = sigmoid(cs_prev * wcf + f)
ci = tanh(ci)
cs = ci .* i + cs_prev .* f
cs = clip(cs, cell_clip)
o = sigmoid(cs * wco + o)
co = tanh(cs)
h = co .* o
Nested Classes
class | LSTMBlockCell.Options | Optional attributes for LSTMBlockCell
|
Public Methods
static LSTMBlockCell.Options |
cellClip(Float cellClip)
|
Output<T> |
ci()
The cell input.
|
Output<T> |
co()
The cell after the tanh.
|
static <T extends Number> LSTMBlockCell<T> | |
Output<T> |
cs()
The cell state before the tanh.
|
Output<T> |
f()
The forget gate.
|
static LSTMBlockCell.Options |
forgetBias(Float forgetBias)
|
Output<T> |
h()
The output h vector.
|
Output<T> |
i()
The input gate.
|
Output<T> |
o()
The output gate.
|
static LSTMBlockCell.Options |
usePeephole(Boolean usePeephole)
|
Inherited Methods
Public Methods
public static LSTMBlockCell.Options cellClip (Float cellClip)
Parameters
cellClip | Value to clip the 'cs' value to. |
---|
public static LSTMBlockCell<T> create (Scope scope, Operand<T> x, Operand<T> csPrev, Operand<T> hPrev, Operand<T> w, Operand<T> wci, Operand<T> wcf, Operand<T> wco, Operand<T> b, Options... options)
Factory method to create a class wrapping a new LSTMBlockCell operation.
Parameters
scope | current scope |
---|---|
x | The input to the LSTM cell, shape (batch_size, num_inputs). |
csPrev | Value of the cell state at previous time step. |
hPrev | Output of the previous cell at previous time step. |
w | The weight matrix. |
wci | The weight matrix for input gate peephole connection. |
wcf | The weight matrix for forget gate peephole connection. |
wco | The weight matrix for output gate peephole connection. |
b | The bias vector. |
options | carries optional attributes values |
Returns
- a new instance of LSTMBlockCell
public static LSTMBlockCell.Options forgetBias (Float forgetBias)
Parameters
forgetBias | The forget gate bias. |
---|
public static LSTMBlockCell.Options usePeephole (Boolean usePeephole)
Parameters
usePeephole | Whether to use peephole weights. |
---|