View source on GitHub |
Custom_getter class is used to do:
tf.contrib.opt.AGNCustomGetter(
worker_device
)
- Change trainable variables to local collection and place them at worker device
- Generate global variables(global center variables)
- Generate grad variables(gradients) which record the gradients sum and place them at worker device Notice that the class should be used with tf.replica_device_setter, so that the global center variables and global step variable can be placed at ps device.
Methods
__call__
__call__(
getter, name, trainable, collections, *args, **kwargs
)
Call self as a function.