Skip to content

Commit

Permalink
Fix for .bin files with shared weights
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Jan 22, 2024
1 parent ec75362 commit 9373d0c
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion util/convert_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
from safetensors.torch import save_file

parser = argparse.ArgumentParser(description="Convert .bin/.pt files to .safetensors")
parser.add_argument("--unshare", action = "store_true", help="Detach tensors to prevent any from sharing memory")
parser.add_argument("input_files", nargs='+', type=str, help="Input file(s)")
args = parser.parse_args()

for file in args.input_files:
print(f" -- Loading {file}...")
state_dict = torch.load(file, map_location="cpu")
state_dict = torch.load(file, map_location = "cpu")

if args.unshare:
for k in state_dict.keys():
state_dict[k] = state_dict[k].clone().detach()

out_file = os.path.splitext(file)[0] + ".safetensors"
print(f" -- Saving {out_file}...")
Expand Down

0 comments on commit 9373d0c

Please sign in to comment.