大家好,欢迎来到IT知识分享网。
写一个DQN的程序,并解释一下其原理
其实DQN就是 Q-Learning 算法 + 神经网络。我们知道,Q-Learning 算法需要维护一张 Q 表格,按照下面公式来更新:

然后学习的过程就是更新 这张 Q表格,如下图所示:

而DQN就是用神经网络来代替这张 Q 表格,其余相同,如下图:

具体参见下面的链接:
https://blog.csdn.net/november_chopin/article/details/
以下是使用PyTorch实现的DQN(Deep Q-Network)程序及其原理详解,以及与Q-learning的区别分析。
import gym import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import random # 深度Q网络模型 class DQN(nn.Module): def __init__(self, state_dim, action_dim): super(DQN, self).__init__() self.fc = nn.Sequential( nn.Linear(state_dim, 24), nn.ReLU(), nn.Linear(24, 24), nn.ReLU(), nn.Linear(24, action_dim) ) def forward(self, x): return self.fc(x) # 经验回放缓冲区 class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.buffer, batch_size) def __len__(self): return len(self.buffer) # 超参数 BATCH_SIZE = 64 GAMMA = 0.99 EPS_START = 1.0 EPS_END = 0.01 EPS_DECAY = 0.995 TARGET_UPDATE = 10 # 初始化环境、网络和优化器 env = gym.make('CartPole-v1') state_dim = env.observation_space.shape[0] action_dim = env.action_space.n policy_net = DQN(state_dim, action_dim) target_net = DQN(state_dim, action_dim) target_net.load_state_dict(policy_net.state_dict()) optimizer = optim.Adam(policy_net.parameters(), lr=0.001) buffer = ReplayBuffer(10000) epsilon = EPS_START # 训练循环 for episode in range(500): state = env.reset() total_reward = 0 while True: # ε-贪婪策略选择动作 if np.random.rand() < epsilon: action = env.action_space.sample() else: with torch.no_grad(): state_tensor = torch.FloatTensor(state).unsqueeze(0) q_values = policy_net(state_tensor) action = q_values.argmax().item() # 执行动作并存储经验 next_state, reward, done, _ = env.step(action) buffer.push(state, action, reward, next_state, done) total_reward += reward state = next_state # 经验回放训练 if len(buffer) >= BATCH_SIZE: transitions = buffer.sample(BATCH_SIZE) batch = list(zip(*transitions)) # 转换为张量 state_batch = torch.FloatTensor(np.array(batch[0])) action_batch = torch.LongTensor(batch[1]) reward_batch = torch.FloatTensor(batch[2]) next_state_batch = torch.FloatTensor(np.array(batch[3])) done_batch = torch.FloatTensor(batch[4]) # 计算当前Q值 current_q = policy_net(state_batch).gather(1, action_batch.unsqueeze(1)) # 计算目标Q值 next_q = target_net(next_state_batch).max(1)[0].detach() target_q = reward_batch + (1 - done_batch) * GAMMA * next_q # 计算损失并优化 loss = nn.MSELoss()(current_q.squeeze(), target_q) optimizer.zero_grad() loss.backward() optimizer.step() if done: break # 更新目标网络和ε if episode % TARGET_UPDATE == 0: target_net.load_state_dict(policy_net.state_dict()) epsilon = max(EPS_END, epsilon * EPS_DECAY) print(f"Episode {episode}, Total Reward: {total_reward}") env.close()
一、DQN原理说明
- 核心思想:
DQN结合了深度神经网络和Q-learning,用神经网络近似Q值函数: - Q(s,a)≈Qθ(s,a)Q(s,a)≈Qθ(s,a)
- 其中θ表示神经网络的参数。
- 关键技术:
- 经验回放(Experience Replay):存储历史经验(s, a, r, s’, done),训练时随机采样打破数据相关性。
- 目标网络(Target Network):使用独立的网络计算目标Q值,避免Q值估计的波动。
- 损失函数:
- L(θ)=E[(r+γmaxa′Qθ−(s′,a′)−Qθ(s,a))2]L(θ)=E[(r+γa′maxQθ−(s′,a′)−Qθ(s,a))2]
- 其中θ⁻为目标网络参数,定期从主网络复制。
二、DQN vs Q-learning 的区别
特性 |
Q-learning |
DQN |
状态表示 |
离散状态(表格存储) |
连续/高维状态(神经网络近似) |
动作空间 |
适合低维离散动作 |
可处理高维离散动作 |
数据使用效率 |
单步更新后丢弃经验 |
经验回放提高数据利用率 |
稳定性 |
直接更新可能导致震荡 |
目标网络稳定训练过程 |
适用场景 |
小规模离散问题(如Grid World) |
复杂环境(如Atari游戏、机器人控制) |
三、关键改进点
- 函数近似:DQN使用神经网络替代Q表,解决维度灾难问题。
- 经验回放:打破时间相关性,提高数据利用率。
- 目标网络:延迟更新目标值计算网络,防止Q值震荡。
通过结合深度学习和强化学习,DQN成为处理复杂RL任务的基础框架,后续算法(如Double DQN、Dueling DQN)均在此基础上改进。
什么是目标网络?为啥需要目标网络?
目标网络(Target Network)是深度Q网络(DQN)中的关键设计,主要用于稳定Q值的估计过程。以下是详细说明:
一、什么是目标网络?
目标网络是DQN中与主网络(Policy Network)结构完全相同但参数更新延迟的神经网络:
- 主网络(Policy Network):实时更新参数,用于选择动作和计算当前Q值。
- 目标网络(Target Network):参数周期性从主网络复制(或按比例缓慢更新),专门用于计算目标Q值。
例如,在DQN代码中:
# 每10个episode将主网络参数复制到目标网络 if episode % TARGET_UPDATE == 0: target_net.load_state_dict(policy_net.state_dict())
二、为什么需要目标网络?
(1)解决“追逐自己尾巴”问题
在普通Q-learning中,Q值的更新公式为:
Q(s,a)←Q(s,a)+α[r+γmaxa′Q(s′,a′)−Q(s,a)]Q(s,a)←Q(s,a)+α[r+γa′maxQ(s′,a′)−Q(s,a)]
- 如果使用同一个网络同时计算当前Q值和目标Q值(maxa′Q(s′,a′)maxa′Q(s′,a′)),会导致目标值随网络参数频繁变化。
- 问题:网络参数更新后,目标值会立即改变,导致Q值估计像“追逐自己尾巴”一样不稳定。
(2)目标网络的作用
通过目标网络计算目标Q值:
目标Q值=r+γmaxa′Qtarget(s′,a′)目标Q值=r+γa′maxQtarget(s′,a′)
- 目标网络的参数固定:在一段时间内,目标网络的参数保持不变,使目标Q值的计算相对稳定。
- 延迟更新:定期将主网络参数同步到目标网络,确保目标值的更新是平缓的。
(3)直观类比
假设你在学习射击移动靶子:
- 没有目标网络:靶子随着你的每次射击不断快速移动(参数实时变化),难以瞄准。
- 有目标网络:靶子每隔一段时间才移动一次(延迟更新),你更容易对准目标。
四、目标网络的技术细节
(1)更新方式
- 硬更新(Hard Update):每隔固定步数完全复制主网络参数(如DQN原始论文)。
target_net.load_state_dict(policy_net.state_dict())
- target_net.load_state_dict(policy_net.state_dict())
- 软更新(Soft Update):每次按比例混合主网络和目标网络参数(如DDPG算法):
- θtarget←τθpolicy+(1−τ)θtargetθtarget←τθpolicy+(1−τ)θtarget
- 其中ττ是一个小值(如0.001)。
(2)实验效果
- 没有目标网络时,Q值估计会出现剧烈震荡,导致训练不稳定。
- 使用目标网络后,Q值收敛更平稳,模型性能显著提升。
五、目标网络 vs 普通Q-learning的对比
场景 |
普通Q-learning |
带目标网络的DQN |
目标Q值计算 |
使用当前网络参数 |
使用目标网络参数 |
更新频率 |
实时更新 |
延迟更新(如每10步更新一次) |
训练稳定性 |
容易震荡 |
显著稳定 |
适用场景 |
小规模离散问题 |
复杂高维状态空间(如Atari游戏) |
六、代码示例说明
在DQN代码中,目标网络的作用体现在以下关键部分:
# 计算目标Q值时使用目标网络 next_q = target_net(next_state_batch).max(1)[0].detach() target_q = reward_batch + (1 - done_batch) * GAMMA * next_q # 主网络计算当前Q值 current_q = policy_net(state_batch).gather(1, action_batch.unsqueeze(1)) # 损失计算(均方误差) loss = nn.MSELoss()(current_q.squeeze(), target_q)
七、总结
目标网络通过延迟参数更新,使目标Q值的计算更加稳定,解决了深度强化学习中因Q值估计波动导致的训练不收敛问题。这一设计是DQN成功应用于复杂环境(如Atari游戏)的核心技术之一。
经验回放(Experience Replay)是深度强化学习中的一种关键技术,用于存储并重复利用智能体(Agent)与环境交互的经验,以提高训练效率和稳定性。以下是详细解释:
经验回放是什么?
一、核心思想
智能体在环境中执行动作后,会将每一步的交互结果(称为“经验”)存储在一个固定容量的缓冲区(Replay Buffer)中。每条经验通常包含:
- 当前状态(state)
- 执行的动作(action)
- 获得的奖励(reward)
- 下一个状态(next_state)
- 是否终止(done,表示是否结束当前回合)
- 训练时,随机从缓冲区中抽取一批历史经验用于更新神经网络,而不是直接使用最新生成的经验。
二、代码示例
class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) # 固定容量的队列 def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.buffer, batch_size) # 随机采样
三、为什么需要经验回放?
(1)打破数据相关性
- 问题:如果直接使用连续采集的经验(如最近10步),相邻数据之间具有强时间相关性。
- 后果:神经网络容易“记住”局部的连续模式,导致训练不稳定或陷入局部最优。
- 解决:通过随机采样历史经验,使训练数据之间独立同分布(i.i.d),模拟监督学习的条件。
(2)提高数据利用率
- 普通Q-learning每步更新后丢弃经验,而经验回放可以重复利用每条经验多次。
- 例如,一条经验可能在多个训练批次中被采样,降低与环境交互的成本(尤其在现实机器人等耗时场景中)。
(3)稳定训练过程
- 随机采样减少了单步更新对当前策略的依赖,避免因策略快速变化导致的Q值震荡。
四、经验回放的工作流程
- 收集经验
智能体与环境交互,将每一步的 (state, action, reward, next_state, done) 存入缓冲区。 - 随机采样
训练时,从缓冲区中随机抽取一个小批次(如64条)经验,例如:
transitions = buffer.sample(BATCH_SIZE) batch = list(zip(*transitions))
- 计算损失并更新网络
使用这批经验计算Q值的预测值和目标值,通过梯度下降更新网络参数(如DQN中的policy_net)。
五、经验回放的直观类比
假设你在学习骑自行车:
- 没有经验回放:你只能根据最近1秒的操作(如“向左转”)调整动作,容易过度关注当前错误,导致频繁摔倒。
- 有经验回放:你从记忆库中随机回顾过去多次骑行经历(包括直行、转弯、刹车等),综合分析不同场景下的错误,学习更稳健的策略。
六、关键参数与设计
- 缓冲区容量(capacity)
- 容量过小:无法覆盖多样化的经验,可能丢失早期重要数据。
- 容量过大:旧经验长期占据缓冲区,影响对新策略的学习。
- 通常设置为数万到百万级(如DQN论文中使用100万条存储Atari游戏经验)。
- 批量大小(batch_size)
- 一般为32、64、128等,需平衡训练效率和GPU内存。
- 优先经验回放(Prioritized Experience Replay)
- 改进版方法,根据经验的“重要性”(如TD误差大小)加权采样,加速关键经验的学习。
六、经验回放 vs 在线学习
特性 |
在线学习(无经验回放) |
经验回放 |
数据使用 |
每步经验只用一次 |
重复利用历史经验 |
数据相关性 |
高(连续经验) |
低(随机采样) |
训练稳定性 |
容易波动 |
更稳定 |
适用场景 |
简单环境 |
复杂环境(如高维状态输入) |
七、总结
经验回放通过存储历史经验和随机采样,解决了强化学习中数据相关性和低效利用的问题,是DQN等算法成功应用于复杂任务(如Atari游戏、机器人控制)的关键技术之一。其核心价值在于:
让智能体像人类一样,从多样化的过往经历中学习,而非局限于眼前的最新经验。
解释一下current q,next q,target q的计算原理和过程
在深度Q网络(DQN)中,Current Q、Next Q 和 Target Q 是Q值计算的核心概念,它们的组合定义了损失函数,驱动神经网络的训练。以下是详细的原理和计算过程说明:
一、Q-learning回顾与DQN的改进
在传统Q-learning中,Q值的更新公式为:
Q(s,a)←Q(s,a)+α[r+γmaxa′Q(s′,a′)−Q(s,a)]Q(s,a)←Q(s,a)+α[r+γa′maxQ(s′,a′)−Q(s,a)]
DQN在此基础上引入神经网络和目标网络,将更新过程转化为监督学习问题。
二、三个Q值的定义与计算
1.Current Q(当前Q值)
- 定义:主网络(Policy Network)对当前状态-动作对的Q值估计。
- 计算目标:预测在状态ss下执行动作aa的预期累积奖励。
- 公式:
- Current Q=Qθ(s,a)Current Q=Qθ(s,a)
- 其中θθ是主网络的参数。
- 代码实现:
# 输入状态state_batch(批量数据),通过主网络计算所有动作的Q值 q_values = policy_net(state_batch) # 形状:[batch_size, action_dim] # 用gather选择实际执行的动作对应的Q值 current_q = q_values.gather(1, action_batch.unsqueeze(1)) # 形状:[batch_size, 1]
2.Next Q(下一状态Q值)
- 定义:目标网络(Target Network)对下一个状态s′s′的最大Q值估计。
- 计算目标:预测在状态s′s′下所有可能动作中的最优Q值。
- 公式:
- Next Q=maxa′Qθ−(s′,a′)Next Q=a′maxQθ−(s′,a′)
- 其中θ−θ−是目标网络的参数。
- 代码实现:
# 输入next_state_batch,通过目标网络计算所有动作的Q值 next_q_values = target_net(next_state_batch) # 形状:[batch_size, action_dim] # 取每行的最大Q值,并detach()阻止梯度传播 next_q = next_q_values.max(1)[0].detach() # 形状:[batch_size]
3.Target Q(目标Q值)
- 定义:基于即时奖励和折现后的Next Q,计算的Q值目标。
- 计算目标:构建一个稳定的监督信号(类似监督学习中的标签)。
- 公式:
- Target Q=r+γ⋅Next Q⋅(1−done)Target Q=r+γ⋅Next Q⋅(1−done)
- 若done=True(回合终止),则目标Q值为rr(无未来奖励)。
- 若done=False,则目标Q值为r+γ⋅Next Qr+γ⋅Next Q。
- 代码实现:
target_q = reward_batch + (1 - done_batch) * GAMMA * next_q # 形状:[batch_size]
三、损失函数与训练过程
- 损失函数:均方误差(MSE)
- L(θ)=E[(Target Q−Current Q)2]L(θ)=E[(Target Q−Current Q)2]
- 优化步骤:
- 计算current_q(主网络预测值)和target_q(目标值)。
- 通过最小化二者的均方误差更新主网络参数。
# 计算均方误差损失 loss = nn.MSELoss()(current_q.squeeze(), target_q) # 梯度清零 + 反向传播 + 参数更新 optimizer.zero_grad() loss.backward() optimizer.step()
四、计算过程图示
经验回放采样 ↓ [state, action, reward, next_state, done] │ ├───▶ Current Q ← policy_net(state)[action] │ └───▶ Target Q ← reward + γ * max(target_net(next_state)) * (1 - done) │ └─────────────┐ ↓ [计算损失并更新主网络]
五、关键设计解析
1.为什么用目标网络计算Next Q?
- 直接使用主网络会导致目标值(target_q)随主网络频繁变化,产生“追逐自己尾巴”的不稳定性。
- 目标网络参数定期更新,使目标值在一段时间内固定,稳定训练。
2.为什么用detach()?
- next_q = target_net(…).max().detach()
阻止目标网络的梯度传播到主网络,避免干扰主网络的参数更新。
3.done标志的作用
- 当回合终止时(done=True),未来奖励不存在,目标Q值仅包含当前奖励rr。
六、与传统Q-learning的对比
步骤 |
Q-learning |
DQN |
Current Q |
查表得到Q(s,a)Q(s,a) |
主网络预测Qθ(s,a)Qθ(s,a) |
Next Q |
查表得到maxQ(s′,a′)maxQ(s′,a′) |
目标网络计算maxQθ−(s′,a′)maxQθ−(s′,a′) |
Target Q |
r+γ⋅Next Qr+γ⋅Next Q |
同左,但引入done标志 |
更新方式 |
直接修改Q表 |
通过梯度下降优化神经网络 |
七、实例说明
假设一个批量数据(batch_size=3):
- state_batch: 3个状态
- action_batch: [0, 1, 0](每个状态选择的动作)
- reward_batch: [1.0, 0.5, -1.0]
- next_state_batch: 3个下一状态
- done_batch: [False, False, True]
计算过程:
- Current Q:
policy_net(state_batch) → [[0.8, 1.2], [0.5, 0.7], [1.0, 0.3]] gather(对应动作) → [0.8, 0.7, 1.0] # 取动作0、1、0的Q值
- Next Q:
target_net(next_state_batch) → [[0.9, 1.5], [0.6, 0.8], [0.0, 0.0]] max(1)[0] → [1.5, 0.8, 0.0]
- Target Q:
Target Q=[1.0+0.99⋅1.5⋅1,0.5+0.99⋅0.8⋅1,−1.0+0.99⋅0.0⋅0]=[2.485,1.292,−1.0]
- 损失计算:

八、总结
- Current Q:主网络对当前动作的价值预测。
- Next Q:目标网络对下一状态最优动作的价值估计。
- Target Q:结合即时奖励和折现后的Next Q,作为监督学习的“目标标签”。
通过分离目标网络和主网络,并利用经验回放,DQN能够稳定地训练深度神经网络,解决传统Q-learning在高维状态空间下的局限性。
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/178193.html