# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
#
# LICENSE
#
# Copyright (C) 2010-2025 GEM Foundation, G. Weatherill, M. Pagani,
# D. Monelli.
#
# The Hazard Modeller's Toolkit is free software: you can redistribute
# it and/or modify it under the terms of the GNU Affero General Public
# License as published by the Free Software Foundation, either version
# 3 of the License, or (at your option) any later version.
#
# You should have received a copy of the GNU Affero General Public License
# along with OpenQuake. If not, see <http://www.gnu.org/licenses/>
#
# DISCLAIMER
#
# The software Hazard Modeller's Toolkit (openquake.hmtk) provided herein
# is released as a prototype implementation on behalf of
# scientists and engineers working within the GEM Foundation (Global
# Earthquake Model).
#
# It is distributed for the purpose of open collaboration and in the
# hope that it will be useful to the scientific, engineering, disaster
# risk and software design communities.
#
# The software is NOT distributed as part of GEM’s OpenQuake suite
# (https://www.globalquakemodel.org/tools-products) and must be considered as a
# separate entity. The software provided herein is designed and implemented
# by scientific staff. It is not developed to the design standards, nor
# subject to same level of critical review by professional software
# developers, as GEM’s OpenQuake software suite.
#
# Feedback and contribution to the software is welcome, and can be
# directed to the hazard scientific staff of the GEM Model Facility
# (hazard@globalquakemodel.org).
#
# The Hazard Modeller's Toolkit (openquake.hmtk) is therefore distributed WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
# for more details.
#
# The GEM Foundation, and the authors of the software, assume no
# liability for use of the software.
"""
"""
import numpy as np
from openquake.hmtk.seismicity.occurrence.utils import (
    input_checks,
    recurrence_table,
)
from openquake.hmtk.seismicity.occurrence.base import (
    SeismicityOccurrence,
    OCCURRENCE_METHODS,
)
from openquake.hmtk.seismicity.occurrence.aki_maximum_likelihood import (
    AkiMaxLikelihood,
)
[docs]@OCCURRENCE_METHODS.add(
    "calculate",
    **{
        "completeness": True,
        "reference_magnitude": 0.0,
        "magnitude_interval": 0.1,
        "Average Type": ["Weighted", "Harmonic"],
    },
)
class BMaxLikelihood(SeismicityOccurrence):
    """Implements maximum likelihood calculations taking into account time
    variation in completeness"
    """
[docs]    def calculate(self, catalogue, config, completeness=None):
        """Calculates recurrence parameters a_value and b_value, and their
        respective uncertainties
        :param catalogue: Earthquake Catalogue
            An instance of :class:`openquake.hmtk.seismicity.catalogue`
        :param dict config:
            A configuration dictionary; the only parameter that can be
            defined in this case if the type of average to be applied
            in the calculation
        :param list or numpy.ndarray completeness:
            Completeness table
        """
        # Input checks
        cmag, ctime, ref_mag, dmag, config = input_checks(
            catalogue, config, completeness
        )
        # Check the configuration
        if config["Average Type"] not in ["Weighted", "Harmonic"]:
            raise ValueError("Average type not recognised in bMaxLiklihood!")
        return self._b_ml(catalogue, config, cmag, ctime, ref_mag, dmag) 
    def _b_ml(self, catalogue, config, cmag, ctime, ref_mag, dmag):
        end_year = float(catalogue.end_year)
        catalogue = catalogue.data
        ival = 0
        mag_eq_tolerance = 1e-5
        aki_ml = AkiMaxLikelihood()
        while ival < np.shape(ctime)[0]:
            id0 = np.abs(ctime - ctime[ival]) < mag_eq_tolerance
            m_c = np.min(cmag[id0])
            print("--- ctime", ctime[ival], " m_c", m_c)
            # Find events later than cut-off year, and with magnitude
            # greater than or equal to the corresponding completeness
            # magnitude. m_c - mag_eq_tolerance is required to correct
            # floating point differences.
            id1 = np.logical_and(
                catalogue["year"] >= ctime[ival],
                catalogue["magnitude"] >= (m_c - mag_eq_tolerance),
            )
            # Get a- and b- value for the selected events
            temp_rec_table = recurrence_table(
                catalogue["magnitude"][id1],
                dmag,
                catalogue["year"][id1],
                end_year - ctime[ival] + 1,
            )
            bval, sigma_b = aki_ml._aki_ml(
                temp_rec_table[:, 0], temp_rec_table[:, 1], dmag, m_c
            )
            if ival == 0:
                gr_pars = np.array([np.hstack([bval, sigma_b])])
                neq = np.sum(id1)  # Number of events
            else:
                gr_pars = np.vstack([gr_pars, np.hstack([bval, sigma_b])])
                neq = np.hstack([neq, np.sum(id1)])
            ival = ival + np.sum(id0)
        # Get average GR parameters
        bval, sigma_b = self._average_parameters(
            gr_pars, neq, config["Average Type"]
        )
        aval = self._calculate_a_value(
            bval,
            float(np.sum(neq)),
            cmag,
            ctime,
            catalogue["magnitude"],
            end_year,
            dmag,
        )
        sigma_a = self._calculate_a_value(
            bval + sigma_b,
            float(np.sum(neq)),
            cmag,
            ctime,
            catalogue["magnitude"],
            end_year,
            dmag,
        )
        if not config["reference_magnitude"]:
            return bval, sigma_b, aval, sigma_a - aval
        else:
            rate = 10.0 ** (aval - bval * config["reference_magnitude"])
            sigma_rate = (
                10.0 ** (sigma_a - bval * config["reference_magnitude"]) - rate
            )
            return bval, sigma_b, rate, sigma_rate
    def _average_parameters(self, gr_params, neq, average_type="Weighted"):
        """
        Calculates the average of a set of Gutenberg-Richter parameters
        depending on the average type
        :param numpy.ndarray gr_params:
            Gutenberg-Richter parameters [b, sigma_b, a, sigma_a]
        :param numpy.ndarray neq:
        """
        if np.shape(gr_params)[0] != neq.size:
            raise ValueError(
                "Number of weights does not correspond"
                " to number of parameters"
            )
        if "Harmonic" in average_type:
            average_parameters = self._harmonic_mean(gr_params, neq)
        else:
            average_parameters = self._weighted_mean(gr_params, neq)
        bval = average_parameters[0]
        sigma_b = average_parameters[1]
        return bval, sigma_b
    def _calculate_a_value(
        self, bvalue, nvalue, cmag, cyear, magnitude, end_year, dmag
    ):
        """
        Calculates the a-value using the method of Weichert (1980) and
        McGuire (2004)
        """
        mmin = cmag[0]
        mmax = np.max(magnitude)
        if mmax > np.max(cmag):
            cmag = np.hstack([cmag, mmax + dmag])
        target_mag = (cmag[:-1] + cmag[1:]) / 2.0
        nyear = end_year - cyear + 1.0
        beta = bvalue * np.log(10.0)
        rate_mmin = (
            nvalue
            * np.sum(np.exp(-beta * target_mag))
            / np.sum(nyear * np.exp(-beta * target_mag))
        )
        return np.log10(rate_mmin) + bvalue * mmin
    def _weighted_mean(self, parameters, neq):
        """Simple weighted mean"""
        weight = neq.astype(float) / np.sum(neq)
        if np.shape(parameters)[0] != weight.size:
            raise ValueError("Parameter vector not same shape as weights")
        else:
            average_value = np.zeros(np.shape(parameters)[1], dtype=float)
            for iloc in range(0, np.shape(parameters)[1]):
                average_value[iloc] = np.sum(parameters[:, iloc] * weight)
        return average_value
    def _harmonic_mean(self, parameters, neq):
        """Harmonic mean"""
        weight = neq.astype(float) / np.sum(neq)
        if np.shape(parameters)[0] != weight.size:
            raise ValueError("Parameter vector not same shape as weights")
        average_value = np.zeros(np.shape(parameters)[1], dtype=float)
        for iloc in range(0, np.shape(parameters)[1]):
            average_value[iloc] = 1.0 / np.sum(
                (weight * (1.0 / parameters[:, iloc]))
            )
        return average_value