View source on GitHub |
Automagically generate CompositeTensor
behavior for cls
.
tfp.experimental.auto_composite_tensor(
cls=None, omit_kwargs=(), non_identifying_kwargs=(), module_name=None
)
CompositeTensor
objects are able to pass in and out of tf.function
and
tf.while_loop
, or serve as part of the signature of a TF saved model.
The contract of auto_composite_tensor
is that all init args and kwargs
must have corresponding public or private attributes (or properties). Each of
these attributes is inspected (recursively) to determine whether it is (or
contains) Tensor
s or non-Tensor
metadata. Nested (list
, tuple
, dict
,
etc) attributes are supported, but must either contain only Tensor
s (or
lists, etc, thereof), or no Tensor
s. E.g.,
- object.attribute = [1., 2., 'abc'] # valid
- object.attribute = [tf.constant(1.), [tf.constant(2.)]] # valid
- object.attribute = ['abc', tf.constant(1.)] # invalid
All __init__
args that may be ResourceVariable
s must also admit Tensor
s
(or else _convert_variables_to_tensors
must be overridden).
If the attribute is a callable, serialization of the TypeSpec
, and therefore
interoperability with tf.saved_model
, is not currently supported. As a
workaround, callables that do not contain or close over Tensor
s may be
expressed as functors that subclass AutoCompositeTensor
and used in place of
the original callable arg:
@auto_composite_tensor(module_name='my.module')
class F(AutoCompositeTensor):
def __call__(self, *args, **kwargs):
return original_callable(*args, **kwargs)
Callable objects that do contain or close over Tensor
s should either
(1) subclass AutoCompositeTensor
, with the Tensor
s passed to the
constructor, (2) subclass CompositeTensor
and implement their own
TypeSpec
, or (3) have a conversion function registered with
type_spec.register_type_spec_from_value_converter
.
If the object has a _composite_tensor_shape_parameters
field (presumed to
have tuple
of str
value), the flattening code will use
tf.get_static_value
to attempt to preserve shapes as static metadata, for
fields whose name matches a name specified in that field. Preserving static
values can be important to correctly propagating shapes through a loop.
Note that the Distribution and Bijector base classes provide a
default implementation of _composite_tensor_shape_parameters
, populated by
parameter_properties
annotations.
If the decorated class A
does not subclass CompositeTensor
, a new class
will be generated, which mixes in A
and CompositeTensor
.
To avoid this extra class in the class hierarchy, we suggest inheriting from
auto_composite_tensor.AutoCompositeTensor
, which inherits from
CompositeTensor
and implants a trivial _type_spec
@property. The
@auto_composite_tensor
decorator will then overwrite this trivial
_type_spec
@property. The trivial one is necessary because _type_spec
is
an abstract property of CompositeTensor
, and a valid class instance must be
created before the decorator can execute -- without the trivial _type_spec
property present, ABCMeta
will throw an error! The user may thus do any of
the following:
AutoCompositeTensor
base class (recommended)
@tfp.experimental.auto_composite_tensor
class MyClass(tfp.experimental.AutoCompositeTensor):
...
mc = MyClass()
type(mc)
# ==> MyClass
No CompositeTensor
base class (ok, but changes expected types)
@tfp.experimental.auto_composite_tensor
class MyClass(object):
...
mc = MyClass()
type(mc)
# ==> MyClass_AutoCompositeTensor
CompositeTensor
base class, requiring trivial _type_spec
from tensorflow.python.framework import composite_tensor
@tfp.experimental.auto_composite_tensor
class MyClass(composite_tensor.CompositeTensor):
@property
def _type_spec(self): # will be overwritten by @auto_composite_tensor
pass
...
mc = MyClass()
type(mc)
# ==> MyClass
Full usage example
@tfp.experimental.auto_composite_tensor(omit_kwargs=('name',))
class Adder(tfp.experimental.AutoCompositeTensor):
def __init__(self, x, y, name=None):
with tf.name_scope(name or 'Adder') as name:
self._x = tf.convert_to_tensor(x)
self._y = tf.convert_to_tensor(y)
self._name = name
def xpy(self):
return self._x + self._y
def body(obj):
return Adder(obj.xpy(), 1.),
result, = tf.while_loop(
cond=lambda _: True,
body=body,
loop_vars=(Adder(1., 1.),),
maximum_iterations=3)
result.xpy() # => 5.
Returns | |
---|---|
composite_tensor_subclass
|
A subclass of cls and TF CompositeTensor.
|