Skip to content

Commit

Permalink
[Model] Add image preprocess for vision model (mlc-ai#2892)
Browse files Browse the repository at this point in the history
This PR add image preprocess for vision model
  • Loading branch information
mengshyu authored Sep 14, 2024
1 parent 6277afb commit 36d0ed1
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 170 deletions.
32 changes: 23 additions & 9 deletions python/mlc_llm/model/llava/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from tvm import tir
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Module, Tensor
from tvm.relax.frontend.nn.op import reshape, wrap_nested
from tvm.relax.frontend.nn.op import permute_dims, reshape, wrap_nested
from tvm.relax.op import strided_slice

from mlc_llm import op as op_ext
from mlc_llm.model.model_preset import MODEL_PRESETS
from mlc_llm.model.vision import CLIPVisionConfig, CLIPVisionModel
from mlc_llm.model.vision import CLIPVisionConfig, CLIPVisionModel, ImageProcessor
from mlc_llm.nn import PagedKVCache, RopeMode

from ...support.config import ConfigBase
Expand Down Expand Up @@ -139,6 +139,7 @@ def __init__(self, config: LlavaConfig):
super().__init__()
self.config = config
self.vision_tower = CLIPVisionModel(config.vision_config)
self.image_processor = ImageProcessor()
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = ARCHITECTURE_MAP[config.text_architecture](config.text_config)
self.vocab_size = config.vocab_size
Expand All @@ -153,7 +154,25 @@ def to(self, dtype: Optional[str] = None):
def embed(self, input_ids: Tensor) -> Tensor:
return self.language_model.embed(input_ids)

def image_preprocess(self, pixel_values: Tensor) -> Tensor:
pixel_values = permute_dims(pixel_values, axes=(0, 2, 3, 1)) # NCHW -> NHWC
pixel_values = self.image_processor.resize(
pixel_values, {"shortest_edge": self.config.vision_config.image_size}
)
pixel_values = self.image_processor.crop(
pixel_values,
{
"height": self.config.vision_config.image_size,
"width": self.config.vision_config.image_size,
},
)
pixel_values = self.image_processor.rescale(pixel_values)
pixel_values = self.image_processor.normalize(pixel_values)
pixel_values = permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW
return pixel_values

def image_embed(self, pixel_values: Tensor) -> Tensor:
pixel_values = self.image_preprocess(pixel_values)
pixel_values = pixel_values.astype(self.dtype)
image_features_all = self.vision_tower.forward(pixel_values)
image_features = wrap_nested(
Expand Down Expand Up @@ -237,13 +256,8 @@ def get_default_spec(self):
},
"image_embed": {
"pixel_values": nn.spec.Tensor(
[
1,
3,
self.config.vision_config.image_size,
self.config.vision_config.image_size,
],
"float32",
[1, 3, "image_height", "image_width"],
"uint8",
),
"$": {
"param_mode": "packed",
Expand Down
11 changes: 5 additions & 6 deletions python/mlc_llm/model/phi3v/phi3v_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def __init__(self, config: ConfigBase):
super().__init__()

self.img_processor = CLIPVisionModel(config.vision_config)
self.image_dim_out = 1024
self.num_img_tokens = 144
self.image_dim_out = config.img_processor["image_dim_out"]

self.glb_GN = nn.Parameter((1, 1, self.image_dim_out * 4))
self.sub_GN = nn.Parameter((1, 1, 1, self.image_dim_out * 4))
Expand All @@ -48,16 +47,16 @@ def get_img_features(self, img_embeds: Tensor) -> Tensor:
patch_feature = nn.op.split(img_processor_output, indices_or_sections=[1], axis=1)
return patch_feature[1]

def forward(self, pixel_values: Tensor) -> Tensor: # pylint: disable=too-many-locals
# pylint: disable=too-many-locals,too-many-locals,unused-argument
def forward(self, pixel_values: Tensor, raw_image_h, raw_image_w) -> Tensor:
h = 3 # raw_image_h // self.image_size
w = 4 # raw_image_w // self.image_size
B_ = h * w
C = self.image_dim_out

img_embeds = nn.op.squeeze(pixel_values, 0)
img_features = self.get_img_features(img_embeds)
# img_embeds = nn.op.squeeze(pixel_values, 0)
img_features = self.get_img_features(pixel_values)
H = T.int32((img_features.shape[1] ** 0.5))

img_features = nn.op.split(img_features, indices_or_sections=[1], axis=0)
global_img_feature = img_features[0]
global_img_feature = nn.op.reshape(global_img_feature, ([1, H, H, C]))
Expand Down
82 changes: 69 additions & 13 deletions python/mlc_llm/model/phi3v/phi3v_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import dataclasses
from typing import Any, Dict, Optional

from tvm import te, tir
from tvm import relax, te, tir
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op

from mlc_llm import op as op_ext
from mlc_llm.model.phi3 import Phi3Model
from mlc_llm.model.vision import CLIPVisionConfig
from mlc_llm.model.vision import CLIPVisionConfig, ImageProcessor
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.support import logging
from mlc_llm.support.config import ConfigBase
Expand Down Expand Up @@ -49,6 +49,7 @@ class Phi3VConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
num_key_value_heads: int
max_position_embeddings: int
vision_config: CLIPVisionConfig = None
img_processor: Optional[Dict[str, Any]] = None
position_embedding_base: int = 0
rope_scaling: Optional[Dict[str, Any]] = None
original_max_position_embeddings: int = 0
Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(self, config: Phi3VConfig) -> None:
self.model = Phi3Model(config)
self.lm_head = nn.Linear(config.hidden_size, "vocab_size", bias=False)
self.vision_embed_tokens = Phi3ImageEmbedding(config)
self.image_processor = ImageProcessor()
self.num_hidden_layers = config.num_hidden_layers
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
Expand Down Expand Up @@ -215,9 +217,72 @@ def embed(self, input_ids: Tensor):
embeds = self.model.embd(input_ids)
return embeds

# pylint: disable=protected-access
def image_preprocess(self, pixel_values: Tensor, num_crops=16) -> Tensor:
pixel_values = op.permute_dims(pixel_values, axes=(0, 2, 3, 1)) # NCHW -> NHWC
pixel_values = self.image_processor.resize(pixel_values, params={"hd_transform": 336})
new_h = tir.Var("new_h", "int64")
new_w = tir.Var("new_w", "int64")
pixel_values = op.wrap_nested(
relax.BlockBuilder()
.current()
.match_cast(
pixel_values._expr,
relax.TensorStructInfo(
[pixel_values.shape[0], new_h, new_w, pixel_values.shape[3]], pixel_values.dtype
),
),
"pixel_values",
)

pixel_values = self.image_processor.pad(pixel_values)
pixel_values = self.image_processor.rescale(pixel_values)
pixel_values = self.image_processor.normalize(pixel_values)
global_image = self.image_processor.resize(
pixel_values, params={"height": 336, "width": 336}
)
global_image = op.wrap_nested(
relax.BlockBuilder()
.current()
.match_cast(
global_image._expr,
relax.TensorStructInfo(
[global_image.shape[0], 336, 336, global_image.shape[3]], global_image.dtype
),
),
"global_image",
)

global_image = op.permute_dims(global_image, axes=(0, 3, 1, 2))
n, h, w, c = pixel_values.shape # pylint: disable=unused-variable
pixel_values = op.permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW
pixel_values = op.reshape(pixel_values, shape=(1, 3, h // 336, 336, w // 336, 336))
pixel_values = op.permute_dims(pixel_values, axes=(0, 2, 4, 1, 3, 5))
pixel_values = op.reshape(pixel_values, shape=(-1, 3, 336, 336))
combined_image = op.concat([pixel_values, global_image], dim=0)

# pad to max num crops tensor
b, c, h, w = combined_image.shape
zeros = op.zeros((num_crops + 1 - b, c, h, w))
combined_image = op.concat([combined_image, zeros], dim=0)

combined_image = op.wrap_nested(
relax.BlockBuilder()
.current()
.match_cast(
combined_image._expr,
relax.TensorStructInfo([num_crops + 1, c, h, w], combined_image.dtype),
),
"combined_image",
)

return combined_image

def image_embed(self, pixel_values: Tensor) -> Tensor:
n, c, h, w = pixel_values.shape # pylint: disable=unused-variable
pixel_values = self.image_preprocess(pixel_values)
pixel_values = pixel_values.astype(self.dtype)
return self.vision_embed_tokens(pixel_values)
return self.vision_embed_tokens(pixel_values, h, w)

def create_paged_kv_cache( # pylint: disable=too-many-arguments
self,
Expand Down Expand Up @@ -255,16 +320,7 @@ def get_default_spec(self):
},
},
"image_embed": {
"pixel_values": nn.spec.Tensor(
[
1,
17,
3,
self.config.vision_config.image_size,
self.config.vision_config.image_size,
],
"float32",
),
"pixel_values": nn.spec.Tensor([1, 3, "image_height", "image_width"], "uint8"),
"$": {
"param_mode": "packed",
"effect_mode": "none",
Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Common `nn.Modules` used to define LLMs in this project."""

from .clip_vision import CLIPVisionConfig, CLIPVisionModel
from .image_processing import ImageProcessor
5 changes: 2 additions & 3 deletions python/mlc_llm/model/vision/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tvm.relax.frontend.nn import Module, Tensor
from tvm.relax.frontend.nn.modules import Conv2D
from tvm.relax.frontend.nn.op import (
add,
broadcast_to,
concat,
permute_dims,
Expand Down Expand Up @@ -45,8 +46,6 @@ class CLIPVisionConfig(ConfigBase): # pylint: disable=too-many-instance-attribu


# pylint: disable=invalid-name,missing-docstring


class CLIPVisionEmbeddings(Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: CLIPVisionConfig):
super().__init__()
Expand Down Expand Up @@ -86,7 +85,7 @@ def forward(self, pixel_values: Tensor) -> Tensor:
self.position_embedding(posi_ids),
shape=(batch_size, self.num_positions, self.embed_dim),
)
embeddings = embeddings + batch_position_embedding
embeddings = add(embeddings, batch_position_embedding)
return embeddings


Expand Down
Loading

0 comments on commit 36d0ed1

Please sign in to comment.