forked from levihsu/OOTDiffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
eb22946
commit 3fa73c0
Showing
3 changed files
with
289 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import os | ||
import numpy as np | ||
|
||
class Image2VideoWriter(): | ||
def __init__(self): | ||
self.image_list = [] | ||
|
||
def append(self,image): | ||
self.image_list.append(image) | ||
|
||
def make_video(self,outvid=None, fps=5, size=None, | ||
is_color=True, format="MP4V", isRGB=False): | ||
""" | ||
Create a video from a list of images. | ||
@param outvid output video | ||
@param images list of images to use in the video, BGR format | ||
@param fps frame per second | ||
@param size size of each frame | ||
@param is_color color | ||
@param format see http://www.fourcc.org/codecs.php | ||
@return see http://opencv-python-tutroals.readthedocs.org/en/latest/py_tutorials/py_gui/py_video_display/py_video_display.html | ||
The function relies on http://opencv-python-tutroals.readthedocs.org/en/latest/. | ||
By default, the video will have the size of the first image. | ||
It will resize every image to this size before adding them to the video. | ||
""" | ||
from cv2 import VideoWriter, VideoWriter_fourcc, imread, resize | ||
fourcc = VideoWriter_fourcc(*format) | ||
vid = None | ||
for image in self.image_list: | ||
img = image | ||
if isRGB: | ||
img=img[:,:,[2,1,0]] | ||
if vid is None: | ||
if size is None: | ||
size = img.shape[1], img.shape[0] | ||
if size[0]+size[1]>3000: | ||
size = img.shape[1]//2, img.shape[0]//2 | ||
vid = VideoWriter(outvid, fourcc, float(fps), size, is_color) | ||
if size[0] != img.shape[1] and size[1] != img.shape[0]: | ||
img = resize(img, size) | ||
vid.write(img) | ||
vid.release() | ||
path, name = os.path.split(outvid) | ||
os.system("ffmpeg -i " + outvid + " -vcodec libx264 " + os.path.join(path,name.split('.')[0]+'temp.mp4')) | ||
os.system("rm " + outvid) | ||
os.system("mv "+os.path.join(path, name.split('.')[0]+'temp.mp4')+ " "+outvid) | ||
return vid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from PIL import Image | ||
import cv2 | ||
import numpy as np | ||
|
||
|
||
def crop2_43(img: Image.Image): | ||
img_reshaper = ImageReshaper(img) | ||
return img_reshaper.get_reshaped() | ||
|
||
|
||
class ImageReshaper: | ||
def __init__(self, img: Image.Image): | ||
self.img = img | ||
self.trans, self.inv_trans = crop2_43_trans(self.img) | ||
w, h = self.img.size | ||
self.trans_mask = self.get_trans_mask(self.inv_trans, [h, w]) | ||
|
||
def get_reshaped(self): | ||
img = np.array(self.img) | ||
new_h = 1024 | ||
new_w = 768 | ||
trans_img = cv2.warpAffine(img, self.trans, (new_w, new_h), | ||
flags=cv2.INTER_LINEAR, | ||
borderMode=cv2.BORDER_CONSTANT, | ||
borderValue=(0, 0, 0)) | ||
return Image.fromarray(trans_img) | ||
|
||
def back2rawSahpe(self, img): | ||
raw_img = np.array(self.img) | ||
new_img = img | ||
w, h = self.img.size | ||
raw_new_img = self.roi2raw(new_img, self.inv_trans, [h, w]) | ||
composed = raw_img.copy() | ||
composed[self.trans_mask] = raw_new_img[self.trans_mask] | ||
return composed | ||
|
||
def roi2raw(self, img, trans, raw_shape): | ||
trans_img = cv2.warpAffine(img, trans, (raw_shape[1], raw_shape[0]), | ||
flags=cv2.INTER_LINEAR, | ||
borderMode=cv2.BORDER_REPLICATE, # cv2.BORDER_CONSTANT, | ||
# borderValue=(0, 0, 0) | ||
) | ||
return trans_img | ||
|
||
def get_trans_mask(self, inv_trans, raw_shape): | ||
mask = np.ones([1024, 768]).astype(np.uint8) | ||
roi_mask = cv2.warpAffine(mask, inv_trans, (raw_shape[1], raw_shape[0]), | ||
flags=cv2.INTER_LINEAR, | ||
borderMode=cv2.BORDER_CONSTANT, | ||
borderValue=0 | ||
) | ||
roi_mask = roi_mask.astype(bool) | ||
return roi_mask | ||
|
||
|
||
def crop2_43_trans(img: Image.Image): | ||
raw_h, raw_w = img.size | ||
src = np.zeros([3, 2], np.float32) | ||
if 3 * raw_h > 4 * raw_w: # too tall | ||
delta = (raw_h - raw_w * (4 / 3)) / 2 | ||
src[0, :] = np.array([0 + delta, 0], np.float32) | ||
src[1, :] = np.array([raw_h - delta, 0], np.float32) | ||
src[2, :] = np.array([raw_h - delta, raw_w], np.float32) | ||
else: # too wide | ||
delta = (raw_w - raw_h * (3 / 4)) / 2 | ||
src[0, :] = np.array([0, 0 + delta], np.float32) | ||
src[1, :] = np.array([raw_h, 0 + delta], np.float32) | ||
src[2, :] = np.array([raw_h, raw_w - delta], np.float32) | ||
|
||
dst = np.zeros([3, 2], np.float32) | ||
dst[0, :] = np.array([0, 0], np.float32) | ||
dst[1, :] = np.array([1024, 0], np.float32) | ||
dst[2, :] = np.array([1024, 768], np.float32) | ||
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) | ||
inv_trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) | ||
return trans, inv_trans |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import cv2 | ||
import numpy as np | ||
# from util.garment_heatmap import HeatmapGenerator | ||
import torch | ||
import torchvision.transforms as transforms | ||
import ffmpeg | ||
from OpticalFlow.optical_flow import OpticalFlow | ||
|
||
class VideoLoader: | ||
def __init__(self, path): | ||
self.path = path | ||
self.frames = self.load_video() | ||
self.min_h = 0 | ||
self.min_w = 0 | ||
self.max_h = self.frameHeight | ||
self.max_w = self.frameWidth | ||
self.crop2square() | ||
self.l = 0 | ||
self.r = 0 | ||
self.u = 0 | ||
self.d = 0 | ||
if self.frameHeight > self.frameWidth: | ||
self.l = (self.frameHeight - self.frameWidth) // 2 | ||
self.r = self.l | ||
# self.heatmap_gen = HeatmapGenerator() | ||
self.post_transform = transforms.Resize((512, 512)) | ||
self.opt_flow = None | ||
self.optical_flow = OpticalFlow() | ||
|
||
def compute_opt_flow(self): | ||
print("Start computing optical flow") | ||
opt_flow_list = [] | ||
for i in range(self.__len__()-1): | ||
with torch.no_grad(): | ||
opt_flow = self.optical_flow(self.frames[i],self.frames[i+1]).cpu() | ||
opt_flow_list.append(opt_flow) | ||
self.opt_flow=opt_flow_list | ||
print("Finish computing optical flow:",self.opt_flow.__len__()) | ||
|
||
def crop2square(self): | ||
if self.frameWidth > self.frameHeight: | ||
offset = (self.frameWidth - self.frameHeight) // 2 | ||
self.min_w = offset | ||
self.max_w = offset + self.frameHeight | ||
|
||
def __getitem__(self, idx): | ||
im = self.get_image(idx) | ||
|
||
normalize = transforms.Normalize( | ||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | ||
) | ||
resize = transforms.Resize((384, 288)) | ||
|
||
all_transforms = transforms.Compose([normalize, resize]) | ||
|
||
# with torch.no_grad(): | ||
# heatmaps = self.heatmap_gen.model(all_transforms(im)) | ||
|
||
# heatmaps = self.post_transform(heatmaps) | ||
return im # , heatmaps | ||
|
||
def __len__(self): | ||
return self.frames.shape[0] | ||
|
||
def set_bbox(self, min_h, min_w, max_h, max_w): | ||
self.min_h = min_h | ||
self.min_w = min_w | ||
self.max_h = max_h | ||
self.max_w = max_w | ||
|
||
def set_padding(self, l, r, u, d): | ||
self.l = l | ||
self.r = r | ||
self.u = u | ||
self.d = d | ||
|
||
def get_image(self, idx): | ||
frame = self.get_numpy_image(idx) | ||
img = torch.from_numpy(frame) / 255.0 | ||
img = img.permute(2, 0, 1) # CHW, BGR | ||
if torch.cuda.is_available(): | ||
img = img.cuda() | ||
img = img.unsqueeze(0) | ||
img = self.post_transform(img) | ||
return img | ||
|
||
def get_numpy_image(self, idx): | ||
frame = self.frames[idx] | ||
frame = frame[self.min_h:self.max_h, self.min_w:self.max_w, :] | ||
if self.l > 0: | ||
left = np.zeros((frame.shape[0], self.l, frame.shape[2]), np.uint8) | ||
frame = np.concatenate((left, frame), 1) | ||
if self.r > 0: | ||
right = np.zeros((frame.shape[0], self.r, frame.shape[2]), np.uint8) | ||
frame = np.concatenate((frame, right), 1) | ||
if self.u > 0: | ||
up = np.zeros((self.u, frame.shape[1], frame.shape[2]), np.uint8) | ||
frame = np.concatenate((up, frame), 0) | ||
if self.d > 0: | ||
down = np.zeros((self.d, frame.shape[1], frame.shape[2]), np.uint8) | ||
frame = np.concatenate((frame, down), 0) | ||
frame = cv2.resize(frame, dsize=(512, 512), interpolation=cv2.INTER_CUBIC) | ||
return frame | ||
|
||
def get_raw_numpy_image(self, idx): | ||
frame = self.frames[idx] | ||
return frame | ||
|
||
def get_heatmap(self, idx): | ||
_, heatmaps = self.__getitem__(idx) | ||
return heatmaps | ||
|
||
def get_motor(self, idx): | ||
return torch.zeros(6).cuda() if torch.cuda.is_available() else torch.zeros(6) | ||
|
||
def check_rotation(self, path_video_file): | ||
# this returns meta-data of the video file in form of a dictionary | ||
meta_dict = ffmpeg.probe(path_video_file) | ||
# from the dictionary, meta_dict['streams'][0]['tags']['rotate'] is the key | ||
# we are looking for | ||
rotate_code = None | ||
rotate = meta_dict.get('streams', [dict(tags=dict())])[0].get('tags', dict()).get('rotate', 0) | ||
return round(int(rotate) / 90.0) * 90 | ||
|
||
def load_video(self): | ||
# rotateCode = self.check_rotation(self.path) | ||
cap = cv2.VideoCapture(self.path) | ||
assert cap.isOpened(), self.path+":video load failed!" | ||
self.frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | ||
self.frameWidth = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | ||
self.frameHeight = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | ||
|
||
fc = 0 | ||
ret = True | ||
|
||
frame_list = [] | ||
|
||
while (fc < self.frameCount and ret): | ||
ret, temp = cap.read() | ||
if temp is None: | ||
break | ||
buff = np.empty((1, self.frameHeight, self.frameWidth, 3), np.dtype('uint8')) | ||
# print(fc,temp.shape) | ||
buff = temp | ||
buff = np.expand_dims(buff, 0) | ||
frame_list.append(buff) | ||
fc += 1 | ||
frames = np.concatenate(frame_list, 0) | ||
#n, h, w = frames.shape | ||
|
||
cap.release() | ||
return frames | ||
|
||
|
||
if __name__ == '__main__': | ||
path = './videos/garment_test.mov' | ||
video_loader = VideoLoader(path) | ||
print(video_loader.frames.shape) | ||
print(len(video_loader)) | ||
import matplotlib.pyplot as plt | ||
|
||
video_loader.set_bbox(0, 180, 720, 1280 - 180) | ||
plt.imshow(video_loader[200]) | ||
plt.show() |