Problem of The Day: K Closest Points to Origin
Problem Statement
Intuition
The goal is to find the k closest points to the origin. The distance of a point from the origin can be calculated using the Euclidean distance formula. Instead of calculating the exact distance using square roots, we can compare the squared distances, which avoids unnecessary computations.
Approach
- Use a max-heap to keep track of the k closest points.
- Iterate over all points, calculate the squared distance for each point.
- Push the negative of the distance and the point into the heap. By maintaining negative values, we ensure that the largest distances are on top of the heap.
- If the size of the heap exceeds k, pop the element with the largest distance.
- Once all points are processed, return the points in the heap.
Complexity
-
Time complexity: \(O(n \log k)\), where n is the number of points. For each point, the heap insertion/removal takes \(O(\log k)\).
-
Space complexity: \(O(k)\), as we maintain a heap of size k.
Code
import heapq
class Solution:
def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
max_heap = []
for x, y in points:
distance = x * x + y * y
heapq.heappush(max_heap, [-distance, [x, y]])
if len(max_heap) > k:
heapq.heappop(max_heap)
return list(map(lambda x: x[1], max_heap))
Editorial
Approach 1: Sort with Custom Comparator
class Solution:
def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
# Sort the list with a custom comparator function
points.sort(key=self.squared_distance)
# Return the first k elements of the sorted list
return points[:k]
def squared_distance(self, point: List[int]) -> int:
"""Calculate and return the squared Euclidean distance."""
return point[0] ** 2 + point[1] ** 2
- time: O(N log N)
- space: O(log N) to O(N)
Approach 2: Max Heap or Max Priority Queue
class Solution:
def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
# Since heap is sorted in increasing order,
# negate the distance to simulate max heap
# and fill the heap with the first k elements of points
heap = [(-self.squared_distance(points[i]), i) for i in range(k)]
heapq.heapify(heap)
for i in range(k, len(points)):
dist = -self.squared_distance(points[i])
if dist > heap[0][0]:
# If this point is closer than the kth farthest,
# discard the farthest point and add this one
heapq.heappushpop(heap, (dist, i))
# Return all points stored in the max heap
return [points[i] for (_, i) in heap]
def squared_distance(self, point: List[int]) -> int:
"""Calculate and return the squared Euclidean distance."""
return point[0] ** 2 + point[1] ** 2
- time: O(N log k)
- space: O(k)
Approach 3: Binary Search
class Solution:
def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
# Precompute the Euclidean distance for each point
distances = [self.euclidean_distance(point) for point in points]
# Create a reference list of point indices
remaining = [i for i in range(len(points))]
# Define the initial binary search range
low, high = 0, max(distances)
# Perform a binary search of the distances
# to find the k closest points
closest = []
while k:
mid = (low + high) / 2
closer, farther = self.split_distances(remaining, distances, mid)
if len(closer) > k:
# If more than k points are in the closer distances
# then discard the farther points and continue
remaining = closer
high = mid
else:
# Add the closer points to the answer array and keep
# searching the farther distances for the remaining points
k -= len(closer)
closest.extend(closer)
remaining = farther
low = mid
# Return the k closest points using the reference indices
return [points[i] for i in closest]
def split_distances(self, remaining: List[int], distances: List[float],
mid: int) -> List[List[int]]:
"""Split the distances around the midpoint
and return them in separate lists."""
closer, farther = [], []
for index in remaining:
if distances[index] <= mid:
closer.append(index)
else:
farther.append(index)
return [closer, farther]
def euclidean_distance(self, point: List[int]) -> float:
"""Calculate and return the squared Euclidean distance."""
return point[0] ** 2 + point[1] ** 2
- time: O(n)
- space: O(n)