Skip to content

Commit

Permalink
review fixes 2
Browse files Browse the repository at this point in the history
  • Loading branch information
VIGNESHinZONE authored and atreyamaj committed Sep 16, 2021
1 parent 347ef3d commit 386979b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions deepchem/models/jax_models/pinns_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''
This python consists of different variations of the Physics Informer Neural Network model using the JaxModel API
This module contains different variations of the Physics Informer Neural Network model using the JaxModel API
'''
import numpy as np
import time
Expand Down Expand Up @@ -29,7 +29,9 @@
logger = logging.getLogger(__name__)


def create_default_update_fn(optimizer, model_loss):
def create_default_update_fn(
optimizer: optax.GradientTransformation,
model_loss: callable):
"""
This function calls the update function, to implement the backpropagation
"""
Expand All @@ -51,7 +53,7 @@ class PINNModel(JaxModel):
but it has the option of passing multiple arguments(Done using *args) suitable for PINNs model.
Ex - Approximating f(x, y, z, t) satisfying a Linear differential equation.
This model is recommended for Linear differential equations but if you can accurately write
This model is recommended for linear partial differential equations but if you can accurately write
the gradient function in Jax depending on your use case, then it will work as well.
This class requires two functions apart from the usual function definition and weights
Expand Down
2 changes: 1 addition & 1 deletion deepchem/models/jax_models/tests/test_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@pytest.mark.jax
def test_sine_x():
"""
Here we are solving the diffrential equation- f'(x) = -sin(x) and f(0) = 1
Here we are solving the differential equation- f'(x) = -sin(x) and f(0) = 1
We give initial for the neural network at x_init --> np.linspace(-1 * np.pi, 1 * np.pi, 5)
And we try to approximate the function for the domain (-np.pi, np.pi)
"""
Expand Down

0 comments on commit 386979b

Please sign in to comment.