2009年2月11日

リファクタリング-5-

さて。一番重たい部分を計算するプログラム公開の前に。どうやら問題を理解してもらえていないようなので、もう一度書いておく。

m 個のデータ列がある。各データ列は N 個のデータからなる。N個のデータはm個のデータ列の間で互いにどれとどれが対応するのかは判っているとする。

これらのデータ列同士の相関を「総当りで」求めたい。相関係数 r は r(x,y) と r(y,x) (ただし、x, y はデータ列) は同じ値になるので、m個のデータ列同士の総当たり戦は m*(m-1)/2 通りになる。これを
面倒なコーディングを可能な限り行わない
形で、実用的な速度で解きたい。

R であれ、ライブラリであれ、モジュールであれ、1組2個のデータ列同士の相関を求める関数は存在するし、それはどれを使っても相応に早い。しかし「総当たり戦」の場合、その関数を m*(m-1)/2 回呼び出さなくてはいけない。

相関を求める式には、データ列単位で計算しておけばそれを使いまわせるものが結構ある。平均値を求めたり、各データと当該平均値との差分を求めたり、それら差分の二乗値の和の平方根を取ったりした結果などがそうだ。
ある任意のデータ列は、(m-1)/2 回相関を求めるために呼び出される。と言うことは、上記の値は、最初に1回だけ計算しておけば ((m-1)/2) - 1 回分、計算をしなくてよい。

このような最適化を施したライブラリは、いまだに存在しない。というか、存在しなくは無いのだが、メモリを食いすぎるため、私のPCでは遅すぎるのだ(32bit環境だと、そもそもアドレス空間が足りなくて core 吐いて終わる。64bit 環境は、メモリが足りないため swap IO ばかりで計算が全然前に進まない)。計算環境が64bit環境だけで作れて、メモリが潤沢にあれば問題は無いのだが…

その一方で。この相関は何のためにとるのかと言うと、「Windows の perfmon の出力から、何か性能劣化の兆候が見えないか?」とか「性能が劣化しているのだが、何が原因かわからないか?」という分析をする際の、調査対象に優先度をつけるために行っている。つまり計算誤差なんぞ多少あっても構わないのだ。まじめな数値計算ではなく、誤差が多少あってもいいから高速に結果を出して欲しい。

実用になる範囲であれば、perl 辺りでライブラリを叩けば十分だったのだが、速度が遅すぎる以上 C でコーディングをするしかないわけだ。

ちなみに、mは 27000.Nが 675。 perl で組んでいたときは、1組当たり 9msec かかっていた。つまり37.9673.. 38日かかる計算になる。これが C だと 37分…まぁIOの関係で40分ぐらい。「値が変化しない」とわかりきっている部分が半分近くあり、それらについては相関は絶対 NaN になると判っているのでメモリも消費しなくなった。すると32bit環境でも一発で計算できるではないか…。NIH(Not Invented Here)症候群と言われようが、自分でコーディングする気になるというものだよ。


次は一番重たい部分を計算するCのプログラム

correlation.c
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <math.h>

#define DEBUG (0)



typedef struct __data {
struct __data *next;
struct __data *prev;
char *element_name;
int correlation;
double sum;
double mean;
double sqsum;
double subpart;
long numdata;
double *datalist;
} data;


data* read_each( char *listfilename );
data* read_datafile( char *linebuffer );
void add_data_array( data **datalistpp, data *new_datap );
void calculate( data *data_list );
void remove_trailing_newline( char *linebuffer );

void help( char *programname );
void *o_malloc( size_t size, const char *msg, const char *func, const long line );



int
main( int argc, char *argv[] )
{
data *data_list;
if ( argc != 2 ) {
help( argv[0] );
}

data_list = read_each( argv[1] );
if ( data_list == NULL ) {
fprintf( stderr, "no data given.\n" );
exit( 1 );
}

calculate( data_list );
}

data*
read_each( char *listfilename )
{
FILE *fp;
char *linebuffer;
int linebufsize;
data *read_datap = NULL;


linebufsize = getpagesize();
linebuffer = o_malloc( linebufsize, "unable to prepare linebuffer\n", __func__, __LINE__ );

fp = fopen( listfilename, "r" );
if ( fp == NULL ) {
fprintf( stderr, "unable to open file %s\n", listfilename );
exit( 1 );
}

while( fgets( linebuffer, linebufsize, fp ) != NULL ) {
data *new_datap;
long len;

if ( linebuffer[0] == '#' ) {
continue;
}

remove_trailing_newline( linebuffer );

new_datap = read_datafile( linebuffer );
if ( new_datap->correlation != 0 ) {
add_data_array( &read_datap, new_datap );
} else {
/* take it away from the memory.*/
free( new_datap->datalist );
free( new_datap->element_name );
free( new_datap );
}
}
fclose( fp );

return read_datap;
}

data*
read_datafile( char *datafilename )
{
data *newdata;
FILE *fp;
char *linebuffer;
size_t linebufsize;
char *tagname;


newdata = o_malloc( sizeof( data ), "unable to allocate memory for newdata\n", __func__, __LINE__ );

linebufsize = getpagesize();
linebuffer = o_malloc( linebufsize, "unable to prepare linebuffer for read_datafile()\n", __func__, __LINE__ );

fp = fopen( datafilename, "r" );
if ( fp == NULL ) {
fprintf( stderr, "unable to open file %s as data file target\n", datafilename );
exit( 1 );
}


/* read tagname */
{
char *p;
long dataoffset;

if( fgets( linebuffer, linebufsize, fp ) == NULL ) {
fprintf( stderr, "unable to read tagname at file %s\n", datafilename );
exit( 1 );
}
remove_trailing_newline( linebuffer );

p = strstr( linebuffer, "tagname\t" );
if ( p != linebuffer ) {
fprintf( stderr, "unable to find \"tagname\t\" indicator in %s\n", datafilename );
exit( 1 );
}

dataoffset = strlen( "tagname\t" );

newdata->element_name = strdup( linebuffer + dataoffset );
if ( newdata->element_name == NULL ) {
fprintf( stderr, "unable to copy tagname %s in %s\n",
linebuffer + dataoffset, datafilename );
exit( 1 );
}
#if DEBUG
fprintf( stderr, "tagname was %s\n", newdata->element_name );
#endif/*DEBUG*/
}

/* read Correlation */
{
char *p;
long dataoffset;

if( fgets( linebuffer, linebufsize, fp ) == NULL ) {
fprintf( stderr, "unable to read Correlation at file %s\n", datafilename );
exit( 1 );
}
p = strstr( linebuffer, "Correlation\t" );
if ( p != linebuffer ) {
fprintf( stderr, "unable to find \"Correlation\t\" indicator in %s\n", datafilename );
exit( 1 );
}
dataoffset = strlen( "Correlation\t" );
if ( strncasecmp( linebuffer + dataoffset, "OK", 2 ) != 0 ) {
newdata->correlation = 0;
} else {
newdata->correlation = 1;
}
#if DEBUG
fprintf( stderr, "correlation was %d\n", newdata->correlation );
#endif/*DEBUG*/
}


/* read sum */
{
char *p;
long dataoffset;
int res;

if( fgets( linebuffer, linebufsize, fp ) == NULL ) {
fprintf( stderr, "unable to read sum at file %s\n", datafilename );
exit( 1 );
}
p = strstr( linebuffer, "sum\t" );
if ( p != linebuffer ) {
fprintf( stderr, "unable to find \"sum\t\" indicator in %s\n", datafilename );
exit( 1 );
}
dataoffset = strlen( "sum\t" );
newdata->sum = 0.0;
res = sscanf( linebuffer + dataoffset, "%lf", &( newdata->sum ));
if ( res != 1 ) {
perror( "sscanf failed when reading sum.\n" );
exit( 1 );
}
#if DEBUG
fprintf( stderr, "sum was %lf\n", newdata->sum );
#endif/*DEBUG*/
}


/* read mean */
{
char *p;
long dataoffset;
int res;

if( fgets( linebuffer, linebufsize, fp ) == NULL ) {
fprintf( stderr, "unable to read mean at file %s\n", datafilename );
exit( 1 );
}
p = strstr( linebuffer, "mean\t" );
if ( p != linebuffer ) {
fprintf( stderr, "unable to find \"mean\t\" indicator in %s\n", datafilename );
exit( 1 );
}
dataoffset = strlen( "mean\t" );
newdata->mean = 0.0;
res = sscanf( linebuffer + dataoffset, "%lf", &( newdata->mean ));
if ( res != 1 ) {
perror( "sscanf failed when reading mean.\n" );
exit( 1 );
}
#if DEBUG
fprintf( stderr, "mean was %lf\n", newdata->mean );
#endif/*DEBUG*/
}


/* read sqsum */
{
char *p;
long dataoffset;
int res;

if( fgets( linebuffer, linebufsize, fp ) == NULL ) {
fprintf( stderr, "unable to read sqsum at file %s\n", datafilename );
exit( 1 );
}
p = strstr( linebuffer, "sqsum\t" );
if ( p != linebuffer ) {
fprintf( stderr, "unable to find \"sqsum\t\" indicator in %s\n", datafilename );
exit( 1 );
}
dataoffset = strlen( "sqsum\t" );
newdata->sqsum = 0.0;
res = sscanf( linebuffer + dataoffset, "%lf", &( newdata->sqsum ));
if ( res != 1 ) {
perror( "sscanf failed when reading sqsum.\n" );
exit( 1 );
}
#if DEBUG
fprintf( stderr, "sqsum was %lf\n", newdata->sqsum );
#endif/*DEBUG*/
}


/* read subpart */
{
char *p;
long dataoffset;
int res;

if( fgets( linebuffer, linebufsize, fp ) == NULL ) {
fprintf( stderr, "unable to read subpart at file %s\n", datafilename );
exit( 1 );
}
p = strstr( linebuffer, "subpart\t" );
if ( p != linebuffer ) {
fprintf( stderr, "unable to find \"subpart\t\" indicator in %s\n", datafilename );
exit( 1 );
}
dataoffset = strlen( "subpart\t" );
newdata->subpart = 0.0;
res = sscanf( linebuffer + dataoffset, "%lf", &( newdata->subpart ));
if ( res != 1 ) {
perror( "sscanf failed when reading subpart.\n" );
exit( 1 );
}
#if DEBUG
fprintf( stderr, "subpart was %lf\n", newdata->subpart );
#endif/*DEBUG*/
}


/* read datas */
{
char *p;
long dataoffset;
int res;

if( fgets( linebuffer, linebufsize, fp ) == NULL ) {
fprintf( stderr, "unable to read datas at file %s\n", datafilename );
exit( 1 );
}
p = strstr( linebuffer, "datas\t" );
if ( p != linebuffer ) {
fprintf( stderr, "unable to find \"datas\t\" indicator in %s\n", datafilename );
exit( 1 );
}
dataoffset = strlen( "datas\t" );
newdata->numdata = 0;
res = sscanf( linebuffer + dataoffset, "%ld", &( newdata->numdata ));
if ( res != 1 ) {
perror( "sscanf failed when reading datas.\n" );
exit( 1 );
}
#if DEBUG
fprintf( stderr, "datas was %d\n", newdata->numdata );
#endif/*DEBUG*/
}

/* read data list */
{
long i;
int res;

newdata->datalist
= malloc( sizeof( double ) * newdata->numdata );
if ( newdata->datalist == NULL ) {
fprintf( stderr, "not enough memory for datalist: %s\n", datafilename );
exit( 1 );
}

if( fgets( linebuffer, linebufsize, fp ) == NULL ) {
fprintf( stderr, "unable to read empty line at file %s\n", datafilename );
exit( 1 );
}

for ( i = 0; i < newdata->numdata; i++ ) {
if( fgets( linebuffer, linebufsize, fp ) == NULL ) {
fprintf( stderr, "unable to read datas at file %s\n", datafilename );
exit( 1 );
}
res = sscanf( linebuffer, "%lf", &( newdata->datalist[i] ));
/* if we failed in reading data, that means that element should be 0.0 */
if ( res != 1 ) {
newdata->datalist[i] = 0.0;
}

#if DEBUG
fprintf( stderr, "datas was %lf\n", newdata->datalist[i] );
#endif/*DEBUG*/
}
}

fclose( fp );
return newdata;
}

void
add_data_array( data **datalistpp, data *new_datap )
{
if (( *datalistpp ) == NULL ) {
new_datap->next = NULL;
new_datap->prev = NULL;
(*datalistpp) = new_datap;
} else {
new_datap->next = (*datalistpp);
new_datap->prev = NULL;
(*datalistpp)->prev = new_datap;
(*datalistpp) = new_datap;
}
}


void calculate( data *data_list )
{
data *one, *two;
double cor, cor2;
int i;

/* First, change all the data in datalist[] to be subtraction of each's mean */
for ( one = data_list; one != NULL; one = one->next ) {
for ( i = 0; i < one->numdata; i++ ) {
one->datalist[i] -= one->mean;
}
}


for ( one = data_list; one != NULL; one = one->next ) {
for ( two = one->next; two != NULL; two = two->next ) {
cor = 0.0;

for ( i = 0; i < one->numdata; i++ ) {
cor += one->datalist[i] * two->datalist[i];
}
cor /= ( one->subpart * two->subpart );
cor2 = cor * cor;

fprintf( stdout,
"%s\t%s\t%lf\t%lf\n",
one->element_name,
two->element_name,
cor, cor2 );
}
}
}

void remove_trailing_newline( char *linebuffer )
{
size_t len;

while( 1 ) {
len = strlen( linebuffer );
if ( linebuffer[len-1] == '\n' ) {
linebuffer[len-1] = '\0';
} else {
break;
}
}
}


void
help( char *programname )
{
fprintf( stderr,
"%s <filename containing target>\n",
programname );
exit( 127 );
}

void*
o_malloc( size_t size, const char *msg, const char *func, const long line )
{
void *p;
p = malloc( size );
if ( p == NULL ) {
fprintf( stderr, "%s:%ld: %s", func, line, msg );
exit( 1 );
}
}

みての通り、ファイルを読んでくるところの処理が延々。実際の計算部分は calculate() の部分だけ。しかもこれすら標準出力への出力コードが含まれているのに、このサイズ。あぁ、Cの最大の弱点はファイルIOの記述/メモリ管理の記述が面倒くさいことだなぁ。

ちなみに。「C++の例外処理で書けばいいじゃない」という提案をいただいた。が、その手段は使えないし、趣味として使わない。例外処理は素晴らしい機能だ。ただし、あれは「コード的には正しいが、動作環境との兼ね合いで不幸に至った場合」に、どこで不幸を食らったかに依存せず「不幸を食らった事実だけ」を書けばよいから素晴らしいのだ。

しかし、上記のコード、多分端々に見えているだろうが、「おかしくなるのは私のコーディングが間違っているから」というケースに対応するためのものだ。実際に発生した現象としては:
  1. tagname を引っ張ってきたのはいいが、行末の "\n" を削り忘れていた。
  2. Correlation のOK を case insensitive にチェックしているが、実は最初のコードでは "OK" に対して入力は "OK\n" で合致しない攻撃をくらいまくっていた。
  3. Correlation が NG の場合、計算対象からはずすためにリンクリストに登録しないのだが、そのメモリを「放置」していた(メモリリークだ)。64bit環境では特に問題は無かった(ちょっと swap IO するだけ)のだが、32bit環境に持ってきたらパンクした。
  4. Correlation 自身はOKであっても、データが全フィールドに渡って存在するとは限らない。存在しないエントリは 0.0 として扱うことになっていたのだが、sscanf() の戻り値をチェックするときにそのことを完全に失念していて、おかしな部分で give up して exit() していた。
まぁ、そういうケースに対応するための「デバッガコード」であって、異常系ではなかったりする。例外処理だと、「どこでおかしいのか」も判らないし、特に 1 のケースなどは assert() にもひっかけようがない。

先にも書いたが、1データ組み合わせ当たり 6.74usec で計算できた。はーやーいー。9msec とは比べ物にならない。やっぱりコンパイラ言語は本気を出すとすごいわ。

4 件のコメント:

  1. 「ある任意のデータ列は、(m-1)/2 回相関を求めるために呼び出される。と言うことは、上記の値は、最初に1回だけ計算しておけば ((m-1)/2) - 1 回分、計算をしなくてよい。」ことの効果は約3倍程度で、後の数百倍がPerl->Cの効果ですよね。
     ループの最内部までPerlなのをその部分はちゃんとコンパイラ言語で作ってあるライブラリーなら数十倍ですんでいそうというものでしょう。
     が、新しいコンピュータが必要な口実にはCできっちり最適化して限界ですと言ったほうがいいかもしれませんね。:P

    返信削除
  2. ちがいますぅ。
    Perl->Cの効果が1500倍もあるんですぅ。日記に書いてないだけで、「計算をしないバージョン」は作りました。それの性能が9msec(これをやらないと14msec)。

    新しいコンピュータが必要なのは、メモリ不足と並列度不足ですね。メモリが十分あれば swap IO は生じませんし、そうなれば multi core がものを言うようにコーディングできます。あとはSSEまわりですかね。gcc にはベクトル演算ができるように記述する方法があるようなので…とは言え、そちらはもう少し時間がある時に。

    返信削除
  3. あ「Perl->Cの効果が1500倍もあるんですぅ。」これはいいんです。要は「データ列ごとに計算した数字を使いまわす」最適化は多分3倍程度しか効果が無くってその最適化を含まない単純にデータ列同士の相関を計算するCプログラムで今のPerlのプログラムの数百倍のスピードが出るだろうと。で、外側のオーバーヘッドもあるでしょうが、Nが大きくて、ベクターデータとして一気に処理するCorFortranのルーチンに渡す処理系なら、ループの内側までPure Perlなプログラムの数十倍まではいけたんではないかと(でも実行時間半日~1日はかかるかな、、)。それで足りない以上Cに結局いくのでしょうね。算法を注意して計算量を減らすのもオーダーが変わらないような最適化だと意外とたいした効果が無くてとっととコンパイラを使えという反教科書的な例ですね。

    返信削除
  4. まぁ、計算のオーダー自体は最初からどうしようもありませんからねぇ。
    ですが、これを「反教科書的な例」と考えるのは早計のような気がします。relog.exeの出力するテキストファイルは(TSVが一番扱いやすいとはいえ)C言語で扱いやすいものではありません。メモリ管理・Token処理など、げんなりするコーディングの山になります。
    「どこまでを文字列処理・メモリ管理が自動化されているLightweight Language で実行するか」
    「どこからをコンパイラ型言語にするか」
    「どのように区分けすると具合が良いのか」
    などを考える、という意味ではむしろ教科書的かと。

    たとえば、私はO(m*N)の部分は perl で計算しています。で、このO(m*N)の部分には「データの個数を数える」というポイントが含まれています。ご存知のようにC言語は「いくつ来るか判らないデータ」を効率よく扱うのが苦手です(というか配列とかが効率よすぎ。SSEとかの計算を含めるとさらに…)。こう「勘定」は perl でやっておいたほうが楽になる。で、そうなると平均値ぐらいまでは perl で計算してもそうは変わらない。

    標準偏差まで求めておくと、Cでデータを読む段階で、「こいつとの計算結果は必ず NaN になる」ものが判ります。なのでその場合はデータ自体を読み込まない、という最適化がC側でできる。つまり O(m*m)の m の値そのものを小さくできる。そのためには標準偏差までは perl で計算しておいたほうがよい。また、標準偏差が別途ファイルに記載されることで、何か別の兆候を人間が見つけられるかもしれない。

    このように「どこまでで何をすることによって、どのように最適化されるか。だからどこまでは遅い処理系でやっても価値があるか」という判断をする、というアルゴリズムを考える上では、良い教材だと思います。

    返信削除