increase efficiency
This commit is contained in:
parent
138bb51997
commit
0b50156255
|
@ -1,4 +1,5 @@
|
|||
package com.twitter.follow_recommendations.common.rankers.weighted_candidate_source_ranker
|
||||
|
||||
import com.twitter.follow_recommendations.common.base.Ranker
|
||||
import com.twitter.follow_recommendations.common.models.CandidateUser
|
||||
import com.twitter.follow_recommendations.common.rankers.common.DedupCandidates
|
||||
|
@ -19,65 +20,42 @@ import com.twitter.timelines.configapi.HasParams
|
|||
* @param shuffleFn the shuffle function that will be used to shuffle each algorithm's sorted candidate list.
|
||||
* @param dedup whether to remove duplicated candidates from the final output.
|
||||
*/
|
||||
class WeightedCandidateSourceRanker[Target <: HasParams](
|
||||
basedRanker: WeightedCandidateSourceBaseRanker[
|
||||
CandidateSourceIdentifier,
|
||||
CandidateUser
|
||||
],
|
||||
shuffleFn: Seq[CandidateUser] => Seq[CandidateUser],
|
||||
dedup: Boolean)
|
||||
extends Ranker[Target, CandidateUser] {
|
||||
|
||||
val name: String = this.getClass.getSimpleName
|
||||
|
||||
class WeightedCandidateSourceRanker(
|
||||
basedRanker: WeightedCandidateSourceBaseRanker[CandidateSourceIdentifier, CandidateUser],
|
||||
shuffleFn: Seq[CandidateUser] => Seq[CandidateUser],
|
||||
dedup: Boolean
|
||||
) extends Ranker[Target, CandidateUser] {
|
||||
|
||||
override def rank(target: Target, candidates: Seq[CandidateUser]): Stitch[Seq[CandidateUser]] = {
|
||||
val scribeRankingInfo: Boolean =
|
||||
target.params(WeightedCandidateSourceRankerParams.ScribeRankingInfoInWeightedRanker)
|
||||
val rankedCands = rankCandidates(group(candidates))
|
||||
Stitch.value(if (scribeRankingInfo) Utils.addRankingInfo(rankedCands, name) else rankedCands)
|
||||
val scribeRankingInfo = target.params(WeightedCandidateSourceRankerParams.ScribeRankingInfoInWeightedRanker)
|
||||
val rankedCandidates = rankCandidates(group(candidates))
|
||||
Stitch.value(if (scribeRankingInfo) Utils.addRankingInfo(rankedCandidates, name) else rankedCandidates)
|
||||
}
|
||||
|
||||
private def group(
|
||||
candidates: Seq[CandidateUser]
|
||||
): Map[CandidateSourceIdentifier, Seq[CandidateUser]] = {
|
||||
val flattened = for {
|
||||
candidate <- candidates
|
||||
identifier <- candidate.getPrimaryCandidateSource
|
||||
} yield (identifier, candidate)
|
||||
flattened.groupBy(_._1).mapValues(_.map(_._2))
|
||||
private def group(candidates: Seq[CandidateUser]): Map[CandidateSourceIdentifier, Seq[CandidateUser]] = {
|
||||
candidates.flatMap(_.getPrimaryCandidateSource.map(identifier => (identifier, Seq(_)))).toMap
|
||||
}
|
||||
|
||||
private def rankCandidates(
|
||||
input: Map[CandidateSourceIdentifier, Seq[CandidateUser]]
|
||||
): Seq[CandidateUser] = {
|
||||
// Sort and shuffle candidates per candidate source.
|
||||
// Note 1: Using map instead mapValue here since mapValue somehow caused infinite loop when used as part of Stream.
|
||||
val sortAndShuffledCandidates = input.map {
|
||||
case (source, candidates) =>
|
||||
// Note 2: toList is required here since candidates is a view, and it will result in infinit loop when used as part of Stream.
|
||||
// Note 3: there is no real sorting logic here, it assumes the input is already sorted by candidate sources
|
||||
val sortedCandidates = candidates.toList
|
||||
source -> shuffleFn(sortedCandidates).iterator
|
||||
}
|
||||
private def rankCandidates(input: Map[CandidateSourceIdentifier, Seq[CandidateUser]]): Seq[CandidateUser] = {
|
||||
val sortAndShuffledCandidates = input.mapValues(shuffleFn compose (_.toList)).toSeq
|
||||
val rankedCandidates = basedRanker(sortAndShuffledCandidates)
|
||||
|
||||
if (dedup) DedupCandidates(rankedCandidates) else rankedCandidates
|
||||
}
|
||||
|
||||
val name: String = getClass.getSimpleName
|
||||
}
|
||||
|
||||
object WeightedCandidateSourceRanker {
|
||||
|
||||
def build[Target <: HasParams](
|
||||
def build(
|
||||
candidateSourceWeight: Map[CandidateSourceIdentifier, Double],
|
||||
shuffleFn: Seq[CandidateUser] => Seq[CandidateUser] = identity,
|
||||
dedup: Boolean = false,
|
||||
randomSeed: Option[Long] = None
|
||||
): WeightedCandidateSourceRanker[Target] = {
|
||||
): WeightedCandidateSourceRanker = {
|
||||
new WeightedCandidateSourceRanker(
|
||||
new WeightedCandidateSourceBaseRanker(
|
||||
candidateSourceWeight,
|
||||
WeightMethod.WeightedRandomSampling,
|
||||
randomSeed = randomSeed),
|
||||
new WeightedCandidateSourceBaseRanker(candidateSourceWeight, WeightMethod.WeightedRandomSampling, randomSeed),
|
||||
shuffleFn,
|
||||
dedup
|
||||
)
|
||||
|
@ -85,16 +63,11 @@ object WeightedCandidateSourceRanker {
|
|||
}
|
||||
|
||||
object WeightedCandidateSourceRankerWithoutRandomSampling {
|
||||
def build[Target <: HasParams](
|
||||
candidateSourceWeight: Map[CandidateSourceIdentifier, Double]
|
||||
): WeightedCandidateSourceRanker[Target] = {
|
||||
def build(candidateSourceWeight: Map[CandidateSourceIdentifier, Double]): WeightedCandidateSourceRanker = {
|
||||
new WeightedCandidateSourceRanker(
|
||||
new WeightedCandidateSourceBaseRanker(
|
||||
candidateSourceWeight,
|
||||
WeightMethod.WeightedRoundRobin,
|
||||
randomSeed = None),
|
||||
new WeightedCandidateSourceBaseRanker(candidateSourceWeight, WeightMethod.WeightedRoundRobin, None),
|
||||
identity,
|
||||
false,
|
||||
false
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue