Skip to content

Commit

Permalink
PythonAPI: more keypoint support (showAnns, loadRes)
Browse files Browse the repository at this point in the history
  • Loading branch information
tylin committed Sep 7, 2016
1 parent af29c4f commit 1a32502
Showing 1 changed file with 42 additions and 47 deletions.
89 changes: 42 additions & 47 deletions PythonAPI/pycocotools/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,17 @@
# Licensed under the Simplified BSD License [see bsd.txt]

import json
import datetime
import time
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
import numpy as np
from skimage.draw import polygon
import urllib
import copy
import itertools
import mask
import os
from collections import defaultdict

class COCO:
def __init__(self, annotation_file=None):
Expand All @@ -67,12 +66,8 @@ def __init__(self, annotation_file=None):
:return:
"""
# load dataset
self.dataset = {}
self.anns = []
self.imgToAnns = {}
self.catToImgs = {}
self.imgs = {}
self.cats = {}
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
if not annotation_file == None:
print 'loading annotations into memory...'
tic = time.time()
Expand All @@ -84,30 +79,22 @@ def __init__(self, annotation_file=None):
def createIndex(self):
# create index
print 'creating index...'
anns = {}
imgToAnns = {}
catToImgs = {}
cats = {}
imgs = {}
anns,cats,imgs = dict(),dict(),dict()
imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
if 'annotations' in self.dataset:
imgToAnns = {ann['image_id']: [] for ann in self.dataset['annotations']}
anns = {ann['id']: [] for ann in self.dataset['annotations']}
for ann in self.dataset['annotations']:
imgToAnns[ann['image_id']] += [ann]
imgToAnns[ann['image_id']].append(ann)
anns[ann['id']] = ann

if 'images' in self.dataset:
imgs = {im['id']: {} for im in self.dataset['images']}
for img in self.dataset['images']:
imgs[img['id']] = img

if 'categories' in self.dataset:
cats = {cat['id']: [] for cat in self.dataset['categories']}
for cat in self.dataset['categories']:
cats[cat['id']] = cat
catToImgs = {cat['id']: [] for cat in self.dataset['categories']}
for ann in self.dataset['annotations']:
catToImgs[ann['category_id']] += [ann['image_id']]
catToImgs[ann['category_id']].append(ann['image_id'])

print 'index created!'

Expand Down Expand Up @@ -142,7 +129,6 @@ def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
anns = self.dataset['annotations']
else:
if not len(imgIds) == 0:
# this can be changed by defaultdict
lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
anns = list(itertools.chain.from_iterable(lists))
else:
Expand Down Expand Up @@ -239,39 +225,42 @@ def showAnns(self, anns):
"""
if len(anns) == 0:
return 0
if 'segmentation' in anns[0]:
if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
datasetType = 'instances'
elif 'caption' in anns[0]:
datasetType = 'captions'
else:
raise Exception("datasetType not supported")
if datasetType == 'instances':
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in anns:
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
if type(ann['segmentation']) == list:
# polygon
for seg in ann['segmentation']:
poly = np.array(seg).reshape((len(seg)/2, 2))
polygons.append(Polygon(poly))
color.append(c)
else:
# mask
t = self.imgs[ann['image_id']]
if type(ann['segmentation']['counts']) == list:
rle = mask.frPyObjects([ann['segmentation']], t['height'], t['width'])
if 'segmentation' in ann:
if type(ann['segmentation']) == list:
# polygon
for seg in ann['segmentation']:
poly = np.array(seg).reshape((len(seg)/2, 2))
polygons.append(Polygon(poly))
color.append(c)
else:
rle = [ann['segmentation']]
m = mask.decode(rle)
img = np.ones( (m.shape[0], m.shape[1], 3) )
if ann['iscrowd'] == 1:
color_mask = np.array([2.0,166.0,101.0])/255
if ann['iscrowd'] == 0:
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack( (img, m*0.5) ))
# mask
t = self.imgs[ann['image_id']]
if type(ann['segmentation']['counts']) == list:
rle = mask.frPyObjects([ann['segmentation']], t['height'], t['width'])
else:
rle = [ann['segmentation']]
m = mask.decode(rle)
img = np.ones( (m.shape[0], m.shape[1], 3) )
if ann['iscrowd'] == 1:
color_mask = np.array([2.0,166.0,101.0])/255
if ann['iscrowd'] == 0:
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack( (img, m*0.5) ))
if 'keypoints' in ann and type(ann['keypoints']) == list:
# turn skeleton into zero-based index
sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
Expand All @@ -282,8 +271,8 @@ def showAnns(self, anns):
for sk in sks:
if np.all(v[sk]>0):
plt.plot(x[sk],y[sk], linewidth=3, color=c)
plt.plot(x[v==1], y[v==1],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
plt.plot(x[v==2], y[v==2],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
ax.add_collection(p)
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
Expand All @@ -300,8 +289,6 @@ def loadRes(self, resFile):
"""
res = COCO()
res.dataset['images'] = [img for img in self.dataset['images']]
# res.dataset['info'] = copy.deepcopy(self.dataset['info'])
# res.dataset['licenses'] = copy.deepcopy(self.dataset['licenses'])

print 'Loading and preparing results... '
tic = time.time()
Expand Down Expand Up @@ -339,6 +326,14 @@ def loadRes(self, resFile):
ann['bbox'] = mask.toBbox([ann['segmentation']])[0]
ann['id'] = id+1
ann['iscrowd'] = 0
elif 'keypoints' in anns[0]:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
s = ann['keypoints']
x = s[0::3]
y = s[1::3]
ann['area'] = float((np.max(x)-np.min(x))*(np.max(y)-np.min(y)))
ann['id'] = id + 1
print 'DONE (t=%0.2fs)'%(time.time()- tic)

res.dataset['annotations'] = anns
Expand Down

0 comments on commit 1a32502

Please sign in to comment.