View source on GitHub |
Decorator to override default implementation for unary elementwise APIs.
tf.experimental.dispatch_for_unary_elementwise_apis(
x_type
)
The decorated function (known as the "elementwise api handler") overrides
the default implementation for any unary elementwise API whenever the value
for the first argument (typically named x
) matches the type annotation
x_type
. The elementwise api handler is called with two arguments:
elementwise_api_handler(api_func, x)
Where api_func
is a function that takes a single parameter and performs the
elementwise operation (e.g., tf.abs
), and x
is the first argument to the
elementwise api.
The following example shows how this decorator can be used to update all
unary elementwise operations to handle a MaskedTensor
type:
from tensorflow.python.framework import extension_type
class MaskedTensor(extension_type.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_api_handler(api_func, x):
return MaskedTensor(api_func(x.values), x.mask)
mt = MaskedTensor([1, -2, -3], [True, False, True])
abs_mt = tf.abs(mt)
print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}")
values=[1 2 3], mask=[ True False True]
For unary elementwise operations that take extra arguments beyond x
, those
arguments are not passed to the elementwise api handler, but are
automatically added when api_func
is called. E.g., in the following
example, the dtype
parameter is not passed to
unary_elementwise_api_handler
, but is added by api_func
.
ones_mt = tf.ones_like(mt, dtype=tf.float32)
print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}")
values=[1.0 1.0 1.0], mask=[ True False True]
Args | |
---|---|
x_type
|
A type annotation indicating when the api handler should be called.
See dispatch_for_api for a list of supported annotation types.
|
Returns | |
---|---|
A decorator. |
Registered APIs
The unary elementwise APIs are:
tf.bitwise.invert(x, name=None)
tf.cast(x, dtype, name=None)
tf.clip_by_value(t, clip_value_min, clip_value_max, name=None)
tf.compat.v1.ones_like(tensor, dtype=None, name=None, optimize=True)
tf.compat.v1.strings.length(input, name=None, unit='BYTE')
tf.compat.v1.strings.substr(input, pos, len, name=None, unit='BYTE')
tf.compat.v1.strings.to_hash_bucket(string_tensor=None, num_buckets=None, name=None, input=None)
tf.compat.v1.substr(input, pos, len, name=None, unit='BYTE')
tf.compat.v1.zeros_like(tensor, dtype=None, name=None, optimize=True)
tf.debugging.check_numerics(tensor, message, name=None)
tf.dtypes.saturate_cast(value, dtype, name=None)
tf.io.decode_base64(input, name=None)
tf.io.decode_compressed(bytes, compression_type='', name=None)
tf.io.encode_base64(input, pad=False, name=None)
tf.math.abs(x, name=None)
tf.math.acos(x, name=None)
tf.math.acosh(x, name=None)
tf.math.angle(input, name=None)
tf.math.asin(x, name=None)
tf.math.asinh(x, name=None)
tf.math.atan(x, name=None)
tf.math.atanh(x, name=None)
tf.math.bessel_i0(x, name=None)
tf.math.bessel_i0e(x, name=None)
tf.math.bessel_i1(x, name=None)
tf.math.bessel_i1e(x, name=None)
tf.math.ceil(x, name=None)
tf.math.conj(x, name=None)
tf.math.cos(x, name=None)
tf.math.cosh(x, name=None)
tf.math.digamma(x, name=None)
tf.math.erf(x, name=None)
tf.math.erfc(x, name=None)
tf.math.erfcinv(x, name=None)
tf.math.erfinv(x, name=None)
tf.math.exp(x, name=None)
tf.math.expm1(x, name=None)
tf.math.floor(x, name=None)
tf.math.imag(input, name=None)
tf.math.is_finite(x, name=None)
tf.math.is_inf(x, name=None)
tf.math.is_nan(x, name=None)
tf.math.lgamma(x, name=None)
tf.math.log(x, name=None)
tf.math.log1p(x, name=None)
tf.math.log_sigmoid(x, name=None)
tf.math.logical_not(x, name=None)
tf.math.ndtri(x, name=None)
tf.math.negative(x, name=None)
tf.math.nextafter(x1, x2, name=None)
tf.math.real(input, name=None)
tf.math.reciprocal(x, name=None)
tf.math.reciprocal_no_nan(x, name=None)
tf.math.rint(x, name=None)
tf.math.round(x, name=None)
tf.math.rsqrt(x, name=None)
tf.math.sigmoid(x, name=None)
tf.math.sign(x, name=None)
tf.math.sin(x, name=None)
tf.math.sinh(x, name=None)
tf.math.softplus(features, name=None)
tf.math.special.bessel_j0(x, name=None)
tf.math.special.bessel_j1(x, name=None)
tf.math.special.bessel_k0(x, name=None)
tf.math.special.bessel_k0e(x, name=None)
tf.math.special.bessel_k1(x, name=None)
tf.math.special.bessel_k1e(x, name=None)
tf.math.special.bessel_y0(x, name=None)
tf.math.special.bessel_y1(x, name=None)
tf.math.special.dawsn(x, name=None)
tf.math.special.expint(x, name=None)
tf.math.special.fresnel_cos(x, name=None)
tf.math.special.fresnel_sin(x, name=None)
tf.math.special.spence(x, name=None)
tf.math.sqrt(x, name=None)
tf.math.square(x, name=None)
tf.math.tan(x, name=None)
tf.math.tanh(x, name=None)
tf.ones_like(input, dtype=None, name=None)
tf.strings.as_string(input, precision=-1, scientific=False, shortest=False, width=-1, fill='', name=None)
tf.strings.length(input, unit='BYTE', name=None)
tf.strings.regex_full_match(input, pattern, name=None)
tf.strings.regex_replace(input, pattern, rewrite, replace_global=True, name=None)
tf.strings.strip(input, name=None)
tf.strings.substr(input, pos, len, unit='BYTE', name=None)
tf.strings.to_hash_bucket(input, num_buckets, name=None)
tf.strings.to_hash_bucket_fast(input, num_buckets, name=None)
tf.strings.to_hash_bucket_strong(input, num_buckets, key, name=None)
tf.strings.to_number(input, out_type=tf.float32, name=None)
tf.strings.unicode_script(input, name=None)
tf.zeros_like(input, dtype=None, name=None)