Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. import json
  4. import math
  5. import random
  6. import statistics
  7. import sys
  8. BENCHMARK_JSON = 'benchmark-data.json'
  9. N_CLUSTERS = None
  10. N_VECTORS = None
  11. VECTOR_SIZE = None
  12. N_ITERATIONS = None
  13. def read_benchmark_data(filename):
  14. d = None
  15. with open(BENCHMARK_JSON, 'r', encoding="utf-8") as f:
  16. d = json.load(f)
  17. return d
  18. def init_values(data):
  19. global N_VECTORS, VECTOR_SIZE, N_CLUSTERS, N_ITERATIONS
  20. N_VECTORS = len(data)
  21. VECTOR_SIZE = len(data[0])
  22. try:
  23. N_CLUSTERS = int(sys.argv[1])
  24. N_ITERATIONS = int(sys.argv[2])
  25. except IndexError:
  26. print("Usage: {} N_CLUSTERS N_ITERATION [N_CHANGES]".format(sys.argv[0]))
  27. exit(0)
  28. def distance(u, v):
  29. return math.sqrt(sum((i - j)**2 for i, j in zip(u, v)))
  30. def is_solution_valid(s):
  31. cluster_empty = [True] * N_CLUSTERS
  32. for cluster in s:
  33. cluster_empty[cluster] = False
  34. return not bool(list(filter(None, cluster_empty)))
  35. def generate_initial_solution():
  36. s = [random.randrange(N_CLUSTERS) for _ in range(N_VECTORS)]
  37. while not is_solution_valid(s):
  38. s = [random.randrange(N_CLUSTERS) for _ in range(N_VECTORS)]
  39. return s
  40. def print_solution(s, name="Solution"):
  41. print(name + ": [")
  42. for i, cluster in enumerate(s):
  43. end = "]\n" if i + 1 == len(s) \
  44. else ",\n" if ((i + 1) % 23 == 0) else ", "
  45. print(cluster, end=end)
  46. def print_vector(v, name="Vector"):
  47. print(name + ": [")
  48. for i, attr in enumerate(v):
  49. end = "]\n" if i + 1 == len(v) \
  50. else ",\n" if ((i + 1) % 7 == 0) else ", "
  51. print("{:.5f}".format(attr), end=end)
  52. def neighbor_solution(s, n_changes=1):
  53. new_s = s[:]
  54. change_locations = get_change_locations(n_changes)
  55. # print("Changes:")
  56. for loc in change_locations:
  57. old_cluster = s[loc]
  58. new_cluster = random.randrange(N_CLUSTERS - 1)
  59. new_cluster += int(new_cluster >= old_cluster)
  60. # print("\tvector {}: ({} -> {})".format(loc, old_cluster, new_cluster))
  61. new_s[loc] = new_cluster
  62. return new_s
  63. def get_change_locations(n):
  64. locations = []
  65. for _ in range(n):
  66. loc = random.randrange(N_VECTORS)
  67. # Make sure the locations are unique
  68. while loc in locations:
  69. loc = random.randrange(N_VECTORS)
  70. locations.append(loc)
  71. return locations
  72. def objective_function(data, s, global_center):
  73. centers = calculate_centers(data, s, global_center)
  74. inter = calculate_inter(global_center, centers)
  75. intra = calculate_intra(data, s, centers)
  76. return (inter, intra, inter - intra)
  77. def calculate_inter(global_center, centers):
  78. return sum([distance(global_center, c) for c in centers]) / N_CLUSTERS
  79. def calculate_intra(data, s, centers):
  80. intras = [0.0] * N_CLUSTERS
  81. vectors_in_cluster = [0] * N_CLUSTERS
  82. for vector, cluster in zip(data, s):
  83. vectors_in_cluster[cluster] += 1
  84. intras[cluster] += distance(centers[cluster], vector)
  85. for cluster, intra in enumerate(intras):
  86. intras[cluster] = intra / vectors_in_cluster[cluster]
  87. return statistics.fmean(intras)
  88. def calculate_global_center(data):
  89. gc = [0.0] * VECTOR_SIZE
  90. for vector in data:
  91. for i, attr in enumerate(vector):
  92. gc[i] += attr
  93. for i, attr in enumerate(gc):
  94. gc[i] = attr / N_VECTORS
  95. return gc
  96. def calculate_centers(data, s, global_center):
  97. centers = [[0.0] * VECTOR_SIZE for _ in range(N_CLUSTERS)]
  98. vectors_in_cluster = [0] * N_CLUSTERS
  99. for vector, cluster in zip(data, s):
  100. vectors_in_cluster[cluster] += 1
  101. for i, attr in enumerate(vector):
  102. centers[cluster][i] += attr
  103. for cluster, vector in enumerate(centers):
  104. for i, attr in enumerate(vector):
  105. centers[cluster][i] = attr / vectors_in_cluster[cluster]
  106. return centers
  107. def main():
  108. data = read_benchmark_data(BENCHMARK_JSON)
  109. init_values(data)
  110. print("N_VECTORS", N_VECTORS)
  111. print("VECTOR_SIZE", VECTOR_SIZE)
  112. print("N_CLUSTERS", N_CLUSTERS)
  113. global_center = calculate_global_center(data)
  114. print_vector(global_center, "Global center")
  115. s0 = generate_initial_solution()
  116. print_solution(s0, "Initial solution")
  117. f0 = objective_function(data, s0, global_center)
  118. print("F = {:f} - {:f} = {:f}".format(*f0))
  119. try:
  120. n_changes = int(sys.argv[3])
  121. except IndexError:
  122. n_changes = 1
  123. for _ in range(N_ITERATIONS):
  124. s1 = neighbor_solution(s0, n_changes)
  125. # Make sure this solution we just crated is valid
  126. while not is_solution_valid(s1):
  127. print("Solution is not valid, generating another...")
  128. s1 = neighbor_solution(s0)
  129. f1 = objective_function(data, s1, global_center)
  130. if f1[2] > f0[2]:
  131. print_solution(s1, "New optimal solution")
  132. print("F = {:f} - {:f} = {:f} ({:f}% improvement)".format(*f1,
  133. math.fabs((f1[2] - f0[2]) / f0[2] * 100)))
  134. s0 = s1
  135. f0 = f1
  136. print_solution(s1, "Final solution")
  137. print("F = {:f} - {:f} = {:f}".format(*f0))
  138. if __name__ == "__main__":
  139. main()