improvements from external prs
-fix corner case where dr converter failed when initializing Closes twitter/the-algorithm#550
This commit is contained in:
parent
23fa75d406
commit
31e82d6474
|
@ -44,5 +44,6 @@ pub struct RenamedFeatures {
|
|||
}
|
||||
|
||||
pub fn parse(json_str: &str) -> Result<AllConfig, Error> {
|
||||
serde_json::from_str(json_str)
|
||||
let all_config: AllConfig = serde_json::from_str(json_str)?;
|
||||
Ok(all_config)
|
||||
}
|
||||
|
|
|
@ -2,6 +2,9 @@ use std::collections::BTreeSet;
|
|||
use std::fmt::{self, Debug, Display};
|
||||
use std::fs;
|
||||
|
||||
use crate::all_config;
|
||||
use crate::all_config::AllConfig;
|
||||
use anyhow::{bail, Context};
|
||||
use bpr_thrift::data::DataRecord;
|
||||
use bpr_thrift::prediction_service::BatchPredictionRequest;
|
||||
use bpr_thrift::tensor::GeneralTensor;
|
||||
|
@ -16,8 +19,6 @@ use segdense::util;
|
|||
use thrift::protocol::{TBinaryInputProtocol, TSerializable};
|
||||
use thrift::transport::TBufferChannel;
|
||||
|
||||
use crate::{all_config, all_config::AllConfig};
|
||||
|
||||
pub fn log_feature_match(
|
||||
dr: &DataRecord,
|
||||
seg_dense_config: &DensificationTransformSpec,
|
||||
|
@ -28,20 +29,24 @@ pub fn log_feature_match(
|
|||
|
||||
for (feature_id, feature_value) in dr.continuous_features.as_ref().unwrap() {
|
||||
debug!(
|
||||
"{dr_type} - Continuous Datarecord => Feature ID: {feature_id}, Feature value: {feature_value}"
|
||||
"{} - Continous Datarecord => Feature ID: {}, Feature value: {}",
|
||||
dr_type, feature_id, feature_value
|
||||
);
|
||||
for input_feature in &seg_dense_config.cont.input_features {
|
||||
if input_feature.feature_id == *feature_id {
|
||||
debug!("Matching input feature: {input_feature:?}")
|
||||
debug!("Matching input feature: {:?}", input_feature)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for feature_id in dr.binary_features.as_ref().unwrap() {
|
||||
debug!("{dr_type} - Binary Datarecord => Feature ID: {feature_id}");
|
||||
debug!(
|
||||
"{} - Binary Datarecord => Feature ID: {}",
|
||||
dr_type, feature_id
|
||||
);
|
||||
for input_feature in &seg_dense_config.binary.input_features {
|
||||
if input_feature.feature_id == *feature_id {
|
||||
debug!("Found input feature: {input_feature:?}")
|
||||
debug!("Found input feature: {:?}", input_feature)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -90,18 +95,19 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
model_version: &str,
|
||||
reporting_feature_ids: Vec<(i64, &str)>,
|
||||
register_metric_fn: Option<impl Fn(&HistogramVec)>,
|
||||
) -> BatchPredictionRequestToTorchTensorConverter {
|
||||
let all_config_path = format!("{model_dir}/{model_version}/all_config.json");
|
||||
let seg_dense_config_path =
|
||||
format!("{model_dir}/{model_version}/segdense_transform_spec_home_recap_2022.json");
|
||||
let seg_dense_config = util::load_config(&seg_dense_config_path);
|
||||
) -> anyhow::Result<BatchPredictionRequestToTorchTensorConverter> {
|
||||
let all_config_path = format!("{}/{}/all_config.json", model_dir, model_version);
|
||||
let seg_dense_config_path = format!(
|
||||
"{}/{}/segdense_transform_spec_home_recap_2022.json",
|
||||
model_dir, model_version
|
||||
);
|
||||
let seg_dense_config = util::load_config(&seg_dense_config_path)?;
|
||||
let all_config = all_config::parse(
|
||||
&fs::read_to_string(&all_config_path)
|
||||
.unwrap_or_else(|error| panic!("error loading all_config.json - {error}")),
|
||||
)
|
||||
.unwrap();
|
||||
.with_context(|| "error loading all_config.json - ")?,
|
||||
)?;
|
||||
|
||||
let feature_mapper = util::load_from_parsed_config_ref(&seg_dense_config);
|
||||
let feature_mapper = util::load_from_parsed_config(seg_dense_config.clone())?;
|
||||
|
||||
let user_embedding_feature_id = Self::get_feature_id(
|
||||
&all_config
|
||||
|
@ -131,11 +137,11 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
let (discrete_feature_metrics, continuous_feature_metrics) = METRICS.get_or_init(|| {
|
||||
let discrete = HistogramVec::new(
|
||||
HistogramOpts::new(":navi:feature_id:discrete", "Discrete Feature ID values")
|
||||
.buckets(Vec::from([
|
||||
0.0f64, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0,
|
||||
.buckets(Vec::from(&[
|
||||
0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0,
|
||||
120.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0,
|
||||
300.0, 500.0, 1000.0, 10000.0, 100000.0,
|
||||
])),
|
||||
] as &'static [f64])),
|
||||
&["feature_id"],
|
||||
)
|
||||
.expect("metric cannot be created");
|
||||
|
@ -144,18 +150,18 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
":navi:feature_id:continuous",
|
||||
"continuous Feature ID values",
|
||||
)
|
||||
.buckets(Vec::from([
|
||||
0.0f64, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0,
|
||||
120.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0,
|
||||
500.0, 1000.0, 10000.0, 100000.0,
|
||||
])),
|
||||
.buckets(Vec::from(&[
|
||||
0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0,
|
||||
130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 250.0, 300.0, 500.0,
|
||||
1000.0, 10000.0, 100000.0,
|
||||
] as &'static [f64])),
|
||||
&["feature_id"],
|
||||
)
|
||||
.expect("metric cannot be created");
|
||||
if let Some(r) = register_metric_fn {
|
||||
register_metric_fn.map(|r| {
|
||||
r(&discrete);
|
||||
r(&continuous);
|
||||
}
|
||||
});
|
||||
(discrete, continuous)
|
||||
});
|
||||
|
||||
|
@ -164,13 +170,16 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
|
||||
for (feature_id, feature_type) in reporting_feature_ids.iter() {
|
||||
match *feature_type {
|
||||
"discrete" => discrete_features_to_report.insert(*feature_id),
|
||||
"continuous" => continuous_features_to_report.insert(*feature_id),
|
||||
_ => panic!("Invalid feature type {feature_type} for reporting metrics!"),
|
||||
"discrete" => discrete_features_to_report.insert(feature_id.clone()),
|
||||
"continuous" => continuous_features_to_report.insert(feature_id.clone()),
|
||||
_ => bail!(
|
||||
"Invalid feature type {} for reporting metrics!",
|
||||
feature_type
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
BatchPredictionRequestToTorchTensorConverter {
|
||||
Ok(BatchPredictionRequestToTorchTensorConverter {
|
||||
all_config,
|
||||
seg_dense_config,
|
||||
all_config_path,
|
||||
|
@ -183,7 +192,7 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
continuous_features_to_report,
|
||||
discrete_feature_metrics,
|
||||
continuous_feature_metrics,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn get_feature_id(feature_name: &str, seg_dense_config: &Root) -> i64 {
|
||||
|
@ -218,43 +227,45 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
let mut working_set = vec![0 as f32; total_size];
|
||||
let mut bpr_start = 0;
|
||||
for (bpr, &bpr_end) in bprs.iter().zip(batch_size) {
|
||||
if bpr.common_features.is_some()
|
||||
&& bpr.common_features.as_ref().unwrap().tensors.is_some()
|
||||
&& bpr
|
||||
.common_features
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.tensors
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains_key(&feature_id)
|
||||
{
|
||||
let source_tensor = bpr
|
||||
.common_features
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.tensors
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.get(&feature_id)
|
||||
.unwrap();
|
||||
let tensor = match source_tensor {
|
||||
GeneralTensor::FloatTensor(float_tensor) =>
|
||||
//Tensor::of_slice(
|
||||
if bpr.common_features.is_some() {
|
||||
if bpr.common_features.as_ref().unwrap().tensors.is_some() {
|
||||
if bpr
|
||||
.common_features
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.tensors
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.contains_key(&feature_id)
|
||||
{
|
||||
float_tensor
|
||||
.floats
|
||||
.iter()
|
||||
.map(|x| x.into_inner() as f32)
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
_ => vec![0 as f32; cols],
|
||||
};
|
||||
let source_tensor = bpr
|
||||
.common_features
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.tensors
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.get(&feature_id)
|
||||
.unwrap();
|
||||
let tensor = match source_tensor {
|
||||
GeneralTensor::FloatTensor(float_tensor) =>
|
||||
//Tensor::of_slice(
|
||||
{
|
||||
float_tensor
|
||||
.floats
|
||||
.iter()
|
||||
.map(|x| x.into_inner() as f32)
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
_ => vec![0 as f32; cols],
|
||||
};
|
||||
|
||||
// since the tensor is found in common feature, add it in all batches
|
||||
for row in bpr_start..bpr_end {
|
||||
for col in 0..cols {
|
||||
working_set[row * cols + col] = tensor[col];
|
||||
// since the tensor is found in common feature, add it in all batches
|
||||
for row in bpr_start..bpr_end {
|
||||
for col in 0..cols {
|
||||
working_set[row * cols + col] = tensor[col];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -298,9 +309,9 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
// (INT64 --> INT64, DataRecord.discrete_feature)
|
||||
fn get_continuous(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
|
||||
// These need to be part of model schema
|
||||
let rows = batch_ends[batch_ends.len() - 1];
|
||||
let cols = 5293;
|
||||
let full_size = rows * cols;
|
||||
let rows: usize = batch_ends[batch_ends.len() - 1];
|
||||
let cols: usize = 5293;
|
||||
let full_size: usize = rows * cols;
|
||||
let default_val = f32::NAN;
|
||||
|
||||
let mut tensor = vec![default_val; full_size];
|
||||
|
@ -325,15 +336,18 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
.unwrap();
|
||||
|
||||
for feature in common_features {
|
||||
if let Some(f_info) = self.feature_mapper.get(feature.0) {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
if idx < cols {
|
||||
// Set value in each row
|
||||
for r in bpr_start..bpr_end {
|
||||
let flat_index = r * cols + idx;
|
||||
tensor[flat_index] = feature.1.into_inner() as f32;
|
||||
match self.feature_mapper.get(feature.0) {
|
||||
Some(f_info) => {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
if idx < cols {
|
||||
// Set value in each row
|
||||
for r in bpr_start..bpr_end {
|
||||
let flat_index: usize = r * cols + idx;
|
||||
tensor[flat_index] = feature.1.into_inner() as f32;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
if self.continuous_features_to_report.contains(feature.0) {
|
||||
self.continuous_feature_metrics
|
||||
|
@ -349,24 +363,28 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
|
||||
// Process the batch of datarecords
|
||||
for r in bpr_start..bpr_end {
|
||||
let dr: &DataRecord = &bpr.individual_features_list[r - bpr_start];
|
||||
let dr: &DataRecord =
|
||||
&bpr.individual_features_list[usize::try_from(r - bpr_start).unwrap()];
|
||||
if dr.continuous_features.is_some() {
|
||||
for feature in dr.continuous_features.as_ref().unwrap() {
|
||||
if let Some(f_info) = self.feature_mapper.get(feature.0) {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
let flat_index = r * cols + idx;
|
||||
if flat_index < tensor.len() && idx < cols {
|
||||
tensor[flat_index] = feature.1.into_inner() as f32;
|
||||
match self.feature_mapper.get(&feature.0) {
|
||||
Some(f_info) => {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
let flat_index: usize = r * cols + idx;
|
||||
if flat_index < tensor.len() && idx < cols {
|
||||
tensor[flat_index] = feature.1.into_inner() as f32;
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
if self.continuous_features_to_report.contains(feature.0) {
|
||||
self.continuous_feature_metrics
|
||||
.with_label_values(&[feature.0.to_string().as_str()])
|
||||
.observe(feature.1.into_inner())
|
||||
.observe(feature.1.into_inner() as f64)
|
||||
} else if self.discrete_features_to_report.contains(feature.0) {
|
||||
self.discrete_feature_metrics
|
||||
.with_label_values(&[feature.0.to_string().as_str()])
|
||||
.observe(feature.1.into_inner())
|
||||
.observe(feature.1.into_inner() as f64)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -383,10 +401,10 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
|
||||
fn get_binary(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
|
||||
// These need to be part of model schema
|
||||
let rows = batch_ends[batch_ends.len() - 1];
|
||||
let cols = 149;
|
||||
let full_size = rows * cols;
|
||||
let default_val = 0;
|
||||
let rows: usize = batch_ends[batch_ends.len() - 1];
|
||||
let cols: usize = 149;
|
||||
let full_size: usize = rows * cols;
|
||||
let default_val: i64 = 0;
|
||||
|
||||
let mut v = vec![default_val; full_size];
|
||||
|
||||
|
@ -410,15 +428,18 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
.unwrap();
|
||||
|
||||
for feature in common_features {
|
||||
if let Some(f_info) = self.feature_mapper.get(feature) {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
if idx < cols {
|
||||
// Set value in each row
|
||||
for r in bpr_start..bpr_end {
|
||||
let flat_index = r * cols + idx;
|
||||
v[flat_index] = 1;
|
||||
match self.feature_mapper.get(feature) {
|
||||
Some(f_info) => {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
if idx < cols {
|
||||
// Set value in each row
|
||||
for r in bpr_start..bpr_end {
|
||||
let flat_index: usize = r * cols + idx;
|
||||
v[flat_index] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -428,10 +449,13 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
let dr: &DataRecord = &bpr.individual_features_list[r - bpr_start];
|
||||
if dr.binary_features.is_some() {
|
||||
for feature in dr.binary_features.as_ref().unwrap() {
|
||||
if let Some(f_info) = self.feature_mapper.get(feature) {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
let flat_index = r * cols + idx;
|
||||
v[flat_index] = 1;
|
||||
match self.feature_mapper.get(&feature) {
|
||||
Some(f_info) => {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
let flat_index: usize = r * cols + idx;
|
||||
v[flat_index] = 1;
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -448,10 +472,10 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
#[allow(dead_code)]
|
||||
fn get_discrete(&self, bprs: &[BatchPredictionRequest], batch_ends: &[usize]) -> InputTensor {
|
||||
// These need to be part of model schema
|
||||
let rows = batch_ends[batch_ends.len() - 1];
|
||||
let cols = 320;
|
||||
let full_size = rows * cols;
|
||||
let default_val = 0;
|
||||
let rows: usize = batch_ends[batch_ends.len() - 1];
|
||||
let cols: usize = 320;
|
||||
let full_size: usize = rows * cols;
|
||||
let default_val: i64 = 0;
|
||||
|
||||
let mut v = vec![default_val; full_size];
|
||||
|
||||
|
@ -475,15 +499,18 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
.unwrap();
|
||||
|
||||
for feature in common_features {
|
||||
if let Some(f_info) = self.feature_mapper.get(feature.0) {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
if idx < cols {
|
||||
// Set value in each row
|
||||
for r in bpr_start..bpr_end {
|
||||
let flat_index = r * cols + idx;
|
||||
v[flat_index] = *feature.1;
|
||||
match self.feature_mapper.get(feature.0) {
|
||||
Some(f_info) => {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
if idx < cols {
|
||||
// Set value in each row
|
||||
for r in bpr_start..bpr_end {
|
||||
let flat_index: usize = r * cols + idx;
|
||||
v[flat_index] = *feature.1;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
if self.discrete_features_to_report.contains(feature.0) {
|
||||
self.discrete_feature_metrics
|
||||
|
@ -495,15 +522,18 @@ impl BatchPredictionRequestToTorchTensorConverter {
|
|||
|
||||
// Process the batch of datarecords
|
||||
for r in bpr_start..bpr_end {
|
||||
let dr: &DataRecord = &bpr.individual_features_list[r];
|
||||
let dr: &DataRecord = &bpr.individual_features_list[usize::try_from(r).unwrap()];
|
||||
if dr.discrete_features.is_some() {
|
||||
for feature in dr.discrete_features.as_ref().unwrap() {
|
||||
if let Some(f_info) = self.feature_mapper.get(feature.0) {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
let flat_index = r * cols + idx;
|
||||
if flat_index < v.len() && idx < cols {
|
||||
v[flat_index] = *feature.1;
|
||||
match self.feature_mapper.get(&feature.0) {
|
||||
Some(f_info) => {
|
||||
let idx = f_info.index_within_tensor as usize;
|
||||
let flat_index: usize = r * cols + idx;
|
||||
if flat_index < v.len() && idx < cols {
|
||||
v[flat_index] = *feature.1;
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
if self.discrete_features_to_report.contains(feature.0) {
|
||||
self.discrete_feature_metrics
|
||||
|
@ -569,7 +599,7 @@ impl Converter for BatchPredictionRequestToTorchTensorConverter {
|
|||
.map(|bpr| bpr.individual_features_list.len())
|
||||
.scan(0usize, |acc, e| {
|
||||
//running total
|
||||
*acc += e;
|
||||
*acc = *acc + e;
|
||||
Some(*acc)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
|
|
@ -122,7 +122,7 @@ enum FullTypeId {
|
|||
// TFT_TENSOR[TFT_INT32, TFT_UNKNOWN]
|
||||
// is a Tensor of int32 element type and unknown shape.
|
||||
//
|
||||
// TODO: Define TFT_SHAPE and add more examples.
|
||||
// TODO(mdan): Define TFT_SHAPE and add more examples.
|
||||
TFT_TENSOR = 1000;
|
||||
|
||||
// Array (or tensorflow::TensorList in the variant type registry).
|
||||
|
@ -178,7 +178,7 @@ enum FullTypeId {
|
|||
// object (for now).
|
||||
|
||||
// The bool element type.
|
||||
// TODO
|
||||
// TODO(mdan): Quantized types, legacy representations (e.g. ref)
|
||||
TFT_BOOL = 200;
|
||||
// Integer element types.
|
||||
TFT_UINT8 = 201;
|
||||
|
@ -195,7 +195,7 @@ enum FullTypeId {
|
|||
TFT_DOUBLE = 211;
|
||||
TFT_BFLOAT16 = 215;
|
||||
// Complex element types.
|
||||
// TODO: Represent as TFT_COMPLEX[TFT_DOUBLE] instead?
|
||||
// TODO(mdan): Represent as TFT_COMPLEX[TFT_DOUBLE] instead?
|
||||
TFT_COMPLEX64 = 212;
|
||||
TFT_COMPLEX128 = 213;
|
||||
// The string element type.
|
||||
|
@ -240,7 +240,7 @@ enum FullTypeId {
|
|||
// ownership is in the true sense: "the op argument representing the lock is
|
||||
// available".
|
||||
// Mutex locks are the dynamic counterpart of control dependencies.
|
||||
// TODO: Properly document this thing.
|
||||
// TODO(mdan): Properly document this thing.
|
||||
//
|
||||
// Parametrization: TFT_MUTEX_LOCK[].
|
||||
TFT_MUTEX_LOCK = 10202;
|
||||
|
@ -271,6 +271,6 @@ message FullTypeDef {
|
|||
oneof attr {
|
||||
string s = 3;
|
||||
int64 i = 4;
|
||||
// TODO: list/tensor, map? Need to reconcile with TFT_RECORD, etc.
|
||||
// TODO(mdan): list/tensor, map? Need to reconcile with TFT_RECORD, etc.
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ message FunctionDefLibrary {
|
|||
// with a value. When a GraphDef has a call to a function, it must
|
||||
// have binding for every attr defined in the signature.
|
||||
//
|
||||
// TODO:
|
||||
// TODO(zhifengc):
|
||||
// * device spec, etc.
|
||||
message FunctionDef {
|
||||
// The definition of the function's name, arguments, return values,
|
||||
|
|
|
@ -61,7 +61,7 @@ message NodeDef {
|
|||
// one of the names from the corresponding OpDef's attr field).
|
||||
// The values must have a type matching the corresponding OpDef
|
||||
// attr's type field.
|
||||
// TODO: Add some examples here showing best practices.
|
||||
// TODO(josh11b): Add some examples here showing best practices.
|
||||
map<string, AttrValue> attr = 5;
|
||||
|
||||
message ExperimentalDebugInfo {
|
||||
|
|
|
@ -96,7 +96,7 @@ message OpDef {
|
|||
// Human-readable description.
|
||||
string description = 4;
|
||||
|
||||
// TODO: bool is_optional?
|
||||
// TODO(josh11b): bool is_optional?
|
||||
|
||||
// --- Constraints ---
|
||||
// These constraints are only in effect if specified. Default is no
|
||||
|
@ -139,7 +139,7 @@ message OpDef {
|
|||
// taking input from multiple devices with a tree of aggregate ops
|
||||
// that aggregate locally within each device (and possibly within
|
||||
// groups of nearby devices) before communicating.
|
||||
// TODO: Implement that optimization.
|
||||
// TODO(josh11b): Implement that optimization.
|
||||
bool is_aggregate = 16; // for things like add
|
||||
|
||||
// Other optimizations go here, like
|
||||
|
|
|
@ -53,7 +53,7 @@ message MemoryStats {
|
|||
|
||||
// Time/size stats recorded for a single execution of a graph node.
|
||||
message NodeExecStats {
|
||||
// TODO: Use some more compact form of node identity than
|
||||
// TODO(tucker): Use some more compact form of node identity than
|
||||
// the full string name. Either all processes should agree on a
|
||||
// global id (cost_id?) for each node, or we should use a hash of
|
||||
// the name.
|
||||
|
|
|
@ -16,7 +16,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framewo
|
|||
message TensorProto {
|
||||
DataType dtype = 1;
|
||||
|
||||
// Shape of the tensor. TODO: sort out the 0-rank issues.
|
||||
// Shape of the tensor. TODO(touts): sort out the 0-rank issues.
|
||||
TensorShapeProto tensor_shape = 2;
|
||||
|
||||
// Only one of the representations below is set, one of "tensor_contents" and
|
||||
|
|
|
@ -532,7 +532,7 @@ message ConfigProto {
|
|||
|
||||
// We removed the flag client_handles_error_formatting. Marking the tag
|
||||
// number as reserved.
|
||||
// TODO: Should we just remove this tag so that it can be
|
||||
// TODO(shikharagarwal): Should we just remove this tag so that it can be
|
||||
// used in future for other purpose?
|
||||
reserved 2;
|
||||
|
||||
|
@ -576,7 +576,7 @@ message ConfigProto {
|
|||
// - If isolate_session_state is true, session states are isolated.
|
||||
// - If isolate_session_state is false, session states are shared.
|
||||
//
|
||||
// TODO: Add a single API that consistently treats
|
||||
// TODO(b/129330037): Add a single API that consistently treats
|
||||
// isolate_session_state and ClusterSpec propagation.
|
||||
bool share_session_state_in_clusterspec_propagation = 8;
|
||||
|
||||
|
@ -704,7 +704,7 @@ message ConfigProto {
|
|||
|
||||
// Options for a single Run() call.
|
||||
message RunOptions {
|
||||
// TODO Turn this into a TraceOptions proto which allows
|
||||
// TODO(pbar) Turn this into a TraceOptions proto which allows
|
||||
// tracing to be controlled in a more orthogonal manner?
|
||||
enum TraceLevel {
|
||||
NO_TRACE = 0;
|
||||
|
@ -781,7 +781,7 @@ message RunMetadata {
|
|||
repeated GraphDef partition_graphs = 3;
|
||||
|
||||
message FunctionGraphs {
|
||||
// TODO: Include some sort of function/cache-key identifier?
|
||||
// TODO(nareshmodi): Include some sort of function/cache-key identifier?
|
||||
repeated GraphDef partition_graphs = 1;
|
||||
|
||||
GraphDef pre_optimization_graph = 2;
|
||||
|
|
|
@ -194,7 +194,7 @@ service CoordinationService {
|
|||
|
||||
// Report error to the task. RPC sets the receiving instance of coordination
|
||||
// service agent to error state permanently.
|
||||
// TODO: Consider splitting this into a different RPC service.
|
||||
// TODO(b/195990880): Consider splitting this into a different RPC service.
|
||||
rpc ReportErrorToAgent(ReportErrorToAgentRequest)
|
||||
returns (ReportErrorToAgentResponse);
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ message DebugTensorWatch {
|
|||
// are to be debugged, the callers of Session::Run() must use distinct
|
||||
// debug_urls to make sure that the streamed or dumped events do not overlap
|
||||
// among the invocations.
|
||||
// TODO: More visible documentation of this in g3docs.
|
||||
// TODO(cais): More visible documentation of this in g3docs.
|
||||
repeated string debug_urls = 4;
|
||||
|
||||
// Do not error out if debug op creation fails (e.g., due to dtype
|
||||
|
|
|
@ -12,7 +12,7 @@ option java_package = "org.tensorflow.util";
|
|||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
|
||||
|
||||
// Available modes for extracting debugging information from a Tensor.
|
||||
// TODO: Document the detailed column names and semantics in a separate
|
||||
// TODO(cais): Document the detailed column names and semantics in a separate
|
||||
// markdown file once the implementation settles.
|
||||
enum TensorDebugMode {
|
||||
UNSPECIFIED = 0;
|
||||
|
@ -223,7 +223,7 @@ message DebuggedDevice {
|
|||
// A debugger-generated ID for the device. Guaranteed to be unique within
|
||||
// the scope of the debugged TensorFlow program, including single-host and
|
||||
// multi-host settings.
|
||||
// TODO: Test the uniqueness guarantee in multi-host settings.
|
||||
// TODO(cais): Test the uniqueness guarantee in multi-host settings.
|
||||
int32 device_id = 2;
|
||||
}
|
||||
|
||||
|
@ -264,7 +264,7 @@ message Execution {
|
|||
// field with the DebuggedDevice messages.
|
||||
repeated int32 output_tensor_device_ids = 9;
|
||||
|
||||
// TODO support, add more fields
|
||||
// TODO(cais): When backporting to V1 Session.run() support, add more fields
|
||||
// such as fetches and feeds.
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu
|
|||
|
||||
// Used to serialize and transmit tensorflow::Status payloads through
|
||||
// grpc::Status `error_details` since grpc::Status lacks payload API.
|
||||
// TODO: Use GRPC API once supported.
|
||||
// TODO(b/204231601): Use GRPC API once supported.
|
||||
message GrpcPayloadContainer {
|
||||
map<string, bytes> payloads = 1;
|
||||
}
|
||||
|
|
|
@ -172,7 +172,7 @@ message WaitQueueDoneRequest {
|
|||
}
|
||||
|
||||
message WaitQueueDoneResponse {
|
||||
// TODO: Consider adding NodeExecStats here to be able to
|
||||
// TODO(nareshmodi): Consider adding NodeExecStats here to be able to
|
||||
// propagate some stats.
|
||||
}
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ message ExtendSessionRequest {
|
|||
}
|
||||
|
||||
message ExtendSessionResponse {
|
||||
// TODO: Return something about the operation?
|
||||
// TODO(mrry): Return something about the operation?
|
||||
|
||||
// The new version number for the extended graph, to be used in the next call
|
||||
// to ExtendSession.
|
||||
|
|
|
@ -176,7 +176,7 @@ message SavedBareConcreteFunction {
|
|||
// allows the ConcreteFunction to be called with nest structure inputs. This
|
||||
// field may not be populated. If this field is absent, the concrete function
|
||||
// can only be called with flat inputs.
|
||||
// TODO: support calling saved ConcreteFunction with structured
|
||||
// TODO(b/169361281): support calling saved ConcreteFunction with structured
|
||||
// inputs in C++ SavedModel API.
|
||||
FunctionSpec function_spec = 4;
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu
|
|||
|
||||
// Special header that is associated with a bundle.
|
||||
//
|
||||
// TODO: maybe in the future, we can add information about
|
||||
// TODO(zongheng,zhifengc): maybe in the future, we can add information about
|
||||
// which binary produced this checkpoint, timestamp, etc. Sometime, these can be
|
||||
// valuable debugging information. And if needed, these can be used as defensive
|
||||
// information ensuring reader (binary version) of the checkpoint and the writer
|
||||
|
|
|
@ -188,7 +188,7 @@ message DeregisterGraphRequest {
|
|||
}
|
||||
|
||||
message DeregisterGraphResponse {
|
||||
// TODO: Optionally add summary stats for the graph.
|
||||
// TODO(mrry): Optionally add summary stats for the graph.
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -294,7 +294,7 @@ message RunGraphResponse {
|
|||
|
||||
// If the request asked for execution stats, the cost graph, or the partition
|
||||
// graphs, these are returned here.
|
||||
// TODO: Package these in a RunMetadata instead.
|
||||
// TODO(suharshs): Package these in a RunMetadata instead.
|
||||
StepStats step_stats = 2;
|
||||
CostGraphDef cost_graph = 3;
|
||||
repeated GraphDef partition_graph = 4;
|
||||
|
|
|
@ -13,5 +13,5 @@ message LogMetadata {
|
|||
SamplingConfig sampling_config = 2;
|
||||
// List of tags used to load the relevant MetaGraphDef from SavedModel.
|
||||
repeated string saved_model_tags = 3;
|
||||
// TODO: Add more metadata as mentioned in the bug.
|
||||
// TODO(b/33279154): Add more metadata as mentioned in the bug.
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ message FileSystemStoragePathSourceConfig {
|
|||
|
||||
// A single servable name/base_path pair to monitor.
|
||||
// DEPRECATED: Use 'servables' instead.
|
||||
// TODO: Stop using these fields, and ultimately remove them here.
|
||||
// TODO(b/30898016): Stop using these fields, and ultimately remove them here.
|
||||
string servable_name = 1 [deprecated = true];
|
||||
string base_path = 2 [deprecated = true];
|
||||
|
||||
|
@ -76,7 +76,7 @@ message FileSystemStoragePathSourceConfig {
|
|||
// check for a version to appear later.)
|
||||
// DEPRECATED: Use 'servable_versions_always_present' instead, which includes
|
||||
// this behavior.
|
||||
// TODO: Remove 2019-10-31 or later.
|
||||
// TODO(b/30898016): Remove 2019-10-31 or later.
|
||||
bool fail_if_zero_versions_at_startup = 4 [deprecated = true];
|
||||
|
||||
// If true, the servable is always expected to exist on the underlying
|
||||
|
|
|
@ -9,7 +9,7 @@ import "tensorflow_serving/config/logging_config.proto";
|
|||
option cc_enable_arenas = true;
|
||||
|
||||
// The type of model.
|
||||
// TODO: DEPRECATED.
|
||||
// TODO(b/31336131): DEPRECATED.
|
||||
enum ModelType {
|
||||
MODEL_TYPE_UNSPECIFIED = 0 [deprecated = true];
|
||||
TENSORFLOW = 1 [deprecated = true];
|
||||
|
@ -31,7 +31,7 @@ message ModelConfig {
|
|||
string base_path = 2;
|
||||
|
||||
// Type of model.
|
||||
// TODO: DEPRECATED. Please use 'model_platform' instead.
|
||||
// TODO(b/31336131): DEPRECATED. Please use 'model_platform' instead.
|
||||
ModelType model_type = 3 [deprecated = true];
|
||||
|
||||
// Type of model (e.g. "tensorflow").
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use log::{info, warn};
|
||||
use x509_parser::{prelude::{parse_x509_pem}, parse_x509_certificate};
|
||||
use std::collections::HashMap;
|
||||
use tokio::time::Instant;
|
||||
use tonic::{
|
||||
|
@ -27,6 +28,7 @@ use crate::cli_args::{ARGS, INPUTS, OUTPUTS};
|
|||
use crate::metrics::{
|
||||
NAVI_VERSION, NUM_PREDICTIONS, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
|
||||
NUM_REQUESTS_RECEIVED, NUM_REQUESTS_RECEIVED_BY_MODEL, RESPONSE_TIME_COLLECTOR,
|
||||
CERT_EXPIRY_EPOCH
|
||||
};
|
||||
use crate::predict_service::{Model, PredictService};
|
||||
use crate::tf_proto::tensorflow_serving::model_spec::VersionChoice::Version;
|
||||
|
@ -233,6 +235,12 @@ impl<T: Model> PredictionService for PredictService<T> {
|
|||
}
|
||||
}
|
||||
|
||||
// A function that takes a timestamp as input and returns a ticker stream
|
||||
fn report_expiry(expiry_time: i64) {
|
||||
info!("Certificate expires at epoch: {:?}", expiry_time);
|
||||
CERT_EXPIRY_EPOCH.set(expiry_time as i64);
|
||||
}
|
||||
|
||||
pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
||||
info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS);
|
||||
//we follow SemVer. So here we assume MAJOR.MINOR.PATCH
|
||||
|
@ -249,6 +257,7 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
|||
);
|
||||
}
|
||||
|
||||
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.thread_name("async worker")
|
||||
.worker_threads(ARGS.num_worker_threads)
|
||||
|
@ -266,6 +275,21 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
|||
let mut builder = if ARGS.ssl_dir.is_empty() {
|
||||
Server::builder()
|
||||
} else {
|
||||
// Read the pem file as a string
|
||||
let pem_str = std::fs::read_to_string(format!("{}/server.crt", ARGS.ssl_dir)).unwrap();
|
||||
let res = parse_x509_pem(&pem_str.as_bytes());
|
||||
match res {
|
||||
Ok((rem, pem_2)) => {
|
||||
assert!(rem.is_empty());
|
||||
assert_eq!(pem_2.label, String::from("CERTIFICATE"));
|
||||
let res_x509 = parse_x509_certificate(&pem_2.contents);
|
||||
info!("Certificate label: {}", pem_2.label);
|
||||
assert!(res_x509.is_ok());
|
||||
report_expiry(res_x509.unwrap().1.validity().not_after.timestamp());
|
||||
},
|
||||
_ => panic!("PEM parsing failed: {:?}", res),
|
||||
}
|
||||
|
||||
let key = tokio::fs::read(format!("{}/server.key", ARGS.ssl_dir))
|
||||
.await
|
||||
.expect("can't find key file");
|
||||
|
@ -281,7 +305,7 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
|
|||
let identity = Identity::from_pem(pem.clone(), key);
|
||||
let client_ca_cert = Certificate::from_pem(pem.clone());
|
||||
let tls = ServerTlsConfig::new()
|
||||
.identity(identity)
|
||||
.identity(identity)
|
||||
.client_ca_root(client_ca_cert);
|
||||
Server::builder()
|
||||
.tls_config(tls)
|
||||
|
|
|
@ -171,6 +171,9 @@ lazy_static! {
|
|||
&["model_name"]
|
||||
)
|
||||
.expect("metric can be created");
|
||||
pub static ref CERT_EXPIRY_EPOCH: IntGauge =
|
||||
IntGauge::new(":navi:cert_expiry_epoch", "Timestamp when the current cert expires")
|
||||
.expect("metric can be created");
|
||||
}
|
||||
|
||||
pub fn register_custom_metrics() {
|
||||
|
@ -249,6 +252,10 @@ pub fn register_custom_metrics() {
|
|||
REGISTRY
|
||||
.register(Box::new(CONVERTER_TIME_COLLECTOR.clone()))
|
||||
.expect("collector can be registered");
|
||||
REGISTRY
|
||||
.register(Box::new(CERT_EXPIRY_EPOCH.clone()))
|
||||
.expect("collector can be registered");
|
||||
|
||||
}
|
||||
|
||||
pub fn register_dynamic_metrics(c: &HistogramVec) {
|
||||
|
|
|
@ -189,7 +189,7 @@ pub mod onnx {
|
|||
&version,
|
||||
reporting_feature_ids,
|
||||
Some(metrics::register_dynamic_metrics),
|
||||
)),
|
||||
)?),
|
||||
};
|
||||
onnx_model.warmup()?;
|
||||
Ok(onnx_model)
|
||||
|
|
|
@ -24,7 +24,7 @@ use serde_json::{self, Value};
|
|||
|
||||
pub trait Model: Send + Sync + Display + Debug + 'static {
|
||||
fn warmup(&self) -> Result<()>;
|
||||
//TODO: refactor this to return Vec<Vec<TensorScores>>, i.e.
|
||||
//TODO: refactor this to return vec<vec<TensorScores>>, i.e.
|
||||
//we have the underlying runtime impl to split the response to each client.
|
||||
//It will eliminate some inefficient memory copy in onnx_model.rs as well as simplify code
|
||||
fn do_predict(
|
||||
|
@ -222,8 +222,8 @@ impl<T: Model> PredictService<T> {
|
|||
.map(|b| b.parse().unwrap())
|
||||
.collect::<Vec<u64>>();
|
||||
let no_msg_wait_millis = *batch_time_out_millis.iter().min().unwrap();
|
||||
let mut all_model_predictors =
|
||||
ArrayVec::<ArrayVec<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS>::new();
|
||||
let mut all_model_predictors: ArrayVec::<ArrayVec<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS> =
|
||||
(0 ..MAX_NUM_MODELS).map( |_| ArrayVec::<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>::new()).collect();
|
||||
loop {
|
||||
let msg = rx.try_recv();
|
||||
let no_more_msg = match msg {
|
||||
|
@ -272,27 +272,23 @@ impl<T: Model> PredictService<T> {
|
|||
queue_reset_ts: Instant::now(),
|
||||
queue_earliest_rq_ts: Instant::now(),
|
||||
};
|
||||
if idx < all_model_predictors.len() {
|
||||
metrics::NEW_MODEL_SNAPSHOT
|
||||
.with_label_values(&[&MODEL_SPECS[idx]])
|
||||
.inc();
|
||||
assert!(idx < all_model_predictors.len());
|
||||
metrics::NEW_MODEL_SNAPSHOT
|
||||
.with_label_values(&[&MODEL_SPECS[idx]])
|
||||
.inc();
|
||||
|
||||
info!("now we serve updated model: {}", predictor.model);
|
||||
//we can do this since the vector is small
|
||||
let predictors = &mut all_model_predictors[idx];
|
||||
if predictors.len() == ARGS.versions_per_model {
|
||||
predictors.remove(predictors.len() - 1);
|
||||
}
|
||||
predictors.insert(0, predictor);
|
||||
} else {
|
||||
info!("now we serve new model: {:}", predictor.model);
|
||||
let mut predictors =
|
||||
ArrayVec::<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>::new();
|
||||
predictors.push(predictor);
|
||||
all_model_predictors.push(predictors);
|
||||
//check the invariant that we always push the last model to the end
|
||||
assert_eq!(all_model_predictors.len(), idx + 1)
|
||||
//we can do this since the vector is small
|
||||
let predictors = &mut all_model_predictors[idx];
|
||||
if predictors.len() == 0 {
|
||||
info!("now we serve new model: {}", predictor.model);
|
||||
}
|
||||
else {
|
||||
info!("now we serve updated model: {}", predictor.model);
|
||||
}
|
||||
if predictors.len() == ARGS.versions_per_model {
|
||||
predictors.remove(predictors.len() - 1);
|
||||
}
|
||||
predictors.insert(0, predictor);
|
||||
false
|
||||
}
|
||||
Err(TryRecvError::Empty) => true,
|
||||
|
|
|
@ -5,39 +5,49 @@ use std::fmt::Display;
|
|||
*/
|
||||
#[derive(Debug)]
|
||||
pub enum SegDenseError {
|
||||
IoError(std::io::Error),
|
||||
Json(serde_json::Error),
|
||||
JsonMissingRoot,
|
||||
JsonMissingObject,
|
||||
JsonMissingArray,
|
||||
JsonArraySize,
|
||||
JsonMissingInputFeature,
|
||||
IoError(std::io::Error),
|
||||
Json(serde_json::Error),
|
||||
JsonMissingRoot,
|
||||
JsonMissingObject,
|
||||
JsonMissingArray,
|
||||
JsonArraySize,
|
||||
JsonMissingInputFeature,
|
||||
}
|
||||
|
||||
impl Display for SegDenseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
|
||||
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
|
||||
SegDenseError::JsonMissingRoot => write!(f, "{}", "SegDense JSON: Root Node note found!"),
|
||||
SegDenseError::JsonMissingObject => write!(f, "{}", "SegDense JSON: Object note found!"),
|
||||
SegDenseError::JsonMissingArray => write!(f, "{}", "SegDense JSON: Array Node note found!"),
|
||||
SegDenseError::JsonArraySize => write!(f, "{}", "SegDense JSON: Array size not as expected!"),
|
||||
SegDenseError::JsonMissingInputFeature => write!(f, "{}", "SegDense JSON: Missing input feature!"),
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
|
||||
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
|
||||
SegDenseError::JsonMissingRoot => {
|
||||
write!(f, "{}", "SegDense JSON: Root Node note found!")
|
||||
}
|
||||
SegDenseError::JsonMissingObject => {
|
||||
write!(f, "{}", "SegDense JSON: Object note found!")
|
||||
}
|
||||
SegDenseError::JsonMissingArray => {
|
||||
write!(f, "{}", "SegDense JSON: Array Node note found!")
|
||||
}
|
||||
SegDenseError::JsonArraySize => {
|
||||
write!(f, "{}", "SegDense JSON: Array size not as expected!")
|
||||
}
|
||||
SegDenseError::JsonMissingInputFeature => {
|
||||
write!(f, "{}", "SegDense JSON: Missing input feature!")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for SegDenseError {}
|
||||
|
||||
impl From<std::io::Error> for SegDenseError {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
SegDenseError::IoError(err)
|
||||
}
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
SegDenseError::IoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for SegDenseError {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
SegDenseError::Json(err)
|
||||
}
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
SegDenseError::Json(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
pub mod error;
|
||||
pub mod segdense_transform_spec_home_recap_2022;
|
||||
pub mod mapper;
|
||||
pub mod util;
|
||||
pub mod segdense_transform_spec_home_recap_2022;
|
||||
pub mod util;
|
||||
|
|
|
@ -5,19 +5,18 @@ use segdense::error::SegDenseError;
|
|||
use segdense::util;
|
||||
|
||||
fn main() -> Result<(), SegDenseError> {
|
||||
env_logger::init();
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
let schema_file_name: &str = if args.len() == 1 {
|
||||
"json/compact.json"
|
||||
} else {
|
||||
&args[1]
|
||||
};
|
||||
env_logger::init();
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
let json_str = fs::read_to_string(schema_file_name)?;
|
||||
let schema_file_name: &str = if args.len() == 1 {
|
||||
"json/compact.json"
|
||||
} else {
|
||||
&args[1]
|
||||
};
|
||||
|
||||
util::safe_load_config(&json_str)?;
|
||||
let json_str = fs::read_to_string(schema_file_name)?;
|
||||
|
||||
Ok(())
|
||||
util::safe_load_config(&json_str)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -19,13 +19,13 @@ pub struct FeatureMapper {
|
|||
impl FeatureMapper {
|
||||
pub fn new() -> FeatureMapper {
|
||||
FeatureMapper {
|
||||
map: HashMap::new()
|
||||
map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait MapWriter {
|
||||
fn set(&mut self, feature_id: i64, info: FeatureInfo);
|
||||
fn set(&mut self, feature_id: i64, info: FeatureInfo);
|
||||
}
|
||||
|
||||
pub trait MapReader {
|
||||
|
|
|
@ -164,7 +164,6 @@ pub struct ComplexFeatureTypeTransformSpec {
|
|||
pub tensor_shape: Vec<i64>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InputFeatureMapRecord {
|
||||
|
|
|
@ -1,23 +1,23 @@
|
|||
use log::debug;
|
||||
use std::fs;
|
||||
use log::{debug};
|
||||
|
||||
use serde_json::{Value, Map};
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
use crate::error::SegDenseError;
|
||||
use crate::mapper::{FeatureMapper, FeatureInfo, MapWriter};
|
||||
use crate::mapper::{FeatureInfo, FeatureMapper, MapWriter};
|
||||
use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature};
|
||||
|
||||
pub fn load_config(file_name: &str) -> seg_dense::Root {
|
||||
let json_str = fs::read_to_string(file_name).expect(
|
||||
&format!("Unable to load segdense file {}", file_name));
|
||||
let seg_dense_config = parse(&json_str).expect(
|
||||
&format!("Unable to parse segdense file {}", file_name));
|
||||
return seg_dense_config;
|
||||
pub fn load_config(file_name: &str) -> Result<seg_dense::Root, SegDenseError> {
|
||||
let json_str = fs::read_to_string(file_name)?;
|
||||
// &format!("Unable to load segdense file {}", file_name));
|
||||
let seg_dense_config = parse(&json_str)?;
|
||||
// &format!("Unable to parse segdense file {}", file_name));
|
||||
Ok(seg_dense_config)
|
||||
}
|
||||
|
||||
pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> {
|
||||
let root: seg_dense::Root = serde_json::from_str(json_str)?;
|
||||
return Ok(root);
|
||||
Ok(root)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -44,15 +44,8 @@ pub fn safe_load_config(json_str: &str) -> Result<FeatureMapper, SegDenseError>
|
|||
load_from_parsed_config(root)
|
||||
}
|
||||
|
||||
pub fn load_from_parsed_config_ref(root: &seg_dense::Root) -> FeatureMapper {
|
||||
load_from_parsed_config(root.clone()).unwrap_or_else(
|
||||
|error| panic!("Error loading all_config.json - {}", error))
|
||||
}
|
||||
|
||||
// Perf note : make 'root' un-owned
|
||||
pub fn load_from_parsed_config(root: seg_dense::Root) ->
|
||||
Result<FeatureMapper, SegDenseError> {
|
||||
|
||||
pub fn load_from_parsed_config(root: seg_dense::Root) -> Result<FeatureMapper, SegDenseError> {
|
||||
let v = root.input_features_map;
|
||||
|
||||
// Do error check
|
||||
|
@ -86,7 +79,7 @@ pub fn load_from_parsed_config(root: seg_dense::Root) ->
|
|||
Some(info) => {
|
||||
debug!("{:?}", info);
|
||||
fm.set(feature_id, info)
|
||||
},
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
|
@ -94,19 +87,22 @@ pub fn load_from_parsed_config(root: seg_dense::Root) ->
|
|||
Ok(fm)
|
||||
}
|
||||
#[allow(dead_code)]
|
||||
fn add_feature_info_to_mapper(feature_mapper: &mut FeatureMapper, input_features: &Vec<InputFeature>) {
|
||||
fn add_feature_info_to_mapper(
|
||||
feature_mapper: &mut FeatureMapper,
|
||||
input_features: &Vec<InputFeature>,
|
||||
) {
|
||||
for input_feature in input_features.iter() {
|
||||
let feature_id = input_feature.feature_id;
|
||||
let feature_info = to_feature_info(input_feature);
|
||||
|
||||
match feature_info {
|
||||
Some(info) => {
|
||||
debug!("{:?}", info);
|
||||
feature_mapper.set(feature_id, info)
|
||||
},
|
||||
None => (),
|
||||
let feature_id = input_feature.feature_id;
|
||||
let feature_info = to_feature_info(input_feature);
|
||||
|
||||
match feature_info {
|
||||
Some(info) => {
|
||||
debug!("{:?}", info);
|
||||
feature_mapper.set(feature_id, info)
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<FeatureInfo> {
|
||||
|
@ -139,7 +135,7 @@ pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<Featur
|
|||
2 => 0,
|
||||
3 => 2,
|
||||
_ => -1,
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
if input_feature.index < 0 {
|
||||
|
@ -156,4 +152,3 @@ pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<Featur
|
|||
index_within_tensor: input_feature.index,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue