View source on GitHub |
Encapsulates tensor normalization and owns normalization variables.
tf_agents.utils.tensor_normalizer.TensorNormalizer(
tensor_spec, scope='normalize_tensor'
)
Example usage:
tensor_normalizer = StreamingTensorNormalizer(
tf.TensorSpec([], tf.float32))
observation_list = [list of float32 scalars or batches]
normalized_list = []
for o in observation_list:
normalized_list.append(tensor_normalizer.normalize(o))
tensor_normalizer.update(o)
For float64 inputs do:
tensor_normalizer = StreamingTensorNormalizer(
tf.TensorSpec([], tf.float64), dtype=tf.float64)
observation_list = [list of float64 scalars or batches]
for o in observation_list:
normalized_list.append(tensor_normalizer.normalize(o))
tensor_normalizer.update(o)
<!-- Tabular view -->
<table class="responsive fixed orange">
<colgroup><col width="214px"><col></colgroup>
<tr><th colspan="2"><h2 class="add-link">Args</h2></th></tr>
<tr>
<td>
`tensor_spec`<a id="tensor_spec"></a>
</td>
<td>
The specs of the tensors to normalize.
</td>
</tr><tr>
<td>
`scope`<a id="scope"></a>
</td>
<td>
Scope for the `tf.Module`.
</td>
</tr>
</table>
## Methods
<h3 id="map_dtype"><code>map_dtype</code></h3>
<a target="_blank" class="external" href="https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/utils/tensor_normalizer.py#L89-L91">View source</a>
<pre class="devsite-click-to-copy prettyprint lang-py tfo-signature-link">
<code>map_dtype(
dtype
)
</code></pre>
<h3 id="normalize"><code>normalize</code></h3>
<a target="_blank" class="external" href="https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/utils/tensor_normalizer.py#L134-L205">View source</a>
<pre class="devsite-click-to-copy prettyprint lang-py tfo-signature-link">
<code>normalize(
tensor, clip_value=5.0, center_mean=True, variance_epsilon=0.001
)
</code></pre>
Applies normalization to tensor.
<!-- Tabular view -->
<table class="responsive fixed orange">
<colgroup><col width="214px"><col></colgroup>
<tr><th colspan="2">Args</th></tr>
<tr>
<td>
`tensor`
</td>
<td>
Tensor to normalize.
</td>
</tr><tr>
<td>
`clip_value`
</td>
<td>
Clips normalized observations between +/- this value if
clip_value > 0, otherwise does not apply clipping.
</td>
</tr><tr>
<td>
`center_mean`
</td>
<td>
If true, subtracts off mean from normalized tensor.
</td>
</tr><tr>
<td>
`variance_epsilon`
</td>
<td>
Epsilon to avoid division by zero in normalization.
</td>
</tr>
</table>
<!-- Tabular view -->
<table class="responsive fixed orange">
<colgroup><col width="214px"><col></colgroup>
<tr><th colspan="2">Returns</th></tr>
<tr>
<td>
`normalized_tensor`
</td>
<td>
Tensor after applying normalization.
</td>
</tr>
</table>
<h3 id="update"><code>update</code></h3>
<a target="_blank" class="external" href="https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/utils/tensor_normalizer.py#L119-L132">View source</a>
<pre class="devsite-click-to-copy prettyprint lang-py tfo-signature-link">
<code>update(
tensor, outer_dims=(0,)
)
</code></pre>
Updates tensor normalizer variables.