Skip to content

Commit

Permalink
Merge pull request pytorch#864 from kleinicke/master
Browse files Browse the repository at this point in the history
fast neural style: run onnxmodel using onnxruntime
  • Loading branch information
msaroufim authored Mar 17, 2022
2 parents 8016876 + 2cde431 commit 15a638f
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions fast_neural_style/neural_style/neural_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def stylize(args):
content_image = content_image.unsqueeze(0).to(device)

if args.model.endswith(".onnx"):
output = stylize_onnx_caffe2(content_image, args)
output = stylize_onnx(content_image, args)
else:
with torch.no_grad():
style_model = TransformerNet()
Expand All @@ -142,31 +142,40 @@ def stylize(args):
del state_dict[k]
style_model.load_state_dict(state_dict)
style_model.to(device)
style_model.eval()
if args.export_onnx:
assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu()
output = torch.onnx._export(
style_model, content_image, args.export_onnx, opset_version=11,
).cpu()
else:
output = style_model(content_image).cpu()
utils.save_image(args.output_image, output[0])


def stylize_onnx_caffe2(content_image, args):
def stylize_onnx(content_image, args):
"""
Read ONNX model and run it using Caffe2
Read ONNX model and run it using onnxruntime
"""

assert not args.export_onnx

import onnx
import onnx_caffe2.backend
import onnxruntime

model = onnx.load(args.model)
ort_session = onnxruntime.InferenceSession(args.model)

prepared_backend = onnx_caffe2.backend.prepare(model, device='CUDA' if args.cuda else 'CPU')
inp = {model.graph.input[0].name: content_image.numpy()}
c2_out = prepared_backend.run(inp)[0]
def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)

return torch.from_numpy(c2_out)
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(content_image)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]

return torch.from_numpy(img_out_y)


def main():
Expand Down

0 comments on commit 15a638f

Please sign in to comment.