2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-30 13:57:54 +00:00

Fix script executions not working outside the current directory

Fixes #41
This commit is contained in:
Dan
2019-06-15 23:02:31 +02:00
parent abc0e992cf
commit 80d8443be4
3 changed files with 47 additions and 38 deletions

View File

@@ -28,7 +28,6 @@ import tempfile
import threading import threading
import time import time
from configparser import ConfigParser from configparser import ConfigParser
from datetime import datetime
from hashlib import sha256, md5 from hashlib import sha256, md5
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
@@ -842,39 +841,7 @@ class Client(Methods, BaseClient):
final_file_path = "" final_file_path = ""
try: try:
data, file_name, done, progress, progress_args, path = packet data, directory, file_name, done, progress, progress_args, path = packet
directory, file_name = os.path.split(file_name)
directory = directory or "downloads"
media_type_str = Client.MEDIA_TYPE_ID[data.media_type]
file_name = file_name or data.file_name
if not file_name:
guessed_extension = self.guess_extension(data.mime_type)
if data.media_type in (0, 1, 2, 14):
extension = ".jpg"
elif data.media_type == 3:
extension = guessed_extension or ".ogg"
elif data.media_type in (4, 10, 13):
extension = guessed_extension or ".mp4"
elif data.media_type == 5:
extension = guessed_extension or ".zip"
elif data.media_type == 8:
extension = guessed_extension or ".webp"
elif data.media_type == 9:
extension = guessed_extension or ".mp3"
else:
continue
file_name = "{}_{}_{}{}".format(
media_type_str,
datetime.fromtimestamp(data.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
self.rnd_id(),
extension
)
temp_file_path = self.get_file( temp_file_path = self.get_file(
media_type=data.media_type, media_type=data.media_type,

View File

@@ -19,6 +19,8 @@
import os import os
import platform import platform
import re import re
import sys
from pathlib import Path
from queue import Queue from queue import Queue
from threading import Lock from threading import Lock
@@ -45,6 +47,8 @@ class BaseClient:
LANG_CODE = "en" LANG_CODE = "en"
PARENT_DIR = Path(sys.argv[0]).parent
INVITE_LINK_RE = re.compile(r"^(?:https?://)?(?:www\.)?(?:t(?:elegram)?\.(?:org|me|dog)/joinchat/)([\w-]+)$") INVITE_LINK_RE = re.compile(r"^(?:https?://)?(?:www\.)?(?:t(?:elegram)?\.(?:org|me|dog)/joinchat/)([\w-]+)$")
BOT_TOKEN_RE = re.compile(r"^\d+:[\w-]+$") BOT_TOKEN_RE = re.compile(r"^\d+:[\w-]+$")
DIALOGS_AT_ONCE = 100 DIALOGS_AT_ONCE = 100
@@ -52,8 +56,8 @@ class BaseClient:
DOWNLOAD_WORKERS = 1 DOWNLOAD_WORKERS = 1
OFFLINE_SLEEP = 900 OFFLINE_SLEEP = 900
WORKERS = 4 WORKERS = 4
WORKDIR = "." WORKDIR = PARENT_DIR
CONFIG_FILE = "./config.ini" CONFIG_FILE = PARENT_DIR / "config.ini"
MEDIA_TYPE_ID = { MEDIA_TYPE_ID = {
0: "photo_thumbnail", 0: "photo_thumbnail",

View File

@@ -17,7 +17,10 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import binascii import binascii
import os
import struct import struct
import time
from datetime import datetime
from threading import Event from threading import Event
from typing import Union from typing import Union
@@ -25,12 +28,14 @@ import pyrogram
from pyrogram.client.ext import BaseClient, FileData, utils from pyrogram.client.ext import BaseClient, FileData, utils
from pyrogram.errors import FileIdInvalid from pyrogram.errors import FileIdInvalid
DEFAULT_DOWNLOAD_DIR = "downloads/"
class DownloadMedia(BaseClient): class DownloadMedia(BaseClient):
def download_media( def download_media(
self, self,
message: Union["pyrogram.Message", str], message: Union["pyrogram.Message", str],
file_name: str = "", file_name: str = DEFAULT_DOWNLOAD_DIR,
block: bool = True, block: bool = True,
progress: callable = None, progress: callable = None,
progress_args: tuple = () progress_args: tuple = ()
@@ -169,7 +174,40 @@ class DownloadMedia(BaseClient):
done = Event() done = Event()
path = [None] path = [None]
self.download_queue.put((data, file_name, done, progress, progress_args, path)) directory, file_name = os.path.split(file_name)
file_name = file_name or data.file_name or ""
if not os.path.isabs(file_name):
directory = self.PARENT_DIR / (directory or DEFAULT_DOWNLOAD_DIR)
media_type_str = self.MEDIA_TYPE_ID[data.media_type]
if not file_name:
guessed_extension = self.guess_extension(data.mime_type)
if data.media_type in (0, 1, 2, 14):
extension = ".jpg"
elif data.media_type == 3:
extension = guessed_extension or ".ogg"
elif data.media_type in (4, 10, 13):
extension = guessed_extension or ".mp4"
elif data.media_type == 5:
extension = guessed_extension or ".zip"
elif data.media_type == 8:
extension = guessed_extension or ".webp"
elif data.media_type == 9:
extension = guessed_extension or ".mp3"
else:
extension = ".unknown"
file_name = "{}_{}_{}{}".format(
media_type_str,
datetime.fromtimestamp(data.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
self.rnd_id(),
extension
)
self.download_queue.put((data, directory, file_name, done, progress, progress_args, path))
if block: if block:
done.wait() done.wait()