Skip to content

Commit a9525aa

Browse files
authored
_
1 parent 3380b79 commit a9525aa

File tree

1 file changed

+175
-82
lines changed

1 file changed

+175
-82
lines changed

helper/FastTelethon.py

+175-82
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,53 @@
99
import math
1010
import os
1111
from collections import defaultdict
12-
from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple, BinaryIO
13-
14-
from telethon import utils, helpers, TelegramClient
12+
from typing import (
13+
AsyncGenerator,
14+
Awaitable,
15+
BinaryIO,
16+
DefaultDict,
17+
List,
18+
Optional,
19+
Tuple,
20+
Union,
21+
)
22+
23+
from telethon import TelegramClient, helpers, utils
1524
from telethon.crypto import AuthKey
1625
from telethon.network import MTProtoSender
1726
from telethon.tl.alltlobjects import LAYER
1827
from telethon.tl.functions import InvokeWithLayerRequest
19-
from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
20-
from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest,
21-
SaveBigFilePartRequest)
22-
from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
23-
InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile,
24-
InputFileBig, InputFile)
28+
from telethon.tl.functions.auth import (
29+
ExportAuthorizationRequest,
30+
ImportAuthorizationRequest,
31+
)
32+
from telethon.tl.functions.upload import (
33+
GetFileRequest,
34+
SaveBigFilePartRequest,
35+
SaveFilePartRequest,
36+
)
37+
from telethon.tl.types import (
38+
Document,
39+
InputDocumentFileLocation,
40+
InputFile,
41+
InputFileBig,
42+
InputFileLocation,
43+
InputPeerPhotoFileLocation,
44+
InputPhotoFileLocation,
45+
TypeInputFile,
46+
)
2547

2648
filename = ""
2749

28-
async_encrypt_attachment = None
29-
30-
log: logging.Logger = logging.getLogger("telethon")
50+
log: logging.Logger = logging.getLogger("FastTelethon")
3151

32-
TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation,
33-
InputFileLocation, InputPhotoFileLocation]
52+
TypeLocation = Union[
53+
Document,
54+
InputDocumentFileLocation,
55+
InputPeerPhotoFileLocation,
56+
InputFileLocation,
57+
InputPhotoFileLocation,
58+
]
3459

3560

3661
class DownloadSender:
@@ -40,8 +65,16 @@ class DownloadSender:
4065
remaining: int
4166
stride: int
4267

43-
def __init__(self, client: TelegramClient, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int,
44-
stride: int, count: int) -> None:
68+
def __init__(
69+
self,
70+
client: TelegramClient,
71+
sender: MTProtoSender,
72+
file: TypeLocation,
73+
offset: int,
74+
limit: int,
75+
stride: int,
76+
count: int,
77+
) -> None:
4578
self.sender = sender
4679
self.client = client
4780
self.request = GetFileRequest(file, offset=offset, limit=limit)
@@ -69,9 +102,17 @@ class UploadSender:
69102
previous: Optional[asyncio.Task]
70103
loop: asyncio.AbstractEventLoop
71104

72-
def __init__(self, client: TelegramClient, sender: MTProtoSender, file_id: int, part_count: int, big: bool,
73-
index: int,
74-
stride: int, loop: asyncio.AbstractEventLoop) -> None:
105+
def __init__(
106+
self,
107+
client: TelegramClient,
108+
sender: MTProtoSender,
109+
file_id: int,
110+
part_count: int,
111+
big: bool,
112+
index: int,
113+
stride: int,
114+
loop: asyncio.AbstractEventLoop,
115+
) -> None:
75116
self.client = client
76117
self.sender = sender
77118
self.part_count = part_count
@@ -90,8 +131,6 @@ async def next(self, data: bytes) -> None:
90131

91132
async def _next(self, data: bytes) -> None:
92133
self.request.bytes = data
93-
log.debug(f"Sending file part {self.request.file_part}/{self.part_count}"
94-
f" with {len(data)} bytes")
95134
await self.client._call(self.sender, self.request)
96135
self.request.file_part += self.stride
97136

@@ -113,8 +152,11 @@ def __init__(self, client: TelegramClient, dc_id: Optional[int] = None) -> None:
113152
self.client = client
114153
self.loop = self.client.loop
115154
self.dc_id = dc_id or self.client.session.dc_id
116-
self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id
117-
else self.client.session.auth_key)
155+
self.auth_key = (
156+
None
157+
if dc_id and self.client.session.dc_id != dc_id
158+
else self.client.session.auth_key
159+
)
118160
self.senders = None
119161
self.upload_ticker = 0
120162

@@ -123,14 +165,16 @@ async def _cleanup(self) -> None:
123165
self.senders = None
124166

125167
@staticmethod
126-
def _get_connection_count(file_size: int, max_count: int = 20,
127-
full_size: int = 100 * 1024 * 1024) -> int:
168+
def _get_connection_count(
169+
file_size: int, max_count: int = 20, full_size: int = 100 * 1024 * 1024
170+
) -> int:
128171
if file_size > full_size:
129172
return max_count
130173
return math.ceil((file_size / full_size) * max_count)
131174

132-
async def _init_download(self, connections: int, file: TypeLocation, part_count: int,
133-
part_size: int) -> None:
175+
async def _init_download(
176+
self, connections: int, file: TypeLocation, part_count: int, part_size: int
177+
) -> None:
134178
minimum, remainder = divmod(part_count, connections)
135179

136180
def get_part_count() -> int:
@@ -143,52 +187,93 @@ def get_part_count() -> int:
143187
# The first cross-DC sender will export+import the authorization, so we always create it
144188
# before creating any other senders.
145189
self.senders = [
146-
await self._create_download_sender(file, 0, part_size, connections * part_size,
147-
get_part_count()),
190+
await self._create_download_sender(
191+
file, 0, part_size, connections * part_size, get_part_count()
192+
),
148193
*await asyncio.gather(
149-
*[self._create_download_sender(file, i, part_size, connections * part_size,
150-
get_part_count())
151-
for i in range(1, connections)])
194+
*[
195+
self._create_download_sender(
196+
file, i, part_size, connections * part_size, get_part_count()
197+
)
198+
for i in range(1, connections)
199+
]
200+
),
152201
]
153202

154-
async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int,
155-
stride: int,
156-
part_count: int) -> DownloadSender:
157-
return DownloadSender(self.client, await self._create_sender(), file, index * part_size, part_size,
158-
stride, part_count)
159-
160-
async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool
161-
) -> None:
203+
async def _create_download_sender(
204+
self,
205+
file: TypeLocation,
206+
index: int,
207+
part_size: int,
208+
stride: int,
209+
part_count: int,
210+
) -> DownloadSender:
211+
return DownloadSender(
212+
self.client,
213+
await self._create_sender(),
214+
file,
215+
index * part_size,
216+
part_size,
217+
stride,
218+
part_count,
219+
)
220+
221+
async def _init_upload(
222+
self, connections: int, file_id: int, part_count: int, big: bool
223+
) -> None:
162224
self.senders = [
163225
await self._create_upload_sender(file_id, part_count, big, 0, connections),
164226
*await asyncio.gather(
165-
*[self._create_upload_sender(file_id, part_count, big, i, connections)
166-
for i in range(1, connections)])
227+
*[
228+
self._create_upload_sender(file_id, part_count, big, i, connections)
229+
for i in range(1, connections)
230+
]
231+
),
167232
]
168233

169-
async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int,
170-
stride: int) -> UploadSender:
171-
return UploadSender(self.client, await self._create_sender(), file_id, part_count, big, index, stride,
172-
loop=self.loop)
234+
async def _create_upload_sender(
235+
self, file_id: int, part_count: int, big: bool, index: int, stride: int
236+
) -> UploadSender:
237+
return UploadSender(
238+
self.client,
239+
await self._create_sender(),
240+
file_id,
241+
part_count,
242+
big,
243+
index,
244+
stride,
245+
loop=self.loop,
246+
)
173247

174248
async def _create_sender(self) -> MTProtoSender:
175249
dc = await self.client._get_dc(self.dc_id)
176250
sender = MTProtoSender(self.auth_key, loggers=self.client._log)
177-
await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id,
178-
loggers=self.client._log,
179-
proxy=self.client._proxy))
251+
await sender.connect(
252+
self.client._connection(
253+
dc.ip_address,
254+
dc.port,
255+
dc.id,
256+
loggers=self.client._log,
257+
proxy=self.client._proxy,
258+
)
259+
)
180260
if not self.auth_key:
181-
log.debug(f"Exporting auth to DC {self.dc_id}")
182261
auth = await self.client(ExportAuthorizationRequest(self.dc_id))
183-
self.client._init_request.query = ImportAuthorizationRequest(id=auth.id,
184-
bytes=auth.bytes)
262+
self.client._init_request.query = ImportAuthorizationRequest(
263+
id=auth.id, bytes=auth.bytes
264+
)
185265
req = InvokeWithLayerRequest(LAYER, self.client._init_request)
186266
await sender.send(req)
187267
self.auth_key = sender.auth_key
188268
return sender
189269

190-
async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None,
191-
connection_count: Optional[int] = None) -> Tuple[int, int, bool]:
270+
async def init_upload(
271+
self,
272+
file_id: int,
273+
file_size: int,
274+
part_size_kb: Optional[float] = None,
275+
connection_count: Optional[int] = None,
276+
) -> Tuple[int, int, bool]:
192277
connection_count = connection_count or self._get_connection_count(file_size)
193278
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
194279
part_count = (file_size + part_size - 1) // part_size
@@ -203,14 +288,16 @@ async def upload(self, part: bytes) -> None:
203288
async def finish_upload(self) -> None:
204289
await self._cleanup()
205290

206-
async def download(self, file: TypeLocation, file_size: int,
207-
part_size_kb: Optional[float] = None,
208-
connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]:
291+
async def download(
292+
self,
293+
file: TypeLocation,
294+
file_size: int,
295+
part_size_kb: Optional[float] = None,
296+
connection_count: Optional[int] = None,
297+
) -> AsyncGenerator[bytes, None]:
209298
connection_count = connection_count or self._get_connection_count(file_size)
210299
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
211300
part_count = math.ceil(file_size / part_size)
212-
log.debug("Starting parallel download: "
213-
f"{connection_count} {part_size} {part_count} {file!s}")
214301
await self._init_download(connection_count, file, part_count, part_size)
215302

216303
part = 0
@@ -224,13 +311,12 @@ async def download(self, file: TypeLocation, file_size: int,
224311
break
225312
yield data
226313
part += 1
227-
log.debug(f"Part {part} downloaded")
228-
229-
log.debug("Parallel download finished, cleaning up connections")
230314
await self._cleanup()
231315

232316

233-
parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
317+
parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(
318+
lambda: asyncio.Lock()
319+
)
234320

235321

236322
def stream_file(file_to_stream: BinaryIO, chunk_size=1024):
@@ -241,10 +327,9 @@ def stream_file(file_to_stream: BinaryIO, chunk_size=1024):
241327
yield data_read
242328

243329

244-
async def _internal_transfer_to_telegram(client: TelegramClient,
245-
response: BinaryIO,
246-
progress_callback: callable
247-
) -> Tuple[TypeInputFile, int]:
330+
async def _internal_transfer_to_telegram(
331+
client: TelegramClient, response: BinaryIO, progress_callback: callable
332+
) -> Tuple[TypeInputFile, int]:
248333
file_id = helpers.generate_random_long()
249334
file_size = os.path.getsize(response.name)
250335

@@ -256,7 +341,10 @@ async def _internal_transfer_to_telegram(client: TelegramClient,
256341
if progress_callback:
257342
r = progress_callback(response.tell(), file_size)
258343
if inspect.isawaitable(r):
259-
await r
344+
try:
345+
await r
346+
except BaseException:
347+
pass
260348
if not is_large:
261349
hash_md5.update(data)
262350
if len(buffer) == 0 and len(data) == part_size:
@@ -280,11 +368,12 @@ async def _internal_transfer_to_telegram(client: TelegramClient,
280368
return InputFile(file_id, part_count, filename, hash_md5.hexdigest()), file_size
281369

282370

283-
async def download_file(client: TelegramClient,
284-
location: TypeLocation,
285-
out: BinaryIO,
286-
progress_callback: callable = None
287-
) -> BinaryIO:
371+
async def download_file(
372+
client: TelegramClient,
373+
location: TypeLocation,
374+
out: BinaryIO,
375+
progress_callback: callable = None,
376+
) -> BinaryIO:
288377
size = location.size
289378
dc_id, location = utils.get_input_location(location)
290379
# We lock the transfers because telegram has connection count limits
@@ -295,16 +384,20 @@ async def download_file(client: TelegramClient,
295384
if progress_callback:
296385
r = progress_callback(out.tell(), size)
297386
if inspect.isawaitable(r):
298-
await r
387+
try:
388+
await r
389+
except BaseException:
390+
pass
299391

300392
return out
301393

302394

303-
async def upload_file(client: TelegramClient,
304-
file: BinaryIO,
305-
name,
306-
progress_callback: callable = None,
307-
) -> TypeInputFile:
308-
global filename
309-
filename = name
310-
return (await _internal_transfer_to_telegram(client, file, progress_callback))[0]
395+
async def upload_file(
396+
client: TelegramClient,
397+
file: BinaryIO,
398+
name,
399+
progress_callback: callable = None,
400+
) -> TypeInputFile:
401+
global filename
402+
filename = name
403+
return (await _internal_transfer_to_telegram(client, file, progress_callback))[0]

0 commit comments

Comments
 (0)