12 #ifndef PLSSVM_DETAIL_IO_LIBSVM_MODEL_PARSING_HPP_
13 #define PLSSVM_DETAIL_IO_LIBSVM_MODEL_PARSING_HPP_
22 #include "fmt/compile.h"
23 #include "fmt/format.h"
37 #include <string_view>
82 template <
typename real_type,
typename label_type,
typename size_type>
83 [[nodiscard]]
inline std::tuple<plssvm::parameter, real_type, std::vector<label_type>, std::size_t>
parse_libsvm_model_header(
const std::vector<std::string_view> &lines) {
87 size_type num_support_vectors{};
90 bool svm_type_set{
false };
91 bool kernel_type_set{
false };
92 bool nr_class_set{
false };
93 bool total_sv_set{
false };
94 bool rho_set{
false };
95 bool label_set{
false };
96 bool nr_sv_set{
false };
98 std::vector<label_type> labels{};
99 std::vector<size_type> num_support_vectors_per_class{};
102 std::size_t header_line = 0;
104 for (; header_line < lines.size(); ++header_line) {
110 std::string_view value{ line };
111 value.remove_prefix(std::min(value.find_first_of(
' ') + 1, value.size()));
116 if (value !=
"c_svc") {
123 std::istringstream iss{ std::string{ value } };
124 iss >> params.kernel_type;
129 kernel_type_set =
true;
132 params.gamma =
detail::convert_to<
typename decltype(params.gamma)::value_type>(value);
135 params.degree =
detail::convert_to<
typename decltype(params.degree)::value_type>(value);
138 params.coef0 =
detail::convert_to<
typename decltype(params.coef0)::value_type>(value);
141 nr_class = detail::convert_to<unsigned long long>(value);
146 num_support_vectors = detail::convert_to<size_type>(value);
147 if (num_support_vectors == 0) {
154 rho = detail::convert_to<real_type>(value);
159 std::string_view original_line =
detail::trim(lines[header_line]);
160 original_line.remove_prefix(std::min(original_line.find_first_of(
' ') + 1, original_line.size()));
162 labels = detail::split_as<label_type>(original_line,
' ');
163 if (labels.size() < 2) {
164 throw invalid_file_format_exception{ fmt::format(
"At least two labels must be set, but only {} label ([{}]) was given!", labels.size(), fmt::join(labels,
", ")) };
167 std::set<label_type> unique_labels{};
168 for (
const label_type &label : labels) {
169 unique_labels.insert(label);
171 if (labels.size() != unique_labels.size()) {
172 throw invalid_file_format_exception{ fmt::format(
"Provided {} labels but only {} of them was/where unique!", labels.size(), unique_labels.size()) };
178 num_support_vectors_per_class = detail::split_as<size_type>(value,
' ');
179 if (num_support_vectors_per_class.size() < 2) {
180 throw invalid_file_format_exception{ fmt::format(
"At least two nr_sv must be set, but only {} ([{}]) was given!", num_support_vectors_per_class.size(), fmt::join(num_support_vectors_per_class,
", ")) };
184 }
else if (line ==
"sv") {
197 if (!kernel_type_set) {
201 switch (params.kernel_type) {
203 if (!params.degree.is_default()) {
206 if (!params.gamma.is_default()) {
209 if (!params.coef0.is_default()) {
217 if (!params.degree.is_default()) {
218 throw invalid_file_format_exception{
"Explicitly provided a value for the degree parameter which is not used in the radial basis function kernel!" };
220 if (!params.coef0.is_default()) {
221 throw invalid_file_format_exception{
"Explicitly provided a value for the coef0 parameter which is not used in the radial basis function kernel!" };
238 if (nr_class != labels.size()) {
239 throw invalid_file_format_exception{ fmt::format(
"The number of classes (nr_class) is {}, but the provided number of different labels is {} (label)!", nr_class, labels.size()) };
245 if (nr_class != num_support_vectors_per_class.size()) {
246 throw invalid_file_format_exception{ fmt::format(
"The number of classes (nr_class) is {}, but the provided number of different labels is {} (nr_sv)!", nr_class, num_support_vectors_per_class.size()) };
249 const auto nr_sv_sum = std::accumulate(num_support_vectors_per_class.begin(), num_support_vectors_per_class.end(), size_type{ 0 });
250 if (nr_sv_sum != num_support_vectors) {
251 throw invalid_file_format_exception{ fmt::format(
"The total number of support vectors is {}, but the sum of nr_sv is {}!", num_support_vectors, nr_sv_sum) };
254 if (header_line + 1 >= lines.size()) {
259 std::vector<label_type> data_labels(num_support_vectors);
261 for (size_type i = 0; i < labels.size(); ++i) {
262 std::fill(data_labels.begin() + pos, data_labels.begin() + pos + num_support_vectors_per_class[i], labels[i]);
263 pos += num_support_vectors_per_class[i];
268 throw invalid_file_format_exception{ fmt::format(
"Currently only binary classification is supported, but {} different label where given!", nr_class) };
271 return std::make_tuple(params, rho, std::move(data_labels), header_line + 1);
295 template <
typename real_type,
typename label_type>
300 std::string out_string = fmt::format(
"svm_type c_svc\nkernel_type {}\n", params.
kernel_type);
306 out_string += fmt::format(
"degree {}\ngamma {}\ncoef0 {}\n", params.
degree, params.
gamma, params.
coef0);
309 out_string += fmt::format(
"gamma {}\n", params.
gamma);
314 const std::vector<label_type> label_values = data.
different_labels().value();
317 std::map<label_type, std::size_t> label_counts_map;
318 const std::vector<label_type> labels = data.
labels().value();
319 for (
const label_type &l : labels) {
320 ++label_counts_map[l];
325 label_counts[i] = label_counts_map[label_values[i]];
328 out_string += fmt::format(
"nr_class {}\nlabel {}\ntotal_sv {}\nnr_sv {}\nrho {}\nSV\n",
330 fmt::join(label_values,
" "),
332 fmt::join(label_counts,
" "),
337 "\n{}\n", out_string);
339 out.print(
"{}", out_string);
370 template <
typename real_type,
typename label_type>
375 const std::vector<std::vector<real_type>> &support_vectors = data.
data();
376 const std::vector<label_type> &labels = data.
labels().value();
380 fmt::ostream out = fmt::output_file(filename);
394 static constexpr std::size_t CHARS_PER_BLOCK = 48;
396 static constexpr std::size_t BLOCK_SIZE = 128;
398 constexpr std::size_t STRING_BUFFER_SIZE = 1024 * 1024;
401 auto format_libsvm_line = [](std::string &output,
const real_type a,
const std::vector<real_type> &d) {
402 static constexpr std::size_t STACK_BUFFER_SIZE = BLOCK_SIZE * CHARS_PER_BLOCK;
403 static char buffer[STACK_BUFFER_SIZE];
404 #pragma omp threadprivate(buffer)
406 output.append(fmt::format(FMT_COMPILE(
"{:.10e} "), a));
407 for (
typename std::vector<real_type>::size_type j = 0; j < d.size(); j += BLOCK_SIZE) {
409 for (std::size_t i = 0; i < std::min<std::size_t>(BLOCK_SIZE, d.size() - j); ++i) {
410 if (d[j + i] != real_type{ 0.0 }) {
412 ptr = fmt::format_to(ptr, FMT_COMPILE(
"{}:{:.10e} "), j + i + 1, d[j + i]);
415 output.append(buffer, ptr - buffer);
417 output.push_back(
'\n');
421 auto counts = std::make_unique<volatile int[]>(label_order.size());
422 #pragma omp parallel default(none) shared(counts, alpha, format_libsvm_line, label_order, labels, support_vectors, out) firstprivate(BLOCK_SIZE, CHARS_PER_BLOCK, num_features)
425 std::string out_string;
426 out_string.reserve(STRING_BUFFER_SIZE + (num_features + 1) * CHARS_PER_BLOCK);
429 #pragma omp for nowait
430 for (
typename std::vector<real_type>::size_type i = 0; i < alpha.size(); ++i) {
431 if (labels[i] == label_order[0]) {
432 format_libsvm_line(out_string, alpha[i], support_vectors[i]);
435 if (out_string.size() > STRING_BUFFER_SIZE) {
438 out.print(
"{}", out_string);
439 #pragma omp flush(out)
449 if (!out_string.empty()) {
450 out.print(
"{}", out_string);
453 counts[0] = counts[0] + 1;
454 #pragma omp flush(counts, out)
457 for (
typename std::vector<label_type>::size_type l = 1; l < label_order.size(); ++l) {
460 #pragma omp for nowait
461 for (
typename std::vector<real_type>::size_type i = 0; i < alpha.size(); ++i) {
462 if (labels[i] == label_order[l]) {
463 format_libsvm_line(out_string, alpha[i], support_vectors[i]);
466 if (out_string.size() > STRING_BUFFER_SIZE) {
469 out.print(
"{}", out_string);
470 #pragma omp flush(out)
479 while (counts[l - 1] < omp_get_num_threads()) {
487 if (!out_string.empty()) {
488 out.print(
"{}", out_string);
491 counts[l] = counts[l] + 1;
492 #pragma omp flush(counts, out)
Implements a custom assert macro PLSSVM_ASSERT.
#define PLSSVM_ASSERT(cond, msg,...)
Defines the PLSSVM_ASSERT macro if PLSSVM_ASSERT_ENABLED is defined.
Definition: assert.hpp:74
bool has_labels() const noexcept
Returns whether this data set contains labels or not.
Definition: data_set.hpp:194
size_type num_different_labels() const noexcept
Returns the number of different labels in this data set.
Definition: data_set.hpp:225
const std::vector< std::vector< real_type > > & data() const noexcept
Return the data points in this data set.
Definition: data_set.hpp:189
size_type num_features() const noexcept
Returns the number of features in this data set.
Definition: data_set.hpp:218
std::optional< std::vector< label_type > > different_labels() const
Returns an optional to the different labels in this data set.
Definition: data_set.hpp:633
optional_ref< const std::vector< label_type > > labels() const noexcept
Returns an optional reference to the labels in this data set.
Definition: data_set.hpp:625
size_type num_data_points() const noexcept
Returns the number of data points in this data set.
Definition: data_set.hpp:213
Implements a data set class encapsulating all data points, features, and potential labels.
Defines universal utility functions.
Defines a simple logging function.
Namespace containing implementation details for the IO related functions. Should not directly be used...
Definition: core.hpp:44
std::vector< label_type > write_libsvm_model_header(fmt::ostream &out, const plssvm::parameter ¶ms, const real_type rho, const data_set< real_type, label_type > &data)
Write the LIBSVM model file header to out.
Definition: libsvm_model_parsing.hpp:296
void write_libsvm_model_data(const std::string &filename, const plssvm::parameter ¶ms, const real_type rho, const std::vector< real_type > &alpha, const data_set< real_type, label_type > &data)
Write the LIBSVM model to the file filename.
Definition: libsvm_model_parsing.hpp:371
std::tuple< plssvm::parameter, real_type, std::vector< label_type >, std::size_t > parse_libsvm_model_header(const std::vector< std::string_view > &lines)
Parse the LIBSVM model file header.
Definition: libsvm_model_parsing.hpp:83
bool starts_with(std::string_view str, std::string_view sv) noexcept
Checks if the string str starts with the prefix sv.
void log(const verbosity_level verb, const std::string_view msg, Args &&...args)
Definition: logger.hpp:109
T convert_to(const std::string_view str)
Converts the string str to a value of type T.
Definition: string_conversion.hpp:47
std::string_view trim(std::string_view str) noexcept
Returns a new std::string_view equal to str where all leading and trailing whitespaces are removed.
std::string_view trim_left(std::string_view str) noexcept
Returns a new std::string_view equal to str where all leading whitespaces are removed.
std::string current_date_time()
Return the current date time in the format "YYYY-MM-DD hh:mm:ss".
std::string & to_lower_case(std::string &str)
Convert the string str to its all lower case representation.
Implements the parameter class encapsulating all important C-SVM parameters.
default_value< real_type > coef0
The coef0 parameter used in the polynomial kernel function.
Definition: parameter.hpp:163
default_value< int > degree
The degree parameter used in the polynomial kernel function.
Definition: parameter.hpp:159
default_value< kernel_function_type > kernel_type
The used kernel function: linear, polynomial, or radial basis functions (rbf).
Definition: parameter.hpp:157
default_value< real_type > gamma
The gamma parameter used in the polynomial and rbf kernel functions.
Definition: parameter.hpp:161