"""
GNPS Utils - Visualizer Module
This module provides functionality to visualize data from GNPS.
Author: Shahneh
"""
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, rdFMCS
from rdkit.Chem import Draw
from modifinder.convert import to_spectrum
from modifinder.utilities.network import *
from modifinder.utilities.gnps_types import *
from modifinder.utilities.general_utils import *
import modifinder.utilities.mol_utils as mu
# import modifinder.utilities.spectra_utils as su
# from modifinder.alignment import _cosine_fast
import io
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from PIL import ImageDraw, ImageFont, Image
from matplotlib.patches import ConnectionPatch
import os
from io import BytesIO
import re
[docs]def draw_molecule(mol, output_type='png', font_size = None, label=None, label_font_size = 20, label_color = (0,0,0), label_position = 'top', **kwargs):
"""
Draw a molecule using RDKit
Parameters
----------
mol : rdkit molecule or str
rdkit molecule or str for SMILES or InChI or GNPS identifier (USI or Accession)
output_type : str
type of output (png or svg)
font_size : int, optional (default=None)
font size for the labels
label : str, optional (default=None)
label for the molecule
label_font_size : int, optional (default=20)
font size for the label
label_color : tuple, optional (default=(0,0,0))
color of the label
label_position : str, optional (default='top')
position of the label (top or bottom)
kwargs : dict
additional arguments for drawing the molecule in rdkit
like highlightAtoms, highlightAtomColors, highlightBonds, highlightBondColors, highlightAtomRadii, etc.
Returns
-------
img : numpy array or str
image of the molecule
Examples
-------
.. code-block:: python
import modifinder.utilities.visualizer as mf_vis
from matplotlib import pyplot as plt
img = mf_vis.draw_molecule('CN1C=NC2=C1C(=O)N(C(=O)N2C)C', output_type='png', label="Caffeine")
plt.imshow(img)
plt.axis('off')
plt.show()
.. image:: ../_static/draw_molecule1.png
:width: 300px
.. code-block:: python
import modifinder.utilities.visualizer as mf_vis
from matplotlib import pyplot as plt
from rdkit import Chem
def mol_with_atom_index(mol):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(atom.GetIdx())
return mol
mol = Chem.MolFromSmiles('CN1C=NC2=C1C(=O)N(C(=O)N2C)C')
mol = mol_with_atom_index(mol)
highlightAtoms = {1, 3, 11, 8}
img = mf_vis.draw_molecule(mol, output_type='png', label="Caffeine", highlightAtoms=highlightAtoms)
plt.imshow(img)
plt.axis('off')
plt.show()
.. image:: ../_static/draw_molecule2.png
:width: 300px
"""
# TODO: test svg
molecule = mu._get_molecule(mol)
if molecule is None:
raise ValueError("Molecule not found")
if type(output_type) != str:
raise ValueError("Output type should be a string")
output_type = output_type.lower()
if output_type not in ["png", "svg"]:
raise ValueError("Output type should be either png or svg")
# if size is not provided, use default size
x_dim, y_dim = kwargs.get("size", (400, 400))
x_dim = kwargs.get("x_dim", x_dim)
y_dim = kwargs.get("y_dim", y_dim)
if font_size is None:
font_size = x_dim // 20
extra_info = kwargs.get("extra_info", {})
draw_kwargs = {}
for key in ['highlightAtoms', 'highlightAtomColors', 'highlightBonds', 'highlightBondColors', 'highlightAtomRadii']:
if key in kwargs:
draw_kwargs[key] = kwargs[key]
if output_type == "png":
d2d = Draw.MolDraw2DCairo(x_dim, y_dim)
d2d.drawOptions().minFontSize = font_size
d2d.drawOptions().maxFontSize = font_size
if "annotation_scale" in extra_info:
d2d.drawOptions().annotationFontScale = extra_info["annotation_scale"]
d2d.DrawMolecule(molecule, **draw_kwargs)
d2d.FinishDrawing()
png = d2d.GetDrawingText()
img = mpimg.imread(io.BytesIO(png))
if label:
img = Image.fromarray((img*255).astype(np.uint8))
draw = ImageDraw.Draw(img)
path_of_this_file = os.path.abspath(os.path.dirname(__file__))
font = ImageFont.truetype(os.path.join(path_of_this_file, "fonts/NotoSans-Regular.ttf"), label_font_size)
if label_position == 'top':
draw.text((x_dim//2, 0), label, label_color, font=font, anchor="mt")
else:
draw.text((x_dim//2, y_dim), label, label_color, font=font, anchor="mb")
img = np.array(img)
if "show_legend" in extra_info and extra_info["show_legend"]:
legend = _generate_heatmap_legend(img, extra_info["scores"], kwargs["highlightAtomColors"],
x_dim, y_dim, extra_info["legend_width"], extra_info["legend_font"], output_type)
img = _overlay_legend(img, legend, (0, y_dim-extra_info["legend_width"]-10), output_type)
return img
else:
d2d = Chem.Draw.MolDraw2DSVG(x_dim, y_dim)
d2d.SetFontSize(font_size)
if "annotation_scale" in extra_info:
d2d.drawOptions().annotationFontScale = extra_info["annotation_scale"]
d2d.DrawMolecule(molecule, **draw_kwargs)
d2d.FinishDrawing()
svg = d2d.GetDrawingText()
if "show_legend" in extra_info and extra_info["show_legend"]:
legend = _generate_heatmap_legend(svg, extra_info["scores"], kwargs["highlightAtomColors"],
x_dim, y_dim, extra_info["legend_width"], extra_info["legend_font"], output_type)
svg = _overlay_legend(svg, legend, (0, y_dim-extra_info["legend_width"]-10), output_type)
return svg
[docs]def draw_modifications(mol1, mol2, output_type='png', show_legend = True, legend_font = 15, legend_position = None, highlight_common = True, highlight_added = True, highlight_removed=True, modification_only = False, **kwargs):
"""
Draw the modifications from molecule 1 to molecule 2
Parameters
----------
mol1 : rdkit molecule, str
rdkit molecule or str for SMILES or InChI or GNPS identifier (USI or Accession)
mol2 : rdkit molecule
rdkit molecule or str for SMILES or InChI or GNPS identifier (USI or Accession)
output_type : str, optional (default='png')
type of output (png or svg)
show_legend : bool, optional (default=True)
show the legend or not
legend_font : int, optional (default=15)
font size of the legend
legend_position : tuple, optional (default=None)
position of the legend
highlight_common : bool, optional (default=True)
highlight the common atoms
highlight_added : bool, optional (default=True)
highlight the added atoms
highlight_removed : bool, optional (default=True)
highlight the removed atoms
modification_only : bool, optional (default=False)
only highlight the modification edges, if highlight_removed or highlight_added is False, this will only show the enabled ones
Returns
-------
img : numpy array or str
image of the modification
Examples
-------
.. code-block:: python
import modifinder.utilities.visualizer as mf_vis
from matplotlib import pyplot as plt
from rdkit import Chem
smiles1, smiles2 = 'N[C@@H](CCC(=O)N[C@@H](CS)C(=O)NCC(O)=O)C(O)=O', 'CCCCCCSCC(CNCC(=O)O)NC(=O)CCC(C(=O)O)N'
mol1 = mf_vis.draw_molecule(smiles1, label="mol1")
mol2 = mf_vis.draw_molecule(smiles2, label="mol2")
modification = mf_vis.draw_modifications(smiles1, smiles2, label="modifications")
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(mol1)
ax[1].imshow(mol2)
ax[2].imshow(modification)
for a in ax:
a.axis('off')
plt.show()
.. image:: ../_static/draw_modifications.png
"""
mol1, mol2 = mu._get_molecules(mol1, mol2)
result = mu.get_transition(mol1, mol2)
highlight_color_removed = (0.8, 0.1, 0.1, 0.2)
highlight_color_removed_dark = (1, 0.1, 0, 0.8)
highlight_color_added = (0.1, 0.8, 0.1, 0.2)
highlight_color_added_dark = (0, 0.7, 0.2, 0.7)
highlight_color_common = (0.1, 0.1, 0.8, 0.2)
highlight_atoms = []
highlight_bonds = []
highlight_atoms_colors = dict()
highlight_bonds_colors = dict()
if modification_only:
highlight_common = False
if highlight_common:
for atom in result['common_atoms']:
highlight_atoms.append(atom)
highlight_atoms_colors[atom] = highlight_color_common
if highlight_added:
for atom in result['added_atoms']:
highlight_atoms.append(atom)
highlight_atoms_colors[atom] = highlight_color_added
if highlight_removed:
for atom in result['removed_atoms']:
highlight_atoms.append(atom)
highlight_atoms_colors[atom] = highlight_color_removed
modification_added_edges = result['modified_added_edges_bridge'] + result['modified_added_edges_inside']
modification_removed_edges = result['modified_removed_edges_bridge'] + result['modified_removed_edges_inside']
for bondIdx in range(result['merged_mol'].GetNumBonds()):
bond = result['merged_mol'].GetBondWithIdx(bondIdx)
pair1 = (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
pair2 = (bond.GetEndAtomIdx(), bond.GetBeginAtomIdx())
if highlight_common:
if pair1 in result['common_bonds'] or pair2 in result['common_bonds']:
highlight_bonds_colors[bondIdx] = highlight_color_common
highlight_bonds.append(bondIdx)
if highlight_added:
if pair1 in modification_added_edges or pair2 in modification_added_edges:
highlight_bonds_colors[bondIdx] = highlight_color_added_dark
highlight_bonds.append(bondIdx)
elif (not modification_only) and (pair1 in result['added_edges'] or pair2 in result['added_edges']):
highlight_bonds_colors[bondIdx] = highlight_color_added
highlight_bonds.append(bondIdx)
if highlight_removed:
if pair1 in modification_removed_edges or pair2 in modification_removed_edges:
highlight_bonds_colors[bondIdx] = highlight_color_removed_dark
highlight_bonds.append(bondIdx)
elif (not modification_only) and (pair1 in result['removed_edges'] or pair2 in result['removed_edges']):
highlight_bonds_colors[bondIdx] = highlight_color_removed
highlight_bonds.append(bondIdx)
img = draw_molecule(result['merged_mol'], highlightAtoms=highlight_atoms, highlightAtomColors=highlight_atoms_colors, highlightBonds=highlight_bonds, highlightBondColors=highlight_bonds_colors, output_type=output_type, **kwargs)
if show_legend:
legend = _generate_modification_legend(legend_font, output_type, **kwargs)
img = _overlay_legend(img, legend, legend_position, output_type)
return img
[docs]def draw_molecule_heatmap(mol, scores, output_type='png', show_labels = False, shrink_labels = False, annotation_scale = 1, show_legend = True, legend_width = 50, legend_font = 40, **kwargs):
"""
Draw a molecule and color the atoms based on the scores
Parameters
----------
mol : rdkit molecule
Molecule to Draw the heatmap for
scores : list
list of scores for each atom (the order should be the same as the order of atoms in the molecule)
output_type : str, optional (default='png')
type of output (png or svg)
show_labels : bool, optional (default=False)
add labels to the atoms
shrink_labels : bool, optional (default=False)
shrink the labels to more compact form
annotation_scale : float, optional (default=1)
size of the labels (font size)
show_legend : bool, optional (default=True)
show the legend or not
legend_width : int, optional (default=50)
width of the legend in pixels
legend_font : int, optional (default=40)
font size of the legend
Returns
-------
img : numpy array or str
Examples
-------
.. code-block:: python
import modifinder.utilities.visualizer as mf_vis
from matplotlib import pyplot as plt
import numpy as np
mol = Chem.MolFromSmiles('CN1C=NC2=C1C(=O)N(C(=O)N2C)C')
scores = np.random.rand(mol.GetNumAtoms())
img = mf_vis.draw_molecule_heatmap(mol, scores, label="Caffeine", show_labels=True, legend_font=20)
plt.imshow(img)
plt.axis('off')
plt.show()
.. image:: ../_static/draw_molecule_heatmap.png
:width: 300px
"""
#TODO: test svg
molecule = mu._get_molecule(mol)
if type(molecule) == str or molecule is None:
raise ValueError("Molecule not found")
# get copy of the molecule
mol = Chem.Mol(molecule)
if min(scores) < 0:
scores = [x + abs(min(scores)) for x in scores]
if max(scores) > 0:
vals = [x/max(scores) for x in scores]
else:
vals = [0 for x in scores]
colors = dict()
for i in range(0, mol.GetNumAtoms()):
colors[i] = _get_heat_map_colors(vals[i])
if show_labels:
for atom in mol.GetAtoms():
lbl = str(round(scores[atom.GetIdx()], 2))
if shrink_labels:
if scores[atom.GetIdx()] == 0:
lbl = ""
else:
lbl = str(int(round(scores[atom.GetIdx()], 2)*100))
atom.SetProp('atomNote',lbl)
# # set the colors in kwargs
kwargs['highlightAtoms'] = list(range(mol.GetNumAtoms()))
kwargs['highlightAtomColors'] = colors
kwargs['highlightBonds'] = []
kwargs['extra_info'] = {}
kwargs['extra_info']['annotation_scale'] = annotation_scale
kwargs['extra_info']['output_type'] = output_type
kwargs['extra_info']['show_legend'] = show_legend
kwargs['extra_info']['legend_width'] = legend_width
kwargs['extra_info']['legend_font'] = legend_font
kwargs['extra_info']['scores'] = scores
# draw the molecule
img = draw_molecule(mol, output_type=output_type, **kwargs)
return img
[docs]def draw_spectrum(spectrum, output_type='png', normalize_peaks = False, colors: dict = {}, flipped = False, size = None, show_x_label = False, show_y_label = False, font_size = None, bar_width = 3, x_lim = None, dpi = 300, **kwargs):
"""
Draw a spectrum
Parameters
----------
spectrum : Spectrum or list of tuples or USI or Accession
Spectrum to draw
output_type : str, optional (default='png')
type of output (png or svg)
normalize_peaks : bool, optional (default=False)
normalize the peaks or not
colors : dict, optional (default={})
dictionary of colors for the peaks, keys are the indices of the peaks,
if value is a list, the first value is the color of top half and the
second value is the color of the bottom half.
flipped : bool, optional (default=False)
flip the spectrum or not
size : tuple, optional (default=None)
size of the figure
show_x_label : bool, optional (default=False)
show x label or not
show_y_label : bool, optional (default=False)
show y label or not
font_size : int, optional (default=None)
font size of the labels
bar_width : int, optional (default=3)
width of the spectrum bars
x_lim : tuple, optional (default=None)
x limit of the figure
dpi : int, optional (default=300)
dpi of the stored image
kwargs : dict
additional arguments for drawing the spectrum in matplotlib
Returns
-------
img : numpy array or str
Examples
-------
.. code-block:: python
import modifinder.utilities.visualizer as mf_vis
from matplotlib import pyplot as plt
import numpy as np
mz = [100, 200, 270, 400, 450]
intensity = [0.1, 0.2, 0.5, 0.4, 0.3]
colors = {
0: 'red',
1: ['blue', 'green'],
3: '#FFA500',
4: (0.9,0.9,0.2)
}
img = mf_vis.draw_spectrum(list(zip(mz, intensity)), output_type='png', show_x_label=True, show_y_label=True, colors=colors)
plt.imshow(img)
plt.axis('off')
plt.show()
.. image:: ../_static/draw_spectrum.png
:width: 300px
"""
spectrum = to_spectrum(spectrum)
if normalize_peaks:
spectrum.normalize_peaks()
# if output type is ax, use the ax to draw the spectrum
if isinstance(output_type, plt.Axes):
ax = output_type
else:
if size is None:
fig, ax = plt.subplots()
else:
fig, ax = plt.subplots(figsize=size)
if normalize_peaks:
ax.set_ylim(0, 1)
ax.set_yticks(list(np.arange(0, 1.001, 0.2)))
if font_size is not None:
plt.rcParams.update({'font.size': font_size})
for i in range(len(spectrum.mz)):
mz = spectrum.mz[i]
intensity = spectrum.intensity[i]
color = colors.get(i, 'gray')
if type(color) == list:
if color[0] is None:
color[0] = 'gray'
if color[1] is None:
color[1] = 'gray'
ax.bar(mz, intensity/2, color=color[0], width=bar_width, bottom=intensity/2)
ax.bar(mz, intensity/2, color=color[1], width=bar_width)
else:
ax.bar(mz, intensity, color=color, width=bar_width)
if show_x_label:
ax.set_xlabel("m/z")
if show_y_label:
ax.set_ylabel("Intensity")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
if flipped:
ax.invert_yaxis()
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
ax.spines['top'].set_visible(True)
ax.spines['bottom'].set_visible(False)
if x_lim is not None:
ax.set_xlim(x_lim)
if output_type == "png":
alignment_bytes = BytesIO()
fig.patch.set_alpha(0) # Set the figure patch alpha to 0 (fully transparent)
fig.savefig(alignment_bytes, format='png', bbox_inches='tight', pad_inches=0, transparent=True, dpi=dpi)
alignment_bytes.seek(0)
plt.close(fig)
# convert the image to numpy array
img = mpimg.imread(alignment_bytes)
return img
elif output_type == "svg":
fig.patch.set_alpha(0)
fig.canvas.draw()
buffer = io.BytesIO()
fig.savefig(buffer, format='svg', bbox_inches='tight', pad_inches=0, transparent=True)
buffer.seek(0)
svg = buffer.read().decode('utf-8')
plt.close(fig)
return svg
elif output_type == "ax":
return ax
[docs]def draw_alignment(spectrums, matches = None, output_type='png', normalize_peaks = False, size = None, dpi=300, draw_mapping_lines = True, ppm=40, x_lim=None, **kwargs):
"""
Draw the alignment of multiple spectrums
Parameters
----------
spectrums : list of SpectrumTuple or list of list of tuples (mz, intensity)
list of spectrums to draw
matches : list of list of tuples or list of tuples, optional (default=None)
matching between the spectrums
output_type : str, optional (default='png')
type of output (png or svg)
normalize_peaks : bool, optional (default=False)
normalize the peaks or not
size : tuple, optional (default=None)
size of the figure
dpi : int, optional (default=300)
dpi of the stored image
draw_mapping_lines : bool, optional (default=True)
draw the mapping lines
ppm : float, optional (default=40)
ppm value for the alignment
x_lim : tuple, optional (default=None)
x limit of the figure
kwargs : dict
additional arguments for drawing the spectrum in matplotlib
Returns
-------
img : numpy array or str
Examples
-------
.. code-block:: python
import modifinder.utilities.visualizer as mf_vis
from matplotlib import pyplot as plt
import numpy as np
peaks1 = list(zip([100, 200, 270, 400, 450], [0.1, 0.2, 0.5, 0.4, 0.3]))
peaks2 = list(zip([100, 230, 350, 360, 430], [0.1, 0.3, 0.2, 0.4, 0.5]))
peaks3 = list(zip([120, 230, 300, 380, 550], [0.1, 0.2, 0.5, 0.4, 0.3]))
matches = [[(0, 0), (1, 1), (3, 4)], [(0, 0), (1, 1), (3, 3)]]
img = mf_vis.draw_alignment([peaks1, peaks2, peaks3], matches=matches,
output_type='png',
normalize_peaks=True,
x_lim=(0, 550))
plt.imshow(img)
plt.axis('off')
plt.show()
.. image:: ../_static/draw_alignment.png
:width: 400px
"""
spectrums = [to_spectrum(spectrum) for spectrum in spectrums]
if normalize_peaks:
for spectrum in spectrums:
spectrum.normalize_peaks()
if x_lim is None:
x_lim = (min(spectrums[0].mz), max(spectrums[0].mz))
for spectrum in spectrums:
x_lim = (min(x_lim[0], min(spectrum.mz)), max(x_lim[1], max(spectrum.mz)))
# calculating the colors
# matches should be a list of matchings for each spectrum pairs, each matching is a list of tuples (index1, index2)
colors = [dict() for i in range(len(spectrums))]
if matches is None:
# perform the alignment
matches = []
# if matches == 'default':
# matches = []
# for i in range(len(spectrums) - 1):
# cosine, match = _cosine_fast(spectrums[i], spectrums[i+1], 0.1, ppm, True)
# matches.append(match)
if len(matches) > 0 and len(matches) != len(spectrums) - 1:
raise ValueError("Number of matches should be equal to the number of spectrums - 1")
# in case there is only one match, accept it as both a list of tuples or a list of lists
if len(matches) == 1:
if type(matches[0][0]) == int:
matches = [matches]
flipped = False
if "flipped" in kwargs:
flipped = kwargs["flipped"]
kwargs.pop("flipped")
lines = []
for match_index, match in enumerate(matches):
for pair in match:
temp_color = 'blue'
if is_shifted(spectrums[match_index].mz[pair[0]], spectrums[match_index+1].mz[pair[1]], ppm, None):
temp_color = 'red'
if pair[0] not in colors[match_index]:
if match_index > 0:
colors[match_index][pair[0]] = [None, temp_color]
else:
colors[match_index][pair[0]] = temp_color
else:
colors[match_index][pair[0]][1] = temp_color
if pair[1] not in colors[match_index+1]:
if match_index < len(spectrums) - 2:
colors[match_index+1][pair[1]] = [temp_color, None]
else:
colors[match_index+1][pair[1]] = temp_color
else:
colors[match_index+1][pair[1]][0] = temp_color
if draw_mapping_lines:
if flipped and match_index == 0:
lines.append([(match_index, spectrums[match_index].mz[pair[0]], 0), (match_index+1, spectrums[match_index+1].mz[pair[1]], 0), temp_color])
else:
lines.append([(match_index, spectrums[match_index].mz[pair[0]], 0), (match_index+1, spectrums[match_index+1].mz[pair[1]], spectrums[match_index+1].intensity[pair[1]]), temp_color])
if size is None:
fig, axs = plt.subplots(len(spectrums), 1)
else:
fig, axs = plt.subplots(len(spectrums), 1, figsize=size)
for index, spectrum in enumerate(spectrums):
if len(spectrums) == 2 and index == 1:
draw_spectrum(spectrum, output_type=axs[index], colors=colors[index], x_lim=x_lim, normalize_peaks=normalize_peaks, flipped=flipped, **kwargs)
else:
draw_spectrum(spectrum, output_type=axs[index], colors=colors[index], x_lim=x_lim, normalize_peaks=normalize_peaks, flipped=False, **kwargs)
for line in lines:
if "bar_width" in kwargs:
bar_width = max(1, kwargs["bar_width"]//1.5)
con = ConnectionPatch(xyA=(line[1][1], line[1][2]), xyB=(line[0][1], line[0][2]), coordsA="data", coordsB="data", axesA=axs[line[1][0]], axesB=axs[line[0][0]], color=line[2], linestyle='dotted', linewidth=bar_width)
else:
con = ConnectionPatch(xyA=(line[1][1], line[1][2]), xyB=(line[0][1], line[0][2]), coordsA="data", coordsB="data", axesA=axs[line[1][0]], axesB=axs[line[0][0]], color=line[2], linestyle='dotted')
axs[line[1][0]].add_artist(con)
# remove all extra paddings (but make sure the labels are not cut)
if flipped and len(spectrums) == 2:
# remove xtick labels
axs[1].set_xticklabels([])
plt.subplots_adjust(hspace=0.01, wspace=0)
else:
plt.subplots_adjust(hspace=0.1, wspace=0)
plt.tight_layout()
if output_type == "png":
alignment_bytes = BytesIO()
fig.patch.set_alpha(0) # Set the figure patch alpha to 0 (fully transparent)
fig.savefig(alignment_bytes, format='png', bbox_inches='tight', pad_inches=0, transparent=True, dpi=dpi)
alignment_bytes.seek(0)
plt.close(fig)
# convert the image to numpy array
img = mpimg.imread(alignment_bytes)
return img
elif output_type == "svg":
fig.patch.set_alpha(0)
fig.canvas.draw()
buffer = io.BytesIO()
fig.savefig(buffer, format='svg', bbox_inches='tight', pad_inches=0, transparent=True)
buffer.seek(0)
svg = buffer.read().decode('utf-8')
plt.close(fig)
return svg
else:
return fig, axs
[docs]def draw_frag_of_molecule(mol, fragment, output_type='png', **kwargs):
"""
draws a fragment of the molecule. The fragment is represented by a binary string where 1 indicates the presence of the atom and 0 indicates the absence.
Parameters
----------
mol : rdkit molecule
Molecule to draw the fragment for
fragment : int or list
fragment represented in base 10 where in the binary representation 1 indicates the presence of the atom and 0 indicates the absence, or a list of atom indices
output_type : str, optional (default='png')
type of output (png or svg)
kwargs : dict
additional arguments for drawing the molecule in rdkit
Returns
-------
img : numpy array or str
Examples
-------
.. code-block:: python
import modifinder.utilities.visualizer as mf_vis
from matplotlib import pyplot as plt
from rdkit import Chem
mol = Chem.MolFromSmiles('CN1C=NC2=C1C(=O)N(C(=O)N2C)C')
def mol_with_atom_index(mol):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(atom.GetIdx())
return mol
mol = mol_with_atom_index(mol)
fragment = int("110111", 2) # Convert binary to decimal
img = mf_vis.draw_frag_of_molecule(mol, fragment, output_type='png')
plt.imshow(img)
plt.axis('off')
plt.show()
.. image:: ../_static/draw_frag_of_molecule.png
:width: 300px
"""
mol = mu._get_molecule(mol)
highlightAtoms = []
if isinstance(fragment, int):
for i in range(0, mol.GetNumAtoms()):
if fragment & (1 << i):
highlightAtoms.append(i)
elif isinstance(fragment, list):
highlightAtoms = fragment
kwargs['highlightAtoms'] = highlightAtoms
img = draw_molecule(mol, output_type=output_type, **kwargs)
return img
def _get_heat_map_colors(val):
"""
Get the heat map color based on an input value
:param val: float - value between 0 and 1
"""
if val == 0:
return (val, 0.2*(1-val), (1-val), 0.4)
else:
return (0.95*val, 0.2*(1-val), (1-val), val*0.3 + 0.40)
def _generate_heatmap_legend(img, scores, colors, image_width, image_height, legend_span, fontSize, output_type):
"""
Add legend to the image
:param img: image
:param scores: list of scores
:param colors: dictionary of colors
:param type: str - type of output (png or svg)
"""
steps = 100
if output_type == "png":
legend_patch = np.zeros((legend_span, image_width, 4))
for i in range(0, steps):
legend_patch[:, int(i*image_width/steps):int((i+1)*image_width/steps), :] = np.array([_get_heat_map_colors(i/steps)]*legend_span).reshape(legend_span, 1, 4)
legend_patch[:,:,:] = legend_patch[:,:,:]*255
legend_image = Image.fromarray(legend_patch.astype(np.uint8), 'RGBA')
# add legend text
draw = ImageDraw.Draw(legend_image)
path_of_this_file = os.path.abspath(os.path.dirname(__file__))
font = ImageFont.truetype(os.path.join(path_of_this_file, "fonts/NotoSans-Regular.ttf"), fontSize)
draw.text((5, legend_span//2), "Low likelihood", (255, 255, 255), font=font, anchor="lm")
draw.text((image_width-5, legend_span//2), "high likelihood", (255, 255, 255), font=font, anchor="rm")
legend_patch = np.array(legend_image)
return legend_patch
# # if image is in RGB format, convert it to RGBA
# if img.max() <= 1:
# img = (img*255).astype(np.uint8)
# if img.shape[2] == 3:
# img = np.concatenate((img, np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8)*255), axis=2)
# img = np.concatenate((img, legend_patch), axis=0)
# return img
else:
svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">""".format(image_width, legend_span)
rectWidth = image_width/steps
rectHeight = legend_span
for i in range(steps, -1, -1):
color = _get_heat_map_colors(i/steps)
svg += """<rect x="{}" y="{}" width="{}" height="{}" fill="rgba({}, {}, {}, {})"/>""".format(i*rectWidth, 0, rectWidth*1.01, rectHeight,
color[0]*255, color[1]*255,color[2]*255, color[3])
svg += """<text x="{}" y="{}" font-size="{}px" fill="white">{}</text>""".format(5, int(legend_span*0.8), fontSize, "Low likelihood")
text = "high likelihood"
# add text to the right side
svg += """<text x="{}" y="{}" font-size="{}px" fill="white" text-anchor="end">{}</text>""".format(image_width-5, int(legend_span*0.84), fontSize, text)
svg += "</svg>"
return svg
def _generate_modification_legend(fontSize, output_type, modification_categories = ['Common', 'Added', 'Removed'], **kwargs):
"""
generate legend for image
:param img: image
:param type: str - type of output (png or svg)
return: numpy array of the image
"""
colors = {'Common':(0.1, 0.1, 0.8, 0.7), 'Added':(0.1, 0.8, 0.1, 0.7), 'Removed':(0.8, 0.1, 0.1, 0.7)}
if output_type == "png":
fig, ax = plt.subplots(figsize=(1, 1))
# Plot the dummy data to create a legend
for category in modification_categories:
ax.plot([], [], color=colors[category], label=category, linewidth=fontSize//1.5)
legend_location = (1, 1)
legend = ax.legend(fontsize=fontSize, **kwargs, loc='center', bbox_to_anchor=legend_location)
ax.axis('off')
fig.patch.set_alpha(0)
fig.canvas.draw()
legend_bytes = BytesIO()
fig.savefig(legend_bytes, format='png', bbox_inches='tight', pad_inches=0.1, transparent=True)
legend_bytes.seek(0)
plt.close(fig)
legend_image = Image.open(legend_bytes)
legend_image = np.array(legend_image)
return legend_image
else:
width = int(fontSize * 7 + 20)
height = len(modification_categories) * fontSize + (len(modification_categories) - 1) * fontSize//2 + 10 + 10
svg = """<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">""".format(width, height)
# draw dashed line around the border
svg += """<rect x="0" y="0" width="{}" height="{}" fill="none" stroke="black" stroke-width="1" stroke-dasharray="5,5"/>""".format(width, height)
for i, category in enumerate(modification_categories):
color = colors[category]
svg += """<rect x="10" y="{}" width="{}" height="{}" fill="rgba({}, {}, {}, {})"/>""".format(10 + i*((fontSize*3)//2), width - 20, fontSize, color[0]*255, color[1]*255, color[2]*255, color[3])
svg += """<text x="{}" y="{}" font-size="{}px" fill="black" text-anchor="middle">{}</text>""".format(width//2, 10 + i*((fontSize*3)//2) + int(fontSize*0.84), fontSize, category)
svg += "</svg>"
return svg
def _overlay_legend(img, legend, position, output_type):
"""
Overlay the legend on the image
:param img: image
:param legend: image
:param position: tuple - position of the legend
:param output_type: str - type of output (png or svg)
"""
if output_type == "png":
if img.max() <= 1:
img = (img*255).astype(np.uint8)
else:
img = img.astype(np.uint8)
img2 = Image.fromarray(img)
legend = Image.fromarray(legend)
# if position is not provided, place the legend at the bottom right corner
if position is None:
position = (max(0, img2.size[0] - legend.size[0]), max(img2.size[1] - legend.size[1], 0))
if type(position) != tuple:
raise ValueError("Position should be a tuple")
size = (max(img2.size[0], legend.size[0] + position[0]), max(img2.size[1], legend.size[1] + position[1] ))
combined_image = Image.new("RGBA", size, img2.getpixel((0, 0)))
combined_image.paste(img2, (0, 0))
combined_image.paste(legend, position, legend)
return np.array(combined_image)
else:
trimmed_img = img.split("</svg>")[0]
# find first > and trim the legend
first_gt = legend.find(">")
trimmed_legend = legend[first_gt+1:]
# add the legend
img_size = re.search(r'width="(\d+)" height="(\d+)"', img)
legend_size = re.search(r'width="(\d+)" height="(\d+)"', legend)
if position is None:
position = (max(0, img_size[0] - legend_size[0]), max(img_size[1] - legend_size[1], 0))
if type(position) is not tuple:
raise ValueError("Position should be a tuple")
# get all the x+ and y+ values and add the position to them
x_values = re.findall(r'x="(\d+)"', trimmed_legend)
y_values = re.findall(r'y="(\d+)"', trimmed_legend)
x_values = [int(x) + position[0] for x in x_values]
y_values = [int(y) + position[1] for y in y_values]
trimmed_legend = re.sub(r'x="(\d+)"', lambda x: f'x="{x_values.pop(0)}"', trimmed_legend)
trimmed_legend = re.sub(r'y="(\d+)"', lambda x: f'y="{y_values.pop(0)}"', trimmed_legend)
svg = trimmed_img + trimmed_legend
return svg
def _cname2hex(cname):
colors = dict(mpl.colors.BASE_COLORS, **mpl.colors.CSS4_COLORS) # dictionary. key: names, values: hex codes
try:
hex = colors[cname]
return hex
except KeyError:
print(cname, ' is not registered as default colors by matplotlib!')
return None
def _hex2rgb(hex, normalize=False):
h = hex.strip('#')
rgb = np.asarray(list(int(h[i:i + 2], 16) for i in (0, 2, 4)))
return rgb
def _handle_color(color):
if type(color) == str:
if color[0] == '#':
return _hex2rgb(color)
else:
return _hex2rgb(_cname2hex(color))
else:
return color
[docs]def return_public_functions():
"""
Return the public functions of the module
"""
return {
"draw_molecule": draw_molecule,
# "draw_modifications": draw_modifications,
"draw_molecule_heatmap": draw_molecule_heatmap,
# "draw_spectrum": draw_spectrum,
# "draw_alignment": draw_alignment,
"draw_frag_of_molecule": draw_frag_of_molecule
}