View source on GitHub |
Decorator to override default implementation for unary elementwise APIs.
tf.experimental.dispatch_for_unary_elementwise_apis(
x_type
)
The decorated function (known as the "elementwise api handler") overrides
the default implementation for any unary elementwise API whenever the value
for the first argument (typically named x
) matches the type annotation
x_type
. The elementwise api handler is called with two arguments:
elementwise_api_handler(api_func, x)
Where api_func
is a function that takes a single parameter and performs the
elementwise operation (e.g., tf.abs
), and x
is the first argument to the
elementwise api.
The following example shows how this decorator can be used to update all
unary elementwise operations to handle a MaskedTensor
type:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_api_handler(api_func, x):
return MaskedTensor(api_func(x.values), x.mask)
mt = MaskedTensor([1, -2, -3], [True, False, True])
abs_mt = tf.abs(mt)
print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}")
values=[1 2 3], mask=[ True False True]
For unary elementwise operations that take extra arguments beyond x
, those
arguments are not passed to the elementwise api handler, but are
automatically added when api_func
is called. E.g., in the following
example, the dtype
parameter is not passed to
unary_elementwise_api_handler
, but is added by api_func
.
ones_mt = tf.ones_like(mt, dtype=tf.float32)
print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}")
values=[1.0 1.0 1.0], mask=[ True False True]
Args | |
---|---|
x_type
|
A type annotation indicating when the api handler should be called.
See dispatch_for_api for a list of supported annotation types.
|
Returns | |
---|---|
A decorator. |
Registered APIs
The unary elementwise APIs are:
tf.bitwise.invert(x, name)
tf.cast(x, dtype, name)
tf.clip_by_value(t, clip_value_min, clip_value_max, name)
tf.compat.v1.nn.log_softmax(logits, axis, name, dim)
tf.compat.v1.ones_like(tensor, dtype, name, optimize)
tf.compat.v1.strings.length(input, name, unit)
tf.compat.v1.strings.substr(input, pos, len, name, unit)
tf.compat.v1.strings.to_hash_bucket(string_tensor, num_buckets, name, input)
tf.compat.v1.substr(input, pos, len, name, unit)
tf.compat.v1.to_bfloat16(x, name)
tf.compat.v1.to_complex128(x, name)
tf.compat.v1.to_complex64(x, name)
tf.compat.v1.to_double(x, name)
tf.compat.v1.to_float(x, name)
tf.compat.v1.to_int32(x, name)
tf.compat.v1.to_int64(x, name)
tf.compat.v1.zeros_like(tensor, dtype, name, optimize)
tf.debugging.check_numerics(tensor, message, name)
tf.dtypes.saturate_cast(value, dtype, name)
tf.image.adjust_brightness(image, delta)
tf.image.adjust_gamma(image, gamma, gain)
tf.image.convert_image_dtype(image, dtype, saturate, name)
tf.image.random_brightness(image, max_delta, seed)
tf.image.stateless_random_brightness(image, max_delta, seed)
tf.io.decode_base64(input, name)
tf.io.decode_compressed(bytes, compression_type, name)
tf.io.encode_base64(input, pad, name)
tf.math.abs(x, name)
tf.math.acos(x, name)
tf.math.acosh(x, name)
tf.math.angle(input, name)
tf.math.asin(x, name)
tf.math.asinh(x, name)
tf.math.atan(x, name)
tf.math.atanh(x, name)
tf.math.bessel_i0(x, name)
tf.math.bessel_i0e(x, name)
tf.math.bessel_i1(x, name)
tf.math.bessel_i1e(x, name)
tf.math.ceil(x, name)
tf.math.conj(x, name)
tf.math.cos(x, name)
tf.math.cosh(x, name)
tf.math.digamma(x, name)
tf.math.erf(x, name)
tf.math.erfc(x, name)
tf.math.erfcinv(x, name)
tf.math.erfinv(x, name)
tf.math.exp(x, name)
tf.math.expm1(x, name)
tf.math.floor(x, name)
tf.math.imag(input, name)
tf.math.is_finite(x, name)
tf.math.is_inf(x, name)
tf.math.is_nan(x, name)
tf.math.lgamma(x, name)
tf.math.log(x, name)
tf.math.log1p(x, name)
tf.math.log_sigmoid(x, name)
tf.math.logical_not(x, name)
tf.math.ndtri(x, name)
tf.math.negative(x, name)
tf.math.nextafter(x1, x2, name)
tf.math.real(input, name)
tf.math.reciprocal(x, name)
tf.math.reciprocal_no_nan(x, name)
tf.math.rint(x, name)
tf.math.round(x, name)
tf.math.rsqrt(x, name)
tf.math.sigmoid(x, name)
tf.math.sign(x, name)
tf.math.sin(x, name)
tf.math.sinh(x, name)
tf.math.softplus(features, name)
tf.math.special.bessel_j0(x, name)
tf.math.special.bessel_j1(x, name)
tf.math.special.bessel_k0(x, name)
tf.math.special.bessel_k0e(x, name)
tf.math.special.bessel_k1(x, name)
tf.math.special.bessel_k1e(x, name)
tf.math.special.bessel_y0(x, name)
tf.math.special.bessel_y1(x, name)
tf.math.special.dawsn(x, name)
tf.math.special.expint(x, name)
tf.math.special.fresnel_cos(x, name)
tf.math.special.fresnel_sin(x, name)
tf.math.special.spence(x, name)
tf.math.sqrt(x, name)
tf.math.square(x, name)
tf.math.tan(x, name)
tf.math.tanh(x, name)
tf.nn.elu(features, name)
tf.nn.gelu(features, approximate, name)
tf.nn.leaky_relu(features, alpha, name)
tf.nn.relu(features, name)
tf.nn.relu6(features, name)
tf.nn.selu(features, name)
tf.nn.silu(features, beta)
tf.nn.softsign(features, name)
tf.ones_like(input, dtype, name)
tf.strings.as_string(input, precision, scientific, shortest, width, fill, name)
tf.strings.length(input, unit, name)
tf.strings.lower(input, encoding, name)
tf.strings.regex_full_match(input, pattern, name)
tf.strings.regex_replace(input, pattern, rewrite, replace_global, name)
tf.strings.strip(input, name)
tf.strings.substr(input, pos, len, unit, name)
tf.strings.to_hash_bucket(input, num_buckets, name)
tf.strings.to_hash_bucket_fast(input, num_buckets, name)
tf.strings.to_hash_bucket_strong(input, num_buckets, key, name)
tf.strings.to_number(input, out_type, name)
tf.strings.unicode_script(input, name)
tf.strings.unicode_transcode(input, input_encoding, output_encoding, errors, replacement_char, replace_control_characters, name)
tf.strings.upper(input, encoding, name)
tf.zeros_like(input, dtype, name)