package anotacja; import java.util.Random; import Jama.Matrix; //import zbiory.Dataset; //import zbiory.DatasetGenerator; public class EntropiaGrad { private int dim; private double[][] dane; private double[] klasy; private int[] idx; private int[] klasyLicz; private boolean[] decyzje; private double[] wagi; private double[] bWagi; private double bOc; public static int losoweCykle = 1000; public static int gradientoweCykle = 1000; public static boolean glosno = false; private double mnoznik = 20; public static boolean uczenieGradientowe = true; private Random r; public EntropiaGrad(int seed) { this.r = new Random(seed); } public void ucz(double[][] dane, int[] kl) { this.dane = dane; this.klasy = new double[kl.length]; this.idx = new int[kl.length]; for (int i = 0; i < kl.length; i++) { this.klasy[i] = (kl[i] == 1) ? 1.0 : -1.0; this.idx[i] = (kl[i] == 1) ? 1 : 0; } this.dim = this.dane[0].length; this.ucz(); } public void ucz(double[][] dane, boolean[] kl) { this.dane = dane; this.klasy = new double[kl.length]; this.idx = new int[kl.length]; for (int i = 0; i < kl.length; i++) { this.klasy[i] = kl[i] ? 1.0 : -1.0; this.idx[i] = kl[i] ? 1 : 0; } this.dim = this.dane[0].length; this.ucz(); } private void ucz() { this.klasyLicz = new int[2]; for (int i = 0; i < this.klasy.length; i++) { this.klasyLicz[this.idx[i]]++; } if ((this.klasyLicz[0] == 0) || (this.klasyLicz[1] == 0)) { return; } //System.out.println(this.klasyLicz[0] + " " + this.klasyLicz[1]); this.uczLosowo(); if (uczenieGradientowe) { this.uczGradient(); } this.przypisanie(); //korekta stron nierownosci, regula delta /*double MSEp = this.ocenRozwiazanieWDelta(); this.odwrocWagi(); double MSEn = this.ocenRozwiazanieWDelta(); if (MSEn > MSEp) { this.odwrocWagi(); } this.ocenRozwiazanieEntropia();*/ } private void losujWagi() { for (int i = 0; i < this.wagi.length; i++) { this.wagi[i] = r.nextDouble() * 2 - 1; } } private void zapiszWagi() { for (int i = 0; i < this.wagi.length; i++) { this.bWagi[i] = this.wagi[i]; } } private void przywrocWagi() { for (int i = 0; i < this.wagi.length; i++) { this.wagi[i] = this.bWagi[i]; } } private void odwrocWagi() { for (int i = 0; i < this.wagi.length; i++) { this.wagi[i] = - this.wagi[i]; } } private void uczLosowo() { this.wagi = new double[dim + 1]; this.bWagi = new double[dim + 1]; this.zapiszWagi(); this.bOc = Double.POSITIVE_INFINITY; for (int i = 0; i < losoweCykle; i++) { this.losujWagi(); double oc = this.ocenRozwiazanieEntropia(); if (oc < this.bOc) { this.zapiszWagi(); this.bOc = oc; } if (glosno) { System.out.println("L" + i + " " + this.bOc + " " + oc); } } this.przywrocWagi(); } private void uczGradient() { boolean newton = true; double initBeta = 10; double beta = initBeta; double k = 0.001; int iter = 0; int bezZmian = 0; while ((bezZmian < 20) && (iter < gradientoweCykle)) { this.przywrocWagi(); double p = 0; double n = 0; double I = this.dane.length; double fp = 0; double fn = 0; double fi = 0; double[] gxp = new double[this.wagi.length]; double[] gxn = new double[this.wagi.length]; double[] gxi = new double[this.wagi.length]; double[][] hxp = new double[this.wagi.length][this.wagi.length]; double[][] hxn = new double[this.wagi.length][this.wagi.length]; double[][] hxi = new double[this.wagi.length][this.wagi.length]; double[] grad = new double[this.wagi.length]; double[] grd2 = new double[this.wagi.length]; double[][] hess = new double[this.wagi.length][this.wagi.length]; for (int d = 0; d < this.dane.length; d++) { double l = this.liniowy(this.dane[d], this.wagi); double f = this.aktywacja(l); double g = this.aktywacjaPrim(l); double h = this.aktywacjaBis(l); if (this.klasy[d] > 0) { p = p + 1; } else { n = n + 1; } fi = fi + f; if (this.klasy[d] > 0) { fp = fp + f; } else { fn = fn + f; } for (int w = 0; w < this.wagi.length; w++) { double xw = (w == this.wagi.length - 1) ? -1 : this.dane[d][w]; gxi[w] = gxi[w] + g * xw; if (this.klasy[d] > 0) { gxp[w] = gxp[w] + g * xw; } else { gxn[w] = gxn[w] + g * xw; } } if (newton) { for (int w = 0; w < this.wagi.length; w++) { double xw = (w == this.wagi.length - 1) ? -1 : this.dane[d][w]; for (int v = 0; v < this.wagi.length; v++) { double xv = (v == this.wagi.length - 1) ? -1 : this.dane[d][v]; hxi[w][v] = hxi[w][v] + h * xw * xv; if (this.klasy[d] > 0) { hxp[w][v] = hxp[w][v] + h * xw * xv; } else { hxn[w][v] = hxn[w][v] + h * xw * xv; } } } } } double fEpp = 0.5 * fp + 0.5 * p; double fEppL = fEpp > 0 ? Math.log(fEpp) : -1; double fEpn = 0.5 * fn + 0.5 * n; double fEpnL = fEpn > 0 ? Math.log(fEpn) : -1; double fEpk = 0.5 * fi + 0.5 * I; double fEpkL = fEpk > 0 ? Math.log(fEpk) : -1; double fEnp = 0.5 * p - 0.5 * fp; double fEnpL = fEnp > 0 ? Math.log(fEnp) : -1; double fEnn = 0.5 * n - 0.5 * fn; double fEnnL = fEnn > 0 ? Math.log(fEnn) : -1; double fEnk = 0.5 * I - 0.5 * fi; double fEnkL = fEnk > 0 ? Math.log(fEnk) : -1; for (int w = 0; w < this.wagi.length; w++) { double gEpp = 0.5 * gxp[w] * (fEppL + 1); double gEpn = 0.5 * gxn[w] * (fEpnL + 1); double gEpk = 0.5 * gxi[w] * (fEpkL + 1); double gEp = - (gEpp + gEpn - gEpk) / I; double gEnp = - 0.5 * gxp[w] * (fEnpL + 1); double gEnn = - 0.5 * gxn[w] * (fEnnL + 1); double gEnk = - 0.5 * gxi[w] * (fEnkL + 1); double gEn = - (gEnp + gEnn - gEnk) / I; grad[w] = gEp + gEn; } if (newton) { //wyznaczanie macierzy Hessa for (int w = 0; w < this.wagi.length; w++) { for (int v = 0; v < this.wagi.length; v++) { double hEpp = 0.5 * hxp[w][v] * (1 + fEppL) + 0.25 * gxp[w] * gxp[v] / fEpp; double hEpn = 0.5 * hxn[w][v] * (1 + fEpnL) + 0.25 * gxn[w] * gxn[v] / fEpn; double hEpk = 0.5 * hxi[w][v] * (1 + fEpkL) + 0.25 * gxi[w] * gxi[v] / fEpk; double hEp = - (hEpp + hEpn - hEpk) / I; double hEnp = - 0.5 * hxp[w][v] * (1 + fEnpL) + 0.25 * gxp[w] * gxp[v] / fEnp; double hEnn = - 0.5 * hxn[w][v] * (1 + fEnnL) + 0.25 * gxn[w] * gxn[v] / fEnn; double hEnk = - 0.5 * hxi[w][v] * (1 + fEnkL) + 0.25 * gxi[w] * gxi[v] / fEnk; double hEn = - (hEnp + hEnn - hEnk) / I; hess[w][v] = hEp + hEn; } } //poprawka Levenberga-Marquardta for (int w = 0; w < this.wagi.length; w++) { //hess[w][w] = (1 + k) * hess[w][w]; hess[w][w] = hess[w][w] + k; } Matrix hM = new Matrix(hess); Matrix ihM = Matrix.identity(this.wagi.length, this.wagi.length); try { ihM = hM.inverse(); } catch (Exception ex) { } double[][] inverseHess = ihM.getArray(); //przemnozenie macierzy H i wektora gradientu for (int w = 0; w < this.wagi.length; w++) { double delta = 0; for (int v = 0; v < this.wagi.length; v++) { delta = delta + inverseHess[w][v] * grad[v]; } grd2[w] = delta; } } if (newton) { //aktualizacja kroku zgodnie z metoda Levenberga-Marquardta, bez parametru beta for (int w = 0; w < this.wagi.length; w++) { double delta = grd2[w]; this.wagi[w] = this.wagi[w] - delta; } } else { //aktualizacja zgodnie z metoda gradientowa for (int w = 0; w < this.wagi.length; w++) { double delta = grad[w] * beta; this.wagi[w] = this.wagi[w] - delta; } } boolean improve = false; double oc = this.ocenRozwiazanieEntropia(); if (oc < this.bOc) { this.zapiszWagi(); this.bOc = oc; beta = initBeta; improve = true; bezZmian = 0; k = k * 0.1; } else { beta = 0.5 * beta; bezZmian++; k = k * 10; } boolean print = false; if (print && glosno) { System.out.print("HESS:"); for (int w = 0; w < this.wagi.length; w++) { System.out.print((float)grd2[w] + " "); } System.out.println(); System.out.print("GRAD:"); for (int w = 0; w < this.wagi.length; w++) { System.out.print((float)grad[w] + " "); } System.out.println(); } if (glosno) { if (newton) { System.out.println("H" + iter + " " + this.bOc + " " + oc + " " + improve + " " + (float)k); } else { System.out.println("G" + iter + " " + this.bOc + " " + oc + " " + improve + " " + (float)beta); } } iter++; } this.przywrocWagi(); } private void przypisanie() { double[] zliczenieP = new double[2]; double[] zliczenieN = new double[2]; for (int d = 0; d < this.dane.length; d++) { double l = this.liniowy(this.dane[d], this.wagi); double f = this.aktywacja(l); double p = (f + 1) * 0.5; double n = (1 - f) * 0.5; int idx = this.idx[d]; zliczenieP[idx] = zliczenieP[idx] + p; zliczenieN[idx] = zliczenieN[idx] + n; } zliczenieP[0] = zliczenieP[0] / (this.klasyLicz[0] + 1); zliczenieN[0] = zliczenieN[0] / (this.klasyLicz[0] + 1); zliczenieP[1] = zliczenieP[1] / (this.klasyLicz[1] + 1); zliczenieN[1] = zliczenieN[1] / (this.klasyLicz[1] + 1); //System.out.println((float)zliczenieP[0] + " " + (float)zliczenieP[1] + " " + (float)zliczenieN[0] + " " + (float)zliczenieN[1]); this.decyzje = new boolean[2]; if (zliczenieP[1] >= zliczenieP[0]) { this.decyzje[0] = true; } else { this.decyzje[0] = false; } if (zliczenieN[1] >= zliczenieN[0]) { this.decyzje[1] = true; } else { this.decyzje[1] = false; } } private double ocenRozwiazanieWDelta() { double MSE = 0; int[] cnt = new int[2]; for (int i = 0; i < this.dane.length; i++) { cnt[idx[i]] = cnt[idx[i]] + 1; } for (int i = 0; i < this.dane.length; i++) { double[] wektor = this.dane[i]; double p = this.testujSmooth(wektor); double mse = (this.klasy[i] - p) * (this.klasy[i] - p); MSE = MSE + mse / (double)cnt[idx[i]]; } return MSE; } private double ocenRozwiazanieEntropia() { double[] pp = new double[2]; double[] pn = new double[2]; double[] cnt = new double[2]; double[] dct = new double[2]; int accP = 0; int accN = 0; double n = this.dane.length; for (int i = 0; i < this.dane.length; i++) { double[] wektor = this.dane[i]; double p = this.testujSmooth(wektor); int idx = this.idx[i]; pp[idx] = pp[idx] + (p + 1) * 0.5; pn[idx] = pn[idx] + (1 - p) * 0.5; dct[1] = dct[1] + (p + 1) * 0.5; dct[0] = dct[0] + (1 - p) * 0.5; /*if (p >= 0) { pp[idx] = pp[idx] + 1; dct[1] = dct[1] + 1; } else { pn[idx] = pn[idx] + 1; dct[0] = dct[0] + 1; }*/ if (p * this.klasy[i] > 0) { accP = accP + 1; } else { accN = accN + 1; } cnt[idx] = cnt[idx] + 1; } double sep = - (pp[1] * Math.log(pp[1]) + pp[0] * Math.log(pp[0]) - dct[1] * Math.log(dct[1])) / n; double sen = - (pn[1] * Math.log(pn[1]) + pn[0] * Math.log(pn[0]) - dct[0] * Math.log(dct[0])) / n; double sig = sep + sen; for (int i = 0; i < 2; i++) { pp[i] = pp[i] / dct[1]; pn[i] = pn[i] / dct[0]; } double ep = - pp[0] * Math.log(pp[0]) - pp[1] * Math.log(pp[1]); double en = - pn[0] * Math.log(pn[0]) - pn[1] * Math.log(pn[1]); double mse = this.ocenRozwiazanieWDelta(); double ig = (dct[1] / n) * ep + (dct[0] / n) * en; if (glosno) { System.out.format("%7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %7.3f %d %d\n", pp[0], pn[0], pp[1], pn[1], ep, en, ig, sig, mse, accP, accN); } double delta = Math.abs(sig - ig); if (delta > 0.00001) { System.err.println("BLAD SIG: " + delta + ". Przerywam!"); System.exit(0); } return ig; } private double testujSmooth(double[] dane) { double l = liniowy(dane, this.wagi); return this.aktywacja(l); } public boolean testuj(double[] dane) { if (this.klasyLicz[1] == 0) { return false; } if (this.klasyLicz[0] == 0) { return true; } int decyzja = 0; double wart = this.testujSmooth(dane); if (wart < 0) { decyzja = 1; } return this.decyzje[decyzja]; } public double liniowy(double[] cechy, double[] wagi) { double res = 0; for (int i = 0; i < cechy.length; i++) { res = res + cechy[i] * wagi[i]; } res = res - wagi[cechy.length]; return res * this.mnoznik; } public double aktywacja(double liniowy) { return Math.tanh(liniowy); } public double aktywacjaPrim(double liniowy) { return 1.0 - Math.tanh(liniowy) * Math.tanh(liniowy); } public double aktywacjaBis(double liniowy) { return 2.0 * Math.tanh(liniowy) * (Math.tanh(liniowy) * Math.tanh(liniowy) - 1.0); } /*public static void main(String[] args) { DatasetGenerator dr = new DatasetGenerator(); //Dataset ds = dr.getGaussianDataset(0, 200, 100, 1); Dataset ds = dr.getNSGaussianDataset(0, 100, 100, 6); boolean[] cls = new boolean[ds.cls.length]; for (int i = 0; i < cls.length; i++) { cls[i] = ds.cls[i] > 0; } EntropiaGrad grad = new EntropiaGrad(0); grad.ucz(ds.data, cls); int poprawneP = 0; int poprawneN = 0; int niepoprawne = 0; for (int i = 0; i < ds.data.length; i++) { if (grad.testuj(ds.data[i]) == cls[i]) { if (cls[i]) { poprawneP++; } else { poprawneN++; } } else { niepoprawne++; } } System.out.println("Wynik: " + poprawneP + " " + poprawneN + " " + niepoprawne); }*/ }