x = tf.compat.v1.placeholder(tf.int32)
print(x.shape)
==> TensorShape(None)
y = x * 2
print(y.shape)
==> TensorShape(None)
y = tf.ensure_shape(y, (None, 3, 3))
print(y.shape)
==> TensorShape([Dimension(None), Dimension(3), Dimension(3)])
with tf.compat.v1.Session() as sess:
# Raises tf.errors.InvalidArgumentError, because the shape (3,) is not
# compatible with the shape (None, 3, 3)
sess.run(y, feed_dict={x: [1, 2, 3]})
Args
x
A Tensor.
shape
A TensorShape representing the shape of this tensor, a
TensorShapeProto, a list, a tuple, or None.
name
A name for this operation (optional). Defaults to "EnsureShape".
Returns
A Tensor. Has the same type and contents as x. At runtime, raises a
tf.errors.InvalidArgumentError if shape is incompatible with the shape
of x.