ファイル名:" cnn.c
"
1: /*************************************/ 2: /* 卒業研究プログラム 「cnn.c」 */ 3: /* 平成9年度 電気工学科5年16番 */ 4: /*************************************/ 5: 6: #include <math.h> 7: #include <stdio.h> 8: #include <stdlib.h> 9: #include <malloc.h> 10: #include <local/bgi.h> 11: 12: /*****************************/ 13: /* definition */ 14: /*****************************/ 15: 16: typedef struct { 17: double Sin, Sout, Sbk; 18: double Tin, Tout; 19: double t; 20: double Sm, Smbk, km, kmo, *w; 21: double Sr, Srbk, kr, kro, ar; 22: } Unit; 23: 24: /*****************************/ 25: /* proto type */ 26: /*****************************/ 27: 28: double drand48(); 29: int tgene(); 30: int *tcheck( int ); 31: 32: void fpatchread( char *, char * ); 33: void fparamread( FILE *, int, int, Unit * ); 34: void fparamwrite( FILE *, int, int, Unit * ); 35: void fplotwrite( FILE *, int, Unit *, int ); 36: void putbox( int, int, int, int, int ); 37: void putpat( int, int, char * ); 38: void putres( int, int, int, Unit *, int * ); 39: void disparam( int, int, Unit *, int, Unit * ); 40: 41: Unit *generate( int ); 42: void conect( Unit *, int, Unit *, int ); 43: void travel( Unit *, int, Unit *, int ); 44: void learn( Unit *, int, Unit *, int ); 45: void backp( Unit *, int, Unit *, int ); 46: double sigmoid( double ); 47: double difmoid( double ); 48: 49: /*****************************/ 50: /* main program */ 51: /*****************************/ 52: 53: double pa = 0.1; 54: double pb = 0.1; 55: double epsilon = 1.0; 56: double ps = 1.0; 57: 58: int tcn; /* 学習時の入力パターンの周期 */ 59: int tin[120]; /* 学習時の入力パターン */ 60: int tch[120][2]; /* 学習時の入力パターンに対する教師信号 */ 61: char pac[6]; /* パターンの文字 */ 62: char pat[12][100]; /* パターンの内容 */ 63: 64: void main( int argc, char *argv[] ) { 65: int i, j, k, n, gd, gm; 66: float f; 67: double err = 0; 68: char buf[20], fbuf[20], fcnn[20], w1st = 0; 69: FILE *fp; 70: 71: int unum1, unum2, unum3; 72: Unit *unit1, *unit2, *unit3; 73: 74: /* x = 1060 = 5 + 55 * 6 + 6 * 120 + 5 > 940 */ 75: /* y = 362 = 15 + 55 * 2 + 120 + 5 + 107 + 5 */ 76: initgraph( &gd, &gm, "1060x362" ); 77: settextfont( "6x12" ); 78: 79: unit1 = generate( unum1 = 100 ); /* 入力層 */ 80: unit2 = generate( unum2 = 50 ); /* 中間層 */ 81: unit3 = generate( unum3 = 2 ); /* 出力層 */ 82: conect( unit1, unum1, unit2, unum2 ); 83: conect( unit2, unum2, unit3, unum3 ); 84: 85: sprintf( fcnn, "cnn%s.", ( argc > 1 ? argv[1] : "0" ) ); 86: strcpy( fbuf, fcnn ); 87: strcat( fbuf, "tch" ); 88: fpatchread( "cnn.pat", fbuf ); 89: for( i = 0; i < tcn; i++ ) 90: putres( 335 + i * 6, 0, tin[i], unit3, tch[i] ); 91: 92: for( buf[0] = 0; buf[0] != 'q'; ) { 93: printf( "rwsltuadxq? > " ); 94: scanf( "%s", buf ); 95: switch( buf[0] ) { 96: case 'r': /* ファイル読込 */ 97: strcpy( fbuf, fcnn ); 98: strcat( fbuf, "wrt" ); 99: if( ( fp = fopen( fbuf, "r" ) ) == NULL ) 100: exit( -1 ); 101: fparamread( fp, unum1, unum2, unit2 ); 102: fparamread( fp, unum2, unum3, unit3 ); 103: fclose( fp ); 104: break; 105: case 'w': /* ファイル書出 */ 106: strcpy( fbuf, fcnn ); 107: strcat( fbuf, "wrt" ); 108: if( ( fp = fopen( fbuf, "w" ) ) == NULL ) 109: exit( -1 ); 110: fparamwrite( fp, unum1, unum2, unit2 ); 111: fparamwrite( fp, unum2, unum3, unit3 ); 112: fclose( fp ); 113: 114: strcpy( fbuf, fcnn ); 115: strcat( fbuf, "plt" ); 116: if( ( fp = fopen( fbuf, w1st == 0 ? "w" : "a" ) ) == NULL ) 117: exit( -1 ); 118: w1st = 1; 119: scanf( "%s", buf ); 120: fprintf( fp, "#_%s_err=%f\n", buf, err ); 121: fplotwrite( fp, unum2, unit2, 1 ); 122: fplotwrite( fp, unum3, unit3, 51 ); 123: fclose( fp ); 124: printf( "wrote '%s'\n", buf ); 125: break; 126: case 's': /* パラメータ設定 */ 127: scanf( "%s %f", buf, &f ); 128: switch( buf[0] ) { 129: case 'a': 130: pa = f; 131: break; 132: case 'b': 133: pb = f; 134: break; 135: case 's': 136: ps = f; 137: break; 138: } 139: printf( "now, alpha = %f, beta = %f", pa, pb ); 140: printf( ", scale = %f\n", ps ); 141: break; 142: case 'l': /* 学習ありループ */ 143: case 't': /* 学習なしループ */ 144: scanf( "%d", &n ); 145: for( k = 0; k < n; k++ ) { 146: err = 0; 147: for( i = 0; i < tcn; i++ ) { 148: for( j = 0; j < unum1; j++ ) 149: unit1[j].Sout = pat[ tin[i] ][j]; 150: travel( unit1, unum1, unit2, unum2 ); 151: travel( unit2, unum2, unit3, unum3 ); 152: for( j = 0; j < unum3; j++ ) { 153: unit3[j].Tin = unit3[j].Sout - tch[i][j]; 154: err += unit3[j].Tin * unit3[j].Tin; 155: } 156: if( buf[0] == 'l' ) { 157: learn( unit2, unum2, unit3, unum3 ); 158: backp( unit2, unum2, unit3, unum3 ); 159: learn( unit1, unum1, unit2, unum2 ); 160: } 161: putres( 335 + i * 6, 0, tin[i], unit3, tch[i] ); 162: } 163: if( k % 10 == 9 ) 164: printf( "loop = %d\n", k + 1 ); 165: } 166: err = sqrt( err / tcn ); 167: printf( "err = %f\n", err ); 168: break; 169: case 'u': /* 使う:自分で入力 */ 170: case 'a': /* 使う:乱数で入力 */ 171: for( i = 0; i < 175; i++ ) { 172: if( buf[0] == 'u' ) { 173: printf( "input-charactor : " ); 174: scanf( "%s", &buf[1] ); 175: if( !strcmp( &buf[1], "end" ) ) 176: break; 177: for( j = 0; j < 6; j++ ) 178: if( buf[1] == pac[j] ) 179: break; 180: if( j == 6 ) { 181: printf( "wrong charactor\n" ); 182: i--; 183: continue; 184: } 185: k = j * 2 + ( buf[2] == 0 ? 0 : 1 ); 186: } else { 187: k = tgene(); 188: } 189: for( j = 0; j < unum1; j++ ) 190: unit1[j].Sout = pat[k][j]; 191: travel( unit1, unum1, unit2, unum2 ); 192: travel( unit2, unum2, unit3, unum3 ); 193: 194: putres( 5 + i * 6, 125, k, unit3, tcheck( k ) ); 195: if( buf[0] == 'u' && i == 174 ) 196: i = -1; 197: } 198: break; 199: case 'd': /* Wの表示 */ 200: disparam( unum1, unum2, unit2, unum3, unit3 ); 201: break; 202: case 'x': /* 画面の保存 */ 203: strcpy( fbuf, fcnn ); 204: strcat( fbuf, "xbm" ); 205: graph2xbm( fbuf ); 206: break; 207: case '?': /* ヘルプ */ 208: printf( " r : read \n" ); 209: printf( " w (s) : write (s:title) \n" ); 210: printf( " s (a,b,s) (n) : set (a:alpha, b:beta, s:scale) " ); 211: printf( "(n:value) \n" ); 212: printf( " l (n) : loop (n), learn \n" ); 213: printf( " t (n) : loop (n), no learn \n" ); 214: printf( " u : use, my input \n" ); 215: printf( " a : use, random input \n" ); 216: printf( " d : display w \n" ); 217: printf( " x : save screen \n" ); 218: printf( " q : quit \n" ); 219: printf( " ? : help \n" ); 220: break; 221: default: 222: break; 223: } 224: } 225: closegraph(); 226: } 227: 228: int tgene() { 229: static int tk = 0, tf = 0; 230: static int d[6] = { 1, 1, 1, 2, 2, 3 }; 231: tk = ( tk + d[ (int)( drand48() * 6 ) ] ) % 3; 232: tf = ( drand48() < 0.9 ? tf : 6 - tf ); 233: return tk * 2 + tf + (int)( drand48() * 1.1 ); 234: } 235: 236: int *tcheck( int tin ) { 237: static int tbk[3], res[2]; 238: int i, j, k, k0, k1, k2; 239: tbk[0] = tbk[1]; 240: tbk[1] = tbk[2]; 241: tbk[2] = ( tin % 2 == 0 ? tin / 2 : -4 ); 242: for( i = 0; i < 2; i++ ) { 243: k0 = ( tbk[0] + i * 3 ) % 6; 244: k1 = ( tbk[1] + i * 3 ) % 6; 245: k2 = ( tbk[2] + i * 3 ) % 6; 246: for( j = 0, k = 0; j < 3; j++ ) { 247: if( k0 == j && k1 == ( j + 1 ) % 3 && k2 == ( j + 2 ) % 3 ) 248: k += 1; 249: } 250: res[i] = ( k > 0 ? 1 : 0 ); 251: } 252: return res; 253: } 254: 255: /*****************************/ 256: /* sub functions */ 257: /*****************************/ 258: 259: void fpatchread( char *fnpat, char *fntch ) { 260: int i, j; 261: char buf[20]; 262: FILE *fp; 263: 264: if( ( fp = fopen( fnpat, "r" ) ) == NULL ) { 265: printf( "error : can't open the file '%s' \n", fnpat ); 266: exit( -1 ); 267: } 268: for( i = 0; i < 6; i++ ) { 269: fscanf( fp, "%s", buf ); 270: pac[i] = buf[0]; 271: for( j = 0; j < 100; j++ ) { 272: if( j % 10 == 0 ) 273: fscanf( fp, "%s", buf ); 274: pat[i*2 ][j] = ( buf[ j % 10 ] == '#' ? 1 : 0 ); 275: pat[i*2+1][j] = pat[i*2][j] ^ 1; 276: } 277: putpat( i * 55 + 5, 15, pat[i*2 ] ); 278: putpat( i * 55 + 5, 70, pat[i*2+1] ); 279: } 280: fclose( fp ); 281: 282: if( ( fp = fopen( fntch, "r" ) ) == NULL ) { 283: printf( "error : can't open the file '%s' \n", fntch ); 284: exit( -1 ); 285: } 286: fscanf( fp, "%d", &tcn ); 287: for( i = 0; i < tcn; i++ ) { 288: fscanf( fp, "%s %d %d", buf, &tch[i][0], &tch[i][1] ); 289: for( j = 0; j < 6; j++ ) 290: if( buf[0] == pac[j] ) 291: break; 292: tin[i] = j * 2 + ( buf[1] == 0 ? 0 : 1 ); 293: } 294: fclose( fp ); 295: } 296: 297: void fparamread( FILE *fp, int unum_in, int unum, Unit *unit ) { 298: int i, s = sizeof( double ); 299: Unit *u; 300: for( i = 0; i < unum; i++ ) { 301: u = &unit[i]; 302: fread( &u->km, s, 1, fp ); 303: fread( &u->kr, s, 1, fp ); 304: fread( &u->ar, s, 1, fp ); 305: fread( &u->t, s, 1, fp ); 306: fread( u->w, s, unum_in, fp ); 307: u->kmo = epsilon * ( log( u->km ) - log( 1 - u->km ) ); 308: u->kro = epsilon * ( log( u->kr ) - log( 1 - u->kr ) ); 309: } 310: } 311: 312: void fparamwrite( FILE *fp, int unum_in, int unum, Unit *unit ) { 313: int i, s = sizeof( double ); 314: Unit *u; 315: for( i = 0; i < unum; i++ ) { 316: u = &unit[i]; 317: fwrite( &u->km, s, 1, fp ); 318: fwrite( &u->kr, s, 1, fp ); 319: fwrite( &u->ar, s, 1, fp ); 320: fwrite( &u->t, s, 1, fp ); 321: fwrite( u->w, s, unum_in, fp ); 322: } 323: } 324: 325: void fplotwrite( FILE *fp, int unum, Unit *un, int j ) { 326: int i; 327: for( i = 0; i < unum; i++ ) { 328: fprintf( fp, "%d %f ", i + j, un[i].km ); 329: fprintf( fp, "%f %f\n", un[i].kr, un[i].ar ); 330: } 331: } 332: 333: void putbox( int sx, int sy, int nx, int ny, int col ) { 334: if( col ) { 335: bar( sx, sy, sx + nx, sy + ny ); 336: } else { 337: rectangle( sx, sy, sx + nx - 1, sy + ny - ( ny != 0 ? 1 : 0 ) ); 338: } 339: } 340: 341: void putpat( int sx, int sy, char pat[] ) { 342: int x, y, i = 0; 343: for( y = 0; y < 10; y++ ) 344: for( x = 0; x < 10; x++ ) 345: putbox( sx + x * 5, sy + y * 5, 4, 4, pat[i++] == 1 ); 346: } 347: 348: void putres( int sx, int sy, int tin, Unit *unit, int tch[] ) { 349: char buf[] = " "; 350: int i, y; 351: double d; 352: setcolor( BLACK ); 353: putbox( sx, sy, 6, 120, 0 == 0 ); 354: setcolor( WHITE ); 355: buf[0] = pac[ tin / 2 ]; 356: outtextxy( sx, sy + 12, buf ); 357: if( tin % 2 == 1 ) 358: putbox( sx, sy + 2, 4, 1, 0 == 0 ); 359: for( i = 0; i < 2; i++ ) { 360: y = sy + 15 + i * 55; 361: d = unit[i].Sout; 362: putbox( sx, y, 5, 4, tch[i] != 0 ); 363: putbox( sx, y + 5, 5, (int)( d * 45 ), d >= 0.5 ); 364: } 365: } 366: 367: void disparam( int unum1, int unum2, Unit *unit2, int unum3, Unit *unit3 ) { 368: int i, j, x, y, sx = 20, sy = 250; 369: for( i = 0; i < unum2; i++ ) { 370: x = sx + i * 18; 371: setcolor( BLACK ); 372: putbox( x, sy, 18, unum1 + 5 + unum3, 0 == 0 ); 373: setcolor( WHITE ); 374: x = x + 9; 375: for( j = 0; j < unum1; j++ ) { 376: y = sy + j; 377: bar( x, y, x + (int)( unit2[i].w[j] * 9 / ps ), y + 1 ); 378: } 379: for( j = 0; j < unum3; j++ ) { 380: y = sy + j + unum1 + 5; 381: bar( x, y, x + (int)( unit3[j].w[i] * 9 / ps ), y + 1 ); 382: } 383: } 384: } 385: 386: /*****************************/ 387: /* cnn functions */ 388: /*****************************/ 389: 390: Unit *generate( int n ) { 391: int i; 392: Unit *unit, *u; 393: unit = (Unit *)calloc( n, sizeof( Unit ) ); 394: if( unit == NULL ) { 395: printf( "error : can't allocate memory \n" ); 396: exit( -1 ); 397: } 398: for( i = 0; i < n; i++ ) { 399: u = &unit[i]; 400: u->Sin = 0; u->Sout = 0; u->Sbk = 0; 401: u->Tin = 0; u->Tout = 0; 402: u->t = 0; 403: u->Sm = 0; u->Smbk = 0; u->w = NULL; 404: u->Sr = 0; u->Srbk = 0; u->ar = drand48() * 2 + 1; 405: u->km = sigmoid( u->kmo = drand48() * 3 - 1.5 ); 406: u->kr = sigmoid( u->kro = drand48() * 3 - 1.5 ); 407: } 408: return unit; 409: } 410: 411: void conect( Unit *unit_in, int unum_in, Unit *unit, int unum ) { 412: int i, j; 413: Unit *u; 414: double *w; 415: for( i = 0; i < unum; i++ ) { 416: u = &unit[i]; 417: w = (double *)calloc( unum_in, sizeof( double ) ); 418: if( w == NULL ) { 419: printf( "error : can't allocate memory \n" ); 420: exit( -1 ); 421: } 422: for( j = 0; j < unum_in; j++ ) { 423: w[j] = drand48() * 1 - 0.5; 424: } 425: u->w = w; 426: u->t = drand48() * 1 - 0.5; 427: } 428: } 429: 430: void travel( Unit *unit_in, int unum_in, Unit *unit, int unum ) { 431: int i, j; 432: Unit *u; 433: for( i = 0; i < unum; i++ ) { 434: u = &unit[i]; 435: u->Sbk = u->Sout; 436: u->Srbk = u->Sr; 437: u->Smbk = u->Sm; 438: 439: u->Sr = u->kr * u->Sr - u->ar * u->Sout; 440: u->Sm = u->km * u->Sm; 441: for( j = 0; j < unum_in; j++ ) { 442: u->Sm += u->w[j] * unit_in[j].Sout; 443: } 444: u->Sin = u->Sr + u->Sm + u->t; 445: u->Sout = sigmoid( u->Sin ); 446: } 447: } 448: 449: void learn( Unit *unit_in, int unum_in, Unit *unit, int unum ) { 450: int i, j; 451: Unit *u; 452: for( i = 0; i < unum; i++ ) { 453: u = &unit[i]; 454: u->Tout = u->Tin * difmoid( u->Sout ); 455: for( j = 0; j < unum_in; j++ ) { 456: u->w[j] -= pa * u->Tout * unit_in[j].Sout; 457: } 458: u->t -= pa * u->Tout; 459: u->kmo -= pb * u->Tout * u->Smbk * difmoid( u->km ); 460: u->kro -= pb * u->Tout * u->Srbk * difmoid( u->kr ); 461: u->ar -= pb * u->Tout * ( - u->Sbk ); 462: u->km = sigmoid( u->kmo ); 463: u->kr = sigmoid( u->kro ); 464: if( u->ar < 0 ) 465: u->ar = 0; 466: } 467: } 468: 469: void backp( Unit *unit, int unum, Unit *unit_out, int unum_out ) { 470: int i, j; 471: Unit *u; 472: for( i = 0; i < unum; i++ ) { 473: u = &unit[i]; 474: u->Tin = 0; 475: for( j = 0; j < unum_out; j++ ) { 476: u->Tin += unit_out[j].Tout * unit_out[j].w[i]; 477: } 478: } 479: } 480: 481: double sigmoid( double x ) { 482: return 1 / ( 1 + exp( - x / epsilon ) ); 483: } 484: 485: double difmoid( double x ) { 486: return x * ( 1 - x ) / epsilon; 487: } 488: 489: /*****************************/ 490: /* program end */ 491: /*****************************/