forked from pytorch/pytorch.github.io
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_install.py
83 lines (67 loc) · 2.55 KB
/
test_install.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python3
def read_published_versions():
import json
import os
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
with open(os.path.join(base_dir, "published_versions.json")) as fp:
return json.load(fp)
def get_os() -> str:
import sys
if sys.platform.startswith("darwin"):
return "macos"
if sys.platform.startswith("linux"):
return "linux"
if sys.platform.startswith("win32") or sys.platform.startswith("cygwin"):
return "windows"
raise RuntimeError(f"Unknown platform {sys.platform}")
def get_acc() -> str:
import os
return os.getenv("TEST_ACC", "accnone")
def get_ver() -> str:
import os
return os.getenv("TEST_VER", "latest_stable")
def get_pkg_type() -> str:
import sys
if len(sys.argv) > 1 and sys.argv[1] == "--conda":
return "conda"
return "pip"
def main() -> None:
import subprocess
import sys
published_versions = read_published_versions()
os = get_os()
acc = get_acc()
pkg_type = get_pkg_type()
version = get_ver()
if version in ["latest_lts", "latest_stable"]:
version = published_versions[version]
versions = published_versions["versions"][version]
pkg_vers = versions[os][pkg_type]
acc_vers = pkg_vers[acc]
note, cmd = acc_vers["note"], acc_vers["command"]
if cmd is None:
print(note)
sys.exit(0)
# Check that PyTorch + Domains are installable
print(f"Installing PyTorch {version} + {acc} using {pkg_type} and Python {sys.version}")
if pkg_type == "pip":
cmd_args = [sys.executable] + cmd.split(" ")
cmd_args[1] = "-mpip"
subprocess.check_call(cmd_args)
else:
assert pkg_type == "conda"
args = cmd.split(" ")
# Add `-y` argument
for idx, arg in enumerate(args):
if arg == "install":
args.insert(idx +1, "-y")
subprocess.check_call(args)
# Check that torch is importable after install
subprocess.check_call([sys.executable, "-c", "import torch;print('PyTorch version is ', torch.__version__)"])
subprocess.check_call([sys.executable, "-c", "import torchvision;print('torchvision version is ', torchvision.__version__)"])
subprocess.check_call([sys.executable,
"-c",
"import torch;import torchvision;print('Is torchvision useable?', all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]))"
])
if __name__ == "__main__":
main()