mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-29 05:18:10 +00:00
Allow download_media to download media to anywhere
Remove the use of a temporary file in the programs working directory.
This commit is contained in:
parent
2fd7cd0054
commit
19b1bbb942
@ -55,6 +55,8 @@ from pyrogram.session.internals import MsgId
|
|||||||
from .input_media import InputMedia
|
from .input_media import InputMedia
|
||||||
from .style import Markdown, HTML
|
from .style import Markdown, HTML
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
ApiKey = namedtuple("ApiKey", ["api_id", "api_hash"])
|
ApiKey = namedtuple("ApiKey", ["api_id", "api_hash"])
|
||||||
@ -509,7 +511,6 @@ class Client:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
media, file_name, done, progress, path = media
|
media, file_name, done, progress, path = media
|
||||||
tmp_file_name = None
|
|
||||||
|
|
||||||
if isinstance(media, types.MessageMediaDocument):
|
if isinstance(media, types.MessageMediaDocument):
|
||||||
document = media.document
|
document = media.document
|
||||||
@ -535,13 +536,14 @@ class Client:
|
|||||||
elif isinstance(i, types.DocumentAttributeAnimated):
|
elif isinstance(i, types.DocumentAttributeAnimated):
|
||||||
file_name = file_name.replace("doc", "gif")
|
file_name = file_name.replace("doc", "gif")
|
||||||
|
|
||||||
tmp_file_name = self.get_file(
|
self.get_file(
|
||||||
dc_id=document.dc_id,
|
dc_id=document.dc_id,
|
||||||
id=document.id,
|
id=document.id,
|
||||||
access_hash=document.access_hash,
|
access_hash=document.access_hash,
|
||||||
version=document.version,
|
version=document.version,
|
||||||
size=document.size,
|
size=document.size,
|
||||||
progress=progress
|
progress=progress,
|
||||||
|
file_out=file_name
|
||||||
)
|
)
|
||||||
elif isinstance(media, (types.MessageMediaPhoto, types.Photo)):
|
elif isinstance(media, (types.MessageMediaPhoto, types.Photo)):
|
||||||
if isinstance(media, types.MessageMediaPhoto):
|
if isinstance(media, types.MessageMediaPhoto):
|
||||||
@ -558,37 +560,23 @@ class Client:
|
|||||||
|
|
||||||
photo_loc = photo.sizes[-1].location
|
photo_loc = photo.sizes[-1].location
|
||||||
|
|
||||||
tmp_file_name = self.get_file(
|
self.get_file(
|
||||||
dc_id=photo_loc.dc_id,
|
dc_id=photo_loc.dc_id,
|
||||||
volume_id=photo_loc.volume_id,
|
volume_id=photo_loc.volume_id,
|
||||||
local_id=photo_loc.local_id,
|
local_id=photo_loc.local_id,
|
||||||
secret=photo_loc.secret,
|
secret=photo_loc.secret,
|
||||||
size=photo.sizes[-1].size,
|
size=photo.sizes[-1].size,
|
||||||
progress=progress
|
progress=progress,
|
||||||
|
file_out=file_name
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_name is not None:
|
if file_name is not None:
|
||||||
path[0] = "downloads/{}".format(file_name)
|
path[0] = file_name
|
||||||
|
|
||||||
try:
|
|
||||||
os.remove("downloads/{}".format(file_name))
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
os.renames("{}".format(tmp_file_name), "downloads/{}".format(file_name))
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(e, exc_info=True)
|
log.error(e, exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
try:
|
|
||||||
os.remove("{}".format(tmp_file_name))
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
log.debug("{} stopped".format(name))
|
log.debug("{} stopped".format(name))
|
||||||
|
|
||||||
def updates_worker(self):
|
def updates_worker(self):
|
||||||
@ -2177,7 +2165,9 @@ class Client:
|
|||||||
secret: int = None,
|
secret: int = None,
|
||||||
version: int = 0,
|
version: int = 0,
|
||||||
size: int = None,
|
size: int = None,
|
||||||
progress: callable = None) -> str:
|
progress: callable = None,
|
||||||
|
file_out: Any = None) -> str:
|
||||||
|
|
||||||
if dc_id != self.dc_id:
|
if dc_id != self.dc_id:
|
||||||
exported_auth = self.send(
|
exported_auth = self.send(
|
||||||
functions.auth.ExportAuthorization(
|
functions.auth.ExportAuthorization(
|
||||||
@ -2225,10 +2215,13 @@ class Client:
|
|||||||
version=version
|
version=version
|
||||||
)
|
)
|
||||||
|
|
||||||
file_name = "download_{}.temp".format(MsgId())
|
|
||||||
limit = 1024 * 1024
|
limit = 1024 * 1024
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
||||||
|
# file object being written
|
||||||
|
f = None
|
||||||
|
close_file, call_flush, call_fsync = False, False, False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r = session.send(
|
r = session.send(
|
||||||
functions.upload.GetFile(
|
functions.upload.GetFile(
|
||||||
@ -2238,30 +2231,49 @@ class Client:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if file_out is None:
|
||||||
|
f = open("download_{}.temp".format(MsgId(), 'wb'))
|
||||||
|
close_file = True
|
||||||
|
|
||||||
|
elif isinstance(file_out, str):
|
||||||
|
f = open(file_out, 'wb')
|
||||||
|
elif hasattr(file_out, 'write'):
|
||||||
|
f = file_out
|
||||||
|
|
||||||
|
if hasattr(file_out, 'flush'):
|
||||||
|
call_flush = True
|
||||||
|
if hasattr(file_out, 'fileno'):
|
||||||
|
call_fsync = True
|
||||||
|
else:
|
||||||
|
raise ValueError('file_out argument of client.get_file must at least implement a write method if not a '
|
||||||
|
'string.')
|
||||||
|
|
||||||
if isinstance(r, types.upload.File):
|
if isinstance(r, types.upload.File):
|
||||||
with open(file_name, "wb") as f:
|
while True:
|
||||||
while True:
|
chunk = r.bytes
|
||||||
chunk = r.bytes
|
|
||||||
|
|
||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
|
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
|
|
||||||
|
if call_flush:
|
||||||
f.flush()
|
f.flush()
|
||||||
|
if call_fsync:
|
||||||
os.fsync(f.fileno())
|
os.fsync(f.fileno())
|
||||||
|
|
||||||
offset += limit
|
offset += limit
|
||||||
|
|
||||||
if progress:
|
if progress:
|
||||||
progress(min(offset, size), size)
|
progress(min(offset, size), size)
|
||||||
|
|
||||||
r = session.send(
|
r = session.send(
|
||||||
functions.upload.GetFile(
|
functions.upload.GetFile(
|
||||||
location=location,
|
location=location,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(r, types.upload.FileCdnRedirect):
|
if isinstance(r, types.upload.FileCdnRedirect):
|
||||||
cdn_session = Session(
|
cdn_session = Session(
|
||||||
@ -2276,63 +2288,65 @@ class Client:
|
|||||||
cdn_session.start()
|
cdn_session.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_name, "wb") as f:
|
while True:
|
||||||
while True:
|
r2 = cdn_session.send(
|
||||||
r2 = cdn_session.send(
|
functions.upload.GetCdnFile(
|
||||||
functions.upload.GetCdnFile(
|
location=location,
|
||||||
location=location,
|
file_token=r.file_token,
|
||||||
file_token=r.file_token,
|
offset=offset,
|
||||||
offset=offset,
|
limit=limit
|
||||||
limit=limit
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
|
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
|
||||||
try:
|
try:
|
||||||
session.send(
|
session.send(
|
||||||
functions.upload.ReuploadCdnFile(
|
functions.upload.ReuploadCdnFile(
|
||||||
file_token=r.file_token,
|
file_token=r.file_token,
|
||||||
request_token=r2.request_token
|
request_token=r2.request_token
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except VolumeLocNotFound:
|
)
|
||||||
break
|
except VolumeLocNotFound:
|
||||||
else:
|
break
|
||||||
continue
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
chunk = r2.bytes
|
chunk = r2.bytes
|
||||||
|
|
||||||
# https://core.telegram.org/cdn#decrypting-files
|
# https://core.telegram.org/cdn#decrypting-files
|
||||||
decrypted_chunk = AES.ctr_decrypt(
|
decrypted_chunk = AES.ctr_decrypt(
|
||||||
chunk,
|
chunk,
|
||||||
r.encryption_key,
|
r.encryption_key,
|
||||||
r.encryption_iv,
|
r.encryption_iv,
|
||||||
|
offset
|
||||||
|
)
|
||||||
|
|
||||||
|
hashes = session.send(
|
||||||
|
functions.upload.GetCdnFileHashes(
|
||||||
|
r.file_token,
|
||||||
offset
|
offset
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
hashes = session.send(
|
# https://core.telegram.org/cdn#verifying-files
|
||||||
functions.upload.GetCdnFileHashes(
|
for i, h in enumerate(hashes):
|
||||||
r.file_token,
|
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||||
offset
|
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://core.telegram.org/cdn#verifying-files
|
f.write(decrypted_chunk)
|
||||||
for i, h in enumerate(hashes):
|
|
||||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
|
||||||
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
|
|
||||||
|
|
||||||
f.write(decrypted_chunk)
|
if call_flush:
|
||||||
f.flush()
|
f.flush()
|
||||||
|
if call_fsync:
|
||||||
os.fsync(f.fileno())
|
os.fsync(f.fileno())
|
||||||
|
|
||||||
offset += limit
|
offset += limit
|
||||||
|
|
||||||
if progress:
|
if progress:
|
||||||
progress(min(offset, size), size)
|
progress(min(offset, size), size)
|
||||||
|
|
||||||
if len(chunk) < limit:
|
if len(chunk) < limit:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(e)
|
log.error(e)
|
||||||
finally:
|
finally:
|
||||||
@ -2340,8 +2354,10 @@ class Client:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(e)
|
log.error(e)
|
||||||
else:
|
else:
|
||||||
return file_name
|
return file_out
|
||||||
finally:
|
finally:
|
||||||
|
if close_file and f and hasattr(f, 'close'):
|
||||||
|
f.close()
|
||||||
session.stop()
|
session.stop()
|
||||||
|
|
||||||
def join_chat(self, chat_id: str):
|
def join_chat(self, chat_id: str):
|
||||||
@ -2602,8 +2618,6 @@ class Client:
|
|||||||
progress: callable = None):
|
progress: callable = None):
|
||||||
"""Use this method to download the media from a Message.
|
"""Use this method to download the media from a Message.
|
||||||
|
|
||||||
Files are saved in the *downloads* folder.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message (:obj:`Message <pyrogram.api.types.Message>`):
|
message (:obj:`Message <pyrogram.api.types.Message>`):
|
||||||
The Message containing the media.
|
The Message containing the media.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user