reluを分岐無しで実装する

ReLUとは、ご存知の通り、xが正の値ならxを返し、そうでなければ0になるという非常に単純な関数である。だがこれが深層学習において重要な役割を演じたのだから面白い。バックプロパゲーションをする時に微分係数が1より小さいよくある活性化関数だと指数的に係数が弱まって深層にすると学習できなくなるとか何とか。

同僚とReLUの話題になった時、私の頭には分岐無しで abs を実装するという話(『ハッカーのたのしみ』参照)が浮かんだ。現代的なプロセッサでは、分岐予測が働くので、予測を外した場合分岐は時間のかかる命令になる。なのでよく呼ばれる簡単な関数では、分岐を減らすことには十分な意味がある。 abs では何が行われていたかというと、bit演算をうまく組み合わせて一度も分岐をすることなく abs 関数を実装していたのだ。負の値だと最上位ビットが1になるがそうでなければ0であるという事実をうまく使って。

それが脳裏にあったので、ReLUも同様にbit演算を使えば分岐無しで実装できるのではないかと思った。

という訳でチャレンジしよう。まず、IEEE754を思い出すと、浮動小数点数は最上位ビットが符号ビットだ。なので負なら最上位が1、正なら0になっている。これは使える。というのも、int32_tにして算術右シフトを31回行えば、正なら 0 、負なら 0xFFFF.... の値になるからだ。

float x = 1.0;
std::cout << std::hex
    << (*reinterpret_cast<std::int32_t*>(x) >> 31) << std::endl;
x = -1.0;
std::cout << std::hex
    << (*reinterpret_cast<std::int32_t*>(x) >> 31) << std::endl;

続いて、これをうまく使って正負どちらの場合でも望みの結果が出るようなbit演算を考えたい。まず、負なら0になってほしいが、ゼロは指数部、仮数部共に0であるので、要するに全てのビットが0である(負の0は考えないことにする)。これを作るのは簡単そうだ。なにせ全てのビットが1の値を作ることができるので、例えば算術右シフトを31回行った結果とのbitwise orを考えよう。これだと、片方が1なら(負の場合)もう片方に関係なく1になる。片方が0なら(正の場合)もう片方と同じ値が出てくる。

筆算で書くとこうなる(便宜的に8bitとする。sは符号、eはexponent、fはfractional)。

    正の数    | 負の数
    seeeffff | seeeffff
    00101001 | 10101001
bor 00000000 | 11111111
------------ | --------
    00101001 | 11111111

負の数が全て1になり、正の数は保たれている。いい感じだ。

だが、負の数の場合に欲しい値は0で、そのbit表現は全てのbitが0というものだ。なのでbitを反転しなければならない。ただし、正の数の場合は反転してはいけない。要するに、次はbit二項演算の中から、0と行うと何も起きないが、1と行うとbitが反転するような操作を探さなければならない。

そしてそれは存在する。xorだ。片方が1だと(今回は負の数の場合)、もう片方が1なら0、0なら1とビットが反転する。片方が0だと(今回は正の数の場合)、もう片方が1なら1、0なら0と何も起きない。なのでこれが探していた演算だ。

という訳で以下のようになる。

    正の数    | 負の数
    seeeffff | seeeffff
    00101001 | 10101001
bor 00000000 | 11111111
------------ | --------
    00101001 | 11111111
xor 00000000 | 11111111
------------ | --------
    00101001 | 00000000

ReLUができたではないか。ここまで来ると、notしてandでもいいということに気づくが、まあそれは何でもいい。実装は以下のようになる。reinterpret_cast が長いので、floatstd::int32_t を変関する ftoiitof を実装したことにする。それはreinterpret_castでもできるし、 union を使って読み替えても良い。

float ReLU(float x)
{
    const auto y =  ftoi(x) >> 31;
    return itof((ftoi(x) | y) ^ y);
}

通常の ReLU はこうだ。

float ReLU_normal(float x)
{
    return (x > 0.0) ? x : 0.0;
}

ではどうなるか比べてみよう。コンパイルしたのちobjdumpで逆アセンブルする。

ReLU:
    movd %xmm0, %eax
    movd %xmm0, %edx
    sarl $31, %eax
    notl %eax
    andl %edx, %eax
    movd %eax, %xmm0
    retq

ReLU_normal:
    maxss    (%rip), %xmm0
    retq

はい。max って1命令でしたね! これは実際に逆アセンブルした時に気づいた。そして書いている間は ReLUmax(x, 0.0) であることも頭になかった。

確かに今回書いたコードには分岐はない。だが、それ以上に命令が多い。知っている限りmaxのレイテンシは1か2なので、シフトやビット命令をしているこちらの方が時間がかかってしまうだろう。

物事に熱中している間は根本的なところに気づかない、という良い例だった。皆さんもたまには立ち止まって出発地点を確かめましょう。