View source on GitHub |
Decorator to override default implementation for binary elementwise APIs.
tf.experimental.dispatch_for_binary_elementwise_apis(
x_type, y_type
)
The decorated function (known as the "elementwise api handler") overrides
the default implementation for any binary elementwise API whenever the value
for the first two arguments (typically named x
and y
) match the specified
type annotations. The elementwise api handler is called with two arguments:
elementwise_api_handler(api_func, x, y)
Where x
and y
are the first two arguments to the elementwise api, and
api_func
is a TensorFlow function that takes two parameters and performs the
elementwise operation (e.g., tf.add
).
The following example shows how this decorator can be used to update all
binary elementwise operations to handle a MaskedTensor
type:
from tensorflow.python.framework import extension_type
class MaskedTensor(extension_type.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
def binary_elementwise_api_handler(api_func, x, y):
return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
a = MaskedTensor([1, 2, 3, 4, 5], [True, True, True, True, False])
b = MaskedTensor([2, 4, 6, 8, 0], [True, True, True, False, True])
c = tf.add(a, b)
print(f"values={c.values.numpy()}, mask={c.mask.numpy()}")
values=[ 3 6 9 12 5], mask=[ True True True False False]
Args | |
---|---|
x_type
|
A type annotation indicating when the api handler should be called. |
y_type
|
A type annotation indicating when the api handler should be called. |
Returns | |
---|---|
A decorator. |
Registered APIs
The binary elementwise APIs are:
tf.bitwise.bitwise_and(x, y, name=None)
tf.bitwise.bitwise_or(x, y, name=None)
tf.bitwise.bitwise_xor(x, y, name=None)
tf.bitwise.left_shift(x, y, name=None)
tf.bitwise.right_shift(x, y, name=None)
tf.dtypes.complex(real, imag, name=None)
tf.math.add(x, y, name=None)
tf.math.atan2(y, x, name=None)
tf.math.divide(x, y, name=None)
tf.math.divide_no_nan(x, y, name=None)
tf.math.equal(x, y, name=None)
tf.math.floordiv(x, y, name=None)
tf.math.floormod(x, y, name=None)
tf.math.greater(x, y, name=None)
tf.math.greater_equal(x, y, name=None)
tf.math.less(x, y, name=None)
tf.math.less_equal(x, y, name=None)
tf.math.logical_and(x, y, name=None)
tf.math.logical_or(x, y, name=None)
tf.math.logical_xor(x, y, name='LogicalXor')
tf.math.maximum(x, y, name=None)
tf.math.minimum(x, y, name=None)
tf.math.multiply(x, y, name=None)
tf.math.multiply_no_nan(x, y, name=None)
tf.math.not_equal(x, y, name=None)
tf.math.pow(x, y, name=None)
tf.math.squared_difference(x, y, name=None)
tf.math.subtract(x, y, name=None)
tf.math.truediv(x, y, name=None)
tf.realdiv(x, y, name=None)
tf.truncatediv(x, y, name=None)
tf.truncatemod(x, y, name=None)