Skip to content

Commit

Permalink
[Example] High-resolution Reconstruction of Cardiac Magnetic Resonanc…
Browse files Browse the repository at this point in the history
…e Imaging (PaddlePaddle#1004)

* merge code of upstream

* merge code of upstream

* merge code of upstream

* merge code of upstream

* merge code of upstream

* merge code of upstream

* Update jointContribution/HighResolution/ffd/engine.py

* Update jointContribution/HighResolution/README.md

Co-authored-by: lijialin03 <[email protected]>

---------

Co-authored-by: HydrogenSulfate <[email protected]>
Co-authored-by: lijialin03 <[email protected]>
  • Loading branch information
3 people authored Nov 11, 2024
1 parent 25f3896 commit db2bc09
Show file tree
Hide file tree
Showing 15 changed files with 2,380 additions and 0 deletions.
30 changes: 30 additions & 0 deletions jointContribution/HighResolution/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## Install deepali

``` sh
# get deepali(with paddle backend)
git clone https://github.com/PFCCLab/deepali.git
# add temporary path
export PYTHONPATH="/path_to_deepali/src/:$PYTHONPATH"
```

## Dataset

Download demo dataset:

``` sh
cd PaddleScience/jointContribution/HighResolution
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/HighResolution/patient001.zip
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/HighResolution/Hammersmith_myo2.zip
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/HighResolution/patient001.zip
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/HighResolution/Hammersmith_myo2.zip

# unzip
unzip patient001.zip -d data
unzip Hammersmith_myo2.zip
```

## Run

python main_ACDC.py
211 changes: 211 additions & 0 deletions jointContribution/HighResolution/ffd/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from __future__ import annotations

import math
import weakref
from collections import OrderedDict
from timeit import default_timer as timer
from typing import Any
from typing import Callable
from typing import Tuple

import paddle

from .losses import RegistrationLoss
from .losses import RegistrationResult
from .optim import slope_of_least_squares_fit

PROFILING = False


class RegistrationEngine:
"""Minimize registration loss until convergence."""

def __init__(
self,
model: paddle.nn.Layer,
loss: RegistrationLoss,
optimizer: paddle.optimizer.Optimizer,
max_steps: int = 500,
min_delta: float = 1e-06,
min_value: float = float("nan"),
max_history: int = 10,
):
"""Initialize registration loop."""
self.model = model
self.loss = loss
self.optimizer = optimizer
self.num_steps = 0
self.max_steps = max_steps
self.min_delta = min_delta
self.min_value = min_value
self.max_history = max(2, max_history)
self.loss_values = []
self._eval_hooks = OrderedDict()
self._step_hooks = OrderedDict()

@property
def loss_value(self) -> float:
if not self.loss_values:
return float("inf")
return self.loss_values[-1]

def step(self) -> float:
"""Perform one registration step.
Returns:
Loss value prior to taking gradient step.
"""
num_evals = 0

def closure() -> float:
self.optimizer.clear_grad()
t_start = timer()
result = self.loss.eval()
if PROFILING:
print(f"Forward pass in {timer() - t_start:.3f}s")
loss = result["loss"]
assert isinstance(loss, paddle.Tensor)
t_start = timer()
loss.backward()
if PROFILING:
print(f"Backward pass in {timer() - t_start:.3f}s")
nonlocal num_evals
num_evals += 1
with paddle.no_grad():
for hook in self._eval_hooks.values():
hook(self, self.num_steps, num_evals, result)
return float(loss)

loss_value = closure()
self.optimizer.step()
assert loss_value is not None
with paddle.no_grad():
for hook in self._step_hooks.values():
hook(self, self.num_steps, num_evals, loss_value)
return loss_value

def run(self) -> float:
"""Perform registration steps until convergence.
Returns:
Loss value prior to taking last gradient step.
"""
self.loss_values = []
self.num_steps = 0
while self.num_steps < self.max_steps and not self.converged():
value = self.step()
self.num_steps += 1
if math.isnan(value):
raise RuntimeError(
f"NaN value in registration loss at gradient step {self.num_steps}"
)
if math.isinf(value):
raise RuntimeError(
f"Inf value in registration loss at gradient step {self.num_steps}"
)
self.loss_values.append(value)
if len(self.loss_values) > self.max_history:
self.loss_values.pop(0)
return self.loss_value

def converged(self) -> bool:
"""Check convergence criteria."""
values = self.loss_values
if not values:
return False
value = values[-1]
if self.min_delta < 0:
epsilon = abs(self.min_delta * value)
else:
epsilon = self.min_delta
slope = slope_of_least_squares_fit(values)
if abs(slope) < epsilon:
return True
if value < self.min_value:
return True
return False

def register_eval_hook(
self,
hook: Callable[["RegistrationEngine", int, int, "RegistrationResult"], None],
) -> "RemovableHandle":
r"""Registers an evaluation hook."""
handle = RemovableHandle(self._eval_hooks)
self._eval_hooks[handle.id] = hook
return handle

def register_step_hook(
self, hook: Callable[["RegistrationEngine", int, int, float], None]
) -> "RemovableHandle":
r"""Registers a gradient step hook."""
handle = RemovableHandle(self._step_hooks)
self._step_hooks[handle.id] = hook
return handle


class RemovableHandle:
r"""
A handle which provides the capability to remove a hook.
Args:
hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``.
extra_dict (Union[dict, List[dict]]): An additional dictionary or list of
dictionaries whose keys will be deleted when the same keys are
removed from ``hooks_dict``.
"""

id: int
next_id: int = 0

def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None:
self.hooks_dict_ref = weakref.ref(hooks_dict)
self.id = RemovableHandle.next_id
RemovableHandle.next_id += 1

self.extra_dict_ref: Tuple = ()
if isinstance(extra_dict, dict):
self.extra_dict_ref = (weakref.ref(extra_dict),)
elif isinstance(extra_dict, list):
self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict)

def remove(self) -> None:
hooks_dict = self.hooks_dict_ref()
if hooks_dict is not None and self.id in hooks_dict:
del hooks_dict[self.id]

for ref in self.extra_dict_ref:
extra_dict = ref()
if extra_dict is not None and self.id in extra_dict:
del extra_dict[self.id]

def __getstate__(self):
if self.extra_dict_ref is None:
return (self.hooks_dict_ref(), self.id)
else:
return (
self.hooks_dict_ref(),
self.id,
tuple(ref() for ref in self.extra_dict_ref),
)

def __setstate__(self, state) -> None:
if state[0] is None:
# create a dead reference
self.hooks_dict_ref = weakref.ref(OrderedDict())
else:
self.hooks_dict_ref = weakref.ref(state[0])
self.id = state[1]
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)

if len(state) < 3 or state[2] is None:
self.extra_dict_ref = ()
else:
self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2])

def __enter__(self) -> "RemovableHandle":
return self

def __exit__(self, type: Any, value: Any, tb: Any) -> None:
self.remove()
72 changes: 72 additions & 0 deletions jointContribution/HighResolution/ffd/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Callable

import paddle
from deepali.core import functional as U
from deepali.core.kernels import gaussian1d
from deepali.spatial import is_linear_transform

from .engine import RegistrationEngine
from .engine import RegistrationResult

RegistrationEvalHook = Callable[
[RegistrationEngine, int, int, RegistrationResult], None
]
RegistrationStepHook = Callable[[RegistrationEngine, int, int, float], None]


def noop(reg: RegistrationEngine, *args, **kwargs) -> None:
"""Dummy no-op loss evaluation hook."""
...


def normalize_linear_grad(reg: RegistrationEngine, *args, **kwargs) -> None:
"""Loss evaluation hook for normalization of linear transformation gradient after backward pass."""
denom = None
for param in reg.model.parameters():
if not param.stop_gradient and param.grad is not None:
max_abs_grad = paddle.max(paddle.abs(param.grad))
if denom is None or denom < max_abs_grad:
denom = max_abs_grad
if denom is None:
return
for param in reg.model.parameters():
if not param.stop_gradient and param.grad is not None:
param.grad /= denom


def normalize_nonrigid_grad(reg: RegistrationEngine, *args, **kwargs) -> None:
"""Loss evaluation hook for normalization of non-rigid transformation gradient after backward pass."""
for param in reg.model.parameters():
if not param.stop_gradient and param.grad is not None:
paddle.assign(
paddle.nn.functional.normalize(x=param.grad, p=2, axis=1),
output=param.grad,
)


def normalize_grad_hook(transform) -> RegistrationEvalHook:
"""Loss evaluation hook for normalization of transformation gradient after backward pass."""
if is_linear_transform(transform):
return normalize_linear_grad
return normalize_nonrigid_grad


def _smooth_nonrigid_grad(reg: RegistrationEngine, sigma: float = 1) -> None:
"""Loss evaluation hook for Gaussian smoothing of non-rigid transformation gradient after backward pass."""
if sigma <= 0:
return
kernel = gaussian1d(sigma)
for param in reg.model.parameters():
if not param.stop_gradient and param.grad is not None:
param.grad.copy_(U.conv(param.grad, kernel))


def smooth_grad_hook(transform, sigma: float) -> RegistrationEvalHook:
"""Loss evaluation hook for Gaussian smoothing of non-rigid gradient after backward pass."""
if is_linear_transform(transform):
return noop

def fn(reg: RegistrationEngine, *args, **kwargs):
return _smooth_nonrigid_grad(reg, sigma=sigma)

return fn
Loading

0 comments on commit db2bc09

Please sign in to comment.