[docx] split commit for file 2000
Signed-off-by: Ari Archer <ari.web.xyz@gmail.com>
This commit is contained in:
parent
65c3a3fe90
commit
2488f40edf
Binary file not shown.
|
@ -1,47 +0,0 @@
|
|||
use anyhow::Result;
|
||||
use log::info;
|
||||
use navi::cli_args::{ARGS, MODEL_SPECS};
|
||||
use navi::cores::validator::validatior::cli_validator;
|
||||
use navi::tf_model::tf::TFModel;
|
||||
use navi::{bootstrap, metrics};
|
||||
use sha256::digest;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
env_logger::init();
|
||||
cli_validator::validate_input_args();
|
||||
//only validate in for tf as other models don't have this
|
||||
assert_eq!(MODEL_SPECS.len(), ARGS.serving_sig.len());
|
||||
metrics::register_custom_metrics();
|
||||
|
||||
//load all the custom ops - comma seperaed
|
||||
if let Some(ref customops_lib) = ARGS.customops_lib {
|
||||
for op_lib in customops_lib.split(",") {
|
||||
load_custom_op(op_lib);
|
||||
}
|
||||
}
|
||||
|
||||
// versioning the customop so library
|
||||
bootstrap::bootstrap(TFModel::new)
|
||||
}
|
||||
|
||||
fn load_custom_op(lib_path: &str) -> () {
|
||||
let res = tensorflow::Library::load(lib_path);
|
||||
info!("{} load status:{:?}", lib_path, res);
|
||||
let customop_version_num = get_custom_op_version(lib_path);
|
||||
// Last OP version is recorded
|
||||
metrics::CUSTOMOP_VERSION.set(customop_version_num);
|
||||
}
|
||||
|
||||
//fn get_custom_op_version(customops_lib: &String) -> i64 {
|
||||
fn get_custom_op_version(customops_lib: &str) -> i64 {
|
||||
let customop_bytes = std::fs::read(customops_lib).unwrap(); // Vec<u8>
|
||||
let customop_hash = digest(customop_bytes.as_slice());
|
||||
//conver the last 4 hex digits to version number as prometheus metrics doesn't support string, the total space is 16^4 == 65536
|
||||
let customop_version_num =
|
||||
i64::from_str_radix(&customop_hash[customop_hash.len() - 4..], 16).unwrap();
|
||||
info!(
|
||||
"customop hash: {}, version_number: {}",
|
||||
customop_hash, customop_version_num
|
||||
);
|
||||
customop_version_num
|
||||
}
|
Binary file not shown.
|
@ -1,24 +0,0 @@
|
|||
use anyhow::Result;
|
||||
use log::info;
|
||||
use navi::cli_args::{ARGS, MODEL_SPECS};
|
||||
use navi::onnx_model::onnx::OnnxModel;
|
||||
use navi::{bootstrap, metrics};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
env_logger::init();
|
||||
info!("global: {:?}", ARGS.onnx_global_thread_pool_options);
|
||||
let assert_session_params = if ARGS.onnx_global_thread_pool_options.is_empty() {
|
||||
// std::env::set_var("OMP_NUM_THREADS", "1");
|
||||
info!("now we use per session thread pool");
|
||||
MODEL_SPECS.len()
|
||||
}
|
||||
else {
|
||||
info!("now we use global thread pool");
|
||||
0
|
||||
};
|
||||
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
|
||||
assert_eq!(assert_session_params, ARGS.inter_op_parallelism.len());
|
||||
|
||||
metrics::register_custom_metrics();
|
||||
bootstrap::bootstrap(OnnxModel::new)
|
||||
}
|
Binary file not shown.
|
@ -1,19 +0,0 @@
|
|||
use anyhow::Result;
|
||||
use log::info;
|
||||
use navi::cli_args::ARGS;
|
||||
use navi::metrics;
|
||||
use navi::torch_model::torch::TorchModel;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
env_logger::init();
|
||||
//torch only has global threadpool settings versus tf has per model threadpool settings
|
||||
assert_eq!(1, ARGS.inter_op_parallelism.len());
|
||||
assert_eq!(1, ARGS.intra_op_parallelism.len());
|
||||
//TODO for now we, we assume each model's output has only 1 tensor.
|
||||
//this will be lifted once torch_model properly implements mtl outputs
|
||||
tch::set_num_interop_threads(ARGS.inter_op_parallelism[0].parse()?);
|
||||
tch::set_num_threads(ARGS.intra_op_parallelism[0].parse()?);
|
||||
info!("torch custom ops not used for now");
|
||||
metrics::register_custom_metrics();
|
||||
navi::bootstrap::bootstrap(TorchModel::new)
|
||||
}
|
Binary file not shown.
|
@ -1,326 +0,0 @@
|
|||
use anyhow::Result;
|
||||
use log::{info, warn};
|
||||
use x509_parser::{prelude::{parse_x509_pem}, parse_x509_certificate};
|
||||
use std::collections::HashMap;
|
||||
use tokio::time::Instant;
|
||||
use tonic::{
|
||||
Request,
|
||||
Response, Status, transport::{Certificate, Identity, Server, ServerTlsConfig},
|
||||
};
|
||||
|
||||
// protobuf related
|
||||
use crate::tf_proto::tensorflow_serving::{
|
||||
ClassificationRequest, ClassificationResponse, GetModelMetadataRequest,
|
||||
GetModelMetadataResponse, MultiInferenceRequest, MultiInferenceResponse, PredictRequest,
|
||||
PredictResponse, RegressionRequest, RegressionResponse,
|
||||
};
|
||||
use crate::{kf_serving::{
|
||||
grpc_inference_service_server::GrpcInferenceService, ModelInferRequest, ModelInferResponse,
|
||||
ModelMetadataRequest, ModelMetadataResponse, ModelReadyRequest, ModelReadyResponse,
|
||||
ServerLiveRequest, ServerLiveResponse, ServerMetadataRequest, ServerMetadataResponse,
|
||||
ServerReadyRequest, ServerReadyResponse,
|
||||
}, ModelFactory, tf_proto::tensorflow_serving::prediction_service_server::{
|
||||
PredictionService, PredictionServiceServer,
|
||||
}, VERSION, NAME};
|
||||
|
||||
use crate::PredictResult;
|
||||
use crate::cli_args::{ARGS, INPUTS, OUTPUTS};
|
||||
use crate::metrics::{
|
||||
NAVI_VERSION, NUM_PREDICTIONS, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
|
||||
NUM_REQUESTS_RECEIVED, NUM_REQUESTS_RECEIVED_BY_MODEL, RESPONSE_TIME_COLLECTOR,
|
||||
CERT_EXPIRY_EPOCH
|
||||
};
|
||||
use crate::predict_service::{Model, PredictService};
|
||||
use crate::tf_proto::tensorflow_serving::model_spec::VersionChoice::Version;
|
||||
use crate::tf_proto::tensorflow_serving::ModelSpec;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TensorInputEnum {
|
||||
String(Vec<Vec<u8>>),
|
||||
Int(Vec<i32>),
|
||||
Int64(Vec<i64>),
|
||||
Float(Vec<f32>),
|
||||
Double(Vec<f64>),
|
||||
Boolean(Vec<bool>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TensorInput {
|
||||
pub tensor_data: TensorInputEnum,
|
||||
pub name: String,
|
||||
pub dims: Option<Vec<i64>>,
|
||||
}
|
||||
|
||||
impl TensorInput {
|
||||
pub fn new(tensor_data: TensorInputEnum, name: String, dims: Option<Vec<i64>>) -> TensorInput {
|
||||
TensorInput {
|
||||
tensor_data,
|
||||
name,
|
||||
dims,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorInputEnum {
|
||||
#[inline(always)]
|
||||
pub(crate) fn extend(&mut self, another: TensorInputEnum) {
|
||||
match (self, another) {
|
||||
(Self::String(input), Self::String(ex)) => input.extend(ex),
|
||||
(Self::Int(input), Self::Int(ex)) => input.extend(ex),
|
||||
(Self::Int64(input), Self::Int64(ex)) => input.extend(ex),
|
||||
(Self::Float(input), Self::Float(ex)) => input.extend(ex),
|
||||
(Self::Double(input), Self::Double(ex)) => input.extend(ex),
|
||||
(Self::Boolean(input), Self::Boolean(ex)) => input.extend(ex),
|
||||
x => panic!("input enum type not matched. input:{:?}, ex:{:?}", x.0, x.1),
|
||||
}
|
||||
}
|
||||
#[inline(always)]
|
||||
pub(crate) fn merge_batch(input_tensors: Vec<Vec<TensorInput>>) -> Vec<TensorInput> {
|
||||
input_tensors
|
||||
.into_iter()
|
||||
.reduce(|mut acc, e| {
|
||||
for (i, ext) in acc.iter_mut().zip(e) {
|
||||
i.tensor_data.extend(ext.tensor_data);
|
||||
}
|
||||
acc
|
||||
})
|
||||
.unwrap() //invariant: we expect there's always rows in input_tensors
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
///entry point for tfServing gRPC
|
||||
#[tonic::async_trait]
|
||||
impl<T: Model> GrpcInferenceService for PredictService<T> {
|
||||
async fn server_live(
|
||||
&self,
|
||||
_request: Request<ServerLiveRequest>,
|
||||
) -> Result<Response<ServerLiveResponse>, Status> {
|
||||
unimplemented!()
|
||||
}
|
||||
async fn server_ready(
|
||||
&self,
|
||||
_request: Request<ServerReadyRequest>,
|
||||
) -> Result<Response<ServerReadyResponse>, Status> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn model_ready(
|
||||
&self,
|
||||
_request: Request<ModelReadyRequest>,
|
||||
) -> Result<Response<ModelReadyResponse>, Status> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn server_metadata(
|
||||
&self,
|
||||
_request: Request<ServerMetadataRequest>,
|
||||
) -> Result<Response<ServerMetadataResponse>, Status> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn model_metadata(
|
||||
&self,
|
||||
_request: Request<ModelMetadataRequest>,
|
||||
) -> Result<Response<ModelMetadataResponse>, Status> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn model_infer(
|
||||
&self,
|
||||
_request: Request<ModelInferRequest>,
|
||||
) -> Result<Response<ModelInferResponse>, Status> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl<T: Model> PredictionService for PredictService<T> {
|
||||
async fn classify(
|
||||
&self,
|
||||
_request: Request<ClassificationRequest>,
|
||||
) -> Result<Response<ClassificationResponse>, Status> {
|
||||
unimplemented!()
|
||||
}
|
||||
async fn regress(
|
||||
&self,
|
||||
_request: Request<RegressionRequest>,
|
||||
) -> Result<Response<RegressionResponse>, Status> {
|
||||
unimplemented!()
|
||||
}
|
||||
async fn predict(
|
||||
&self,
|
||||
request: Request<PredictRequest>,
|
||||
) -> Result<Response<PredictResponse>, Status> {
|
||||
NUM_REQUESTS_RECEIVED.inc();
|
||||
let start = Instant::now();
|
||||
let mut req = request.into_inner();
|
||||
let (model_spec, version) = req.take_model_spec();
|
||||
NUM_REQUESTS_RECEIVED_BY_MODEL
|
||||
.with_label_values(&[&model_spec])
|
||||
.inc();
|
||||
let idx = PredictService::<T>::get_model_index(&model_spec).ok_or_else(|| {
|
||||
Status::failed_precondition(format!("model spec not found:{}", model_spec))
|
||||
})?;
|
||||
let input_spec = match INPUTS[idx].get() {
|
||||
Some(input) => input,
|
||||
_ => return Err(Status::not_found(format!("model input spec {}", idx))),
|
||||
};
|
||||
let input_val = req.take_input_vals(input_spec);
|
||||
self.predict(idx, version, input_val, start)
|
||||
.await
|
||||
.map_or_else(
|
||||
|e| {
|
||||
NUM_REQUESTS_FAILED.inc();
|
||||
NUM_REQUESTS_FAILED_BY_MODEL
|
||||
.with_label_values(&[&model_spec])
|
||||
.inc();
|
||||
Err(Status::internal(e.to_string()))
|
||||
},
|
||||
|res| {
|
||||
RESPONSE_TIME_COLLECTOR
|
||||
.with_label_values(&[&model_spec])
|
||||
.observe(start.elapsed().as_millis() as f64);
|
||||
|
||||
match res {
|
||||
PredictResult::Ok(tensors, version) => {
|
||||
let mut outputs = HashMap::new();
|
||||
NUM_PREDICTIONS.with_label_values(&[&model_spec]).inc();
|
||||
//FIXME: uncomment when prediction scores are normal
|
||||
// PREDICTION_SCORE_SUM
|
||||
// .with_label_values(&[&model_spec])
|
||||
// .inc_by(tensors[0]as f64);
|
||||
for (tp, output_name) in tensors
|
||||
.into_iter()
|
||||
.map(|tensor| tensor.create_tensor_proto())
|
||||
.zip(OUTPUTS[idx].iter())
|
||||
{
|
||||
outputs.insert(output_name.to_owned(), tp);
|
||||
}
|
||||
let reply = PredictResponse {
|
||||
model_spec: Some(ModelSpec {
|
||||
version_choice: Some(Version(version)),
|
||||
..Default::default()
|
||||
}),
|
||||
outputs,
|
||||
};
|
||||
Ok(Response::new(reply))
|
||||
}
|
||||
PredictResult::DropDueToOverload => Err(Status::resource_exhausted("")),
|
||||
PredictResult::ModelNotFound(idx) => {
|
||||
Err(Status::not_found(format!("model index {}", idx)))
|
||||
},
|
||||
PredictResult::ModelNotReady(idx) => {
|
||||
Err(Status::unavailable(format!("model index {}", idx)))
|
||||
}
|
||||
PredictResult::ModelVersionNotFound(idx, version) => Err(
|
||||
Status::not_found(format!("model index:{}, version {}", idx, version)),
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
async fn multi_inference(
|
||||
&self,
|
||||
_request: Request<MultiInferenceRequest>,
|
||||
) -> Result<Response<MultiInferenceResponse>, Status> {
|
||||
unimplemented!()
|
||||
}
|
||||
async fn get_model_metadata(
|
||||
&self,
|
||||
_request: Request<GetModelMetadataRequest>,
|
||||
) -> Result<Response<GetModelMetadataResponse>, Status> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
// A function that takes a timestamp as input and returns a ticker stream
|
||||
fn report_expiry(expiry_time: i64) {
|
||||
info!("Certificate expires at epoch: {:?}", expiry_time);
|
||||
CERT_EXPIRY_EPOCH.set(expiry_time as i64);
|
||||
}
|
||||
|
||||
pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
||||
info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS);
|
||||
//we follow SemVer. So here we assume MAJOR.MINOR.PATCH
|
||||
let parts = VERSION
|
||||
.split(".")
|
||||
.map(|v| v.parse::<i64>())
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?;
|
||||
if let [major, minor, patch] = &parts[..] {
|
||||
NAVI_VERSION.set(major * 1000_000 + minor * 1000 + patch);
|
||||
} else {
|
||||
warn!(
|
||||
"version {} doesn't follow SemVer conversion of MAJOR.MINOR.PATCH",
|
||||
VERSION
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.thread_name("async worker")
|
||||
.worker_threads(ARGS.num_worker_threads)
|
||||
.max_blocking_threads(ARGS.max_blocking_threads)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
#[cfg(feature = "navi_console")]
|
||||
console_subscriber::init();
|
||||
let addr = format!("0.0.0.0:{}", ARGS.port).parse()?;
|
||||
|
||||
let ps = PredictService::init(model_factory).await;
|
||||
|
||||
let mut builder = if ARGS.ssl_dir.is_empty() {
|
||||
Server::builder()
|
||||
} else {
|
||||
// Read the pem file as a string
|
||||
let pem_str = std::fs::read_to_string(format!("{}/server.crt", ARGS.ssl_dir)).unwrap();
|
||||
let res = parse_x509_pem(&pem_str.as_bytes());
|
||||
match res {
|
||||
Ok((rem, pem_2)) => {
|
||||
assert!(rem.is_empty());
|
||||
assert_eq!(pem_2.label, String::from("CERTIFICATE"));
|
||||
let res_x509 = parse_x509_certificate(&pem_2.contents);
|
||||
info!("Certificate label: {}", pem_2.label);
|
||||
assert!(res_x509.is_ok());
|
||||
report_expiry(res_x509.unwrap().1.validity().not_after.timestamp());
|
||||
},
|
||||
_ => panic!("PEM parsing failed: {:?}", res),
|
||||
}
|
||||
|
||||
let key = tokio::fs::read(format!("{}/server.key", ARGS.ssl_dir))
|
||||
.await
|
||||
.expect("can't find key file");
|
||||
let crt = tokio::fs::read(format!("{}/server.crt", ARGS.ssl_dir))
|
||||
.await
|
||||
.expect("can't find crt file");
|
||||
let chain = tokio::fs::read(format!("{}/server.chain", ARGS.ssl_dir))
|
||||
.await
|
||||
.expect("can't find chain file");
|
||||
let mut pem = Vec::new();
|
||||
pem.extend(crt);
|
||||
pem.extend(chain);
|
||||
let identity = Identity::from_pem(pem.clone(), key);
|
||||
let client_ca_cert = Certificate::from_pem(pem.clone());
|
||||
let tls = ServerTlsConfig::new()
|
||||
.identity(identity)
|
||||
.client_ca_root(client_ca_cert);
|
||||
Server::builder()
|
||||
.tls_config(tls)
|
||||
.expect("fail to config SSL")
|
||||
};
|
||||
|
||||
info!(
|
||||
"Prometheus server started: 0.0.0.0: {}",
|
||||
ARGS.prometheus_port
|
||||
);
|
||||
|
||||
let ps_server = builder
|
||||
.add_service(PredictionServiceServer::new(ps).accept_gzip().send_gzip())
|
||||
.serve(addr);
|
||||
info!("Prediction server started: {}", addr);
|
||||
ps_server.await.map_err(anyhow::Error::msg)
|
||||
})
|
||||
}
|
Binary file not shown.
|
@ -1,236 +0,0 @@
|
|||
use crate::{MAX_NUM_INPUTS, MAX_NUM_MODELS, MAX_NUM_OUTPUTS};
|
||||
use arrayvec::ArrayVec;
|
||||
use clap::Parser;
|
||||
use log::info;
|
||||
use once_cell::sync::OnceCell;
|
||||
use std::error::Error;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::well_known::Rfc3339;
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
///Navi is configured through CLI arguments(for now) defined below.
|
||||
//TODO: use clap_serde to make it config file driven
|
||||
pub struct Args {
|
||||
#[clap(short, long, help = "gRPC port Navi runs ons")]
|
||||
pub port: i32,
|
||||
#[clap(long, default_value_t = 9000, help = "prometheus metrics port")]
|
||||
pub prometheus_port: u16,
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
default_value_t = 1,
|
||||
help = "number of worker threads for tokio async runtime"
|
||||
)]
|
||||
pub num_worker_threads: usize,
|
||||
#[clap(
|
||||
long,
|
||||
default_value_t = 14,
|
||||
help = "number of blocking threads in tokio blocking thread pool"
|
||||
)]
|
||||
pub max_blocking_threads: usize,
|
||||
#[clap(long, default_value = "16", help = "maximum batch size for a batch")]
|
||||
pub max_batch_size: Vec<String>,
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
default_value = "2",
|
||||
help = "max wait time for accumulating a batch"
|
||||
)]
|
||||
pub batch_time_out_millis: Vec<String>,
|
||||
#[clap(
|
||||
long,
|
||||
default_value_t = 90,
|
||||
help = "threshold to start dropping batches under stress"
|
||||
)]
|
||||
pub batch_drop_millis: u64,
|
||||
#[clap(
|
||||
long,
|
||||
default_value_t = 300,
|
||||
help = "polling interval for new version of a model and META.json config"
|
||||
)]
|
||||
pub model_check_interval_secs: u64,
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
default_value = "models/pvideo/",
|
||||
help = "root directory for models"
|
||||
)]
|
||||
pub model_dir: Vec<String>,
|
||||
#[clap(
|
||||
long,
|
||||
help = "directory containing META.json config. separate from model_dir to facilitate remote config management"
|
||||
)]
|
||||
pub meta_json_dir: Option<String>,
|
||||
#[clap(short, long, default_value = "", help = "directory for ssl certs")]
|
||||
pub ssl_dir: String,
|
||||
#[clap(
|
||||
long,
|
||||
help = "call out to external process to check model updates. custom logic can be written to pull from hdfs, gcs etc"
|
||||
)]
|
||||
pub modelsync_cli: Option<String>,
|
||||
#[clap(
|
||||
long,
|
||||
default_value_t = 1,
|
||||
help = "specify how many versions Navi retains in memory. good for cases of rolling model upgrade"
|
||||
)]
|
||||
pub versions_per_model: usize,
|
||||
#[clap(
|
||||
short,
|
||||
long,
|
||||
help = "most runtimes support loading ops custom writen. currently only implemented for TF"
|
||||
)]
|
||||
pub customops_lib: Option<String>,
|
||||
#[clap(
|
||||
long,
|
||||
default_value = "8",
|
||||
help = "number of threads to paralleling computations inside an op"
|
||||
)]
|
||||
pub intra_op_parallelism: Vec<String>,
|
||||
#[clap(
|
||||
long,
|
||||
help = "number of threads to parallelize computations of the graph"
|
||||
)]
|
||||
pub inter_op_parallelism: Vec<String>,
|
||||
#[clap(
|
||||
long,
|
||||
help = "signature of a serving. only TF"
|
||||
)]
|
||||
pub serving_sig: Vec<String>,
|
||||
#[clap(long, default_value = "examples", help = "name of each input tensor")]
|
||||
pub input: Vec<String>,
|
||||
#[clap(long, default_value = "output_0", help = "name of each output tensor")]
|
||||
pub output: Vec<String>,
|
||||
#[clap(
|
||||
long,
|
||||
default_value_t = 500,
|
||||
help = "max warmup records to use. warmup only implemented for TF"
|
||||
)]
|
||||
pub max_warmup_records: usize,
|
||||
#[clap(long, value_parser = Args::parse_key_val::<String, String>, value_delimiter=',')]
|
||||
pub onnx_global_thread_pool_options: Vec<(String, String)>,
|
||||
#[clap(
|
||||
long,
|
||||
default_value = "true",
|
||||
help = "when to use graph parallelization. only for ONNX"
|
||||
)]
|
||||
pub onnx_use_parallel_mode: String,
|
||||
// #[clap(long, default_value = "false")]
|
||||
// pub onnx_use_onednn: String,
|
||||
#[clap(
|
||||
long,
|
||||
default_value = "true",
|
||||
help = "trace internal memory allocation and generate bulk memory allocations. only for ONNX. turn if off if batch size dynamic"
|
||||
)]
|
||||
pub onnx_use_memory_pattern: String,
|
||||
#[clap(long, value_parser = Args::parse_key_val::<String, String>, value_delimiter=',')]
|
||||
pub onnx_ep_options: Vec<(String, String)>,
|
||||
#[clap(long, help = "choice of gpu EPs for ONNX: cuda or tensorrt")]
|
||||
pub onnx_gpu_ep: Option<String>,
|
||||
#[clap(
|
||||
long,
|
||||
default_value = "home",
|
||||
help = "converter for various input formats"
|
||||
)]
|
||||
pub onnx_use_converter: Option<String>,
|
||||
#[clap(
|
||||
long,
|
||||
help = "whether to enable runtime profiling. only implemented for ONNX for now"
|
||||
)]
|
||||
pub profiling: Option<String>,
|
||||
#[clap(
|
||||
long,
|
||||
default_value = "",
|
||||
help = "metrics reporting for discrete features. only for Home converter for now"
|
||||
)]
|
||||
pub onnx_report_discrete_feature_ids: Vec<String>,
|
||||
#[clap(
|
||||
long,
|
||||
default_value = "",
|
||||
help = "metrics reporting for continuous features. only for Home converter for now"
|
||||
)]
|
||||
pub onnx_report_continuous_feature_ids: Vec<String>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
pub fn get_model_specs(model_dir: Vec<String>) -> Vec<String> {
|
||||
let model_specs = model_dir
|
||||
.iter()
|
||||
//let it panic if some model_dir are wrong
|
||||
.map(|dir| {
|
||||
dir.trim_end_matches('/')
|
||||
.rsplit_once('/')
|
||||
.unwrap()
|
||||
.1
|
||||
.to_owned()
|
||||
})
|
||||
.collect();
|
||||
info!("all model_specs: {:?}", model_specs);
|
||||
model_specs
|
||||
}
|
||||
pub fn version_str_to_epoch(dt_str: &str) -> Result<i64, anyhow::Error> {
|
||||
dt_str
|
||||
.parse::<i64>()
|
||||
.or_else(|_| {
|
||||
let ts = OffsetDateTime::parse(dt_str, &Rfc3339)
|
||||
.map(|d| (d.unix_timestamp_nanos()/1_000_000) as i64);
|
||||
if ts.is_ok() {
|
||||
info!("original version {} -> {}", dt_str, ts.unwrap());
|
||||
}
|
||||
ts
|
||||
})
|
||||
.map_err(anyhow::Error::msg)
|
||||
}
|
||||
/// Parse a single key-value pair
|
||||
fn parse_key_val<T, U>(s: &str) -> Result<(T, U), Box<dyn Error + Send + Sync + 'static>>
|
||||
where
|
||||
T: std::str::FromStr,
|
||||
T::Err: Error + Send + Sync + 'static,
|
||||
U: std::str::FromStr,
|
||||
U::Err: Error + Send + Sync + 'static,
|
||||
{
|
||||
let pos = s
|
||||
.find('=')
|
||||
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{}`", s))?;
|
||||
Ok((s[..pos].parse()?, s[pos + 1..].parse()?))
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
pub static ref ARGS: Args = Args::parse();
|
||||
pub static ref MODEL_SPECS: ArrayVec<String, MAX_NUM_MODELS> = {
|
||||
let mut specs = ArrayVec::<String, MAX_NUM_MODELS>::new();
|
||||
Args::get_model_specs(ARGS.model_dir.clone())
|
||||
.into_iter()
|
||||
.for_each(|m| specs.push(m));
|
||||
specs
|
||||
};
|
||||
pub static ref INPUTS: ArrayVec<OnceCell<ArrayVec<String, MAX_NUM_INPUTS>>, MAX_NUM_MODELS> = {
|
||||
let mut inputs =
|
||||
ArrayVec::<OnceCell<ArrayVec<String, MAX_NUM_INPUTS>>, MAX_NUM_MODELS>::new();
|
||||
for (idx, o) in ARGS.input.iter().enumerate() {
|
||||
if o.trim().is_empty() {
|
||||
info!("input spec is empty for model {}, auto detect later", idx);
|
||||
inputs.push(OnceCell::new());
|
||||
} else {
|
||||
inputs.push(OnceCell::with_value(
|
||||
o.split(",")
|
||||
.map(|s| s.to_owned())
|
||||
.collect::<ArrayVec<String, MAX_NUM_INPUTS>>(),
|
||||
));
|
||||
}
|
||||
}
|
||||
info!("all inputs:{:?}", inputs);
|
||||
inputs
|
||||
};
|
||||
pub static ref OUTPUTS: ArrayVec<ArrayVec<String, MAX_NUM_OUTPUTS>, MAX_NUM_MODELS> = {
|
||||
let mut outputs = ArrayVec::<ArrayVec<String, MAX_NUM_OUTPUTS>, MAX_NUM_MODELS>::new();
|
||||
for o in ARGS.output.iter() {
|
||||
outputs.push(
|
||||
o.split(",")
|
||||
.map(|s| s.to_owned())
|
||||
.collect::<ArrayVec<String, MAX_NUM_OUTPUTS>>(),
|
||||
);
|
||||
}
|
||||
info!("all outputs:{:?}", outputs);
|
||||
outputs
|
||||
};
|
||||
}
|
Binary file not shown.
|
@ -1,22 +0,0 @@
|
|||
pub mod validatior {
|
||||
pub mod cli_validator {
|
||||
use crate::cli_args::{ARGS, MODEL_SPECS};
|
||||
|
||||
pub fn validate_input_args() {
|
||||
assert_eq!(MODEL_SPECS.len(), ARGS.inter_op_parallelism.len());
|
||||
assert_eq!(MODEL_SPECS.len(), ARGS.intra_op_parallelism.len());
|
||||
//TODO for now we, we assume each model's output has only 1 tensor.
|
||||
//this will be lifted once tf_model properly implements mtl outputs
|
||||
//assert_eq!(OUTPUTS.len(), OUTPUTS.iter().fold(0usize, |a, b| a+b.len()));
|
||||
}
|
||||
|
||||
pub fn validate_ps_model_args() {
|
||||
assert!(ARGS.versions_per_model <= 2);
|
||||
assert!(ARGS.versions_per_model >= 1);
|
||||
assert_eq!(MODEL_SPECS.len(), ARGS.input.len());
|
||||
assert_eq!(MODEL_SPECS.len(), ARGS.model_dir.len());
|
||||
assert_eq!(MODEL_SPECS.len(), ARGS.max_batch_size.len());
|
||||
assert_eq!(MODEL_SPECS.len(), ARGS.batch_time_out_millis.len());
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,215 +0,0 @@
|
|||
#[macro_use]
|
||||
extern crate lazy_static;
|
||||
extern crate core;
|
||||
|
||||
use serde_json::Value;
|
||||
use tokio::sync::oneshot::Sender;
|
||||
use tokio::time::Instant;
|
||||
use std::ops::Deref;
|
||||
use itertools::Itertools;
|
||||
use crate::bootstrap::TensorInput;
|
||||
use crate::predict_service::Model;
|
||||
use crate::tf_proto::{DataType, TensorProto};
|
||||
|
||||
pub mod batch;
|
||||
pub mod bootstrap;
|
||||
pub mod cli_args;
|
||||
pub mod metrics;
|
||||
pub mod onnx_model;
|
||||
pub mod predict_service;
|
||||
pub mod tf_model;
|
||||
pub mod torch_model;
|
||||
pub mod cores {
|
||||
pub mod validator;
|
||||
}
|
||||
|
||||
pub mod tf_proto {
|
||||
tonic::include_proto!("tensorflow");
|
||||
pub mod tensorflow_serving {
|
||||
tonic::include_proto!("tensorflow.serving");
|
||||
}
|
||||
}
|
||||
|
||||
pub mod kf_serving {
|
||||
tonic::include_proto!("inference");
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::cli_args::Args;
|
||||
#[test]
|
||||
fn test_version_string_to_epoch() {
|
||||
assert_eq!(
|
||||
Args::version_str_to_epoch("2022-12-20T10:18:53.000Z").unwrap_or(-1),
|
||||
1671531533000
|
||||
);
|
||||
assert_eq!(Args::version_str_to_epoch("1203444").unwrap_or(-1), 1203444);
|
||||
}
|
||||
}
|
||||
|
||||
mod utils {
|
||||
use crate::cli_args::{ARGS, MODEL_SPECS};
|
||||
use anyhow::Result;
|
||||
use log::info;
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn read_config(meta_file: &String) -> Result<Value> {
|
||||
let json = std::fs::read_to_string(meta_file)?;
|
||||
let v: Value = serde_json::from_str(&json)?;
|
||||
Ok(v)
|
||||
}
|
||||
pub fn get_config_or_else<F>(model_config: &Value, key: &str, default: F) -> String
|
||||
where
|
||||
F: FnOnce() -> String,
|
||||
{
|
||||
match model_config[key] {
|
||||
Value::String(ref v) => {
|
||||
info!("from model_config: {}={}", key, v);
|
||||
v.to_string()
|
||||
}
|
||||
Value::Number(ref num) => {
|
||||
info!(
|
||||
"from model_config: {}={} (turn number into a string)",
|
||||
key, num
|
||||
);
|
||||
num.to_string()
|
||||
}
|
||||
_ => {
|
||||
let d = default();
|
||||
info!("from default: {}={}", key, d);
|
||||
d
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn get_config_or(model_config: &Value, key: &str, default: &str) -> String {
|
||||
get_config_or_else(model_config, key, || default.to_string())
|
||||
}
|
||||
pub fn get_meta_dir() -> &'static str {
|
||||
ARGS.meta_json_dir
|
||||
.as_ref()
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or_else(|| {
|
||||
let model_dir = &ARGS.model_dir[0];
|
||||
let meta_dir = &model_dir[0..model_dir.rfind(&MODEL_SPECS[0]).unwrap()];
|
||||
info!(
|
||||
"no meta_json_dir specified, hence derive from first model dir:{}->{}",
|
||||
model_dir, meta_dir
|
||||
);
|
||||
meta_dir
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub type SerializedInput = Vec<u8>;
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
pub const NAME: &str = env!("CARGO_PKG_NAME");
|
||||
pub type ModelFactory<T> = fn(usize, String, &Value) -> anyhow::Result<T>;
|
||||
pub const MAX_NUM_MODELS: usize = 16;
|
||||
pub const MAX_NUM_OUTPUTS: usize = 30;
|
||||
pub const MAX_NUM_INPUTS: usize = 120;
|
||||
pub const META_INFO: &str = "META.json";
|
||||
|
||||
//use a heap allocated generic type here so that both
|
||||
//Tensorflow & Pytorch implementation can return their Tensor wrapped in a Box
|
||||
//without an extra memcopy to Vec
|
||||
pub type TensorReturn<T> = Box<dyn Deref<Target = [T]>>;
|
||||
|
||||
//returned tensor may be int64 i.e., a list of relevant ad ids
|
||||
pub enum TensorReturnEnum {
|
||||
FloatTensorReturn(TensorReturn<f32>),
|
||||
StringTensorReturn(TensorReturn<String>),
|
||||
Int64TensorReturn(TensorReturn<i64>),
|
||||
Int32TensorReturn(TensorReturn<i32>),
|
||||
}
|
||||
|
||||
impl TensorReturnEnum {
|
||||
#[inline(always)]
|
||||
pub fn slice(&self, start: usize, end: usize) -> TensorScores {
|
||||
match self {
|
||||
TensorReturnEnum::FloatTensorReturn(f32_return) => {
|
||||
TensorScores::Float32TensorScores(f32_return[start..end].to_vec())
|
||||
}
|
||||
TensorReturnEnum::Int64TensorReturn(i64_return) => {
|
||||
TensorScores::Int64TensorScores(i64_return[start..end].to_vec())
|
||||
}
|
||||
TensorReturnEnum::Int32TensorReturn(i32_return) => {
|
||||
TensorScores::Int32TensorScores(i32_return[start..end].to_vec())
|
||||
}
|
||||
TensorReturnEnum::StringTensorReturn(str_return) => {
|
||||
TensorScores::StringTensorScores(str_return[start..end].to_vec())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum PredictResult {
|
||||
Ok(Vec<TensorScores>, i64),
|
||||
DropDueToOverload,
|
||||
ModelNotFound(usize),
|
||||
ModelNotReady(usize),
|
||||
ModelVersionNotFound(usize, i64),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TensorScores {
|
||||
Float32TensorScores(Vec<f32>),
|
||||
Int64TensorScores(Vec<i64>),
|
||||
Int32TensorScores(Vec<i32>),
|
||||
StringTensorScores(Vec<String>),
|
||||
}
|
||||
|
||||
impl TensorScores {
|
||||
pub fn create_tensor_proto(self) -> TensorProto {
|
||||
match self {
|
||||
TensorScores::Float32TensorScores(f32_tensor) => TensorProto {
|
||||
dtype: DataType::DtFloat as i32,
|
||||
float_val: f32_tensor,
|
||||
..Default::default()
|
||||
},
|
||||
TensorScores::Int64TensorScores(i64_tensor) => TensorProto {
|
||||
dtype: DataType::DtInt64 as i32,
|
||||
int64_val: i64_tensor,
|
||||
..Default::default()
|
||||
},
|
||||
TensorScores::Int32TensorScores(i32_tensor) => TensorProto {
|
||||
dtype: DataType::DtInt32 as i32,
|
||||
int_val: i32_tensor,
|
||||
..Default::default()
|
||||
},
|
||||
TensorScores::StringTensorScores(str_tensor) => TensorProto {
|
||||
dtype: DataType::DtString as i32,
|
||||
string_val: str_tensor.into_iter().map(|s| s.into_bytes()).collect_vec(),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
pub fn len(&self) -> usize {
|
||||
match &self {
|
||||
TensorScores::Float32TensorScores(t) => t.len(),
|
||||
TensorScores::Int64TensorScores(t) => t.len(),
|
||||
TensorScores::Int32TensorScores(t) => t.len(),
|
||||
TensorScores::StringTensorScores(t) => t.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum PredictMessage<T: Model> {
|
||||
Predict(
|
||||
usize,
|
||||
Option<i64>,
|
||||
Vec<TensorInput>,
|
||||
Sender<PredictResult>,
|
||||
Instant,
|
||||
),
|
||||
UpsertModel(T),
|
||||
/*
|
||||
#[allow(dead_code)]
|
||||
DeleteModel(usize),
|
||||
*/
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Callback(Sender<PredictResult>, usize);
|
||||
|
||||
pub const MAX_VERSIONS_PER_MODEL: usize = 2;
|
Binary file not shown.
|
@ -1,297 +0,0 @@
|
|||
use log::error;
|
||||
use prometheus::{
|
||||
CounterVec, HistogramOpts, HistogramVec, IntCounter, IntCounterVec, IntGauge, IntGaugeVec,
|
||||
Opts, Registry,
|
||||
};
|
||||
use warp::{Rejection, Reply};
|
||||
use crate::{NAME, VERSION};
|
||||
|
||||
lazy_static! {
|
||||
pub static ref REGISTRY: Registry = Registry::new();
|
||||
pub static ref NUM_REQUESTS_RECEIVED: IntCounter =
|
||||
IntCounter::new(":navi:num_requests", "Number of Requests Received")
|
||||
.expect("metric can be created");
|
||||
pub static ref NUM_REQUESTS_FAILED: IntCounter = IntCounter::new(
|
||||
":navi:num_requests_failed",
|
||||
"Number of Request Inference Failed"
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref NUM_REQUESTS_DROPPED: IntCounter = IntCounter::new(
|
||||
":navi:num_requests_dropped",
|
||||
"Number of Oneshot Receivers Dropped"
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref NUM_BATCHES_DROPPED: IntCounter = IntCounter::new(
|
||||
":navi:num_batches_dropped",
|
||||
"Number of Batches Proactively Dropped"
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref NUM_BATCH_PREDICTION: IntCounter =
|
||||
IntCounter::new(":navi:num_batch_prediction", "Number of batch prediction")
|
||||
.expect("metric can be created");
|
||||
pub static ref BATCH_SIZE: IntGauge =
|
||||
IntGauge::new(":navi:batch_size", "Size of current batch").expect("metric can be created");
|
||||
pub static ref NAVI_VERSION: IntGauge =
|
||||
IntGauge::new(":navi:navi_version", "navi's current version")
|
||||
.expect("metric can be created");
|
||||
pub static ref RESPONSE_TIME_COLLECTOR: HistogramVec = HistogramVec::new(
|
||||
HistogramOpts::new(":navi:response_time", "Response Time in ms").buckets(Vec::from(&[
|
||||
0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, 130.0,
|
||||
140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 500.0, 1000.0
|
||||
]
|
||||
as &'static [f64])),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref NUM_PREDICTIONS: IntCounterVec = IntCounterVec::new(
|
||||
Opts::new(
|
||||
":navi:num_predictions",
|
||||
"Number of predictions made by model"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref PREDICTION_SCORE_SUM: CounterVec = CounterVec::new(
|
||||
Opts::new(
|
||||
":navi:prediction_score_sum",
|
||||
"Sum of prediction score made by model"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref NEW_MODEL_SNAPSHOT: IntCounterVec = IntCounterVec::new(
|
||||
Opts::new(
|
||||
":navi:new_model_snapshot",
|
||||
"Load a new version of model snapshot"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref MODEL_SNAPSHOT_VERSION: IntGaugeVec = IntGaugeVec::new(
|
||||
Opts::new(
|
||||
":navi:model_snapshot_version",
|
||||
"Record model snapshot version"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref NUM_REQUESTS_RECEIVED_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
||||
Opts::new(
|
||||
":navi:num_requests_by_model",
|
||||
"Number of Requests Received by model"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref NUM_REQUESTS_FAILED_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
||||
Opts::new(
|
||||
":navi:num_requests_failed_by_model",
|
||||
"Number of Request Inference Failed by model"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref NUM_REQUESTS_DROPPED_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
||||
Opts::new(
|
||||
":navi:num_requests_dropped_by_model",
|
||||
"Number of Oneshot Receivers Dropped by model"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref NUM_BATCHES_DROPPED_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
||||
Opts::new(
|
||||
":navi:num_batches_dropped_by_model",
|
||||
"Number of Batches Proactively Dropped by model"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref INFERENCE_FAILED_REQUESTS_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
||||
Opts::new(
|
||||
":navi:inference_failed_requests_by_model",
|
||||
"Number of failed inference requests by model"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref NUM_PREDICTION_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
||||
Opts::new(
|
||||
":navi:num_prediction_by_model",
|
||||
"Number of prediction by model"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref NUM_BATCH_PREDICTION_BY_MODEL: IntCounterVec = IntCounterVec::new(
|
||||
Opts::new(
|
||||
":navi:num_batch_prediction_by_model",
|
||||
"Number of batch prediction by model"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref BATCH_SIZE_BY_MODEL: IntGaugeVec = IntGaugeVec::new(
|
||||
Opts::new(
|
||||
":navi:batch_size_by_model",
|
||||
"Size of current batch by model"
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref CUSTOMOP_VERSION: IntGauge =
|
||||
IntGauge::new(":navi:customop_version", "The hashed Custom OP Version")
|
||||
.expect("metric can be created");
|
||||
pub static ref MPSC_CHANNEL_SIZE: IntGauge =
|
||||
IntGauge::new(":navi:mpsc_channel_size", "The mpsc channel's request size")
|
||||
.expect("metric can be created");
|
||||
pub static ref BLOCKING_REQUEST_NUM: IntGauge = IntGauge::new(
|
||||
":navi:blocking_request_num",
|
||||
"The (batch) request waiting or being executed"
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref MODEL_INFERENCE_TIME_COLLECTOR: HistogramVec = HistogramVec::new(
|
||||
HistogramOpts::new(":navi:model_inference_time", "Model inference time in ms").buckets(
|
||||
Vec::from(&[
|
||||
0.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0, 65.0,
|
||||
70.0, 75.0, 80.0, 85.0, 90.0, 100.0, 110.0, 120.0, 130.0, 140.0, 150.0, 160.0,
|
||||
170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 500.0, 1000.0
|
||||
] as &'static [f64])
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref CONVERTER_TIME_COLLECTOR: HistogramVec = HistogramVec::new(
|
||||
HistogramOpts::new(":navi:converter_time", "converter time in microseconds").buckets(
|
||||
Vec::from(&[
|
||||
0.0, 500.0, 1000.0, 1500.0, 2000.0, 2500.0, 3000.0, 3500.0, 4000.0, 4500.0, 5000.0,
|
||||
5500.0, 6000.0, 6500.0, 7000.0, 20000.0
|
||||
] as &'static [f64])
|
||||
),
|
||||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref CERT_EXPIRY_EPOCH: IntGauge =
|
||||
IntGauge::new(":navi:cert_expiry_epoch", "Timestamp when the current cert expires")
|
||||
.expect("metric can be created");
|
||||
}
|
||||
|
||||
pub fn register_custom_metrics() {
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_REQUESTS_RECEIVED.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_REQUESTS_FAILED.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_REQUESTS_DROPPED.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(RESPONSE_TIME_COLLECTOR.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NAVI_VERSION.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(BATCH_SIZE.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_BATCH_PREDICTION.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_BATCHES_DROPPED.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_PREDICTIONS.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(PREDICTION_SCORE_SUM.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NEW_MODEL_SNAPSHOT.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(MODEL_SNAPSHOT_VERSION.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_REQUESTS_RECEIVED_BY_MODEL.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_REQUESTS_FAILED_BY_MODEL.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_REQUESTS_DROPPED_BY_MODEL.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_BATCHES_DROPPED_BY_MODEL.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(INFERENCE_FAILED_REQUESTS_BY_MODEL.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_PREDICTION_BY_MODEL.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(NUM_BATCH_PREDICTION_BY_MODEL.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(BATCH_SIZE_BY_MODEL.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(CUSTOMOP_VERSION.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(MPSC_CHANNEL_SIZE.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(BLOCKING_REQUEST_NUM.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(MODEL_INFERENCE_TIME_COLLECTOR.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(CONVERTER_TIME_COLLECTOR.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(CERT_EXPIRY_EPOCH.clone()))
|
||||
.expect("collector can be registered");
|
||||
|
||||
}
|
||||
|
||||
pub fn register_dynamic_metrics(c: &HistogramVec) {
|
||||
REGISTRY
|
||||
.register(Box::new(c.clone()))
|
||||
.expect("dynamic metric collector cannot be registered");
|
||||
}
|
||||
|
||||
pub async fn metrics_handler() -> Result<impl Reply, Rejection> {
|
||||
use prometheus::Encoder;
|
||||
let encoder = prometheus::TextEncoder::new();
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
if let Err(e) = encoder.encode(®ISTRY.gather(), &mut buffer) {
|
||||
error!("could not encode custom metrics: {}", e);
|
||||
};
|
||||
let mut res = match String::from_utf8(buffer) {
|
||||
Ok(v) => format!("#{}:{}\n{}", NAME, VERSION, v),
|
||||
Err(e) => {
|
||||
error!("custom metrics could not be from_utf8'd: {}", e);
|
||||
String::default()
|
||||
}
|
||||
};
|
||||
|
||||
buffer = Vec::new();
|
||||
if let Err(e) = encoder.encode(&prometheus::gather(), &mut buffer) {
|
||||
error!("could not encode prometheus metrics: {}", e);
|
||||
};
|
||||
let res_custom = match String::from_utf8(buffer) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("prometheus metrics could not be from_utf8'd: {}", e);
|
||||
String::default()
|
||||
}
|
||||
};
|
||||
|
||||
res.push_str(&res_custom);
|
||||
Ok(res)
|
||||
}
|
Binary file not shown.
|
@ -1,275 +0,0 @@
|
|||
#[cfg(feature = "onnx")]
|
||||
pub mod onnx {
|
||||
use crate::TensorReturnEnum;
|
||||
use crate::bootstrap::{TensorInput, TensorInputEnum};
|
||||
use crate::cli_args::{
|
||||
Args, ARGS, INPUTS, MODEL_SPECS, OUTPUTS,
|
||||
};
|
||||
use crate::metrics::{self, CONVERTER_TIME_COLLECTOR};
|
||||
use crate::predict_service::Model;
|
||||
use crate::{MAX_NUM_INPUTS, MAX_NUM_OUTPUTS, META_INFO, utils};
|
||||
use anyhow::Result;
|
||||
use arrayvec::ArrayVec;
|
||||
use dr_transform::converter::{BatchPredictionRequestToTorchTensorConverter, Converter};
|
||||
use itertools::Itertools;
|
||||
use log::{debug, info};
|
||||
use dr_transform::ort::environment::Environment;
|
||||
use dr_transform::ort::session::Session;
|
||||
use dr_transform::ort::tensor::InputTensor;
|
||||
use dr_transform::ort::{ExecutionProvider, GraphOptimizationLevel, SessionBuilder};
|
||||
use dr_transform::ort::LoggingLevel;
|
||||
use serde_json::Value;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::sync::Arc;
|
||||
use std::{fmt, fs};
|
||||
use tokio::time::Instant;
|
||||
lazy_static! {
|
||||
pub static ref ENVIRONMENT: Arc<Environment> = Arc::new(
|
||||
Environment::builder()
|
||||
.with_name("onnx home")
|
||||
.with_log_level(LoggingLevel::Error)
|
||||
.with_global_thread_pool(ARGS.onnx_global_thread_pool_options.clone())
|
||||
.build()
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct OnnxModel {
|
||||
pub session: Session,
|
||||
pub model_idx: usize,
|
||||
pub version: i64,
|
||||
pub export_dir: String,
|
||||
pub output_filters: ArrayVec<usize, MAX_NUM_OUTPUTS>,
|
||||
pub input_converter: Box<dyn Converter>,
|
||||
}
|
||||
impl Display for OnnxModel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"idx: {}, onnx model_name:{}, version:{}, output_filters:{:?}, converter:{:}",
|
||||
self.model_idx,
|
||||
MODEL_SPECS[self.model_idx],
|
||||
self.version,
|
||||
self.output_filters,
|
||||
self.input_converter
|
||||
)
|
||||
}
|
||||
}
|
||||
impl Drop for OnnxModel {
|
||||
fn drop(&mut self) {
|
||||
if ARGS.profiling != None {
|
||||
self.session.end_profiling().map_or_else(
|
||||
|e| {
|
||||
info!("end profiling with some error:{:?}", e);
|
||||
},
|
||||
|f| {
|
||||
info!("profiling ended with file:{}", f);
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
impl OnnxModel {
|
||||
fn get_output_filters(session: &Session, idx: usize) -> ArrayVec<usize, MAX_NUM_OUTPUTS> {
|
||||
OUTPUTS[idx]
|
||||
.iter()
|
||||
.map(|output| session.outputs.iter().position(|o| o.name == *output))
|
||||
.flatten()
|
||||
.collect::<ArrayVec<usize, MAX_NUM_OUTPUTS>>()
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
fn ep_choices() -> Vec<ExecutionProvider> {
|
||||
match ARGS.onnx_gpu_ep.as_ref().map(|e| e.as_str()) {
|
||||
Some("onednn") => vec![Self::ep_with_options(ExecutionProvider::onednn())],
|
||||
Some("tensorrt") => vec![Self::ep_with_options(ExecutionProvider::tensorrt())],
|
||||
Some("cuda") => vec![Self::ep_with_options(ExecutionProvider::cuda())],
|
||||
_ => vec![Self::ep_with_options(ExecutionProvider::cpu())],
|
||||
}
|
||||
}
|
||||
fn ep_with_options(mut ep: ExecutionProvider) -> ExecutionProvider {
|
||||
for (ref k, ref v) in ARGS.onnx_ep_options.clone() {
|
||||
ep = ep.with(k, v);
|
||||
info!("setting option:{} -> {} and now ep is:{:?}", k, v, ep);
|
||||
}
|
||||
ep
|
||||
}
|
||||
#[cfg(target_os = "macos")]
|
||||
fn ep_choices() -> Vec<ExecutionProvider> {
|
||||
vec![Self::ep_with_options(ExecutionProvider::cpu())]
|
||||
}
|
||||
pub fn new(idx: usize, version: String, model_config: &Value) -> Result<OnnxModel> {
|
||||
let export_dir = format!("{}/{}/model.onnx", ARGS.model_dir[idx], version);
|
||||
let meta_info = format!("{}/{}/{}", ARGS.model_dir[idx], version, META_INFO);
|
||||
let mut builder = SessionBuilder::new(&ENVIRONMENT)?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_parallel_execution(ARGS.onnx_use_parallel_mode == "true")?;
|
||||
if ARGS.onnx_global_thread_pool_options.is_empty() {
|
||||
builder = builder
|
||||
.with_inter_threads(
|
||||
utils::get_config_or(
|
||||
model_config,
|
||||
"inter_op_parallelism",
|
||||
&ARGS.inter_op_parallelism[idx],
|
||||
)
|
||||
.parse()?,
|
||||
)?
|
||||
.with_intra_threads(
|
||||
utils::get_config_or(
|
||||
model_config,
|
||||
"intra_op_parallelism",
|
||||
&ARGS.intra_op_parallelism[idx],
|
||||
)
|
||||
.parse()?,
|
||||
)?;
|
||||
}
|
||||
else {
|
||||
builder = builder.with_disable_per_session_threads()?;
|
||||
}
|
||||
builder = builder
|
||||
.with_memory_pattern(ARGS.onnx_use_memory_pattern == "true")?
|
||||
.with_execution_providers(&OnnxModel::ep_choices())?;
|
||||
match &ARGS.profiling {
|
||||
Some(p) => {
|
||||
debug!("Enable profiling, writing to {}", *p);
|
||||
builder = builder.with_profiling(p)?
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let session = builder.with_model_from_file(&export_dir)?;
|
||||
|
||||
info!(
|
||||
"inputs: {:?}, outputs: {:?}",
|
||||
session.inputs.iter().format(","),
|
||||
session.outputs.iter().format(",")
|
||||
);
|
||||
|
||||
fs::read_to_string(&meta_info)
|
||||
.ok()
|
||||
.map(|info| info!("meta info:{}", info));
|
||||
let output_filters = OnnxModel::get_output_filters(&session, idx);
|
||||
let mut reporting_feature_ids: Vec<(i64, &str)> = vec![];
|
||||
|
||||
let input_spec_cell = &INPUTS[idx];
|
||||
if input_spec_cell.get().is_none() {
|
||||
let input_spec = session
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|input| input.name.clone())
|
||||
.collect::<ArrayVec<String, MAX_NUM_INPUTS>>();
|
||||
input_spec_cell.set(input_spec.clone()).map_or_else(
|
||||
|_| info!("unable to set the input_spec for model {}", idx),
|
||||
|_| info!("auto detect and set the inputs: {:?}", input_spec),
|
||||
);
|
||||
}
|
||||
ARGS.onnx_report_discrete_feature_ids
|
||||
.iter()
|
||||
.for_each(|ids| {
|
||||
ids.split(",")
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| s.parse::<i64>().unwrap())
|
||||
.for_each(|id| reporting_feature_ids.push((id, "discrete")))
|
||||
});
|
||||
ARGS.onnx_report_continuous_feature_ids
|
||||
.iter()
|
||||
.for_each(|ids| {
|
||||
ids.split(",")
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| s.parse::<i64>().unwrap())
|
||||
.for_each(|id| reporting_feature_ids.push((id, "continuous")))
|
||||
});
|
||||
|
||||
let onnx_model = OnnxModel {
|
||||
session,
|
||||
model_idx: idx,
|
||||
version: Args::version_str_to_epoch(&version)?,
|
||||
export_dir,
|
||||
output_filters,
|
||||
input_converter: Box::new(BatchPredictionRequestToTorchTensorConverter::new(
|
||||
&ARGS.model_dir[idx],
|
||||
&version,
|
||||
reporting_feature_ids,
|
||||
Some(metrics::register_dynamic_metrics),
|
||||
)?),
|
||||
};
|
||||
onnx_model.warmup()?;
|
||||
Ok(onnx_model)
|
||||
}
|
||||
}
|
||||
///Currently we only assume the input as just one string tensor.
|
||||
///The string tensor will be be converted to the actual raw tensors.
|
||||
/// The converter we are using is very specific to home.
|
||||
/// It reads a BatchDataRecord thrift and decode it to a batch of raw input tensors.
|
||||
/// Navi will then do server side batching and feed it to ONNX runtime
|
||||
impl Model for OnnxModel {
|
||||
//TODO: implement a generic online warmup for all runtimes
|
||||
fn warmup(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn do_predict(
|
||||
&self,
|
||||
input_tensors: Vec<Vec<TensorInput>>,
|
||||
_: u64,
|
||||
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>) {
|
||||
let batched_tensors = TensorInputEnum::merge_batch(input_tensors);
|
||||
let (inputs, batch_ends): (Vec<Vec<InputTensor>>, Vec<Vec<usize>>) = batched_tensors
|
||||
.into_iter()
|
||||
.map(|batched_tensor| {
|
||||
match batched_tensor.tensor_data {
|
||||
TensorInputEnum::String(t) if ARGS.onnx_use_converter.is_some() => {
|
||||
let start = Instant::now();
|
||||
let (inputs, batch_ends) = self.input_converter.convert(t);
|
||||
// info!("batch_ends:{:?}", batch_ends);
|
||||
CONVERTER_TIME_COLLECTOR
|
||||
.with_label_values(&[&MODEL_SPECS[self.model_idx()]])
|
||||
.observe(
|
||||
start.elapsed().as_micros() as f64
|
||||
/ (*batch_ends.last().unwrap() as f64),
|
||||
);
|
||||
(inputs, batch_ends)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
})
|
||||
.unzip();
|
||||
//invariant we only support one input as string. will relax later
|
||||
assert_eq!(inputs.len(), 1);
|
||||
let output_tensors = self
|
||||
.session
|
||||
.run(inputs.into_iter().flatten().collect::<Vec<_>>())
|
||||
.unwrap();
|
||||
self.output_filters
|
||||
.iter()
|
||||
.map(|&idx| {
|
||||
let mut size = 1usize;
|
||||
let output = output_tensors[idx].try_extract::<f32>().unwrap();
|
||||
for &dim in self.session.outputs[idx].dimensions.iter().flatten() {
|
||||
size *= dim as usize;
|
||||
}
|
||||
let tensor_ends = batch_ends[0]
|
||||
.iter()
|
||||
.map(|&batch| batch * size)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
(
|
||||
//only works for batch major
|
||||
//TODO: to_vec() obviously wasteful, especially for large batches(GPU) . Will refactor to
|
||||
//break up output and return Vec<Vec<TensorScore>> here
|
||||
TensorReturnEnum::FloatTensorReturn(Box::new(output.view().as_slice().unwrap().to_vec(),
|
||||
)),
|
||||
tensor_ends,
|
||||
)
|
||||
})
|
||||
.unzip()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn model_idx(&self) -> usize {
|
||||
self.model_idx
|
||||
}
|
||||
#[inline(always)]
|
||||
fn version(&self) -> i64 {
|
||||
self.version
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,315 +0,0 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use arrayvec::ArrayVec;
|
||||
use itertools::Itertools;
|
||||
use log::{error, info};
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::string::String;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
use tokio::sync::mpsc::{Receiver, Sender};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::{Instant, sleep};
|
||||
use warp::Filter;
|
||||
|
||||
use crate::batch::BatchPredictor;
|
||||
use crate::bootstrap::TensorInput;
|
||||
use crate::{MAX_NUM_MODELS, MAX_VERSIONS_PER_MODEL, META_INFO, metrics, ModelFactory, PredictMessage, PredictResult, TensorReturnEnum, utils};
|
||||
|
||||
use crate::cli_args::{ARGS, MODEL_SPECS};
|
||||
use crate::cores::validator::validatior::cli_validator;
|
||||
use crate::metrics::MPSC_CHANNEL_SIZE;
|
||||
use serde_json::{self, Value};
|
||||
|
||||
pub trait Model: Send + Sync + Display + Debug + 'static {
|
||||
fn warmup(&self) -> Result<()>;
|
||||
//TODO: refactor this to return vec<vec<TensorScores>>, i.e.
|
||||
//we have the underlying runtime impl to split the response to each client.
|
||||
//It will eliminate some inefficient memory copy in onnx_model.rs as well as simplify code
|
||||
fn do_predict(
|
||||
&self,
|
||||
input_tensors: Vec<Vec<TensorInput>>,
|
||||
total_len: u64,
|
||||
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>);
|
||||
fn model_idx(&self) -> usize;
|
||||
fn version(&self) -> i64;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PredictService<T: Model> {
|
||||
tx: Sender<PredictMessage<T>>,
|
||||
}
|
||||
impl<T: Model> PredictService<T> {
|
||||
pub async fn init(model_factory: ModelFactory<T>) -> Self {
|
||||
cli_validator::validate_ps_model_args();
|
||||
let (tx, rx) = mpsc::channel(32_000);
|
||||
tokio::spawn(PredictService::tf_queue_manager(rx));
|
||||
tokio::spawn(PredictService::model_watcher_latest(
|
||||
model_factory,
|
||||
tx.clone(),
|
||||
));
|
||||
let metrics_route = warp::path!("metrics").and_then(metrics::metrics_handler);
|
||||
let metric_server = warp::serve(metrics_route).run(([0, 0, 0, 0], ARGS.prometheus_port));
|
||||
tokio::spawn(metric_server);
|
||||
PredictService { tx }
|
||||
}
|
||||
#[inline(always)]
|
||||
pub async fn predict(
|
||||
&self,
|
||||
idx: usize,
|
||||
version: Option<i64>,
|
||||
val: Vec<TensorInput>,
|
||||
ts: Instant,
|
||||
) -> Result<PredictResult> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
if let Err(e) = self
|
||||
.tx
|
||||
.clone()
|
||||
.send(PredictMessage::Predict(idx, version, val, tx, ts))
|
||||
.await
|
||||
{
|
||||
error!("mpsc send error:{}", e);
|
||||
Err(anyhow!(e))
|
||||
} else {
|
||||
MPSC_CHANNEL_SIZE.inc();
|
||||
rx.await.map_err(anyhow::Error::msg)
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_latest_model_from_model_dir(
|
||||
model_factory: ModelFactory<T>,
|
||||
model_config: &Value,
|
||||
tx: Sender<PredictMessage<T>>,
|
||||
idx: usize,
|
||||
max_version: String,
|
||||
latest_version: &mut String,
|
||||
) {
|
||||
match model_factory(idx, max_version.clone(), model_config) {
|
||||
Ok(tf_model) => tx
|
||||
.send(PredictMessage::UpsertModel(tf_model))
|
||||
.await
|
||||
.map_or_else(
|
||||
|e| error!("send UpsertModel error: {}", e),
|
||||
|_| *latest_version = max_version,
|
||||
),
|
||||
Err(e) => {
|
||||
error!("skip loading model due to failure: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn scan_load_latest_model_from_model_dir(
|
||||
model_factory: ModelFactory<T>,
|
||||
model_config: &Value,
|
||||
tx: Sender<PredictMessage<T>>,
|
||||
model_idx: usize,
|
||||
cur_version: &mut String,
|
||||
) -> Result<()> {
|
||||
let model_dir = &ARGS.model_dir[model_idx];
|
||||
let next_version = utils::get_config_or_else(model_config, "version", || {
|
||||
info!("no version found, hence use max version");
|
||||
std::fs::read_dir(model_dir)
|
||||
.map_err(|e| format!("read dir error:{}", e))
|
||||
.and_then(|paths| {
|
||||
paths
|
||||
.into_iter()
|
||||
.flat_map(|p| {
|
||||
p.map_err(|e| error!("dir entry error: {}", e))
|
||||
.and_then(|dir| {
|
||||
dir.file_name()
|
||||
.into_string()
|
||||
.map_err(|e| error!("osstring error: {:?}", e))
|
||||
})
|
||||
.ok()
|
||||
})
|
||||
.filter(|f| !f.to_lowercase().contains(&META_INFO.to_lowercase()))
|
||||
.max()
|
||||
.ok_or_else(|| "no dir found hence no max".to_owned())
|
||||
})
|
||||
.unwrap_or_else(|e| {
|
||||
error!(
|
||||
"can't get the max version hence return cur_version, error is: {}",
|
||||
e
|
||||
);
|
||||
cur_version.to_string()
|
||||
})
|
||||
});
|
||||
//as long as next version doesn't match cur version maintained we reload
|
||||
if next_version.ne(cur_version) {
|
||||
info!("reload the version: {}->{}", cur_version, next_version);
|
||||
PredictService::load_latest_model_from_model_dir(
|
||||
model_factory,
|
||||
model_config,
|
||||
tx,
|
||||
model_idx,
|
||||
next_version,
|
||||
cur_version,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn model_watcher_latest(model_factory: ModelFactory<T>, tx: Sender<PredictMessage<T>>) {
|
||||
async fn call_external_modelsync(cli: &str, cur_versions: &Vec<String>) -> Result<()> {
|
||||
let mut args = cli.split_whitespace();
|
||||
|
||||
let mut cmd = Command::new(args.next().ok_or(anyhow!("model sync cli empty"))?);
|
||||
let extr_args = MODEL_SPECS
|
||||
.iter()
|
||||
.zip(cur_versions)
|
||||
.flat_map(|(spec, version)| vec!["--model-spec", spec, "--cur-version", version])
|
||||
.collect_vec();
|
||||
info!("run model sync: {} with extra args: {:?}", cli, extr_args);
|
||||
let output = cmd.args(args).args(extr_args).output().await?;
|
||||
info!("model sync stdout:{}", String::from_utf8(output.stdout)?);
|
||||
info!("model sync stderr:{}", String::from_utf8(output.stderr)?);
|
||||
if output.status.success() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"model sync failed with status: {:?}!",
|
||||
output.status
|
||||
))
|
||||
}
|
||||
}
|
||||
let meta_dir = utils::get_meta_dir();
|
||||
let meta_file = format!("{}{}", meta_dir, META_INFO);
|
||||
//initialize the latest version array
|
||||
let mut cur_versions = vec!["".to_owned(); MODEL_SPECS.len()];
|
||||
loop {
|
||||
info!("***polling for models***"); //nice deliminter
|
||||
if let Some(ref cli) = ARGS.modelsync_cli {
|
||||
if let Err(e) = call_external_modelsync(cli, &cur_versions).await {
|
||||
error!("model sync cli running error:{}", e)
|
||||
}
|
||||
}
|
||||
let config = utils::read_config(&meta_file).unwrap_or_else(|e| {
|
||||
info!("config file {} not found due to: {}", meta_file, e);
|
||||
Value::Null
|
||||
});
|
||||
info!("config:{}", config);
|
||||
for (idx, cur_version) in cur_versions.iter_mut().enumerate() {
|
||||
let model_dir = &ARGS.model_dir[idx];
|
||||
PredictService::scan_load_latest_model_from_model_dir(
|
||||
model_factory,
|
||||
&config[&MODEL_SPECS[idx]],
|
||||
tx.clone(),
|
||||
idx,
|
||||
cur_version,
|
||||
)
|
||||
.await
|
||||
.map_or_else(
|
||||
|e| error!("scanned {}, error {:?}", model_dir, e),
|
||||
|_| info!("scanned {}, latest_version: {}", model_dir, cur_version),
|
||||
);
|
||||
}
|
||||
sleep(Duration::from_secs(ARGS.model_check_interval_secs)).await;
|
||||
}
|
||||
}
|
||||
async fn tf_queue_manager(mut rx: Receiver<PredictMessage<T>>) {
|
||||
// Start receiving messages
|
||||
info!("setting up queue manager");
|
||||
let max_batch_size = ARGS
|
||||
.max_batch_size
|
||||
.iter()
|
||||
.map(|b| b.parse().unwrap())
|
||||
.collect::<Vec<usize>>();
|
||||
let batch_time_out_millis = ARGS
|
||||
.batch_time_out_millis
|
||||
.iter()
|
||||
.map(|b| b.parse().unwrap())
|
||||
.collect::<Vec<u64>>();
|
||||
let no_msg_wait_millis = *batch_time_out_millis.iter().min().unwrap();
|
||||
let mut all_model_predictors: ArrayVec::<ArrayVec<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS> =
|
||||
(0 ..MAX_NUM_MODELS).map( |_| ArrayVec::<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>::new()).collect();
|
||||
loop {
|
||||
let msg = rx.try_recv();
|
||||
let no_more_msg = match msg {
|
||||
Ok(PredictMessage::Predict(model_spec_at, version, val, resp, ts)) => {
|
||||
if let Some(model_predictors) = all_model_predictors.get_mut(model_spec_at) {
|
||||
if model_predictors.is_empty() {
|
||||
resp.send(PredictResult::ModelNotReady(model_spec_at))
|
||||
.unwrap_or_else(|e| error!("cannot send back model not ready error: {:?}", e));
|
||||
}
|
||||
else {
|
||||
match version {
|
||||
None => model_predictors[0].push(val, resp, ts),
|
||||
Some(the_version) => match model_predictors
|
||||
.iter_mut()
|
||||
.find(|x| x.model.version() == the_version)
|
||||
{
|
||||
None => resp
|
||||
.send(PredictResult::ModelVersionNotFound(
|
||||
model_spec_at,
|
||||
the_version,
|
||||
))
|
||||
.unwrap_or_else(|e| {
|
||||
error!("cannot send back version error: {:?}", e)
|
||||
}),
|
||||
Some(predictor) => predictor.push(val, resp, ts),
|
||||
},
|
||||
}
|
||||
}
|
||||
} else {
|
||||
resp.send(PredictResult::ModelNotFound(model_spec_at))
|
||||
.unwrap_or_else(|e| error!("cannot send back model not found error: {:?}", e))
|
||||
}
|
||||
MPSC_CHANNEL_SIZE.dec();
|
||||
false
|
||||
}
|
||||
Ok(PredictMessage::UpsertModel(tf_model)) => {
|
||||
let idx = tf_model.model_idx();
|
||||
let predictor = BatchPredictor {
|
||||
model: Arc::new(tf_model),
|
||||
input_tensors: Vec::with_capacity(max_batch_size[idx]),
|
||||
callbacks: Vec::with_capacity(max_batch_size[idx]),
|
||||
cur_batch_size: 0,
|
||||
max_batch_size: max_batch_size[idx],
|
||||
batch_time_out_millis: batch_time_out_millis[idx],
|
||||
//initialize to be current time
|
||||
queue_reset_ts: Instant::now(),
|
||||
queue_earliest_rq_ts: Instant::now(),
|
||||
};
|
||||
assert!(idx < all_model_predictors.len());
|
||||
metrics::NEW_MODEL_SNAPSHOT
|
||||
.with_label_values(&[&MODEL_SPECS[idx]])
|
||||
.inc();
|
||||
|
||||
//we can do this since the vector is small
|
||||
let predictors = &mut all_model_predictors[idx];
|
||||
if predictors.len() == 0 {
|
||||
info!("now we serve new model: {}", predictor.model);
|
||||
}
|
||||
else {
|
||||
info!("now we serve updated model: {}", predictor.model);
|
||||
}
|
||||
if predictors.len() == ARGS.versions_per_model {
|
||||
predictors.remove(predictors.len() - 1);
|
||||
}
|
||||
predictors.insert(0, predictor);
|
||||
false
|
||||
}
|
||||
Err(TryRecvError::Empty) => true,
|
||||
Err(TryRecvError::Disconnected) => true,
|
||||
};
|
||||
for predictor in all_model_predictors.iter_mut().flatten() {
|
||||
//if predictor batch queue not empty and times out or no more msg in the queue, flush
|
||||
if (!predictor.input_tensors.is_empty() && (predictor.duration_past(predictor.batch_time_out_millis) || no_more_msg))
|
||||
//if batch queue reaches limit, flush
|
||||
|| predictor.cur_batch_size >= predictor.max_batch_size
|
||||
{
|
||||
predictor.batch_predict();
|
||||
}
|
||||
}
|
||||
if no_more_msg {
|
||||
sleep(Duration::from_millis(no_msg_wait_millis)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
#[inline(always)]
|
||||
pub fn get_model_index(model_spec: &str) -> Option<usize> {
|
||||
MODEL_SPECS.iter().position(|m| m == model_spec)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,492 +0,0 @@
|
|||
#[cfg(feature = "tf")]
|
||||
pub mod tf {
|
||||
use arrayvec::ArrayVec;
|
||||
use itertools::Itertools;
|
||||
use log::{debug, error, info, warn};
|
||||
use prost::Message;
|
||||
use std::fmt;
|
||||
use std::fmt::Display;
|
||||
use std::string::String;
|
||||
use tensorflow::io::{RecordReader, RecordReadError};
|
||||
use tensorflow::Operation;
|
||||
use tensorflow::SavedModelBundle;
|
||||
use tensorflow::SessionOptions;
|
||||
use tensorflow::SessionRunArgs;
|
||||
use tensorflow::Tensor;
|
||||
use tensorflow::{DataType, FetchToken, Graph, TensorInfo, TensorType};
|
||||
|
||||
use std::thread::sleep;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::cli_args::{Args, ARGS, INPUTS, MODEL_SPECS, OUTPUTS};
|
||||
use crate::tf_proto::tensorflow_serving::prediction_log::LogType;
|
||||
use crate::tf_proto::tensorflow_serving::{PredictionLog, PredictLog};
|
||||
use crate::tf_proto::ConfigProto;
|
||||
use anyhow::{Context, Result};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::TensorReturnEnum;
|
||||
use crate::bootstrap::{TensorInput, TensorInputEnum};
|
||||
use crate::metrics::{
|
||||
INFERENCE_FAILED_REQUESTS_BY_MODEL, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
|
||||
};
|
||||
use crate::predict_service::Model;
|
||||
use crate::{MAX_NUM_INPUTS, utils};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TFTensorEnum {
|
||||
String(Tensor<String>),
|
||||
Int(Tensor<i32>),
|
||||
Int64(Tensor<i64>),
|
||||
Float(Tensor<f32>),
|
||||
Double(Tensor<f64>),
|
||||
Boolean(Tensor<bool>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TFModel {
|
||||
pub model_idx: usize,
|
||||
pub bundle: SavedModelBundle,
|
||||
pub input_names: ArrayVec<String, MAX_NUM_INPUTS>,
|
||||
pub input_info: Vec<TensorInfo>,
|
||||
pub input_ops: Vec<Operation>,
|
||||
pub output_names: Vec<String>,
|
||||
pub output_info: Vec<TensorInfo>,
|
||||
pub output_ops: Vec<Operation>,
|
||||
pub export_dir: String,
|
||||
pub version: i64,
|
||||
pub inter_op: i32,
|
||||
pub intra_op: i32,
|
||||
}
|
||||
|
||||
impl Display for TFModel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"idx: {}, tensorflow model_name:{}, export_dir:{}, version:{}, inter:{}, intra:{}",
|
||||
self.model_idx,
|
||||
MODEL_SPECS[self.model_idx],
|
||||
self.export_dir,
|
||||
self.version,
|
||||
self.inter_op,
|
||||
self.intra_op
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl TFModel {
|
||||
pub fn new(idx: usize, version: String, model_config: &Value) -> Result<TFModel> {
|
||||
// Create input variables for our addition
|
||||
let config = ConfigProto {
|
||||
intra_op_parallelism_threads: utils::get_config_or(
|
||||
model_config,
|
||||
"intra_op_parallelism",
|
||||
&ARGS.intra_op_parallelism[idx],
|
||||
)
|
||||
.parse()?,
|
||||
inter_op_parallelism_threads: utils::get_config_or(
|
||||
model_config,
|
||||
"inter_op_parallelism",
|
||||
&ARGS.inter_op_parallelism[idx],
|
||||
)
|
||||
.parse()?,
|
||||
..Default::default()
|
||||
};
|
||||
let mut buf = Vec::new();
|
||||
buf.reserve(config.encoded_len());
|
||||
config.encode(&mut buf).unwrap();
|
||||
let mut opts = SessionOptions::new();
|
||||
opts.set_config(&buf)?;
|
||||
let export_dir = format!("{}/{}", ARGS.model_dir[idx], version);
|
||||
let mut graph = Graph::new();
|
||||
let bundle = SavedModelBundle::load(&opts, ["serve"], &mut graph, &export_dir)
|
||||
.context("error load model")?;
|
||||
let signature = bundle
|
||||
.meta_graph_def()
|
||||
.get_signature(&ARGS.serving_sig[idx])
|
||||
.context("error finding signature")?;
|
||||
let input_names = INPUTS[idx]
|
||||
.get_or_init(|| {
|
||||
let input_spec = signature
|
||||
.inputs()
|
||||
.iter()
|
||||
.map(|p| p.0.clone())
|
||||
.collect::<ArrayVec<String, MAX_NUM_INPUTS>>();
|
||||
info!(
|
||||
"input not set from cli, now we set from model metadata:{:?}",
|
||||
input_spec
|
||||
);
|
||||
input_spec
|
||||
})
|
||||
.clone();
|
||||
let input_info = input_names
|
||||
.iter()
|
||||
.map(|i| {
|
||||
signature
|
||||
.get_input(i)
|
||||
.context("error finding input op info")
|
||||
.unwrap()
|
||||
.clone()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let input_ops = input_info
|
||||
.iter()
|
||||
.map(|i| {
|
||||
graph
|
||||
.operation_by_name_required(&i.name().name)
|
||||
.context("error finding input op")
|
||||
.unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
info!("Model Input size: {}", input_info.len());
|
||||
|
||||
let output_names = OUTPUTS[idx].to_vec().clone();
|
||||
|
||||
let output_info = output_names
|
||||
.iter()
|
||||
.map(|o| {
|
||||
signature
|
||||
.get_output(o)
|
||||
.context("error finding output op info")
|
||||
.unwrap()
|
||||
.clone()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let output_ops = output_info
|
||||
.iter()
|
||||
.map(|o| {
|
||||
graph
|
||||
.operation_by_name_required(&o.name().name)
|
||||
.context("error finding output op")
|
||||
.unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let tf_model = TFModel {
|
||||
model_idx: idx,
|
||||
bundle,
|
||||
input_names,
|
||||
input_info,
|
||||
input_ops,
|
||||
output_names,
|
||||
output_info,
|
||||
output_ops,
|
||||
export_dir,
|
||||
version: Args::version_str_to_epoch(&version)?,
|
||||
inter_op: config.inter_op_parallelism_threads,
|
||||
intra_op: config.intra_op_parallelism_threads,
|
||||
};
|
||||
tf_model.warmup()?;
|
||||
Ok(tf_model)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_tftensor_dimensions<T>(
|
||||
t: &[T],
|
||||
input_size: u64,
|
||||
batch_size: u64,
|
||||
input_dims: Option<Vec<i64>>,
|
||||
) -> Vec<u64> {
|
||||
// if input size is 1, we just specify a single dimension to outgoing tensor matching the
|
||||
// size of the input tensor. This is for backwards compatiblity with existing Navi clients
|
||||
// which specify input as a single string tensor (like tfexample) and use batching support.
|
||||
let mut dims = vec![];
|
||||
if input_size > 1 {
|
||||
if batch_size == 1 && input_dims.is_some() {
|
||||
// client side batching is enabled?
|
||||
input_dims
|
||||
.unwrap()
|
||||
.iter()
|
||||
.for_each(|axis| dims.push(*axis as u64));
|
||||
} else {
|
||||
dims.push(batch_size);
|
||||
dims.push(t.len() as u64 / batch_size);
|
||||
}
|
||||
} else {
|
||||
dims.push(t.len() as u64);
|
||||
}
|
||||
dims
|
||||
}
|
||||
|
||||
fn convert_to_tftensor_enum(
|
||||
input: TensorInput,
|
||||
input_size: u64,
|
||||
batch_size: u64,
|
||||
) -> TFTensorEnum {
|
||||
match input.tensor_data {
|
||||
TensorInputEnum::String(t) => {
|
||||
let strings = t
|
||||
.into_iter()
|
||||
.map(|x| unsafe { String::from_utf8_unchecked(x) })
|
||||
.collect_vec();
|
||||
TFTensorEnum::String(
|
||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
||||
strings.as_slice(),
|
||||
input_size,
|
||||
batch_size,
|
||||
input.dims,
|
||||
))
|
||||
.with_values(strings.as_slice())
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
TensorInputEnum::Int(t) => TFTensorEnum::Int(
|
||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
||||
t.as_slice(),
|
||||
input_size,
|
||||
batch_size,
|
||||
input.dims,
|
||||
))
|
||||
.with_values(t.as_slice())
|
||||
.unwrap(),
|
||||
),
|
||||
TensorInputEnum::Int64(t) => TFTensorEnum::Int64(
|
||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
||||
t.as_slice(),
|
||||
input_size,
|
||||
batch_size,
|
||||
input.dims,
|
||||
))
|
||||
.with_values(t.as_slice())
|
||||
.unwrap(),
|
||||
),
|
||||
TensorInputEnum::Float(t) => TFTensorEnum::Float(
|
||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
||||
t.as_slice(),
|
||||
input_size,
|
||||
batch_size,
|
||||
input.dims,
|
||||
))
|
||||
.with_values(t.as_slice())
|
||||
.unwrap(),
|
||||
),
|
||||
TensorInputEnum::Double(t) => TFTensorEnum::Double(
|
||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
||||
t.as_slice(),
|
||||
input_size,
|
||||
batch_size,
|
||||
input.dims,
|
||||
))
|
||||
.with_values(t.as_slice())
|
||||
.unwrap(),
|
||||
),
|
||||
TensorInputEnum::Boolean(t) => TFTensorEnum::Boolean(
|
||||
Tensor::new(&TFModel::get_tftensor_dimensions(
|
||||
t.as_slice(),
|
||||
input_size,
|
||||
batch_size,
|
||||
input.dims,
|
||||
))
|
||||
.with_values(t.as_slice())
|
||||
.unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
fn fetch_output<T: TensorType>(
|
||||
args: &mut SessionRunArgs,
|
||||
token_output: &FetchToken,
|
||||
batch_size: u64,
|
||||
output_size: u64,
|
||||
) -> (Tensor<T>, u64) {
|
||||
let tensor_output = args.fetch::<T>(*token_output).expect("fetch output failed");
|
||||
let mut tensor_width = tensor_output.dims()[1];
|
||||
if batch_size == 1 && output_size > 1 {
|
||||
tensor_width = tensor_output.dims().iter().fold(1, |mut total, &val| {
|
||||
total *= val;
|
||||
total
|
||||
});
|
||||
}
|
||||
(tensor_output, tensor_width)
|
||||
}
|
||||
}
|
||||
|
||||
impl Model for TFModel {
|
||||
fn warmup(&self) -> Result<()> {
|
||||
// warm up
|
||||
let warmup_file = format!(
|
||||
"{}/assets.extra/tf_serving_warmup_requests",
|
||||
self.export_dir
|
||||
);
|
||||
if std::path::Path::new(&warmup_file).exists() {
|
||||
use std::io::Cursor;
|
||||
info!(
|
||||
"found warmup assets in {}, now perform warming up",
|
||||
warmup_file
|
||||
);
|
||||
let f = std::fs::File::open(warmup_file).context("cannot open warmup file")?;
|
||||
// let mut buf = Vec::new();
|
||||
let read = std::io::BufReader::new(f);
|
||||
let mut reader = RecordReader::new(read);
|
||||
let mut warmup_cnt = 0;
|
||||
loop {
|
||||
let next = reader.read_next_owned();
|
||||
match next {
|
||||
Ok(res) => match res {
|
||||
Some(vec) => {
|
||||
// info!("read one tfRecord");
|
||||
match PredictionLog::decode(&mut Cursor::new(vec))
|
||||
.context("can't parse PredictonLog")?
|
||||
{
|
||||
PredictionLog {
|
||||
log_metadata: _,
|
||||
log_type:
|
||||
Some(LogType::PredictLog(PredictLog {
|
||||
request: Some(mut req),
|
||||
response: _,
|
||||
})),
|
||||
} => {
|
||||
if warmup_cnt == ARGS.max_warmup_records {
|
||||
//warm up to max_warmup_records records
|
||||
warn!(
|
||||
"reached max warmup {} records, exit warmup for {}",
|
||||
ARGS.max_warmup_records,
|
||||
MODEL_SPECS[self.model_idx]
|
||||
);
|
||||
break;
|
||||
}
|
||||
self.do_predict(
|
||||
vec![req.take_input_vals(&self.input_names)],
|
||||
1,
|
||||
);
|
||||
sleep(Duration::from_millis(100));
|
||||
warmup_cnt += 1;
|
||||
}
|
||||
_ => error!("some wrong record in warming up file"),
|
||||
}
|
||||
}
|
||||
None => {
|
||||
info!("end of warmup file, warmed up with records: {}", warmup_cnt);
|
||||
break;
|
||||
}
|
||||
},
|
||||
Err(RecordReadError::CorruptFile)
|
||||
| Err(RecordReadError::IoError { .. }) => {
|
||||
error!("read tfrecord error for warmup files, skip");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn do_predict(
|
||||
&self,
|
||||
input_tensors: Vec<Vec<TensorInput>>,
|
||||
batch_size: u64,
|
||||
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>) {
|
||||
// let mut batch_ends = input_tensors.iter().map(|t| t.len()).collect::<Vec<usize>>();
|
||||
let output_size = self.output_names.len() as u64;
|
||||
let input_size = self.input_names.len() as u64;
|
||||
debug!(
|
||||
"Request for Tensorflow with batch size: {} and input_size: {}",
|
||||
batch_size, input_size
|
||||
);
|
||||
// build a set of input TF tensors
|
||||
|
||||
let batch_end = (1usize..=input_tensors.len() as usize)
|
||||
.into_iter()
|
||||
.collect_vec();
|
||||
let mut batch_ends = vec![batch_end; output_size as usize];
|
||||
|
||||
let batched_tensors = TensorInputEnum::merge_batch(input_tensors)
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(_, i)| TFModel::convert_to_tftensor_enum(i, input_size, batch_size))
|
||||
.collect_vec();
|
||||
|
||||
let mut args = SessionRunArgs::new();
|
||||
for (index, tf_tensor) in batched_tensors.iter().enumerate() {
|
||||
match tf_tensor {
|
||||
TFTensorEnum::String(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
||||
TFTensorEnum::Int(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
||||
TFTensorEnum::Int64(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
||||
TFTensorEnum::Float(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
||||
TFTensorEnum::Double(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
||||
TFTensorEnum::Boolean(inner) => args.add_feed(&self.input_ops[index], 0, inner),
|
||||
}
|
||||
}
|
||||
// For output ops, we receive the same op object by name. Actual tensor tokens are available at different offsets.
|
||||
// Since indices are ordered, its important to specify output flag to Navi in the same order.
|
||||
let token_outputs = self
|
||||
.output_ops
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, op)| args.request_fetch(op, idx as i32))
|
||||
.collect_vec();
|
||||
match self.bundle.session.run(&mut args) {
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
NUM_REQUESTS_FAILED.inc_by(batch_size);
|
||||
NUM_REQUESTS_FAILED_BY_MODEL
|
||||
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
|
||||
.inc_by(batch_size);
|
||||
INFERENCE_FAILED_REQUESTS_BY_MODEL
|
||||
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
|
||||
.inc_by(batch_size);
|
||||
panic!("{model}: {e:?}", model = MODEL_SPECS[self.model_idx], e = e);
|
||||
}
|
||||
}
|
||||
let mut predict_return = vec![];
|
||||
// Check the output.
|
||||
for (index, token_output) in token_outputs.iter().enumerate() {
|
||||
// same ops, with type info at different offsets.
|
||||
let (res, width) = match self.output_ops[index].output_type(index) {
|
||||
DataType::Float => {
|
||||
let (tensor_output, tensor_width) =
|
||||
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
|
||||
(
|
||||
TensorReturnEnum::FloatTensorReturn(Box::new(tensor_output)),
|
||||
tensor_width,
|
||||
)
|
||||
}
|
||||
DataType::Int64 => {
|
||||
let (tensor_output, tensor_width) =
|
||||
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
|
||||
(
|
||||
TensorReturnEnum::Int64TensorReturn(Box::new(tensor_output)),
|
||||
tensor_width,
|
||||
)
|
||||
}
|
||||
DataType::Int32 => {
|
||||
let (tensor_output, tensor_width) =
|
||||
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
|
||||
(
|
||||
TensorReturnEnum::Int32TensorReturn(Box::new(tensor_output)),
|
||||
tensor_width,
|
||||
)
|
||||
}
|
||||
DataType::String => {
|
||||
let (tensor_output, tensor_width) =
|
||||
TFModel::fetch_output(&mut args, token_output, batch_size, output_size);
|
||||
(
|
||||
TensorReturnEnum::StringTensorReturn(Box::new(tensor_output)),
|
||||
tensor_width,
|
||||
)
|
||||
}
|
||||
_ => panic!("Unsupported return type!"),
|
||||
};
|
||||
let width = width as usize;
|
||||
for b in batch_ends[index].iter_mut() {
|
||||
*b *= width;
|
||||
}
|
||||
predict_return.push(res)
|
||||
}
|
||||
//TODO: remove in the future
|
||||
//TODO: support actual mtl model outputs
|
||||
(predict_return, batch_ends)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn model_idx(&self) -> usize {
|
||||
self.model_idx
|
||||
}
|
||||
#[inline(always)]
|
||||
fn version(&self) -> i64 {
|
||||
self.version
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,183 +0,0 @@
|
|||
#[cfg(feature = "torch")]
|
||||
pub mod torch {
|
||||
use std::fmt;
|
||||
use std::fmt::Display;
|
||||
use std::string::String;
|
||||
|
||||
use crate::TensorReturnEnum;
|
||||
use crate::SerializedInput;
|
||||
use crate::bootstrap::TensorInput;
|
||||
use crate::cli_args::{Args, ARGS, MODEL_SPECS};
|
||||
use crate::metrics;
|
||||
use crate::metrics::{
|
||||
INFERENCE_FAILED_REQUESTS_BY_MODEL, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
|
||||
};
|
||||
use crate::predict_service::Model;
|
||||
use anyhow::Result;
|
||||
use dr_transform::converter::BatchPredictionRequestToTorchTensorConverter;
|
||||
use dr_transform::converter::Converter;
|
||||
use serde_json::Value;
|
||||
use tch::Tensor;
|
||||
use tch::{kind, CModule, IValue};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TorchModel {
|
||||
pub model_idx: usize,
|
||||
pub version: i64,
|
||||
pub module: CModule,
|
||||
pub export_dir: String,
|
||||
// FIXME: make this Box<Option<..>> so input converter can be optional.
|
||||
// Also consider adding output_converter.
|
||||
pub input_converter: Box<dyn Converter>,
|
||||
}
|
||||
|
||||
impl Display for TorchModel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"idx: {}, torch model_name:{}, version:{}",
|
||||
self.model_idx, MODEL_SPECS[self.model_idx], self.version
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl TorchModel {
|
||||
pub fn new(idx: usize, version: String, _model_config: &Value) -> Result<TorchModel> {
|
||||
let export_dir = format!("{}/{}/model.pt", ARGS.model_dir[idx], version);
|
||||
let model = CModule::load(&export_dir).unwrap();
|
||||
let torch_model = TorchModel {
|
||||
model_idx: idx,
|
||||
version: Args::version_str_to_epoch(&version)?,
|
||||
module: model,
|
||||
export_dir,
|
||||
//TODO: move converter lookup in a registry.
|
||||
input_converter: Box::new(BatchPredictionRequestToTorchTensorConverter::new(
|
||||
&ARGS.model_dir[idx].as_str(),
|
||||
version.as_str(),
|
||||
vec![],
|
||||
Some(&metrics::register_dynamic_metrics),
|
||||
)),
|
||||
};
|
||||
|
||||
torch_model.warmup()?;
|
||||
Ok(torch_model)
|
||||
}
|
||||
#[inline(always)]
|
||||
pub fn decode_to_inputs(bytes: SerializedInput) -> Vec<Tensor> {
|
||||
//FIXME: for now we generate 4 random tensors as inputs to unblock end to end testing
|
||||
//when Shajan's decoder is ready we will swap
|
||||
let row = bytes.len() as i64;
|
||||
let t1 = Tensor::randn(&[row, 5293], kind::FLOAT_CPU); //continuous
|
||||
let t2 = Tensor::randint(10, &[row, 149], kind::INT64_CPU); //binary
|
||||
let t3 = Tensor::randint(10, &[row, 320], kind::INT64_CPU); //discrete
|
||||
let t4 = Tensor::randn(&[row, 200], kind::FLOAT_CPU); //user_embedding
|
||||
let t5 = Tensor::randn(&[row, 200], kind::FLOAT_CPU); //user_eng_embedding
|
||||
let t6 = Tensor::randn(&[row, 200], kind::FLOAT_CPU); //author_embedding
|
||||
|
||||
vec![t1, t2, t3, t4, t5, t6]
|
||||
}
|
||||
#[inline(always)]
|
||||
pub fn output_to_vec(res: IValue, dst: &mut Vec<f32>) {
|
||||
match res {
|
||||
IValue::Tensor(tensor) => TorchModel::tensors_to_vec(&[tensor], dst),
|
||||
IValue::Tuple(ivalues) => {
|
||||
TorchModel::tensors_to_vec(&TorchModel::ivalues_to_tensors(ivalues), dst)
|
||||
}
|
||||
_ => panic!("we only support output as a single tensor or a vec of tensors"),
|
||||
}
|
||||
}
|
||||
#[inline(always)]
|
||||
pub fn tensor_flatten_size(t: &Tensor) -> usize {
|
||||
t.size().into_iter().fold(1, |acc, x| acc * x) as usize
|
||||
}
|
||||
#[inline(always)]
|
||||
pub fn tensor_to_vec<T: kind::Element>(res: &Tensor) -> Vec<T> {
|
||||
let size = TorchModel::tensor_flatten_size(res);
|
||||
let mut res_f32: Vec<T> = Vec::with_capacity(size);
|
||||
unsafe {
|
||||
res_f32.set_len(size);
|
||||
}
|
||||
res.copy_data(res_f32.as_mut_slice(), size);
|
||||
// println!("Copied tensor:{}, {:?}", res_f32.len(), res_f32);
|
||||
res_f32
|
||||
}
|
||||
#[inline(always)]
|
||||
pub fn tensors_to_vec(tensors: &[Tensor], dst: &mut Vec<f32>) {
|
||||
let mut offset = dst.len();
|
||||
tensors.iter().for_each(|t| {
|
||||
let size = TorchModel::tensor_flatten_size(t);
|
||||
let next_size = offset + size;
|
||||
unsafe {
|
||||
dst.set_len(next_size);
|
||||
}
|
||||
t.copy_data(&mut dst[offset..], size);
|
||||
offset = next_size;
|
||||
});
|
||||
}
|
||||
pub fn ivalues_to_tensors(ivalues: Vec<IValue>) -> Vec<Tensor> {
|
||||
ivalues
|
||||
.into_iter()
|
||||
.map(|t| {
|
||||
if let IValue::Tensor(vanilla_t) = t {
|
||||
vanilla_t
|
||||
} else {
|
||||
panic!("not a tensor")
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Tensor>>()
|
||||
}
|
||||
}
|
||||
|
||||
impl Model for TorchModel {
|
||||
fn warmup(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
//TODO: torch runtime needs some refactor to make it a generic interface
|
||||
#[inline(always)]
|
||||
fn do_predict(
|
||||
&self,
|
||||
input_tensors: Vec<Vec<TensorInput>>,
|
||||
total_len: u64,
|
||||
) -> (Vec<TensorReturnEnum>, Vec<Vec<usize>>) {
|
||||
let mut buf: Vec<f32> = Vec::with_capacity(10_000);
|
||||
let mut batch_ends = vec![0usize; input_tensors.len()];
|
||||
for (i, batch_bytes_in_request) in input_tensors.into_iter().enumerate() {
|
||||
for _ in batch_bytes_in_request.into_iter() {
|
||||
//FIXME: for now use some hack
|
||||
let model_input = TorchModel::decode_to_inputs(vec![0u8; 30]); //self.input_converter.convert(bytes);
|
||||
let input_batch_tensors = model_input
|
||||
.into_iter()
|
||||
.map(|t| IValue::Tensor(t))
|
||||
.collect::<Vec<IValue>>();
|
||||
// match self.module.forward_is(&input_batch_tensors) {
|
||||
match self.module.method_is("forward_serve", &input_batch_tensors) {
|
||||
Ok(res) => TorchModel::output_to_vec(res, &mut buf),
|
||||
Err(e) => {
|
||||
NUM_REQUESTS_FAILED.inc_by(total_len);
|
||||
NUM_REQUESTS_FAILED_BY_MODEL
|
||||
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
|
||||
.inc_by(total_len);
|
||||
INFERENCE_FAILED_REQUESTS_BY_MODEL
|
||||
.with_label_values(&[&MODEL_SPECS[self.model_idx]])
|
||||
.inc_by(total_len);
|
||||
panic!("{model}: {e:?}", model = MODEL_SPECS[self.model_idx], e = e);
|
||||
}
|
||||
}
|
||||
}
|
||||
batch_ends[i] = buf.len();
|
||||
}
|
||||
(
|
||||
vec![TensorReturnEnum::FloatTensorReturn(Box::new(buf))],
|
||||
vec![batch_ends],
|
||||
)
|
||||
}
|
||||
#[inline(always)]
|
||||
fn model_idx(&self) -> usize {
|
||||
self.model_idx
|
||||
}
|
||||
#[inline(always)]
|
||||
fn version(&self) -> i64 {
|
||||
self.version
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,11 +0,0 @@
|
|||
[package]
|
||||
name = "segdense"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
|
||||
[dependencies]
|
||||
env_logger = "0.10.0"
|
||||
serde = { version = "1.0.104", features = ["derive"] }
|
||||
serde_json = "1.0.48"
|
||||
log = "0.4.17"
|
Binary file not shown.
|
@ -1,53 +0,0 @@
|
|||
use std::fmt::Display;
|
||||
|
||||
/**
|
||||
* Custom error
|
||||
*/
|
||||
#[derive(Debug)]
|
||||
pub enum SegDenseError {
|
||||
IoError(std::io::Error),
|
||||
Json(serde_json::Error),
|
||||
JsonMissingRoot,
|
||||
JsonMissingObject,
|
||||
JsonMissingArray,
|
||||
JsonArraySize,
|
||||
JsonMissingInputFeature,
|
||||
}
|
||||
|
||||
impl Display for SegDenseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
|
||||
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
|
||||
SegDenseError::JsonMissingRoot => {
|
||||
write!(f, "{}", "SegDense JSON: Root Node note found!")
|
||||
}
|
||||
SegDenseError::JsonMissingObject => {
|
||||
write!(f, "{}", "SegDense JSON: Object note found!")
|
||||
}
|
||||
SegDenseError::JsonMissingArray => {
|
||||
write!(f, "{}", "SegDense JSON: Array Node note found!")
|
||||
}
|
||||
SegDenseError::JsonArraySize => {
|
||||
write!(f, "{}", "SegDense JSON: Array size not as expected!")
|
||||
}
|
||||
SegDenseError::JsonMissingInputFeature => {
|
||||
write!(f, "{}", "SegDense JSON: Missing input feature!")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for SegDenseError {}
|
||||
|
||||
impl From<std::io::Error> for SegDenseError {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
SegDenseError::IoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for SegDenseError {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
SegDenseError::Json(err)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,4 +0,0 @@
|
|||
pub mod error;
|
||||
pub mod mapper;
|
||||
pub mod segdense_transform_spec_home_recap_2022;
|
||||
pub mod util;
|
Binary file not shown.
|
@ -1,22 +0,0 @@
|
|||
use std::env;
|
||||
use std::fs;
|
||||
|
||||
use segdense::error::SegDenseError;
|
||||
use segdense::util;
|
||||
|
||||
fn main() -> Result<(), SegDenseError> {
|
||||
env_logger::init();
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
let schema_file_name: &str = if args.len() == 1 {
|
||||
"json/compact.json"
|
||||
} else {
|
||||
&args[1]
|
||||
};
|
||||
|
||||
let json_str = fs::read_to_string(schema_file_name)?;
|
||||
|
||||
util::safe_load_config(&json_str)?;
|
||||
|
||||
Ok(())
|
||||
}
|
Binary file not shown.
|
@ -1,45 +0,0 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FeatureInfo {
|
||||
pub tensor_index: i8,
|
||||
pub index_within_tensor: i64,
|
||||
}
|
||||
|
||||
pub static NULL_INFO: FeatureInfo = FeatureInfo {
|
||||
tensor_index: -1,
|
||||
index_within_tensor: -1,
|
||||
};
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct FeatureMapper {
|
||||
map: HashMap<i64, FeatureInfo>,
|
||||
}
|
||||
|
||||
impl FeatureMapper {
|
||||
pub fn new() -> FeatureMapper {
|
||||
FeatureMapper {
|
||||
map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MapWriter {
|
||||
fn set(&mut self, feature_id: i64, info: FeatureInfo);
|
||||
}
|
||||
|
||||
pub trait MapReader {
|
||||
fn get(&self, feature_id: &i64) -> Option<&FeatureInfo>;
|
||||
}
|
||||
|
||||
impl MapWriter for FeatureMapper {
|
||||
fn set(&mut self, feature_id: i64, info: FeatureInfo) {
|
||||
self.map.insert(feature_id, info);
|
||||
}
|
||||
}
|
||||
|
||||
impl MapReader for FeatureMapper {
|
||||
fn get(&self, feature_id: &i64) -> Option<&FeatureInfo> {
|
||||
self.map.get(feature_id)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,182 +0,0 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Root {
|
||||
#[serde(rename = "common_prefix")]
|
||||
pub common_prefix: String,
|
||||
#[serde(rename = "densification_transform_spec")]
|
||||
pub densification_transform_spec: DensificationTransformSpec,
|
||||
#[serde(rename = "identity_transform_spec")]
|
||||
pub identity_transform_spec: Vec<IdentityTransformSpec>,
|
||||
#[serde(rename = "complex_feature_type_transform_spec")]
|
||||
pub complex_feature_type_transform_spec: Vec<ComplexFeatureTypeTransformSpec>,
|
||||
#[serde(rename = "input_features_map")]
|
||||
pub input_features_map: Value,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DensificationTransformSpec {
|
||||
pub discrete: Discrete,
|
||||
pub cont: Cont,
|
||||
pub binary: Binary,
|
||||
pub string: Value, // Use StringType
|
||||
pub blob: Blob,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Discrete {
|
||||
pub tag: String,
|
||||
#[serde(rename = "generic_feature_type")]
|
||||
pub generic_feature_type: i64,
|
||||
#[serde(rename = "feature_identifier")]
|
||||
pub feature_identifier: String,
|
||||
#[serde(rename = "fixed_length")]
|
||||
pub fixed_length: i64,
|
||||
#[serde(rename = "default_value")]
|
||||
pub default_value: DefaultValue,
|
||||
#[serde(rename = "input_features")]
|
||||
pub input_features: Vec<InputFeature>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DefaultValue {
|
||||
#[serde(rename = "type")]
|
||||
pub type_field: String,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InputFeature {
|
||||
#[serde(rename = "feature_id")]
|
||||
pub feature_id: i64,
|
||||
#[serde(rename = "full_feature_name")]
|
||||
pub full_feature_name: String,
|
||||
#[serde(rename = "feature_type")]
|
||||
pub feature_type: i64,
|
||||
pub index: i64,
|
||||
#[serde(rename = "maybe_exclude")]
|
||||
pub maybe_exclude: bool,
|
||||
pub tag: String,
|
||||
#[serde(rename = "added_at")]
|
||||
pub added_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Cont {
|
||||
pub tag: String,
|
||||
#[serde(rename = "generic_feature_type")]
|
||||
pub generic_feature_type: i64,
|
||||
#[serde(rename = "feature_identifier")]
|
||||
pub feature_identifier: String,
|
||||
#[serde(rename = "fixed_length")]
|
||||
pub fixed_length: i64,
|
||||
#[serde(rename = "default_value")]
|
||||
pub default_value: DefaultValue,
|
||||
#[serde(rename = "input_features")]
|
||||
pub input_features: Vec<InputFeature>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Binary {
|
||||
pub tag: String,
|
||||
#[serde(rename = "generic_feature_type")]
|
||||
pub generic_feature_type: i64,
|
||||
#[serde(rename = "feature_identifier")]
|
||||
pub feature_identifier: String,
|
||||
#[serde(rename = "fixed_length")]
|
||||
pub fixed_length: i64,
|
||||
#[serde(rename = "default_value")]
|
||||
pub default_value: DefaultValue,
|
||||
#[serde(rename = "input_features")]
|
||||
pub input_features: Vec<InputFeature>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct StringType {
|
||||
pub tag: String,
|
||||
#[serde(rename = "generic_feature_type")]
|
||||
pub generic_feature_type: i64,
|
||||
#[serde(rename = "feature_identifier")]
|
||||
pub feature_identifier: String,
|
||||
#[serde(rename = "fixed_length")]
|
||||
pub fixed_length: i64,
|
||||
#[serde(rename = "default_value")]
|
||||
pub default_value: DefaultValue,
|
||||
#[serde(rename = "input_features")]
|
||||
pub input_features: Vec<InputFeature>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Blob {
|
||||
pub tag: String,
|
||||
#[serde(rename = "generic_feature_type")]
|
||||
pub generic_feature_type: i64,
|
||||
#[serde(rename = "feature_identifier")]
|
||||
pub feature_identifier: String,
|
||||
#[serde(rename = "fixed_length")]
|
||||
pub fixed_length: i64,
|
||||
#[serde(rename = "default_value")]
|
||||
pub default_value: DefaultValue,
|
||||
#[serde(rename = "input_features")]
|
||||
pub input_features: Vec<Value>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct IdentityTransformSpec {
|
||||
#[serde(rename = "feature_id")]
|
||||
pub feature_id: i64,
|
||||
#[serde(rename = "full_feature_name")]
|
||||
pub full_feature_name: String,
|
||||
#[serde(rename = "feature_type")]
|
||||
pub feature_type: i64,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ComplexFeatureTypeTransformSpec {
|
||||
#[serde(rename = "feature_id")]
|
||||
pub feature_id: i64,
|
||||
#[serde(rename = "full_feature_name")]
|
||||
pub full_feature_name: String,
|
||||
#[serde(rename = "feature_type")]
|
||||
pub feature_type: i64,
|
||||
pub index: i64,
|
||||
#[serde(rename = "maybe_exclude")]
|
||||
pub maybe_exclude: bool,
|
||||
pub tag: String,
|
||||
#[serde(rename = "tensor_data_type")]
|
||||
pub tensor_data_type: Option<i64>,
|
||||
#[serde(rename = "added_at")]
|
||||
pub added_at: i64,
|
||||
#[serde(rename = "tensor_shape")]
|
||||
#[serde(default)]
|
||||
pub tensor_shape: Vec<i64>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InputFeatureMapRecord {
|
||||
#[serde(rename = "feature_id")]
|
||||
pub feature_id: i64,
|
||||
#[serde(rename = "full_feature_name")]
|
||||
pub full_feature_name: String,
|
||||
#[serde(rename = "feature_type")]
|
||||
pub feature_type: i64,
|
||||
pub index: i64,
|
||||
#[serde(rename = "maybe_exclude")]
|
||||
pub maybe_exclude: bool,
|
||||
pub tag: String,
|
||||
#[serde(rename = "added_at")]
|
||||
pub added_at: i64,
|
||||
}
|
Binary file not shown.
|
@ -1,154 +0,0 @@
|
|||
use log::debug;
|
||||
use std::fs;
|
||||
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
use crate::error::SegDenseError;
|
||||
use crate::mapper::{FeatureInfo, FeatureMapper, MapWriter};
|
||||
use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature};
|
||||
|
||||
pub fn load_config(file_name: &str) -> Result<seg_dense::Root, SegDenseError> {
|
||||
let json_str = fs::read_to_string(file_name)?;
|
||||
// &format!("Unable to load segdense file {}", file_name));
|
||||
let seg_dense_config = parse(&json_str)?;
|
||||
// &format!("Unable to parse segdense file {}", file_name));
|
||||
Ok(seg_dense_config)
|
||||
}
|
||||
|
||||
pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> {
|
||||
let root: seg_dense::Root = serde_json::from_str(json_str)?;
|
||||
Ok(root)
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a json string containing a seg dense schema create a feature mapper
|
||||
* which is essentially:
|
||||
*
|
||||
* {feature-id -> (Tensor Index, Index of feature within the tensor)}
|
||||
*
|
||||
* Feature id : 64 bit hash of the feature name used in DataRecords.
|
||||
*
|
||||
* Tensor Index : A vector of tensors is passed to the model. Tensor
|
||||
* index refers to the tensor this feature is part of.
|
||||
*
|
||||
* Index of feature in tensor : The tensors are vectors, the index of
|
||||
* feature is the position to put the feature value.
|
||||
*
|
||||
* There are many assumptions made in this function that is very model specific.
|
||||
* These assumptions are called out below and need to be schematized eventually.
|
||||
*
|
||||
* Call this once for each segdense schema and cache the FeatureMapper.
|
||||
*/
|
||||
pub fn safe_load_config(json_str: &str) -> Result<FeatureMapper, SegDenseError> {
|
||||
let root = parse(json_str)?;
|
||||
load_from_parsed_config(root)
|
||||
}
|
||||
|
||||
// Perf note : make 'root' un-owned
|
||||
pub fn load_from_parsed_config(root: seg_dense::Root) -> Result<FeatureMapper, SegDenseError> {
|
||||
let v = root.input_features_map;
|
||||
|
||||
// Do error check
|
||||
let map: Map<String, Value> = match v {
|
||||
Value::Object(map) => map,
|
||||
_ => return Err(SegDenseError::JsonMissingObject),
|
||||
};
|
||||
|
||||
let mut fm: FeatureMapper = FeatureMapper::new();
|
||||
|
||||
let items = map.values();
|
||||
|
||||
// Perf : Consider a way to avoid clone here
|
||||
for item in items.cloned() {
|
||||
let mut vec = match item {
|
||||
Value::Array(v) => v,
|
||||
_ => return Err(SegDenseError::JsonMissingArray),
|
||||
};
|
||||
|
||||
if vec.len() != 1 {
|
||||
return Err(SegDenseError::JsonArraySize);
|
||||
}
|
||||
|
||||
let val = vec.pop().unwrap();
|
||||
|
||||
let input_feature: seg_dense::InputFeature = serde_json::from_value(val)?;
|
||||
let feature_id = input_feature.feature_id;
|
||||
let feature_info = to_feature_info(&input_feature);
|
||||
|
||||
match feature_info {
|
||||
Some(info) => {
|
||||
debug!("{:?}", info);
|
||||
fm.set(feature_id, info)
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(fm)
|
||||
}
|
||||
#[allow(dead_code)]
|
||||
fn add_feature_info_to_mapper(
|
||||
feature_mapper: &mut FeatureMapper,
|
||||
input_features: &Vec<InputFeature>,
|
||||
) {
|
||||
for input_feature in input_features.iter() {
|
||||
let feature_id = input_feature.feature_id;
|
||||
let feature_info = to_feature_info(input_feature);
|
||||
|
||||
match feature_info {
|
||||
Some(info) => {
|
||||
debug!("{:?}", info);
|
||||
feature_mapper.set(feature_id, info)
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<FeatureInfo> {
|
||||
if input_feature.maybe_exclude {
|
||||
return None;
|
||||
}
|
||||
|
||||
// This part needs to be schema driven
|
||||
//
|
||||
// tensor index : Which of these tensors this feature is part of
|
||||
// [Continious, Binary, Discrete, User_embedding, user_eng_embedding, author_embedding]
|
||||
// Note that this order is fixed/hardcoded here, and need to be schematized
|
||||
//
|
||||
let tensor_idx: i8 = match input_feature.feature_id {
|
||||
// user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings
|
||||
// Feature name is mapped to a feature-id value. The hardcoded values below correspond to a specific feature name.
|
||||
-2550691008059411095 => 3,
|
||||
|
||||
// user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings
|
||||
5390650078733277231 => 4,
|
||||
|
||||
// original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings
|
||||
3223956748566688423 => 5,
|
||||
|
||||
_ => match input_feature.feature_type {
|
||||
// feature_type : src/thrift/com/twitter/ml/api/data.thrift
|
||||
// BINARY = 1, CONTINUOUS = 2, DISCRETE = 3,
|
||||
// Map to slots in [Continious, Binary, Discrete, ..]
|
||||
1 => 1,
|
||||
2 => 0,
|
||||
3 => 2,
|
||||
_ => -1,
|
||||
},
|
||||
};
|
||||
|
||||
if input_feature.index < 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Handle this case later
|
||||
if tensor_idx == -1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(FeatureInfo {
|
||||
tensor_index: tensor_idx,
|
||||
index_within_tensor: input_feature.index,
|
||||
})
|
||||
}
|
Binary file not shown.
|
@ -1,8 +0,0 @@
|
|||
[package]
|
||||
name = "bpr_thrift"
|
||||
description = "Thrift parser for Batch Prediction Request"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
thrift = "0.17.0"
|
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
@ -1,78 +0,0 @@
|
|||
|
||||
// A feature value can be one of these
|
||||
enum FeatureVal {
|
||||
Empty,
|
||||
U8Vector(Vec<u8>),
|
||||
FloatVector(Vec<f32>),
|
||||
}
|
||||
|
||||
// A Feture has a name and a value
|
||||
// The name for now is 'id' of type string
|
||||
// Eventually this needs to be flexible - example to accomodate feature-id
|
||||
struct Feature {
|
||||
id: String,
|
||||
val: FeatureVal,
|
||||
}
|
||||
|
||||
impl Feature {
|
||||
fn new() -> Feature {
|
||||
Feature {
|
||||
id: String::new(),
|
||||
val: FeatureVal::Empty
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A single inference record will have multiple features
|
||||
struct Record {
|
||||
fields: Vec<Feature>,
|
||||
}
|
||||
|
||||
impl Record {
|
||||
fn new() -> Record {
|
||||
Record { fields: vec![] }
|
||||
}
|
||||
}
|
||||
|
||||
// This is the main API used by external components
|
||||
// Given a serialized input, decode it into Records
|
||||
fn decode(input: Vec<u8>) -> Vec<Record> {
|
||||
// For helping define the interface
|
||||
vec![get_random_record(), get_random_record()]
|
||||
}
|
||||
|
||||
// Used for testing the API, will be eventually removed
|
||||
fn get_random_record() -> Record {
|
||||
let mut record: Record = Record::new();
|
||||
|
||||
let f1: Feature = Feature {
|
||||
id: String::from("continuous_features"),
|
||||
val: FeatureVal::FloatVector(vec![1.0f32; 2134]),
|
||||
};
|
||||
|
||||
record.fields.push(f1);
|
||||
|
||||
let f2: Feature = Feature {
|
||||
id: String::from("user_embedding"),
|
||||
val: FeatureVal::FloatVector(vec![2.0f32; 200]),
|
||||
};
|
||||
|
||||
record.fields.push(f2);
|
||||
|
||||
let f3: Feature = Feature {
|
||||
id: String::from("author_embedding"),
|
||||
val: FeatureVal::FloatVector(vec![3.0f32; 200]),
|
||||
};
|
||||
|
||||
record.fields.push(f3);
|
||||
|
||||
let f4: Feature = Feature {
|
||||
id: String::from("binary_features"),
|
||||
val: FeatureVal::U8Vector(vec![4u8; 43]),
|
||||
};
|
||||
|
||||
record.fields.push(f4);
|
||||
|
||||
record
|
||||
}
|
||||
|
Binary file not shown.
|
@ -1,4 +0,0 @@
|
|||
pub mod prediction_service;
|
||||
pub mod data;
|
||||
pub mod tensor;
|
||||
|
Binary file not shown.
|
@ -1,81 +0,0 @@
|
|||
use std::collections::BTreeSet;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use bpr_thrift::data::DataRecord;
|
||||
use bpr_thrift::prediction_service::BatchPredictionRequest;
|
||||
use thrift::OrderedFloat;
|
||||
|
||||
use thrift::protocol::TBinaryInputProtocol;
|
||||
use thrift::protocol::TSerializable;
|
||||
use thrift::transport::TBufferChannel;
|
||||
use thrift::Result;
|
||||
|
||||
fn main() {
|
||||
let data_path = "/tmp/current/timelines/output-1";
|
||||
let bin_data: Vec<u8> = std::fs::read(data_path).expect("Could not read file!");
|
||||
|
||||
println!("Length : {}", bin_data.len());
|
||||
|
||||
let mut bc = TBufferChannel::with_capacity(bin_data.len(), 0);
|
||||
|
||||
bc.set_readable_bytes(&bin_data);
|
||||
|
||||
let mut protocol = TBinaryInputProtocol::new(bc, true);
|
||||
|
||||
let result: Result<BatchPredictionRequest> =
|
||||
BatchPredictionRequest::read_from_in_protocol(&mut protocol);
|
||||
|
||||
match result {
|
||||
Ok(bpr) => logBP(bpr),
|
||||
Err(err) => println!("Error {}", err),
|
||||
}
|
||||
}
|
||||
|
||||
fn logBP(bpr: BatchPredictionRequest) {
|
||||
println!("-------[OUTPUT]---------------");
|
||||
println!("data {:?}", bpr);
|
||||
println!("------------------------------");
|
||||
|
||||
/*
|
||||
let common = bpr.common_features;
|
||||
let recs = bpr.individual_features_list;
|
||||
|
||||
println!("--------[Len : {}]------------------", recs.len());
|
||||
|
||||
println!("-------[COMMON]---------------");
|
||||
match common {
|
||||
Some(dr) => logDR(dr),
|
||||
None => println!("None"),
|
||||
}
|
||||
println!("------------------------------");
|
||||
for rec in recs {
|
||||
logDR(rec);
|
||||
}
|
||||
println!("------------------------------");
|
||||
*/
|
||||
}
|
||||
|
||||
fn logDR(dr: DataRecord) {
|
||||
println!("--------[DR]------------------");
|
||||
|
||||
match dr.binary_features {
|
||||
Some(bf) => logBin(bf),
|
||||
_ => (),
|
||||
}
|
||||
|
||||
match dr.continuous_features {
|
||||
Some(cf) => logCF(cf),
|
||||
_ => (),
|
||||
}
|
||||
println!("------------------------------");
|
||||
}
|
||||
|
||||
fn logBin(bin: BTreeSet<i64>) {
|
||||
println!("B: {:?}", bin)
|
||||
}
|
||||
|
||||
fn logCF(cf: BTreeMap<i64, OrderedFloat<f64>>) {
|
||||
for (id, fs) in cf {
|
||||
println!("C: {} -> [{}]", id, fs);
|
||||
}
|
||||
}
|
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
@ -1,41 +0,0 @@
|
|||
Product Mixer
|
||||
=============
|
||||
|
||||
## Overview
|
||||
|
||||
Product Mixer is a common service framework and set of libraries that make it easy to build,
|
||||
iterate on, and own product surface areas. It consists of:
|
||||
|
||||
- **Core Libraries:** A set of libraries that enable you to build execution pipelines out of
|
||||
reusable components. You define your logic in small, well-defined, reusable components and focus
|
||||
on expressing the business logic you want to have. Then you can define easy to understand pipelines
|
||||
that compose your components. Product Mixer handles the execution and monitoring of your pipelines
|
||||
allowing you to focus on what really matters, your business logic.
|
||||
|
||||
- **Service Framework:** A common service skeleton for teams to host their Product Mixer products.
|
||||
|
||||
- **Component Library:** A shared library of components made by the Product Mixer Team, or
|
||||
contributed by users. This enables you to both easily share the reusable components you make as well
|
||||
as benefit from the work other teams have done by utilizing their shared components in the library.
|
||||
|
||||
## Architecture
|
||||
|
||||
The bulk of a Product Mixer can be broken down into Pipelines and Components. Components allow you
|
||||
to break business logic into separate, standardized, reusable, testable, and easily composable
|
||||
pieces, where each component has a well defined abstraction. Pipelines are essentially configuration
|
||||
files specifying which Components should be used and when. This makes it easy to understand how your
|
||||
code will execute while keeping it organized and structured in a maintainable way.
|
||||
|
||||
Requests first go to Product Pipelines, which are used to select which Mixer Pipeline or
|
||||
Recommendation Pipeline to run for a given request. Each Mixer or Recommendation
|
||||
Pipeline may run multiple Candidate Pipelines to fetch candidates to include in the response.
|
||||
|
||||
Mixer Pipelines combine the results of multiple heterogeneous Candidate Pipelines together
|
||||
(e.g. ads, tweets, users) while Recommendation Pipelines are used to score (via Scoring Pipelines)
|
||||
and rank the results of homogenous Candidate Pipelines so that the top ranked ones can be returned.
|
||||
These pipelines also marshall candidates into a domain object and then into a transport object
|
||||
to return to the caller.
|
||||
|
||||
Candidate Pipelines fetch candidates from underlying Candidate Sources and perform some basic
|
||||
operations on the Candidates, such as filtering out unwanted candidates, applying decorations,
|
||||
and hydrating features.
|
Binary file not shown.
|
@ -1,57 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.account_recommendations_mixer
|
||||
|
||||
import com.twitter.account_recommendations_mixer.{thriftscala => t}
|
||||
import com.twitter.product_mixer.component_library.model.candidate.UserCandidate
|
||||
import com.twitter.product_mixer.core.feature.Feature
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSourceWithExtractedFeatures
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidatesWithSourceFeatures
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.stitch.Stitch
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
object WhoToFollowModuleHeaderFeature extends Feature[UserCandidate, t.Header]
|
||||
object WhoToFollowModuleFooterFeature extends Feature[UserCandidate, Option[t.Footer]]
|
||||
object WhoToFollowModuleDisplayOptionsFeature
|
||||
extends Feature[UserCandidate, Option[t.DisplayOptions]]
|
||||
|
||||
@Singleton
|
||||
class AccountRecommendationsMixerCandidateSource @Inject() (
|
||||
accountRecommendationsMixer: t.AccountRecommendationsMixer.MethodPerEndpoint)
|
||||
extends CandidateSourceWithExtractedFeatures[
|
||||
t.AccountRecommendationsMixerRequest,
|
||||
t.RecommendedUser
|
||||
] {
|
||||
|
||||
override val identifier: CandidateSourceIdentifier =
|
||||
CandidateSourceIdentifier(name = "AccountRecommendationsMixer")
|
||||
|
||||
override def apply(
|
||||
request: t.AccountRecommendationsMixerRequest
|
||||
): Stitch[CandidatesWithSourceFeatures[t.RecommendedUser]] = {
|
||||
Stitch
|
||||
.callFuture(accountRecommendationsMixer.getWtfRecommendations(request))
|
||||
.map { response: t.WhoToFollowResponse =>
|
||||
responseToCandidatesWithSourceFeatures(
|
||||
response.userRecommendations,
|
||||
response.header,
|
||||
response.footer,
|
||||
response.displayOptions)
|
||||
}
|
||||
}
|
||||
|
||||
private def responseToCandidatesWithSourceFeatures(
|
||||
userRecommendations: Seq[t.RecommendedUser],
|
||||
header: t.Header,
|
||||
footer: Option[t.Footer],
|
||||
displayOptions: Option[t.DisplayOptions],
|
||||
): CandidatesWithSourceFeatures[t.RecommendedUser] = {
|
||||
val features = FeatureMapBuilder()
|
||||
.add(WhoToFollowModuleHeaderFeature, header)
|
||||
.add(WhoToFollowModuleFooterFeature, footer)
|
||||
.add(WhoToFollowModuleDisplayOptionsFeature, displayOptions)
|
||||
.build()
|
||||
CandidatesWithSourceFeatures(userRecommendations, features)
|
||||
}
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
platform = "java8",
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"account-recommendations-mixer/thrift/src/main/thrift:thrift-scala",
|
||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/candidate",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/pipeline/pipeline_failure",
|
||||
"src/thrift/com/twitter/ads/adserver:adserver_common-scala",
|
||||
"stitch/stitch-core",
|
||||
],
|
||||
exports = [
|
||||
"account-recommendations-mixer/thrift/src/main/thrift:thrift-scala",
|
||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
||||
"stitch/stitch-core",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,29 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.ads
|
||||
|
||||
import com.twitter.adserver.thriftscala.AdImpression
|
||||
import com.twitter.adserver.thriftscala.AdRequestParams
|
||||
import com.twitter.adserver.thriftscala.AdRequestResponse
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyFetcherSource
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.strato.client.Fetcher
|
||||
import com.twitter.strato.generated.client.ads.admixer.MakeAdRequestClientColumn
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class AdsProdStratoCandidateSource @Inject() (adsClient: MakeAdRequestClientColumn)
|
||||
extends StratoKeyFetcherSource[
|
||||
AdRequestParams,
|
||||
AdRequestResponse,
|
||||
AdImpression
|
||||
] {
|
||||
|
||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("AdsProdStrato")
|
||||
|
||||
override val fetcher: Fetcher[AdRequestParams, Unit, AdRequestResponse] = adsClient.fetcher
|
||||
|
||||
override protected def stratoResultTransformer(
|
||||
stratoResult: AdRequestResponse
|
||||
): Seq[AdImpression] =
|
||||
stratoResult.impressions
|
||||
}
|
Binary file not shown.
|
@ -1,22 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.ads
|
||||
|
||||
import com.twitter.adserver.thriftscala.AdImpression
|
||||
import com.twitter.adserver.thriftscala.AdRequestParams
|
||||
import com.twitter.adserver.thriftscala.NewAdServer
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.stitch.Stitch
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class AdsProdThriftCandidateSource @Inject() (
|
||||
adServerClient: NewAdServer.MethodPerEndpoint)
|
||||
extends CandidateSource[AdRequestParams, AdImpression] {
|
||||
|
||||
override val identifier: CandidateSourceIdentifier =
|
||||
CandidateSourceIdentifier("AdsProdThrift")
|
||||
|
||||
override def apply(request: AdRequestParams): Stitch[Seq[AdImpression]] =
|
||||
Stitch.callFuture(adServerClient.makeAdRequest(request)).map(_.impressions)
|
||||
}
|
Binary file not shown.
|
@ -1,28 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.ads
|
||||
|
||||
import com.twitter.adserver.thriftscala.AdImpression
|
||||
import com.twitter.adserver.thriftscala.AdRequestParams
|
||||
import com.twitter.adserver.thriftscala.AdRequestResponse
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyFetcherSource
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.strato.client.Fetcher
|
||||
import com.twitter.strato.generated.client.ads.admixer.MakeAdRequestStagingClientColumn
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class AdsStagingCandidateSource @Inject() (adsClient: MakeAdRequestStagingClientColumn)
|
||||
extends StratoKeyFetcherSource[
|
||||
AdRequestParams,
|
||||
AdRequestResponse,
|
||||
AdImpression
|
||||
] {
|
||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("AdsStaging")
|
||||
|
||||
override val fetcher: Fetcher[AdRequestParams, Unit, AdRequestResponse] = adsClient.fetcher
|
||||
|
||||
override protected def stratoResultTransformer(
|
||||
stratoResult: AdRequestResponse
|
||||
): Seq[AdImpression] =
|
||||
stratoResult.impressions
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/javax/inject:javax.inject",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
||||
"src/thrift/com/twitter/ads/adserver:adserver_common-scala",
|
||||
"src/thrift/com/twitter/ads/adserver:adserver_rpc-scala",
|
||||
"strato/config/columns/ads/admixer:admixer-strato-client",
|
||||
],
|
||||
exports = [
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
||||
"src/thrift/com/twitter/ads/adserver:adserver_common-scala",
|
||||
"src/thrift/com/twitter/ads/adserver:adserver_rpc-scala",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,43 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.ann
|
||||
|
||||
import com.twitter.ann.common._
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.stitch.Stitch
|
||||
import com.twitter.util.{Time => _, _}
|
||||
import com.twitter.finagle.util.DefaultTimer
|
||||
|
||||
/**
|
||||
* @param annQueryableById Ann Queryable by Id client that returns nearest neighbors for a sequence of queries
|
||||
* @param identifier Candidate Source Identifier
|
||||
* @tparam T1 type of the query.
|
||||
* @tparam T2 type of the result.
|
||||
* @tparam P runtime parameters supported by the index.
|
||||
* @tparam D distance function used in the index.
|
||||
*/
|
||||
class AnnCandidateSource[T1, T2, P <: RuntimeParams, D <: Distance[D]](
|
||||
val annQueryableById: QueryableById[T1, T2, P, D],
|
||||
val batchSize: Int,
|
||||
val timeoutPerRequest: Duration,
|
||||
override val identifier: CandidateSourceIdentifier)
|
||||
extends CandidateSource[AnnIdQuery[T1, P], NeighborWithDistanceWithSeed[T1, T2, D]] {
|
||||
|
||||
implicit val timer = DefaultTimer
|
||||
|
||||
override def apply(
|
||||
request: AnnIdQuery[T1, P]
|
||||
): Stitch[Seq[NeighborWithDistanceWithSeed[T1, T2, D]]] = {
|
||||
val ids = request.ids
|
||||
val numOfNeighbors = request.numOfNeighbors
|
||||
val runtimeParams = request.runtimeParams
|
||||
Stitch
|
||||
.collect(
|
||||
ids
|
||||
.grouped(batchSize).map { batchedIds =>
|
||||
annQueryableById
|
||||
.batchQueryWithDistanceById(batchedIds, numOfNeighbors, runtimeParams).map {
|
||||
annResult => annResult.toSeq
|
||||
}.within(timeoutPerRequest).handle { case _ => Seq.empty }
|
||||
}.toSeq).map(_.flatten)
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,18 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.ann
|
||||
|
||||
import com.twitter.ann.common._
|
||||
|
||||
/**
|
||||
* A [[AnnIdQuery]] is a query class which defines the ann entities with runtime params and number of neighbors requested
|
||||
*
|
||||
* @param ids Sequence of queries
|
||||
* @param numOfNeighbors Number of neighbors requested
|
||||
* @param runtimeParams ANN Runtime Params
|
||||
* @param batchSize Batch size to the stitch client
|
||||
* @tparam T type of query.
|
||||
* @tparam P runtime parameters supported by the index.
|
||||
*/
|
||||
case class AnnIdQuery[T, P <: RuntimeParams](
|
||||
ids: Seq[T],
|
||||
numOfNeighbors: Int,
|
||||
runtimeParams: P)
|
|
@ -1,17 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"ann/src/main/scala/com/twitter/ann/common",
|
||||
"ann/src/main/scala/com/twitter/ann/hnsw",
|
||||
"ann/src/main/thrift/com/twitter/ann/common:ann-common-scala",
|
||||
"product-mixer/component-library/src/main/thrift/com/twitter/product_mixer/component_library:thrift-scala",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
||||
"servo/manhattan/src/main/scala",
|
||||
"servo/repo/src/main/scala",
|
||||
"servo/util/src/main/scala",
|
||||
"stitch/stitch-core",
|
||||
],
|
||||
)
|
Binary file not shown.
|
@ -1,14 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/cursor",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
||||
"src/thrift/com/twitter/periscope/audio_space:audio_space-scala",
|
||||
"strato/config/columns/periscope:periscope-strato-client",
|
||||
"strato/config/src/thrift/com/twitter/strato/graphql:graphql-scala",
|
||||
"strato/src/main/scala/com/twitter/strato/client",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,49 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.audiospace
|
||||
|
||||
import com.twitter.periscope.audio_space.thriftscala.CreatedSpacesView
|
||||
import com.twitter.periscope.audio_space.thriftscala.SpaceSlice
|
||||
import com.twitter.product_mixer.component_library.model.cursor.NextCursorFeature
|
||||
import com.twitter.product_mixer.component_library.model.cursor.PreviousCursorFeature
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyViewFetcherWithSourceFeaturesSource
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.strato.client.Fetcher
|
||||
import com.twitter.strato.generated.client.periscope.CreatedSpacesSliceOnUserClientColumn
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class CreatedSpacesCandidateSource @Inject() (
|
||||
column: CreatedSpacesSliceOnUserClientColumn)
|
||||
extends StratoKeyViewFetcherWithSourceFeaturesSource[
|
||||
Long,
|
||||
CreatedSpacesView,
|
||||
SpaceSlice,
|
||||
String
|
||||
] {
|
||||
|
||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("CreatedSpaces")
|
||||
|
||||
override val fetcher: Fetcher[Long, CreatedSpacesView, SpaceSlice] = column.fetcher
|
||||
|
||||
override def stratoResultTransformer(
|
||||
stratoKey: Long,
|
||||
stratoResult: SpaceSlice
|
||||
): Seq[String] =
|
||||
stratoResult.items
|
||||
|
||||
override protected def extractFeaturesFromStratoResult(
|
||||
stratoKey: Long,
|
||||
stratoResult: SpaceSlice
|
||||
): FeatureMap = {
|
||||
val featureMapBuilder = FeatureMapBuilder()
|
||||
stratoResult.sliceInfo.previousCursor.foreach { cursor =>
|
||||
featureMapBuilder.add(PreviousCursorFeature, cursor)
|
||||
}
|
||||
stratoResult.sliceInfo.nextCursor.foreach { cursor =>
|
||||
featureMapBuilder.add(NextCursorFeature, cursor)
|
||||
}
|
||||
featureMapBuilder.build()
|
||||
}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"product-mixer/component-library/src/main/scala/com/twitter/product_mixer/component_library/model/cursor",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
||||
"strato/config/columns/consumer-identity/business-profiles:business-profiles-strato-client",
|
||||
"strato/config/src/thrift/com/twitter/strato/graphql:graphql-scala",
|
||||
"strato/src/main/scala/com/twitter/strato/client",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,53 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.business_profiles
|
||||
|
||||
import com.twitter.product_mixer.component_library.model.cursor.NextCursorFeature
|
||||
import com.twitter.product_mixer.component_library.model.cursor.PreviousCursorFeature
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMap
|
||||
import com.twitter.product_mixer.core.feature.featuremap.FeatureMapBuilder
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.strato.StratoKeyViewFetcherWithSourceFeaturesSource
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.strato.client.Fetcher
|
||||
import com.twitter.strato.generated.client.consumer_identity.business_profiles.BusinessProfileTeamMembersOnUserClientColumn
|
||||
import com.twitter.strato.generated.client.consumer_identity.business_profiles.BusinessProfileTeamMembersOnUserClientColumn.{
|
||||
Value => TeamMembersSlice
|
||||
}
|
||||
import com.twitter.strato.generated.client.consumer_identity.business_profiles.BusinessProfileTeamMembersOnUserClientColumn.{
|
||||
View => TeamMembersView
|
||||
}
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class TeamMembersCandidateSource @Inject() (
|
||||
column: BusinessProfileTeamMembersOnUserClientColumn)
|
||||
extends StratoKeyViewFetcherWithSourceFeaturesSource[
|
||||
Long,
|
||||
TeamMembersView,
|
||||
TeamMembersSlice,
|
||||
Long
|
||||
] {
|
||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier(
|
||||
"BusinessProfileTeamMembers")
|
||||
|
||||
override val fetcher: Fetcher[Long, TeamMembersView, TeamMembersSlice] = column.fetcher
|
||||
|
||||
override def stratoResultTransformer(
|
||||
stratoKey: Long,
|
||||
stratoResult: TeamMembersSlice
|
||||
): Seq[Long] =
|
||||
stratoResult.members
|
||||
|
||||
override protected def extractFeaturesFromStratoResult(
|
||||
stratoKey: Long,
|
||||
stratoResult: TeamMembersSlice
|
||||
): FeatureMap = {
|
||||
val featureMapBuilder = FeatureMapBuilder()
|
||||
stratoResult.previousCursor.foreach { cursor =>
|
||||
featureMapBuilder.add(PreviousCursorFeature, cursor.toString)
|
||||
}
|
||||
stratoResult.nextCursor.foreach { cursor =>
|
||||
featureMapBuilder.add(NextCursorFeature, cursor.toString)
|
||||
}
|
||||
featureMapBuilder.build()
|
||||
}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"cr-mixer/thrift/src/main/thrift:thrift-scala",
|
||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
||||
"stitch/stitch-core",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,25 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.cr_mixer
|
||||
|
||||
import com.twitter.cr_mixer.{thriftscala => t}
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.stitch.Stitch
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
/**
|
||||
* Returns out-of-network Tweet recommendations by using user recommendations
|
||||
* from FollowRecommendationService as an input seed-set to Earlybird
|
||||
*/
|
||||
@Singleton
|
||||
class CrMixerFrsBasedTweetRecommendationsCandidateSource @Inject() (
|
||||
crMixerClient: t.CrMixer.MethodPerEndpoint)
|
||||
extends CandidateSource[t.FrsTweetRequest, t.FrsTweet] {
|
||||
|
||||
override val identifier: CandidateSourceIdentifier =
|
||||
CandidateSourceIdentifier("CrMixerFrsBasedTweetRecommendations")
|
||||
|
||||
override def apply(request: t.FrsTweetRequest): Stitch[Seq[t.FrsTweet]] = Stitch
|
||||
.callFuture(crMixerClient.getFrsBasedTweetRecommendations(request))
|
||||
.map(_.tweets)
|
||||
}
|
Binary file not shown.
|
@ -1,21 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.cr_mixer
|
||||
|
||||
import com.twitter.cr_mixer.{thriftscala => t}
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.stitch.Stitch
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class CrMixerTweetRecommendationsCandidateSource @Inject() (
|
||||
crMixerClient: t.CrMixer.MethodPerEndpoint)
|
||||
extends CandidateSource[t.CrMixerTweetRequest, t.TweetRecommendation] {
|
||||
|
||||
override val identifier: CandidateSourceIdentifier =
|
||||
CandidateSourceIdentifier("CrMixerTweetRecommendations")
|
||||
|
||||
override def apply(request: t.CrMixerTweetRequest): Stitch[Seq[t.TweetRecommendation]] = Stitch
|
||||
.callFuture(crMixerClient.getTweetRecommendations(request))
|
||||
.map(_.tweets)
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
||||
"src/thrift/com/twitter/search:earlybird-scala",
|
||||
"stitch/stitch-core",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,26 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.earlybird
|
||||
|
||||
import com.twitter.search.earlybird.{thriftscala => t}
|
||||
import com.twitter.inject.Logging
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.stitch.Stitch
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class EarlybirdTweetCandidateSource @Inject() (
|
||||
earlybirdService: t.EarlybirdService.MethodPerEndpoint)
|
||||
extends CandidateSource[t.EarlybirdRequest, t.ThriftSearchResult]
|
||||
with Logging {
|
||||
|
||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("EarlybirdTweets")
|
||||
|
||||
override def apply(request: t.EarlybirdRequest): Stitch[Seq[t.ThriftSearchResult]] = {
|
||||
Stitch
|
||||
.callFuture(earlybirdService.search(request))
|
||||
.map { response: t.EarlybirdResponse =>
|
||||
response.searchResults.map(_.results).getOrElse(Seq.empty)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/javax/inject:javax.inject",
|
||||
"explore/explore-ranker/thrift/src/main/thrift:thrift-scala",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
||||
"stitch/stitch-core",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,31 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.explore_ranker
|
||||
|
||||
import com.twitter.explore_ranker.{thriftscala => t}
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.stitch.Stitch
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class ExploreRankerCandidateSource @Inject() (
|
||||
exploreRankerService: t.ExploreRanker.MethodPerEndpoint)
|
||||
extends CandidateSource[t.ExploreRankerRequest, t.ImmersiveRecsResult] {
|
||||
|
||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier("ExploreRanker")
|
||||
|
||||
override def apply(
|
||||
request: t.ExploreRankerRequest
|
||||
): Stitch[Seq[t.ImmersiveRecsResult]] = {
|
||||
Stitch
|
||||
.callFuture(exploreRankerService.getRankedResults(request))
|
||||
.map {
|
||||
case t.ExploreRankerResponse(
|
||||
t.ExploreRankerProductResponse
|
||||
.ImmersiveRecsResponse(t.ImmersiveRecsResponse(immersiveRecsResults))) =>
|
||||
immersiveRecsResults
|
||||
case response =>
|
||||
throw new UnsupportedOperationException(s"Unknown response type: $response")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
platform = "java8",
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"finatra/inject/inject-core/src/main/scala/com/twitter/inject",
|
||||
"onboarding/service/thrift/src/main/thrift:thrift-scala",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
||||
"stitch/stitch-core",
|
||||
],
|
||||
exports = [
|
||||
"onboarding/service/thrift/src/main/thrift:thrift-scala",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source",
|
||||
],
|
||||
)
|
Binary file not shown.
Binary file not shown.
|
@ -1,50 +0,0 @@
|
|||
package com.twitter.product_mixer.component_library.candidate_source.flexible_injection_pipeline
|
||||
|
||||
import com.twitter.inject.Logging
|
||||
import com.twitter.onboarding.injections.{thriftscala => injectionsthrift}
|
||||
import com.twitter.onboarding.task.service.{thriftscala => servicethrift}
|
||||
import com.twitter.product_mixer.core.functional_component.candidate_source.CandidateSource
|
||||
import com.twitter.product_mixer.core.model.common.identifier.CandidateSourceIdentifier
|
||||
import com.twitter.stitch.Stitch
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
/**
|
||||
* Returns a list of prompts to insert into a user's timeline (inline prompt, cover modals, etc)
|
||||
* from go/flip (the prompting platform for Twitter).
|
||||
*/
|
||||
@Singleton
|
||||
class PromptCandidateSource @Inject() (taskService: servicethrift.TaskService.MethodPerEndpoint)
|
||||
extends CandidateSource[servicethrift.GetInjectionsRequest, IntermediatePrompt]
|
||||
with Logging {
|
||||
|
||||
override val identifier: CandidateSourceIdentifier = CandidateSourceIdentifier(
|
||||
"InjectionPipelinePrompts")
|
||||
|
||||
override def apply(
|
||||
request: servicethrift.GetInjectionsRequest
|
||||
): Stitch[Seq[IntermediatePrompt]] = {
|
||||
Stitch
|
||||
.callFuture(taskService.getInjections(request)).map {
|
||||
_.injections.flatMap {
|
||||
// The entire carousel is getting added to each IntermediatePrompt item with a
|
||||
// corresponding index to be unpacked later on to populate its TimelineEntry counterpart.
|
||||
case injection: injectionsthrift.Injection.TilesCarousel =>
|
||||
injection.tilesCarousel.tiles.zipWithIndex.map {
|
||||
case (tile: injectionsthrift.Tile, index: Int) =>
|
||||
IntermediatePrompt(injection, Some(index), Some(tile))
|
||||
}
|
||||
case injection => Seq(IntermediatePrompt(injection, None, None))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gives an intermediate step to help 'explosion' of tile carousel tiles due to TimelineModule
|
||||
* not being an extension of TimelineItem
|
||||
*/
|
||||
case class IntermediatePrompt(
|
||||
injection: injectionsthrift.Injection,
|
||||
offsetInModule: Option[Int],
|
||||
carouselTile: Option[injectionsthrift.Tile])
|
|
@ -1,16 +0,0 @@
|
|||
scala_library(
|
||||
sources = ["*.scala"],
|
||||
compiler_option_sets = ["fatal_warnings"],
|
||||
strict_deps = True,
|
||||
tags = ["bazel-compatible"],
|
||||
dependencies = [
|
||||
"3rdparty/jvm/javax/inject:javax.inject",
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
||||
"src/thrift/com/twitter/hermit:hermit-scala",
|
||||
"strato/config/columns/onboarding:onboarding-strato-client",
|
||||
],
|
||||
exports = [
|
||||
"product-mixer/core/src/main/scala/com/twitter/product_mixer/core/functional_component/candidate_source/strato",
|
||||
"src/thrift/com/twitter/hermit:hermit-scala",
|
||||
],
|
||||
)
|
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue