next up previous contents
Next: A.2 仕様 Up: 付録A 学習プログラム Previous: 付録A 学習プログラム

A.1 プログラムリスト

  1: /* 学習プログラム cz.c */
  2: 
  3: #include <stdio.h>
  4: #include <stdlib.h>
  5: #include <math.h>
  6: #include <string.h>
  7: 
  8: 
  9: #define NIN   20   /* 入力層の素子数 */
 10: #define NMID  50   /* 中間層の素子数 */
 11: #define NOUT   2   /* 出力層の素子数 */
 12: #define SMAX  256
 13: #define FALSE 0
 14: #define TRUE  !FALSE
 15: #define ESTFILE "cz.est"
 16: #define CHAOS 1
 17: #define OUT   2
 18: 
 19: #define fs(x) (1/(1+exp(-(x)/epsilon)))   /* シグモイド関数 */
 20: #define fd(x) ((x)*(1-(x))/epsilon)       /* シグモイド関数の微分 */
 21: 
 22: #define EQU !strcmp
 23: 
 24: 
 25: double drand48();
 26: void   srand48(long seedval);
 27: 
 28: int nin=NIN, nmid=NMID, nout=NOUT;
 29: 
 30: /* 出力 */
 31: double     xi[ NIN ],     xm[ NMID ],     xo[ NOUT ];
 32: double old_xi[ NIN ], old_xm[ NMID ], old_xo[ NOUT ];
 33: 
 34: /* 内部状態 (イータ, ゼータ) */
 35: double em[ NMID ], old_em[ NMID ], eo[ NOUT ], old_eo[ NOUT ];
 36: double zm[ NMID ], old_zm[ NMID ], zo[ NOUT ], old_zo[ NOUT ];
 37: 
 38: /* カオスニューロンのパラメータ */
 39: double kmm[ NMID ], kmdm[ NMID ], kmo[ NOUT ], kmdo[ NOUT ];
 40: double krm[ NMID ], krdm[ NMID ], kro[ NOUT ], krdo[ NOUT ];
 41: double alm[ NMID ], aldm[ NMID ], alo[ NOUT ], aldo[ NOUT ];
 42: 
 43: /* 結合荷重と閾値θ */
 44: double wim[ NIN+ 1 ][ NMID ], old_dwim[ NIN +1 ][ NMID ];
 45: double wmo[ NMID+1 ][ NOUT ], old_dwmo[ NMID+1 ][ NOUT ];
 46: 
 47: /* 教師信号 */
 48: double *ts;
 49: int     tsmax;
 50: 
 51: /* 学習定数 */
 52: double ra=0.01, rb=0.005, rc=0.02;
 53: 
 54: double epsilon=1.0;
 55: 
 56: int flg;
 57: 
 58: 
 59: main(int argc, char *argv[]){
 60:   int n, ninterval=1, rinterval=1;
 61:   double ratio=1.0;
 62:   long seedval;
 63:   char ts_file[SMAX], ld_file[SMAX], sv_file[SMAX];
 64:   FILE *fi;
 65:   char buf1[SMAX], buf2[SMAX];
 66:   void load_ts(char fn[]), load_prm(char fn[]), save_prm(char fn[]);
 67:   void syokika(), learn(int n1, int n2, int n3, double r);
 68:   void chaoticparameter(double prm), karamawasi(int n);
 69: 
 70:   if(argc>=2){
 71:     if((fi=fopen(argv[1], "r"))==NULL){
 72:       fprintf(stderr, "can't open %s\n", argv[1]);
 73:       exit(1);
 74:     }
 75:   }else{
 76:     if((fi=fopen(ESTFILE, "r"))==NULL){
 77:       fprintf(stderr, "can't open %s\n", argv[1]);
 78:       exit(1);
 79:     }
 80:   }
 81: 
 82:   seedval=0;
 83:   flg=CHAOS | OUT;
 84:   strcpy(ts_file,"cz.ts");
 85:   ld_file[0]=NULL;
 86:   sv_file[0]=NULL;
 87:   while(fscanf(fi, "%s %s", buf1, buf2)!=EOF){
 88:     if(EQU(buf1, "teachersignalfile"))   strcpy(ts_file, buf2);
 89:     else if(EQU(buf1, "ratio_a"))        ra=atof(buf2);
 90:     else if(EQU(buf1, "ratio_b"))        rb=atof(buf2);
 91:     else if(EQU(buf1, "ratio_c"))        rc=atof(buf2);
 92:     else if(EQU(buf1, "epsilon"))        epsilon=atof(buf2);
 93:     else if(EQU(buf1, "seedval"))        seedval=atoi(buf2);
 94:     else if(EQU(buf1, "learn"))          n=atoi(buf2);
 95:     else if(EQU(buf1, "printinterval"))  ninterval=atoi(buf2);
 96:     else if(EQU(buf1, "ratiointerval"))  rinterval=atoi(buf2);
 97:     else if(EQU(buf1, "ratiochange"))    ratio=atof(buf2);
 98:     else if(EQU(buf1, "loadfile"))       strcpy(ld_file, buf2);
 99:     else if(EQU(buf1, "savefile"))       strcpy(sv_file, buf2);
100:     else if(EQU(buf1, "no_in"))          nin =atoi(buf2)*2;
101:     else if(EQU(buf1, "no_mid"))         nmid=atoi(buf2);
102:     else if(EQU(buf1, "chaos"))         {if(EQU(buf2, "off")) flg &= ~CHAOS;}
103:     else if(EQU(buf1, "output"))        {if(EQU(buf2, "off")) flg &= ~OUT;}
104:   }
105: 
106:   fclose(fi);
107: 
108:   if(seedval!=0){
109:     syokika(seedval);
110:   }
111: 
112:   load_ts(ts_file);
113: 
114:   if(ld_file[0]!=NULL){
115:     load_prm(ld_file);
116:   }
117: 
118:   if(!(flg & CHAOS)){
119:     chaoticparameter(0.0);
120:   }
121: 
122:   karamawasi(100);
123: 
124:   learn(n, ninterval, rinterval, ratio);
125: 
126:   if(sv_file[0]!=NULL){
127:     save_prm(sv_file);
128:   }
129: }
130: 
131: 
132: /* 教師信号の読み込み */
133: void load_ts(char *filename){
134:   FILE *fi;
135:   int i, ct;
136:   char buf[SMAX];
137: 
138:   if((fi=fopen(filename,"r"))==NULL){
139:     fprintf(stderr, "Can't open %s\n", filename);
140:     exit(1);
141:   }
142: 
143:   ct=0;
144:   while(fgets(buf, SMAX, fi)!=NULL)
145:     ct++;
146:   ct*=2;
147: 
148:   if((ts=(double *)calloc(ct, sizeof(double)))==NULL){
149:     fprintf(stderr, "Can't get memory\n");
150:     exit(1);
151:   }
152: 
153:   rewind(fi);
154:   for(i=0; i<ct; i+=2){
155:     fscanf(fi, "%lf %lf", &ts[i], &ts[i+1]);
156:   }
157: 
158:   tsmax=ct;
159: }
160: 
161: 
162: /* パラメータの読み込み */
163: void load_prm(char *filename){
164:   FILE *fi;
165:   size_t wsize;
166:   int i;
167: 
168:   if((fi=fopen(filename, "r"))==NULL){
169:     fprintf(stderr, "can't open %s\n", filename);
170:     exit(1);
171:   }
172: 
173:   wsize=sizeof(double);
174:   for(i=0; i<nin+1; i++){
175:     fread(wim[i],      wsize, nmid, fi);
176:     fread(old_dwim[i], wsize, nmid, fi);
177:   }
178:   for(i=0; i<nmid+1; i++){
179:     fread(wmo[i],      wsize, nout, fi);
180:     fread(old_dwmo[i], wsize, nout, fi);
181:   }
182:   fread(kmm,  wsize, nmid, fi);
183:   fread(kmdm, wsize, nmid, fi);
184:   fread(krm,  wsize, nmid, fi);
185:   fread(krdm, wsize, nmid, fi);
186:   fread(alm,  wsize, nmid, fi);
187:   fread(aldm, wsize, nmid, fi);
188:   fread(em,   wsize, nmid, fi);
189:   fread(zm,   wsize, nmid, fi);
190:   fread(xm,   wsize, nmid, fi);
191:   fread(kmo,  wsize, nout, fi);
192:   fread(kmdo, wsize, nout, fi);
193:   fread(kro,  wsize, nout, fi);
194:   fread(krdo, wsize, nout, fi);
195:   fread(alo,  wsize, nout, fi);
196:   fread(aldo, wsize, nout, fi);
197:   fread(eo,   wsize, nout, fi);
198:   fread(zo,   wsize, nout, fi);
199:   fread(xo,   wsize, nout, fi);
200: 
201:   fclose(fi);
202: }
203: 
204: 
205: /* パラメータの保存 */
206: void save_prm(char *filename){
207:   FILE *fo;
208:   size_t wsize;
209:   int i;
210: 
211:   if((fo=fopen(filename, "w"))==NULL){
212:     fprintf(stderr, "can't open %s\n", filename);
213:     exit(1);
214:   }
215: 
216:   wsize=sizeof(double);
217:   for(i=0; i<nin+1; i++){
218:     fwrite(wim[i],      wsize, nmid, fo);
219:     fwrite(old_dwim[i], wsize, nmid, fo);
220:   }
221:   for(i=0; i<nmid+1; i++){
222:     fwrite(wmo[i],      wsize, nout, fo);
223:     fwrite(old_dwmo[i], wsize, nout, fo);
224:   }
225: 
226:   fwrite(kmm,  wsize, nmid, fo);
227:   fwrite(kmdm, wsize, nmid, fo);
228:   fwrite(krm,  wsize, nmid, fo);
229:   fwrite(krdm, wsize, nmid, fo);
230:   fwrite(alm,  wsize, nmid, fo);
231:   fwrite(aldm, wsize, nmid, fo);
232:   fwrite(em,   wsize, nmid, fo);
233:   fwrite(zm,   wsize, nmid, fo);
234:   fwrite(xm,   wsize, nmid, fo);
235:   fwrite(kmo,  wsize, nout, fo);
236:   fwrite(kmdo, wsize, nout, fo);
237:   fwrite(kro,  wsize, nout, fo);
238:   fwrite(krdo, wsize, nout, fo);
239:   fwrite(alo,  wsize, nout, fo);
240:   fwrite(aldo, wsize, nout, fo);
241:   fwrite(eo,   wsize, nout, fo);
242:   fwrite(zo,   wsize, nout, fo);
243:   fwrite(xo,   wsize, nout, fo);
244: 
245:   fclose(fo);
246: }
247: 
248: 
249: /* パラメータの初期化 */
250: void syokika(long seedval){
251:   int i, j;
252: 
253:   srand48(seedval);
254:   for(i=0; i<nmid; i++){
255:     if(flg & CHAOS){
256:       kmm[i]=drand48();
257:       kmdm[i]=-log(1/kmm[i]-1)/epsilon;
258:       krm[i]=drand48();
259:       krdm[i]=-log(1/krm[i]-1)/epsilon;
260:       alm[i]=5*drand48()+5;
261:       aldm[i]=log(alm[i]);
262:     }
263:     for(j=0; j<nin; j++)
264:       wim[j][i]=drand48()-0.5;
265:     wim[j][i]=drand48();
266:   }
267:   for(i=0; i<nout; i++){
268:     if(flg & CHAOS){
269:       kmo[i]=drand48();
270:       kmdo[i]=-log(1/kmo[i]-1)/epsilon;
271:       kro[i]=drand48();
272:       krdo[i]=-log(1/kro[i]-1)/epsilon;
273:       alo[i]=5*drand48()+5;
274:       aldo[i]=log(alo[i]);
275:     }
276:     for(j=0; j<nmid; j++)
277:       wmo[j][i]=drand48()-0.5;
278:     wmo[j][i]=drand48();
279:   }
280: }
281: 
282: 
283: /* カオスニューロンのパラメータの変更 */
284: void chaoticparameter(double x){
285:   int i;
286: 
287:   for(i=0; i<nmid; i++){
288:     kmm[i]=krm[i]=alm[i]=x;
289:   }
290: 
291:   for(i=0; i<nout; i++){
292:     kmo[i]=kro[i]=alo[i]=x;
293:   }
294: }
295: 
296: 
297: /* 学習 */
298: void learn(int n, int ninterval, int rinterval, double r){
299:   int i, j, k, l, nn;
300:   double gosa1, gosa2;
301:   void keisan(), bkp(int);
302: 
303:   if(flg & OUT){
304:     for(i=0; i<n; i++){
305:       gosa1=gosa2=0.0;
306:       for(j=0; j<tsmax; j+=2){
307:    for(k=0; k<nin; k++){
308:      if(j-nin+k>=0)
309:        xi[k]=ts[j-nin+k];
310:      else
311:        xi[k]=ts[tsmax+j-nin+k];
312:    }
313:    keisan();
314:    gosa1+=fabs(xo[0]-ts[j]);
315:    gosa2+=fabs(xo[1]-ts[j+1]);
316:    bkp(j);
317:       }
318:       printf("%lf+%lf=%lf\n", gosa1/tsmax, gosa2/tsmax, (gosa1+gosa2)/tsmax);
319:     }
320:   }else{
321:     nn=n/ninterval;
322:     for(l=0; l<nn; l++){
323:       for(i=0; i<ninterval-1; i++){
324:    for(j=0; j<tsmax; j+=2){
325:      for(k=0; k<nin; k++){
326:        if(j-nin+k>=0)
327:          xi[k]=ts[j-nin+k];
328:        else
329:          xi[k]=ts[tsmax+j-nin+k];
330:      }
331:      keisan();
332:      bkp(j);
333:    }
334:       }
335:       gosa1=0.0;
336:       for(j=0; j<tsmax; j+=2){
337:    for(k=0; k<nin; k++){
338:      if(j-nin+k>=0)
339:        xi[k]=ts[j-nin+k];
340:      else
341:        xi[k]=ts[tsmax+j-nin+k];
342:    }
343:    keisan();
344:    gosa1+=fabs(xo[0]-ts[j])+fabs(xo[1]-ts[j+1]);
345:    bkp(j);
346:       }
347:       printf("%lf\n", gosa1/tsmax);
348:       fflush(stdout);
349:       if((l+1)%(rinterval/ninterval)==0){
350:    ra*=r; rb*=r; rc*=r;
351:       }
352:     }
353:   }
354: }
355: 
356: 
357: /* ネットワークの計算 */
358: void keisan(){
359:   int i, j;
360:   double sum;
361: 
362:   for(j=0; j<nmid; j++){
363:     sum=0.0;
364:     for(i=0; i<nin; i++)
365:       sum+=wim[i][j]*xi[i];
366:     sum+=wim[i][j];
367:     old_em[j]=em[j];
368:     em[j]=kmm[j]*old_em[j]+sum;
369:     old_zm[j]=zm[j];
370:     old_xm[j]=xm[j];
371:     zm[j]=krm[j]*zm[j]-alm[j]*old_xm[j];
372:     xm[j]=fs(em[j]+zm[j]);
373:   }
374:   for(j=0; j<nout; j++){
375:     sum=0.0;
376:     for(i=0; i<nmid; i++)
377:       sum+=wmo[i][j]*xm[i];
378:     sum+=wmo[i][j];
379:     old_eo[j]=eo[j];
380:     eo[j]=kmo[j]*old_eo[j]+sum;
381:     old_zo[j]=zo[j];
382:     old_xo[j]=xo[j];
383:     zo[j]=kro[j]*zo[j]-alo[j]*old_xo[j];
384:     xo[j]=fs(eo[j]+zo[j]);
385:   }
386: }
387: 
388: 
389: /* バックプロパゲーション */
390: void bkp(int n){
391:   double dlo[NOUT], dlm[NMID];
392:   double wk;
393:   int i, j, k;
394: 
395:   for(i=0; i<nout; i++)
396:     dlo[i]=(xo[i]-ts[n+i])*fd(xo[i]);
397: 
398:   for(i=0; i<nmid; i++){
399:     wk=0.0;
400:     for(j=0; j<nout; j++)
401:       wk+=dlo[j]*wmo[i][j];
402:     dlm[i]=wk*fd(xm[i]);
403:   }
404: 
405:   for(i=0; i<nmid; i++){
406:     for(j=0; j<nin+1; j++){
407:       wk=-ra*dlm[i]*((j<nin)?xi[j]:1)+rb*old_dwim[j][i];
408:       old_dwim[j][i]=wk;
409:       wim[j][i]+=wk;
410:     }
411:   }
412: 
413:   for(i=0; i<nout; i++){
414:     for(j=0; j<nmid+1; j++){
415:       wk=-ra*dlo[i]*((j<nmid)?xm[j]:1)+rb*old_dwmo[j][i];
416:       old_dwmo[j][i]=wk;
417:       wmo[j][i]+=wk;
418:     }
419:   }
420: 
421:   if(flg & CHAOS){
422:     for(i=0; i<nmid; i++){
423:       kmdm[i]-=rc*dlm[i]*old_em[i]*fd(kmm[i]);
424:       kmm[i]=fs(kmdm[i]);
425:       krdm[i]-=rc*dlm[i]*old_zm[i]*fd(krm[i]);
426:       krm[i]=fs(krdm[i]);
427:       aldm[i]+=rc*dlm[i]*old_xm[i]*exp(aldm[i]/epsilon)/epsilon;
428:       alm[i]=exp(aldm[i]/epsilon);
429:     }
430: 
431:     for(i=0; i<nout; i++){
432:       kmdo[i]-=rc*dlo[i]*old_eo[i]*fd(kmo[i]);
433:       kmo[i]=fs(kmdo[i]);
434:       krdo[i]-=rc*dlo[i]*old_zo[i]*fd(kro[i]);
435:       kro[i]=fs(krdo[i]);
436:       aldo[i]+=rc*dlo[i]*old_xo[i]*exp(aldo[i]/epsilon)/epsilon;
437:       alo[i]=exp(aldo[i]/epsilon);
438:     }
439:   }
440: }
441: 
442: 
443: /* ネットワークの空回し */
444: void karamawasi(int n){
445:   int i, j, k;
446:   for(i=0; i<n; i++){
447:     for(j=0; j<tsmax; j+=2){
448:       for(k=0; k<nin; k++){
449:    if(j-nin+k>=0){
450:      xi[k]=ts[j-nin+k];
451:    }else{
452:      xi[k]=ts[tsmax+j-nin+k];
453:    }
454:       }
455:       keisan();
456:     }
457:   }
458: }



Deguchi Toshinori
1999年03月23日 (火) 15時43分49秒 JST