31 using namespace crn::literals;
38 row_covered(dm.size(),
false),
39 col_covered(dm.size(),
false),
40 path(dm.size() * 2, std::vector<size_t>(dm.size() * 2, 0)),
41 marked(dm.size(), std::vector<uint8_t>(dm.size(), 0))
46 row_covered(dm.GetRows(),
false),
47 col_covered(dm.GetRows(),
false),
48 path(dm.GetRows() * 2, std::vector<size_t>(dm.GetRows() * 2, 0)),
49 marked(dm.GetRows(), std::vector<uint8_t>(dm.GetRows(), 0))
54 size_t z0_r = 0, z0_c = 0;
55 std::vector<std::vector<size_t>>
path;
56 std::vector<std::vector<uint8_t>>
marked;
66 static inline std::pair<int, int> find_a_zero(
const kuhn_munkres &km)
69 for (
int i = 0; i < int(km.
n); ++i)
70 for (
int j = 0; j < int(km.
n); ++j)
73 return std::make_pair(i, j);
75 return std::make_pair(-1, -1);
78 static inline int find_star_in_row(
const kuhn_munkres &km,
size_t row)
81 for (
int j = 0; j < int(km.
n); ++j)
82 if (km.
marked[row][j] == 1)
87 static inline int find_star_in_col(
const kuhn_munkres &km,
size_t col)
90 for (
int i = 0; i < int(km.
n); ++i)
91 if (km.
marked[i][col] == 1)
96 static inline int find_prime_in_row(
const kuhn_munkres &km,
size_t row)
99 for (
int j = 0; j < int(km.
n); ++j)
100 if (km.
marked[row][j] == 2)
105 static inline void convert_path(
kuhn_munkres &km,
size_t count)
107 for (
size_t i = 0; i <= count; ++i)
117 for (
int i = 0; i < int(km.
n); ++i)
118 for (
int j = 0; j < int(km.
n); ++j)
123 static inline double find_smallest(
const kuhn_munkres &km)
126 auto minval = std::numeric_limits<double>::max();
127 for (
size_t i = 0; i < km.
n; ++i)
128 for (
size_t j = 0; j < km.
n; ++j)
130 if (minval > km.
c[i][j])
138 for (
size_t i = 0; i < km.
n; ++i)
141 auto minval = *std::min_element(km.
c[i], km.
c[i] + km.
c.
GetCols());
142 for (
auto v =
size_t(0); v < km.
c.
GetCols(); ++v)
143 km.
c[i][v] -= minval;
151 for (
size_t i = 0; i < km.
n; ++i)
152 for (
size_t j = 0; j < km.
n; ++j)
167 auto count = size_t(0);
168 for (
size_t i = 0; i < km.
n; ++i)
169 for (
size_t j = 0; j < km.
n; ++j)
186 auto rc = find_a_zero(km);
189 km.
marked[rc.first][rc.second] = 2;
190 auto star_col = find_star_in_row(km, rc.first);
193 rc.second = star_col;
209 auto count = size_t(0);
214 auto row = find_star_in_col(km, km.
path[count][1]);
218 km.
path[count][0] = row;
219 km.
path[count][1] = km.
path[count-1][1];
224 auto col = find_prime_in_row(km, km.
path[count][0]);
226 km.
path[count][0] = km.
path[count-1][0];
227 km.
path[count][1] = col;
229 convert_path(km, count);
238 auto minval = find_smallest(km);
239 for (
size_t i = 0; i < km.
n; ++i)
240 for (
size_t j = 0; j < km.
n; ++j)
243 km.
c[i][j] += minval;
245 km.
c[i][j] -= minval;
250 template<
typename T> std::tuple<double, std::vector<std::pair<size_t, size_t>>>
hung(
const T &distmat)
278 auto pairs = std::vector<std::pair<size_t, size_t>>{};
281 for (
size_t i = 0; i < km.
n; ++i)
282 for (
size_t j = 0; j < km.
n; ++j)
285 pairs.emplace_back(i, j);
286 cost += distmat[i][j];
289 return std::make_tuple(cost, std::move(pairs));
300 std::tuple<double, std::vector<std::pair<size_t, size_t>>>
crn::Hungarian(
const std::vector<std::vector<double>> &distmat)
304 for (
const auto &row : distmat)
305 if (row.size() != distmat.size())
307 return hung(distmat);
320 return hung(distmat);
std::vector< std::vector< size_t > > path
std::vector< bool > row_covered
size_t GetCols() const noexcept
Returns the number of columns.
kuhn_munkres(const std::vector< std::vector< double >> &dm)
crn::SquareMatrixDouble c
std::vector< bool > col_covered
std::tuple< double, std::vector< std::pair< size_t, size_t > > > hung(const T &distmat)
std::vector< std::vector< uint8_t > > marked
void erase_primes(kuhn_munkres &km)
std::tuple< double, std::vector< std::pair< size_t, size_t > > > Hungarian(const std::vector< std::vector< double >> &distmat)
Square double matrix class.
kuhn_munkres(const crn::SquareMatrixDouble &dm)
Invalid argument error (e.g.: nullptr pointer)