大家好,欢迎来到IT知识分享网。
TorchDrug教程–逆合成
教程来源TorchDrug开源
目录
- TorchDrug安装
- 分子数据结构
- 属性预测
- 预训练的分子表示
- 分子生成
- 逆合成
- 知识图推理
反合成是药物发现的一项基本任务。给定一个目标分子,反合成的目标是确定一组可以产生目标的反应物。
在这个例子中,我们将展示如何使用G2Gs框架预测逆合成。G2Gs首先识别反应中心,即产物中产生的键。根据反应中心,产物被分解成几个合成子,每个合成子被转化为一个反应物。
准备数据
我们使用标准USPTO50k数据集。该数据集包含50k分子及其合成途径。首先,让我们下载并加载数据集。这可能需要一段时间。有两种模式来加载数据集。reaction模式将数据集加载为(reactants
, product
)对,用于中心识别。synthon模式将数据集作为(reactant
,synthon
)对加载,用于synthon
完成。
from torchdrug import data, datasets, utils reaction_dataset = datasets.USPTO50k("~/molecule-datasets/", atom_feature="center_identification", kekulize=True) synthon_dataset = datasets.USPTO50k("~/molecule-datasets/", as_synthon=True, atom_feature="synthon_completion", kekulize=True)
然后我们将数据集中的一些样本可视化。对于反应数据集,我们可以使用connected components()将反应物图和生成物图拆分为单个分子。注意USPTO50k忽略了所有非目标产品,所以右边只有一个产品。
from torchdrug.utils import plot for i in range(2): sample = reaction_dataset[i] reactant, product = sample["graph"] reactants = reactant.connected_components()[0] products = product.connected_components()[0] plot.reaction(reactants, products)
下面是synthon数据集中对应的示例。
for i in range(3): sample = synthon_dataset[i] reactant, synthon = sample["graph"] plot.reaction([reactant], [synthon])
为了确保两个数据集使用相同的split,我们可以在调用split()
之前设置随机种子。
import torch torch.manual_seed(1) reaction_train, reaction_valid, reaction_test = reaction_dataset.split() torch.manual_seed(1) synthon_train, synthon_valid, synthon_test = synthon_dataset.split()
中心识别
现在我们定义我们的模型。我们使用一个关系图卷积网络(RGCN)作为我们的表示模型,并包装它来完成中心识别任务。注意,这里也可以使用其他图表示学习模型。
from torchdrug import core, models, tasks reaction_model = models.RGCN(input_dim=reaction_dataset.node_feature_dim, hidden_dims=[256, 256, 256, 256, 256, 256], num_relation=reaction_dataset.num_bond_type, concat_hidden=True) reaction_task = tasks.CenterIdentification(reaction_model, feature=("graph", "atom", "bond"))
reaction_optimizer = torch.optim.Adam(reaction_task.parameters(), lr=1e-3) reaction_solver = core.Engine(reaction_task, reaction_train, reaction_valid, reaction_test, reaction_optimizer, gpus=[0], batch_size=128) reaction_solver.train(num_epoch=50) reaction_solver.evaluate("valid") reaction_solver.save("g2gs_reaction_model.pth")
验证集上的计算结果可能如下所示
accuracy: 0.
我们可以从我们的模型中展示一些预测。为了多样性,我们收集了4种不同反应类型的样品。
batch = [] reaction_set = set() for sample in reaction_valid: if sample["reaction"] not in reaction_set: reaction_set.add(sample["reaction"]) batch.append(sample) if len(batch) == 4: break batch = data.graph_collate(batch) batch = utils.cuda(batch) result = reaction_task.predict_synthon(batch)
下面的代码可视化了基本事实以及我们对样本的预测。我们用蓝色代表基本事实,红色代表错误的预测,紫色代表正确的预测。
def atoms_and_bonds(molecule, reaction_center): is_reaction_atom = (molecule.atom_map > 0) & \ (molecule.atom_map.unsqueeze(-1) == \ reaction_center.unsqueeze(0)).any(dim=-1) node_in, node_out = molecule.edge_list.t()[:2] edge_map = molecule.atom_map[molecule.edge_list[:, :2]] is_reaction_bond = (edge_map > 0).all(dim=-1) & \ (edge_map == reaction_center.unsqueeze(0)).all(dim=-1) atoms = is_reaction_atom.nonzero().flatten().tolist() bonds = is_reaction_bond[node_in < node_out].nonzero().flatten().tolist() return atoms, bonds products = batch["graph"][1] reaction_centers = result["reaction_center"] for i, product in enumerate(products): true_atoms, true_bonds = atoms_and_bonds(product, product.reaction_center) true_atoms, true_bonds = set(true_atoms), set(true_bonds) pred_atoms, pred_bonds = atoms_and_bonds(product, reaction_centers[i]) pred_atoms, pred_bonds = set(pred_atoms), set(pred_bonds) overlap_atoms = true_atoms.intersection(pred_atoms) overlap_bonds = true_bonds.intersection(pred_bonds) atoms = true_atoms.union(pred_atoms) bonds = true_bonds.union(pred_bonds) red = (1, 0.5, 0.5) blue = (0.5, 0.5, 1) purple = (1, 0.5, 1) atom_colors = {
} bond_colors = {
} for atom in atoms: if atom in overlap_atoms: atom_colors[atom] = purple elif atom in pred_atoms: atom_colors[atom] = red else: atom_colors[atom] = blue for bond in bonds: if bond in overlap_bonds: bond_colors[bond] = purple elif bond in pred_bonds: bond_colors[bond] = red else: bond_colors[bond] = blue plot.highlight(product, atoms, bonds, atom_colors, bond_colors)
合成纤维完成
类似地,我们在synthon数据集上训练synthon完成模型。
synthon_model = models.RGCN(input_dim=synthon_dataset.node_feature_dim, hidden_dims=[256, 256, 256, 256, 256, 256], num_relation=synthon_dataset.num_bond_type, concat_hidden=True) synthon_task = tasks.SynthonCompletion(synthon_model, feature=("graph",))
synthon_optimizer = torch.optim.Adam(synthon_task.parameters(), lr=1e-3) synthon_solver = core.Engine(synthon_task, synthon_train, synthon_valid, synthon_test, synthon_optimizer, gpus=[0], batch_size=128) synthon_solver.train(num_epoch=10) synthon_solver.evaluate("valid") synthon_solver.save("g2gs_synthon_model.pth")
我们可以得到一些结果
bond accuracy: 0. node in accuracy: 0. node out accuracy: 0. stop accuracy: 0. total accuracy: 0.
然后,我们执行束搜索,以产生候选反应物。
batch = [] reaction_set = set() for sample in synthon_valid: if sample["reaction"] not in reaction_set: reaction_set.add(sample["reaction"]) batch.append(sample) if len(batch) == 4: break batch = data.graph_collate(batch) batch = utils.cuda(batch) reactants, synthons = batch["graph"] reactants = reactants.ion_to_molecule() predictions = synthon_task.predict_reactant(batch, num_beam=10, max_prediction=5) synthon_id = -1 i = 0 titles = [] graphs = [] for prediction in predictions: if synthon_id != prediction.synthon_id: synthon_id = prediction.synthon_id.item() i = 0 graphs.append(reactants[synthon_id]) titles.append("Truth %d" % synthon_id) i += 1 graphs.append(prediction) if reactants[synthon_id] == prediction: titles.append("Prediction %d-%d, Correct!" % (synthon_id, i)) else: titles.append("Prediction %d-%d" % (synthon_id, i)) # reset attributes so that pack can work properly mols = [graph.to_molecule() for graph in graphs] graphs = data.PackedMolecule.from_molecule(mols) graphs.visualize(titles, save_file="uspto50k_synthon_valid.png", num_col=6)
逆合成
给定训练过的模型,我们可以将它们组合成一个端点管道进行逆向合成。这是通过将两个子任务包裹在一个逆合成任务中来完成的。
注意,如果您从未声明reaction_task
和synthon_task
的求解器,那么在将它们组合到管道中之前,您需要手动调用它们的preprocess()
方法。
# reaction_task.preprocess(reaction_train, None, None) # synthon_task.preprocess(synthon_train, None, None) task = tasks.Retrosynthesis(reaction_task, synthon_task, center_topk=2, num_synthon_beam=5, max_prediction=10)
管道将对来自两个子任务的预测之间的所有可能组合执行波束搜索。为了演示,我们使用一个较小的光束尺寸,并且只对验证集的子集进行评估。注意,如果我们给光束搜索更多的预算,结果会更好。
from torch.utils import data as torch_data lengths = [len(reaction_valid) // 10, len(reaction_valid) - len(reaction_valid) // 10] reaction_valid_small = torch_data.random_split(reaction_valid, lengths)[0] optimizer = torch.optim.Adam(task.parameters(), lr=1e-3) solver = core.Engine(task, reaction_train, reaction_valid_small, reaction_test, optimizer, gpus=[0], batch_size=32)
要加载两个子任务的参数,我们只需load_optimizer
。注意负载优化器应该设置为False
以避免冲突。
solver.load("g2gs_reaction_model.pth", load_optimizer=False) solver.load("g2gs_synthon_model.pth", load_optimizer=False) solver.evaluate("valid")
反合成的准确性可能接近于以下
top-1 accuracy: 0.47541 top-3 accuracy: 0. top-5 accuracy: 0. top-10 accuracy: 0.
以下是验证集中样本的前1个预测
batch = [] reaction_set = set() for sample in reaction_valid: if sample["reaction"] not in reaction_set: reaction_set.add(sample["reaction"]) batch.append(sample) if len(batch) == 4: break batch = data.graph_collate(batch) batch = utils.cuda(batch) predictions, num_prediction = task.predict(batch) products = batch["graph"][1] top1_index = num_prediction.cumsum(0) - num_prediction for i in range(len(products)): reactant = predictions[top1_index[i]].connected_components()[0] product = products[i].connected_components()[0] plot.reaction(reactant, product)
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/158266.html