13#ifndef _RBTNMSIMPLEX_H_
14#define _RBTNMSIMPLEX_H_
17#define EIGEN_DONT_VECTORIZE
21#include "rxdock/NMState.h"
38template <
class DataType,
class ParameterType,
class Function,
class Criterion>
42 ParameterType m_deltas;
45 Eigen::Array<DataType, Eigen::Dynamic, Eigen::Dynamic> m_polytopePoints;
46 Eigen::Array<DataType, Eigen::Dynamic, Eigen::Dynamic> m_polytopeValues;
49 Criterion m_criterion;
51 void InitializePolytope(
const ParameterType &start_point, DataType delta,
54 m_polytopePoints.resize(start_point.size(), start_point.size() + 1);
55 m_polytopeValues.resize(1, start_point.size() + 1);
56 m_polytopePoints.col(0) = start_point;
57 m_polytopeValues(0, 0) = fun(start_point);
58 Display(fun, m_polytopePoints.col(0));
59 for (
int i = 1; i < start_point.size() + 1; ++i) {
60 m_polytopePoints.col(i) = start_point;
61 m_polytopePoints(i - 1, i) += delta;
62 m_polytopeValues(0, i) = fun(m_polytopePoints.col(i));
63 Display(fun, m_polytopePoints.col(i));
67 void InitializePolytope(
const ParameterType &start_point,
68 ParameterType deltas, Function &fun)
70 m_polytopePoints.resize(start_point.size(), start_point.size() + 1);
71 m_polytopeValues.resize(1, start_point.size() + 1);
72 m_polytopePoints.col(0) = start_point;
73 m_polytopeValues(0, 0) = fun(start_point);
74 Display(fun, m_polytopePoints.col(0));
75 for (
int i = 1; i < start_point.size() + 1; ++i) {
76 m_polytopePoints.col(i) = start_point;
77 m_polytopePoints(i - 1, i) += deltas(i - 1);
78 m_polytopeValues(0, i) = fun(m_polytopePoints.col(i));
79 Display(fun, m_polytopePoints.col(i));
83 void Display(Function &fun,
const ParameterType ¶meters)
85 LOG_F(1,
"Point: {}", parameters);
86 LOG_F(1,
"Value: {}", fun(parameters));
89 void FindBestWorstNearWorst(Function &fun,
int &best,
int &worst,
92 if (m_polytopeValues(0, 0) > m_polytopeValues(0, 1)) {
100 for (
int i = 2; i < m_polytopeValues.cols(); ++i) {
101 if (m_polytopeValues(0, i) < m_polytopeValues(0, best)) {
104 if (m_polytopeValues(0, i) > m_polytopeValues(0, worst)) {
107 }
else if (m_polytopeValues(0, i) > m_polytopeValues(0, near_worst)) {
111 LOG_F(1,
"Worst value is in position {} {}\n{}", worst,
112 m_polytopeValues(0, worst), m_polytopePoints.col(worst));
113 LOG_F(1,
"Near-worst value is in position {} {}\n{}", near_worst,
114 m_polytopeValues(0, near_worst), m_polytopePoints.col(near_worst));
115 LOG_F(1,
"Best value is in position {} {}\n{}", best,
116 m_polytopeValues(0, best), m_polytopePoints.col(best));
119 ParameterType CreateNewParameters(
const ParameterType &sum,
120 const ParameterType &discarded_point,
122 DataType fac1 = (1 - t) / sum.size();
123 DataType fac2 = fac1 - t;
124 return sum * fac1 - discarded_point * fac2;
128 Simplex(
const Criterion &criterion)
129 : m_delta(0), use_deltas(
false), m_criterion(criterion) {}
131 void Optimize(Function &fun)
133 m_state.iteration = 0;
134 m_state.currentValue = fun(m_state.currentParameters);
135 m_state.formerValue = std::numeric_limits<DataType>::max();
136 m_state.bestValue = std::numeric_limits<DataType>::max();
139 InitializePolytope(m_state.currentParameters, m_deltas, fun);
141 InitializePolytope(m_state.currentParameters, m_delta, fun);
144 while (m_criterion(m_state)) {
145 LOG_F(1,
"Starting iteration {}", m_state.iteration);
146 int best, worst, near_worst;
147 FindBestWorstNearWorst(fun, best, worst, near_worst);
148 m_state.currentValue = m_polytopeValues(0, best);
149 m_state.currentParameters = m_polytopePoints.col(best);
150 m_state.formerValue = m_polytopeValues(0, worst);
151 m_state.formerParameters = m_polytopePoints.col(worst);
153 if (m_state.currentValue < m_state.bestValue) {
154 m_state.bestValue = m_state.currentValue;
155 m_state.bestParameters = m_state.currentParameters;
158 ParameterType new_parameters = CreateNewParameters(
159 m_polytopePoints.rowwise().sum(), m_polytopePoints.col(worst), -1);
160 DataType new_value = fun(new_parameters);
161 LOG_F(1,
"Trying normal");
162 Display(fun, new_parameters);
163 if (new_value < m_state.bestValue) {
164 ParameterType expansion_parameters = CreateNewParameters(
165 m_polytopePoints.rowwise().sum(), m_polytopePoints.col(worst), -2);
166 DataType expansion_value = fun(expansion_parameters);
167 LOG_F(1,
"Trying expansion");
168 Display(fun, expansion_parameters);
169 if (expansion_value < m_state.bestValue) {
170 m_state.currentValue = m_polytopeValues(0, worst) = expansion_value;
171 m_state.currentParameters = m_polytopePoints.col(worst) =
172 expansion_parameters;
174 m_state.currentValue = m_polytopeValues(0, worst) = new_value;
175 m_state.currentParameters = m_polytopePoints.col(worst) =
178 }
else if (new_value > m_polytopeValues(0, near_worst)) {
180 ParameterType contraction_parameters = CreateNewParameters(
181 m_polytopePoints.rowwise().sum(), m_polytopePoints.col(worst), -.5);
182 DataType contraction_value = fun(contraction_parameters);
183 LOG_F(1,
"Trying contraction");
184 Display(fun, contraction_parameters);
185 if (contraction_value > new_value) {
186 LOG_F(1,
"Contraction around lowest");
187 Display(fun, contraction_parameters);
188 ParameterType best_parameters = m_polytopePoints.col(best);
189 LOG_F(1,
"Difference from the best {}",
190 ((m_polytopePoints.colwise() - best_parameters.array()) / 2));
192 ((m_polytopePoints.colwise() - best_parameters.array()) / 2)
194 best_parameters.array();
195 for (
int i = 0; i < m_polytopePoints.cols(); ++i) {
196 m_polytopeValues(0, i) = fun(m_polytopePoints.col(i));
199 m_state.currentValue = m_polytopeValues(0, worst) = contraction_value;
200 m_state.currentParameters = m_polytopePoints.col(worst) =
201 contraction_parameters;
205 m_state.currentValue = m_polytopeValues(0, worst) = new_value;
206 m_state.currentParameters = m_polytopePoints.col(worst) =
218 return m_state.bestParameters;
226 void SetStartPoint(
const ParameterType &point) {
227 m_state.currentParameters = point;
230 void SetDelta(DataType delta) {
231 this->m_delta = delta;
235 void SetDelta(ParameterType deltas) {
236 this->m_deltas = deltas;
241template <
class Function,
class Criterion>
242static Simplex<
typename Function::DataType,
typename Function::ParameterType,
244CreateSimplex(Function &fun,
const Criterion &criterion)
246 return Simplex<
typename Function::DataType,
typename Function::ParameterType,
247 Function, Criterion>(criterion);
Definition NMSimplex.h:39
const ParameterType & GetBestParameters() const
Definition NMSimplex.h:217
DataType GetBestValue() const
Definition NMSimplex.h:224