网创优客建站品牌官网
为成都网站建设公司企业提供高品质网站建设
热线:028-86922220
成都专业网站建设公司

定制建站费用3500元

符合中小企业对网站设计、功能常规化式的企业展示型网站建设

成都品牌网站建设

品牌网站建设费用6000元

本套餐主要针对企业品牌型网站、中高端设计、前端互动体验...

成都商城网站建设

商城网站建设费用8000元

商城网站建设因基本功能的需求不同费用上面也有很大的差别...

成都微信网站建设

手机微信网站建站3000元

手机微信网站开发、微信官网、微信商城网站...

建站知识

当前位置:首页 > 建站知识

怎么使用pytorch框架

这篇文章主要讲解了“怎么使用pytorch框架”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“怎么使用pytorch框架”吧!

专注于为中小企业提供成都网站设计、网站建设服务,电脑端+手机端+微信端的三站合一,更高效的管理,为中小企业商城免费做网站提供优质的服务。我们立足成都,凝聚了一批互联网行业人才,有力地推动了1000多家企业的稳健成长,帮助中小企业通过网站建设实现规模扩充和转变。

中文新闻情感分类 Bert-Pytorch-transformers

使用pytorch框架以及transformers包,以及Bert的中文预训练模型

文件目录

data

Train_DataSet.csv

Train_DataSet_Label.csv

main.py

NewsData.py

#main.py

from transformers import BertTokenizer

from transformers import BertForSequenceClassification

from transformers import BertConfig

from transformers import BertPreTrainedModel

import torch

import torch.nn as nn

from transformers import BertModel

import time

import argparse

from NewsData import NewsData

import os

def get_train_args():

parser=argparse.ArgumentParser()

parser.add_argument('--batch_size',type=int,default=10,help = '每批数据的数量')

parser.add_argument('--nepoch',type=int,default=3,help = '训练的轮次')

parser.add_argument('--lr',type=float,default=0.001,help = '学习率')

parser.add_argument('--gpu',type=bool,default=True,help = '是否使用gpu')

parser.add_argument('--num_workers',type=int,default=2,help='dataloader使用的线程数量')

parser.add_argument('--num_labels',type=int,default=3,help='分类类数')

parser.add_argument('--data_path',type=str,default='./data',help='数据路径')

opt=parser.parse_args()

print(opt)

return opt

def get_model(opt):

#类方法.from_pretrained()获取预训练模型,num_labels是分类的类数

model = BertForSequenceClassification.from_pretrained('bert-base-chinese',num_labels=opt.num_labels)

return model

def get_data(opt):

#NewsData继承于pytorch的Dataset类

trainset = NewsData(opt.data_path,is_train = 1)

trainloader=torch.utils.data.DataLoader(trainset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers)

testset = NewsData(opt.data_path,is_train = 0)

testloader=torch.utils.data.DataLoader(testset,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers)

return trainloader,testloader

def train(epoch,model,trainloader,testloader,optimizer,opt):

print('\ntrain-Epoch: %d' % (epoch+1))

model.train()

start_time = time.time()

print_step = int(len(trainloader)/10)

for batch_idx,(sue,label,posi) in enumerate(trainloader):

if opt.gpu:

sue = sue.cuda()

posi = posi.cuda()

label = label.unsqueeze(1).cuda()

optimizer.zero_grad()

#输入参数为词列表、位置列表、标签

outputs = model(sue, position_ids=posi,labels = label)

loss, logits = outputs[0],outputs[1]

loss.backward()

optimizer.step()

if batch_idx % print_step == 0:

print("Epoch:%d [%d|%d] loss:%f" %(epoch+1,batch_idx,len(trainloader),loss.mean()))

print("time:%.3f" % (time.time() - start_time))

def test(epoch,model,trainloader,testloader,opt):

print('\ntest-Epoch: %d' % (epoch+1))

model.eval()

total=0

correct=0

with torch.no_grad():

for batch_idx,(sue,label,posi) in enumerate(testloader):

if opt.gpu:

sue = sue.cuda()

posi = posi.cuda()

labels = label.unsqueeze(1).cuda()

label = label.cuda()

else:

labels = label.unsqueeze(1)

outputs = model(sue, labels=labels)

loss, logits = outputs[:2]

_,predicted=torch.max(logits.data,1)

total+=sue.size(0)

correct+=predicted.data.eq(label.data).cpu().sum()

s = ("Acc:%.3f" %((1.0*correct.numpy())/total))

print(s)

if __name__=='__main__':

opt = get_train_args()

model = get_model(opt)

trainloader,testloader = get_data(opt)

if opt.gpu:

model.cuda()

optimizer=torch.optim.SGD(model.parameters(),lr=opt.lr,momentum=0.9)

if not os.path.exists('./model.pth'):

for epoch in range(opt.nepoch):

train(epoch,model,trainloader,testloader,optimizer,opt)

test(epoch,model,trainloader,testloader,opt)

torch.save(model.state_dict(),'./model.pth')

else:郑州治疗妇科哪个医院好 http://www.120kdfk.com/

model.load_state_dict(torch.load('model.pth'))

print('模型存在,直接test')

test(0,model,trainloader,testloader,opt)

#NewsData.py

from transformers import BertTokenizer

from transformers import BertForSequenceClassification

from transformers import BertConfig

from transformers import BertPreTrainedModel

import torch

import torch.nn as nn

from transformers import BertModel

import time

import argparse

class NewsData(torch.utils.data.Dataset):

def __init__(self,root,is_train = 1):

self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

self.data_num = 7346

self.x_list = []

self.y_list = []

self.posi = []

with open(root + '/Train_DataSet.csv',encoding='UTF-8') as f:

for i in range(self.data_num+1):

line = f.readline()[:-1] + '这是一个中性的数据'

data_one_str = line.split(',')[len(line.split(','))-2]

data_two_str = line.split(',')[len(line.split(','))-1]

if len(data_one_str) < 6:

z = len(data_one_str)

data_one_str = data_one_str + ',' + data_two_str[0:min(200,len(data_two_str))]

else:

data_one_str = data_one_str

if i==0:

continue

word_l = self.tokenizer.encode(data_one_str, add_special_tokens=False)

if len(word_l)<100:

while(len(word_l)!=100):

word_l.append(0)

else:

word_l = word_l[0:100]

word_l.append(102)

l = word_l

word_l = [101]

word_l.extend(l)

self.x_list.append(torch.tensor(word_l))

self.posi.append(torch.tensor([i for i in range(102)]))

with open(root + '/Train_DataSet_Label.csv',encoding='UTF-8') as f:

for i in range(self.data_num+1):

#print(i)

label_one = f.readline()[-2]

if i==0:

continue

label_one = int(label_one)

self.y_list.append(torch.tensor(label_one))

#训练集或者是测试集

if is_train == 1:

self.x_list = self.x_list[0:6000]

self.y_list = self.y_list[0:6000]

self.posi = self.posi[0:6000]

else:

self.x_list = self.x_list[6000:]

self.y_list = self.y_list[6000:]

self.posi = self.posi[6000:]

self.len = len(self.x_list)

def __getitem__(self, index):

return self.x_list[index], self.y_list[index],self.posi[index]

def __len__(self):

return self.len

感谢各位的阅读,以上就是“怎么使用pytorch框架”的内容了,经过本文的学习后,相信大家对怎么使用pytorch框架这一问题有了更深刻的体会,具体使用情况还需要大家实践验证。这里是创新互联,小编将为大家推送更多相关知识点的文章,欢迎关注!


名称栏目:怎么使用pytorch框架
网站地址:http://bjjierui.cn/article/peschd.html

其他资讯