Refactoring

This commit is contained in:
Changhua 2025-02-01 08:45:28 +08:00
parent 2033ae2ab8
commit ff5f2378c0
3 changed files with 200 additions and 126 deletions

View File

@ -1,10 +1,16 @@
#include <iterator>
#include <algorithm>
#include <iterator>
#include "exec_sql.h"
#include "fill_response.h"
#include "log.hpp"
#include "sqlite3.h"
#include "util.h"
extern UINT64 g_WeChatWinDllAddr;
namespace exec_sql
{
#define OFFSET_DB_INSTANCE 0x5902000
#define OFFSET_DB_MICROMSG 0xB8
#define OFFSET_DB_CHAT_MSG 0x2C8
@ -16,131 +22,138 @@
#define OFFSET_DB_NAME 0x28
#define OFFSET_DB_MSG_MGR 0x595F900
extern UINT64 g_WeChatWinDllAddr;
using db_map_t = std::map<std::string, QWORD>;
static db_map_t db_map;
typedef map<string, QWORD> dbMap_t;
static dbMap_t dbMap;
static void GetDbHandle(QWORD base, QWORD offset)
static void get_db_handle(QWORD base, QWORD offset)
{
wchar_t *wsp = (wchar_t *)(*(QWORD *)(base + offset + OFFSET_DB_NAME));
string dbname = Wstring2String(wstring(wsp));
dbMap[dbname] = GET_QWORD(base + offset);
auto *wsp = reinterpret_cast<wchar_t *>(*(QWORD *)(base + offset + OFFSET_DB_NAME));
std::string dbname = Wstring2String(std::wstring(wsp));
db_map[dbname] = GET_QWORD(base + offset);
}
static void GetMsgDbHandle(QWORD msgMgrAddr)
static void get_msg_db_handle(QWORD msg_mgr_addr)
{
QWORD dbIndex = GET_QWORD(msgMgrAddr + 0x68);
QWORD pStart = GET_QWORD(msgMgrAddr + 0x50);
for (uint32_t i = 0; i < dbIndex; i++) {
QWORD dbAddr = GET_QWORD(pStart + i * 0x08);
if (dbAddr) {
QWORD db_index = GET_QWORD(msg_mgr_addr + 0x68);
QWORD p_start = GET_QWORD(msg_mgr_addr + 0x50);
for (uint32_t i = 0; i < db_index; i++) {
QWORD db_addr = GET_QWORD(p_start + i * 0x08);
if (db_addr) {
// MSGi.db
string dbname = Wstring2String(GET_WSTRING(dbAddr));
dbMap[dbname] = GET_QWORD(dbAddr + 0x78);
std::string dbname = Wstring2String(GET_WSTRING(db_addr));
db_map[dbname] = GET_QWORD(db_addr + 0x78);
// MediaMsgi.db
QWORD mmdbAddr = GET_QWORD(dbAddr + 0x20);
string mmdbname = Wstring2String(GET_WSTRING(mmdbAddr + 0x78));
dbMap[mmdbname] = GET_QWORD(mmdbAddr + 0x50);
QWORD mmdb_addr = GET_QWORD(db_addr + 0x20);
std::string mmdbname = Wstring2String(GET_WSTRING(mmdb_addr + 0x78));
db_map[mmdbname] = GET_QWORD(mmdb_addr + 0x50);
}
}
}
dbMap_t GetDbHandles()
db_map_t get_db_handles()
{
dbMap.clear();
db_map.clear();
QWORD db_instance_addr = GET_QWORD(g_WeChatWinDllAddr + OFFSET_DB_INSTANCE);
QWORD dbInstanceAddr = GET_QWORD(g_WeChatWinDllAddr + OFFSET_DB_INSTANCE);
get_db_handle(db_instance_addr, OFFSET_DB_MICROMSG); // MicroMsg.db
get_db_handle(db_instance_addr, OFFSET_DB_CHAT_MSG); // ChatMsg.db
get_db_handle(db_instance_addr, OFFSET_DB_MISC); // Misc.db
get_db_handle(db_instance_addr, OFFSET_DB_EMOTION); // Emotion.db
get_db_handle(db_instance_addr, OFFSET_DB_MEDIA); // Media.db
get_db_handle(db_instance_addr, OFFSET_DB_FUNCTION_MSG); // Function.db
GetDbHandle(dbInstanceAddr, OFFSET_DB_MICROMSG); // MicroMsg.db
GetDbHandle(dbInstanceAddr, OFFSET_DB_CHAT_MSG); // ChatMsg.db
GetDbHandle(dbInstanceAddr, OFFSET_DB_MISC); // Misc.db
GetDbHandle(dbInstanceAddr, OFFSET_DB_EMOTION); // Emotion.db
GetDbHandle(dbInstanceAddr, OFFSET_DB_MEDIA); // Media.db
GetDbHandle(dbInstanceAddr, OFFSET_DB_FUNCTION_MSG); // Function.db
get_msg_db_handle(GET_QWORD(g_WeChatWinDllAddr + OFFSET_DB_MSG_MGR)); // MSGi.db & MediaMsgi.db
GetMsgDbHandle(GET_QWORD(g_WeChatWinDllAddr + OFFSET_DB_MSG_MGR)); // MSGi.db & MediaMsgi.db
return dbMap;
return db_map;
}
DbNames_t GetDbNames()
DbNames_t get_db_names()
{
DbNames_t names;
if (dbMap.empty()) {
dbMap = GetDbHandles();
if (db_map.empty()) {
db_map = get_db_handles();
}
for (auto &[k, v] : dbMap) {
DbNames_t names;
for (const auto &[k, _] : db_map) {
names.push_back(k);
}
return names;
}
static int cbGetTables(void *ret, int argc, char **argv, char **azColName)
static int cb_get_tables(void *ret, int argc, char **argv, char **azColName)
{
DbTables_t *tbls = (DbTables_t *)ret;
auto *tables = static_cast<DbTables_t *>(ret);
DbTable_t tbl;
for (int i = 0; i < argc; i++) {
if (strcmp(azColName[i], "name") == 0) {
tbl.name = argv[i] ? argv[i] : "";
} else if (strcmp(azColName[i], "sql") == 0) {
string sql(argv[i]);
std::string sql(argv[i]);
sql.erase(std::remove(sql.begin(), sql.end(), '\t'), sql.end());
tbl.sql = sql.c_str();
tbl.sql = sql;
}
}
tbls->push_back(tbl);
tables->push_back(tbl);
return 0;
}
DbTables_t GetDbTables(const string db)
DbTables_t get_db_tables(const std::string &db)
{
DbTables_t tables;
if (dbMap.empty()) {
dbMap = GetDbHandles();
if (db_map.empty()) {
db_map = get_db_handles();
}
auto it = dbMap.find(db);
if (it == dbMap.end()) {
return tables; // DB not found
auto it = db_map.find(db);
if (it == db_map.end()) {
return tables;
}
const char *sql = "select name, sql from sqlite_master where type=\"table\";";
Sqlite3_exec p_Sqlite3_exec = (Sqlite3_exec)(g_WeChatWinDllAddr + SQLITE3_EXEC_OFFSET);
p_Sqlite3_exec(it->second, sql, (Sqlite3_callback)cbGetTables, (void *)&tables, 0);
constexpr const char *sql = "SELECT name FROM sqlite_master WHERE type='table';";
Sqlite3_exec p_sqlite3_exec = reinterpret_cast<Sqlite3_exec>(g_WeChatWinDllAddr + SQLITE3_EXEC_OFFSET);
p_sqlite3_exec(it->second, sql, (Sqlite3_callback)cb_get_tables, (void *)&tables, nullptr);
return tables;
}
DbRows_t ExecDbQuery(const string db, const string sql)
DbRows_t exec_db_query(const std::string &db, const std::string &sql)
{
DbRows_t rows;
Sqlite3_prepare func_prepare = (Sqlite3_prepare)(g_WeChatWinDllAddr + SQLITE3_PREPARE_OFFSET);
Sqlite3_step func_step = (Sqlite3_step)(g_WeChatWinDllAddr + SQLITE3_STEP_OFFSET);
Sqlite3_column_count func_column_count = (Sqlite3_column_count)(g_WeChatWinDllAddr + SQLITE3_COLUMN_COUNT_OFFSET);
Sqlite3_column_name func_column_name = (Sqlite3_column_name)(g_WeChatWinDllAddr + SQLITE3_COLUMN_NAME_OFFSET);
Sqlite3_column_type func_column_type = (Sqlite3_column_type)(g_WeChatWinDllAddr + SQLITE3_COLUMN_TYPE_OFFSET);
Sqlite3_column_blob func_column_blob = (Sqlite3_column_blob)(g_WeChatWinDllAddr + SQLITE3_COLUMN_BLOB_OFFSET);
Sqlite3_column_bytes func_column_bytes = (Sqlite3_column_bytes)(g_WeChatWinDllAddr + SQLITE3_COLUMN_BYTES_OFFSET);
Sqlite3_finalize func_finalize = (Sqlite3_finalize)(g_WeChatWinDllAddr + SQLITE3_FINALIZE_OFFSET);
if (dbMap.empty()) {
dbMap = GetDbHandles();
Sqlite3_prepare func_prepare = reinterpret_cast<Sqlite3_prepare>(g_WeChatWinDllAddr + SQLITE3_PREPARE_OFFSET);
Sqlite3_step func_step = reinterpret_cast<Sqlite3_step>(g_WeChatWinDllAddr + SQLITE3_STEP_OFFSET);
Sqlite3_column_count func_column_count
= reinterpret_cast<Sqlite3_column_count>(g_WeChatWinDllAddr + SQLITE3_COLUMN_COUNT_OFFSET);
Sqlite3_column_name func_column_name
= reinterpret_cast<Sqlite3_column_name>(g_WeChatWinDllAddr + SQLITE3_COLUMN_NAME_OFFSET);
Sqlite3_column_type func_column_type
= reinterpret_cast<Sqlite3_column_type>(g_WeChatWinDllAddr + SQLITE3_COLUMN_TYPE_OFFSET);
Sqlite3_column_blob func_column_blob
= reinterpret_cast<Sqlite3_column_blob>(g_WeChatWinDllAddr + SQLITE3_COLUMN_BLOB_OFFSET);
Sqlite3_column_bytes func_column_bytes
= reinterpret_cast<Sqlite3_column_bytes>(g_WeChatWinDllAddr + SQLITE3_COLUMN_BYTES_OFFSET);
Sqlite3_finalize func_finalize = reinterpret_cast<Sqlite3_finalize>(g_WeChatWinDllAddr + SQLITE3_FINALIZE_OFFSET);
if (db_map.empty()) {
db_map = get_db_handles();
}
auto it = db_map.find(db);
if (it == db_map.end() || it->second == 0) {
LOG_WARN("Empty handle for database '{}', retrying...", db);
db_map = get_db_handles();
it = db_map.find(db);
if (it == db_map.end() || it->second == 0) {
LOG_ERROR("Failed to get handle for database '{}'", db);
return rows;
}
}
QWORD *stmt;
QWORD handle = dbMap[db];
if (handle == 0) {
LOG_WARN("Empty handle, retrying...");
dbMap = GetDbHandles();
}
int rc = func_prepare(dbMap[db], sql.c_str(), -1, &stmt, 0);
int rc = func_prepare(it->second, sql.c_str(), -1, &stmt, nullptr);
if (rc != SQLITE_OK) {
LOG_ERROR("SQL prepare failed for '{}': error code {}", db, rc);
return rows;
}
@ -154,79 +167,119 @@ DbRows_t ExecDbQuery(const string db, const string sql)
int length = func_column_bytes(stmt, i);
const void *blob = func_column_blob(stmt, i);
if (length && (field.type != 5)) {
field.content.reserve(length);
copy((uint8_t *)blob, (uint8_t *)blob + length, back_inserter(field.content));
if (length > 0 && field.type != SQLITE_NULL) {
field.content.resize(length);
std::memcpy(field.content.data(), blob, length);
}
row.push_back(field);
}
rows.push_back(row);
}
func_finalize(stmt);
return rows;
}
int GetLocalIdandDbidx(uint64_t id, uint64_t *localId, uint32_t *dbIdx)
int get_local_id_and_dbidx(uint64_t id, uint64_t *local_id, uint32_t *db_idx)
{
QWORD msgMgrAddr = GET_QWORD(g_WeChatWinDllAddr + OFFSET_DB_MSG_MGR);
int dbIndex = (int)GET_QWORD(msgMgrAddr + 0x68); // 总不能 int 还不够吧?
QWORD pStart = GET_QWORD(msgMgrAddr + 0x50);
if (!local_id || !db_idx) {
LOG_ERROR("Invalid pointer arguments!");
return -1;
}
*dbIdx = 0;
for (int i = dbIndex - 1; i >= 0; i--) { // 从后往前遍历
QWORD dbAddr = GET_QWORD(pStart + i * 0x08);
if (dbAddr) {
string dbname = Wstring2String(GET_WSTRING(dbAddr));
dbMap[dbname] = GET_QWORD(dbAddr + 0x78);
string sql = "SELECT localId FROM MSG WHERE MsgSvrID=" + to_string(id) + ";";
DbRows_t rows = ExecDbQuery(dbname, sql);
if (rows.empty()) {
continue;
}
DbRow_t row = rows.front();
if (row.empty()) {
continue;
}
DbField_t field = row.front();
if ((field.column.compare("localId") != 0) && (field.type != 1)) {
continue;
}
QWORD msg_mgr_addr = GET_QWORD(g_WeChatWinDllAddr + OFFSET_DB_MSG_MGR);
int db_index = static_cast<int>(GET_QWORD(msg_mgr_addr + 0x68)); // 总不能 int 还不够吧?
QWORD p_start = GET_QWORD(msg_mgr_addr + 0x50);
*localId = strtoull((const char *)(field.content.data()), NULL, 10);
*dbIdx = (uint32_t)(GET_QWORD(GET_QWORD(dbAddr + 0x28) + 0x1E8) >> 32);
return 0;
*db_idx = 0;
for (int i = db_index - 1; i >= 0; i--) { // 从后往前遍历
QWORD db_addr = GET_QWORD(p_start + i * 0x08);
if (!db_addr) {
continue;
}
std::string dbname = Wstring2String(GET_WSTRING(db_addr));
db_map[dbname] = GET_QWORD(db_addr + 0x78);
std::string sql = "SELECT localId FROM MSG WHERE MsgSvrID=" + std::to_string(id) + ";";
DbRows_t rows = exec_db_query(dbname, sql);
if (rows.empty() || rows.front().empty()) {
continue;
}
const DbField_t &field = rows.front().front();
if (field.column != "localId" || field.type != SQLITE_INTEGER) {
continue;
}
std::string id_str(field.content.begin(), field.content.end());
try {
*local_id = std::stoull(id_str);
} catch (const std::exception &e) {
LOG_ERROR("Failed to parse localId: {}", e.what());
continue;
}
*db_idx = static_cast<uint32_t>(GET_QWORD(GET_QWORD(db_addr + 0x28) + 0x1E8) >> 32);
return 0;
}
return -1;
}
vector<uint8_t> GetAudioData(uint64_t id)
std::vector<uint8_t> get_audio_data(uint64_t id)
{
QWORD msgMgrAddr = GET_QWORD(g_WeChatWinDllAddr + OFFSET_DB_MSG_MGR);
int dbIndex = (int)GET_QWORD(msgMgrAddr + 0x68);
QWORD msg_mgr_addr = GET_QWORD(g_WeChatWinDllAddr + OFFSET_DB_MSG_MGR);
int db_index = static_cast<int>(GET_QWORD(msg_mgr_addr + 0x68));
string sql = "SELECT Buf FROM Media WHERE Reserved0=" + to_string(id) + ";";
for (int i = dbIndex - 1; i >= 0; i--) {
string dbname = "MediaMSG" + to_string(i) + ".db";
DbRows_t rows = ExecDbQuery(dbname, sql);
if (rows.empty()) {
std::string sql = "SELECT Buf FROM Media WHERE Reserved0=" + std::to_string(id) + ";";
for (int i = db_index - 1; i >= 0; i--) {
std::string dbname = "MediaMSG" + std::to_string(i) + ".db";
DbRows_t rows = exec_db_query(dbname, sql);
if (rows.empty() || rows.front().empty()) {
continue;
}
DbRow_t row = rows.front();
if (row.empty()) {
continue;
}
DbField_t field = row.front();
if (field.column.compare("Buf") != 0) {
const DbField_t &field = rows.front().front();
if (field.column != "Buf" || field.content.empty()) {
continue;
}
// 首字节为 0x02估计是混淆用的去掉。
vector<uint8_t> rv(field.content.begin() + 1, field.content.end());
if (field.content.front() == 0x02) {
return std::vector<uint8_t>(field.content.begin() + 1, field.content.end());
}
return rv;
return field.content;
}
return vector<uint8_t>();
return {};
}
bool rpc_get_db_names(uint8_t *out, size_t *len)
{
return fill_response<Functions_FUNC_GET_DB_NAMES>(out, len, [&](Response &rsp) {
rsp.msg.dbs.names.funcs.encode = encode_dbnames;
rsp.msg.dbs.names.arg = &get_db_names();
});
}
bool rpc_get_db_tables(const std::string &db, uint8_t *out, size_t *len)
{
return fill_response<Functions_FUNC_GET_DB_TABLES>(out, len, [&](Response &rsp) {
rsp.msg.tables.tables.funcs.encode = encode_tables;
rsp.msg.tables.tables.arg = &get_db_tables(db);
});
}
bool rpc_exec_db_query(const std::string &db, const std::string &sql, uint8_t *out, size_t *len)
{
return fill_response<Functions_FUNC_EXEC_DB_QUERY>(out, len, [&](Response &rsp) {
rsp.msg.rows.rows.funcs.encode = encode_rows;
rsp.msg.rows.rows.arg = &exec_db_query(db, sql);
});
}
} // namespace exec_sql

View File

@ -1,11 +1,32 @@
#pragma once
#include <optional>
#include <string>
#include <vector>
#include "pb_types.h"
DbNames_t GetDbNames();
DbTables_t GetDbTables(const string db);
DbRows_t ExecDbQuery(const string db, const string sql);
int GetLocalIdandDbidx(uint64_t id, uint64_t *localId, uint32_t *dbIdx);
vector<uint8_t> GetAudioData(uint64_t msgid);
namespace exec_sql
{
// 获取数据库名称列表
DbNames_t get_db_names();
// 获取指定数据库的表列表
DbTables_t get_db_tables(const std::string &db);
// 执行 SQL 查询
DbRows_t exec_db_query(const std::string &db, const std::string &sql);
// 获取本地消息 ID 和数据库索引
std::optional<std::pair<uint64_t, uint32_t>> get_local_id_and_dbidx(uint64_t id);
// 获取音频数据
std::vector<uint8_t> get_audio_data(uint64_t msg_id);
// RPC 方法
bool rpc_get_db_names(uint8_t *out, size_t *len);
bool rpc_get_db_tables(const std::string &db, uint8_t *out, size_t *len);
bool rpc_exec_db_query(const std::string &db, const std::string &sql, uint8_t *out, size_t *len);
} // namespace exec_sql

View File

@ -112,7 +112,7 @@ static bool func_get_contacts(uint8_t *out, size_t *len)
static bool func_get_db_names(uint8_t *out, size_t *len)
{
return FillResponse<Functions_FUNC_GET_DB_NAMES>(Response_dbs_tag, out, len, [](Response &rsp) {
static DbNames_t dbnames = GetDbNames();
static DbNames_t dbnames = exec_sql::get_db_names();
rsp.msg.dbs.names.funcs.encode = encode_dbnames;
rsp.msg.dbs.names.arg = &dbnames;
});
@ -121,7 +121,7 @@ static bool func_get_db_names(uint8_t *out, size_t *len)
static bool func_get_db_tables(char *db, uint8_t *out, size_t *len)
{
return FillResponse<Functions_FUNC_GET_DB_TABLES>(Response_tables_tag, out, len, [db](Response &rsp) {
static DbTables_t tables = GetDbTables(db);
static DbTables_t tables = exec_sql::get_db_tables(db);
rsp.msg.tables.tables.funcs.encode = encode_tables;
rsp.msg.tables.tables.arg = &tables;
});
@ -362,7 +362,7 @@ static bool func_exec_db_query(char *db, char *sql, uint8_t *out, size_t *len)
if ((db == nullptr) || (sql == nullptr)) {
LOG_ERROR("Empty db or sql.");
} else {
rows = ExecDbQuery(db, sql);
rows = exec_sql::exec_db_query(db, sql);
}
rsp.msg.rows.rows.arg = &rows;
rsp.msg.rows.rows.funcs.encode = encode_rows;