mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-28 21:07:59 +00:00
Add initial support for downloading media
This commit is contained in:
parent
d89d238d30
commit
15561d19d5
@ -27,6 +27,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
|
from datetime import datetime
|
||||||
from hashlib import sha256, md5
|
from hashlib import sha256, md5
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from signal import signal, SIGINT, SIGTERM, SIGABRT
|
from signal import signal, SIGINT, SIGTERM, SIGABRT
|
||||||
@ -39,8 +40,8 @@ from pyrogram.api.errors import (
|
|||||||
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty,
|
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty,
|
||||||
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
|
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
|
||||||
PasswordHashInvalid, FloodWait, PeerIdInvalid, FilePartMissing,
|
PasswordHashInvalid, FloodWait, PeerIdInvalid, FilePartMissing,
|
||||||
ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned
|
ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned,
|
||||||
)
|
VolumeLocNotFound)
|
||||||
from pyrogram.api.types import (
|
from pyrogram.api.types import (
|
||||||
User, Chat, Channel,
|
User, Chat, Channel,
|
||||||
PeerUser, PeerChannel,
|
PeerUser, PeerChannel,
|
||||||
@ -49,6 +50,7 @@ from pyrogram.api.types import (
|
|||||||
)
|
)
|
||||||
from pyrogram.crypto import AES
|
from pyrogram.crypto import AES
|
||||||
from pyrogram.session import Auth, Session
|
from pyrogram.session import Auth, Session
|
||||||
|
from pyrogram.session.internals import MsgId
|
||||||
from .input_media import InputMedia
|
from .input_media import InputMedia
|
||||||
from .style import Markdown, HTML
|
from .style import Markdown, HTML
|
||||||
|
|
||||||
@ -103,6 +105,7 @@ class Client:
|
|||||||
INVITE_LINK_RE = re.compile(r"^(?:https?://)?t\.me/joinchat/(.+)$")
|
INVITE_LINK_RE = re.compile(r"^(?:https?://)?t\.me/joinchat/(.+)$")
|
||||||
DIALOGS_AT_ONCE = 100
|
DIALOGS_AT_ONCE = 100
|
||||||
UPDATES_WORKERS = 2
|
UPDATES_WORKERS = 2
|
||||||
|
DOWNLOAD_WORKERS = 1
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
session_name: str,
|
session_name: str,
|
||||||
@ -148,6 +151,8 @@ class Client:
|
|||||||
self.update_queue = Queue()
|
self.update_queue = Queue()
|
||||||
self.update_handler = None
|
self.update_handler = None
|
||||||
|
|
||||||
|
self.download_queue = Queue()
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Use this method to start the Client after creating it.
|
"""Use this method to start the Client after creating it.
|
||||||
Requires no parameters.
|
Requires no parameters.
|
||||||
@ -176,7 +181,7 @@ class Client:
|
|||||||
self.password = None
|
self.password = None
|
||||||
self.save_session()
|
self.save_session()
|
||||||
|
|
||||||
self.rnd_id = self.session.msg_id
|
self.rnd_id = MsgId
|
||||||
self.get_dialogs()
|
self.get_dialogs()
|
||||||
|
|
||||||
for i in range(self.UPDATES_WORKERS):
|
for i in range(self.UPDATES_WORKERS):
|
||||||
@ -185,6 +190,9 @@ class Client:
|
|||||||
for i in range(self.workers):
|
for i in range(self.workers):
|
||||||
Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start()
|
Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start()
|
||||||
|
|
||||||
|
for i in range(self.DOWNLOAD_WORKERS):
|
||||||
|
Thread(target=self.download_worker, name="DownloadWorker#{}".format(i + 1)).start()
|
||||||
|
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
@ -199,6 +207,9 @@ class Client:
|
|||||||
for _ in range(self.workers):
|
for _ in range(self.workers):
|
||||||
self.update_queue.put(None)
|
self.update_queue.put(None)
|
||||||
|
|
||||||
|
for _ in range(self.DOWNLOAD_WORKERS):
|
||||||
|
self.download_queue.put(None)
|
||||||
|
|
||||||
def fetch_peers(self, entities: list):
|
def fetch_peers(self, entities: list):
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if isinstance(entity, User):
|
if isinstance(entity, User):
|
||||||
@ -260,6 +271,67 @@ class Client:
|
|||||||
if username is not None:
|
if username is not None:
|
||||||
self.peers_by_username[username] = input_peer
|
self.peers_by_username[username] = input_peer
|
||||||
|
|
||||||
|
def download_worker(self):
|
||||||
|
name = threading.current_thread().name
|
||||||
|
log.debug("{} started".format(name))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
message = self.download_queue.get()
|
||||||
|
|
||||||
|
if message is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
message, done = message
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(message.media, types.MessageMediaDocument):
|
||||||
|
document = message.media.document
|
||||||
|
|
||||||
|
if isinstance(document, types.Document):
|
||||||
|
file_name = "doc_{}{}".format(
|
||||||
|
datetime.fromtimestamp(document.date).strftime("%Y-%m-%d_%H-%M-%S"),
|
||||||
|
mimetypes.guess_extension(document.mime_type) or ".unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in document.attributes:
|
||||||
|
if isinstance(i, types.DocumentAttributeFilename):
|
||||||
|
file_name = i.file_name
|
||||||
|
break
|
||||||
|
elif isinstance(i, types.DocumentAttributeSticker):
|
||||||
|
file_name = file_name.replace("doc", "sticker")
|
||||||
|
elif isinstance(i, types.DocumentAttributeAudio):
|
||||||
|
file_name = file_name.replace("doc", "audio")
|
||||||
|
elif isinstance(i, types.DocumentAttributeVideo):
|
||||||
|
file_name = file_name.replace("doc", "video")
|
||||||
|
elif isinstance(i, types.DocumentAttributeAnimated):
|
||||||
|
file_name = file_name.replace("doc", "gif")
|
||||||
|
|
||||||
|
tmp_file_name = self.get_file(
|
||||||
|
dc_id=document.dc_id,
|
||||||
|
id=document.id,
|
||||||
|
access_hash=document.access_hash,
|
||||||
|
version=document.version
|
||||||
|
)
|
||||||
|
|
||||||
|
i = 1
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
os.renames("./{}".format(tmp_file_name), "./downloads/{}".format(
|
||||||
|
".".join(file_name.split(".")[:-1])
|
||||||
|
+ (" ({}).".format(i) if i > 1 else ".")
|
||||||
|
+ file_name.split(".")[-1]
|
||||||
|
))
|
||||||
|
except FileExistsError:
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
done.set()
|
||||||
|
except Exception as e:
|
||||||
|
log.error(e, exc_info=True)
|
||||||
|
|
||||||
|
log.debug("{} stopped".format(name))
|
||||||
|
|
||||||
def updates_worker(self):
|
def updates_worker(self):
|
||||||
name = threading.current_thread().name
|
name = threading.current_thread().name
|
||||||
log.debug("{} started".format(name))
|
log.debug("{} started".format(name))
|
||||||
@ -1667,8 +1739,7 @@ class Client:
|
|||||||
part_size = 512 * 1024
|
part_size = 512 * 1024
|
||||||
file_size = os.path.getsize(path)
|
file_size = os.path.getsize(path)
|
||||||
file_total_parts = math.ceil(file_size / part_size)
|
file_total_parts = math.ceil(file_size / part_size)
|
||||||
# is_big = True if file_size > 10 * 1024 * 1024 else False
|
is_big = True if file_size > 10 * 1024 * 1024 else False
|
||||||
is_big = False # Treat all files as not-big to have the server check for the md5 sum
|
|
||||||
is_missing_part = True if file_id is not None else False
|
is_missing_part = True if file_id is not None else False
|
||||||
file_id = file_id or self.rnd_id()
|
file_id = file_id or self.rnd_id()
|
||||||
md5_sum = md5() if not is_big and not is_missing_part else None
|
md5_sum = md5() if not is_big and not is_missing_part else None
|
||||||
@ -1759,22 +1830,19 @@ class Client:
|
|||||||
session.start()
|
session.start()
|
||||||
|
|
||||||
if volume_id: # Photos are accessed by volume_id, local_id, secret
|
if volume_id: # Photos are accessed by volume_id, local_id, secret
|
||||||
file_name = "_".join(str(i) for i in [dc_id, volume_id, local_id, secret])
|
|
||||||
|
|
||||||
location = types.InputFileLocation(
|
location = types.InputFileLocation(
|
||||||
volume_id=volume_id,
|
volume_id=volume_id,
|
||||||
local_id=local_id,
|
local_id=local_id,
|
||||||
secret=secret
|
secret=secret
|
||||||
)
|
)
|
||||||
else: # Any other file can be more easily accessed by id and access_hash
|
else: # Any other file can be more easily accessed by id and access_hash
|
||||||
file_name = "_".join(str(i) for i in [dc_id, id, access_hash, version])
|
|
||||||
|
|
||||||
location = types.InputDocumentFileLocation(
|
location = types.InputDocumentFileLocation(
|
||||||
id=id,
|
id=id,
|
||||||
access_hash=access_hash,
|
access_hash=access_hash,
|
||||||
version=version
|
version=version
|
||||||
)
|
)
|
||||||
|
|
||||||
|
file_name = str(MsgId())
|
||||||
limit = 1024 * 1024
|
limit = 1024 * 1024
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
||||||
@ -1822,63 +1890,57 @@ class Client:
|
|||||||
cdn_session.start()
|
cdn_session.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r2 = cdn_session.send(
|
with open(file_name, "wb") as f:
|
||||||
functions.upload.GetCdnFile(
|
while True:
|
||||||
location=location,
|
r2 = cdn_session.send(
|
||||||
file_token=r.file_token,
|
functions.upload.GetCdnFile(
|
||||||
offset=offset,
|
location=location,
|
||||||
limit=limit
|
file_token=r.file_token,
|
||||||
)
|
offset=offset,
|
||||||
)
|
limit=limit
|
||||||
|
)
|
||||||
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
|
|
||||||
session.send(
|
|
||||||
functions.upload.ReuploadCdnFile(
|
|
||||||
file_token=r.file_token,
|
|
||||||
request_token=r2.request_token
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
|
||||||
with open(file_name, "wb") as f:
|
try:
|
||||||
while True:
|
session.send(
|
||||||
if not isinstance(r2, types.upload.CdnFile):
|
functions.upload.ReuploadCdnFile(
|
||||||
|
file_token=r.file_token,
|
||||||
|
request_token=r2.request_token
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except VolumeLocNotFound:
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
chunk = r2.bytes
|
chunk = r2.bytes
|
||||||
|
|
||||||
# https://core.telegram.org/cdn#decrypting-files
|
# https://core.telegram.org/cdn#decrypting-files
|
||||||
decrypted_chunk = AES.ctr_decrypt(
|
decrypted_chunk = AES.ctr_decrypt(
|
||||||
chunk,
|
chunk,
|
||||||
r.encryption_key,
|
r.encryption_key,
|
||||||
r.encryption_iv,
|
r.encryption_iv,
|
||||||
|
offset
|
||||||
|
)
|
||||||
|
|
||||||
|
hashes = session.send(
|
||||||
|
functions.upload.GetCdnFileHashes(
|
||||||
|
r.file_token,
|
||||||
offset
|
offset
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
hashes = session.send(
|
# https://core.telegram.org/cdn#verifying-files
|
||||||
functions.upload.GetCdnFileHashes(
|
for i, h in enumerate(hashes):
|
||||||
r.file_token,
|
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||||
offset
|
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, h in enumerate(hashes):
|
f.write(decrypted_chunk)
|
||||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
f.flush()
|
||||||
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
|
os.fsync(f.fileno())
|
||||||
|
|
||||||
f.write(decrypted_chunk)
|
offset += limit
|
||||||
f.flush()
|
|
||||||
os.fsync(f.fileno())
|
|
||||||
|
|
||||||
offset += limit
|
|
||||||
|
|
||||||
r2 = cdn_session.send(
|
|
||||||
functions.upload.GetCdnFile(
|
|
||||||
location=location,
|
|
||||||
file_token=r.file_token,
|
|
||||||
offset=offset,
|
|
||||||
limit=limit
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(e)
|
log.error(e)
|
||||||
finally:
|
finally:
|
||||||
@ -2238,3 +2300,8 @@ class Client:
|
|||||||
reply_to_msg_id=reply_to_message_id
|
reply_to_msg_id=reply_to_message_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def download_media(self, message: types.Message):
|
||||||
|
done = Event()
|
||||||
|
self.download_queue.put((message, done))
|
||||||
|
done.wait()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user