libcrn  3.9.5
A document image processing library
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CRNBasicClassify.h
Go to the documentation of this file.
1 /* Copyright 2008-2016 INSA Lyon, CoReNum, ENS-Lyon
2  *
3  * This file is part of libcrn.
4  *
5  * libcrn is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU Lesser General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * libcrn is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU Lesser General Public License for more details.
14  *
15  * You should have received a copy of the GNU Lesser General Public License
16  * along with libcrn. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * file: CRNBasicClassify.h
19  * \author Yann LEYDIER
20  */
21 
22 #ifndef CRNBASICCLASSIFY_HEADER
23 #define CRNBASICCLASSIFY_HEADER
24 
25 #include <CRNObject.h>
26 #include <CRNData/CRNMap.h>
27 #include <CRNData/CRNVector.h>
28 #include <CRNException.h>
29 #include <CRNAI/CRNClassifResult.h>
30 #include <set>
31 #include <map>
32 #include <algorithm>
33 #include <limits>
34 
35 namespace crn
36 {
37  /****************************************************************************/
48  {
49  public:
62  template<
63  typename ConstIterator,
64  typename std::enable_if<IsMetric<typename std::iterator_traits<ConstIterator>::value_type>::value, int>::type = 0
65  >
66  static ClassifResult NearestNeighbor(const typename std::iterator_traits<ConstIterator>::value_type &obj, ConstIterator begin, ConstIterator end)
67  {
68  auto nearest = 0;
69  auto mindist = std::numeric_limits<double>::max();
70  auto classid = 0;
71  ConstIterator prot;
72  for (auto it = begin; it != end; ++it)
73  {
74  const auto d = Distance(obj, *it);
75  if (d < mindist)
76  {
77  mindist = d;
78  nearest = classid;
79  prot = it;
80  }
81  classid += 1;
82  }
83 
84  return ClassifResult(nearest, mindist, *prot);
85  }
86 
99  template<
100  typename T,
101  typename std::enable_if<IsMetric<T>::value, int>::type = 0
102  >
103  static ClassifResult kNearestNeighbors(const T &obj, const Map &database, int k)
104  {
105  std::set<ClassifResult> knn;
106  std::set<ClassifResult>::iterator maxneighbor;
107  auto classid = 0;
108  for (Map::const_iterator dataclass = database.begin();
109  dataclass != database.end(); ++dataclass)
110  {
111  String label = dataclass->first;
112  SVector samples = std::dynamic_pointer_cast<Vector>(dataclass->second);
113  if (!samples)
114  throw ExceptionInvalidArgument("ClassifResult BasicClassify::"
115  "kNearestNeighbors(const Object &obj, const Map &database, "
116  "int k): invalid database.");
117  for (Vector::const_iterator sample = samples->begin();
118  sample != samples->end(); ++sample)
119  {
120  auto d = Distance(obj, dynamic_cast<const T&>(**sample)); // may throw
121  if (knn.size() < (unsigned int)k)
122  { // knn list not already full
123  knn.insert(ClassifResult(classid, label, d, *sample));
124  }
125  else
126  { // knn list full
127  if (d < maxneighbor->distance)
128  { // add to list and remove the (k+1)th neighbor
129  knn.insert(ClassifResult(classid, label, d, *sample));
130  knn.erase(std::max_element(knn.begin(), knn.end()));
131  }
132  }
133  maxneighbor = std::max_element(knn.begin(), knn.end());
134  }
135  classid += 1;
136  }
137  return chooseClass(knn);
138  }
139 
152  template<
153  typename T,
154  typename std::enable_if<IsMetric<T>::value, int>::type = 0
155  >
156  static ClassifResult EpsilonNeighbors(const T &obj, const Map &database, double epsilon)
157  {
158  std::set<ClassifResult> en;
159  int classid = 0;
160  for (Map::const_iterator dataclass = database.begin();
161  dataclass != database.end(); ++dataclass)
162  {
163  String label = dataclass->first;
164  SVector samples = std::dynamic_pointer_cast<Vector>(dataclass->second);
165  if (!samples)
166  throw ExceptionInvalidArgument("ClassifResult BasicClassify::"
167  "EpsilonNeighbors(const Object &obj, const Map &database, "
168  "double epsilon): invalid database.");
169  for (Vector::const_iterator sample = samples->begin();
170  sample != samples->end(); ++sample)
171  {
172  double d = Distance(obj, dynamic_cast<const T&>(**sample)); // may throw
173  if (d < epsilon)
174  { // add to list
175  en.insert(ClassifResult(classid, label, d, *sample));
176  }
177  }
178  classid += 1;
179  }
180  return chooseClass(en);
181  }
182 
183  private:
193  static bool pairintintvalcmp(const std::pair<int, int> &p1, const std::pair<int, int> &p2)
194  {
195  return p1.second < p2.second;
196  }
201  class pairintintvaleq: public std::unary_function<const std::pair<int, int> &, bool>
202  {
203  public:
206  pairintintvaleq(int v):val(v) {}
211  bool operator()(const std::pair<int, int> &p)
212  {
213  return p.second == val;
214  }
215  private:
216  int val;
217  };
226  static ClassifResult chooseClass(std::set<ClassifResult> &nn)
227  {
228  // count population for each class
229  std::map<int, int> pop;
230  for (std::set<ClassifResult>::iterator it = nn.begin();
231  it != nn.end(); ++it)
232  {
233  pop[it->class_id] += 1;
234  }
235  // retrieve maximal population
236  int maxpop = std::max_element(pop.begin(), pop.end(),
237  pairintintvalcmp)->second;
238  // look for classes that have maximal population
239  // (there can be several such classes)
240  std::map<int, int>::iterator popit = pop.begin();
241  pairintintvaleq sel(maxpop);
242  std::set<ClassifResult> sameclass;
243  while ((popit = std::find_if(popit, pop.end(), sel)) != pop.end())
244  {
245  // look for neighbors whose class has maximal population
246  ClassifResult::SelectId csel(popit->first);
247  std::set<ClassifResult>::iterator cit = nn.begin();
248  while ((cit = std::find_if(cit, nn.end(), csel)) != nn.end())
249  {
250  sameclass.insert(*cit);
251  ++cit;
252  }
253  ++popit;
254  }
255  return *std::min_element(sameclass.begin(), sameclass.end());
256  }
258  };
259 }
260 #endif
static ClassifResult EpsilonNeighbors(const T &obj, const Map &database, double epsilon)
Classify a sample using the k nearest neighbors.
static ClassifResult kNearestNeighbors(const T &obj, const Map &database, int k)
Classify a sample using the k nearest neighbors.
Basic classification tools.
A UTF32 character string class.
Definition: CRNString.h:61
iterator begin()
Returns an iterator to the first element.
Definition: CRNMap.h:86
std::map< String, SObject >::const_iterator const_iterator
const_iterator on the contents of the container
Definition: CRNMap.h:93
A generic classification result.
double Distance(const Int &i1, const Int &i2) noexcept
Definition: CRNInt.h:78
Data vector class.
Definition: CRNVector.h:42
static ClassifResult NearestNeighbor(const typename std::iterator_traits< ConstIterator >::value_type &obj, ConstIterator begin, ConstIterator end)
Finds the nearest neighbor in a set of objects.
Data map class.
Definition: CRNMap.h:42
iterator end()
Returns an iterator after the last element.
Definition: CRNMap.h:88
Invalid argument error (e.g.: nullptr pointer)
Definition: CRNException.h:107