diff --git a/src/arcaea_offline_ocr/__init__.py b/src/arcaea_offline_ocr/__init__.py index 41057d6..c2e0b50 100644 --- a/src/arcaea_offline_ocr/__init__.py +++ b/src/arcaea_offline_ocr/__init__.py @@ -1,5 +1,4 @@ from .crop import * from .device import * -from .mask import * from .ocr import * from .utils import * diff --git a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py index c033088..07dd434 100644 --- a/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py +++ b/src/arcaea_offline_ocr/b30/chieri/v4/ocr.py @@ -8,7 +8,6 @@ from PIL import Image from ....crop import crop_xywh from ....ocr import FixRects, ocr_digits_by_contour_knn, preprocess_hog from ....phash_db import ImagePHashDatabase -from ....sift_db import SIFTDatabase from ....types import Mat, cv2_ml_KNearest from ....utils import construct_int_xywh_rect from ...shared import B30OcrResultItem diff --git a/src/arcaea_offline_ocr/mask.py b/src/arcaea_offline_ocr/mask.py deleted file mode 100644 index 085c99d..0000000 --- a/src/arcaea_offline_ocr/mask.py +++ /dev/null @@ -1,119 +0,0 @@ -import cv2 -import numpy as np - -from .types import Mat - -__all__ = [ - "GRAY_MIN_HSV", - "GRAY_MAX_HSV", - "WHITE_MIN_HSV", - "WHITE_MAX_HSV", - "PFL_WHITE_MIN_HSV", - "PFL_WHITE_MAX_HSV", - "PST_MIN_HSV", - "PST_MAX_HSV", - "PRS_MIN_HSV", - "PRS_MAX_HSV", - "FTR_MIN_HSV", - "FTR_MAX_HSV", - "BYD_MIN_HSV", - "BYD_MAX_HSV", - "MAX_RECALL_PURPLE_MIN_HSV", - "MAX_RECALL_PURPLE_MAX_HSV", - "mask_gray", - "mask_white", - "mask_pfl_white", - "mask_pst", - "mask_prs", - "mask_ftr", - "mask_byd", - "mask_rating_class", - "mask_max_recall_purple", -] - -GRAY_MIN_HSV = np.array([0, 0, 70], np.uint8) -GRAY_MAX_HSV = np.array([0, 0, 200], np.uint8) - -GRAY_MIN_BGR = np.array([50] * 3, np.uint8) -GRAY_MAX_BGR = np.array([160] * 3, np.uint8) - -WHITE_MIN_HSV = np.array([0, 0, 240], np.uint8) -WHITE_MAX_HSV = np.array([179, 10, 255], np.uint8) - -PFL_WHITE_MIN_HSV = np.array([0, 0, 248], np.uint8) -PFL_WHITE_MAX_HSV = np.array([179, 10, 255], np.uint8) - -PST_MIN_HSV = np.array([100, 50, 80], np.uint8) -PST_MAX_HSV = np.array([100, 255, 255], np.uint8) - -PRS_MIN_HSV = np.array([43, 40, 75], np.uint8) -PRS_MAX_HSV = np.array([50, 155, 190], np.uint8) - -FTR_MIN_HSV = np.array([149, 30, 0], np.uint8) -FTR_MAX_HSV = np.array([155, 181, 150], np.uint8) - -BYD_MIN_HSV = np.array([170, 50, 50], np.uint8) -BYD_MAX_HSV = np.array([179, 210, 198], np.uint8) - -MAX_RECALL_PURPLE_MIN_HSV = np.array([125, 0, 0], np.uint8) -MAX_RECALL_PURPLE_MAX_HSV = np.array([130, 100, 150], np.uint8) - - -def mask_gray(__img_bgr: Mat): - # bgr_value_equal_mask = all(__img_bgr[:, 1:] == __img_bgr[:, :-1], axis=1) - bgr_value_equal_mask = np.max(__img_bgr, axis=2) - np.min(__img_bgr, axis=2) <= 5 - img_bgr = __img_bgr.copy() - img_bgr[~bgr_value_equal_mask] = np.array([0, 0, 0], __img_bgr.dtype) - img_bgr = cv2.erode(img_bgr, cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))) - img_bgr = cv2.dilate(img_bgr, cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))) - return cv2.inRange(img_bgr, GRAY_MIN_BGR, GRAY_MAX_BGR) - - -def mask_white(img_hsv: Mat): - mask = cv2.inRange(img_hsv, WHITE_MIN_HSV, WHITE_MAX_HSV) - mask = cv2.dilate(mask, cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))) - return mask - - -def mask_pfl_white(img_hsv: Mat): - mask = cv2.inRange(img_hsv, PFL_WHITE_MIN_HSV, PFL_WHITE_MAX_HSV) - mask = cv2.dilate(mask, cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))) - return mask - - -def mask_pst(img_hsv: Mat): - mask = cv2.inRange(img_hsv, PST_MIN_HSV, PST_MAX_HSV) - mask = cv2.dilate(mask, (1, 1)) - return mask - - -def mask_prs(img_hsv: Mat): - mask = cv2.inRange(img_hsv, PRS_MIN_HSV, PRS_MAX_HSV) - mask = cv2.dilate(mask, (1, 1)) - return mask - - -def mask_ftr(img_hsv: Mat): - mask = cv2.inRange(img_hsv, FTR_MIN_HSV, FTR_MAX_HSV) - mask = cv2.dilate(mask, (1, 1)) - return mask - - -def mask_byd(img_hsv: Mat): - mask = cv2.inRange(img_hsv, BYD_MIN_HSV, BYD_MAX_HSV) - mask = cv2.dilate(mask, (2, 2)) - return mask - - -def mask_rating_class(img_hsv: Mat): - pst = mask_pst(img_hsv) - prs = mask_prs(img_hsv) - ftr = mask_ftr(img_hsv) - byd = mask_byd(img_hsv) - return cv2.bitwise_or(byd, cv2.bitwise_or(ftr, cv2.bitwise_or(pst, prs))) - - -def mask_max_recall_purple(img_hsv: Mat): - mask = cv2.inRange(img_hsv, MAX_RECALL_PURPLE_MIN_HSV, MAX_RECALL_PURPLE_MAX_HSV) - mask = cv2.dilate(mask, (2, 2)) - return mask diff --git a/src/arcaea_offline_ocr/ocr.py b/src/arcaea_offline_ocr/ocr.py index 1c9c36e..95eb0dc 100644 --- a/src/arcaea_offline_ocr/ocr.py +++ b/src/arcaea_offline_ocr/ocr.py @@ -7,7 +7,6 @@ import numpy as np from numpy.linalg import norm from .crop import crop_xywh -from .mask import mask_byd, mask_ftr, mask_prs, mask_pst from .types import Mat, cv2_ml_KNearest __all__ = [ @@ -199,13 +198,3 @@ def ocr_digits_by_contour_knn( ) -> int: samples = ocr_digits_by_contour_get_samples(__roi_gray, size) return ocr_digit_samples_knn(samples, knn_model, k) - - -def ocr_rating_class(roi_hsv: Mat): - mask_results = [ - mask_pst(roi_hsv), - mask_prs(roi_hsv), - mask_ftr(roi_hsv), - mask_byd(roi_hsv), - ] - return max(enumerate(mask_results), key=lambda e: np.count_nonzero(e[1]))[0] diff --git a/src/arcaea_offline_ocr/sift_db.py b/src/arcaea_offline_ocr/sift_db.py deleted file mode 100644 index d249dae..0000000 --- a/src/arcaea_offline_ocr/sift_db.py +++ /dev/null @@ -1,110 +0,0 @@ -import io -import sqlite3 -from gzip import GzipFile -from typing import Tuple - -import cv2 -import numpy as np - -from .types import Mat - - -class SIFTDatabase: - def __init__(self, db_path: str, load: bool = True): - self.__db_path = db_path - self.__tags = [] - self.__descriptors = [] - self.__size = None - - self.__sift = cv2.SIFT_create() - self.__bf_matcher = cv2.BFMatcher() - - if load: - self.load_db() - - @property - def db_path(self): - return self.__db_path - - @db_path.setter - def db_path(self, value): - self.__db_path = value - - @property - def tags(self): - return self.__tags - - @property - def descriptors(self): - return self.__descriptors - - @property - def size(self): - return self.__size - - @size.setter - def size(self, value: Tuple[int, int]): - self.__size = value - - @property - def sift(self): - return self.__sift - - @property - def bf_matcher(self): - return self.__bf_matcher - - def load_db(self): - conn = sqlite3.connect(self.db_path) - with conn: - cursor = conn.cursor() - - size_str = cursor.execute( - "SELECT value FROM properties WHERE id = 'size'" - ).fetchone()[0] - sizr_str_arr = size_str.split(", ") - self.size = tuple(int(s) for s in sizr_str_arr) - tag__descriptors_bytes = cursor.execute( - "SELECT tag, descriptors FROM sift" - ).fetchall() - - gzipped = int( - cursor.execute( - "SELECT value FROM properties WHERE id = 'gzip'" - ).fetchone()[0] - ) - for tag, descriptor_bytes in tag__descriptors_bytes: - buffer = io.BytesIO(descriptor_bytes) - self.tags.append(tag) - if gzipped == 0: - self.descriptors.append(np.load(buffer)) - else: - gzipped_buffer = GzipFile(None, "rb", fileobj=buffer) - self.descriptors.append(np.load(gzipped_buffer)) - - def lookup_img( - self, - __img: Mat, - *, - sift=None, - bf=None, - ) -> Tuple[str, float]: - sift = sift or self.sift - bf = bf or self.bf_matcher - - img = __img.copy() - if self.size is not None: - img = cv2.resize(img, self.size) - _, descriptors = sift.detectAndCompute(img, None) - - good_results = [] - for des in self.descriptors: - matches = bf.knnMatch(descriptors, des, k=2) - good = sum(m.distance < 0.75 * n.distance for m, n in matches) - good_results.append(good) - best_match_index = max(enumerate(good_results), key=lambda i: i[1])[0] - - return ( - self.tags[best_match_index], - good_results[best_match_index] / len(descriptors), - )