3 minute read

Problem Statement

problem

Intuition

The goal is to find the k closest points to the origin. The distance of a point from the origin can be calculated using the Euclidean distance formula. Instead of calculating the exact distance using square roots, we can compare the squared distances, which avoids unnecessary computations.

Approach

  • Use a max-heap to keep track of the k closest points.
  • Iterate over all points, calculate the squared distance for each point.
  • Push the negative of the distance and the point into the heap. By maintaining negative values, we ensure that the largest distances are on top of the heap.
  • If the size of the heap exceeds k, pop the element with the largest distance.
  • Once all points are processed, return the points in the heap.

Complexity

  • Time complexity: \(O(n \log k)\), where n is the number of points. For each point, the heap insertion/removal takes \(O(\log k)\).

  • Space complexity: \(O(k)\), as we maintain a heap of size k.

Code

import heapq

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        max_heap = []
        for x, y in points:
            distance = x * x + y * y
            heapq.heappush(max_heap, [-distance, [x, y]])
            if len(max_heap) > k:
                heapq.heappop(max_heap)
        return list(map(lambda x: x[1], max_heap))

Editorial

Approach 1: Sort with Custom Comparator

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        # Sort the list with a custom comparator function
        points.sort(key=self.squared_distance)

        # Return the first k elements of the sorted list
        return points[:k]

    def squared_distance(self, point: List[int]) -> int:
        """Calculate and return the squared Euclidean distance."""
        return point[0] ** 2 + point[1] ** 2
  • time: O(N log N)
  • space: O(log N) to O(N)

Approach 2: Max Heap or Max Priority Queue

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        # Since heap is sorted in increasing order,
        # negate the distance to simulate max heap
        # and fill the heap with the first k elements of points
        heap = [(-self.squared_distance(points[i]), i) for i in range(k)]
        heapq.heapify(heap)
        for i in range(k, len(points)):
            dist = -self.squared_distance(points[i])
            if dist > heap[0][0]:
                # If this point is closer than the kth farthest,
                # discard the farthest point and add this one
                heapq.heappushpop(heap, (dist, i))

        # Return all points stored in the max heap
        return [points[i] for (_, i) in heap]

    def squared_distance(self, point: List[int]) -> int:
        """Calculate and return the squared Euclidean distance."""
        return point[0] ** 2 + point[1] ** 2
  • time: O(N log k)
  • space: O(k)
class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        # Precompute the Euclidean distance for each point
        distances = [self.euclidean_distance(point) for point in points]
        # Create a reference list of point indices
        remaining = [i for i in range(len(points))]
        # Define the initial binary search range
        low, high = 0, max(distances)

        # Perform a binary search of the distances
        # to find the k closest points
        closest = []
        while k:
            mid = (low + high) / 2
            closer, farther = self.split_distances(remaining, distances, mid)
            if len(closer) > k:
                # If more than k points are in the closer distances
                # then discard the farther points and continue
                remaining = closer
                high = mid
            else:
                # Add the closer points to the answer array and keep
                # searching the farther distances for the remaining points
                k -= len(closer)
                closest.extend(closer)
                remaining = farther
                low = mid

        # Return the k closest points using the reference indices
        return [points[i] for i in closest]

    def split_distances(self, remaining: List[int], distances: List[float],
                        mid: int) -> List[List[int]]:
        """Split the distances around the midpoint
        and return them in separate lists."""
        closer, farther = [], []
        for index in remaining:
            if distances[index] <= mid:
                closer.append(index)
            else:
                farther.append(index)
        return [closer, farther]

    def euclidean_distance(self, point: List[int]) -> float:
        """Calculate and return the squared Euclidean distance."""
        return point[0] ** 2 + point[1] ** 2
  • time: O(n)
  • space: O(n)