diff --git a/chapter12/loaves.py b/chapter12/loaves.py index c93027e..6206f69 100644 --- a/chapter12/loaves.py +++ b/chapter12/loaves.py @@ -30,10 +30,7 @@ def knn(point: Point, neighbours): neighbour.distance = np.linalg.norm(point.array - neighbour.array) logger.debug(f"{neighbour.identifier}: {neighbour.distance}") - total = 0 - for n in sorted(neighbours, key=lambda x: x.distance)[:K]: - total += n.sold - return total / K + return sorted(neighbours, key=lambda x: x.distance)[:K] neighbours = [ @@ -47,7 +44,13 @@ neighbours = [ point = Point("T", 4, True, False) K = 4 -average_distance = knn(point, neighbours) -logger.debug(average_distance) -print(f"Number of loaves to make: {int(round(average_distance, 0))}") +k_nearest = knn(point, neighbours) + +average_sold = 0 +for n in k_nearest: + average_sold += n.sold +average_sold = average_sold / K + +logger.debug(average_sold) +print(f"Number of loaves to make: {int(round(average_sold, 0))}")