diff --git a/follow-recommendations-service/common/src/main/scala/com/twitter/follow_recommendations/common/rankers/weighted_candidate_source_ranker/WeightedCandidateSourceRanker.scala b/follow-recommendations-service/common/src/main/scala/com/twitter/follow_recommendations/common/rankers/weighted_candidate_source_ranker/WeightedCandidateSourceRanker.scala index c6f55adbc..58ef8d33a 100644 --- a/follow-recommendations-service/common/src/main/scala/com/twitter/follow_recommendations/common/rankers/weighted_candidate_source_ranker/WeightedCandidateSourceRanker.scala +++ b/follow-recommendations-service/common/src/main/scala/com/twitter/follow_recommendations/common/rankers/weighted_candidate_source_ranker/WeightedCandidateSourceRanker.scala @@ -1,4 +1,5 @@ package com.twitter.follow_recommendations.common.rankers.weighted_candidate_source_ranker + import com.twitter.follow_recommendations.common.base.Ranker import com.twitter.follow_recommendations.common.models.CandidateUser import com.twitter.follow_recommendations.common.rankers.common.DedupCandidates @@ -19,65 +20,42 @@ import com.twitter.timelines.configapi.HasParams * @param shuffleFn the shuffle function that will be used to shuffle each algorithm's sorted candidate list. * @param dedup whether to remove duplicated candidates from the final output. */ -class WeightedCandidateSourceRanker[Target <: HasParams]( - basedRanker: WeightedCandidateSourceBaseRanker[ - CandidateSourceIdentifier, - CandidateUser - ], - shuffleFn: Seq[CandidateUser] => Seq[CandidateUser], - dedup: Boolean) - extends Ranker[Target, CandidateUser] { - val name: String = this.getClass.getSimpleName + +class WeightedCandidateSourceRanker( + basedRanker: WeightedCandidateSourceBaseRanker[CandidateSourceIdentifier, CandidateUser], + shuffleFn: Seq[CandidateUser] => Seq[CandidateUser], + dedup: Boolean +) extends Ranker[Target, CandidateUser] { override def rank(target: Target, candidates: Seq[CandidateUser]): Stitch[Seq[CandidateUser]] = { - val scribeRankingInfo: Boolean = - target.params(WeightedCandidateSourceRankerParams.ScribeRankingInfoInWeightedRanker) - val rankedCands = rankCandidates(group(candidates)) - Stitch.value(if (scribeRankingInfo) Utils.addRankingInfo(rankedCands, name) else rankedCands) + val scribeRankingInfo = target.params(WeightedCandidateSourceRankerParams.ScribeRankingInfoInWeightedRanker) + val rankedCandidates = rankCandidates(group(candidates)) + Stitch.value(if (scribeRankingInfo) Utils.addRankingInfo(rankedCandidates, name) else rankedCandidates) } - private def group( - candidates: Seq[CandidateUser] - ): Map[CandidateSourceIdentifier, Seq[CandidateUser]] = { - val flattened = for { - candidate <- candidates - identifier <- candidate.getPrimaryCandidateSource - } yield (identifier, candidate) - flattened.groupBy(_._1).mapValues(_.map(_._2)) + private def group(candidates: Seq[CandidateUser]): Map[CandidateSourceIdentifier, Seq[CandidateUser]] = { + candidates.flatMap(_.getPrimaryCandidateSource.map(identifier => (identifier, Seq(_)))).toMap } - private def rankCandidates( - input: Map[CandidateSourceIdentifier, Seq[CandidateUser]] - ): Seq[CandidateUser] = { - // Sort and shuffle candidates per candidate source. - // Note 1: Using map instead mapValue here since mapValue somehow caused infinite loop when used as part of Stream. - val sortAndShuffledCandidates = input.map { - case (source, candidates) => - // Note 2: toList is required here since candidates is a view, and it will result in infinit loop when used as part of Stream. - // Note 3: there is no real sorting logic here, it assumes the input is already sorted by candidate sources - val sortedCandidates = candidates.toList - source -> shuffleFn(sortedCandidates).iterator - } + private def rankCandidates(input: Map[CandidateSourceIdentifier, Seq[CandidateUser]]): Seq[CandidateUser] = { + val sortAndShuffledCandidates = input.mapValues(shuffleFn compose (_.toList)).toSeq val rankedCandidates = basedRanker(sortAndShuffledCandidates) - if (dedup) DedupCandidates(rankedCandidates) else rankedCandidates } + + val name: String = getClass.getSimpleName } object WeightedCandidateSourceRanker { - - def build[Target <: HasParams]( + def build( candidateSourceWeight: Map[CandidateSourceIdentifier, Double], shuffleFn: Seq[CandidateUser] => Seq[CandidateUser] = identity, dedup: Boolean = false, randomSeed: Option[Long] = None - ): WeightedCandidateSourceRanker[Target] = { + ): WeightedCandidateSourceRanker = { new WeightedCandidateSourceRanker( - new WeightedCandidateSourceBaseRanker( - candidateSourceWeight, - WeightMethod.WeightedRandomSampling, - randomSeed = randomSeed), + new WeightedCandidateSourceBaseRanker(candidateSourceWeight, WeightMethod.WeightedRandomSampling, randomSeed), shuffleFn, dedup ) @@ -85,16 +63,11 @@ object WeightedCandidateSourceRanker { } object WeightedCandidateSourceRankerWithoutRandomSampling { - def build[Target <: HasParams]( - candidateSourceWeight: Map[CandidateSourceIdentifier, Double] - ): WeightedCandidateSourceRanker[Target] = { + def build(candidateSourceWeight: Map[CandidateSourceIdentifier, Double]): WeightedCandidateSourceRanker = { new WeightedCandidateSourceRanker( - new WeightedCandidateSourceBaseRanker( - candidateSourceWeight, - WeightMethod.WeightedRoundRobin, - randomSeed = None), + new WeightedCandidateSourceBaseRanker(candidateSourceWeight, WeightMethod.WeightedRoundRobin, None), identity, - false, + false ) } }