libcrn  3.9.5
A document image processing library
•All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CRNAffinityPropagation.cpp
Go to the documentation of this file.
1 /* Copyright 2015 Université Paris Descartes
2  *
3  * This file is part of libcrn.
4  *
5  * libcrn is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU Lesser General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * libcrn is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU Lesser General Public License for more details.
14  *
15  * You should have received a copy of the GNU Lesser General Public License
16  * along with libcrn. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * file: CRNAffinityPropagation.cpp
19  * \author Yann LEYDIER
20  */
21 
24 #include <CRNException.h>
25 #include <CRNi18n.h>
26 
27 using namespace crn;
28 
29 static std::pair<std::vector<size_t>, std::vector<size_t>> affinityPropagation(const crn::SquareMatrixDouble &s, double damping, size_t stable_iters_stop, size_t max_iter)
30 {
31  if ((damping < 0.0) || (damping >= 1.0))
32  throw crn::ExceptionDomain(_("The damping must be in [0, 1[."));
33  if (stable_iters_stop <= 1)
34  throw crn::ExceptionDomain(_("The number of stable iterations to stop must be >1."));
35  if (max_iter <= 1)
36  throw crn::ExceptionDomain(_("The maximal number of iterations must be >1."));
37 
38  const auto N = s.GetRows();
39 
40  // main loop
41  auto r = crn::SquareMatrixDouble{N};
42  auto a = crn::SquareMatrixDouble{N};
43  auto identical = size_t(0);
44  auto clusters = std::vector<size_t>(N, 0);
45  for (auto cnt = size_t(0); cnt < max_iter; ++cnt)
46  {
47  // update responsibility
48  for (auto i = size_t(0); i < N; ++i)
49  for (auto k = size_t(0); k < N; ++k)
50  {
51  auto m = -std::numeric_limits<double>::max();
52  for (auto kp = size_t(0); kp < N; ++kp)
53  if (kp != k)
54  {
55  const auto v = a[i][kp] - s[i][kp];
56  if (v > m)
57  m = v;
58  }
59  r[i][k] = damping * r[i][k] + (1.0 - damping) * (-s[i][k] - m);
60  }
61  // update availability
62  for (auto i = size_t(0); i < N; ++i)
63  for (auto k = size_t(0); k < N; ++k)
64  {
65  if (i == k)
66  {
67  auto v = 0.0;
68  for (auto ip = size_t(0); ip < N; ++ip)
69  if (ip != k)
70  v += crn::Max(0.0, r[ip][k]);
71  a[i][k] = damping * a[i][k] + (1.0 - damping) * v;
72  }
73  else
74  {
75  auto v = r[k][k];
76  for (auto ip = size_t(0); ip < N; ++ip)
77  if ((ip != i) && (ip != k))
78  v += crn::Max(0.0, r[ip][k]);
79  a[i][k] = damping * a[i][k] + (1.0 - damping) * crn::Min(0.0, v);
80  }
81  }
82 
83  // compute clusters
84  auto newclusters = std::vector<size_t>(N, 0);
85  for (auto i = size_t(0); i < N; ++i)
86  {
87  auto c = size_t(0);
88  auto maxval = -std::numeric_limits<double>::max();
89  for (auto k = size_t(0); k < N; ++k)
90  {
91  const auto val = r[i][k] + a[i][k];
92  if (val > maxval)
93  {
94  maxval = val;
95  c = k;
96  }
97  }
98  newclusters[i] = c;
99  }
100 
101  // check if there were changes
102  if (clusters == newclusters)
103  identical += 1;
104  else
105  {
106  clusters.swap(newclusters);
107  identical = 0;
108  }
109  if (identical >= stable_iters_stop)
110  break;
111  } // main loop
112 
113  auto protos = std::vector<size_t>{};
114  for (auto i = size_t(0); i < N; ++i)
115  if (clusters[i] == i)
116  protos.push_back(i);
117  return std::make_pair(std::move(protos), std::move(clusters));
118 }
119 
128 std::pair<std::vector<size_t>, std::vector<size_t>> crn::AffinityPropagation(const crn::SquareMatrixDouble &distance_matrix, AProClusters nclusters, double damping, size_t stable_iters_stop, size_t max_iter)
129 {
130  const auto N = distance_matrix.GetRows();
131 
132  // create similarity matrix
133  auto s = distance_matrix;
134  auto diag = 0.0;
135  if (nclusters == AProClusters::MEDIUM)
136  { // pick median value
137  auto vect = s.Std();
138  std::sort(vect.begin(), vect.end());
139  diag = vect[(vect.size() + N) / 2]; // N first values are 0.0, do not count them
140  }
141  else //if (nclusters == AProClusters::LOW)
142  {
143  diag = s.GetMax();
144  }
145  for (auto tmp = size_t(0); tmp < N; ++tmp)
146  s[tmp][tmp] = diag;
147  s *= -1;
148 
149  return affinityPropagation(s, damping, stable_iters_stop, max_iter);
150 }
151 
160 std::pair<std::vector<size_t>, std::vector<size_t>> crn::AffinityPropagation(const crn::SquareMatrixDouble &distance_matrix, double preference, double damping, size_t stable_iters_stop, size_t max_iter)
161 {
162  // create similarity matrix
163  auto s = distance_matrix;
164  for (auto tmp = size_t(0); tmp < s.GetRows(); ++tmp)
165  s[tmp][tmp] = preference;
166  s *= -1;
167  return affinityPropagation(s, damping, stable_iters_stop, max_iter);
168 }
169 
178 std::pair<std::vector<size_t>, std::vector<size_t>> crn::AffinityPropagation(const crn::SquareMatrixDouble &distance_matrix, const std::vector<double> &preference, double damping, size_t stable_iters_stop, size_t max_iter)
179 {
180  if (distance_matrix.GetRows() != preference.size())
181  throw crn::ExceptionDimension{_("The preference is not the same dimension as the distance matrix.")};
182  // create similarity matrix
183  auto s = distance_matrix;
184  for (auto tmp = size_t(0); tmp < s.GetRows(); ++tmp)
185  s[tmp][tmp] = preference[tmp];
186  s *= -1;
187  return affinityPropagation(s, damping, stable_iters_stop, max_iter);
188 }
189 
size_t GetRows() const noexcept
Returns the number of rows.
Definition: CRNMatrix.h:157
#define _(String)
Definition: CRNi18n.h:51
const T & Max(const T &a, const T &b)
Returns the max of two values.
Definition: CRNMath.h:47
A generic domain error.
Definition: CRNException.h:83
A dimension error.
Definition: CRNException.h:119
const T & Min(const T &a, const T &b)
Returns the min of two values.
Definition: CRNMath.h:49
const std::vector< T > & Std() const &noexcept
Definition: CRNMatrix.h:658
Square double matrix class.
T GetMax() const
Definition: CRNMatrix.h:438
AProClusters
Strategies to limit the number of classes in affinity propagation.
std::pair< std::vector< size_t >, std::vector< size_t > > AffinityPropagation(const SquareMatrixDouble &distance_matrix, AProClusters nclusters, double damping=0.5, size_t stable_iters_stop=10, size_t max_iter=100)
Computes clusters and their prototypes.