diff --git a/pywxdump/file/AttachmentAbstract.py b/pywxdump/file/AttachmentAbstract.py new file mode 100644 index 0000000..811202a --- /dev/null +++ b/pywxdump/file/AttachmentAbstract.py @@ -0,0 +1,26 @@ +from typing import Protocol, IO + + +# 基类 +class Attachment(Protocol): + + def exists(self, path) -> bool: + pass + + def makedirs(self, path) -> bool: + pass + + def open(self, path, param) -> IO: + pass + + @classmethod + def join(cls, __a: str, *paths: str) -> str: + pass + + @classmethod + def dirname(cls, path: str) -> str: + pass + + @classmethod + def basename(cls, path: str) -> str: + pass diff --git a/pywxdump/file/AttachmentContext.py b/pywxdump/file/AttachmentContext.py new file mode 100644 index 0000000..3556c2a --- /dev/null +++ b/pywxdump/file/AttachmentContext.py @@ -0,0 +1,72 @@ +import os +from datetime import datetime +from typing import AnyStr, BinaryIO, Callable, Union, IO +from flask import send_file, Response + +from pywxdump.file.AttachmentAbstract import Attachment +from pywxdump.file.LocalAttachment import LocalAttachment +from pywxdump.file.S3Attachment import S3Attachment + + +def determine_strategy(file_path: str) -> Attachment: + if file_path.startswith("s3://"): + return S3Attachment() + else: + return LocalAttachment() + + +def exists(path: str) -> bool: + return determine_strategy(path).exists(path) + + +def open_file(path: str, mode: str) -> IO: + return determine_strategy(path).open(path, mode) + + +def makedirs(path: str) -> bool: + return determine_strategy(path).makedirs(path) + + +def join(__a: str, *paths: str) -> str: + return determine_strategy(__a).join(__a, *paths) + + +def dirname(path: str) -> str: + return determine_strategy(path).dirname(path) + + +def basename(path: str) -> str: + return determine_strategy(path).basename(path) + + +def send_attachment( + path_or_file: Union[os.PathLike[AnyStr], str], + mimetype: Union[str, None] = None, + as_attachment: bool = False, + download_name: Union[str, None] = None, + conditional: bool = True, + etag: Union[bool, str] = True, + last_modified: Union[datetime, int, float, None] = None, + max_age: Union[None, int, Callable[[Union[str, None]], Union[int, None]]] = None, +) -> Response: + file_io = open_file(path_or_file, "rb") + + # 如果没有提供 download_name 或 mimetype,则从 path_or_file 中获取文件名和 MIME 类型 + if download_name is None: + download_name = basename(path_or_file) + if mimetype is None: + mimetype = 'application/octet-stream' + + return send_file(file_io, mimetype, as_attachment, download_name, conditional, etag, last_modified, max_age) + + +def download_file(db_path, local_path): + with open(local_path, 'wb') as f: + with open_file(db_path, 'rb') as r: + f.write(r.read()) + return local_path + + +def isLocalPath(path: str) -> bool: + return isinstance(determine_strategy(path), LocalAttachment) + diff --git a/pywxdump/file/LocalAttachment.py b/pywxdump/file/LocalAttachment.py new file mode 100644 index 0000000..3705b45 --- /dev/null +++ b/pywxdump/file/LocalAttachment.py @@ -0,0 +1,46 @@ +# 本地文件处理类 +import os +import sys +from typing import IO + + +class LocalAttachment: + + def open(self, path, mode) -> IO: + path = self.dealLocalPath(path) + return open(path, mode) + + def exists(self, path) -> bool: + path = self.dealLocalPath(path) + return os.path.exists(path) + + def makedirs(self, path) -> bool: + path = self.dealLocalPath(path) + os.makedirs(path) + return True + + @classmethod + def join(cls, __a: str, *paths: str) -> str: + return os.path.join(__a, *paths) + + @classmethod + def dirname(cls, path: str) -> str: + return os.path.dirname(path) + + @classmethod + def basename(cls, path: str) -> str: + return os.path.basename(path) + + def dealLocalPath(self, path: str) -> str: + # 获取当前系统的地址分隔符 + # 将path中的 /替换为当前系统的分隔符 + path = path.replace('/', os.sep) + if sys.platform == "win32": + # 如果是windows系统,且路径长度超过260个字符 + if len(path) >= 260: + # 添加前缀 + return '\\\\?\\' + path + else: + return path + else: + return path diff --git a/pywxdump/file/S3Attachment.py b/pywxdump/file/S3Attachment.py new file mode 100644 index 0000000..d117981 --- /dev/null +++ b/pywxdump/file/S3Attachment.py @@ -0,0 +1,96 @@ +# 对象存储文件处理类(示例:假设是 AWS S3) +import os +from typing import IO +from urllib.parse import urlparse + +from botocore.exceptions import ClientError +from smart_open import open +import boto3 +from botocore.client import Config + +class S3Attachment: + + def __init__(self): + # 腾讯云 COS 配置 + self.cos_endpoint = "https://cos..myqcloud.com" # 替换 为你的 COS 区域,例如 ap-shanghai + self.access_key_id = "SecretId" # 替换为你的腾讯云 SecretId + self.secret_access_key = "SecretKey" # 替换为你的腾讯云 SecretKey + + # 创建 S3 客户端 + self.s3_client = boto3.client( + 's3', + endpoint_url=self.cos_endpoint, + aws_access_key_id=self.access_key_id, + aws_secret_access_key=self.secret_access_key, + config=Config(s3={"addressing_style": "virtual", "signature_version": 's3v4'}) + ) + + def exists(self, path) -> bool: + bucket_name, path = self.dealS3Url(path) + # 检查是否为目录 + if path.endswith('/'): + # 尝试列出该路径下的对象 + try: + response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path, MaxKeys=1) + if 'Contents' in response: + return True + else: + return False + except ClientError as e: + print(f"Error: {e}") + return False + else: + # 检查是否为文件 + try: + self.s3_client.head_object(Bucket=bucket_name, Key=path) + return True + except ClientError as e: + if e.response['Error']['Code'] == '404': + return False + else: + print(f"Error: {e}") + return False + + def makedirs(self, path) -> bool: + if not self.exists(path): + bucket_name, path = self.dealS3Url(path) + self.s3_client.put_object(Bucket=bucket_name, Key=f'{path}/') + return True + + def open(self, path, mode) -> IO: + self.dealS3Url(path) + return open(uri=path, mode=mode, transport_params={'client': self.s3_client}) + + @classmethod + def join(cls, __a: str, *paths: str) -> str: + return os.path.join(__a, *paths) + + @classmethod + def dirname(cls, path: str) -> str: + return os.path.dirname(path) + + @classmethod + def basename(cls, path: str) -> str: + return os.path.basename(path) + + def dealS3Url(self, path: str) -> object: + """ + 解析 S3 URL 并返回存储桶名称和路径 + + 参数: + path (str): S3 URL + + 返回: + tuple: 包含存储桶名称和路径的元组 + """ + parsed_url = urlparse(path) + + # 确保URL是S3 URL + if parsed_url.scheme != 's3': + raise ValueError("URL必须是S3 URL,格式为s3://bucket_name/path") + + bucket_name = parsed_url.netloc + s3_path = parsed_url.path.lstrip('/') + + return bucket_name, s3_path +