Skip to content

Commit

Permalink
fixed wrong format
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacperez committed Aug 19, 2024
1 parent c50e34e commit 3c68254
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions docs/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,36 +179,35 @@ During backpropagation, the `GradientFunction` plays a vital role by orchestrati

#### How Backpropagation Works

1. __Starting the Backward Pass__:
1. __Starting the Backward Pass__:
Backpropagation begins at the tensor where the loss is computed. Typically, this tensor is a scalar (i.e., a tensor with zero dimensions). The `backward()` method is called on this tensor to initiate the gradient computation.

```python
loss.backward()
```
If the tensor is a scalar, its gradient is initialized to `1.0`, since the derivative of a value with respect to itself is `1.0`. If the tensor is not a scalar, an external gradient must be provided.

2. __Signal Propagation__:
2. __Signal Propagation__:
Before any actual gradient computation takes place, the system propagates a "signal" through the computational graph. This signal helps each tensor keep track of how many gradients it should expect to receive from its downstream operations. The purpose of this step is to ensure that a tensor only propagates its gradient backward once it has received all the expected gradients. This reduces redundant gradient propagations and makes the backpropagation process more efficient.

- __Signal Propagation Implementation__:
When the `backward()` method is called, the tensor first calls `_propagate_reference_signal()`. This method increments a counter (`_pending_gradients_count`) in each tensor to track the number of gradients it needs to receive. Only after all expected gradients have been received will the tensor propagate its accumulated gradient further back through the graph.

3. __Gradient Initialization__:
3. __Gradient Initialization__:
Once the signal propagation is complete, the actual gradient computation begins. The gradient for the starting tensor (typically the loss) is initialized. For scalar tensors, this gradient is `1.0`. For non-scalar tensors, the provided gradient is validated and used for the backward pass.

4. __Gradient Accumulation__:
4. __Gradient Accumulation__:
As gradients flow backward through the graph, each tensor accumulates the incoming gradients. This is especially important when a tensor contributes to multiple operations, as it needs to sum the gradients from all those operations before sending its accumulated gradient backward.

- __Gradient Accumulation Implementation__:
In the `Tensor` class, the `_accumulate_gradient()` method is responsible for accumulating gradients in the `grad` attribute (for leaf tensors or when retaining gradients) and in the `_accumulated_gradient_to_propagate` attribute, which is used to store gradients temporarily until the tensor has received all the expected gradients.

5. __Gradient Propagation__:
5. __Gradient Propagation__:
Once a tensor has accumulated all the gradients it expects, it propagates the accumulated gradient to its input tensors. This step involves calling the `backward()` method of the `GradientFunction` associated with the tensor, which further propagates the gradient to the tensor's inputs.

- __Gradient Propagation Implementation__:
The `_propagate_gradient()` method in the `Tensor` class is responsible for this step. It ensures that the gradient is only propagated once all expected gradients have been received, thereby avoiding unnecessary recomputation.

6. __Releasing the Computational Graph__:
6. __Releasing the Computational Graph__:
After backpropagation, the computational graph is usually released to free up memory. This means that you cannot perform another backward pass unless you explicitly retain the graph by passing `retain_graph=True` to the `backward()` method.

```python
Expand Down

0 comments on commit 3c68254

Please sign in to comment.