Skip to content

Commit

Permalink
Merge branch 'main' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu committed May 30, 2023
2 parents 9312994 + 8b21169 commit 15f31eb
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 20 deletions.
10 changes: 2 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


## Release

- [5/13] 🔥 Interested in quantifying the emerged **zero-shot OCR** performance of LLaVA and open-sourced LMM? Please check out the paper ["On the Hidden Mystery of OCR in Large Multimodal Models"](https://arxiv.org/abs/2305.07895), where LLaVA consistently outperforms miniGPT4 on 17 out of 18 datasets, despite LlaVA being trained with an order of magnitude smaller training data.
- [5/6] 🔥 We are releasing [LLaVA-Lighting-MPT-7B-preview](https://huggingface.co/liuhaotian/LLaVA-Lightning-MPT-7B-preview), based on MPT-7B-Chat! See [here](#LLaVA-MPT-7b) for more details.
- [5/2] 🔥 We are releasing LLaVA-Lighting! Train a lite, multimodal GPT-4 with just $40 in 3 hours! See [here](#train-llava-lightning) for more details.
- [5/2] We upgrade LLaVA package to v0.1 to support Vicuna v0 and v1 checkpoints, please upgrade following instructions [here](#install).
Expand Down Expand Up @@ -94,7 +94,7 @@ pip install -e .
3. Install additional packages for training cases
```
pip install ninja
pip install flash-attn
pip install flash-attn==1.0.2
```

### Upgrade to v0.1
Expand Down Expand Up @@ -166,12 +166,6 @@ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:
```
Wait until the process finishes loading the model and you see "Uvicorn running on ...".


#### Send a test message
```Shell
python -m llava.serve.test_message --model-name LLaVA-13B-v0 --controller http://localhost:10000
```

#### Launch a gradio web server.
```Shell
python -m llava.serve.gradio_web_server --controller http://localhost:10000
Expand Down
7 changes: 3 additions & 4 deletions llava/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,12 @@ def dict(self):
)

simple_conv = Conversation(
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture."
"You are designed to assist human with a variety of tasks using natural language."
"Follow the instructions carefully.",
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "Hi!"),
("Assistant", "Hi there! How can I help you today?\n")
("Assistant", "Hi there! How can I help you today?")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
Expand Down
14 changes: 7 additions & 7 deletions llava/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def heart_beat_worker(controller):
controller.send_heart_beat()


def load_model(model_path, num_gpus):
def load_model(model_path, model_name, num_gpus):
if num_gpus == 1:
kwargs = {}
else:
Expand All @@ -56,19 +56,19 @@ def load_model(model_path, num_gpus):
}

tokenizer = AutoTokenizer.from_pretrained(model_path)
if 'llava' in model_path.lower():
if 'mpt' in model_path.lower():
if 'llava' in model_name.lower():
if 'mpt' in model_name.lower():
model = LlavaMPTForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs)
else:
model = LlavaLlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs)
elif 'mpt' in model_path.lower():
elif 'mpt' in model_name.lower():
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs)

image_processor = None

if 'llava' in model_path.lower():
if 'llava' in model_name.lower():
from transformers import CLIPImageProcessor, CLIPVisionModel
image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)

Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(self, controller_addr, worker_addr,
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
self.keep_aspect_ratio = keep_aspect_ratio
self.tokenizer, self.model, self.image_processor, self.context_len = load_model(
model_path, num_gpus)
model_path, self.model_name, num_gpus)
self.is_multimodal = 'llava' in model_path.lower()

if not no_register:
Expand Down Expand Up @@ -186,7 +186,7 @@ def generate_stream(self, params):
prompt = params["prompt"]
ori_prompt = prompt
images = params.get("images", None)
if images is not None and self.is_multimodal:
if images is not None and len(images) > 0 and self.is_multimodal:
from PIL import Image
from io import BytesIO
import base64
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"accelerate", "fastapi", "gradio==3.23", "markdown2[all]", "numpy",
"accelerate", "einops", "fastapi", "gradio==3.23", "markdown2[all]", "numpy",
"requests", "sentencepiece", "tokenizers==0.12.1",
"torch", "torchvision", "uvicorn", "wandb",
"transformers @ git+https://github.com/huggingface/transformers.git@cae78c46"
Expand Down

0 comments on commit 15f31eb

Please sign in to comment.