1.增加了对象存储的支持,并兼容读取本地文件

This commit is contained in:
cllcode 2024-07-07 18:00:37 +08:00
parent bc90597e54
commit 4086cc53c8
4 changed files with 240 additions and 0 deletions

View File

@ -0,0 +1,26 @@
from typing import Protocol, IO
# 基类
class Attachment(Protocol):
def exists(self, path) -> bool:
pass
def makedirs(self, path) -> bool:
pass
def open(self, path, param) -> IO:
pass
@classmethod
def join(cls, __a: str, *paths: str) -> str:
pass
@classmethod
def dirname(cls, path: str) -> str:
pass
@classmethod
def basename(cls, path: str) -> str:
pass

View File

@ -0,0 +1,72 @@
import os
from datetime import datetime
from typing import AnyStr, BinaryIO, Callable, Union, IO
from flask import send_file, Response
from pywxdump.file.AttachmentAbstract import Attachment
from pywxdump.file.LocalAttachment import LocalAttachment
from pywxdump.file.S3Attachment import S3Attachment
def determine_strategy(file_path: str) -> Attachment:
if file_path.startswith("s3://"):
return S3Attachment()
else:
return LocalAttachment()
def exists(path: str) -> bool:
return determine_strategy(path).exists(path)
def open_file(path: str, mode: str) -> IO:
return determine_strategy(path).open(path, mode)
def makedirs(path: str) -> bool:
return determine_strategy(path).makedirs(path)
def join(__a: str, *paths: str) -> str:
return determine_strategy(__a).join(__a, *paths)
def dirname(path: str) -> str:
return determine_strategy(path).dirname(path)
def basename(path: str) -> str:
return determine_strategy(path).basename(path)
def send_attachment(
path_or_file: Union[os.PathLike[AnyStr], str],
mimetype: Union[str, None] = None,
as_attachment: bool = False,
download_name: Union[str, None] = None,
conditional: bool = True,
etag: Union[bool, str] = True,
last_modified: Union[datetime, int, float, None] = None,
max_age: Union[None, int, Callable[[Union[str, None]], Union[int, None]]] = None,
) -> Response:
file_io = open_file(path_or_file, "rb")
# 如果没有提供 download_name 或 mimetype则从 path_or_file 中获取文件名和 MIME 类型
if download_name is None:
download_name = basename(path_or_file)
if mimetype is None:
mimetype = 'application/octet-stream'
return send_file(file_io, mimetype, as_attachment, download_name, conditional, etag, last_modified, max_age)
def download_file(db_path, local_path):
with open(local_path, 'wb') as f:
with open_file(db_path, 'rb') as r:
f.write(r.read())
return local_path
def isLocalPath(path: str) -> bool:
return isinstance(determine_strategy(path), LocalAttachment)

View File

@ -0,0 +1,46 @@
# 本地文件处理类
import os
import sys
from typing import IO
class LocalAttachment:
def open(self, path, mode) -> IO:
path = self.dealLocalPath(path)
return open(path, mode)
def exists(self, path) -> bool:
path = self.dealLocalPath(path)
return os.path.exists(path)
def makedirs(self, path) -> bool:
path = self.dealLocalPath(path)
os.makedirs(path)
return True
@classmethod
def join(cls, __a: str, *paths: str) -> str:
return os.path.join(__a, *paths)
@classmethod
def dirname(cls, path: str) -> str:
return os.path.dirname(path)
@classmethod
def basename(cls, path: str) -> str:
return os.path.basename(path)
def dealLocalPath(self, path: str) -> str:
# 获取当前系统的地址分隔符
# 将path中的 /替换为当前系统的分隔符
path = path.replace('/', os.sep)
if sys.platform == "win32":
# 如果是windows系统且路径长度超过260个字符
if len(path) >= 260:
# 添加前缀
return '\\\\?\\' + path
else:
return path
else:
return path

View File

@ -0,0 +1,96 @@
# 对象存储文件处理类(示例:假设是 AWS S3
import os
from typing import IO
from urllib.parse import urlparse
from botocore.exceptions import ClientError
from smart_open import open
import boto3
from botocore.client import Config
class S3Attachment:
def __init__(self):
# 腾讯云 COS 配置
self.cos_endpoint = "https://cos.<your-region>.myqcloud.com" # 替换 <your-region> 为你的 COS 区域,例如 ap-shanghai
self.access_key_id = "SecretId" # 替换为你的腾讯云 SecretId
self.secret_access_key = "SecretKey" # 替换为你的腾讯云 SecretKey
# 创建 S3 客户端
self.s3_client = boto3.client(
's3',
endpoint_url=self.cos_endpoint,
aws_access_key_id=self.access_key_id,
aws_secret_access_key=self.secret_access_key,
config=Config(s3={"addressing_style": "virtual", "signature_version": 's3v4'})
)
def exists(self, path) -> bool:
bucket_name, path = self.dealS3Url(path)
# 检查是否为目录
if path.endswith('/'):
# 尝试列出该路径下的对象
try:
response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path, MaxKeys=1)
if 'Contents' in response:
return True
else:
return False
except ClientError as e:
print(f"Error: {e}")
return False
else:
# 检查是否为文件
try:
self.s3_client.head_object(Bucket=bucket_name, Key=path)
return True
except ClientError as e:
if e.response['Error']['Code'] == '404':
return False
else:
print(f"Error: {e}")
return False
def makedirs(self, path) -> bool:
if not self.exists(path):
bucket_name, path = self.dealS3Url(path)
self.s3_client.put_object(Bucket=bucket_name, Key=f'{path}/')
return True
def open(self, path, mode) -> IO:
self.dealS3Url(path)
return open(uri=path, mode=mode, transport_params={'client': self.s3_client})
@classmethod
def join(cls, __a: str, *paths: str) -> str:
return os.path.join(__a, *paths)
@classmethod
def dirname(cls, path: str) -> str:
return os.path.dirname(path)
@classmethod
def basename(cls, path: str) -> str:
return os.path.basename(path)
def dealS3Url(self, path: str) -> object:
"""
解析 S3 URL 并返回存储桶名称和路径
参数:
path (str): S3 URL
返回:
tuple: 包含存储桶名称和路径的元组
"""
parsed_url = urlparse(path)
# 确保URL是S3 URL
if parsed_url.scheme != 's3':
raise ValueError("URL必须是S3 URL格式为s3://bucket_name/path")
bucket_name = parsed_url.netloc
s3_path = parsed_url.path.lstrip('/')
return bucket_name, s3_path