博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
CBOW模型实战——pytorch
阅读量:4115 次
发布时间:2019-05-25

本文共 2745 字,大约阅读时间需要 9 分钟。

CONTEXT_SIZE = 2      # 2 words to the left, 2 to the rightEMBEDDING_DIM = 100raw_text = """We are about to study the idea of a computational process.Computational processes are abstract beings that inhabit computers.As they evolve, processes manipulate other abstract things called data.The evolution of a process is directed by a pattern of rulescalled a program. People create programs to direct processes. In effect,we conjure the spirits of the computer with our spells.""".split()# By deriving a set from `raw_text`, we deduplicate the arrayvocab = set(raw_text)vocab_size = len(vocab)word_to_ix = {word:i for i, word in enumerate(vocab)}data = []for i in range(2, len(raw_text)-2):    context = [raw_text[i-2], raw_text[i-1],               raw_text[i+1], raw_text[i+2]]    target = raw_text[i]    data.append((context, target))print(data[:5])

测试打印 data[:5] 为:

[(['We', 'are', 'to', 'study'], 'about'), (['are', 'about', 'study', 'the'], 'to'), (['about', 'to', 'the', 'idea'], 'study'), (['to', 'study', 'idea', 'of'], 'the'), (['study', 'the', 'of', 'a'], 'idea')]

构建模型:

class CBOW(nn.Module):    def __init__(self, vocab_size, n_dim, context_size):        super(CBOW, self).__init__()        self.embeddings = nn.Embedding(vocab_size, n_dim)        self.linear1 = nn.Linear(2*context_size * n_dim, 128)        self.linear2 = nn.Linear(128, vocab_size)            def forward(self, inputs):        embeds = self.embeddings(inputs).view(1, -1)        out = F.relu(self.linear1(embeds))        out = self.linear2(out)        log_probs = F.log_softmax(out, dim=1)        return log_probs
# create your model and train.  here are some functions to help you make# the data ready for use by your moduledef make_context_vector(context, word_to_ix):    idxs = [word_to_ix[w] for w in context]    return torch.tensor(idxs, dtype=torch.long)print(make_context_vector(data[0][0], word_to_ix))  # examplemodel = CBOW(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE)if torch.cuda.is_available():    model = model.cuda()losses = []loss_function = nn.NLLLoss()optimizer = optim.SGD(model.parameters(), lr=0.001)for epoch in range(200):    total_loss = 0    context_one_hots = []    for context, target in data:        context_vector = make_context_vector(context, word_to_ix)        target = torch.tensor([word_to_ix[target]], dtype=torch.long)        if torch.cuda.is_available():            context_vector = context_vector.cuda()            target = target.cuda()                optimizer.zero_grad()                log_probs = model(context_vector)        loss = loss_function(log_probs, target)        loss.backward()        optimizer.step()                total_loss += loss.item()    print("epoch", epoch, " -->", total_loss)    losses.append(total_loss)

转载地址:http://svwpi.baihongyu.com/

你可能感兴趣的文章
对话周鸿袆:从程序员创业谈起
查看>>
web.py 0.3 新手指南 - 如何用Gmail发送邮件
查看>>
web.py 0.3 新手指南 - RESTful doctesting using app.request
查看>>
web.py 0.3 新手指南 - 使用db.query进行高级数据库查询
查看>>
web.py 0.3 新手指南 - 多数据库使用
查看>>
一步步开发 Spring MVC 应用
查看>>
python: extend (扩展) 与 append (追加) 的差别
查看>>
「译」在 python 中,如果 x 是 list,为什么 x += "ha" 可以运行,而 x = x + "ha" 却抛出异常呢?...
查看>>
浅谈JavaScript的语言特性
查看>>
LeetCode第39题思悟——组合总和(combination-sum)
查看>>
LeetCode第43题思悟——字符串相乘(multiply-strings)
查看>>
LeetCode第44题思悟——通配符匹配(wildcard-matching)
查看>>
LeetCode第45题思悟——跳跃游戏(jump-game-ii)
查看>>
LeetCode第46题思悟——全排列(permutations)
查看>>
LeetCode第47题思悟—— 全排列 II(permutations-ii)
查看>>
LeetCode第48题思悟——旋转图像(rotate-image)
查看>>
驱动力3.0,动力全开~
查看>>
记CSDN访问量10万+
查看>>
Linux下Oracle数据库账户被锁:the account is locked问题的解决
查看>>
记CSDN访问20万+
查看>>