Skip to content

Commit

Permalink
[3] Add progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
starpiens committed Jun 2, 2020
1 parent 002de8c commit 2bcc8de
Showing 1 changed file with 47 additions and 35 deletions.
82 changes: 47 additions & 35 deletions assignment3/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,49 @@
import argparse
from typing import List
from tqdm import trange
import os
from multipledispatch import dispatch
import numpy as np

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm


def main():
parser = argparse.ArgumentParser()
parser.add_argument("data_path", help="path of input data file", type=str)
parser.add_argument("n", help="number of clusters for the corresponding input data", type=int)
parser.add_argument("eps", help="maximum radius of the neighborhood", type=float)
parser.add_argument("min_pts", help="minimum number of points in an Eps-neighborhood of a given point", type=int)
args = parser.parse_args()

# Read unclustered data.
f = open(args.data_path, "r")
objects = []
while True:
line = f.readline()
if line == '':
break
objects.append(DataObject(*line.split('\t')))
f.close()

# Get clustered data with DBSCAN.
clustered = DBSCAN(objects, args.eps, args.min_pts)

# Write clustered data.
t = tqdm(desc="Writing results", total=len(clustered)-1)
for idx, cluster in enumerate(clustered):
# Last cluster is unclustered objects.
if idx == len(clustered) - 1:
break
f = open(os.path.splitext(args.data_path)[0] + "_cluster_{}.txt".format(idx), "w")
for obj in cluster:
f.write(str(obj.id) + "\n")
f.close()
t.update(1)
t.close()

draw_scatter(clustered)


class DataObject(object):
Expand Down Expand Up @@ -54,6 +91,8 @@ def DBSCAN(data_objects: List[DataObject], eps: float, min_pts: int) \
:param min_pts: Minimum number of neighbors to be core point.
:return: List of clusters which contains data objects. Last cluster contains unclustered objects.
"""
t = tqdm(total=len(data_objects), desc="Clustering")

clusters = []
objects = []
for obj in data_objects:
Expand All @@ -65,19 +104,21 @@ def DBSCAN(data_objects: List[DataObject], eps: float, min_pts: int) \
continue

# If a cluster formed, append it.
new_cluster = form_cluster(objects, obj, eps, min_pts)
new_cluster = form_cluster(objects, obj, eps, min_pts, t)
if new_cluster:
clusters.append(new_cluster)

clusters.append([])
for obj in objects:
if not obj.clustered:
clusters[-1].append(obj)
t.update(1)

t.close()
return clusters


def form_cluster(objects: List[DBSCAN_Object], seed: DBSCAN_Object, eps: float, min_pts: int) \
def form_cluster(objects: List[DBSCAN_Object], seed: DBSCAN_Object, eps: float, min_pts: int, t: tqdm) \
-> List[DBSCAN_Object]:
# It cannot be seed unless it is dense enough.
if len(seed.get_neighbors(objects, eps)) <= min_pts:
Expand All @@ -90,6 +131,7 @@ def form_cluster(objects: List[DBSCAN_Object], seed: DBSCAN_Object, eps: float,
queue = [seed]
while queue:
neighbors = queue.pop(0).get_neighbors(objects, eps)
t.update(1)
if len(neighbors) <= min_pts:
continue
for n in neighbors:
Expand All @@ -116,34 +158,4 @@ def draw_scatter(clustered: List[List[DataObject]]):


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("data_path", help="path of input data file", type=str)
parser.add_argument("n", help="number of clusters for the corresponding input data", type=int)
parser.add_argument("eps", help="maximum radius of the neighborhood", type=float)
parser.add_argument("min_pts", help="minimum number of points in an Eps-neighborhood of a given point", type=int)
args = parser.parse_args()

# Read unclustered data.
f = open(args.data_path, "r")
objects = []
while True:
line = f.readline()
if line == '':
break
objects.append(DataObject(*line.split('\t')))
f.close()

# Get clustered data with DBSCAN.
clustered = DBSCAN(objects, args.eps, args.min_pts)

# Write clustered data.
for idx, cluster in enumerate(clustered):
# Last cluster is unclustered objects.
if idx == len(clustered) - 1:
break
f = open(os.path.splitext(args.data_path)[0] + "_cluster_{}.txt".format(idx), "w")
for obj in cluster:
f.write(str(obj.id) + "\n")
f.close()

draw_scatter(clustered)
main()

0 comments on commit 2bcc8de

Please sign in to comment.