主页 > 游戏开发  > 

【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务实战

【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务实战

【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务 1. 认为主流程code2. NLP 对话和预测基本均属于分类任务详细见3. Tensorborad

1. 认为主流程code import time import torch import numpy as np from train_eval import train, init_network from importlib import import_module import argparse from tensorboardX import SummaryWriter ###制定参数 --model TextRNN parser = argparse.ArgumentParser(description='Chinese Text Classification') parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer') parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained') parser.add_argument('--word', default=False, type=bool, help='True for word, False for char') args = parser.parse_args() if __name__ == '__main__': dataset = 'THUCNews' # 数据集 # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random embedding = 'embedding_SougouNews.npz' if args.embedding == 'random': embedding = 'random' model_name = args.model #TextCNN, TextRNN, if model_name == 'FastText': from utils_fasttext import build_dataset, build_iterator, get_time_dif embedding = 'random' else: from utils import build_dataset, build_iterator, get_time_dif x = import_module('models.' + model_name) config = x.Config(dataset, embedding) np.random.seed(1) torch.manual_seed(1) torch.cuda.manual_seed_all(1) torch.backends.cudnn.deterministic = True # 保证每次结果一样 start_time = time.time() print("Loading data...") vocab, train_data, dev_data, test_data = build_dataset(config, args.word) train_iter = build_iterator(train_data, config) dev_iter = build_iterator(dev_data, config) test_iter = build_iterator(test_data, config) time_dif = get_time_dif(start_time) print("Time usage:", time_dif) # train config.n_vocab = len(vocab) model = x.Model(config).to(config.device) writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime())) if model_name != 'Transformer': init_network(model) print(model.parameters) train(config, model, train_iter, dev_iter, test_iter,writer)

RNN

class Model(nn.Module): def __init__(self, config): super(Model, self).__init__() if config.embedding_pretrained is not None: self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) else: self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers, bidirectional=True, batch_first=True, dropout=config.dropout) self.fc = nn.Linear(config.hidden_size * 2, config.num_classes) def forward(self, x): x, _ = x out = self.embedding(x) # [batch_size, seq_len, embeding]=[128, 32, 300] out, _ = self.lstm(out) out = self.fc(out[:, -1, :]) # 句子最后时刻的 hidden state return out

TextRNN h_t 为RNN提取出来的特征

2. NLP 对话和预测基本均属于分类任务详细见

Pytorch学习记录分享9-PyTorch新闻数据集文本分类任务实战

3. Tensorborad

数据可视化操作 code repo

标签:

【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务实战由讯客互联游戏开发栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务实战