算法描述
You are given an integer array nums and you have to return a new counts array. The counts array has the property where counts[i] is the number of smaller elements to the right of nums[i].
Example:
Given nums = [5, 2, 6, 1]
To the right of 5 there are 2 smaller elements (2 and 1).
To the right of 2 there is only 1 smaller element (1).
To the right of 6 there is 1 smaller element (1).
To the right of 1 there is 0 smaller element.
Return the array [2, 1, 1, 0].
题目大意
给定一个数组nums,要求返回一个数组counts,其中counts数组中的第i个元素是在nums数组中位于nums[i]的右边且比nums[i]小的元素的个数。
解题思路
题目要求是计算位于数组元素右端且小于该数组元素的元素数目,因此最直观的想法是从右到左遍历数组,同时维护一个辅助数组,辅助数组下标从小到大分别代表排序后的nums中的元素。我们用辅助数组来记录已经遇到过的元素的出现次数。这样求小于该数组元素的元素数目只需要对辅助数组求前缀和即可。由于普通数组求前缀和时间复杂度为O(n),因此我们可以考虑使用树状数组来作为辅助数组,这样可以降低时间复杂度到O(log n)。关于树状数组的介绍,可以参考Fenwick Tree。
另一种思路,我们可以用二分查找树(Binary Search Tree),树的每一个结点带有一个count值,表示该结点元素出现次数。我们在从右至左遍历nums数组时,同时更新BST树,counts的值可以由搜索路径所经过的结点的counts值之和得到。
解法I:Fenwick Tree
class Solution(object):
    def countSmaller(self, nums):
        """
        :type nums: List[int]
        :rtype: List[int]
        """
        result = [0] * len(nums)
        order = {}
        for i, num in enumerate(sorted(set(nums))):
            order[num] = i + 1
        tree = FenwickTree(len(nums))
        for i in xrange(len(nums) - 1, -1, -1):
            result[i] = tree.sum(order[nums[i]] - 1)
            tree.add(order[nums[i]], 1)
        return result
class FenwickTree(object):
    def __init__(self, n):
        self.sum_array = [0] * (n + 1)
        self.n = n
    def lowbit(self, x):
        return x & -x
    def add(self, x, val):
        while x <= self.n:
            self.sum_array[x] += val
            x += self.lowbit(x)
    def sum(self, x):
        ret = 0
        while x > 0:
            ret += self.sum_array[x]
            x -= self.lowbit(x)
        return ret解法II:Binary Search Tree
class Solution(object):
    def __init__(self):
        self.root = None
    def countSmaller(self, nums):
        """
        :type nums: List[int]
        :rtype: List[int]
        """
        counts = [0] * len(nums)
        for i in range(len(nums) - 1, -1, -1):
            counts[i] = self.traverse(nums[i])
        return counts
    def traverse(self, val):
        if not self.root:
            self.root = Node(val)
            return 0
        count = 0
        p = self.root
        while p:
            if val < p.val:
                p.small_cnt += 1
                if not p.left:
                    p.left = Node(val)
                    break
                p = p.left
            elif val > p.val:
                count += p.small_cnt + p.count
                if not p.right:
                    p.right = Node(val)
                    break
                p = p.right
            else:
                count += p.small_cnt
                p.count += 1
                break
        return count
class Node(object):
    def __init__(self, val):
        self.small_cnt = 0
        self.count = 1
        self.val = val
        self.left = None
        self.right = None