Skip to content

Commit

Permalink
fix a model inference problem on cpu (WongKinYiu#502)
Browse files Browse the repository at this point in the history
* fix model inference problem on cpu

* Update keypoint.ipynb
  • Loading branch information
spacewalk01 authored Aug 16, 2022
1 parent 36ce6b2 commit 064c71e
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tools/keypoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"weigths = torch.load('yolov7-w6-pose.pt')\n",
"weigths = torch.load('yolov7-w6-pose.pt', map_location=device)\n",
"model = weigths['model']\n",
"model = model.half().to(device)\n",
"_ = model.eval()"
"_ = model.float().eval()\n",
"\n",
"if torch.cuda.is_available():\n",
" model.half().to(device)"
]
},
{
Expand All @@ -43,9 +45,9 @@
"image_ = image.copy()\n",
"image = transforms.ToTensor()(image)\n",
"image = torch.tensor(np.array([image.numpy()]))\n",
"image = image.to(device)\n",
"image = image.half()\n",
"\n",
"if torch.cuda.is_available():\n",
" image = image.half().to(device) \n",
"output, _ = model(image)"
]
},
Expand Down Expand Up @@ -118,7 +120,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.9.12"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 064c71e

Please sign in to comment.