Source code for scludam.pipeline

# scludam, Star CLUster Detection And Membership estimation package
# Copyright (C) 2022  Simón Pedro González

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.


"""Module with default Detection and Membership Estimation Pipeline.

All steps of the process can be configured following the documentation
for each function.

Examples
--------
.. literalinclude:: ../../examples/pipeline/dep.py
    :language: python
    :linenos:

.. image:: ../../examples/pipeline/dep_pmra_pmdec.png

"""

import copy
import warnings
from typing import List, Optional

import numpy as np
import pandas as pd
from astropy.table.table import Table
from attrs import Factory, define, field, validators
from beartype import beartype
from sklearn.preprocessing import MinMaxScaler, RobustScaler

from scludam.detection import CountPeakDetector, DetectionResult
from scludam.fetcher import search_objects_near_data, simbad2gaiacolnames
from scludam.masker import RangeMasker
from scludam.membership import DBME
from scludam.plots import plot_objects, scatter2dprobaplot
from scludam.shdbscan import SHDBSCAN
from scludam.stat_tests import StatTest, TestResult
from scludam.utils import Colnames


[docs]@define class DEP: """Detection and membership estimation pipeline. Class for running detection and membership estimation in a dataframe. Attributes ---------- detector : CountPeakDetector Detector to use, it is required and should not be fitted. det_cols : List[str] Columns to use for detection. Should be a subset of the columns in the dataframe to be used. sample_sigma_factor : float, optional Factor to multiply the sigma of the detection region, in order to get a region sample for the overdensity detected, by default 1. Take into account that the sigma used is currently the bin shape in each dimension. tests : List[StatTest], optional Statistical tests to use for detection, optional, by default []. The list can include a non fitted instance of :class:`~scludam.stat_tests.StatTest`. test_cols : List[List[str]], optional Columns to use for statistical tests, optional, by default []. Note that the list must have the same length as the tests list, and each item is a list of columns to use for the statistical test. test_mode : str, optional Mode to use for statistical tests, optional, by default 'any'. If 'any', the test is considered passed if any of the stat tests results in the rejection of their null hypothesis. Other options are 'all' and 'majority'. clusterer : SHDBSCAN, optional Clusterer to be used to get the initial probabilities. By default, an instance of :class:`~scludam.shdbscan.SHDBSCAN` with the following parameters: auto_allow_single_cluster=True, noise_proba_mode="conservative", cluster_proba_mode="hard", scaler=RobustScaler() (from sklearn.preprocessing). estimator : DBME, optional Estimator to use for membership estimation, by default an instance of :class:`~scludam.membership.DBME` with the default parameters. mem_cols : List[str], optional Columns to use for membership estimation, by default, ``det_cols``. Should be a subset of the columns in the dataframe to be used. It is recommended to use the same columns as for detection, or at least some of them, so center estimation can be used for better clustering. For the estimation process, error and correlation columns can be used if they are available in the dataframe. For example, if mem_cols is ['x', 'y'], the program will check if columns 'x_error', 'y_error', 'x_y_corr' are present. n_detected : int Output attribute, number of overdensities detected. detection_result: DetectionResult Output attribute, result of the detection. test_results: List[TestResult] Output attribute, list of statistical test results. is_clusterable: List[bool] Output attribute, list of decisions taken in the tests, whether the overdensity can be clustered. n_estimated : int Output attribute, number of overdensities estimated, that is, overdensities that passed the tests. proba : np.ndarray Output attribute, membership probabilities of each cluster found. labels : np.ndarray Output attribute, label of each data point, starting from -1 (noise). limits: List[Numeric1DArray] Output attribute, list of limits used for each detection region. masks: List[NDArray[bool]] Output attribute, list of masks used for each detection region. clusterers: List[SHDBSCAN] Output attribute, list of clusterers used for each detection region. estimators: List[DBME] Output attribute, list of estimators used for each detection region. """ # input attributes detector: CountPeakDetector det_cols: List[str] tests: List[StatTest] = Factory(list) test_cols: List[List[str]] = field(factory=list) clusterer: SHDBSCAN = SHDBSCAN( auto_allow_single_cluster=True, min_cluster_size=50, noise_proba_mode="conservative", cluster_proba_mode="hard", scaler=RobustScaler(), ) estimator: DBME = DBME() test_mode: str = field( default="any", validator=validators.in_(["any", "all", "majority"]) ) mem_cols: List[str] = field(default=None) sample_sigma_factor: int = 1 # output attributes n_detected: int = None n_estimated: int = None test_results: List[List[TestResult]] = Factory(list) detection_result: DetectionResult = None proba: np.ndarray = None labels: np.ndarray = None limits: List = Factory(list) masks: List = Factory(list) clusterers: List = Factory(list) estimators: List = Factory(list) is_clusterable: List = Factory(list) # internal attributes _df: pd.DataFrame = None _colnames: Colnames = None _objects: pd.DataFrame = None @test_cols.validator def _test_cols_validator(self, attr, value): if len(value) != len(self.tests): raise ValueError("test_cols must have the same length as tests") def __attrs_post_init__(self): """Attrs initialization. Do not execute. """ if self.mem_cols is None: self.mem_cols = self.det_cols elif sorted(self.mem_cols) == sorted(self.det_cols): self.mem_cols = sorted(self.det_cols) def _check_cols(self, cols): if len(self._colnames.data(cols)) != len(cols): raise ValueError( "Columns must be a subset of {}".format(self._colnames.data()) ) def _detect(self, df: pd.DataFrame): detection_data = df[self.det_cols].values detection_result = self.detector.detect(detection_data) self.detection_result = detection_result return detection_result def _get_region_mask(self, df: pd.DataFrame, center: np.ndarray, sigma: np.ndarray): detection_data = df[self.det_cols].values limits = np.vstack( ( center - sigma * self.sample_sigma_factor, center + sigma * self.sample_sigma_factor, ) ).T self.limits.append(limits) mask = RangeMasker(limits).mask(detection_data) self.masks.append(mask) return mask def _test(self, df: pd.DataFrame): if len(self.tests) == 0: return None test_cols = list(set([item for sublist in self.test_cols for item in sublist])) test_df = df[test_cols] results = [] for i, stat_test in enumerate(self.tests): data = test_df[self.test_cols[i]].values data = MinMaxScaler().fit_transform(data) results.append(stat_test.test(data)) self.test_results.append(results) return results def _is_sample_clusterable(self, test_results: List[TestResult]): if test_results is None: return True if len(test_results) == 0: is_clusterable = True trs = np.asarray([tr.rejectH0 for tr in test_results]) if self.test_mode == "any": is_clusterable = np.any(trs) elif self.test_mode == "all": is_clusterable = np.all(trs) elif self.test_mode == "majority": is_clusterable = np.sum(trs) >= trs.size / 2 else: raise ValueError("test_mode must be one of 'any', 'all', 'majority'") self.is_clusterable.append(is_clusterable) return is_clusterable def _estimate_membership(self, df: pd.DataFrame, count: int, center: np.ndarray): data = df[self.mem_cols].values # create a clusterer for the data # with the configuration defined by the user clusterer = copy.deepcopy(self.clusterer) if clusterer.min_cluster_size and clusterer.clusterer is None: clusterer.min_cluster_size = int(count) if set(self.mem_cols) <= set(self.det_cols): # if possible, use the center for # cluster selection if len(self.mem_cols) != len(self.det_cols): center = center.take([self.det_cols.index(c) for c in self.mem_cols]) clusterer.fit(data=data, centers=[center]) else: clusterer.fit(data=data) self.clusterers.append(clusterer) # get err and corr columns and use them if they exist err_cols = self._colnames.error(self.mem_cols) corr_cols = self._colnames.corr(self.mem_cols) if not self._colnames.missing_error(self.mem_cols): err = df[err_cols].values else: err = None if not self._colnames.missing_corr(self.mem_cols): corr = df[corr_cols].values else: corr = None # estimate membershipts estimator = copy.deepcopy(self.estimator) estimator.fit(data=data, init_proba=clusterer.proba, err=err, corr=corr) self.estimators.append(estimator) return estimator.posteriors
[docs] @beartype def fit(self, df: pd.DataFrame): """Perform the detection and membership estimation. NaNs are dropped from the dataframe copy and are not taken into account. Parameters ---------- df : pd.DataFrame Data frame. Returns ------- DEP Fitted instance of DEP. """ df = df.dropna() self._df = df n, d = df.shape # check all columns self._colnames = Colnames(df.columns) self._check_cols(self.det_cols) self._check_cols(self.mem_cols) for i in range(len(self.test_cols)): self._check_cols(self.test_cols[i]) # detect print("detecting overdensities...") self.detection_result = self._detect(df) # if no clusters detected, return full noise probs if not self.detection_result.centers.size: self.n_detected = 0 self.n_estimated = 0 self.proba = np.ones(n).reshape(-1, 1) return self global_proba = [] # scatter_with_coors(df[["pmra", "pmdec"]], self.detection_result.centers) # plt.show() print(f"found {self.detection_result.centers.shape[0]} overdensities") for i, center in enumerate(self.detection_result.centers): count = self.detection_result.counts[i] sigma = self.detection_result.sigmas[i] mask = self._get_region_mask(df, center, sigma) region_df = df[mask] # test print(f"testing peak {i}...") test_results = self._test(region_df) if self._is_sample_clusterable(test_results): print(f"estimating membership of peak {i}...") proba = self._estimate_membership(region_df, count, center) n_classes = proba.shape[1] n_clusters = n_classes - 1 # if np.any(proba) > 1 or np.any(proba) < 0: # print("stop") # add each found cluster probs for n_c in range(n_clusters): cluster_proba = np.zeros(n) cluster_proba[mask] = proba[:, n_c + 1] global_proba.append(cluster_proba) # add row for field prob global_proba = np.array(global_proba).T if global_proba.size == 0: self.n_detected = 0 self.n_estimated = 0 self.proba = np.ones(n).reshape(-1, 1) return self _, total_clusters = global_proba.shape col_idx = global_proba.argmax(axis=1) + 1 row_idx = np.arange(0, global_proba.shape[0]) idx = tuple(map(tuple, np.stack([row_idx, col_idx]))) result = np.zeros((n, total_clusters + 1)) # in tcase of region overlap, only the highest prob is kept result[idx] = global_proba.max(axis=1) result[:, 0] = 1 - result[idx] self.proba = result self.labels = np.argmax(self.proba, axis=1) - 1 self.n_detected = self.detection_result.centers.shape[0] self.n_estimated = self.proba.shape[1] - 1 return self
def _is_fitted(self): return self.proba is not None
[docs] @beartype def proba_df(self): """Return the data frame with the probabilities. Returns the full dataframe used for the process added columns for the labels and the probabilites. Returns ------- pd.DataFrame Data with probabilities Raises ------ Exception If DEP instance is not fitted. """ if not self._is_fitted(): raise Exception("Not fitted, try running fit()") cols = [f"proba({i-1})" for i in range(self.proba.shape[1])] df = pd.DataFrame(self.proba, columns=cols) df["label"] = self.labels return pd.concat( [self._df.reset_index(drop=True), df.reset_index(drop=True)], axis=1, sort=False, )
[docs] @beartype def write(self, path: str, **kwargs): """Write the data frame with the probabilities to a file. Writes the data frame used for the process with labels and probabilities to a file. kwargs are passed to the astropy.table.Table.write method. Default kwargs are "overwrite"=True and "format"="fits". Parameters ---------- path : str Full filepath with filename. Raises ------ Exception If DEP instance is not fitted. """ if not self._is_fitted(): raise Exception("Not fitted, try running fit()") df = self.proba_df() table = Table.from_pandas(df) default_kws = { "overwrite": True, "format": "fits", } default_kws.update(kwargs) return table.write(path, **default_kws)
[docs] @beartype def cm_diagram( self, cols: str = ["bp_rp", "phot_g_mean_mag"], plotcols: Optional[List[str]] = None, plot_objects: bool = True, **kwargs, ): """Color-magnitude diagram. Plots a 2d color magnitude diagram of the data, labels and probabilities. kwargs are passed to the :func:`~scludam.plots.scatter2dprobaplot`, some useful kwargs are "select_labels" for choosing which clusters to plot and "palette" for choosing the color palette. Parameters ---------- cols : list, optional Dataframe columns to be used, by default ["bp_rp", "phot_g_mean_mag"]. If the columns are not present in the dataframe, a KeyError is raised. plotcols : List[str], optional Colnames used for the axis labels in the plot, by default None. plot_objects : bool, optional Whether to plot objects found in the data region, by default True. By default the objects retrieved are of simbad otype "Cl*", meaning star clusters. This can be changed executing the function :func:`~scludam.pipeline.DEP.get_simbad_objects`. Returns ------- Axes Axis of the plot Raises ------ Exception If DEP instance is not fitted. """ if not self._is_fitted(): raise Exception("Not fitted, try running fit()") df = self._df[cols] ax = scatter2dprobaplot(df, self.proba, self.labels, plotcols, **kwargs) ax.invert_yaxis() if plot_objects: self._plot_objects(ax, cols) return ax
[docs] @beartype def radec_plot( self, cols: str = ["ra", "dec"], plotcols: Optional[List[str]] = None, plot_objects: bool = True, **kwargs, ): """Color-magnitude diagram. Plots a 2d color magnitude diagram of the data, labels and probabilities. kwargs are passed to the :func:`~scludam.plots.scatter2dprobaplot`, some useful kwargs are "select_labels" for choosing which clusters to plot and "palette" for choosing the color palette. Parameters ---------- cols : list, optional Dataframe columns to be used, by default ["bp_rp", "phot_g_mean_mag"]. If the columns are not present in the dataframe, a KeyError is raised. plotcols : List[str], optional Colnames used for the axis labels in the plot, by default None. plot_objects : bool, optional Whether to plot objects found in the data region, by default True. By default the objects retrieved are of simbad otype "Cl*", meaning star clusters. This can be changed executing the function :func:`~scludam.pipeline.DEP.get_simbad_objects`. Returns ------- Axes Axis of the plot Raises ------ Exception If DEP instance is not fitted. """ if not self._is_fitted(): raise Exception("Not fitted, try running fit()") df = self._df[cols] ax = scatter2dprobaplot(df, self.proba, self.labels, plotcols, **kwargs) ax.invert_xaxis() if plot_objects: self._plot_objects(ax, cols) return ax
[docs] @beartype def scatterplot( self, cols: List[str] = ["pmra", "pmdec"], plotcols: Optional[List[str]] = None, plot_objects: bool = True, **kwargs, ): """Scatter plot with results. Plots a 2d scatterplot of the data, labels and probabilities. kwargs are passed to the :func:`~scludam.plots.scatter2dprobaplot`, some useful kwargs are "select_labels" for choosing which clusters to plot and "palette" for choosing the color palette. Parameters ---------- cols : List[str], optional Dataframe columns to be used, by default ["ra", "dec"]. If the columns are not present in the dataframe, a KeyError is raised. plotcols : Optional[List[str]], optional Names of the axes labels to be used, by default None. If None, ``cols`` are used. plot_objects : bool, optional Whether to plot simbad objects found in the data region, by default True. By default the objects retrieved are of simbad otype "Cl*", meaning star clusters. This can be changed executing the function :func:`~scludam.pipeline.DEP.get_simbad_objects`. Returns ------- Axes Axis of the plot Raises ------ Exception If DEP instance is not fitted. """ if not self._is_fitted(): raise Exception("Not fitted, try running fit()") df = self._df[cols] ax = scatter2dprobaplot(df, self.proba, self.labels, plotcols, **kwargs) if plot_objects: self._plot_objects(ax, cols) return ax
def _plot_objects(self, ax, cols, otype="Cl*"): try: if self._objects is None: self._objects = self.get_simbad_objects(otype=otype) objects_to_plot = simbad2gaiacolnames(self._objects).to_pandas() plot_objects(objects_to_plot, ax, cols) return ax except Exception as e: warnings.warn(f"Could not plot objects: {e}")
[docs] def get_simbad_objects(self, **kwargs): """Get simbad objects found in the data region. kwargs are passed to the :func:`~scludam.fetcher.search_objects_near_data` function. If executed, scatterplot and cm_diagram will plot the objects found in the data region. Returns ------- astropy.table.Table Table with the objects found. Raises ------ Exception If DEP instance is not fitted. """ print("getting simbad objects...") if not self._is_fitted(): raise Exception("Not fitted, try running fit()") self._objects = search_objects_near_data(self._df, **kwargs) return self._objects