Skip to content

Commit

Permalink
fix ckpt manager of nn and geonn
Browse files Browse the repository at this point in the history
  • Loading branch information
Zirui Zhang committed May 3, 2023
1 parent 1d898da commit 8379a9f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 32 deletions.
29 changes: 8 additions & 21 deletions Geonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self,
self.input_dim = input_dim
self.num_hidden_layers = num_hidden_layers
self.num_neurons_per_layer = num_neurons_per_layer

# Input layer
self.input_layer = tf.keras.layers.Dense(self.num_neurons_per_layer, activation='tanh', input_shape=(self.input_dim,))

Expand All @@ -32,12 +32,7 @@ def __init__(self,

self.build(input_shape=(None,input_dim))



self.checkpoint = tf.train.Checkpoint(model=self)
self.checkpoint_manager = tf.train.CheckpointManager(self.checkpoint, directory='geockpt', max_to_keep=3)


def call(self, inputs):
x = self.input_layer(inputs)
for hidden_layer in self.hidden_layers:
Expand All @@ -52,22 +47,14 @@ def train(self, X, Pwmq, Pgmq, phiq, epochs=10000):
history = self.fit(X, {'Pwm': Pwmq, 'Pgm': Pgmq, 'phi': phiq}, epochs=epochs, batch_size=batch_size)
return history

def save_checkpoint(self):
self.checkpoint_manager.save()

def load_checkpoint(self, checkpoint_path=None):
if checkpoint_path:
self.checkpoint.restore(checkpoint_path)
else:
self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
# def save_checkpoint(self):
# self.checkpoint_manager.save()

def save_checkpoint(self, checkpoint_dir='geockpt'):
os.makedirs(checkpoint_dir, exist_ok=True)
self.save_weights(os.path.join(checkpoint_dir,'ckpt'))

def load_checkpoint(self, checkpoint_dir='geockpt'):
self.load_weights(os.path.join(checkpoint_dir,'ckpt'))

# def load_checkpoint(self, checkpoint_path=None):
# if checkpoint_path:
# self.checkpoint.restore(checkpoint_path)
# else:
# self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)

if __name__ == '__main__':

Expand Down
14 changes: 7 additions & 7 deletions glioma.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, opts) -> None:
# input is spatial coordiante, output Pwm, Pgm, phi
if opts['usegeo'] is True:
self.geomodel = Geonn(input_dim=self.xdim)
self.geomodel.manager = self.setup_ckpt(self.geomodel, ckptdir = 'geockpt', restore = self.opts['restore'])

# get init from dataset
if opts['initfromdata'] is True:
Expand Down Expand Up @@ -225,7 +226,7 @@ def pde(xr, nn, geomodel):
regularizer=reg)

# load model, also change self.param
self.manager = self.setup_ckpt(self.model, ckptdir = 'ckpt', restore = self.opts['restore'])
self.model.manager = self.setup_ckpt(self.model, ckptdir = 'ckpt', restore = self.opts['restore'])


# for x in self.param:
Expand All @@ -240,7 +241,6 @@ def pde(xr, nn, geomodel):
self.solver = PINNSolver(self.model, pde,
losses,
self.dataset,
manager = self.manager,
geomodel = self.geomodel,
options = opts)

Expand Down Expand Up @@ -271,17 +271,17 @@ def setup_ckpt(self, model, ckptdir = 'ckpt', restore = None):
# manager.latest_checkpoint is None if no ckpt found

if self.opts['restore'] is not None and (self.opts['restore']):
# restore from previous simulation
prev_manager = tf.train.CheckpointManager(checkpoint, directory=os.path.join(self.opts['restore'],ckptdir), max_to_keep=4)
# not None and not empty
# self.opts['restore'] is a number or a path
if isinstance(self.opts['restore'],int):
# restore check point in the same directory by integer, 0 = ckpt-1
ckptpath = manager.checkpoints[self.opts['restore']]
ckptpath = prev_manager.checkpoints[self.opts['restore']]
else:
# restore checkpoint by path
if "/ckpt" in self.opts['restore']:
ckptpath = os.path.join(self.opts['restore'])
else:
ckptpath = os.path.join(self.opts['restore'],ckptdir,'ckpt-2')
ckptpath = prev_manager.latest_checkpoint

checkpoint.restore(ckptpath)
print("Restored from {}".format(ckptpath))
else:
Expand Down
7 changes: 3 additions & 4 deletions pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,14 @@ class PINNSolver():
def __init__(self, model, pde,
losses,
dataset,
manager,
geomodel=None,
wr = None,
options=None):
self.model = model
self.geomodel = geomodel
self.pde = pde
self.dataset = dataset
self.manager = manager


self.losses = losses

Expand Down Expand Up @@ -537,10 +536,10 @@ def callback_train_end(self):
# also make prediction of xr at various time
self.earlystop.reset()
if self.options.get('saveckpt'):
save_path = self.manager.save()
save_path = self.model.manager.save()
print("Saved checkpoint for {} step {} {}".format(int(self.iter),self.current_optimizer, save_path))
if self.geomodel is not None:
save_path = self.geomodel.save_checkpoint()
save_path = self.geomodel.manager.save()

else:
print("checkpoint not saved")
Expand Down

0 comments on commit 8379a9f

Please sign in to comment.