9
9
import math
10
10
import os
11
11
from collections import defaultdict
12
- from typing import Optional , List , AsyncGenerator , Union , Awaitable , DefaultDict , Tuple , BinaryIO
13
-
14
- from telethon import utils , helpers , TelegramClient
12
+ from typing import (
13
+ AsyncGenerator ,
14
+ Awaitable ,
15
+ BinaryIO ,
16
+ DefaultDict ,
17
+ List ,
18
+ Optional ,
19
+ Tuple ,
20
+ Union ,
21
+ )
22
+
23
+ from telethon import TelegramClient , helpers , utils
15
24
from telethon .crypto import AuthKey
16
25
from telethon .network import MTProtoSender
17
26
from telethon .tl .alltlobjects import LAYER
18
27
from telethon .tl .functions import InvokeWithLayerRequest
19
- from telethon .tl .functions .auth import ExportAuthorizationRequest , ImportAuthorizationRequest
20
- from telethon .tl .functions .upload import (GetFileRequest , SaveFilePartRequest ,
21
- SaveBigFilePartRequest )
22
- from telethon .tl .types import (Document , InputFileLocation , InputDocumentFileLocation ,
23
- InputPhotoFileLocation , InputPeerPhotoFileLocation , TypeInputFile ,
24
- InputFileBig , InputFile )
28
+ from telethon .tl .functions .auth import (
29
+ ExportAuthorizationRequest ,
30
+ ImportAuthorizationRequest ,
31
+ )
32
+ from telethon .tl .functions .upload import (
33
+ GetFileRequest ,
34
+ SaveBigFilePartRequest ,
35
+ SaveFilePartRequest ,
36
+ )
37
+ from telethon .tl .types import (
38
+ Document ,
39
+ InputDocumentFileLocation ,
40
+ InputFile ,
41
+ InputFileBig ,
42
+ InputFileLocation ,
43
+ InputPeerPhotoFileLocation ,
44
+ InputPhotoFileLocation ,
45
+ TypeInputFile ,
46
+ )
25
47
26
48
filename = ""
27
49
28
- async_encrypt_attachment = None
29
-
30
- log : logging .Logger = logging .getLogger ("telethon" )
50
+ log : logging .Logger = logging .getLogger ("FastTelethon" )
31
51
32
- TypeLocation = Union [Document , InputDocumentFileLocation , InputPeerPhotoFileLocation ,
33
- InputFileLocation , InputPhotoFileLocation ]
52
+ TypeLocation = Union [
53
+ Document ,
54
+ InputDocumentFileLocation ,
55
+ InputPeerPhotoFileLocation ,
56
+ InputFileLocation ,
57
+ InputPhotoFileLocation ,
58
+ ]
34
59
35
60
36
61
class DownloadSender :
@@ -40,8 +65,16 @@ class DownloadSender:
40
65
remaining : int
41
66
stride : int
42
67
43
- def __init__ (self , client : TelegramClient , sender : MTProtoSender , file : TypeLocation , offset : int , limit : int ,
44
- stride : int , count : int ) -> None :
68
+ def __init__ (
69
+ self ,
70
+ client : TelegramClient ,
71
+ sender : MTProtoSender ,
72
+ file : TypeLocation ,
73
+ offset : int ,
74
+ limit : int ,
75
+ stride : int ,
76
+ count : int ,
77
+ ) -> None :
45
78
self .sender = sender
46
79
self .client = client
47
80
self .request = GetFileRequest (file , offset = offset , limit = limit )
@@ -69,9 +102,17 @@ class UploadSender:
69
102
previous : Optional [asyncio .Task ]
70
103
loop : asyncio .AbstractEventLoop
71
104
72
- def __init__ (self , client : TelegramClient , sender : MTProtoSender , file_id : int , part_count : int , big : bool ,
73
- index : int ,
74
- stride : int , loop : asyncio .AbstractEventLoop ) -> None :
105
+ def __init__ (
106
+ self ,
107
+ client : TelegramClient ,
108
+ sender : MTProtoSender ,
109
+ file_id : int ,
110
+ part_count : int ,
111
+ big : bool ,
112
+ index : int ,
113
+ stride : int ,
114
+ loop : asyncio .AbstractEventLoop ,
115
+ ) -> None :
75
116
self .client = client
76
117
self .sender = sender
77
118
self .part_count = part_count
@@ -90,8 +131,6 @@ async def next(self, data: bytes) -> None:
90
131
91
132
async def _next (self , data : bytes ) -> None :
92
133
self .request .bytes = data
93
- log .debug (f"Sending file part { self .request .file_part } /{ self .part_count } "
94
- f" with { len (data )} bytes" )
95
134
await self .client ._call (self .sender , self .request )
96
135
self .request .file_part += self .stride
97
136
@@ -113,8 +152,11 @@ def __init__(self, client: TelegramClient, dc_id: Optional[int] = None) -> None:
113
152
self .client = client
114
153
self .loop = self .client .loop
115
154
self .dc_id = dc_id or self .client .session .dc_id
116
- self .auth_key = (None if dc_id and self .client .session .dc_id != dc_id
117
- else self .client .session .auth_key )
155
+ self .auth_key = (
156
+ None
157
+ if dc_id and self .client .session .dc_id != dc_id
158
+ else self .client .session .auth_key
159
+ )
118
160
self .senders = None
119
161
self .upload_ticker = 0
120
162
@@ -123,14 +165,16 @@ async def _cleanup(self) -> None:
123
165
self .senders = None
124
166
125
167
@staticmethod
126
- def _get_connection_count (file_size : int , max_count : int = 20 ,
127
- full_size : int = 100 * 1024 * 1024 ) -> int :
168
+ def _get_connection_count (
169
+ file_size : int , max_count : int = 20 , full_size : int = 100 * 1024 * 1024
170
+ ) -> int :
128
171
if file_size > full_size :
129
172
return max_count
130
173
return math .ceil ((file_size / full_size ) * max_count )
131
174
132
- async def _init_download (self , connections : int , file : TypeLocation , part_count : int ,
133
- part_size : int ) -> None :
175
+ async def _init_download (
176
+ self , connections : int , file : TypeLocation , part_count : int , part_size : int
177
+ ) -> None :
134
178
minimum , remainder = divmod (part_count , connections )
135
179
136
180
def get_part_count () -> int :
@@ -143,52 +187,93 @@ def get_part_count() -> int:
143
187
# The first cross-DC sender will export+import the authorization, so we always create it
144
188
# before creating any other senders.
145
189
self .senders = [
146
- await self ._create_download_sender (file , 0 , part_size , connections * part_size ,
147
- get_part_count ()),
190
+ await self ._create_download_sender (
191
+ file , 0 , part_size , connections * part_size , get_part_count ()
192
+ ),
148
193
* await asyncio .gather (
149
- * [self ._create_download_sender (file , i , part_size , connections * part_size ,
150
- get_part_count ())
151
- for i in range (1 , connections )])
194
+ * [
195
+ self ._create_download_sender (
196
+ file , i , part_size , connections * part_size , get_part_count ()
197
+ )
198
+ for i in range (1 , connections )
199
+ ]
200
+ ),
152
201
]
153
202
154
- async def _create_download_sender (self , file : TypeLocation , index : int , part_size : int ,
155
- stride : int ,
156
- part_count : int ) -> DownloadSender :
157
- return DownloadSender (self .client , await self ._create_sender (), file , index * part_size , part_size ,
158
- stride , part_count )
159
-
160
- async def _init_upload (self , connections : int , file_id : int , part_count : int , big : bool
161
- ) -> None :
203
+ async def _create_download_sender (
204
+ self ,
205
+ file : TypeLocation ,
206
+ index : int ,
207
+ part_size : int ,
208
+ stride : int ,
209
+ part_count : int ,
210
+ ) -> DownloadSender :
211
+ return DownloadSender (
212
+ self .client ,
213
+ await self ._create_sender (),
214
+ file ,
215
+ index * part_size ,
216
+ part_size ,
217
+ stride ,
218
+ part_count ,
219
+ )
220
+
221
+ async def _init_upload (
222
+ self , connections : int , file_id : int , part_count : int , big : bool
223
+ ) -> None :
162
224
self .senders = [
163
225
await self ._create_upload_sender (file_id , part_count , big , 0 , connections ),
164
226
* await asyncio .gather (
165
- * [self ._create_upload_sender (file_id , part_count , big , i , connections )
166
- for i in range (1 , connections )])
227
+ * [
228
+ self ._create_upload_sender (file_id , part_count , big , i , connections )
229
+ for i in range (1 , connections )
230
+ ]
231
+ ),
167
232
]
168
233
169
- async def _create_upload_sender (self , file_id : int , part_count : int , big : bool , index : int ,
170
- stride : int ) -> UploadSender :
171
- return UploadSender (self .client , await self ._create_sender (), file_id , part_count , big , index , stride ,
172
- loop = self .loop )
234
+ async def _create_upload_sender (
235
+ self , file_id : int , part_count : int , big : bool , index : int , stride : int
236
+ ) -> UploadSender :
237
+ return UploadSender (
238
+ self .client ,
239
+ await self ._create_sender (),
240
+ file_id ,
241
+ part_count ,
242
+ big ,
243
+ index ,
244
+ stride ,
245
+ loop = self .loop ,
246
+ )
173
247
174
248
async def _create_sender (self ) -> MTProtoSender :
175
249
dc = await self .client ._get_dc (self .dc_id )
176
250
sender = MTProtoSender (self .auth_key , loggers = self .client ._log )
177
- await sender .connect (self .client ._connection (dc .ip_address , dc .port , dc .id ,
178
- loggers = self .client ._log ,
179
- proxy = self .client ._proxy ))
251
+ await sender .connect (
252
+ self .client ._connection (
253
+ dc .ip_address ,
254
+ dc .port ,
255
+ dc .id ,
256
+ loggers = self .client ._log ,
257
+ proxy = self .client ._proxy ,
258
+ )
259
+ )
180
260
if not self .auth_key :
181
- log .debug (f"Exporting auth to DC { self .dc_id } " )
182
261
auth = await self .client (ExportAuthorizationRequest (self .dc_id ))
183
- self .client ._init_request .query = ImportAuthorizationRequest (id = auth .id ,
184
- bytes = auth .bytes )
262
+ self .client ._init_request .query = ImportAuthorizationRequest (
263
+ id = auth .id , bytes = auth .bytes
264
+ )
185
265
req = InvokeWithLayerRequest (LAYER , self .client ._init_request )
186
266
await sender .send (req )
187
267
self .auth_key = sender .auth_key
188
268
return sender
189
269
190
- async def init_upload (self , file_id : int , file_size : int , part_size_kb : Optional [float ] = None ,
191
- connection_count : Optional [int ] = None ) -> Tuple [int , int , bool ]:
270
+ async def init_upload (
271
+ self ,
272
+ file_id : int ,
273
+ file_size : int ,
274
+ part_size_kb : Optional [float ] = None ,
275
+ connection_count : Optional [int ] = None ,
276
+ ) -> Tuple [int , int , bool ]:
192
277
connection_count = connection_count or self ._get_connection_count (file_size )
193
278
part_size = (part_size_kb or utils .get_appropriated_part_size (file_size )) * 1024
194
279
part_count = (file_size + part_size - 1 ) // part_size
@@ -203,14 +288,16 @@ async def upload(self, part: bytes) -> None:
203
288
async def finish_upload (self ) -> None :
204
289
await self ._cleanup ()
205
290
206
- async def download (self , file : TypeLocation , file_size : int ,
207
- part_size_kb : Optional [float ] = None ,
208
- connection_count : Optional [int ] = None ) -> AsyncGenerator [bytes , None ]:
291
+ async def download (
292
+ self ,
293
+ file : TypeLocation ,
294
+ file_size : int ,
295
+ part_size_kb : Optional [float ] = None ,
296
+ connection_count : Optional [int ] = None ,
297
+ ) -> AsyncGenerator [bytes , None ]:
209
298
connection_count = connection_count or self ._get_connection_count (file_size )
210
299
part_size = (part_size_kb or utils .get_appropriated_part_size (file_size )) * 1024
211
300
part_count = math .ceil (file_size / part_size )
212
- log .debug ("Starting parallel download: "
213
- f"{ connection_count } { part_size } { part_count } { file !s} " )
214
301
await self ._init_download (connection_count , file , part_count , part_size )
215
302
216
303
part = 0
@@ -224,13 +311,12 @@ async def download(self, file: TypeLocation, file_size: int,
224
311
break
225
312
yield data
226
313
part += 1
227
- log .debug (f"Part { part } downloaded" )
228
-
229
- log .debug ("Parallel download finished, cleaning up connections" )
230
314
await self ._cleanup ()
231
315
232
316
233
- parallel_transfer_locks : DefaultDict [int , asyncio .Lock ] = defaultdict (lambda : asyncio .Lock ())
317
+ parallel_transfer_locks : DefaultDict [int , asyncio .Lock ] = defaultdict (
318
+ lambda : asyncio .Lock ()
319
+ )
234
320
235
321
236
322
def stream_file (file_to_stream : BinaryIO , chunk_size = 1024 ):
@@ -241,10 +327,9 @@ def stream_file(file_to_stream: BinaryIO, chunk_size=1024):
241
327
yield data_read
242
328
243
329
244
- async def _internal_transfer_to_telegram (client : TelegramClient ,
245
- response : BinaryIO ,
246
- progress_callback : callable
247
- ) -> Tuple [TypeInputFile , int ]:
330
+ async def _internal_transfer_to_telegram (
331
+ client : TelegramClient , response : BinaryIO , progress_callback : callable
332
+ ) -> Tuple [TypeInputFile , int ]:
248
333
file_id = helpers .generate_random_long ()
249
334
file_size = os .path .getsize (response .name )
250
335
@@ -256,7 +341,10 @@ async def _internal_transfer_to_telegram(client: TelegramClient,
256
341
if progress_callback :
257
342
r = progress_callback (response .tell (), file_size )
258
343
if inspect .isawaitable (r ):
259
- await r
344
+ try :
345
+ await r
346
+ except BaseException :
347
+ pass
260
348
if not is_large :
261
349
hash_md5 .update (data )
262
350
if len (buffer ) == 0 and len (data ) == part_size :
@@ -280,11 +368,12 @@ async def _internal_transfer_to_telegram(client: TelegramClient,
280
368
return InputFile (file_id , part_count , filename , hash_md5 .hexdigest ()), file_size
281
369
282
370
283
- async def download_file (client : TelegramClient ,
284
- location : TypeLocation ,
285
- out : BinaryIO ,
286
- progress_callback : callable = None
287
- ) -> BinaryIO :
371
+ async def download_file (
372
+ client : TelegramClient ,
373
+ location : TypeLocation ,
374
+ out : BinaryIO ,
375
+ progress_callback : callable = None ,
376
+ ) -> BinaryIO :
288
377
size = location .size
289
378
dc_id , location = utils .get_input_location (location )
290
379
# We lock the transfers because telegram has connection count limits
@@ -295,16 +384,20 @@ async def download_file(client: TelegramClient,
295
384
if progress_callback :
296
385
r = progress_callback (out .tell (), size )
297
386
if inspect .isawaitable (r ):
298
- await r
387
+ try :
388
+ await r
389
+ except BaseException :
390
+ pass
299
391
300
392
return out
301
393
302
394
303
- async def upload_file (client : TelegramClient ,
304
- file : BinaryIO ,
305
- name ,
306
- progress_callback : callable = None ,
307
- ) -> TypeInputFile :
308
- global filename
309
- filename = name
310
- return (await _internal_transfer_to_telegram (client , file , progress_callback ))[0 ]
395
+ async def upload_file (
396
+ client : TelegramClient ,
397
+ file : BinaryIO ,
398
+ name ,
399
+ progress_callback : callable = None ,
400
+ ) -> TypeInputFile :
401
+ global filename
402
+ filename = name
403
+ return (await _internal_transfer_to_telegram (client , file , progress_callback ))[0 ]
0 commit comments