mirror of
https://github.com/283375/arcaea-offline-ocr.git
synced 2025-04-20 22:10:17 +00:00
97 lines
2.6 KiB
Python
97 lines
2.6 KiB
Python
import io
|
|
import sqlite3
|
|
from gzip import GzipFile
|
|
from typing import Tuple
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from .types import Mat
|
|
|
|
|
|
class SIFTDatabase:
|
|
def __init__(self, db_path: str, load: bool = True):
|
|
self.__db_path = db_path
|
|
self.__tags = []
|
|
self.__descriptors = []
|
|
self.__size = None
|
|
|
|
if load:
|
|
self.load_db()
|
|
|
|
@property
|
|
def db_path(self):
|
|
return self.__db_path
|
|
|
|
@db_path.setter
|
|
def db_path(self, value):
|
|
self.__db_path = value
|
|
|
|
@property
|
|
def tags(self):
|
|
return self.__tags
|
|
|
|
@property
|
|
def descriptors(self):
|
|
return self.__descriptors
|
|
|
|
@property
|
|
def size(self):
|
|
return self.__size
|
|
|
|
@size.setter
|
|
def size(self, value: Tuple[int, int]):
|
|
self.__size = value
|
|
|
|
def load_db(self):
|
|
conn = sqlite3.connect(self.db_path)
|
|
with conn:
|
|
cursor = conn.cursor()
|
|
|
|
size_str = cursor.execute(
|
|
"SELECT value FROM properties WHERE id = 'size'"
|
|
).fetchone()[0]
|
|
sizr_str_arr = size_str.split(", ")
|
|
self.size = tuple(int(s) for s in sizr_str_arr)
|
|
tag__descriptors_bytes = cursor.execute(
|
|
"SELECT tag, descriptors FROM sift"
|
|
).fetchall()
|
|
|
|
gzipped = int(
|
|
cursor.execute(
|
|
"SELECT value FROM properties WHERE id = 'gzip'"
|
|
).fetchone()[0]
|
|
)
|
|
for tag, descriptor_bytes in tag__descriptors_bytes:
|
|
buffer = io.BytesIO(descriptor_bytes)
|
|
self.tags.append(tag)
|
|
if gzipped == 0:
|
|
self.descriptors.append(np.load(buffer))
|
|
else:
|
|
gzipped_buffer = GzipFile(None, "rb", fileobj=buffer)
|
|
self.descriptors.append(np.load(gzipped_buffer))
|
|
|
|
def lookup_img(
|
|
self,
|
|
__img: Mat,
|
|
*,
|
|
sift=cv2.SIFT_create(),
|
|
bf: cv2.BFMatcher = cv2.BFMatcher(),
|
|
) -> Tuple[str, float]:
|
|
img = __img.copy()
|
|
if self.size is not None:
|
|
img = cv2.resize(img, self.size)
|
|
_, descriptors = sift.detectAndCompute(img, None)
|
|
|
|
good_results = []
|
|
for des in self.descriptors:
|
|
matches = bf.knnMatch(descriptors, des, k=2)
|
|
good = sum(m.distance < 0.75 * n.distance for m, n in matches)
|
|
good_results.append(good)
|
|
best_match_index = max(enumerate(good_results), key=lambda i: i[1])[0]
|
|
|
|
return (
|
|
self.tags[best_match_index],
|
|
good_results[best_match_index] / len(descriptors),
|
|
)
|