PyTorchの埋め込み (Embedding) の使い方


埋め込みは、主に自然言語処理のタスクで利用されます。例えば、単語をベクトルに変換してニューラルネットワークに入力する際に使われます。以下に、PyTorchでの埋め込みの使い方とコード例を示します。

まず、PyTorchライブラリをインポートします。

import torch
import torch.nn as nn

次に、埋め込み層を定義します。埋め込み層は、カテゴリ値の表現を密なベクトルに変換する役割を持ちます。以下は、埋め込み層の定義例です。

vocab_size = 10000  # 語彙数
embedding_dim = 300  # 埋め込みベクトルの次元数
embedding = nn.Embedding(vocab_size, embedding_dim)

上記の例では、語彙数が10,000で、埋め込みベクトルの次元数が300です。この埋め込み層は、語彙数のサイズの入力を受け取り、各カテゴリ値を300次元のベクトル表現に変換します。

次に、埋め込み層に入力データを与えて変換を行います。例えば、以下のような入力データがあるとします。

input_data = torch.LongTensor([[1, 2, 3, 4], [5, 6, 7, 8]])

この場合、input_dataは2つの文からなるバッチであり、各文は4つの単語からなります。input_dataを埋め込み層に入力するには、次のようにします。

embedded_data = embedding(input_data)

これにより、embedded_dataには入力データが埋め込みベクトルに変換された結果が格納されます。

以上が、PyTorchでの埋め込みの使い方とコード例です。埋め込みを利用することで、カテゴリ値の表現を連続値のベクトルに変換し、機械学習モデルに入力することができます。この方法は、自然言語処理タスクなどで特に有用です。