Skip to content

Commit 4851942

Browse files
NumberPiOsodhruvmanilagithub-actions
authored
Reduce complexity linear_discriminant_analysis. (TheAlgorithms#2452)
* Reduce complexity linear_discriminant_analysis. * Fix whitespace * Update machine_learning/linear_discriminant_analysis.py Co-authored-by: Dhruv Manilawala <[email protected]> * fixup! Format Python code with psf/black push * Fix format to surpass pre-commit tests * updating DIRECTORY.md * Update machine_learning/linear_discriminant_analysis.py Co-authored-by: Dhruv Manilawala <[email protected]> * fixup! Format Python code with psf/black push Co-authored-by: Dhruv Manilawala <[email protected]> Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
1 parent a6ad25c commit 4851942

File tree

1 file changed

+65
-65
lines changed

1 file changed

+65
-65
lines changed

machine_learning/linear_discriminant_analysis.py

+65-65
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Linear Discriminant Analysis
33
44
5+
56
Assumptions About Data :
67
1. The input variables has a gaussian distribution.
78
2. The variance calculated for each input variables by class grouping is the
@@ -44,6 +45,7 @@
4445
from math import log
4546
from os import name, system
4647
from random import gauss, seed
48+
from typing import Callable, TypeVar
4749

4850

4951
# Make a training dataset drawn from a gaussian distribution
@@ -245,6 +247,40 @@ def accuracy(actual_y: list, predicted_y: list) -> float:
245247
return (correct / len(actual_y)) * 100
246248

247249

250+
num = TypeVar("num")
251+
252+
253+
def valid_input(
254+
input_type: Callable[[object], num], # Usually float or int
255+
input_msg: str,
256+
err_msg: str,
257+
condition: Callable[[num], bool] = lambda x: True,
258+
default: str = None,
259+
) -> num:
260+
"""
261+
Ask for user value and validate that it fulfill a condition.
262+
263+
:input_type: user input expected type of value
264+
:input_msg: message to show user in the screen
265+
:err_msg: message to show in the screen in case of error
266+
:condition: function that represents the condition that user input is valid.
267+
:default: Default value in case the user does not type anything
268+
:return: user's input
269+
"""
270+
while True:
271+
try:
272+
user_input = input_type(input(input_msg).strip() or default)
273+
if condition(user_input):
274+
return user_input
275+
else:
276+
print(f"{user_input}: {err_msg}")
277+
continue
278+
except ValueError:
279+
print(
280+
f"{user_input}: Incorrect input type, expected {input_type.__name__!r}"
281+
)
282+
283+
248284
# Main Function
249285
def main():
250286
""" This function starts execution phase """
@@ -254,87 +290,51 @@ def main():
254290
print("First of all we should specify the number of classes that")
255291
print("we want to generate as training dataset")
256292
# Trying to get number of classes
257-
n_classes = 0
258-
while True:
259-
try:
260-
user_input = int(
261-
input("Enter the number of classes (Data Groupings): ").strip()
262-
)
263-
if user_input > 0:
264-
n_classes = user_input
265-
break
266-
else:
267-
print(
268-
f"Your entered value is {user_input} , Number of classes "
269-
f"should be positive!"
270-
)
271-
continue
272-
except ValueError:
273-
print("Your entered value is not numerical!")
293+
n_classes = valid_input(
294+
input_type=int,
295+
condition=lambda x: x > 0,
296+
input_msg="Enter the number of classes (Data Groupings): ",
297+
err_msg="Number of classes should be positive!",
298+
)
274299

275300
print("-" * 100)
276301

277-
std_dev = 1.0 # Default value for standard deviation of dataset
278302
# Trying to get the value of standard deviation
279-
while True:
280-
try:
281-
user_sd = float(
282-
input(
283-
"Enter the value of standard deviation"
284-
"(Default value is 1.0 for all classes): "
285-
).strip()
286-
or "1.0"
287-
)
288-
if user_sd >= 0.0:
289-
std_dev = user_sd
290-
break
291-
else:
292-
print(
293-
f"Your entered value is {user_sd}, Standard deviation should "
294-
f"not be negative!"
295-
)
296-
continue
297-
except ValueError:
298-
print("Your entered value is not numerical!")
303+
std_dev = valid_input(
304+
input_type=float,
305+
condition=lambda x: x >= 0,
306+
input_msg=(
307+
"Enter the value of standard deviation"
308+
"(Default value is 1.0 for all classes): "
309+
),
310+
err_msg="Standard deviation should not be negative!",
311+
default="1.0",
312+
)
299313

300314
print("-" * 100)
301315

302316
# Trying to get number of instances in classes and theirs means to generate
303317
# dataset
304318
counts = [] # An empty list to store instance counts of classes in dataset
305319
for i in range(n_classes):
306-
while True:
307-
try:
308-
user_count = int(
309-
input(f"Enter The number of instances for class_{i+1}: ")
310-
)
311-
if user_count > 0:
312-
counts.append(user_count)
313-
break
314-
else:
315-
print(
316-
f"Your entered value is {user_count}, Number of "
317-
"instances should be positive!"
318-
)
319-
continue
320-
except ValueError:
321-
print("Your entered value is not numerical!")
320+
user_count = valid_input(
321+
input_type=int,
322+
condition=lambda x: x > 0,
323+
input_msg=(f"Enter The number of instances for class_{i+1}: "),
324+
err_msg="Number of instances should be positive!",
325+
)
326+
counts.append(user_count)
322327
print("-" * 100)
323328

324329
# An empty list to store values of user-entered means of classes
325330
user_means = []
326331
for a in range(n_classes):
327-
while True:
328-
try:
329-
user_mean = float(
330-
input(f"Enter the value of mean for class_{a+1}: ")
331-
)
332-
if isinstance(user_mean, float):
333-
user_means.append(user_mean)
334-
break
335-
print(f"You entered an invalid value: {user_mean}")
336-
except ValueError:
337-
print("Your entered value is not numerical!")
332+
user_mean = valid_input(
333+
input_type=float,
334+
input_msg=(f"Enter the value of mean for class_{a+1}: "),
335+
err_msg="This is an invalid value.",
336+
)
337+
user_means.append(user_mean)
338338
print("-" * 100)
339339

340340
print("Standard deviation: ", std_dev)

0 commit comments

Comments
 (0)