'How to assert that sum of two series is equal to sum of another two series
Let's say I have 4 series objects:
ser1=pd.Series(data={'a':1,'b':2,'c':NaN, 'd':5, 'e':50})
ser2=pd.Series(data={'a':4,'b':NaN,'c':NaN, 'd':10, 'e':100})
ser3=pd.Series(data={'a':0,'b':NaN,'c':7,'d':15, 'e':NaN})
ser4=pd.Series(data={'a':5,'b':2,'c':10, 'd':NaN, 'e':NaN})
I would like to assert
assert (ser1 + ser2 == ser3 + ser4) where I treat NaNs as zeros, only not a situation where both ser1 and ser2 are Nans - then I want to ommit this case and treat assert as true. For example when ser1 and ser2 are both NaNs ('c') then assert should return True no matter what are the values of ser3 and ser4. In case only one of ser1 or ser2 is NaN, filling nans with zeros would work.
Solution 1:[1]
Here is one way to do it:
def assert_sum_equality(ser1, ser2, ser3, ser4):
"""Helper function.
"""
if ser1.isna().all() and ser2.isna().all():
return True
_ = [ser.fillna(0, inplace=True) for ser in [ser1, ser2, ser3, ser4]]
return all(ser1 + ser2 == ser3 + ser4)
import pandas as pd
# ser1 and ser2 are filled with pd.NA
ser1 = pd.Series({"a": pd.NA, "b": pd.NA, "c": pd.NA, "d": pd.NA, "e": pd.NA})
ser2 = pd.Series({"a": pd.NA, "b": pd.NA, "c": pd.NA, "d": pd.NA, "e": pd.NA})
ser3 = pd.Series({"a": 0, "b": pd.NA, "c": 7, "d": 15, "e": pd.NA})
ser4 = pd.Series({"a": 5, "b": 2, "c": 10, "d": pd.NA, "e": 125})
print(assert_sum_equality(ser1, ser2, ser3, ser4)) # True
# ser1 + ser2 == ser3 + ser4 on all rows
ser1 = pd.Series({"a": 1, "b": 2, "c": 13, "d": 5, "e": 50})
ser2 = pd.Series({"a": 4, "b": pd.NA, "c": 4, "d": 10, "e": 100})
ser3 = pd.Series({"a": 0, "b": pd.NA, "c": 7, "d": 15, "e": pd.NA})
ser4 = pd.Series({"a": 5, "b": 2, "c": 10, "d": pd.NA, "e": 150})
print(assert_sum_equality(ser1, ser2, ser3, ser4)) # True
# ser1 + ser2 != ser3 + ser4 on rows 'c' and 'e'
ser1 = pd.Series({"a": 1, "b": 2, "c": pd.NA, "d": 5, "e": 50})
ser2 = pd.Series({"a": 4, "b": pd.NA, "c": pd.NA, "d": 10, "e": 100})
ser3 = pd.Series({"a": 0, "b": pd.NA, "c": 7, "d": 15, "e": pd.NA})
ser4 = pd.Series({"a": 5, "b": 2, "c": 10, "d": pd.NA, "e": pd.NA})
print(assert_sum_equality(ser1, ser2, ser3, ser4)) # False
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Laurent |
