Skip to content

Commit

Permalink
Add a free Solve function (RobotLocomotion#9942)
Browse files Browse the repository at this point in the history
Add a free Solver function, and a MakeSolver from solver ID.
  • Loading branch information
hongkai-dai authored Nov 6, 2018
1 parent ae2be68 commit 83b0e51
Show file tree
Hide file tree
Showing 25 changed files with 288 additions and 47 deletions.
8 changes: 6 additions & 2 deletions bindings/pydrake/solvers/mathematicalprogram_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,12 @@ PYBIND11_MODULE(mathematicalprogram, m) {
doc.MathematicalProgramSolverInterface.available.doc)
.def("solver_id", &MathematicalProgramSolverInterface::solver_id,
doc.MathematicalProgramSolverInterface.solver_id.doc)
.def("Solve", &MathematicalProgramSolverInterface::Solve,
doc.MathematicalProgramSolverInterface.Solve.doc)
.def("Solve",
// NOLINTNEXTLINE(whitespace/parens)
static_cast<SolutionResult (MathematicalProgramSolverInterface::*)(
MathematicalProgram&) const>(
&MathematicalProgramSolverInterface::Solve),
py::arg("prog"), doc.MathematicalProgramSolverInterface.Solve.doc)
// TODO(m-chaturvedi) Add Pybind11 documentation.
.def("solver_type",
[](const MathematicalProgramSolverInterface& self) {
Expand Down
19 changes: 19 additions & 0 deletions solvers/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ drake_cc_package_library(
":scs_solver",
":snopt_solver",
":solution_result",
":solve",
":solver_id",
":solver_options",
":solver_result",
Expand Down Expand Up @@ -597,6 +598,16 @@ drake_cc_library(
],
)

drake_cc_library(
name = "solve",
srcs = ["solve.cc"],
hdrs = ["solve.h"],
deps = [
":choose_best_solver",
":mathematical_program",
],
)

# Internal Solvers.

drake_cc_library(
Expand Down Expand Up @@ -1473,6 +1484,14 @@ drake_cc_googletest(
],
)

drake_cc_googletest(
name = "solve_test",
deps = [
":solve",
"//common/test_utilities:eigen_matrix_compare",
],
)

# The extra_srcs are required here because add_lint_tests() doesn't understand
# how to extract labels from select() functions yet.
add_lint_tests(cpplint_extra_srcs = [
Expand Down
27 changes: 27 additions & 0 deletions solvers/choose_best_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,32 @@ SolverId ChooseBestSolver(const MathematicalProgram& prog) {
throw std::invalid_argument(
"There is no available solver for the optimization program");
}

std::unique_ptr<MathematicalProgramSolverInterface> MakeSolver(
const SolverId& id) {
if (id == LinearSystemSolver::id()) {
return std::make_unique<LinearSystemSolver>();
} else if (id == EqualityConstrainedQPSolver::id()) {
return std::make_unique<EqualityConstrainedQPSolver>();
} else if (id == MosekSolver::id()) {
return std::make_unique<MosekSolver>();
} else if (id == GurobiSolver::id()) {
return std::make_unique<GurobiSolver>();
} else if (id == OsqpSolver::id()) {
return std::make_unique<OsqpSolver>();
} else if (id == MobyLcpSolverId::id()) {
return std::make_unique<MobyLCPSolver<double>>();
} else if (id == SnoptSolver::id()) {
return std::make_unique<SnoptSolver>();
} else if (id == IpoptSolver::id()) {
return std::make_unique<IpoptSolver>();
} else if (id == NloptSolver::id()) {
return std::make_unique<NloptSolver>();
} else if (id == ScsSolver::id()) {
return std::make_unique<ScsSolver>();
} else {
throw std::invalid_argument("MakeSolver: no matching solver " + id.name());
}
}
} // namespace solvers
} // namespace drake
9 changes: 9 additions & 0 deletions solvers/choose_best_solver.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <memory>

#include "drake/solvers/mathematical_program.h"
#include "drake/solvers/solver_id.h"

Expand All @@ -11,5 +13,12 @@ namespace solvers {
* @throw invalid_argument if there is no available solver for @p prog.
*/
SolverId ChooseBestSolver(const MathematicalProgram& prog);

/**
* Given the solver ID, create the solver with the matching ID.
* @throw invalid_argument if there is no matching solver.
*/
std::unique_ptr<MathematicalProgramSolverInterface> MakeSolver(
const SolverId& id);
} // namespace solvers
} // namespace drake
6 changes: 6 additions & 0 deletions solvers/dreal_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class DrealSolver : public MathematicalProgramSolverInterface {

SolutionResult Solve(MathematicalProgram& prog) const override;

void Solve(const MathematicalProgram&, const optional<Eigen::VectorXd>&,
const optional<SolverOptions>&,
MathematicalProgramResult*) const override {
throw std::runtime_error("Not implemented yet.");
}

SolverId solver_id() const override;

/// @return same as MathematicalProgramSolverInterface::solver_id()
Expand Down
6 changes: 6 additions & 0 deletions solvers/equality_constrained_qp_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ class EqualityConstrainedQPSolver : public MathematicalProgramSolverInterface {
*/
SolutionResult Solve(MathematicalProgram& prog) const override;

void Solve(const MathematicalProgram&, const optional<Eigen::VectorXd>&,
const optional<SolverOptions>&,
MathematicalProgramResult*) const override {
throw std::runtime_error("Not implemented yet.");
}

SolverId solver_id() const override;

/// @return same as MathematicalProgramSolverInterface::solver_id()
Expand Down
6 changes: 6 additions & 0 deletions solvers/gurobi_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ class GurobiSolver : public MathematicalProgramSolverInterface {

SolutionResult Solve(MathematicalProgram& prog) const override;

void Solve(const MathematicalProgram&, const optional<Eigen::VectorXd>&,
const optional<SolverOptions>&,
MathematicalProgramResult*) const override {
throw std::runtime_error("Not implemented yet.");
}

SolverId solver_id() const override;

/// @return same as MathematicalProgramSolverInterface::solver_id()
Expand Down
6 changes: 6 additions & 0 deletions solvers/ipopt_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class IpoptSolver : public MathematicalProgramSolverInterface {

SolutionResult Solve(MathematicalProgram& prog) const override;

void Solve(const MathematicalProgram&, const optional<Eigen::VectorXd>&,
const optional<SolverOptions>&,
MathematicalProgramResult*) const override {
throw std::runtime_error("Not implemented yet.");
}

SolverId solver_id() const override;

/// @return same as MathematicalProgramSolverInterface::solver_id()
Expand Down
29 changes: 20 additions & 9 deletions solvers/linear_system_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ namespace solvers {

bool LinearSystemSolver::is_available() { return true; }

SolutionResult LinearSystemSolver::Solve(MathematicalProgram& prog) const {
void LinearSystemSolver::Solve(const MathematicalProgram& prog,
const optional<Eigen::VectorXd>& initial_guess,
const optional<SolverOptions>& solver_options,
MathematicalProgramResult* result) const {
// The initial guess doesn't help us, and we don't offer any tuning options.
unused(initial_guess, solver_options);
size_t num_constraints = 0;
for (auto const& binding : prog.linear_equality_constraints()) {
num_constraints += binding.evaluator()->A().rows();
Expand Down Expand Up @@ -49,20 +54,26 @@ SolutionResult LinearSystemSolver::Solve(MathematicalProgram& prog) const {
// least-squares solution
const Eigen::VectorXd least_square_sol =
Aeq.jacobiSvd(Eigen::ComputeThinU | Eigen::ComputeThinV).solve(beq);
SolverResult solver_result(id());
solver_result.set_decision_variable_values(least_square_sol);

result->set_solver_id(id());
result->set_x_val(least_square_sol);
if (beq.isApprox(Aeq * least_square_sol)) {
solver_result.set_optimal_cost(0.);
prog.SetSolverResult(solver_result);
return SolutionResult::kSolutionFound;
result->set_optimal_cost(0.);
result->set_solution_result(SolutionResult::kSolutionFound);
} else {
solver_result.set_optimal_cost(MathematicalProgram::kGlobalInfeasibleCost);
prog.SetSolverResult(solver_result);
return SolutionResult::kInfeasibleConstraints;
result->set_optimal_cost(MathematicalProgram::kGlobalInfeasibleCost);
result->set_solution_result(SolutionResult::kInfeasibleConstraints);
}
}

SolutionResult LinearSystemSolver::Solve(MathematicalProgram& prog) const {
MathematicalProgramResult result;
Solve(prog, {}, {}, &result);
const SolverResult solver_result = result.ConvertToSolverResult();
prog.SetSolverResult(solver_result);
return result.get_solution_result();
}

SolverId LinearSystemSolver::solver_id() const { return id(); }

SolverId LinearSystemSolver::id() {
Expand Down
5 changes: 5 additions & 0 deletions solvers/linear_system_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class LinearSystemSolver : public MathematicalProgramSolverInterface {
/// Find the least-square solution to the linear system A * x = b.
SolutionResult Solve(MathematicalProgram& prog) const override;

void Solve(const MathematicalProgram& prog,
const optional<Eigen::VectorXd>& initial_guess,
const optional<SolverOptions>& solver_options,
MathematicalProgramResult* result) const override;

SolverId solver_id() const override;

/// @return same as MathematicalProgramSolverInterface::solver_id()
Expand Down
10 changes: 10 additions & 0 deletions solvers/mathematical_program_solver_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
#include <Eigen/Core>

#include "drake/common/drake_copyable.h"
#include "drake/solvers/mathematical_program_result.h"
#include "drake/solvers/solution_result.h"
#include "drake/solvers/solver_id.h"
#include "drake/solvers/solver_options.h"
#include "drake/solvers/solver_result.h"

namespace drake {
Expand All @@ -29,6 +31,14 @@ class MathematicalProgramSolverInterface {
// TODO(#2274) Fix NOLINTNEXTLINE(runtime/references).
virtual SolutionResult Solve(MathematicalProgram& prog) const = 0;

/// Solves an optimization program with optional initial guess and solver
/// options. Note that these initial guess and solver options are not written
/// to @p prog.
virtual void Solve(const MathematicalProgram& prog,
const optional<Eigen::VectorXd>& initial_guess,
const optional<SolverOptions>& solver_options,
MathematicalProgramResult* result) const = 0;

/// Returns the identifier of this solver.
virtual SolverId solver_id() const = 0;

Expand Down
6 changes: 6 additions & 0 deletions solvers/moby_lcp_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,12 @@ class MobyLCPSolver : public MathematicalProgramSolverInterface {

SolutionResult Solve(MathematicalProgram& prog) const override;

void Solve(const MathematicalProgram&, const optional<Eigen::VectorXd>&,
const optional<SolverOptions>&,
MathematicalProgramResult*) const override {
throw std::runtime_error("Not implemented yet.");
}

SolverId solver_id() const override;

bool AreProgramAttributesSatisfied(
Expand Down
42 changes: 21 additions & 21 deletions solvers/mosek_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,14 @@ std::shared_ptr<MosekSolver::License> MosekSolver::AcquireLicense() {

bool MosekSolver::is_available() { return true; }

MathematicalProgramResult MosekSolver::SolveConstProg(
const MathematicalProgram& prog) const {
void MosekSolver::Solve(const MathematicalProgram& prog,
const optional<Eigen::VectorXd>& initial_guess,
const optional<SolverOptions>& solver_options,
MathematicalProgramResult* result) const {
// TODO(hongkai.dai): support setting initial guess and solver options.
if (initial_guess.has_value() || solver_options.has_value()) {
throw std::runtime_error("Not implemented yet.");
}
const int num_vars = prog.num_vars();
MSKtask_t task = nullptr;
MSKrescodee rescode;
Expand Down Expand Up @@ -728,8 +734,7 @@ MathematicalProgramResult MosekSolver::SolveConstProg(
}
}

MathematicalProgramResult result;
result.set_solution_result(SolutionResult::kUnknownError);
result->set_solution_result(SolutionResult::kUnknownError);
// Run optimizer.
if (rescode == MSK_RES_OK) {
// TODO([email protected]): add trmcode to the returned struct.
Expand All @@ -753,7 +758,7 @@ MathematicalProgramResult MosekSolver::SolveConstProg(
solution_type = MSK_SOL_ITR;
}

result.set_solver_id(id());
result->set_solver_id(id());
// TODO([email protected]) : Add MOSEK parameters.
// Mosek parameter are added by enum, not by string.
MSKsolstae solution_status{MSK_SOL_STA_UNKNOWN};
Expand All @@ -767,7 +772,7 @@ MathematicalProgramResult MosekSolver::SolveConstProg(
case MSK_SOL_STA_NEAR_OPTIMAL:
case MSK_SOL_STA_INTEGER_OPTIMAL:
case MSK_SOL_STA_NEAR_INTEGER_OPTIMAL: {
result.set_solution_result(SolutionResult::kSolutionFound);
result->set_solution_result(SolutionResult::kSolutionFound);
MSKint32t num_mosek_vars;
rescode = MSK_getnumvar(task, &num_mosek_vars);
DRAKE_ASSERT(rescode == MSK_RES_OK);
Expand All @@ -783,35 +788,35 @@ MathematicalProgramResult MosekSolver::SolveConstProg(
}
}
if (rescode == MSK_RES_OK) {
result.set_x_val(sol_vector);
result->set_x_val(sol_vector);
}
MSKrealt optimal_cost;
rescode = MSK_getprimalobj(task, solution_type, &optimal_cost);
DRAKE_ASSERT(rescode == MSK_RES_OK);
if (rescode == MSK_RES_OK) {
result.set_optimal_cost(optimal_cost);
result->set_optimal_cost(optimal_cost);
}
break;
}
case MSK_SOL_STA_DUAL_INFEAS_CER:
case MSK_SOL_STA_NEAR_DUAL_INFEAS_CER:
result.set_solution_result(SolutionResult::kDualInfeasible);
result->set_solution_result(SolutionResult::kDualInfeasible);
break;
case MSK_SOL_STA_PRIM_INFEAS_CER:
case MSK_SOL_STA_NEAR_PRIM_INFEAS_CER: {
result.set_solution_result(SolutionResult::kInfeasibleConstraints);
result->set_solution_result(SolutionResult::kInfeasibleConstraints);
break;
}
default: {
result.set_solution_result(SolutionResult::kUnknownError);
result->set_solution_result(SolutionResult::kUnknownError);
break;
}
}
}
}

MosekSolverDetails& solver_details =
result.SetSolverDetailsType<MosekSolverDetails>();
result->SetSolverDetailsType<MosekSolverDetails>();
solver_details.rescode = rescode;
solver_details.solution_status = solution_status;
if (rescode == MSK_RES_OK) {
Expand All @@ -820,21 +825,16 @@ MathematicalProgramResult MosekSolver::SolveConstProg(
}
// rescode is not used after this. If in the future, the user wants to call
// more MSK functions after this line, then he/she needs to check if rescode
// is OK. But do not modify result.solution_result_ if rescode is not OK after
// this line.
// is OK. But do not modify result->solution_result_ if rescode is not OK
// after this line.
unused(rescode);

MSK_deletetask(&task);
return result;
}

MathematicalProgramResult MosekSolver::Solve(
const MathematicalProgram& prog) const {
return SolveConstProg(prog);
}

SolutionResult MosekSolver::Solve(MathematicalProgram& prog) const {
const MathematicalProgramResult result = SolveConstProg(prog);
MathematicalProgramResult result;
Solve(prog, {}, {}, &result);
const SolverResult solver_result = result.ConvertToSolverResult();
prog.SetSolverResult(solver_result);
return result.get_solution_result();
Expand Down
9 changes: 4 additions & 5 deletions solvers/mosek_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ class MosekSolver : public MathematicalProgramSolverInterface {

static bool is_available();

MathematicalProgramResult Solve(const MathematicalProgram& prog) const;
void Solve(const MathematicalProgram& prog,
const optional<Eigen::VectorXd>& initial_guess,
const optional<SolverOptions>& solver_options,
MathematicalProgramResult* result) const override;

// Todo([email protected]): deprecate Solve with a non-const
// MathematicalProgram.
Expand Down Expand Up @@ -90,10 +93,6 @@ class MosekSolver : public MathematicalProgramSolverInterface {
static std::shared_ptr<License> AcquireLicense();

private:
// TODO(hongkai.dai) remove this function when we remove Solve() with a
// non-cost MathematicalProgram.
MathematicalProgramResult SolveConstProg(
const MathematicalProgram& prog) const;
// Note that this is mutable to allow latching the allocation of mosek_env_
// during the first call of Solve() (which avoids grabbing a Mosek license
// before we know that we actually want one).
Expand Down
Loading

0 comments on commit 83b0e51

Please sign in to comment.