next up previous contents
Next: この文書について ... Up: 無題 Previous: 参考文献

付録A カオスニューロン学習プログラム

ファイル名:" cnn.c "

  1: /*************************************/
  2: /* 卒業研究プログラム 「cnn.c」 */
  3: /*   平成9年度 電気工学科5年16番 */
  4: /*************************************/
  5: 
  6: #include        <math.h>
  7: #include        <stdio.h>
  8: #include        <stdlib.h>
  9: #include        <malloc.h>
 10: #include        <local/bgi.h>
 11: 
 12: /*****************************/
 13: /* definition                */
 14: /*****************************/
 15: 
 16: typedef struct {
 17:         double  Sin, Sout, Sbk;
 18:         double  Tin, Tout;
 19:         double  t;
 20:         double  Sm, Smbk, km, kmo, *w;
 21:         double  Sr, Srbk, kr, kro, ar;
 22:     } Unit;
 23: 
 24: /*****************************/
 25: /* proto type                */
 26: /*****************************/
 27: 
 28: double  drand48();
 29: int     tgene();
 30: int     *tcheck( int );
 31: 
 32: void    fpatchread( char *, char * );
 33: void    fparamread( FILE *, int, int, Unit * );
 34: void    fparamwrite( FILE *, int, int, Unit * );
 35: void    fplotwrite( FILE *, int, Unit *, int );
 36: void    putbox( int, int, int, int, int );
 37: void    putpat( int, int, char * );
 38: void    putres( int, int, int, Unit *, int * );
 39: void    disparam( int, int, Unit *, int, Unit * );
 40: 
 41: Unit    *generate( int );
 42: void    conect( Unit *, int, Unit *, int );
 43: void    travel( Unit *, int, Unit *, int );
 44: void    learn( Unit *, int, Unit *, int );
 45: void    backp( Unit *, int, Unit *, int );
 46: double  sigmoid( double );
 47: double  difmoid( double );
 48: 
 49: /*****************************/
 50: /* main program              */
 51: /*****************************/
 52: 
 53: double  pa = 0.1;
 54: double  pb = 0.1;
 55: double  epsilon = 1.0;
 56: double  ps = 1.0;
 57: 
 58: int     tcn;                    /* 学習時の入力パターンの周期 */
 59: int     tin[120];               /* 学習時の入力パターン */
 60: int     tch[120][2];            /* 学習時の入力パターンに対する教師信号 */
 61: char    pac[6];                 /* パターンの文字 */
 62: char    pat[12][100];           /* パターンの内容 */
 63: 
 64: void main( int argc, char *argv[] ) {
 65:         int     i, j, k, n, gd, gm;
 66:         float   f;
 67:         double  err = 0;
 68:         char    buf[20], fbuf[20], fcnn[20], w1st = 0;
 69:         FILE    *fp;
 70: 
 71:         int     unum1, unum2, unum3;
 72:         Unit    *unit1, *unit2, *unit3;
 73: 
 74:         /* x = 1060 =  5 + 55 * 6 + 6 * 120 + 5 > 940 */
 75:         /* y =  362 = 15 + 55 * 2 + 120 + 5 + 107 + 5 */
 76:         initgraph( &gd, &gm, "1060x362" );
 77:         settextfont( "6x12" );
 78: 
 79:         unit1 = generate( unum1 = 100 );                /* 入力層 */
 80:         unit2 = generate( unum2 = 50 );                 /* 中間層 */
 81:         unit3 = generate( unum3 = 2 );                  /* 出力層 */
 82:         conect( unit1, unum1, unit2, unum2 );
 83:         conect( unit2, unum2, unit3, unum3 );
 84: 
 85:         sprintf( fcnn, "cnn%s.", ( argc > 1 ? argv[1] : "0" ) );
 86:         strcpy( fbuf, fcnn );
 87:         strcat( fbuf, "tch" );
 88:         fpatchread( "cnn.pat", fbuf );
 89:         for( i = 0; i < tcn; i++ )
 90:             putres( 335 + i * 6, 0, tin[i], unit3, tch[i] );
 91: 
 92:         for( buf[0] = 0; buf[0] != 'q'; ) {
 93:             printf( "rwsltuadxq? > " );
 94:             scanf( "%s", buf );
 95:             switch( buf[0] ) {
 96:               case 'r':                                 /* ファイル読込 */
 97:                 strcpy( fbuf, fcnn );
 98:                 strcat( fbuf, "wrt" );
 99:                 if( ( fp = fopen( fbuf, "r" ) ) == NULL )
100:                     exit( -1 );
101:                 fparamread( fp, unum1, unum2, unit2 );
102:                 fparamread( fp, unum2, unum3, unit3 );
103:                 fclose( fp );
104:                 break;
105:               case 'w':                                 /* ファイル書出 */
106:                 strcpy( fbuf, fcnn );
107:                 strcat( fbuf, "wrt" );
108:                 if( ( fp = fopen( fbuf, "w" ) ) == NULL )
109:                     exit( -1 );
110:                 fparamwrite( fp, unum1, unum2, unit2 );
111:                 fparamwrite( fp, unum2, unum3, unit3 );
112:                 fclose( fp );
113: 
114:                 strcpy( fbuf, fcnn );
115:                 strcat( fbuf, "plt" );
116:                 if( ( fp = fopen( fbuf, w1st == 0 ? "w" : "a" ) ) == NULL )
117:                     exit( -1 );
118:                 w1st = 1;
119:                 scanf( "%s", buf );
120:                 fprintf( fp, "#_%s_err=%f\n", buf, err );
121:                 fplotwrite( fp, unum2, unit2,  1 );
122:                 fplotwrite( fp, unum3, unit3, 51 );
123:                 fclose( fp );
124:                 printf( "wrote '%s'\n", buf );
125:                 break;
126:               case 's':                                 /* パラメータ設定 */
127:                 scanf( "%s %f", buf, &f );
128:                 switch( buf[0] ) {
129:                   case 'a':
130:                     pa = f;
131:                     break;
132:                   case 'b':
133:                     pb = f;
134:                     break;
135:                   case 's':
136:                     ps = f;
137:                     break;
138:                 }
139:                 printf( "now, alpha = %f, beta = %f", pa, pb );
140:                 printf( ", scale = %f\n", ps );
141:                 break;
142:               case 'l':                                 /* 学習ありループ */
143:               case 't':                                 /* 学習なしループ */
144:                 scanf( "%d", &n );
145:                 for( k = 0; k < n; k++ ) {
146:                     err = 0;
147:                     for( i = 0; i < tcn; i++ ) {
148:                         for( j = 0; j < unum1; j++ )
149:                             unit1[j].Sout = pat[ tin[i] ][j];
150:                         travel( unit1, unum1, unit2, unum2 );
151:                         travel( unit2, unum2, unit3, unum3 );
152:                         for( j = 0; j < unum3; j++ ) {
153:                             unit3[j].Tin = unit3[j].Sout - tch[i][j];
154:                             err += unit3[j].Tin * unit3[j].Tin;
155:                         }
156:                         if( buf[0] == 'l' ) {
157:                             learn( unit2, unum2, unit3, unum3 );
158:                             backp( unit2, unum2, unit3, unum3 );
159:                             learn( unit1, unum1, unit2, unum2 );
160:                         }
161:                         putres( 335 + i * 6, 0, tin[i], unit3, tch[i] );
162:                     }
163:                     if( k % 10 == 9 )
164:                     printf( "loop = %d\n", k + 1 );
165:                 }
166:                 err = sqrt( err / tcn );
167:                 printf( "err = %f\n", err );
168:                 break;
169:               case 'u':                                 /* 使う:自分で入力 */
170:               case 'a':                                 /* 使う:乱数で入力 */
171:                 for( i = 0; i < 175; i++ ) {
172:                     if( buf[0] == 'u' ) {
173:                         printf( "input-charactor : " );
174:                         scanf( "%s", &buf[1] );
175:                         if( !strcmp( &buf[1], "end" ) )
176:                             break;
177:                         for( j = 0; j < 6; j++ )
178:                             if( buf[1] == pac[j] )
179:                                 break;
180:                         if( j == 6 ) {
181:                             printf( "wrong charactor\n" );
182:                             i--;
183:                             continue;
184:                         }
185:                         k = j * 2 + ( buf[2] == 0 ? 0 : 1 );
186:                     } else {
187:                         k = tgene();
188:                     }
189:                     for( j = 0; j < unum1; j++ )
190:                         unit1[j].Sout = pat[k][j];
191:                     travel( unit1, unum1, unit2, unum2 );
192:                     travel( unit2, unum2, unit3, unum3 );
193: 
194:                     putres( 5 + i * 6, 125, k, unit3, tcheck( k ) );
195:                     if( buf[0] == 'u' && i == 174 )
196:                         i = -1;
197:                 }
198:                 break;
199:               case 'd':                                 /* Wの表示 */
200:                 disparam( unum1, unum2, unit2, unum3, unit3 );
201:                 break;
202:               case 'x':                                 /* 画面の保存 */
203:                 strcpy( fbuf, fcnn );
204:                 strcat( fbuf, "xbm" );
205:                 graph2xbm( fbuf );
206:                 break;
207:               case '?':                                 /* ヘルプ */
208:                 printf( " r             : read \n" );
209:                 printf( " w (s)         : write (s:title) \n" );
210:                 printf( " s (a,b,s) (n) : set (a:alpha, b:beta, s:scale) " );
211:                 printf( "(n:value) \n" );
212:                 printf( " l (n)         : loop (n), learn \n" );
213:                 printf( " t (n)         : loop (n), no learn \n" );
214:                 printf( " u             : use, my input \n" );
215:                 printf( " a             : use, random input \n" );
216:                 printf( " d             : display w \n" );
217:                 printf( " x             : save screen \n" );
218:                 printf( " q             : quit \n" );
219:                 printf( " ?             : help \n" );
220:                 break;
221:               default:
222:                 break;
223:             }
224:         }
225:         closegraph();
226: }
227: 
228: int tgene() {
229:         static  int     tk = 0, tf = 0;
230:         static  int     d[6] = { 1, 1, 1, 2, 2, 3 };
231:         tk = ( tk + d[ (int)( drand48() * 6 ) ] ) % 3;
232:         tf = ( drand48() < 0.9 ? tf : 6 - tf );
233:         return tk * 2 + tf + (int)( drand48() * 1.1 );
234: }
235: 
236: int *tcheck( int tin ) {
237:         static  int     tbk[3], res[2];
238:         int     i, j, k, k0, k1, k2;
239:         tbk[0] = tbk[1];
240:         tbk[1] = tbk[2];
241:         tbk[2] = ( tin % 2 == 0 ? tin / 2 : -4 );
242:         for( i = 0; i < 2; i++ ) {
243:             k0 = ( tbk[0] + i * 3 ) % 6;
244:             k1 = ( tbk[1] + i * 3 ) % 6;
245:             k2 = ( tbk[2] + i * 3 ) % 6;
246:             for( j = 0, k = 0; j < 3; j++ ) {
247:                 if( k0 == j && k1 == ( j + 1 ) % 3 && k2 == ( j + 2 ) % 3 )
248:                     k += 1;
249:             }
250:             res[i] = ( k > 0 ? 1 : 0 );
251:         }
252:         return res;
253: }
254: 
255: /*****************************/
256: /* sub functions             */
257: /*****************************/
258: 
259: void fpatchread( char *fnpat, char *fntch ) {
260:         int     i, j;
261:         char    buf[20];
262:         FILE    *fp;
263: 
264:         if( ( fp = fopen( fnpat, "r" ) ) == NULL ) {
265:             printf( "error : can't open the file '%s' \n", fnpat );
266:             exit( -1 );
267:         }
268:         for( i = 0; i < 6; i++ ) {
269:             fscanf( fp, "%s", buf );
270:             pac[i] = buf[0];
271:             for( j = 0; j < 100; j++ ) {
272:                 if( j % 10 == 0 )
273:                     fscanf( fp, "%s", buf );
274:                 pat[i*2  ][j] = ( buf[ j % 10 ] == '#' ? 1 : 0 );
275:                 pat[i*2+1][j] = pat[i*2][j] ^ 1;
276:             }
277:             putpat( i * 55 + 5, 15, pat[i*2  ] );
278:             putpat( i * 55 + 5, 70, pat[i*2+1] );
279:         }
280:         fclose( fp );
281: 
282:         if( ( fp = fopen( fntch, "r" ) ) == NULL ) {
283:             printf( "error : can't open the file '%s' \n", fntch );
284:             exit( -1 );
285:         }
286:         fscanf( fp, "%d", &tcn );
287:         for( i = 0; i < tcn; i++ ) {
288:             fscanf( fp, "%s %d %d", buf, &tch[i][0], &tch[i][1] );
289:             for( j = 0; j < 6; j++ )
290:                 if( buf[0] == pac[j] )
291:                     break;
292:             tin[i] = j * 2 + ( buf[1] == 0 ? 0 : 1 );
293:         }
294:         fclose( fp );
295: }
296: 
297: void fparamread( FILE *fp, int unum_in, int unum, Unit *unit ) {
298:         int     i, s = sizeof( double );
299:         Unit    *u;
300:         for( i = 0; i < unum; i++ ) {
301:             u = &unit[i];
302:             fread( &u->km, s, 1, fp );
303:             fread( &u->kr, s, 1, fp );
304:             fread( &u->ar, s, 1, fp );
305:             fread( &u->t,  s, 1, fp );
306:             fread(  u->w,  s, unum_in, fp );
307:             u->kmo = epsilon * ( log( u->km ) - log( 1 - u->km ) );
308:             u->kro = epsilon * ( log( u->kr ) - log( 1 - u->kr ) );
309:         }
310: }
311: 
312: void fparamwrite( FILE *fp, int unum_in, int unum, Unit *unit ) {
313:         int     i, s = sizeof( double );
314:         Unit    *u;
315:         for( i = 0; i < unum; i++ ) {
316:             u = &unit[i];
317:             fwrite( &u->km, s, 1, fp );
318:             fwrite( &u->kr, s, 1, fp );
319:             fwrite( &u->ar, s, 1, fp );
320:             fwrite( &u->t,  s, 1, fp );
321:             fwrite(  u->w,  s, unum_in, fp );
322:         }
323: }
324: 
325: void fplotwrite( FILE *fp, int unum, Unit *un, int j ) {
326:         int     i;
327:         for( i = 0; i < unum; i++ ) {
328:             fprintf( fp, "%d %f ", i + j, un[i].km );
329:             fprintf( fp, "%f %f\n", un[i].kr, un[i].ar );
330:         }
331: }
332: 
333: void putbox( int sx, int sy, int nx, int ny, int col ) {
334:         if( col ) {
335:             bar( sx, sy, sx + nx, sy + ny );
336:         } else {
337:             rectangle( sx, sy, sx + nx - 1, sy + ny - ( ny != 0 ? 1 : 0 ) );
338:         }
339: }
340: 
341: void putpat( int sx, int sy, char pat[] ) {
342:         int     x, y, i = 0;
343:         for( y = 0; y < 10; y++ )
344:             for( x = 0; x < 10; x++ )
345:                 putbox( sx + x * 5, sy + y * 5, 4, 4, pat[i++] == 1 );
346: }
347: 
348: void putres( int sx, int sy, int tin, Unit *unit, int tch[] ) {
349:         char    buf[] = " ";
350:         int     i, y;
351:         double  d;
352:         setcolor( BLACK );
353:         putbox( sx, sy, 6, 120, 0 == 0 );
354:         setcolor( WHITE );
355:         buf[0] = pac[ tin / 2 ];
356:         outtextxy( sx, sy + 12, buf );
357:         if( tin % 2 == 1 )
358:             putbox( sx, sy + 2, 4, 1, 0 == 0 );
359:         for( i = 0; i < 2; i++ ) {
360:             y = sy + 15 + i * 55;
361:             d = unit[i].Sout;
362:             putbox( sx, y, 5, 4, tch[i] != 0 );
363:             putbox( sx, y + 5, 5, (int)( d * 45 ), d >= 0.5 );
364:         }
365: }
366: 
367: void disparam( int unum1, int unum2, Unit *unit2, int unum3, Unit *unit3 ) {
368:         int     i, j, x, y, sx = 20, sy = 250;
369:         for( i = 0; i < unum2; i++ ) {
370:             x = sx + i * 18;
371:             setcolor( BLACK );
372:             putbox( x, sy, 18, unum1 + 5 + unum3, 0 == 0 );
373:             setcolor( WHITE );
374:             x = x + 9;
375:             for( j = 0; j < unum1; j++ ) {
376:                 y = sy + j;
377:                 bar( x, y, x + (int)( unit2[i].w[j] * 9 / ps ), y + 1 );
378:             }
379:             for( j = 0; j < unum3; j++ ) {
380:                 y = sy + j + unum1 + 5;
381:                 bar( x, y, x + (int)( unit3[j].w[i] * 9 / ps ), y + 1 );
382:             }
383:         }
384: }
385: 
386: /*****************************/
387: /* cnn functions             */
388: /*****************************/
389: 
390: Unit *generate( int n ) {
391:         int     i;
392:         Unit    *unit, *u;
393:         unit = (Unit *)calloc( n, sizeof( Unit ) );
394:         if( unit == NULL ) {
395:             printf( "error : can't allocate memory \n" );
396:             exit( -1 );
397:         }
398:         for( i = 0; i < n; i++ ) {
399:             u = &unit[i];
400:             u->Sin = 0; u->Sout = 0; u->Sbk = 0;
401:             u->Tin = 0; u->Tout = 0;
402:             u->t  = 0;
403:             u->Sm = 0; u->Smbk = 0; u->w  = NULL;
404:             u->Sr = 0; u->Srbk = 0; u->ar = drand48() * 2 + 1;
405:             u->km = sigmoid( u->kmo = drand48() * 3 - 1.5 );
406:             u->kr = sigmoid( u->kro = drand48() * 3 - 1.5 );
407:         }
408:         return unit;
409: }
410: 
411: void conect( Unit *unit_in, int unum_in, Unit *unit, int unum ) {
412:         int     i, j;
413:         Unit    *u;
414:         double  *w;
415:         for( i = 0; i < unum; i++ ) {
416:             u = &unit[i];
417:             w = (double *)calloc( unum_in, sizeof( double ) );
418:             if( w == NULL ) {
419:                 printf( "error : can't allocate memory \n" );
420:                 exit( -1 );
421:             }
422:             for( j = 0; j < unum_in; j++ ) {
423:                 w[j] = drand48() * 1 - 0.5;
424:             }
425:             u->w = w;
426:             u->t = drand48() * 1 - 0.5;
427:         }
428: }
429: 
430: void travel( Unit *unit_in, int unum_in, Unit *unit, int unum ) {
431:         int     i, j;
432:         Unit    *u;
433:         for( i = 0; i < unum; i++ ) {
434:             u = &unit[i];
435:             u->Sbk  = u->Sout;
436:             u->Srbk = u->Sr;
437:             u->Smbk = u->Sm;
438: 
439:             u->Sr = u->kr * u->Sr - u->ar * u->Sout;
440:             u->Sm = u->km * u->Sm;
441:             for( j = 0; j < unum_in; j++ ) {
442:                 u->Sm += u->w[j] * unit_in[j].Sout;
443:             }
444:             u->Sin  = u->Sr + u->Sm + u->t;
445:             u->Sout = sigmoid( u->Sin );
446:         }
447: }
448: 
449: void learn( Unit *unit_in, int unum_in, Unit *unit, int unum ) {
450:         int     i, j;
451:         Unit    *u;
452:         for( i = 0; i < unum; i++ ) {
453:             u = &unit[i];
454:             u->Tout = u->Tin * difmoid( u->Sout );
455:             for( j = 0; j < unum_in; j++ ) {
456:                  u->w[j] -= pa * u->Tout * unit_in[j].Sout;
457:             }
458:             u->t   -= pa * u->Tout;
459:             u->kmo -= pb * u->Tout * u->Smbk * difmoid( u->km );
460:             u->kro -= pb * u->Tout * u->Srbk * difmoid( u->kr );
461:             u->ar  -= pb * u->Tout * ( - u->Sbk );
462:             u->km = sigmoid( u->kmo );
463:             u->kr = sigmoid( u->kro );
464:             if( u->ar < 0 )
465:                 u->ar = 0;
466:         }
467: }
468: 
469: void backp( Unit *unit, int unum, Unit *unit_out, int unum_out ) {
470:         int     i, j;
471:         Unit    *u;
472:         for( i = 0; i < unum; i++ ) {
473:             u = &unit[i];
474:             u->Tin = 0;
475:             for( j = 0; j < unum_out; j++ ) {
476:                 u->Tin += unit_out[j].Tout * unit_out[j].w[i];
477:             }
478:         }
479: }
480: 
481: double sigmoid( double x ) {
482:         return 1 / ( 1 + exp( - x / epsilon ) );
483: }
484: 
485: double difmoid( double x ) {
486:         return x * ( 1 - x ) / epsilon;
487: }
488: 
489: /*****************************/
490: /* program end               */
491: /*****************************/



Deguchi Toshinori
1998年04月01日 (水) 12時03分23秒 JST