1: /* 学習プログラム cz.c */
2:
3: #include <stdio.h>
4: #include <stdlib.h>
5: #include <math.h>
6: #include <string.h>
7:
8:
9: #define NIN 20 /* 入力層の素子数 */
10: #define NMID 50 /* 中間層の素子数 */
11: #define NOUT 2 /* 出力層の素子数 */
12: #define SMAX 256
13: #define FALSE 0
14: #define TRUE !FALSE
15: #define ESTFILE "cz.est"
16: #define CHAOS 1
17: #define OUT 2
18:
19: #define fs(x) (1/(1+exp(-(x)/epsilon))) /* シグモイド関数 */
20: #define fd(x) ((x)*(1-(x))/epsilon) /* シグモイド関数の微分 */
21:
22: #define EQU !strcmp
23:
24:
25: double drand48();
26: void srand48(long seedval);
27:
28: int nin=NIN, nmid=NMID, nout=NOUT;
29:
30: /* 出力 */
31: double xi[ NIN ], xm[ NMID ], xo[ NOUT ];
32: double old_xi[ NIN ], old_xm[ NMID ], old_xo[ NOUT ];
33:
34: /* 内部状態 (イータ, ゼータ) */
35: double em[ NMID ], old_em[ NMID ], eo[ NOUT ], old_eo[ NOUT ];
36: double zm[ NMID ], old_zm[ NMID ], zo[ NOUT ], old_zo[ NOUT ];
37:
38: /* カオスニューロンのパラメータ */
39: double kmm[ NMID ], kmdm[ NMID ], kmo[ NOUT ], kmdo[ NOUT ];
40: double krm[ NMID ], krdm[ NMID ], kro[ NOUT ], krdo[ NOUT ];
41: double alm[ NMID ], aldm[ NMID ], alo[ NOUT ], aldo[ NOUT ];
42:
43: /* 結合荷重と閾値θ */
44: double wim[ NIN+ 1 ][ NMID ], old_dwim[ NIN +1 ][ NMID ];
45: double wmo[ NMID+1 ][ NOUT ], old_dwmo[ NMID+1 ][ NOUT ];
46:
47: /* 教師信号 */
48: double *ts;
49: int tsmax;
50:
51: /* 学習定数 */
52: double ra=0.01, rb=0.005, rc=0.02;
53:
54: double epsilon=1.0;
55:
56: int flg;
57:
58:
59: main(int argc, char *argv[]){
60: int n, ninterval=1, rinterval=1;
61: double ratio=1.0;
62: long seedval;
63: char ts_file[SMAX], ld_file[SMAX], sv_file[SMAX];
64: FILE *fi;
65: char buf1[SMAX], buf2[SMAX];
66: void load_ts(char fn[]), load_prm(char fn[]), save_prm(char fn[]);
67: void syokika(), learn(int n1, int n2, int n3, double r);
68: void chaoticparameter(double prm), karamawasi(int n);
69:
70: if(argc>=2){
71: if((fi=fopen(argv[1], "r"))==NULL){
72: fprintf(stderr, "can't open %s\n", argv[1]);
73: exit(1);
74: }
75: }else{
76: if((fi=fopen(ESTFILE, "r"))==NULL){
77: fprintf(stderr, "can't open %s\n", argv[1]);
78: exit(1);
79: }
80: }
81:
82: seedval=0;
83: flg=CHAOS | OUT;
84: strcpy(ts_file,"cz.ts");
85: ld_file[0]=NULL;
86: sv_file[0]=NULL;
87: while(fscanf(fi, "%s %s", buf1, buf2)!=EOF){
88: if(EQU(buf1, "teachersignalfile")) strcpy(ts_file, buf2);
89: else if(EQU(buf1, "ratio_a")) ra=atof(buf2);
90: else if(EQU(buf1, "ratio_b")) rb=atof(buf2);
91: else if(EQU(buf1, "ratio_c")) rc=atof(buf2);
92: else if(EQU(buf1, "epsilon")) epsilon=atof(buf2);
93: else if(EQU(buf1, "seedval")) seedval=atoi(buf2);
94: else if(EQU(buf1, "learn")) n=atoi(buf2);
95: else if(EQU(buf1, "printinterval")) ninterval=atoi(buf2);
96: else if(EQU(buf1, "ratiointerval")) rinterval=atoi(buf2);
97: else if(EQU(buf1, "ratiochange")) ratio=atof(buf2);
98: else if(EQU(buf1, "loadfile")) strcpy(ld_file, buf2);
99: else if(EQU(buf1, "savefile")) strcpy(sv_file, buf2);
100: else if(EQU(buf1, "no_in")) nin =atoi(buf2)*2;
101: else if(EQU(buf1, "no_mid")) nmid=atoi(buf2);
102: else if(EQU(buf1, "chaos")) {if(EQU(buf2, "off")) flg &= ~CHAOS;}
103: else if(EQU(buf1, "output")) {if(EQU(buf2, "off")) flg &= ~OUT;}
104: }
105:
106: fclose(fi);
107:
108: if(seedval!=0){
109: syokika(seedval);
110: }
111:
112: load_ts(ts_file);
113:
114: if(ld_file[0]!=NULL){
115: load_prm(ld_file);
116: }
117:
118: if(!(flg & CHAOS)){
119: chaoticparameter(0.0);
120: }
121:
122: karamawasi(100);
123:
124: learn(n, ninterval, rinterval, ratio);
125:
126: if(sv_file[0]!=NULL){
127: save_prm(sv_file);
128: }
129: }
130:
131:
132: /* 教師信号の読み込み */
133: void load_ts(char *filename){
134: FILE *fi;
135: int i, ct;
136: char buf[SMAX];
137:
138: if((fi=fopen(filename,"r"))==NULL){
139: fprintf(stderr, "Can't open %s\n", filename);
140: exit(1);
141: }
142:
143: ct=0;
144: while(fgets(buf, SMAX, fi)!=NULL)
145: ct++;
146: ct*=2;
147:
148: if((ts=(double *)calloc(ct, sizeof(double)))==NULL){
149: fprintf(stderr, "Can't get memory\n");
150: exit(1);
151: }
152:
153: rewind(fi);
154: for(i=0; i<ct; i+=2){
155: fscanf(fi, "%lf %lf", &ts[i], &ts[i+1]);
156: }
157:
158: tsmax=ct;
159: }
160:
161:
162: /* パラメータの読み込み */
163: void load_prm(char *filename){
164: FILE *fi;
165: size_t wsize;
166: int i;
167:
168: if((fi=fopen(filename, "r"))==NULL){
169: fprintf(stderr, "can't open %s\n", filename);
170: exit(1);
171: }
172:
173: wsize=sizeof(double);
174: for(i=0; i<nin+1; i++){
175: fread(wim[i], wsize, nmid, fi);
176: fread(old_dwim[i], wsize, nmid, fi);
177: }
178: for(i=0; i<nmid+1; i++){
179: fread(wmo[i], wsize, nout, fi);
180: fread(old_dwmo[i], wsize, nout, fi);
181: }
182: fread(kmm, wsize, nmid, fi);
183: fread(kmdm, wsize, nmid, fi);
184: fread(krm, wsize, nmid, fi);
185: fread(krdm, wsize, nmid, fi);
186: fread(alm, wsize, nmid, fi);
187: fread(aldm, wsize, nmid, fi);
188: fread(em, wsize, nmid, fi);
189: fread(zm, wsize, nmid, fi);
190: fread(xm, wsize, nmid, fi);
191: fread(kmo, wsize, nout, fi);
192: fread(kmdo, wsize, nout, fi);
193: fread(kro, wsize, nout, fi);
194: fread(krdo, wsize, nout, fi);
195: fread(alo, wsize, nout, fi);
196: fread(aldo, wsize, nout, fi);
197: fread(eo, wsize, nout, fi);
198: fread(zo, wsize, nout, fi);
199: fread(xo, wsize, nout, fi);
200:
201: fclose(fi);
202: }
203:
204:
205: /* パラメータの保存 */
206: void save_prm(char *filename){
207: FILE *fo;
208: size_t wsize;
209: int i;
210:
211: if((fo=fopen(filename, "w"))==NULL){
212: fprintf(stderr, "can't open %s\n", filename);
213: exit(1);
214: }
215:
216: wsize=sizeof(double);
217: for(i=0; i<nin+1; i++){
218: fwrite(wim[i], wsize, nmid, fo);
219: fwrite(old_dwim[i], wsize, nmid, fo);
220: }
221: for(i=0; i<nmid+1; i++){
222: fwrite(wmo[i], wsize, nout, fo);
223: fwrite(old_dwmo[i], wsize, nout, fo);
224: }
225:
226: fwrite(kmm, wsize, nmid, fo);
227: fwrite(kmdm, wsize, nmid, fo);
228: fwrite(krm, wsize, nmid, fo);
229: fwrite(krdm, wsize, nmid, fo);
230: fwrite(alm, wsize, nmid, fo);
231: fwrite(aldm, wsize, nmid, fo);
232: fwrite(em, wsize, nmid, fo);
233: fwrite(zm, wsize, nmid, fo);
234: fwrite(xm, wsize, nmid, fo);
235: fwrite(kmo, wsize, nout, fo);
236: fwrite(kmdo, wsize, nout, fo);
237: fwrite(kro, wsize, nout, fo);
238: fwrite(krdo, wsize, nout, fo);
239: fwrite(alo, wsize, nout, fo);
240: fwrite(aldo, wsize, nout, fo);
241: fwrite(eo, wsize, nout, fo);
242: fwrite(zo, wsize, nout, fo);
243: fwrite(xo, wsize, nout, fo);
244:
245: fclose(fo);
246: }
247:
248:
249: /* パラメータの初期化 */
250: void syokika(long seedval){
251: int i, j;
252:
253: srand48(seedval);
254: for(i=0; i<nmid; i++){
255: if(flg & CHAOS){
256: kmm[i]=drand48();
257: kmdm[i]=-log(1/kmm[i]-1)/epsilon;
258: krm[i]=drand48();
259: krdm[i]=-log(1/krm[i]-1)/epsilon;
260: alm[i]=5*drand48()+5;
261: aldm[i]=log(alm[i]);
262: }
263: for(j=0; j<nin; j++)
264: wim[j][i]=drand48()-0.5;
265: wim[j][i]=drand48();
266: }
267: for(i=0; i<nout; i++){
268: if(flg & CHAOS){
269: kmo[i]=drand48();
270: kmdo[i]=-log(1/kmo[i]-1)/epsilon;
271: kro[i]=drand48();
272: krdo[i]=-log(1/kro[i]-1)/epsilon;
273: alo[i]=5*drand48()+5;
274: aldo[i]=log(alo[i]);
275: }
276: for(j=0; j<nmid; j++)
277: wmo[j][i]=drand48()-0.5;
278: wmo[j][i]=drand48();
279: }
280: }
281:
282:
283: /* カオスニューロンのパラメータの変更 */
284: void chaoticparameter(double x){
285: int i;
286:
287: for(i=0; i<nmid; i++){
288: kmm[i]=krm[i]=alm[i]=x;
289: }
290:
291: for(i=0; i<nout; i++){
292: kmo[i]=kro[i]=alo[i]=x;
293: }
294: }
295:
296:
297: /* 学習 */
298: void learn(int n, int ninterval, int rinterval, double r){
299: int i, j, k, l, nn;
300: double gosa1, gosa2;
301: void keisan(), bkp(int);
302:
303: if(flg & OUT){
304: for(i=0; i<n; i++){
305: gosa1=gosa2=0.0;
306: for(j=0; j<tsmax; j+=2){
307: for(k=0; k<nin; k++){
308: if(j-nin+k>=0)
309: xi[k]=ts[j-nin+k];
310: else
311: xi[k]=ts[tsmax+j-nin+k];
312: }
313: keisan();
314: gosa1+=fabs(xo[0]-ts[j]);
315: gosa2+=fabs(xo[1]-ts[j+1]);
316: bkp(j);
317: }
318: printf("%lf+%lf=%lf\n", gosa1/tsmax, gosa2/tsmax, (gosa1+gosa2)/tsmax);
319: }
320: }else{
321: nn=n/ninterval;
322: for(l=0; l<nn; l++){
323: for(i=0; i<ninterval-1; i++){
324: for(j=0; j<tsmax; j+=2){
325: for(k=0; k<nin; k++){
326: if(j-nin+k>=0)
327: xi[k]=ts[j-nin+k];
328: else
329: xi[k]=ts[tsmax+j-nin+k];
330: }
331: keisan();
332: bkp(j);
333: }
334: }
335: gosa1=0.0;
336: for(j=0; j<tsmax; j+=2){
337: for(k=0; k<nin; k++){
338: if(j-nin+k>=0)
339: xi[k]=ts[j-nin+k];
340: else
341: xi[k]=ts[tsmax+j-nin+k];
342: }
343: keisan();
344: gosa1+=fabs(xo[0]-ts[j])+fabs(xo[1]-ts[j+1]);
345: bkp(j);
346: }
347: printf("%lf\n", gosa1/tsmax);
348: fflush(stdout);
349: if((l+1)%(rinterval/ninterval)==0){
350: ra*=r; rb*=r; rc*=r;
351: }
352: }
353: }
354: }
355:
356:
357: /* ネットワークの計算 */
358: void keisan(){
359: int i, j;
360: double sum;
361:
362: for(j=0; j<nmid; j++){
363: sum=0.0;
364: for(i=0; i<nin; i++)
365: sum+=wim[i][j]*xi[i];
366: sum+=wim[i][j];
367: old_em[j]=em[j];
368: em[j]=kmm[j]*old_em[j]+sum;
369: old_zm[j]=zm[j];
370: old_xm[j]=xm[j];
371: zm[j]=krm[j]*zm[j]-alm[j]*old_xm[j];
372: xm[j]=fs(em[j]+zm[j]);
373: }
374: for(j=0; j<nout; j++){
375: sum=0.0;
376: for(i=0; i<nmid; i++)
377: sum+=wmo[i][j]*xm[i];
378: sum+=wmo[i][j];
379: old_eo[j]=eo[j];
380: eo[j]=kmo[j]*old_eo[j]+sum;
381: old_zo[j]=zo[j];
382: old_xo[j]=xo[j];
383: zo[j]=kro[j]*zo[j]-alo[j]*old_xo[j];
384: xo[j]=fs(eo[j]+zo[j]);
385: }
386: }
387:
388:
389: /* バックプロパゲーション */
390: void bkp(int n){
391: double dlo[NOUT], dlm[NMID];
392: double wk;
393: int i, j, k;
394:
395: for(i=0; i<nout; i++)
396: dlo[i]=(xo[i]-ts[n+i])*fd(xo[i]);
397:
398: for(i=0; i<nmid; i++){
399: wk=0.0;
400: for(j=0; j<nout; j++)
401: wk+=dlo[j]*wmo[i][j];
402: dlm[i]=wk*fd(xm[i]);
403: }
404:
405: for(i=0; i<nmid; i++){
406: for(j=0; j<nin+1; j++){
407: wk=-ra*dlm[i]*((j<nin)?xi[j]:1)+rb*old_dwim[j][i];
408: old_dwim[j][i]=wk;
409: wim[j][i]+=wk;
410: }
411: }
412:
413: for(i=0; i<nout; i++){
414: for(j=0; j<nmid+1; j++){
415: wk=-ra*dlo[i]*((j<nmid)?xm[j]:1)+rb*old_dwmo[j][i];
416: old_dwmo[j][i]=wk;
417: wmo[j][i]+=wk;
418: }
419: }
420:
421: if(flg & CHAOS){
422: for(i=0; i<nmid; i++){
423: kmdm[i]-=rc*dlm[i]*old_em[i]*fd(kmm[i]);
424: kmm[i]=fs(kmdm[i]);
425: krdm[i]-=rc*dlm[i]*old_zm[i]*fd(krm[i]);
426: krm[i]=fs(krdm[i]);
427: aldm[i]+=rc*dlm[i]*old_xm[i]*exp(aldm[i]/epsilon)/epsilon;
428: alm[i]=exp(aldm[i]/epsilon);
429: }
430:
431: for(i=0; i<nout; i++){
432: kmdo[i]-=rc*dlo[i]*old_eo[i]*fd(kmo[i]);
433: kmo[i]=fs(kmdo[i]);
434: krdo[i]-=rc*dlo[i]*old_zo[i]*fd(kro[i]);
435: kro[i]=fs(krdo[i]);
436: aldo[i]+=rc*dlo[i]*old_xo[i]*exp(aldo[i]/epsilon)/epsilon;
437: alo[i]=exp(aldo[i]/epsilon);
438: }
439: }
440: }
441:
442:
443: /* ネットワークの空回し */
444: void karamawasi(int n){
445: int i, j, k;
446: for(i=0; i<n; i++){
447: for(j=0; j<tsmax; j+=2){
448: for(k=0; k<nin; k++){
449: if(j-nin+k>=0){
450: xi[k]=ts[j-nin+k];
451: }else{
452: xi[k]=ts[tsmax+j-nin+k];
453: }
454: }
455: keisan();
456: }
457: }
458: }