Skip to content

Commit

Permalink
like dislike saving
Browse files Browse the repository at this point in the history
  • Loading branch information
krflorian committed Feb 4, 2024
1 parent 3ccd1ca commit 78edf78
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 6 deletions.
3 changes: 2 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gradio
import yaml
from pathlib import Path

from mtg.bot import MagicGPT
from mtg.history.chat_history import ChatHistory
Expand Down Expand Up @@ -38,7 +39,7 @@ def get_magic_bot() -> MagicGPT:
model=MODEL_NAME,
temperature_deck_building=0.7,
max_token_limit=1000,
data_filepath=config.get("message_filepath", "data/messages"),
data_filepath=Path(config.get("message_filepath", "data/messages")),
)
return magic_bot

Expand Down
4 changes: 2 additions & 2 deletions mtg/bot/magic_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __init__(
)

self.data_filepath: Path = data_filepath
(self.data_filepath / "liked").mkdir(exist_ok=True)
(self.data_filepath / "disliked").mkdir(exist_ok=True)
(self.data_filepath / "liked").mkdir(exist_ok=True, parents=True)
(self.data_filepath / "disliked").mkdir(exist_ok=True, parents=True)

self.chat_history: ChatHistory = chat_history
self.memory = memory
Expand Down
15 changes: 14 additions & 1 deletion mtg/history/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def create_message(self, text: str, message_type: MessageType) -> Message:
)

# add additional cards
if message_type == "deckbuilding":
if message_type == MessageType.DECKBUILDING:
message = self.add_additional_cards(message=message, max_number_of_cards=10)

logger.info(
Expand All @@ -205,6 +205,7 @@ def add_additional_cards(
threshold: float = 0.5,
lasso_threshold: float = 0.03,
) -> Message:

additional_cards = []
for card in message.cards:
# for each card in message get max_number_of_cards
Expand All @@ -218,13 +219,25 @@ def add_additional_cards(
)
)

# from message
additional_cards.extend(
self.data_service.get_cards(
message.text,
k=max(10, max_number_of_cards),
threshold=threshold,
lasso_threshold=lasso_threshold,
sample_results=True,
)
)

# choose cards
cards = []
for card in additional_cards:
if card not in message.cards and not card in cards:
cards.append(card)
cards = random.choices(cards, k=min(len(cards), max_number_of_cards))

# add additional cards
message.cards.extend(cards)
logger.debug(f"added {len(cards)} additional cards")

Expand Down
3 changes: 3 additions & 0 deletions mtg/objects/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class Card(BaseModel):
def __repr__(self) -> str:
return f"Card({self.name})"

def to_dict(self) -> dict:
return self.model_dump()

def to_text(self, include_price: bool = True):
"""parse card data to text format"""
text = []
Expand Down
3 changes: 3 additions & 0 deletions mtg/objects/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ class Document(BaseModel):

def __repr__(self):
return f"Document({self.name})"

def to_dict(self):
return self.model_dump()
4 changes: 2 additions & 2 deletions mtg/objects/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def to_dict(self):
return {
"text": self.text,
"type": str(self.type),
"cards": [asdict(card) for card in self.cards],
"rules": [asdict(rule) for rule in self.rules],
"cards": [card.to_dict() for card in self.cards],
"rules": [rule.to_dict() for rule in self.rules],
}

0 comments on commit 78edf78

Please sign in to comment.