uncategorized

Reshaping Dataframe Using Pivot and Melt in Apache Spark and Pandas

Data cleaning is one of the most important and tedious part of data science workflow often mentioned but least discussed topic. Reflecting on my daily workflow, task of reshaping DataFrame is the very common operation I often do to get the data in desired format. Reshaping dataframe means transformation of the table structure, may be remove/adding of columns/rows or doing some aggregations on certains rows and produce a new column to summerize the aggregation result. In this post I won’t cover everything about reshaping, but I will discuss two most frequently used operations i.e. pivot and melt. The solutions I discuss are in spark to be more specific pyspark and I will give you brief solution for pandas but if you want detail explanation of pandas solution I would recommend you to read this post.

#Pivoting operation on data

Data is usually stored in stacked format or record format, which can sometimes have repeated values which means the data is not normalized. When you want to summarize the data in a tabular view, pivot can be a very use transformation. Let take an example of online business which sell music and books and sales data(over simplified) is shown table below.

TABLE A

productcategoryquarterprofit
memoriesbookq110
dreamsbookq220
reflectionsbookq330
how to build a housebookq440
wonderful lifemusicq110
million milesmusicq220
run awaymusicq330
mind and bodymusicq440

Above table has four columns, product and category are pretty much self explanatory, columns quarter and profit columns describe the profit contributed by the each product for that quarter. Some observation that can be made about table is that it is aggregation for the each product profit for each quarter but difficult to grasp in first look. This format of the table can be seen as:

  1. stacked format : individual observation for each product is stacked on each other.
  2. record format : each row is the record for a song or a book.
  3. long format : if there are million music tables will be long and the columns are not too much.

If we want to summarize the quarterly profit for each category of product in tabular format then the table below would have been more appropriate.

TABLE B

categoryq1q2q3q4
music10203040
book10203040

The table above is much more intuitive compared to TABLE A. This is what pivot operation will help us to achieve. Pivot will take unique value of a specific column/columns and turn it into one or more columns with the unique values of that column as the name of the columns, in our example q1-q4 were the unique value of the column quarter so a new columns is created for each quarter has column, this newly added columns are called pivot columns. Spark provides pivot functions in DataFrame object to for pivot transformation. Pivot functions requires four parameters the on which as as follows:

  1. Pivot column is the column who’s unique values will become pivot columns. In case of our example category column is the pivot.
  2. Value column is the column whos value will be aggregated and mapped to index column, profit columns is what we are aggregating for this example so its value column.
  3. Index column is the column which you want to use it as a index for the pivot columns. For this example category is the index column.
  4. Aggregation function in case if there are more the one row for a the column we are pivoting on. We have used sum function for this example, if there are more then one row for books category for q4 then we will sum profit for q4, but can change the aggregating function depending on question you are trying to answer. One could use average function to find average cost of book sold in each quarter.

#Pivot in Spark

Lets try some code example, below is the pyspark implementation of pivot transformation:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from pyspark.sql import SQLContext

sqlContext = SQLContext(sc)

# create RDD for products
data = sc.parallelize([
['memories','book','q1',10],
['dreams','book','q2',20],
['reflections','book','q3',30],
['how to build a house','book','q4',40],
['wonderful life','music','q1',10],
['million miles','music','q2',20],
['run away','music','q3',30],
['mind and body','music','q4',40],
])

# convert the RDD to DataFrame
df_products = sqlContext.createDataFrame(data, ['product','category','quarter','profit'])

# index column : category
# value column : profit
# pivot column : quarter
# agg function : sum
#
# apply pivot on DataFrame DataFrame
df_products.groupBy('category').pivot('quarter').sum('profit').show()

We first create a RDD of the product then, create a DataFrame from that RDD. In the next step pivot transformation is applied. As you might have noticed that you don’t exclusively pass parameters to pivot function, pivot function only take the name of the pivot column. The result of the above code snippet is the TABLE B.

#Pivot in Pandas

Pivot method is also available in pandas library which take the same four parameters we described above, difference is in pandas just one method call will be provided with all the four parameters. The code below first converts the Spark’s Dataframe to pandas DataFrame and then apply pivot__table function on the DataFrame, resulting DataFrame will look like TABLE B

1
2
3
4
5
6
7
8
# convert from spark DataFrame to pandas DataFrame
df_products_pd = df_products.toPandas()

# apply pivot on pandas DataFrame
df_products_pd.pivot_table(index='category'
,columns='quarter'
,values='profit'
,aggfunc=sum)

#Common Mistakes in Pivot operation

One of the common mistake people make while using pivot transformation is that they try to apply it on column with numeric values, while pivot function is suppose to be used on column with categorical values. In case of our example, we applied in on quarter column which has all categorical values.

#Melt operation on data

Melt transformation is opposite of of pivot transformation. With this operation data is converted from wide(unstacked) format to stacked/long format. Pivot operation help you to give a quick overview(tabular view) of the data, which is good for human analysis but difficult to do complex operations like grouping, etc. While the tables in wide format are to pretty to summarize data but difficult for analysis that where melt operation help us. Lets take and example of pewforum.org Income data of various religious group in the US. Tabular view of the data is as follows

religion<$10k$10-20k$20-30k$30-40k$40-50k$50-75k$75-100k$100-150k>150k
Agnostic273460817613712210984
Atheist122737523570735974
Buddhist272130343358623953
Catholic4186177326706381116949792633

This table gives a good summarization of income of a particular religious group, but it would be good to have table with three column religion income and number of individual for that income, this format can help us to do some complex analysis like grouping the table by income category and finding which religious tradition is missing or present in a particular income category.

#Melt in Spark

Melt operation API is not provided by spark, but its not that difficult to create this operation. Below is the code to that implementations the melt function.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from pyspark.sql.functions import array, col, explode, lit, struct
from pyspark.sql import DataFrame
from typing import Iterable

def melt_df(
df: DataFrame,
id_vars: Iterable[str], value_vars: Iterable[str],
var_name: str="variable", value_name: str="value") -> DataFrame:
"""Convert :class:`DataFrame` from wide to long format."""

# Create array<struct<variable: str, value: ...>>
_vars_and_vals = array(*(
struct(lit(c).alias(var_name), col(c).alias(value_name))
for c in value_vars))

# Add to the DataFrame and explode
_tmp = df.withColumn("_vars_and_vals", explode(_vars_and_vals))

cols = id_vars + [
col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name]]
return _tmp.select(*cols)

lets try the apply melt_df function on the religions group income dataset. melt_df function takes five parameters:

  1. df : The Dataframe on which the operation will be carried out.
  2. id_vars : array of columns which will be the index to which the values of the columns to which matched to. In out example religion is the only id_vars, as we want to map it to various income class.
  3. value_vars: while id_vars help use to find the index of the values, this is the actual values will be extracted from these columns.
  4. var_name: the name of the variable column in the resulting DataFrame.
  5. value_name: this is the name of the value variable in the resulting DataFrame.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from pyspark.sql import SQLContext

sqlContext = SQLContext(sc)

# create RDD for products
data = sc.parallelize([
['Agnostic' ,27,34,60,81,76,137,122,109,84],
['Atheist',12,27,37,52,35,70,73,59,74],
['Buddhist',27,21,30,34,33,58,62,39,53],
['Catholic',418,617,732,670,638,1116,949,792,633],
])

# column headers
table_columns = ["religion","<$10k","$10-20k","$20-30k","$30-40k",
"$40-50k","$50-75k","$75-100k","$100-150k",">150k"]

#create the DataFrame
df_rel_trad = sqlContext.createDataFrame(data, table_columns)

df_rel = melt_df(df_rel_trad, ['religion'], table_columns[1:10], 'income', 'count')

df_rel.show()

The above code will give the output as shown in below table

religionincomecount
Agnostic<$10k27
Agnostic$10-20k34
Agnostic$75-100k122
Agnostic>150k84
Atheist<$10k12
Atheist$10-20k27
Atheist$100-150k59
Atheist>150k74
Buddhist<$10k27
Buddhist$10-20k21

#Melt in pandas

pandas provides melt operator which is the snippet as below the parameters are same as explained previously. In the below example we have create a pandas dataframe and the applied melt operation the results are the same for the previous example.

1
2
3
4
5
6
7
8
9
10
11
12
13
# religious_df is the dataframe which stores above table
In [1]: value_variables = ['Less than $30000','$30000-$49999','$50000-$99999','$100000 or more']

In [2]: religious_df = religious_df.melt(id_vars=['Religious tradition'], value_vars=value_variables)

In [3]: religious_df[religious_df['Religious tradition'] == 'Buddhist']

Out[3]:
Religious tradition variable value
0 Buddhist Less than $30000 36
12 Buddhist $30000-$49999 18
24 Buddhist $50000-$99999 32
36 Buddhist $100000 or more 13

#Conclusion

We just saw how to implement the pivot and melt transformation which reshapes the DataFrame. To summarize what we have learnt pivot operation can be helpful to quickly summerize the table in tabular format which is easy for the human analysis. While on contrast data is tabular format is not quite helpful for complex analysis, this is where this melt operation converts the table from wide format to long format. There is quite a bit to say about which format is suitable in what situation which is topic of this post, I have discussed the details of various heustics of tidying the data.

#Useful Links

  1. Spark Pivot API docs
  2. Databricks blog on pivot
  3. Stackoverflow anwser of pyspark melt operation
Share