diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index a1033629..d3a5850f 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -152,7 +152,8 @@ class Client: force_sms: bool = False, first_name: str = None, last_name: str = None, - workers: int = 4): + workers: int = 4, + workdir: str = "."): self.session_name = session_name self.api_id = int(api_id) if api_id else None self.api_hash = api_hash @@ -167,6 +168,7 @@ class Client: self.force_sms = force_sms self.workers = workers + self.workdir = workdir self.token = None @@ -882,7 +884,7 @@ class Client: def load_session(self): try: - with open("{}.session".format(self.session_name), encoding="utf-8") as f: + with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), encoding="utf-8") as f: s = json.load(f) except FileNotFoundError: self.dc_id = 1 @@ -914,7 +916,9 @@ class Client: auth_key = base64.b64encode(self.auth_key).decode() auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] - with open("{}.session".format(self.session_name), "w", encoding="utf-8") as f: + os.makedirs(self.workdir, exist_ok=True) + + with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), "w", encoding="utf-8") as f: json.dump( dict( dc_id=self.dc_id, diff --git a/pyrogram/client/syncer.py b/pyrogram/client/syncer.py index e92f4084..125c5ce0 100644 --- a/pyrogram/client/syncer.py +++ b/pyrogram/client/syncer.py @@ -23,6 +23,7 @@ import os import shutil import time from threading import Thread, Event, Lock + from . import utils log = logging.getLogger(__name__) @@ -80,6 +81,9 @@ class Syncer: @classmethod def sync(cls, client): + temporary = os.path.join(client.workdir, "{}.sync".format(client.session_name)) + persistent = os.path.join(client.workdir, "{}.session".format(client.session_name)) + try: auth_key = base64.b64encode(client.auth_key).decode() auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] @@ -104,7 +108,9 @@ class Syncer: } ) - with open("{}.sync".format(client.session_name), "w", encoding="utf-8") as f: + os.makedirs(client.workdir, exist_ok=True) + + with open(temporary, "w", encoding="utf-8") as f: json.dump(data, f, indent=4) f.flush() @@ -112,14 +118,10 @@ class Syncer: except Exception as e: log.critical(e, exc_info=True) else: - shutil.move( - "{}.sync".format(client.session_name), - "{}.session".format(client.session_name) - ) - + shutil.move(temporary, persistent) log.info("Synced {}".format(client.session_name)) finally: try: - os.remove("{}.sync".format(client.session_name)) + os.remove(temporary) except OSError: pass