'Get name / alias of column in PySpark

I am defining a column object like this:

column = F.col('foo').alias('bar')

I know I can get the full expression using str(column). But how can I get the column's alias only?

In the example, I'm looking for a function get_column_name where get_column_name(column) returns the string bar.



Solution 1:[1]

One way is through regular expressions:

from pyspark.sql.functions import col
column = col('foo').alias('bar')
print(column)
#Column<foo AS `bar`>

import re
print(re.findall("(?<=AS `)\w+(?=`>$)", str(column)))[0]
#'bar'

Solution 2:[2]

Alternatively, we could use a wrapper function to tweak the behavior of Column.alias and Column.name methods to store the alias only in an AS attribute:

from pyspark.sql import Column, SparkSession
from pyspark.sql.functions import col, explode, array, struct, lit
SparkSession.builder.getOrCreate()

def alias_wrapper(self, *alias, **kwargs):
    renamed_col = Column._alias(self, *alias, **kwargs)
    renamed_col.AS = alias[0] if len(alias) == 1 else alias
    return renamed_col

Column._alias, Column.alias, Column.name, Column.AS = Column.alias, alias_wrapper, alias_wrapper, None

which then guarantees:

assert(col("foo").alias("bar").AS == "bar")
# `name` should act like `alias`
assert(col("foo").name("bar").AS == "bar")
# column without alias should have None in `AS`
assert(col("foo").AS is None)
# multialias should be handled
assert(explode(array(struct(lit(1), lit("a")))).alias("foo", "bar").AS == ("foo", "bar"))

Solution 3:[3]

I've noticed that in some systems you may have backticks surrounding column names. The following options work both with backticks and without.

Option 1 (no regex): str(col).replace("`", "").split("'")[-2].split(" AS ")[-1])

from pyspark.sql.functions import col
col_1 = col('foo')
col_2 = col('foo').alias('bar')
col_3 = col('foo').alias('bar').alias('baz')

s = str(col_1)
print(col_1)
print(s.replace("`", "").split("'")[-2].split(" AS ")[-1])
# Column<'foo'>
# foo

s = str(col_2)
print(col_2)
print(s.replace("`", "").split("'")[-2].split(" AS ")[-1])
# Column<'foo AS bar'>
# bar

s = str(col_3)
print(col_3)
print(s.replace("`", "").split("'")[-2].split(" AS ")[-1])
# Column<'foo AS bar AS baz'>
# baz

Option 2 (regex): pattern '.*?`?(\w+)`?' looks safe enough:
re.search(r"'.*?`?(\w+)`?'", str(col)).group(1)

from pyspark.sql.functions import col
col_1 = col('foo')
col_2 = col('foo').alias('bar')
col_3 = col('foo').alias('bar').alias('baz')

import re

print(col_1)
print(re.search(r"'.*?`?(\w+)`?'", str(col_1)).group(1))
# Column<'foo'>
# foo

print(col_2)
print(re.search(r"'.*?`?(\w+)`?'", str(col_2)).group(1))
# Column<'foo AS bar'>
# bar

print(col_3)
print(re.search(r"'.*?`?(\w+)`?'", str(col_3)).group(1))
# Column<'foo AS bar AS baz'>
# baz

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 pault
Solution 2
Solution 3