Skip to content

Commit

Permalink
Begin re-factor of URLs
Browse files Browse the repository at this point in the history
* Move API_BASE to common
* Move endpoints to model source
  • Loading branch information
adithyabsk committed Apr 13, 2020
1 parent 0a80a05 commit 247b245
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 57 deletions.
5 changes: 5 additions & 0 deletions pyrh/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,10 @@

from typing import Any, Dict

from yarl import URL


JSON = Dict[str, Any]

API_BASE = URL("https://api.robinhood.com")
"""Base robinhood api endpoint."""
74 changes: 30 additions & 44 deletions pyrh/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
"""Define Robinhood endpoints."""

from typing import Callable

from yarl import URL


BASE = URL("https://api.robinhood.com")

# OAuth
OAUTH: URL = BASE.with_path("/oauth2/token/")
OAUTH_REVOKE: URL = BASE.with_path("/oauth2/revoke_token/")
CHALLENGE: Callable[[str], URL] = lambda cid: BASE.with_path(
f"/challenge/{cid}/respond/"
)


def logout():
return BASE.with_path("/oauth2/revoke_token/")
from .common import API_BASE


def investment_profile():
return BASE.with_path("/user/investment_profile/")
return API_BASE.with_path("/user/investment_profile/")


def accounts():
return BASE.with_path("/accounts/")
return API_BASE.with_path("/accounts/")


def ach(option):
Expand All @@ -35,30 +19,30 @@ def ach(option):
* transfers
"""
return (
BASE.with_path("/ach/iav/auth/")
API_BASE.with_path("/ach/iav/auth/")
if option == "iav"
else BASE.with_path(f"/ach/{option}/")
else API_BASE.with_path(f"/ach/{option}/")
)


def applications():
return BASE.with_path("/applications/")
return API_BASE.with_path("/applications/")


def dividends():
return BASE.with_path("/dividends/")
return API_BASE.with_path("/dividends/")


def edocuments():
return BASE.with_path("/documents/")
return API_BASE.with_path("/documents/")


def instruments(instrument_id=None, option=None):
"""
Return information about a specific instrument by providing its instrument id.
Add extra options for additional information such as "popularity"
"""
url = BASE.with_path(f"/instruments/")
url = API_BASE.with_path(f"/instruments/")
if instrument_id is not None:
url += f"{instrument_id}"
if option is not None:
Expand All @@ -68,82 +52,84 @@ def instruments(instrument_id=None, option=None):


def margin_upgrades():
return BASE.with_path("/margin/upgrades/")
return API_BASE.with_path("/margin/upgrades/")


def markets():
return BASE.with_path("/markets/")
return API_BASE.with_path("/markets/")


def notifications():
return BASE.with_path("/notifications/")
return API_BASE.with_path("/notifications/")


def orders(order_id=""):
return BASE.with_path(f"/orders/{order_id}/")
return API_BASE.with_path(f"/orders/{order_id}/")


def password_reset():
return BASE.with_path("/password_reset/request/")
return API_BASE.with_path("/password_reset/request/")


def portfolios():
return BASE.with_path("/portfolios/")
return API_BASE.with_path("/portfolios/")


def positions():
return BASE.with_path("/positions/")
return API_BASE.with_path("/positions/")


def quotes():
return BASE.with_path("/quotes/")
return API_BASE.with_path("/quotes/")


def historicals():
return BASE.with_path("/quotes/historicals/")
return API_BASE.with_path("/quotes/historicals/")


def document_requests():
return BASE.with_path("/upload/document_requests/")
return API_BASE.with_path("/upload/document_requests/")


def user():
return BASE.with_path("/user/")
return API_BASE.with_path("/user/")


def watchlists():
return BASE.with_path("/watchlists/")
return API_BASE.with_path("/watchlists/")


def news(stock):
return BASE.with_path(f"/midlands/news/{stock}/")
return API_BASE.with_path(f"/midlands/news/{stock}/")


def fundamentals(stock):
return BASE.with_path(f"/fundamentals/{stock}/")
return API_BASE.with_path(f"/fundamentals/{stock}/")


def tags(tag):
"""
Returns endpoint with tag concatenated.
"""
return BASE.with_path(f"/midlands/tags/tag/{tag}/")
return API_BASE.with_path(f"/midlands/tags/tag/{tag}/")


def chain(instrument_id):
return BASE.with_path(f"/options/chains/?equity_instrument_ids={instrument_id}/")
return API_BASE.with_path(
f"/options/chains/?equity_instrument_ids={instrument_id}/"
)


def options(chain_id, dates, option_type):
return BASE.with_path(
return API_BASE.with_path(
f"/options/instruments/?chain_id={chain_id}&expiration_dates={dates}"
f"&state=active&tradability=tradable&type={option_type}"
)


def market_data(option_id):
return BASE.with_path(f"/marketdata/options/{option_id}/")
return API_BASE.with_path(f"/marketdata/options/{option_id}/")


def convert_token():
return BASE.with_path("/oauth2/migrate_token/")
return API_BASE.with_path("/oauth2/migrate_token/")
33 changes: 24 additions & 9 deletions pyrh/models/sessionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
import uuid
from datetime import datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Union,
cast,
overload,
)
from urllib.request import getproxies

import certifi
Expand All @@ -16,9 +25,8 @@
from typing_extensions import Literal
from yarl import URL

from pyrh import endpoints
from pyrh.cache import CACHE_ROOT
from pyrh.common import JSON
from pyrh.common import API_BASE, JSON
from pyrh.exceptions import AuthenticationError

from .base import BaseModel, BaseSchema
Expand All @@ -37,6 +45,13 @@
else:
CaseInsensitiveDictType = CaseInsensitiveDict

# Endpoints
OAUTH: URL = API_BASE.with_path("/oauth2/token/")
OAUTH_REVOKE: URL = API_BASE.with_path("/oauth2/revoke_token/")
CHALLENGE: Callable[[str], URL] = lambda cid: API_BASE.with_path(
f"/challenge/{cid}/respond/"
)

Proxies = Dict[str, str]
HEADERS: CaseInsensitiveDictType = CaseInsensitiveDict(
{
Expand Down Expand Up @@ -356,7 +371,7 @@ def _challenge_oauth2(self, oauth: OAuth, oauth_payload: JSON) -> OAuth:
"""
# login challenge
challenge_url = endpoints.CHALLENGE(oauth.challenge.id)
challenge_url = CHALLENGE(oauth.challenge.id)
print(
f"Input challenge code from {oauth.challenge.type.capitalize()} "
f"({oauth.challenge.remaining_attempts}/"
Expand All @@ -379,7 +394,7 @@ def _challenge_oauth2(self, oauth: OAuth, oauth_payload: JSON) -> OAuth:
if res.status_code == requests.codes.ok:
try:
res2 = self.post(
endpoints.OAUTH,
OAUTH,
data=oauth_payload,
headers=challenge_header,
auto_login=False,
Expand Down Expand Up @@ -416,7 +431,7 @@ def _mfa_oauth2(self, oauth_payload: JSON, attempts: int = 3) -> OAuth:
mfa_code = input()
oauth_payload["mfa_code"] = mfa_code
res = self.post(
endpoints.OAUTH,
OAUTH,
data=oauth_payload,
raise_errors=False,
auto_login=False,
Expand Down Expand Up @@ -454,7 +469,7 @@ def _login_oauth2(self) -> None:
}

res = self.post(
endpoints.OAUTH,
OAUTH,
data=oauth_payload,
raise_errors=False,
auto_login=False,
Expand Down Expand Up @@ -493,7 +508,7 @@ def _refresh_oauth2(self) -> None:
}
self.session.headers.pop("Authorization", None)
try:
res = self.post(endpoints.OAUTH, data=relogin_payload, auto_login=False)
res = self.post(OAUTH, data=relogin_payload, auto_login=False)
except HTTPError:
raise AuthenticationError("Failed to refresh token")

Expand All @@ -509,7 +524,7 @@ def logout(self) -> None:
"""
logout_payload = {"client_id": CLIENT_ID, "token": self.oauth.refresh_token}
try:
self.post(endpoints.OAUTH_REVOKE, data=logout_payload, auto_login=False)
self.post(OAUTH_REVOKE, data=logout_payload, auto_login=False)
self.oauth = OAuth()
self.session.headers.pop("Authorization", None)
except HTTPError:
Expand Down
7 changes: 3 additions & 4 deletions tests/test_sessionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def sm_adap(monkeypatch):
"password": "some password",
}

monkeypatch.setattr("pyrh.endpoints.OAUTH", MOCK_URL)
monkeypatch.setattr("pyrh.endpoints.OAUTH_REVOKE", MOCK_URL)
monkeypatch.setattr("pyrh.endpoints.CHALLENGE", lambda x: MOCK_URL)
monkeypatch.setattr("pyrh.models.sessionmanager.OAUTH", MOCK_URL)
monkeypatch.setattr("pyrh.models.sessionmanager.OAUTH_REVOKE", MOCK_URL)
monkeypatch.setattr("pyrh.models.sessionmanager.CHALLENGE", lambda x: MOCK_URL)

session_manager = SessionManager(**sample_user)
adapter = requests_mock.Adapter()
Expand Down Expand Up @@ -71,7 +71,6 @@ def test_login_oauth2_errors(monkeypatch, sm_adap):
# oauth from the mfa approaches as those individual functions will error
# out themselves

monkeypatch.setattr("pyrh.endpoints.OAUTH", MOCK_URL)
adapter.register_uri(
"POST", MOCK_URL, text='{"error": "Some error"}', status_code=400
)
Expand Down

0 comments on commit 247b245

Please sign in to comment.