Skip to content

Commit

Permalink
Update generate_full_weights_1024x1024.py
Browse files Browse the repository at this point in the history
  • Loading branch information
w1oves authored Mar 19, 2024
1 parent d4535c6 commit a90ec5f
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tools/generate_full_weights_1024x1024.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ def main(args):
backbone = args.backbone
rein_head = args.rein_head

if not osp.isfile(dinov2_segmentor_path):
weight = torch.load(rein_head, map_location='cpu')
weight['state_dict'].update({f'backbone.{k}': v for k, v in load_backbone(backbone).items()})
torch.save(weight, dinov2_segmentor_path)
weight = torch.load(rein_head, map_location='cpu')
if 'state_dict' not in weight:
weight=dict(state_dict=weight)
weight['state_dict'].update({f'backbone.{k}': v for k, v in load_backbone(backbone).items()})
torch.save(weight, dinov2_segmentor_path)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Load and process model weights.")
Expand Down

0 comments on commit a90ec5f

Please sign in to comment.