Skip to content

Commit

Permalink
Update demo
Browse files Browse the repository at this point in the history
  • Loading branch information
stansf committed May 28, 2022
1 parent d5983d9 commit bf0b278
Showing 1 changed file with 111 additions and 54 deletions.
165 changes: 111 additions & 54 deletions demo/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import base64
import yaml
import colorsys
import warnings

import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from PIL.ImageDraw import ImageDraw
import streamlit as st
from st_aggrid import AgGrid, GridOptionsBuilder
from models import create_df, get_path, get_points, WalrusCoord

from requests_toolbelt import MultipartEncoder
import requests
Expand Down Expand Up @@ -48,56 +51,42 @@ def get_mask(b64mask, width, height) -> np.ndarray:
}
BBOX_WIDTH = 2
CIRCLE_WIDTH = 2
POLY_WIDTH = 1
COEF = 0.002




def visualize_results(
image: Image,
json_results,
draw_boxes=False,
draw_centers=True,
draw_mask=False
draw_centers,
draw_polygons
):
if draw_mask:
# TODO
pass
# mask = get_mask(
# json_results['mask']['b64mask'],
# json_results['mask']['width'],
# json_results['mask']['height']
# )
# image = apply_mask(image, mask)
draw = ImageDraw(image)
if draw_boxes:
boxes = json_results['boxes']
bbox_width = max(BBOX_WIDTH, int(image.width * COEF))

for bbox in boxes:
x1, y1, x2, y2 = bbox
x1 *= image.width
x2 *= image.width
y1 *= image.height
y2 *= image.height
draw.rectangle(
(x1, y1, x2, y2),
fill=None,
outline='white',
width=bbox_width
)
if draw_centers:
if draw_centers and 'centers' in json_results:
circle_width = max(CIRCLE_WIDTH, int(image.width * COEF))
centers = json_results['centers']
# print(centers)
for x, y in centers:

# x *= image.width
# y *= image.height
if 'classes' in json_results:
classes = json_results['classes']
else:
classes = [False] * len(centers)
for (x, y), is_young in zip(centers, classes):
draw.ellipse(
(x - circle_width, y - circle_width,
x + circle_width, y + circle_width),
fill='red',
fill='green' if is_young else 'red',
)
if draw_polygons and 'polygons' in json_results:
polygons = json_results['polygons']
if 'classes' in json_results:
classes = json_results['classes']
else:
classes = [False] * len(polygons)
width = max(POLY_WIDTH, int(image.width * COEF))
for poly, is_young in zip(polygons, classes):
draw.polygon(
poly,
outline='green' if is_young else 'red',
width=width
)
return image

Expand All @@ -106,6 +95,7 @@ def display_map():
# TODO
pass


def save_csv(image: Image, json_data: dict, dst_fname: str):
result_str = ''
for x, y in json_data.get('centers', []):
Expand All @@ -116,12 +106,68 @@ def save_csv(image: Image, json_data: dict, dst_fname: str):
f.write('x,y\n' + result_str)


from typing import List


def _prepare_coords(coords: List[WalrusCoord]) -> dict:
result = {
'centers': [],
'classes': []
}
for c in coords:
result['centers'].append((c.x, c.y))
result['classes'].append(c.is_young)
return result


def show_historic_data():
data = create_df()
gb = GridOptionsBuilder.from_dataframe(data)
gb.configure_selection(use_checkbox=True)
gridOptions = gb.build()
grid_response = AgGrid(
data,
gridOptions=gridOptions,
data_return_mode='AS_INPUT',
update_mode='MODEL_CHANGED',
fit_columns_on_grid_load=False,
theme='light', # Add theme color to the table
enable_enterprise_modules=True,
# height=350,
width='100%',
reload_data=False # True
)
load_btn = st.button('Загрузить')
image = None
container = st.empty()
path = ''
coords = []
selected = grid_response['selected_rows']
if len(selected) > 0:
path = get_path(selected[0]['id'])

if load_btn and path != '':
print(path)
image = Image.open(path)
coords = get_points(selected[0]['id'])
st.text(f'Кол-во моржей на фото: {len(coords)}')
with container:
if image is not None:
coords_json = _prepare_coords(coords)
image = visualize_results(
image,
coords_json,
draw_centers=True,
draw_polygons=False
)
st.image(image)


def layout():
st.sidebar.markdown('---')
st.sidebar.subheader('Параметры визуализации')
draw_boxes = st.sidebar.checkbox('Рамки', value=True, disabled=True)
draw_centers = st.sidebar.checkbox('Центры', value=True, disabled=True)
# draw_mask = st.sidebar.checkbox('Маска сегментации', value=True)
# st.sidebar.subheader('Параметры визуализации')
# draw_poly = st.sidebar.checkbox('Объекты', value=True, disabled=False)
# draw_centers = st.sidebar.checkbox('Центры', value=True, disabled=False)

image_uploaded = st.file_uploader(
'Загрузите изображение',
Expand All @@ -130,6 +176,7 @@ def layout():
st.subheader('Изображение')
container = st.empty()
if image_uploaded is not None:
# Image.open(image_uploaded).save(image_uploaded.name)
# mp_encoder = MultipartEncoder(
# fields={
# 'image': (
Expand All @@ -143,17 +190,16 @@ def layout():
# json_results = response.json()
# json_results = MOCK_JSON

image = Image.open(image_uploaded)
with st.spinner('Выполняется обработка...'):
json_results = predict(np.array(image))
image, json_results = process_image(image_uploaded)

fname = 'results.csv'

save_csv(image, json_results, fname)

with container:
vis_image = visualize_results(
image, json_results, draw_boxes=False, draw_centers=True, draw_mask=False
image, json_results, draw_centers=True, draw_polygons=True
)
st.image(vis_image)

Expand All @@ -168,16 +214,29 @@ def layout():
)



def process_image(image_uploaded: str):
image = Image.open(image_uploaded)
json_results = predict(np.array(image))
return image, json_results


def predict(image: np.ndarray):
# return {'centers': mock_predict()}
wri = WindowReadyImage(image, model)
return {'centers': wri.get_points(), 'boxes': []}
polygons = [
np.array(det.poly.exterior.xy).T.astype(int).ravel().tolist()
for det in wri.detections
]

return {
'centers': wri.get_points(),
'boxes': [],
'polygons': polygons
}


def mock_predict(*args):
return np.array([
# [0.1, 0.2],
# [0.3, 0.3]
[100, 200],
[250, 500]
])
Expand All @@ -186,21 +245,19 @@ def mock_predict(*args):
def main():
st.set_page_config(
page_title='Мониторинг популяции ненецких моржей',
# page_icon='./walrus-icon.png', # config['icon'],
layout='wide'
)
# print(config)
st.title('Мониторинг популяции ненецких моржей')
st.sidebar.image('icon.png')
# st.sidebar.subheader('Выберите режим работы')
# radio = st.sidebar.radio(
# '', options=['Одно изображение', 'Несколько изображений']
# '', options=['Обработать изображение', 'Исторические данные']
# )
# if radio == 'Обработать изображение':
layout()
# else:
# show_historic_data()


if __name__ == '__main__':
main()
# img = cv2.imread('../../DJI_0005 (2).jpg')
# h, w = img.shape[:2]
# print(predict(img)[0].shape)

0 comments on commit bf0b278

Please sign in to comment.