#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/op_kernel.h" #include #include "tensorflow_utils.h" #include "resource_utils.h" REGISTER_OP("DecodeAndHashBatchPredictionRequest") .Input("input_bytes: uint8") .Attr("keep_features: list(int)") .Attr("keep_codes: list(int)") .Attr("decode_mode: int = 0") .Output("hashed_data_record_handle: resource") .SetShapeFn(shape_inference::ScalarShape) .Doc(R"doc( A tensorflow OP that decodes batch prediction request and creates a handle to the batch of hashed data records. Attr keep_features: a list of int ids to keep. keep_codes: their corresponding code. decode_mode: integer, indicates which decoding method to use. Let a sparse continuous have a feature_name and a dict of {name: value}. 0 indicates feature_ids are computed as hash(name). 1 indicates feature_ids are computed as hash(feature_name, name) shared_name: name used by the resource handle inside the resource manager. container: name used by the container of the resources. shared_name and container are required when inheriting from ResourceOpKernel. Input input_bytes: Input tensor containing the serialized batch of BatchPredictionRequest. Outputs hashed_data_record_handle: A resource handle to the HashedDataRecordResource containing batch of HashedDataRecords. )doc"); class DecodeAndHashBatchPredictionRequest : public OpKernel { public: explicit DecodeAndHashBatchPredictionRequest(OpKernelConstruction* context) : OpKernel(context) { std::vector keep_features; std::vector keep_codes; OP_REQUIRES_OK(context, context->GetAttr("keep_features", &keep_features)); OP_REQUIRES_OK(context, context->GetAttr("keep_codes", &keep_codes)); OP_REQUIRES_OK(context, context->GetAttr("decode_mode", &m_decode_mode)); OP_REQUIRES(context, keep_features.size() == keep_codes.size(), errors::InvalidArgument("keep keys and values must have same size.")); #ifdef USE_DENSE_HASH m_keep_map.set_empty_key(0); #endif // USE_DENSE_HASH for (uint64_t i = 0; i < keep_features.size(); i++) { m_keep_map[keep_features[i]] = keep_codes[i]; } } private: twml::Map m_keep_map; int64 m_decode_mode; void Compute(OpKernelContext* context) override { try { HashedDataRecordResource *resource = nullptr; OP_REQUIRES_OK(context, makeResourceHandle(context, 0, &resource)); // Store the input bytes in the resource so it isnt freed before the resource. // This is necessary because we are not copying the contents for tensors. resource->input = context->input(0); const uint8_t *input_bytes = resource->input.flat().data(); twml::HashedDataRecordReader reader; twml::HashedBatchPredictionRequest bpr; reader.setKeepMap(&m_keep_map); reader.setBuffer(input_bytes); reader.setDecodeMode(m_decode_mode); bpr.decode(reader); resource->common = std::move(bpr.common()); resource->records = std::move(bpr.requests()); // Each datarecord has a copy of common features. // Initialize total_size by common_size * num_records int64 common_size = static_cast(resource->common.totalSize()); int64 num_records = static_cast(resource->records.size()); int64 total_size = common_size * num_records; for (const auto &record : resource->records) { total_size += static_cast(record.totalSize()); } resource->total_size = total_size; resource->num_labels = 0; resource->num_weights = 0; } catch (const std::exception &e) { context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); } } }; REGISTER_KERNEL_BUILDER( Name("DecodeAndHashBatchPredictionRequest").Device(DEVICE_CPU), DecodeAndHashBatchPredictionRequest); REGISTER_OP("DecodeBatchPredictionRequest") .Input("input_bytes: uint8") .Attr("keep_features: list(int)") .Attr("keep_codes: list(int)") .Output("data_record_handle: resource") .SetShapeFn(shape_inference::ScalarShape) .Doc(R"doc( A tensorflow OP that decodes batch prediction request and creates a handle to the batch of data records. Attr keep_features: a list of int ids to keep. keep_codes: their corresponding code. shared_name: name used by the resource handle inside the resource manager. container: name used by the container of the resources. shared_name and container are required when inheriting from ResourceOpKernel. Input input_bytes: Input tensor containing the serialized batch of BatchPredictionRequest. Outputs data_record_handle: A resource handle to the DataRecordResource containing batch of DataRecords. )doc"); class DecodeBatchPredictionRequest : public OpKernel { public: explicit DecodeBatchPredictionRequest(OpKernelConstruction* context) : OpKernel(context) { std::vector keep_features; std::vector keep_codes; OP_REQUIRES_OK(context, context->GetAttr("keep_features", &keep_features)); OP_REQUIRES_OK(context, context->GetAttr("keep_codes", &keep_codes)); OP_REQUIRES(context, keep_features.size() == keep_codes.size(), errors::InvalidArgument("keep keys and values must have same size.")); #ifdef USE_DENSE_HASH m_keep_map.set_empty_key(0); #endif // USE_DENSE_HASH for (uint64_t i = 0; i < keep_features.size(); i++) { m_keep_map[keep_features[i]] = keep_codes[i]; } } private: twml::Map m_keep_map; void Compute(OpKernelContext* context) override { try { DataRecordResource *resource = nullptr; OP_REQUIRES_OK(context, makeResourceHandle(context, 0, &resource)); // Store the input bytes in the resource so it isnt freed before the resource. // This is necessary because we are not copying the contents for tensors. resource->input = context->input(0); const uint8_t *input_bytes = resource->input.flat().data(); twml::DataRecordReader reader; twml::BatchPredictionRequest bpr; reader.setKeepMap(&m_keep_map); reader.setBuffer(input_bytes); bpr.decode(reader); resource->common = std::move(bpr.common()); resource->records = std::move(bpr.requests()); resource->num_weights = 0; resource->num_labels = 0; resource->keep_map = &m_keep_map; } catch (const std::exception &e) { context->CtxFailureWithWarning(errors::InvalidArgument(e.what())); } } }; REGISTER_KERNEL_BUILDER( Name("DecodeBatchPredictionRequest").Device(DEVICE_CPU), DecodeBatchPredictionRequest);