torch nn.Embedding 举例

依据最大索引转向量:在最后一维上增加一个维度,将索引转化为一个浮点向量

 
import torch
from torch import nn

embedding = nn.Embedding(num_embeddings=8,
                        embedding_dim=3,
                            padding_idx=0)


index = torch.tensor([1,3,5,7])
print(embedding(index).shape)  #torch.Size([4, 3])

 

    

 


 

  

 

    

 


nn.Embedding

num_embeddings

 
nn.Embedding中num_embeddings理论上指索引的个数,
但这有个前提,索引编码从0开始,并且一个连一个,
不能跳跃,一般情况下也不会跳跃,但跳跃也没有关系

实际上num_embeddings指提供的索引中的最大值 
比如,本例如果num_embeddings低于8将报错,而实际上的索引个数只有4个
正常情况下索引编码都是从0开始,也无跳跃

 

    

 

    
参考