[docx] split commit for file 2400
Signed-off-by: Ari Archer <ari.web.xyz@gmail.com>
This commit is contained in:
parent
b471ac86b4
commit
53c2f869f3
Binary file not shown.
|
@ -1,16 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.ClearCacheTimelineInstruction
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
case class ClearCacheInstructionBuilder[Query <: PipelineQuery](
|
||||
override val includeInstruction: IncludeInstruction[Query] = AlwaysInclude)
|
||||
extends UrtInstructionBuilder[Query, ClearCacheTimelineInstruction] {
|
||||
|
||||
override def build(
|
||||
query: Query,
|
||||
entries: Seq[TimelineEntry]
|
||||
): Seq[ClearCacheTimelineInstruction] =
|
||||
if (includeInstruction(query, entries)) Seq(ClearCacheTimelineInstruction()) else Seq.empty
|
||||
}
|
Binary file not shown.
|
@ -1,33 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.component_library.model.cursor.UrtPassThroughCursor
|
||||
import com.twitter.product_mixer.component_library.premarshaller.cursor.UrtCursorSerializer
|
||||
import com.twitter.product_mixer.core.feature.Feature
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorType
|
||||
import com.twitter.product_mixer.core.pipeline.HasPipelineCursor
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
case class PassThroughCursorBuilder[
|
||||
-Query <: PipelineQuery with HasPipelineCursor[UrtPassThroughCursor]
|
||||
](
|
||||
cursorFeature: Feature[Query, String],
|
||||
override val cursorType: CursorType)
|
||||
extends UrtCursorBuilder[Query] {
|
||||
|
||||
override val includeOperation: IncludeInstruction[Query] = { (query, _) =>
|
||||
query.features.exists(_.getOrElse(cursorFeature, "").nonEmpty)
|
||||
}
|
||||
|
||||
override def cursorValue(
|
||||
query: Query,
|
||||
entries: Seq[TimelineEntry]
|
||||
): String =
|
||||
UrtCursorSerializer.serializeCursor(
|
||||
UrtPassThroughCursor(
|
||||
cursorSortIndex(query, entries),
|
||||
query.features.map(_.get(cursorFeature)).getOrElse(""),
|
||||
cursorType = Some(cursorType)
|
||||
)
|
||||
)
|
||||
}
|
Binary file not shown.
|
@ -1,31 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.pipeline.HasPipelineCursor
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
trait IncludeInstruction[-Query <: PipelineQuery] { self =>
|
||||
def apply(query: Query, entries: Seq[TimelineEntry]): Boolean
|
||||
|
||||
def inverse(): IncludeInstruction[Query] = new IncludeInstruction[Query] {
|
||||
def apply(query: Query, entries: Seq[TimelineEntry]): Boolean = !self.apply(query, entries)
|
||||
}
|
||||
}
|
||||
|
||||
object AlwaysInclude extends IncludeInstruction[PipelineQuery] {
|
||||
override def apply(query: PipelineQuery, entries: Seq[TimelineEntry]): Boolean = true
|
||||
}
|
||||
|
||||
object IncludeOnFirstPage extends IncludeInstruction[PipelineQuery with HasPipelineCursor[_]] {
|
||||
override def apply(
|
||||
query: PipelineQuery with HasPipelineCursor[_],
|
||||
entries: Seq[TimelineEntry]
|
||||
): Boolean = query.isFirstPage
|
||||
}
|
||||
|
||||
object IncludeAfterFirstPage extends IncludeInstruction[PipelineQuery with HasPipelineCursor[_]] {
|
||||
override def apply(
|
||||
query: PipelineQuery with HasPipelineCursor[_],
|
||||
entries: Seq[TimelineEntry]
|
||||
): Boolean = !query.isFirstPage
|
||||
}
|
Binary file not shown.
|
@ -1,32 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.MarkEntriesUnreadInstruction
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.metadata.MarkUnreadableEntry
|
||||
|
||||
/**
|
||||
* Build a MarkUnreadEntries instruction
|
||||
*
|
||||
* Note that this implementation currently supports top-level entries, but not module item entries.
|
||||
*/
|
||||
case class MarkUnreadInstructionBuilder[Query <: PipelineQuery](
|
||||
override val includeInstruction: IncludeInstruction[Query] = AlwaysInclude)
|
||||
extends UrtInstructionBuilder[Query, MarkEntriesUnreadInstruction] {
|
||||
|
||||
override def build(
|
||||
query: Query,
|
||||
entries: Seq[TimelineEntry]
|
||||
): Seq[MarkEntriesUnreadInstruction] = {
|
||||
if (includeInstruction(query, entries)) {
|
||||
val filteredEntries = entries.collect {
|
||||
case entry: MarkUnreadableEntry if entry.isMarkUnread.contains(true) =>
|
||||
entry.entryIdentifier
|
||||
}
|
||||
if (filteredEntries.nonEmpty) Seq(MarkEntriesUnreadInstruction(filteredEntries))
|
||||
else Seq.empty
|
||||
} else {
|
||||
Seq.empty
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,45 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.component_library.model.cursor.UrtOrderedCursor
|
||||
import com.twitter.product_mixer.component_library.premarshaller.cursor.UrtCursorSerializer
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.BottomCursor
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorType
|
||||
import com.twitter.product_mixer.core.pipeline.HasPipelineCursor
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineCursorSerializer
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
/**
|
||||
* Builds [[UrtOrderedCursor]] in the Bottom position
|
||||
*
|
||||
* @param idSelector Specifies the entry from which to derive the `id` field
|
||||
* @param includeOperation Logic to determine whether or not to build the bottom cursor, which only
|
||||
* applies if gap cursors are required (e.g. Home Latest). When applicable,
|
||||
* this logic should always be the inverse of the logic used to decide
|
||||
* whether or not to build the gap cursor via [[OrderedGapCursorBuilder]],
|
||||
* since either the gap or the bottom cursor must always be returned.
|
||||
* @param serializer Converts the cursor to an encoded string
|
||||
*/
|
||||
case class OrderedBottomCursorBuilder[
|
||||
-Query <: PipelineQuery with HasPipelineCursor[UrtOrderedCursor]
|
||||
](
|
||||
idSelector: PartialFunction[TimelineEntry, Long],
|
||||
override val includeOperation: IncludeInstruction[Query] = AlwaysInclude,
|
||||
serializer: PipelineCursorSerializer[UrtOrderedCursor] = UrtCursorSerializer)
|
||||
extends UrtCursorBuilder[Query] {
|
||||
override val cursorType: CursorType = BottomCursor
|
||||
|
||||
override def cursorValue(query: Query, timelineEntries: Seq[TimelineEntry]): String = {
|
||||
val bottomId = timelineEntries.reverseIterator.collectFirst(idSelector)
|
||||
|
||||
val id = bottomId.orElse(query.pipelineCursor.flatMap(_.id))
|
||||
|
||||
val cursor = UrtOrderedCursor(
|
||||
initialSortIndex = nextBottomInitialSortIndex(query, timelineEntries),
|
||||
id = id,
|
||||
cursorType = Some(cursorType)
|
||||
)
|
||||
|
||||
serializer.serializeCursor(cursor)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,54 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.component_library.model.cursor.UrtOrderedCursor
|
||||
import com.twitter.product_mixer.component_library.premarshaller.cursor.UrtCursorSerializer
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorType
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.GapCursor
|
||||
import com.twitter.product_mixer.core.pipeline.HasPipelineCursor
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineCursorSerializer
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
/**
|
||||
* Builds [[UrtOrderedCursor]] in the Bottom position as a Gap cursor.
|
||||
*
|
||||
* @param idSelector Specifies the entry from which to derive the `id` field
|
||||
* @param includeOperation Logic to determine whether or not to build the gap cursor, which should
|
||||
* always be the inverse of the logic used to decide whether or not to build
|
||||
* the bottom cursor via [[OrderedBottomCursorBuilder]], since either the
|
||||
* gap or the bottom cursor must always be returned.
|
||||
* @param serializer Converts the cursor to an encoded string
|
||||
*/
|
||||
case class OrderedGapCursorBuilder[
|
||||
-Query <: PipelineQuery with HasPipelineCursor[UrtOrderedCursor]
|
||||
](
|
||||
idSelector: PartialFunction[TimelineEntry, Long],
|
||||
override val includeOperation: IncludeInstruction[Query],
|
||||
serializer: PipelineCursorSerializer[UrtOrderedCursor] = UrtCursorSerializer)
|
||||
extends UrtCursorBuilder[Query] {
|
||||
override val cursorType: CursorType = GapCursor
|
||||
|
||||
override def cursorValue(
|
||||
query: Query,
|
||||
timelineEntries: Seq[TimelineEntry]
|
||||
): String = {
|
||||
// To determine the gap boundary, use any existing cursor gap boundary id (i.e. if submitted
|
||||
// from a previous gap cursor, else use the existing cursor id (i.e. from a previous top cursor)
|
||||
val gapBoundaryId = query.pipelineCursor.flatMap(_.gapBoundaryId).orElse {
|
||||
query.pipelineCursor.flatMap(_.id)
|
||||
}
|
||||
|
||||
val bottomId = timelineEntries.reverseIterator.collectFirst(idSelector)
|
||||
|
||||
val id = bottomId.orElse(gapBoundaryId)
|
||||
|
||||
val cursor = UrtOrderedCursor(
|
||||
initialSortIndex = nextBottomInitialSortIndex(query, timelineEntries),
|
||||
id = id,
|
||||
cursorType = Some(cursorType),
|
||||
gapBoundaryId = gapBoundaryId
|
||||
)
|
||||
|
||||
serializer.serializeCursor(cursor)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,52 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.component_library.model.cursor.UrtOrderedCursor
|
||||
import com.twitter.product_mixer.component_library.premarshaller.cursor.UrtCursorSerializer
|
||||
import com.twitter.product_mixer.component_library.premarshaller.urt.builder.OrderedTopCursorBuilder.TopCursorOffset
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorType
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.TopCursor
|
||||
import com.twitter.product_mixer.core.pipeline.HasPipelineCursor
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineCursorSerializer
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
case object OrderedTopCursorBuilder {
|
||||
// Ensure that the next initial sort index is at least 10000 entries away from top cursor's
|
||||
// current sort index. This is to ensure that the contents of the next page can be populated
|
||||
// without being assigned sort indices which conflict with that of the current page. This assumes
|
||||
// that each page will have fewer than 10000 entries.
|
||||
val TopCursorOffset = 10000L
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds [[UrtOrderedCursor]] in the Top position
|
||||
*
|
||||
* @param idSelector Specifies the entry from which to derive the `id` field
|
||||
* @param serializer Converts the cursor to an encoded string
|
||||
*/
|
||||
case class OrderedTopCursorBuilder(
|
||||
idSelector: PartialFunction[UniversalNoun[_], Long],
|
||||
serializer: PipelineCursorSerializer[UrtOrderedCursor] = UrtCursorSerializer)
|
||||
extends UrtCursorBuilder[
|
||||
PipelineQuery with HasPipelineCursor[UrtOrderedCursor]
|
||||
] {
|
||||
override val cursorType: CursorType = TopCursor
|
||||
|
||||
override def cursorValue(
|
||||
query: PipelineQuery with HasPipelineCursor[UrtOrderedCursor],
|
||||
timelineEntries: Seq[TimelineEntry]
|
||||
): String = {
|
||||
val topId = timelineEntries.collectFirst(idSelector)
|
||||
|
||||
val id = topId.orElse(query.pipelineCursor.flatMap(_.id))
|
||||
|
||||
val cursor = UrtOrderedCursor(
|
||||
initialSortIndex = cursorSortIndex(query, timelineEntries) + TopCursorOffset,
|
||||
id = id,
|
||||
cursorType = Some(cursorType)
|
||||
)
|
||||
|
||||
serializer.serializeCursor(cursor)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,23 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.PinEntryTimelineInstruction
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.metadata.PinnableEntry
|
||||
|
||||
case class PinEntryInstructionBuilder()
|
||||
extends UrtInstructionBuilder[PipelineQuery, PinEntryTimelineInstruction] {
|
||||
|
||||
override def build(
|
||||
query: PipelineQuery,
|
||||
entries: Seq[TimelineEntry]
|
||||
): Seq[PinEntryTimelineInstruction] = {
|
||||
// Only one entry can be pinned and the desirable behavior is to pick the entry with the highest
|
||||
// sort index in the event that multiple pinned items exist. Since the entries are already
|
||||
// sorted we can accomplish this by picking the first one.
|
||||
entries.collectFirst {
|
||||
case entry: PinnableEntry if entry.isPinned.getOrElse(false) =>
|
||||
PinEntryTimelineInstruction(entry)
|
||||
}.toSeq
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,34 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.component_library.model.cursor.UrtPlaceholderCursor
|
||||
import com.twitter.product_mixer.component_library.premarshaller.cursor.UrtCursorSerializer
|
||||
import com.twitter.product_mixer.component_library.premarshaller.urt.builder.PlaceholderTopCursorBuilder.DefaultPlaceholderCursor
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorType
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.TopCursor
|
||||
import com.twitter.product_mixer.core.pipeline.HasPipelineCursor
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineCursorSerializer
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.product_mixer.core.pipeline.UrtPipelineCursor
|
||||
|
||||
object PlaceholderTopCursorBuilder {
|
||||
val DefaultPlaceholderCursor = UrtPlaceholderCursor()
|
||||
}
|
||||
|
||||
/**
|
||||
* Top cursor builder that can be used when the Product does not support paging up. The URT spec
|
||||
* requires that both bottom and top cursors always be present on each page. Therefore, if the
|
||||
* product does not support paging up, then we can use a cursor value that is not deserializable.
|
||||
* This way if the client submits a TopCursor, the backend will treat the the request as if no
|
||||
* cursor was submitted.
|
||||
*/
|
||||
case class PlaceholderTopCursorBuilder(
|
||||
serializer: PipelineCursorSerializer[UrtPipelineCursor] = UrtCursorSerializer)
|
||||
extends UrtCursorBuilder[PipelineQuery with HasPipelineCursor[UrtPipelineCursor]] {
|
||||
override val cursorType: CursorType = TopCursor
|
||||
|
||||
override def cursorValue(
|
||||
query: PipelineQuery with HasPipelineCursor[UrtPipelineCursor],
|
||||
timelineEntries: Seq[TimelineEntry]
|
||||
): String = serializer.serializeCursor(DefaultPlaceholderCursor)
|
||||
}
|
Binary file not shown.
|
@ -1,63 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.ReplaceEntryTimelineInstruction
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorOperation
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorType
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
/**
|
||||
* Selects one or more [[TimelineEntry]] instance from the input timeline entries.
|
||||
*
|
||||
* @tparam Query The domain model for the [[PipelineQuery]] used as input.
|
||||
*/
|
||||
trait EntriesToReplace[-Query <: PipelineQuery] {
|
||||
def apply(query: Query, entries: Seq[TimelineEntry]): Seq[TimelineEntry]
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects all entries with a non-empty valid entryIdToReplace.
|
||||
*
|
||||
* @note this will result in multiple [[ReplaceEntryTimelineInstruction]]s
|
||||
*/
|
||||
case object ReplaceAllEntries extends EntriesToReplace[PipelineQuery] {
|
||||
def apply(query: PipelineQuery, entries: Seq[TimelineEntry]): Seq[TimelineEntry] =
|
||||
entries.filter(_.entryIdToReplace.isDefined)
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects a replaceable URT [[CursorOperation]] from the timeline entries, that matches the
|
||||
* input cursorType.
|
||||
*/
|
||||
case class ReplaceUrtCursor(cursorType: CursorType) extends EntriesToReplace[PipelineQuery] {
|
||||
override def apply(query: PipelineQuery, entries: Seq[TimelineEntry]): Seq[TimelineEntry] =
|
||||
entries.collectFirst {
|
||||
case cursorOperation: CursorOperation
|
||||
if cursorOperation.cursorType == cursorType && cursorOperation.entryIdToReplace.isDefined =>
|
||||
cursorOperation
|
||||
}.toSeq
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a ReplaceEntry instruction
|
||||
*
|
||||
* @param entriesToReplace each replace instruction can contain only one entry. Users specify which
|
||||
* entry to replace using [[EntriesToReplace]]. If multiple entries are
|
||||
* specified, multiple [[ReplaceEntryTimelineInstruction]]s will be created.
|
||||
* @param includeInstruction whether the instruction should be included in the response
|
||||
*/
|
||||
case class ReplaceEntryInstructionBuilder[Query <: PipelineQuery](
|
||||
entriesToReplace: EntriesToReplace[Query],
|
||||
override val includeInstruction: IncludeInstruction[Query] = AlwaysInclude)
|
||||
extends UrtInstructionBuilder[Query, ReplaceEntryTimelineInstruction] {
|
||||
|
||||
override def build(
|
||||
query: Query,
|
||||
entries: Seq[TimelineEntry]
|
||||
): Seq[ReplaceEntryTimelineInstruction] = {
|
||||
if (includeInstruction(query, entries))
|
||||
entriesToReplace(query, entries).map(ReplaceEntryTimelineInstruction)
|
||||
else
|
||||
Seq.empty
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,23 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.ShowAlert
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.ShowAlertInstruction
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
case class ShowAlertInstructionBuilder[Query <: PipelineQuery](
|
||||
override val includeInstruction: IncludeInstruction[Query] = AlwaysInclude)
|
||||
extends UrtInstructionBuilder[Query, ShowAlertInstruction] {
|
||||
|
||||
override def build(
|
||||
query: Query,
|
||||
entries: Seq[TimelineEntry]
|
||||
): Seq[ShowAlertInstruction] = {
|
||||
if (includeInstruction(query, entries)) {
|
||||
// Currently only one Alert is supported per response
|
||||
entries.collectFirst {
|
||||
case alertEntry: ShowAlert => ShowAlertInstruction(alertEntry)
|
||||
}.toSeq
|
||||
} else Seq.empty
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,24 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.ShowCoverInstruction
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.Cover
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
case class ShowCoverInstructionBuilder[Query <: PipelineQuery](
|
||||
override val includeInstruction: IncludeInstruction[Query] = AlwaysInclude)
|
||||
extends UrtInstructionBuilder[Query, ShowCoverInstruction] {
|
||||
override def build(
|
||||
query: Query,
|
||||
entries: Seq[TimelineEntry]
|
||||
): Seq[ShowCoverInstruction] = {
|
||||
if (includeInstruction(query, entries)) {
|
||||
// Currently only one cover is supported per response
|
||||
entries.collectFirst {
|
||||
case coverEntry: Cover => ShowCoverInstruction(coverEntry)
|
||||
}.toSeq
|
||||
} else {
|
||||
Seq.empty
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,15 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineScribeConfig
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
case class StaticTimelineScribeConfigBuilder(
|
||||
timelineScribeConfig: TimelineScribeConfig)
|
||||
extends TimelineScribeConfigBuilder[PipelineQuery] {
|
||||
|
||||
def build(
|
||||
query: PipelineQuery,
|
||||
entries: Seq[TimelineEntry]
|
||||
): Option[TimelineScribeConfig] = Some(timelineScribeConfig)
|
||||
}
|
Binary file not shown.
|
@ -1,44 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.BottomTermination
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TerminateTimelineInstruction
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineTerminationDirection
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TopAndBottomTermination
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TopTermination
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
sealed trait TerminateInstructionBuilder[Query <: PipelineQuery]
|
||||
extends UrtInstructionBuilder[Query, TerminateTimelineInstruction] {
|
||||
|
||||
def direction: TimelineTerminationDirection
|
||||
|
||||
override def build(
|
||||
query: Query,
|
||||
entries: Seq[TimelineEntry]
|
||||
): Seq[TerminateTimelineInstruction] =
|
||||
if (includeInstruction(query, entries))
|
||||
Seq(TerminateTimelineInstruction(terminateTimelineDirection = direction))
|
||||
else Seq.empty
|
||||
}
|
||||
|
||||
case class TerminateTopInstructionBuilder[Query <: PipelineQuery](
|
||||
override val includeInstruction: IncludeInstruction[Query] = AlwaysInclude)
|
||||
extends TerminateInstructionBuilder[Query] {
|
||||
|
||||
override val direction = TopTermination
|
||||
}
|
||||
|
||||
case class TerminateBottomInstructionBuilder[Query <: PipelineQuery](
|
||||
override val includeInstruction: IncludeInstruction[Query] = AlwaysInclude)
|
||||
extends TerminateInstructionBuilder[Query] {
|
||||
|
||||
override val direction = BottomTermination
|
||||
}
|
||||
|
||||
case class TerminateTopAndBottomInstructionBuilder[Query <: PipelineQuery](
|
||||
override val includeInstruction: IncludeInstruction[Query] = AlwaysInclude)
|
||||
extends TerminateInstructionBuilder[Query] {
|
||||
|
||||
override val direction = TopAndBottomTermination
|
||||
}
|
Binary file not shown.
|
@ -1,18 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineScribeConfig
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
/**
|
||||
* Trait for our builder which given a query and entries will return an `Option[TimelineScribeConfig]`
|
||||
*
|
||||
* @tparam Query
|
||||
*/
|
||||
trait TimelineScribeConfigBuilder[-Query <: PipelineQuery] {
|
||||
|
||||
def build(
|
||||
query: Query,
|
||||
entries: Seq[TimelineEntry]
|
||||
): Option[TimelineScribeConfig]
|
||||
}
|
Binary file not shown.
|
@ -1,43 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.component_library.model.cursor.UrtUnorderedBloomFilterCursor
|
||||
import com.twitter.product_mixer.component_library.premarshaller.cursor.UrtCursorSerializer
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.BottomCursor
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorType
|
||||
import com.twitter.product_mixer.core.pipeline.HasPipelineCursor
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineCursorSerializer
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.search.common.util.bloomfilter.AdaptiveLongIntBloomFilterBuilder
|
||||
|
||||
/**
|
||||
* Builds [[UrtUnorderedBloomFilterCursor]] in the Bottom position
|
||||
*
|
||||
* @param idSelector Specifies the entry from which to derive the `id` field
|
||||
* @param serializer Converts the cursor to an encoded string
|
||||
*/
|
||||
case class UnorderedBloomFilterBottomCursorBuilder(
|
||||
idSelector: PartialFunction[UniversalNoun[_], Long],
|
||||
serializer: PipelineCursorSerializer[UrtUnorderedBloomFilterCursor] = UrtCursorSerializer)
|
||||
extends UrtCursorBuilder[
|
||||
PipelineQuery with HasPipelineCursor[UrtUnorderedBloomFilterCursor]
|
||||
] {
|
||||
|
||||
override val cursorType: CursorType = BottomCursor
|
||||
|
||||
override def cursorValue(
|
||||
query: PipelineQuery with HasPipelineCursor[UrtUnorderedBloomFilterCursor],
|
||||
entries: Seq[TimelineEntry]
|
||||
): String = {
|
||||
val bloomFilter = query.pipelineCursor.map(_.longIntBloomFilter)
|
||||
val ids = entries.collect(idSelector)
|
||||
|
||||
val cursor = UrtUnorderedBloomFilterCursor(
|
||||
initialSortIndex = nextBottomInitialSortIndex(query, entries),
|
||||
longIntBloomFilter = AdaptiveLongIntBloomFilterBuilder.build(ids, bloomFilter)
|
||||
)
|
||||
|
||||
serializer.serializeCursor(cursor)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,26 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.component_library.model.cursor.UrtUnorderedExcludeIdsCursor
|
||||
import com.twitter.product_mixer.component_library.premarshaller.cursor.UrtCursorSerializer
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineCursorSerializer
|
||||
import com.twitter.timelines.configapi.Param
|
||||
|
||||
/**
|
||||
* Builds [[UrtUnorderedExcludeIdsCursor]] in the Bottom position
|
||||
*
|
||||
* @param excludedIdsMaxLengthParam The maximum length of the cursor
|
||||
* @param excludeIdsSelector Specifies the entry Ids to populate on the `excludedIds` field
|
||||
* @param serializer Converts the cursor to an encoded string
|
||||
*/
|
||||
case class UnorderedExcludeIdsBottomCursorBuilder(
|
||||
override val excludedIdsMaxLengthParam: Param[Int],
|
||||
excludeIdsSelector: PartialFunction[UniversalNoun[_], Long],
|
||||
override val serializer: PipelineCursorSerializer[UrtUnorderedExcludeIdsCursor] =
|
||||
UrtCursorSerializer)
|
||||
extends BaseUnorderedExcludeIdsBottomCursorBuilder {
|
||||
|
||||
override def excludeEntriesCollector(entries: Seq[TimelineEntry]): Seq[Long] =
|
||||
entries.collect(excludeIdsSelector)
|
||||
}
|
Binary file not shown.
|
@ -1,30 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.component_library.model.cursor.UrtUnorderedExcludeIdsCursor
|
||||
import com.twitter.product_mixer.component_library.premarshaller.cursor.UrtCursorSerializer
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineCursorSerializer
|
||||
import com.twitter.timelines.configapi.Param
|
||||
|
||||
/**
|
||||
* Builds [[UrtUnorderedExcludeIdsCursor]] in the Bottom position when we want to also exclude ids
|
||||
* of items inside a module. The reason we cannot use [[UnorderedExcludeIdsBottomCursorBuilder]] in
|
||||
* such case is that the excludeIdsSelector of [[UnorderedExcludeIdsBottomCursorBuilder]] is doing a
|
||||
* one to one mapping between entries and excluded ids, but in case of having a module, a module
|
||||
* entry can result in excluding a sequence of entries.
|
||||
*
|
||||
* @param excludedIdsMaxLengthParam The maximum length of the cursor
|
||||
* @param excludeIdsSelector Specifies the entry Ids to populate on the `excludedIds` field
|
||||
* @param serializer Converts the cursor to an encoded string
|
||||
*/
|
||||
case class UnorderedExcludeIdsSeqBottomCursorBuilder(
|
||||
override val excludedIdsMaxLengthParam: Param[Int],
|
||||
excludeIdsSelector: PartialFunction[UniversalNoun[_], Seq[Long]],
|
||||
override val serializer: PipelineCursorSerializer[UrtUnorderedExcludeIdsCursor] =
|
||||
UrtCursorSerializer)
|
||||
extends BaseUnorderedExcludeIdsBottomCursorBuilder {
|
||||
|
||||
override def excludeEntriesCollector(entries: Seq[TimelineEntry]): Seq[Long] =
|
||||
entries.collect(excludeIdsSelector).flatten
|
||||
}
|
Binary file not shown.
|
@ -1,94 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorOperation
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.Timeline
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineInstruction
|
||||
import com.twitter.product_mixer.core.pipeline.HasPipelineCursor
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.product_mixer.core.pipeline.UrtPipelineCursor
|
||||
import com.twitter.product_mixer.core.util.SortIndexBuilder
|
||||
|
||||
trait UrtBuilder[-Query <: PipelineQuery, +Instruction <: TimelineInstruction] {
|
||||
private val TimelineIdSuffix = "-Timeline"
|
||||
|
||||
def instructionBuilders: Seq[UrtInstructionBuilder[Query, Instruction]]
|
||||
|
||||
def cursorBuilders: Seq[UrtCursorBuilder[Query]]
|
||||
def cursorUpdaters: Seq[UrtCursorUpdater[Query]]
|
||||
|
||||
def metadataBuilder: Option[BaseUrtMetadataBuilder[Query]]
|
||||
|
||||
// Timeline entry sort indexes will count down by this value. Values higher than 1 are useful to
|
||||
// leave room in the sequence for dynamically injecting content in between existing entries.
|
||||
def sortIndexStep: Int = 1
|
||||
|
||||
final def buildTimeline(
|
||||
query: Query,
|
||||
entries: Seq[TimelineEntry]
|
||||
): Timeline = {
|
||||
val initialSortIndex = getInitialSortIndex(query)
|
||||
|
||||
// Set the sort indexes of the entries before we pass them to the cursor builders, since many
|
||||
// cursor implementations use the sort index of the first/last entry as part of the cursor value
|
||||
val sortIndexedEntries = updateSortIndexes(initialSortIndex, entries)
|
||||
|
||||
// Iterate over the cursorUpdaters in the order they were defined. Note that each updater will
|
||||
// be passed the timelineEntries updated by the previous cursorUpdater.
|
||||
val updatedCursorEntries: Seq[TimelineEntry] =
|
||||
cursorUpdaters.foldLeft(sortIndexedEntries) { (timelineEntries, cursorUpdater) =>
|
||||
cursorUpdater.update(query, timelineEntries)
|
||||
}
|
||||
|
||||
val allCursoredEntries =
|
||||
updatedCursorEntries ++ cursorBuilders.flatMap(_.build(query, updatedCursorEntries))
|
||||
|
||||
val instructions: Seq[Instruction] =
|
||||
instructionBuilders.flatMap(_.build(query, allCursoredEntries))
|
||||
|
||||
val metadata = metadataBuilder.map(_.build(query, allCursoredEntries))
|
||||
|
||||
Timeline(
|
||||
id = query.product.identifier.toString + TimelineIdSuffix,
|
||||
instructions = instructions,
|
||||
metadata = metadata
|
||||
)
|
||||
}
|
||||
|
||||
final def getInitialSortIndex(query: Query): Long =
|
||||
query match {
|
||||
case cursorQuery: HasPipelineCursor[_] =>
|
||||
UrtPipelineCursor
|
||||
.getCursorInitialSortIndex(cursorQuery)
|
||||
.getOrElse(SortIndexBuilder.timeToId(query.queryTime))
|
||||
case _ => SortIndexBuilder.timeToId(query.queryTime)
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the sort indexes in the timeline entries starting from the given initial sort index
|
||||
* value and decreasing by the value defined in the sort index step field
|
||||
*
|
||||
* @param initialSortIndex The initial value of the sort index
|
||||
* @param timelineEntries Timeline entries to update
|
||||
*/
|
||||
final def updateSortIndexes(
|
||||
initialSortIndex: Long,
|
||||
timelineEntries: Seq[TimelineEntry]
|
||||
): Seq[TimelineEntry] = {
|
||||
val indexRange =
|
||||
initialSortIndex to (initialSortIndex - (timelineEntries.size * sortIndexStep)) by -sortIndexStep
|
||||
|
||||
// Skip any existing cursors because their sort indexes will be managed by their cursor updater.
|
||||
// If the cursors are not removed first, then the remaining entries would have a gap everywhere
|
||||
// an existing cursor was present.
|
||||
val (cursorEntries, nonCursorEntries) = timelineEntries.partition {
|
||||
case _: CursorOperation => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
nonCursorEntries.zip(indexRange).map {
|
||||
case (entry, index) =>
|
||||
entry.withSortIndex(index)
|
||||
} ++ cursorEntries
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,134 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.component_library.premarshaller.urt.builder.UrtCursorBuilder.DefaultSortIndex
|
||||
import com.twitter.product_mixer.component_library.premarshaller.urt.builder.UrtCursorBuilder.NextPageTopCursorEntryOffset
|
||||
import com.twitter.product_mixer.component_library.premarshaller.urt.builder.UrtCursorBuilder.UrtEntryOffset
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.BottomCursor
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorItem
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorOperation
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorType
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.GapCursor
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.TopCursor
|
||||
import com.twitter.product_mixer.core.pipeline.HasPipelineCursor
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.product_mixer.core.pipeline.UrtPipelineCursor
|
||||
import com.twitter.product_mixer.core.util.SortIndexBuilder
|
||||
|
||||
object UrtCursorBuilder {
|
||||
val NextPageTopCursorEntryOffset = 1L
|
||||
val UrtEntryOffset = 1L
|
||||
val DefaultSortIndex = (query: PipelineQuery) => SortIndexBuilder.timeToId(query.queryTime)
|
||||
}
|
||||
|
||||
trait UrtCursorBuilder[-Query <: PipelineQuery] {
|
||||
|
||||
val includeOperation: IncludeInstruction[Query] = AlwaysInclude
|
||||
|
||||
def cursorType: CursorType
|
||||
def cursorValue(query: Query, entries: Seq[TimelineEntry]): String
|
||||
|
||||
/**
|
||||
* Identifier of an *existing* timeline cursor that this new cursor would replace, if this cursor
|
||||
* is returned in a `ReplaceEntry` timeline instruction.
|
||||
*
|
||||
* Note:
|
||||
* - This id is used to populate the `entryIdToReplace` field on the URT TimelineEntry
|
||||
* generated. More details at [[CursorOperation.entryIdToReplace]].
|
||||
* - As a convention, we use the sortIndex of the cursor for its id/entryId fields. So the
|
||||
* `idToReplace` should represent the sortIndex of the existing cursor to be replaced.
|
||||
*/
|
||||
def idToReplace(query: Query): Option[Long] = None
|
||||
|
||||
def cursorSortIndex(query: Query, entries: Seq[TimelineEntry]): Long =
|
||||
(query, cursorType) match {
|
||||
case (query: PipelineQuery with HasPipelineCursor[_], TopCursor) =>
|
||||
topCursorSortIndex(query, entries)
|
||||
case (query: PipelineQuery with HasPipelineCursor[_], BottomCursor | GapCursor) =>
|
||||
bottomCursorSortIndex(query, entries)
|
||||
case _ =>
|
||||
throw new UnsupportedOperationException(
|
||||
"Automatic sort index support limited to top and bottom cursors")
|
||||
}
|
||||
|
||||
def build(query: Query, entries: Seq[TimelineEntry]): Option[CursorOperation] = {
|
||||
if (includeOperation(query, entries)) {
|
||||
val sortIndex = cursorSortIndex(query, entries)
|
||||
|
||||
val cursorOperation = CursorOperation(
|
||||
id = sortIndex,
|
||||
sortIndex = Some(sortIndex),
|
||||
value = cursorValue(query, entries),
|
||||
cursorType = cursorType,
|
||||
displayTreatment = None,
|
||||
idToReplace = idToReplace(query),
|
||||
)
|
||||
|
||||
Some(cursorOperation)
|
||||
} else None
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the top cursor sort index which handles the following cases:
|
||||
* 1. When there is at least one non-cursor entry, use the first entry's sort index + UrtEntryOffset
|
||||
* 2. When there are no non-cursor entries, and initialSortIndex is not set which indicates that
|
||||
* it is the first page, use DefaultSortIndex + UrtEntryOffset
|
||||
* 3. When there are no non-cursor entries, and initialSortIndex is set which indicates that it is
|
||||
* not the first page, use the query.initialSortIndex from the passed-in cursor + UrtEntryOffset
|
||||
*/
|
||||
protected def topCursorSortIndex(
|
||||
query: PipelineQuery with HasPipelineCursor[_],
|
||||
entries: Seq[TimelineEntry]
|
||||
): Long = {
|
||||
val nonCursorEntries = entries.filter {
|
||||
case _: CursorOperation => false
|
||||
case _: CursorItem => false
|
||||
case _ => true
|
||||
}
|
||||
|
||||
lazy val initialSortIndex =
|
||||
UrtPipelineCursor.getCursorInitialSortIndex(query).getOrElse(DefaultSortIndex(query))
|
||||
|
||||
nonCursorEntries.headOption.flatMap(_.sortIndex).getOrElse(initialSortIndex) + UrtEntryOffset
|
||||
}
|
||||
|
||||
/**
|
||||
* Specifies the point at which the next page's entries' sort indices will start counting.
|
||||
*
|
||||
* Note that in the case of URT, the next page's entries' does not include the top cursor. As
|
||||
* such, the value of initialSortIndex passed back in the cursor is typically the bottom cursor's
|
||||
* sort index - 2. Subtracting 2 leaves room for the next page's top cursor, which will have a
|
||||
* sort index of top entry + 1.
|
||||
*/
|
||||
protected def nextBottomInitialSortIndex(
|
||||
query: PipelineQuery with HasPipelineCursor[_],
|
||||
entries: Seq[TimelineEntry]
|
||||
): Long = {
|
||||
bottomCursorSortIndex(query, entries) - NextPageTopCursorEntryOffset - UrtEntryOffset
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the bottom cursor sort index which handles the following cases:
|
||||
* 1. When there is at least one non-cursor entry, use the last entry's sort index - UrtEntryOffset
|
||||
* 2. When there are no non-cursor entries, and initialSortIndex is not set which indicates that
|
||||
* it is the first page, use DefaultSortIndex
|
||||
* 3. When there are no non-cursor entries, and initialSortIndex is set which indicates that it is
|
||||
* not the first page, use the query.initialSortIndex from the passed-in cursor
|
||||
*/
|
||||
protected def bottomCursorSortIndex(
|
||||
query: PipelineQuery with HasPipelineCursor[_],
|
||||
entries: Seq[TimelineEntry]
|
||||
): Long = {
|
||||
val nonCursorEntries = entries.filter {
|
||||
case _: CursorOperation => false
|
||||
case _: CursorItem => false
|
||||
case _ => true
|
||||
}
|
||||
|
||||
lazy val initialSortIndex =
|
||||
UrtPipelineCursor.getCursorInitialSortIndex(query).getOrElse(DefaultSortIndex(query))
|
||||
|
||||
nonCursorEntries.lastOption
|
||||
.flatMap(_.sortIndex).map(_ - UrtEntryOffset).getOrElse(initialSortIndex)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,44 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.component_library.premarshaller.urt.builder.UrtCursorUpdater.getCursorByType
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorOperation
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.operation.CursorType
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
object UrtCursorUpdater {
|
||||
|
||||
def getCursorByType(
|
||||
entries: Seq[TimelineEntry],
|
||||
cursorType: CursorType
|
||||
): Option[CursorOperation] = {
|
||||
entries.collectFirst {
|
||||
case cursor: CursorOperation if cursor.cursorType == cursorType => cursor
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a CursorCandidate is returned by a Candidate Source, use this trait to update that Cursor as
|
||||
// necessary (as opposed to building a new cursor which is done with the UrtCursorBuilder)
|
||||
trait UrtCursorUpdater[-Query <: PipelineQuery] extends UrtCursorBuilder[Query] { self =>
|
||||
|
||||
def getExistingCursor(entries: Seq[TimelineEntry]): Option[CursorOperation] = {
|
||||
getCursorByType(entries, self.cursorType)
|
||||
}
|
||||
|
||||
def update(query: Query, entries: Seq[TimelineEntry]): Seq[TimelineEntry] = {
|
||||
if (includeOperation(query, entries)) {
|
||||
getExistingCursor(entries)
|
||||
.map { existingCursor =>
|
||||
// Safe .get because includeOperation() is shared in this context
|
||||
// build() method creates a new CursorOperation. We copy over the `idToReplace`
|
||||
// from the existing cursor.
|
||||
val newCursor =
|
||||
build(query, entries).get
|
||||
.copy(idToReplace = existingCursor.idToReplace)
|
||||
|
||||
entries.filterNot(_ == existingCursor) :+ newCursor
|
||||
}.getOrElse(entries)
|
||||
} else entries
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,15 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineInstruction
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
|
||||
trait UrtInstructionBuilder[-Query <: PipelineQuery, +Instruction <: TimelineInstruction] {
|
||||
|
||||
def includeInstruction: IncludeInstruction[Query] = AlwaysInclude
|
||||
|
||||
def build(
|
||||
query: Query,
|
||||
entries: Seq[TimelineEntry]
|
||||
): Seq[Instruction]
|
||||
}
|
Binary file not shown.
|
@ -1,43 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.premarshaller.urt.builder
|
||||
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineEntry
|
||||
import com.twitter.product_mixer.core.model.marshalling.response.urt.TimelineMetadata
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.stringcenter.client.StringCenter
|
||||
import com.twitter.stringcenter.client.core.ExternalString
|
||||
|
||||
trait BaseUrtMetadataBuilder[-Query <: PipelineQuery] {
|
||||
def build(
|
||||
query: Query,
|
||||
entries: Seq[TimelineEntry]
|
||||
): TimelineMetadata
|
||||
}
|
||||
|
||||
case class UrtMetadataBuilder(
|
||||
title: Option[String] = None,
|
||||
scribeConfigBuilder: Option[TimelineScribeConfigBuilder[PipelineQuery]])
|
||||
extends BaseUrtMetadataBuilder[PipelineQuery] {
|
||||
|
||||
override def build(
|
||||
query: PipelineQuery,
|
||||
entries: Seq[TimelineEntry]
|
||||
): TimelineMetadata = TimelineMetadata(
|
||||
title = title,
|
||||
scribeConfig = scribeConfigBuilder.flatMap(_.build(query, entries))
|
||||
)
|
||||
}
|
||||
|
||||
case class UrtMetadataStringCenterBuilder(
|
||||
titleKey: ExternalString,
|
||||
scribeConfigBuilder: Option[TimelineScribeConfigBuilder[PipelineQuery]],
|
||||
stringCenter: StringCenter)
|
||||
extends BaseUrtMetadataBuilder[PipelineQuery] {
|
||||
|
||||
override def build(
|
||||
query: PipelineQuery,
|
||||
entries: Seq[TimelineEntry]
|
||||
): TimelineMetadata = TimelineMetadata(
|
||||
title = Some(stringCenter.prepare(titleKey)),
|
||||
scribeConfig = scribeConfigBuilder.flatMap(_.build(query, entries))
|
||||
)
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/io/grpc:grpc-netty",
|
||||
"3rdparty/jvm/io/netty:netty4-tcnative-boringssl-static",
|
||||
"3rdparty/jvm/io/opil:tensorflow-serving-client",
|
||||
"3rdparty/jvm/javax/inject:javax.inject",
|
||||
"3rdparty/jvm/triton/inference:triton-grpc",
|
||||
"finagle-internal/finagle-grpc/src/main/scala",
|
||||
"finagle-internal/finagle-grpc/src/test/java",
|
||||
"finagle-internal/finagle-grpc/src/test/proto",
|
||||
"finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/authentication",
|
||||
"finagle-internal/mtls/src/main/scala/com/twitter/finagle/mtls/client",
|
||||
"finagle/finagle-http/src/main/scala",
|
||||
"finatra-internal/mtls/src/main/scala",
|
||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
||||
"ml-serving/scala:kfserving-tfserving-converter",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
"stitch/stitch-core",
|
||||
],
|
||||
exports = [
|
||||
"3rdparty/jvm/triton/inference:triton-grpc",
|
||||
"finagle/finagle-http/src/main/scala",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
"stitch/stitch-core",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,12 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.common
|
||||
|
||||
import com.twitter.stitch.Stitch
|
||||
import inference.GrpcService.ModelInferRequest
|
||||
import inference.GrpcService.ModelInferResponse
|
||||
|
||||
/**
|
||||
* MLModelInferenceClient for calling different Inference Service such as ManagedModelClient or NaviModelClient.
|
||||
*/
|
||||
trait MLModelInferenceClient {
|
||||
def score(request: ModelInferRequest): Stitch[ModelInferResponse]
|
||||
}
|
Binary file not shown.
|
@ -1,33 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.common
|
||||
|
||||
import com.twitter.finagle.Http
|
||||
import com.twitter.finagle.grpc.FinagleChannelBuilder
|
||||
import com.twitter.finagle.grpc.FutureConverters
|
||||
import com.twitter.stitch.Stitch
|
||||
import inference.GRPCInferenceServiceGrpc
|
||||
import inference.GrpcService.ModelInferRequest
|
||||
import inference.GrpcService.ModelInferResponse
|
||||
import io.grpc.ManagedChannel
|
||||
|
||||
/**
|
||||
* Client wrapper for calling a Cortex Managed Inference Service (go/cmis) ML Model using GRPC.
|
||||
* @param httpClient Finagle HTTP Client to use for connection.
|
||||
* @param modelPath Wily path to the ML Model service (e.g. /cluster/local/role/service/instance).
|
||||
*/
|
||||
case class ManagedModelClient(
|
||||
httpClient: Http.Client,
|
||||
modelPath: String)
|
||||
extends MLModelInferenceClient {
|
||||
|
||||
private val channel: ManagedChannel =
|
||||
FinagleChannelBuilder.forTarget(modelPath).httpClient(httpClient).build()
|
||||
|
||||
private val inferenceServiceStub = GRPCInferenceServiceGrpc.newFutureStub(channel)
|
||||
|
||||
def score(request: ModelInferRequest): Stitch[ModelInferResponse] = {
|
||||
Stitch
|
||||
.callFuture(
|
||||
FutureConverters
|
||||
.RichListenableFuture(inferenceServiceStub.modelInfer(request)).toTwitter)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,28 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.common
|
||||
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.timelines.configapi.Param
|
||||
|
||||
/**
|
||||
* Selector for choosing which Model ID/Name to use when calling an underlying ML Model Service.
|
||||
*/
|
||||
trait ModelSelector[-Query <: PipelineQuery] {
|
||||
def apply(query: Query): Option[String]
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple Model ID Selector that chooses model based off of a Param object.
|
||||
* @param param ConfigAPI Param that decides the model id.
|
||||
*/
|
||||
case class ParamModelSelector[Query <: PipelineQuery](param: Param[String])
|
||||
extends ModelSelector[Query] {
|
||||
override def apply(query: Query): Option[String] = Some(query.params(param))
|
||||
}
|
||||
|
||||
/**
|
||||
* Static Selector that chooses the same model name always
|
||||
* @param modelName The model name to use.
|
||||
*/
|
||||
case class StaticModelSelector(modelName: String) extends ModelSelector[PipelineQuery] {
|
||||
override def apply(query: PipelineQuery): Option[String] = Some(modelName)
|
||||
}
|
Binary file not shown.
|
@ -1,50 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.common
|
||||
|
||||
import com.twitter.finagle.Http
|
||||
import com.twitter.finagle.grpc.FinagleChannelBuilder
|
||||
import com.twitter.finagle.grpc.FutureConverters
|
||||
import com.twitter.mlserving.frontend.TFServingInferenceServiceImpl
|
||||
import com.twitter.stitch.Stitch
|
||||
import tensorflow.serving.PredictionServiceGrpc
|
||||
import inference.GrpcService.ModelInferRequest
|
||||
import inference.GrpcService.ModelInferResponse
|
||||
import io.grpc.ManagedChannel
|
||||
import io.grpc.Status
|
||||
|
||||
/**
|
||||
* Client wrapper for calling a Navi Inference Service (go/navi).
|
||||
* @param httpClient Finagle HTTP Client to use for connection.
|
||||
* @param modelPath Wily path to the ML Model service (e.g. /s/role/service).
|
||||
*/
|
||||
case class NaviModelClient(
|
||||
httpClient: Http.Client,
|
||||
modelPath: String)
|
||||
extends MLModelInferenceClient {
|
||||
|
||||
private val channel: ManagedChannel =
|
||||
FinagleChannelBuilder
|
||||
.forTarget(modelPath)
|
||||
.httpClient(httpClient)
|
||||
// Navi enforces an authority name.
|
||||
.overrideAuthority("rustserving")
|
||||
// certain GRPC errors need to be retried.
|
||||
.enableRetryForStatus(Status.UNKNOWN)
|
||||
.enableRetryForStatus(Status.RESOURCE_EXHAUSTED)
|
||||
// this is required at channel level as mTLS is enabled at httpClient level
|
||||
.usePlaintext()
|
||||
.build()
|
||||
|
||||
private val inferenceServiceStub = PredictionServiceGrpc.newFutureStub(channel)
|
||||
|
||||
def score(request: ModelInferRequest): Stitch[ModelInferResponse] = {
|
||||
val tfServingRequest = TFServingInferenceServiceImpl.adaptModelInferRequest(request)
|
||||
Stitch
|
||||
.callFuture(
|
||||
FutureConverters
|
||||
.RichListenableFuture(inferenceServiceStub.predict(tfServingRequest)).toTwitter
|
||||
.map { response =>
|
||||
TFServingInferenceServiceImpl.adaptModelInferResponse(response)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"finagle/finagle-http/src/main/scala",
|
||||
"finatra-internal/mtls/src/main/scala",
|
||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/candidate",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/module/http",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/scorer/common",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/scorer/tensorbuilder",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/datarecord",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/featuremap",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/featuremap/datarecord",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/configapi",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/scorer",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline/pipeline_failure",
|
||||
"src/scala/com/twitter/ml/featurestore/lib",
|
||||
"src/thrift/com/twitter/ml/prediction_service:prediction_service-java",
|
||||
],
|
||||
exports = [
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/candidate",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/module/http",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/scorer/common",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/scorer/tensorbuilder",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/datarecord",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/featuremap/datarecord",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/configapi",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/scorer",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,137 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.cortex
|
||||
|
||||
import com.google.protobuf.ByteString
|
||||
import com.twitter.ml.prediction_service.BatchPredictionRequest
|
||||
import com.twitter.ml.prediction_service.BatchPredictionResponse
|
||||
import com.twitter.product_mixer.component_library.scorer.common.ManagedModelClient
|
||||
import com.twitter.product_mixer.component_library.scorer.common.ModelSelector
|
||||
import com.twitter.product_mixer.core.feature.Feature
|
||||
import com.twitter.product_mixer.core.feature.datarecord.BaseDataRecordFeature
|
||||
import com.twitter.product_mixer.core.feature.datarecord.TensorDataRecordCompatible
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.DataRecordConverter
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.DataRecordExtractor
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.FeaturesScope
|
||||
import com.twitter.product_mixer.core.functional_component.scorer.Scorer
|
||||
import com.twitter.product_mixer.core.model.common.CandidateWithFeatures
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.product_mixer.core.pipeline.pipeline_failure.IllegalStateFailure
|
||||
import inference.GrpcService
|
||||
import inference.GrpcService.ModelInferRequest
|
||||
import inference.GrpcService.ModelInferResponse
|
||||
import com.twitter.product_mixer.core.pipeline.pipeline_failure.PipelineFailure
|
||||
import com.twitter.stitch.Stitch
|
||||
import org.apache.thrift.TDeserializer
|
||||
import org.apache.thrift.TSerializer
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
private[cortex] class CortexManagedDataRecordScorer[
|
||||
Query <: PipelineQuery,
|
||||
Candidate <: UniversalNoun[Any],
|
||||
QueryFeatures <: BaseDataRecordFeature[Query, _],
|
||||
CandidateFeatures <: BaseDataRecordFeature[Candidate, _],
|
||||
ResultFeatures <: BaseDataRecordFeature[Candidate, _] with TensorDataRecordCompatible[_]
|
||||
](
|
||||
override val identifier: ScorerIdentifier,
|
||||
modelSignature: String,
|
||||
modelSelector: ModelSelector[Query],
|
||||
modelClient: ManagedModelClient,
|
||||
queryFeatures: FeaturesScope[QueryFeatures],
|
||||
candidateFeatures: FeaturesScope[CandidateFeatures],
|
||||
resultFeatures: Set[ResultFeatures])
|
||||
extends Scorer[Query, Candidate] {
|
||||
|
||||
require(resultFeatures.nonEmpty, "Result features cannot be empty")
|
||||
override val features: Set[Feature[_, _]] = resultFeatures.asInstanceOf[Set[Feature[_, _]]]
|
||||
|
||||
private val queryDataRecordAdapter = new DataRecordConverter(queryFeatures)
|
||||
private val candidatesDataRecordAdapter = new DataRecordConverter(candidateFeatures)
|
||||
private val resultDataRecordExtractor = new DataRecordExtractor(resultFeatures)
|
||||
|
||||
private val localTSerializer = new ThreadLocal[TSerializer] {
|
||||
override protected def initialValue: TSerializer = new TSerializer()
|
||||
}
|
||||
|
||||
private val localTDeserializer = new ThreadLocal[TDeserializer] {
|
||||
override protected def initialValue: TDeserializer = new TDeserializer()
|
||||
}
|
||||
|
||||
override def apply(
|
||||
query: Query,
|
||||
candidates: Seq[CandidateWithFeatures[Candidate]]
|
||||
): Stitch[Seq[FeatureMap]] = {
|
||||
modelClient.score(buildRequest(query, candidates)).map(buildResponse(candidates, _))
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes candidates to be scored and converts it to a ModelInferRequest that can be passed to the
|
||||
* managed ML service
|
||||
*/
|
||||
private def buildRequest(
|
||||
query: Query,
|
||||
scorerCandidates: Seq[CandidateWithFeatures[Candidate]]
|
||||
): ModelInferRequest = {
|
||||
// Convert the feature maps to thrift data records and construct thrift request.
|
||||
val thriftDataRecords = scorerCandidates.map { candidate =>
|
||||
candidatesDataRecordAdapter.toDataRecord(candidate.features)
|
||||
}
|
||||
val batchRequest = new BatchPredictionRequest(thriftDataRecords.asJava)
|
||||
query.features.foreach { featureMap =>
|
||||
batchRequest.setCommonFeatures(queryDataRecordAdapter.toDataRecord(featureMap))
|
||||
}
|
||||
val serializedBatchRequest = localTSerializer.get().serialize(batchRequest)
|
||||
|
||||
// Build Tensor Request
|
||||
val requestBuilder = ModelInferRequest
|
||||
.newBuilder()
|
||||
|
||||
modelSelector.apply(query).foreach { modelName =>
|
||||
requestBuilder.setModelName(modelName) // model name in the model config
|
||||
}
|
||||
|
||||
val inputTensorBuilder = ModelInferRequest.InferInputTensor
|
||||
.newBuilder()
|
||||
.setName("request")
|
||||
.setDatatype("UINT8")
|
||||
.addShape(serializedBatchRequest.length)
|
||||
|
||||
val inferParameter = GrpcService.InferParameter
|
||||
.newBuilder()
|
||||
.setStringParam(modelSignature) // signature of exported tf function
|
||||
.build()
|
||||
|
||||
requestBuilder
|
||||
.addInputs(inputTensorBuilder)
|
||||
.addRawInputContents(ByteString.copyFrom(serializedBatchRequest))
|
||||
.putParameters("signature_name", inferParameter)
|
||||
.build()
|
||||
}
|
||||
|
||||
private def buildResponse(
|
||||
scorerCandidates: Seq[CandidateWithFeatures[Candidate]],
|
||||
response: ModelInferResponse
|
||||
): Seq[FeatureMap] = {
|
||||
|
||||
val responseByteString = if (response.getRawOutputContentsList.isEmpty()) {
|
||||
throw PipelineFailure(
|
||||
IllegalStateFailure,
|
||||
"Model inference response has empty raw outputContents")
|
||||
} else {
|
||||
response.getRawOutputContents(0)
|
||||
}
|
||||
val batchPredictionResponse: BatchPredictionResponse = new BatchPredictionResponse()
|
||||
localTDeserializer.get().deserialize(batchPredictionResponse, responseByteString.toByteArray)
|
||||
|
||||
// get the prediction values from the batch prediction response
|
||||
val resultScoreMaps =
|
||||
batchPredictionResponse.predictions.asScala.map(resultDataRecordExtractor.fromDataRecord)
|
||||
|
||||
if (resultScoreMaps.size != scorerCandidates.size) {
|
||||
throw PipelineFailure(IllegalStateFailure, "Result Size mismatched candidates size")
|
||||
}
|
||||
|
||||
resultScoreMaps
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,67 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.cortex
|
||||
|
||||
import com.twitter.finagle.Http
|
||||
import com.twitter.product_mixer.component_library.module.http.FinagleHttpClientModule.FinagleHttpClientModule
|
||||
import com.twitter.product_mixer.component_library.scorer.common.ManagedModelClient
|
||||
import com.twitter.product_mixer.component_library.scorer.common.ModelSelector
|
||||
import com.twitter.product_mixer.core.feature.datarecord.BaseDataRecordFeature
|
||||
import com.twitter.product_mixer.core.feature.datarecord.TensorDataRecordCompatible
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.FeaturesScope
|
||||
import com.twitter.product_mixer.core.functional_component.scorer.Scorer
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Named
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class CortexManagedInferenceServiceDataRecordScorerBuilder @Inject() (
|
||||
@Named(FinagleHttpClientModule) httpClient: Http.Client) {
|
||||
|
||||
/**
|
||||
* Builds a configurable Scorer to call into your desired DataRecord-backed Cortex Managed ML Model Service.
|
||||
*
|
||||
* If your service does not bind an Http.Client implementation, add
|
||||
* [[com.twitter.product_mixer.component_library.module.http.FinagleHttpClientModule]]
|
||||
* to your server module list
|
||||
*
|
||||
* @param scorerIdentifier Unique identifier for the scorer
|
||||
* @param modelPath MLS path to model
|
||||
* @param modelSignature Model Signature Key
|
||||
* @param modelSelector [[ModelSelector]] for choosing the model name, can be an anon function.
|
||||
* @param candidateFeatures Desired candidate level feature store features to pass to the model.
|
||||
* @param resultFeatures Desired candidate level feature store features to extract from the model.
|
||||
* Since the Cortex Managed Platform always returns tensor values, the
|
||||
* feature must use a [[TensorDataRecordCompatible]].
|
||||
* @tparam Query Type of pipeline query.
|
||||
* @tparam Candidate Type of candidates to score.
|
||||
* @tparam QueryFeatures type of the query level features consumed by the scorer.
|
||||
* @tparam CandidateFeatures type of the candidate level features consumed by the scorer.
|
||||
* @tparam ResultFeatures type of the candidate level features returned by the scorer.
|
||||
*/
|
||||
def build[
|
||||
Query <: PipelineQuery,
|
||||
Candidate <: UniversalNoun[Any],
|
||||
QueryFeatures <: BaseDataRecordFeature[Query, _],
|
||||
CandidateFeatures <: BaseDataRecordFeature[Candidate, _],
|
||||
ResultFeatures <: BaseDataRecordFeature[Candidate, _] with TensorDataRecordCompatible[_]
|
||||
](
|
||||
scorerIdentifier: ScorerIdentifier,
|
||||
modelPath: String,
|
||||
modelSignature: String,
|
||||
modelSelector: ModelSelector[Query],
|
||||
queryFeatures: FeaturesScope[QueryFeatures],
|
||||
candidateFeatures: FeaturesScope[CandidateFeatures],
|
||||
resultFeatures: Set[ResultFeatures]
|
||||
): Scorer[Query, Candidate] =
|
||||
new CortexManagedDataRecordScorer(
|
||||
identifier = scorerIdentifier,
|
||||
modelSignature = modelSignature,
|
||||
modelSelector = modelSelector,
|
||||
modelClient = ManagedModelClient(httpClient, modelPath),
|
||||
queryFeatures = queryFeatures,
|
||||
candidateFeatures = candidateFeatures,
|
||||
resultFeatures = resultFeatures
|
||||
)
|
||||
}
|
Binary file not shown.
|
@ -1,97 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.cortex
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.product_mixer.component_library.scorer.common.MLModelInferenceClient
|
||||
import com.twitter.product_mixer.component_library.scorer.tensorbuilder.ModelInferRequestBuilder
|
||||
import com.twitter.product_mixer.core.feature.Feature
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
|
||||
import com.twitter.product_mixer.core.functional_component.scorer.Scorer
|
||||
import com.twitter.product_mixer.core.model.common.CandidateWithFeatures
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.product_mixer.core.pipeline.pipeline_failure.IllegalStateFailure
|
||||
import com.twitter.product_mixer.core.pipeline.pipeline_failure.PipelineFailure
|
||||
import com.twitter.stitch.Stitch
|
||||
import com.twitter.util.logging.Logging
|
||||
import inference.GrpcService.ModelInferRequest
|
||||
import inference.GrpcService.ModelInferResponse.InferOutputTensor
|
||||
import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
|
||||
|
||||
private[scorer] class CortexManagedInferenceServiceTensorScorer[
|
||||
Query <: PipelineQuery,
|
||||
Candidate <: UniversalNoun[Any]
|
||||
](
|
||||
override val identifier: ScorerIdentifier,
|
||||
modelInferRequestBuilder: ModelInferRequestBuilder[
|
||||
Query,
|
||||
Candidate
|
||||
],
|
||||
resultFeatureExtractors: Seq[FeatureWithExtractor[Query, Candidate, _]],
|
||||
client: MLModelInferenceClient,
|
||||
statsReceiver: StatsReceiver)
|
||||
extends Scorer[Query, Candidate]
|
||||
with Logging {
|
||||
|
||||
require(resultFeatureExtractors.nonEmpty, "Result Extractors cannot be empty")
|
||||
|
||||
private val managedServiceRequestFailures = statsReceiver.counter("managedServiceRequestFailures")
|
||||
override val features: Set[Feature[_, _]] =
|
||||
resultFeatureExtractors.map(_.feature).toSet.asInstanceOf[Set[Feature[_, _]]]
|
||||
|
||||
override def apply(
|
||||
query: Query,
|
||||
candidates: Seq[CandidateWithFeatures[Candidate]]
|
||||
): Stitch[Seq[FeatureMap]] = {
|
||||
val batchInferRequest: ModelInferRequest = modelInferRequestBuilder(query, candidates)
|
||||
|
||||
val managedServiceResponse: Stitch[Seq[InferOutputTensor]] =
|
||||
client.score(batchInferRequest).map(_.getOutputsList.toSeq).onFailure { e =>
|
||||
error(s"request to ML Managed Service Failed: $e")
|
||||
managedServiceRequestFailures.incr()
|
||||
}
|
||||
|
||||
managedServiceResponse.map { responses =>
|
||||
extractResponse(query, candidates.map(_.candidate), responses)
|
||||
}
|
||||
}
|
||||
|
||||
def extractResponse(
|
||||
query: Query,
|
||||
candidates: Seq[Candidate],
|
||||
tensorOutput: Seq[InferOutputTensor]
|
||||
): Seq[FeatureMap] = {
|
||||
val featureMapBuilders = candidates.map { _ => FeatureMapBuilder.apply() }
|
||||
// Extract the feature for each candidate from the tensor outputs
|
||||
resultFeatureExtractors.foreach {
|
||||
case FeatureWithExtractor(feature, extractor) =>
|
||||
val extractedValues = extractor.apply(query, tensorOutput)
|
||||
if (candidates.size != extractedValues.size) {
|
||||
throw PipelineFailure(
|
||||
IllegalStateFailure,
|
||||
s"Managed Service returned a different number of $feature than the number of candidates." +
|
||||
s"Returned ${extractedValues.size} scores but there were ${candidates.size} candidates."
|
||||
)
|
||||
}
|
||||
// Go through the extracted features list one by one and update the feature map result for each candidate.
|
||||
featureMapBuilders.zip(extractedValues).foreach {
|
||||
case (builder, value) =>
|
||||
builder.add(feature, Some(value))
|
||||
}
|
||||
}
|
||||
|
||||
featureMapBuilders.map(_.build())
|
||||
}
|
||||
}
|
||||
|
||||
case class FeatureWithExtractor[
|
||||
-Query <: PipelineQuery,
|
||||
-Candidate <: UniversalNoun[Any],
|
||||
ResultType
|
||||
](
|
||||
feature: Feature[Candidate, Option[ResultType]],
|
||||
featureExtractor: ModelFeatureExtractor[Query, ResultType])
|
||||
|
||||
class UnexpectedFeatureTypeException(feature: Feature[_, _])
|
||||
extends UnsupportedOperationException(s"Unsupported Feature type passed in $feature")
|
Binary file not shown.
|
@ -1,47 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.cortex
|
||||
|
||||
import com.twitter.finagle.stats.StatsReceiver
|
||||
import com.twitter.product_mixer.component_library.scorer.common.MLModelInferenceClient
|
||||
import com.twitter.product_mixer.component_library.scorer.tensorbuilder.ModelInferRequestBuilder
|
||||
import com.twitter.product_mixer.core.functional_component.scorer.Scorer
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class CortexManagedInferenceServiceTensorScorerBuilder @Inject() (
|
||||
statsReceiver: StatsReceiver) {
|
||||
|
||||
/**
|
||||
* Builds a configurable Scorer to call into your desired Cortex Managed ML Model Service.
|
||||
*
|
||||
* If your service does not bind an Http.Client implementation, add
|
||||
* [[com.twitter.product_mixer.component_library.module.http.FinagleHttpClientModule]]
|
||||
* to your server module list
|
||||
*
|
||||
* @param scorerIdentifier Unique identifier for the scorer
|
||||
* @param resultFeatureExtractors The result features an their tensor extractors for each candidate.
|
||||
* @tparam Query Type of pipeline query.
|
||||
* @tparam Candidate Type of candidates to score.
|
||||
* @tparam QueryFeatures type of the query level features consumed by the scorer.
|
||||
* @tparam CandidateFeatures type of the candidate level features consumed by the scorer.
|
||||
*/
|
||||
def build[Query <: PipelineQuery, Candidate <: UniversalNoun[Any]](
|
||||
scorerIdentifier: ScorerIdentifier,
|
||||
modelInferRequestBuilder: ModelInferRequestBuilder[
|
||||
Query,
|
||||
Candidate
|
||||
],
|
||||
resultFeatureExtractors: Seq[FeatureWithExtractor[Query, Candidate, _]],
|
||||
client: MLModelInferenceClient
|
||||
): Scorer[Query, Candidate] =
|
||||
new CortexManagedInferenceServiceTensorScorer(
|
||||
scorerIdentifier,
|
||||
modelInferRequestBuilder,
|
||||
resultFeatureExtractors,
|
||||
client,
|
||||
statsReceiver.scope(scorerIdentifier.name)
|
||||
)
|
||||
}
|
Binary file not shown.
|
@ -1,15 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.cortex
|
||||
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import inference.GrpcService.ModelInferResponse.InferOutputTensor
|
||||
|
||||
/**
|
||||
* Extractor defining how a Scorer should go from outputted tensors to the individual results
|
||||
* for each candidate being scored.
|
||||
*
|
||||
* @tparam Result the type of the Value being returned.
|
||||
* Users can pass in an anonymous function
|
||||
*/
|
||||
trait ModelFeatureExtractor[-Query <: PipelineQuery, Result] {
|
||||
def apply(query: Query, tensorOutput: Seq[InferOutputTensor]): Seq[Result]
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
scala_library(
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/javax/inject:javax.inject",
|
||||
"cr-ml-ranker/thrift/src/main/thrift:thrift-scala",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/feature_hydrator/query/cr_ml_ranker",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/candidate",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/scorer",
|
||||
],
|
||||
exports = [
|
||||
"cr-ml-ranker/thrift/src/main/thrift:thrift-scala",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/feature_hydrator/query/cr_ml_ranker",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/candidate",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/scorer",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,52 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.cr_ml_ranker
|
||||
|
||||
import com.twitter.product_mixer.component_library.feature_hydrator.query.cr_ml_ranker.CrMlRankerCommonFeatures
|
||||
import com.twitter.product_mixer.component_library.feature_hydrator.query.cr_ml_ranker.CrMlRankerRankingConfig
|
||||
import com.twitter.product_mixer.component_library.model.candidate.TweetCandidate
|
||||
import com.twitter.product_mixer.core.feature.Feature
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
|
||||
import com.twitter.product_mixer.core.functional_component.scorer.Scorer
|
||||
import com.twitter.product_mixer.core.model.common.CandidateWithFeatures
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.stitch.Stitch
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
object CrMlRankerScore extends Feature[TweetCandidate, Double]
|
||||
|
||||
/**
|
||||
* Scorer that scores tweets using the Content Recommender ML Light Ranker: http://go/cr-ml-ranker
|
||||
*/
|
||||
@Singleton
|
||||
class CrMlRankerScorer @Inject() (crMlRanker: CrMlRankerScoreStitchClient)
|
||||
extends Scorer[PipelineQuery, TweetCandidate] {
|
||||
|
||||
override val identifier: ScorerIdentifier = ScorerIdentifier("CrMlRanker")
|
||||
|
||||
override val features: Set[Feature[_, _]] = Set(CrMlRankerScore)
|
||||
|
||||
override def apply(
|
||||
query: PipelineQuery,
|
||||
candidates: Seq[CandidateWithFeatures[TweetCandidate]]
|
||||
): Stitch[Seq[FeatureMap]] = {
|
||||
val queryFeatureMap = query.features.getOrElse(FeatureMap.empty)
|
||||
val rankingConfig = queryFeatureMap.get(CrMlRankerRankingConfig)
|
||||
val commonFeatures = queryFeatureMap.get(CrMlRankerCommonFeatures)
|
||||
val userId = query.getRequiredUserId
|
||||
|
||||
val scoresStitch = Stitch.collect(candidates.map { candidateWithFeatures =>
|
||||
crMlRanker
|
||||
.getScore(userId, candidateWithFeatures.candidate, rankingConfig, commonFeatures).map(
|
||||
_.score)
|
||||
})
|
||||
scoresStitch.map { scores =>
|
||||
scores.map { score =>
|
||||
FeatureMapBuilder()
|
||||
.add(CrMlRankerScore, score)
|
||||
.build()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,79 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.cr_ml_ranker
|
||||
|
||||
import com.twitter.cr_ml_ranker.{thriftscala => t}
|
||||
import com.twitter.product_mixer.component_library.model.candidate.BaseTweetCandidate
|
||||
import com.twitter.stitch.SeqGroup
|
||||
import com.twitter.stitch.Stitch
|
||||
import com.twitter.util.Future
|
||||
import com.twitter.util.Return
|
||||
import com.twitter.util.Try
|
||||
|
||||
case class CrMlRankerResult(
|
||||
tweetId: Long,
|
||||
score: Double)
|
||||
|
||||
class CrMlRankerScoreStitchClient(
|
||||
crMLRanker: t.CrMLRanker.MethodPerEndpoint,
|
||||
maxBatchSize: Int) {
|
||||
|
||||
def getScore(
|
||||
userId: Long,
|
||||
tweetCandidate: BaseTweetCandidate,
|
||||
rankingConfig: t.RankingConfig,
|
||||
commonFeatures: t.CommonFeatures
|
||||
): Stitch[CrMlRankerResult] = {
|
||||
Stitch.call(
|
||||
tweetCandidate,
|
||||
CrMlRankerGroup(
|
||||
userId = userId,
|
||||
rankingConfig = rankingConfig,
|
||||
commonFeatures = commonFeatures
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
private case class CrMlRankerGroup(
|
||||
userId: Long,
|
||||
rankingConfig: t.RankingConfig,
|
||||
commonFeatures: t.CommonFeatures)
|
||||
extends SeqGroup[BaseTweetCandidate, CrMlRankerResult] {
|
||||
|
||||
override val maxSize: Int = maxBatchSize
|
||||
|
||||
override protected def run(
|
||||
tweetCandidates: Seq[BaseTweetCandidate]
|
||||
): Future[Seq[Try[CrMlRankerResult]]] = {
|
||||
val crMlRankerCandidates =
|
||||
tweetCandidates.map { tweetCandidate =>
|
||||
t.RankingCandidate(
|
||||
tweetId = tweetCandidate.id,
|
||||
hydrationContext = Some(
|
||||
t.FeatureHydrationContext.HomeHydrationContext(t
|
||||
.HomeFeatureHydrationContext(tweetAuthor = None)))
|
||||
)
|
||||
}
|
||||
|
||||
val thriftResults = crMLRanker.getRankedResults(
|
||||
t.RankingRequest(
|
||||
requestContext = t.RankingRequestContext(
|
||||
userId = userId,
|
||||
config = rankingConfig
|
||||
),
|
||||
candidates = crMlRankerCandidates,
|
||||
commonFeatures = commonFeatures.commonFeatures
|
||||
)
|
||||
)
|
||||
|
||||
thriftResults.map { response =>
|
||||
response.scoredTweets.map { scoredTweet =>
|
||||
Return(
|
||||
CrMlRankerResult(
|
||||
tweetId = scoredTweet.tweetId,
|
||||
score = scoredTweet.score
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/javax/inject:javax.inject",
|
||||
"cortex-deepbird/thrift/src/main/thrift:thrift-java",
|
||||
"finagle/finagle-http/src/main/scala",
|
||||
"finatra-internal/mtls/src/main/scala",
|
||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/candidate",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/scorer/common",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/featuremap",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/featuremap/datarecord",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/featuremap/featurestorev1",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/configapi",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/scorer",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline/pipeline_failure",
|
||||
"scrooge/scrooge-serializer",
|
||||
"src/java/com/twitter/ml/api:api-base",
|
||||
"src/java/com/twitter/ml/common/base",
|
||||
"src/java/com/twitter/ml/prediction/core",
|
||||
"src/thrift/com/twitter/ml/prediction_service:prediction_service-java",
|
||||
"src/thrift/com/twitter/ml/prediction_service:prediction_service-scala",
|
||||
"twml/runtime/src/main/scala/com/twitter/deepbird/runtime/prediction_engine",
|
||||
],
|
||||
exports = [
|
||||
"cortex-deepbird/thrift/src/main/thrift:thrift-java",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/candidate",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/scorer/common",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/featuremap/datarecord",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/configapi",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/scorer",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline/pipeline_failure",
|
||||
"src/java/com/twitter/ml/prediction/core",
|
||||
"src/thrift/com/twitter/ml/prediction_service:prediction_service-java",
|
||||
"twml/runtime/src/main/scala/com/twitter/deepbird/runtime/prediction_engine",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,91 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.deepbird
|
||||
|
||||
import com.twitter.product_mixer.core.feature.datarecord.BaseDataRecordFeature
|
||||
import com.twitter.ml.prediction_service.BatchPredictionRequest
|
||||
import com.twitter.ml.prediction_service.BatchPredictionResponse
|
||||
import com.twitter.cortex.deepbird.thriftjava.{ModelSelector => TModelSelector}
|
||||
import com.twitter.ml.api.DataRecord
|
||||
import com.twitter.product_mixer.component_library.scorer.common.ModelSelector
|
||||
import com.twitter.product_mixer.core.feature.Feature
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.DataRecordConverter
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.DataRecordExtractor
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.FeaturesScope
|
||||
import com.twitter.product_mixer.core.functional_component.scorer.Scorer
|
||||
import com.twitter.product_mixer.core.model.common.CandidateWithFeatures
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
|
||||
import scala.collection.JavaConverters._
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.product_mixer.core.pipeline.pipeline_failure.IllegalStateFailure
|
||||
import com.twitter.product_mixer.core.pipeline.pipeline_failure.PipelineFailure
|
||||
import com.twitter.stitch.Stitch
|
||||
import com.twitter.util.Future
|
||||
|
||||
abstract class BaseDeepbirdV2Scorer[
|
||||
Query <: PipelineQuery,
|
||||
Candidate <: UniversalNoun[Any],
|
||||
QueryFeatures <: BaseDataRecordFeature[Query, _],
|
||||
CandidateFeatures <: BaseDataRecordFeature[Candidate, _],
|
||||
ResultFeatures <: BaseDataRecordFeature[Candidate, _]
|
||||
](
|
||||
override val identifier: ScorerIdentifier,
|
||||
modelIdSelector: ModelSelector[Query],
|
||||
queryFeatures: FeaturesScope[QueryFeatures],
|
||||
candidateFeatures: FeaturesScope[CandidateFeatures],
|
||||
resultFeatures: Set[ResultFeatures])
|
||||
extends Scorer[Query, Candidate] {
|
||||
|
||||
private val queryDataRecordConverter = new DataRecordConverter(queryFeatures)
|
||||
private val candidateDataRecordConverter = new DataRecordConverter(candidateFeatures)
|
||||
private val resultDataRecordExtractor = new DataRecordExtractor(resultFeatures)
|
||||
|
||||
require(resultFeatures.nonEmpty, "Result features cannot be empty")
|
||||
override val features: Set[Feature[_, _]] = resultFeatures.asInstanceOf[Set[Feature[_, _]]]
|
||||
def getBatchPredictions(
|
||||
request: BatchPredictionRequest,
|
||||
modelSelector: TModelSelector
|
||||
): Future[BatchPredictionResponse]
|
||||
|
||||
override def apply(
|
||||
query: Query,
|
||||
candidates: Seq[CandidateWithFeatures[Candidate]]
|
||||
): Stitch[Seq[FeatureMap]] = {
|
||||
// Convert all candidate feature maps to java datarecords then to scala datarecords.
|
||||
val thriftCandidateDataRecords = candidates.map { candidate =>
|
||||
candidateDataRecordConverter.toDataRecord(candidate.features)
|
||||
}
|
||||
|
||||
val request = new BatchPredictionRequest(thriftCandidateDataRecords.asJava)
|
||||
|
||||
// Convert the query feature map to data record if available.
|
||||
query.features.foreach { featureMap =>
|
||||
request.setCommonFeatures(queryDataRecordConverter.toDataRecord(featureMap))
|
||||
}
|
||||
|
||||
val modelSelector = modelIdSelector
|
||||
.apply(query).map { id =>
|
||||
val selector = new TModelSelector()
|
||||
selector.setId(id)
|
||||
selector
|
||||
}.orNull
|
||||
|
||||
Stitch.callFuture(getBatchPredictions(request, modelSelector)).map { response =>
|
||||
val dataRecords = Option(response.predictions).map(_.asScala).getOrElse(Seq.empty)
|
||||
buildResults(candidates, dataRecords)
|
||||
}
|
||||
}
|
||||
|
||||
private def buildResults(
|
||||
candidates: Seq[CandidateWithFeatures[Candidate]],
|
||||
dataRecords: Seq[DataRecord]
|
||||
): Seq[FeatureMap] = {
|
||||
if (dataRecords.size != candidates.size) {
|
||||
throw PipelineFailure(IllegalStateFailure, "Result Size mismatched candidates size")
|
||||
}
|
||||
|
||||
dataRecords.map { resultDataRecord =>
|
||||
resultDataRecordExtractor.fromDataRecord(resultDataRecord)
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,55 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.deepbird
|
||||
|
||||
import com.twitter.cortex.deepbird.{thriftjava => t}
|
||||
import com.twitter.ml.prediction_service.BatchPredictionRequest
|
||||
import com.twitter.ml.prediction_service.BatchPredictionResponse
|
||||
import com.twitter.product_mixer.component_library.scorer.common.ModelSelector
|
||||
import com.twitter.product_mixer.core.feature.datarecord.BaseDataRecordFeature
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.FeaturesScope
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.util.Future
|
||||
|
||||
/**
|
||||
* Configurable Scorer that calls any Deepbird Prediction Service thrift.
|
||||
* @param identifier Unique identifier for the scorer
|
||||
* @param predictionService The Prediction Thrift Service
|
||||
* @param modelSelector Model ID Selector to decide which model to select, can also be represented
|
||||
* as an anonymous function: { query: Query => Some("Ex") }
|
||||
* @param queryFeatures The Query Features to convert and pass to the deepbird model.
|
||||
* @param candidateFeatures The Candidate Features to convert and pass to the deepbird model.
|
||||
* @param resultFeatures The Candidate features returned by the model.
|
||||
* @tparam Query Type of pipeline query.
|
||||
* @tparam Candidate Type of candidates to score.
|
||||
* @tparam QueryFeatures type of the query level features consumed by the scorer.
|
||||
* @tparam CandidateFeatures type of the candidate level features consumed by the scorer.
|
||||
* @tparam ResultFeatures type of the candidate level features returned by the scorer.
|
||||
*/
|
||||
case class DeepbirdV2PredictionServerScorer[
|
||||
Query <: PipelineQuery,
|
||||
Candidate <: UniversalNoun[Any],
|
||||
QueryFeatures <: BaseDataRecordFeature[Query, _],
|
||||
CandidateFeatures <: BaseDataRecordFeature[Candidate, _],
|
||||
ResultFeatures <: BaseDataRecordFeature[Candidate, _]
|
||||
](
|
||||
override val identifier: ScorerIdentifier,
|
||||
predictionService: t.DeepbirdPredictionService.ServiceToClient,
|
||||
modelSelector: ModelSelector[Query],
|
||||
queryFeatures: FeaturesScope[QueryFeatures],
|
||||
candidateFeatures: FeaturesScope[CandidateFeatures],
|
||||
resultFeatures: Set[ResultFeatures])
|
||||
extends BaseDeepbirdV2Scorer[
|
||||
Query,
|
||||
Candidate,
|
||||
QueryFeatures,
|
||||
CandidateFeatures,
|
||||
ResultFeatures
|
||||
](identifier, modelSelector, queryFeatures, candidateFeatures, resultFeatures) {
|
||||
|
||||
override def getBatchPredictions(
|
||||
request: BatchPredictionRequest,
|
||||
modelSelector: t.ModelSelector
|
||||
): Future[BatchPredictionResponse] =
|
||||
predictionService.batchPredictFromModel(request, modelSelector)
|
||||
}
|
Binary file not shown.
|
@ -1,61 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.deepbird
|
||||
|
||||
import com.twitter.ml.prediction.core.PredictionEngine
|
||||
import com.twitter.ml.prediction_service.PredictionRequest
|
||||
import com.twitter.product_mixer.core.feature.Feature
|
||||
import com.twitter.product_mixer.core.feature.datarecord.BaseDataRecordFeature
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.DataRecordConverter
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.DataRecordExtractor
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.FeaturesScope
|
||||
import com.twitter.product_mixer.core.functional_component.scorer.Scorer
|
||||
import com.twitter.product_mixer.core.model.common.CandidateWithFeatures
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.stitch.Stitch
|
||||
|
||||
/**
|
||||
* Scorer that locally loads a Deepbird model.
|
||||
* @param identifier Unique identifier for the scorer
|
||||
* @param predictionEngine Prediction Engine hosting the Deepbird model.
|
||||
* @param candidateFeatures The Candidate Features to convert and pass to the deepbird model.
|
||||
* @param resultFeatures The Candidate features returned by the model.
|
||||
* @tparam Query Type of pipeline query.
|
||||
* @tparam Candidate Type of candidates to score.
|
||||
* @tparam QueryFeatures type of the query level features consumed by the scorer.
|
||||
* @tparam CandidateFeatures type of the candidate level features consumed by the scorer.
|
||||
* @tparam ResultFeatures type of the candidate level features returned by the scorer.
|
||||
*/
|
||||
class LollyPredictionEngineScorer[
|
||||
Query <: PipelineQuery,
|
||||
Candidate <: UniversalNoun[Any],
|
||||
QueryFeatures <: BaseDataRecordFeature[Query, _],
|
||||
CandidateFeatures <: BaseDataRecordFeature[Candidate, _],
|
||||
ResultFeatures <: BaseDataRecordFeature[Candidate, _]
|
||||
](
|
||||
override val identifier: ScorerIdentifier,
|
||||
predictionEngine: PredictionEngine,
|
||||
candidateFeatures: FeaturesScope[CandidateFeatures],
|
||||
resultFeatures: Set[ResultFeatures])
|
||||
extends Scorer[Query, Candidate] {
|
||||
|
||||
private val dataRecordAdapter = new DataRecordConverter(candidateFeatures)
|
||||
|
||||
require(resultFeatures.nonEmpty, "Result features cannot be empty")
|
||||
override val features: Set[Feature[_, _]] = resultFeatures.asInstanceOf[Set[Feature[_, _]]]
|
||||
|
||||
private val resultsDataRecordExtractor = new DataRecordExtractor(resultFeatures)
|
||||
|
||||
override def apply(
|
||||
query: Query,
|
||||
candidates: Seq[CandidateWithFeatures[Candidate]]
|
||||
): Stitch[Seq[FeatureMap]] = {
|
||||
val featureMaps = candidates.map { candidateWithFeatures =>
|
||||
val dataRecord = dataRecordAdapter.toDataRecord(candidateWithFeatures.features)
|
||||
val predictionResponse = predictionEngine.apply(new PredictionRequest(dataRecord), true)
|
||||
resultsDataRecordExtractor.fromDataRecord(predictionResponse.getPrediction)
|
||||
}
|
||||
Stitch.value(featureMaps)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,58 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.deepbird
|
||||
|
||||
import com.twitter.cortex.deepbird.runtime.prediction_engine.TensorflowPredictionEngine
|
||||
import com.twitter.cortex.deepbird.thriftjava.ModelSelector
|
||||
import com.twitter.ml.prediction_service.BatchPredictionRequest
|
||||
import com.twitter.ml.prediction_service.BatchPredictionResponse
|
||||
import com.twitter.product_mixer.core.feature.datarecord.BaseDataRecordFeature
|
||||
import com.twitter.product_mixer.core.feature.featuremap.datarecord.FeaturesScope
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.util.Future
|
||||
|
||||
/**
|
||||
* Configurable Scorer that calls a TensorflowPredictionEngine.
|
||||
* @param identifier Unique identifier for the scorer
|
||||
* @param tensorflowPredictionEngine The TensorFlow Prediction Engine
|
||||
* @param queryFeatures The Query Features to convert and pass to the deepbird model.
|
||||
* @param candidateFeatures The Candidate Features to convert and pass to the deepbird model.
|
||||
* @param resultFeatures The Candidate features returned by the model.
|
||||
* @tparam Query Type of pipeline query.
|
||||
* @tparam Candidate Type of candidates to score.
|
||||
* @tparam QueryFeatures type of the query level features consumed by the scorer.
|
||||
* @tparam CandidateFeatures type of the candidate level features consumed by the scorer.
|
||||
* @tparam ResultFeatures type of the candidate level features returned by the scorer.
|
||||
*/
|
||||
class TensorflowPredictionEngineScorer[
|
||||
Query <: PipelineQuery,
|
||||
Candidate <: UniversalNoun[Any],
|
||||
QueryFeatures <: BaseDataRecordFeature[Query, _],
|
||||
CandidateFeatures <: BaseDataRecordFeature[Candidate, _],
|
||||
ResultFeatures <: BaseDataRecordFeature[Candidate, _]
|
||||
](
|
||||
override val identifier: ScorerIdentifier,
|
||||
tensorflowPredictionEngine: TensorflowPredictionEngine,
|
||||
queryFeatures: FeaturesScope[QueryFeatures],
|
||||
candidateFeatures: FeaturesScope[CandidateFeatures],
|
||||
resultFeatures: Set[ResultFeatures])
|
||||
extends BaseDeepbirdV2Scorer[
|
||||
Query,
|
||||
Candidate,
|
||||
QueryFeatures,
|
||||
CandidateFeatures,
|
||||
ResultFeatures
|
||||
](
|
||||
identifier,
|
||||
{ _: Query =>
|
||||
None
|
||||
},
|
||||
queryFeatures,
|
||||
candidateFeatures,
|
||||
resultFeatures) {
|
||||
|
||||
override def getBatchPredictions(
|
||||
request: BatchPredictionRequest,
|
||||
modelSelector: ModelSelector
|
||||
): Future[BatchPredictionResponse] = tensorflowPredictionEngine.getBatchPrediction(request)
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/configapi",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/scorer",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
],
|
||||
exports = [
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/configapi",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/scorer",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,43 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.param_gated
|
||||
|
||||
import com.twitter.product_mixer.component_library.scorer.param_gated.ParamGatedScorer.IdentifierPrefix
|
||||
import com.twitter.product_mixer.core.feature.Feature
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
||||
import com.twitter.product_mixer.core.functional_component.common.alert.Alert
|
||||
import com.twitter.product_mixer.core.functional_component.scorer.Scorer
|
||||
import com.twitter.product_mixer.core.model.common.CandidateWithFeatures
|
||||
import com.twitter.product_mixer.core.model.common.Conditionally
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.stitch.Stitch
|
||||
import com.twitter.timelines.configapi.Param
|
||||
|
||||
/**
|
||||
* A [[scorer]] with [[Conditionally]] based on a [[Param]]
|
||||
*
|
||||
* @param enabledParam the param to turn this [[scorer]] on and off
|
||||
* @param scorer the underlying [[scorer]] to run when `enabledParam` is true
|
||||
* @tparam Query The domain model for the query or request
|
||||
* @tparam Result The type of the candidates
|
||||
*/
|
||||
case class ParamGatedScorer[-Query <: PipelineQuery, Result <: UniversalNoun[Any]](
|
||||
enabledParam: Param[Boolean],
|
||||
scorer: Scorer[Query, Result])
|
||||
extends Scorer[Query, Result]
|
||||
with Conditionally[Query] {
|
||||
override val identifier: ScorerIdentifier = ScorerIdentifier(
|
||||
IdentifierPrefix + scorer.identifier.name)
|
||||
override val alerts: Seq[Alert] = scorer.alerts
|
||||
override val features: Set[Feature[_, _]] = scorer.features
|
||||
override def onlyIf(query: Query): Boolean =
|
||||
Conditionally.and(query, scorer, query.params(enabledParam))
|
||||
override def apply(
|
||||
query: Query,
|
||||
candidates: Seq[CandidateWithFeatures[Result]]
|
||||
): Stitch[Seq[FeatureMap]] = scorer(query, candidates)
|
||||
}
|
||||
|
||||
object ParamGatedScorer {
|
||||
val IdentifierPrefix = "ParamGated"
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/configapi",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/scorer",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
],
|
||||
exports = [
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/configapi",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/scorer",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,59 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.qualityfactor_gated
|
||||
|
||||
import com.twitter.product_mixer.component_library.scorer.qualityfactor_gated.QualityFactorGatedScorer.IdentifierPrefix
|
||||
import com.twitter.product_mixer.core.feature.Feature
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
||||
import com.twitter.product_mixer.core.functional_component.common.alert.Alert
|
||||
import com.twitter.product_mixer.core.functional_component.scorer.Scorer
|
||||
import com.twitter.product_mixer.core.model.common.CandidateWithFeatures
|
||||
import com.twitter.product_mixer.core.model.common.Conditionally
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ComponentIdentifier
|
||||
import com.twitter.product_mixer.core.model.common.identifier.ScorerIdentifier
|
||||
import com.twitter.product_mixer.core.pipeline.PipelineQuery
|
||||
import com.twitter.product_mixer.core.quality_factor.HasQualityFactorStatus
|
||||
import com.twitter.stitch.Stitch
|
||||
import com.twitter.timelines.configapi.Param
|
||||
|
||||
/**
|
||||
* A [[scorer]] with [[Conditionally]] based on quality factor value and threshold
|
||||
*
|
||||
* @param qualityFactorThreshold quliaty factor threshold that turn off the scorer
|
||||
* @param pipelineIdentifier identifier of the pipeline that quality factor is based on
|
||||
* @param scorer the underlying [[scorer]] to run when `enabledParam` is true
|
||||
* @tparam Query The domain model for the query or request
|
||||
* @tparam Result The type of the candidates
|
||||
*/
|
||||
case class QualityFactorGatedScorer[
|
||||
-Query <: PipelineQuery with HasQualityFactorStatus,
|
||||
Result <: UniversalNoun[Any]
|
||||
](
|
||||
pipelineIdentifier: ComponentIdentifier,
|
||||
qualityFactorThresholdParam: Param[Double],
|
||||
scorer: Scorer[Query, Result])
|
||||
extends Scorer[Query, Result]
|
||||
with Conditionally[Query] {
|
||||
|
||||
override val identifier: ScorerIdentifier = ScorerIdentifier(
|
||||
IdentifierPrefix + scorer.identifier.name)
|
||||
|
||||
override val alerts: Seq[Alert] = scorer.alerts
|
||||
|
||||
override val features: Set[Feature[_, _]] = scorer.features
|
||||
|
||||
override def onlyIf(query: Query): Boolean =
|
||||
Conditionally.and(
|
||||
query,
|
||||
scorer,
|
||||
query.getQualityFactorCurrentValue(pipelineIdentifier) >= query.params(
|
||||
qualityFactorThresholdParam))
|
||||
|
||||
override def apply(
|
||||
query: Query,
|
||||
candidates: Seq[CandidateWithFeatures[Result]]
|
||||
): Stitch[Seq[FeatureMap]] = scorer(query, candidates)
|
||||
}
|
||||
|
||||
object QualityFactorGatedScorer {
|
||||
val IdentifierPrefix = "QualityFactorGated"
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/candidate",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/scorer/common",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/featuremap",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/featuremap/featurestorev1",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/feature/featurestorev1",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
"src/thrift/com/twitter/ml/api:embedding-scala",
|
||||
],
|
||||
exports = [
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/candidate",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/scorer/common",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,13 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.tensorbuilder
|
||||
|
||||
import inference.GrpcService.ModelInferRequest.InferInputTensor
|
||||
|
||||
case object BooleanInferInputTensorBuilder extends InferInputTensorBuilder[Boolean] {
|
||||
def apply(
|
||||
featureName: String,
|
||||
featureValues: Seq[Boolean]
|
||||
): Seq[InferInputTensor] = {
|
||||
val tensorShape = Seq(featureValues.size, 1)
|
||||
InferInputTensorBuilder.buildBoolInferInputTensor(featureName, featureValues, tensorShape)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,13 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.tensorbuilder
|
||||
|
||||
import inference.GrpcService.ModelInferRequest.InferInputTensor
|
||||
|
||||
case object BytesInferInputTensorBuilder extends InferInputTensorBuilder[String] {
|
||||
def apply(
|
||||
featureName: String,
|
||||
featureValues: Seq[String]
|
||||
): Seq[InferInputTensor] = {
|
||||
val tensorShape = Seq(featureValues.size, 1)
|
||||
InferInputTensorBuilder.buildBytesInferInputTensor(featureName, featureValues, tensorShape)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,70 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.scorer.tensorbuilder
|
||||
|
||||
import com.twitter.ml.api.thriftscala.FloatTensor
|
||||
import com.twitter.product_mixer.core.feature.Feature
|
||||
import com.twitter.product_mixer.core.feature.FeatureWithDefaultOnFailure
|
||||
import com.twitter.product_mixer.core.feature.ModelFeatureName
|
||||
import com.twitter.product_mixer.core.feature.featuremap.featurestorev1.FeatureStoreV1FeatureMap._
|
||||
import com.twitter.product_mixer.core.feature.featurestorev1.FeatureStoreV1CandidateFeature
|
||||
import com.twitter.product_mixer.core.feature.featurestorev1.FeatureStoreV1QueryFeature
|
||||
import com.twitter.product_mixer.core.model.common.CandidateWithFeatures
|
||||
import com.twitter.product_mixer.core.model.common.UniversalNoun
|
||||
import inference.GrpcService.ModelInferRequest.InferInputTensor
|
||||
|
||||
class CandidateInferInputTensorBuilder[-Candidate <: UniversalNoun[Any], +Value](
|
||||
builder: InferInputTensorBuilder[Value],
|
||||
features: Set[_ <: Feature[Candidate, _] with ModelFeatureName]) {
|
||||
def apply(
|
||||
candidates: Seq[CandidateWithFeatures[Candidate]],
|
||||
): Seq[InferInputTensor] = {
|
||||
features.flatMap { feature =>
|
||||
val featureValues: Seq[Value] = feature match {
|
||||
case feature: FeatureStoreV1CandidateFeature[_, Candidate, _, Value] =>
|
||||
candidates.map(_.features.getFeatureStoreV1CandidateFeature(feature))
|
||||
case feature: FeatureStoreV1QueryFeature[_, _, _] =>
|
||||
throw new UnexpectedFeatureTypeException(feature)
|
||||
case feature: FeatureWithDefaultOnFailure[Candidate, Value] =>
|
||||
candidates.map(_.features.getTry(feature).toOption.getOrElse(feature.defaultValue))
|
||||
case feature: Feature[Candidate, Value] =>
|
||||
candidates.map(_.features.get(feature))
|
||||
}
|
||||
builder.apply(feature.featureName, featureValues)
|
||||
}.toSeq
|
||||
}
|
||||
}
|
||||
|
||||
case class CandidateBooleanInferInputTensorBuilder[-Candidate <: UniversalNoun[Any]](
|
||||
features: Set[_ <: Feature[Candidate, Boolean] with ModelFeatureName])
|
||||
extends CandidateInferInputTensorBuilder[Candidate, Boolean](
|
||||
BooleanInferInputTensorBuilder,
|
||||
features)
|
||||
|
||||
case class CandidateBytesInferInputTensorBuilder[-Candidate <: UniversalNoun[Any]](
|
||||
features: Set[_ <: Feature[Candidate, String] with ModelFeatureName])
|
||||
extends CandidateInferInputTensorBuilder[Candidate, String](
|
||||
BytesInferInputTensorBuilder,
|
||||
features)
|
||||
|
||||
case class CandidateFloat32InferInputTensorBuilder[-Candidate <: UniversalNoun[Any]](
|
||||
features: Set[_ <: Feature[Candidate, _ <: AnyVal] with ModelFeatureName])
|
||||
extends CandidateInferInputTensorBuilder[Candidate, AnyVal](
|
||||
Float32InferInputTensorBuilder,
|
||||
features)
|
||||
|
||||
case class CandidateFloatTensorInferInputTensorBuilder[-Candidate <: UniversalNoun[Any]](
|
||||
features: Set[_ <: Feature[Candidate, FloatTensor] with ModelFeatureName])
|
||||
extends CandidateInferInputTensorBuilder[Candidate, FloatTensor](
|
||||
FloatTensorInferInputTensorBuilder,
|
||||
features)
|
||||
|
||||
case class CandidateInt64InferInputTensorBuilder[-Candidate <: UniversalNoun[Any]](
|
||||
features: Set[_ <: Feature[Candidate, _ <: AnyVal] with ModelFeatureName])
|
||||
extends CandidateInferInputTensorBuilder[Candidate, AnyVal](
|
||||
Int64InferInputTensorBuilder,
|
||||
features)
|
||||
|
||||
case class CandidateSparseMapInferInputTensorBuilder[-Candidate <: UniversalNoun[Any]](
|
||||
features: Set[_ <: Feature[Candidate, Option[Map[Int, Double]]] with ModelFeatureName])
|
||||
extends CandidateInferInputTensorBuilder[Candidate, Option[Map[Int, Double]]](
|
||||
SparseMapInferInputTensorBuilder,
|
||||
features)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue