Skip to content

Commit

Permalink
modelをddpmに固定
Browse files Browse the repository at this point in the history
  • Loading branch information
kenta-tsukaue committed Jan 8, 2023
1 parent 4a7a9d4 commit f7e2f8c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
1 change: 0 additions & 1 deletion models/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
default_initializer = layers.default_init


@utils.register_model(name='ddpm')
class DDPM(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down
6 changes: 4 additions & 2 deletions models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
import sde_lib
from .ddpm import DDPM
import numpy as np


Expand Down Expand Up @@ -87,8 +88,9 @@ def get_ddpm_params(config):

def create_model(config):
"""Create the score model."""
model_name = config.model.name
score_model = get_model(model_name)(config)
print("モデルの名前は",config.model.name)
#score_model = get_model(model_name)(config)
score_model = DDPM(config)
score_model = score_model.to(config.device)
score_model = torch.nn.DataParallel(score_model)
return score_model
Expand Down

0 comments on commit f7e2f8c

Please sign in to comment.