Skip to content

Commit

Permalink
Merge branch 'master' of github.com:sethjuarez/AttnGAN
Browse files Browse the repository at this point in the history
  • Loading branch information
sethjuarez committed May 1, 2018
2 parents edc33da + 78cb3d9 commit 8638971
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 67 deletions.
2 changes: 1 addition & 1 deletion eval/build.cmd
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
docker build -t "attngan" .
docker build -t "attngan" -f dockerfile.cpu .
REM docker run -it -e BLOB_KEY=KEY -p 5678:8080 attngan
REM curl -H "Content-Type: application/json" -X POST -d '{"caption":"the bird has a yellow crown and a black eyering that is round"}' https://attgan.azurewebsites.net/api/v1.0/bird
10 changes: 6 additions & 4 deletions eval/dockerfile → eval/dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@ RUN mkdir -p /usr/src/app
WORKDIR /usr/src/app

COPY requirements.txt /usr/src/app/
RUN pip install --upgrade pip
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install http://download.pytorch.org/whl/cpu/torch-0.3.1-cp27-cp27mu-linux_x86_64.whl
RUN pip install torchvision

COPY . /usr/src/app

ENV GPU False
ENV EXPORT_MODEL True

EXPOSE 8080

CMD ["python", "main.py"]

#RUN python code/main.py --cfg code/cfg/eval_bird.yml

#RUN python main.py --cfg cfg/eval_bird.yml --gpu -1

#CMD ["sh"]
29 changes: 29 additions & 0 deletions eval/dockerfile.gpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
FROM nvidia/cuda

RUN apt-get update \
&& apt-get upgrade -y \
&& apt-get install -y \
python-pip \
python2.7 \
&& apt-get autoremove \
&& apt-get clean

RUN mkdir -p /usr/src/app
WORKDIR /usr/src/app

COPY requirements.txt /usr/src/app/
RUN pip install --upgrade pip
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install http://download.pytorch.org/whl/cu90/torch-0.3.1-cp27-cp27mu-linux_x86_64.whl
RUN pip install torchvision

COPY . /usr/src/app

ENV NVIDIA_VISIBLE_DEVICES all
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
ENV GPU True
ENV EXPORT_MODEL False

EXPOSE 8080

CMD ["python", "main.py"]
135 changes: 92 additions & 43 deletions eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import numpy as np
from PIL import Image
import torch.onnx
from datetime import datetime
from torch.autograd import Variable
from miscc.config import cfg
Expand All @@ -22,7 +23,7 @@
from werkzeug.contrib.cache import SimpleCache
cache = SimpleCache()

def vectorize_caption(wordtoix, caption):
def vectorize_caption(wordtoix, caption, copies=2):
# create caption vector
tokens = caption.split(' ')
cap_v = []
Expand All @@ -32,34 +33,68 @@ def vectorize_caption(wordtoix, caption):
cap_v.append(wordtoix[t])

# expected state for single generation
captions, cap_lens = np.array([cap_v, cap_v]), np.array([len(cap_v), len(cap_v)])
captions = np.zeros((copies, len(cap_v)))
for i in range(copies):
captions[i,:] = np.array(cap_v)
cap_lens = np.zeros(copies) + len(cap_v)

return captions, cap_lens
#print(captions.astype(int), cap_lens.astype(int))
#captions, cap_lens = np.array([cap_v, cap_v]), np.array([len(cap_v), len(cap_v)])
#print(captions, cap_lens)
#return captions, cap_lens

def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service):
return captions.astype(int), cap_lens.astype(int)

def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copies=2):
# load word vector
captions, cap_lens = vectorize_caption(wordtoix, caption)
captions, cap_lens = vectorize_caption(wordtoix, caption, copies)
n_words = len(wordtoix)

# only one to generate
batch_size = captions.shape[0]

nz = cfg.GAN.Z_DIM
captions = Variable(torch.from_numpy(captions), volatile=True)
cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)
noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)

if cfg.CUDA:
captions = captions.cuda()
cap_lens = cap_lens.cuda()
noise = noise.cuda()



#######################################################
# (1) Extract text embeddings
#######################################################
hidden = text_encoder.init_hidden(batch_size)
words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
mask = (captions == 0)


#######################################################
# (2) Generate fake images
#######################################################
noise.data.normal_(0, 1)
fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)

# ONNX EXPORT
#export = os.environ["EXPORT_MODEL"].lower() == 'true'
if False:
print("saving text_encoder.onnx")
text_encoder_out = torch.onnx._export(text_encoder, (captions, cap_lens, hidden), "text_encoder.onnx", export_params=True)
print("uploading text_encoder.onnx")
blob_service.create_blob_from_path('models', "text_encoder.onnx", os.path.abspath("text_encoder.onnx"))
print("done")

print("saving netg.onnx")
netg_out = torch.onnx._export(netG, (noise, sent_emb, words_embs, mask), "netg.onnx", export_params=True)
print("uploading netg.onnx")
blob_service.create_blob_from_path('models', "netg.onnx", os.path.abspath("netg.onnx"))
print("done")
return

# G attention
cap_lens_np = cap_lens.cpu().data.numpy()

Expand All @@ -69,53 +104,63 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service):
prefix = datetime.now().strftime('%Y/%B/%d/%H_%M_%S_%f')
urls = []
# only look at first one
j = 0
for k in range(len(fake_imgs)):
im = fake_imgs[k][j].data.cpu().numpy()
im = (im + 1.0) * 127.5
im = im.astype(np.uint8)
im = np.transpose(im, (1, 2, 0))
im = Image.fromarray(im)

# save image to stream
stream = io.BytesIO()
im.save(stream, format="png")
stream.seek(0)

blob_name = '%s/%s_g%d.png' % (prefix, "bird", k)
blob_service.create_blob_from_stream(container_name, blob_name, stream)
urls.append(full_path % blob_name)

for k in range(len(attention_maps)):
if len(fake_imgs) > 1:
im = fake_imgs[k + 1].detach().cpu()
else:
im = fake_imgs[0].detach().cpu()
attn_maps = attention_maps[k]
att_sze = attn_maps.size(2)
img_set, sentences = \
build_super_images2(im[j].unsqueeze(0),
captions[j].unsqueeze(0),
[cap_lens_np[j]], ixtoword,
[attn_maps[j]], att_sze)
if img_set is not None:
im = Image.fromarray(img_set)
#j = 0
for j in range(batch_size):
for k in range(len(fake_imgs)):
im = fake_imgs[k][j].data.cpu().numpy()
im = (im + 1.0) * 127.5
im = im.astype(np.uint8)
im = np.transpose(im, (1, 2, 0))
im = Image.fromarray(im)

# save image to stream
stream = io.BytesIO()
im.save(stream, format="png")
stream.seek(0)

blob_name = '%s/%s_a%d.png' % (prefix, "attmaps", k)
if copies > 2:
blob_name = '%s/%d/%s_g%d.png' % (prefix, j, "bird", k)
else:
blob_name = '%s/%s_g%d.png' % (prefix, "bird", k)
blob_service.create_blob_from_stream(container_name, blob_name, stream)
urls.append(full_path % blob_name)

if copies == 2:
for k in range(len(attention_maps)):
#if False:
if len(fake_imgs) > 1:
im = fake_imgs[k + 1].detach().cpu()
else:
im = fake_imgs[0].detach().cpu()

attn_maps = attention_maps[k]
att_sze = attn_maps.size(2)

img_set, sentences = \
build_super_images2(im[j].unsqueeze(0),
captions[j].unsqueeze(0),
[cap_lens_np[j]], ixtoword,
[attn_maps[j]], att_sze)

if img_set is not None:
im = Image.fromarray(img_set)
stream = io.BytesIO()
im.save(stream, format="png")
stream.seek(0)

blob_name = '%s/%s_a%d.png' % (prefix, "attmaps", k)
blob_service.create_blob_from_stream(container_name, blob_name, stream)
urls.append(full_path % blob_name)
if copies == 2:
break

#print(len(urls), urls)
return urls

def word_index():

ixtoword = cache.get('ixtoword')
wordtoix = cache.get('wordtoix')
if ixtoword is None or wordtoix is None:
print("ix and word not cached")
#print("ix and word not cached")
# load word to index dictionary
x = pickle.load(open('data/captions.pickle', 'rb'))
ixtoword = x[2]
Expand All @@ -127,22 +172,26 @@ def word_index():
return wordtoix, ixtoword

def models(word_len):

#print(word_len)
text_encoder = cache.get('text_encoder')
if text_encoder is None:
print("text_encoder not cached")
#print("text_encoder not cached")
text_encoder = RNN_ENCODER(word_len, nhidden=cfg.TEXT.EMBEDDING_DIM)
state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
text_encoder.load_state_dict(state_dict)
if cfg.CUDA:
text_encoder.cuda()
text_encoder.eval()
cache.set('text_encoder', text_encoder, timeout=60 * 60 * 24)

netG = cache.get('netG')
if netG is None:
print("netG not cached")
#print("netG not cached")
netG = G_NET()
state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)
if cfg.CUDA:
netG.cuda()
netG.eval()
cache.set('netG', netG, timeout=60 * 60 * 24)

Expand Down
88 changes: 73 additions & 15 deletions eval/main.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,93 @@
#!flask/bin/python
import os
import time
import random
from eval import *
from flask import Flask, jsonify, request, abort
from applicationinsights import TelemetryClient
from applicationinsights.requests import WSGIApplication
from applicationinsights.exceptions import enable
from miscc.config import cfg
from werkzeug.contrib.profiler import ProfilerMiddleware
#from werkzeug.contrib.profiler import ProfilerMiddleware

enable(os.environ["TELEMETRY"])
app = Flask(__name__)


# load word dictionaries
wordtoix, ixtoword = word_index()
# lead models
text_encoder, netG = models(len(wordtoix))
# load blob service
blob_service = BlockBlobService(account_name='attgan', account_key=os.environ["BLOB_KEY"])
app.wsgi_app = WSGIApplication(os.environ["TELEMETRY"], app.wsgi_app)

@app.route('/api/v1.0/bird', methods=['POST'])
def create_bird():
if not request.json or not 'caption' in request.json:
abort(400)

response = eval(request.json['caption'])
caption = request.json['caption']

t0 = time.time()
urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service)
t1 = time.time()

response = {
'small': urls[0],
'medium': urls[1],
'large': urls[2],
'map1': urls[3],
'map2': urls[4],
'caption': caption,
'elapsed': t1 - t0
}
return jsonify({'bird': response}), 201

@app.route('/api/v1.0/birds', methods=['POST'])
def create_birds():
if not request.json or not 'caption' in request.json:
abort(400)

caption = request.json['caption']

t0 = time.time()
urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copies=6)
t1 = time.time()

response = {
'bird1' : { 'small': urls[0], 'medium': urls[1], 'large': urls[2] },
'bird2' : { 'small': urls[3], 'medium': urls[4], 'large': urls[5] },
'bird3' : { 'small': urls[6], 'medium': urls[7], 'large': urls[8] },
'bird4' : { 'small': urls[9], 'medium': urls[10], 'large': urls[11] },
'bird5' : { 'small': urls[12], 'medium': urls[13], 'large': urls[14] },
'bird6' : { 'small': urls[15], 'medium': urls[16], 'large': urls[17] },
'caption': caption,
'elapsed': t1 - t0
}
return jsonify({'bird': response}), 201

@app.route('/', methods=['GET'])
def get_bird():
return 'hello!'
return 'Version 1'

if __name__ == '__main__':
app.config['PROFILE'] = True
app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[30])
app.run(host='0.0.0.0', port=8080, debug=True)
t0 = time.time()
tc = TelemetryClient(os.environ["TELEMETRY"])

# gpu based
cfg.CUDA = os.environ["GPU"].lower() == 'true'
tc.track_event('container initializing', {"CUDA": str(cfg.CUDA)})

# load word dictionaries
wordtoix, ixtoword = word_index()
# lead models
text_encoder, netG = models(len(wordtoix))
# load blob service
blob_service = BlockBlobService(account_name='attgan', account_key=os.environ["BLOB_KEY"])

seed = 100
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if cfg.CUDA:
torch.cuda.manual_seed_all(seed)

#app.config['PROFILE'] = True
#app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[30])
#app.run(host='0.0.0.0', port=8080, debug = True)

t1 = time.time()
tc.track_event('container start', {"starttime": str(t1-t0)})
app.run(host='0.0.0.0', port=8080)
Loading

0 comments on commit 8638971

Please sign in to comment.