Подтвердить что ты не робот

Сравните (утвердите равенство) две сложные структуры данных, содержащие массивы numpy в unittest

Я использую модуль Python unittest и хочу проверить, равны ли две сложные структуры данных. Объектами могут быть списки dicts со всеми значениями: номерами, строками, контейнерами Python (списки/кортежи/dicts) и numpy массивами. Последние являются причиной задавать вопрос, потому что я не могу просто сделать

self.assertEqual(big_struct1, big_struct2)

поскольку он создает

ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()

Я предполагаю, что для этого мне нужно написать свой собственный тест на равенство. Он должен работать для произвольных структур. Моя текущая идея - это рекурсивная функция, которая:

  • пытается прямое сравнение текущего "node" arg1 с соответствующим node of arg2;
  • Если исключение не создано, он перемещается (здесь также обрабатываются "конечные" узлы/листья);
  • если ValueError пойман, идет глубже, пока не найдет numpy.array;
  • сравнивает массивы (например, как это).

Кажется, что проблематично отслеживать "соответствующие" узлы двух структур, но, возможно, zip - это все, что мне нужно.

Вопрос: есть ли хорошие (более простые) альтернативы этому подходу? Может быть, numpy содержит некоторые инструменты для этого? Если альтернативы не предложены, я реализую эту идею (если у меня не будет лучшей) и опубликую ответ.

P.S. У меня есть смутное чувство, что я мог бы рассмотреть вопрос, касающийся этой проблемы, но я не могу найти его сейчас.

P.P.S. Альтернативный подход - это функция, которая пересекает структуру и преобразует все numpy.array в списки, но проще ли это реализовать? Кажется таким же для меня.


Изменить: Подклассификация numpy.ndarray звучит очень многообещающе, но, очевидно, у меня нет двух сторон сравнения, жестко закодированных в тесте. Один из них, правда, действительно жестко закодирован, поэтому я могу:

  • заполнить его пользовательскими подклассами numpy.array;
  • измените isinstance(other, SaneEqualityArray) на isinstance(other, np.ndarray) в jterrace ответ;
  • всегда используйте его как LHS в сравнении.

Мои вопросы в этом отношении:

  • Будет ли это работать (я имею в виду, это звучит для меня правильно, но, может быть, некоторые сложные случаи кросс не будут обрабатываться правильно)? Будет ли мой пользовательский объект всегда заканчиваться как LHS в рекурсивных проверках равенства, как я ожидаю?
  • Опять же, есть ли лучшие способы (учитывая, что я получаю хотя бы одну из структур с реальными массивами numpy).

Изменить 2. Я пробовал это, рабочая версия (по-видимому) показана в этом ответе.

4b9b3361

Ответ 1

Итак, идея, проиллюстрированная jterrace, кажется, работает для меня с небольшой модификацией:

class SaneEqualityArray(np.ndarray):
    def __eq__(self, other):
        return (isinstance(other, np.ndarray) and self.shape == other.shape and 
            np.allclose(self, other))

Как я уже сказал, контейнер с этими объектами должен находиться в левой части проверки равенства. Я создаю объекты SaneEqualityArray из существующего numpy.ndarray следующим образом:

SaneEqualityArray(my_array.shape, my_array.dtype, my_array)

в соответствии с ndarray сигнатурой конструктора:

ndarray(shape, dtype=float, buffer=None, offset=0,
        strides=None, order=None)

Этот класс определен в наборе тестов и служит только для тестирования. RHS проверки равенства является фактическим объектом, возвращаемым проверенной функцией и содержит реальные объекты numpy.ndarray.

P.S. Благодаря авторам обоих ответов, опубликованных до сих пор, они оба были очень полезными. Если кто-либо увидит какие-либо проблемы с этим подходом, я буду благодарен за ваши отзывы.

Ответ 2

Прокомментировал бы, но он слишком длинный...

Забавно, вы не можете использовать == для проверки того, являются ли массивы одинаковыми, я бы предложил вместо np.testing.assert_array_equal.

  • который проверяет тип, форму и т.д.,
  • что не подходит для аккуратной математической математики (float('nan') == float('nan')) == False (нормальная последовательность python == имеет еще более интересный способ игнорировать это иногда, потому что она использует PyObject_RichCompareBool, которая делает (для NaNs неверно) is быстрая проверка (для тестирования, конечно, это идеально)...
  • Существует также assert_allclose, поскольку равенство с плавающей запятой может стать очень сложным, если вы выполняете фактические вычисления, и обычно вы хотите получить почти одинаковые значения, поскольку значения могут стать зависимыми от оборудования или, возможно, случайными, в зависимости от того, что вы с ними делаете.

Я бы почти предложил попробовать сериализовать его с рассолом, если вы хотите что-то такое безумно вложенное, но это чересчур строго (а точка 3, конечно, полностью сломана), например, макет памяти вашего массива не имеет значения, но имеет значение для его сериализации.

Ответ 3

Функция assertEqual будет вызывать метод объектов __eq__, который должен обрабатывать сложные типы данных. Исключением является numpy, который не имеет разумного метода __eq__. Используя подкласс numpy из этого вопроса, вы можете восстановить здравомыслие в отношении поведения равенства:

import copy
import numpy
import unittest

class SaneEqualityArray(numpy.ndarray):
    def __eq__(self, other):
        return (isinstance(other, SaneEqualityArray) and
                self.shape == other.shape and
                numpy.ndarray.__eq__(self, other).all())

class TestAsserts(unittest.TestCase):

    def testAssert(self):
        tests = [
            [1, 2],
            {'foo': 2},
            [2, 'foo', {'d': 4}],
            SaneEqualityArray([1, 2]),
            {'foo': {'hey': SaneEqualityArray([2, 3])}},
            [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
             SaneEqualityArray([5, 6]), 34]
        ]
        for t in tests:
            self.assertEqual(t, copy.deepcopy(t))

if __name__ == '__main__':
    unittest.main()

Этот тест проходит.

Ответ 4

Я бы определил свой собственный метод assertNumpyArraysEqual(), который явно делает сравнение, которое вы хотите использовать. Таким образом, ваш производственный код не изменится, но вы можете сделать разумные утверждения в своих модульных тестах. Обязательно определите его в модуле, который включает __unittest = True, чтобы он не включался в трассировку стека:

import numpy
__unittest = True

def assertNumpyArraysEqual(self, other):
    if self.shape != other.shape:
        raise AssertionError("Shapes don't match")
    if not numpy.allclose(self, other)
        raise AssertionError("Elements don't match!")

Ответ 5

Я столкнулся с той же проблемой и разработал функцию сравнения равенства на основе создания фиксированного хэша для объекта. Это дает дополнительное преимущество в том, что вы можете проверить, что объект, как и ожидалось, сравнивает его хэш с фиксированным, который хранится в коде.

Код (автономный файл python, здесь). Существуют две функции: fixed_hash_eq, которая решает вашу проблему, и compute_fixed_hash, что делает хэш из структуры. Тесты здесь

Здесь тест:

obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3[2]['b'][4] = 0
assert fixed_hash_eq(obj1, obj2)
assert not fixed_hash_eq(obj1, obj3)

Ответ 6

Основываясь на @dbw (с благодарностью), следующий метод, вставленный в подклассу тестового случая, хорошо работал у меня:

 def assertNumpyArraysEqual(self,this,that,msg=''):
    '''
    modified from http://stackoverflow.com/a/15399475/5459638
    '''
    if this.shape != that.shape:
        raise AssertionError("Shapes don't match")
    if not np.allclose(this,that):
        raise AssertionError("Elements don't match!")

Я использовал его как self.assertNumpyArraysEqual(this,that) внутри моих тестовых примеров и работал как шарм.