5 minute read

Problem Statement

726

Intuition

When solving the problem of counting atoms in a chemical formula, my first thought is to break down the formula by identifying individual elements and their counts. Since the formula can have nested parentheses, I’ll need to handle these cases carefully, especially when numbers outside the parentheses multiply the counts of elements inside.

Approach

I’ll use a stack to help parse the formula:

  1. Traverse the formula from the end to the beginning.
  2. Use a stack to store characters and manage groups of elements.
  3. Handle digits by multiplying the counts of elements.
  4. Use dictionaries to store the counts of elements within each group.
  5. Finally, combine all the counts and sort them lexicographically to get the required format.

Complexity

  • Time complexity: (O(n)), where (n) is the length of the formula string. This is because each character is processed once.
  • Space complexity: (O(n)), as we use a stack and dictionaries to store intermediate data.

Code

class Solution:
    def countOfAtoms(self, formula: str) -> str:
        stack = list(formula)
        group = []
        elems = defaultdict(int)
        res = []
        while stack:
            c = stack.pop()
            if c != '(':
                group.append(c)
            else:
                curr_dict = defaultdict(int)
                curr = ''
                while group and group[-1] != ')':
                    ch = group.pop()
                    if isinstance(ch, defaultdict):
                        for k, v in ch.items():
                            curr_dict[k] += v
                    elif ch.isdigit():
                        n = int(ch)
                        while group and isinstance(group[-1], str) and group[-1].isdigit():
                            n = n * 10 + int(group.pop())
                        curr_dict[curr] += n - 1
                    else:
                        curr = ch
                        while group and isinstance(group[-1], str) and group[-1].isalpha() and group[-1].islower() and curr != ')':
                            curr = group.pop()
                            ch = ch + curr

                        curr = ch
                        curr_dict[ch] += 1

                group.pop()
                if group and isinstance(group[-1], str) and group[-1].isdigit():
                    n = int(group.pop())
                    while group and isinstance(group[-1], str) and group[-1].isdigit():
                        n = n * 10 + int(group.pop())
                    for k, v in curr_dict.items():
                        curr_dict[k] = v * n

                stack.append(curr_dict)

        curr = ''
        while group:
            ch = group.pop()
            if isinstance(ch, defaultdict):
                for k, v in ch.items():
                    elems[k] += v
            elif ch.isdigit():
                n = int(ch)
                while group and isinstance(group[-1], str) and group[-1].isdigit():
                    n = n * 10 + int(group.pop())
                elems[curr] += n - 1
            else:
                curr = ch
                while group and isinstance(group[-1], str) and group[-1].isalpha() and group[-1].islower():
                    curr = group.pop()
                    ch = ch + curr
                curr = ch
                elems[ch] += 1

        kv = list(elems.items())
        kv.sort(key=lambda x: x[0])
        for k, v in kv:
            res.append(k)
            if elems[k] > 1:
                res.append(str(v))

        return ''.join(res)

Editorial

Approach 1: Recursion

class Solution:
    def countOfAtoms(self, formula: str) -> str:
        # Length of the formula
        n = len(formula)

        # Current index. It should be global as needs
        # to be updated in the recursive function
        self.index = 0

        # Recursively parse the formula
        def parse_formula():
            # Local variable
            curr_map = defaultdict(int)
            curr_atom = ""
            curr_count = ""

            # Iterate until the end of the formula
            while self.index < n:
                # UPPERCASE LETTER
                if formula[self.index].isupper():
                    # Save the previous atom and count
                    if curr_atom:
                        if curr_count == "":
                            curr_map[curr_atom] += 1
                        else:
                            curr_map[curr_atom] += int(curr_count)

                    curr_atom = formula[self.index]
                    curr_count = ""
                    self.index += 1

                # lowercase letter
                elif formula[self.index].islower():
                    curr_atom += formula[self.index]
                    self.index += 1

                # Digit. Concatenate the count
                elif formula[self.index].isdigit():
                    curr_count += formula[self.index]
                    self.index += 1

                # Left Parenthesis
                elif formula[self.index] == "(":
                    self.index += 1
                    nested_map = parse_formula()
                    for atom in nested_map:
                        curr_map[atom] += nested_map[atom]

                # Right Parenthesis
                elif formula[self.index] == ")":
                    # Save the last atom and count of nested formula
                    if curr_atom:
                        if curr_count == "":
                            curr_map[curr_atom] += 1
                        else:
                            curr_map[curr_atom] += int(curr_count)

                    self.index += 1
                    multiplier = ""
                    while self.index < n and formula[self.index].isdigit():
                        multiplier += formula[self.index]
                        self.index += 1
                    if multiplier:
                        multiplier = int(multiplier)
                        for atom in curr_map:
                            curr_map[atom] *= multiplier

                    return curr_map

            # Save the last atom and count
            if curr_atom:
                if curr_count == "":
                    curr_map[curr_atom] += 1
                else:
                    curr_map[curr_atom] += int(curr_count)

            return curr_map

        # Parse the formula
        final_map = parse_formula()

        # Sort the final map
        final_map = dict(sorted(final_map.items()))

        # Generate the answer string
        ans = ""
        for atom in final_map:
            ans += atom
            if final_map[atom] > 1:
                ans += str(final_map[atom])

        return ans

Improve the implementation

class Solution:
    def countOfAtoms(self, formula: str) -> str:
        # Length of the formula
        n = len(formula)

        # Current index. It should be global as needs
        # to be updated in the recursive function
        self.index = 0

        # Recursively parse the formula
        def parse_formula():
            # To save the count of atoms in the formula
            curr_map = defaultdict(int)

            # Iterate until the right parenthesis or end of the formula
            while self.index < n and formula[self.index] != ")":
                # If left parenthesis, do recursion
                if formula[self.index] == "(":
                    self.index += 1
                    nested_map = parse_formula()
                    for atom in nested_map:
                        curr_map[atom] += nested_map[atom]

                # Otherwise, it should be UPPERCASE LETTER
                # Extract the atom and count in one go.
                else:
                    curr_atom = formula[self.index]
                    self.index += 1
                    while self.index < n and formula[self.index].islower():
                        curr_atom += formula[self.index]
                        self.index += 1

                    curr_count = ""
                    while self.index < n and formula[self.index].isdigit():
                        curr_count += formula[self.index]
                        self.index += 1

                    if curr_count == "":
                        curr_map[curr_atom] += 1
                    else:
                        curr_map[curr_atom] += int(curr_count)

            # If the right parenthesis, extract the multiplier
            # and multiply the count of atoms in the curr_map
            self.index += 1
            multiplier = ""
            while self.index < n and formula[self.index].isdigit():
                multiplier += formula[self.index]
                self.index += 1

            if multiplier:
                multiplier = int(multiplier)
                for atom in curr_map:
                    curr_map[atom] *= multiplier

            return curr_map

        # Parse the formula
        final_map = parse_formula()

        # Sort the final map
        final_map = dict(sorted(final_map.items()))

        # Generate the answer string
        ans = ""
        for atom in final_map:
            ans += atom
            if final_map[atom] > 1:
                ans += str(final_map[atom])

        return ans

Approach 2: Stack

class Solution:
    def countOfAtoms(self, formula: str) -> str:
        # Stack to keep track of the atoms and their counts
        stack = [defaultdict(int)]

        # Index to keep track of the current character
        index = 0

        # Parse the formula
        while index < len(formula):
            # If left parenthesis, insert a new hashmap to the stack. It will
            # keep track of the atoms and their counts in the nested formula
            if formula[index] == "(":
                stack.append(defaultdict(int))
                index += 1

            # If right parenthesis, pop the top element from the stack
            # Multiply the count with the multiplicity of the nested formula
            elif formula[index] == ")":
                curr_map = stack.pop()
                index += 1
                multiplier = ""
                while index < len(formula) and formula[index].isdigit():
                    multiplier += formula[index]
                    index += 1
                if multiplier:
                    multiplier = int(multiplier)
                    for atom in curr_map:
                        curr_map[atom] *= multiplier

                for atom in curr_map:
                    stack[-1][atom] += curr_map[atom]

            # Otherwise, it must be a UPPERCASE LETTER. Extract the complete
            # atom with frequency, and update the most recent hashmap
            else:
                curr_atom = formula[index]
                index += 1
                while index < len(formula) and formula[index].islower():
                    curr_atom += formula[index]
                    index += 1

                curr_count = ""
                while index < len(formula) and formula[index].isdigit():
                    curr_count += formula[index]
                    index += 1

                if curr_count == "":
                    stack[-1][curr_atom] += 1
                else:
                    stack[-1][curr_atom] += int(curr_count)

        # Sort the final map
        final_map = dict(sorted(stack[0].items()))

        # Generate the answer string
        ans = ""
        for atom in final_map:
            ans += atom
            if final_map[atom] > 1:
                ans += str(final_map[atom])

        return ans