I am given a data set which contains tr_id (travelers id), train_id (which train ride) all of type integer and unique per traveler and per train ride respectively. I am then asked to determine the travelers who have been on more than 3 rides together.
So the output should look like:
tr_id 1 | tr_id 2 | number of rides together
48 | 54 | 6
My thought process: first delete all travelers that have been on three rides or less
val window = Window.partitionBy("tr_id")
val moreThanThreeRides_df = flight_df.withColumn("count", count("train_id").over(window))
.filter(col("count") > 3).drop("count").orderBy("passengerId", "flightId")
Then I would create a set containing all train_id’s for the corresponding tr_id and map one to the other.
val collectIds_df = moreThanThreeRides_df.groupBy("tr_id")
.agg(collect_set(struct("train_id")).as("Set of Rides")).orderBy("tr_id")
collectIds_df.withColumn("mapToRides", map(col("tr_id"),col("Set of Rides")))
How I would proceed: intersect the set of flights from one traveler with that of the other traveler, check if it contains more than 3 elements and if so return the respective tr_id’s and number of elements. This would require a nested for loop but the inner loop requires less iterations as (tr_id1,tr_id2) = (tr_id2,tr_id1) as the order does not matter. I have tried to write this down in Scala but without success. How do I write this down?
If there is any better or efficient way of solving this I would be open for that as well. All your help is appreciated!