Skip to content

Commit

Permalink
Add type hints, convert Bing/Google to coros
Browse files Browse the repository at this point in the history
- Adds type hints.

- Converts get_bing_links, get_google_links,
  and get_search_links to coroutines for using aiohttp
  to fetch Bing urls.

- Adds "Tie" output for answer

- Refactors live_show.py message handling
  • Loading branch information
Exaphis committed Jan 5, 2021
1 parent 6c1d1ce commit d496caa
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 129 deletions.
20 changes: 13 additions & 7 deletions hackq_trivia/hq_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json.decoder
import time
from typing import Optional
from datetime import datetime
import os

Expand All @@ -19,7 +20,12 @@ class BearerError(Exception):
"""Raise when bearer token is invalid/expired"""


def next_available_name(base_name):
def next_available_name(base_name: str) -> str:
"""
Finds lowest available file name using .format() to insert numbers (starts at 1).
:param base_name: File name containing format placeholder ({})
:return: File name with lowest number inserted.
"""
num = 1
curr_name = base_name.format(num)
while os.path.exists(curr_name):
Expand All @@ -29,7 +35,7 @@ def next_available_name(base_name):
return curr_name


def init_root_logger():
def init_root_logger() -> None:
import os

class LogFilterColor(logging.Filter):
Expand Down Expand Up @@ -65,7 +71,7 @@ def filter(self, record):
logging.config.dictConfig(log_conf_dict)


def download_nltk_resources():
def download_nltk_resources() -> None:
for resource in ('stopwords', 'averaged_perceptron_tagger', 'punkt'):
nltk.download(resource, raise_on_error=True)

Expand Down Expand Up @@ -103,7 +109,7 @@ def __init__(self):
self.validate_bearer()
self.logger.info('HackQ-Trivia initialized.\n', extra={'pre': colorama.Fore.GREEN})

def validate_bearer(self):
def validate_bearer(self) -> None:
try:
# verify and options args exist to support all versions of pyjwt
# iat/exp is not checked by pyjwt if verify_signature is False
Expand All @@ -127,11 +133,11 @@ def validate_bearer(self):
self.logger.info(f' Issuing time: {iat_local.strftime("%Y-%m-%d %I:%M %p")}')
self.logger.info(f' Expiration time: {exp_local.strftime("%Y-%m-%d %I:%M %p")}')

async def __connect_show(self, uri):
async def __connect_show(self, uri) -> None:
async with LiveShow(self.headers) as show:
await show.connect(uri)

def connect(self):
def connect(self) -> None:
while True:
try:
websocket_uri = self.get_next_show_info()
Expand All @@ -144,7 +150,7 @@ def connect(self):
self.logger.error('Interrupted, exiting...')
break

def get_next_show_info(self):
def get_next_show_info(self) -> Optional[str]:
"""
Gets info of upcoming shows from HQ, prints it out if ShowNextShowInfo is True
:return: The show's WebSocket URI if it is live, else None
Expand Down
97 changes: 52 additions & 45 deletions hackq_trivia/live_show.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
from typing import Dict

import aiohttp
import colorama
Expand All @@ -25,72 +26,78 @@ def __init__(self, headers):
self.logger = logging.getLogger(__name__)
self.logger.info('LiveShow initialized.')

async def connect(self, uri):
async def connect(self, uri: str) -> None:
session = aiohttp.ClientSession()

rejoin = True
while rejoin:
async with session.ws_connect(uri, headers=self.headers, heartbeat=5) as ws:
async for msg in ws:
# suppress incorrect type warning for msg in PyCharm
rejoin = await self.handle_msg(msg) # noqa
if msg.type != aiohttp.WSMsgType.TEXT: # noqa
continue
message = json.loads(msg.data) # noqa

await self.handle_msg(message)

rejoin = self.should_rejoin(message)
if rejoin:
break

self.logger.info('Disconnected.')

async def handle_msg(self, msg):
"""
Handles WebSocket frame received from HQ server.
:param msg: Message received by aiohttp
:return: True if the WS connection should be rejoined, False otherwise
"""
if msg.type == aiohttp.WSMsgType.TEXT:
message = json.loads(msg.data)
self.logger.debug(message)
@staticmethod
def should_rejoin(message: Dict) -> bool:
if message['type'] != 'broadcastEnded':
return False

if 'error' in message and message['error'] == 'Auth not valid':
raise ConnectionRefusedError('User ID/Bearer invalid. Please check your settings.ini.')
return message.get('reason', '') == 'You are no longer in the game. Please join again.'

message_type = message['type']
async def handle_msg(self, message: Dict) -> None:
self.logger.debug(message)

if message_type == 'broadcastEnded' and \
message['reason'] == 'You are no longer in the game. Please join again.':
return True
if 'error' in message and message['error'] == 'Auth not valid':
raise ConnectionRefusedError('User ID/Bearer invalid. Please check your settings.ini.')

elif message_type == 'interaction' and self.show_chat and not self.block_chat:
self.logger.info(f'{message["metadata"]["username"]}: {message["metadata"]["message"]}')

elif message_type == 'question':
question = unidecode(message['question'])
choices = [unidecode(choice['text']) for choice in message['answers']]
message_type = message['type']

self.logger.info('\n' * 5)
self.logger.info(f'Question {message["questionNumber"]} out of {message["questionCount"]}')
self.logger.info(question, extra={"pre": colorama.Fore.BLUE})
self.logger.info(f'Choices: {", ".join(choices)}', extra={'pre': colorama.Fore.BLUE})
if message_type == 'broadcastEnded':
if 'reason' in message:
reason = message['reason']
self.logger.info(f'Disconnected: {reason}')
else:
self.logger.info('Disconnected.')

elif message_type == 'interaction' and self.show_chat and not self.block_chat:
self.logger.info(f'{message["metadata"]["username"]}: {message["metadata"]["message"]}')

elif message_type == 'question':
question = unidecode(message['question'])
choices = [unidecode(choice['text']) for choice in message['answers']]

await self.question_handler.answer_question(question, choices)
self.logger.info('\n' * 5)
self.logger.info(f'Question {message["questionNumber"]} out of {message["questionCount"]}')
self.logger.info(question, extra={"pre": colorama.Fore.BLUE})
self.logger.info(f'Choices: {", ".join(choices)}', extra={'pre': colorama.Fore.BLUE})

self.block_chat = True

elif message_type == 'questionSummary' and self.show_question_summary:
question = unidecode(message['question'])
self.logger.info(f'Question summary: {question}', extra={'pre': colorama.Fore.BLUE})
await self.question_handler.answer_question(question, choices)

for answer in message['answerCounts']:
ans_str = unidecode(answer['answer'])
self.block_chat = True

self.logger.info(f'{ans_str}:{answer["count"]}:{answer["correct"]}',
extra={'pre': colorama.Fore.GREEN if answer['correct'] else colorama.Fore.RED})
elif message_type == 'questionSummary' and self.show_question_summary:
question = unidecode(message['question'])
self.logger.info(f'Question summary: {question}', extra={'pre': colorama.Fore.BLUE})

self.logger.info(f'{message["advancingPlayersCount"]} players advancing')
self.logger.info(f'{message["eliminatedPlayersCount"]} players eliminated\n')

elif message_type == 'questionClosed' and self.block_chat:
self.block_chat = False
if self.show_chat:
self.logger.info('\n' * 5)
for answer in message['answerCounts']:
ans_str = unidecode(answer['answer'])

return False
self.logger.info(f'{ans_str}:{answer["count"]}:{answer["correct"]}',
extra={'pre': colorama.Fore.GREEN if answer['correct'] else colorama.Fore.RED})

self.logger.info(f'{message["advancingPlayersCount"]} players advancing')
self.logger.info(f'{message["eliminatedPlayersCount"]} players eliminated\n')

elif message_type == 'questionClosed' and self.block_chat:
self.block_chat = False
if self.show_chat:
self.logger.info('\n' * 5)
62 changes: 25 additions & 37 deletions hackq_trivia/question_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import string
from time import time
from typing import Match
from typing import Dict, List, Match

import nltk
import colorama
Expand All @@ -28,7 +28,7 @@ def __init__(self):
async def close(self):
await self.searcher.close()

async def answer_question(self, question, original_choices):
async def answer_question(self, question: str, original_choices: List[str]):
self.logger.info('Searching...')
start_time = time()

Expand All @@ -50,7 +50,7 @@ async def answer_question(self, question, original_choices):
self.logger.debug(f'Keywords took {round(time() - keyword_start_time, 2)} seconds')

search_start_time = time()
links = self.searcher.get_search_links(' '.join(question_keywords), self.num_sites)
links = await self.searcher.get_search_links(' '.join(question_keywords), self.num_sites)
self.logger.debug(f'Web search took {round(time() - search_start_time, 2)} seconds')
self.logger.debug(f'Found links: {links}')

Expand All @@ -61,16 +61,23 @@ async def answer_question(self, question, original_choices):
self.logger.debug(f'Fetching took {round(time() - fetch_start_time, 2)} seconds')

# Step 3: Find best answer for all search methods
# TODO: async-ify the search methods
post_process_start_time = time()
answers = []
for search_method in self.search_methods_to_use:
self.logger.info(search_method(link_texts, choices, choice_groups, reverse),
extra={'pre': colorama.Fore.BLUE})
answer = await search_method(link_texts, choices, choice_groups, reverse)
answers.append(answer)
if answer:
self.logger.info(answer, extra={'pre': colorama.Fore.BLUE})
else:
self.logger.info('Tie', extra={'pre': colorama.Fore.BLUE})

self.logger.debug(f'Post-processing took {round(time() - post_process_start_time, 2)} seconds')

self.logger.info(f'Search took {round(time() - start_time, 2)} seconds')
return answers

def __method1(self, texts, answers, answer_groups, reverse):
async def __method1(self, texts: List[str], answers: List[str],
answer_groups: List[List[str]], reverse: bool) -> str:
"""
Returns the answer with the best number of exact occurrences in texts.
:param texts: List of webpages (strings) to analyze
Expand All @@ -89,7 +96,8 @@ def __method1(self, texts, answers, answer_groups, reverse):
self.logger.info(counts)
return self.__get_best_answer(counts, answer_groups, reverse)

def __method2(self, texts, answers, answer_groups, reverse):
async def __method2(self, texts: List[str], answers: List[str],
answer_groups: List[List[str]], reverse: bool) -> str:
"""
Returns the answers with the best number of occurrences of the answer's keywords in texts.
:param texts: List of webpages (strings) to analyze
Expand All @@ -109,7 +117,7 @@ def __method2(self, texts, answers, answer_groups, reverse):
self.logger.info(counts)
return self.__get_best_answer(counts, answer_groups, reverse)

def find_keywords(self, text: str, sentences=True):
def find_keywords(self, text: str, sentences: bool = True) -> List[str]:
"""
Returns the keywords from a string containing text, in the order they appear.
Keywords:
Expand Down Expand Up @@ -156,35 +164,15 @@ def process_match(match: Match[str]):
# TODO: handle plural and singular, see test_question_handler.py
return keywords

def find_nouns(self, text, num_words, reverse=False):
tokens = nltk.word_tokenize(text)
tags = [tag for tag in self.perceptron_tagger.tag(tokens) if tag[1] != 'POS']

if not self.simplified_output:
self.logger.info(tags)

tags = tags[:num_words] if not reverse else tags[-num_words:]

nouns = []
consecutive_nouns = []

for tag in tags:
tag_type = tag[1]
word = tag[0]

if 'NN' not in tag_type and len(consecutive_nouns) > 0:
nouns.append(' '.join(consecutive_nouns))
consecutive_nouns = []
elif 'NN' in tag_type:
consecutive_nouns.append(word)

if len(consecutive_nouns) > 0:
nouns.append(' '.join(consecutive_nouns))

return nouns

@staticmethod
def __get_best_answer(all_scores, choice_groups, reverse=False):
def __get_best_answer(all_scores: Dict, choice_groups: List[List[str]], reverse: bool = False):
"""
Returns best answer based on scores for each choice and groups of choices.
:param all_scores: Dict mapping choices to scores
:param choice_groups: List of lists (groups) of choices
:param reverse: If True, return lowest scoring choice group, otherwise return highest
:return: String (first entry in group) of the group with the highest/lowest total score
"""
# Add scores of the same answer together due to two ways of removing punctuation
scores = {choices[0]: sum(all_scores[choice] for choice in choices) for choices in choice_groups}

Expand Down
26 changes: 14 additions & 12 deletions hackq_trivia/searcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
import operator
from html import unescape
from typing import Iterable, List

import aiohttp
import bs4
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(self):
self.session = aiohttp.ClientSession(headers=Searcher.HEADERS, timeout=client_timeout)
self.logger = logging.getLogger(__name__)

async def fetch(self, url):
async def fetch(self, url: str) -> str:
try:
async with self.session.get(url, timeout=self.timeout) as response:
return await response.text()
Expand All @@ -54,24 +54,26 @@ async def fetch(self, url):

return ""

async def fetch_multiple(self, urls):
tasks = [asyncio.create_task(self.fetch(url)) for url in urls]
responses = await asyncio.gather(*tasks)
# no typing info for return value because https://github.com/python/typeshed/issues/2652
async def fetch_multiple(self, urls: Iterable[str]):
coroutines = [self.fetch(url) for url in urls]
responses = await asyncio.gather(*coroutines)
return responses

async def close(self):
async def close(self) -> None:
await self.session.close()

def get_search_links(self, query, num_results):
return self.search_func(query, num_results)
async def get_search_links(self, query: str, num_results: int) -> List[str]:
return await self.search_func(query, num_results)

def get_google_links(self, query, num_results):
async def get_google_links(self, query: str, num_results: int) -> List[str]:
response = self.google_service.cse().list(q=query, cx=self.google_cse_id, num=num_results).execute()
self.logger.debug(f'google: {query}, n={num_results}')
self.logger.debug(response)
return list(map(operator.itemgetter('link'), response['items']))

def get_bing_links(self, query, num_results):
return [item['link'] for item in response['items']]

async def get_bing_links(self, query: str, num_results: int) -> List[str]:
# could be using aiohttp here...
search_params = {'q': query, 'count': num_results}
resp = requests.get(self.BING_ENDPOINT, headers=self.bing_headers, params=search_params)
Expand All @@ -85,7 +87,7 @@ def get_bing_links(self, query, num_results):
self.logger.debug(f'bing: {query}, n={num_results}')
self.logger.debug(resp_data)

return list(map(operator.itemgetter('url'), resp_data['webPages']['value']))
return [item['url'] for item in resp_data['webPages']['value']]

@staticmethod
def html_to_visible_text(html):
Expand Down
Loading

0 comments on commit d496caa

Please sign in to comment.