diff --git a/_includes/quick-start-module.js b/_includes/quick-start-module.js
index 37da3ab79893..f17ece931ec8 100644
--- a/_includes/quick-start-module.js
+++ b/_includes/quick-start-module.js
@@ -21,6 +21,7 @@ var opts = {
pm: 'pip',
language: 'python',
ptbuild: 'stable',
+ 'torch-compile': null
};
var supportedCloudPlatforms = [
@@ -34,6 +35,7 @@ var package = $(".package > .option");
var language = $(".language > .option");
var cuda = $(".cuda > .option");
var ptbuild = $(".ptbuild > .option");
+var torchCompile = $(".torch-compile > .option")
os.on("click", function() {
selectedOption(os, this, "os");
@@ -50,6 +52,9 @@ cuda.on("click", function() {
ptbuild.on("click", function() {
selectedOption(ptbuild, this, "ptbuild")
});
+torchCompile.on("click", function() {
+ selectedOption(torchCompile, this, "torch-compile")
+});
// Pre-select user's operating system
$(function() {
@@ -168,9 +173,123 @@ function changeAccNoneName(osname) {
}
}
+function getIDFromBackend(backend) {
+ const idTobackendMap = {
+ inductor: 'inductor',
+ cgraphs : 'cudagraphs',
+ onnxrt: 'onnxrt',
+ openvino: 'openvino',
+ tensorrt: 'tensorrt',
+ tvm: 'tvm',
+ };
+ return idTobackendMap[backend];
+}
+
+function getPmCmd(backend) {
+ const pmCmd = {
+ onnxrt: 'onnxruntime',
+ tvm: 'apache-tvm',
+ openvino: 'openvino',
+ tensorrt: 'torch-tensorrt',
+ };
+ return pmCmd[backend];
+}
+
+function getImportCmd(backend) {
+ const importCmd = {
+ onnxrt: 'import onnxruntime',
+ tvm: 'import tvm',
+ openvino: 'import openvino.torch',
+ tensorrt: 'import torch_tensorrt'
+ }
+ return importCmd[backend];
+}
+
+function getInstallCommand(optionID) {
+ backend = getIDFromBackend(optionID);
+ pmCmd = getPmCmd(optionID);
+ finalCmd = "";
+ if (opts.pm == "pip") {
+ finalCmd = `pip3 install ${pmCmd}`;
+ }
+ else if (opts.pm == "conda") {
+ finalCmd = `conda install ${pmCmd}`;
+ }
+ return finalCmd;
+}
+
+function getTorchCompileUsage(optionId) {
+ backend = getIDFromBackend(optionId);
+ importCmd = getImportCmd(optionId) + "
";
+ finalCmd = "";
+ tcUsage = "# Torch Compile usage: " + "
";
+ backendCmd = `torch.compile(model, backend="${backend}")`;
+ libtorchCmd = `# Torch compile ${backend} not supported with Libtorch`;
+
+ if (opts.pm == "libtorch") {
+ return libtorchCmd;
+ }
+ if (backend == "inductor" || backend == "cudagraphs") {
+ return tcUsage + backendCmd;
+ }
+ if (backend == "openvino") {
+ if (opts.pm == "source") {
+ finalCmd += "# Follow instructions at this URL to build openvino from source: https://github.com/openvinotoolkit/openvino/blob/master/docs/dev/build.md" + "
" ;
+ tcUsage += importCmd;
+ }
+ else if (opts.pm == "conda") {
+ tcUsage += importCmd;
+ }
+ if (opts.os == "windows" && !tcUsage.includes(importCmd)) {
+ tcUsage += importCmd;
+ }
+ }
+ else{
+ tcUsage += importCmd;
+ }
+ if (backend == "onnxrt") {
+ if (opts.pm == "source") {
+ finalCmd += "# Follow instructions at this URL to build onnxruntime from source: https://onnxruntime.ai/docs/build" + "
" ;
+ }
+ }
+ if (backend == "tvm") {
+ if (opts.pm == "source") {
+ finalCmd += "# Follow instructions at this URL to build tvm from source: https://tvm.apache.org/docs/install/from_source.html" + "
" ;
+ }
+ }
+ if (backend == "tensorrt") {
+ if (opts.pm == "source") {
+ finalCmd += "# Follow instructions at this URL to build tensorrt from source: https://pytorch.org/TensorRT/getting_started/installation.html#compiling-from-source" + "
" ;
+ }
+ }
+ finalCmd += tcUsage + backendCmd;
+ return finalCmd
+}
+
+function addTorchCompileCommandNote(selectedOptionId) {
+
+ if (!selectedOptionId) {
+ return;
+ }
+ if (selectedOptionId == "inductor" || selectedOptionId == "cgraphs") {
+ $("#command").append(
+ `
${getTorchCompileUsage(selectedOptionId)}` + ); + } + else { + $("#command").append( + `
${getInstallCommand(selectedOptionId)}` + ); + $("#command").append( + `
${getTorchCompileUsage(selectedOptionId)}` + ); + } +} + function selectedOption(option, selection, category) { $(option).removeClass("selected"); $(selection).addClass("selected"); + const previousSelection = opts[category]; opts[category] = selection.id; if (category === "pm") { var elements = document.getElementsByClassName("language")[0].children; @@ -208,6 +327,11 @@ function selectedOption(option, selection, category) { changeVersion(opts.ptbuild); //make sure unsupported platforms are disabled disableUnsupportedPlatforms(opts.os); + } else if (category === "torch-compile") { + if (selection.id === previousSelection) { + $(selection).removeClass("selected"); + opts[category] = null; + } } commandMessage(buildMatcher()); if (category === "os") { @@ -215,6 +339,7 @@ function selectedOption(option, selection, category) { display(opts.os, 'installation', 'os'); } changeAccNoneName(opts.os); + addTorchCompileCommandNote(opts['torch-compile']) } function display(selection, id, category) { diff --git a/_includes/quick_start_local.html b/_includes/quick_start_local.html index d56c586a2f02..cd59dccca14c 100644 --- a/_includes/quick_start_local.html +++ b/_includes/quick_start_local.html @@ -24,6 +24,9 @@
${getTorchCompileUsage(selectedOptionId)}` + ); + } + else { + $("#command").append( + `
${getInstallCommand(selectedOptionId)}` + ); + $("#command").append( + `
${getTorchCompileUsage(selectedOptionId)}` + ); + } +} // determine os (mac, linux, windows) based on user's platform function getDefaultSelectedOS() { @@ -171,6 +289,7 @@ function changeAccNoneName(osname) { function selectedOption(option, selection, category) { $(option).removeClass("selected"); $(selection).addClass("selected"); + const previousSelection = opts[category]; opts[category] = selection.id; if (category === "pm") { var elements = document.getElementsByClassName("language")[0].children; @@ -208,6 +327,11 @@ function selectedOption(option, selection, category) { changeVersion(opts.ptbuild); //make sure unsupported platforms are disabled disableUnsupportedPlatforms(opts.os); + } else if (category === "torch-compile") { + if (selection.id === previousSelection) { + $(selection).removeClass("selected"); + opts[category] = null; + } } commandMessage(buildMatcher()); if (category === "os") { @@ -215,6 +339,7 @@ function selectedOption(option, selection, category) { display(opts.os, 'installation', 'os'); } changeAccNoneName(opts.os); + addTorchCompileCommandNote(opts['torch-compile']) } function display(selection, id, category) {