Source code for scludam.plots

# scludam, Star CLUster Detection And Membership estimation package
# Copyright (C) 2022  Simón Pedro González

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

"""Module for helper plotting functions."""

import warnings
from numbers import Number
from typing import List, Optional, Tuple, Union

# import matplotlib.patches as mp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.patches import Ellipse
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler

from scludam.type_utils import ArrayLike, Numeric1DArray, Numeric2DArray, NumericArray


def _prepare_data_to_plot(
    data: Union[Numeric2DArray, pd.DataFrame], cols: Optional[List[str]] = None
):
    if isinstance(data, np.ndarray):
        obs, dims = data.shape
        data = pd.DataFrame(data)
        if cols is not None:
            if len(cols) != dims:
                raise ValueError("Data and cols must have the same length.")
            data.columns = cols
        else:
            data.columns = [f"var {i+1}" for i in range(dims)]
    return data


[docs]def color_from_proba(proba: Numeric2DArray, palette: str): """Create color list from palette and probabilities. It desaturates the colors given the probabilities Parameters ---------- proba : Numeric2DArray Membership probability array of shape (n_points, n_classes). palette : str Name of seaborn palette. Returns ------- List Color list of length n_points where each point has a color according to the class it belongs. List Desaturated color list of length n_points where each point has a color according to the class it belongs. The saturation is higher if the probability is closer to 1 and lower if it is closer to 1 / n_classes. List Color list of length n_classes, defining a color for each class. """ if len(proba.shape) == 1: proba = np.atleast_2d(proba).T _, n_classes = proba.shape color_palette = sns.color_palette(palette, proba.shape[1]) c = [color_palette[np.argmax(x)] for x in proba] if n_classes == 1: desaturation_factors = MinMaxScaler().fit_transform(proba) proba_c = [ sns.desaturate( color_palette[0], des_fact, ) for des_fact in desaturation_factors ] else: proba_c = [ sns.desaturate( color_palette[np.argmax(x)], (np.max(x) - 1 / n_classes) / (1 - 1 / n_classes), ) for x in proba ] return c, proba_c, color_palette
[docs]def scatter3dprobaplot( data: Union[Numeric2DArray, pd.DataFrame], proba: Numeric2DArray, cols: Optional[List[str]] = None, x: int = 0, y: int = 1, z: int = 2, palette: str = "viridis", desaturate: bool = True, **kwargs, ): """Create a 3D probability plot. It represents the provided data in x, y and z. It passes kwargs to matplotlib scatter3D [1]_ Parameters ---------- data : Union[Numeric2DArray, pd.DataFrame] Data to be plotted. proba : Numeric2DArray Array of membership probabilities, of shape (n_points, n_classes) cols : List[str], optional List of ordered column names, by default ``None``. Used if data is provided as numpy array. x : int, optional Index of the x variable, by default 0. y : int, optional Index of the y variable, by default 1. z : int, optional Index of the z variable, by default 2. palette : str, optional Seaborn palette string, by default "viridis" desaturate : bool, optional If ``True``, desaturate colors according to probability, by default ``True``. Returns ------- matplotlib.collections.PathCollection Plot of the clustering results. Raises ------ ValueError If data has less than 3 columns. References ---------- .. [1] https://matplotlib.org/stable/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html?highlight=scatter3d#mpl_toolkits.mplot3d.axes3d.Axes3D.scatter3D """ # noqa: E501 data = _prepare_data_to_plot(data, cols) cols = data.columns data = data.values if data.shape[1] < 3: raise ValueError("Data must have at least 3 columns.") fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) c, proba_c, pal = color_from_proba(proba, palette) default_kws = { "c": proba_c if desaturate else c, "alpha": 1, "s": (proba.max(axis=1).round(2) * 100).astype(int) - 49, } default_kws.update(kwargs) ax.scatter3D( data[:, x], data[:, y], data[:, z], **default_kws, ) ax.set_xlabel(cols[x]) ax.set_ylabel(cols[y]) ax.set_zlabel(cols[z]) return fig, ax
[docs]def surfprobaplot( data: Union[pd.DataFrame, Numeric2DArray], proba: Numeric2DArray, x: int = 0, y: int = 1, palette: str = "viridis", cols: Optional[List[str]] = None, **kwargs, ): """Create surface 3D probability plot. It represents the provided data in x y. It passes kwargs to matplotlib plot_trisurf [2]_. Parameters ---------- data : Union[pd.DataFrame, Numeric2DArray] Data to be plotted. proba : Numeric2DArray Membership probability array. x : int, optional Index of the x variable, by default 0 y : int, optional Index of the y variable, by default 1 palette : str, optional Seaborn palette string, by default "viridis" cols : List[str], optional List of ordered column names, by default ``None``. Returns ------- matplotlib.collections.PathCollection Plot of the clustering results. Raises ------ ValueError If data has less than 2 columns. ValueError If x or y parameters are invalid. References ---------- .. [2] https://matplotlib.org/stable/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html?highlight=plot_trisurf#mpl_toolkits.mplot3d.axes3d.Axes3D.plot_trisurf """ # noqa: E501 data = _prepare_data_to_plot(data, cols) cols = data.columns data = data.values if data.shape[1] < 2: raise ValueError("Data must have at least 2 columns.") if x == y or x >= data.shape[1] or y >= data.shape[1]: raise ValueError("Invalid x, y parameters.") if proba.shape[1] == 1: z = proba.ravel() else: z = proba[:, 1:].max(axis=1) fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) default_kws = { "cmap": palette, "shade": True, } default_kws.update(kwargs) ax.plot_trisurf( data[:, x], data[:, y], z, **default_kws, ) ax.set_xlabel(cols[x]) ax.set_ylabel(cols[y]) ax.set_zlabel("proba") return fig, ax
[docs]def pairprobaplot( data: Union[Numeric2DArray, pd.DataFrame], proba: Numeric2DArray, labels: Numeric1DArray, cols: Optional[List[str]] = None, palette: str = "viridis_r", diag_kind: str = "kde", diag_kws: Optional[dict] = None, plot_kws: Optional[dict] = None, **kwargs, ): """Pairplot of the data and the membership probabilities. It passes kwargs, diag_kws and plot_kws to seaborn pairplot [3]_ function. Parameters ---------- data : Union[Numeric2DArray, pd.DataFrame] Data to be plotted. proba : Numeric2DArray Membership probability array. labels : Numeric1DArray Labels of the data. cols : List[str], optional Column names, by default ``None`` palette : str, optional Seaborn palette, by default "viridis_r" diag_kind : str, optional Kind of plot for diagonal, by default "kde". Valid values are "hist" and "kde". diag_kws : dict, optional Additional arguments for diagonal plots, by default ``None`` plot_kws : dict, optional Additional arguments for off-diagonal plots, by default ``None`` Returns ------- seaborn.PairGrid Pairplot. Raises ------ ValueError Invalid diag_kind. References ---------- .. [3] https://seaborn.pydata.org/generated/seaborn.pairplot.html """ df = _prepare_data_to_plot(data, cols) df["Label"] = labels.astype(str) hue_order = np.sort(np.unique(labels))[::-1].astype(str).tolist() df["Proba"] = proba.max(axis=1) if diag_kind == "kde": default_diag_kws = { "multiple": "stack", "fill": True, "linewidth": 0, } elif diag_kind == "hist": default_diag_kws = { "multiple": "stack", "element": "step", "linewidth": 0, } else: raise ValueError("Invalid diag_kind") if diag_kws is not None: default_diag_kws.update(diag_kws) default_plot_kws = { "marker": "o", "size": df["Proba"], } if plot_kws is not None: default_plot_kws.update(plot_kws) grid = sns.pairplot( df.loc[:, df.columns != "Proba"], hue="Label", hue_order=hue_order, palette=palette, diag_kind=diag_kind, diag_kws=default_diag_kws, plot_kws=default_plot_kws, **kwargs, ) grid.legend.set_title("") return grid
[docs]def tsneprobaplot( data: Union[pd.DataFrame, Numeric2DArray], labels: Numeric1DArray, proba: Numeric2DArray, **kwargs, ): """Plot of data and membership probabilities using t-SNE projection. It pases kwargs to seaborn scatterplot [4]_ function. Parameters ---------- data : Union[pd.DataFrame, Numeric2DArray] Data to be plotted. labels : Numeric1DArray Labels of the data. proba : Numeric2DArray Membership probability array. Returns ------- matplotlib.axes.Axes T-SNE projected plot. References ---------- .. [4] https://seaborn.pydata.org/generated/seaborn.scatterplot.html """ if isinstance(data, pd.DataFrame): data = data.values projection = TSNE().fit_transform(data) df = pd.DataFrame( { "x": projection[:, 0], "y": projection[:, 1], "Label": labels.astype(str), "Proba": proba.max(axis=1), } ) default_kws = { "edgecolor": None, "size": df["Proba"], "palette": "viridis_r", } default_kws.update(kwargs) hue_order = np.sort(np.unique(labels))[::-1].astype(str).tolist() return sns.scatterplot( data=df, x="x", y="y", hue="Label", hue_order=hue_order, **default_kws )
[docs]def heatmap2D( hist2D: NumericArray, edges: ArrayLike, bin_shape: ArrayLike, index: ArrayLike = None, annot: bool = True, annot_prec: int = 2, annot_threshold: Number = 0.1, ticks: bool = True, tick_prec: int = 2, **kwargs, ): """Create a heatmap from a 2D histogram. Also marks index if provided. Create ticklabels from bin centers and not from bin indices. kwargs are passed to seaborn.heatmap [5]_. Parameters ---------- hist2D : NumericArray Histogram. edges : ArrayLike Edges. bin_shape : ArrayLike Bin shape of the histogram. index : ArrayLike, optional Index to be marked, by default ``None`` annot : bool, optional Use default annotations, by default ``True``. If true, annotations are created taking into account the rest of annot parameters. annot_prec : int, optional Annotation number precision, by default 2 annot_threshold : Number, optional Only annotate cells with values bigger than annot_threshold, by default 0.1 ticks : bool, optional Create ticklabels from the bin centers, by default ``True`` tick_prec : int, optional Ticklabels number precision, by default 2 Returns ------- matplotlib.axes._subplots.AxesSubplot Heatmap. To get the figure from the result of the function, use ``fig = heatmap2D.get_figure()``. References ---------- .. [5] https://seaborn.pydata.org/generated/seaborn.heatmap.html """ # annotations if annot: # create annotations as value of the histogram # but only for those bins that are above a certain threshold annot_indices = np.argwhere(hist2D.round(annot_prec) > annot_threshold) annot_values = hist2D[tuple(map(tuple, annot_indices.T))].round(annot_prec) if annot_prec == 0: annot_values = annot_values.astype(int) annot = np.ndarray(shape=hist2D.shape, dtype=str).tolist() for i, xy in enumerate(annot_indices): annot[xy[0]][xy[1]] = str(annot_values[i]) kwargs["annot"] = annot kwargs["fmt"] = "s" annot_kws = kwargs.get("annot_kws", {}) fontsize = annot_kws.get("fontsize", 8) annot_kws["fontsize"] = fontsize kwargs["annot_kws"] = annot_kws # labels # set tick labels as the value of the center of the bins, not the indices if ticks: labels = [ np.round((edges[i] + bin_shape[i] / 2).astype(float), tick_prec)[:-1] for i in range(2) ] kwargs["yticklabels"] = labels[0] kwargs["xticklabels"] = labels[1] if kwargs.get("cmap", None) is None: kwargs["cmap"] = "gist_yarg_r" hm = sns.heatmap( hist2D, **kwargs, ) if index is not None: # add lines marking the peak hlines = [index[0], index[0] + 1] vlines = [index[1], index[1] + 1] hm.hlines(hlines, *hm.get_xlim(), color="w") hm.vlines(vlines, *hm.get_ylim(), color="w") hm.invert_yaxis() hm.set_xticklabels(hm.get_xticklabels(), rotation=45) return hm
[docs]def univariate_density_plot( x: Numeric1DArray, y: Numeric1DArray, ax: Optional[Axes] = None, figure: Optional[Figure] = None, figsize: Tuple[int, int] = (8, 6), grid: bool = True, **kwargs, ): """Plot univariate density plot. Create a filled lineplot given the densities for x. kwargs are passed to matplotlib scatter plot [6]_. Parameters ---------- x : Numeric1DArray X linespace. y : Numeric1DArray Densities. ax : Optional[Axes], optional Ax to plot, by default None figure : Optional[Figure], optional Figure to plot, by default None figsize : Tuple[int, int], optional Figure size, by default (8, 6) grid : bool, optional Add grid, by default True Returns ------- matplotlib.axes.Axes Axes of the plot. References ---------- .. [6] https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html """ if ax is None: if figure is None: figure = plt.figure(figsize=figsize) ax = figure.add_subplot(1, 1, 1) default_kwargs = { "color": "blue", "marker": ",", # "linestyle": "", # "lw": 0, # "linewidths": 0, "s": 0.01, } default_kwargs.update(kwargs) ax.scatter(x, y, **default_kwargs) zero = np.zeros(len(y)) ax.fill_between( x, y, where=y >= zero, interpolate=True, color=default_kwargs.get("color", "blue"), ) ax.set_yticks([], []) if grid: ax.grid("on") return ax
[docs]def bivariate_density_plot( x: Numeric1DArray, y: Numeric1DArray, z: Numeric1DArray, levels: int = None, contour_color: str = "black", ax: Optional[Axes] = None, figure: Optional[Figure] = None, figsize: Tuple[int, int] = (8, 6), colorbar: bool = True, title: Optional[str] = None, title_size: int = 16, grid: bool = True, **kwargs, ): """Create a bivariate density plot. Create a heatmap like density plot given densities in x and y. kwargs are passed to matplotlib imshow [7]_. Parameters ---------- x : Numeric1DArray X linespace. y : Numeric1DArray Y linespace. z : Numeric1DArray Densities in x and y. levels : int, optional Number of levels to draw contour, by default None contour_color : str, optional Color to draw contour, by default "black" ax : Optional[Axes], optional Ax to plot, by default None figure : Optional[Figure], optional Figure to plot, by default None figsize : Tuple[int, int], optional Figure size, by default (8, 6) colorbar : bool, optional Add a colorbar, by default True title : Optional[str], optional Title to set, by default None title_size : int, optional Title size, by default 16 grid : bool, optional Add grid, by default True Returns ------- matplotlib.axes.Axes Axes of the plot. matplotlib.image.AxesImage Image of the plot. References ---------- .. [7] https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html """ if ax is None: if figure is None: figure = plt.figure(figsize=figsize) ax = figure.add_subplot(1, 1, 1) if levels is not None: contour = ax.contour(x, y, z, levels, colors=contour_color) ax.clabel(contour, inline=True, fontsize=8) alpha = 0.75 else: alpha = 1 default_kws = { "origin": "lower", "aspect": "auto", "cmap": "inferno", "alpha": alpha, } default_kws.update(kwargs) im = ax.imshow( z, extent=[x.min(), x.max(), y.min(), y.max()], **default_kws, ) if colorbar: divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.1) ax.get_figure().colorbar(im, cax=cax, orientation="vertical") if title is not None: ax.set_title(title, size=title_size) if grid: ax.grid("on") return ax, im
# def add_label_legend(labels, palette: List[tuple], ax): # patches = [ # mp.Patch(color=palette[i], label=f"Label {i}") for i in np.unique(labels) # ] # ax.legend(handles=patches) # return ax def _select_labels(labels, proba, select_labels): if isinstance(select_labels, int): select_labels = [select_labels] if -1 in select_labels: select_labels.remove(-1) if len(select_labels) == 0: return labels, proba new_proba = proba.copy() new_labels = proba.argmax(axis=1) new_labels[~np.isin(labels, select_labels)] = -1 selected_cols = np.array(select_labels) + 1 summarize_cols = np.array(list(set(np.arange(proba.shape[1])) - set(selected_cols))) non_selected_sum = np.atleast_2d(new_proba[:, summarize_cols]).T.sum(axis=0) new_proba[:, 0] = non_selected_sum new_proba = new_proba[:, np.array([0] + list(selected_cols))] return new_labels, new_proba def _select_1(proba, select_1): # 0 if proba [select_1] > 0, -1 otherwise new_labels = np.ones(proba.shape[0], dtype=int)*-1 new_labels[proba[:, select_1] > 0] = 0 new_proba = np.zeros((proba.shape[0], 2)) new_proba[:, 1] = proba[:, select_1] new_proba[:, 0] = 1 - proba[:, select_1] return new_labels, new_proba
[docs]def scatter2dprobaplot( data: pd.DataFrame, proba: np.ndarray, labels: np.ndarray, cols: Optional[List[str]] = None, palette: str = "Set1", select_labels: Optional[Union[List[int], int]] = None, select_1: Optional[int] = None, bg_kws: dict = {}, fg_kws: dict = {}, ): """Create a scatter plot with labels and probabilites. Parameters ---------- data : pd.DataFrame dataframe with at least 2 columns. proba : np.ndarray Probability array. labels : np.ndarray Label array. select_labels : Optional[Union[List[int], int]], optional Select labels to plot, by default None. If None, all labels are plotted. select_1: Optional[int], optional Used to select only one of the labels. Only plots that population and the background (noise lable -1), by default None. cols : Optional[List[str]], optional Axes labels to be used, by default None. If None, the columns of data are used. palette : str, optional Palette to be used to choolse label colors, by default "Set1" bg_kws : dict, optional kwargs to be passed to sns.scatterplot for the background (noise label [-1]) scatter plot, by default {}. fg_kws : dict, optional kwargs to be passed to sns.scatterplot for the foreground (labels [0, 1, ...]), by default {}. Returns ------- Axes Axes with the plot. Raises ------ ValueError If data has less than 2 columns. ValueError If probability and data have different number of rows. """ sns.set_style("whitegrid") if select_labels is not None: labels, proba = _select_labels(labels, proba, select_labels) if select_1 is not None: labels, proba = _select_1(proba, select_1) if data.shape[1] != 2: raise ValueError("Data must have 2 columns") if isinstance(data, np.ndarray): if cols is not None: df = pd.DataFrame(data, columns=cols) else: df = pd.DataFrame(data, columns=["x", "y"]) else: df = data if cols is not None: df.columns = cols else: cols = df.columns if proba.shape[0] != data.shape[0]: raise ValueError("proba must have the same number of rows as data") plotdf = pd.concat( [ df.reset_index(drop=True), pd.DataFrame( np.vstack(( np.max(proba, axis=1) if select_1 is None else proba[np.arange(proba.shape[0]), labels+1], labels)).T, columns=["Probability", "Label"], ).reset_index(drop=True), ], axis=1, sort=False, ) c, proba_c, label_c = color_from_proba(proba, palette) proba_c = np.array(proba_c) # plot background default_kws = { "s": 5, "alpha": 0.2, "color": label_c[0], "palette": palette, } default_kws.update(bg_kws) ax = sns.scatterplot(data=plotdf[labels == -1], x=cols[0], y=cols[1], **default_kws) default_kws = { "sizes": (5, 50), "size": "Probability", "hue": "Label", "palette": label_c[1:], "alpha": 0.8, } default_kws.update(fg_kws) sns.scatterplot( ax=ax, data=plotdf[labels != -1], x=cols[0], y=cols[1], **default_kws, ) return ax
[docs]def plot_objects(df: pd.DataFrame, ax: Axes, cols: List[str]): """Plot object dataframe in an axis. Object dataframe refers to a pandas dataframe created from simbad Table result, translated with :func:`~scludam.fetcher.simbad2gaiacolnames`. Parameters ---------- df : pd.DataFrame Dataframe of objects. must contain at least "MAIN_ID", "TYPED_ID" and "OTYPE". ax : Axes Axis to plot on. cols : list, optional Columns in the object dataframe to plot in the, x y axes of ``ax``. Returns ------- Axes axis with plotted objects. Raises ------ ValueError _description_ """ necessary_cols = ["MAIN_ID", "TYPED_ID", "OTYPE", cols[0], cols[1]] if not set(necessary_cols).issubset(set(df.columns)): warnings.warn( f"Object dataframe must contain {necessary_cols} columns, not plotting" " objects", UserWarning, ) df["annot"] = df["MAIN_ID"].astype(str) + "(" + df["OTYPE"] + ")" df[df["TYPED_ID"] != ""]["annot"] = df["annot"] + "\n" + df["TYPED_ID"] stardf = df[df["OTYPE"] == "Star"] nonstardf = df[df["OTYPE"] != "Star"] ax.plot(stardf[cols[0]], stardf[cols[1]], "*", color="red", alpha=0.5) ax.plot(nonstardf[cols[0]], nonstardf[cols[1]], "s", color="red", alpha=0.5) for row in df[[cols[0], cols[1], "annot"]].itertuples(): _, col1, col2, annot = row ax.annotate(annot, (col1, col2)) return ax
def _plot_cov_ellipse(cov, pos, nstd=2, ax=None, **kwargs): """Plot ellipse based on the specified covariance. Plots an `nstd` sigma error ellipse based on the specified covariance matrix (`cov`). Additional keyword arguments are passed on to the ellipse patch artist. Parameters ---------- cov : The 2x2 covariance matrix to base the ellipse on pos : The location of the center of the ellipse. Expects a 2-element sequence of [x0, y0]. nstd : The radius of the ellipse in numbers of standard deviations. Defaults to 2 standard deviations. ax : The axis that the ellipse will be plotted on. Defaults to the current axis. Additional keyword arguments are pass on to the ellipse patch. Returns ------- A matplotlib ellipse artist """ def eigsorted(cov): vals, vecs = np.linalg.eigh(cov) order = vals.argsort()[::-1] return vals[order], vecs[:, order] if ax is None: ax = plt.gca() vals, vecs = eigsorted(cov) theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) # Width and height are "full" widths, not radius width, height = 2 * nstd * np.sqrt(vals) default_kws = { "facecolor": "none", "edgecolor": "k", "linewidth": 0.5, "alpha": 1, } default_kws.update(kwargs) ellip = Ellipse(xy=pos, width=width, height=height, angle=theta, **default_kws) ax.add_artist(ellip) return ellip
[docs]def plot_kernels(ax, means, covariances, nstd=3, **kwargs): """Plot a collection of 2D Gaussians as ellipses. Parameters ---------- ax : Axes ax to plot on means : np.ndarray 2d array of kernel means. covariances : np.ndarray 1d array of 2d covariances (3d array) nstd : int, optional number of standard deviations to draw contour, by default 3 Returns ------- Axes ax with ploted ellipses. """ for i in range(means.shape[0]): _plot_cov_ellipse(cov=covariances[i], pos=means[i], nstd=nstd, ax=ax, **kwargs) return ax
[docs]def horizontal_lineplots(ys: List[np.ndarray], cols=[], **kwargs): """Plot a list of 1d arrays as horizontal lineplots. Parameters ---------- ys : List[np.ndarray] List of 1d arrays to plot. Returns ------- Axes axis with ploted lineplots. """ import matplotlib.ticker as ticker if not cols: cols = [f"col{i}" for i in range(len(ys))] df = pd.DataFrame({col: y for col, y in zip(cols, ys)}) df["index"] = df.index sns.set_style("whitegrid") default_kws = { "marker": "o", "color": "k", } default_kws.update(kwargs) fig, ax = plt.subplots(nrows=len(ys), sharex=True) for i, col in enumerate(cols): sns.lineplot(data=df, x=df["index"], y=col, ax=ax[i], **default_kws) ax[i].xaxis.set_major_locator(ticker.MultipleLocator(1)) ax[i].xaxis.set_major_formatter(ticker.ScalarFormatter()) return fig, ax