2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-27 20:37:54 +00:00

Use OS temp file, specific path download via path seperator inspection

This commit is contained in:
Eric Blundell 2018-03-20 15:20:04 -05:00
parent 62831001b7
commit 5bc10b45a3

View File

@ -34,6 +34,11 @@ from hashlib import sha256, md5
from queue import Queue
from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Event, Thread
import tempfile
import shutil
import errno
from pyrogram.api import functions, types
from pyrogram.api.core import Object
@ -55,8 +60,6 @@ from pyrogram.session.internals import MsgId
from .input_media import InputMedia
from .style import Markdown, HTML
from typing import Any
log = logging.getLogger(__name__)
ApiKey = namedtuple("ApiKey", ["api_id", "api_hash"])
@ -510,14 +513,18 @@ class Client:
break
try:
media, file_dir, file_name, done, progress, path = media
media, file_name, done, progress, path = media
tmp_file_name = None
if file_dir is not None:
# Make file_dir if it was specified
os.makedirs(file_dir, exist_ok=True)
download_directory = "downloads"
if isinstance(file_name, str) and file_name is not None:
os.makedirs(os.path.dirname(file_name), exist_ok=True)
if file_name.endswith('/') or file_name.endswith('\\'):
# treat the file name as a directory
download_directory = file_name
file_name = None
elif '/' in file_name or '\\' in file_name:
# use file_name as a full path instead
download_directory = ''
if isinstance(media, types.MessageMediaDocument):
document = media.document
@ -543,16 +550,13 @@ class Client:
elif isinstance(i, types.DocumentAttributeAnimated):
file_name = file_name.replace("doc", "gif")
file_name = os.path.join(file_dir if file_dir is not None else '', file_name)
self.get_file(
tmp_file_name = self.get_file(
dc_id=document.dc_id,
id=document.id,
access_hash=document.access_hash,
version=document.version,
size=document.size,
progress=progress,
file_out=file_name
progress=progress
)
elif isinstance(media, (types.MessageMediaPhoto, types.Photo)):
if isinstance(media, types.MessageMediaPhoto):
@ -567,27 +571,46 @@ class Client:
self.rnd_id()
)
file_name = os.path.join(file_dir if file_dir is not None else '', file_name)
photo_loc = photo.sizes[-1].location
self.get_file(
tmp_file_name = self.get_file(
dc_id=photo_loc.dc_id,
volume_id=photo_loc.volume_id,
local_id=photo_loc.local_id,
secret=photo_loc.secret,
size=photo.sizes[-1].size,
progress=progress,
file_out=file_name
progress=progress
)
if file_name is not None:
path[0] = file_name
path[0] = os.path.join(download_directory, file_name)
try:
os.remove(os.path.join(download_directory, file_name))
except OSError:
pass
finally:
try:
if download_directory:
os.makedirs(download_directory, exist_ok=True)
else:
os.makedirs(os.path.dirname(file_name), exist_ok=True)
# avoid errors moving between drives on windows
shutil.move(tmp_file_name, os.path.join(download_directory, file_name))
except OSError as e:
log.error(e, exc_info=True)
except Exception as e:
log.error(e, exc_info=True)
finally:
done.set()
try:
os.remove(tmp_file_name)
except OSError as e:
if not e.errno == errno.ENOENT:
log.error(e, exc_info=True)
log.debug("{} stopped".format(name))
def updates_worker(self):
@ -2176,9 +2199,7 @@ class Client:
secret: int = None,
version: int = 0,
size: int = None,
progress: callable = None,
file_out: Any = None) -> str:
progress: callable = None) -> str:
if dc_id != self.dc_id:
exported_auth = self.send(
functions.auth.ExportAuthorization(
@ -2226,13 +2247,11 @@ class Client:
version=version
)
fd, file_name = tempfile.mkstemp()
limit = 1024 * 1024
offset = 0
# file object being written
f = None
close_file, call_flush, call_fsync = False, False, False
try:
r = session.send(
functions.upload.GetFile(
@ -2242,26 +2261,8 @@ class Client:
)
)
if file_out is None:
f = open("download_{}.temp".format(MsgId(), 'wb'))
close_file = True
elif isinstance(file_out, str):
f = open(file_out, 'wb')
close_file = True
elif hasattr(file_out, 'write'):
f = file_out
if hasattr(file_out, 'flush'):
call_flush = True
if hasattr(file_out, 'fileno'):
call_fsync = True
else:
raise ValueError('file_out argument of client.get_file must at least implement a write method if not a '
'string.')
if isinstance(r, types.upload.File):
with os.fdopen(fd, "wb") as f:
while True:
chunk = r.bytes
@ -2269,10 +2270,7 @@ class Client:
break
f.write(chunk)
if call_flush:
f.flush()
if call_fsync:
os.fsync(f.fileno())
offset += limit
@ -2301,6 +2299,7 @@ class Client:
cdn_session.start()
try:
with os.fdopen(fd, "wb") as f:
while True:
r2 = cdn_session.send(
functions.upload.GetCdnFile(
@ -2347,10 +2346,7 @@ class Client:
assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i)
f.write(decrypted_chunk)
if call_flush:
f.flush()
if call_fsync:
os.fsync(f.fileno())
offset += limit
@ -2361,16 +2357,19 @@ class Client:
if len(chunk) < limit:
break
except Exception as e:
log.error(e)
raise e
finally:
cdn_session.stop()
except Exception as e:
log.error(e)
log.error(e, exc_info=True)
try:
os.remove(file_name)
except OSError:
pass
else:
return file_out
return file_name
finally:
if close_file and f is not None:
f.close()
session.stop()
def join_chat(self, chat_id: str):
@ -2627,27 +2626,18 @@ class Client:
def download_media(self,
message: types.Message,
file_name: str = None,
file_dir: str = None,
block: bool = True,
progress: callable = None
):
progress: callable = None):
"""Use this method to download the media from a Message.
Files are saved in the *downloads* folder.
Args:
message (:obj:`Message <pyrogram.api.types.Message>`):
The Message containing the media.
file_name (:obj:`str`, optional):
Specify a custom *file_name* to be used instead of the one provided by Telegram.
This parameter is expected to be a full file path to the location you want the
file to be placed, or a file like object. If not specified, the file will
be put into the directory specified by *file_dir* with a generated name.
file_dir (:obj:`str`, optional):
Specify a directory to place the file in if no *file_name* is specified.
If *file_dir* is *None*, the current working directory is used. The default
value is the "downloads" folder in the current working directory. The
directory tree will be created if it does not exist.
block (:obj:`bool`, optional):
Blocks the code execution until the file has been downloaded.
@ -2669,15 +2659,7 @@ class Client:
Raises:
:class:`pyrogram.Error`
:class:`ValueError` if both file_name and file_dir are specified.
"""
if file_name is not None and file_dir is not None:
raise ValueError('file_name and file_dir may not be specified together.')
if file_name is None and file_dir is None:
file_dir = 'downloads'
if isinstance(message, (types.Message, types.Photo)):
done = Event()
path = [None]
@ -2688,7 +2670,7 @@ class Client:
media = message
if media is not None:
self.download_queue.put((media, file_dir, file_name, done, progress, path))
self.download_queue.put((media, file_name, done, progress, path))
else:
return
@ -2700,7 +2682,6 @@ class Client:
def download_photo(self,
photo: types.Photo or types.UserProfilePhoto or types.ChatPhoto,
file_name: str = None,
file_dir: str = None,
block: bool = True):
"""Use this method to download a photo not contained inside a Message.
For example, a photo of a User or a Chat/Channel.
@ -2712,16 +2693,7 @@ class Client:
The photo object.
file_name (:obj:`str`, optional):
Specify a custom *file_name* to be used instead of the one provided by Telegram.
This parameter is expected to be a full file path to the location you want the
photo to be placed, or a file like object. If not specified, the photo will
be put into the directory specified by *file_dir* with a generated name.
file_dir (:obj:`str`, optional):
Specify a directory to place the photo in if no *file_name* is specified.
If *file_dir* is *None*, the current working directory is used. The default
value is the "downloads" folder in the current working directory. The
directory tree will be created if it does not exist.
Specify a custom *file_name* to be used.
block (:obj:`bool`, optional):
Blocks the code execution until the photo has been downloaded.
@ -2747,7 +2719,7 @@ class Client:
)]
)
return self.download_media(photo, file_name, file_dir, block)
return self.download_media(photo, file_name, block)
def add_contacts(self, contacts: list):
"""Use this method to add contacts to your Telegram address book.