diff --git a/libcom/objectstitch/objectstitch.py b/libcom/objectstitch/objectstitch.py index b767f57..efa6846 100644 --- a/libcom/objectstitch/objectstitch.py +++ b/libcom/objectstitch/objectstitch.py @@ -3,7 +3,7 @@ from libcom.utils.model_download import download_pretrained_model, download_entire_folder from libcom.utils.process_image import * from libcom.utils.environment import * -import torch +import torch import os import torchvision.transforms as transforms import torch.nn.functional as F @@ -32,24 +32,24 @@ cur_dir = os.path.dirname(os.path.abspath(__file__)) model_dir = os.environ.get('LIBCOM_MODEL_DIR',cur_dir) -model_set = ['ObjectStitch'] +model_set = ['ObjectStitch'] class Mure_ObjectStitchModel: """ Unofficial implementation of the paper "ObjectStitch: Object Compositing with Diffusion Model", CVPR 2023. Building upon ObjectStitch, we have made improvements to support input of multiple foreground images. - + Args: device (str | torch.device): gpu id model_type (str): predefined model type. kwargs (dict): sampler='ddim' (default) or 'plms', other parameters for building model - + Examples: - >>> from libcom import MureObjectStitchModel + >>> from libcom import Mure_ObjectStitchModel >>> from libcom.utils.process_image import make_image_grid, draw_bbox_on_image >>> import cv2 >>> import os - >>> net = MureObjectStitchModel(device=0, sampler='plms') + >>> net = Mure_ObjectStitchModel(device=0, sampler='plms') >>> sample_list = ['000000000003', '000000000004'] >>> sample_dir = './tests/mure_objectstitch/' >>> bbox_list = [[623, 1297, 1159, 1564], [363, 205, 476, 276]] @@ -71,17 +71,17 @@ class Mure_ObjectStitchModel: :scale: 21 % .. image:: _static/image/mureobjectstitch_result2.jpg :scale: 21 % - - + + """ def __init__(self, device=0, model_type='ObjectStitch', **kwargs): assert model_type in model_set, f'Not implementation for {model_type}' self.model_type = model_type self.option = kwargs - + weight_path = os.path.join(cur_dir, 'pretrained_models', f'{self.model_type}.pth') download_pretrained_model(weight_path) - + self.device = check_gpu_device(device) self.build_pretrained_model(weight_path) self.build_data_transformer() @@ -100,7 +100,7 @@ def build_pretrained_model(self, weight_path): self.sampler = PLMSSampler(self.model) else: self.sampler = DDIMSampler(self.model) - + def build_data_transformer(self): self.image_size = (512, 512) self.clip_transform = get_tensor_clip(image_size=(224, 224)) @@ -153,11 +153,11 @@ def draw_compose_fg_img(self, fg_img_compose): if fg_img_nums>5: fg_img_compose = fg_img_compose[:5] - + for idx, img in enumerate(fg_img_compose): fg_img = img.resize(size) final_img.paste(fg_img, positions[idx]) - + return final_img def rescale_image_with_bbox(self, image, bbox=None, long_size=1024): @@ -201,7 +201,7 @@ def generate_multifg(self, fg_list_path, fgmask_list_path): for fg_mask_name in fgmask_list_path: fg_mask = Image.open(fg_mask_name).convert('RGB') fg_mask_list.append(fg_mask) - + for idx, fg_mask in enumerate(fg_mask_list): fg_mask = fg_mask.convert('L') mask = np.asarray(fg_mask) @@ -248,7 +248,7 @@ def generate_image_batch(self, bg_path, fg_list_path, fgmask_list_path, bbox): "fg_img": fg_img, "fg_img_list": fg_img_list, "bbox": bbox_t.unsqueeze(0)} - + def prepare_input(self, batch, shape, num_samples): if num_samples > 1: for k in batch.keys(): @@ -275,7 +275,7 @@ def prepare_input(self, batch, shape, num_samples): c = torch.cat([c] * num_samples, dim=0) uc = self.model.learnable_vector.repeat(c.shape[0], c.shape[1], 1) # 1,1,768 return test_model_kwargs, c, uc - + def inputs_preprocess(self, background_image, fg_list_path, fgmask_list_path, bbox, num_samples): batch = self.generate_image_batch(background_image, fg_list_path, fgmask_list_path, bbox) @@ -283,17 +283,17 @@ def inputs_preprocess(self, background_image, fg_list_path, fgmask_list_path, bb show_fg_img = batch["fg_img"] return test_kwargs, c, uc, show_fg_img - - + + def outputs_postprocess(self, outputs): x_samples_ddim = self.model.decode_first_stage(outputs[:,:4]).cpu().float() comp_img = tensor2numpy(x_samples_ddim, image_size=self.image_size) if len(comp_img) == 1: return comp_img[0] return comp_img - + @torch.no_grad() - def __call__(self, background_image, foreground_image, foreground_mask, bbox, + def __call__(self, background_image, foreground_image, foreground_mask, bbox, num_samples=1, sample_steps=50, guidance_scale=5, seed=321): """ Controllable image composition based on diffusion model. @@ -306,15 +306,15 @@ def __call__(self, background_image, foreground_image, foreground_mask, bbox, num_samples (int): Number of images to be generated. default: 1. sample_steps (int): Number of denoising steps. The recommended setting is 25 for PLMS sampler and 50 for DDIM sampler. default: 50. guidance_scale (int): Scale in classifier-free guidance (minimum: 1; maximum: 20). default: 5. - seed (int): Random Seed is used to reproduce results and same seed will lead to same results. + seed (int): Random Seed is used to reproduce results and same seed will lead to same results. Returns: - composite_images (numpy.ndarray): Generated images with a shape of 512x512x3 or Nx512x512x3, where N indicates the number of generated images. + composite_images (numpy.ndarray): Generated images with a shape of 512x512x3 or Nx512x512x3, where N indicates the number of generated images. """ seed_everything(seed) - test_kwargs, c, uc, show_fg_img = self.inputs_preprocess(background_image, foreground_image, + test_kwargs, c, uc, show_fg_img = self.inputs_preprocess(background_image, foreground_image, foreground_mask, bbox, num_samples) start_code = torch.randn([num_samples]+self.latent_shape, device=self.device) @@ -329,4 +329,4 @@ def __call__(self, background_image, foreground_image, foreground_mask, bbox, unconditional_conditioning=uc, test_model_kwargs=test_kwargs) comp_img = self.outputs_postprocess(outputs) - return comp_img, show_fg_img \ No newline at end of file + return comp_img, show_fg_img