Skip to content

Commit

Permalink
[Vision][Fix] Enable image processing kernel on non-CUDA backends (ml…
Browse files Browse the repository at this point in the history
…c-ai#2923)

Prior to this PR, when compiling/running phi3.5-vision on non-CUDA
backend like Metal, we would run into the following issues:

- Shape inference would exceed int32 (CUDA does not run into this as we
use int64 on CUDA), leading to error in runtime:
```
TVMError: Assert fail: (T.Div(new_h - 2147483185, 336) - -6391320) * 336 == T.Cast("int32", resize2d1_var_lv4_shape[1]), Argument resize2d1.var_lv4.shape[1] has an unsatisfied constraint: new_h + T.Div((new_h + 336 - 1) // 336 * 336 - new_h, 2) + ((new_h + 336 - 1) // 336 * 336 - new_h - T.Div((new_h + 336 - 1) // 336 * 336 - new_h, 2)) == T.Cast("int32", resize2d1_var_lv4_shape[1])
```   
- If naively keeping int64 on Metal, we run into:
  - `TVMError: Check failed: blockSize <= maxTotalThreadsPerThreadgroup (1024 vs. 896) :`
  - This is because when we use too many registers, number of available threads
in a block decreases (to 896 here)

This PR fixes the issues above.

Besides, we rename `std` to `stddev` to avoid reserved name issues on backends like WGSL.

Tested on Metal with:
```
python python/mlc_llm/testing/debug_chat.py "List the objects you can identify in this image succinctly." --generate-len 256 --model dist/phi-3_5-vision-q4f16_1 --model-lib dist/libs/phi-3_5-vision-q4f16_1-metal.so --debug-dir debug/ --image-url https://www.islandvulnerability.org/borders/ai8699.jpg --disable-instrument
```

---------

Co-authored-by: Ruihang Lai <[email protected]>
  • Loading branch information
CharlieFRuan and MasterJH5574 authored Sep 19, 2024
1 parent 1828f95 commit 763a677
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
3 changes: 2 additions & 1 deletion python/mlc_llm/model/phi3v/phi3v_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,9 @@ def image_preprocess(self, pixel_values: Tensor, num_crops=16) -> Tensor:

global_image = op.permute_dims(global_image, axes=(0, 3, 1, 2))
n, h, w, c = pixel_values.shape # pylint: disable=unused-variable
assert isinstance(h, tir.Mul) and isinstance(h.b, tir.IntImm) and h.b.value == 336
pixel_values = op.permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW
pixel_values = op.reshape(pixel_values, shape=(1, 3, h // 336, 336, w // 336, 336))
pixel_values = op.reshape(pixel_values, shape=(1, 3, h.a, 336, w // 336, 336))
pixel_values = op.permute_dims(pixel_values, axes=(0, 2, 4, 1, 3, 5))
pixel_values = op.reshape(pixel_values, shape=(-1, 3, 336, 336))
combined_image = op.concat([pixel_values, global_image], dim=0)
Expand Down
25 changes: 12 additions & 13 deletions python/mlc_llm/model/vision/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,25 +153,25 @@ def normalize_func(image: T.handle, out: T.handle):
image_buf = T.match_buffer(image, (n, h, w, c), dtype=dtype)
out_buf = T.match_buffer(out, (n, h, w, c), dtype=o_dtype)
mean = _var(o_dtype)
std = _var(o_dtype)
stddev = _var(o_dtype)
for n_idx in T.thread_binding(n, thread="blockIdx.x"):
for h_idx, w_idx, c_idx in T.grid(h, w, c):
with T.block("compute"):
T.reads(image_buf[n_idx, h_idx, w_idx, c_idx])
T.writes(out_buf[n_idx, h_idx, w_idx, c_idx])
if 0 == c_idx:
mean[0] = 0.48145466
std[0] = 0.26862954
stddev[0] = 0.26862954
elif 1 == c_idx:
mean[0] = 0.4578275
std[0] = 0.26130258
stddev[0] = 0.26130258
elif 2 == c_idx:
mean[0] = 0.40821073
std[0] = 0.27577711
stddev[0] = 0.27577711

out_buf[n_idx, h_idx, w_idx, c_idx] = (
T.cast(image_buf[n_idx, h_idx, w_idx, c_idx], o_dtype) - mean[0]
) / std[0]
) / stddev[0]

return normalize_func

Expand Down Expand Up @@ -206,19 +206,18 @@ def pad_func(image: T.handle, out: T.handle, t: T.int64(), b: T.int64()):

return pad_func

def cal_pad_num(image):
h = image.shape[1]
tar = tir.generic.cast(tir.ceildiv(h, 336) * 336, "int64")
t = tir.generic.cast(tir.div(tar - h, 2), "int64")
b = tar - h - t
return 0, t, 0, b
h = image.shape[1]
tar = tir.truncdiv(h + 335, 336) * 336
t = tir.div(tar - h, 2)
b = tar - h - t
l = 0
r = 0

n, h, w, c = image.shape
l, t, r, b = cal_pad_num(image)
out = op.tensor_ir_op(
create_pad_func(l, r),
"pad",
[image, t, b],
[Tensor.placeholder((n, h + t + b, w + l + r, c), image.dtype)],
[Tensor.placeholder((n, tar, w, c), image.dtype)],
)
return out

0 comments on commit 763a677

Please sign in to comment.