Represents a Mesh configuration over a certain list of Mesh Dimensions.
tf.experimental.dtensor.Mesh(
dim_names: List[str],
global_device_ids: np.ndarray,
local_device_ids: List[int],
local_devices: List[Union[tf_device.DeviceSpec, str]],
mesh_name: str = '',
global_devices: Optional[List[Union[tf_device.DeviceSpec, str]]] = None,
use_xla_spmd: bool = USE_XLA_SPMD
)
A mesh consists of named dimensions with sizes, which describe how a set of devices are arranged. Defining tensor layouts in terms of mesh dimensions allows us to efficiently determine the communication required when computing an operation with tensors of different layouts.
A mesh provides information not only about the placement of the tensors but
also the topology of the underlying devices. For example, we can group 8 TPUs
as a 1-D array for data parallelism or a 2x4
grid for (2-way) data
parallelism and (4-way) model parallelism.
Refer to DTensor Concepts for in depth discussion and examples.
Methods
as_proto
as_proto()
as_proto(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> tensorflow::dtensor::MeshProto
Returns the MeshProto protobuf message.
contains_dim
contains_dim()
contains_dim(self: tensorflow.python._pywrap_dtensor_device.Mesh, dim_name: str) -> bool
Returns True if a Mesh contains the given dimension name.
coords
coords(
device_idx: int
) -> tf.Tensor
Converts the device index into a tensor of mesh coordinates.
device_location
device_location()
device_location(self: tensorflow.python._pywrap_dtensor_device.Mesh, arg0: int) -> List[int]
device_type
device_type()
device_type(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> str
Returns the device_type of a Mesh.
dim_size
dim_size()
dim_size(self: tensorflow.python._pywrap_dtensor_device.Mesh, dim_name: str) -> int
Returns the size of mesh dimension.
from_device
@classmethod
from_device( device: str ) -> 'Mesh'
Constructs a single device mesh from a device string.
from_proto
@classmethod
from_proto( proto: layout_pb2.MeshProto ) -> 'Mesh'
Construct a mesh instance from input proto
.
from_string
@classmethod
from_string( mesh_str: str ) -> 'Mesh'
global_device_ids
global_device_ids() -> np.ndarray
Returns a global device list as an array.
global_devices
global_devices()
global_devices(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> List[str]
Returns a list of global device specs represented as strings.
host_mesh
host_mesh() -> 'Mesh'
Returns a host mesh.
is_remote
is_remote()
is_remote(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> bool
Returns True if a Mesh contains only remote devices.
is_single_device
is_single_device()
is_single_device(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> bool
Returns True if the mesh represents a non-distributed device.
local_device_ids
local_device_ids()
local_device_ids(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> Span[int]
Returns a list of local device IDs.
local_device_locations
local_device_locations() -> List[Dict[str, int]]
Returns a list of local device locations.
A device location is a dictionary from dimension names to indices on those dimensions.
local_devices
local_devices()
local_devices(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> Span[str]
Returns a list of local device specs represented as strings.
min_global_device_id
min_global_device_id()
min_global_device_id(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> int
Returns the minimum global device ID.
num_local_devices
num_local_devices()
num_local_devices(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> int
Returns the number of local devices.
shape
shape()
shape(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> List[int]
Returns the shape of the mesh.
to_string
to_string()
to_string(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> str
Returns string representation of Mesh.
unravel_index
unravel_index()
Returns a dictionary from device ID to {dim_name: dim_index}.
For example, for a 3x2 mesh, return this:
{ 0: {'x': 0, 'y', 0},
1: {'x': 0, 'y', 1},
2: {'x': 1, 'y', 0},
3: {'x': 1, 'y', 1},
4: {'x': 2, 'y', 0},
5: {'x': 2, 'y', 1} }
use_xla_spmd
use_xla_spmd()
use_xla_spmd(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> bool
Returns True if Mesh will use XLA for SPMD instead of DTensor SPMD.
__contains__
__contains__()
contains(self: tensorflow.python._pywrap_dtensor_device.Mesh, dim_name: str) -> bool
__eq__
__eq__()
eq(self: tensorflow.python._pywrap_dtensor_device.Mesh, arg0: tensorflow.python._pywrap_dtensor_device.Mesh) -> bool
__getitem__
__getitem__(
dim_name: str
) -> MeshDimension