Skip to content

Commit

Permalink
Merge pull request #1 from calincru/master
Browse files Browse the repository at this point in the history
Use a conda .environment file to easily set up the dev environment and reproduce the behaviour
  • Loading branch information
josipd authored Mar 19, 2018
2 parents 3ce5abf + 4fbaf85 commit 00a2196
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 10 deletions.
16 changes: 16 additions & 0 deletions .environment
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
dependencies:
- python=3.5
- numpy=1.14
- scipy=1.0
- matplotlib=2.2
- pyzmq=17.0
- jupyter=1.0
- cython=0.28
- pytorch=0.3 # from channel pytorch
- torchvision=0.2 # from channel pytorch
- shogun=6.0 # from channel conda-forge
- tqdm
- pylint
- sphinxcontrib-bibtex
- sphinx_rtd_theme
- mock
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ build
*.swo
dist
docs/_build
.pytest_cache/
10 changes: 4 additions & 6 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@ install:
- hash -r
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
# Add channels necessary for the dependencies
- conda config --add channels pytorch
- conda config --add channels conda-forge
# Useful for debugging any issues with conda
- conda info -a

- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION
- conda env create -n test-environment -f .environment python=$TRAVIS_PYTHON_VERSION
- source activate test-environment
- conda install pytorch -c soumith
- conda install -c conda-forge shogun
- conda install numpy
- conda install cython
- conda install scipy
- python setup.py install

script:
Expand Down
2 changes: 2 additions & 0 deletions notebooks/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.ipynb_checkpoints/
mnist_dir/
2 changes: 1 addition & 1 deletion tests/inference_cardinality_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_softmax():
unaries = Variable(torch.randn(batch_size, 15))
count_potentials = Variable(NINF * torch.ones(batch_size, k + 1))
count_potentials[:, 1] = 0
output_softmax = softmax(unaries)
output_softmax = softmax(unaries, dim=1)
output_cardinf = inference_cardinality(unaries, count_potentials)
assert np.allclose(output_softmax.cpu().data.numpy(),
output_cardinf.cpu().data.numpy())
Expand Down
5 changes: 3 additions & 2 deletions torch_two_sample/inference_cardinality.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ def create_var(val, *dims):

bb = bmsgs[:, 2:, :]
ff = fmsgs[:, :-2, :]
b0 = logsumexp(bb + ff, 2)
b1 = logsumexp(bb[:, :, :-1] + ff[:, :, 1:], 2) + node_potentials[:, :-1]
b0 = logsumexp(bb + ff, 2).view(batch_size, dim_node-1)
b1 = logsumexp(bb[:, :, :-1] + ff[:, :, 1:], 2).view(
batch_size, dim_node-1) + node_potentials[:, :-1]

marginals = create_var(0, batch_size, dim_node)
marginals[:, :-1] = torch.sigmoid(b1 - b0)
Expand Down
2 changes: 1 addition & 1 deletion torch_two_sample/statistics_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def __call__(self, sample_1, sample_2, alphas, norm=2, ret_matrix=False):
margs_ = None
for alpha in alphas:
if self.k == 1:
margs_a = softmax(- alpha * diffs)
margs_a = softmax(-alpha * diffs, dim=1)
else:
margs_a = inference_cardinality(
- alpha * diffs.cpu(), count_potential)
Expand Down

0 comments on commit 00a2196

Please sign in to comment.