We all know that to find the maximum value index we can use argmax, but what if you want to find the top 3 or top 5 values. Then you can use argpartition.
Let’s take an example array.
x = [10,1,6,8,2,12,20,15,56,23]
In this array, it’s very easy to find the maximum value index, it’s 8.
But what if you want the top 3 or top 5, then you can use np.argmax.
How it works is that it first sorts the array and then partitions the array on the kth element. All elements lower than the kth element will be behind it and larger ones will be after it.
Let’s see with a few examples.
idx = np.argpartition(x, kth=-3)
print(idx)
>>> [1 4 2 3 0 5 7 6 8 9]
print([x[i] for i in idx ])
>>> [1, 2, 6, 8, 10, 12, 15, 20, 56, 23]
Here you can see that you get the top 3 indices as the last 3 values of the list, you can simply filter the values you can want by using idx[-3:].
Similarly for the top 5 –
idx = np.argpartition(x, kth=-5)
print(idx[-5:])
>>> [5 7 6 8 9]
Hopefully, this post explains how you can use arg-partition to get the top k element indices. If you have any questions, feel free to ask in the comments or here on my Youtube Channel.
Leave a comment