RxDock 0.1.0
A fast, versatile, and open-source program for docking ligands to proteins and nucleic acids
Loading...
Searching...
No Matches
NMSimplex.h
1/***********************************************************************
2 * The rDock program was developed from 1998 - 2006 by the software team
3 * at RiboTargets (subsequently Vernalis (R&D) Ltd).
4 * In 2006, the software was licensed to the University of York for
5 * maintenance and distribution.
6 * In 2012, Vernalis and the University of York agreed to release the
7 * program as Open Source software.
8 * This version is licensed under GNU-LGPL version 3.0 with support from
9 * the University of Barcelona.
10 * http://rdock.sourceforge.net/
11 ***********************************************************************/
12
13#ifndef _RBTNMSIMPLEX_H_
14#define _RBTNMSIMPLEX_H_
15
16#ifdef __PGI
17#define EIGEN_DONT_VECTORIZE
18#endif
19#include <Eigen/Core>
20
21#include "rxdock/NMState.h"
22
23#include <loguru.hpp>
24
25namespace rxdock {
26
27namespace neldermead {
28
38template <class DataType, class ParameterType, class Function, class Criterion>
39class Simplex {
40private:
41 DataType m_delta;
42 ParameterType m_deltas;
43 bool use_deltas;
44
45 Eigen::Array<DataType, Eigen::Dynamic, Eigen::Dynamic> m_polytopePoints;
46 Eigen::Array<DataType, Eigen::Dynamic, Eigen::Dynamic> m_polytopeValues;
47
49 Criterion m_criterion;
50
51 void InitializePolytope(const ParameterType &start_point, DataType delta,
52 Function &fun) // const
53 {
54 m_polytopePoints.resize(start_point.size(), start_point.size() + 1);
55 m_polytopeValues.resize(1, start_point.size() + 1);
56 m_polytopePoints.col(0) = start_point;
57 m_polytopeValues(0, 0) = fun(start_point);
58 Display(fun, m_polytopePoints.col(0));
59 for (int i = 1; i < start_point.size() + 1; ++i) {
60 m_polytopePoints.col(i) = start_point;
61 m_polytopePoints(i - 1, i) += delta;
62 m_polytopeValues(0, i) = fun(m_polytopePoints.col(i));
63 Display(fun, m_polytopePoints.col(i));
64 }
65 }
66
67 void InitializePolytope(const ParameterType &start_point,
68 ParameterType deltas, Function &fun) // const
69 {
70 m_polytopePoints.resize(start_point.size(), start_point.size() + 1);
71 m_polytopeValues.resize(1, start_point.size() + 1);
72 m_polytopePoints.col(0) = start_point;
73 m_polytopeValues(0, 0) = fun(start_point);
74 Display(fun, m_polytopePoints.col(0));
75 for (int i = 1; i < start_point.size() + 1; ++i) {
76 m_polytopePoints.col(i) = start_point;
77 m_polytopePoints(i - 1, i) += deltas(i - 1);
78 m_polytopeValues(0, i) = fun(m_polytopePoints.col(i));
79 Display(fun, m_polytopePoints.col(i));
80 }
81 }
82
83 void Display(Function &fun, const ParameterType &parameters) // const
84 {
85 LOG_F(1, "Point: {}", parameters);
86 LOG_F(1, "Value: {}", fun(parameters));
87 }
88
89 void FindBestWorstNearWorst(Function &fun, int &best, int &worst,
90 int &near_worst) // const
91 {
92 if (m_polytopeValues(0, 0) > m_polytopeValues(0, 1)) {
93 worst = 0;
94 near_worst = 1;
95 } else {
96 worst = 1;
97 near_worst = 0;
98 }
99 best = near_worst;
100 for (int i = 2; i < m_polytopeValues.cols(); ++i) {
101 if (m_polytopeValues(0, i) < m_polytopeValues(0, best)) {
102 best = i;
103 }
104 if (m_polytopeValues(0, i) > m_polytopeValues(0, worst)) {
105 near_worst = worst;
106 worst = i;
107 } else if (m_polytopeValues(0, i) > m_polytopeValues(0, near_worst)) {
108 near_worst = i;
109 }
110 }
111 LOG_F(1, "Worst value is in position {} {}\n{}", worst,
112 m_polytopeValues(0, worst), m_polytopePoints.col(worst));
113 LOG_F(1, "Near-worst value is in position {} {}\n{}", near_worst,
114 m_polytopeValues(0, near_worst), m_polytopePoints.col(near_worst));
115 LOG_F(1, "Best value is in position {} {}\n{}", best,
116 m_polytopeValues(0, best), m_polytopePoints.col(best));
117 }
118
119 ParameterType CreateNewParameters(const ParameterType &sum,
120 const ParameterType &discarded_point,
121 DataType t) {
122 DataType fac1 = (1 - t) / sum.size();
123 DataType fac2 = fac1 - t;
124 return sum * fac1 - discarded_point * fac2;
125 }
126
127public:
128 Simplex(const Criterion &criterion)
129 : m_delta(0), use_deltas(false), m_criterion(criterion) {}
130
131 void Optimize(Function &fun) // const
132 {
133 m_state.iteration = 0;
134 m_state.currentValue = fun(m_state.currentParameters);
135 m_state.formerValue = std::numeric_limits<DataType>::max();
136 m_state.bestValue = std::numeric_limits<DataType>::max();
137
138 if (use_deltas) {
139 InitializePolytope(m_state.currentParameters, m_deltas, fun);
140 } else {
141 InitializePolytope(m_state.currentParameters, m_delta, fun);
142 }
143
144 while (m_criterion(m_state)) {
145 LOG_F(1, "Starting iteration {}", m_state.iteration);
146 int best, worst, near_worst;
147 FindBestWorstNearWorst(fun, best, worst, near_worst);
148 m_state.currentValue = m_polytopeValues(0, best);
149 m_state.currentParameters = m_polytopePoints.col(best);
150 m_state.formerValue = m_polytopeValues(0, worst);
151 m_state.formerParameters = m_polytopePoints.col(worst);
152
153 if (m_state.currentValue < m_state.bestValue) {
154 m_state.bestValue = m_state.currentValue;
155 m_state.bestParameters = m_state.currentParameters;
156 }
157
158 ParameterType new_parameters = CreateNewParameters(
159 m_polytopePoints.rowwise().sum(), m_polytopePoints.col(worst), -1);
160 DataType new_value = fun(new_parameters);
161 LOG_F(1, "Trying normal");
162 Display(fun, new_parameters);
163 if (new_value < m_state.bestValue) {
164 ParameterType expansion_parameters = CreateNewParameters(
165 m_polytopePoints.rowwise().sum(), m_polytopePoints.col(worst), -2);
166 DataType expansion_value = fun(expansion_parameters);
167 LOG_F(1, "Trying expansion");
168 Display(fun, expansion_parameters);
169 if (expansion_value < m_state.bestValue) {
170 m_state.currentValue = m_polytopeValues(0, worst) = expansion_value;
171 m_state.currentParameters = m_polytopePoints.col(worst) =
172 expansion_parameters;
173 } else {
174 m_state.currentValue = m_polytopeValues(0, worst) = new_value;
175 m_state.currentParameters = m_polytopePoints.col(worst) =
176 new_parameters;
177 }
178 } else if (new_value > m_polytopeValues(0, near_worst)) {
179 // New point is not better than near worst
180 ParameterType contraction_parameters = CreateNewParameters(
181 m_polytopePoints.rowwise().sum(), m_polytopePoints.col(worst), -.5);
182 DataType contraction_value = fun(contraction_parameters);
183 LOG_F(1, "Trying contraction");
184 Display(fun, contraction_parameters);
185 if (contraction_value > new_value) {
186 LOG_F(1, "Contraction around lowest");
187 Display(fun, contraction_parameters);
188 ParameterType best_parameters = m_polytopePoints.col(best);
189 LOG_F(1, "Difference from the best {}",
190 ((m_polytopePoints.colwise() - best_parameters.array()) / 2));
191 m_polytopePoints =
192 ((m_polytopePoints.colwise() - best_parameters.array()) / 2)
193 .colwise() +
194 best_parameters.array();
195 for (int i = 0; i < m_polytopePoints.cols(); ++i) {
196 m_polytopeValues(0, i) = fun(m_polytopePoints.col(i));
197 }
198 } else {
199 m_state.currentValue = m_polytopeValues(0, worst) = contraction_value;
200 m_state.currentParameters = m_polytopePoints.col(worst) =
201 contraction_parameters;
202 }
203 } else {
204 // New point, not the best, but better than the near worst
205 m_state.currentValue = m_polytopeValues(0, worst) = new_value;
206 m_state.currentParameters = m_polytopePoints.col(worst) =
207 new_parameters;
208 }
209
210 ++m_state.iteration;
211 }
212 }
213
217 const ParameterType &GetBestParameters() const {
218 return m_state.bestParameters;
219 }
220
224 DataType GetBestValue() const { return m_state.bestValue; }
225
226 void SetStartPoint(const ParameterType &point) {
227 m_state.currentParameters = point;
228 }
229
230 void SetDelta(DataType delta) {
231 this->m_delta = delta;
232 use_deltas = false;
233 }
234
235 void SetDelta(ParameterType deltas) {
236 this->m_deltas = deltas;
237 use_deltas = true;
238 }
239};
240
241template <class Function, class Criterion>
242static Simplex<typename Function::DataType, typename Function::ParameterType,
243 Function, Criterion>
244CreateSimplex(Function &fun, const Criterion &criterion) // const
245{
246 return Simplex<typename Function::DataType, typename Function::ParameterType,
247 Function, Criterion>(criterion);
248}
249
250} // namespace neldermead
251
252} // namespace rxdock
253
254#endif /* _RBTNMSIMPLEX_H_ */
Definition NMSimplex.h:39
const ParameterType & GetBestParameters() const
Definition NMSimplex.h:217
DataType GetBestValue() const
Definition NMSimplex.h:224
Definition NMState.h:27