Skip to content

Commit

Permalink
Merge branch 'main' into demo-save_outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
rubendax committed Nov 4, 2024
2 parents 16dc689 + 19af013 commit c58951d
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 110 deletions.
22 changes: 11 additions & 11 deletions OmniGen/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,27 @@
"""


90
class OmniGenPipeline:
def __init__(
self,
vae: AutoencoderKL,
model: OmniGen,
processor: OmniGenProcessor,
device: Union[str, torch.device] = None,
):
self.vae = vae
self.model = model
self.processor = processor
self.device = device

if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
elif is_torch_npu_available():
self.device = torch.device("npu")
else:
logger.info("Don't detect any available devices, using CPU instead")
self.device = torch.device("cpu")
if device is None:
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!")
self.device = torch.device("cpu")

self.model.to(torch.bfloat16)
self.model.eval()
Expand Down Expand Up @@ -306,4 +306,4 @@ def __call__(

torch.cuda.empty_cache() # Clear VRAM
gc.collect() # Run garbage collection to free system RAM
return output_images
return output_images
2 changes: 2 additions & 0 deletions OmniGen/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class OmniGenCache(DynamicCache):
def __init__(self,
num_tokens_for_img: int, offload_kv_cache: bool=False) -> None:
if not torch.cuda.is_available():
print("No avaliable GPU, offload_kv_cache wiil be set to False, which will result in large memory usage and time cost when input multiple images!!!")
offload_kv_cache = False
raise RuntimeError("OffloadedCache can only be used with a GPU")
super().__init__()
self.original_device = []
Expand Down
2 changes: 1 addition & 1 deletion OmniGen/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def forward(
attention_mask = (1 - attention_mask) * min_dtype
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
else:
raise
raise Exception("attention_mask parameter was unavailable or invalid")
# causal_mask = self._update_causal_mask(
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
# )
Expand Down
42 changes: 26 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
</a>
<a href="https://huggingface.co/Shitao/OmniGen-v1">
<img alt="Build" src="https://img.shields.io/badge/HF%20Model-🤗-yellow">
</a>
<a href="https://replicate.com/chenxwh/omnigen">
<img alt="Build" src="https://replicate.com/chenxwh/omnigen/badge">
</a>
</p>

Expand All @@ -31,21 +34,21 @@


## 1. News
- 2024-11-03: Added Replicate Demo and API: [![Replicate](https://replicate.com/chenxwh/omnigen/badge)](https://replicate.com/chenxwh/omnigen)
- 2024-10-28: We release new version of inference code, optimizing the memory usage and time cost. You can refer to [docs/inference.md](docs/inference.md#requiremented-resources) for detailed information.
- 2024-10-22: :fire: We release the code for OmniGen. Inference: [docs/inference.md](docs/inference.md) Train: [docs/fine-tuning.md](docs/fine-tuning.md)
- 2024-10-22: :fire: We release the first version of OmniGen. Model Weight: [Shitao/OmniGen-v1](https://huggingface.co/Shitao/OmniGen-v1) HF Demo: [🤗](https://huggingface.co/spaces/Shitao/OmniGen)


## 2. Overview

OmniGen is a unified image generation model that can generate a wide range of images from multi-modal prompts. It is designed to be simple, flexible and easy to use. We provide [inference code](#5-quick-start) so that everyone can explore more functionalities of OmniGen.
OmniGen is a unified image generation model that can generate a wide range of images from multi-modal prompts. It is designed to be simple, flexible, and easy to use. We provide [inference code](#5-quick-start) so that everyone can explore more functionalities of OmniGen.

Existing image generation models often require loading several additional network modules (such as ControlNet, IP-Adapter, Reference-Net, etc.) and performing extra preprocessing steps (e.g., face detection, pose estimation, cropping, etc.) to generate a satisfactory image. However, **we believe that the future image generation paradigm should be more simple and flexible, that is, generating various images directly through arbitrarily multi-modal instructions without the need for additional plugins and operations, similar to how GPT works in language generation.**

Due to the limited resources, OmniGen still has room for improvement. We will continue to optimize it, and hope it inspire more universal image generation models. You can also easily fine-tune OmniGen without worrying about designing networks for specific tasks; you just need to prepare the corresponding data, and then run the [script](#6-finetune). Imagination is no longer limited; everyone can construct any image generation task, and perhaps we can achieve very interesting, wonderful and creative things.

If you have any questions, ideas or interesting tasks you want OmniGen to accomplish, feel free to discuss with us: [email protected], [email protected], [email protected]. We welcome any feedback to help us improve the model.
Due to the limited resources, OmniGen still has room for improvement. We will continue to optimize it, and hope it inspires more universal image-generation models. You can also easily fine-tune OmniGen without worrying about designing networks for specific tasks; you just need to prepare the corresponding data, and then run the [script](#6-finetune). Imagination is no longer limited; everyone can construct any image-generation task, and perhaps we can achieve very interesting, wonderful, and creative things.

If you have any questions, ideas, or interesting tasks you want OmniGen to accomplish, feel free to discuss with us: [email protected], [email protected], [email protected]. We welcome any feedback to help us improve the model.



Expand All @@ -54,13 +57,13 @@ If you have any questions, ideas or interesting tasks you want OmniGen to accomp
You can see details in our [paper](https://arxiv.org/abs/2409.11340).


## 4. What Can OmniGen do?

## 4. What Can OmniGen do?

OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, image editing, and image-conditioned generation. **OmniGen don't need additional plugins or operations, it can automatically identify the features (e.g., required object, human pose, depth mapping) in input images according the text prompt.**
We showcase some examples in [inference.ipynb](inference.ipynb). And in [inference_demo.ipynb](inference_demo.ipynb), we show an interesting pipeline to generate and modify a image.
OmniGen is a unified image generation model that you can use to perform various tasks, including but not limited to text-to-image generation, subject-driven generation, Identity-Preserving Generation, image editing, and image-conditioned generation. **OmniGen doesn't need additional plugins or operations, it can automatically identify the features (e.g., required object, human pose, depth mapping) in input images according to the text prompt.**
We showcase some examples in [inference.ipynb](inference.ipynb). And in [inference_demo.ipynb](inference_demo.ipynb), we show an interesting pipeline to generate and modify an image.

Here is the illustration of OmniGen's capabilities:
Here is the illustrations of OmniGen's capabilities:
- You can control the image generation flexibly via OmniGen
![demo](./imgs/demo_cases.png)
- Referring Expression Generation: You can input multiple images and use simple, general language to refer to the objects within those images. OmniGen can automatically recognize the necessary objects in each image and generate new images based on them. No additional operations, such as image cropping or face detection, are required.
Expand All @@ -85,9 +88,11 @@ Here are some examples:
```python
from OmniGen import OmniGenPipeline

pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1")
# Note: Your local model path is also acceptable, such as 'pipe = OmniGenPipeline.from_pretrained(your_local_model_path)', where all files in your_local_model_path should be organized as https://huggingface.co/Shitao/OmniGen-v1/tree/main
# Note: If the original link https://huggingface.co/Shitao/OmniGen-v1/tree/main is unstable when downloading, it is recommended to use this mirror link https://hf-mirror.com/Shitao/OmniGen-v1/tree/main or other ways in https://hf-mirror.com/

# Text to Image
## Text to Image
images = pipe(
prompt="A curly-haired man in a red shirt is drinking tea.",
height=1024,
Expand All @@ -97,8 +102,8 @@ images = pipe(
)
images[0].save("example_t2i.png") # save output PIL Image

# Multi-modal to Image
# In prompt, we use the placeholder to represent the image. The image placeholder should be in the format of <img><|image_*|></img>
## Multi-modal to Image
# In the prompt, we use the placeholder to represent the image. The image placeholder should be in the format of <img><|image_*|></img>
# You can add multiple images in the input_images. Please ensure that each image has its placeholder. For example, for the list input_images [img1_path, img2_path], the prompt needs to have two placeholders: <img><|image_1|></img>, <img><|image_2|></img>.
images = pipe(
prompt="A man in a black shirt is reading a book. The man is the right man in <img><|image_1|></img>.",
Expand All @@ -111,27 +116,27 @@ images = pipe(
)
images[0].save("example_ti2i.png") # save output PIL image
```
- If out of memory, you can set `offload_model=True`. If inference time is too long when input multiple images, you can reduce the `max_input_image_size`. For thre required resources and the method to run OmniGen efficiently, please refer to [docs/inference.md#requiremented-resources](docs/inference.md#requiremented-resources).
- For more examples for image generation, you can refer to [inference.ipynb](inference.ipynb) and [inference_demo.ipynb](inference_demo.ipynb)
- If out of memory, you can set `offload_model=True`. If the inference time is too long when inputting multiple images, you can reduce the `max_input_image_size`. For the required resources and the method to run OmniGen efficiently, please refer to [docs/inference.md#requiremented-resources](docs/inference.md#requiremented-resources).
- For more examples of image generation, you can refer to [inference.ipynb](inference.ipynb) and [inference_demo.ipynb](inference_demo.ipynb)
- For more details about the argument in inference, please refer to [docs/inference.md](docs/inference.md).


### Using Diffusers

Coming soon.


### Gradio Demo

We construct an online demo in [Huggingface](https://huggingface.co/spaces/Shitao/OmniGen).

For the local gradio demo, you need to install `pip install gradio spaces` , and then you can run:
For the local gradio demo, you need to install `pip install gradio spaces`, and then you can run:
```python
pip install gradio spaces
python app.py
```



## 6. Finetune
We provide a training script `train.py` to fine-tune OmniGen.
Here is a toy example about LoRA finetune:
Expand All @@ -157,7 +162,12 @@ accelerate launch --num_processes=1 train.py \

Please refer to [docs/fine-tuning.md](docs/fine-tuning.md) for more details (e.g. full finetune).

### Contributors:
Thank all our contributors for their efforts and warmly welcome new members to join in!

<a href="https://github.com/VectorSpaceLab/OmniGen/graphs/contributors">
<img src="https://contrib.rocks/image?repo=VectorSpaceLab/OmniGen" />
</a>

## License
This repo is licensed under the [MIT License](LICENSE).
Expand Down
Loading

0 comments on commit c58951d

Please sign in to comment.