From c1359cf6f58cd58811bacd19ca13f659d2e62a85 Mon Sep 17 00:00:00 2001 From: 283375 Date: Fri, 11 Aug 2023 23:19:54 +0800 Subject: [PATCH] feat: 283375/image-sift-database --- src/arcaea_offline_ocr/sift_db.py | 96 +++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 src/arcaea_offline_ocr/sift_db.py diff --git a/src/arcaea_offline_ocr/sift_db.py b/src/arcaea_offline_ocr/sift_db.py new file mode 100644 index 0000000..11be09f --- /dev/null +++ b/src/arcaea_offline_ocr/sift_db.py @@ -0,0 +1,96 @@ +import io +import sqlite3 +from gzip import GzipFile +from typing import Tuple + +import cv2 +import numpy as np + +from .types import Mat + + +class SIFTDatabase: + def __init__(self, db_path: str, load: bool = True): + self.__db_path = db_path + self.__tags = [] + self.__descriptors = [] + self.__size = None + + if load: + self.load_db() + + @property + def db_path(self): + return self.__db_path + + @db_path.setter + def db_path(self, value): + self.__db_path = value + + @property + def tags(self): + return self.__tags + + @property + def descriptors(self): + return self.__descriptors + + @property + def size(self): + return self.__size + + @size.setter + def size(self, value: Tuple[int, int]): + self.__size = value + + def load_db(self): + conn = sqlite3.connect(self.db_path) + with conn: + cursor = conn.cursor() + + size_str = cursor.execute( + "SELECT value FROM properties WHERE id = 'size'" + ).fetchone()[0] + sizr_str_arr = size_str.split(", ") + self.size = tuple(int(s) for s in sizr_str_arr) + tag__descriptors_bytes = cursor.execute( + "SELECT tag, descriptors FROM sift" + ).fetchall() + + gzipped = int( + cursor.execute( + "SELECT value FROM properties WHERE id = 'gzip'" + ).fetchone()[0] + ) + for tag, descriptor_bytes in tag__descriptors_bytes: + buffer = io.BytesIO(descriptor_bytes) + self.tags.append(tag) + if gzipped == 0: + self.descriptors.append(np.load(buffer)) + else: + gzipped_buffer = GzipFile(None, "rb", fileobj=buffer) + self.descriptors.append(np.load(gzipped_buffer)) + + def lookup_img( + self, + __img: Mat, + *, + sift=cv2.SIFT_create(), + bf: cv2.BFMatcher = cv2.BFMatcher(), + ) -> Tuple[str, float]: + img = __img.copy() + if self.size is not None: + img = cv2.resize(img, self.size) + _, descriptors = sift.detectAndCompute(img, None) + + good_results = [] + for des in self.descriptors: + matches = bf.knnMatch(descriptors, des, k=2) + good = sum(m.distance < 0.75 * n.distance for m, n in matches) + good_results.append(good) + best_match_index = max(enumerate(good_results), key=lambda i: i[1])[0] + + return ( + self.tags[best_match_index], + good_results[best_match_index] / len(descriptors), + )