深度学习笔记(4)——TextCNN、BiLSTM实现情感分类(weibo100k数据集)

深度学习笔记(4)——TextCNN、BiLSTM实现情感分类(weibo100k数据集)使用数据集 微博数据集 共有约 12 万条数据 标签数为 2

大家好,欢迎来到IT知识分享网。

0 前言

1 数据准备

1.1 路径、常量、超参数

# 路径 DATASET_PATH = '../data/weibo/weibo_senti_100k.csv' USER_DICT = '../data/weibo/user_dict.txt' # 常量 DEVICE = 'cuda:0' if torch.cuda.is_available() else "cpu" loss_func = nn.CrossEntropyLoss() loss_list, accuracy_list = [], [] # 超参数 MAX_LEN = 200 # 语句长度 BATCH_SIZE = 128 # 批次大小 EMBEDDING_SIZE = 600 # embedding层大小 WINDOWS_SIZE = (2, 3, 4) # 窗口大小 FEATURE_SIZE = 200 # 特征大小 N_CLASSES = 2 # 类别数 EPOCHS = 10 # 批次大小 

1.2 加载数据集

# 数据探索性分析 def eda(): # 加载数据 dataset = pd.read_csv(DATASET_PATH) data1 = dataset['review'].iloc[:20000].values.tolist() data2 = dataset['review'].iloc[20000:40000].values.tolist() data3 = dataset['review'].iloc[40000:60000].values.tolist() data4 = dataset['review'].iloc[60000:80000].values.tolist() data5 = dataset['review'].iloc[80000:].values.tolist() data6 = dataset['review'].iloc[:].values.tolist() datas = [data1, data2, data3, data4, data5, data6] labels = dataset['label'] print(labels.value_counts()) # 文本特征提取器 return datas, labels 

难点
数据量过大、处理低频词耗时较长(算法时间复杂度为 O 2 O^2 O2
原因
去低频词时,双重for循环时间复杂度过高,文本量也大。
解决方案

  1. 减少min_threshold(效果不明显)
  2. 数据分批处理(效果显著,本文使用)
  3. 多线程优化(暂时没学会,好像效果也不好)
  4. list转为numpy(效果不明显)
  5. numba加速(效果不明显,不清楚什么原因)

2 文本清洗

corpus = [ '[鼓掌]//@权金城崔洪峰:扩散@权金城彭涌 @权金城-崔成哲 //@思想聚焦:转发微博', 'UP!虽然你很不和谐//@风言疯语LaiN胖子:为啥你不关注别人,却要别人关注你?学名人啊?[嘻嘻] //@ponponxu:转发微博。', '[鼓掌] //@金朝顺:帅哥美女如云~恭喜开课!@魏英俊 @solonso //@洪璐葫芦:恭喜了!//@昆晏:转发微博 转发微博', '#轻松一刻# 笑成狗了!主人太有才了![哈哈] #哈哈#', '#约惠海航 圆梦飞翔#【惠享直减】购票购票购票,直减直减直减[打哈欠] #测试#这不冲突,也很科学,来海航官网购票,管够,管实惠【惠享直减】http://t.cn/zRpYB9r [嘻嘻] 每天500个名额,20元的直减,ok的赶快来赞 今天第二波15点开始~', '激动人心的时刻[心]#微动日照#传播大赛大奖ipad实图奉上!感谢@日照市旅游局官方微博 的好活动.东方太阳城给了我太多惊喜,美食霸占味蕾,美景俘获视觉[爱你]仙山兔耳鳗鱼香螺,故地重游也仍有遗憾.日照,美就一个字,我还会再来的@日照旅游王立新@日照旅游-日出先照当属日照@日照旅游咨询网@山海美景', '#昆航动态#2010年11月6日,在昆明市创业投资引导基金推介暨颁奖晚宴上,2010年11月昆明航空有限公司董事长王清民(图中左五)从昆明市委常委、副市长刘光溪手上接过#2010泛亚地区最具投资潜质十强企业#证书和奖杯。[鼓掌] 昆明航空成为500多家报名企业中唯一一家获奖的航空企业。[礼花] http://sinaurl.cn/h4QFmF', '回复@夜里梵高:君亭的家门向每个游子敞开!欢迎回家![鼓掌] //@夜里梵高:我想回家!哈哈哈亲亲 //@杭州君亭湖滨酒店:君亭,你在杭州的另一个家', '... 。。。 !!! !!! ??? ??? ?! 。。。。 !!!!!! ' ] 
  1. 话题
  2. 转发微博
  3. 回复@usename
  4. @username(空格)
  5. 含有隐含意义的标点符号,例如。。。 ??? !!!(中文文化真是博大精深)
  6. 网址
  7. 时间
    通过设立正则表达式匹配实现替换这些这些对情感分析无关联的信息
    最后,清楚文本中的英文字符、中英文符号等
import string import jieba import re from zhon import hanzi from tqdm import tqdm class WeiBoTextCleaner: def __init__(self, corpus): self.corpus = corpus self.new_corpus = [] def extract_topic(self, sent): """ 提取话题 #*# 【*】 :return: """ pattern1 = re.compile('【([^】]+)】') pattern2 = re.compile('#([^#]+)#') sent = re.sub(pattern1, '', sent) sent = re.sub(pattern2, '', sent) return sent def extract_forward(self, sent): """ 提取转发微博 转发微博 :return: """ pattern = re.compile('转发微博') return re.sub(pattern, '', sent) def extract_reply(self, sent): """ 提取回复@username 回复@username: :return: """ pattern = re.compile('回复@[a-zA-Z\u4e00-\u9fa5_0-9-]+') return re.sub(pattern, '', sent) def extract_username(self, sent): """ 提取用户名 @username(空格) :return: """ pattern = re.compile('@[a-zA-Z\u4e00-\u9fa5_0-9-]+') return re.sub(pattern, '', sent) def extract_emotional_punctuation(self, sent): """ 提取含有隐含意义的标点符号 ... !!! ??? ?! 。。。 !!! ??? ?! :return: """ pattern1 = re.compile('。{3,}') pattern2 = re.compile(r'\.{3,}') pattern3 = re.compile('!{3,}') pattern4 = re.compile('!{3,}') pattern5 = re.compile(r'\?{3,}') pattern6 = re.compile('?{3,}') pattern7 = re.compile(r'\?!') pattern8 = re.compile(r'?!') sent = re.sub(pattern1, '自定义一', sent) sent = re.sub(pattern2, '自定义一', sent) sent = re.sub(pattern3, '自定义二', sent) sent = re.sub(pattern4, '自定义二', sent) sent = re.sub(pattern5, '自定义三', sent) sent = re.sub(pattern6, '自定义三', sent) sent = re.sub(pattern7, '自定义四', sent) sent = re.sub(pattern8, '自定义四', sent) return sent def extract_weblink(self, sent): """ 提取网址 http://* :return: """ pattern = re.compile('http://[0-9a-zA-Z./]+') return re.sub(pattern, '', sent) def extract_time(self, sent): """ 提取时间 *年*月*日 *年*月 *月*日 :return: """ pattern1 = re.compile(r'\d{4}年\d{1,2}月\d{1,2}日') pattern2 = re.compile(r'\d{4}年\d{1,2}月') pattern3 = re.compile(r'\d{1,2}月\d{1,2}日') sent = re.sub(pattern1, '', sent) sent = re.sub(pattern2, '', sent) sent = re.sub(pattern3, '', sent) return sent def clear_character(self, sent): """ 清楚无效字符 :param sent: :return: """ pattern1 = re.compile('[a-zA-Z0-9]') # 英文字符和数字 pattern2 = re.compile(r'[^\s::' + '\u4e00-\u9fa5]+') # 表情和其他字符 pattern3 = re.compile('[%s]+' % re.escape(string.punctuation + hanzi.punctuation)) # 标点符号 sent = re.sub(pattern1, '', sent) sent = re.sub(pattern2, '', sent) sent = re.sub(pattern3, '', sent) sent = ''.join(sent.split()) # 去除空白 return sent def execute(self): for sentence in tqdm(self.corpus): sentence = self.extract_forward(sentence) sentence = self.extract_reply(sentence) sentence = self.extract_username(sentence) sentence = self.extract_topic(sentence) sentence = self.extract_weblink(sentence) sentence = self.extract_time(sentence) sentence = self.extract_emotional_punctuation(sentence) sentence = self.clear_character(sentence) self.new_corpus.append(sentence) return self.new_corpus 

3 分词

# 去低频词 def remove_words(corpus, delete_list): for seg_list in tqdm(corpus): for seg in seg_list: if seg in delete_list: seg_list.remove(seg) return corpus # 分词 def tokenizer(corpus, min_threshold, i): t1 = time.time() # 加载用户字典 jieba.load_userdict(USER_DICT) corpus = list(map(jieba.lcut, corpus)) # 去低频词 print('去低频词') word_list = [] for seg_list in tqdm(corpus): word_list.extend(seg_list) counter = Counter(word_list) delete_list = [] # 要去除的词 for k, v in counter.items(): if v < min_threshold: delete_list.append(k) print(f'词总数:{ 
     len(word_list)}') print(f'要去除低频词数量:{ 
     len(delete_list)}') corpus = remove_words(corpus, delete_list) print(len(corpus)) print('序列化列表') with open(f'../data/weibo/corpus{ 
     i}.pkl', 'wb') as f: pickle.dump(corpus, f) t2 = time.time() print(f'共耗时{ 
     t2 - t1}秒') # 合并pkl def combine_pkl(paths: list): sentences = [] for path in paths: with open(path, 'rb') as f: sentences.extend(pickle.load(f)) return sentences 

4 工具类、文本向量化

工具类

import torch from tqdm import tqdm from sklearn.metrics import accuracy_score # 生成word2index def compute_word2index(sentences, word2index): for sentences in sentences: for word in sentences: if word not in word2index: word2index[word] = len(word2index) # word2index存储的是索引 return word2index # 生成sent2index def compute_sent2index(sentence, max_len, word2index): sent2index = [word2index.get(word, 0) for word in sentence] if len(sentence) < max_len: sent2index += (max_len - len(sentence)) * [0] else: sent2index = sentence[:max_len] return sent2index # 文本表示 def text_embedding(sentences, max_len): # 生成词向量与句向量 word2index = { 
   "PAD": 0} word2index = compute_word2index(sentences, word2index) sent2indexs = [] for sent in tqdm(sentences): sentence = compute_sent2index(sent, max_len, word2index) sent2indexs.append(sentence) return word2index, sent2indexs # 计算准确率 def get_accuracy(model, datas, labels): out = torch.softmax(model(datas), dim=1, dtype=torch.float32) predictions = torch.max(input=out, dim=1)[1] # 最大值的索引 y_predict = predictions.to('cpu').data.numpy() y_true = labels.to('cpu').data.numpy() accuracy = accuracy_score(y_true, y_predict) # 准确率 return accuracy 

5 模型构建

TextCNN见上一篇文章

BiLSTM
网络结构

代码

import torch from torch import nn, optim class BiLSTM(nn.Module): def __init__(self, num_embeddings, embedding_dim, hidden_size, num_layers, num_classes, device): super(BiLSTM, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.device = device # 词嵌入层 self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) # LSTM self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True) # Dropout层 self.dropout = nn.Dropout(p=0.5) # 全连接层 self.fc = nn.Linear(in_features=hidden_size * 2, out_features=num_classes) def forward(self, x): x = self.embed(x) # [batch_size, max_len, 100] h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(self.device) c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(self.device) out, (h_n, c_n) = self.lstm(x, (h0, c0)) output_fw = h_n[-2, :, :] # 正向最后一次输出 output_bw = h_n[-1, :, :] # 反向最后一词输出 out = torch.concat([output_fw, output_bw], dim=1) # [batch_size, hidden_size*2] # x = torch.softmax(x, dim=1) # x = self.fc(out[:, -1, :]) x = self.fc(out) return x 

6 评估

# 训练 def train(model, dataloaer, optimizer, epoch): model.train() # 模型训练 for i, (datas, labels) in enumerate(dataloaer): # 设备转换 datas = datas.to(DEVICE) labels = labels.to(DEVICE) # 计算结果 out = model(datas) # 计算损失值 loss = loss_func(out, labels) # 梯度清零 optimizer.zero_grad() # 反向传播 loss.backward() # 梯度更新 optimizer.step() # 打印损失值 if i % 300 == 0: loss_list.append(loss.item()) accuracy = get_accuracy(model, datas, labels) accuracy_list.append(accuracy) print('Train Epoch:%d Loss:%0.6f Accuracy:%0.6f' % (epoch, loss.item(), accuracy)) # 绘制曲线 def plot_curve(epochs, accuracy_list, loss_list, model_name): # 计算平均值 accuracy_array = np.array(accuracy_list).reshape(epochs, -1) accuracy_array = np.mean(accuracy_array, axis=1) loss_array = np.array(loss_list).reshape(epochs, -1) loss_array = np.mean(loss_array, axis=1) # 绘制Loss曲线 plt.rcParams['figure.figsize'] = (16, 8) plt.subplots(1, 2) plt.subplot(1, 2, 1) plt.plot(range(epochs), loss_array) plt.xlabel('epoch') plt.ylabel('loss') plt.title('Loss Curve') plt.subplot(1, 2, 2) plt.plot(range(epochs), accuracy_array) plt.xlabel('epoch') plt.ylabel('accuracy') plt.title('Accuracy Cure') plt.savefig(f'../figure/weibo_{ 
     model_name}.png') 

TextCNN
在这里插入图片描述

BiLSTM
在这里插入图片描述
两种模型仅训练了几个epoch,在训练集上的准确率均达到了98.5%

7 总览

def execute(): # # EDA datas, labels = eda() # 数据清洗 for i, data in enumerate(datas): print(f'数据清洗 第{ 
     i+1}份') cleaner = WeiBoTextCleaner(data) corpus = cleaner.execute() # 分词 tokenizer(corpus, 25, i + 1) # 合并pickle paths = [f'../data/weibo/corpus{ 
     i}.pkl' for i in range(1, 7)] sentences = combine_pkl(paths) print(len(sentences)) # 文本表示 print('文本表示') word2index, sent2index = text_embedding(sentences, MAX_LEN) # 装载数据集 train_dataset = MyDataSet(sent2index, labels) dataloader_train = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8) # 构建模型 vocab_size = len(word2index) # TextCNN # model = TextCNN(vocab_size=vocab_size, embedding_dim=EMBEDDING_SIZE, windows_size=WINDOWS_SIZE, # max_len=MAX_LEN, feature_size=FEATURE_SIZE, n_class=N_CLASSES).to(DEVICE) # optimizer = optim.Adam(model.parameters(), lr=0.001) print('GPU_Allocated:%d' % torch.cuda.memory_allocated()) model = BiLSTM(num_embeddings=vocab_size, embedding_dim=MAX_LEN, hidden_size=MAX_LEN, num_layers=2, num_classes=N_CLASSES, device=DEVICE).to(DEVICE) optimizer = optim.Adam(model.parameters(), lr=0.001) print('GPU_Allocated:%d' % torch.cuda.memory_allocated()) # 模型训练 for i in range(EPOCHS): print(f'{ 
     i+1}/{ 
     EPOCHS}') train(model, dataloader_train, optimizer, i+1) # 模型保存 torch.save(model.state_dict(), '../model/bilstem_weibo.pkl') # 绘制曲线 plot_curve(EPOCHS, accuracy_list, loss_list, 'BiLSTM') if __name__ == '__main__': execute() # test_model(64826) 

8 实时测试

# 实时检测 def test_model(vocab_size): # 加载模型 model = TextCNN(vocab_size=vocab_size, embedding_dim=EMBEDDING_SIZE, windows_size=WINDOWS_SIZE, max_len=MAX_LEN, feature_size=FEATURE_SIZE, n_class=N_CLASSES).to(DEVICE) model.load_state_dict(torch.load('../model/textcnn_weibo.pkl')) warnings.filterwarnings(action='ignore') while True: sentence = input("检测您的微博") data = [sentence] # 处理 cleaner = WeiBoTextCleaner(data) corpus = cleaner.execute() jieba.load_userdict(USER_DICT) corpus = list(map(jieba.lcut, corpus)) word2index, sent2index = text_embedding(corpus, MAX_LEN) datas = sent2index datas = torch.LongTensor(datas).to(DEVICE) # 预测 out = model(datas) out = torch.softmax(out, dim=1, dtype=torch.float32) predictions = torch.max(input=out, dim=1)[1] y_predict = predictions.to('cpu').data.numpy() if y_predict[0] == 1: print('积极') else: print('消极') 

在这里插入图片描述
在这里插入图片描述

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/147395.html

(0)
上一篇 2025-04-07 20:33
下一篇 2025-04-07 20:45

相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注微信