Impl GetDbNames, GetDbTalbes and ExecDbQuery

This commit is contained in:
Changhua 2023-02-17 22:06:40 +08:00
parent 33a5ed4033
commit 95ce4578bf
7 changed files with 277 additions and 56 deletions

View File

@ -2,6 +2,7 @@
#include <map> #include <map>
#include <string> #include <string>
#include <vector>
using namespace std; using namespace std;
@ -16,3 +17,19 @@ typedef struct {
string province; string province;
string city; string city;
} RpcContact_t; } RpcContact_t;
typedef vector<string> DbNames_t;
typedef struct {
string name;
string sql;
} DbTable_t;
typedef vector<DbTable_t> DbTables_t;
typedef struct {
int32_t type;
string column;
vector<uint8_t> content;
} DbField_t;
typedef vector<DbField_t> DbRow_t;
typedef vector<DbRow_t> DbRows_t;

View File

@ -23,6 +23,7 @@ bool encode_string(pb_ostream_t *stream, const pb_field_t *field, void *const *a
const char *str = (const char *)*arg; const char *str = (const char *)*arg;
if (!pb_encode_tag_for_field(stream, field)) { if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false; return false;
} }
@ -35,11 +36,24 @@ bool decode_string(pb_istream_t *stream, const pb_field_t *field, void **arg)
size_t len = stream->bytes_left; size_t len = stream->bytes_left;
str->resize(len); str->resize(len);
if (!pb_read(stream, (uint8_t *)str->data(), len)) { if (!pb_read(stream, (uint8_t *)str->data(), len)) {
LOG_ERROR("Decoding failed: {}", PB_GET_ERROR(stream));
return false; return false;
} }
return true; return true;
} }
bool encode_bytes(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{
vector<uint8_t> *v = (vector<uint8_t> *)*arg;
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
return pb_encode_string(stream, (uint8_t *)v->data(), v->size());
}
bool encode_types(pb_ostream_t *stream, const pb_field_t *field, void *const *arg) bool encode_types(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{ {
MsgTypes_t *m = (MsgTypes_t *)*arg; MsgTypes_t *m = (MsgTypes_t *)*arg;
@ -51,10 +65,12 @@ bool encode_types(pb_ostream_t *stream, const pb_field_t *field, void *const *ar
message.value.arg = (void *)it->second.c_str(); message.value.arg = (void *)it->second.c_str();
if (!pb_encode_tag_for_field(stream, field)) { if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false; return false;
} }
if (!pb_encode_submessage(stream, MsgTypes_TypesEntry_fields, &message)) { if (!pb_encode_submessage(stream, MsgTypes_TypesEntry_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false; return false;
} }
} }
@ -90,10 +106,112 @@ bool encode_contacts(pb_ostream_t *stream, const pb_field_t *field, void *const
message.gender = (*it).gender; message.gender = (*it).gender;
if (!pb_encode_tag_for_field(stream, field)) { if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false; return false;
} }
if (!pb_encode_submessage(stream, RpcContact_fields, &message)) { if (!pb_encode_submessage(stream, RpcContact_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
}
return true;
}
bool encode_dbnames(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{
vector<string> *v = (vector<string> *)*arg;
DbNames message = DbNames_init_default;
for (auto it = v->begin(); it != v->end(); it++) {
message.names.funcs.encode = &encode_string;
message.names.arg = (void *)(*it).c_str();
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
if (!pb_encode_submessage(stream, DbNames_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
}
return true;
}
bool encode_tables(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{
DbTables_t *v = (DbTables_t *)*arg;
DbTable message = DbTable_init_default;
for (auto it = v->begin(); it != v->end(); it++) {
message.name.funcs.encode = &encode_string;
message.name.arg = (void *)(*it).name.c_str();
message.sql.funcs.encode = &encode_string;
message.sql.arg = (void *)(*it).sql.c_str();
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
if (!pb_encode_submessage(stream, DbTable_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
}
return true;
}
static bool encode_fields(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{
DbRow_t *v = (DbRow_t *)*arg;
DbField message = DbField_init_default;
for (auto it = v->begin(); it != v->end(); it++) {
message.type = (*it).type;
message.column.arg = (void *)(*it).column.c_str();
message.column.funcs.encode = &encode_string;
message.content.arg = (void *)&(*it).content;
message.content.funcs.encode = &encode_bytes;
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
if (!pb_encode_submessage(stream, DbField_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
}
return true;
}
bool encode_rows(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{
DbRows_t *v = (DbRows_t *)*arg;
DbRow message = DbRow_init_default;
for (auto it = v->begin(); it != v->end(); it++) {
message.fields.arg = (void *)&(*it);
message.fields.funcs.encode = &encode_fields;
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
if (!pb_encode_submessage(stream, DbRow_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false; return false;
} }
} }

View File

@ -8,3 +8,6 @@ bool encode_string(pb_ostream_t *stream, const pb_field_t *field, void *const *a
bool decode_string(pb_istream_t *stream, const pb_field_t *field, void **arg); bool decode_string(pb_istream_t *stream, const pb_field_t *field, void **arg);
bool encode_types(pb_ostream_t *stream, const pb_field_t *field, void *const *arg); bool encode_types(pb_ostream_t *stream, const pb_field_t *field, void *const *arg);
bool encode_contacts(pb_ostream_t *stream, const pb_field_t *field, void *const *arg); bool encode_contacts(pb_ostream_t *stream, const pb_field_t *field, void *const *arg);
bool encode_dbnames(pb_ostream_t *stream, const pb_field_t *field, void *const *arg);
bool encode_tables(pb_ostream_t *stream, const pb_field_t *field, void *const *arg);
bool encode_rows(pb_ostream_t *stream, const pb_field_t *field, void *const *arg);

View File

@ -3,3 +3,7 @@
* fallback_type:FT_POINTER * fallback_type:FT_POINTER
MsgTypes* fallback_type:FT_CALLBACK MsgTypes* fallback_type:FT_CALLBACK
RpcContact* fallback_type:FT_CALLBACK RpcContact* fallback_type:FT_CALLBACK
DbNames* fallback_type:FT_CALLBACK
DbTable* fallback_type:FT_CALLBACK
DbField* fallback_type:FT_CALLBACK
DbRow* fallback_type:FT_CALLBACK

View File

@ -1,13 +1,9 @@
#include <algorithm> #include <iterator>
#include <map>
#include <string>
#if 0
#include "exec_sql.h" #include "exec_sql.h"
#include "load_calls.h" #include "load_calls.h"
#include "util.h" #include "util.h"
using namespace std;
#define SQLITE_OK 0 /* Successful result */ #define SQLITE_OK 0 /* Successful result */
#define SQLITE_ERROR 1 /* Generic error */ #define SQLITE_ERROR 1 /* Generic error */
#define SQLITE_INTERNAL 2 /* Internal logic error in SQLite */ #define SQLITE_INTERNAL 2 /* Internal logic error in SQLite */
@ -71,22 +67,6 @@ typedef const void *(__cdecl *Sqlite3_column_blob)(DWORD *, int);
typedef int(__cdecl *Sqlite3_column_bytes)(DWORD *, int); typedef int(__cdecl *Sqlite3_column_bytes)(DWORD *, int);
typedef int(__cdecl *Sqlite3_finalize)(DWORD *); typedef int(__cdecl *Sqlite3_finalize)(DWORD *);
static int cbGetTables(void *ret, int argc, char **argv, char **azColName)
{
wcf::DbTables *tbls = (wcf::DbTables *)ret;
wcf::DbTable *tbl = tbls->add_tables();
for (int i = 0; i < argc; i++) {
if (strcmp(azColName[i], "name") == 0) {
tbl->set_name(argv[i] ? argv[i] : "");
} else if (strcmp(azColName[i], "sql") == 0) {
string sql(argv[i]);
sql.erase(std::remove(sql.begin(), sql.end(), '\t'), sql.end());
tbl->set_sql(sql.c_str());
}
}
return 0;
}
dbMap_t GetDbHandles() dbMap_t GetDbHandles()
{ {
if (!dbMap.empty()) if (!dbMap.empty())
@ -109,37 +89,60 @@ dbMap_t GetDbHandles()
return dbMap; return dbMap;
} }
void GetDbNames(wcf::DbNames *names) DbNames_t GetDbNames()
{ {
DbNames_t names;
if (dbMap.empty()) { if (dbMap.empty()) {
dbMap = GetDbHandles(); dbMap = GetDbHandles();
} }
for (auto &[k, v] : dbMap) { for (auto &[k, v] : dbMap) {
auto *name = names->add_names(); names.push_back(k);
name->assign(k);
} }
return names;
} }
void GetDbTables(const string db, wcf::DbTables *tables) static int cbGetTables(void *ret, int argc, char **argv, char **azColName)
{ {
DbTables_t *tbls = (DbTables_t *)ret;
DbTable_t tbl;
for (int i = 0; i < argc; i++) {
if (strcmp(azColName[i], "name") == 0) {
tbl.name = argv[i] ? argv[i] : "";
} else if (strcmp(azColName[i], "sql") == 0) {
string sql(argv[i]);
sql.erase(std::remove(sql.begin(), sql.end(), '\t'), sql.end());
tbl.sql = sql.c_str();
}
}
tbls->push_back(tbl);
return 0;
}
DbTables_t GetDbTables(const string db)
{
DbTables_t tables;
if (dbMap.empty()) { if (dbMap.empty()) {
dbMap = GetDbHandles(); dbMap = GetDbHandles();
} }
auto it = dbMap.find(db); auto it = dbMap.find(db);
if (it == dbMap.end()) { if (it == dbMap.end()) {
return; // DB not found return tables; // DB not found
} }
const char *sql = "select name, sql from sqlite_master where type=\"table\";"; const char *sql = "select name, sql from sqlite_master where type=\"table\";";
Sqlite3_exec p_Sqlite3_exec = (Sqlite3_exec)(g_WeChatWinDllAddr + g_WxCalls.sql.exec); Sqlite3_exec p_Sqlite3_exec = (Sqlite3_exec)(g_WeChatWinDllAddr + g_WxCalls.sql.exec);
p_Sqlite3_exec(it->second, sql, (sqlite3_callback)cbGetTables, tables, 0); p_Sqlite3_exec(it->second, sql, (sqlite3_callback)cbGetTables, (void *)&tables, 0);
return tables;
} }
void ExecDbQuery(const string db, const string sql, wcf::DbRows *rows) DbRows_t ExecDbQuery(const string db, const string sql)
{ {
DbRows_t rows;
Sqlite3_prepare func_prepare = (Sqlite3_prepare)(g_WeChatWinDllAddr + 0x14227F0); Sqlite3_prepare func_prepare = (Sqlite3_prepare)(g_WeChatWinDllAddr + 0x14227F0);
Sqlite3_step func_step = (Sqlite3_step)(g_WeChatWinDllAddr + 0x13EA780); Sqlite3_step func_step = (Sqlite3_step)(g_WeChatWinDllAddr + 0x13EA780);
Sqlite3_column_count func_column_count = (Sqlite3_column_count)(g_WeChatWinDllAddr + 0x13EACD0); Sqlite3_column_count func_column_count = (Sqlite3_column_count)(g_WeChatWinDllAddr + 0x13EACD0);
@ -156,22 +159,26 @@ void ExecDbQuery(const string db, const string sql, wcf::DbRows *rows)
DWORD *stmt; DWORD *stmt;
int rc = func_prepare(dbMap[db], sql.c_str(), -1, &stmt, 0); int rc = func_prepare(dbMap[db], sql.c_str(), -1, &stmt, 0);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
return; return rows;
} }
while (func_step(stmt) == SQLITE_ROW) { while (func_step(stmt) == SQLITE_ROW) {
wcf::DbRow *row = rows->add_rows(); DbRow_t row;
int col_count = func_column_count(stmt); int col_count = func_column_count(stmt);
for (int i = 0; i < col_count; i++) { for (int i = 0; i < col_count; i++) {
wcf::DbField *field = row->add_fields(); DbField_t field;
field->set_type(func_column_type(stmt, i)); field.type = func_column_type(stmt, i);
field->set_column(func_column_name(stmt, i)); field.column = func_column_name(stmt, i);
int length = func_column_bytes(stmt, i); int length = func_column_bytes(stmt, i);
const void *blob = func_column_blob(stmt, i); const void *blob = func_column_blob(stmt, i);
if (length && (field->type() != 5)) { if (length && (field.type != 5)) {
field->set_content(string((char *)blob, length)); field.content.reserve(length);
copy((uint8_t *)blob, (uint8_t *)blob + length, back_inserter(field.content));
} }
row.push_back(field);
} }
rows.push_back(row);
} }
return rows;
} }
#endif

View File

@ -1,11 +1,7 @@
#pragma once #pragma once
#if 0
#include <string>
#include <vector>
#include "../proto/wcf.grpc.pb.h" #include "pb_types.h"
void GetDbNames(wcf::DbNames *names); DbNames_t GetDbNames();
void GetDbTables(const std::string db, wcf::DbTables *tables); DbTables_t GetDbTables(const string db);
void ExecDbQuery(const std::string db, const std::string sql, wcf::DbRows *rows); DbRows_t ExecDbQuery(const string db, const string sql);
#endif

View File

@ -25,7 +25,7 @@
#include "spy_types.h" #include "spy_types.h"
#include "util.h" #include "util.h"
#define G_BUF_SIZE (1024 * 1024) #define G_BUF_SIZE (16 * 1024 * 1024)
extern int IsLogin(void); // Defined in spy.cpp extern int IsLogin(void); // Defined in spy.cpp
extern std::string GetSelfWxid(); // Defined in spy.cpp extern std::string GetSelfWxid(); // Defined in spy.cpp
@ -51,7 +51,7 @@ bool func_is_login(uint8_t *out, size_t *len)
pb_ostream_t stream = pb_ostream_from_buffer(out, *len); pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) { if (!pb_encode(&stream, Response_fields, &rsp)) {
printf("Encoding failed: %s\n", PB_GET_ERROR(&stream)); LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false; return false;
} }
*len = stream.bytes_written; *len = stream.bytes_written;
@ -68,7 +68,7 @@ bool func_get_self_wxid(uint8_t *out, size_t *len)
pb_ostream_t stream = pb_ostream_from_buffer(out, *len); pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) { if (!pb_encode(&stream, Response_fields, &rsp)) {
printf("Encoding failed: %s\n", PB_GET_ERROR(&stream)); LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false; return false;
} }
*len = stream.bytes_written; *len = stream.bytes_written;
@ -88,7 +88,7 @@ bool func_get_msg_types(uint8_t *out, size_t *len)
pb_ostream_t stream = pb_ostream_from_buffer(out, *len); pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) { if (!pb_encode(&stream, Response_fields, &rsp)) {
printf("Encoding failed: %s\n", PB_GET_ERROR(&stream)); LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false; return false;
} }
*len = stream.bytes_written; *len = stream.bytes_written;
@ -102,13 +102,73 @@ bool func_get_contacts(uint8_t *out, size_t *len)
rsp.func = Functions_FUNC_GET_CONTACTS; rsp.func = Functions_FUNC_GET_CONTACTS;
rsp.which_msg = Response_contacts_tag; rsp.which_msg = Response_contacts_tag;
vector<RpcContact_t> contacts = GetContacts(); vector<RpcContact_t> contacts = GetContacts();
rsp.msg.types.types.funcs.encode = encode_contacts; rsp.msg.contacts.contacts.funcs.encode = encode_contacts;
rsp.msg.types.types.arg = &contacts; rsp.msg.contacts.contacts.arg = &contacts;
pb_ostream_t stream = pb_ostream_from_buffer(out, *len); pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) { if (!pb_encode(&stream, Response_fields, &rsp)) {
printf("Encoding failed: %s\n", PB_GET_ERROR(&stream)); LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false;
}
*len = stream.bytes_written;
return true;
}
bool func_get_db_names(uint8_t *out, size_t *len)
{
Response rsp = Response_init_default;
rsp.func = Functions_FUNC_GET_DB_NAMES;
rsp.which_msg = Response_dbs_tag;
DbNames_t dbnames = GetDbNames();
rsp.msg.dbs.names.funcs.encode = encode_dbnames;
rsp.msg.dbs.names.arg = &dbnames;
pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false;
}
*len = stream.bytes_written;
return true;
}
bool func_get_db_tables(char *db, uint8_t *out, size_t *len)
{
Response rsp = Response_init_default;
rsp.func = Functions_FUNC_GET_DB_TABLES;
rsp.which_msg = Response_tables_tag;
DbTables_t tables = GetDbTables(db);
rsp.msg.tables.tables.funcs.encode = encode_tables;
rsp.msg.tables.tables.arg = &tables;
pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false;
}
*len = stream.bytes_written;
return true;
}
bool func_exec_db_query(char *db, char *sql, uint8_t *out, size_t *len)
{
Response rsp = Response_init_default;
rsp.func = Functions_FUNC_GET_DB_TABLES;
rsp.which_msg = Response_rows_tag;
DbRows_t rows = ExecDbQuery(db, sql);
rsp.msg.rows.rows.arg = &rows;
rsp.msg.rows.rows.funcs.encode = encode_rows;
pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false; return false;
} }
*len = stream.bytes_written; *len = stream.bytes_written;
@ -149,6 +209,21 @@ static bool dispatcher(uint8_t *in, size_t in_len, uint8_t *out, size_t *out_len
ret = func_get_contacts(out, out_len); ret = func_get_contacts(out, out_len);
break; break;
} }
case Functions_FUNC_GET_DB_NAMES: {
LOG_INFO("[Functions_FUNC_GET_DB_NAMES]");
ret = func_get_db_names(out, out_len);
break;
}
case Functions_FUNC_GET_DB_TABLES: {
LOG_INFO("[Functions_FUNC_GET_DB_TABLES]");
ret = func_get_db_tables(req.msg.str, out, out_len);
break;
}
case Functions_FUNC_EXEC_DB_QUERY: {
LOG_INFO("[Functions_FUNC_EXEC_DB_QUERY]");
ret = func_exec_db_query(req.msg.query.db, req.msg.query.sql, out, out_len);
break;
}
default: { default: {
LOG_ERROR("[UNKNOW FUNCTION]"); LOG_ERROR("[UNKNOW FUNCTION]");
break; break;
@ -189,7 +264,8 @@ static int RunServer()
log_buffer(in, in_len); log_buffer(in, in_len);
if (dispatcher(in, in_len, gBuffer, &out_len)) { if (dispatcher(in, in_len, gBuffer, &out_len)) {
log_buffer(gBuffer, out_len); LOG_INFO("Send data length {}", out_len);
// log_buffer(gBuffer, out_len);
rv = nng_send(sock, gBuffer, out_len, 0); rv = nng_send(sock, gBuffer, out_len, 0);
if (rv != 0) { if (rv != 0) {
LOG_ERROR("nng_send: {}", rv); LOG_ERROR("nng_send: {}", rv);
@ -199,7 +275,7 @@ static int RunServer()
// Error // Error
LOG_ERROR("Dispatcher failed..."); LOG_ERROR("Dispatcher failed...");
rv = nng_send(sock, gBuffer, 0, 0); rv = nng_send(sock, gBuffer, 0, 0);
break; // break;
} }
nng_free(in, in_len); nng_free(in, in_len);
} }