2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-29 05:18:10 +00:00

Allow download_media to download media to anywhere

Remove the use of a temporary file in the programs
working directory.
This commit is contained in:
Eric Blundell 2018-03-20 07:04:35 -05:00
parent 2fd7cd0054
commit 19b1bbb942

View File

@ -55,6 +55,8 @@ from pyrogram.session.internals import MsgId
from .input_media import InputMedia from .input_media import InputMedia
from .style import Markdown, HTML from .style import Markdown, HTML
from typing import Any
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ApiKey = namedtuple("ApiKey", ["api_id", "api_hash"]) ApiKey = namedtuple("ApiKey", ["api_id", "api_hash"])
@ -509,7 +511,6 @@ class Client:
try: try:
media, file_name, done, progress, path = media media, file_name, done, progress, path = media
tmp_file_name = None
if isinstance(media, types.MessageMediaDocument): if isinstance(media, types.MessageMediaDocument):
document = media.document document = media.document
@ -535,13 +536,14 @@ class Client:
elif isinstance(i, types.DocumentAttributeAnimated): elif isinstance(i, types.DocumentAttributeAnimated):
file_name = file_name.replace("doc", "gif") file_name = file_name.replace("doc", "gif")
tmp_file_name = self.get_file( self.get_file(
dc_id=document.dc_id, dc_id=document.dc_id,
id=document.id, id=document.id,
access_hash=document.access_hash, access_hash=document.access_hash,
version=document.version, version=document.version,
size=document.size, size=document.size,
progress=progress progress=progress,
file_out=file_name
) )
elif isinstance(media, (types.MessageMediaPhoto, types.Photo)): elif isinstance(media, (types.MessageMediaPhoto, types.Photo)):
if isinstance(media, types.MessageMediaPhoto): if isinstance(media, types.MessageMediaPhoto):
@ -558,37 +560,23 @@ class Client:
photo_loc = photo.sizes[-1].location photo_loc = photo.sizes[-1].location
tmp_file_name = self.get_file( self.get_file(
dc_id=photo_loc.dc_id, dc_id=photo_loc.dc_id,
volume_id=photo_loc.volume_id, volume_id=photo_loc.volume_id,
local_id=photo_loc.local_id, local_id=photo_loc.local_id,
secret=photo_loc.secret, secret=photo_loc.secret,
size=photo.sizes[-1].size, size=photo.sizes[-1].size,
progress=progress progress=progress,
file_out=file_name
) )
if file_name is not None: if file_name is not None:
path[0] = "downloads/{}".format(file_name) path[0] = file_name
try:
os.remove("downloads/{}".format(file_name))
except OSError:
pass
finally:
try:
os.renames("{}".format(tmp_file_name), "downloads/{}".format(file_name))
except OSError:
pass
except Exception as e: except Exception as e:
log.error(e, exc_info=True) log.error(e, exc_info=True)
finally: finally:
done.set() done.set()
try:
os.remove("{}".format(tmp_file_name))
except OSError:
pass
log.debug("{} stopped".format(name)) log.debug("{} stopped".format(name))
def updates_worker(self): def updates_worker(self):
@ -2177,7 +2165,9 @@ class Client:
secret: int = None, secret: int = None,
version: int = 0, version: int = 0,
size: int = None, size: int = None,
progress: callable = None) -> str: progress: callable = None,
file_out: Any = None) -> str:
if dc_id != self.dc_id: if dc_id != self.dc_id:
exported_auth = self.send( exported_auth = self.send(
functions.auth.ExportAuthorization( functions.auth.ExportAuthorization(
@ -2225,10 +2215,13 @@ class Client:
version=version version=version
) )
file_name = "download_{}.temp".format(MsgId())
limit = 1024 * 1024 limit = 1024 * 1024
offset = 0 offset = 0
# file object being written
f = None
close_file, call_flush, call_fsync = False, False, False
try: try:
r = session.send( r = session.send(
functions.upload.GetFile( functions.upload.GetFile(
@ -2238,30 +2231,49 @@ class Client:
) )
) )
if file_out is None:
f = open("download_{}.temp".format(MsgId(), 'wb'))
close_file = True
elif isinstance(file_out, str):
f = open(file_out, 'wb')
elif hasattr(file_out, 'write'):
f = file_out
if hasattr(file_out, 'flush'):
call_flush = True
if hasattr(file_out, 'fileno'):
call_fsync = True
else:
raise ValueError('file_out argument of client.get_file must at least implement a write method if not a '
'string.')
if isinstance(r, types.upload.File): if isinstance(r, types.upload.File):
with open(file_name, "wb") as f: while True:
while True: chunk = r.bytes
chunk = r.bytes
if not chunk: if not chunk:
break break
f.write(chunk) f.write(chunk)
if call_flush:
f.flush() f.flush()
if call_fsync:
os.fsync(f.fileno()) os.fsync(f.fileno())
offset += limit offset += limit
if progress: if progress:
progress(min(offset, size), size) progress(min(offset, size), size)
r = session.send( r = session.send(
functions.upload.GetFile( functions.upload.GetFile(
location=location, location=location,
offset=offset, offset=offset,
limit=limit limit=limit
)
) )
)
if isinstance(r, types.upload.FileCdnRedirect): if isinstance(r, types.upload.FileCdnRedirect):
cdn_session = Session( cdn_session = Session(
@ -2276,63 +2288,65 @@ class Client:
cdn_session.start() cdn_session.start()
try: try:
with open(file_name, "wb") as f: while True:
while True: r2 = cdn_session.send(
r2 = cdn_session.send( functions.upload.GetCdnFile(
functions.upload.GetCdnFile( location=location,
location=location, file_token=r.file_token,
file_token=r.file_token, offset=offset,
offset=offset, limit=limit
limit=limit
)
) )
)
if isinstance(r2, types.upload.CdnFileReuploadNeeded): if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try: try:
session.send( session.send(
functions.upload.ReuploadCdnFile( functions.upload.ReuploadCdnFile(
file_token=r.file_token, file_token=r.file_token,
request_token=r2.request_token request_token=r2.request_token
)
) )
except VolumeLocNotFound: )
break except VolumeLocNotFound:
else: break
continue else:
continue
chunk = r2.bytes chunk = r2.bytes
# https://core.telegram.org/cdn#decrypting-files # https://core.telegram.org/cdn#decrypting-files
decrypted_chunk = AES.ctr_decrypt( decrypted_chunk = AES.ctr_decrypt(
chunk, chunk,
r.encryption_key, r.encryption_key,
r.encryption_iv, r.encryption_iv,
offset
)
hashes = session.send(
functions.upload.GetCdnFileHashes(
r.file_token,
offset offset
) )
)
hashes = session.send( # https://core.telegram.org/cdn#verifying-files
functions.upload.GetCdnFileHashes( for i, h in enumerate(hashes):
r.file_token, cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
offset assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
)
)
# https://core.telegram.org/cdn#verifying-files f.write(decrypted_chunk)
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
f.write(decrypted_chunk) if call_flush:
f.flush() f.flush()
if call_fsync:
os.fsync(f.fileno()) os.fsync(f.fileno())
offset += limit offset += limit
if progress: if progress:
progress(min(offset, size), size) progress(min(offset, size), size)
if len(chunk) < limit: if len(chunk) < limit:
break break
except Exception as e: except Exception as e:
log.error(e) log.error(e)
finally: finally:
@ -2340,8 +2354,10 @@ class Client:
except Exception as e: except Exception as e:
log.error(e) log.error(e)
else: else:
return file_name return file_out
finally: finally:
if close_file and f and hasattr(f, 'close'):
f.close()
session.stop() session.stop()
def join_chat(self, chat_id: str): def join_chat(self, chat_id: str):
@ -2602,8 +2618,6 @@ class Client:
progress: callable = None): progress: callable = None):
"""Use this method to download the media from a Message. """Use this method to download the media from a Message.
Files are saved in the *downloads* folder.
Args: Args:
message (:obj:`Message <pyrogram.api.types.Message>`): message (:obj:`Message <pyrogram.api.types.Message>`):
The Message containing the media. The Message containing the media.