#include "internal/interpolate.h" #include "internal/error.h" #include namespace twml { template void mdlInfer(Tensor &output_keys, Tensor &output_vals, const Tensor &input_keys, const Tensor &input_vals, const Tensor &bin_ids, const Tensor &bin_vals, const Tensor &feature_offsets, bool return_bin_indices) { auto okeysData = output_keys.getData(); auto ovalsData = output_vals.getData(); uint64_t okeysStride = output_keys.getStride(0); uint64_t ovaluesStride = output_vals.getStride(0); auto ikeysData = input_keys.getData(); auto ivalsData = input_vals.getData(); uint64_t ikeysStride = input_keys.getStride(0); uint64_t ivaluesStride = input_vals.getStride(0); auto xsData = bin_vals.getData(); auto ysData = bin_ids.getData(); uint64_t xsStride = bin_vals.getStride(0); uint64_t ysStride = bin_ids.getStride(0); auto offsetData = feature_offsets.getData(); uint64_t size = input_keys.getDim(0); uint64_t total_bins = bin_ids.getNumElements(); uint64_t fsize = feature_offsets.getNumElements(); for (uint64_t i = 0; i < size; i++) { int64_t ikey = ikeysData[i * ikeysStride] - TWML_INDEX_BASE; T val = ivalsData[i * ivaluesStride]; if (ikey == -1) { ovalsData[i * ovaluesStride] = val; continue; } // Perform interpolation uint64_t offset = offsetData[ikey]; uint64_t next_offset = (ikey == (int64_t)(fsize - 1)) ? total_bins : offsetData[ikey + 1]; uint64_t mainSize = next_offset - offset; const T *lxsData = xsData + offset; const int64_t *lysData = ysData + offset; int64_t okey = interpolation(lxsData, xsStride, lysData, ysStride, val, mainSize, NEAREST, 0, return_bin_indices); okeysData[i * okeysStride] = okey + TWML_INDEX_BASE; ovalsData[i * ovaluesStride] = 1; } } void mdlInfer(Tensor &output_keys, Tensor &output_vals, const Tensor &input_keys, const Tensor &input_vals, const Tensor &bin_ids, const Tensor &bin_vals, const Tensor &feature_offsets, bool return_bin_indices) { if (input_keys.getType() != TWML_TYPE_INT64) { throw twml::Error(TWML_ERR_TYPE, "input_keys must be a Long Tensor"); } if (output_keys.getType() != TWML_TYPE_INT64) { throw twml::Error(TWML_ERR_TYPE, "output_keys must be a Long Tensor"); } if (bin_ids.getType() != TWML_TYPE_INT64) { throw twml::Error(TWML_ERR_TYPE, "bin_ids must be a Long Tensor"); } if (feature_offsets.getType() != TWML_TYPE_INT64) { throw twml::Error(TWML_ERR_TYPE, "bin_ids must be a Long Tensor"); } if (input_vals.getType() != bin_vals.getType()) { throw twml::Error(TWML_ERR_TYPE, "Data type of input_vals does not match type of bin_vals"); } if (bin_vals.getNumDims() != 1) { throw twml::Error(TWML_ERR_SIZE, "bin_vals must be 1 Dimensional"); } if (bin_ids.getNumDims() != 1) { throw twml::Error(TWML_ERR_SIZE, "bin_ids must be 1 Dimensional"); } if (bin_vals.getNumElements() != bin_ids.getNumElements()) { throw twml::Error(TWML_ERR_SIZE, "Dimensions of bin_vals and bin_ids do not match"); } if (feature_offsets.getStride(0) != 1) { throw twml::Error(TWML_ERR_SIZE, "feature_offsets must be contiguous"); } switch (input_vals.getType()) { case TWML_TYPE_FLOAT: twml::mdlInfer(output_keys, output_vals, input_keys, input_vals, bin_ids, bin_vals, feature_offsets, return_bin_indices); break; case TWML_TYPE_DOUBLE: twml::mdlInfer(output_keys, output_vals, input_keys, input_vals, bin_ids, bin_vals, feature_offsets, return_bin_indices); break; default: throw twml::Error(TWML_ERR_TYPE, "Unsupported datatype for mdlInfer"); } } const int DEFAULT_INTERPOLATION_LOWEST = 0; /** * @param output tensor to hold linear or nearest interpolation output. * This function does not allocate space. * The output tensor must have space allcoated. * @param input input tensor; size must match output. * input is assumed to have size [batch_size, number_of_labels]. * @param xs the bins. * @param ys the values for the bins. * @param mode: linear or nearest InterpolationMode. * linear is used for isotonic calibration. * nearest is used for MDL calibration and MDL inference. * * @return Returns nothing. Output is stored into the output tensor. * * This is used by IsotonicCalibration inference. */ template void interpolation( Tensor output, const Tensor input, const Tensor xs, const Tensor ys, const InterpolationMode mode) { // Sanity check: input and output should have two dims. if (input.getNumDims() != 2 || output.getNumDims() != 2) { throw twml::Error(TWML_ERR_TYPE, "input and output should have 2 dimensions."); } // Sanity check: input and output size should match. for (int i = 0; i < input.getNumDims(); i++) { if (input.getDim(i) != output.getDim(i)) { throw twml::Error(TWML_ERR_TYPE, "input and output mismatch in size."); } } // Sanity check: number of labels in input should match // number of labels in xs / ys. if (input.getDim(1) != xs.getDim(0) || input.getDim(1) != ys.getDim(0)) { throw twml::Error(TWML_ERR_TYPE, "input, xs, ys should have the same number of labels."); } const uint64_t inputStride0 = input.getStride(0); const uint64_t inputStride1 = input.getStride(1); const uint64_t outputStride0 = output.getStride(0); const uint64_t outputStride1 = output.getStride(1); const uint64_t xsStride0 = xs.getStride(0); const uint64_t xsStride1 = xs.getStride(1); const uint64_t ysStride0 = ys.getStride(0); const uint64_t ysStride1 = ys.getStride(1); const uint64_t mainSize = xs.getDim(1); // for each value in the input matrix, compute output value by // calling interpolation. auto inputData = input.getData(); auto outputData = output.getData(); auto xsData = xs.getData(); auto ysData = ys.getData(); for (uint64_t i = 0; i < input.getDim(0); i++) { for (uint64_t j = 0; j < input.getDim(1); j++) { const T val = inputData[i * inputStride0 + j * inputStride1]; const T *lxsData = xsData + j * xsStride0; const T *lysData = ysData + j * ysStride0; const T res = interpolation( lxsData, xsStride1, lysData, ysStride1, val, mainSize, mode, DEFAULT_INTERPOLATION_LOWEST); outputData[i * outputStride0 + j * outputStride1] = res; } } } void linearInterpolation( Tensor output, const Tensor input, const Tensor xs, const Tensor ys) { switch (input.getType()) { case TWML_TYPE_FLOAT: twml::interpolation(output, input, xs, ys, LINEAR); break; case TWML_TYPE_DOUBLE: twml::interpolation(output, input, xs, ys, LINEAR); break; default: throw twml::Error(TWML_ERR_TYPE, "Unsupported datatype for linearInterpolation."); } } void nearestInterpolation( Tensor output, const Tensor input, const Tensor xs, const Tensor ys) { switch (input.getType()) { case TWML_TYPE_FLOAT: twml::interpolation(output, input, xs, ys, NEAREST); break; case TWML_TYPE_DOUBLE: twml::interpolation(output, input, xs, ys, NEAREST); break; default: throw twml::Error(TWML_ERR_TYPE, "Unsupported datatype for nearestInterpolation."); } } } // namespace twml twml_err twml_optim_mdl_infer(twml_tensor output_keys, twml_tensor output_vals, const twml_tensor input_keys, const twml_tensor input_vals, const twml_tensor bin_ids, const twml_tensor bin_vals, const twml_tensor feature_offsets, bool return_bin_indices) { HANDLE_EXCEPTIONS( using namespace twml; mdlInfer(*getTensor(output_keys), *getTensor(output_vals), *getConstTensor(input_keys), *getConstTensor(input_vals), *getConstTensor(bin_ids), *getConstTensor(bin_vals), *getConstTensor(feature_offsets), return_bin_indices);); return TWML_ERR_NONE; } twml_err twml_optim_nearest_interpolation( twml_tensor output, const twml_tensor input, const twml_tensor xs, const twml_tensor ys) { HANDLE_EXCEPTIONS( using namespace twml; nearestInterpolation(*getTensor(output), *getConstTensor(input), *getConstTensor(xs), *getConstTensor(ys));); return TWML_ERR_NONE; }