Skip to content

Commit

Permalink
add objectstitch
Browse files Browse the repository at this point in the history
  • Loading branch information
bo-zhang-cs committed Apr 11, 2024
1 parent c1b7f4a commit a42cddf
Show file tree
Hide file tree
Showing 90 changed files with 8,274 additions and 69 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
- **HarmonyScoreModel** evaluates the harmony level between foreground and background in a composite image.
- **InharmoniousLocalizationModel** localizes the inharmonious region in a synthetic image.
- **FOSScoreModel** evaluates the compatibility between foreground and background in a composite image in terms of geometry and semantics.
- **ControlComModel** is a generative image composition model, which unifies image blending and image harmonization in one diffusion model.
- **ControlComModel** is a generative image composition model, which unifies image blending, image harmonization, view synthesis, and generative composition within a diffusion model.
- **ObjectStitchModel** is another generative image composition model that aims to generate a composite image from a pair of background and foreground, with non-object pixels filled with black.
- **ShadowGenerationModel** generates plausible shadow for the inserted object in a composite image.

**For the detailed method descriptions, code examples, visualization results, and performance comments, please refer to our [[documents]](https://libcom.readthedocs.io/en/latest/).**
Expand Down
Binary file added docs/_static/image/objectstitch_result1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/image/objectstitch_result2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion libcom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from .fos_score import *
from .controllable_composition import *
from .shadow_generation import *
from .objectstitch import *

__all__ = [
'color_transfer', 'get_composite_image', 'OPAScoreModel',
'HarmonyScoreModel', 'InharmoniousLocalizationModel',
'ImageHarmonizationModel', 'PainterlyHarmonizationModel',
'FOPAHeatMapModel', 'FOSScoreModel', 'ControlComModel',
'ShadowGenerationModel'
'ShadowGenerationModel', 'ObjectStitchModel'
]
14 changes: 9 additions & 5 deletions libcom/controllable_composition/controllable_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@

cur_dir = os.path.dirname(os.path.abspath(__file__))
model_dir = os.environ.get('LIBCOM_MODEL_DIR',cur_dir)
model_set = ['ControlCom']
task_set = ['blending', 'harmonization'] # 'viewsynthesis', 'composition'
model_set = ['ControlCom', 'ControlCom_blend_harm', 'ControlCom_view_comp']
task_set = ['blending', 'harmonization', 'viewsynthesis', 'composition']

class ControlComModel:
"""
Comtrollable composition model.
Args:
device (str | torch.device): gpu id
model_type (str): predefined model type
model_type (str): predefined model type. "ControlCom" refers to the version trained on all four tasks comprehensively, covering a wide range of domains or objectives. "ControlCom_blend_harm" fine-tunes the "full" version specifically for image blending and harmonization tasks. "ControlCom_view_comp" fine-tunes the "full" version to excel in view synthesis and generative composition tasks.
kwargs (dict): sampler='ddim' (default) or 'plms', other parameters for building model
Examples:
Expand All @@ -54,7 +54,7 @@ class ControlComModel:
>>> fg_img = test_dir + 'foreground/' + img_names[i]
>>> bbox = bboxes[i]
>>> mask = test_dir + 'foreground_mask/' + img_names[i]
>>> net = ControlComModel(device=0)
>>> net = ControlComModel(device=0, model_type="ControlCom")
>>> comp = net(bg_img, fg_img, bbox, mask, task=['blending', 'harmonization'])
>>> bg_img = draw_bbox_on_image(bg_img, bbox)
>>> grid_img = make_image_grid([bg_img, fg_img, comp[0], comp[1]])
Expand All @@ -74,7 +74,7 @@ def __init__(self, device=0, model_type='ControlCom', **kwargs):
self.model_type = model_type
self.option = kwargs

weight_path = os.path.join(model_dir, 'pretrained_models', 'ControlCom.pth')
weight_path = os.path.join(cur_dir, 'pretrained_models', f'{self.model_type}.pth')
download_pretrained_model(weight_path)

self.device = check_gpu_device(device)
Expand Down Expand Up @@ -182,6 +182,10 @@ def task_to_indicator(self, task):
indicator.append([0,0])
elif t == 'harmonization':
indicator.append([1,0])
elif t == 'viewsynthesis':
indicator.append([0,1])
else:
indicator.append([1,1])
return indicator

@torch.no_grad()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
model:
base_learning_rate: 1.0e-05
target: ldm.models.diffusion.ddpm.LatentDiffusion
target: libcom.controllable_composition.source.ControlCom.ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.012
Expand Down Expand Up @@ -41,7 +41,7 @@ model:
use_guidance: true
local_uncond: same
scheduler_config:
target: ldm.lr_scheduler.LambdaLinearScheduler
target: libcom.controllable_composition.source.ControlCom.ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps:
- 1000
Expand All @@ -54,7 +54,7 @@ model:
f_min:
- 1.0
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
target: libcom.controllable_composition.source.ControlCom.ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32
in_channels: 11
Expand All @@ -78,7 +78,7 @@ model:
legacy: false
add_conv_in_front_of_unet: false
local_encoder_config:
conditioning_key: ldm.modules.local_module.LocalRefineBlock
conditioning_key: libcom.controllable_composition.source.ControlCom.ldm.modules.local_module.LocalRefineBlock
add_position_emb: false
roi_size: 16
context_dim: 1024
Expand All @@ -89,7 +89,7 @@ model:
add_in_decoder: true
add_before_crossattn: false
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
target: libcom.controllable_composition.source.ControlCom.ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
Expand All @@ -111,7 +111,7 @@ model:
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
target: libcom.controllable_composition.source.ControlCom.ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
params:
version: openai-clip-vit-large-patch14
local_hidden_index: 12
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities import rank_zero_info

from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
from libcom.controllable_composition.source.ControlCom.ldm.data.base import Txt2ImgIterableBaseDataset
from libcom.controllable_composition.source.ControlCom.ldm.util import instantiate_from_config
import socket
from pytorch_lightning.plugins.environments import ClusterEnvironment,SLURMEnvironment

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from glob import glob
from natsort import natsorted

from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
from libcom.controllable_composition.source.ControlCom.ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from libcom.controllable_composition.source.ControlCom.ldm.util import log_txt_as_img, default, ismap, instantiate_from_config

__models__ = {
'class_label': EncoderUNetModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def __init__(self,
super().__init__()
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
self.parameterization = parameterization
# print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
Expand Down Expand Up @@ -1636,7 +1635,7 @@ def log_local(save_dir,

if __name__ == '__main__':
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from libcom.controllable_composition.source.ControlCom.ldm.util import instantiate_from_config
import os, torchvision
from PIL import Image
import shutil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os, sys
proj_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, proj_dir)
from ldm.modules.diffusionmodules.util import checkpoint
from libcom.controllable_composition.source.ControlCom.ldm.modules.diffusionmodules.util import checkpoint
from torchvision.ops import roi_align

def exists(val):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import os, sys
proj_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.insert(0, proj_dir)
from ldm.modules.diffusionmodules.util import (
from libcom.controllable_composition.source.ControlCom.ldm.modules.diffusionmodules.util import (
checkpoint,
conv_nd,
linear,
Expand Down Expand Up @@ -1141,7 +1141,7 @@ def get_intermediate_features(self, x_bbox, timesteps=None, context=None, y=None
import torch
device = torch.device("cuda:0")
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from libcom.controllable_composition.source.ControlCom.ldm.util import instantiate_from_config
cfg_path = os.path.join(proj_dir, 'configs/finetune_paint.yaml')
configs = OmegaConf.load(cfg_path).model.params
model = instantiate_from_config(configs.unet_config).to(device)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import torch
import torch.nn as nn
from functools import partial
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel,CLIPVisionModel,CLIPModel,CLIPVisionModelWithProjection
from transformers import CLIPVisionModel
import os,sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from ldm.modules.encoders.xf import LayerNorm, Transformer
from libcom.controllable_composition.source.ControlCom.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from libcom.controllable_composition.source.ControlCom.ldm.modules.encoders.xf import LayerNorm, Transformer
import math
import torch.nn.functional as F

Expand Down Expand Up @@ -217,7 +216,7 @@ def encode(self, image):


if __name__ == "__main__":
from ldm.util import count_params
from libcom.controllable_composition.source.ControlCom.ldm.util import count_params
device = torch.device("cuda:0")
model = FrozenCLIPImageEmbedder().to(device)
count_params(model, verbose=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torch import nn, einsum
from einops import rearrange, repeat
import os, sys
from ldm.modules.diffusionmodules.util import checkpoint
from ldm.modules.attention import CrossAttention, zero_module, Normalize
from libcom.controllable_composition.source.ControlCom.ldm.modules.diffusionmodules.util import checkpoint
from libcom.controllable_composition.source.ControlCom.ldm.modules.attention import CrossAttention, zero_module, Normalize
from torchvision.ops import roi_align

class FDN(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
from libcom.controllable_composition.source.ControlCom.ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
2 changes: 0 additions & 2 deletions libcom/controllable_composition/source/ControlCom/ldm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ def get_obj_from_str(string, reload=False):
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
if 'ldm' in module:
module = '.source.ControlCom.' + module
return getattr(importlib.import_module(module, package='libcom.controllable_composition'), cls)


Expand Down
4 changes: 4 additions & 0 deletions libcom/objectstitch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# change to your lib name
from .objectstitch import ObjectStitchModel

__all__ = ['ObjectStitchModel']
Loading

0 comments on commit a42cddf

Please sign in to comment.