Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch compile row to pytorch install table #1762

Open
wants to merge 3 commits into
base: site
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions _includes/quick-start-module.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ var opts = {
pm: 'pip',
language: 'python',
ptbuild: 'stable',
'torch-compile': null
};

var supportedCloudPlatforms = [
Expand All @@ -34,6 +35,7 @@ var package = $(".package > .option");
var language = $(".language > .option");
var cuda = $(".cuda > .option");
var ptbuild = $(".ptbuild > .option");
var torchCompile = $(".torch-compile > .option")

os.on("click", function() {
selectedOption(os, this, "os");
Expand All @@ -50,6 +52,9 @@ cuda.on("click", function() {
ptbuild.on("click", function() {
selectedOption(ptbuild, this, "ptbuild")
});
torchCompile.on("click", function() {
selectedOption(torchCompile, this, "torch-compile")
});

// Pre-select user's operating system
$(function() {
Expand Down Expand Up @@ -168,9 +173,123 @@ function changeAccNoneName(osname) {
}
}

function getIDFromBackend(backend) {
const idTobackendMap = {
inductor: 'inductor',
cgraphs : 'cudagraphs',
onnxrt: 'onnxrt',
openvino: 'openvino',
tensorrt: 'tensorrt',
tvm: 'tvm',
};
return idTobackendMap[backend];
}

function getPmCmd(backend) {
const pmCmd = {
onnxrt: 'onnxruntime',
tvm: 'apache-tvm',
openvino: 'openvino',
tensorrt: 'torch-tensorrt',
};
return pmCmd[backend];
}

function getImportCmd(backend) {
const importCmd = {
onnxrt: 'import onnxruntime',
tvm: 'import tvm',
openvino: 'import openvino.torch',
tensorrt: 'import torch_tensorrt'
}
return importCmd[backend];
}

function getInstallCommand(optionID) {
backend = getIDFromBackend(optionID);
pmCmd = getPmCmd(optionID);
finalCmd = "";
if (opts.pm == "pip") {
finalCmd = `pip3 install ${pmCmd}`;
}
else if (opts.pm == "conda") {
finalCmd = `conda install ${pmCmd}`;
}
return finalCmd;
}

function getTorchCompileUsage(optionId) {
backend = getIDFromBackend(optionId);
importCmd = getImportCmd(optionId) + "<br>";
finalCmd = "";
tcUsage = "# Torch Compile usage: " + "<br>";
backendCmd = `torch.compile(model, backend="${backend}")`;
libtorchCmd = `# Torch compile ${backend} not supported with Libtorch`;

if (opts.pm == "libtorch") {
return libtorchCmd;
}
if (backend == "inductor" || backend == "cudagraphs") {
return tcUsage + backendCmd;
}
if (backend == "openvino") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build openvino from source: https://github.com/openvinotoolkit/openvino/blob/master/docs/dev/build.md" + "<br>" ;
tcUsage += importCmd;
}
else if (opts.pm == "conda") {
tcUsage += importCmd;
}
if (opts.os == "windows" && !tcUsage.includes(importCmd)) {
tcUsage += importCmd;
}
}
else{
tcUsage += importCmd;
}
if (backend == "onnxrt") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build onnxruntime from source: https://onnxruntime.ai/docs/build" + "<br>" ;
}
}
if (backend == "tvm") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build tvm from source: https://tvm.apache.org/docs/install/from_source.html" + "<br>" ;
}
}
if (backend == "tensorrt") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build tensorrt from source: https://pytorch.org/TensorRT/getting_started/installation.html#compiling-from-source" + "<br>" ;
}
}
finalCmd += tcUsage + backendCmd;
return finalCmd
}

function addTorchCompileCommandNote(selectedOptionId) {

if (!selectedOptionId) {
return;
}
if (selectedOptionId == "inductor" || selectedOptionId == "cgraphs") {
$("#command").append(
`<pre> ${getTorchCompileUsage(selectedOptionId)} </pre>`
);
}
else {
$("#command").append(
`<pre> ${getInstallCommand(selectedOptionId)} </pre>`
);
$("#command").append(
`<pre> ${getTorchCompileUsage(selectedOptionId)} </pre>`
);
}
}

function selectedOption(option, selection, category) {
$(option).removeClass("selected");
$(selection).addClass("selected");
const previousSelection = opts[category];
opts[category] = selection.id;
if (category === "pm") {
var elements = document.getElementsByClassName("language")[0].children;
Expand Down Expand Up @@ -208,13 +327,19 @@ function selectedOption(option, selection, category) {
changeVersion(opts.ptbuild);
//make sure unsupported platforms are disabled
disableUnsupportedPlatforms(opts.os);
} else if (category === "torch-compile") {
if (selection.id === previousSelection) {
$(selection).removeClass("selected");
opts[category] = null;
}
}
commandMessage(buildMatcher());
if (category === "os") {
disableUnsupportedPlatforms(opts.os);
display(opts.os, 'installation', 'os');
}
changeAccNoneName(opts.os);
addTorchCompileCommandNote(opts['torch-compile'])
}

function display(selection, id, category) {
Expand Down
28 changes: 28 additions & 0 deletions _includes/quick_start_local.html
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
<div class="col-md-12 title-block">
<div class="option-text">Compute Platform</div>
</div>
<div class="col-md-12 title-block">
<div class="option-text">Torch Compile</div>
</div>
<div class="col-md-12 title-block command-block">
<div class="option-text command-text">Run this Command:</div>
</div>
Expand Down Expand Up @@ -103,6 +106,31 @@
<div class="option-text">CPU</div>
</div>
</div>
<div class="row torch-compile">
<!-- Section Label -->
<div class="col-md-12 title-block mobile-heading">
<div class="option-text">Torch Compile</div>
</div>
<!-- Section Label -->
<div class="col-md-2 option block version" id="inductor">
<div class="option-text">Inductor</div>
</div>
<div class="col-md-2 option block version" id="cgraphs">
<div class="option-text">CUDA Graphs</div>
</div>
<div class="col-md-2 option block version" id="openvino">
<div class="option-text">OpenVINO</div>
</div>
<div class="col-md-2 option block version" id="onnxrt">
<div class="option-text">ONNX Runtime</div>
</div>
<div class="col-md-2 option block version" id="tensorrt">
<div class="option-text">TensorRT</div>
</div>
<div class="col-md-2 option block version" id="tvm">
<div class="option-text">TVM</div>
</div>
</div>
<div class="row">
<div class="col-md-12 title-block command-mobile-heading">
<div class="option-text">Run this Command:</div>
Expand Down
125 changes: 125 additions & 0 deletions assets/quick-start-module.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ var opts = {
pm: 'pip',
language: 'python',
ptbuild: 'stable',
'torch-compile': null
};

var supportedCloudPlatforms = [
Expand All @@ -34,6 +35,7 @@ var package = $(".package > .option");
var language = $(".language > .option");
var cuda = $(".cuda > .option");
var ptbuild = $(".ptbuild > .option");
var torchCompile = $(".torch-compile > .option")

os.on("click", function() {
selectedOption(os, this, "os");
Expand All @@ -50,6 +52,9 @@ cuda.on("click", function() {
ptbuild.on("click", function() {
selectedOption(ptbuild, this, "ptbuild")
});
torchCompile.on("click", function() {
selectedOption(torchCompile, this, "torch-compile")
});

// Pre-select user's operating system
$(function() {
Expand All @@ -63,6 +68,119 @@ $(function() {
}
});

function getIDFromBackend(backend) {
const idTobackendMap = {
inductor: 'inductor',
cgraphs : 'cudagraphs',
onnxrt: 'onnxrt',
openvino: 'openvino',
tensorrt: 'tensorrt',
tvm: 'tvm',
};
return idTobackendMap[backend];
}

function getPmCmd(backend) {
const pmCmd = {
onnxrt: 'onnxruntime',
tvm: 'apache-tvm',
openvino: 'openvino',
tensorrt: 'torch-tensorrt',
};
return pmCmd[backend];
}

function getImportCmd(backend) {
const importCmd = {
onnxrt: 'import onnxruntime',
tvm: 'import tvm',
openvino: 'import openvino.torch',
tensorrt: 'import torch_tensorrt'
}
return importCmd[backend];
}

function getInstallCommand(optionID) {
backend = getIDFromBackend(optionID);
pmCmd = getPmCmd(optionID);
finalCmd = "";
if (opts.pm == "pip") {
finalCmd = `pip3 install ${pmCmd}`;
}
else if (opts.pm == "conda") {
finalCmd = `conda install ${pmCmd}`;
}
return finalCmd;
}

function getTorchCompileUsage(optionId) {
backend = getIDFromBackend(optionId);
importCmd = getImportCmd(optionId) + "<br>";
finalCmd = "";
tcUsage = "# Torch Compile usage: " + "<br>";
backendCmd = `torch.compile(model, backend="${backend}")`;
libtorchCmd = `# Torch compile ${backend} not supported with Libtorch`;
console.log("Surya log", finalCmd)

if (opts.pm == "libtorch") {
return libtorchCmd;
}
if (backend == "inductor" || backend == "cudagraphs") {
return tcUsage + backendCmd;
}
if (backend == "openvino") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build openvino from source: https://github.com/openvinotoolkit/openvino/blob/master/docs/dev/build.md" + "<br>" ;
tcUsage += importCmd;
}
else if (opts.pm == "conda") {
tcUsage += importCmd;
}
if (opts.os == "windows" && !tcUsage.includes(importCmd)) {
tcUsage += importCmd;
}
}
else{
tcUsage += importCmd;
}
if (backend == "onnxrt") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build onnxruntime from source: https://onnxruntime.ai/docs/build" + "<br>" ;
}
}
if (backend == "tvm") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build tvm from source: https://tvm.apache.org/docs/install/from_source.html" + "<br>" ;
}
}
if (backend == "tensorrt") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build tensorrt from source: https://pytorch.org/TensorRT/getting_started/installation.html#compiling-from-source" + "<br>" ;
}
}
finalCmd += tcUsage + backendCmd;
return finalCmd
}

function addTorchCompileCommandNote(selectedOptionId) {

if (!selectedOptionId) {
return;
}
if (selectedOptionId == "inductor" || selectedOptionId == "cgraphs") {
$("#command").append(
`<pre> ${getTorchCompileUsage(selectedOptionId)} </pre>`
);
}
else {
$("#command").append(
`<pre> ${getInstallCommand(selectedOptionId)} </pre>`
);
$("#command").append(
`<pre> ${getTorchCompileUsage(selectedOptionId)} </pre>`
);
}
}

// determine os (mac, linux, windows) based on user's platform
function getDefaultSelectedOS() {
Expand Down Expand Up @@ -171,6 +289,7 @@ function changeAccNoneName(osname) {
function selectedOption(option, selection, category) {
$(option).removeClass("selected");
$(selection).addClass("selected");
const previousSelection = opts[category];
opts[category] = selection.id;
if (category === "pm") {
var elements = document.getElementsByClassName("language")[0].children;
Expand Down Expand Up @@ -208,13 +327,19 @@ function selectedOption(option, selection, category) {
changeVersion(opts.ptbuild);
//make sure unsupported platforms are disabled
disableUnsupportedPlatforms(opts.os);
} else if (category === "torch-compile") {
if (selection.id === previousSelection) {
$(selection).removeClass("selected");
opts[category] = null;
}
}
commandMessage(buildMatcher());
if (category === "os") {
disableUnsupportedPlatforms(opts.os);
display(opts.os, 'installation', 'os');
}
changeAccNoneName(opts.os);
addTorchCompileCommandNote(opts['torch-compile'])
}

function display(selection, id, category) {
Expand Down