Skip to content

Commit

Permalink
fixed batch first bug
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanonardo committed Apr 12, 2018
1 parent 259a261 commit a053adc
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ PyTorch-ESN is a PyTorch module implementing Echo State Networks with leaky-inte

### Offline training (ridge regression)

```
```python
from torchesn.nn import ESN
from torchesn.utils import prepare_target

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import setup, find_packages

setup(name='pytorch-esn',
version='1.0.2',
version='1.0.3',
packages=find_packages(),
install_requires=[
'torch',
Expand Down
3 changes: 2 additions & 1 deletion torchesn/nn/echo_state_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def forward(self, input, h_0, washout=0, target=None):
padded_input = padded_input.transpose(0, 1)
output = torch.cat([padded_input[washout:], output], -1)
else:
input = input.transpose(0, 1)
if self.batch_first:
input = input.transpose(0, 1)
output = torch.cat([input[washout:], output], -1)

if self.readout_training == 'online' or target is None:
Expand Down

0 comments on commit a053adc

Please sign in to comment.