Source code for openquake.baselib.slots

# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4

# Copyright (C) 2013-2018 GEM Foundation

# OpenQuake is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# OpenQuake is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with OpenQuake.  If not, see <http://www.gnu.org/licenses/>.

import operator
import numpy


def with_slots(cls):
    """
    Decorator for a class with _slots_. It automatically defines
    the methods __eq__, __ne__, assert_equal.
    """
    def _compare(self, other):
        for slot in self.__class__._slots_:
            attr = operator.attrgetter(slot)
            source = attr(self)
            target = attr(other)
            if isinstance(source, numpy.ndarray):
                eq = numpy.array_equal(source, target)
            elif hasattr(source, '_slots_'):
                source.assert_equal(target)
                eq = True
            else:
                eq = source == target
            yield slot, source, target, eq

    def __eq__(self, other):
        return all(eq for slot, source, target, eq in _compare(self, other))

    def __ne__(self, other):
        return not self.__eq__(other)

    def assert_equal(self, other, ignore=()):
        for slot, source, target, eq in _compare(self, other):
            if not eq and slot not in ignore:
                raise AssertionError('slot %s: %s is different from %s' %
                                     (slot, source, target))

    cls._slots_  # raise an AttributeError for missing slots
    cls.__eq__ = __eq__
    cls.__ne__ = __ne__
    cls.assert_equal = assert_equal
    return cls