Commit 694e404b authored by Yuxin Wu's avatar Yuxin Wu

[WIP] move away all graph building logic

parent ad5cb725
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: utils.py # File: _utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import copy import copy
from six.moves import zip from six.moves import zip
from contextlib import contextmanager
import operator
import tensorflow as tf
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
...@@ -57,3 +62,63 @@ def get_sublist_by_names(lst, names): ...@@ -57,3 +62,63 @@ def get_sublist_by_names(lst, names):
raise raise
ret.append(lst[idx]) ret.append(lst[idx])
return ret return ret
@contextmanager
def override_to_local_variable(enable=True):
if enable:
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=OverrideToLocalVariable()):
yield
else:
yield
class OverrideToLocalVariable(object):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
collections.add(tf.GraphKeys.LOCAL_VARIABLES)
kwargs['collections'] = list(collections)
return getter(name, *args, **kwargs)
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class LeastLoadedDeviceSetter(object):
""" Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices):
"""
Args:
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return sanitize_name(self.worker_device)
device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return sanitize_name(device_name)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -7,9 +7,9 @@ from .base import Trainer ...@@ -7,9 +7,9 @@ from .base import Trainer
from ..utils import logger from ..utils import logger
from ..graph_builder.input_source import FeedInput, QueueInput from ..graph_builder.input_source import FeedInput, QueueInput
from ..graph_builder.training import SimpleGraphBuilder from ..graph_builder.training import SimpleBuilder
__all__ = ['SimpleTrainer'] __all__ = ['SimpleTrainer', 'QueueInputTrainer']
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
...@@ -46,7 +46,7 @@ class SimpleTrainer(Trainer): ...@@ -46,7 +46,7 @@ class SimpleTrainer(Trainer):
self.model.build_graph(inputs) self.model.build_graph(inputs)
return self.model.get_cost() return self.model.get_cost()
self.train_op = SimpleGraphBuilder().build(self._input_source, get_cost, self.model.get_optimizer) self.train_op = SimpleBuilder().build(self._input_source, get_cost, self.model.get_optimizer)
self.config.callbacks.extend(cbs) self.config.callbacks.extend(cbs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment