#include "internal/endianutils.h" #include "internal/error.h" #include "internal/thrift.h" #include #include #include #include #include #include #include #include #include #include // When the number of predictions is very high, as some cases that Ads wants, the generic thrift // encoder becomes super expensive because we have to deal with lua tables. // This function is a special operation to efficiently write a batch prediction responses based on // tensors. namespace twml { BatchPredictionResponse::BatchPredictionResponse( const Tensor &keys, const Tensor &values, const Tensor &dense_keys, const std::vector &dense_values ) : keys_(keys), values_(values), dense_keys_(dense_keys), dense_values_(dense_values) { // determine batch size if (values_.getNumDims() > 0) { batch_size_ = values_.getDim(0); } else if (dense_keys_.getNumElements() < 1) { throw twml::Error(TWML_ERR_TYPE, "Continuous values and dense tensors are both empty"); } else if (dense_keys_.getNumElements() != dense_values_.size()) { throw twml::Error(TWML_ERR_TYPE, "Number of tensors not equal to number of keys"); } else { // dim 0 for each tensor indexes batch elements std::vector batch_sizes; batch_sizes.reserve(dense_values_.size()); for (int i = 0; i < dense_values_.size(); i++) batch_sizes.push_back(dense_values_.at(i).getDim(0)); if (std::adjacent_find( batch_sizes.begin(), batch_sizes.end(), std::not_equal_to()) != batch_sizes.end()) throw twml::Error(TWML_ERR_TYPE, "Batch size (dim 0) for all tensors must be the same"); batch_size_ = dense_values.at(0).getDim(0); } } void BatchPredictionResponse::encode(twml::ThriftWriter &thrift_writer) { if (hasContinuous()) { switch (values_.getType()) { case TWML_TYPE_FLOAT: serializePredictions(thrift_writer); break; case TWML_TYPE_DOUBLE: serializePredictions(thrift_writer); break; default: throw twml::Error(TWML_ERR_TYPE, "Predictions must be float or double."); } } else { // dense tensor predictions serializePredictions(thrift_writer); } } template void BatchPredictionResponse::serializePredictions(twml::ThriftWriter &thrift_writer) { twml::DataRecordWriter record_writer = twml::DataRecordWriter(thrift_writer); // start BatchPredictionResponse thrift_writer.writeStructFieldHeader(TTYPE_LIST, BPR_PREDICTIONS); thrift_writer.writeListHeader(TTYPE_STRUCT, getBatchSize()); for (int i = 0; i < getBatchSize(); i++) { twml::DataRecord record = twml::DataRecord(); if (hasContinuous()) { const T *values = values_.getData(); const int64_t *local_keys = keys_.getData(); const T *local_values = values + (i * getPredictionSize()); record.addContinuous(local_keys, getPredictionSize(), local_values); } if (hasDenseTensors()) { const int64_t *local_dense_keys = dense_keys_.getData(); for (int j = 0; j < dense_keys_.getNumElements(); j++) { const RawTensor &dense_value = dense_values_.at(j).getSlice(i); record.addRawTensor(local_dense_keys[j], dense_value); } } record_writer.write(record); } // end BatchPredictionResponse thrift_writer.writeStructStop(); } // calculate expected binary Thrift size (no memory is copied) uint64_t BatchPredictionResponse::encodedSize() { bool dry_mode = true; twml::ThriftWriter dry_writer = twml::ThriftWriter(nullptr, 0, dry_mode); encode(dry_writer); return dry_writer.getBytesWritten(); } void BatchPredictionResponse::write(Tensor &result) { size_t result_size = result.getNumElements(); uint8_t *result_data = result.getData(); if (result_size != this->encodedSize()) { throw twml::Error(TWML_ERR_SIZE, "Sizes do not match"); } twml::ThriftWriter writer = twml::ThriftWriter(result_data, result_size); encode(writer); } } // namespace twml