Skip to content

Commit

Permalink
Merge pull request jiupinjia#17 from amrzv/edit-colab-notebook
Browse files Browse the repository at this point in the history
Edit colab notebook
  • Loading branch information
jiupinjia authored Jan 18, 2021
2 parents d6e66a9 + 9018705 commit 408ea96
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 119 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SkyAR

[Preprint](<https://arxiv.org/abs/2010.11800>) | [Project Page](<https://jiupinjia.github.io/skyar/>) | [Google Colab](<https://colab.research.google.com/drive/1-BqXD3EzDY6PHRdwb3cWayk2KictbFaz?usp=sharing>)
[Preprint](<https://arxiv.org/abs/2010.11800>) | [Project Page](<https://jiupinjia.github.io/skyar/>) | [Google Colab](https://colab.research.google.com/github/jiupinjia/SkyAR/blob/main/colab_demo.ipynb)

### Official Pytorch implementation of the preprint paper "Castle in the Sky: Dynamic Sky Replacement and Harmonization in Videos", in arXiv:2010.11800.

Expand Down Expand Up @@ -103,7 +103,7 @@ If you want to try on your own data, or want a different blending style, you can

## Google Colab

Here we also provide a minimal working example of the inference runtime of our method. Check out [this link](https://colab.research.google.com/drive/1-BqXD3EzDY6PHRdwb3cWayk2KictbFaz?usp=sharing) and see your result on Colab.
Here we also provide a minimal working example of the inference runtime of our method. Check out [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jiupinjia/SkyAR/blob/main/colab_demo.ipynb) and see your result on Colab.



Expand Down
152 changes: 35 additions & 117 deletions colab_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"outputId": "4dd5f90b-53bb-4084-b18c-0ca2f5c86466",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 117.0
"height": 117
}
},
"outputs": [
Expand All @@ -47,7 +47,7 @@
],
"source": [
"# Clone the repository\n",
"!git clone https://github.com/jiupinjia/SkyAR.git "
"!git clone https://github.com/jiupinjia/SkyAR.git"
]
},
{
Expand All @@ -58,7 +58,7 @@
"outputId": "fd259ee6-29aa-4fe6-d51c-b70d57007195",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33.0
"height": 33
}
},
"outputs": [
Expand Down Expand Up @@ -86,7 +86,6 @@
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"import os\n",
"import glob\n",
"import argparse\n",
"from networks import *\n",
"from skyboxengine import *\n",
Expand All @@ -108,48 +107,6 @@
"Download pretrained sky matting model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I669W3x7eLPJ"
},
"outputs": [],
"source": [
"# Define some helper functions for downloading pretrained model\n",
"# taken from this StackOverflow answer: https://stackoverflow.com/a/39225039\n",
"import requests\n",
"\n",
"def download_file_from_google_drive(id, destination):\n",
" URL = \"https://docs.google.com/uc?export=download\"\n",
"\n",
" session = requests.Session()\n",
"\n",
" response = session.get(URL, params = { 'id' : id }, stream = True)\n",
" token = get_confirm_token(response)\n",
"\n",
" if token:\n",
" params = { 'id' : id, 'confirm' : token }\n",
" response = session.get(URL, params = params, stream = True)\n",
"\n",
" save_response_content(response, destination) \n",
"\n",
"def get_confirm_token(response):\n",
" for key, value in response.cookies.items():\n",
" if key.startswith('download_warning'):\n",
" return value\n",
"\n",
" return None\n",
"\n",
"def save_response_content(response, destination):\n",
" CHUNK_SIZE = 32768\n",
"\n",
" with open(destination, \"wb\") as f:\n",
" for chunk in response.iter_content(CHUNK_SIZE):\n",
" if chunk: # filter out keep-alive new chunks\n",
" f.write(chunk)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -160,34 +117,10 @@
"source": [
"# download and unzip...\n",
"file_id = '1COMROzwR4R_7mym6DL9LXhHQlJmJaV0J'\n",
"destination = './checkpoints_G_coord_resnet50.zip'\n",
"download_file_from_google_drive(file_id, destination)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Eu13Edbtv3fl",
"outputId": "0130edcc-6a2e-4696-b79a-3eb67408f8c7",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 67.0
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Archive: checkpoints_G_coord_resnet50.zip\n",
" creating: checkpoints_G_coord_resnet50/\n",
" inflating: checkpoints_G_coord_resnet50/best_ckpt.pt \n"
]
}
],
"source": [
"!unzip checkpoints_G_coord_resnet50.zip"
"file_name = 'checkpoints_G_coord_resnet50.zip'\n",
"\n",
"!gdown --id {file_id}\n",
"!unzip {file_name}"
]
},
{
Expand Down Expand Up @@ -233,7 +166,7 @@
"# args.out_size_h = 480 # ...\n",
"\n",
"# args.skybox_center_crop = 0.5 # view of the virtual camera\n",
"# args.auto_light_matching = False \n",
"# args.auto_light_matching = False\n",
"# args.relighting_factor = 0.8\n",
"# args.recoloring_factor = 0.5\n",
"# args.halo_effect = True"
Expand Down Expand Up @@ -272,10 +205,14 @@
" self.net_G = define_G(input_nc=3, output_nc=1, ngf=64, netG=args.net_G).to(device)\n",
" self.load_model()\n",
"\n",
" self.video_writer = cv2.VideoWriter('demo.avi', cv2.VideoWriter_fourcc(*'MJPG'),\n",
" 20.0, (args.out_size_w, args.out_size_h))\n",
" self.video_writer_cat = cv2.VideoWriter('demo-cat.avi', cv2.VideoWriter_fourcc(*'MJPG'),\n",
" 20.0, (2*args.out_size_w, args.out_size_h))\n",
" self.video_writer = cv2.VideoWriter('demo.avi',\n",
" cv2.VideoWriter_fourcc(*'MJPG'),\n",
" 20.0,\n",
" (args.out_size_w, args.out_size_h))\n",
" self.video_writer_cat = cv2.VideoWriter('demo-cat.avi',\n",
" cv2.VideoWriter_fourcc(*'MJPG'),\n",
" 20.0,\n",
" (2*args.out_size_w, args.out_size_h))\n",
"\n",
" if os.path.exists(args.output_dir) is False:\n",
" os.mkdir(args.output_dir)\n",
Expand All @@ -288,7 +225,8 @@
" def load_model(self):\n",
" # load pretrained sky matting model\n",
" print('loading the best checkpoint...')\n",
" checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'))\n",
" checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'),\n",
" map_location=device)\n",
" self.net_G.load_state_dict(checkpoint['model_G_state_dict'])\n",
" self.net_G.to(device)\n",
" self.net_G.eval()\n",
Expand All @@ -305,7 +243,7 @@
"\n",
" # define a result buffer\n",
" self.output_img_list.append(frame_cat)\n",
" \n",
"\n",
"\n",
" def synthesize(self, img_HD, img_HD_prev):\n",
"\n",
Expand All @@ -318,7 +256,10 @@
"\n",
" with torch.no_grad():\n",
" G_pred = self.net_G(img.to(device))\n",
" G_pred = torch.nn.functional.interpolate(G_pred, (h, w), mode='bicubic', align_corners=False)\n",
" G_pred = torch.nn.functional.interpolate(G_pred,\n",
" (h, w),\n",
" mode='bicubic',\n",
" align_corners=False)\n",
" G_pred = G_pred[0, :].permute([1, 2, 0])\n",
" G_pred = torch.cat([G_pred, G_pred, G_pred], dim=-1)\n",
" G_pred = np.array(G_pred.detach().cpu())\n",
Expand All @@ -331,7 +272,6 @@
" return syneth, G_pred, skymask\n",
"\n",
"\n",
"\n",
" def cvtcolor_and_resize(self, img_HD):\n",
"\n",
" img_HD = cv2.cvtColor(img_HD, cv2.COLOR_BGR2RGB)\n",
Expand All @@ -348,7 +288,7 @@
" cap = cv2.VideoCapture(self.datadir)\n",
" m_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
" img_HD_prev = None\n",
" \n",
"\n",
" for idx in range(m_frames):\n",
" ret, frame = cap.read()\n",
" if ret:\n",
Expand All @@ -363,11 +303,11 @@
"\n",
" img_HD_prev = img_HD\n",
"\n",
" if idx % 50 == 1:\n",
" print('processing video, frame %d / %d ... ' % (idx, m_frames))\n",
" if (idx + 1) % 50 == 0:\n",
" print(f'processing video, frame {idx + 1} / {m_frames} ... ')\n",
"\n",
" else: # if reach the last frame\n",
" break\n"
" break"
]
},
{
Expand All @@ -387,7 +327,7 @@
"outputId": "8d7af558-d532-49d9-bd24-fdfedb5d257d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 150.0
"height": 150
}
},
"outputs": [
Expand Down Expand Up @@ -428,45 +368,23 @@
},
"outputs": [],
"source": [
"# Check out your results at './SkyAR/demo.avi' and './SkyAR/demo-cat.avi'. \n",
"# Check out your results at './SkyAR/demo.avi' and './SkyAR/demo-cat.avi'.\n",
"# Download them and enjoy.\n",
"\n",
"# If you would like to pre-view your results. Run the following to see the animated \n",
"# results of the first 40 frames.\n",
"# If you would like to pre-view your results. Run the following to see the\n",
"# animated results of the first 40 frames.\n",
"\n",
"import matplotlib.animation as animation\n",
"from IPython.display import HTML\n",
"\n",
"fig = plt.figure(figsize=(4,8))\n",
"fig = plt.figure(figsize=(8, 4))\n",
"plt.axis('off')\n",
"fig.axes.get_yaxis().set_visible(False)\n",
"ims = [[plt.imshow(img[:,:,::-1], animated=True)] for img in sf.output_img_list[0:40]]\n",
"ims = [[plt.imshow(img[:, :, ::-1], animated=True)]\n",
" for img in sf.output_img_list[0:40]]\n",
"ani = animation.ArtistAnimation(fig, ims, interval=50)\n",
"\n",
"HTML(ani.to_jshtml())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Gpg5oPhdySnX"
},
"outputs": [],
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w0axvUtii8ua"
},
"outputs": [],
"source": [
""
]
}
],
"metadata": {
Expand All @@ -483,4 +401,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

0 comments on commit 408ea96

Please sign in to comment.