PLSSVM - Parallel Least Squares Support Vector Machine  2.0.0
A Least Squares Support Vector Machine implementation using different backends.
kernel_function_types.hpp
Go to the documentation of this file.
1 
12 #ifndef PLSSVM_KERNEL_FUNCTION_TYPES_HPP_
13 #define PLSSVM_KERNEL_FUNCTION_TYPES_HPP_
14 #pragma once
15 
16 #include "plssvm/detail/assert.hpp" // PLSSVM_ASSERT
17 #include "plssvm/detail/operators.hpp" // dot product, plssvm::squared_euclidean_dist
18 #include "plssvm/detail/type_traits.hpp" // plssvm::detail::always_false_v
19 #include "plssvm/detail/utility.hpp" // plssvm::detail::get
20 #include "plssvm/exceptions/exceptions.hpp" // plssvm::unsupported_kernel_type_exception
21 
22 #include <cmath> // std::pow, std::exp, std::fma
23 #include <iosfwd> // forward declare std::ostream and std::istream
24 #include <vector> // std::vector
25 
26 namespace plssvm {
27 
33  linear = 0,
35  polynomial = 1,
37  rbf = 2
38 };
39 
46 std::ostream &operator<<(std::ostream &out, kernel_function_type kernel);
47 
54 [[nodiscard]] std::string_view kernel_function_type_to_math_string(kernel_function_type kernel) noexcept;
55 
63 std::istream &operator>>(std::istream &in, kernel_function_type &kernel);
64 
75 template <kernel_function_type kernel, typename real_type, typename... Args>
76 [[nodiscard]] inline real_type kernel_function(const std::vector<real_type> &xi, const std::vector<real_type> &xj, Args &&...args) {
77  using namespace plssvm::operators;
78 
79  PLSSVM_ASSERT(xi.size() == xj.size(), "Sizes mismatch!: {} != {}", xi.size(), xj.size());
80 
81  if constexpr (kernel == kernel_function_type::linear) {
82  static_assert(sizeof...(args) == 0, "Illegal number of additional parameters! Must be 0.");
83  return transposed{ xi } * xj;
84  } else if constexpr (kernel == kernel_function_type::polynomial) {
85  static_assert(sizeof...(args) == 3, "Illegal number of additional parameters! Must be 3.");
86  const auto degree = static_cast<real_type>(detail::get<0>(args...));
87  const auto gamma = static_cast<real_type>(detail::get<1>(args...));
88  const auto coef0 = static_cast<real_type>(detail::get<2>(args...));
89  return std::pow(std::fma(gamma, (transposed<real_type>{ xi } * xj), coef0), degree);
90  } else if constexpr (kernel == kernel_function_type::rbf) {
91  static_assert(sizeof...(args) == 1, "Illegal number of additional parameters! Must be 1.");
92  const auto gamma = static_cast<real_type>(detail::get<0>(args...));
93  return std::exp(-gamma * squared_euclidean_dist(xi, xj));
94  } else {
95  static_assert(detail::always_false_v<real_type>, "Unknown kernel type!");
96  }
97 }
98 
99 // forward declare parameter class
100 namespace detail {
101 template <typename>
102 struct parameter;
103 }
104 
114 template <typename real_type>
115 [[nodiscard]] real_type kernel_function(const std::vector<real_type> &xi, const std::vector<real_type> &xj, const detail::parameter<real_type> &params);
116 
117 } // namespace plssvm
118 
119 #endif // PLSSVM_KERNEL_FUNCTION_TYPES_HPP_
Implements a custom assert macro PLSSVM_ASSERT.
#define PLSSVM_ASSERT(cond, msg,...)
Defines the PLSSVM_ASSERT macro if PLSSVM_ASSERT_ENABLED is defined.
Definition: assert.hpp:74
Defines universal utility functions.
Implements custom exception classes derived from std::runtime_error including source location informa...
Namespace containing operator overloads for std::vector and other mathematical functions on vectors.
Definition: core.hpp:49
T squared_euclidean_dist(const std::vector< T > &lhs, const std::vector< T > &rhs)
Calculates the squared Euclidean distance of both vectors: .
Definition: operators.hpp:162
The main namespace containing all public API functions.
Definition: backend_types.hpp:24
std::istream & operator>>(std::istream &in, backend_type &backend)
Use the input-stream in to initialize the backend type.
std::ostream & operator<<(std::ostream &out, backend_type backend)
Output the backend to the given output-stream out.
kernel_function_type
Enum class for all implemented kernel functions.
Definition: kernel_function_types.hpp:31
detail::parameter< double > parameter
The public parameter type uses double to store the SVM parameters.
Definition: parameter.hpp:328
real_type kernel_function(const std::vector< real_type > &xi, const std::vector< real_type > &xj, Args &&...args)
Computes the value of the two vectors xi and xj using the kernel function determined at compile-time.
Definition: kernel_function_types.hpp:76
std::string_view kernel_function_type_to_math_string(kernel_function_type kernel) noexcept
Return the mathematical representation of the kernel_type kernel.
Defines (arithmetic) functions on std::vector and scalars.
Class for encapsulating all important C-SVM parameters.
Definition: parameter.hpp:106
Wrapper struct for overloading the dot product operator.
Definition: operators.hpp:99
Defines some generic type traits used in the PLSSVM library.