ファイル名:" cnn.c "
1: /*************************************/
2: /* 卒業研究プログラム 「cnn.c」 */
3: /* 平成9年度 電気工学科5年16番 */
4: /*************************************/
5:
6: #include <math.h>
7: #include <stdio.h>
8: #include <stdlib.h>
9: #include <malloc.h>
10: #include <local/bgi.h>
11:
12: /*****************************/
13: /* definition */
14: /*****************************/
15:
16: typedef struct {
17: double Sin, Sout, Sbk;
18: double Tin, Tout;
19: double t;
20: double Sm, Smbk, km, kmo, *w;
21: double Sr, Srbk, kr, kro, ar;
22: } Unit;
23:
24: /*****************************/
25: /* proto type */
26: /*****************************/
27:
28: double drand48();
29: int tgene();
30: int *tcheck( int );
31:
32: void fpatchread( char *, char * );
33: void fparamread( FILE *, int, int, Unit * );
34: void fparamwrite( FILE *, int, int, Unit * );
35: void fplotwrite( FILE *, int, Unit *, int );
36: void putbox( int, int, int, int, int );
37: void putpat( int, int, char * );
38: void putres( int, int, int, Unit *, int * );
39: void disparam( int, int, Unit *, int, Unit * );
40:
41: Unit *generate( int );
42: void conect( Unit *, int, Unit *, int );
43: void travel( Unit *, int, Unit *, int );
44: void learn( Unit *, int, Unit *, int );
45: void backp( Unit *, int, Unit *, int );
46: double sigmoid( double );
47: double difmoid( double );
48:
49: /*****************************/
50: /* main program */
51: /*****************************/
52:
53: double pa = 0.1;
54: double pb = 0.1;
55: double epsilon = 1.0;
56: double ps = 1.0;
57:
58: int tcn; /* 学習時の入力パターンの周期 */
59: int tin[120]; /* 学習時の入力パターン */
60: int tch[120][2]; /* 学習時の入力パターンに対する教師信号 */
61: char pac[6]; /* パターンの文字 */
62: char pat[12][100]; /* パターンの内容 */
63:
64: void main( int argc, char *argv[] ) {
65: int i, j, k, n, gd, gm;
66: float f;
67: double err = 0;
68: char buf[20], fbuf[20], fcnn[20], w1st = 0;
69: FILE *fp;
70:
71: int unum1, unum2, unum3;
72: Unit *unit1, *unit2, *unit3;
73:
74: /* x = 1060 = 5 + 55 * 6 + 6 * 120 + 5 > 940 */
75: /* y = 362 = 15 + 55 * 2 + 120 + 5 + 107 + 5 */
76: initgraph( &gd, &gm, "1060x362" );
77: settextfont( "6x12" );
78:
79: unit1 = generate( unum1 = 100 ); /* 入力層 */
80: unit2 = generate( unum2 = 50 ); /* 中間層 */
81: unit3 = generate( unum3 = 2 ); /* 出力層 */
82: conect( unit1, unum1, unit2, unum2 );
83: conect( unit2, unum2, unit3, unum3 );
84:
85: sprintf( fcnn, "cnn%s.", ( argc > 1 ? argv[1] : "0" ) );
86: strcpy( fbuf, fcnn );
87: strcat( fbuf, "tch" );
88: fpatchread( "cnn.pat", fbuf );
89: for( i = 0; i < tcn; i++ )
90: putres( 335 + i * 6, 0, tin[i], unit3, tch[i] );
91:
92: for( buf[0] = 0; buf[0] != 'q'; ) {
93: printf( "rwsltuadxq? > " );
94: scanf( "%s", buf );
95: switch( buf[0] ) {
96: case 'r': /* ファイル読込 */
97: strcpy( fbuf, fcnn );
98: strcat( fbuf, "wrt" );
99: if( ( fp = fopen( fbuf, "r" ) ) == NULL )
100: exit( -1 );
101: fparamread( fp, unum1, unum2, unit2 );
102: fparamread( fp, unum2, unum3, unit3 );
103: fclose( fp );
104: break;
105: case 'w': /* ファイル書出 */
106: strcpy( fbuf, fcnn );
107: strcat( fbuf, "wrt" );
108: if( ( fp = fopen( fbuf, "w" ) ) == NULL )
109: exit( -1 );
110: fparamwrite( fp, unum1, unum2, unit2 );
111: fparamwrite( fp, unum2, unum3, unit3 );
112: fclose( fp );
113:
114: strcpy( fbuf, fcnn );
115: strcat( fbuf, "plt" );
116: if( ( fp = fopen( fbuf, w1st == 0 ? "w" : "a" ) ) == NULL )
117: exit( -1 );
118: w1st = 1;
119: scanf( "%s", buf );
120: fprintf( fp, "#_%s_err=%f\n", buf, err );
121: fplotwrite( fp, unum2, unit2, 1 );
122: fplotwrite( fp, unum3, unit3, 51 );
123: fclose( fp );
124: printf( "wrote '%s'\n", buf );
125: break;
126: case 's': /* パラメータ設定 */
127: scanf( "%s %f", buf, &f );
128: switch( buf[0] ) {
129: case 'a':
130: pa = f;
131: break;
132: case 'b':
133: pb = f;
134: break;
135: case 's':
136: ps = f;
137: break;
138: }
139: printf( "now, alpha = %f, beta = %f", pa, pb );
140: printf( ", scale = %f\n", ps );
141: break;
142: case 'l': /* 学習ありループ */
143: case 't': /* 学習なしループ */
144: scanf( "%d", &n );
145: for( k = 0; k < n; k++ ) {
146: err = 0;
147: for( i = 0; i < tcn; i++ ) {
148: for( j = 0; j < unum1; j++ )
149: unit1[j].Sout = pat[ tin[i] ][j];
150: travel( unit1, unum1, unit2, unum2 );
151: travel( unit2, unum2, unit3, unum3 );
152: for( j = 0; j < unum3; j++ ) {
153: unit3[j].Tin = unit3[j].Sout - tch[i][j];
154: err += unit3[j].Tin * unit3[j].Tin;
155: }
156: if( buf[0] == 'l' ) {
157: learn( unit2, unum2, unit3, unum3 );
158: backp( unit2, unum2, unit3, unum3 );
159: learn( unit1, unum1, unit2, unum2 );
160: }
161: putres( 335 + i * 6, 0, tin[i], unit3, tch[i] );
162: }
163: if( k % 10 == 9 )
164: printf( "loop = %d\n", k + 1 );
165: }
166: err = sqrt( err / tcn );
167: printf( "err = %f\n", err );
168: break;
169: case 'u': /* 使う:自分で入力 */
170: case 'a': /* 使う:乱数で入力 */
171: for( i = 0; i < 175; i++ ) {
172: if( buf[0] == 'u' ) {
173: printf( "input-charactor : " );
174: scanf( "%s", &buf[1] );
175: if( !strcmp( &buf[1], "end" ) )
176: break;
177: for( j = 0; j < 6; j++ )
178: if( buf[1] == pac[j] )
179: break;
180: if( j == 6 ) {
181: printf( "wrong charactor\n" );
182: i--;
183: continue;
184: }
185: k = j * 2 + ( buf[2] == 0 ? 0 : 1 );
186: } else {
187: k = tgene();
188: }
189: for( j = 0; j < unum1; j++ )
190: unit1[j].Sout = pat[k][j];
191: travel( unit1, unum1, unit2, unum2 );
192: travel( unit2, unum2, unit3, unum3 );
193:
194: putres( 5 + i * 6, 125, k, unit3, tcheck( k ) );
195: if( buf[0] == 'u' && i == 174 )
196: i = -1;
197: }
198: break;
199: case 'd': /* Wの表示 */
200: disparam( unum1, unum2, unit2, unum3, unit3 );
201: break;
202: case 'x': /* 画面の保存 */
203: strcpy( fbuf, fcnn );
204: strcat( fbuf, "xbm" );
205: graph2xbm( fbuf );
206: break;
207: case '?': /* ヘルプ */
208: printf( " r : read \n" );
209: printf( " w (s) : write (s:title) \n" );
210: printf( " s (a,b,s) (n) : set (a:alpha, b:beta, s:scale) " );
211: printf( "(n:value) \n" );
212: printf( " l (n) : loop (n), learn \n" );
213: printf( " t (n) : loop (n), no learn \n" );
214: printf( " u : use, my input \n" );
215: printf( " a : use, random input \n" );
216: printf( " d : display w \n" );
217: printf( " x : save screen \n" );
218: printf( " q : quit \n" );
219: printf( " ? : help \n" );
220: break;
221: default:
222: break;
223: }
224: }
225: closegraph();
226: }
227:
228: int tgene() {
229: static int tk = 0, tf = 0;
230: static int d[6] = { 1, 1, 1, 2, 2, 3 };
231: tk = ( tk + d[ (int)( drand48() * 6 ) ] ) % 3;
232: tf = ( drand48() < 0.9 ? tf : 6 - tf );
233: return tk * 2 + tf + (int)( drand48() * 1.1 );
234: }
235:
236: int *tcheck( int tin ) {
237: static int tbk[3], res[2];
238: int i, j, k, k0, k1, k2;
239: tbk[0] = tbk[1];
240: tbk[1] = tbk[2];
241: tbk[2] = ( tin % 2 == 0 ? tin / 2 : -4 );
242: for( i = 0; i < 2; i++ ) {
243: k0 = ( tbk[0] + i * 3 ) % 6;
244: k1 = ( tbk[1] + i * 3 ) % 6;
245: k2 = ( tbk[2] + i * 3 ) % 6;
246: for( j = 0, k = 0; j < 3; j++ ) {
247: if( k0 == j && k1 == ( j + 1 ) % 3 && k2 == ( j + 2 ) % 3 )
248: k += 1;
249: }
250: res[i] = ( k > 0 ? 1 : 0 );
251: }
252: return res;
253: }
254:
255: /*****************************/
256: /* sub functions */
257: /*****************************/
258:
259: void fpatchread( char *fnpat, char *fntch ) {
260: int i, j;
261: char buf[20];
262: FILE *fp;
263:
264: if( ( fp = fopen( fnpat, "r" ) ) == NULL ) {
265: printf( "error : can't open the file '%s' \n", fnpat );
266: exit( -1 );
267: }
268: for( i = 0; i < 6; i++ ) {
269: fscanf( fp, "%s", buf );
270: pac[i] = buf[0];
271: for( j = 0; j < 100; j++ ) {
272: if( j % 10 == 0 )
273: fscanf( fp, "%s", buf );
274: pat[i*2 ][j] = ( buf[ j % 10 ] == '#' ? 1 : 0 );
275: pat[i*2+1][j] = pat[i*2][j] ^ 1;
276: }
277: putpat( i * 55 + 5, 15, pat[i*2 ] );
278: putpat( i * 55 + 5, 70, pat[i*2+1] );
279: }
280: fclose( fp );
281:
282: if( ( fp = fopen( fntch, "r" ) ) == NULL ) {
283: printf( "error : can't open the file '%s' \n", fntch );
284: exit( -1 );
285: }
286: fscanf( fp, "%d", &tcn );
287: for( i = 0; i < tcn; i++ ) {
288: fscanf( fp, "%s %d %d", buf, &tch[i][0], &tch[i][1] );
289: for( j = 0; j < 6; j++ )
290: if( buf[0] == pac[j] )
291: break;
292: tin[i] = j * 2 + ( buf[1] == 0 ? 0 : 1 );
293: }
294: fclose( fp );
295: }
296:
297: void fparamread( FILE *fp, int unum_in, int unum, Unit *unit ) {
298: int i, s = sizeof( double );
299: Unit *u;
300: for( i = 0; i < unum; i++ ) {
301: u = &unit[i];
302: fread( &u->km, s, 1, fp );
303: fread( &u->kr, s, 1, fp );
304: fread( &u->ar, s, 1, fp );
305: fread( &u->t, s, 1, fp );
306: fread( u->w, s, unum_in, fp );
307: u->kmo = epsilon * ( log( u->km ) - log( 1 - u->km ) );
308: u->kro = epsilon * ( log( u->kr ) - log( 1 - u->kr ) );
309: }
310: }
311:
312: void fparamwrite( FILE *fp, int unum_in, int unum, Unit *unit ) {
313: int i, s = sizeof( double );
314: Unit *u;
315: for( i = 0; i < unum; i++ ) {
316: u = &unit[i];
317: fwrite( &u->km, s, 1, fp );
318: fwrite( &u->kr, s, 1, fp );
319: fwrite( &u->ar, s, 1, fp );
320: fwrite( &u->t, s, 1, fp );
321: fwrite( u->w, s, unum_in, fp );
322: }
323: }
324:
325: void fplotwrite( FILE *fp, int unum, Unit *un, int j ) {
326: int i;
327: for( i = 0; i < unum; i++ ) {
328: fprintf( fp, "%d %f ", i + j, un[i].km );
329: fprintf( fp, "%f %f\n", un[i].kr, un[i].ar );
330: }
331: }
332:
333: void putbox( int sx, int sy, int nx, int ny, int col ) {
334: if( col ) {
335: bar( sx, sy, sx + nx, sy + ny );
336: } else {
337: rectangle( sx, sy, sx + nx - 1, sy + ny - ( ny != 0 ? 1 : 0 ) );
338: }
339: }
340:
341: void putpat( int sx, int sy, char pat[] ) {
342: int x, y, i = 0;
343: for( y = 0; y < 10; y++ )
344: for( x = 0; x < 10; x++ )
345: putbox( sx + x * 5, sy + y * 5, 4, 4, pat[i++] == 1 );
346: }
347:
348: void putres( int sx, int sy, int tin, Unit *unit, int tch[] ) {
349: char buf[] = " ";
350: int i, y;
351: double d;
352: setcolor( BLACK );
353: putbox( sx, sy, 6, 120, 0 == 0 );
354: setcolor( WHITE );
355: buf[0] = pac[ tin / 2 ];
356: outtextxy( sx, sy + 12, buf );
357: if( tin % 2 == 1 )
358: putbox( sx, sy + 2, 4, 1, 0 == 0 );
359: for( i = 0; i < 2; i++ ) {
360: y = sy + 15 + i * 55;
361: d = unit[i].Sout;
362: putbox( sx, y, 5, 4, tch[i] != 0 );
363: putbox( sx, y + 5, 5, (int)( d * 45 ), d >= 0.5 );
364: }
365: }
366:
367: void disparam( int unum1, int unum2, Unit *unit2, int unum3, Unit *unit3 ) {
368: int i, j, x, y, sx = 20, sy = 250;
369: for( i = 0; i < unum2; i++ ) {
370: x = sx + i * 18;
371: setcolor( BLACK );
372: putbox( x, sy, 18, unum1 + 5 + unum3, 0 == 0 );
373: setcolor( WHITE );
374: x = x + 9;
375: for( j = 0; j < unum1; j++ ) {
376: y = sy + j;
377: bar( x, y, x + (int)( unit2[i].w[j] * 9 / ps ), y + 1 );
378: }
379: for( j = 0; j < unum3; j++ ) {
380: y = sy + j + unum1 + 5;
381: bar( x, y, x + (int)( unit3[j].w[i] * 9 / ps ), y + 1 );
382: }
383: }
384: }
385:
386: /*****************************/
387: /* cnn functions */
388: /*****************************/
389:
390: Unit *generate( int n ) {
391: int i;
392: Unit *unit, *u;
393: unit = (Unit *)calloc( n, sizeof( Unit ) );
394: if( unit == NULL ) {
395: printf( "error : can't allocate memory \n" );
396: exit( -1 );
397: }
398: for( i = 0; i < n; i++ ) {
399: u = &unit[i];
400: u->Sin = 0; u->Sout = 0; u->Sbk = 0;
401: u->Tin = 0; u->Tout = 0;
402: u->t = 0;
403: u->Sm = 0; u->Smbk = 0; u->w = NULL;
404: u->Sr = 0; u->Srbk = 0; u->ar = drand48() * 2 + 1;
405: u->km = sigmoid( u->kmo = drand48() * 3 - 1.5 );
406: u->kr = sigmoid( u->kro = drand48() * 3 - 1.5 );
407: }
408: return unit;
409: }
410:
411: void conect( Unit *unit_in, int unum_in, Unit *unit, int unum ) {
412: int i, j;
413: Unit *u;
414: double *w;
415: for( i = 0; i < unum; i++ ) {
416: u = &unit[i];
417: w = (double *)calloc( unum_in, sizeof( double ) );
418: if( w == NULL ) {
419: printf( "error : can't allocate memory \n" );
420: exit( -1 );
421: }
422: for( j = 0; j < unum_in; j++ ) {
423: w[j] = drand48() * 1 - 0.5;
424: }
425: u->w = w;
426: u->t = drand48() * 1 - 0.5;
427: }
428: }
429:
430: void travel( Unit *unit_in, int unum_in, Unit *unit, int unum ) {
431: int i, j;
432: Unit *u;
433: for( i = 0; i < unum; i++ ) {
434: u = &unit[i];
435: u->Sbk = u->Sout;
436: u->Srbk = u->Sr;
437: u->Smbk = u->Sm;
438:
439: u->Sr = u->kr * u->Sr - u->ar * u->Sout;
440: u->Sm = u->km * u->Sm;
441: for( j = 0; j < unum_in; j++ ) {
442: u->Sm += u->w[j] * unit_in[j].Sout;
443: }
444: u->Sin = u->Sr + u->Sm + u->t;
445: u->Sout = sigmoid( u->Sin );
446: }
447: }
448:
449: void learn( Unit *unit_in, int unum_in, Unit *unit, int unum ) {
450: int i, j;
451: Unit *u;
452: for( i = 0; i < unum; i++ ) {
453: u = &unit[i];
454: u->Tout = u->Tin * difmoid( u->Sout );
455: for( j = 0; j < unum_in; j++ ) {
456: u->w[j] -= pa * u->Tout * unit_in[j].Sout;
457: }
458: u->t -= pa * u->Tout;
459: u->kmo -= pb * u->Tout * u->Smbk * difmoid( u->km );
460: u->kro -= pb * u->Tout * u->Srbk * difmoid( u->kr );
461: u->ar -= pb * u->Tout * ( - u->Sbk );
462: u->km = sigmoid( u->kmo );
463: u->kr = sigmoid( u->kro );
464: if( u->ar < 0 )
465: u->ar = 0;
466: }
467: }
468:
469: void backp( Unit *unit, int unum, Unit *unit_out, int unum_out ) {
470: int i, j;
471: Unit *u;
472: for( i = 0; i < unum; i++ ) {
473: u = &unit[i];
474: u->Tin = 0;
475: for( j = 0; j < unum_out; j++ ) {
476: u->Tin += unit_out[j].Tout * unit_out[j].w[i];
477: }
478: }
479: }
480:
481: double sigmoid( double x ) {
482: return 1 / ( 1 + exp( - x / epsilon ) );
483: }
484:
485: double difmoid( double x ) {
486: return x * ( 1 - x ) / epsilon;
487: }
488:
489: /*****************************/
490: /* program end */
491: /*****************************/