2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-27 20:37:54 +00:00

Use OS temp file, specific path download via path seperator inspection

This commit is contained in:
Eric Blundell 2018-03-20 15:20:04 -05:00
parent 62831001b7
commit 5bc10b45a3

View File

@ -34,6 +34,11 @@ from hashlib import sha256, md5
from queue import Queue from queue import Queue
from signal import signal, SIGINT, SIGTERM, SIGABRT from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Event, Thread from threading import Event, Thread
import tempfile
import shutil
import errno
from pyrogram.api import functions, types from pyrogram.api import functions, types
from pyrogram.api.core import Object from pyrogram.api.core import Object
@ -55,8 +60,6 @@ from pyrogram.session.internals import MsgId
from .input_media import InputMedia from .input_media import InputMedia
from .style import Markdown, HTML from .style import Markdown, HTML
from typing import Any
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ApiKey = namedtuple("ApiKey", ["api_id", "api_hash"]) ApiKey = namedtuple("ApiKey", ["api_id", "api_hash"])
@ -510,14 +513,18 @@ class Client:
break break
try: try:
media, file_dir, file_name, done, progress, path = media media, file_name, done, progress, path = media
tmp_file_name = None
if file_dir is not None: download_directory = "downloads"
# Make file_dir if it was specified
os.makedirs(file_dir, exist_ok=True)
if isinstance(file_name, str) and file_name is not None: if file_name.endswith('/') or file_name.endswith('\\'):
os.makedirs(os.path.dirname(file_name), exist_ok=True) # treat the file name as a directory
download_directory = file_name
file_name = None
elif '/' in file_name or '\\' in file_name:
# use file_name as a full path instead
download_directory = ''
if isinstance(media, types.MessageMediaDocument): if isinstance(media, types.MessageMediaDocument):
document = media.document document = media.document
@ -543,16 +550,13 @@ class Client:
elif isinstance(i, types.DocumentAttributeAnimated): elif isinstance(i, types.DocumentAttributeAnimated):
file_name = file_name.replace("doc", "gif") file_name = file_name.replace("doc", "gif")
file_name = os.path.join(file_dir if file_dir is not None else '', file_name) tmp_file_name = self.get_file(
self.get_file(
dc_id=document.dc_id, dc_id=document.dc_id,
id=document.id, id=document.id,
access_hash=document.access_hash, access_hash=document.access_hash,
version=document.version, version=document.version,
size=document.size, size=document.size,
progress=progress, progress=progress
file_out=file_name
) )
elif isinstance(media, (types.MessageMediaPhoto, types.Photo)): elif isinstance(media, (types.MessageMediaPhoto, types.Photo)):
if isinstance(media, types.MessageMediaPhoto): if isinstance(media, types.MessageMediaPhoto):
@ -567,27 +571,46 @@ class Client:
self.rnd_id() self.rnd_id()
) )
file_name = os.path.join(file_dir if file_dir is not None else '', file_name)
photo_loc = photo.sizes[-1].location photo_loc = photo.sizes[-1].location
self.get_file( tmp_file_name = self.get_file(
dc_id=photo_loc.dc_id, dc_id=photo_loc.dc_id,
volume_id=photo_loc.volume_id, volume_id=photo_loc.volume_id,
local_id=photo_loc.local_id, local_id=photo_loc.local_id,
secret=photo_loc.secret, secret=photo_loc.secret,
size=photo.sizes[-1].size, size=photo.sizes[-1].size,
progress=progress, progress=progress
file_out=file_name
) )
if file_name is not None: if file_name is not None:
path[0] = file_name path[0] = os.path.join(download_directory, file_name)
try:
os.remove(os.path.join(download_directory, file_name))
except OSError:
pass
finally:
try:
if download_directory:
os.makedirs(download_directory, exist_ok=True)
else:
os.makedirs(os.path.dirname(file_name), exist_ok=True)
# avoid errors moving between drives on windows
shutil.move(tmp_file_name, os.path.join(download_directory, file_name))
except OSError as e:
log.error(e, exc_info=True)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) log.error(e, exc_info=True)
finally: finally:
done.set() done.set()
try:
os.remove(tmp_file_name)
except OSError as e:
if not e.errno == errno.ENOENT:
log.error(e, exc_info=True)
log.debug("{} stopped".format(name)) log.debug("{} stopped".format(name))
def updates_worker(self): def updates_worker(self):
@ -2176,9 +2199,7 @@ class Client:
secret: int = None, secret: int = None,
version: int = 0, version: int = 0,
size: int = None, size: int = None,
progress: callable = None, progress: callable = None) -> str:
file_out: Any = None) -> str:
if dc_id != self.dc_id: if dc_id != self.dc_id:
exported_auth = self.send( exported_auth = self.send(
functions.auth.ExportAuthorization( functions.auth.ExportAuthorization(
@ -2226,13 +2247,11 @@ class Client:
version=version version=version
) )
fd, file_name = tempfile.mkstemp()
limit = 1024 * 1024 limit = 1024 * 1024
offset = 0 offset = 0
# file object being written
f = None
close_file, call_flush, call_fsync = False, False, False
try: try:
r = session.send( r = session.send(
functions.upload.GetFile( functions.upload.GetFile(
@ -2242,51 +2261,30 @@ class Client:
) )
) )
if file_out is None:
f = open("download_{}.temp".format(MsgId(), 'wb'))
close_file = True
elif isinstance(file_out, str):
f = open(file_out, 'wb')
close_file = True
elif hasattr(file_out, 'write'):
f = file_out
if hasattr(file_out, 'flush'):
call_flush = True
if hasattr(file_out, 'fileno'):
call_fsync = True
else:
raise ValueError('file_out argument of client.get_file must at least implement a write method if not a '
'string.')
if isinstance(r, types.upload.File): if isinstance(r, types.upload.File):
while True: with os.fdopen(fd, "wb") as f:
chunk = r.bytes while True:
chunk = r.bytes
if not chunk: if not chunk:
break break
f.write(chunk) f.write(chunk)
if call_flush:
f.flush() f.flush()
if call_fsync:
os.fsync(f.fileno()) os.fsync(f.fileno())
offset += limit offset += limit
if progress: if progress:
progress(min(offset, size), size) progress(min(offset, size), size)
r = session.send( r = session.send(
functions.upload.GetFile( functions.upload.GetFile(
location=location, location=location,
offset=offset, offset=offset,
limit=limit limit=limit
)
) )
)
if isinstance(r, types.upload.FileCdnRedirect): if isinstance(r, types.upload.FileCdnRedirect):
cdn_session = Session( cdn_session = Session(
@ -2301,76 +2299,77 @@ class Client:
cdn_session.start() cdn_session.start()
try: try:
while True: with os.fdopen(fd, "wb") as f:
r2 = cdn_session.send( while True:
functions.upload.GetCdnFile( r2 = cdn_session.send(
location=location, functions.upload.GetCdnFile(
file_token=r.file_token, location=location,
offset=offset, file_token=r.file_token,
limit=limit offset=offset,
) limit=limit
)
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try:
session.send(
functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
) )
except VolumeLocNotFound: )
break
else:
continue
chunk = r2.bytes if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try:
session.send(
functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
)
except VolumeLocNotFound:
break
else:
continue
# https://core.telegram.org/cdn#decrypting-files chunk = r2.bytes
decrypted_chunk = AES.ctr_decrypt(
chunk,
r.encryption_key,
r.encryption_iv,
offset
)
hashes = session.send( # https://core.telegram.org/cdn#decrypting-files
functions.upload.GetCdnFileHashes( decrypted_chunk = AES.ctr_decrypt(
r.file_token, chunk,
r.encryption_key,
r.encryption_iv,
offset offset
) )
)
# https://core.telegram.org/cdn#verifying-files hashes = session.send(
for i, h in enumerate(hashes): functions.upload.GetCdnFileHashes(
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] r.file_token,
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i) offset
)
)
f.write(decrypted_chunk) # https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
if call_flush: f.write(decrypted_chunk)
f.flush() f.flush()
if call_fsync:
os.fsync(f.fileno()) os.fsync(f.fileno())
offset += limit offset += limit
if progress: if progress:
progress(min(offset, size), size) progress(min(offset, size), size)
if len(chunk) < limit: if len(chunk) < limit:
break break
except Exception as e: except Exception as e:
log.error(e) raise e
finally: finally:
cdn_session.stop() cdn_session.stop()
except Exception as e: except Exception as e:
log.error(e) log.error(e, exc_info=True)
try:
os.remove(file_name)
except OSError:
pass
else: else:
return file_out return file_name
finally: finally:
if close_file and f is not None:
f.close()
session.stop() session.stop()
def join_chat(self, chat_id: str): def join_chat(self, chat_id: str):
@ -2627,27 +2626,18 @@ class Client:
def download_media(self, def download_media(self,
message: types.Message, message: types.Message,
file_name: str = None, file_name: str = None,
file_dir: str = None,
block: bool = True, block: bool = True,
progress: callable = None progress: callable = None):
):
"""Use this method to download the media from a Message. """Use this method to download the media from a Message.
Files are saved in the *downloads* folder.
Args: Args:
message (:obj:`Message <pyrogram.api.types.Message>`): message (:obj:`Message <pyrogram.api.types.Message>`):
The Message containing the media. The Message containing the media.
file_name (:obj:`str`, optional): file_name (:obj:`str`, optional):
Specify a custom *file_name* to be used instead of the one provided by Telegram. Specify a custom *file_name* to be used instead of the one provided by Telegram.
This parameter is expected to be a full file path to the location you want the
file to be placed, or a file like object. If not specified, the file will
be put into the directory specified by *file_dir* with a generated name.
file_dir (:obj:`str`, optional):
Specify a directory to place the file in if no *file_name* is specified.
If *file_dir* is *None*, the current working directory is used. The default
value is the "downloads" folder in the current working directory. The
directory tree will be created if it does not exist.
block (:obj:`bool`, optional): block (:obj:`bool`, optional):
Blocks the code execution until the file has been downloaded. Blocks the code execution until the file has been downloaded.
@ -2669,15 +2659,7 @@ class Client:
Raises: Raises:
:class:`pyrogram.Error` :class:`pyrogram.Error`
:class:`ValueError` if both file_name and file_dir are specified.
""" """
if file_name is not None and file_dir is not None:
raise ValueError('file_name and file_dir may not be specified together.')
if file_name is None and file_dir is None:
file_dir = 'downloads'
if isinstance(message, (types.Message, types.Photo)): if isinstance(message, (types.Message, types.Photo)):
done = Event() done = Event()
path = [None] path = [None]
@ -2688,7 +2670,7 @@ class Client:
media = message media = message
if media is not None: if media is not None:
self.download_queue.put((media, file_dir, file_name, done, progress, path)) self.download_queue.put((media, file_name, done, progress, path))
else: else:
return return
@ -2700,7 +2682,6 @@ class Client:
def download_photo(self, def download_photo(self,
photo: types.Photo or types.UserProfilePhoto or types.ChatPhoto, photo: types.Photo or types.UserProfilePhoto or types.ChatPhoto,
file_name: str = None, file_name: str = None,
file_dir: str = None,
block: bool = True): block: bool = True):
"""Use this method to download a photo not contained inside a Message. """Use this method to download a photo not contained inside a Message.
For example, a photo of a User or a Chat/Channel. For example, a photo of a User or a Chat/Channel.
@ -2712,16 +2693,7 @@ class Client:
The photo object. The photo object.
file_name (:obj:`str`, optional): file_name (:obj:`str`, optional):
Specify a custom *file_name* to be used instead of the one provided by Telegram. Specify a custom *file_name* to be used.
This parameter is expected to be a full file path to the location you want the
photo to be placed, or a file like object. If not specified, the photo will
be put into the directory specified by *file_dir* with a generated name.
file_dir (:obj:`str`, optional):
Specify a directory to place the photo in if no *file_name* is specified.
If *file_dir* is *None*, the current working directory is used. The default
value is the "downloads" folder in the current working directory. The
directory tree will be created if it does not exist.
block (:obj:`bool`, optional): block (:obj:`bool`, optional):
Blocks the code execution until the photo has been downloaded. Blocks the code execution until the photo has been downloaded.
@ -2747,7 +2719,7 @@ class Client:
)] )]
) )
return self.download_media(photo, file_name, file_dir, block) return self.download_media(photo, file_name, block)
def add_contacts(self, contacts: list): def add_contacts(self, contacts: list):
"""Use this method to add contacts to your Telegram address book. """Use this method to add contacts to your Telegram address book.