I'm using matplotlib to show my data set in 3d. But I'd like to show each change to the dataset in a 3D chart. And create an animation of it or maybe save the animation. Follow the code I'm using:
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
import random
import time
K = 3
ATTRIBUTE_COUNT = 4
_random = random.seed()
data = []
data_classes = {}
def manhattan_distance(point_x, point_y):
distance, length = 0, len(point_x)
for i in range(length):
distance += abs( point_x[i] - point_y[i])
return distance
def get_centroid(cluster):
summ = []
for j in range(ATTRIBUTE_COUNT):
summ.append(0)
for element in cluster:
for index, attribute in enumerate(element):
summ[index] += attribute
if len(cluster) != 0:
new_centroid = list(map(lambda x: x/len(cluster), summ))
return new_centroid
else:
return None
#read data set from file
with open('iris.data') as f:
for line in f:
line_data = line.split(',I')
classe_name = 'I' + line_data[1]
point = list((map(float, line_data[0].split(','))))
data.append(point)
data_classes[str(point)] = classe_name.rstrip()
# I am just showing the original data set.
aux = [[],[],[],[]]
for point in data:
aux[0].append(point[0])
aux[1].append(point[1])
aux[2].append(point[2])
aux[3].append(point[3])
fig = plt.figure(1)
fig.suptitle('Data Set')
ax = fig.gca(projection='3d')
ax.scatter(aux[0],aux[1],aux[2], c='b', s=[25 * x for x in aux[3]],cmap=plt.hot())
plt.show()
centroids = []
centroids = random.sample(data, K)
changed_centroid = True
while(changed_centroid):
clusters = []
new_centroids = []
for i in range(K):
clusters.append([])
for element in data:
distances = []
for centroid in centroids:
distances.append(manhattan_distance(element, centroid))
index = distances.index(min(distances))#index of minimuim distance
clusters[index].append(element)
#for each cluster I would like to show it in a different color.
for cluster in clusters:
new_centroid = get_centroid(cluster)
new_centroids.append(new_centroid)
if new_centroids == centroids:
changed_centroid = False
centroids = new_centroids
#for centroid in centroids:
#for each centroid I would like to show like a triangle.
#show result of classification
for index, cluster in enumerate(clusters):
print('classe', index ,'------------------------------')
for point in cluster:
print(data_classes[str(point)], end=' ')
print('\n')