10 Commits

19 changed files with 425 additions and 79 deletions

View File

@ -4,11 +4,10 @@ repos:
hooks: hooks:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.1.0 - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.13
hooks: hooks:
- id: black - id: ruff
- repo: https://github.com/PyCQA/isort args: ["--fix"]
rev: 5.12.0 - id: ruff-format
hooks:
- id: isort

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "arcaea-offline-ocr" name = "arcaea-offline-ocr"
version = "0.0.98" version = "0.0.99"
authors = [{ name = "283375", email = "log_283375@163.com" }] authors = [{ name = "283375", email = "log_283375@163.com" }]
description = "Extract your Arcaea play result from screenshot." description = "Extract your Arcaea play result from screenshot."
readme = "README.md" readme = "README.md"

View File

@ -1,3 +1,2 @@
attrs==23.1.0 numpy~=2.3
numpy==1.26.1 opencv-python~=4.11
opencv-python==4.8.1.78

View File

@ -1,4 +1,3 @@
from math import floor
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import cv2 import cv2
@ -13,9 +12,21 @@ from ....ocr import (
) )
from ....phash_db import ImagePhashDatabase from ....phash_db import ImagePhashDatabase
from ....types import Mat from ....types import Mat
from ....utils import construct_int_xywh_rect
from ...shared import B30OcrResultItem from ...shared import B30OcrResultItem
from .colors import * from .colors import (
BYD_MAX_HSV,
BYD_MIN_HSV,
FAR_BG_MAX_HSV,
FAR_BG_MIN_HSV,
FTR_MAX_HSV,
FTR_MIN_HSV,
LOST_BG_MAX_HSV,
LOST_BG_MIN_HSV,
PRS_MAX_HSV,
PRS_MIN_HSV,
PURE_BG_MAX_HSV,
PURE_BG_MIN_HSV,
)
from .rois import ChieriBotV4Rois from .rois import ChieriBotV4Rois
@ -25,7 +36,7 @@ class ChieriBotV4Ocr:
score_knn: cv2.ml.KNearest, score_knn: cv2.ml.KNearest,
pfl_knn: cv2.ml.KNearest, pfl_knn: cv2.ml.KNearest,
phash_db: ImagePhashDatabase, phash_db: ImagePhashDatabase,
factor: Optional[float] = 1.0, factor: float = 1.0,
): ):
self.__score_knn = score_knn self.__score_knn = score_knn
self.__pfl_knn = pfl_knn self.__pfl_knn = pfl_knn
@ -72,9 +83,8 @@ class ChieriBotV4Ocr:
self.factor = img.shape[0] / 4400 self.factor = img.shape[0] / 4400
def ocr_component_rating_class(self, component_bgr: Mat) -> int: def ocr_component_rating_class(self, component_bgr: Mat) -> int:
rating_class_rect = construct_int_xywh_rect( rating_class_rect = self.rois.component_rois.rating_class_rect.rounded()
self.rois.component_rois.rating_class_rect
)
rating_class_roi = crop_xywh(component_bgr, rating_class_rect) rating_class_roi = crop_xywh(component_bgr, rating_class_rect)
rating_class_roi = cv2.cvtColor(rating_class_roi, cv2.COLOR_BGR2HSV) rating_class_roi = cv2.cvtColor(rating_class_roi, cv2.COLOR_BGR2HSV)
rating_class_masks = [ rating_class_masks = [
@ -89,9 +99,7 @@ class ChieriBotV4Ocr:
return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1 return max(enumerate(rating_class_results), key=lambda i: i[1])[0] + 1
def ocr_component_song_id(self, component_bgr: Mat): def ocr_component_song_id(self, component_bgr: Mat):
jacket_rect = construct_int_xywh_rect( jacket_rect = self.rois.component_rois.jacket_rect.floored()
self.rois.component_rois.jacket_rect, floor
)
jacket_roi = cv2.cvtColor( jacket_roi = cv2.cvtColor(
crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY crop_xywh(component_bgr, jacket_rect), cv2.COLOR_BGR2GRAY
) )
@ -99,7 +107,7 @@ class ChieriBotV4Ocr:
def ocr_component_score_knn(self, component_bgr: Mat) -> int: def ocr_component_score_knn(self, component_bgr: Mat) -> int:
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
score_rect = construct_int_xywh_rect(self.rois.component_rois.score_rect) score_rect = self.rois.component_rois.score_rect.rounded()
score_roi = cv2.cvtColor( score_roi = cv2.cvtColor(
crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY crop_xywh(component_bgr, score_rect), cv2.COLOR_BGR2GRAY
) )
@ -119,7 +127,9 @@ class ChieriBotV4Ocr:
score_roi = cv2.fillPoly(score_roi, [contour], 0) score_roi = cv2.fillPoly(score_roi, [contour], 0)
return ocr_digits_by_contour_knn(score_roi, self.score_knn) return ocr_digits_by_contour_knn(score_roi, self.score_knn)
def find_pfl_rects(self, component_pfl_processed: Mat) -> List[List[int]]: def find_pfl_rects(
self, component_pfl_processed: Mat
) -> List[Tuple[int, int, int, int]]:
# sourcery skip: inline-immediately-returned-variable # sourcery skip: inline-immediately-returned-variable
pfl_roi_find = cv2.morphologyEx( pfl_roi_find = cv2.morphologyEx(
component_pfl_processed, component_pfl_processed,
@ -146,7 +156,7 @@ class ChieriBotV4Ocr:
return pfl_rects_adjusted return pfl_rects_adjusted
def preprocess_component_pfl(self, component_bgr: Mat) -> Mat: def preprocess_component_pfl(self, component_bgr: Mat) -> Mat:
pfl_rect = construct_int_xywh_rect(self.rois.component_rois.pfl_rect) pfl_rect = self.rois.component_rois.pfl_rect.rounded()
pfl_roi = crop_xywh(component_bgr, pfl_rect) pfl_roi = crop_xywh(component_bgr, pfl_rect)
pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV) pfl_roi_hsv = cv2.cvtColor(pfl_roi, cv2.COLOR_BGR2HSV)

View File

@ -1,12 +1,12 @@
from typing import List, Optional from typing import List
from ....crop import crop_xywh from ....crop import crop_xywh
from ....types import Mat, XYWHRect from ....types import Mat, XYWHRect
from ....utils import apply_factor, construct_int_xywh_rect from ....utils import apply_factor
class ChieriBotV4ComponentRois: class ChieriBotV4ComponentRois:
def __init__(self, factor: Optional[float] = 1.0): def __init__(self, factor: float = 1.0):
self.__factor = factor self.__factor = factor
@property @property
@ -19,11 +19,11 @@ class ChieriBotV4ComponentRois:
@property @property
def top_font_color_detect(self): def top_font_color_detect(self):
return apply_factor((35, 10, 120, 100), self.factor) return apply_factor(XYWHRect(35, 10, 120, 100), self.factor)
@property @property
def bottom_font_color_detect(self): def bottom_font_color_detect(self):
return apply_factor((30, 125, 175, 110), self.factor) return apply_factor(XYWHRect(30, 125, 175, 110), self.factor)
@property @property
def bg_point(self): def bg_point(self):
@ -31,31 +31,31 @@ class ChieriBotV4ComponentRois:
@property @property
def rating_class_rect(self): def rating_class_rect(self):
return apply_factor((21, 40, 7, 20), self.factor) return apply_factor(XYWHRect(21, 40, 7, 20), self.factor)
@property @property
def title_rect(self): def title_rect(self):
return apply_factor((35, 10, 430, 50), self.factor) return apply_factor(XYWHRect(35, 10, 430, 50), self.factor)
@property @property
def jacket_rect(self): def jacket_rect(self):
return apply_factor((263, 0, 239, 239), self.factor) return apply_factor(XYWHRect(263, 0, 239, 239), self.factor)
@property @property
def score_rect(self): def score_rect(self):
return apply_factor((30, 60, 270, 55), self.factor) return apply_factor(XYWHRect(30, 60, 270, 55), self.factor)
@property @property
def pfl_rect(self): def pfl_rect(self):
return apply_factor((50, 125, 80, 100), self.factor) return apply_factor(XYWHRect(50, 125, 80, 100), self.factor)
@property @property
def date_rect(self): def date_rect(self):
return apply_factor((205, 200, 225, 25), self.factor) return apply_factor(XYWHRect(205, 200, 225, 25), self.factor)
class ChieriBotV4Rois: class ChieriBotV4Rois:
def __init__(self, factor: Optional[float] = 1.0): def __init__(self, factor: float = 1.0):
self.__factor = factor self.__factor = factor
self.__component_rois = ChieriBotV4ComponentRois(factor) self.__component_rois = ChieriBotV4ComponentRois(factor)
@ -100,9 +100,7 @@ class ChieriBotV4Rois:
def horizontal_items(self): def horizontal_items(self):
return 3 return 3
@property vertical_items = 10
def vertical_items(self):
return 10
@property @property
def b33_vertical_gap(self): def b33_vertical_gap(self):
@ -112,16 +110,17 @@ class ChieriBotV4Rois:
first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height) first_rect = XYWHRect(x=self.left, y=self.top, w=self.width, h=self.height)
results = [] results = []
last_rect = first_rect
for vi in range(self.vertical_items): for vi in range(self.vertical_items):
rect = XYWHRect(*first_rect) rect = XYWHRect(*first_rect)
rect += (0, (self.vertical_gap + self.height) * vi, 0, 0) rect += (0, (self.vertical_gap + self.height) * vi, 0, 0)
for hi in range(self.horizontal_items): for hi in range(self.horizontal_items):
if hi > 0: if hi > 0:
rect += ((self.width + self.horizontal_gap), 0, 0, 0) rect += ((self.width + self.horizontal_gap), 0, 0, 0)
int_rect = construct_int_xywh_rect(rect) results.append(crop_xywh(img_bgr, rect.rounded()))
results.append(crop_xywh(img_bgr, int_rect)) last_rect = rect
rect += ( last_rect += (
-(self.width + self.horizontal_gap) * 2, -(self.width + self.horizontal_gap) * 2,
self.height + self.b33_vertical_gap, self.height + self.b33_vertical_gap,
0, 0,
@ -129,8 +128,7 @@ class ChieriBotV4Rois:
) )
for hi in range(self.horizontal_items): for hi in range(self.horizontal_items):
if hi > 0: if hi > 0:
rect += ((self.width + self.horizontal_gap), 0, 0, 0) last_rect += ((self.width + self.horizontal_gap), 0, 0, 0)
int_rect = construct_int_xywh_rect(rect) results.append(crop_xywh(img_bgr, last_rect.rounded()))
results.append(crop_xywh(img_bgr, int_rect))
return results return results

View File

@ -1,10 +1,9 @@
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
import attrs
@dataclass
@attrs.define
class B30OcrResultItem: class B30OcrResultItem:
rating_class: int rating_class: int
score: int score: int

View File

View File

@ -0,0 +1,3 @@
from .index import average, dct, difference
__all__ = ["average", "dct", "difference"]

View File

@ -0,0 +1,7 @@
import cv2
from arcaea_offline_ocr.types import Mat
def _resize_image(src: Mat, dsize: ...) -> Mat:
return cv2.resize(src, dsize, fx=0, fy=0, interpolation=cv2.INTER_AREA)

View File

@ -0,0 +1,35 @@
import cv2
import numpy as np
from arcaea_offline_ocr.types import Mat
from ._common import _resize_image
def average(img_gray: Mat, hash_size: int) -> Mat:
img_resized = _resize_image(img_gray, (hash_size, hash_size))
diff = img_resized > img_resized.mean()
return diff.flatten()
def difference(img_gray: Mat, hash_size: int) -> Mat:
img_size = (hash_size + 1, hash_size)
img_resized = _resize_image(img_gray, img_size)
previous = img_resized[:, :-1]
current = img_resized[:, 1:]
diff = previous > current
return diff.flatten()
def dct(img_gray: Mat, hash_size: int = 16, high_freq_factor: int = 4) -> Mat:
# TODO: consistency?
img_size_base = hash_size * high_freq_factor
img_size = (img_size_base, img_size_base)
img_resized = _resize_image(img_gray, img_size)
img_resized = img_resized.astype(np.float32)
dct_mat = cv2.dct(img_resized)
hash_mat = dct_mat[:hash_size, :hash_size]
return hash_mat > hash_mat.mean()

View File

@ -0,0 +1,18 @@
from .builder import ImageHashesDatabaseBuilder
from .index import ImageHashesDatabase, ImageHashesDatabasePropertyMissingError
from .models import (
ImageHashBuildTask,
ImageHashHashType,
ImageHashResult,
ImageHashCategory,
)
__all__ = [
"ImageHashesDatabase",
"ImageHashesDatabasePropertyMissingError",
"ImageHashHashType",
"ImageHashResult",
"ImageHashCategory",
"ImageHashesDatabaseBuilder",
"ImageHashBuildTask",
]

View File

@ -0,0 +1,85 @@
import logging
from datetime import datetime, timezone
from sqlite3 import Connection
from typing import List
from arcaea_offline_ocr.core import hashers
from .index import ImageHashesDatabase
from .models import ImageHash, ImageHashBuildTask, ImageHashHashType
logger = logging.getLogger(__name__)
class ImageHashesDatabaseBuilder:
@staticmethod
def __insert_property(conn: Connection, key: str, value: str):
return conn.execute(
"INSERT INTO properties (key, value) VALUES (?, ?)",
(key, value),
)
@classmethod
def build(
cls,
conn: Connection,
tasks: List[ImageHashBuildTask],
*,
hash_size: int = 16,
high_freq_factor: int = 4,
):
rows: List[ImageHash] = []
for task in tasks:
try:
img_gray = task.imread_function(task.image_path)
for hash_type, hash_mat in [
(
ImageHashHashType.AVERAGE,
hashers.average(img_gray, hash_size),
),
(
ImageHashHashType.DCT,
hashers.dct(img_gray, hash_size, high_freq_factor),
),
(
ImageHashHashType.DIFFERENCE,
hashers.difference(img_gray, hash_size),
),
]:
rows.append(
ImageHash(
hash_type=hash_type,
category=task.category,
label=task.label,
hash=ImageHashesDatabase.hash_mat_to_bytes(hash_mat),
)
)
except Exception:
logger.exception("Error processing task %r", task)
conn.execute("CREATE TABLE properties (`key` VARCHAR, `value` VARCHAR)")
conn.execute(
"CREATE TABLE hashes (`hash_type` INTEGER, `category` INTEGER, `label` VARCHAR, `hash` BLOB)"
)
now = datetime.now(tz=timezone.utc)
timestamp = int(now.timestamp() * 1000)
cls.__insert_property(conn, ImageHashesDatabase.KEY_HASH_SIZE, str(hash_size))
cls.__insert_property(
conn, ImageHashesDatabase.KEY_HIGH_FREQ_FACTOR, str(high_freq_factor)
)
cls.__insert_property(
conn, ImageHashesDatabase.KEY_BUILT_TIMESTAMP, str(timestamp)
)
conn.executemany(
"INSERT INTO hashes (hash_type, category, label, hash) VALUES (?, ?, ?, ?)",
[
(row.hash_type.value, row.category.value, row.label, row.hash)
for row in rows
],
)
conn.commit()

View File

@ -0,0 +1,144 @@
import sqlite3
from datetime import datetime, timezone
from typing import Any, Callable, List, Optional, TypeVar
from arcaea_offline_ocr.core import hashers
from arcaea_offline_ocr.types import Mat
from .models import ImageHashHashType, ImageHashResult, ImageHashCategory
T = TypeVar("T")
def _sql_hamming_distance(hash1: bytes, hash2: bytes):
assert len(hash1) == len(hash2), "hash size does not match!"
count = sum(1 for byte1, byte2 in zip(hash1, hash2) if byte1 != byte2)
return count
class ImageHashesDatabasePropertyMissingError(Exception):
pass
class ImageHashesDatabase:
KEY_HASH_SIZE = "hash_size"
KEY_HIGH_FREQ_FACTOR = "high_freq_factor"
KEY_BUILT_TIMESTAMP = "built_timestamp"
def __init__(self, conn: sqlite3.Connection):
self.conn = conn
self.conn.create_function("HAMMING_DISTANCE", 2, _sql_hamming_distance)
self._hash_size: int = -1
self._high_freq_factor: int = -1
self._built_time: Optional[datetime] = None
self._hashes_count = {
ImageHashCategory.JACKET: 0,
ImageHashCategory.PARTNER_ICON: 0,
}
self._hash_length: int = -1
self._initialize()
@property
def hash_size(self):
return self._hash_size
@property
def high_freq_factor(self):
return self._high_freq_factor
@property
def hash_length(self):
return self._hash_length
def _initialize(self):
def query_property(key, convert_func: Callable[[Any], T]) -> Optional[T]:
result = self.conn.execute(
"SELECT value FROM properties WHERE key = ?",
(key,),
).fetchone()
return convert_func(result[0]) if result is not None else None
def set_hashes_count(category: ImageHashCategory):
self._hashes_count[category] = self.conn.execute(
"SELECT COUNT(DISTINCT label) FROM hashes WHERE category = ?",
(category.value,),
).fetchone()[0]
hash_size = query_property(self.KEY_HASH_SIZE, lambda x: int(x))
if hash_size is None:
raise ImageHashesDatabasePropertyMissingError("hash_size")
self._hash_size = hash_size
high_freq_factor = query_property(self.KEY_HIGH_FREQ_FACTOR, lambda x: int(x))
if high_freq_factor is None:
raise ImageHashesDatabasePropertyMissingError("high_freq_factor")
self._high_freq_factor = high_freq_factor
self._built_time = query_property(
self.KEY_BUILT_TIMESTAMP,
lambda ts: datetime.fromtimestamp(int(ts) / 1000, tz=timezone.utc),
)
set_hashes_count(ImageHashCategory.JACKET)
set_hashes_count(ImageHashCategory.PARTNER_ICON)
self._hash_length = self._hash_size**2
def lookup_hash(
self, category: ImageHashCategory, hash_type: ImageHashHashType, hash: bytes
) -> List[ImageHashResult]:
cursor = self.conn.execute(
"SELECT"
" label,"
" HAMMING_DISTANCE(hash, ?) AS distance"
" FROM hashes"
" WHERE category = ? AND hash_type = ?"
" ORDER BY distance ASC LIMIT 10",
(hash, category.value, hash_type.value),
)
results = []
for label, distance in cursor.fetchall():
results.append(
ImageHashResult(
hash_type=hash_type,
category=category,
label=label,
confidence=(self.hash_length - distance) / self.hash_length,
)
)
return results
@staticmethod
def hash_mat_to_bytes(hash: Mat) -> bytes:
return bytes([255 if b else 0 for b in hash.flatten()])
def identify_image(self, category: ImageHashCategory, img) -> List[ImageHashResult]:
results = []
ahash = hashers.average(img, self.hash_size)
dhash = hashers.difference(img, self.hash_size)
phash = hashers.dct(img, self.hash_size, self.high_freq_factor)
results.extend(
self.lookup_hash(
category, ImageHashHashType.AVERAGE, self.hash_mat_to_bytes(ahash)
)
)
results.extend(
self.lookup_hash(
category, ImageHashHashType.DIFFERENCE, self.hash_mat_to_bytes(dhash)
)
)
results.extend(
self.lookup_hash(
category, ImageHashHashType.DCT, self.hash_mat_to_bytes(phash)
)
)
return results

View File

@ -0,0 +1,46 @@
import dataclasses
from enum import IntEnum
from typing import Callable
import cv2
from arcaea_offline_ocr.types import Mat
class ImageHashHashType(IntEnum):
AVERAGE = 0
DIFFERENCE = 1
DCT = 2
class ImageHashCategory(IntEnum):
JACKET = 0
PARTNER_ICON = 1
@dataclasses.dataclass
class ImageHash:
hash_type: ImageHashHashType
category: ImageHashCategory
label: str
hash: bytes
@dataclasses.dataclass
class ImageHashResult:
hash_type: ImageHashHashType
category: ImageHashCategory
label: str
confidence: float
def _default_imread_gray(image_path: str):
return cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2GRAY)
@dataclasses.dataclass
class ImageHashBuildTask:
image_path: str
category: ImageHashCategory
label: str
imread_function: Callable[[str], Mat] = _default_imread_gray

View File

@ -1,9 +1,8 @@
from dataclasses import dataclass
from typing import Optional from typing import Optional
import attrs
@dataclass
@attrs.define
class DeviceOcrResult: class DeviceOcrResult:
rating_class: int rating_class: int
pure: int pure: int

View File

@ -110,7 +110,7 @@ class DeviceOcr:
@staticmethod @staticmethod
def preprocess_char_icon(img_gray: Mat): def preprocess_char_icon(img_gray: Mat):
h, w = img_gray.shape[:2] h, w = img_gray.shape[:2]
img = cv2.copyMakeBorder(img_gray, w - h, 0, 0, 0, cv2.BORDER_REPLICATE) img = cv2.copyMakeBorder(img_gray, max(w - h, 0), 0, 0, 0, cv2.BORDER_REPLICATE)
h, w = img.shape[:2] h, w = img.shape[:2]
img = cv2.fillPoly( img = cv2.fillPoly(
img, img,

View File

@ -125,7 +125,7 @@ class DeviceRoisMaskerAutoT1(DeviceRoisMaskerAuto):
class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto): class DeviceRoisMaskerAutoT2(DeviceRoisMaskerAuto):
PFL_HSV_MIN = np.array([0, 0, 248], np.uint8) PFL_HSV_MIN = np.array([0, 0, 248], np.uint8)
PFL_HSV_MAX = np.array([179, 10, 255], np.uint8) PFL_HSV_MAX = np.array([179, 40, 255], np.uint8)
SCORE_HSV_MIN = np.array([0, 0, 180], np.uint8) SCORE_HSV_MIN = np.array([0, 0, 180], np.uint8)
SCORE_HSV_MAX = np.array([179, 255, 255], np.uint8) SCORE_HSV_MAX = np.array([179, 255, 255], np.uint8)

View File

@ -1,25 +1,36 @@
from collections.abc import Iterable from math import floor
from typing import NamedTuple, Tuple, Union from typing import Callable, NamedTuple, Union
import numpy as np import numpy as np
Mat = np.ndarray Mat = np.ndarray
_IntOrFloat = Union[int, float]
class XYWHRect(NamedTuple): class XYWHRect(NamedTuple):
x: int x: _IntOrFloat
y: int y: _IntOrFloat
w: int w: _IntOrFloat
h: int h: _IntOrFloat
def __add__(self, other: Union["XYWHRect", Tuple[int, int, int, int]]): def _to_int(self, func: Callable[[_IntOrFloat], int]):
if not isinstance(other, Iterable) or len(other) != 4: return (func(self.x), func(self.y), func(self.w), func(self.h))
raise ValueError()
def rounded(self):
return self._to_int(round)
def floored(self):
return self._to_int(floor)
def __add__(self, other):
if not isinstance(other, (list, tuple)) or len(other) != 4:
raise TypeError()
return self.__class__(*[a + b for a, b in zip(self, other)]) return self.__class__(*[a + b for a, b in zip(self, other)])
def __sub__(self, other: Union["XYWHRect", Tuple[int, int, int, int]]): def __sub__(self, other):
if not isinstance(other, Iterable) or len(other) != 4: if not isinstance(other, (list, tuple)) or len(other) != 4:
raise ValueError() raise TypeError()
return self.__class__(*[a - b for a, b in zip(self, other)]) return self.__class__(*[a - b for a, b in zip(self, other)])

View File

@ -1,5 +1,5 @@
from collections.abc import Iterable from collections.abc import Iterable
from typing import Callable, TypeVar, Union, overload from typing import TypeVar, overload
import cv2 import cv2
import numpy as np import numpy as np
@ -15,32 +15,25 @@ def imread_unicode(filepath: str, flags: int = cv2.IMREAD_UNCHANGED):
return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags) return cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), flags)
def construct_int_xywh_rect( @overload
rect: XYWHRect, func: Callable[[Union[int, float]], int] = round def apply_factor(item: int, factor: float) -> float: ...
):
return XYWHRect(*[func(num) for num in rect])
@overload @overload
def apply_factor(item: int, factor: float) -> float: def apply_factor(item: float, factor: float) -> float: ...
...
@overload
def apply_factor(item: float, factor: float) -> float:
...
T = TypeVar("T", bound=Iterable) T = TypeVar("T", bound=Iterable)
@overload @overload
def apply_factor(item: T, factor: float) -> T: def apply_factor(item: T, factor: float) -> T: ...
...
def apply_factor(item, factor: float): def apply_factor(item, factor: float):
if isinstance(item, (int, float)): if isinstance(item, (int, float)):
return item * factor return item * factor
if isinstance(item, XYWHRect):
return item.__class__(*[i * factor for i in item])
if isinstance(item, Iterable): if isinstance(item, Iterable):
return item.__class__([i * factor for i in item]) return item.__class__([i * factor for i in item])