when running below code, receive assertionerror in main function, assert len(args) > 1. idea in code issue occurs?
k-means clustering implementation
import numpy np math import sqrt import csv import sys
====
define function computes distance between 2 data points
gap = 2 min_val = 1000000 def get_distance(point1, point2): dis = sqrt(pow(point1[0] - point2[0],2) + pow(point1[1] - point2[1],2)) return dis
====
define function reads data in csv
def csvreader(data_file): sampledata = [] global countries open(data_file, 'r') csvfile: read_data = csv.reader(csvfile, delimiter=' ', quotechar='|') row in read_data: print ', '.join(row) if read_data <> none: f in read_data: values = f.split(",") if values[0] <> 'countries': sampledata.append([values[1],values[2]]) return sampledata
====
write initialisation procedure
def cluster_dis(centroid, cluster): dis = 0.0 point in cluster: dis += get_distance(centroid, point) return dis def update_centroids(centroids, cluster_id, cluster): x, y = 0.0, 0.0 length = len(cluster) if length == 0: return item in cluster: x += item[0] y += item[1] centroids[cluster_id] = (x / length, y / length)
====
implement k-means algorithm, using appropriate looping
def kmeans(data, k): assert k <= len(data) seed_ids = np.random.randint(0, len(data), k) centroids = [data[idx] idx in seed_ids] clusters = [[] _ in xrange(k)] cluster_idx = [-1] * len(data) pre_dis = 0 while true: point_id, point in enumerate(data): min_distance, tmp_id = min_val, -1 seed_id, seed in enumerate(centroids): distance = get_distance(seed, point) if distance < min_distance: min_distance = distance tmp_id = seed_id if cluster_idx[point_id] != -1: dex = clusters[cluster_idx[point_id]].index(point) del clusters[cluster_idx[point_id]][dex] clusters[tmp_id].append(point) cluster_idx[point_id] = tmp_id now_dis = 0.0 cluster_id, cluster in enumerate(clusters): now_dis += cluster_dis(centroids[cluster_id], cluster) update_centroids(centroids, cluster_id, cluster) delta_dis = now_dis - pre_dis pre_dis = now_dis if delta_dis < gap: break print(centroids) print(clusters) return centroids, clusters def main(): args = sys.argv[1:] assert len(args) > 1 data_file, k = args[0], int(args[1]) data = csvreader(data_file) kmeans(data, k) if __name__ == '__main__': main()
Comments
Post a Comment