Skip to content

Commit 3353266

Browse files
committed
network in progress
1 parent e69ec3b commit 3353266

File tree

7 files changed

+129
-26
lines changed

7 files changed

+129
-26
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ TensorFlow implementation of [Learning from Simulated and Unsupervised Images th
1313

1414
## Usage
1515

16+
First generate synthetic gaze dataset with [UnityEyes](http://www.cl.cam.ac.uk/research/rainbow/projects/unityeyes/). There is no details in the paper but I changed `Camera parameters` to `0, 0, 20, 40` before generating images.
17+
1618
To train a model:
1719

1820
$ python main.py --data_set gaze

config.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,17 @@ def add_argument_group(name):
2222
data_arg = add_argument_group('Data')
2323
data_arg.add_argument('--data_set', type=str, default='gaze')
2424
data_arg.add_argument('--data_dir', type=str, default='data')
25+
data_arg.add_argument('--input_height', type=int, default=35)
26+
data_arg.add_argument('--input_width', type=int, default=55)
27+
data_arg.add_argument('--input_channel', type=int, default=1)
2528

2629
# Training / test parameters
2730
train_arg = add_argument_group('Training')
28-
train_arg.add_argument('--optimizer', default='rmsprop', help='')
31+
train_arg.add_argument('--optimizer', type=str, default='rmsprop', help='')
32+
train_arg.add_argument('--max_step', type=int, default=200, help='')
33+
train_arg.add_argument('--lambda', type=float, default=1, help='')
34+
train_arg.add_argument('--K_d', type=int, default=1, help='')
35+
train_arg.add_argument('--K_g', type=int, default=5, help='')
2936
train_arg.add_argument('--batch_size', type=int, default=512, help='')
3037
train_arg.add_argument('--num_epochs', type=int, default=12, help='')
3138
train_arg.add_argument('--random_seed', type=int, default=123, help='')

data/gaze_data.py

+12-20
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
def maybe_download_and_extract(
1818
data_path,
19-
url='http://datasets.d2.mpi-inf.mpg.de/MPIIGAZE_PATH/MPIIGAZE_PATH.tar.gz'):
20-
if not os.path.exists(os.path.join(data_path, 'MPIIGAZE_PATH')):
19+
url='http://datasets.d2.mpi-inf.mpg.de/MPIIGaze/MPIIGaze.tar.gz'):
20+
if not os.path.exists(os.path.join(data_path, MPIIGAZE_PATH)):
2121
if not os.path.exists(data_path):
2222
os.makedirs(data_path)
2323

@@ -32,11 +32,11 @@ def _progress(count, block_size, total_size):
3232

3333
filepath, _ = urllib.request.urlretrieve(url, filepath, _progress)
3434
statinfo = os.stat(filepath)
35-
print('\nSuccessfully downloaded', filename, statinfo.st_size, 'bytes.')
35+
print('\nSuccessfully downloaded {} {} bytes.'.format(filename, statinfo.st_size))
3636
tarfile.open(filepath, 'r:gz').extractall(data_path)
3737

3838
def maybe_preprocess(data_path):
39-
base_path = os.path.join(data_path, 'MPIIGAZE_PATH/Data/Normalized')
39+
base_path = os.path.join(data_path, '{}/Data/Normalized'.format(MPIIGAZE_PATH))
4040
npz_path = os.path.join(data_path, DATA_FNAME)
4141

4242
if os.path.exists(npz_path):
@@ -48,6 +48,8 @@ def maybe_preprocess(data_path):
4848
for filename in fnmatch.filter(filenames, '*.mat'):
4949
mat_paths.append(os.path.join(root, filename))
5050

51+
print("[*] Preprocessing `gaze` data...")
52+
5153
images =[]
5254
for mat_path in tqdm(mat_paths):
5355
mat = loadmat(mat_path)
@@ -57,14 +59,7 @@ def maybe_preprocess(data_path):
5759
images.extend(mat['data'][0][0][1][0][0][1])
5860

5961
real_data = np.stack(images, axis=0)
60-
61-
# UnityEyes dataset
62-
synthetic_data = None
63-
64-
#raise Exception("[!] Not implemented yet")
65-
66-
np.savez(npz_path, real=real_data, synthetic=synthetic_data)
67-
print("[*] Preprocessing of `gaze` data is finished.")
62+
np.savez(npz_path, real=real_data)
6863

6964
def load(data_path, debug=False):
7065
if not os.path.exists(data_path):
@@ -75,28 +70,25 @@ def load(data_path, debug=False):
7570
maybe_preprocess(data_path)
7671

7772
gaze_data = np.load(os.path.join(data_path, DATA_FNAME))
73+
real_data = gaze_data['real']
7874

79-
real_data, synthetic_data = gaze_data['real'], gaze_data['synthetic']
8075
if debug:
8176
print("[*] Save sample images in {}".format(data_path))
82-
for idx in range(10):
83-
image_path = os.path.join(synthetic_images,
84-
"sample_real_{}".format(idx))
77+
for idx in range(100):
78+
image_path = os.path.join(data_path, "sample_real_{}.png".format(idx))
8579
imwrite(image_path, real_data[idx])
86-
return real_data, synthetic_data
80+
return real_data
8781

8882
class DataLoader(object):
8983
def __init__(self, data_dir, batch_size, debug=False, rng=None):
9084
self.data_path = os.path.join(data_dir, 'gaze')
9185
self.batch_size = batch_size
9286

93-
self.data, self.labels = load(self.data_path, conf.debug)
94-
self.data = np.transpose(self.data, (0,2,3,1)) # (N,3,32,32) -> (N,32,32,3)
87+
self.data = load(self.data_path, debug)
9588

9689
self.p = 0 # pointer to where we are in iteration
9790
self.rng = np.random.RandomState(1) if rng is None else rng
9891

99-
10092
def get_observation_size(self):
10193
return self.data.shape[1:]
10294

layers.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,40 @@
1+
import tensorflow as tf
2+
import tensorflow.contrib.slim as slim
13
from tensorflow.contrib.framework import add_arg_scope
24

5+
def _update_dict(layer_dict, scope, layer):
6+
name = "{}/{}".format(tf.get_variable_scope().name, scope)
7+
layer_dict[name] = layer
8+
9+
@add_arg_scope
10+
def resnet_block(
11+
inputs, scope, num_outputs=64, kernel_size=[3, 3],
12+
stride=[1, 1], padding="SAME", layer_dict={}):
13+
with tf.variable_scope(scope):
14+
layer = slim.conv2d(
15+
inputs, num_outputs, kernel_size, stride,
16+
padding=padding, activation_fn=tf.nn.relu, scope="conv1")
17+
layer = slim.conv2d(
18+
inputs, num_outputs, kernel_size, stride,
19+
padding=padding, scope="conv2")
20+
outputs = tf.nn.relu(tf.add(inputs, layer))
21+
_update_dict(layer_dict, scope, outputs)
22+
return outputs
23+
24+
@add_arg_scope
25+
def repeat(inputs, repetitions, layer, layer_dict={}, **kargv):
26+
outputs = slim.repeat(inputs, repetitions, layer, **kargv)
27+
_update_dict(layer_dict, kargv['scope'], outputs)
28+
return outputs
29+
30+
@add_arg_scope
31+
def conv2d(inputs, num_outputs, kernel_size, stride, layer_dict={}, **kargv):
32+
outputs = slim.conv2d(inputs, num_outputs, kernel_size, stride, **kargv)
33+
_update_dict(layer_dict, kargv['scope'], outputs)
34+
return outputs
35+
336
@add_arg_scope
4-
def resnet_block():
5-
pass
37+
def max_pool2d(inputs, kernel_size=[3, 3], stride=[1, 1], layer_dict={}, **kargv):
38+
outputs = slim.max_pool2d(inputs, kernel_size, stride, **kargv)
39+
_update_dict(layer_dict, kargv['scope'], outputs)
40+
return outputs

main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def main(_):
2424
'hand': hand_data.DataLoader,
2525
}[config.data_set]
2626

27+
model = Model(config)
2728
data_loader = DataLoader(config.data_dir, config.batch_size,
2829
config.debug, rng=rng)
29-
model = Model()
3030

3131
if __name__ == "__main__":
3232
config, unparsed = get_config()

model.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1+
from tqdm import tqdm
2+
13
import tensorflow as tf
24
from tensorflow.contrib.framework.python.ops import arg_scope
35

46
from network import Network
57

68
class Model(object):
79
def __init__(self, config):
10+
self.K_d = config.K_d
11+
self.K_g = config.K_g
12+
813
self.network = Network(config)
9-
pass
14+
15+
def train(self):
16+
for step in range(self.max_step):
17+
for k in range(self.K_g):
18+
pass
19+
20+
for k in range(self.K_d):
21+
pass

network.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,60 @@
11
import tensorflow as tf
2+
from tensorflow.contrib.framework import arg_scope
3+
4+
from layers import *
25

36
class Network(object):
47
def __init__(self, config):
5-
pass
8+
input_dims = [
9+
None, config.input_height,
10+
config.input_width, config.input_channel,
11+
]
12+
13+
def to_float(layer):
14+
return tf.image.convert_image_dtype(layer, tf.float32)
15+
16+
self.inputs = to_float(tf.placeholder(tf.uint8, input_dims, 'inputs'))
17+
#self.input_real = to_float(tf.placeholder(tf.uint8, input_dims, 'input_real'))
18+
#self.input_synthetic = to_float(tf.placeholder(tf.uint8, input_dims, 'input_synthetic'))
19+
self.targets = to_float(tf.placeholder(tf.uint8, input_dims, 'targets'))
20+
21+
self.layer_dict = {}
22+
23+
with arg_scope([resnet_block, conv2d, max_pool2d], layer_dict=self.layer_dict):
24+
self.refiner_outputs = self._build_refiner_network()
25+
26+
self.discrim_inputs = self._build_discriminator_network(self.inputs)
27+
self.discrim_refiner = self._build_discriminator_network(self.refiner_outputs)
28+
import ipdb; ipdb.set_trace()
29+
30+
#self.estimate_outputs = self._build_estimation_network()
31+
32+
self.refiner_loss = tf.reduce_sum(self.refiner_outputs - self.inputs, [1, 2, 3])
33+
34+
def _build_refiner_network(self):
35+
layer = self.inputs
36+
with tf.variable_scope("refiner"):
37+
layer = repeat(layer, 5, resnet_block, scope="resnet")
38+
layer = conv2d(layer, 1, 1, 1, scope="conv_1")
39+
return layer
40+
41+
def _build_discriminator_network(self, layer):
42+
with tf.variable_scope("discriminator"):
43+
layer = conv2d(layer, 96, 3, 2, scope="conv_1")
44+
layer = conv2d(layer, 64, 3, 2, scope="conv_2")
45+
layer = max_pool2d(layer, 3, 1, scope="max_1")
46+
layer = conv2d(layer, 32, 3, 1, scope="conv_3")
47+
layer = conv2d(layer, 32, 1, 1, scope="conv_4")
48+
layer = conv2d(layer, 2, 1, 1, activation_fn=tf.nn.softmax, scope="conv_5")
49+
return layer
50+
51+
def _build_estimation_network(self):
52+
layer = self.inputs
53+
with tf.variable_scope("estimation"):
54+
layer = conv2d(layer, 96, 3, 2, scope="conv_1")
55+
layer = conv2d(layer, 64, 3, 2, scope="conv_2")
56+
layer = max_pool2d(layer, 64, 3, scope="max_1")
57+
layer = conv2d(layer, 32, 3, 1, scope="conv_3")
58+
layer = conv2d(layer, 32, 1, 1, scope="conv_4")
59+
layer = conv2d(layer, 2, 1, 1, activation_fn=slim.softmax)
60+
return layer

0 commit comments

Comments
 (0)