Skip to content

Commit

Permalink
Export modules correctly.
Browse files Browse the repository at this point in the history
  • Loading branch information
mblondel committed Aug 2, 2021
1 parent 932f370 commit 3e8147f
Show file tree
Hide file tree
Showing 24 changed files with 291 additions and 53 deletions.
6 changes: 3 additions & 3 deletions examples/lasso_implicit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax.numpy as jnp

from jaxopt import BlockCoordinateDescent
from jaxopt import objectives
from jaxopt import objective
from jaxopt import OptaxSolver
from jaxopt import prox
from jaxopt import ProximalGradient
Expand All @@ -44,13 +44,13 @@ def outer_objective(theta, init_inner, data):

if FLAGS.solver == "pg":
solver = ProximalGradient(
fun=objectives.least_squares,
fun=objective.least_squares,
prox=prox.prox_lasso,
implicit_diff=not FLAGS.unrolling,
maxiter=500)
elif FLAGS.solver == "bcd":
solver = BlockCoordinateDescent(
fun=objectives.least_squares,
fun=objective.least_squares,
block_prox=prox.prox_lasso,
implicit_diff=not FLAGS.unrolling,
maxiter=500)
Expand Down
4 changes: 2 additions & 2 deletions examples/multiclass_linear_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from absl import app
import jax.numpy as jnp
from jaxopt import BlockCoordinateDescent
from jaxopt import objectives
from jaxopt import objective
from jaxopt import projection
from jaxopt import prox
from sklearn import datasets
Expand All @@ -44,7 +44,7 @@ def main(argv):

# Set up parameters.
block_prox = prox.make_prox_from_projection(projection.projection_simplex)
fun = objectives.multiclass_linear_svm_dual
fun = objective.multiclass_linear_svm_dual
data = (X, Y)
lam = 1000.0
beta_init = jnp.ones((n_samples, n_classes)) / n_classes
Expand Down
4 changes: 2 additions & 2 deletions examples/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax.numpy as jnp

from jaxopt import BlockCoordinateDescent
from jaxopt import objectives
from jaxopt import objective
from jaxopt import prox

import numpy as onp
Expand Down Expand Up @@ -51,7 +51,7 @@ def nnreg(U, V_init, X, maxiter=150):
else:
raise ValueError("Invalid penalty.")

bcd = BlockCoordinateDescent(fun=objectives.least_squares,
bcd = BlockCoordinateDescent(fun=objective.least_squares,
block_prox=block_prox,
maxiter=maxiter)
sol = bcd.run(init_params=V_init.T, hyperparams_prox=FLAGS.gamma, data=(U, X))
Expand Down
11 changes: 0 additions & 11 deletions jaxopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src import base
from jaxopt._src import implicit_diff
from jaxopt._src import linear_solve
from jaxopt._src import loop
from jaxopt._src import loss
from jaxopt._src import objectives
from jaxopt._src import perturbations
from jaxopt._src import projection
from jaxopt._src import prox
from jaxopt._src import tree_util

from jaxopt._src.bisection import Bisection
from jaxopt._src.block_cd import BlockCoordinateDescent
from jaxopt._src.gradient_descent import GradientDescent
Expand Down
6 changes: 0 additions & 6 deletions jaxopt/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@
import jax.numpy as jnp


class OptimizeResults(NamedTuple):
error: float
nit: int
x: Any


class OptStep(NamedTuple):
params: Any
state: Any
Expand Down
2 changes: 1 addition & 1 deletion jaxopt/_src/block_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class BlockCoordinateDescent:
Attributes:
fun: a smooth function of the form ``fun(params, *args, **kwargs)``.
It should be a ``objectives.CompositeLinearFunction`` object.
It should be a ``objective.CompositeLinearFunction`` object.
block_prox: block-wise proximity operator associated with ``non_smooth``,
a function of the form ``block_prox(x[j], hyperparams_prox, scaling=1.0)``.
See ``jaxopt.prox`` for examples.
Expand Down
File renamed without changes.
16 changes: 16 additions & 0 deletions jaxopt/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.base import LinearOperator
from jaxopt._src.base import OptStep
18 changes: 18 additions & 0 deletions jaxopt/implicit_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.implicit_diff import custom_root
from jaxopt._src.implicit_diff import custom_fixed_point
from jaxopt._src.implicit_diff import root_jvp
from jaxopt._src.implicit_diff import root_vjp
20 changes: 20 additions & 0 deletions jaxopt/linear_solve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.linear_solve import solve_lu
from jaxopt._src.linear_solve import solve_cholesky
from jaxopt._src.linear_solve import solve_cg
from jaxopt._src.linear_solve import solve_normal_cg
from jaxopt._src.linear_solve import solve_gmres
from jaxopt._src.linear_solve import solve_bicgstab
15 changes: 15 additions & 0 deletions jaxopt/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.loop import while_loop
18 changes: 18 additions & 0 deletions jaxopt/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.loss import binary_logistic_loss
from jaxopt._src.loss import huber_loss
from jaxopt._src.loss import multiclass_logistic_loss
from jaxopt._src.loss import multiclass_sparsemax_loss
22 changes: 22 additions & 0 deletions jaxopt/objective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.objective import CompositeLinearFunction
from jaxopt._src.objective import least_squares
from jaxopt._src.objective import multiclass_logreg
from jaxopt._src.objective import multiclass_logreg_with_intercept
from jaxopt._src.objective import l2_multiclass_logreg
from jaxopt._src.objective import l2_multiclass_logreg_with_intercept
from jaxopt._src.objective import binary_logreg
from jaxopt._src.objective import multiclass_linear_svm_dual
18 changes: 18 additions & 0 deletions jaxopt/perturbations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.perturbations import Gumbel
from jaxopt._src.perturbations import Normal
from jaxopt._src.perturbations import make_perturbed_argmax
from jaxopt._src.perturbations import make_perturbed_max
27 changes: 27 additions & 0 deletions jaxopt/projection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.projection import projection_non_negative
from jaxopt._src.projection import projection_box
from jaxopt._src.projection import projection_simplex
from jaxopt._src.projection import projection_l1_sphere
from jaxopt._src.projection import projection_l1_ball
from jaxopt._src.projection import projection_l2_sphere
from jaxopt._src.projection import projection_l2_ball
from jaxopt._src.projection import projection_linf_ball
from jaxopt._src.projection import projection_hyperplane
from jaxopt._src.projection import projection_halfspace
from jaxopt._src.projection import projection_affine_set
from jaxopt._src.projection import projection_polyhedron
from jaxopt._src.projection import projection_box_section
22 changes: 22 additions & 0 deletions jaxopt/prox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.prox import make_prox_from_projection
from jaxopt._src.prox import prox_none
from jaxopt._src.prox import prox_lasso
from jaxopt._src.prox import prox_non_negative_lasso
from jaxopt._src.prox import prox_elastic_net
from jaxopt._src.prox import prox_group_lasso
from jaxopt._src.prox import prox_ridge
from jaxopt._src.prox import prox_non_negative_ridge
25 changes: 25 additions & 0 deletions jaxopt/tree_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jaxopt._src.tree_util import tree_map
from jaxopt._src.tree_util import tree_multimap
from jaxopt._src.tree_util import tree_reduce
from jaxopt._src.tree_util import tree_add
from jaxopt._src.tree_util import tree_sub
from jaxopt._src.tree_util import tree_mul
from jaxopt._src.tree_util import tree_scalar_mul
from jaxopt._src.tree_util import tree_add_scalar_mul
from jaxopt._src.tree_util import tree_vdot
from jaxopt._src.tree_util import tree_sum
from jaxopt._src.tree_util import tree_l2_norm
16 changes: 8 additions & 8 deletions tests/block_cd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax.numpy as jnp

from jaxopt import BlockCoordinateDescent
from jaxopt import objectives
from jaxopt import objective
from jaxopt import projection
from jaxopt import prox
from jaxopt._src import test_util
Expand All @@ -37,7 +37,7 @@ def test_lasso_manual_loop(self):
X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0)

# Setup parameters.
fun = objectives.least_squares # fun(params, data)
fun = objective.least_squares # fun(params, data)
l2reg = 10.0
data = (X, y)

Expand All @@ -57,7 +57,7 @@ def test_lasso(self):
X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0)

# Set up parameters.
fun = objectives.least_squares # fun(params, data)
fun = objective.least_squares # fun(params, data)
l2reg = 10.0
data = (X, y)
w_init = jnp.zeros(X.shape[1])
Expand All @@ -79,7 +79,7 @@ def test_elastic_net(self):
X, y = datasets.make_regression(n_samples=10, n_features=3, random_state=0)

# Set up parameters.
fun = objectives.least_squares # fun(params, data)
fun = objective.least_squares # fun(params, data)
hyperparams_prox = (2.0, 0.8)
data = (X, y)
w_init = jnp.zeros(X.shape[1])
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_multitask_reg(self):
Y = jnp.dot(X, W) + rng.randn(n_samples, n_tasks)

# Set up parameters.
fun = objectives.least_squares # fun(params, data)
fun = objective.least_squares # fun(params, data)
block_prox = prox.prox_group_lasso
l2reg = 1e-1
W_init = jnp.zeros((n_features, n_tasks))
Expand Down Expand Up @@ -149,9 +149,9 @@ def test_logreg(self, multiclass, penalty):
block_prox = prox.prox_ridge

if multiclass:
fun = objectives.multiclass_logreg
fun = objective.multiclass_logreg
else:
fun = objectives.binary_logreg
fun = objective.binary_logreg

l2reg = 1e-2

Expand Down Expand Up @@ -205,7 +205,7 @@ def test_multiclass_linear_svm(self):

# Set up parameters.
block_prox = prox.make_prox_from_projection(projection.projection_simplex)
fun = objectives.multiclass_linear_svm_dual
fun = objective.multiclass_linear_svm_dual
data = (X, Y)
l2reg = 1000.0
beta_init = jnp.ones((n_samples, n_classes)) / n_classes
Expand Down
Loading

0 comments on commit 3e8147f

Please sign in to comment.