Last active
May 13, 2020 12:15
-
-
Save hsuRush/4bf21fff72f00cb7e37c2af3f572a86e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import glob | |
import xml.etree.ElementTree as ET | |
import numpy as np | |
# the final result is in './k_means_anchor' | |
path_to_dataset = "/data1000G/steven/ML_PLATE/data/train/labels_voc/" | |
CLUSTERS = 9 | |
HEIGHT = 240 | |
WIDTH = 320 | |
classes = ["apple","banana","pineapple"] | |
---- | |
times = 50 # 跑 50 次kmeans後取效果最好的,如果資料很大不要設太高要跑很久 | |
#run 50 times and select the best one, the speed depends on the size of the dataset. | |
def iou(box, clusters): | |
""" | |
Calculates the Intersection over Union (IoU) between a box and k clusters. | |
:param box: tuple or array, shifted to the origin (i. e. width and height) | |
:param clusters: numpy array of shape (k, 2) where k is the number of clusters | |
:return: numpy array of shape (k, 0) where k is the number of clusters | |
""" | |
x = np.minimum(clusters[:, 0], box[0]) | |
y = np.minimum(clusters[:, 1], box[1]) | |
if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0: | |
raise ValueError("Box has no area") | |
intersection = x * y | |
box_area = box[0] * box[1] | |
cluster_area = clusters[:, 0] * clusters[:, 1] | |
iou_ = intersection / (box_area + cluster_area - intersection) | |
return iou_ | |
def avg_iou(boxes, clusters): | |
""" | |
Calculates the average Intersection over Union (IoU) between a numpy array of boxes and k clusters. | |
:param boxes: numpy array of shape (r, 2), where r is the number of rows | |
:param clusters: numpy array of shape (k, 2) where k is the number of clusters | |
:return: average IoU as a single float | |
""" | |
return np.mean([np.max(iou(boxes[i], clusters)) for i in range(boxes.shape[0])]) | |
def translate_boxes(boxes): | |
""" | |
Translates all the boxes to the origin. | |
:param boxes: numpy array of shape (r, 4) | |
:return: numpy array of shape (r, 2) | |
""" | |
new_boxes = boxes.copy() | |
for row in range(new_boxes.shape[0]): | |
new_boxes[row][2] = np.abs(new_boxes[row][2] - new_boxes[row][0]) | |
new_boxes[row][3] = np.abs(new_boxes[row][3] - new_boxes[row][1]) | |
return np.delete(new_boxes, [0, 1], axis=1) | |
def kmeans(boxes, k, dist=np.median): | |
""" | |
Calculates k-means clustering with the Intersection over Union (IoU) metric. | |
:param boxes: numpy array of shape (r, 2), where r is the number of rows | |
:param k: number of clusters | |
:param dist: distance function | |
:return: numpy array of shape (k, 2) | |
""" | |
rows = boxes.shape[0] | |
distances = np.empty((rows, k)) | |
last_clusters = np.zeros((rows,)) | |
np.random.seed() | |
# the Forgy method will fail if the whole array contains the same rows | |
clusters = boxes[np.random.choice(rows, k, replace=False)] | |
while True: | |
for row in range(rows): | |
distances[row] = 1 - iou(boxes[row], clusters) | |
nearest_clusters = np.argmin(distances, axis=1) | |
if (last_clusters == nearest_clusters).all(): | |
break | |
for cluster in range(k): | |
clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0) | |
last_clusters = nearest_clusters | |
return clusters | |
def load_dataset(path): | |
dataset = [] | |
for xml_file in glob.glob("{}/*xml".format(path)): | |
tree = ET.parse(xml_file) | |
height = int(tree.findtext("./size/height")) | |
width = int(tree.findtext("./size/width")) | |
if height != HEIGHT and width != WIDTH: | |
print("weidth and height is NOT match!!!") | |
break | |
for obj in tree.iter("object"): | |
cls = obj.find('name').text | |
difficult = obj.find('difficult').text | |
if cls not in classes or int(difficult) == 1: | |
continue | |
xmin = int(obj.findtext("bndbox/xmin")) / width | |
ymin = int(obj.findtext("bndbox/ymin")) / height | |
xmax = int(obj.findtext("bndbox/xmax")) / width | |
ymax = int(obj.findtext("bndbox/ymax")) / height | |
xmin = np.float64(xmin) | |
ymin = np.float64(ymin) | |
xmax = np.float64(xmax) | |
ymax = np.float64(ymax) | |
if xmax == xmin or ymax == ymin: | |
print(xml_file, "w or h is 0") | |
continue | |
dataset.append([xmax - xmin, ymax - ymin]) | |
return np.array(dataset) | |
if __name__ == '__main__': | |
best_acc = 0 | |
best_anchor = None | |
data = load_dataset(path_to_dataset) | |
for i in range(times): | |
print('the ',i,' times') | |
out = kmeans(data, k=CLUSTERS) | |
#clusters = [[10,13],[16,30],[33,23],[30,61],[62,45],[59,119],[116,90],[156,198],[373,326]] | |
#out= np.array(clusters)/416.0 | |
#print(out) | |
if avg_iou(data, out) * 100 > best_acc: | |
best_acc = avg_iou(data, out) * 100 | |
print("Accuracy: {:.2f}%".format(avg_iou(data, out) * 100)) | |
anchors_for_yolo = [[int(round(x)), int(round(y))]for x,y in zip(out[:, 0]*WIDTH, out[:, 1]*HEIGHT )] | |
if best_acc == avg_iou(data, out) * 100: | |
best_anchor = anchors_for_yolo | |
END =', ' | |
print('anchors = ', end='') | |
for i, anchor in enumerate(anchors_for_yolo): | |
if i == len(anchors_for_yolo)-1: | |
END = '\n' | |
print(*anchor, sep=',', end=END) | |
#print("Boxes:\n {}-{}".format(out[:, 0]*WIDTH, out[:, 1]*HEIGHT)) | |
ratios = np.around(out[:, 0] / out[:, 1], decimals=2).tolist() | |
#print("Ratios:\n {}".format(sorted(ratios))) | |
with open('k_means_anchor', 'w') as f: | |
print("Accuracy: {:.2f}%".format(best_acc), file=f) | |
print("Accuracy: {:.2f}%".format(best_acc)) | |
print('anchors = ', end='', file=f) | |
print('anchors = ', end='') | |
for i, anchor in enumerate(best_anchor): | |
if i == len(anchors_for_yolo)-1: | |
END = '\n' | |
print(*anchor, sep=',', end=END, file=f) | |
print(*anchor, sep=',', end=END) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment