diff --git a/generate.py b/generate.py index ecfd2bf6d..5e807d7f1 100644 --- a/generate.py +++ b/generate.py @@ -136,6 +136,8 @@ def main(): args = get_arguments() started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now()) logdir = os.path.join(args.logdir, 'generate', started_datestring) + if not os.path.exists(logdir): + os.makedirs(logdir) with open(args.wavenet_params, 'r') as config_file: wavenet_params = json.load(config_file) diff --git a/train.py b/train.py index 02ec80074..13671a733 100644 --- a/train.py +++ b/train.py @@ -112,9 +112,6 @@ def save(saver, sess, logdir, step): print('Storing checkpoint to {} ...'.format(logdir), end="") sys.stdout.flush() - if not os.path.exists(logdir): - os.makedirs(logdir) - saver.save(sess, checkpoint_path, global_step=step) print(' Done.') @@ -171,6 +168,8 @@ def validate_directories(args): if logdir is None: logdir = get_default_logdir(logdir_root) print('Using default logdir: {}'.format(logdir)) + if not os.path.exists(logdir): + os.makedirs(logdir) restore_from = args.restore_from if restore_from is None: