You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

preprocess.py 5.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. #!/usr/bin/python3
  2. import json
  3. import sys
  4. BENCHMARK_FILENAME = "benchmark-data.txt"
  5. OUTPUT_FILENAME = "benchmark-data.c"
  6. def parse_attribute_line(line: str) -> dict:
  7. words = line.split()
  8. assert words[0] == "@attribute"
  9. attribute = { "name": words[1] }
  10. if words[2] == "numeric":
  11. attribute["value-types"] = words[2]
  12. elif words[2].startswith('{'):
  13. values = words[2].lstrip('{').rstrip().rstrip('}').split(',')
  14. attribute["value-types"] = "enum"
  15. attribute["values"] = list(enumerate(values))
  16. attribute["namedict"]: dict[str, int] = {}
  17. for number, name in attribute["values"]:
  18. attribute["namedict"][name] = number
  19. x_max = len(attribute["values"]) - 1
  20. attribute["normalized-values"]: list[float] = []
  21. for value, _ in attribute["values"]:
  22. attribute["normalized-values"].append(value / x_max)
  23. return attribute
  24. def parse_data(attributes: list, line: str) -> list[float]:
  25. parsed_data = []
  26. for fieldnum, field in enumerate(line.split(',')):
  27. attr = attributes[fieldnum]
  28. if attr["value-types"] == "numeric":
  29. # Numeric field. Just copy it as is, we'll do the normalization later.
  30. # Although keeping track of the min and max values for each field now
  31. # would be more efficient.
  32. parsed_data.append(float(field))
  33. elif attr["value-types"] == "enum":
  34. # Get the normalized numeric value for the current symbolic field
  35. numeric_value: int = attributes[fieldnum]["namedict"].get(field)
  36. assert numeric_value is not None
  37. parsed_data.append(attributes[fieldnum]["normalized-values"][numeric_value])
  38. else:
  39. print("Unknown value type at field {} ({}). Line: {}"
  40. .format(fieldnum, field, line))
  41. print("attr: ", json.dumps(attr))
  42. return parsed_data
  43. def update_min_max(min_max: list[list[float]], parsed_line: list[float]) -> None:
  44. for fieldnum, field in enumerate(parsed_line):
  45. oldmin = min_max[0][fieldnum]
  46. oldmax = min_max[1][fieldnum]
  47. if oldmin is None or field < oldmin:
  48. min_max[0][fieldnum] = field
  49. if oldmax is None or field > oldmax:
  50. min_max[1][fieldnum] = field
  51. #min_max[0][fieldnum] = field if oldmin is None or field < oldmin else oldmin
  52. #min_max[1][fieldnum] = field if oldmax is None or field > oldmax else oldmax
  53. def normalize(data: list[list[float]],
  54. min_max: list[list[float]],
  55. attributes: list) -> list[list[float]]:
  56. normalized_data: list[list[float]] = []
  57. for line in data:
  58. normalized_line = []
  59. for fieldnum, field in enumerate(line):
  60. # Fields with values of type enum are already normalized, so we
  61. # should skip them
  62. if attributes[fieldnum]["value-types"] == "enum":
  63. normalized_line.append(field)
  64. continue
  65. x_min = min_max[0][fieldnum]
  66. x_max = min_max[1][fieldnum]
  67. if x_min == x_max:
  68. if 0 <= field <= 1:
  69. normalized_line.append(field)
  70. else:
  71. print("Problem with field {} ({}). Line: {}"
  72. .format(fieldnum, field, line))
  73. print("attr: ", json.dumps(attributes[fieldnum]))
  74. normalized_line.append('ERROR')
  75. continue
  76. normalized_value = (field - x_min) / (x_max - x_min)
  77. normalized_line.append(normalized_value)
  78. normalized_data.append(normalized_line)
  79. return normalized_data
  80. def main():
  81. attributes = []
  82. data = []
  83. min_max = [None, None]
  84. with open(BENCHMARK_FILENAME, 'r', encoding="utf-8") as benchmark_file:
  85. data_started = False
  86. for line in benchmark_file:
  87. if not line.rstrip():
  88. continue
  89. if not data_started:
  90. if line.startswith('@'):
  91. if line.startswith("@attribute "):
  92. attributes.append(parse_attribute_line(line))
  93. elif line.rstrip() == "@data":
  94. #min_max[0] = [+float('inf')] * len(attributes)
  95. #min_max[1] = [-float('inf')] * len(attributes)
  96. min_max[0] = [None] * len(attributes)
  97. min_max[1] = [None] * len(attributes)
  98. data_started = True
  99. else:
  100. # Should not happen
  101. print("What the hell happened here?", file=sys.stderr)
  102. else:
  103. # Data
  104. parsed_line = parse_data(attributes, line.rstrip())
  105. data.append(parsed_line)
  106. update_min_max(min_max, parsed_line)
  107. data = normalize(data, min_max, attributes)
  108. print(json.dumps(attributes, indent=4))
  109. print('\n' + ('-' * 76))
  110. print(json.dumps(data, indent=4))
  111. print('\n' + ('-' * 76))
  112. print(json.dumps(list(zip(min_max[0], min_max[1],
  113. [attr["name"] for attr in attributes])), indent=4))
  114. with open(OUTPUT_FILENAME, 'w', encoding="utf-8") as out:
  115. print("const double benchmark_data[{}][{}] = {}"
  116. .format(len(data), len(attributes), '{'),
  117. file=out)
  118. #print('\n'.join(map(lambda l: ','.join(map(str, l)), data)), file=out)
  119. for line in data:
  120. print("\t{" + ", ".join([str(field) for field in line]) + "},", file=out)
  121. print("};", file=out)
  122. if __name__ == "__main__":
  123. main()