You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

542 lines
19 KiB

import os
import json
import time
import shutil
import zipfile
from urllib.parse import urlparse
from uuid import uuid4
from PySide2.QtCore import QMutex
from typing import Dict
from .bottle import request, static_file, HTTPError
from .ocr_server import get_ocr_options
from ..ocr.output import Output
from ..mission.mission_doc import MissionDOC
from ..utils.utils import initConfigDict, DocSuf
from ..ocr.output.tools import getDataText
from call_func import CallFunc
UPLOAD_DIR = "./temp_doc" # 上传文件临时目录
TEMP_FILE_RETENTION_DURATION = 24 # 任务临时文件保留时长,小时
TEMP_FILE_CLEANUP_INTERVAL = 0.5 # 自动清理临时文件的间隔,小时
# 获取参数模板字典
def get_doc_options():
opts = get_ocr_options(is_format=False)
opts["tbpu.ignoreRangeStart"] = {
"title": "忽略区域起始",
"toolTip": "忽略区域生效的页数范围起始。从1开始。",
"default": 1,
"isInt": True,
}
opts["tbpu.ignoreRangeEnd"] = {
"title": "忽略区域结束",
"toolTip": "忽略区域生效的页数范围结束。可以用负数表示倒数第X页。",
"default": -1,
"isInt": True,
}
opts["pageRangeStart"] = {
"title": "OCR页数起始",
"toolTip": "OCR的页数范围起始。从1开始。",
"default": 1,
"isInt": True,
}
opts["pageRangeEnd"] = {
"title": "OCR页数结束",
"toolTip": "OCR的页数范围结束。可以用负数表示倒数第X页。",
"default": -1,
"isInt": True,
}
opts["pageList"] = {
"title": "OCR页数列表",
"toolTip": "数组,可指定单个或多个页数。例:[1,2,5]表示对第1、2、5页进行OCR。如果与页数范围同时填写,则 pageList 优先。",
"default": [],
"type": "var",
}
opts["password"] = {
"title": "密码",
"toolTip": "如果文档已加密,则填写文档密码。",
"default": "",
}
opts["doc.extractionMode"] = {
"title": "内容提取模式",
"toolTip": "若一页文档既存在图片又存在文本,如何进行处理。",
"optionsList": [
["mixed", "混合OCR/原文本"],
["fullPage", "整页强制OCR"],
["imageOnly", "仅OCR图片"],
["textOnly", "仅拷贝原有文本"],
],
}
opts = initConfigDict(opts) # 格式化
return opts
UPLOAD_DIR = os.path.abspath(UPLOAD_DIR) # 路径转绝对
TEMP_FILE_RETENTION_DURATION *= 3600 # 小时转为秒
TEMP_FILE_CLEANUP_INTERVAL *= 3600
# 异常类
class DocUnitError(Exception):
def __init__(self, data):
self.data = data
# 单个任务单元
class _DocUnit:
def __init__(
self, dir_id, dir_path, origin_path, origin_name, origin_prefix, options
):
# 提取文档信息
doc_info = MissionDOC.getDocInfo(origin_path)
if "error" in doc_info.keys():
raise DocUnitError({"code": 201, "data": doc_info["error"]})
# 补充缺失的默认参数
default = get_doc_options()
for key in default:
if key not in options:
options[key] = default[key]["default"]
# 提取参数
page_range = [options["pageRangeStart"], options["pageRangeEnd"]] # 识别范围
page_list = options["pageList"] # 页数列表
if page_list: # 下标起始由1转为0
page_list = [x - 1 for x in page_list]
password = options["password"] # 密码
if not password and doc_info["is_encrypted"]:
raise DocUnitError(
{
"code": 202,
"data": "The doc is encrypted, please fill in the password.",
}
)
# 从 options 中提取一些条目,组装 docArgd 作为 MissionDoc 任务参数字典
prefixes = ["ocr.", "doc.", "tbpu."] # 要提取的条目前缀
doc_argd = {}
for k, v in options.items():
for prefix in prefixes:
if k.startswith(prefix):
doc_argd[k] = v
break
# 任务信息
msnInfo = {
"onStart": self._onStart,
"onGet": self._onGet,
"onEnd": self._onEnd,
"argd": doc_argd,
}
# 提交任务
self.msnID = ""
msg = MissionDOC.addMission(
msnInfo, origin_path, page_range, page_list, password
)
if not msg:
raise DocUnitError({"code": 203, "data": "addMission unknow."})
if msg.startswith("["):
raise DocUnitError({"code": 204, "data": msg})
page_list = msnInfo["pageList"]
self.password = password
self.dir_id = dir_id
self.dir_path = dir_path
self.origin_prefix = origin_prefix
self.origin_name = origin_name
self.origin_path = origin_path
self.msnID = msg # 任务ID
self.results = {} # 任务结果原始字典,键为页数
self.pages_count = len(page_list) # 任务总页数
self.processed_count = 0 # 已处理的页数
self.unread_list = [] # 未读的任务列表
self.is_done = False # 当前任务是否完成
self.state = "waiting" # 任务状态, waiting running success failure
self.message = "" # 如果任务失败,则记录失败信息
self.start_timestamp = time.time() # 开始时间戳
self.end_timestamp = time.time() # 任务结束的时间戳
self._mutex = QMutex() # 主锁
# ========================= 【接口】 =========================
# 获取结果
def get_result(
self,
is_data=False, # True 时返回识别内容data
format="dict", # 识别内容格式, "dict", "text"
is_unread=False, # True 时只返回未读过的识别内容
):
self._mutex.lock()
data = {
"code": 100,
"processed_count": self.processed_count, # 已处理的数量
"pages_count": self.pages_count, # 总页数
"is_done": self.is_done, # 是否已结束
"state": self.state, # 任务状态
"data": [], # 结果
}
if self.state == "failure":
data["message"] = self.message
# 需要返回识别内容
if is_data:
datas = []
# 增量式
if is_unread:
for page in self.unread_list:
datas.append(self.results[page])
self.unread_list = []
# 全量式
else:
for _, res in self.results.items():
datas.append(res)
# 需要转为纯文本
if format == "text":
datas_text = ""
for res in datas:
if res["code"] == 100:
datas_text += getDataText(res["data"])
datas = datas_text
data["data"] = datas
self._mutex.unlock()
return data
# 获取文件
def get_files(
self,
base_url, # 下载基础url
file_types=["pdfLayered"], # 输出文件类型,可选:
# txt, txtPlain, jsonl, csv, pdfLayered, pdfOneLayer
ingore_blank=True, # 忽略空白页数
):
if not self.is_done:
return {"code": 201, "data": f"{self.msnID} 任务尚未结束,无法获取文件"}
if not self.state == "success":
return {"code": 201, "data": f"{self.msnID} 任务处理失败,无法获取文件"}
if not isinstance(file_types, list) or not isinstance(ingore_blank, bool):
return {
"code": 202,
"data": f"参数类型错误: file_types={file_types} , ingore_blank={ingore_blank}",
}
# 删除旧的文件
for filename in os.listdir(self.dir_path):
file_path = os.path.join(self.dir_path, filename)
if filename != self.origin_name and os.path.isfile(file_path):
os.remove(file_path)
# 准备参数
startDatetime = time.strftime( # 日期时间字符串(标准格式)
r"%Y-%m-%d %H:%M:%S", time.localtime(self.start_timestamp)
)
outputArgd = {
"outputDir": self.dir_path, # 输出路径
"outputDirType": "specify",
"outputFileName": "[OCR]_" + self.origin_prefix, # 输出文件名(前缀)
"startDatetime": startDatetime, # 开始日期
"ingoreBlank": ingore_blank, # 忽略空白页数
"originPath": self.origin_path, # 原始文件
"password": self.password, # 文档密码
}
# 创建输出器
output = []
try:
for f in file_types:
output.append(Output[f](outputArgd))
except Exception as e:
return {"code": 203, "data": f"初始化输出器失败。{e}"}
# 输出
for o in output:
for _, res in self.results.items():
try:
o.print(res)
except Exception as e:
return {"code": 204, "data": f"输出失败:{o}\n{e}"}
try:
o.onEnd() # 保存
except Exception as e:
return {"code": 205, "data": f"保存失败:{o}\n{e}"}
# 收集新的文件
download_paths = []
for filename in os.listdir(self.dir_path):
file_path = os.path.join(self.dir_path, filename)
if filename != self.origin_name and os.path.isfile(file_path):
download_paths.append(file_path)
# 如果文件多,则打包zip
if not download_paths:
return {"code": 206, "data": "未找到生成的文件"}
elif len(download_paths) == 1:
download_name = os.path.basename(download_paths[0])
else:
download_name = f"[OCR]_{self.origin_prefix}.zip"
zip_path = os.path.join(self.dir_path, download_name)
# 将 download_list 中的所有文件打包为 zip
try:
with zipfile.ZipFile(zip_path, "w") as zipf:
for p in download_paths:
zipf.write(p, os.path.basename(p))
except Exception as e:
return {"code": 207, "data": f"无法打包zip:{e}"}
# 组合下载地址
url = f"{base_url}/api/doc/download/{self.dir_id}/{download_name}"
return {"code": 100, "data": url, "name": download_name}
# 清理任务
def clear(self):
# 停止任务
if not self.is_done:
MissionDOC.stopMissionList([self.msnID])
time.sleep(0.1) # 给一些时间收尾
# 删除目录
if os.path.exists(self.dir_path):
shutil.rmtree(self.dir_path)
# ========================= 【任务控制器的异步回调】 =========================
def _onStart(self, msnInfo): # 一个文档 开始
self.state = "running"
def _onGet(self, msnInfo, page, res): # 一个文档的一页 获取结果
page += 1
res["page"] = page
res["path"] = f"{self.origin_name} - {page}"
res["fileName"] = f"{self.origin_name} - {page}"
# 记录信息
self._mutex.lock()
self.results[page] = res
self.processed_count += 1
self.unread_list.append(page)
self._mutex.unlock()
def _onEnd(self, msnInfo, msg): # 一个文档处理完毕
# msg: [Success] [Warning] [Error]
# 记录信息
self._mutex.lock()
self.is_done = True
if msg == "[Success]":
self.state = "success"
else:
self.state = "failure"
self.message = msg
self.end_timestamp = time.time() # 刷新结束时间戳
self._mutex.unlock()
# 管理所有任务单元
class _DocUnitManagerClass:
def __init__(self):
self.doc_units: Dict[str, _DocUnit] = {}
# 添加一个任务单元
def add(self, id: str, unit: _DocUnit):
self.doc_units[id] = unit
# 获取一个任务单元
def get(self, id: str):
if id not in self.doc_units:
return None
return self.doc_units[id]
# 手动清理一个任务
def clear(self, id: str):
if id in self.doc_units:
self.doc_units[id].clear()
del self.doc_units[id]
return True
return False
# 自动清理
def auto_clear(self):
# 清理超时的任务和文件
if self.doc_units:
now = time.time() # 当前时间戳
del_list = [] # 要清理的id
for id, unit in self.doc_units.items():
if now - unit.end_timestamp > TEMP_FILE_RETENTION_DURATION:
print(f"超时自动清理 {id}")
unit.clear() # 清理文件
del_list.append(id)
for id in del_list:
del self.doc_units[id] # 清理任务对象
# 计划下一次清理
CallFunc.delay(self.auto_clear, TEMP_FILE_CLEANUP_INTERVAL)
_DocUnitManager = _DocUnitManagerClass()
# 路由函数
def init(UmiWeb):
# 清空上传文件目录内容
if os.path.exists(UPLOAD_DIR):
shutil.rmtree(UPLOAD_DIR)
os.makedirs(UPLOAD_DIR)
# 启动自动清理循环
_DocUnitManager.auto_clear()
@UmiWeb.route("/api/doc/get_options")
def _get_options_json():
opts = get_doc_options()
res = json.dumps(opts)
return res
"""
上传文档,方法:POST
参数:文档内容
返回值:
成功: {"code": 100, "data": "任务id"}
失败: {"code": 不是100的值, "data": "失败原因"}
"""
@UmiWeb.route("/api/doc/upload", method="POST")
def _upload():
# 1. 获取上传文件
upload = request.files.get("file")
if not upload:
return {"code": 101, "data": "[Error] No file was uploaded."}
# 2. 检查文件后缀
origin_name = upload.filename
origin_prefix, ext = os.path.splitext(origin_name)
ext = ext.lower()
if ext not in DocSuf:
return {
"code": 102,
"data": f"[Error] File extension '{ext}' is not allowed.",
}
# 3. 指定文件编号。创建对应目录,保存文件到 ./temp/dir_id/原文件名
dir_id = str(uuid4())
dir_path = os.path.join(UPLOAD_DIR, f"{dir_id}")
dir_path = os.path.abspath(dir_path) # 将路径转为绝对路径
file_path = os.path.join(dir_path, origin_name)
# 安全检测: file_path 是否在 UPLOAD_DIR 中
if os.path.commonpath([UPLOAD_DIR]) != os.path.commonpath(
[UPLOAD_DIR, file_path]
):
return {"code": 103, "data": f"[Error] Unauthorized path"}
try:
if os.path.exists(dir_path): # 如果目录存在,则删除该目录
shutil.rmtree(dir_path)
os.makedirs(dir_path) # 重新创建目录
except Exception as e:
return {"code": 104, "data": f"[Error] Failed to create dir_id: {e}"}
try:
upload.save(file_path, overwrite=True) # 保存文件
except Exception as e:
return {"code": 105, "data": f"[Error] Failed to save file: {e}"}
# 4. 提取 options 参数
options = request.forms.get("json")
if options:
try:
options = json.loads(options)
except Exception as e:
shutil.rmtree(dir_path)
return {
"code": 106,
"data": f"[Error] Invalid JSON format: {options} | {e}",
}
if not isinstance(options, dict):
options = {}
# 5. 构造任务对象
try:
doc_unit = _DocUnit(
dir_id, dir_path, file_path, origin_name, origin_prefix, options
)
msnID = doc_unit.msnID
_DocUnitManager.add(msnID, doc_unit)
return {"code": 100, "data": msnID}
except DocUnitError as e:
shutil.rmtree(dir_path)
return e.data
except Exception as e:
shutil.rmtree(dir_path)
return {"code": 107, "data": f"[Error] Failed to submit mission: {e}"}
"""
获取结果,方法:POST
json参数:
"id"="", # 任务ID
"is_data"=False, # True 时返回识别内容data
"format"="dict", # 识别内容格式, "dict", "text"
"is_unread"=False, # True 时只返回未读过的识别内容
返回值: {}
"""
@UmiWeb.route("/api/doc/result", method="POST")
def _result():
try:
user_data = request.json
except Exception as e:
return {"code": 101, "data": f"请求无法解析为json。"}
if not user_data or "id" not in user_data:
return {"code": 102, "data": f"未填写id。"}
msnID = user_data["id"]
doc_unit = _DocUnitManager.get(msnID)
if not doc_unit:
return {"code": 103, "data": f"任务 {msnID} 不存在。"}
is_data = user_data.get("is_data", False)
format = user_data.get("format", "dict")
is_unread = user_data.get("is_unread", False)
return doc_unit.get_result(is_data, format, is_unread)
"""
获取文件,方法:POST
json参数:
"id"="", # 任务ID
"file_types"=["pdfLayered"], # 输出文件类型,可选:
# ["txt", "txtPlain", "jsonl", "csv", "pdfLayered", "pdfOneLayer"]
"ingore_blank"=True, # 忽略空白页数
返回值: {}
"""
@UmiWeb.route("/api/doc/download", method="POST")
def _download_build():
try:
user_data = request.json
except Exception as e:
return {"code": 101, "data": f"请求无法解析为json。"}
if not user_data or "id" not in user_data:
return {"code": 102, "data": f"未填写id。"}
msnID = user_data["id"]
doc_unit = _DocUnitManager.get(msnID)
if not doc_unit:
return {"code": 103, "data": f"任务 {msnID} 不存在。"}
file_types = user_data.get("file_types", ["pdfLayered"])
ingore_blank = user_data.get("ingore_blank", True)
parsed_url = urlparse(request.url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
return doc_unit.get_files(base_url, file_types, ingore_blank)
# 下载文件
@UmiWeb.route("/api/doc/download/<id>/<download_name>")
def _download_get(id, download_name):
dir = os.path.join(UPLOAD_DIR, id)
path = os.path.join(dir, download_name)
# 安全检测: path 是否在 UPLOAD_DIR 中
if os.path.commonpath([UPLOAD_DIR]) != os.path.commonpath([UPLOAD_DIR, path]):
raise HTTPError(103, "[Error] Unauthorized path.")
return static_file(download_name, root=dir)
# 清理任务
@UmiWeb.route("/api/doc/clear/<id>")
def _clear(id):
flag = _DocUnitManager.clear(id)
if flag:
return {"code": 100, "data": "Success"}
return {"code": 101, "data": f"{id} does not exist."}