Skip to content
Snippets Groups Projects
test_sponsor.py 2.07 KiB
from functools import partial

import pytest

from greg.models import (
    OrganizationalUnit,
    Sponsor,
    SponsorOrganizationalUnit,
)


sponsor_ou_relation = partial(
    SponsorOrganizationalUnit.objects.create,
    hierarchical_access=False,
)


@pytest.fixture
def sponsor_foo() -> Sponsor:
    return Sponsor.objects.create(feide_id="foosponsor@uio.no")


@pytest.fixture
def sponsor_bar() -> Sponsor:
    return Sponsor.objects.create(feide_id="barsponsor@uio.no")


@pytest.fixture
def unit1() -> OrganizationalUnit:
    return OrganizationalUnit.objects.create(name_en="First unit")


@pytest.fixture
def unit2() -> OrganizationalUnit:
    return OrganizationalUnit.objects.create(name_en="Second unit")


@pytest.mark.django_db
def test_add_sponsor_to_multiple_units(sponsor_foo, unit1, unit2):
    sponsor_ou_relation(sponsor=sponsor_foo, organizational_unit=unit1)
    sponsor_ou_relation(sponsor=sponsor_foo, organizational_unit=unit2)
    assert list(sponsor_foo.units.all()) == [unit1, unit2]


@pytest.mark.django_db
def test_add_muliple_sponsors_to_unit(sponsor_foo, sponsor_bar, unit1, unit2):
    sponsor_ou_relation(sponsor=sponsor_foo, organizational_unit=unit1)
    sponsor_ou_relation(sponsor=sponsor_bar, organizational_unit=unit1)
    assert list(sponsor_foo.units.all()) == [unit1]
    assert list(sponsor_bar.units.all()) == [unit1]
    assert list(Sponsor.objects.filter(units=unit1.id)) == [sponsor_foo, sponsor_bar]
    assert not list(Sponsor.objects.filter(units=unit2.id))


@pytest.mark.django_db
def test_sponsor_repr(sponsor_guy):
    assert (
        repr(sponsor_guy)
        == "Sponsor(id=1, feide_id='guy@example.org', first_name='Sponsor', last_name='Guy', "
        "work_email='sponsor_guy@example.com')"
    )


@pytest.mark.django_db
def test_sponsor_str(sponsor_guy):
    assert str(sponsor_guy) == "guy@example.org (Sponsor Guy)"


@pytest.mark.django_db
def test_get_allowed_loop(loop_sponsor, looped_units, unit_foo):
    units = loop_sponsor.get_allowed_units()
    expected = [i.id for i in looped_units] + [unit_foo.id]
    assert [x.id for x in units] == expected