Commit 546af8b2 authored by Yuxin Wu's avatar Yuxin Wu

add function to set rng seed

parent dfaa63eb
......@@ -135,7 +135,8 @@ def get_config():
logger.auto_set_dir()
dataset_train, dataset_test = get_data()
# how many iterations you want in each epoch
# How many iterations you want in each epoch.
# This is the default value, don't actually need to set it in the config
steps_per_epoch = dataset_train.size()
# get the config which contains everything necessary in a training
......
......@@ -13,6 +13,7 @@ import numpy as np
__all__ = ['change_env',
'get_rng',
'fix_rng_seed',
'get_tqdm_kwargs',
'get_tqdm',
'execute_only_once',
......@@ -38,6 +39,21 @@ def change_env(name, val):
os.environ[name] = oldval
_RNG_SEED = None
def fix_rng_seed(seed):
"""
Args:
seed (int):
Note:
See https://github.com/ppwwyyxx/tensorpack/issues/196.
"""
global _RNG_SEED
_RNG_SEED = int(seed)
def get_rng(obj=None):
"""
Get a good RNG seeded with time, pid and the object.
......@@ -49,6 +65,8 @@ def get_rng(obj=None):
"""
seed = (id(obj) + os.getpid() +
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
if _RNG_SEED is not None:
seed = _RNG_SEED
return np.random.RandomState(seed)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment