Bamba news

KAN (Kolmogorov-Arnold Networks)を実装してみよう!わかりやすく解説

【スマホからでも実行可能】AIの新しい波、KAN(Kolmogorov-Arnold Networks)の実装方法を、具体的なPythonコード付きで一から丁寧に解説。本記事では、KANの理論的な背景から、実際の回帰問題への適用、モデルの解釈、そしてチューニングまでを網羅。機械学習の新しいアーキテクチャを実践的に学びたい方に最適です。


はじめに

近年、ニューラルネットワークの世界に新しい風を吹き込んでいる「KAN(Kolmogorov-Arnold Networks)」というアーキテクチャが注目を集めています。従来の多層パーセプトロン(MLP)とは一線を画すその構造は、高い精度と、これまでブラックボックスとされがちだったモデルの「解釈性」を両立する可能性を秘めています。

この記事は、KANを実際に動かしてみたいと考えるエンジニアや研究者、学生の皆さんに向けて、Pythonライブラリ pykan を用いた具体的な実装方法をステップバイステップで解説する実践編です。理論的な難解さを極力排除し、コードとその意味を理解することに焦点を当てます。

数学的な背景や基本的な概念については、入門編や活用例編で詳しく解説するとして、この実践編では「とにかく動かしてみる」ことを目指します。

記事の最後に、環境構築なしでスマホからでも即実行可能なGoogle Colabノートブックをご用意しています。


KANとは?MLPとの違いを簡単に

実装に入る前に、KANが従来のニューラルネットワーク(特にMLP)と何が違うのか、その核心だけを簡単におさらいしましょう。

  • MLP (Multi-Layer Perceptron): ニューロン(ノード)に固定の活性化関数(ReLUやSigmoidなど)を持ち、ノード間をつなぐエッジの重みを学習します。計算は「線形変換(重み付け)→非線形変換(活性化関数)」の繰り返しです。

  • KAN (Kolmogorov-Arnold Network): ニューロン(ノード)は単純な足し算を行うだけで、ノード間をつなぐエッジ上に学習可能な活性化関数を持ちます。つまり、KANはエッジ上の「関数の形」そのものをデータから学習します。

この構造的な違いが、KANの大きな特徴である「高い精度」と「モデルの解釈性」を生み出しています。エッジで学習された関数を可視化することで、モデルが入力と出力の間にどのような関係性を見出したのかを直感的に理解できるのです。

(出典: KAN: Kolmogorov-Arnold Networks 論文より)


実践:pykan を使ってみよう

それでは、実際にPythonライブラリ pykan を使ってKANモデルを構築し、学習させるプロセスを見ていきましょう。今回は、特定の数式で表されるデータセットをKANに学習させ、もとの数式を再発見できるか試す「回帰問題」に取り組みます。

1. 環境構築とライブラリのインストール

まず、pykan ライブラリをインストールします。Google ColabやJupyter Notebook環境では、以下のコマンドを実行します。

# pykanライブラリをpipでインストールします
!pip install pykan

必要なライブラリもインポートしておきましょう。

import torch
from kan import *

# KANモデルのインスタンスを作成
# width: 各層のニューロン数をリストで指定
# grid: 活性化関数を表現するスプラインの分割数
# k: スプラインの次数
model = KAN(width=[2, 5, 1], grid=5, k=3, seed=0)

2. データセットの作成

今回は、以下の数式をターゲットとして、学習用のデータセットを人工的に作成します。

f(x,y)=exp(sin(πx)+y2)f(x, y) = \exp(\sin(\pi x) + y^2)

pykan には、このような関数からデータセットを手軽に作成するユーティリティが用意されています。

# データセットを作成する関数を定義
f = lambda x: torch.exp(torch.sin(torch.pi * x[:, [0]]) + x[:, [1]]**2)

# 上記の関数から学習、テストデータを生成
# n_var: 変数の数(今回はxとyの2つ)
dataset = create_dataset(f, n_var=2)

print(dataset['train_input'].shape)
print(dataset['train_label'].shape)

create_dataset 関数は、指定された関数の入力(train_input)と出力(train_label)のペアを自動で生成してくれる便利な関数です。

3. モデルの学習

データセットの準備ができたので、いよいよモデルを学習させます。model.fit() を呼び出すだけで学習が始まります。

# モデルの学習を実行
# dataset: 使用するデータセット
# opt: 最適化アルゴリズム("LBFGS"を推奨)
# steps: 学習ステップ数
# lamb: 正則化の強さ(後述)
model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_l1=1.0)

学習プロセス中のログには、train_loss (学習データの誤差)、test_loss (テストデータの誤差)、reg (正則化項) などが表示され、学習の進捗を確認できます。

4. 学習結果の可視化と解釈

KANの真骨頂はここからです。学習済みのモデルをプロットして、ネットワークがどのような関数を学習したのかを見てみましょう。

# 学習済みモデルの構造と活性化関数をプロット
model.plot()

このコマンドを実行すると、以下のような図が出力されます。

この図は、KANの内部構造を表しています。

  • ノード: 各層の丸い部分がノードです。
  • エッジ: ノード間を結ぶ線がエッジです。
  • 小さなグラフ: 各エッジの上に乗っている小さなグラフが、学習された活性化関数(B-スプライン)の形です。

入力層(左側)には2つのノード(x_1x\_1, x_2x\_2)があり、中間層(中央)の5つのノードを経由して、出力層(右側)の1つのノードにつながっています。

この図を観察することで、どの入力変数が中間層のどのノードに、どのような関数(線形、非線形など)で影響を与えているかを視覚的に理解できます。


モデルの改良:プルーニングと記号的回帰

現在のモデルはまだ複雑です。KANでは、モデルをよりシンプルで解釈しやすくするための「プルーニング(枝刈り)」というテクニックが使えます。

1. モデルのプルーニング

プルーニングは、ネットワーク内の重要でない(影響の小さい)接続を自動的に削除するプロセスです。これにより、モデルがより疎(スパース)になり、本質的な関係性だけが残ります。

# モデルを自動的にプルーニング
pruned_model = model.prune()

# プルーニング後のモデルをプロット
pruned_model.plot()

実行すると、先ほどよりもエッジの数が減り、シンプルな構造になったネットワークが表示されるはずです。

このシンプルな構造から、モデルが入力x_1x\_1x_2x\_2をそれぞれ別々に処理し、後段で合算していることが推測できます。

2. 記号的回帰(Symbolic Regression)

pykan の最も強力な機能の一つが、学習した活性化関数を既知の数学関数(例: sin\\sin, exp\\exp, x2x^2など)に当てはめて、数式として提案してくれる「記号的回帰」です。

# 学習した関数を記号的な数式に変換
pruned_model.auto_symbolic()

# 結果を再度プロット
pruned_model.plot()

この auto_symbolic() を実行すると、pykanはライブラリ内の既知の関数と、学習したスプラインの形状を比較し、最もフィットするものを自動で見つけ出してくれます。

プロットを見ると、

  • x_1x\_1 につながるエッジでは sin 関数が
  • x_2x\_2 につながるエッジでは x^2 (2次関数) が
  • 中間層から出力層へのエッジでは exp 関数が

それぞれ学習されていることがわかります。これは、私たちが最初に定義したターゲット関数 f(x,y)=exp(sin(pix)+y2)f(x, y) = \\exp(\\sin(\\pi x) + y^2) の構造と非常によく一致しています。

このように、KANは単に高い精度で予測するだけでなく、データに潜む「数理的な構造」そのものを発見し、人間が解釈できる形で提示してくれる能力を持っています。


まとめと今後の展望

今回は、新しいニューラルネットワークアーキテクチャであるKANを、Pythonライブラリ pykan を用いて実装する具体的な手順を解説しました。

  • MLPとの違い: KANは重みではなく、エッジ上の「関数」を学習する。
  • 実装: pykan を使えば、データ作成、学習、可視化が数行のコードで実現できる。
  • 解釈性: plot() による可視化、prune() による単純化、auto_symbolic() による数式化を通じて、モデルの内部動作を深く理解できる。

KANはまだ新しい技術であり、学習に時間がかかるなどの課題も残されていますが、その高い性能と解釈性は、科学的な法則の発見、金融モデリング、医療診断支援など、これまで機械学習の適用が難しかった領域への応用が期待されています。

この記事が、皆さんがKANの世界に足を踏み入れるための一助となれば幸いです。ぜひ、ご自身のデータセットで試したり、ネットワークの構造やハイパーパラメータを変更して、その挙動を探求してみてください。


環境構築なし
実行できるファイルはこちら!

このボタンからGoogle Colabを開き、すぐにコードをお試しいただけます。

お仕事のご依頼・ご相談はこちら

フロントエンドからバックエンドまで、アプリケーション開発のご相談を承っております。
まずはお気軽にご連絡ください。

関連する記事

主成分分析(PCA)を実装してみよう!わかりやすく解説

【スマホからでも実行可能】主成分分析(PCA)のPythonによる実装方法を、初心者向けにコード付きで丁寧に解説します。scikit-learnライブラリを使い、データの要約から可視化までをステップバイステップで学べます。機械学習やデータ分析の第一歩に最適です。

畳み込みニューラルネットワーク(CNN)を実装してみよう!わかりやすく解説

【スマホからでも実行可能】畳み込みニューラルネットワーク(CNN)の実装方法を初心者にも分かりやすく解説。Pythonコード付きで、画像認識の仕組みを実践的に学べます。AI、ディープラーニング、画像分類に興味がある方必見です。

ガウス過程回帰を実装してみよう!わかりやすく解説

【スマホからでも実行可能】Pythonでガウス過程回帰を実装する方法を初心者向けにわかりやすく解説します。機械学習のモデル構築や不確実性の可視化に興味がある方必見です。

ベイズ最適化を実装してみよう!わかりやすく解説

【スマホからでも実行可能】ベイズ最適化の実装方法をPythonコード付きで徹底解説。機械学習のハイパーパラメータチューニングを効率化したい方必見。サンプルコードを動かしながら、実践的に学べます。

エクストラツリー(ExtraTrees)を実装してみよう!わかりやすく解説

【スマホからでも実行可能】この記事では、機械学習のアルゴリズムであるエクストラツリー(ExtraTrees)について、その仕組みからPythonによる実装方法までを丁寧に解説します。ランダムフォレストとの違いも理解しながら、実践的なスキルを身につけましょう。

Bamba news