forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Docs][Serve] Add model serve using AWS NeuronCore (ray-project#38811)
* [Docs] Add model serve using AWS NeuronCore Signed-off-by: maheedhar reddy chappidi <[email protected]>
- Loading branch information
Showing
5 changed files
with
204 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
115 changes: 115 additions & 0 deletions
115
doc/source/serve/doc_code/aws_neuron_core_inference_serve.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from contextlib import contextmanager | ||
|
||
# __compile_neuron_code_start__ | ||
from transformers import AutoModelForSequenceClassification, AutoTokenizer | ||
import torch, torch_neuronx # noqa | ||
|
||
hf_model = "j-hartmann/emotion-english-distilroberta-base" | ||
neuron_model = "./sentiment_neuron.pt" | ||
|
||
model = AutoModelForSequenceClassification.from_pretrained(hf_model) | ||
tokenizer = AutoTokenizer.from_pretrained(hf_model) | ||
sequence_0 = "The company HuggingFace is based in New York City" | ||
sequence_1 = "HuggingFace's headquarters are situated in Manhattan" | ||
example_inputs = tokenizer.encode_plus( | ||
sequence_0, | ||
sequence_1, | ||
return_tensors="pt", | ||
padding="max_length", | ||
truncation=True, | ||
max_length=128, | ||
) | ||
neuron_inputs = example_inputs["input_ids"], example_inputs["attention_mask"] | ||
n_model = torch_neuronx.trace(model, neuron_inputs) | ||
n_model.save(neuron_model) | ||
print(f"Saved Neuron-compiled model {neuron_model}") | ||
# __compile_neuron_code_end__ | ||
|
||
|
||
# __neuron_serve_code_start__ | ||
from fastapi import FastAPI # noqa | ||
from ray import serve # noqa | ||
|
||
import torch # noqa | ||
|
||
app = FastAPI() | ||
|
||
hf_model = "j-hartmann/emotion-english-distilroberta-base" | ||
neuron_model = "./sentiment_neuron.pt" | ||
|
||
|
||
@serve.deployment(num_replicas=1, route_prefix="/") | ||
@serve.ingress(app) | ||
class APIIngress: | ||
def __init__(self, bert_base_model_handle) -> None: | ||
self.handle = bert_base_model_handle | ||
|
||
@app.get("/infer") | ||
async def infer(self, sentence: str): | ||
ref = await self.handle.infer.remote(sentence) | ||
result = await ref | ||
return result | ||
|
||
|
||
@serve.deployment( | ||
ray_actor_options={"resources": {"neuron_cores": 1}}, | ||
autoscaling_config={"min_replicas": 1, "max_replicas": 2}, | ||
) | ||
class BertBaseModel: | ||
def __init__(self): | ||
import torch, torch_neuronx # noqa | ||
from transformers import AutoTokenizer | ||
|
||
self.model = torch.jit.load(neuron_model) | ||
self.tokenizer = AutoTokenizer.from_pretrained(hf_model) | ||
self.classmap = { | ||
0: "anger", | ||
1: "disgust", | ||
2: "fear", | ||
3: "joy", | ||
4: "neutral", | ||
5: "sadness", | ||
6: "surprise", | ||
} | ||
|
||
def infer(self, sentence: str): | ||
inputs = self.tokenizer.encode_plus( | ||
sentence, | ||
return_tensors="pt", | ||
padding="max_length", | ||
truncation=True, | ||
max_length=128, | ||
) | ||
output = self.model(*(inputs["input_ids"], inputs["attention_mask"])) | ||
class_id = torch.argmax(output["logits"], dim=1).item() | ||
return self.classmap[class_id] | ||
|
||
|
||
entrypoint = APIIngress.bind(BertBaseModel.bind()) | ||
|
||
|
||
# __neuron_serve_code_end__ | ||
|
||
|
||
@contextmanager | ||
def serve_session(deployment): | ||
handle = serve.run(deployment) | ||
try: | ||
yield handle | ||
finally: | ||
serve.shutdown() | ||
|
||
|
||
if __name__ == "__main__": | ||
import requests | ||
import ray | ||
|
||
# On inf2.8xlarge instance, there will be 2 neuron cores. | ||
ray.init(resources={"neuron_cores": 2}) | ||
|
||
with serve_session(entrypoint): | ||
prompt = "Ray is super cool." | ||
resp = requests.get(f"http://127.0.0.1:8000/infer?sentence={prompt}") | ||
print(resp.status_code, resp.json()) | ||
|
||
assert resp.status_code == 200 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
(aws-neuron-core-inference-tutorial)= | ||
|
||
# Serving an inference model on AWS NeuronCores using Fast API (Experimental) | ||
This example compiles bert based model and deploys the traced model on AWS Inferentia (Inf2) or Tranium (Trn1) | ||
instance using Ray Serve and Fast API. | ||
|
||
|
||
:::{note} | ||
The setup assumes that the user has followed the | ||
[PyTorch Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/torch-neuronx.html#setup-torch-neuronx) | ||
setup guide and installed AWS NeuronCore drivers/tools and torch-neuronx based on the instance-type. | ||
|
||
::: | ||
|
||
```bash | ||
python -m pip install "ray[serve]" requests transformers | ||
``` | ||
|
||
This example uses the [j-hartmann/emotion-english-distilroberta-base](https://huggingface.co/j-hartmann/emotion-english-distilroberta-base) model and [FastAPI](https://fastapi.tiangolo.com/). | ||
|
||
Use the following code to compile the model: | ||
```{literalinclude} ../doc_code/aws_neuron_core_inference_serve.py | ||
:language: python | ||
:start-after: __compile_neuron_code_start__ | ||
:end-before: __compile_neuron_code_end__ | ||
``` | ||
|
||
|
||
For compiling the model, you should see the following logs: | ||
```text | ||
Downloading (…)lve/main/config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.00k/1.00k [00:00<00:00, 242kB/s] | ||
Downloading pytorch_model.bin: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 329M/329M [00:01<00:00, 217MB/s] | ||
Downloading (…)okenizer_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 294/294 [00:00<00:00, 305kB/s] | ||
Downloading (…)olve/main/vocab.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 798k/798k [00:00<00:00, 22.0MB/s] | ||
Downloading (…)olve/main/merges.txt: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 456k/456k [00:00<00:00, 57.0MB/s] | ||
Downloading (…)/main/tokenizer.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.36M/1.36M [00:00<00:00, 6.16MB/s] | ||
Downloading (…)cial_tokens_map.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 239/239 [00:00<00:00, 448kB/s] | ||
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks... | ||
To disable this warning, you can either: | ||
- Avoid using `tokenizers` before the fork if possible | ||
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) | ||
Saved Neuron-compiled model ./sentiment_neuron.pt | ||
``` | ||
|
||
The traced model should be ready for deployment. Save the following code to a file named aws_neuron_core_inference_serve.py. | ||
|
||
Use `serve run aws_neuron_core_inference_serve:entrypoint` to start the serve application. | ||
```{literalinclude} ../doc_code/aws_neuron_core_inference_serve.py | ||
:language: python | ||
:start-after: __neuron_serve_code_start__ | ||
:end-before: __neuron_serve_code_end__ | ||
``` | ||
|
||
|
||
You should see the following logs for a successful deployment: | ||
```text | ||
(ServeController pid=43105) INFO 2023-08-23 20:29:32,694 controller 43105 deployment_state.py:1372 - Deploying new version of deployment default_BertBaseModel. | ||
(ServeController pid=43105) INFO 2023-08-23 20:29:32,695 controller 43105 deployment_state.py:1372 - Deploying new version of deployment default_APIIngress. | ||
(HTTPProxyActor pid=43147) INFO 2023-08-23 20:29:32,620 http_proxy 10.0.1.234 http_proxy.py:1328 - Proxy actor 8be14f6b6b10c0190cd0c39101000000 starting on node 46a7f740898fef723c3360ef598c1309701b07d11fb9dc45e236620a. | ||
(HTTPProxyActor pid=43147) INFO: Started server process [43147] | ||
(ServeController pid=43105) INFO 2023-08-23 20:29:32,799 controller 43105 deployment_state.py:1654 - Adding 1 replica to deployment default_BertBaseModel. | ||
(ServeController pid=43105) INFO 2023-08-23 20:29:32,801 controller 43105 deployment_state.py:1654 - Adding 1 replica to deployment default_APIIngress. | ||
2023-08-23 20:29:44,690 SUCC scripts.py:462 -- Deployed Serve app successfully. | ||
``` | ||
|
||
Use the following code to send requests: | ||
```python | ||
import requests | ||
|
||
response = requests.get(f"http://127.0.0.1:8000/infer?sentence=Ray is super cool") | ||
print(response.status_code, response.json()) | ||
``` | ||
The response includes status code and the classifier output | ||
|
||
```text | ||
200 joy | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters