from pyspark.sql import Window
from pyspark.sql.functions import nth_value
df = self.spark.createDataFrame(
[
("a", 0, None),
("a", 1, "x"),
("a", 2, "y"),
("a", 3, "z"),
("a", 4, None),
("b", 1, None),
("b", 2, None),
],
schema=("key", "order", "value"),
)
w = Window.partitionBy("key").orderBy("order")
rs = df.select(
df.key,
df.order,
nth_value("value", 2).over(w),
nth_value("value", 2, False).over(w),
nth_value("value", 2, True).over(w),
).collect()
expected = [
("a", 0, None, None, None),
("a", 1, "x", "x", None),
("a", 2, "x", "x", "y"),
("a", 3, "x", "x", "y"),
("a", 4, "x", "x", "y"),
("b", 1, None, None, None),
("b", 2, None, None, None),
]
for r, ex in zip(sorted(rs), sorted(expected)):
self.assertEqual(tuple(r), ex[: len(r)])