From db50384808a3a5c747dbc56de5dc3590367d756d Mon Sep 17 00:00:00 2001 From: xaoyaoo Date: Fri, 22 Mar 2024 17:55:24 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E8=BF=9E=E6=8E=A5?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=E6=94=B9=E4=B8=BA=E5=85=B1=E7=94=A8=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=EF=BC=8C=E9=99=8D=E4=BD=8E=E6=97=B6=E9=97=B4=E5=BC=80?= =?UTF-8?q?=E9=94=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pywxdump/__init__.py | 6 +- pywxdump/analyzer/__init__.py | 2 +- pywxdump/analyzer/export_chat.py | 431 +++++++++++++++---------------- pywxdump/analyzer/utils.py | 65 ++++- 4 files changed, 271 insertions(+), 233 deletions(-) diff --git a/pywxdump/__init__.py b/pywxdump/__init__.py index 0864d90..e2de77f 100644 --- a/pywxdump/__init__.py +++ b/pywxdump/__init__.py @@ -6,10 +6,10 @@ # Date: 2023/10/14 # ------------------------------------------------------------------------------- from .wx_info import BiasAddr, read_info, get_wechat_db, batch_decrypt, decrypt, get_core_db -from .wx_info import merge_copy_db, merge_msg_db, merge_media_msg_db, merge_db, decrypt_merge,merge_real_time_db +from .wx_info import merge_copy_db, merge_msg_db, merge_media_msg_db, merge_db, decrypt_merge, merge_real_time_db from .analyzer.db_parsing import read_img_dat, read_emoji, decompress_CompressContent, read_audio_buf, read_audio, \ parse_xml_string, read_BytesExtra -from .analyzer import export_csv,export_json +from .analyzer import export_csv, export_json, DBPool from .ui import app_show_chat, get_user_list, export from .server import start_falsk @@ -26,3 +26,5 @@ except: PYWXDUMP_ROOT_PATH = os.path.dirname(__file__) __version__ = "2.4.60" + +db_init = DBPool("DBPOOL_INIT") \ No newline at end of file diff --git a/pywxdump/analyzer/__init__.py b/pywxdump/analyzer/__init__.py index 6fe6e09..fdf84aa 100644 --- a/pywxdump/analyzer/__init__.py +++ b/pywxdump/analyzer/__init__.py @@ -9,4 +9,4 @@ from .db_parsing import read_img_dat, read_emoji, decompress_CompressContent, re parse_xml_string, read_BytesExtra from .export_chat import export_csv, get_contact_list, get_chatroom_list, get_msg_list, get_chat_count, export_json, \ get_all_chat_count -from .utils import get_type_name, get_name_typeid +from .utils import get_type_name, get_name_typeid,DBPool diff --git a/pywxdump/analyzer/export_chat.py b/pywxdump/analyzer/export_chat.py index 447ae0b..b1f3f52 100644 --- a/pywxdump/analyzer/export_chat.py +++ b/pywxdump/analyzer/export_chat.py @@ -20,7 +20,7 @@ import json import time from functools import wraps -from .utils import get_md5, attach_databases, execute_sql, get_type_name, match_BytesExtra +from .utils import get_md5, attach_databases, execute_sql, get_type_name, match_BytesExtra, DBPool from .db_parsing import parse_xml_string, decompress_CompressContent, read_BytesExtra @@ -31,24 +31,20 @@ def get_contact(MicroMsg_db_path, wx_id): :param wx_id: 微信id :return: 联系人信息 """ - db = sqlite3.connect(MicroMsg_db_path) - cursor = db.cursor() - # 获取username是wx_id的用户 - sql = ("SELECT A.UserName, A.NickName, A.Remark,A.Alias,A.Reserved6,B.bigHeadImgUrl " - "FROM Contact A,ContactHeadImgUrl B " - f"WHERE A.UserName = '{wx_id}' AND A.UserName = B.usrName " - "ORDER BY NickName ASC;") - cursor.execute(sql) - result = cursor.fetchone() - cursor.close() - db.close() - print('联系人信息:', result) - if not result: - print('居然没找到!') - print(wx_id) - return None - return {"username": result[0], "nickname": result[1], "remark": result[2], "account": result[3], - "describe": result[4], "headImgUrl": result[5]} + with DBPool(MicroMsg_db_path) as db: + # 获取username是wx_id的用户 + sql = ("SELECT A.UserName, A.NickName, A.Remark,A.Alias,A.Reserved6,B.bigHeadImgUrl " + "FROM Contact A,ContactHeadImgUrl B " + f"WHERE A.UserName = '{wx_id}' AND A.UserName = B.usrName " + "ORDER BY NickName ASC;") + result = execute_sql(db, sql) + print('联系人信息:', result) + if not result: + print('居然没找到!') + print(wx_id) + return None + return {"username": result[0], "nickname": result[1], "remark": result[2], "account": result[3], + "describe": result[4], "headImgUrl": result[5]} def get_contact_list(MicroMsg_db_path): @@ -59,24 +55,19 @@ def get_contact_list(MicroMsg_db_path): """ users = [] # 连接 MicroMsg.db 数据库,并执行查询 - db = sqlite3.connect(MicroMsg_db_path) - cursor = db.cursor() - sql = ("SELECT A.UserName, A.NickName, A.Remark,A.Alias,A.Reserved6,B.bigHeadImgUrl " - "FROM Contact A,ContactHeadImgUrl B " - "where UserName==usrName " - "ORDER BY NickName ASC;") - cursor.execute(sql) - result = cursor.fetchall() - - for row in result: - # 获取用户名、昵称、备注和聊天记录数量 - username, nickname, remark, Alias, describe, headImgUrl = row - users.append( - {"username": username, "nickname": nickname, "remark": remark, "account": Alias, "describe": describe, - "headImgUrl": headImgUrl}) - cursor.close() - db.close() - return users + with DBPool(MicroMsg_db_path) as db: + sql = ("SELECT A.UserName, A.NickName, A.Remark,A.Alias,A.Reserved6,B.bigHeadImgUrl " + "FROM Contact A,ContactHeadImgUrl B " + "where UserName==usrName " + "ORDER BY NickName ASC;") + result = execute_sql(db, sql) + for row in result: + # 获取用户名、昵称、备注和聊天记录数量 + username, nickname, remark, Alias, describe, headImgUrl = row + users.append( + {"username": username, "nickname": nickname, "remark": remark, "account": Alias, "describe": describe, + "headImgUrl": headImgUrl}) + return users def get_chatroom_list(MicroMsg_db_path): @@ -87,24 +78,21 @@ def get_chatroom_list(MicroMsg_db_path): """ rooms = [] # 连接 MicroMsg.db 数据库,并执行查询 - db = sqlite3.connect(MicroMsg_db_path) - - sql = ("SELECT A.ChatRoomName,A.UserNameList, A.DisplayNameList, B.Announcement,B.AnnouncementEditor " - "FROM ChatRoom A,ChatRoomInfo B " - "where A.ChatRoomName==B.ChatRoomName " - "ORDER BY A.ChatRoomName ASC;") - - result = execute_sql(db, sql) - db.close() - for row in result: - # 获取用户名、昵称、备注和聊天记录数量 - ChatRoomName, UserNameList, DisplayNameList, Announcement, AnnouncementEditor = row - UserNameList = UserNameList.split("^G") - DisplayNameList = DisplayNameList.split("^G") - rooms.append( - {"ChatRoomName": ChatRoomName, "UserNameList": UserNameList, "DisplayNameList": DisplayNameList, - "Announcement": Announcement, "AnnouncementEditor": AnnouncementEditor}) - return rooms + with DBPool(MicroMsg_db_path) as db: + sql = ("SELECT A.ChatRoomName,A.UserNameList, A.DisplayNameList, B.Announcement,B.AnnouncementEditor " + "FROM ChatRoom A,ChatRoomInfo B " + "where A.ChatRoomName==B.ChatRoomName " + "ORDER BY A.ChatRoomName ASC;") + result = execute_sql(db, sql) + for row in result: + # 获取用户名、昵称、备注和聊天记录数量 + ChatRoomName, UserNameList, DisplayNameList, Announcement, AnnouncementEditor = row + UserNameList = UserNameList.split("^G") + DisplayNameList = DisplayNameList.split("^G") + rooms.append( + {"ChatRoomName": ChatRoomName, "UserNameList": UserNameList, "DisplayNameList": DisplayNameList, + "Announcement": Announcement, "AnnouncementEditor": AnnouncementEditor}) + return rooms def get_room_user_list(MSG_db_path, selected_talker): @@ -116,36 +104,31 @@ def get_room_user_list(MSG_db_path, selected_talker): """ # 连接 MSG_ALL.db 数据库,并执行查询 - db1 = sqlite3.connect(MSG_db_path) - cursor1 = db1.cursor() + with DBPool(MSG_db_path) as db1: + sql = ( + "SELECT localId, IsSender, StrContent, StrTalker, Sequence, Type, SubType,CreateTime,MsgSvrID,DisplayContent,CompressContent,BytesExtra,ROW_NUMBER() OVER (ORDER BY CreateTime ASC) AS id " + "FROM MSG WHERE StrTalker=? " + "ORDER BY CreateTime ASC") - sql = ( - "SELECT localId, IsSender, StrContent, StrTalker, Sequence, Type, SubType,CreateTime,MsgSvrID,DisplayContent,CompressContent,BytesExtra,ROW_NUMBER() OVER (ORDER BY CreateTime ASC) AS id " - "FROM MSG WHERE StrTalker=? " - "ORDER BY CreateTime ASC") - - cursor1.execute(sql, (selected_talker,)) - result1 = cursor1.fetchall() - cursor1.close() - db1.close() - user_list = [] - read_user_wx_id = [] - for row in result1: - localId, IsSender, StrContent, StrTalker, Sequence, Type, SubType, CreateTime, MsgSvrID, DisplayContent, CompressContent, BytesExtra, id = row - bytes_extra = read_BytesExtra(BytesExtra) - if bytes_extra: - try: - talker = bytes_extra['3'][0]['2'].decode('utf-8', errors='ignore') - except: + result1 = execute_sql(db1, sql, (selected_talker,)) + user_list = [] + read_user_wx_id = [] + for row in result1: + localId, IsSender, StrContent, StrTalker, Sequence, Type, SubType, CreateTime, MsgSvrID, DisplayContent, CompressContent, BytesExtra, id = row + bytes_extra = read_BytesExtra(BytesExtra) + if bytes_extra: + try: + talker = bytes_extra['3'][0]['2'].decode('utf-8', errors='ignore') + except: + continue + if talker in read_user_wx_id: continue - if talker in read_user_wx_id: - continue - user = get_contact(MSG_db_path, talker) - if not user: - continue - user_list.append(user) - read_user_wx_id.append(talker) - return user_list + user = get_contact(MSG_db_path, talker) + if not user: + continue + user_list.append(user) + read_user_wx_id.append(talker) + return user_list def get_msg_list(MSG_db_path, selected_talker="", start_index=0, page_size=500): @@ -159,136 +142,132 @@ def get_msg_list(MSG_db_path, selected_talker="", start_index=0, page_size=500): """ # 连接 MSG_ALL.db 数据库,并执行查询 - db1 = sqlite3.connect(MSG_db_path) - cursor1 = db1.cursor() - if selected_talker: - sql = ( - "SELECT localId, IsSender, StrContent, StrTalker, Sequence, Type, SubType,CreateTime,MsgSvrID,DisplayContent,CompressContent,BytesExtra,ROW_NUMBER() OVER (ORDER BY CreateTime ASC) AS id " - "FROM MSG WHERE StrTalker=? " - "ORDER BY CreateTime ASC LIMIT ?,?") - cursor1.execute(sql, (selected_talker, start_index, page_size)) - else: - sql = ( - "SELECT localId, IsSender, StrContent, StrTalker, Sequence, Type, SubType,CreateTime,MsgSvrID,DisplayContent,CompressContent,BytesExtra,ROW_NUMBER() OVER (ORDER BY CreateTime ASC) AS id " - "FROM MSG ORDER BY CreateTime ASC LIMIT ?,?") - cursor1.execute(sql, (start_index, page_size)) - result1 = cursor1.fetchall() - cursor1.close() - db1.close() - - data = [] - for row in result1: - localId, IsSender, StrContent, StrTalker, Sequence, Type, SubType, CreateTime, MsgSvrID, DisplayContent, CompressContent, BytesExtra, id = row - CreateTime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(CreateTime)) - - type_id = (Type, SubType) - type_name = get_type_name(type_id) - - content = {"src": "", "msg": StrContent} - - if type_id == (1, 0): # 文本 - content["msg"] = StrContent - - elif type_id == (3, 0): # 图片 - DictExtra = read_BytesExtra(BytesExtra) - DictExtra = str(DictExtra) - match = re.search(r"FileStorage(.*?)'", DictExtra) - if match: - img_path = match.group(0).replace("'", "") - img_path = [i for i in img_path.split("\\") if i] - img_path = os.path.join(*img_path) - content["src"] = img_path - else: - content["src"] = "" - content["msg"] = "图片" - elif type_id == (34, 0): - tmp_c = parse_xml_string(StrContent) - voicelength = tmp_c.get("voicemsg", {}).get("voicelength", "") - transtext = tmp_c.get("voicetrans", {}).get("transtext", "") - if voicelength.isdigit(): - voicelength = int(voicelength) / 1000 - voicelength = f"{voicelength:.2f}" - content[ - "msg"] = f"语音时长:{voicelength}秒\n翻译结果:{transtext}" if transtext else f"语音时长:{voicelength}秒" - content["src"] = os.path.join("audio", f"{StrTalker}", - f"{CreateTime.replace(':', '-').replace(' ', '_')}_{IsSender}_{MsgSvrID}.wav") - elif type_id == (43, 0): # 视频 - DictExtra = read_BytesExtra(BytesExtra) - DictExtra = str(DictExtra) - match = re.search(r"FileStorage(.*?)'", DictExtra) - if match: - video_path = match.group(0).replace("'", "") - content["src"] = video_path - else: - content["src"] = "" - content["msg"] = "视频" - - elif type_id == (47, 0): # 动画表情 - content_tmp = parse_xml_string(StrContent) - cdnurl = content_tmp.get("emoji", {}).get("cdnurl", "") - if cdnurl: - content = {"src": cdnurl, "msg": "表情"} - - elif type_id == (49, 0): - DictExtra = read_BytesExtra(BytesExtra) - url = match_BytesExtra(DictExtra) - content["src"] = url - file_name = os.path.basename(url) - content["msg"] = file_name - - elif type_id == (49, 19): # 合并转发的聊天记录 - CompressContent = decompress_CompressContent(CompressContent) - content_tmp = parse_xml_string(CompressContent) - title = content_tmp.get("appmsg", {}).get("title", "") - des = content_tmp.get("appmsg", {}).get("des", "") - recorditem = content_tmp.get("appmsg", {}).get("recorditem", "") - recorditem = parse_xml_string(recorditem) - content["msg"] = f"{title}\n{des}" - content["src"] = recorditem - - elif type_id == (49, 2000): # 转账消息 - CompressContent = decompress_CompressContent(CompressContent) - content_tmp = parse_xml_string(CompressContent) - feedesc = content_tmp.get("appmsg", {}).get("wcpayinfo", {}).get("feedesc", "") - content["msg"] = f"转账:{feedesc}" - content["src"] = "" - - elif type_id[0] == 49 and type_id[1] != 0: - DictExtra = read_BytesExtra(BytesExtra) - url = match_BytesExtra(DictExtra) - content["src"] = url - content["msg"] = type_name - - elif type_id == (50, 0): # 语音通话 - content["msg"] = "语音/视频通话[%s]" % DisplayContent - - # elif type_id == (10000, 0): - # content["msg"] = StrContent - # elif type_id == (10000, 4): - # content["msg"] = StrContent - # elif type_id == (10000, 8000): - # content["msg"] = StrContent - - talker = "未知" - if IsSender == 1: - talker = "我" + with DBPool(MSG_db_path) as db1: + if selected_talker: + sql = ( + "SELECT localId, IsSender, StrContent, StrTalker, Sequence, Type, SubType,CreateTime,MsgSvrID,DisplayContent,CompressContent,BytesExtra,ROW_NUMBER() OVER (ORDER BY CreateTime ASC) AS id " + "FROM MSG WHERE StrTalker=? " + "ORDER BY CreateTime ASC LIMIT ?,?") + result1 = execute_sql(db1,sql, (selected_talker, start_index, page_size)) else: - if StrTalker.endswith("@chatroom"): - bytes_extra = read_BytesExtra(BytesExtra) - if bytes_extra: - try: - talker = bytes_extra['3'][0]['2'].decode('utf-8', errors='ignore') - if "publisher-id" in talker: - talker = "系统" - except: - pass - else: - talker = StrTalker + sql = ( + "SELECT localId, IsSender, StrContent, StrTalker, Sequence, Type, SubType,CreateTime,MsgSvrID,DisplayContent,CompressContent,BytesExtra,ROW_NUMBER() OVER (ORDER BY CreateTime ASC) AS id " + "FROM MSG ORDER BY CreateTime ASC LIMIT ?,?") + result1 = execute_sql(db1,sql, (start_index, page_size)) - row_data = {"MsgSvrID": str(MsgSvrID), "type_name": type_name, "is_sender": IsSender, "talker": talker, - "room_name": StrTalker, "content": content, "CreateTime": CreateTime, "id": id} - data.append(row_data) - return data + data = [] + for row in result1: + localId, IsSender, StrContent, StrTalker, Sequence, Type, SubType, CreateTime, MsgSvrID, DisplayContent, CompressContent, BytesExtra, id = row + CreateTime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(CreateTime)) + + type_id = (Type, SubType) + type_name = get_type_name(type_id) + + content = {"src": "", "msg": StrContent} + + if type_id == (1, 0): # 文本 + content["msg"] = StrContent + + elif type_id == (3, 0): # 图片 + DictExtra = read_BytesExtra(BytesExtra) + DictExtra = str(DictExtra) + match = re.search(r"FileStorage(.*?)'", DictExtra) + if match: + img_path = match.group(0).replace("'", "") + img_path = [i for i in img_path.split("\\") if i] + img_path = os.path.join(*img_path) + content["src"] = img_path + else: + content["src"] = "" + content["msg"] = "图片" + elif type_id == (34, 0): + tmp_c = parse_xml_string(StrContent) + voicelength = tmp_c.get("voicemsg", {}).get("voicelength", "") + transtext = tmp_c.get("voicetrans", {}).get("transtext", "") + if voicelength.isdigit(): + voicelength = int(voicelength) / 1000 + voicelength = f"{voicelength:.2f}" + content[ + "msg"] = f"语音时长:{voicelength}秒\n翻译结果:{transtext}" if transtext else f"语音时长:{voicelength}秒" + content["src"] = os.path.join("audio", f"{StrTalker}", + f"{CreateTime.replace(':', '-').replace(' ', '_')}_{IsSender}_{MsgSvrID}.wav") + elif type_id == (43, 0): # 视频 + DictExtra = read_BytesExtra(BytesExtra) + DictExtra = str(DictExtra) + match = re.search(r"FileStorage(.*?)'", DictExtra) + if match: + video_path = match.group(0).replace("'", "") + content["src"] = video_path + else: + content["src"] = "" + content["msg"] = "视频" + + elif type_id == (47, 0): # 动画表情 + content_tmp = parse_xml_string(StrContent) + cdnurl = content_tmp.get("emoji", {}).get("cdnurl", "") + if cdnurl: + content = {"src": cdnurl, "msg": "表情"} + + elif type_id == (49, 0): + DictExtra = read_BytesExtra(BytesExtra) + url = match_BytesExtra(DictExtra) + content["src"] = url + file_name = os.path.basename(url) + content["msg"] = file_name + + elif type_id == (49, 19): # 合并转发的聊天记录 + CompressContent = decompress_CompressContent(CompressContent) + content_tmp = parse_xml_string(CompressContent) + title = content_tmp.get("appmsg", {}).get("title", "") + des = content_tmp.get("appmsg", {}).get("des", "") + recorditem = content_tmp.get("appmsg", {}).get("recorditem", "") + recorditem = parse_xml_string(recorditem) + content["msg"] = f"{title}\n{des}" + content["src"] = recorditem + + elif type_id == (49, 2000): # 转账消息 + CompressContent = decompress_CompressContent(CompressContent) + content_tmp = parse_xml_string(CompressContent) + feedesc = content_tmp.get("appmsg", {}).get("wcpayinfo", {}).get("feedesc", "") + content["msg"] = f"转账:{feedesc}" + content["src"] = "" + + elif type_id[0] == 49 and type_id[1] != 0: + DictExtra = read_BytesExtra(BytesExtra) + url = match_BytesExtra(DictExtra) + content["src"] = url + content["msg"] = type_name + + elif type_id == (50, 0): # 语音通话 + content["msg"] = "语音/视频通话[%s]" % DisplayContent + + # elif type_id == (10000, 0): + # content["msg"] = StrContent + # elif type_id == (10000, 4): + # content["msg"] = StrContent + # elif type_id == (10000, 8000): + # content["msg"] = StrContent + + talker = "未知" + if IsSender == 1: + talker = "我" + else: + if StrTalker.endswith("@chatroom"): + bytes_extra = read_BytesExtra(BytesExtra) + if bytes_extra: + try: + talker = bytes_extra['3'][0]['2'].decode('utf-8', errors='ignore') + if "publisher-id" in talker: + talker = "系统" + except: + pass + else: + talker = StrTalker + + row_data = {"MsgSvrID": str(MsgSvrID), "type_name": type_name, "is_sender": IsSender, "talker": talker, + "room_name": StrTalker, "content": content, "CreateTime": CreateTime, "id": id} + data.append(row_data) + return data def get_chat_count(MSG_db_path: [str, list], username: str = ""): @@ -301,15 +280,14 @@ def get_chat_count(MSG_db_path: [str, list], username: str = ""): sql = f"SELECT StrTalker,COUNT(*) FROM MSG WHERE StrTalker='{username}';" else: sql = f"SELECT StrTalker, COUNT(*) FROM MSG GROUP BY StrTalker ORDER BY COUNT(*) DESC;" - db1 = sqlite3.connect(MSG_db_path) - result = execute_sql(db1, sql) - chat_counts = {} - for row in result: - username, chat_count = row - chat_counts[username] = chat_count - db1.close() - return chat_counts + with DBPool(MSG_db_path) as db1: + result = execute_sql(db1, sql) + chat_counts = {} + for row in result: + username, chat_count = row + chat_counts[username] = chat_count + return chat_counts def get_all_chat_count(MSG_db_path: [str, list]): @@ -319,14 +297,13 @@ def get_all_chat_count(MSG_db_path: [str, list]): :return: 聊天记录数量 """ sql = f"SELECT COUNT(*) FROM MSG;" - db1 = sqlite3.connect(MSG_db_path) - result = execute_sql(db1, sql) - if result and len(result) > 0: - chat_counts = result[0][0] - db1.close() - return chat_counts - db1.close() - return 0 + with DBPool(MSG_db_path) as db1: + result = execute_sql(db1, sql) + if result and len(result) > 0: + chat_counts = result[0][0] + return chat_counts + return 0 + def export_csv(username, outpath, MSG_ALL_db_path, page_size=5000): diff --git a/pywxdump/analyzer/utils.py b/pywxdump/analyzer/utils.py index 3cac9b0..e328a7d 100644 --- a/pywxdump/analyzer/utils.py +++ b/pywxdump/analyzer/utils.py @@ -6,7 +6,9 @@ # Date: 2023/12/03 # ------------------------------------------------------------------------------- import hashlib +import os import re +import sqlite3 def read_dict_all_values(data): @@ -115,6 +117,12 @@ def get_name_typeid(type_name: str): (43, 0): "视频", (47, 0): "动画表情", + (37, 0): "添加好友", # 感谢 https://github.com/zhyc9de + (42, 0): "推荐公众号", # 感谢 https://github.com/zhyc9de + (48, 0): "地图信息", # 感谢 https://github.com/zhyc9de + (49, 40): "分享收藏夹", # 感谢 https://github.com/zhyc9de + (49, 53): "接龙", # 感谢 https://github.com/zhyc9de + (49, 0): "文件", (49, 1): "类似文字消息而不一样的消息", (49, 5): "卡片式链接", @@ -153,6 +161,57 @@ def get_md5(data): return md5.hexdigest() +import threading + + +def get_thread_id(): + current_thread = threading.current_thread() + thread_id = current_thread.ident + return thread_id + + +class DBPool: + __db_pool = {} + __thread_pool = {} + + def __new__(cls, *args, **kwargs): + if not hasattr(cls, '_instance'): + cls._instance = super(DBPool, cls).__new__(cls) + return cls._instance + + @classmethod + def create_connection(cls, db_path): + if db_path == "DBPOOL_INIT": + return + if not os.path.exists(db_path): + raise FileNotFoundError(f"数据库文件不存在:{db_path}") + + if db_path not in cls.__db_pool: + cls.__db_pool[db_path] = sqlite3.connect(db_path, check_same_thread=False) + print(f"数据库连接成功") + print(f"数据库连接成功 1") + print(cls.__db_pool) + cls.connection = cls.__db_pool[db_path] + + def __init__(self, db_path): + if db_path == "DBPOOL_INIT": + return + self.db_path = db_path + if db_path not in self.__db_pool: + self.create_connection(db_path) + self.connection = self.__db_pool.get(db_path) + + def __enter__(self): + return self.connection + + def __exit__(self, exc_type, exc_val, exc_tb): + self.connection = None + + def close(self): + self.connection.close() + self.connection = None + + def attach_databases(connection, databases): """ 将多个数据库附加到给定的SQLite连接。 @@ -198,7 +257,7 @@ def execute_sql(connection, sql, params=None): else: cursor.execute(sql) return cursor.fetchall() - except Exception as e: + except Exception as e1: try: connection.text_factory = bytes cursor = connection.cursor() @@ -209,8 +268,8 @@ def execute_sql(connection, sql, params=None): rdata = cursor.fetchall() connection.text_factory = str return rdata - except Exception as e: - print(f"**********\nSQL: {sql}\nparams: {params}\n{e}\n**********") + except Exception as e2: + print(f"**********\nSQL: {sql}\nparams: {params}\n{e1}\n{e2}\n**********") return None