Skip to content

Commit

Permalink
Added a video generation script and instructions to README
Browse files Browse the repository at this point in the history
  • Loading branch information
aquadzn authored and Mathilde Caron committed May 13, 2021
1 parent c66329a commit 2d21cae
Show file tree
Hide file tree
Showing 2 changed files with 330 additions and 1 deletion.
28 changes: 27 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,33 @@ You can look at the self-attention of the [CLS] token on the different heads of
python visualize_attention.py
```

Also, check out [this colab](https://gist.github.com/aquadzn/32ac53aa6e485e7c3e09b1a0914f7422) for video inference.
## Self-attention video generation
You can generate videos like the one on the blog post with `video_generation.py`.

Extract frames from input video and generate attention video:
```
python video_generation.py --input_path ../video.mp4 \
--output_dir ../output/ \
--resize 256 \
```

Use folder of frames already extracted and attention video:
```
python video_generation.py --input_path ../frames/ \
--output_dir ../output/ \
--resize 720 1280 \
--video_format avi
```

Only generate video from folder of attention maps images:
```
python video_generation.py --output_dir ../output/ \
--resize 256 \
--fps 60 \
--video_only
```

Also, check out [this colab](https://gist.github.com/aquadzn/32ac53aa6e485e7c3e09b1a0914f7422) for a video inference notebook.

<div align="center">
<img width="100%" alt="Self-attention from a Vision Transformer with 8x8 patches trained with DINO" src=".github/attention_maps.png">
Expand Down
303 changes: 303 additions & 0 deletions video_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import glob
import sys
import argparse
import cv2

from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms as pth_transforms
import numpy as np
from PIL import Image

import utils
import vision_transformer as vits


def extract_frames_from_video():
print("Extracting frames from", args.input_path)
vidcap = cv2.VideoCapture(args.input_path)
success, image = vidcap.read()
count = 0
while success:
cv2.imwrite(os.path.join(args.output_dir, f"frame-{count:04}.jpg"), image)
success, image = vidcap.read()
count += 1


def generate_video_from_images(format="mp4"):
print("Generating video...")
img_array = []
# Change format to png if needed
for filename in tqdm(sorted(glob.glob(os.path.join(args.output_dir, "*.jpg")))):
with open(filename, "rb") as f:
img = Image.open(f)
img = img.convert("RGB")
size = (img.width, img.height)

img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))

if args.video_format == "avi":
out = cv2.VideoWriter(
"video.avi", cv2.VideoWriter_fourcc(*"XVID"), args.fps, size
)
else:
out = cv2.VideoWriter(
"video.mp4", cv2.VideoWriter_fourcc(*"MP4V"), args.fps, size
)

for i in range(len(img_array)):
out.write(img_array[i])
out.release()
print("Done")


def inference(images_folder_list: str):
for img_path in tqdm(images_folder_list):
with open(img_path, "rb") as f:
img = Image.open(f)
img = img.convert("RGB")

if args.resize is not None:
transform = pth_transforms.Compose(
[
pth_transforms.ToTensor(),
pth_transforms.Resize(args.resize),
pth_transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
),
]
)
else:
transform = pth_transforms.Compose(
[
pth_transforms.ToTensor(),
pth_transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
),
]
)

img = transform(img)

# make the image divisible by the patch size
w, h = (
img.shape[1] - img.shape[1] % args.patch_size,
img.shape[2] - img.shape[2] % args.patch_size,
)
img = img[:, :w, :h].unsqueeze(0)

w_featmap = img.shape[-2] // args.patch_size
h_featmap = img.shape[-1] // args.patch_size

attentions = model.forward_selfattention(img.to(device))

nh = attentions.shape[1] # number of head

# we keep only the output patch attention
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)

# we keep only a certain percentage of the mass
val, idx = torch.sort(attentions)
val /= torch.sum(val, dim=1, keepdim=True)
cumval = torch.cumsum(val, dim=1)
th_attn = cumval > (1 - args.threshold)
idx2 = torch.argsort(idx)
for head in range(nh):
th_attn[head] = th_attn[head][idx2[head]]
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
# interpolate
th_attn = (
nn.functional.interpolate(
th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest"
)[0]
.cpu()
.numpy()
)

attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = (
nn.functional.interpolate(
attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest"
)[0]
.cpu()
.numpy()
)

# save attentions heatmaps
os.makedirs(args.output_dir, exist_ok=True)
fname = os.path.join(args.output_dir, "attn-" + os.path.basename(img_path))
plt.imsave(
fname=fname,
arr=sum(
attentions[i] * 1 / attentions.shape[0]
for i in range(attentions.shape[0])
),
cmap="inferno",
format="jpg",
)

generate_video_from_images(args.video_format)


def load_model():
# build model
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
for p in model.parameters():
p.requires_grad = False
model.eval()
model.to(device)
if os.path.isfile(args.pretrained_weights):
state_dict = torch.load(args.pretrained_weights, map_location="cpu")
if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
state_dict = state_dict[args.checkpoint_key]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
msg = model.load_state_dict(state_dict, strict=False)
print(
"Pretrained weights found at {} and loaded with msg: {}".format(
args.pretrained_weights, msg
)
)
else:
print(
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
)
url = None
if args.arch == "deit_small" and args.patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif args.arch == "deit_small" and args.patch_size == 8:
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
elif args.arch == "vit_base" and args.patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif args.arch == "vit_base" and args.patch_size == 8:
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
if url is not None:
print(
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
)
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/" + url
)
model.load_state_dict(state_dict, strict=True)
else:
print(
"There is no reference weights available for this model => We use random weights."
)
return model


def parse_args():
parser = argparse.ArgumentParser("Visualize Self-Attention maps")
parser.add_argument(
"--arch",
default="deit_small",
type=str,
choices=["deit_tiny", "deit_small", "vit_base"],
help="Architecture (support only ViT atm).",
)
parser.add_argument(
"--patch_size", default=8, type=int, help="Patch resolution of the model."
)
parser.add_argument(
"--pretrained_weights",
default="",
type=str,
help="Path to pretrained weights to load.",
)
parser.add_argument(
"--checkpoint_key",
default="teacher",
type=str,
help='Key to use in the checkpoint (example: "teacher")',
)
parser.add_argument(
"--input_path",
default=None,
type=str,
help="""Path to a video file if you want to extract frames
or to a folder of images already extracted by yourself.""",
)
parser.add_argument(
"--output_dir",
required=True,
type=str,
help="Path where to save visualizations and / or video.",
)
parser.add_argument(
"--threshold",
type=float,
default=0.6,
help="""We visualize masks
obtained by thresholding the self-attention maps to keep xx percent of the mass.""",
)
parser.add_argument(
"--resize",
default=None,
type=int,
nargs="+",
help="""Apply a resize transformation to input image(s). Use if OOM error.
Usage (single or W H): --resize 512, --resize 720 1280""",
)
parser.add_argument(
"--fps",
default=30.0,
type=float,
help="FPS of input / output video. Default: 30",
)
parser.add_argument(
"--video_only",
action="store_true",
help="""Use this flag if you only want to generate a video and not all attention images.
If used, --output_dir must be set to the folder containing attention images.""",
)
parser.add_argument(
"--video_format",
default="mp4",
type=str,
choices=["mp4", "avi"],
help="Format of generated video (mp4 or avi).",
)

return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = load_model()

# If you only want a video
if args.video_only:
generate_video_from_images(args.video_format)
else:
# If input path isn't set
if args.input_path is None:
print(f"Provided input path {args.input_path} is non valid.")
sys.exit(1)
else:
# If input path exists
if os.path.exists(args.input_path):
# If input is a video file
if os.path.isfile(args.input_path):
extract_frames_from_video()
imgs_list = [
os.path.join(args.output_dir, i)
for i in sorted(os.listdir(args.output_dir))
]
inference(imgs_list)
# If input is an images folder
if os.path.isdir(args.input_path):
imgs_list = [
os.path.join(args.input_path, i)
for i in sorted(os.listdir(args.input_path))
]
inference(imgs_list)
# If input path doesn't exists
else:
print(f"Provided video file path {args.input_path} is non valid.")
sys.exit(1)

0 comments on commit 2d21cae

Please sign in to comment.