Skip to content

Commit

Permalink
chore: code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed Oct 13, 2022
1 parent 14e6101 commit 3638213
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 794 deletions.
70 changes: 0 additions & 70 deletions draw_match.py

This file was deleted.

6 changes: 2 additions & 4 deletions inference/base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ def predict_from_df(self, df: pd.DataFrame, img: np.ndarray) -> pd.DataFrame:
TypeError
If df is not a pandas DataFrame
"""

if type(df) != pd.DataFrame:
raise TypeError("result must be a pandas DataFrame")
if not isinstance(df, pd.DataFrame):
raise TypeError("df must be a pandas DataFrame")

box_images = []

Expand Down Expand Up @@ -132,7 +131,6 @@ def accuarcy_on_folder(
List[np.ndarray]
List of the images that were misclassified
"""
# load images in array
images = []
for filename in os.listdir(folder_path):
img = cv2.imread(os.path.join(folder_path, filename))
Expand Down
61 changes: 28 additions & 33 deletions inference/hsv_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,20 @@ def check_tuple_format(self, a_tuple: tuple, name: str) -> tuple:
ValueError
If tuple elements are not integers
"""
# Check upper hsv is a tuple
# Check class is a tuple
if type(a_tuple) != tuple:
raise ValueError(f"{name} must be a tuple")

# Check lower hsv is a tuple of length 3
# Check length 3
if len(a_tuple) != 3:
raise ValueError(f"{name} must be a tuple of length 3")

# Check all lower hsv tuple values are ints
# Check all values are ints
for value in a_tuple:
if type(value) != int:
raise ValueError(f"{name} values must be ints")

def check_tuple_intervals(self, a_tuple: tuple, name: str) -> tuple:
def check_tuple_intervals(self, a_tuple: tuple, name: str):
"""
Check tuple intervals
Expand All @@ -102,11 +102,6 @@ def check_tuple_intervals(self, a_tuple: tuple, name: str) -> tuple:
name : str
Name of the tuple
Returns
-------
tuple
Tuple checked
Raises
------
ValueError
Expand Down Expand Up @@ -203,7 +198,7 @@ def check_filter_format(self, filter: dict) -> dict:
ValueError
If filter does not have colors
ValueError
If filter colors is not a list
If filter colors is not a list or a tuple
"""

if type(filter) != dict:
Expand All @@ -216,8 +211,8 @@ def check_filter_format(self, filter: dict) -> dict:
if type(filter["name"]) != str:
raise ValueError("Filter name must be a string")

if type(filter["colors"]) != list:
raise ValueError("Filter colors must be a list")
if type(filter["colors"]) != list and type(filter["colors"] != tuple):
raise ValueError("Filter colors must be a list or tuple")

filter["colors"] = [
self.check_color_format(color) for color in filter["colors"]
Expand Down Expand Up @@ -300,21 +295,19 @@ def add_median_blur(self, img: np.ndarray) -> np.ndarray:
"""
return cv2.medianBlur(img, 5)

def get_img_power(self, img: np.ndarray) -> float:
def non_black_pixels_count(self, img: np.ndarray) -> float:
"""
Get image power.
Power is defined as the number of non black pixels of an image.
Returns the amount of non black pixels an image has
Parameters
----------
img : np.ndarray
Image to get power of
Image
Returns
-------
float
Image power
Count of non black pixels in img
"""
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
return cv2.countNonZero(img)
Expand All @@ -341,29 +334,31 @@ def crop_filter_and_blur_img(self, img: np.ndarray, filter: dict) -> np.ndarray:
transformed_img = self.add_median_blur(transformed_img)
return transformed_img

def set_power_in_filter(self, img: np.ndarray, filter: dict) -> dict:
def add_non_black_pixels_count_in_filter(
self, img: np.ndarray, filter: dict
) -> dict:
"""
Applies filter to image and saves the output power in the filter.
Applies filter to image and saves the number of non black pixels in the filter.
Parameters
----------
img : np.ndarray
Image to apply filter to
filter : dict
Filter to apply
Filter to apply to img
Returns
-------
dict
Filter with power
Filter with non black pixels count
"""
transformed_img = self.crop_filter_and_blur_img(img, filter)
filter["power"] = self.get_img_power(transformed_img)
filter["non_black_pixels_count"] = self.non_black_pixels_count(transformed_img)
return filter

def predict_img(self, img: np.ndarray) -> str:
"""
Gets the filter with most power on img and returns its name.
Gets the filter with most non blakc pixels on img and returns its name.
Parameters
----------
Expand All @@ -373,9 +368,7 @@ def predict_img(self, img: np.ndarray) -> str:
Returns
-------
str
Name of the filter with most power
float
Confidence of the filter with most power
Name of the filter with most non black pixels on img
"""
if img is None:
raise ValueError("Image can't be None")
Expand All @@ -384,14 +377,16 @@ def predict_img(self, img: np.ndarray) -> str:

for i, filter in enumerate(filters):
for color in filter["colors"]:
color = self.set_power_in_filter(img, color)
if "power" not in filter:
filter["power"] = 0
filter["power"] += color["power"]
color = self.add_non_black_pixels_count_in_filter(img, color)
if "non_black_pixels_count" not in filter:
filter["non_black_pixels_count"] = 0
filter["non_black_pixels_count"] += color["non_black_pixels_count"]

max_power_filter = max(filters, key=lambda x: x["power"])
max_non_black_pixels_filter = max(
filters, key=lambda x: x["non_black_pixels_count"]
)

return max_power_filter["name"]
return max_non_black_pixels_filter["name"]

def predict(self, input_image: List[np.ndarray]) -> str:
"""
Expand Down
Loading

0 comments on commit 3638213

Please sign in to comment.