-
Notifications
You must be signed in to change notification settings - Fork 3
/
Solution.h
140 lines (94 loc) · 4.23 KB
/
Solution.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#ifndef Solution_H
#define Solution_H
/* Authors: Daniel Gribel and Thibaut Vidal
* Contact: dgribel@inf.puc-rio.br
*
* The Solution class represents a clustering solution. It can be expressed by two representations:
* (i) assignment: specifies for each sample the index of the cluster with which it is associated;
* (ii) centroids: represents the coordinates of the center of each cluster.
* A Solution object also stores the cost of the solution (in terms of MSSC objective),
* the mutation parameter alpha and a reference to problem data
*/
#include <iostream>
#include <vector>
#include <math.h>
#include "hamerly/dataset.h"
#include "hamerly/hamerly_kmeans.h"
#include "MathUtils.h"
#include "PbData.h"
#define MUTATION_RATE 0.2
using namespace std;
using namespace MathUtils;
class Solution {
private:
// Assignment representation of a solution
unsigned short* assignment;
// Centroids representation of a solution
vector< vector<double> > centroids;
// Solution cost, in terms of MSSC objective
double cost;
// Alpha weigth
double alpha;
// Problem Data
PbData pb_data;
double crand;
double nmi;
double centroid_index;
// Generate the centroids representation of a solution (assignment representation supposed to exist)
void AssignmentToCentroids();
// Generate the assignment representation of a solution (centroids representation supposed to exist)
void CentroidsToAssignment();
// Initialize assignment data structure
void InitAssignment();
// Initialize centroids data structure
void InitCentroids();
// Count coefficients for Rand and C-Rand indicators
void CountRandCoefficients(unsigned short* ground_truth, int& a, int& b, int& c, int& d);
// Remove random center and re-assign data points to closest remaining center
// Partial assignment with m-1 centers is generated
void RemoveCenter(int barycenter);
// Reinsert removed center c in the position of a data point p; re-assign data points to closest center
// Complete assignment with m centers is generated
void ReinsertCenter(int c, int p, vector<double> & dist_centroid);
public:
// Constructor in which solution assignment and cost are known
Solution(unsigned short* assignment, double cost, double alpha, PbData pb_data);
// Constructor in which solution centroids and cost are known
Solution(vector< vector<double> > & centroids, double cost, double alpha, PbData pb_data);
// Constructor in which solution assignment is known
Solution(unsigned short* assignment, double alpha, PbData pb_data);
// Constructor in which solution centroids are known
Solution(vector< vector<double> > & centroids, double alpha, PbData pb_data);
// Assignment free memory
~Solution();
// Get the assignment representation
unsigned short* GetAssignment() { return assignment; };
// Get all centroids
vector< vector<double> > GetCentroids() { return centroids; };
// Get specific centroid
vector<double> GetCentroids(int i) { return centroids[i]; };
// Get solution cost (in terms of MSSC objective)
double GetCost() { return cost; }
// Get alpha parameter
double GetAlpha() { return alpha; }
// Mutate the solution
void Mutate();
// Apply local search to solution (K-means)
void DoLocalSearch(Dataset const *x);
// Mutate the alpha weigth
void MutateAlpha();
// Repair the solution if assignment is degenerated (less than m clusters)
void Repair();
/* External Clustering measures */
// Calculate the C-Rand indicator
void ComputeCRand();
// Calculate the Normalized mutual information indicator
void ComputeNmi();
// Calculate the Centroid Index indicator
void ComputeCentroidIndex();
void ComputeExternalMetrics();
double GetCRand() { return crand; };
double GetNmi() { return nmi; };
double GetCentroidIndex() { return centroid_index; };
};
#endif