Commit afdd120a authored by Lukas Pravda's avatar Lukas Pravda
Browse files

remove empty H labels from annotations file

parent 8ceadca2
__version__ = '0.5.6'
__version__ = '0.5.7'
......@@ -23,6 +23,7 @@ import re
import xml.etree.ElementTree as ET
from collections import OrderedDict
from sys import platform
from numpy.core.multiarray import result_type
import rdkit
from PIL import Image, ImageDraw, ImageFont
......@@ -40,9 +41,13 @@ def save_no_image(path_to_image, default_msg=None, width=200):
path_to_image (str): path to the image
width (int, optional): Defaults to 200. width of the image
"""
if path_to_image.split('.')[-1] == "svg":
svg = _svg_no_image_with_id(default_msg, width) if default_msg else _svg_no_image(width)
with open(path_to_image, 'w') as f:
if path_to_image.split(".")[-1] == "svg":
svg = (
_svg_no_image_with_id(default_msg, width)
if default_msg
else _svg_no_image(width)
)
with open(path_to_image, "w") as f:
f.write(svg)
else:
_png_no_image(path_to_image, width)
......@@ -63,26 +68,36 @@ def draw_molecule(mol, drawer, file_name, wedge_bonds, atom_highlight, bond_high
ids and RGB colors.
"""
try:
copy = rdkit.Chem.Draw.rdMolDraw2D.PrepareMolForDrawing(mol, wedgeBonds=wedge_bonds,
kekulize=True, addChiralHs=True)
copy = rdkit.Chem.Draw.rdMolDraw2D.PrepareMolForDrawing(
mol, wedgeBonds=wedge_bonds, kekulize=True, addChiralHs=True
)
except (RuntimeError, ValueError):
try:
copy = rdkit.Chem.Draw.rdMolDraw2D.PrepareMolForDrawing(mol, wedgeBonds=False,
kekulize=True, addChiralHs=True)
copy = rdkit.Chem.Draw.rdMolDraw2D.PrepareMolForDrawing(
mol, wedgeBonds=False, kekulize=True, addChiralHs=True
)
except (RuntimeError, ValueError):
copy = rdkit.Chem.Draw.rdMolDraw2D.PrepareMolForDrawing(mol, wedgeBonds=False,
kekulize=True, addChiralHs=False)
copy = rdkit.Chem.Draw.rdMolDraw2D.PrepareMolForDrawing(
mol, wedgeBonds=False, kekulize=True, addChiralHs=False
)
if bond_highlight is None:
drawer.DrawMolecule(copy, highlightAtoms=atom_highlight.keys(),
highlightAtomColors=atom_highlight)
drawer.DrawMolecule(
copy,
highlightAtoms=atom_highlight.keys(),
highlightAtomColors=atom_highlight,
)
else:
drawer.DrawMolecule(copy, highlightAtoms=atom_highlight.keys(),
highlightAtomColors=atom_highlight,
highlightBonds=bond_highlight.keys(), highlightBondColors=bond_highlight)
drawer.DrawMolecule(
copy,
highlightAtoms=atom_highlight.keys(),
highlightAtomColors=atom_highlight,
highlightBonds=bond_highlight.keys(),
highlightBondColors=bond_highlight,
)
drawer.FinishDrawing()
with open(file_name, 'w') as f:
with open(file_name, "w") as f:
svg = drawer.GetDrawingText()
f.write(svg)
......@@ -127,85 +142,146 @@ def convert_svg(svg_string, ccd_id, mol: rdkit.Chem.Mol):
json serialization.
"""
result_bag = OrderedDict([
('ccd_id', ccd_id),
('resolution', {}),
('atoms', []),
('bonds', [])
])
result_bag = OrderedDict(
[("ccd_id", ccd_id), ("resolution", {}), ("atoms", []), ("bonds", [])]
)
svg_string = _fix_svg(svg_string)
svg = ET.fromstring(svg_string)
atom_elem = svg.findall('{http://www.w3.org/2000/svg}circle')
bond_elem = svg.findall('{http://www.w3.org/2000/svg}path')
dimensions_svg = svg.find('{http://www.w3.org/2000/svg}rect')
label_elem = svg.findall('{http://www.w3.org/2000/svg}text')
atom_elem = svg.findall("{http://www.w3.org/2000/svg}circle")
bond_elem = svg.findall("{http://www.w3.org/2000/svg}path")
dimensions_svg = svg.find("{http://www.w3.org/2000/svg}rect")
label_elem = svg.findall("{http://www.w3.org/2000/svg}text")
kd_tree = None
for atom_svg in atom_elem:
atom_id_str = re.search(r'\d+', atom_svg.attrib.get('class')).group(0)
atom_id = int(atom_id_str)
atoms = _parse_atoms_from_svg(atom_elem, mol)
bonds = _parse_bonds_from_svg(bond_elem, mol)
atom_centers = [[atom["x"], atom["y"]] for atom in atoms]
kd_tree = KDTree(atom_centers)
_parse_labels_from_svg(label_elem, kd_tree, atoms)
result_bag['atoms'] = atoms
result_bag['bonds'] = bonds
if atom_id >= mol.GetNumAtoms():
result_bag["resolution"] = {
"x": float(dimensions_svg.attrib.get("width")),
"y": float(dimensions_svg.attrib.get("height")),
}
return result_bag
def _parse_atoms_from_svg(atom_elements, mol: rdkit.Chem.Mol):
"""[summary]
Args:
atom_elements ([type]): [description]
mol (rdkit.Chem.Mol): [description]
Returns:
[type]: [description]
"""
result = []
for atom_svg in atom_elements:
atom_id_str = re.search(r"\d+", atom_svg.attrib.get("class")).group(0)
atom_id = int(atom_id_str)
if atom_id >= mol.GetNumAtoms():
continue
temp = {
"name": mol.GetAtomWithIdx(atom_id).GetProp("name"),
"labels": [],
"x": float(atom_svg.attrib.get("cx")),
"y": float(atom_svg.attrib.get("cy")),
}
result.append(temp)
return result
def _parse_labels_from_svg(label_elements, kd_tree, atoms):
"""Parse atom label information from the SVG.
Args:
label_elements ([type]):
kd_tree (KDTree): Kdtree with atom proximities
atoms (list of dict): JSON representation of atoms
"""
for label_svg in label_elements:
x = label_svg.attrib.get("x")
y = label_svg.attrib.get("y")
if "nan" in x or "nan" in y: # check for broken labels with 'nan' and '-nan'
continue
temp = {
'name': mol.GetAtomWithIdx(atom_id).GetProp('name'),
'labels': [],
'x': float(atom_svg.attrib.get('cx')),
'y': float(atom_svg.attrib.get('cy'))
"x": float(x),
"y": float(y),
"style": label_svg.attrib.get("style"),
"dominant-baseline": label_svg.attrib.get("dominant-baseline"),
"text-anchor": label_svg.attrib.get("text-anchor"),
"tspans": [],
}
result_bag['atoms'].append(temp)
filtered_tspans = filter(
lambda x: x.text is not None,
label_svg.findall("{http://www.w3.org/2000/svg}tspan"),
)
for tspan in filtered_tspans:
if tspan.text == "H": # get rid of H as we do not have any connection to them anyway in 2D.
continue
tspan_item = {
"value": tspan.text,
"style": ""
if tspan.attrib.get("style") is None
else tspan.attrib.get("style"),
}
atom_centers = [[atom['x'], atom['y']] for atom in result_bag['atoms']]
kd_tree = KDTree(atom_centers)
temp["tspans"].append(tspan_item)
nearest_index = kd_tree.query([temp["x"], temp["y"]])[1]
if temp['tspans']:
atoms[nearest_index]["labels"].append(temp)
def _parse_bonds_from_svg(bond_elements, mol):
"""[summary]
for bond_svg in bond_elem:
if 'class' not in bond_svg.attrib or 'bond-selector' in bond_svg.attrib['class']:
Args:
bond_elements ([type]): [description]
mol ([type]): [description]
Returns:
[type]: [description]
"""
result = []
for bond_svg in bond_elements:
if (
"class" not in bond_svg.attrib
or "bond-selector" in bond_svg.attrib["class"]
):
continue
bond_id_str = re.search(r'\d+', bond_svg.attrib['class']).group(0)
bond_id_str = re.search(r"\d+", bond_svg.attrib["class"]).group(0)
bond_id = int(bond_id_str)
if bond_id >= mol.GetNumBonds():
continue
bond = mol.GetBondWithIdx(bond_id)
temp = {
'bgn': bond.GetBeginAtom().GetProp('name'),
'end': bond.GetEndAtom().GetProp('name'),
'coords': bond_svg.attrib.get('d'),
'style': bond_svg.attrib.get('style')
"bgn": bond.GetBeginAtom().GetProp("name"),
"end": bond.GetEndAtom().GetProp("name"),
"coords": bond_svg.attrib.get("d"),
"style": bond_svg.attrib.get("style"),
}
result_bag['bonds'].append(temp)
for label_svg in label_elem:
x = label_svg.attrib.get('x')
y = label_svg.attrib.get('y')
if 'nan' in x or 'nan' in y: # check for broken labels with 'nan' and '-nan'
continue
temp = {
'x': float(x),
'y': float(y),
'style': label_svg.attrib.get('style'),
'dominant-baseline': label_svg.attrib.get('dominant-baseline'),
'text-anchor': label_svg.attrib.get('text-anchor'),
'tspans': [{
'value': tspan.text,
'style': '' if tspan.attrib.get('style') is None else tspan.attrib.get('style')
}
for tspan in filter(lambda x: x.text is not None, label_svg.findall('{http://www.w3.org/2000/svg}tspan'))]
}
nearest_index = kd_tree.query([temp['x'], temp['y']])[1]
result_bag['atoms'][nearest_index]['labels'].append(temp)
result_bag['resolution'] = {
'x': float(dimensions_svg.attrib.get('width')),
'y': float(dimensions_svg.attrib.get('height'))
}
result.append(temp)
return result
return result_bag
def _fix_svg(svg_string):
......@@ -220,10 +296,10 @@ def _fix_svg(svg_string):
str: Fixed XML representation as a string.
"""
svg_string = re.sub('<sub>', '&lt;sub&gt;', svg_string)
svg_string = re.sub('</sub>', '&lt;/sub&gt;', svg_string)
svg_string = re.sub('<sup>', '&lt;sup&gt;', svg_string)
svg_string = re.sub('</sup>', '&lt;/sup&gt;', svg_string)
svg_string = re.sub("<sub>", "&lt;sub&gt;", svg_string)
svg_string = re.sub("</sub>", "&lt;/sub&gt;", svg_string)
svg_string = re.sub("<sup>", "&lt;sup&gt;", svg_string)
svg_string = re.sub("</sup>", "&lt;/sup&gt;", svg_string)
return svg_string
......@@ -249,8 +325,13 @@ def _png_no_image(path_to_image, width):
black = (0, 0, 0)
img = Image.new("RGBA", (width, width), white)
draw = ImageDraw.Draw(img)
draw.multiline_text((width / 4, width / 3), "No image\n available",
font=font, align='center', fill=black)
draw.multiline_text(
(width / 4, width / 3),
"No image\n available",
font=font,
align="center",
fill=black,
)
draw = ImageDraw.Draw(img)
img.save(path_to_image)
......@@ -314,13 +395,13 @@ def _supply_font():
Returns:
str: path to the font
"""
font = ''
font = ""
if platform == "linux" or platform == "linux2":
font = '/usr/share/fonts/gnu-free/FreeSans.ttf'
font = "/usr/share/fonts/gnu-free/FreeSans.ttf"
elif platform == "darwin":
font = '/Library/Fonts/arial.ttf'
font = "/Library/Fonts/arial.ttf"
elif platform == "win32":
font = 'c:\\windows\\font\\arial.ttf'
font = "c:\\windows\\font\\arial.ttf"
if os.path.isfile(font):
return font
......
......@@ -12,15 +12,12 @@ from pdbeccdutils.core.depictions import DepictionManager, DepictionSource
from pdbeccdutils.tests.tst_utilities import cif_filename
collision_free_templates = [
('hem', ['HEM', 'HEA', 'HEB', 'HEC', 'HDD', 'HEG']),
('porphycene', ['HNN', 'HME']),
('ru_complex', ['11R']),
("hem", ["HEM", "HEA", "HEB", "HEC", "HDD", "HEG"]),
("porphycene", ["HNN", "HME"]),
("ru_complex", ["11R"]),
]
collision_templates = [
('cube', ['SF4', '0KA', '1CL']),
('adamantane', ['ADM'])
]
collision_templates = [("cube", ["SF4", "0KA", "1CL"]), ("adamantane", ["ADM"])]
depictions = DepictionManager()
......@@ -33,32 +30,36 @@ def load_molecule(id_):
class TestWriteImg:
@staticmethod
def test_file_generated(tmpdir): # tmpdir is a fixture with temporary directory
mol = load_molecule('ATP')
path = str(tmpdir.join('atp_test.svg'))
mol = load_molecule("ATP")
path = str(tmpdir.join("atp_test.svg"))
mol.export_2d_svg(path)
assert os.path.isfile(path)
@staticmethod
@pytest.mark.parametrize("ccd_id,expected,names", [
("NAG", 'C8', True),
("ATP", 'C5&apos;', True),
("08T", 'BE', True),
("BCD", 'C66', True),
("ATP", '<rect', False),
("08T", '<rect', False),
("10R", '<rect', False),
("0OD", '<rect', False),
])
@pytest.mark.parametrize(
"ccd_id,expected,names",
[
("NAG", "C8", True),
("ATP", "C5&apos;", True),
("08T", "BE", True),
("BCD", "C66", True),
("ATP", "<rect", False),
("08T", "<rect", False),
("10R", "<rect", False),
("0OD", "<rect", False),
],
)
def test_image_generation_with_names(tmpdir, ccd_id, expected, names):
mol = load_molecule(ccd_id)
path = str(tmpdir.join('{}_{}.svg'.format(ccd_id, 'names' if names else 'no_names')))
path = str(
tmpdir.join("{}_{}.svg".format(ccd_id, "names" if names else "no_names"))
)
mol.export_2d_svg(path, names=names)
with open(path, 'r') as f:
with open(path, "r") as f:
content = f.readlines()
assert any(expected in i for i in content)
......@@ -106,27 +107,36 @@ class TestWriteImg:
return
json_obj = None
wd = tmpdir_factory.mktemp('svg_json_test')
out_file = os.path.join(wd, f'{component.id}.json')
wd = tmpdir_factory.mktemp("svg_json_test")
out_file = os.path.join(wd, f"{component.id}.json")
component.compute_2d(depictions)
component.export_2d_annotation(out_file)
assert os.path.isfile(out_file)
assert os.path.getsize(out_file) > 0
with open(out_file, 'r') as fp:
with open(out_file, "r") as fp:
json_obj = json.load(fp)
assert json_obj['ccd_id'] == component.id
assert json_obj['resolution']['x'] >= 0
assert json_obj['resolution']['y'] >= 0
atom_names = [atom['name'] for atom in json_obj['atoms']]
assert len(json_obj['atoms']) == component.mol_no_h.GetNumAtoms()
assert len(json_obj['bonds']) >= component.mol_no_h.GetNumBonds()
assert any(atom['labels'] for atom in json_obj['atoms']) # do we have any labels (not all atoms has one)
assert all(atom['name'] for atom in json_obj['atoms']) # do we have atom names?
assert all(bond['bgn'] in atom_names and bond['end'] in atom_names for bond in json_obj['bonds']) # are all the atoms defined?
assert all(bond['coords'] for bond in json_obj['bonds']) # do we have coordinates?
assert all(bond['style'] for bond in json_obj['bonds']) # and its stylling?
assert json_obj["ccd_id"] == component.id
assert json_obj["resolution"]["x"] >= 0
assert json_obj["resolution"]["y"] >= 0
atom_names = [atom["name"] for atom in json_obj["atoms"]]
assert len(json_obj["atoms"]) == component.mol_no_h.GetNumAtoms()
assert len(json_obj["bonds"]) >= component.mol_no_h.GetNumBonds()
assert all(atom["name"] for atom in json_obj["atoms"]) # do we have atom names?
for atom in json_obj["atoms"]:
for l in atom["labels"]:
for t in l["tspans"]:
assert t['value'] != "H" # we do not have any H labels, because we dont have links to them.
assert all(
bond["bgn"] in atom_names and bond["end"] in atom_names
for bond in json_obj["bonds"]
) # are all the atoms defined?
assert all(
bond["coords"] for bond in json_obj["bonds"]
) # do we have coordinates?
assert all(bond["style"] for bond in json_obj["bonds"]) # and its stylling?
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment