Skip to content

Commit

Permalink
bugfix (mindspore-lab#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
LoomisChen authored Apr 20, 2023
1 parent a3cb11a commit 1447389
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 9 deletions.
11 changes: 7 additions & 4 deletions vision/wukong-huahua/ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,10 @@ def construct(self, x, c):
def p_losses(self, x_start, cond, t, noise=None):
noise = msnp.randn(x_start.shape)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)

model_output = self.apply_model(x_noisy, t,
c_concat=cond if self.model.conditioning_key == 'concat' else None,
c_crossattn=cond if self.model.conditioning_key == 'crossattn' else None)

if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
Expand Down Expand Up @@ -484,7 +486,9 @@ def construct(self, train_x, train_c, reg_x, reg_c):
def p_losses(self, x_start, cond, t, noise=None):
noise = msnp.randn(x_start.shape)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
model_output = self.apply_model(x_noisy, t,
c_concat=cond if self.model.conditioning_key == 'concat' else None,
c_crossattn=cond if self.model.conditioning_key == 'crossattn' else None)

if self.parameterization == "x0":
target = x_start
Expand Down Expand Up @@ -518,4 +522,3 @@ def __init__(
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys

Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ def sample(self,
img = x_T

ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)

model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
lambda x, t, c: self.model.apply_model(x, t,
c_concat=c if self.model.model.conditioning_key == 'concat' else None,
c_crossattn=c if self.model.model.conditioning_key == 'crossattn' else None),
ns,
model_type="noise",
guidance_type="classifier-free",
Expand Down
2 changes: 1 addition & 1 deletion vision/wukong-huahua/ldm/models/diffusion/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F

def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
e_t = self.model.apply_model(x, t, c_crossattn=c)
else:
x_in = ops.concat((x, x), axis=0)
t_in = ops.concat((t, t), axis=0)
Expand Down
4 changes: 2 additions & 2 deletions vision/wukong-huahua/scripts/run_db_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ export ASCEND_GLOBAL_LOG_LEVEL=3
export ASCEND_SLOG_PRINT_TO_STDOUT=0
device_id=3

output_path=/mnt/liuqiyuan/output
output_path=output
task_name=α猫
train_data_path=dataset/train_cat
reg_data_path=dataset/reg_cat
class_word=猫
pretrained_model_path=/mnt/liuqiyuan/wkhh_models
pretrained_model_path=models
train_config_file=configs/train_db_config.json
token=α

Expand Down

0 comments on commit 1447389

Please sign in to comment.