'How to remove all entries of a specific ID after a binary variable becomes true in Pandas?
Suppose we have the following already sorted dataset:
ID Dead
01 F
01 F
01 T
01 T
01 T
02 F
02 F
02 F
02 F
02 T
03 T
03 T
03 T
03 T
03 T
We have 3 IDs (01, 02, and 03) and whether the individual is dead (True or False). I want the indices where the individuals are alive and the initial row when the individual died, which would leave me with the following dataset:
ID Dead
0 01 F
1 01 F
2 01 T
5 02 F
6 02 F
7 02 F
8 02 F
9 02 T
10 03 T
I came up with a solution that involves looping over all rows and appending the ID to a list if they have died previously. Is there a quicker approach?
Edit: It also has to be in order. Data is not "perfect", for example, we might have the following dataset:
ID Dead
04 F
04 T
04 F
04 F
04 F
And the desired output is:
ID Dead
04 F
04 T
Solution 1:[1]
You can try with groupby with transform idxmax
out = df[df.index<=df['Dead'].eq('T').groupby(df['ID']).transform('idxmax')]
Out[545]:
ID Dead
0 1 F
1 1 F
2 1 T
5 2 F
6 2 F
7 2 F
8 2 F
9 2 T
10 3 T
Or
out = df[df['Dead'].eq('T').groupby(df['ID']).cumsum()<=1]
Out[546]:
ID Dead
0 1 F
1 1 F
2 1 T
5 2 F
6 2 F
7 2 F
8 2 F
9 2 T
10 3 T
For update example
out = df[df.index<=df['Dead'].eq('T').groupby(df['ID']).transform('idxmax')]
out
Out[552]:
ID Dead
0 1 F
1 1 F
2 1 T
5 2 F
6 2 F
7 2 F
8 2 F
9 2 T
10 3 T
15 4 F
16 4 T
Solution 2:[2]
IIUC you want "Dead" equal "F" OR not duplicated on "ID+Dead"
You can use boolean indexing:
m1 = df['Dead'].eq('F')
m2 = ~df.duplicated(['ID', 'Dead'])
df[m1|m2] # keep if either mask is True
output:
ID Dead
0 1 F
1 1 F
2 1 T
5 2 F
6 2 F
7 2 F
8 2 F
9 2 T
10 3 T
updated example: stop after the first T
# dead
m = df['Dead'].eq('T')
# fill incorrect Alive after Dead and compute mask
mask = m.where(m).groupby(df['ID']).apply(lambda x: x.ffill().shift())
df[~mask.fillna(False)]
output:
ID Dead
0 1 F
1 1 F
2 1 T
5 2 F
6 2 F
7 2 F
8 2 F
9 2 T
10 3 T
15 4 F
16 4 T
Solution 3:[3]
Easy solution with datar, a pandas wrapper that reimagines pandas APIs:
Construct data
>>> from datar.all import c, f, rep, tibble, group_by, filter, cumsum
[2022-03-31 11:03:49][datar][WARNING] Builtin name "filter" has been overriden by datar.
>>>
>>> df = tibble(
... ID=rep(["01", "02", "03", "04"], each=5),
... Dead=c(
... rep(["F", "T"], [2, 3]),
... rep(["F", "T"], [4, 1]),
... rep(["F", "T"], [0, 5]),
... rep(["F", "T", "F"], [1, 1, 3]),
... )
... )
>>> df
ID Dead
<object> <object>
0 01 F
1 01 F
2 01 T
3 01 T
4 01 T
5 02 F
6 02 F
7 02 F
8 02 F
9 02 T
10 03 T
11 03 T
12 03 T
13 03 T
14 03 T
15 04 F
16 04 T
17 04 F
18 04 F
19 04 F
Filter data
>>> df >> group_by(f.ID) >> filter(cumsum(cumsum(f.Dead=="T")) <= 1)
ID Dead
<object> <object>
0 01 F
1 01 F
2 01 T
3 02 F
4 02 F
5 02 F
6 02 F
7 02 T
8 03 T
9 04 F
10 04 T
[TibbleGrouped: ID (n=4)]
Explain
The first cumsum() makes sure the earlier appeared Fs are marked 0. The second makes sure the Fs after Ts are marked as values > 1.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | |
| Solution 2 | |
| Solution 3 | Panwen Wang |
