大家好,欢迎来到IT知识分享网。
目录
一、Surya模型检测使用Python接口中的源码详解
使用surya源码进行模型检测的过程中, 模型的各种参数设置、环境变量配置都写在下载的源码下 /surya/settings.py 文件中,修改其中的参数即可实现全局的配置。
1.选择模型检测GPU
修改 settings.py 文件中的 TORCH_DEVICE 参数,默认为 None 时,运行监测代码会自动检查当前设备,并选择索引顺序最小的GPU运行——‘cuda:0’。在实际部署中,如果服务器中第一块GPU——‘cuda:0’有其他模型在跑,可能需要调整模型预测位置,将模型放到另一块GPU上运行。只需修改以下参数代码。
# 指定模型所在GPU TORCH_DEVICE: Optional[str] = 'cuda:1'
2.配置加载模型参数
在 settings.py 文件中模型参数默认为在线加载,当服务器无法连接外部网络时,离线部署加载模型参数需调整设置中的地址。需根据自己模型存放位置,修改为下面的参数字符串。
地址需改为你存放模型的绝对地址 # 文本行检测模型 DETECTOR_MODEL_CHECKPOINT: str = "//Surya-OCR/hugging_model/surya_det2" # 文本区域检测模型 LAYOUT_MODEL_CHECKPOINT: str = "//Surya-OCR/hugging_model/surya_layout"
默认在线加载模型参数位置:
参数修改内容:(为你存放离线下载模型地址)
测试模型是否加载成功的代码如下:
from surya.model.detection.segformer import load_model, load_processor from surya.settings import settings # 行检测模型:surya_det_2 det_model = load_model() det_processor = load_processor() print('det2_model load success') # 区域检测:surya_layout model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) print('layout_model load success')
修改后,检验模型加载情况:
3.批量检测图片
实际部署中,官方文档提供 Python 接口检测的代码对单个图片检测顺利,但在批量检测图片集——文件夹时报错。 (单图代码写在上一篇博文《从零开始使用Surya-OCR——文本目标检测模型的安装与部署》中,具体可参考从零开始使用Surya-OCR——项目源码拆解)
第一个报错是,官方提供接口函数,无法读取文件夹内图片,报读取文件权限被拒。暂未实现直接解决该问题的办法。参看后续 batch_text_detection 源代码传参信息,得知图片读取后的传入函数结果是一个列表,可以选择替代方案实现同等效果,替代方案代码如下。
import os from PIL import Image from surya.detection import batch_text_detection from surya.model.detection.segformer import load_model, load_processor IMAGE_PATH = 'image_path' model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) print('model load success') # 批量将文件夹图片读入images列表中 images = [] for file in os.listdir(IMAGE_PATH): image_path = os.path.join(IMAGE_PATH, file) image = Image.open(image_path) images.append(image) predictions = batch_text_detection(images, model, processor) print(predictions)
使用新代码后,原本的报错问题解决了,但出现了新的报错。
第二个报错是,检查surya模型做批量预测任务时,在得到模型输出后,还会对多张图片的结果进行多进程处理。问题就在多进程处理时,源代码内置函数重复调用主函数——导致模型本来只需加载一次,此时不断加载模型致使GPU显存爆了。我们找到对应报错的源码处检查。
按住 Ctrl 并点击报错的 batch_text_detection 函数,即可进入源码处,检查研究发现可能是Windows系统和 Linux 系统对于多进程和多线程的解释存在差异,本机是windows系统, 可能在导入 ProcessPoolExecutor 函数时,将我的主函数视为多进程对象,不断创建新的进程。但是此处希望实现的是中间过程结果的多线程处理,而不影响主函数。因此需要将 ProcessPoolExecutor 改为使用多线程的 ThreadPoolExecutor,问题即可解决。
经过上述所有源代码修改处理,成功运行主函数,得到surya模型批量检测图片后得到的框信息结果,并将其打印出来。
4.检测输出结果源码解读
Surya模型输出是自定义类的数据格式,下面根据其官方文档和项目源码解读其输出的格式,以方便后续对输出的处理,提取出所需的数据信息。
官方文档:https://github.com/VikParuchuri/surya
Surya 模型有三种预测模式——OCR & Text Line & Layout,对应三种模型输出的格式,每种模式的输出都是以类的形式定义的。下面重点放在 Text Line 文本行检测和 Layout 区域检测的源码信息解读上。
①文本行检测的模型输出——Text Line
将与输出相关的源代码从项目中单独提取出来看,下面是输出的基础类,即每个图片模型预测后的信息都封装在了 TextDetectionResult 里面。
"文本行检测" # 输出的基础类 class TextDetectionResult(BaseModel): bboxes: List[PolygonBox] vertical_lines: List[ColumnLine] horizontal_lines: List[ColumnLine] heatmap: Any affinity_map: Any image_bbox: List[float]
下面分别解释输出基础类中的具体信息都是怎么定义的,从源码中提出相关代码。
输出的第一个类信息:PolygonBox 注解
(下面非完整代码,为清晰类输出含义,只选取主要功能代码)
# 输出框信息类 class PolygonBox(BaseModel): polygon: List[List[float]] 存储框四个角——全坐标 confidence: Optional[float] = None 框预测置信度 def bbox(self) -> List[float]: box = [self.polygon[0][0], self.polygon[0][1], self.polygon[1][0], self.polygon[2][1]] if box[0] > box[2]: box[0], box[2] = box[2], box[0] if box[1] > box[3]: box[1], box[3] = box[3], box[1] return box 存储框左上右下——对角坐标
通过源码可知,此类保存的是检测框的坐标信息和置信度,这是预测中的主要信息。使用具体的模型输出结果,可以更清楚该类的输出形状。输出TextDetectionResult 包含多个类信息,其中定义的 bboxes ——即 PolygonBox 存储的是一张图片内检测出来的所有框,而每个框的信息结构是包含三个子类:全坐标、置信度和对角坐标。
输出的第二个类信息:ColumnLine 注解
(下面非完整代码,为清晰类输出含义,只选取主要功能代码)
# 图片中线的检测 class ColumnLine(Bbox): vertical: bool # 垂直线:有-True;无-False horizontal: bool # 水平线:有-True;无-False # 检测线的框坐标 class Bbox(BaseModel): bbox: List[float]
同样通过源码可知,输出的第二个类信息 ColumnLine 是用来保存在模型预测中对图片检测到的水平线、垂直线的信息,如果检测到了那么就同时保存其框位置。通过具体模型输出对应位置检索,可以清晰理解。
输出的剩余类信息:heatmap、affinity_map、image_bbox
剩余的类信息是对图片结果输出的补充,我们可以直接将其打印出来,看看其内容。最容易看出来的是 image_bbox ,其实际就是用一个框将整个图片框起来,然后返回对角坐标。而heatmap、affinity_map 则是 PIL.Image.Image 的类信息。
②文本区域检测的模型输出 ——Layout
同文本行检测一样,废话少说,直接上代码和图。
# 区域检测输出 class LayoutResult(BaseModel): bboxes: List[LayoutBox] segmentation_map: Any image_bbox: List[float]、 class LayoutBox(PolygonBox): label: str 多了一个区域类别的预测
5.批量信息的保存和可视化
将surya模型自定义的类输出转化为 json 格式保存到指定文件夹,代码如下。
import os import json from PIL import Image from surya.detection import batch_text_detection from surya.model.detection.segformer import load_model, load_processor import cv2 import numpy as np IMAGE_PATH = 'iamge_path' 检测图片保存地址 json_file = 'json_path' 框json保存地址 checkpoint = 'model_path' 模型参数加载地址 heat_file = 'heat_path' 热图保存 上述为修改部分,根据实际地址填入,下面无需修改 model, processor = load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint) print('model load success') # 模型预测 images = [] image_name = [] for file in os.listdir(IMAGE_PATH): image_path = os.path.join(IMAGE_PATH, file) image = Image.open(image_path) images.append(image) image_name.append(file) predictions = batch_text_detection(images, model, processor) print('predict success') # 保存模型结果 类型转为json def class_to_json(bboxes, file, box_type=True): json_list = [] for i, bbox in enumerate(bboxes): if box_type: json_dict = dict() box = bbox.bbox box.append(bbox.confidence) json_dict["id"] = i json_dict["name"] = file json_dict["box"] = box json_list.append(json_dict) else: json_dict = dict() box = bbox.bbox json_dict["id"] = i json_dict["name"] = file json_dict["box"] = box json_list.append(json_dict) return json_list 保存到指定文件夹 def save_json(json_list, json_path): with open(json_path, 'w') as f: json.dump(json_list, f) 主函数 def save_predict(predictions, image_name, heat_file): for i, pred in enumerate(predictions): # 框信息保存 bboxes = pred.bboxes vertical = pred.vertical_lines horizontal = pred.horizontal_lines file = image_name[i] bboxes_json = class_to_json(bboxes, file) vertical_json = class_to_json(vertical, file, box_type=False) horizontal_json = class_to_json(horizontal, file, box_type=False) basename = file.split('.')[0] save_json(bboxes_json, os.path.join(json_file+'box/', basename + '.json')) save_json(vertical_json, os.path.join(json_file + 'vertical/', basename + '.json')) save_json(horizontal_json, os.path.join(json_file + 'horizontal/', basename + '.json')) # 热图调参信息保存 heatmap = pred.heatmap img = cv2.cvtColor(np.asarray(heatmap), cv2.COLOR_RGB2BGR) cv2.imwrite(heat_file+basename+'.jpg', img) print(basename + ' success') if __name__ == '__main__': save_predict(predictions, image_name, heat_file)
可视化框的代码如下。
import os import json import cv2 # jpg、json、vis文件位置 jpg_path = 'JPG' json_path = 'JSON' vis_path = 'VIS' 上述为修改部分,根据实际地址填入,下面无需修改 # 可视化锚框 锚框展示细节 def hsv2bgr(h, s, v): h_i = int(h * 6) f = h * 6 - h_i p = v * (1 - s) q = v * (1 - f * s) t = v * (1 - (1 - f) * s) r, g, b = 0, 0, 0 if h_i == 0: r, g, b = v, t, p elif h_i == 1: r, g, b = q, v, p elif h_i == 2: r, g, b = p, v, t elif h_i == 3: r, g, b = p, q, v elif h_i == 4: r, g, b = t, p, v elif h_i == 5: r, g, b = v, p, q return int(b * 255), int(g * 255), int(r * 255) def random_color(id): h_plane = (((id << 2) ^ 0x) % 100) / 100.0 s_plane = (((id << 3) ^ 0x) % 100) / 100.0 return hsv2bgr(h_plane, s_plane, 1) # 可视化主函数 def visualize(json_path, jpg_path, vis_path, box_type=True): if box_type: for file in os.listdir(json_path): with open(json_path+file,'r') as f: drawResult = json.load(f) basefile = file.split('.')[0] jpg_file = os.path.join(jpg_path,basefile+".jpg") img = cv2.imread(jpg_file) for idx, result in enumerate(drawResult): left, top, right, bottom = int(result['box'][0]), int(result['box'][1]), int(result['box'][2]), int(result['box'][3]) label = int(result['box'][4]) color = random_color(1) cv2.rectangle(img, (left, top), (right, bottom), color=color ,thickness=2, lineType=cv2.LINE_AA) caption = f"{'ZW'}" w, h = cv2.getTextSize(caption, 0, 1, 2)[0] cv2.rectangle(img, (left - 3, top - 33), (left + w + 10, top), color, -1) cv2.putText(img, caption, (left, top - 5), 0, 1, (0, 0, 0), 2, 16) save_file = os.path.join(vis_path, basefile+".jpg") print(save_file) cv2.imwrite(save_file, img) else: for file in os.listdir(json_path): with open(json_path+file,'r') as f: drawResult = json.load(f) basefile = file.split('.')[0] jpg_file = os.path.join(jpg_path,basefile+".jpg") img = cv2.imread(jpg_file) for idx, result in enumerate(drawResult): left, top, right, bottom = int(result['box'][0]), int(result['box'][1]), int(result['box'][2]), int(result['box'][3]) color = random_color(1) cv2.rectangle(img, (left, top), (right, bottom), color=color ,thickness=2, lineType=cv2.LINE_AA) caption = f"{'ZW'}" w, h = cv2.getTextSize(caption, 0, 1, 2)[0] cv2.rectangle(img, (left - 3, top - 33), (left + w + 10, top), color, -1) cv2.putText(img, caption, (left, top - 5), 0, 1, (0, 0, 0), 2, 16) save_file = os.path.join(vis_path, basefile+".jpg") print(save_file) cv2.imwrite(save_file, img) if __name__ == '__main__': visualize(json_path+'box/', jpg_path, vis_path+'box/') visualize(json_path + 'vertical/', jpg_path, vis_path + 'vertical/', box_type=False) visualize(json_path + 'horizontal/', jpg_path, vis_path + 'horizontal/', box_type=False)
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/131940.html