You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tabu.py 5.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. import json
  4. import math
  5. import random
  6. import statistics
  7. import sys
  8. from collections import deque
  9. BENCHMARK_JSON = 'benchmark-data.json'
  10. N_CLUSTERS = None
  11. N_VECTORS = None
  12. VECTOR_SIZE = None
  13. N_ITERATIONS = None
  14. N_NEIGHBORS = 128
  15. TABU_SIZE = 1024
  16. def read_benchmark_data(filename):
  17. d = None
  18. with open(BENCHMARK_JSON, 'r', encoding="utf-8") as f:
  19. d = json.load(f)
  20. return d
  21. def init_values(data):
  22. global N_VECTORS, VECTOR_SIZE, N_CLUSTERS, N_ITERATIONS
  23. N_VECTORS = len(data)
  24. VECTOR_SIZE = len(data[0])
  25. try:
  26. N_CLUSTERS = int(sys.argv[1])
  27. N_ITERATIONS = int(sys.argv[2])
  28. except IndexError:
  29. print("Usage: {} N_CLUSTERS N_ITERATION".format(sys.argv[0]))
  30. exit(0)
  31. def distance(u, v):
  32. return math.sqrt(sum((i - j)**2 for i, j in zip(u, v)))
  33. def is_solution_valid(s, tabu=[]):
  34. cluster_empty = [True] * N_CLUSTERS
  35. for cluster in s:
  36. cluster_empty[cluster] = False
  37. return not bool(list(filter(None, cluster_empty))) and s not in tabu
  38. def generate_initial_solution():
  39. s = [random.randrange(N_CLUSTERS) for _ in range(N_VECTORS)]
  40. while not is_solution_valid(s):
  41. s = [random.randrange(N_CLUSTERS) for _ in range(N_VECTORS)]
  42. return s
  43. def print_solution(s, name="Solution"):
  44. print(name + ": [")
  45. for i, cluster in enumerate(s):
  46. end = "]\n" if i + 1 == len(s) \
  47. else ",\n" if ((i + 1) % 23 == 0) else ", "
  48. print(cluster, end=end)
  49. def print_vector(v, name="Vector"):
  50. print(name + ": [")
  51. for i, attr in enumerate(v):
  52. end = "]\n" if i + 1 == len(v) \
  53. else ",\n" if ((i + 1) % 7 == 0) else ", "
  54. print("{:.5f}".format(attr), end=end)
  55. def neighbor_solution(s, n_changes=1):
  56. new_s = s[:]
  57. change_locations = get_change_locations(n_changes)
  58. # print("Changes:")
  59. for loc in change_locations:
  60. old_cluster = s[loc]
  61. new_cluster = random.randrange(N_CLUSTERS - 1)
  62. new_cluster += int(new_cluster >= old_cluster)
  63. # print("\tvector {}: ({} -> {})".format(loc, old_cluster, new_cluster))
  64. new_s[loc] = new_cluster
  65. return new_s
  66. def get_change_locations(n):
  67. locations = []
  68. for _ in range(n):
  69. loc = random.randrange(N_VECTORS)
  70. # Make sure the locations are unique
  71. while loc in locations:
  72. loc = random.randrange(N_VECTORS)
  73. locations.append(loc)
  74. return locations
  75. def objective_function(data, s, global_center):
  76. centers = calculate_centers(data, s, global_center)
  77. inter = calculate_inter(global_center, centers)
  78. intra = calculate_intra(data, s, centers)
  79. return (inter, intra, inter - intra)
  80. def calculate_inter(global_center, centers):
  81. return sum([distance(global_center, c) for c in centers]) / N_CLUSTERS
  82. def calculate_intra(data, s, centers):
  83. intras = [0.0] * N_CLUSTERS
  84. vectors_in_cluster = [0] * N_CLUSTERS
  85. for vector, cluster in zip(data, s):
  86. vectors_in_cluster[cluster] += 1
  87. intras[cluster] += distance(centers[cluster], vector)
  88. for cluster, intra in enumerate(intras):
  89. intras[cluster] = intra / vectors_in_cluster[cluster]
  90. return statistics.fmean(intras)
  91. def calculate_global_center(data):
  92. gc = [0.0] * VECTOR_SIZE
  93. for vector in data:
  94. for i, attr in enumerate(vector):
  95. gc[i] += attr
  96. for i, attr in enumerate(gc):
  97. gc[i] = attr / N_VECTORS
  98. return gc
  99. def calculate_centers(data, s, global_center):
  100. centers = [[0.0] * VECTOR_SIZE for _ in range(N_CLUSTERS)]
  101. vectors_in_cluster = [0] * N_CLUSTERS
  102. for vector, cluster in zip(data, s):
  103. vectors_in_cluster[cluster] += 1
  104. for i, attr in enumerate(vector):
  105. centers[cluster][i] += attr
  106. for cluster, vector in enumerate(centers):
  107. for i, attr in enumerate(vector):
  108. centers[cluster][i] = attr / vectors_in_cluster[cluster]
  109. return centers
  110. def main():
  111. data = read_benchmark_data(BENCHMARK_JSON)
  112. init_values(data)
  113. print("N_VECTORS", N_VECTORS)
  114. print("VECTOR_SIZE", VECTOR_SIZE)
  115. print("N_CLUSTERS", N_CLUSTERS)
  116. global_center = calculate_global_center(data)
  117. print_vector(global_center, "Global center")
  118. s0 = generate_initial_solution()
  119. print_solution(s0, "Initial solution")
  120. f0 = objective_function(data, s0, global_center)
  121. print("F = {:f} - {:f} = {:f}".format(*f0))
  122. tabu = deque([], TABU_SIZE)
  123. neighbors = []
  124. for _ in range(N_ITERATIONS):
  125. # Generate N_NEIGHBORS neighbor solutions of s0
  126. # TODO: make sure they're all distinct
  127. for _ in range(N_NEIGHBORS):
  128. s = neighbor_solution(s0)
  129. while not is_solution_valid(s, tabu):
  130. s = neighbor_solution(s0)
  131. f = objective_function(data, s, global_center)
  132. neighbors.append((s, f))
  133. best_neighbor = max(neighbors, key=lambda neighbor: neighbor[1][2])
  134. s, f = best_neighbor
  135. if f[2] > f0[2]:
  136. print_solution(s, "New optimal solution")
  137. print("F = {:f} - {:f} = {:f} ({:f}% improvement)"
  138. .format(*f, math.fabs((f[2] - f0[2]) / f0[2] * 100)))
  139. s0 = s
  140. f0 = f
  141. tabu.append(s)
  142. print_solution(s0, "Final solution")
  143. print("F = {:f} - {:f} = {:f}".format(*f0))
  144. if __name__ == "__main__":
  145. main()