From 2bcc8de8688503c16a4ac5c1de8f7743eafe26b3 Mon Sep 17 00:00:00 2001 From: Jihun Kim Date: Tue, 2 Jun 2020 17:31:58 +0900 Subject: [PATCH] [3] Add progress bar --- assignment3/main.py | 82 ++++++++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 35 deletions(-) diff --git a/assignment3/main.py b/assignment3/main.py index e4f4d88..faebd39 100644 --- a/assignment3/main.py +++ b/assignment3/main.py @@ -1,12 +1,49 @@ import argparse from typing import List -from tqdm import trange import os from multipledispatch import dispatch -import numpy as np +import numpy as np import matplotlib.pyplot as plt import seaborn as sns +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("data_path", help="path of input data file", type=str) + parser.add_argument("n", help="number of clusters for the corresponding input data", type=int) + parser.add_argument("eps", help="maximum radius of the neighborhood", type=float) + parser.add_argument("min_pts", help="minimum number of points in an Eps-neighborhood of a given point", type=int) + args = parser.parse_args() + + # Read unclustered data. + f = open(args.data_path, "r") + objects = [] + while True: + line = f.readline() + if line == '': + break + objects.append(DataObject(*line.split('\t'))) + f.close() + + # Get clustered data with DBSCAN. + clustered = DBSCAN(objects, args.eps, args.min_pts) + + # Write clustered data. + t = tqdm(desc="Writing results", total=len(clustered)-1) + for idx, cluster in enumerate(clustered): + # Last cluster is unclustered objects. + if idx == len(clustered) - 1: + break + f = open(os.path.splitext(args.data_path)[0] + "_cluster_{}.txt".format(idx), "w") + for obj in cluster: + f.write(str(obj.id) + "\n") + f.close() + t.update(1) + t.close() + + draw_scatter(clustered) class DataObject(object): @@ -54,6 +91,8 @@ def DBSCAN(data_objects: List[DataObject], eps: float, min_pts: int) \ :param min_pts: Minimum number of neighbors to be core point. :return: List of clusters which contains data objects. Last cluster contains unclustered objects. """ + t = tqdm(total=len(data_objects), desc="Clustering") + clusters = [] objects = [] for obj in data_objects: @@ -65,7 +104,7 @@ def DBSCAN(data_objects: List[DataObject], eps: float, min_pts: int) \ continue # If a cluster formed, append it. - new_cluster = form_cluster(objects, obj, eps, min_pts) + new_cluster = form_cluster(objects, obj, eps, min_pts, t) if new_cluster: clusters.append(new_cluster) @@ -73,11 +112,13 @@ def DBSCAN(data_objects: List[DataObject], eps: float, min_pts: int) \ for obj in objects: if not obj.clustered: clusters[-1].append(obj) + t.update(1) + t.close() return clusters -def form_cluster(objects: List[DBSCAN_Object], seed: DBSCAN_Object, eps: float, min_pts: int) \ +def form_cluster(objects: List[DBSCAN_Object], seed: DBSCAN_Object, eps: float, min_pts: int, t: tqdm) \ -> List[DBSCAN_Object]: # It cannot be seed unless it is dense enough. if len(seed.get_neighbors(objects, eps)) <= min_pts: @@ -90,6 +131,7 @@ def form_cluster(objects: List[DBSCAN_Object], seed: DBSCAN_Object, eps: float, queue = [seed] while queue: neighbors = queue.pop(0).get_neighbors(objects, eps) + t.update(1) if len(neighbors) <= min_pts: continue for n in neighbors: @@ -116,34 +158,4 @@ def draw_scatter(clustered: List[List[DataObject]]): if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("data_path", help="path of input data file", type=str) - parser.add_argument("n", help="number of clusters for the corresponding input data", type=int) - parser.add_argument("eps", help="maximum radius of the neighborhood", type=float) - parser.add_argument("min_pts", help="minimum number of points in an Eps-neighborhood of a given point", type=int) - args = parser.parse_args() - - # Read unclustered data. - f = open(args.data_path, "r") - objects = [] - while True: - line = f.readline() - if line == '': - break - objects.append(DataObject(*line.split('\t'))) - f.close() - - # Get clustered data with DBSCAN. - clustered = DBSCAN(objects, args.eps, args.min_pts) - - # Write clustered data. - for idx, cluster in enumerate(clustered): - # Last cluster is unclustered objects. - if idx == len(clustered) - 1: - break - f = open(os.path.splitext(args.data_path)[0] + "_cluster_{}.txt".format(idx), "w") - for obj in cluster: - f.write(str(obj.id) + "\n") - f.close() - - draw_scatter(clustered) + main() \ No newline at end of file