x=tf.compat.v1.placeholder(tf.int32)print(x.shape)==> TensorShape(None)y=x*2print(y.shape)==> TensorShape(None)y=tf.ensure_shape(y,(None,3,3))print(y.shape)==> TensorShape([Dimension(None),Dimension(3),Dimension(3)])withtf.compat.v1.Session()assess:# Raises tf.errors.InvalidArgumentError, because the shape (3,) is not# compatible with the shape (None, 3, 3)sess.run(y,feed_dict={x:[1,2,3]})
Args
x
A Tensor.
shape
A TensorShape representing the shape of this tensor, a
TensorShapeProto, a list, a tuple, or None.
name
A name for this operation (optional). Defaults to "EnsureShape".
Returns
A Tensor. Has the same type and contents as x. At runtime, raises a
tf.errors.InvalidArgumentError if shape is incompatible with the shape
of x.