1.增加了对象存储的支持,并兼容读取本地文件
This commit is contained in:
parent
bc90597e54
commit
4086cc53c8
26
pywxdump/file/AttachmentAbstract.py
Normal file
26
pywxdump/file/AttachmentAbstract.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from typing import Protocol, IO
|
||||||
|
|
||||||
|
|
||||||
|
# 基类
|
||||||
|
class Attachment(Protocol):
|
||||||
|
|
||||||
|
def exists(self, path) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def makedirs(self, path) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def open(self, path, param) -> IO:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def join(cls, __a: str, *paths: str) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dirname(cls, path: str) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def basename(cls, path: str) -> str:
|
||||||
|
pass
|
72
pywxdump/file/AttachmentContext.py
Normal file
72
pywxdump/file/AttachmentContext.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import AnyStr, BinaryIO, Callable, Union, IO
|
||||||
|
from flask import send_file, Response
|
||||||
|
|
||||||
|
from pywxdump.file.AttachmentAbstract import Attachment
|
||||||
|
from pywxdump.file.LocalAttachment import LocalAttachment
|
||||||
|
from pywxdump.file.S3Attachment import S3Attachment
|
||||||
|
|
||||||
|
|
||||||
|
def determine_strategy(file_path: str) -> Attachment:
|
||||||
|
if file_path.startswith("s3://"):
|
||||||
|
return S3Attachment()
|
||||||
|
else:
|
||||||
|
return LocalAttachment()
|
||||||
|
|
||||||
|
|
||||||
|
def exists(path: str) -> bool:
|
||||||
|
return determine_strategy(path).exists(path)
|
||||||
|
|
||||||
|
|
||||||
|
def open_file(path: str, mode: str) -> IO:
|
||||||
|
return determine_strategy(path).open(path, mode)
|
||||||
|
|
||||||
|
|
||||||
|
def makedirs(path: str) -> bool:
|
||||||
|
return determine_strategy(path).makedirs(path)
|
||||||
|
|
||||||
|
|
||||||
|
def join(__a: str, *paths: str) -> str:
|
||||||
|
return determine_strategy(__a).join(__a, *paths)
|
||||||
|
|
||||||
|
|
||||||
|
def dirname(path: str) -> str:
|
||||||
|
return determine_strategy(path).dirname(path)
|
||||||
|
|
||||||
|
|
||||||
|
def basename(path: str) -> str:
|
||||||
|
return determine_strategy(path).basename(path)
|
||||||
|
|
||||||
|
|
||||||
|
def send_attachment(
|
||||||
|
path_or_file: Union[os.PathLike[AnyStr], str],
|
||||||
|
mimetype: Union[str, None] = None,
|
||||||
|
as_attachment: bool = False,
|
||||||
|
download_name: Union[str, None] = None,
|
||||||
|
conditional: bool = True,
|
||||||
|
etag: Union[bool, str] = True,
|
||||||
|
last_modified: Union[datetime, int, float, None] = None,
|
||||||
|
max_age: Union[None, int, Callable[[Union[str, None]], Union[int, None]]] = None,
|
||||||
|
) -> Response:
|
||||||
|
file_io = open_file(path_or_file, "rb")
|
||||||
|
|
||||||
|
# 如果没有提供 download_name 或 mimetype,则从 path_or_file 中获取文件名和 MIME 类型
|
||||||
|
if download_name is None:
|
||||||
|
download_name = basename(path_or_file)
|
||||||
|
if mimetype is None:
|
||||||
|
mimetype = 'application/octet-stream'
|
||||||
|
|
||||||
|
return send_file(file_io, mimetype, as_attachment, download_name, conditional, etag, last_modified, max_age)
|
||||||
|
|
||||||
|
|
||||||
|
def download_file(db_path, local_path):
|
||||||
|
with open(local_path, 'wb') as f:
|
||||||
|
with open_file(db_path, 'rb') as r:
|
||||||
|
f.write(r.read())
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
|
||||||
|
def isLocalPath(path: str) -> bool:
|
||||||
|
return isinstance(determine_strategy(path), LocalAttachment)
|
||||||
|
|
46
pywxdump/file/LocalAttachment.py
Normal file
46
pywxdump/file/LocalAttachment.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# 本地文件处理类
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import IO
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAttachment:
|
||||||
|
|
||||||
|
def open(self, path, mode) -> IO:
|
||||||
|
path = self.dealLocalPath(path)
|
||||||
|
return open(path, mode)
|
||||||
|
|
||||||
|
def exists(self, path) -> bool:
|
||||||
|
path = self.dealLocalPath(path)
|
||||||
|
return os.path.exists(path)
|
||||||
|
|
||||||
|
def makedirs(self, path) -> bool:
|
||||||
|
path = self.dealLocalPath(path)
|
||||||
|
os.makedirs(path)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def join(cls, __a: str, *paths: str) -> str:
|
||||||
|
return os.path.join(__a, *paths)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dirname(cls, path: str) -> str:
|
||||||
|
return os.path.dirname(path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def basename(cls, path: str) -> str:
|
||||||
|
return os.path.basename(path)
|
||||||
|
|
||||||
|
def dealLocalPath(self, path: str) -> str:
|
||||||
|
# 获取当前系统的地址分隔符
|
||||||
|
# 将path中的 /替换为当前系统的分隔符
|
||||||
|
path = path.replace('/', os.sep)
|
||||||
|
if sys.platform == "win32":
|
||||||
|
# 如果是windows系统,且路径长度超过260个字符
|
||||||
|
if len(path) >= 260:
|
||||||
|
# 添加前缀
|
||||||
|
return '\\\\?\\' + path
|
||||||
|
else:
|
||||||
|
return path
|
||||||
|
else:
|
||||||
|
return path
|
96
pywxdump/file/S3Attachment.py
Normal file
96
pywxdump/file/S3Attachment.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# 对象存储文件处理类(示例:假设是 AWS S3)
|
||||||
|
import os
|
||||||
|
from typing import IO
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from smart_open import open
|
||||||
|
import boto3
|
||||||
|
from botocore.client import Config
|
||||||
|
|
||||||
|
class S3Attachment:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 腾讯云 COS 配置
|
||||||
|
self.cos_endpoint = "https://cos.<your-region>.myqcloud.com" # 替换 <your-region> 为你的 COS 区域,例如 ap-shanghai
|
||||||
|
self.access_key_id = "SecretId" # 替换为你的腾讯云 SecretId
|
||||||
|
self.secret_access_key = "SecretKey" # 替换为你的腾讯云 SecretKey
|
||||||
|
|
||||||
|
# 创建 S3 客户端
|
||||||
|
self.s3_client = boto3.client(
|
||||||
|
's3',
|
||||||
|
endpoint_url=self.cos_endpoint,
|
||||||
|
aws_access_key_id=self.access_key_id,
|
||||||
|
aws_secret_access_key=self.secret_access_key,
|
||||||
|
config=Config(s3={"addressing_style": "virtual", "signature_version": 's3v4'})
|
||||||
|
)
|
||||||
|
|
||||||
|
def exists(self, path) -> bool:
|
||||||
|
bucket_name, path = self.dealS3Url(path)
|
||||||
|
# 检查是否为目录
|
||||||
|
if path.endswith('/'):
|
||||||
|
# 尝试列出该路径下的对象
|
||||||
|
try:
|
||||||
|
response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path, MaxKeys=1)
|
||||||
|
if 'Contents' in response:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except ClientError as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# 检查是否为文件
|
||||||
|
try:
|
||||||
|
self.s3_client.head_object(Bucket=bucket_name, Key=path)
|
||||||
|
return True
|
||||||
|
except ClientError as e:
|
||||||
|
if e.response['Error']['Code'] == '404':
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def makedirs(self, path) -> bool:
|
||||||
|
if not self.exists(path):
|
||||||
|
bucket_name, path = self.dealS3Url(path)
|
||||||
|
self.s3_client.put_object(Bucket=bucket_name, Key=f'{path}/')
|
||||||
|
return True
|
||||||
|
|
||||||
|
def open(self, path, mode) -> IO:
|
||||||
|
self.dealS3Url(path)
|
||||||
|
return open(uri=path, mode=mode, transport_params={'client': self.s3_client})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def join(cls, __a: str, *paths: str) -> str:
|
||||||
|
return os.path.join(__a, *paths)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dirname(cls, path: str) -> str:
|
||||||
|
return os.path.dirname(path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def basename(cls, path: str) -> str:
|
||||||
|
return os.path.basename(path)
|
||||||
|
|
||||||
|
def dealS3Url(self, path: str) -> object:
|
||||||
|
"""
|
||||||
|
解析 S3 URL 并返回存储桶名称和路径
|
||||||
|
|
||||||
|
参数:
|
||||||
|
path (str): S3 URL
|
||||||
|
|
||||||
|
返回:
|
||||||
|
tuple: 包含存储桶名称和路径的元组
|
||||||
|
"""
|
||||||
|
parsed_url = urlparse(path)
|
||||||
|
|
||||||
|
# 确保URL是S3 URL
|
||||||
|
if parsed_url.scheme != 's3':
|
||||||
|
raise ValueError("URL必须是S3 URL,格式为s3://bucket_name/path")
|
||||||
|
|
||||||
|
bucket_name = parsed_url.netloc
|
||||||
|
s3_path = parsed_url.path.lstrip('/')
|
||||||
|
|
||||||
|
return bucket_name, s3_path
|
||||||
|
|
Loading…
Reference in New Issue
Block a user