From 14988ef8a259f536ab428a8646326acca974e111 Mon Sep 17 00:00:00 2001 From: Wok Date: Wed, 8 May 2019 19:27:39 +0200 Subject: [PATCH 1/2] Create a .tar archive before copying a checkpoint folder to Google Drive --- gpt_2_simple/gpt_2.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index f3c2dba..33e64ee 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -1,3 +1,4 @@ +import tarfile import os import json import requests @@ -448,18 +449,36 @@ def is_mounted(): assert os.path.isdir('/content/drive'), "You must mount first using mount_gdrive()" +def get_tarfile_name(checkpoint_folder): + """Converts a folder path into a filename for a .tar archive""" + tarfile_name = checkpoint_folder.replace(os.path.sep, '_') + '.tar' + + return tarfile_name + + def copy_checkpoint_to_gdrive(checkpoint_folder=os.path.join('checkpoint', 'run1')): """Copies the checkpoint folder to a mounted Google Drive.""" is_mounted() - shutil.copytree(checkpoint_folder, "/content/drive/My Drive/" + checkpoint_folder) + file_path = get_tarfile_name(checkpoint_folder) + + # Reference: https://stackoverflow.com/a/17081026 + with tarfile.open(file_path, 'w') as tar: + tar.add(checkpoint_folder, arcname=os.path.basename(checkpoint_folder)) + + shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path) def copy_checkpoint_from_gdrive(checkpoint_folder=os.path.join('checkpoint', 'run1')): """Copies the checkpoint folder from a mounted Google Drive.""" is_mounted() - shutil.copytree("/content/drive/My Drive/" + checkpoint_folder, checkpoint_folder) + file_path = get_tarfile_name(checkpoint_folder) + + shutil.copyfile("/content/drive/My Drive/" + file_path, file_path) + + with tarfile.open(file_path, 'r') as tar: + tar.extractall() def copy_file_to_gdrive(file_path): From 84c961e8d53b8e8a1968bd451bf89deb44bfdd0d Mon Sep 17 00:00:00 2001 From: Wok Date: Wed, 8 May 2019 19:49:09 +0200 Subject: [PATCH 2/2] Keep the tree structure --- gpt_2_simple/gpt_2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index 33e64ee..456c6d4 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -464,7 +464,7 @@ def copy_checkpoint_to_gdrive(checkpoint_folder=os.path.join('checkpoint', 'run1 # Reference: https://stackoverflow.com/a/17081026 with tarfile.open(file_path, 'w') as tar: - tar.add(checkpoint_folder, arcname=os.path.basename(checkpoint_folder)) + tar.add(checkpoint_folder) shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path)