'How to create transition matrix with groupby in pyspark

I have a pyspark dataframe that looks like this

import pandas as pd
so = pd.DataFrame({'id': ['a','a','a','a','b','b','b','b','c','c','c','c'],
                   'time': [1,2,3,4,1,2,3,4,1,2,3,4],
                   'group':['A','A','A','A','A','A','A','A','B','B','B','B'],
                   'value':['S','C','C','C', 'S','C','H', 'H', 'S','C','C','C']})

df_so = spark.createDataFrame(so)
df_so.show()

+---+----+-----+-----+
| id|time|group|value|
+---+----+-----+-----+
|  a|   1|    A|    S|
|  a|   2|    A|    C|
|  a|   3|    A|    C|
|  a|   4|    A|    C|
|  b|   1|    A|    S|
|  b|   2|    A|    C|
|  b|   3|    A|    H|
|  b|   4|    A|    H|
|  c|   1|    B|    S|
|  c|   2|    B|    C|
|  c|   3|    B|    C|
|  c|   4|    B|    C|
+---+----+-----+-----+

I would like to create the "transition matrix" of value by group

The transition matrix indicates what is the probability of e.g. going from value S to value C within each id while time progresses.

Example:

For group A:

  • We have in total 6 movements
  • S->C goes 1 time for id==a and 1 time for id==b, so S to C is (1+1)/6
  • C->S is 0, since within id there is no transition from C to S
  • C->C is 2/6
  • C->H is 1/6
  • H->H is 1/6

Respectively we can do the same for group B

Is there a way to do this in pyspark ?



Solution 1:[1]

First I use lag to make the source column (left side of transition) of the transition for each row, then count the frequency group by source & value(target) divided by the total count.

lagw = Window.partitionBy(['group', 'id']).orderBy('time')
frqw = Window.partitionBy(['group', 'source', 'value'])
ttlw = Window.partitionBy('group')

df = (df.withColumn('source', F.lag('value').over(lagw))
  .withColumn('transition_p', F.count('source').over(frqw) / F.count('source').over(ttlw)))

df.show()

# +---+----+-----+-----+------+------------+
# | id|time|group|value|source|transition_p|
# +---+----+-----+-----+------+------------+
# |  c|   1|    B|    S|  null|         0.0|
# |  c|   3|    B|    C|     C| 0.666666666|
# |  c|   4|    B|    C|     C| 0.666666666|
# |  c|   2|    B|    C|     S| 0.333333333|
# |  b|   1|    A|    S|  null|         0.0|
# .....

If I understand what you like at the end,

(df.filter(df.group == 'A')
 .groupby('source')
 .pivot('value')
 .agg(F.first('transition_p'))
).show()

# +------+---------+---------+---------+
# |source|        C|        H|        S| 
# +------+---------+---------+---------+
# |  null|     null|     null|      0.0|
# |     C|0.3333333|0.1666666|     null|
# |     S|0.3333333|     null|     null|
# |     H|     null|0.1666666|     null|
# +------+---------+---------+---------+

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1