2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 21:07:59 +00:00

Add initial support for downloading media

This commit is contained in:
Dan 2018-02-18 18:11:33 +01:00
parent d89d238d30
commit 15561d19d5

View File

@ -27,6 +27,7 @@ import threading
import time import time
from collections import namedtuple from collections import namedtuple
from configparser import ConfigParser from configparser import ConfigParser
from datetime import datetime
from hashlib import sha256, md5 from hashlib import sha256, md5
from queue import Queue from queue import Queue
from signal import signal, SIGINT, SIGTERM, SIGABRT from signal import signal, SIGINT, SIGTERM, SIGABRT
@ -39,8 +40,8 @@ from pyrogram.api.errors import (
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty, PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty,
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded, PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
PasswordHashInvalid, FloodWait, PeerIdInvalid, FilePartMissing, PasswordHashInvalid, FloodWait, PeerIdInvalid, FilePartMissing,
ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned,
) VolumeLocNotFound)
from pyrogram.api.types import ( from pyrogram.api.types import (
User, Chat, Channel, User, Chat, Channel,
PeerUser, PeerChannel, PeerUser, PeerChannel,
@ -49,6 +50,7 @@ from pyrogram.api.types import (
) )
from pyrogram.crypto import AES from pyrogram.crypto import AES
from pyrogram.session import Auth, Session from pyrogram.session import Auth, Session
from pyrogram.session.internals import MsgId
from .input_media import InputMedia from .input_media import InputMedia
from .style import Markdown, HTML from .style import Markdown, HTML
@ -103,6 +105,7 @@ class Client:
INVITE_LINK_RE = re.compile(r"^(?:https?://)?t\.me/joinchat/(.+)$") INVITE_LINK_RE = re.compile(r"^(?:https?://)?t\.me/joinchat/(.+)$")
DIALOGS_AT_ONCE = 100 DIALOGS_AT_ONCE = 100
UPDATES_WORKERS = 2 UPDATES_WORKERS = 2
DOWNLOAD_WORKERS = 1
def __init__(self, def __init__(self,
session_name: str, session_name: str,
@ -148,6 +151,8 @@ class Client:
self.update_queue = Queue() self.update_queue = Queue()
self.update_handler = None self.update_handler = None
self.download_queue = Queue()
def start(self): def start(self):
"""Use this method to start the Client after creating it. """Use this method to start the Client after creating it.
Requires no parameters. Requires no parameters.
@ -176,7 +181,7 @@ class Client:
self.password = None self.password = None
self.save_session() self.save_session()
self.rnd_id = self.session.msg_id self.rnd_id = MsgId
self.get_dialogs() self.get_dialogs()
for i in range(self.UPDATES_WORKERS): for i in range(self.UPDATES_WORKERS):
@ -185,6 +190,9 @@ class Client:
for i in range(self.workers): for i in range(self.workers):
Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start() Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start()
for i in range(self.DOWNLOAD_WORKERS):
Thread(target=self.download_worker, name="DownloadWorker#{}".format(i + 1)).start()
mimetypes.init() mimetypes.init()
def stop(self): def stop(self):
@ -199,6 +207,9 @@ class Client:
for _ in range(self.workers): for _ in range(self.workers):
self.update_queue.put(None) self.update_queue.put(None)
for _ in range(self.DOWNLOAD_WORKERS):
self.download_queue.put(None)
def fetch_peers(self, entities: list): def fetch_peers(self, entities: list):
for entity in entities: for entity in entities:
if isinstance(entity, User): if isinstance(entity, User):
@ -260,6 +271,67 @@ class Client:
if username is not None: if username is not None:
self.peers_by_username[username] = input_peer self.peers_by_username[username] = input_peer
def download_worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
while True:
message = self.download_queue.get()
if message is None:
break
message, done = message
try:
if isinstance(message.media, types.MessageMediaDocument):
document = message.media.document
if isinstance(document, types.Document):
file_name = "doc_{}{}".format(
datetime.fromtimestamp(document.date).strftime("%Y-%m-%d_%H-%M-%S"),
mimetypes.guess_extension(document.mime_type) or ".unknown"
)
for i in document.attributes:
if isinstance(i, types.DocumentAttributeFilename):
file_name = i.file_name
break
elif isinstance(i, types.DocumentAttributeSticker):
file_name = file_name.replace("doc", "sticker")
elif isinstance(i, types.DocumentAttributeAudio):
file_name = file_name.replace("doc", "audio")
elif isinstance(i, types.DocumentAttributeVideo):
file_name = file_name.replace("doc", "video")
elif isinstance(i, types.DocumentAttributeAnimated):
file_name = file_name.replace("doc", "gif")
tmp_file_name = self.get_file(
dc_id=document.dc_id,
id=document.id,
access_hash=document.access_hash,
version=document.version
)
i = 1
while True:
try:
os.renames("./{}".format(tmp_file_name), "./downloads/{}".format(
".".join(file_name.split(".")[:-1])
+ (" ({}).".format(i) if i > 1 else ".")
+ file_name.split(".")[-1]
))
except FileExistsError:
i += 1
else:
break
done.set()
except Exception as e:
log.error(e, exc_info=True)
log.debug("{} stopped".format(name))
def updates_worker(self): def updates_worker(self):
name = threading.current_thread().name name = threading.current_thread().name
log.debug("{} started".format(name)) log.debug("{} started".format(name))
@ -1667,8 +1739,7 @@ class Client:
part_size = 512 * 1024 part_size = 512 * 1024
file_size = os.path.getsize(path) file_size = os.path.getsize(path)
file_total_parts = math.ceil(file_size / part_size) file_total_parts = math.ceil(file_size / part_size)
# is_big = True if file_size > 10 * 1024 * 1024 else False is_big = True if file_size > 10 * 1024 * 1024 else False
is_big = False # Treat all files as not-big to have the server check for the md5 sum
is_missing_part = True if file_id is not None else False is_missing_part = True if file_id is not None else False
file_id = file_id or self.rnd_id() file_id = file_id or self.rnd_id()
md5_sum = md5() if not is_big and not is_missing_part else None md5_sum = md5() if not is_big and not is_missing_part else None
@ -1759,22 +1830,19 @@ class Client:
session.start() session.start()
if volume_id: # Photos are accessed by volume_id, local_id, secret if volume_id: # Photos are accessed by volume_id, local_id, secret
file_name = "_".join(str(i) for i in [dc_id, volume_id, local_id, secret])
location = types.InputFileLocation( location = types.InputFileLocation(
volume_id=volume_id, volume_id=volume_id,
local_id=local_id, local_id=local_id,
secret=secret secret=secret
) )
else: # Any other file can be more easily accessed by id and access_hash else: # Any other file can be more easily accessed by id and access_hash
file_name = "_".join(str(i) for i in [dc_id, id, access_hash, version])
location = types.InputDocumentFileLocation( location = types.InputDocumentFileLocation(
id=id, id=id,
access_hash=access_hash, access_hash=access_hash,
version=version version=version
) )
file_name = str(MsgId())
limit = 1024 * 1024 limit = 1024 * 1024
offset = 0 offset = 0
@ -1822,63 +1890,57 @@ class Client:
cdn_session.start() cdn_session.start()
try: try:
r2 = cdn_session.send( with open(file_name, "wb") as f:
functions.upload.GetCdnFile( while True:
location=location, r2 = cdn_session.send(
file_token=r.file_token, functions.upload.GetCdnFile(
offset=offset, location=location,
limit=limit file_token=r.file_token,
) offset=offset,
) limit=limit
)
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
session.send(
functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
) )
)
else: if isinstance(r2, types.upload.CdnFileReuploadNeeded):
with open(file_name, "wb") as f: try:
while True: session.send(
if not isinstance(r2, types.upload.CdnFile): functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
)
except VolumeLocNotFound:
break break
else:
continue
chunk = r2.bytes chunk = r2.bytes
# https://core.telegram.org/cdn#decrypting-files # https://core.telegram.org/cdn#decrypting-files
decrypted_chunk = AES.ctr_decrypt( decrypted_chunk = AES.ctr_decrypt(
chunk, chunk,
r.encryption_key, r.encryption_key,
r.encryption_iv, r.encryption_iv,
offset
)
hashes = session.send(
functions.upload.GetCdnFileHashes(
r.file_token,
offset offset
) )
)
hashes = session.send( # https://core.telegram.org/cdn#verifying-files
functions.upload.GetCdnFileHashes( for i, h in enumerate(hashes):
r.file_token, cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
offset assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
)
)
for i, h in enumerate(hashes): f.write(decrypted_chunk)
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] f.flush()
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i) os.fsync(f.fileno())
f.write(decrypted_chunk) offset += limit
f.flush()
os.fsync(f.fileno())
offset += limit
r2 = cdn_session.send(
functions.upload.GetCdnFile(
location=location,
file_token=r.file_token,
offset=offset,
limit=limit
)
)
except Exception as e: except Exception as e:
log.error(e) log.error(e)
finally: finally:
@ -2238,3 +2300,8 @@ class Client:
reply_to_msg_id=reply_to_message_id reply_to_msg_id=reply_to_message_id
) )
) )
def download_media(self, message: types.Message):
done = Event()
self.download_queue.put((message, done))
done.wait()