Commit 86d1b2e5 authored by Yuxin Wu's avatar Yuxin Wu

add gpu_compat option, improve globalstepcounter

parent 476d7cf0
...@@ -55,13 +55,14 @@ class MaintainStepCounter(Callback): ...@@ -55,13 +55,14 @@ class MaintainStepCounter(Callback):
# ensure it exists # ensure it exists
gs_var = get_global_step_var() gs_var = get_global_step_var()
with tf.name_scope(None): with tf.name_scope(None):
self.gs_incr_var = tf.assign_add( with tf.device(gs_var.device):
gs_var, 1, self.gs_incr_op = tf.assign_add(
name=GLOBAL_STEP_INCR_OP_NAME) gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME).op
# tf.mod( # tf.mod(
# self.gs_incr_var, self.trainer.config.steps_per_epoch, # self.gs_incr_var, self.trainer.config.steps_per_epoch,
# name=LOCAL_STEP_OP_NAME) # name=LOCAL_STEP_OP_NAME)
self._fetches = tf.train.SessionRunArgs(self.gs_incr_var) self._fetches = tf.train.SessionRunArgs(self.gs_incr_op)
def _before_train(self): def _before_train(self):
gs_val = get_global_step_value() gs_val = get_global_step_value()
......
...@@ -39,9 +39,11 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -39,9 +39,11 @@ def get_default_sess_config(mem_fraction=0.99):
conf.inter_op_parallelism_threads = 0 conf.inter_op_parallelism_threads = 0
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
# TODO check version
conf.gpu_options.force_gpu_compatible = True
conf.gpu_options.allocator_type = 'BFC' conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True conf.gpu_options.allow_growth = True
# force gpu compatible?
conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
return conf return conf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment