Skip to content

Commit

Permalink
fix bot pm + pool freeze issue
Browse files Browse the repository at this point in the history
  • Loading branch information
rking32 committed Jul 24, 2021
1 parent 6ca203b commit 9e2e6cb
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 108 deletions.
25 changes: 20 additions & 5 deletions userge/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
from typing import List, Awaitable, Any, Optional, Union

from pyrogram import idle
from pyrogram.types import User

from userge import logging, Config, logbot
from userge.utils import time_formatter
from userge.utils.exceptions import UsergeBotNotFound
from userge.plugins import get_all_plugins
from .methods import Methods
from .ext import RawClient, pool
from .database import get_collection, _close_db
from .database import get_collection

_LOG = logging.getLogger(__name__)
_LOG_STR = "<<<! ##### %s ##### !>>>"
Expand Down Expand Up @@ -63,8 +64,13 @@ async def _complete_init_tasks() -> None:


class _AbstractUserge(Methods, RawClient):
def __init__(self, **kwargs) -> None:
self._me: Optional[User] = None
super().__init__(**kwargs)

@property
def id(self) -> int:
""" returns client id """
if self.is_bot:
return RawClient.BOT_ID
return RawClient.USER_ID
Expand Down Expand Up @@ -129,6 +135,19 @@ async def reload_plugins(self) -> int:
await self.finalize_load()
return len(reloaded)

async def get_me(self, cached: bool = True) -> User:
if not cached or self._me is None:
self._me = await super().get_me()
return self._me

async def start(self):
await super().start()
self._me = await self.get_me()
if self.is_bot:
RawClient.BOT_ID = self._me.id
else:
RawClient.USER_ID = self._me.id

def __eq__(self, o: object) -> bool:
return isinstance(o, _AbstractUserge) and self.id == o.id

Expand Down Expand Up @@ -167,8 +186,6 @@ def __init__(self, **kwargs) -> None:
kwargs['bot'] = UsergeBot(bot=self, **kwargs)
kwargs['session_name'] = Config.HU_STRING_SESSION or ":memory:"
super().__init__(**kwargs)
self.executor.shutdown()
self.executor = pool._get() # pylint: disable=protected-access

@property
def dual_mode(self) -> bool:
Expand Down Expand Up @@ -215,7 +232,6 @@ async def stop(self) -> None: # pylint: disable=arguments-differ
_LOG.info(_LOG_STR, "Stopping Userge")
await super().stop()
await _set_running(False)
_close_db()
pool._stop() # pylint: disable=protected-access

def begin(self, coro: Optional[Awaitable[Any]] = None) -> None:
Expand All @@ -237,7 +253,6 @@ async def _finalize() -> None:
if self.is_initialized:
await self.stop()
else:
_close_db()
pool._stop() # pylint: disable=protected-access
# pylint: disable=expression-not-assigned
[t.cancel() for t in asyncio.all_tasks() if t is not asyncio.current_task()]
Expand Down
4 changes: 0 additions & 4 deletions userge/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,4 @@ def get_collection(name: str) -> AgnosticCollection:
return _DATABASE[name]


def _close_db() -> None:
_MGCLIENT.close()


logbot.del_last_msg()
9 changes: 2 additions & 7 deletions userge/core/ext/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
from concurrent.futures import ThreadPoolExecutor, Future
from functools import wraps, partial

from motor.frameworks.asyncio import _EXECUTOR # pylint: disable=protected-access

from userge import logging
from userge import logging, Config

_LOG = logging.getLogger(__name__)
_LOG_STR = "<<<! |||| %s |||| !>>>"
_EXECUTOR = ThreadPoolExecutor(Config.WORKERS)


def submit_thread(func: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Future:
Expand All @@ -37,10 +36,6 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper


def _get() -> ThreadPoolExecutor:
return _EXECUTOR


def _stop():
_EXECUTOR.shutdown()
# pylint: disable=protected-access
Expand Down
25 changes: 3 additions & 22 deletions userge/core/methods/decorators/raw_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

_CH_LKS: Dict[str, asyncio.Lock] = {}
_CH_LKS_LK = asyncio.Lock()
_INIT_LK = asyncio.Lock()


async def _update_u_cht(r_m: RawMessage) -> Optional[ChatMember]:
Expand Down Expand Up @@ -91,30 +90,12 @@ def _clear_cht() -> None:
_TASK_1_START_TO = time.time()


async def _init(r_c: Union['_client.Userge', '_client.UsergeBot'],
r_m: RawMessage, is_bot: bool) -> None:
async def _init(r_m: RawMessage) -> None:
if r_m.from_user and (
r_m.from_user.is_self or (
r_m.from_user.id in Config.SUDO_USERS) or (
r_m.from_user.id in Config.OWNER_ID)):
RawClient.LAST_OUTGOING_TIME = time.time()
async with _INIT_LK:
if RawClient.DUAL_MODE:
if RawClient.USER_ID and RawClient.BOT_ID:
return
else:
if RawClient.USER_ID or RawClient.BOT_ID:
return
if is_bot:
if not RawClient.BOT_ID:
RawClient.BOT_ID = (await r_c.get_me()).id
if RawClient.DUAL_MODE and not RawClient.USER_ID:
RawClient.USER_ID = (await r_c.ubot.get_me()).id
else:
if not RawClient.USER_ID:
RawClient.USER_ID = (await r_c.get_me()).id
if RawClient.DUAL_MODE and not RawClient.BOT_ID:
RawClient.BOT_ID = (await r_c.bot.get_me()).id


async def _raise_func(r_c: Union['_client.Userge', '_client.UsergeBot'],
Expand Down Expand Up @@ -245,13 +226,13 @@ async def template(r_c: Union['_client.Userge', '_client.UsergeBot'],
return
if r_m.chat and r_m.chat.id in Config.DISABLED_CHATS:
return
is_bot = r_c.is_bot
await _init(r_c, r_m, is_bot)
await _init(r_m)
_raise = partial(_raise_func, r_c, r_m)
if r_m.chat and r_m.chat.type not in flt.scope:
if isinstance(flt, types.raw.Command):
await _raise(f"`invalid chat type [{r_m.chat.type}]`")
return
is_bot = r_c.is_bot
if r_m.chat and flt.only_admins and not await _is_admin(r_m, is_bot):
if isinstance(flt, types.raw.Command):
await _raise("`chat admin required`")
Expand Down
Loading

0 comments on commit 9e2e6cb

Please sign in to comment.