forked from 1Danish-00/CompressorBot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FastTelethon.py
403 lines (359 loc) · 12.1 KB
/
FastTelethon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
"""
> Based on parallel_file_transfer.py from mautrix-telegram, with permission to distribute under the MIT license
> Copyright (C) 2019 Tulir Asokan - https://github.com/tulir/mautrix-telegram
"""
import asyncio
import hashlib
import inspect
import logging
import math
import os
from collections import defaultdict
from typing import (
AsyncGenerator,
Awaitable,
BinaryIO,
DefaultDict,
List,
Optional,
Tuple,
Union,
)
from telethon import TelegramClient, helpers, utils
from telethon.crypto import AuthKey
from telethon.network import MTProtoSender
from telethon.tl.alltlobjects import LAYER
from telethon.tl.functions import InvokeWithLayerRequest
from telethon.tl.functions.auth import (
ExportAuthorizationRequest,
ImportAuthorizationRequest,
)
from telethon.tl.functions.upload import (
GetFileRequest,
SaveBigFilePartRequest,
SaveFilePartRequest,
)
from telethon.tl.types import (
Document,
InputDocumentFileLocation,
InputFile,
InputFileBig,
InputFileLocation,
InputPeerPhotoFileLocation,
InputPhotoFileLocation,
TypeInputFile,
)
filename = ""
log: logging.Logger = logging.getLogger("FastTelethon")
TypeLocation = Union[
Document,
InputDocumentFileLocation,
InputPeerPhotoFileLocation,
InputFileLocation,
InputPhotoFileLocation,
]
class DownloadSender:
client: TelegramClient
sender: MTProtoSender
request: GetFileRequest
remaining: int
stride: int
def __init__(
self,
client: TelegramClient,
sender: MTProtoSender,
file: TypeLocation,
offset: int,
limit: int,
stride: int,
count: int,
) -> None:
self.sender = sender
self.client = client
self.request = GetFileRequest(file, offset=offset, limit=limit)
self.stride = stride
self.remaining = count
async def next(self) -> Optional[bytes]:
if not self.remaining:
return None
result = await self.client._call(self.sender, self.request)
self.remaining -= 1
self.request.offset += self.stride
return result.bytes
def disconnect(self) -> Awaitable[None]:
return self.sender.disconnect()
class UploadSender:
client: TelegramClient
sender: MTProtoSender
request: Union[SaveFilePartRequest, SaveBigFilePartRequest]
part_count: int
stride: int
previous: Optional[asyncio.Task]
loop: asyncio.AbstractEventLoop
def __init__(
self,
client: TelegramClient,
sender: MTProtoSender,
file_id: int,
part_count: int,
big: bool,
index: int,
stride: int,
loop: asyncio.AbstractEventLoop,
) -> None:
self.client = client
self.sender = sender
self.part_count = part_count
if big:
self.request = SaveBigFilePartRequest(file_id, index, part_count, b"")
else:
self.request = SaveFilePartRequest(file_id, index, b"")
self.stride = stride
self.previous = None
self.loop = loop
async def next(self, data: bytes) -> None:
if self.previous:
await self.previous
self.previous = self.loop.create_task(self._next(data))
async def _next(self, data: bytes) -> None:
self.request.bytes = data
await self.client._call(self.sender, self.request)
self.request.file_part += self.stride
async def disconnect(self) -> None:
if self.previous:
await self.previous
return await self.sender.disconnect()
class ParallelTransferrer:
client: TelegramClient
loop: asyncio.AbstractEventLoop
dc_id: int
senders: Optional[List[Union[DownloadSender, UploadSender]]]
auth_key: AuthKey
upload_ticker: int
def __init__(self, client: TelegramClient, dc_id: Optional[int] = None) -> None:
self.client = client
self.loop = self.client.loop
self.dc_id = dc_id or self.client.session.dc_id
self.auth_key = (
None
if dc_id and self.client.session.dc_id != dc_id
else self.client.session.auth_key
)
self.senders = None
self.upload_ticker = 0
async def _cleanup(self) -> None:
await asyncio.gather(*[sender.disconnect() for sender in self.senders])
self.senders = None
@staticmethod
def _get_connection_count(
file_size: int, max_count: int = 20, full_size: int = 100 * 1024 * 1024
) -> int:
if file_size > full_size:
return max_count
return math.ceil((file_size / full_size) * max_count)
async def _init_download(
self, connections: int, file: TypeLocation, part_count: int, part_size: int
) -> None:
minimum, remainder = divmod(part_count, connections)
def get_part_count() -> int:
nonlocal remainder
if remainder > 0:
remainder -= 1
return minimum + 1
return minimum
# The first cross-DC sender will export+import the authorization, so we always create it
# before creating any other senders.
self.senders = [
await self._create_download_sender(
file, 0, part_size, connections * part_size, get_part_count()
),
*await asyncio.gather(
*[
self._create_download_sender(
file, i, part_size, connections * part_size, get_part_count()
)
for i in range(1, connections)
]
),
]
async def _create_download_sender(
self,
file: TypeLocation,
index: int,
part_size: int,
stride: int,
part_count: int,
) -> DownloadSender:
return DownloadSender(
self.client,
await self._create_sender(),
file,
index * part_size,
part_size,
stride,
part_count,
)
async def _init_upload(
self, connections: int, file_id: int, part_count: int, big: bool
) -> None:
self.senders = [
await self._create_upload_sender(file_id, part_count, big, 0, connections),
*await asyncio.gather(
*[
self._create_upload_sender(file_id, part_count, big, i, connections)
for i in range(1, connections)
]
),
]
async def _create_upload_sender(
self, file_id: int, part_count: int, big: bool, index: int, stride: int
) -> UploadSender:
return UploadSender(
self.client,
await self._create_sender(),
file_id,
part_count,
big,
index,
stride,
loop=self.loop,
)
async def _create_sender(self) -> MTProtoSender:
dc = await self.client._get_dc(self.dc_id)
sender = MTProtoSender(self.auth_key, loggers=self.client._log)
await sender.connect(
self.client._connection(
dc.ip_address,
dc.port,
dc.id,
loggers=self.client._log,
proxy=self.client._proxy,
)
)
if not self.auth_key:
auth = await self.client(ExportAuthorizationRequest(self.dc_id))
self.client._init_request.query = ImportAuthorizationRequest(
id=auth.id, bytes=auth.bytes
)
req = InvokeWithLayerRequest(LAYER, self.client._init_request)
await sender.send(req)
self.auth_key = sender.auth_key
return sender
async def init_upload(
self,
file_id: int,
file_size: int,
part_size_kb: Optional[float] = None,
connection_count: Optional[int] = None,
) -> Tuple[int, int, bool]:
connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = (file_size + part_size - 1) // part_size
is_large = file_size > 10 * 1024 * 1024
await self._init_upload(connection_count, file_id, part_count, is_large)
return part_size, part_count, is_large
async def upload(self, part: bytes) -> None:
await self.senders[self.upload_ticker].next(part)
self.upload_ticker = (self.upload_ticker + 1) % len(self.senders)
async def finish_upload(self) -> None:
await self._cleanup()
async def download(
self,
file: TypeLocation,
file_size: int,
part_size_kb: Optional[float] = None,
connection_count: Optional[int] = None,
) -> AsyncGenerator[bytes, None]:
connection_count = connection_count or self._get_connection_count(file_size)
part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
part_count = math.ceil(file_size / part_size)
await self._init_download(connection_count, file, part_count, part_size)
part = 0
while part < part_count:
tasks = []
for sender in self.senders:
tasks.append(self.loop.create_task(sender.next()))
for task in tasks:
data = await task
if not data:
break
yield data
part += 1
await self._cleanup()
parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(
lambda: asyncio.Lock()
)
def stream_file(file_to_stream: BinaryIO, chunk_size=1024):
while True:
data_read = file_to_stream.read(chunk_size)
if not data_read:
break
yield data_read
async def _internal_transfer_to_telegram(
client: TelegramClient, response: BinaryIO, progress_callback: callable
) -> Tuple[TypeInputFile, int]:
file_id = helpers.generate_random_long()
file_size = os.path.getsize(response.name)
hash_md5 = hashlib.md5()
uploader = ParallelTransferrer(client)
part_size, part_count, is_large = await uploader.init_upload(file_id, file_size)
buffer = bytearray()
for data in stream_file(response):
if progress_callback:
r = progress_callback(response.tell(), file_size)
if inspect.isawaitable(r):
try:
await r
except BaseException:
pass
if not is_large:
hash_md5.update(data)
if len(buffer) == 0 and len(data) == part_size:
await uploader.upload(data)
continue
new_len = len(buffer) + len(data)
if new_len >= part_size:
cutoff = part_size - len(buffer)
buffer.extend(data[:cutoff])
await uploader.upload(bytes(buffer))
buffer.clear()
buffer.extend(data[cutoff:])
else:
buffer.extend(data)
if len(buffer) > 0:
await uploader.upload(bytes(buffer))
await uploader.finish_upload()
if is_large:
return InputFileBig(file_id, part_count, filename), file_size
else:
return InputFile(file_id, part_count, filename, hash_md5.hexdigest()), file_size
async def download_file(
client: TelegramClient,
location: TypeLocation,
out: BinaryIO,
progress_callback: callable = None,
) -> BinaryIO:
size = location.size
dc_id, location = utils.get_input_location(location)
# We lock the transfers because telegram has connection count limits
downloader = ParallelTransferrer(client, dc_id)
downloaded = downloader.download(location, size)
async for x in downloaded:
out.write(x)
if progress_callback:
r = progress_callback(out.tell(), size)
if inspect.isawaitable(r):
try:
await r
except BaseException:
pass
return out
async def upload_file(
client: TelegramClient,
file: BinaryIO,
name,
progress_callback: callable = None,
) -> TypeInputFile:
global filename
filename = name
return (await _internal_transfer_to_telegram(client, file, progress_callback))[0]