Skip to content

Commit 6b1c214

Browse files
committed
Pull missing changes from develop
1 parent 8aa260a commit 6b1c214

File tree

3 files changed

+288
-7
lines changed

3 files changed

+288
-7
lines changed

gdl/datasets/FaceAlignmentTools.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import numpy as np
2+
from pathlib import Path
3+
from gdl.datasets.ImageDatasetHelpers import bbox2point, bbpoint_warp
4+
import skvideo
5+
import types
6+
7+
8+
def align_face(image, landmarks, landmark_type, scale_adjustment, target_size_height, target_size_width=None,):
9+
"""
10+
Returns an image with the face aligned to the center of the image.
11+
:param image: The full resolution image in which to align the face.
12+
:param landmarks: The landmarks of the face in the image (in the original image coordinates).
13+
:param landmark_type: The type of landmarks. Such as 'kpt68' or 'bbox' or 'mediapipe'.
14+
:param scale_adjustment: The scale adjustment to apply to the image.
15+
:param target_size_height: The height of the output image.
16+
:param target_size_width: The width of the output image. If not provided, it is assumed to be the same as target_size_height.
17+
:return: The aligned face image. The image will be in range [0,1].
18+
"""
19+
# landmarks_for_alignment = "mediapipe"
20+
left = landmarks[:,0].min()
21+
top = landmarks[:,1].min()
22+
right = landmarks[:,0].max()
23+
bottom = landmarks[:,1].max()
24+
25+
old_size, center = bbox2point(left, right, top, bottom, type=landmark_type)
26+
size = (old_size * scale_adjustment).astype(np.int32)
27+
28+
img_warped, lmk_warped = bbpoint_warp(image, center, size, target_size_height, target_size_width, landmarks=landmarks)
29+
30+
return img_warped
31+
32+
33+
def align_video(video, centers, sizes, landmarks, target_size_height, target_size_width=None, ):
34+
"""
35+
Returns a video with the face aligned to the center of the image.
36+
:param video: The full resolution video in which to align the face.
37+
:param landmarks: The landmarks of the face in the video (in the original video coordinates).
38+
:param target_size_height: The height of the output video.
39+
:param target_size_width: The width of the output video. If not provided, it is assumed to be the same as target_size_height.
40+
:return: The aligned face video. The video will be in range [0,1].
41+
"""
42+
if isinstance(video, (str, Path)):
43+
video = skvideo.io.vread(video)
44+
elif isinstance(video, (np.ndarray, types.GeneratorType)):
45+
pass
46+
else:
47+
raise ValueError("video must be a string, Path, or numpy array")
48+
49+
aligned_video = []
50+
warped_landmarks = []
51+
if isinstance(video, np.ndarray):
52+
for i in range(len(centers)):
53+
img_warped, lmk_warped = bbpoint_warp(video[i], centers[i], sizes[i],
54+
target_size_height=target_size_height, target_size_width=target_size_width,
55+
landmarks=landmarks[i])
56+
aligned_video.append(img_warped)
57+
warped_landmarks += [lmk_warped]
58+
59+
elif isinstance(video, types.GeneratorType):
60+
for i, frame in enumerate(video):
61+
img_warped, lmk_warped = bbpoint_warp(frame, centers[i], sizes[i],
62+
target_size_height=target_size_height, target_size_width=target_size_width,
63+
landmarks=landmarks[i])
64+
aligned_video.append(img_warped)
65+
warped_landmarks += [lmk_warped]
66+
67+
aligned_video = np.stack(aligned_video, axis=0)
68+
return aligned_video, warped_landmarks
69+
70+
71+
def align_and_save_video(video, out_video_path, centers, sizes, landmarks, target_size_height, target_size_width=None, output_dict=None):
72+
"""
73+
Returns a video with the face aligned to the center of the image.
74+
:param video: The full resolution video in which to align the face.
75+
:param landmarks: The landmarks of the face in the video (in the original video coordinates).
76+
:param target_size_height: The height of the output video.
77+
:param target_size_width: The width of the output video. If not provided, it is assumed to be the same as target_size_height.
78+
:return: The aligned face video. The video will be in range [0,1].
79+
"""
80+
if isinstance(video, (str, Path)):
81+
video = skvideo.io.vread(video)
82+
elif isinstance(video, (np.ndarray, types.GeneratorType)):
83+
pass
84+
else:
85+
raise ValueError("video must be a string, Path, or numpy array")
86+
87+
writer = skvideo.io.FFmpegWriter(str(out_video_path), outputdict=output_dict)
88+
warped_landmarks = []
89+
if isinstance(video, np.ndarray):
90+
for i in range(len(centers)):
91+
img_warped, lmk_warped = bbpoint_warp(video[i], centers[i], sizes[i],
92+
target_size_height=target_size_height, target_size_width=target_size_width,
93+
landmarks=landmarks[i])
94+
img_warped = (img_warped * 255).astype(np.uint8)
95+
writer.writeFrame(img_warped)
96+
warped_landmarks += [lmk_warped]
97+
98+
elif isinstance(video, types.GeneratorType):
99+
for i, frame in enumerate(video):
100+
img_warped, lmk_warped = bbpoint_warp(frame, centers[i], sizes[i],
101+
target_size_height=target_size_height, target_size_width=target_size_width,
102+
landmarks=landmarks[i])
103+
img_warped = (img_warped * 255).astype(np.uint8)
104+
writer.writeFrame(img_warped)
105+
warped_landmarks += [lmk_warped]
106+
writer.close()
107+
108+
return warped_landmarks

gdl/datasets/ImageDatasetHelpers.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,25 @@ def bbox2point(left, right, top, bottom, type='bbox'):
2727
'''
2828
if type == 'kpt68':
2929
old_size = (right - left + bottom - top) / 2 * 1.1
30-
center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
30+
center_x = right - (right - left) / 2.0
31+
center_y = bottom - (bottom - top) / 2.0
32+
# center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
3133
elif type == 'bbox':
3234
old_size = (right - left + bottom - top) / 2
33-
center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size * 0.12])
35+
center_x = right - (right - left) / 2.0
36+
center_y = bottom - (bottom - top) / 2.0 + old_size * 0.12
37+
# center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size * 0.12])
38+
elif type == "mediapipe":
39+
old_size = (right - left + bottom - top) / 2 * 1.1
40+
center_x = right - (right - left) / 2.0
41+
center_y = bottom - (bottom - top) / 2.0
42+
# center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
3443
else:
35-
raise NotImplementedError
44+
raise NotImplementedError(f" bbox2point not implemented for {type} ")
45+
if isinstance(center_x, np.ndarray):
46+
center = np.stack([center_x, center_y], axis=1)
47+
else:
48+
center = np.array([center_x, center_y])
3649
return old_size, center
3750

3851

@@ -53,15 +66,31 @@ def point2transform(center, size, target_size_height, target_size_width):
5366
return tform
5467

5568

56-
def bbpoint_warp(image, center, size, target_size_height, target_size_width=None, output_shape=None, inv=True, landmarks=None):
69+
def bbpoint_warp(image, center, size, target_size_height, target_size_width=None, output_shape=None, inv=True, landmarks=None,
70+
order=3 # order of interpolation, bicubic by default
71+
):
5772
target_size_width = target_size_width or target_size_height
5873
tform = point2transform(center, size, target_size_height, target_size_width)
5974
tf = tform.inverse if inv else tform
6075
output_shape = output_shape or (target_size_height, target_size_width)
61-
dst_image = warp(image, tf, output_shape=output_shape, order=3)
76+
dst_image = warp(image, tf, output_shape=output_shape, order=order)
6277
if landmarks is None:
6378
return dst_image
6479
# points need the matrix
65-
tf_lmk = tform if inv else tform.inverse
66-
dst_landmarks = tf_lmk(landmarks)
80+
if isinstance(landmarks, np.ndarray):
81+
assert isinstance(landmarks, np.ndarray)
82+
tf_lmk = tform if inv else tform.inverse
83+
dst_landmarks = tf_lmk(landmarks[:, :2])
84+
elif isinstance(landmarks, list):
85+
tf_lmk = tform if inv else tform.inverse
86+
dst_landmarks = []
87+
for i in range(len(landmarks)):
88+
dst_landmarks += [tf_lmk(landmarks[i][:, :2])]
89+
elif isinstance(landmarks, dict):
90+
tf_lmk = tform if inv else tform.inverse
91+
dst_landmarks = {}
92+
for key, value in landmarks.items():
93+
dst_landmarks[key] = tf_lmk(landmarks[key][:, :2])
94+
else:
95+
raise ValueError("landmarks must be np.ndarray, list or dict")
6796
return dst_image, dst_landmarks
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
Author: Radek Danecek
3+
Copyright (c) 2022, Radek Danecek
4+
All rights reserved.
5+
6+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
7+
# holder of all proprietary rights on this computer program.
8+
# Using this computer program means that you agree to the terms
9+
# in the LICENSE file included with this software distribution.
10+
# Any use not explicitly granted by the LICENSE is prohibited.
11+
#
12+
# Copyright©2022 Max-Planck-Gesellschaft zur Förderung
13+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14+
# for Intelligent Systems. All rights reserved.
15+
#
16+
# For comments or questions, please email us at emoca@tue.mpg.de
17+
# For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
18+
"""
19+
20+
21+
import numpy as np
22+
import torch
23+
from PIL import Image
24+
from skimage.io import imread
25+
from torchvision.transforms import ToTensor
26+
27+
from gdl.utils.FaceDetector import load_landmark
28+
from gdl.datasets.FaceAlignmentTools import align_face
29+
30+
from skvideo.io import vread, vreader
31+
from types import GeneratorType
32+
import pickle as pkl
33+
34+
class VideoFaceDetectionDataset(torch.utils.data.Dataset):
35+
36+
def __init__(self, video_name, landmark_path, image_transforms=None,
37+
align_landmarks=False, vid_read=None, output_im_range=None,
38+
scale_adjustment=1.25,
39+
target_size_height=256,
40+
target_size_width=256,
41+
):
42+
super().__init__()
43+
self.video_name = video_name
44+
self.landmark_path = landmark_path / "landmarks_original.pkl"
45+
# if landmark_list is not None and len(lanmark_file_name) != len(image_list):
46+
# raise RuntimeError("There must be a landmark for every image")
47+
self.image_transforms = image_transforms
48+
self.vid_read = vid_read or 'skvreader' # 'skvread'
49+
self.prev_index = -1
50+
51+
self.scale_adjustment=scale_adjustment
52+
self.target_size_height=target_size_height
53+
self.target_size_width=target_size_width
54+
55+
self.video_frames = None
56+
if self.vid_read == "skvread":
57+
self.video_frames = vread(str(self.video_name))
58+
elif self.vid_read == "skvreader":
59+
self.video_frames = vreader(str(self.video_name))
60+
61+
with open(self.landmark_path, "rb") as f:
62+
self.landmark_list = pkl.load(f)
63+
64+
with open(landmark_path / "landmark_types.pkl", "rb") as f:
65+
self.landmark_types = pkl.load(f)
66+
67+
self.total_len = 0
68+
self.frame_map = {} # detection index to frame map
69+
self.index_for_frame_map = {} # detection index to frame map
70+
for i in range(len(self.landmark_list)):
71+
for j in range(len(self.landmark_list[i])):
72+
self.frame_map[self.total_len + j] = i
73+
self.index_for_frame_map[self.total_len + j] = j
74+
self.total_len += len(self.landmark_list[i])
75+
76+
self.output_im_range = output_im_range
77+
78+
79+
def __getitem__(self, index):
80+
# if index < len(self.image_list):
81+
# x = self.mnist_data[index]
82+
# raise IndexError("Out of bounds")
83+
if index != self.prev_index+1 and self.vid_read != 'skvread':
84+
raise RuntimeError("This dataset is meant to be accessed in ordered way only (and with 0 or 1 workers)")
85+
86+
frame_index = self.frame_map[index]
87+
detection_in_frame_index = self.index_for_frame_map[index]
88+
landmark = self.landmark_list[frame_index][detection_in_frame_index]
89+
landmark_type = self.landmark_types[frame_index][detection_in_frame_index]
90+
91+
if isinstance(self.video_frames, np.ndarray):
92+
img = self.video_frames[frame_index, ...]
93+
elif isinstance(self.video_frames, GeneratorType):
94+
img = next(self.video_frames)
95+
else:
96+
raise NotImplementedError()
97+
98+
# try:
99+
# if self.vid_read == 'skvread':
100+
# img = vread(self.image_list[index])
101+
# img = img.transpose([2, 0, 1]).astype(np.float32)
102+
# img_torch = torch.from_numpy(img)
103+
# path = str(self.image_list[index])
104+
# elif self.vid_read == 'pil':
105+
# img = Image.open(self.image_list[index])
106+
# img_torch = ToTensor()(img)
107+
# path = str(self.image_list[index])
108+
# # path = f"{index:05d}"
109+
# else:
110+
# raise ValueError(f"Invalid image reading method {self.im_read}")
111+
# except Exception as e:
112+
# print(f"Failed to read '{self.image_list[index]}'. File is probably corrupted. Rerun data processing")
113+
# raise e
114+
115+
# crop out the face
116+
img = align_face(img, landmark, landmark_type, scale_adjustment=1.25, target_size_height=256, target_size_width=256,)
117+
if self.output_im_range == 255:
118+
img = img * 255.0
119+
img = img.astype(np.float32)
120+
img_torch = ToTensor()(img)
121+
122+
# # plot img with pyplot
123+
# import matplotlib.pyplot as plt
124+
# plt.figure()
125+
# plt.imshow(img)
126+
# plt.show()
127+
# # plot image with plotly
128+
# import plotly.graph_objects as go
129+
# fig = go.Figure(data=go.Image(z=img*255.,))
130+
# fig.show()
131+
132+
133+
if self.image_transforms is not None:
134+
img_torch = self.image_transforms(img_torch)
135+
136+
batch = {"image" : img_torch,
137+
# "path" : path
138+
}
139+
140+
self.prev_index += 1
141+
return batch
142+
143+
def __len__(self):
144+
return self.total_len

0 commit comments

Comments
 (0)