This repository has been archived by the owner on Nov 6, 2024. It is now read-only.
forked from jgraving/DeepPoseKit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- The perfect image format seems to be a square and gray image - Added the photo which respect these constraints - Updated the classes which implement this photo (annotation_set, annotator) - Created the class "train" useful to train the network
- Loading branch information
Alberto Ursino
committed
Nov 26, 2020
1 parent
6480faa
commit 89b3db8
Showing
7 changed files
with
182 additions
and
9 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,12 @@ | ||
from deepposekit import Annotator | ||
from os.path import expanduser | ||
import glob | ||
|
||
HOME = 'C:/Users/Alberto Ursino/Desktop/IntellIj Local Files/DeepPoseKit/alberto' | ||
|
||
app = Annotator( | ||
datapath=HOME + '/deepposekit-data/datasets/dog/example_annotation_set.h5', | ||
datapath=HOME + '/deepposekit-data/datasets/dog/annotation_set.h5', | ||
dataset='images', | ||
skeleton=HOME + '/deepposekit-data/datasets/dog/skeleton.csv', | ||
shuffle_colors=False, | ||
text_scale=1) | ||
text_scale=0.7) | ||
|
||
app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import tensorflow as tf | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import glob | ||
|
||
from deepposekit.io import TrainingGenerator, DataGenerator | ||
from deepposekit.augment import FlipAxis | ||
import imgaug.augmenters as iaa | ||
import imgaug as ia | ||
|
||
from deepposekit.models import StackedHourglass | ||
from deepposekit.models import load_model | ||
|
||
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping | ||
|
||
from deepposekit.callbacks import Logger, ModelCheckpoint | ||
|
||
import time | ||
from os.path import expanduser | ||
|
||
HOME = 'C:/Users/Alberto Ursino/Desktop/IntellIj Local Files/DeepPoseKit/alberto' | ||
|
||
data_generator = DataGenerator(HOME + '/deepposekit-data/datasets/dog/annotation_set.h5') | ||
|
||
# image, keypoints = data_generator[0] | ||
# | ||
# plt.figure(figsize=(5, 5)) | ||
# image = image[0] if image.shape[-1] is 3 else image[0, ..., 0] | ||
# cmap = None if image.shape[-1] is 3 else 'gray' | ||
# plt.imshow(image, cmap=cmap, interpolation='none') | ||
# for idx, jdx in enumerate(data_generator.graph): | ||
# if jdx > -1: | ||
# plt.plot( | ||
# [keypoints[0, idx, 0], keypoints[0, jdx, 0]], | ||
# [keypoints[0, idx, 1], keypoints[0, jdx, 1]], | ||
# 'r-' | ||
# ) | ||
# plt.scatter(keypoints[0, :, 0], keypoints[0, :, 1], c=np.arange(data_generator.keypoints_shape[0]), s=50, | ||
# cmap=plt.cm.hsv, zorder=3) | ||
# | ||
# plt.show() | ||
|
||
# Augmentation | ||
|
||
augmenter = [] | ||
|
||
augmenter.append(FlipAxis(data_generator, axis=0)) # flip image up-down | ||
augmenter.append(FlipAxis(data_generator, axis=1)) # flip image left-right | ||
|
||
sometimes = [] | ||
sometimes.append(iaa.Affine(scale={"x": (0.95, 1.05), "y": (0.95, 1.05)}, | ||
translate_percent={'x': (-0.05, 0.05), 'y': (-0.05, 0.05)}, | ||
shear=(-8, 8), | ||
order=ia.ALL, | ||
cval=ia.ALL, | ||
mode=ia.ALL) | ||
) | ||
sometimes.append(iaa.Affine(scale=(0.8, 1.2), | ||
mode=ia.ALL, | ||
order=ia.ALL, | ||
cval=ia.ALL) | ||
) | ||
augmenter.append(iaa.Sometimes(0.75, sometimes)) | ||
augmenter.append(iaa.Affine(rotate=(-180, 180), | ||
mode=ia.ALL, | ||
order=ia.ALL, | ||
cval=ia.ALL) | ||
) | ||
augmenter = iaa.Sequential(augmenter) | ||
|
||
# image, keypoints = data_generator[0] | ||
# image, keypoints = augmenter(images=image, keypoints=keypoints) | ||
# plt.figure(figsize=(5,5)) | ||
# image = image[0] if image.shape[-1] is 3 else image[0, ..., 0] | ||
# cmap = None if image.shape[-1] is 3 else 'gray' | ||
# plt.imshow(image, cmap=cmap, interpolation='none') | ||
# for idx, jdx in enumerate(data_generator.graph): | ||
# if jdx > -1: | ||
# plt.plot( | ||
# [keypoints[0, idx, 0], keypoints[0, jdx, 0]], | ||
# [keypoints[0, idx, 1], keypoints[0, jdx, 1]], | ||
# 'r-' | ||
# ) | ||
# plt.scatter(keypoints[0, :, 0], keypoints[0, :, 1], c=np.arange(data_generator.keypoints_shape[0]), s=50, cmap=plt.cm.hsv, zorder=3) | ||
# | ||
# plt.show() | ||
|
||
train_generator = TrainingGenerator(generator=data_generator, | ||
downsample_factor=3, | ||
augmenter=augmenter, | ||
sigma=5, | ||
validation_split=0, | ||
use_graph=True, | ||
random_seed=1, | ||
graph_scale=1) | ||
train_generator.get_config() | ||
|
||
# n_keypoints = data_generator.keypoints_shape[0] | ||
# batch = train_generator(batch_size=1, validation=False)[0] | ||
# inputs = batch[0] | ||
# outputs = batch[1] | ||
# | ||
# fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10, 10)) | ||
# ax1.set_title('image') | ||
# ax1.imshow(inputs[0, ..., 0], cmap='gray', vmin=0, vmax=255) | ||
# | ||
# ax2.set_title('posture graph') | ||
# ax2.imshow(outputs[0, ..., n_keypoints:-1].max(-1)) | ||
# | ||
# ax3.set_title('keypoints confidence') | ||
# ax3.imshow(outputs[0, ..., :n_keypoints].max(-1)) | ||
# | ||
# ax4.set_title('posture graph and keypoints confidence') | ||
# ax4.imshow(outputs[0, ..., -1], vmin=0) | ||
# plt.show() | ||
# | ||
# train_generator.on_epoch_end() | ||
|
||
# Define a model | ||
|
||
model = StackedHourglass(train_generator) | ||
|
||
model.get_config() | ||
|
||
# data_size = (10,) + data_generator.image_shape | ||
# x = np.random.randint(0, 255, data_size, dtype="uint8") | ||
# y = model.predict(x[:100], batch_size=100) # make sure the model is in GPU memory | ||
# t0 = time.time() | ||
# y = model.predict(x, batch_size=100, verbose=1) | ||
# t1 = time.time() | ||
# print(x.shape[0] / (t1 - t0)) | ||
|
||
logger = Logger(validation_batch_size=10, | ||
# filepath saves the logger data to a .h5 file | ||
filepath=HOME + "/deepposekit-data/datasets/dog/log_densenet.h5" | ||
) | ||
|
||
# Remember, if you set validation_split=0 for your TrainingGenerator, | ||
# which will just use the training set for model fitting, | ||
# make sure to set monitor="loss" instead of monitor="val_loss". | ||
reduce_lr = ReduceLROnPlateau(monitor="loss", factor=0.2, verbose=1, patience=20) | ||
|
||
model_checkpoint = ModelCheckpoint( | ||
HOME + "/deepposekit-data/datasets/fly/best_model_densenet.h5", | ||
monitor="val_loss", | ||
# monitor="loss" # use if validation_split=0 | ||
verbose=1, | ||
save_best_only=True, | ||
) | ||
|
||
early_stop = EarlyStopping( | ||
monitor="val_loss", | ||
# monitor="loss" # use if validation_split=0 | ||
min_delta=0.001, | ||
patience=100, | ||
verbose=1 | ||
) | ||
|
||
callbacks = [early_stop, reduce_lr, model_checkpoint, logger] |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
name,parent,swap | ||
snout,, | ||
head,snout, | ||
neck,head, | ||
forelegL1,neck,forelegR1 | ||
forelegR1,neck,forelegL1 | ||
hindlegL1,tailbase,hindlegR1 | ||
hindlegR1,tailbase,hindlegL1 | ||
tailbase,, | ||
tailtip,tailbase, |