Skip to content

Commit

Permalink
pikvm/kvmd#66: OCR API
Browse files Browse the repository at this point in the history
  • Loading branch information
mdevaev committed Jan 18, 2022
1 parent 3ee1948 commit 3ab43ed
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 11 deletions.
4 changes: 4 additions & 0 deletions kvmd/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ def _get_config_scheme() -> Dict:
"cmd_append": Option([], type=valid_options),
},

"ocr": {
"langs": Option(["eng"], type=valid_string_list, unpack_as="default_langs"),
},

"snapshot": {
"idle_interval": Option(0.0, type=valid_float_f0),
"live_interval": Option(0.0, type=valid_float_f0),
Expand Down
2 changes: 2 additions & 0 deletions kvmd/apps/kvmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .ugpio import UserGpio
from .streamer import Streamer
from .snapshoter import Snapshoter
from .tesseract import TesseractOcr
from .server import KvmdServer


Expand Down Expand Up @@ -86,6 +87,7 @@ def main(argv: Optional[List[str]]=None) -> None:
info_manager=InfoManager(global_config),
log_reader=LogReader(),
user_gpio=UserGpio(config.gpio, global_config.otg.udc),
ocr=TesseractOcr(**config.ocr._unpack()),

hid=hid,
atx=get_atx_class(config.atx.type)(**config.atx._unpack(ignore=["type"])),
Expand Down
4 changes: 2 additions & 2 deletions kvmd/apps/kvmd/api/hid.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def __reset_handler(self, _: Request) -> Response:

# =====

def get_keymaps(self) -> Dict: # Ugly hack to generate hid_keymaps_state (see server.py)
async def get_keymaps(self) -> Dict: # Ugly hack to generate hid_keymaps_state (see server.py)
keymaps: Set[str] = set()
for keymap_name in os.listdir(self.__keymaps_dir_path):
path = os.path.join(self.__keymaps_dir_path, keymap_name)
Expand All @@ -127,7 +127,7 @@ def get_keymaps(self) -> Dict: # Ugly hack to generate hid_keymaps_state (see s

@exposed_http("GET", "/hid/keymaps")
async def __keymaps_handler(self, _: Request) -> Response:
return make_json_response(self.get_keymaps())
return make_json_response(await self.get_keymaps())

@exposed_http("POST", "/hid/print")
async def __print_handler(self, request: Request) -> Response:
Expand Down
54 changes: 52 additions & 2 deletions kvmd/apps/kvmd/api/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@
import io
import functools

from typing import List
from typing import Dict

from aiohttp.web import Request
from aiohttp.web import Response

from PIL import Image as PilImage

from ....validators import check_string_in_list
from ....validators.basic import valid_bool
from ....validators.basic import valid_number
from ....validators.basic import valid_int_f0
from ....validators.basic import valid_string_list
from ....validators.kvm import valid_stream_quality

from .... import aiotools
Expand All @@ -41,11 +47,14 @@
from ..streamer import StreamerSnapshot
from ..streamer import Streamer

from ..tesseract import TesseractOcr


# =====
class StreamerApi:
def __init__(self, streamer: Streamer) -> None:
def __init__(self, streamer: Streamer, ocr: TesseractOcr) -> None:
self.__streamer = streamer
self.__ocr = ocr

# =====

Expand All @@ -61,7 +70,25 @@ async def __take_snapshot_handler(self, request: Request) -> Response:
allow_offline=valid_bool(request.query.get("allow_offline", "false")),
)
if snapshot:
if valid_bool(request.query.get("preview", "false")):
if valid_bool(request.query.get("ocr", "false")):
langs = await self.__ocr.get_available_langs()
return Response(
body=(await self.__ocr.recognize(
data=snapshot.data,
langs=valid_string_list(
arg=str(request.query.get("ocr_langs", "")).strip(),
subval=(lambda lang: check_string_in_list(lang, "OCR lang", langs)),
name="OCR langs list",
),
left=int(valid_number(request.query.get("ocr_left", "-1"))),
top=int(valid_number(request.query.get("ocr_top", "-1"))),
right=int(valid_number(request.query.get("ocr_right", "-1"))),
bottom=int(valid_number(request.query.get("ocr_bottom", "-1"))),
)),
headers=dict(snapshot.headers),
content_type="text/plain",
)
elif valid_bool(request.query.get("preview", "false")):
data = await self.__make_preview(
snapshot=snapshot,
max_width=valid_int_f0(request.query.get("preview_max_width", "0")),
Expand All @@ -84,6 +111,29 @@ async def __remove_snapshot_handler(self, _: Request) -> Response:

# =====

async def get_ocr(self) -> Dict: # XXX: Ugly hack
enabled = self.__ocr.is_available()
default: List[str] = []
available: List[str] = []
if enabled:
default = await self.__ocr.get_default_langs()
available = await self.__ocr.get_available_langs()
return {
"ocr": {
"enabled": enabled,
"langs": {
"default": default,
"available": available,
},
},
}

@exposed_http("GET", "/streamer/ocr")
async def __ocr_handler(self, _: Request) -> Response:
return make_json_response(await self.get_ocr())

# =====

async def __make_preview(self, snapshot: StreamerSnapshot, max_width: int, max_height: int, quality: int) -> bytes:
if max_width == 0 and max_height == 0:
max_width = snapshot.width // 5
Expand Down
34 changes: 27 additions & 7 deletions kvmd/apps/kvmd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from typing import Dict
from typing import Set
from typing import Callable
from typing import Awaitable
from typing import Coroutine
from typing import AsyncGenerator
from typing import Optional
Expand Down Expand Up @@ -68,6 +69,7 @@
from .ugpio import UserGpio
from .streamer import Streamer
from .snapshoter import Snapshoter
from .tesseract import TesseractOcr

from .http import HttpError
from .http import HttpExposed
Expand Down Expand Up @@ -147,6 +149,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
info_manager: InfoManager,
log_reader: LogReader,
user_gpio: UserGpio,
ocr: TesseractOcr,

hid: BaseHid,
atx: BaseAtx,
Expand Down Expand Up @@ -192,6 +195,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
]

self.__hid_api = HidApi(hid, keymap_path, ignore_keys, mouse_x_range, mouse_y_range) # Ugly hack to get keymaps state
self.__streamer_api = StreamerApi(streamer, ocr) # Same hack to get ocr langs state
self.__apis: List[object] = [
self,
AuthApi(auth_manager),
Expand All @@ -201,7 +205,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self.__hid_api,
AtxApi(atx),
MsdApi(msd),
StreamerApi(streamer),
self.__streamer_api,
ExportApi(info_manager, atx, user_gpio),
RedfishApi(info_manager, atx),
]
Expand Down Expand Up @@ -251,21 +255,27 @@ async def __streamer_reset_handler(self, _: aiohttp.web.Request) -> aiohttp.web.
@exposed_http("GET", "/ws")
async def __ws_handler(self, request: aiohttp.web.Request) -> aiohttp.web.WebSocketResponse:
logger = get_logger(0)

client = _WsClient(
ws=aiohttp.web.WebSocketResponse(heartbeat=self.__heartbeat),
stream=valid_bool(request.query.get("stream", "true")),
)
await client.ws.prepare(request)
await self.__register_ws_client(client)

try:
await self.__send_event(client.ws, "gpio_model_state", await self.__user_gpio.get_model())
await self.__send_event(client.ws, "hid_keymaps_state", self.__hid_api.get_keymaps())
await asyncio.gather(*[
self.__send_event(client.ws, component.event_type, await component.get_state())
for component in self.__components
if component.get_state
await self.__send_events_aws(client.ws, [
("gpio_model_state", self.__user_gpio.get_model()),
("hid_keymaps_state", self.__hid_api.get_keymaps()),
("streamer_ocr_state", self.__streamer_api.get_ocr()),
])
await self.__send_events_aws(client.ws, [
(comp.event_type, comp.get_state())
for comp in self.__components
if comp.get_state
])
await self.__send_event(client.ws, "loop", {})

async for msg in client.ws:
if msg.type == aiohttp.web.WSMsgType.TEXT:
try:
Expand All @@ -282,6 +292,7 @@ async def __ws_handler(self, request: aiohttp.web.Request) -> aiohttp.web.WebSoc
logger.error("Unknown websocket event: %r", data)
else:
break

return client.ws
finally:
await self.__remove_ws_client(client)
Expand Down Expand Up @@ -380,6 +391,15 @@ async def __on_cleanup(self, _: aiohttp.web.Application) -> None:
logger.exception("Cleanup error on %s", comp.name)
logger.info("On-Cleanup complete")

async def __send_events_aws(self, ws: aiohttp.web.WebSocketResponse, sources: List[Tuple[str, Awaitable]]) -> None:
await asyncio.gather(*[
self.__send_event(ws, event_type, state)
for (event_type, state) in zip(
map(operator.itemgetter(0), sources),
await asyncio.gather(*map(operator.itemgetter(1), sources)),
)
])

async def __send_event(self, ws: aiohttp.web.WebSocketResponse, event_type: str, event: Optional[Dict]) -> None:
await ws.send_str(json.dumps({
"event_type": event_type,
Expand Down
161 changes: 161 additions & 0 deletions kvmd/apps/kvmd/tesseract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# ========================================================================== #
# #
# KVMD - The main PiKVM daemon. #
# #
# Copyright (C) 2018-2022 Maxim Devaev <[email protected]> #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# ========================================================================== #


import io
import ctypes
import ctypes.util
import contextlib
import warnings

from ctypes import POINTER
from ctypes import Structure
from ctypes import c_int
from ctypes import c_bool
from ctypes import c_char_p
from ctypes import c_void_p
from ctypes import c_char

from typing import List
from typing import Set
from typing import Generator
from typing import Optional

from PIL import Image as PilImage

from ...errors import OperationError

from ... import libc
from ... import aiotools


# =====
class OcrError(OperationError):
pass


# =====
class _TessBaseAPI(Structure):
pass


def _load_libtesseract() -> Optional[ctypes.CDLL]:
try:
path = ctypes.util.find_library("tesseract")
if not path:
raise RuntimeError("Can't find libtesseract")
lib = ctypes.CDLL(path)
for (name, restype, argtypes) in [
("TessBaseAPICreate", POINTER(_TessBaseAPI), []),
("TessBaseAPIInit3", c_int, [POINTER(_TessBaseAPI), c_char_p, c_char_p]),
("TessBaseAPISetImage", None, [POINTER(_TessBaseAPI), c_void_p, c_int, c_int, c_int, c_int]),
("TessBaseAPIGetUTF8Text", POINTER(c_char), [POINTER(_TessBaseAPI)]),
("TessBaseAPISetVariable", c_bool, [POINTER(_TessBaseAPI), c_char_p, c_char_p]),
("TessBaseAPIGetAvailableLanguagesAsVector", POINTER(POINTER(c_char)), [POINTER(_TessBaseAPI)]),
]:
func = getattr(lib, name)
if not func:
raise RuntimeError(f"Can't find libtesseract.{name}")
setattr(func, "restype", restype)
setattr(func, "argtypes", argtypes)
return lib
except Exception as err:
warnings.warn(f"Can't load libtesseract: {err}", RuntimeWarning)
return None


_libtess = _load_libtesseract()


@contextlib.contextmanager
def _tess_api(langs: List[str]) -> Generator[_TessBaseAPI, None, None]:
if not _libtess:
raise OcrError("Tesseract is not available")
api = _libtess.TessBaseAPICreate()
try:
if _libtess.TessBaseAPIInit3(api, None, "+".join(langs).encode()) != 0:
raise OcrError("Can't initialize Tesseract")
if not _libtess.TessBaseAPISetVariable(api, b"debug_file", b"/dev/null"):
raise OcrError("Can't set debug_file=/dev/null")
yield api
finally:
_libtess.TessBaseAPIDelete(api)


# =====
class TesseractOcr:
def __init__(self, default_langs: List[str]) -> None:
self.__default_langs = default_langs

def is_available(self) -> bool:
return bool(_libtess)

async def get_default_langs(self) -> List[str]:
return list(self.__default_langs)

async def get_available_langs(self) -> List[str]:
return (await aiotools.run_async(self.__inner_get_available_langs))

def __inner_get_available_langs(self) -> List[str]:
with _tess_api(["osd"]) as api:
assert _libtess
langs: Set[str] = set()
langs_ptr = _libtess.TessBaseAPIGetAvailableLanguagesAsVector(api)
if langs_ptr is not None:
index = 0
while langs_ptr[index]:
lang = ctypes.cast(langs_ptr[index], c_char_p).value
if lang is not None:
langs.add(lang.decode())
libc.free(langs_ptr[index])
index += 1
libc.free(langs_ptr)
return sorted(langs)

async def recognize(self, data: bytes, langs: List[str], left: int, top: int, right: int, bottom: int) -> str:
if not langs:
langs = self.__default_langs
return (await aiotools.run_async(self.__inner_recognize, data, langs, left, top, right, bottom))

def __inner_recognize(self, data: bytes, langs: List[str], left: int, top: int, right: int, bottom: int) -> str:
with _tess_api(langs) as api:
assert _libtess
with io.BytesIO(data) as bio:
with PilImage.open(bio) as image:
if left >= 0 or top >= 0 or right >= 0 or bottom >= 0:
left = (0 if left < 0 else min(image.width, left))
top = (0 if top < 0 else min(image.height, top))
right = (image.width if right < 0 else min(image.width, right))
bottom = (image.height if bottom < 0 else min(image.height, bottom))
if left < right and top < bottom:
image.crop((left, top, right, bottom))

_libtess.TessBaseAPISetImage(api, image.tobytes("raw", "RGB"), image.width, image.height, 3, image.width * 3)
text_ptr = None
try:
text_ptr = _libtess.TessBaseAPIGetUTF8Text(api)
text = ctypes.cast(text_ptr, c_char_p).value
if text is None:
raise OcrError("Can't recognize image")
return text.decode("utf-8")
finally:
if text_ptr is not None:
libc.free(text_ptr)
Loading

0 comments on commit 3ab43ed

Please sign in to comment.