From 92c804a4743e3cd3d0e5cd669c3a944f91355b82 Mon Sep 17 00:00:00 2001 From: Alfredo Di Stasio Date: Tue, 10 Mar 2026 15:47:01 +0100 Subject: [PATCH] Fix combined search filter semantics across player season joins --- apps/api/views.py | 2 +- apps/players/services/search.py | 182 +++++++++++++++++-------- apps/players/views.py | 2 +- tests/test_api.py | 42 ++++++ tests/test_players_filters_advanced.py | 146 ++++++++++++++++++++ 5 files changed, 314 insertions(+), 60 deletions(-) diff --git a/apps/api/views.py b/apps/api/views.py index 14ddd33..53d3190 100644 --- a/apps/api/views.py +++ b/apps/api/views.py @@ -46,7 +46,7 @@ class PlayerSearchApiView(ReadOnlyBaseAPIView, generics.ListAPIView): queryset = filter_players(queryset, form.cleaned_data) sort_key = form.cleaned_data.get("sort", "name_asc") if sort_key in METRIC_SORT_KEYS: - queryset = annotate_player_metrics(queryset) + queryset = annotate_player_metrics(queryset, form.cleaned_data) queryset = apply_sorting(queryset, sort_key) else: queryset = queryset.order_by("full_name", "id") diff --git a/apps/players/services/search.py b/apps/players/services/search.py index 09d18af..85d294e 100644 --- a/apps/players/services/search.py +++ b/apps/players/services/search.py @@ -3,11 +3,13 @@ from datetime import date, timedelta from django.db.models import ( Case, DecimalField, + Exists, ExpressionWrapper, F, FloatField, IntegerField, Max, + OuterRef, Q, Value, When, @@ -15,6 +17,7 @@ from django.db.models import ( from django.db.models.functions import Coalesce from apps.players.models import Player +from apps.stats.models import PlayerSeason METRIC_SORT_KEYS = {"ppg_desc", "ppg_asc", "mpg_desc", "mpg_asc"} @@ -31,15 +34,15 @@ def _years_ago_today(years: int) -> date: def _apply_min_max_filter(queryset, min_key: str, max_key: str, field_name: str, data: dict): min_value = data.get(min_key) max_value = data.get(max_key) - if min_value is not None: + if min_value not in (None, ""): queryset = queryset.filter(**{f"{field_name}__gte": min_value}) - if max_value is not None: + if max_value not in (None, ""): queryset = queryset.filter(**{f"{field_name}__lte": max_value}) return queryset -def _needs_distinct(data: dict) -> bool: - join_filter_keys = ( +def _season_scope_filter_keys() -> tuple[str, ...]: + return ( "q", "team", "competition", @@ -69,7 +72,105 @@ def _needs_distinct(data: dict) -> bool: "efficiency_metric_min", "efficiency_metric_max", ) - return any(data.get(key) not in (None, "") for key in join_filter_keys) + + +def _has_season_scope_filters(data: dict) -> bool: + return any(data.get(key) not in (None, "") for key in _season_scope_filter_keys() if key != "q") + + +def _apply_mpg_filter(queryset, *, games_field: str, minutes_field: str, min_value, max_value): + if min_value not in (None, ""): + queryset = queryset.filter(**{f"{games_field}__gt": 0}).filter( + **{f"{minutes_field}__gte": F(games_field) * min_value} + ) + if max_value not in (None, ""): + queryset = queryset.filter(**{f"{games_field}__gt": 0}).filter( + **{f"{minutes_field}__lte": F(games_field) * max_value} + ) + return queryset + + +def _apply_player_season_scope_filters(queryset, data: dict): + if data.get("team"): + queryset = queryset.filter(team=data["team"]) + if data.get("competition"): + queryset = queryset.filter(competition=data["competition"]) + if data.get("season"): + queryset = queryset.filter(season=data["season"]) + + queryset = _apply_min_max_filter(queryset, "games_played_min", "games_played_max", "games_played", data) + queryset = _apply_mpg_filter( + queryset, + games_field="games_played", + minutes_field="minutes_played", + min_value=data.get("minutes_per_game_min"), + max_value=data.get("minutes_per_game_max"), + ) + + stat_pairs = ( + ("points_per_game_min", "points_per_game_max", "stats__points"), + ("rebounds_per_game_min", "rebounds_per_game_max", "stats__rebounds"), + ("assists_per_game_min", "assists_per_game_max", "stats__assists"), + ("steals_per_game_min", "steals_per_game_max", "stats__steals"), + ("blocks_per_game_min", "blocks_per_game_max", "stats__blocks"), + ("turnovers_per_game_min", "turnovers_per_game_max", "stats__turnovers"), + ("fg_pct_min", "fg_pct_max", "stats__fg_pct"), + ("three_pct_min", "three_pct_max", "stats__three_pct"), + ("ft_pct_min", "ft_pct_max", "stats__ft_pct"), + ("efficiency_metric_min", "efficiency_metric_max", "stats__player_efficiency_rating"), + ) + for min_key, max_key, field_name in stat_pairs: + queryset = _apply_min_max_filter(queryset, min_key, max_key, field_name, data) + + return queryset + + +def _build_metric_context_filter(data: dict) -> Q: + context_filter = Q() + if data.get("team"): + context_filter &= Q(player_seasons__team=data["team"]) + if data.get("competition"): + context_filter &= Q(player_seasons__competition=data["competition"]) + if data.get("season"): + context_filter &= Q(player_seasons__season=data["season"]) + + minmax_pairs = ( + ("games_played_min", "games_played_max", "player_seasons__games_played"), + ("points_per_game_min", "points_per_game_max", "player_seasons__stats__points"), + ("rebounds_per_game_min", "rebounds_per_game_max", "player_seasons__stats__rebounds"), + ("assists_per_game_min", "assists_per_game_max", "player_seasons__stats__assists"), + ("steals_per_game_min", "steals_per_game_max", "player_seasons__stats__steals"), + ("blocks_per_game_min", "blocks_per_game_max", "player_seasons__stats__blocks"), + ("turnovers_per_game_min", "turnovers_per_game_max", "player_seasons__stats__turnovers"), + ("fg_pct_min", "fg_pct_max", "player_seasons__stats__fg_pct"), + ("three_pct_min", "three_pct_max", "player_seasons__stats__three_pct"), + ("ft_pct_min", "ft_pct_max", "player_seasons__stats__ft_pct"), + ( + "efficiency_metric_min", + "efficiency_metric_max", + "player_seasons__stats__player_efficiency_rating", + ), + ) + for min_key, max_key, field_name in minmax_pairs: + min_value = data.get(min_key) + max_value = data.get(max_key) + if min_value not in (None, ""): + context_filter &= Q(**{f"{field_name}__gte": min_value}) + if max_value not in (None, ""): + context_filter &= Q(**{f"{field_name}__lte": max_value}) + + mpg_min = data.get("minutes_per_game_min") + mpg_max = data.get("minutes_per_game_max") + if mpg_min not in (None, ""): + context_filter &= Q(player_seasons__games_played__gt=0) & Q( + player_seasons__minutes_played__gte=F("player_seasons__games_played") * mpg_min + ) + if mpg_max not in (None, ""): + context_filter &= Q(player_seasons__games_played__gt=0) & Q( + player_seasons__minutes_played__lte=F("player_seasons__games_played") * mpg_max + ) + + return context_filter def filter_players(queryset, data: dict): @@ -88,13 +189,6 @@ def filter_players(queryset, data: dict): if data.get("origin_team"): queryset = queryset.filter(origin_team=data["origin_team"]) - if data.get("team"): - queryset = queryset.filter(player_seasons__team=data["team"]) - if data.get("competition"): - queryset = queryset.filter(player_seasons__competition=data["competition"]) - if data.get("season"): - queryset = queryset.filter(player_seasons__season=data["season"]) - queryset = _apply_min_max_filter(queryset, "height_min", "height_max", "height_cm", data) queryset = _apply_min_max_filter(queryset, "weight_min", "weight_max", "weight_kg", data) @@ -106,50 +200,22 @@ def filter_players(queryset, data: dict): earliest_birth = _years_ago_today(age_max + 1) + timedelta(days=1) queryset = queryset.filter(birth_date__gte=earliest_birth) - queryset = _apply_min_max_filter( - queryset, - "games_played_min", - "games_played_max", - "player_seasons__games_played", - data, - ) - - mpg_min = data.get("minutes_per_game_min") - mpg_max = data.get("minutes_per_game_max") - if mpg_min is not None: - queryset = queryset.filter(player_seasons__games_played__gt=0).filter( - player_seasons__minutes_played__gte=F("player_seasons__games_played") * mpg_min - ) - if mpg_max is not None: - queryset = queryset.filter(player_seasons__games_played__gt=0).filter( - player_seasons__minutes_played__lte=F("player_seasons__games_played") * mpg_max + if _has_season_scope_filters(data): + scoped_seasons = _apply_player_season_scope_filters( + PlayerSeason.objects.filter(player_id=OuterRef("pk")), + data, ) + queryset = queryset.filter(Exists(scoped_seasons)) - stat_pairs = ( - ("points_per_game_min", "points_per_game_max", "player_seasons__stats__points"), - ("rebounds_per_game_min", "rebounds_per_game_max", "player_seasons__stats__rebounds"), - ("assists_per_game_min", "assists_per_game_max", "player_seasons__stats__assists"), - ("steals_per_game_min", "steals_per_game_max", "player_seasons__stats__steals"), - ("blocks_per_game_min", "blocks_per_game_max", "player_seasons__stats__blocks"), - ("turnovers_per_game_min", "turnovers_per_game_max", "player_seasons__stats__turnovers"), - ("fg_pct_min", "fg_pct_max", "player_seasons__stats__fg_pct"), - ("three_pct_min", "three_pct_max", "player_seasons__stats__three_pct"), - ("ft_pct_min", "ft_pct_max", "player_seasons__stats__ft_pct"), - ( - "efficiency_metric_min", - "efficiency_metric_max", - "player_seasons__stats__player_efficiency_rating", - ), - ) - for min_key, max_key, field_name in stat_pairs: - queryset = _apply_min_max_filter(queryset, min_key, max_key, field_name, data) - - if _needs_distinct(data): + if query: return queryset.distinct() return queryset -def annotate_player_metrics(queryset): +def annotate_player_metrics(queryset, data: dict | None = None): + data = data or {} + context_filter = _build_metric_context_filter(data) + mpg_expression = Case( When( player_seasons__games_played__gt=0, @@ -164,38 +230,38 @@ def annotate_player_metrics(queryset): return queryset.annotate( games_played_value=Coalesce( - Max("player_seasons__games_played"), + Max("player_seasons__games_played", filter=context_filter), Value(0, output_field=IntegerField()), output_field=IntegerField(), ), - mpg_value=Coalesce(Max(mpg_expression), Value(0.0)), + mpg_value=Coalesce(Max(mpg_expression, filter=context_filter), Value(0.0)), ppg_value=Coalesce( - Max("player_seasons__stats__points"), + Max("player_seasons__stats__points", filter=context_filter), Value(0, output_field=DecimalField(max_digits=6, decimal_places=2)), output_field=DecimalField(max_digits=6, decimal_places=2), ), rpg_value=Coalesce( - Max("player_seasons__stats__rebounds"), + Max("player_seasons__stats__rebounds", filter=context_filter), Value(0, output_field=DecimalField(max_digits=6, decimal_places=2)), output_field=DecimalField(max_digits=6, decimal_places=2), ), apg_value=Coalesce( - Max("player_seasons__stats__assists"), + Max("player_seasons__stats__assists", filter=context_filter), Value(0, output_field=DecimalField(max_digits=6, decimal_places=2)), output_field=DecimalField(max_digits=6, decimal_places=2), ), spg_value=Coalesce( - Max("player_seasons__stats__steals"), + Max("player_seasons__stats__steals", filter=context_filter), Value(0, output_field=DecimalField(max_digits=6, decimal_places=2)), output_field=DecimalField(max_digits=6, decimal_places=2), ), bpg_value=Coalesce( - Max("player_seasons__stats__blocks"), + Max("player_seasons__stats__blocks", filter=context_filter), Value(0, output_field=DecimalField(max_digits=6, decimal_places=2)), output_field=DecimalField(max_digits=6, decimal_places=2), ), top_efficiency=Coalesce( - Max("player_seasons__stats__player_efficiency_rating"), + Max("player_seasons__stats__player_efficiency_rating", filter=context_filter), Value(0, output_field=DecimalField(max_digits=6, decimal_places=2)), output_field=DecimalField(max_digits=6, decimal_places=2), ), diff --git a/apps/players/views.py b/apps/players/views.py index 37a118e..c0969b6 100644 --- a/apps/players/views.py +++ b/apps/players/views.py @@ -48,7 +48,7 @@ class PlayerSearchView(ListView): if form.is_valid(): queryset = filter_players(queryset, form.cleaned_data) - queryset = annotate_player_metrics(queryset) + queryset = annotate_player_metrics(queryset, form.cleaned_data) queryset = apply_sorting(queryset, form.cleaned_data.get("sort", "name_asc")) else: queryset = annotate_player_metrics(queryset).order_by("full_name", "id") diff --git a/tests/test_api.py b/tests/test_api.py index 5bde9b2..7a8e95b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -154,3 +154,45 @@ def test_player_detail_api_includes_origin_fields(client): payload = response.json() assert payload["origin_competition"] == competition.name assert payload["origin_team"] == team.name + + +@pytest.mark.django_db +def test_api_combined_filters_respect_same_player_season_context(client): + nationality = Nationality.objects.create(name="Poland", iso2_code="PL", iso3_code="POL") + competition = Competition.objects.create( + name="PLK", + slug="plk", + competition_type=Competition.CompetitionType.LEAGUE, + gender=Competition.Gender.MEN, + country=nationality, + ) + season = Season.objects.create(label="2024-2025", start_date=date(2024, 9, 1), end_date=date(2025, 6, 30)) + team_a = Team.objects.create(name="Warsaw", slug="warsaw", country=nationality) + team_b = Team.objects.create(name="Gdansk", slug="gdansk", country=nationality) + + player = Player.objects.create(first_name="Piotr", last_name="Filter", full_name="Piotr Filter", nationality=nationality) + ps_a = PlayerSeason.objects.create( + player=player, + season=season, + team=team_a, + competition=competition, + games_played=10, + minutes_played=200, + ) + PlayerSeasonStats.objects.create(player_season=ps_a, points=7, rebounds=2, assists=3, steals=1, blocks=0, turnovers=1) + ps_b = PlayerSeason.objects.create( + player=player, + season=season, + team=team_b, + competition=competition, + games_played=10, + minutes_played=300, + ) + PlayerSeasonStats.objects.create(player_season=ps_b, points=21, rebounds=4, assists=5, steals=1, blocks=0, turnovers=2) + + response = client.get( + reverse("api:players"), + data={"team": team_a.id, "season": season.id, "competition": competition.id, "points_per_game_min": "20"}, + ) + assert response.status_code == 200 + assert response.json()["count"] == 0 diff --git a/tests/test_players_filters_advanced.py b/tests/test_players_filters_advanced.py index f5d1ae0..8aa1892 100644 --- a/tests/test_players_filters_advanced.py +++ b/tests/test_players_filters_advanced.py @@ -186,3 +186,149 @@ def test_player_search_results_include_favorite_ids(client): response = client.get(reverse("players:index")) assert response.status_code == 200 assert player.id in response.context["favorite_player_ids"] + + +@pytest.mark.django_db +def test_combined_reverse_join_filters_do_not_match_across_different_player_seasons(client): + nationality = Nationality.objects.create(name="Lithuania", iso2_code="LT", iso3_code="LTU") + position = Position.objects.create(code="SG", name="Shooting Guard") + role = Role.objects.create(code="scorer", name="Scorer") + competition = Competition.objects.create( + name="LKL", + slug="lkl", + competition_type=Competition.CompetitionType.LEAGUE, + gender=Competition.Gender.MEN, + country=nationality, + ) + season = Season.objects.create(label="2025-2026", start_date=date(2025, 9, 1), end_date=date(2026, 6, 30)) + target_team = Team.objects.create(name="Kaunas", slug="kaunas", country=nationality) + other_team = Team.objects.create(name="Vilnius", slug="vilnius", country=nationality) + + player = Player.objects.create( + first_name="Jonas", + last_name="Scope", + full_name="Jonas Scope", + birth_date=date(2001, 1, 1), + nationality=nationality, + nominal_position=position, + inferred_role=role, + ) + + # Matching team/season row but low scoring. + ps_target = PlayerSeason.objects.create( + player=player, + season=season, + team=target_team, + competition=competition, + games_played=20, + minutes_played=400, + ) + PlayerSeasonStats.objects.create( + player_season=ps_target, + points=8.0, + rebounds=3.0, + assists=2.0, + steals=1.0, + blocks=0.2, + turnovers=1.5, + ) + + # High-scoring row but different team; should not satisfy combined filter. + ps_other = PlayerSeason.objects.create( + player=player, + season=season, + team=other_team, + competition=competition, + games_played=20, + minutes_played=400, + ) + PlayerSeasonStats.objects.create( + player_season=ps_other, + points=22.0, + rebounds=4.0, + assists=3.0, + steals=1.2, + blocks=0.3, + turnovers=2.0, + ) + + response = client.get( + reverse("players:index"), + data={ + "team": target_team.id, + "season": season.id, + "competition": competition.id, + "points_per_game_min": "20", + }, + ) + assert response.status_code == 200 + assert list(response.context["players"]) == [] + + +@pytest.mark.django_db +def test_displayed_metrics_are_scoped_to_filtered_context(client): + nationality = Nationality.objects.create(name="Turkey", iso2_code="TR", iso3_code="TUR") + position = Position.objects.create(code="PG", name="Point Guard") + role = Role.objects.create(code="playmaker", name="Playmaker") + competition = Competition.objects.create( + name="BSL", + slug="bsl", + competition_type=Competition.CompetitionType.LEAGUE, + gender=Competition.Gender.MEN, + country=nationality, + ) + season = Season.objects.create(label="2025-2026", start_date=date(2025, 9, 1), end_date=date(2026, 6, 30)) + target_team = Team.objects.create(name="Ankara", slug="ankara", country=nationality) + other_team = Team.objects.create(name="Izmir", slug="izmir", country=nationality) + + player = Player.objects.create( + first_name="Can", + last_name="Context", + full_name="Can Context", + birth_date=date(2000, 2, 2), + nationality=nationality, + nominal_position=position, + inferred_role=role, + ) + + ps_target = PlayerSeason.objects.create( + player=player, + season=season, + team=target_team, + competition=competition, + games_played=10, + minutes_played=250, + ) + PlayerSeasonStats.objects.create( + player_season=ps_target, + points=9.0, + rebounds=2.0, + assists=4.0, + steals=1.0, + blocks=0.1, + turnovers=2.0, + ) + + ps_other = PlayerSeason.objects.create( + player=player, + season=season, + team=other_team, + competition=competition, + games_played=12, + minutes_played=420, + ) + PlayerSeasonStats.objects.create( + player_season=ps_other, + points=24.0, + rebounds=5.0, + assists=7.0, + steals=1.5, + blocks=0.2, + turnovers=3.0, + ) + + response = client.get(reverse("players:index"), data={"team": target_team.id, "season": season.id}) + assert response.status_code == 200 + row = list(response.context["players"])[0] + assert float(row.ppg_value) == pytest.approx(9.0) + assert float(row.mpg_value) == pytest.approx(25.0)