From 930efaf0e9ee8df0e88d8c7fe386356a899b78d4 Mon Sep 17 00:00:00 2001 From: xaoyaoo Date: Sat, 20 Jul 2024 19:11:36 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E9=80=9Fmerge=5Fdb=E5=90=88=E5=B9=B6?= =?UTF-8?q?=E9=80=9F=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pywxdump/wx_info/merge_db.py | 171 ++++++++++++++++++++--------------- 1 file changed, 99 insertions(+), 72 deletions(-) diff --git a/pywxdump/wx_info/merge_db.py b/pywxdump/wx_info/merge_db.py index 6b5ded1..7b0ea45 100644 --- a/pywxdump/wx_info/merge_db.py +++ b/pywxdump/wx_info/merge_db.py @@ -203,7 +203,7 @@ def execute_sql(connection, sql, params=None): return None -def merge_db(db_paths, save_path="merge.db", CreateTime: int = 0, endCreateTime: int = 0): +def merge_db(db_paths, save_path="merge.db", startCreateTime: int = 0, endCreateTime: int = 0): """ 合并数据库 会忽略主键以及重复的行。 :param db_paths: @@ -214,88 +214,115 @@ def merge_db(db_paths, save_path="merge.db", CreateTime: int = 0, endCreateTime: if os.path.isdir(save_path): save_path = os.path.join(save_path, f"merge_{int(time.time())}.db") + _db_paths = [] + if isinstance(db_paths, str): + if os.path.isdir(db_paths): + _db_paths = [os.path.join(db_paths, i) for i in os.listdir(db_paths) if i.endswith(".db")] + elif os.path.isfile(db_paths): + _db_paths = [db_paths] + else: + raise FileNotFoundError("db_paths 不存在") + if isinstance(db_paths, list): # alias, file_path databases = {f"MSG{i}": db_path for i, db_path in enumerate(db_paths)} - elif isinstance(db_paths, str): - # 判断是否是文件or文件夹 - if os.path.isdir(db_paths): - db_paths = [os.path.join(db_paths, i) for i in os.listdir(db_paths) if i.endswith(".db")] - databases = {f"MSG{i}": db_path for i, db_path in enumerate(db_paths)} - elif os.path.isfile(db_paths): - databases = {"MSG": db_paths} - else: - raise FileNotFoundError("db_paths 不存在") else: raise TypeError("db_paths 类型错误") outdb = sqlite3.connect(save_path) out_cursor = outdb.cursor() + + # 检查是否存在表 sync_log,用于记录同步记录,包括微信数据库路径,表名,记录数,同步时间 + sync_log_status = execute_sql(outdb, "SELECT name FROM sqlite_master WHERE type='table' AND name='sync_log'") + if len(sync_log_status) < 1: + # db_path 微信数据库路径,tbl_name 表名,src_count 源数据库记录数,current_count 当前合并后的数据库对应表记录数 + sync_record_create_sql = ("CREATE TABLE sync_log (" + "id INTEGER PRIMARY KEY AUTOINCREMENT," + "db_path TEXT NOT NULL," + "tbl_name TEXT NOT NULL," + "src_count INT," + "current_count INT," + "createTime INT DEFAULT (strftime('%s', 'now')), " + "updateTime INT DEFAULT (strftime('%s', 'now'))" + ");") + out_cursor.execute(sync_record_create_sql) + # 创建索引 + out_cursor.execute("CREATE INDEX idx_sync_log_db_path ON sync_log (db_path);") + out_cursor.execute("CREATE INDEX idx_sync_log_tbl_name ON sync_log (tbl_name);") + # 创建联合索引,防止重复 + out_cursor.execute("CREATE UNIQUE INDEX idx_sync_log_db_tbl ON sync_log (db_path, tbl_name);") + outdb.commit() + # 将MSG_db_paths中的数据合并到out_db_path中 - for alias in databases: - db = sqlite3.connect(databases[alias]) - # 获取表名 - sql = f"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;" - tables = execute_sql(db, sql) - try: - for table in tables: - table = table[0] - if table == "sqlite_sequence": - continue - # 获取表中的字段名 - sql = f"PRAGMA table_info({table})" - columns = execute_sql(db, sql) - if not columns or len(columns) < 1: - continue - col_type = { - (i[1] if isinstance(i[1], str) else i[1].decode(), i[2] if isinstance(i[2], str) else i[2].decode()) - for - i in columns} - columns = [i[1] if isinstance(i[1], str) else i[1].decode() for i in columns] - if not columns or len(columns) < 1: - continue + for alias, path in databases.items(): + # 附加数据库 + sql_attach = f"ATTACH DATABASE '{path}' AS {alias}" + out_cursor.execute(sql_attach) + outdb.commit() + sql_query_tbl_name = f"SELECT name FROM {alias}.sqlite_master WHERE type='table' ORDER BY name;" + tables = execute_sql(outdb, sql_query_tbl_name) + for table in tables: + table = table[0] + if table == "sqlite_sequence": + continue + # 获取表中的字段名 + sql_query_columns = f"PRAGMA table_info({table})" + columns = execute_sql(outdb, sql_query_columns) + col_type = { + (i[1] if isinstance(i[1], str) else i[1].decode(), + i[2] if isinstance(i[2], str) else i[2].decode()) + for i in columns} + columns = [i[0] for i in col_type] + if not columns or len(columns) < 1: + continue + # 创建表table + sql_create_tbl = f"CREATE TABLE IF NOT EXISTS {table} AS SELECT * FROM {alias}.{table} WHERE 0 = 1;" + out_cursor.execute(sql_create_tbl) + # 创建包含 NULL 值比较的 UNIQUE 索引 + index_name = f"{table}_unique_index" + coalesce_columns = ','.join(f"COALESCE({column}, '')" for column in columns) + sql = f"CREATE UNIQUE INDEX IF NOT EXISTS {index_name} ON {table} ({coalesce_columns})" + out_cursor.execute(sql) - # 检测表是否存在 - sql = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table}'" - out_cursor.execute(sql) - if len(out_cursor.fetchall()) < 1: - # 创建表 - # 拼接创建表的SQL语句 - column_definitions = [] - for column in col_type: - column_name = column[0] if isinstance(column[0], str) else column[0].decode() - column_type = column[1] if isinstance(column[1], str) else column[1].decode() - column_definition = f"{column_name} {column_type}" - column_definitions.append(column_definition) - sql = f"CREATE TABLE IF NOT EXISTS {table} ({','.join(column_definitions)})" - # sql = f"CREATE TABLE IF NOT EXISTS {table} ({','.join(columns)})" - out_cursor.execute(sql) + # 插入sync_log + sql_query_sync_log = f"SELECT * FROM sync_log WHERE db_path=? AND tbl_name=?" + sync_log = execute_sql(outdb, sql_query_sync_log, (path, table)) + if not sync_log or len(sync_log) < 1: + sql_insert_sync_log = "INSERT INTO sync_log (db_path, tbl_name, src_count, current_count) VALUES (?, ?, ?, ?)" + out_cursor.execute(sql_insert_sync_log, (path, table, 0, 0)) + outdb.commit() - # 创建包含 NULL 值比较的 UNIQUE 索引 - index_name = f"{table}_unique_index" - coalesce_columns = ','.join(f"COALESCE({column}, '')" for column in columns) # 将 NULL 值转换为 '' - sql = f"CREATE UNIQUE INDEX IF NOT EXISTS {index_name} ON {table} ({coalesce_columns})" - out_cursor.execute(sql) + # 比较源数据库和合并后的数据库记录数 + log_src_count = execute_sql(outdb, sql_query_sync_log, (path, table))[0][3] + src_count = execute_sql(outdb, f"SELECT COUNT(*) FROM {alias}.{table}")[0][0] + if src_count <= log_src_count: + continue - # 获取表中的数据 - if "CreateTime" in columns and CreateTime > 0: - sql = f"SELECT {','.join([i[0] for i in col_type])} FROM {table} WHERE CreateTime>? ORDER BY CreateTime" - src_data = execute_sql(db, sql, (CreateTime,)) - else: - sql = f"SELECT {','.join([i[0] for i in col_type])} FROM {table}" - src_data = execute_sql(db, sql) - if not src_data or len(src_data) < 1: - continue - # 插入数据 - sql = f"INSERT OR IGNORE INTO {table} ({','.join([i[0] for i in col_type])}) VALUES ({','.join(['?'] * len(columns))})" - try: - out_cursor.executemany(sql, src_data) - except Exception as e: - logging.error(f"error: {alias}\n{table}\n{sql}\n{src_data}\n{len(src_data)}\n{e}\n**********") - outdb.commit() - except Exception as e: - logging.error(f"fun(merge_db) error: {alias}\n{e}\n**********") - db.close() + sql_base = f"SELECT {','.join([i for i in columns])} FROM {alias}.{table} " + # 构建WHERE子句 + where_clauses, params = [], [] + if "CreateTime" in columns: + if startCreateTime > 0: + where_clauses.append("CreateTime > ?") + params.append(startCreateTime) + if endCreateTime > 0: + where_clauses.append("CreateTime < ?") + params.append(endCreateTime) + # 如果有WHERE子句,将其添加到SQL语句中,并添加ORDER BY子句 + sql = f"{sql_base} WHERE {' AND '.join(where_clauses)} ORDER BY CreateTime" if where_clauses else sql_base + src_data = execute_sql(outdb, sql, tuple(params)) + if not src_data or len(src_data) < 1: + continue + # 插入数据 + sql = f"INSERT OR IGNORE INTO {table} ({','.join([i for i in columns])}) VALUES ({','.join(['?'] * len(columns))})" + try: + out_cursor.executemany(sql, src_data) + except Exception as e: + logging.error(f"error: {path}\n{table}\n{sql}\n{src_data}\n{len(src_data)}\n{e}\n", exc_info=True) + # 分离数据库 + sql_detach = f"DETACH DATABASE {alias}" + out_cursor.execute(sql_detach) + outdb.commit() outdb.close() return save_path @@ -362,7 +389,7 @@ def decrypt_merge(wx_path, key, outpath="", CreateTime: int = 0, endCreateTime: de_db_type = [f"de_{i}" for i in db_type] parpare_merge_db_path = [i for i in out_dbs if any(keyword in i for keyword in de_db_type)] - merge_save_path = merge_db(parpare_merge_db_path, merge_save_path, CreateTime=CreateTime, + merge_save_path = merge_db(parpare_merge_db_path, merge_save_path, startCreateTime=CreateTime, endCreateTime=endCreateTime) return True, merge_save_path