python 2.7 - AssertionError when running K means Main Function -


when running below code, receive assertionerror in main function, assert len(args) > 1. idea in code issue occurs?

k-means clustering implementation

import numpy np math import sqrt import csv import sys 

====

define function computes distance between 2 data points

gap = 2 min_val = 1000000  def get_distance(point1, point2):     dis = sqrt(pow(point1[0] - point2[0],2) + pow(point1[1] - point2[1],2))     return dis 

====

define function reads data in csv

def csvreader(data_file):     sampledata = []     global countries     open(data_file, 'r') csvfile:         read_data = csv.reader(csvfile, delimiter=' ', quotechar='|')         row in read_data:             print ', '.join(row)         if read_data <> none:             f in read_data:                 values = f.split(",")                 if values[0] <> 'countries':                     sampledata.append([values[1],values[2]])         return sampledata 

====

write initialisation procedure

def cluster_dis(centroid, cluster):     dis = 0.0     point in cluster:         dis += get_distance(centroid, point)     return dis  def update_centroids(centroids, cluster_id, cluster):     x, y = 0.0, 0.0     length = len(cluster)     if length == 0:         return     item in cluster:         x += item[0]         y += item[1]     centroids[cluster_id] = (x / length, y / length) 

====

implement k-means algorithm, using appropriate looping

def kmeans(data, k):     assert k <= len(data)      seed_ids = np.random.randint(0, len(data), k)     centroids = [data[idx] idx in seed_ids]     clusters = [[] _ in xrange(k)]     cluster_idx = [-1] * len(data)      pre_dis = 0      while true:         point_id, point in enumerate(data):             min_distance, tmp_id = min_val, -1         seed_id, seed in enumerate(centroids):             distance = get_distance(seed, point)         if distance < min_distance:             min_distance = distance             tmp_id = seed_id         if cluster_idx[point_id] != -1:             dex = clusters[cluster_idx[point_id]].index(point)         del clusters[cluster_idx[point_id]][dex]         clusters[tmp_id].append(point)         cluster_idx[point_id] = tmp_id          now_dis = 0.0         cluster_id, cluster in enumerate(clusters):             now_dis += cluster_dis(centroids[cluster_id], cluster)             update_centroids(centroids, cluster_id, cluster)          delta_dis = now_dis - pre_dis         pre_dis = now_dis          if delta_dis < gap:             break      print(centroids)     print(clusters)      return centroids, clusters  def main():     args = sys.argv[1:]     assert len(args) > 1     data_file, k = args[0], int(args[1])      data = csvreader(data_file)     kmeans(data, k)  if __name__ == '__main__':     main() 


Comments