Commit 58fc1dea authored by Yuxin Wu's avatar Yuxin Wu

update

parent 1b98fe58
...@@ -146,12 +146,12 @@ class GPUUtilizationTracker(Callback): ...@@ -146,12 +146,12 @@ class GPUUtilizationTracker(Callback):
if evt.is_set(): # stop epoch if evt.is_set(): # stop epoch
if stop_evt.is_set(): # or on exit if stop_evt.is_set(): # or on exit
return return
evt.clear()
if cnt > 1: if cnt > 1:
# Ignore the last datapoint. Usually is zero, makes us underestimate the util. # Ignore the last datapoint. Usually is zero, makes us underestimate the util.
stats -= data stats -= data
cnt -= 1 cnt -= 1
rst_queue.put(stats / cnt) rst_queue.put(stats / cnt)
evt.clear()
break break
except Exception: except Exception:
logger.exception("Exception in GPUUtilizationTracker.worker") logger.exception("Exception in GPUUtilizationTracker.worker")
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from ..utils import logger
from ..compat import tfv1 as tf from ..compat import tfv1 as tf
from .common import get_tf_version_tuple from .common import get_tf_version_tuple
...@@ -113,7 +114,15 @@ def freeze_variables(stop_gradient=True, skip_collection=False): ...@@ -113,7 +114,15 @@ def freeze_variables(stop_gradient=True, skip_collection=False):
# do not perform unnecessary changes if it's not originally trainable # do not perform unnecessary changes if it's not originally trainable
# otherwise the variable may get added to MODEL_VARIABLES twice # otherwise the variable may get added to MODEL_VARIABLES twice
if trainable and skip_collection: if trainable and skip_collection:
if isinstance(v, tf.Variable):
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v) tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
else:
logger.warning("""
[freeze_variables] variable getter did not return a Variable, but '{}' instead, likely due to
another custom getter. freeze_variables() work only if the other custom getter respects the
`trainable` argument and don't put variables with `trainable=False` into TRAINABLE_VARIABLES
collection. Please double check if this is true for the custom getter.
""".format(str(v)).replace("\n", ""))
if trainable and stop_gradient: if trainable and stop_gradient:
v = tf.stop_gradient(v, name='freezed_' + name) v = tf.stop_gradient(v, name='freezed_' + name)
return v return v
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment