Skip to content

nirvedhmeshram/SHARK

Repository files navigation

SHARK

High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerators and Heterogeneous Clusters

Nightly Release Validate torch-models on Shark Runtime

Communication Channels

Installation

Installation (Linux and macOS)

Setup a new pip Virtual Environment

This step sets up a new VirtualEnv for Python

python --version #Check you have 3.7->3.10 on Linux or 3.10 on macOS
python -m venv shark_venv
source shark_venv/bin/activate

# If you are using conda create and activate a new conda env

# Some older pip installs may not be able to handle the recent PyTorch deps
python -m pip install --upgrade pip

macOS Metal users please install https://sdk.lunarg.com/sdk/download/latest/mac/vulkan-sdk.dmg and enable "System wide install"

Install SHARK

This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10

pip install nodai-shark -f https://github.com/nod-ai/SHARK/releases -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/shark-runtime/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu

If you are on an Intel macOS machine you need this workaround for an upstream issue.

Download and run Resnet50 sample

curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/resnet50_script.py
#Install deps for test script
pip install --pre torch torchvision torchaudio tqdm pillow --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python ./resnet50_script.py --device="cpu"  #use cuda or vulkan or metal

Download and run BERT (MiniLM) sample

curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/minilm_jit.py
#Install deps for test script
pip install transformers torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python ./minilm_jit.py --device="cpu"  #use cuda or vulkan or metal
Source Installation

Check out the code

git clone https://github.com/nod-ai/SHARK.git

Setup your Python VirtualEnvironment and Dependencies

# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
./setup_venv.sh
# USE_IREE=1 ./setup_venv.sh #uses the latest IREE nightly instead of SHARK
# Please activate the venv after installation.

Run a demo script

python -m  shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
Testing

Run all model tests on CPU/GPU/VULKAN/Metal

pytest tank

# If on Linux for quicker results:
pytest tank -n auto

Running specific tests

# Run tests for a specific model:
pytest tank/<MODEL_NAME> #i.e., pytest tank/bert-base-uncased

# Run tests for a specific case:
pytest tank/<MODEL_NAME>/<MODEL_TEST>.py::<MODEL>ModuleTest::<CASE>
# i.e., pytest tank/bert-base-uncased/bert-base-uncased_test.py::BertModuleTest::test_module_static_cpu
# For frontends other than pytorch, if available for a model, add frontend to filename: tank/bert-base-uncased/bert-base-uncased_tf_test.py

# Run all tests, including tests for benchmarking and SHARK modules:
# From base SHARK directory,
pytest

Run all model benchmark tests on CPU/GPU/VULKAN/Metal

pytest benchmarks
API Reference

Shark Inference API

from shark_runner import SharkInference

shark_module = SharkInference(
        module = model class.
        (input,)  = inputs to model (must be a torch-tensor)
        dynamic (boolean) = Pass the input shapes as static or dynamic.
        device = `cpu`, `gpu` or `vulkan` is supported.
        tracing_required = (boolean) = Jit trace the module with the given input, useful in the case where jit.script doesn't work. )
shark_module.set_frontend("pytorch") # Use tensorflow, mhlo, linalg, tosa
shark_module.compile()

result = shark_module.forward(inputs)

Example demonstrating running MHLO IR.

from shark.shark_inference import SharkInference
import numpy as np

mhlo_ir = r"""builtin.module  {
      func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
        %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
        %1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
        return %1 : tensor<4x4xf32>
      }
}"""

arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)

shark_module = SharkInference(mhlo_ir, (arg0, arg1))
shark_module.set_frontend("mhlo")
shark_module.compile()
print(shark_module.forward((arg0, arg1)))

Supported and Validated Models

PyTorch Models

Huggingface PyTorch Models

Hugging Face Models Torch-MLIR lowerable SHARK-CPU SHARK-CUDA SHARK-METAL
BERT πŸ’š (JIT) πŸ’š πŸ’š πŸ’š
Albert πŸ’š (JIT) πŸ’š πŸ’š πŸ’š
BigBird πŸ’š (AOT)
DistilBERT πŸ’š (JIT) πŸ’š πŸ’š πŸ’š
GPT2 πŸ’” (AOT)

Torchvision Models

TORCHVISION Models Torch-MLIR lowerable SHARK-CPU SHARK-CUDA SHARK-METAL
AlexNet πŸ’š (Script) πŸ’š πŸ’š πŸ’š
DenseNet121 πŸ’š (Script)
MNasNet1_0 πŸ’š (Script)
MobileNetV2 πŸ’š (Script)
MobileNetV3 πŸ’š (Script)
Unet πŸ’” (Script)
Resnet18 πŸ’š (Script) πŸ’š πŸ’š πŸ’š
Resnet50 πŸ’š (Script) πŸ’š πŸ’š πŸ’š
Resnet101 πŸ’š (Script) πŸ’š πŸ’š πŸ’š
Resnext50_32x4d πŸ’š (Script)
ShuffleNet_v2 πŸ’” (Script)
SqueezeNet πŸ’š (Script) πŸ’š πŸ’š πŸ’š
EfficientNet πŸ’š (Script)
Regnet πŸ’š (Script)
Resnest πŸ’” (Script)
Vision Transformer πŸ’š (Script)
VGG 16 πŸ’š (Script) πŸ’š πŸ’š
Wide Resnet πŸ’š (Script) πŸ’š πŸ’š πŸ’š
RAFT πŸ’” (JIT)

For more information refer to MODEL TRACKING SHEET

PyTorch Training Models

Models Torch-MLIR lowerable SHARK-CPU SHARK-CUDA SHARK-METAL
BERT πŸ’” πŸ’”
FullyConnected πŸ’š πŸ’š
JAX Models

JAX Models

Models JAX-MHLO lowerable SHARK-CPU SHARK-CUDA SHARK-METAL
DALL-E πŸ’” πŸ’”
FullyConnected πŸ’š πŸ’š
TFLite Models

TFLite Models

Models TOSA/LinAlg SHARK-CPU SHARK-CUDA SHARK-METAL
BERT πŸ’” πŸ’”
FullyConnected πŸ’š πŸ’š
albert πŸ’š πŸ’š
asr_conformer πŸ’š πŸ’š
bird_classifier πŸ’š πŸ’š
cartoon_gan πŸ’š πŸ’š
craft_text πŸ’š πŸ’š
deeplab_v3 πŸ’š πŸ’š
densenet πŸ’š πŸ’š
east_text_detector πŸ’š πŸ’š
efficientnet_lite0_int8 πŸ’š πŸ’š
efficientnet πŸ’š πŸ’š
gpt2 πŸ’š πŸ’š
image_stylization πŸ’š πŸ’š
inception_v4 πŸ’š πŸ’š
inception_v4_uint8 πŸ’š πŸ’š
lightning_fp16 πŸ’š πŸ’š
lightning_i8 πŸ’š πŸ’š
lightning πŸ’š πŸ’š
magenta πŸ’š πŸ’š
midas πŸ’š πŸ’š
mirnet πŸ’š πŸ’š
mnasnet πŸ’š πŸ’š
mobilebert_edgetpu_s_float πŸ’š πŸ’š
mobilebert_edgetpu_s_quant πŸ’š πŸ’š
mobilebert πŸ’š πŸ’š
mobilebert_tf2_float πŸ’š πŸ’š
mobilebert_tf2_quant πŸ’š πŸ’š
mobilenet_ssd_quant πŸ’š πŸ’š
mobilenet_v1 πŸ’š πŸ’š
mobilenet_v1_uint8 πŸ’š πŸ’š
mobilenet_v2_int8 πŸ’š πŸ’š
mobilenet_v2 πŸ’š πŸ’š
mobilenet_v2_uint8 πŸ’š πŸ’š
mobilenet_v3-large πŸ’š πŸ’š
mobilenet_v3-large_uint8 πŸ’š πŸ’š
mobilenet_v35-int8 πŸ’š πŸ’š
nasnet πŸ’š πŸ’š
person_detect πŸ’š πŸ’š
posenet πŸ’š πŸ’š
resnet_50_int8 πŸ’š πŸ’š
rosetta πŸ’š πŸ’š
spice πŸ’š πŸ’š
squeezenet πŸ’š πŸ’š
ssd_mobilenet_v1 πŸ’š πŸ’š
ssd_mobilenet_v1_uint8 πŸ’š πŸ’š
ssd_mobilenet_v2_fpnlite πŸ’š πŸ’š
ssd_mobilenet_v2_fpnlite_uint8 πŸ’š πŸ’š
ssd_mobilenet_v2_int8 πŸ’š πŸ’š
ssd_mobilenet_v2 πŸ’š πŸ’š
ssd_spaghettinet_large πŸ’š πŸ’š
ssd_spaghettinet_large_uint8 πŸ’š πŸ’š
visual_wake_words_i8 πŸ’š πŸ’š
TF Models

Tensorflow Models (Inference)

Hugging Face Models tf-mhlo lowerable SHARK-CPU SHARK-CUDA SHARK-METAL
BERT πŸ’š πŸ’š πŸ’š πŸ’š
albert-base-v2 πŸ’š πŸ’š πŸ’š πŸ’š
DistilBERT πŸ’š πŸ’š πŸ’š πŸ’š
CamemBert πŸ’š πŸ’š πŸ’š πŸ’š
ConvBert πŸ’š πŸ’š πŸ’š πŸ’š
Deberta
electra πŸ’š πŸ’š πŸ’š πŸ’š
funnel
layoutlm πŸ’š πŸ’š πŸ’š πŸ’š
longformer
mobile-bert πŸ’š πŸ’š πŸ’š πŸ’š
remembert
tapas
flaubert πŸ’š πŸ’š πŸ’š πŸ’š
roberta πŸ’š πŸ’š πŸ’š πŸ’š
xlm-roberta πŸ’š πŸ’š πŸ’š πŸ’š
mpnet πŸ’š πŸ’š πŸ’š πŸ’š

Related Projects

IREE Project Channels
MLIR and Torch-MLIR Project Channels

License

nod.ai SHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions. See LICENSE for more information.

About

Distributed SHARK

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 71.4%
  • C 16.8%
  • C++ 7.6%
  • Jupyter Notebook 2.3%
  • CMake 0.9%
  • Shell 0.4%
  • Other 0.6%