Skip to content

Commit

Permalink
funding
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed May 20, 2024
1 parent 8fe566b commit 05140f0
Show file tree
Hide file tree
Showing 14 changed files with 218 additions and 148 deletions.
2 changes: 2 additions & 0 deletions .github/FUNDING.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
patreon: GreenWizard
buy_me_a_coffee: greenwizard89
9 changes: 9 additions & 0 deletions FUNDING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Support and Funding

If you want to support my quest for sarcastic humor, brilliant software achievements, and witty commentary, here’s how you can do it:

1. **Patreon**: All my amazing content there is free. But if you feel an irresistible urge to support me with monthly donations, head over to [my Patreon page](https://www.patreon.com/GreenWizard). Your support will help me keep delighting you with sarcasm and software revelations.

2. **Buy Me a Coffee**: Prefer one-time acts of generosity? You can buy me a coffee at [Buy Me a Coffee](https://buymeacoffee.com/greenwizard89). Because, honestly, my level of sarcasm and ability to write genius code directly correlate with the amount of caffeine I consume.

By supporting me, you ensure a continuous flow of sarcasm, wit, and cutting-edge software insights. And who wouldn't want more of that in their life?
22 changes: 19 additions & 3 deletions NN/restorators/CARProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ def __init__(self, predictions, sourceDistribution, sampler):
def forward(self, x0, xT=None):
B = tf.shape(x0)[0]
# source distribution need to know the shape of the input, so we need to ensure it explicitly
x0 = tf.ensure_shape(x0, (None, self.predictions))
# x0 = tf.ensure_shape(x0, (None, self.predictions))
sampled = self._sourceDistribution.sampleFor(x0)
x1 = sampled['xT'] if xT is None else xT

tf.assert_equal(tf.shape(x0), (B, self.predictions))
# tf.assert_equal(tf.shape(x0), (B, self.predictions))
return self._sampler.train(x0=x0, x1=x1, T=sampled['T'], xT=xT)

def calculate_loss(self, x_hat, predicted, **kwargs):
if hasattr(self._sampler, 'calculate_loss'):
return self._sampler.calculate_loss(x_hat, predicted, **kwargs)
lossFn = kwargs.get('lossFn', tf.losses.mae) # default loss function
target = self.withExtraOutputs(x_hat['target'], **kwargs)
tf.assert_equal(tf.shape(target), tf.shape(predicted))
Expand All @@ -45,11 +47,25 @@ def reverse(self, value, denoiser, modelT=None, index=0, **kwargs):

denoiser = self._makeDenoiser(denoiser, modelT)
res = self._sampler.sample(value=value, model=denoiser, index=index, **kwargs)
tf.assert_equal(tf.shape(res), tf.shape(value))
# tf.assert_equal(tf.shape(res), tf.shape(value))
return res

def targets(self, x_hat, values):
return self._sampler.targets(x_hat, values[:, :self.predictions])

@property
def channels(self):
sampler = self._sampler
if hasattr(sampler, 'channels'):
return sampler.channels
return super().channels

@property
def predictions(self):
sampler = self._sampler
if hasattr(sampler, 'predictions'):
return sampler.predictions
return super().predictions
# End of CARProcess

def autoregressive_restoration_from_config(config):
Expand Down
9 changes: 7 additions & 2 deletions NN/restorators/IRestorationProcess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import tensorflow as tf

class IRestorationProcess:
class IRestorationProcess(tf.keras.Model):
@staticmethod
def getOutputSize(outputs):
sizes = {
Expand All @@ -15,7 +15,9 @@ def getOutputSize(outputs):
continue
return channels

def __init__(self, predictions):
def __init__(self, predictions, name=None, **kwargs):
if name is None: name = self.__class__.__name__
super().__init__(name=name, **kwargs)
predictions = list(predictions)
if 'rgb' not in predictions: predictions.insert(0, 'rgb')
self._outputs = predictions
Expand All @@ -24,6 +26,9 @@ def __init__(self, predictions):
print('[IRestorationProcess] Restorator:', self.__class__.__name__)
return

def call(self, *args, **kwargs):
raise RuntimeError('IRestorationProcess object cannot be called directly')

def forward(self, x0, xT=None):
raise NotImplementedError()

Expand Down
2 changes: 1 addition & 1 deletion NN/restorators/samplers/CARSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def solve(self, x_hat, step, value, interpolant, params, **kwargs):
)

def directSolve(self, x_hat, values, interpolant, **kwargs):
solved = interpolant.solve(x_hat=x_hat['x1'], xt=values, t=1.0).x0
solved = interpolant.solve(x_hat=values, xt=x_hat['x1'], t=1.0).x0
return solved
# End of CARSamplingAlgorithm

Expand Down
24 changes: 21 additions & 3 deletions NN/restorators/samplers/CBasicInterpolantSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,21 @@
from .ISamplingAlgorithm import ISamplingAlgorithm
from NN.utils import is_namedtuple

class CBasicInterpolantSampler:
def __init__(self, interpolant, algorithm):
class IBasicInterpolantSampler(tf.keras.Model):
@property
def interpolant(self):
raise NotImplementedError()

@tf.function
def sample(self, value, model, index=0, **kwargs):
raise NotImplementedError()

def targets(self, x_hat, values):
raise NotImplementedError()

class CBasicInterpolantSampler(IBasicInterpolantSampler):
def __init__(self, interpolant, algorithm, **kwargs):
super().__init__(**kwargs)
self._interpolant = interpolant
self._algorithm = algorithm
return
Expand All @@ -17,7 +30,7 @@ def sample(self, value, model, index=0, **kwargs):
kwargs = dict(**kwargs, interpolant=self._interpolant, index=index)
# wrap algorithm with hook, if provided
algorithm = kwargs.get('algorithmInterceptor', lambda x: x)( self._algorithm )
assert isinstance(algorithm, ISamplingAlgorithm), f'Algorithm must be an instance of ISamplingAlgorithm, but got {type(algorithm)}'
assert issubclass(type(algorithm), ISamplingAlgorithm), f'Algorithm must be an instance of ISamplingAlgorithm, but got {type(algorithm)}'

step = algorithm.firstStep(value=value, **kwargs)
# CFakeObject is a namedtuple, so we need to check for it
Expand All @@ -39,6 +52,11 @@ def sample(self, value, model, index=0, **kwargs):
# update value
tf.assert_equal(tf.shape(value), tf.shape(solution.value))
value = solution.value
# for debugging, print euclidean distance to GT
if 'GT' in kwargs:
gt = kwargs['GT']
dist = tf.reduce_sum(tf.square(value - gt), axis=-1)
tf.print(f'Iteration {iteration}:', dist, summarize=10)
iteration += 1
continue

Expand Down
107 changes: 2 additions & 105 deletions NN/restorators/samplers/CDDIMInterpolantSampler.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,6 @@
import tensorflow as tf
from Utils.utils import CFakeObject
from NN.utils import normVec
from .CBasicInterpolantSampler import CBasicInterpolantSampler, ISamplingAlgorithm

class CDDIMSamplingAlgorithm(ISamplingAlgorithm):
def __init__(self, stochasticity, noiseProvider, schedule, steps, clipping, projectNoise):
self._stochasticity = stochasticity
self._noiseProvider = noiseProvider
self._schedule = schedule
self._steps = steps
self._clipping = clipping
self._projectNoise = projectNoise
return

def _makeStep(self, current_step, steps, **kwargs):
schedule = kwargs.get('schedule', self._schedule)
eta = kwargs.get('eta', self._stochasticity)

T = steps[0][current_step]
alpha_hat_t = schedule.parametersForT(T).alphaHat
prevStepInd = steps[1][current_step]
alpha_hat_t_prev = schedule.parametersForT(prevStepInd).alphaHat

stepVariance = schedule.varianceBetween(alpha_hat_t, alpha_hat_t_prev)
sigma = tf.sqrt(stepVariance) * eta

return CFakeObject(
steps=steps,
current_step=current_step,
active=(0 <= current_step),
sigma=sigma,
#
T=T,
t=alpha_hat_t,
t_prev=alpha_hat_t_prev,
t_prev_2=1.0 - alpha_hat_t_prev - tf.square(sigma),
)

def firstStep(self, **kwargs):
schedule = kwargs.get('schedule', self._schedule)
assert schedule is not None, 'schedule is None'
assert schedule.is_discrete, 'schedule is not discrete'
steps = schedule.steps_sequence(
startStep=kwargs.get('startStep', None),
endStep=kwargs.get('endStep', None),
config=kwargs.get('stepsConfig', self._steps),
reverse=True, # reverse steps order to make it easier to iterate over them
)

return self._makeStep(
current_step=tf.size(steps[0]) - 1,
steps=steps,
**kwargs
)

def nextStep(self, step, **kwargs):
return self._makeStep(
current_step=step.current_step - 1,
steps=step.steps,
**kwargs
)

def inference(self, model, step, value, **kwargs):
schedule = kwargs.get('schedule', self._schedule)
return model(
x=value,
T=step.T,
t=schedule.to_continuous(step.T),
)

def _withNoise(self, value, sigma, x0, kwargs):
noise_provider = kwargs.get('noiseProvider', self._noiseProvider)
noise = noise_provider(shape=tf.shape(value), sigma=sigma)
if not kwargs.get('projectNoise', self._projectNoise): return value + noise

_, L = normVec(value - x0)
vec, _ = normVec(value + noise - x0)
return x0 + L * vec # project noise back to the spherical manifold

def _withClipping(self, value, kwargs):
clipping = kwargs.get('clipping', self._clipping)
if clipping is None: return value
return tf.clip_by_value(value, clip_value_min=clipping['min'], clip_value_max=clipping['max'])

def solve(self, x_hat, step, value, interpolant, **kwargs):
solved = interpolant.solve(x_hat=x_hat, xt=value, t=step.t)
x_prev = interpolant.interpolate(
x0=solved.x0, x1=solved.x1,
t=step.t_prev, t2=step.t_prev_2
)
x_prev = self._withNoise(x_prev, sigma=step.sigma, x0=solved.x0, kwargs=kwargs)
x_prev = self._withClipping(x_prev, kwargs=kwargs)
# return solution and additional information for debugging
return CFakeObject(
value=x_prev,
x0=solved.x0,
x1=solved.x1,
T=step.T,
current_step=step.current_step,
sigma=step.sigma,
)

def directSolve(self, x_hat, xt, interpolant):
return interpolant.solve(x_hat=xt, xt=x_hat['xT'], t=x_hat['alphaHat']).x0
# End of CDDIMSamplingAlgorithm
from .CBasicInterpolantSampler import CBasicInterpolantSampler
from .CDDIMSamplingAlgorithm import CDDIMSamplingAlgorithm

class CDDIMInterpolantSampler(CBasicInterpolantSampler):
def __init__(
Expand Down
1 change: 0 additions & 1 deletion NN/restorators/samplers/CSamplerWatcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import tensorflow as tf
from NN.utils import is_namedtuple
import NN.utils as NNU
from .CSamplingInterceptor import CSamplingInterceptor
from .ISamplerWatcher import ISamplerWatcher

Expand Down
34 changes: 29 additions & 5 deletions NN/restorators/samplers/CSamplingInterceptor.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,48 @@
from .ISamplingAlgorithm import ISamplingAlgorithm
from .ISamplerWatcher import ISamplerWatcher

class CSamplingInterceptor(ISamplingAlgorithm):
class CSamplingInterceptor(ISamplingAlgorithm, ISamplerWatcher):
def __init__(self, watcher, algorithm):
assert issubclass(type(watcher), ISamplerWatcher), f'Invalid watcher: {watcher}'
self._watcher = watcher
self._algorithm = algorithm
return

def interceptor(self):
def F(algorithm):
if isinstance(self._watcher, ISamplerWatcher):
self._watcher = algorithm = self._watcher.interceptor()(algorithm)

if callable(self._watcher): # replace the watcher with the interceptor
self._watcher = algorithm = self._watcher(algorithm)

# if self._algorithm is not None:
# assert isinstance(algorithm, ISamplingAlgorithm), f'algorithm is not an instance of ISamplingAlgorithm: {algorithm}'
return self
return F

def firstStep(self, **kwargs):
res = self._algorithm.firstStep(**kwargs)
res = self._algorithm.firstStep(**kwargs) if self._algorithm is not None else None
self._watcher._onStart(value=kwargs['value'], kwargs=kwargs)
return res

def _onStart(self, value, kwargs):
return self._watcher._onStart(value=value, kwargs=kwargs)

def nextStep(self, **kwargs):
self._watcher._onNextStep(iteration=kwargs['iteration'], kwargs=kwargs)
res = self._algorithm.nextStep(**kwargs)
res = self._algorithm.nextStep(**kwargs) if self._algorithm is not None else None
return res

def _onNextStep(self, iteration, kwargs):
return self._watcher._onNextStep(iteration=iteration, kwargs=kwargs)

def inference(self, **kwargs):
return self._algorithm.inference(**kwargs)
return self._algorithm.inference(**kwargs) if self._algorithm is not None else None

def solve(self, **kwargs):
return self._algorithm.solve(**kwargs)
return self._algorithm.solve(**kwargs) if self._algorithm is not None else None

def tracked(self, name):
return self._watcher.tracked(name)
# End of CSamplingInterceptor
Loading

0 comments on commit 05140f0

Please sign in to comment.