diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..273404c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,14 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort diff --git a/pyproject.toml b/pyproject.toml index 85543d8..535c215 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ classifiers = [ "Homepage" = "https://github.com/283375/arcaea-offline-ocr" "Bug Tracker" = "https://github.com/283375/arcaea-offline-ocr/issues" +[tool.black] +extend-exclude = 'src/arcaea_offline_ocr/_builtin_templates.py' + [tool.isort] profile = "black" src_paths = ["src/arcaea_offline_ocr"] diff --git a/src/arcaea_offline_ocr/__init__.py b/src/arcaea_offline_ocr/__init__.py index e69de29..074dcb1 100644 --- a/src/arcaea_offline_ocr/__init__.py +++ b/src/arcaea_offline_ocr/__init__.py @@ -0,0 +1,6 @@ +from .crop import * +from .device import * +from .mask import * +from .ocr import * +from .recognize import * +from .template import * diff --git a/src/arcaea_offline_ocr/crop.py b/src/arcaea_offline_ocr/crop.py index 0a3aedd..6ea48f5 100644 --- a/src/arcaea_offline_ocr/crop.py +++ b/src/arcaea_offline_ocr/crop.py @@ -4,6 +4,18 @@ from cv2 import Mat from .device import Device +__all__ = [ + "crop_img", + "crop_from_device_attr", + "crop_to_pure", + "crop_to_far", + "crop_to_lost", + "crop_to_max_recall", + "crop_to_rating_class", + "crop_to_score", + "crop_to_title", +] + def crop_img(img: Mat, *, top: int, left: int, bottom: int, right: int): return img[top:bottom, left:right] diff --git a/src/arcaea_offline_ocr/device.py b/src/arcaea_offline_ocr/device.py index 6af8f3f..d03b9f0 100644 --- a/src/arcaea_offline_ocr/device.py +++ b/src/arcaea_offline_ocr/device.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from typing import Any, Dict, Tuple +__all__ = ["Device"] + @dataclass(kw_only=True) class Device: diff --git a/src/arcaea_offline_ocr/mask.py b/src/arcaea_offline_ocr/mask.py index 2bf7be4..4008b73 100644 --- a/src/arcaea_offline_ocr/mask.py +++ b/src/arcaea_offline_ocr/mask.py @@ -1,6 +1,28 @@ from cv2 import BORDER_CONSTANT, BORDER_ISOLATED, Mat, bitwise_or, dilate, inRange from numpy import array, uint8 +__all__ = [ + "GRAY_MIN_HSV", + "GRAY_MAX_HSV", + "WHITE_MIN_HSV", + "WHITE_MAX_HSV", + "PST_MIN_HSV", + "PST_MAX_HSV", + "PRS_MIN_HSV", + "PRS_MAX_HSV", + "FTR_MIN_HSV", + "FTR_MAX_HSV", + "BYD_MIN_HSV", + "BYD_MAX_HSV", + "mask_gray", + "mask_white", + "mask_pst", + "mask_prs", + "mask_ftr", + "mask_byd", + "mask_rating_class", +] + GRAY_MIN_HSV = array([0, 0, 70], uint8) GRAY_MAX_HSV = array([0, 70, 200], uint8) diff --git a/src/arcaea_offline_ocr/ocr.py b/src/arcaea_offline_ocr/ocr.py index a705790..af78c46 100644 --- a/src/arcaea_offline_ocr/ocr.py +++ b/src/arcaea_offline_ocr/ocr.py @@ -11,6 +11,19 @@ from .template import ( matchTemplateMultiple, ) +__all__ = [ + "group_numbers", + "FilterDigitResultDict", + "filter_digit_results", + "ocr_digits", + "ocr_pure", + "ocr_far_lost", + "ocr_score", + "ocr_max_recall", + "ocr_rating_class", + "ocr_title", +] + def group_numbers(numbers: List[int], threshold: int) -> List[List[int]]: """ diff --git a/src/arcaea_offline_ocr/recognize.py b/src/arcaea_offline_ocr/recognize.py index ef4d5b1..0caa66e 100644 --- a/src/arcaea_offline_ocr/recognize.py +++ b/src/arcaea_offline_ocr/recognize.py @@ -8,6 +8,19 @@ from .device import Device from .mask import * from .ocr import * +__all__ = [ + "process_digits_ocr_img", + "process_tesseract_ocr_img", + "recognize_pure", + "recognize_far_lost", + "recognize_score", + "recognize_max_recall", + "recognize_rating_class", + "recognize_title", + "RecognizeResult", + "recognize", +] + def process_digits_ocr_img(img_hsv_cropped: Mat, mask=Callable[[Mat], Mat]): img_hsv_cropped = mask(img_hsv_cropped) diff --git a/src/arcaea_offline_ocr/template.py b/src/arcaea_offline_ocr/template.py index 2d5c2b9..53f3bad 100644 --- a/src/arcaea_offline_ocr/template.py +++ b/src/arcaea_offline_ocr/template.py @@ -32,6 +32,13 @@ from numpy import uint8 from ._builtin_templates import GeoSansLight_Italic, GeoSansLight_Regular +__all__ = [ + "load_digit_template", + "load_builtin_digit_template", + "MatchTemplateMultipleResult", + "matchTemplateMultiple", +] + def load_digit_template(filename: str) -> Dict[int, Mat]: """