PLSSVM - Parallel Least Squares Support Vector Machine  2.0.0
A Least Squares Support Vector Machine implementation using different backends.
model.hpp
Go to the documentation of this file.
1 
12 #ifndef PLSSVM_MODEL_HPP_
13 #define PLSSVM_MODEL_HPP_
14 #pragma once
15 
16 #include "plssvm/data_set.hpp" // plssvm::data_set
17 #include "plssvm/detail/assert.hpp" // PLSSVM_ASSERT
18 #include "plssvm/detail/io/libsvm_model_parsing.hpp" // plssvm::detail::io::{parse_libsvm_model_header, write_libsvm_model_data}
19 #include "plssvm/detail/io/libsvm_parsing.hpp" // plssvm::detail::io::parse_libsvm_data
20 #include "plssvm/detail/logger.hpp" // plssvm::detail::log, plssvm::verbosity_level
21 #include "plssvm/detail/performance_tracker.hpp" // plssvm::detail::tracking_entry
22 #include "plssvm/detail/type_list.hpp" // plssvm::detail::{real_type_list, label_type_list, type_list_contains_v}
23 #include "plssvm/parameter.hpp" // plssvm::parameter
24 
25 #include "fmt/chrono.h" // format std::chrono types using fmt
26 #include "fmt/core.h" // fmt::format
27 
28 #include <chrono> // std::chrono::{time_point, steady_clock, duration_cast, milliseconds}
29 #include <cstddef> // std::size_t
30 #include <iostream> // std::cout, std::endl
31 #include <memory> // std::shared_ptr, std::make_shared
32 #include <string> // std::string
33 #include <tuple> // std::tie
34 #include <utility> // std::move
35 #include <vector> // std::vector
36 
37 namespace plssvm {
38 
49 template <typename T, typename U = int>
50 class model {
51  // make sure only valid template types are used
52  static_assert(detail::type_list_contains_v<T, detail::real_type_list>, "Illegal real type provided! See the 'real_type_list' in the type_list.hpp header for a list of the allowed types.");
53  static_assert(detail::type_list_contains_v<U, detail::label_type_list>, "Illegal label type provided! See the 'label_type_list' in the type_list.hpp header for a list of the allowed types.");
54 
55  // plssvm::csvm needs the private constructor
56  friend class csvm;
57 
58  public:
60  using real_type = T;
62  using label_type = U;
64  using size_type = std::size_t;
65 
71  explicit model(const std::string &filename);
72 
77  void save(const std::string &filename) const;
78 
83  [[nodiscard]] size_type num_support_vectors() const noexcept { return num_support_vectors_; }
88  [[nodiscard]] size_type num_features() const noexcept { return num_features_; }
89 
94  [[nodiscard]] const parameter &get_params() const noexcept { return params_; }
100  [[nodiscard]] const std::vector<std::vector<real_type>> &support_vectors() const noexcept { return data_.data(); }
101 
107  [[nodiscard]] const std::vector<label_type> &labels() const noexcept { return data_.labels()->get(); }
114  [[nodiscard]] size_type num_different_labels() const noexcept { return data_.num_different_labels(); }
120  [[nodiscard]] std::vector<label_type> different_labels() const { return data_.different_labels().value(); }
121 
127  [[nodiscard]] const std::vector<real_type> &weights() const noexcept {
128  PLSSVM_ASSERT(alpha_ptr_ != nullptr, "The alpha_ptr may never be a nullptr!");
129  return *alpha_ptr_;
130  }
135  [[nodiscard]] real_type rho() const noexcept { return rho_; }
136 
137  private:
146 
155 
157  std::shared_ptr<std::vector<real_type>> alpha_ptr_{ nullptr };
159  real_type rho_{ 0.0 };
160 
166  std::shared_ptr<std::vector<real_type>> w_{ std::make_shared<std::vector<real_type>>() };
167 };
168 
169 template <typename T, typename U>
170 model<T, U>::model(const std::string &filename) {
171  const std::chrono::time_point start_time = std::chrono::steady_clock::now();
172 
173  // open the file
174  detail::io::file_reader reader{ filename };
175  reader.read_lines('#');
176 
177  // parse the libsvm model header
178  std::vector<label_type> labels{};
179  std::size_t num_header_lines{};
180  std::tie(params_, rho_, labels, num_header_lines) = detail::io::parse_libsvm_model_header<real_type, label_type, size_type>(reader.lines());
181 
182  // create empty support vectors and alpha vector
183  std::vector<std::vector<real_type>> support_vectors;
184  std::vector<real_type> alphas;
185 
186  // parse libsvm model data
187  std::tie(num_support_vectors_, num_features_, support_vectors, alphas) = detail::io::parse_libsvm_data<real_type, real_type>(reader, num_header_lines);
188 
189  // create data set
190  data_ = data_set<real_type, label_type>{ std::move(support_vectors), std::move(labels) };
191  alpha_ptr_ = std::make_shared<decltype(alphas)>(std::move(alphas));
192 
193  const std::chrono::time_point end_time = std::chrono::steady_clock::now();
195  "Read {} support vectors with {} features in {} using the libsvm model parser from file '{}'.\n\n",
196  detail::tracking_entry{ "model_read", "num_support_vectors", num_support_vectors_ },
197  detail::tracking_entry{ "model_read", "num_features", num_features_ },
198  detail::tracking_entry{ "model_read", "time", std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time) },
199  detail::tracking_entry{ "model_read", "filename", filename });
201 }
202 
203 template <typename T, typename U>
205  params_{ std::move(params) }, data_{ std::move(data) }, num_support_vectors_{ data_.num_data_points() }, num_features_{ data_.num_features() }, alpha_ptr_{ std::make_shared<std::vector<real_type>>(data_.num_data_points()) } {}
206 
207 template <typename T, typename U>
208 void model<T, U>::save(const std::string &filename) const {
209  const std::chrono::time_point start_time = std::chrono::steady_clock::now();
210 
211  // save model file header and support vectors
212  detail::io::write_libsvm_model_data(filename, params_, rho_, *alpha_ptr_, data_);
213 
214  const std::chrono::time_point end_time = std::chrono::steady_clock::now();
216  "Write {} support vectors with {} features in {} to the libsvm model file '{}'.\n",
217  detail::tracking_entry{ "model_write", "num_support_vectors", num_support_vectors_ },
218  detail::tracking_entry{ "model_write", "num_features", num_features_ },
219  detail::tracking_entry{ "model_write", "time", std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time) },
220  detail::tracking_entry{ "model_write", "filename", filename });
222 }
223 
224 } // namespace plssvm
225 
226 #endif // PLSSVM_MODEL_HPP_
Implements a custom assert macro PLSSVM_ASSERT.
#define PLSSVM_ASSERT(cond, msg,...)
Defines the PLSSVM_ASSERT macro if PLSSVM_ASSERT_ENABLED is defined.
Definition: assert.hpp:74
Base class for all C-SVM backends.
Definition: csvm.hpp:50
size_type num_different_labels() const noexcept
Returns the number of different labels in this data set.
Definition: data_set.hpp:225
const std::vector< std::vector< real_type > > & data() const noexcept
Return the data points in this data set.
Definition: data_set.hpp:189
std::optional< std::vector< label_type > > different_labels() const
Returns an optional to the different labels in this data set.
Definition: data_set.hpp:633
optional_ref< const std::vector< label_type > > labels() const noexcept
Returns an optional reference to the labels in this data set.
Definition: data_set.hpp:625
The plssvm::detail::file_reader class is responsible for reading a file and splitting it into its lin...
Definition: file_reader.hpp:42
const std::vector< std::string_view > & read_lines(std::string_view comment={ "\n" })
Read the content of the associated file and split it into lines, ignoring empty lines and lines start...
Implements a class encapsulating the result of a call to the SVM fit function. A model is used to pre...
Definition: model.hpp:50
const std::vector< std::vector< real_type > > & support_vectors() const noexcept
The support vectors representing the learned model.
Definition: model.hpp:100
size_type num_features_
The number of features per support vector.
Definition: model.hpp:154
T real_type
The type of the data points: either float or double.
Definition: model.hpp:60
data_set< real_type, label_type > data_
The data (support vectors + respective label) used to learn this model.
Definition: model.hpp:150
real_type rho_
The bias after learning this model.
Definition: model.hpp:159
parameter params_
The SVM parameter used to learn this model.
Definition: model.hpp:148
const std::vector< label_type > & labels() const noexcept
Returns the labels of the support vectors.
Definition: model.hpp:107
void save(const std::string &filename) const
Save the model to a LIBSVM model file for later usage.
Definition: model.hpp:208
size_type num_features() const noexcept
The number of features of the support vectors used in this model.
Definition: model.hpp:88
const std::vector< real_type > & weights() const noexcept
The learned weights for the support vectors.
Definition: model.hpp:127
U label_type
The type of the labels: any arithmetic type or std::string.
Definition: model.hpp:62
std::shared_ptr< std::vector< real_type > > w_
A vector used to speedup the prediction in case of the linear kernel function.
Definition: model.hpp:166
const parameter & get_params() const noexcept
Return the SVM parameter that were used to learn this model.
Definition: model.hpp:94
std::vector< label_type > different_labels() const
Returns the different labels of the support vectors.
Definition: model.hpp:120
std::size_t size_type
The unsigned size type.
Definition: model.hpp:64
size_type num_different_labels() const noexcept
Returns the number of different labels in this data set.
Definition: model.hpp:114
real_type rho() const noexcept
The bias value after learning.
Definition: model.hpp:135
size_type num_support_vectors_
The number of support vectors representing this model.
Definition: model.hpp:152
size_type num_support_vectors() const noexcept
The number of support vectors used in this model.
Definition: model.hpp:83
std::shared_ptr< std::vector< real_type > > alpha_ptr_
The learned weights for each support vector.
Definition: model.hpp:157
model(const std::string &filename)
Read a previously learned model from the LIBSVM model file filename.
Definition: model.hpp:170
Implements a data set class encapsulating all data points, features, and potential labels.
Implements parsing functions for the LIBSVM model file.
Implements parsing functions for the LIBSVM file format.
Defines a simple logging function.
void write_libsvm_model_data(const std::string &filename, const plssvm::parameter &params, const real_type rho, const std::vector< real_type > &alpha, const data_set< real_type, label_type > &data)
Write the LIBSVM model to the file filename.
Definition: libsvm_model_parsing.hpp:371
void log(const verbosity_level verb, const std::string_view msg, Args &&...args)
Definition: logger.hpp:109
The main namespace containing all public API functions.
Definition: backend_types.hpp:24
Implements the parameter class encapsulating all important C-SVM parameters.
Defines a performance tracker which can dump performance information in a YAML file.
#define PLSSVM_DETAIL_PERFORMANCE_TRACKER_ADD_TRACKING_ENTRY(entry)
Defines the PLSSVM_DETAIL_PERFORMANCE_TRACKER_ADD_TRACKING_ENTRY macro if PLSSVM_PERFORMANCE_TRACKER_...
Definition: performance_tracker.hpp:245
A single tracking entry containing a specific category, a unique name, and the actual value to be tra...
Definition: performance_tracker.hpp:40
All possible real_type and label_type combinations for a plssvm::model and plssvm::data_set.