Skip to content

Commit

Permalink
add channel and group support to conversation
Browse files Browse the repository at this point in the history
  • Loading branch information
rking32 committed Aug 13, 2020
1 parent 4cda2e6 commit fc46fcf
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
8 changes: 6 additions & 2 deletions userge/core/methods/chats/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
class Conversation: # pylint: disable=missing-class-docstring
def conversation(self,
chat_id: Union[str, int],
*, timeout: Union[int, float] = 10,
*, user_id: Union[str, int] = 0,
timeout: Union[int, float] = 10,
limit: int = 10) -> 'types.new.Conversation':
"""\nThis returns new conversation object.
Expand All @@ -30,6 +31,9 @@ def conversation(self,
For a contact that exists in your Telegram address book
you can use his phone number (str).
user_id (``int`` | ``str`` | , *optional*):
define a specific user in this chat.
timeout (``int`` | ``float`` | , *optional*):
set conversation timeout.
defaults to 10.
Expand All @@ -38,4 +42,4 @@ def conversation(self,
set conversation message limit.
defaults to 10.
"""
return types.new.Conversation(self, chat_id, timeout, limit)
return types.new.Conversation(self, chat_id, user_id, timeout, limit)
43 changes: 32 additions & 11 deletions userge/core/types/new/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
__all__ = ['Conversation']

import asyncio
from typing import Union, Dict, Optional
from typing import Union, Dict, Tuple, Optional

from pyrogram import Message as RawMessage, Filters, MessageHandler

Expand All @@ -22,7 +22,7 @@
_LOG = logging.getLogger(__name__)
_LOG_STR = "<<<! ::::: %s ::::: !>>>"

_CONV_DICT: Dict[int, asyncio.Queue] = {}
_CONV_DICT: Dict[int, Union[asyncio.Queue, Tuple[int, asyncio.Queue]]] = {}


class _MsgLimitReached(Exception):
Expand All @@ -33,14 +33,17 @@ class Conversation:
""" Conversation class for userge """
def __init__(self,
client: '_client.Userge',
chat: Union[str, int],
user: Union[str, int],
timeout: Union[int, float],
limit: int) -> None:
self._client = client
self._chat = chat
self._user = user
self._timeout = timeout
self._limit = limit
self._chat_id: int
self._user_id: int
self._count = 0

@property
Expand All @@ -65,8 +68,10 @@ async def get_response(self, *, timeout: Union[int, float] = 0,
"""
if self._count >= self._limit:
raise _MsgLimitReached
response_ = await asyncio.wait_for(_CONV_DICT[self._chat_id].get(),
timeout or self._timeout)
queue = _CONV_DICT[self._chat_id]
if isinstance(queue, tuple):
queue = queue[1]
response_ = await asyncio.wait_for(queue.get(), timeout or self._timeout)
self._count += 1
if mark_read:
await self.mark_read(response_)
Expand Down Expand Up @@ -136,23 +141,39 @@ async def forward_message(self, message: RawMessage) -> RawMessage:
def init(client: '_client.Userge') -> None:
""" initialize the conversation method """
async def _on_conversation(_, msg: RawMessage) -> None:
_CONV_DICT[msg.from_user.id].put_nowait(msg)
data = _CONV_DICT[msg.chat.id]
if isinstance(data, asyncio.Queue):
data.put_nowait(msg)
elif msg.from_user and msg.from_user.id == data[0]:
data[1].put_nowait(msg)
msg.continue_propagation()
client.add_handler(
MessageHandler(
_on_conversation,
Filters.create(
lambda _, query: _CONV_DICT and query.from_user
and query.from_user.id in _CONV_DICT)), 0)
lambda _, query: _CONV_DICT and query.chat
and query.chat.id in _CONV_DICT, 0)))

async def __aenter__(self) -> 'Conversation':
self._chat_id = int(self._user) if isinstance(self._user, int) else \
(await self._client.get_users(self._user)).id
_CONV_DICT[self._chat_id] = asyncio.Queue(self._limit)
self._chat_id = int(self._chat) if isinstance(self._chat, int) else \
(await self._client.get_chats(self._chat)).id
if self._chat_id in _CONV_DICT:
error = f"already started conversation with {self._chat_id} !"
_LOG.error(_LOG_STR, error)
raise StopConversation(error)
if self._user:
self._user_id = int(self._user) if isinstance(self._user, int) else \
(await self._client.get_users(self._user)).id
_CONV_DICT[self._chat_id] = (self._user_id, asyncio.Queue(self._limit))
else:
_CONV_DICT[self._chat_id] = asyncio.Queue(self._limit)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
_CONV_DICT[self._chat_id].put_nowait(None)
queue = _CONV_DICT[self._chat_id]
if isinstance(queue, tuple):
queue = queue[1]
queue.put_nowait(None)
del _CONV_DICT[self._chat_id]
error = ''
if isinstance(exc_val, asyncio.exceptions.TimeoutError):
Expand Down

0 comments on commit fc46fcf

Please sign in to comment.