Commit f1a8acf4 authored by Yuxin Wu's avatar Yuxin Wu

make Image2Image support windows (fix #1412)

parent e50423db
......@@ -146,11 +146,12 @@ class Model(GANModelDesc):
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3)
def split_input(img):
def split_input(dp):
"""
img: an RGB image of shape (s, 2s, 3).
dp: the datapoint. first component is an RGB image of shape (s, 2s, 3).
:return: [input, output]
"""
img = dp[0]
# split the image into left + right pairs
s = img.shape[0]
assert img.shape[1] == 2 * s
......@@ -169,7 +170,7 @@ def get_data():
imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, lambda dp: split_input(dp[0]))
ds = MapData(ds, split_input)
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH)
......@@ -186,7 +187,7 @@ def sample(datadir, model_path):
imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, lambda dp: split_input(dp[0]))
ds = MapData(ds, split_input)
ds = AugmentImageComponents(ds, [imgaug.Resize(256)], (0, 1))
ds = BatchData(ds, 6)
......
# Copyright (c) Tensorpack Contributors. All Rights Reserved
# -*- coding: utf-8 -*-
# File: batch_norm.py
......
......@@ -201,7 +201,7 @@ def collect_env_info():
# List devices with NVML
data.append(
("CUDA_VISIBLE_DEVICES",
os.environ.get("CUDA_VISIBLE_DEVICES", str(None))))
os.environ.get("CUDA_VISIBLE_DEVICES", "Unspecified")))
try:
devs = defaultdict(list)
with NVMLContext() as ctx:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment