Skip to content

Commit

Permalink
Merge pull request ericjang#5 from iamgroot42/master
Browse files Browse the repository at this point in the history
Changes for Tensorflow API1. thanks @iamgroot42 !
  • Loading branch information
ericjang authored Apr 9, 2017
2 parents 00c55a4 + cbb82ea commit c384dbe
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@

x = tf.placeholder(tf.float32,shape=(batch_size,img_size)) # input (batch_size * img_size)
e=tf.random_normal((batch_size,z_size), mean=0, stddev=1) # Qsampler noise
lstm_enc = tf.nn.rnn_cell.LSTMCell(enc_size, state_is_tuple=True) # encoder Op
lstm_dec = tf.nn.rnn_cell.LSTMCell(dec_size, state_is_tuple=True) # decoder Op
lstm_enc = tf.contrib.rnn.LSTMCell(enc_size, state_is_tuple=True) # encoder Op
lstm_dec = tf.contrib.rnn.LSTMCell(dec_size, state_is_tuple=True) # decoder Op

def linear(x,output_dim):
"""
Expand Down Expand Up @@ -73,7 +73,8 @@ def filterbank(gx, gy, sigma2,delta, N):
def attn_window(scope,h_dec,N):
with tf.variable_scope(scope,reuse=DO_SHARE):
params=linear(h_dec,5)
gx_,gy_,log_sigma2,log_delta,log_gamma=tf.split(1,5,params)
# gx_,gy_,log_sigma2,log_delta,log_gamma=tf.split(1,5,params)
gx_,gy_,log_sigma2,log_delta,log_gamma=tf.split(params,5,1)
gx=(A+1)/2*(gx_+1)
gy=(B+1)/2*(gy_+1)
sigma2=tf.exp(log_sigma2)
Expand All @@ -82,19 +83,19 @@ def attn_window(scope,h_dec,N):

## READ ##
def read_no_attn(x,x_hat,h_dec_prev):
return tf.concat(1,[x,x_hat])
return tf.concat([x,x_hat], 1)

def read_attn(x,x_hat,h_dec_prev):
Fx,Fy,gamma=attn_window("read",h_dec_prev,read_n)
def filter_img(img,Fx,Fy,gamma,N):
Fxt=tf.transpose(Fx,perm=[0,2,1])
img=tf.reshape(img,[-1,B,A])
glimpse=tf.batch_matmul(Fy,tf.batch_matmul(img,Fxt))
glimpse=tf.matmul(Fy,tf.matmul(img,Fxt))
glimpse=tf.reshape(glimpse,[-1,N*N])
return glimpse*tf.reshape(gamma,[-1,1])
x=filter_img(x,Fx,Fy,gamma,read_n) # batch x (read_n*read_n)
x_hat=filter_img(x_hat,Fx,Fy,gamma,read_n)
return tf.concat(1,[x,x_hat]) # concat along feature axis
return tf.concat([x,x_hat], 1) # concat along feature axis

read = read_attn if FLAGS.read_attn else read_no_attn

Expand Down Expand Up @@ -140,7 +141,7 @@ def write_attn(h_dec):
w=tf.reshape(w,[batch_size,N,N])
Fx,Fy,gamma=attn_window("write",h_dec,write_n)
Fyt=tf.transpose(Fy,perm=[0,2,1])
wr=tf.batch_matmul(Fyt,tf.batch_matmul(w,Fx))
wr=tf.matmul(Fyt,tf.matmul(w,Fx))
wr=tf.reshape(wr,[batch_size,B*A])
#gamma=tf.tile(gamma,[1,B*A])
return wr*tf.reshape(1.0/gamma,[-1,1])
Expand All @@ -163,7 +164,7 @@ def write_attn(h_dec):
c_prev = tf.zeros((batch_size,img_size)) if t==0 else cs[t-1]
x_hat=x-tf.sigmoid(c_prev) # error image
r=read(x,x_hat,h_dec_prev)
h_enc,enc_state=encode(enc_state,tf.concat(1,[r,h_dec_prev]))
h_enc,enc_state=encode(enc_state,tf.concat([r,h_dec_prev], 1))
z,mus[t],logsigmas[t],sigmas[t]=sampleQ(h_enc)
h_dec,dec_state=decode(dec_state,z)
cs[t]=c_prev+write(h_dec) # store results
Expand Down Expand Up @@ -217,7 +218,7 @@ def binary_crossentropy(t,o):
sess=tf.InteractiveSession()

saver = tf.train.Saver() # saves variables learned during training
tf.initialize_all_variables().run()
tf.global_variables_initializer().run()
#saver.restore(sess, "/tmp/draw/drawmodel.ckpt") # to restore from model, uncomment this line

for i in range(train_iters):
Expand All @@ -241,5 +242,3 @@ def binary_crossentropy(t,o):
print("Model saved in file: %s" % saver.save(sess,ckpt_file))

sess.close()

print('Done drawing! Have a nice day! :)')

0 comments on commit c384dbe

Please sign in to comment.