Details
-
Bug
-
Status: Open
-
Minor
-
Resolution: Unresolved
-
3.1.1
-
None
-
None
Description
I originally posted a question on SO because I thought perhaps I was doing something wrong:
https://stackoverflow.com/questions/69308560
Perhaps I am, but I'm now fairly convinced that there's something wonky with the implementation of `first` that's causing it to unnecessarily have a much worse complexity than `last`.
More or less copy-pasted from the above post:
I was working on a pyspark routine to interpolate the missing values in a configuration table.
Imagine a table of configuration values that go from 0 to 50,000. The user specifies a few data points in between (say at 0, 50, 100, 500, 2000, 500000) and we interpolate the remainder. My solution mostly follows this blog post quite closely, except I'm not using any UDFs.
In troubleshooting the performance of this (takes ~3 minutes) I found that one particular window function is taking all of the time, and everything else I'm doing takes mere seconds.
Here is the main area of interest - where I use window functions to fill in the previous and next user-supplied configuration values:
from pyspark.sql import Window, functions as F # Create partition windows that are required to generate new rows from the ones provided win_last = Window.partitionBy('PORT_TYPE', 'loss_process').orderBy('rank').rowsBetween(Window.unboundedPreceding, 0) win_next = Window.partitionBy('PORT_TYPE', 'loss_process').orderBy('rank').rowsBetween(0, Window.unboundedFollowing) # Join back in the provided config table to populate the "known" scale factors df_part1 = (df_scale_factors_template .join(df_users_config, ['PORT_TYPE', 'loss_process', 'rank'], 'leftouter') # Add computed columns that can lookup the prior config and next config for each missing value .withColumn('last_rank', F.last( F.col('rank'), ignorenulls=True).over(win_last)) .withColumn('last_sf', F.last( F.col('scale_factor'), ignorenulls=True).over(win_last)) ).cache() debug_log_dataframe(df_part1 , 'df_part1') # Force a .count() and time Part1 df_part2 = (df_part1 .withColumn('next_rank', F.first(F.col('rank'), ignorenulls=True).over(win_next)) .withColumn('next_sf', F.first(F.col('scale_factor'), ignorenulls=True).over(win_next)) ).cache() debug_log_dataframe(df_part2 , 'df_part2') # Force a .count() and time Part2 df_part3 = (df_part2 # Implements standard linear interpolation: y = y1 + ((y2-y1)/(x2-x1)) * (x-x1) .withColumn('scale_factor', F.when(F.col('last_rank')==F.col('next_rank'), F.col('last_sf')) # Handle div/0 case .otherwise(F.col('last_sf') + ((F.col('next_sf')-F.col('last_sf'))/(F.col('next_rank')-F.col('last_rank'))) * (F.col('rank')-F.col('last_rank')))) .select('PORT_TYPE', 'loss_process', 'rank', 'scale_factor') ).cache() debug_log_dataframe(df_part3, 'df_part3', explain: True)
The above used to be a single chained dataframe statement, but I've since split it into 3 parts so that I could isolate the part that's taking so long. The results are:
- Part 1: Generated 8 columns and 300006 rows in 0.65 seconds
- Part 2: Generated 10 columns and 300006 rows in 189.55 seconds
- Part 3: Generated 4 columns and 300006 rows in 0.24 seconds
In trying various things to speed up my routine, it occurred to me to try re-rewriting my usages of first() to just be usages of last() with a reversed sort order.
So rewriting this:
win_next = (Window.partitionBy('PORT_TYPE', 'loss_process') .orderBy('rank').rowsBetween(0, Window.unboundedFollowing)) df_part2 = (df_part1 .withColumn('next_rank', F.first(F.col('rank'), ignorenulls=True).over(win_next)) .withColumn('next_sf', F.first(F.col('scale_factor'), ignorenulls=True).over(win_next)) )
As this:
win_next = (Window.partitionBy('PORT_TYPE', 'loss_process') .orderBy(F.desc('rank')).rowsBetween(Window.unboundedPreceding, 0)) df_part2 = (df_part1 .withColumn('next_rank', F.last(F.col('rank'), ignorenulls=True).over(win_next)) .withColumn('next_sf', F.last(F.col('scale_factor'), ignorenulls=True).over(win_next)) )
Much to my amazement, this actually solved the performance problem, and now the entire dataframe is generated in just 3 seconds.
I don't know anything about the internals, but conceptually I feel as though the initial solution should be faster, because all 4 columns should be able to take advantage of the same window and sort order by merely look forwards or backwards along the window - re-sorting like this shouldn't be necessary.