fe4aeb3034
- Financial year page (/financial-year): year selector, 3 KPI cards (income, fixed costs, actual expenses), income and budget-items tabs with inline CRUD - Revenue accounts as income source: salary-months toggle (12/13) per account - Household support: create household, invite members by email (existing and new users via PendingHouseholdInvite), accept invitations, set roles - Combined household income view across all active members - FinancialYear, YearlyIncome, YearlyBudgetItem, Household, HouseholdMembership models with migrations; household invite email template - Management command to migrate existing accounts/budgets to financial years - FinancialYearService in Angular with full API integration - Dashboard updated: income/fixed-costs read from financial year data, year dropdown synced with available financial years - Sidebar: financial year nav item added - i18n: all keys in DE/EN/FR/IT
1459 lines
59 KiB
Python
1459 lines
59 KiB
Python
import base64
|
||
import datetime
|
||
import hmac
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
import secrets
|
||
import time
|
||
import urllib.parse
|
||
import urllib.request
|
||
import pyotp
|
||
|
||
logger = logging.getLogger('armarium')
|
||
from django.conf import settings
|
||
from django.contrib.auth import get_user_model, authenticate
|
||
from django.http import HttpResponse
|
||
from icalendar import Calendar as iCalendar, Event as iCalEvent
|
||
|
||
from django.db import models
|
||
from rest_framework import viewsets, views, status
|
||
from rest_framework.response import Response
|
||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||
from rest_framework.throttling import AnonRateThrottle
|
||
from rest_framework_simplejwt.tokens import RefreshToken
|
||
from rest_framework_simplejwt.exceptions import TokenError
|
||
from django.db import transaction as db_transaction
|
||
from .models import (
|
||
Account, Transaction, Budget, Expense, Profile, Deadline, ReadEvent, BackupCode, UserSession,
|
||
Household, HouseholdMembership, FinancialYear, YearlyIncome, YearlyBudgetItem,
|
||
)
|
||
from .serializers import (
|
||
AccountSerializer, TransactionSerializer, BudgetSerializer,
|
||
ExpenseSerializer, ProfileSerializer, DeadlineSerializer, RegisterSerializer,
|
||
HouseholdSerializer, HouseholdMembershipSerializer,
|
||
FinancialYearSerializer, YearlyIncomeSerializer, YearlyBudgetItemSerializer,
|
||
)
|
||
|
||
|
||
def _verify_turnstile(token: str, remote_ip: str = '') -> bool:
|
||
if settings.DEBUG:
|
||
return True
|
||
if not token or not settings.TURNSTILE_SECRET_KEY:
|
||
return False
|
||
data = urllib.parse.urlencode({
|
||
'secret': settings.TURNSTILE_SECRET_KEY,
|
||
'response': token,
|
||
'remoteip': remote_ip,
|
||
}).encode()
|
||
try:
|
||
req = urllib.request.Request(
|
||
'https://challenges.cloudflare.com/turnstile/v0/siteverify',
|
||
data=data,
|
||
method='POST',
|
||
)
|
||
with urllib.request.urlopen(req, timeout=5) as resp:
|
||
return json.loads(resp.read()).get('success', False)
|
||
except Exception:
|
||
logger.warning('Turnstile verification request failed')
|
||
return False
|
||
|
||
|
||
def generate_ical_token(user_id: int) -> str:
|
||
return hmac.new(
|
||
settings.SECRET_KEY.encode(),
|
||
str(user_id).encode(),
|
||
hashlib.sha256
|
||
).hexdigest()
|
||
|
||
MAX_AVATAR_SIZE_BYTES = 2 * 1024 * 1024 # 2 MB
|
||
|
||
|
||
class AuthThrottle(AnonRateThrottle):
|
||
rate = '5/min'
|
||
|
||
|
||
class AccountViewSet(viewsets.ModelViewSet):
|
||
serializer_class = AccountSerializer
|
||
|
||
def get_queryset(self):
|
||
return Account.objects.filter(user=self.request.user)
|
||
|
||
def perform_create(self, serializer):
|
||
serializer.save(user=self.request.user)
|
||
|
||
|
||
class TransactionViewSet(viewsets.ModelViewSet):
|
||
serializer_class = TransactionSerializer
|
||
|
||
def get_queryset(self):
|
||
return Transaction.objects.filter(source_account__user=self.request.user)
|
||
|
||
def get_serializer_context(self):
|
||
context = super().get_serializer_context()
|
||
context['request'] = self.request
|
||
return context
|
||
|
||
|
||
class BudgetViewSet(viewsets.ModelViewSet):
|
||
serializer_class = BudgetSerializer
|
||
|
||
def get_queryset(self):
|
||
return Budget.objects.filter(account__user=self.request.user)
|
||
|
||
|
||
class ExpenseViewSet(viewsets.ModelViewSet):
|
||
serializer_class = ExpenseSerializer
|
||
|
||
def get_queryset(self):
|
||
return Expense.objects.filter(account__user=self.request.user)
|
||
|
||
|
||
class DeadlineViewSet(viewsets.ModelViewSet):
|
||
serializer_class = DeadlineSerializer
|
||
|
||
def get_queryset(self):
|
||
return Deadline.objects.filter(user=self.request.user).order_by('date')
|
||
|
||
def perform_create(self, serializer):
|
||
serializer.save(user=self.request.user)
|
||
|
||
|
||
class ProfileView(views.APIView):
|
||
def get(self, request):
|
||
profile, _ = Profile.objects.get_or_create(user=request.user)
|
||
return Response(ProfileSerializer(profile).data)
|
||
|
||
def put(self, request):
|
||
from .email import send_email
|
||
|
||
avatar = request.FILES.get('avatar_image')
|
||
if avatar and avatar.size > MAX_AVATAR_SIZE_BYTES:
|
||
return Response({'detail': 'Image must be smaller than 2 MB.'}, status=400)
|
||
|
||
recovery_email = request.data.get('recovery_email', '').strip().lower()
|
||
if recovery_email and recovery_email == request.user.email.lower():
|
||
return Response(
|
||
{'recovery_email': 'Recovery email must differ from your login email.'},
|
||
status=400,
|
||
)
|
||
|
||
old_email = request.user.email
|
||
profile, _ = Profile.objects.get_or_create(user=request.user)
|
||
serializer = ProfileSerializer(profile, data=request.data, partial=True)
|
||
if serializer.is_valid():
|
||
serializer.save()
|
||
new_email = request.user.email
|
||
if new_email != old_email:
|
||
send_email(
|
||
'email_changed',
|
||
{'new_email': new_email},
|
||
'Armarium – Deine E-Mail-Adresse wurde geändert',
|
||
old_email,
|
||
)
|
||
return Response(serializer.data)
|
||
return Response(serializer.errors, status=400)
|
||
|
||
def delete(self, request):
|
||
password = request.data.get('password', '')
|
||
if not password or not request.user.check_password(password):
|
||
return Response({'detail': 'Passwort ungültig.'}, status=status.HTTP_403_FORBIDDEN)
|
||
request.user.delete()
|
||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||
|
||
|
||
class RegisterView(views.APIView):
|
||
permission_classes = [AllowAny]
|
||
throttle_classes = [AuthThrottle]
|
||
|
||
def post(self, request):
|
||
from .email import send_email
|
||
|
||
if not _verify_turnstile(
|
||
request.data.get('cf_turnstile_response', ''),
|
||
request.META.get('REMOTE_ADDR', ''),
|
||
):
|
||
return Response({'detail': 'Captcha verification failed.'}, status=status.HTTP_400_BAD_REQUEST)
|
||
serializer = RegisterSerializer(data=request.data)
|
||
if serializer.is_valid():
|
||
user = serializer.save()
|
||
from django.utils import timezone
|
||
from datetime import timedelta
|
||
token = secrets.token_urlsafe(32)
|
||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||
profile, _ = Profile.objects.get_or_create(user=user)
|
||
profile.email_verify_token = token_hash
|
||
profile.email_verify_token_expires = timezone.now() + timedelta(hours=24)
|
||
profile.save(update_fields=['email_verify_token', 'email_verify_token_expires'])
|
||
link = f"{settings.FRONTEND_URL}/verify-email?token={token}"
|
||
send_email('registration_confirm', {'link': link}, 'Armarium – E-Mail-Adresse bestätigen', user.email)
|
||
return Response({'detail': 'Account created.'}, status=status.HTTP_201_CREATED)
|
||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
|
||
class LogoutView(views.APIView):
|
||
permission_classes = [AllowAny]
|
||
|
||
def post(self, request):
|
||
refresh_token = request.data.get('refresh')
|
||
if not refresh_token:
|
||
return Response({'detail': 'Refresh token required.'}, status=400)
|
||
try:
|
||
token = RefreshToken(refresh_token)
|
||
jti = token.payload.get('jti', '')
|
||
token.blacklist()
|
||
if jti:
|
||
UserSession.objects.filter(refresh_jti=jti).delete()
|
||
except TokenError:
|
||
pass # already invalid, treat as success
|
||
|
||
session_key = request.headers.get('X-Session-Key', '')
|
||
if session_key:
|
||
UserSession.objects.filter(session_key=session_key).delete()
|
||
|
||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||
|
||
|
||
class ChangePasswordView(views.APIView):
|
||
def post(self, request):
|
||
from .email import send_email
|
||
|
||
password = request.data.get('password', '')
|
||
if len(password) < 8:
|
||
return Response({'detail': 'Password must be at least 8 characters.'}, status=400)
|
||
request.user.set_password(password)
|
||
request.user.save()
|
||
|
||
current_key = request.headers.get('X-Session-Key', '')
|
||
other_sessions = UserSession.objects.filter(user=request.user).exclude(session_key=current_key)
|
||
for session in other_sessions:
|
||
_blacklist_session(session)
|
||
|
||
send_email('password_changed', {}, 'Armarium – Dein Passwort wurde geändert', request.user.email)
|
||
return Response({'detail': 'Password updated.'})
|
||
|
||
|
||
class SearchView(views.APIView):
|
||
"""Global search across all user resources."""
|
||
|
||
def get(self, request):
|
||
q = request.query_params.get('q', '').strip()
|
||
if len(q) < 2:
|
||
return Response({})
|
||
|
||
user = request.user
|
||
results = {}
|
||
|
||
accounts = Account.objects.filter(user=user, name__icontains=q)[:5]
|
||
if accounts:
|
||
results['accounts'] = [
|
||
{'id': a.id, 'title': a.name, 'subtitle': a.get_account_type_display()}
|
||
for a in accounts
|
||
]
|
||
|
||
budgets = Budget.objects.filter(account__user=user, name__icontains=q)[:5]
|
||
if budgets:
|
||
results['budgets'] = [
|
||
{'id': b.id, 'title': b.name, 'subtitle': f'CHF {b.amount}'}
|
||
for b in budgets
|
||
]
|
||
|
||
expenses = Expense.objects.filter(account__user=user, name__icontains=q)[:5]
|
||
if expenses:
|
||
results['expenses'] = [
|
||
{'id': e.id, 'title': e.name, 'subtitle': f'{e.date} · CHF {e.amount}'}
|
||
for e in expenses
|
||
]
|
||
|
||
transactions = Transaction.objects.filter(
|
||
source_account__user=user, description__icontains=q
|
||
)[:5]
|
||
if transactions:
|
||
results['transactions'] = [
|
||
{'id': t.id, 'title': t.description, 'subtitle': f'{t.date} · CHF {t.amount}'}
|
||
for t in transactions
|
||
]
|
||
|
||
deadlines = Deadline.objects.filter(user=user, title__icontains=q)[:5]
|
||
if deadlines:
|
||
results['deadlines'] = [
|
||
{'id': d.id, 'title': d.title, 'subtitle': str(d.date), 'date': str(d.date)}
|
||
for d in deadlines
|
||
]
|
||
|
||
return Response(results)
|
||
|
||
|
||
class NotificationsView(views.APIView):
|
||
"""Returns all unread active events (date <= today) for the authenticated user."""
|
||
|
||
def get(self, request):
|
||
from datetime import date
|
||
today = date.today()
|
||
|
||
read_deadlines = set(
|
||
ReadEvent.objects.filter(user=request.user, event_type='deadline')
|
||
.values_list('event_id', flat=True)
|
||
)
|
||
read_expenses = set(
|
||
ReadEvent.objects.filter(user=request.user, event_type='expense')
|
||
.values_list('event_id', flat=True)
|
||
)
|
||
|
||
notifications = []
|
||
|
||
for d in Deadline.objects.filter(user=request.user, date__lte=today):
|
||
if d.id not in read_deadlines:
|
||
notifications.append({
|
||
'event_type': 'deadline',
|
||
'event_id': d.id,
|
||
'title': d.title,
|
||
'date': str(d.date),
|
||
})
|
||
|
||
for e in Expense.objects.filter(account__user=request.user, due_date__lte=today):
|
||
if e.id not in read_expenses:
|
||
notifications.append({
|
||
'event_type': 'expense',
|
||
'event_id': e.id,
|
||
'title': e.name,
|
||
'date': str(e.due_date),
|
||
})
|
||
|
||
notifications.sort(key=lambda x: x['date'])
|
||
return Response(notifications)
|
||
|
||
def post(self, request):
|
||
"""Mark a single event as read."""
|
||
event_type = request.data.get('event_type')
|
||
event_id = request.data.get('event_id')
|
||
if event_type not in ('deadline', 'expense') or not event_id:
|
||
return Response({'detail': 'Invalid payload.'}, status=400)
|
||
ReadEvent.objects.get_or_create(
|
||
user=request.user, event_type=event_type, event_id=event_id
|
||
)
|
||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||
|
||
|
||
class ICalUrlView(views.APIView):
|
||
"""Returns the personal iCal feed URL for the authenticated user."""
|
||
|
||
def get(self, request):
|
||
token = generate_ical_token(request.user.id)
|
||
base_url = request.build_absolute_uri('/')
|
||
url = f"{base_url}api/calendar/ical/{request.user.id}/{token}/"
|
||
return Response({'url': url})
|
||
|
||
|
||
class ICalFeedView(views.APIView):
|
||
"""Serves the iCal feed. Token acts as authentication — no JWT required."""
|
||
permission_classes = [AllowAny]
|
||
|
||
def get(self, request, user_id, token):
|
||
expected = generate_ical_token(user_id)
|
||
if not hmac.compare_digest(expected, token):
|
||
return HttpResponse(status=404)
|
||
|
||
User = get_user_model()
|
||
try:
|
||
user = User.objects.get(pk=user_id)
|
||
except User.DoesNotExist:
|
||
return HttpResponse(status=404)
|
||
|
||
cal = iCalendar()
|
||
cal.add('prodid', '-//Budget App//EN')
|
||
cal.add('version', '2.0')
|
||
cal.add('x-wr-calname', 'Budget App')
|
||
cal.add('x-wr-timezone', 'Europe/Zurich')
|
||
|
||
# Deadlines
|
||
for deadline in Deadline.objects.filter(user=user):
|
||
event = iCalEvent()
|
||
event.add('summary', f'[{deadline.get_type_display()}] {deadline.title}')
|
||
event.add('dtstart', deadline.date)
|
||
event.add('dtend', deadline.date)
|
||
event.add('uid', f'deadline-{deadline.id}@budget-app')
|
||
if deadline.notes:
|
||
event.add('description', deadline.notes)
|
||
cal.add_component(event)
|
||
|
||
# Expense due dates
|
||
for expense in Expense.objects.filter(account__user=user, due_date__isnull=False):
|
||
event = iCalEvent()
|
||
event.add('summary', f'[Invoice] {expense.name} – CHF {expense.amount}')
|
||
event.add('dtstart', expense.due_date)
|
||
event.add('dtend', expense.due_date)
|
||
event.add('uid', f'expense-{expense.id}@budget-app')
|
||
if expense.notes:
|
||
event.add('description', expense.notes)
|
||
cal.add_component(event)
|
||
|
||
response = HttpResponse(cal.to_ical(), content_type='text/calendar; charset=utf-8')
|
||
response['Content-Disposition'] = 'attachment; filename="budget-app.ics"'
|
||
return response
|
||
|
||
|
||
# ── 2FA helpers ──────────────────────────────────────────────────────────────
|
||
|
||
def _make_2fa_token(user_id: int) -> str:
|
||
"""Create a short-lived signed token binding step-1 to step-2 of login."""
|
||
payload = f"{user_id}:{int(time.time())}"
|
||
sig = hmac.new(settings.SECRET_KEY.encode(), payload.encode(), hashlib.sha256).hexdigest()
|
||
return base64.urlsafe_b64encode(f"{payload}:{sig}".encode()).decode()
|
||
|
||
|
||
def _verify_2fa_token(token: str, max_age: int = 300) -> int | None:
|
||
"""Return user_id if token is valid and not expired, else None."""
|
||
try:
|
||
decoded = base64.urlsafe_b64decode(token.encode()).decode()
|
||
*payload_parts, sig = decoded.split(':')
|
||
payload = ':'.join(payload_parts)
|
||
user_id_str, ts_str = payload_parts
|
||
expected = hmac.new(settings.SECRET_KEY.encode(), payload.encode(), hashlib.sha256).hexdigest()
|
||
if not hmac.compare_digest(expected, sig):
|
||
return None
|
||
if int(time.time()) - int(ts_str) > max_age:
|
||
return None
|
||
return int(user_id_str)
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def _generate_backup_codes(user, count: int = 8) -> list[str]:
|
||
"""Invalidate all old backup codes and return a fresh set of plain-text codes."""
|
||
BackupCode.objects.filter(user=user).delete()
|
||
plain = []
|
||
for _ in range(count):
|
||
code = f"{secrets.token_hex(4).upper()}-{secrets.token_hex(4).upper()}"
|
||
BackupCode.objects.create(
|
||
user=user,
|
||
code_hash=hashlib.sha256(code.encode()).hexdigest(),
|
||
)
|
||
plain.append(code)
|
||
return plain
|
||
|
||
|
||
def _verify_totp_with_replay_check(profile, code: str) -> bool:
|
||
"""Verify TOTP code and reject replay within the same 30-second window."""
|
||
if profile.totp_last_used_code == code:
|
||
return False
|
||
totp = pyotp.TOTP(profile.totp_secret)
|
||
if not totp.verify(code, valid_window=1):
|
||
return False
|
||
profile.totp_last_used_code = code
|
||
profile.save(update_fields=['totp_last_used_code'])
|
||
return True
|
||
|
||
|
||
# ── 2FA views ─────────────────────────────────────────────────────────────────
|
||
|
||
class LoginView(views.APIView):
|
||
"""Replaces TokenObtainPairView. Returns a short-lived temp_token when 2FA is required."""
|
||
permission_classes = [AllowAny]
|
||
throttle_classes = [AuthThrottle]
|
||
|
||
def post(self, request):
|
||
if not _verify_turnstile(
|
||
request.data.get('cf_turnstile_response', ''),
|
||
request.META.get('REMOTE_ADDR', ''),
|
||
):
|
||
return Response({'detail': 'Captcha verification failed.'}, status=status.HTTP_400_BAD_REQUEST)
|
||
email = request.data.get('username', '')
|
||
password = request.data.get('password', '')
|
||
user = authenticate(request, username=email, password=password)
|
||
if user is None:
|
||
return Response({'detail': 'No active account found with the given credentials.'}, status=401)
|
||
|
||
profile, _ = Profile.objects.get_or_create(user=user)
|
||
if profile.totp_enabled:
|
||
return Response({'2fa_required': True, 'temp_token': _make_2fa_token(user.id)}, status=200)
|
||
|
||
refresh = RefreshToken.for_user(user)
|
||
session_key = _create_session(user, request, refresh)
|
||
return Response({'access': str(refresh.access_token), 'refresh': str(refresh), 'session_key': session_key})
|
||
|
||
|
||
class TwoFactorLoginView(views.APIView):
|
||
"""Step 2 of login — accepts TOTP code or backup code, returns JWT tokens."""
|
||
permission_classes = [AllowAny]
|
||
throttle_classes = [AuthThrottle]
|
||
|
||
def post(self, request):
|
||
temp_token = request.data.get('temp_token', '')
|
||
code = str(request.data.get('code', '')).strip()
|
||
|
||
user_id = _verify_2fa_token(temp_token)
|
||
if user_id is None:
|
||
return Response({'detail': 'Session expired. Please log in again.'}, status=401)
|
||
|
||
User = get_user_model()
|
||
try:
|
||
user = User.objects.get(pk=user_id)
|
||
except User.DoesNotExist:
|
||
return Response({'detail': 'Invalid credentials.'}, status=401)
|
||
|
||
profile, _ = Profile.objects.get_or_create(user=user)
|
||
if not profile.totp_enabled or not profile.totp_secret:
|
||
return Response({'detail': 'Invalid credentials.'}, status=401)
|
||
|
||
if code.isdigit() and len(code) == 6:
|
||
if not _verify_totp_with_replay_check(profile, code):
|
||
return Response({'detail': 'Invalid or already used code.'}, status=400)
|
||
else:
|
||
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||
backup = BackupCode.objects.filter(user=user, code_hash=code_hash, used=False).first()
|
||
if backup is None:
|
||
return Response({'detail': 'Invalid backup code.'}, status=400)
|
||
backup.used = True
|
||
backup.save(update_fields=['used'])
|
||
|
||
refresh = RefreshToken.for_user(user)
|
||
session_key = _create_session(user, request, refresh)
|
||
return Response({'access': str(refresh.access_token), 'refresh': str(refresh), 'session_key': session_key})
|
||
|
||
|
||
class TwoFactorSetupView(views.APIView):
|
||
"""Generates a fresh TOTP secret and returns the otpauth:// URI for QR display."""
|
||
|
||
def get(self, request):
|
||
profile, _ = Profile.objects.get_or_create(user=request.user)
|
||
secret = pyotp.random_base32()
|
||
profile.totp_secret = secret
|
||
profile.totp_enabled = False
|
||
profile.save(update_fields=['totp_secret', 'totp_enabled'])
|
||
email = request.user.email or request.user.username
|
||
uri = pyotp.TOTP(secret).provisioning_uri(name=email, issuer_name='Armarium')
|
||
return Response({'uri': uri})
|
||
|
||
|
||
class TwoFactorEnableView(views.APIView):
|
||
"""Verifies the first TOTP code, activates 2FA and returns one-time backup codes."""
|
||
|
||
def post(self, request):
|
||
code = str(request.data.get('code', '')).strip()
|
||
profile, _ = Profile.objects.get_or_create(user=request.user)
|
||
if not profile.totp_secret:
|
||
return Response({'detail': 'Run setup first.'}, status=400)
|
||
if not _verify_totp_with_replay_check(profile, code):
|
||
return Response({'detail': 'Invalid code.'}, status=400)
|
||
profile.totp_enabled = True
|
||
profile.save(update_fields=['totp_enabled'])
|
||
backup_codes = _generate_backup_codes(request.user)
|
||
return Response({'detail': '2FA enabled.', 'backup_codes': backup_codes})
|
||
|
||
|
||
class TwoFactorDisableView(views.APIView):
|
||
"""Disables 2FA — accepts TOTP code or a backup code as proof."""
|
||
|
||
def post(self, request):
|
||
code = str(request.data.get('code', '')).strip()
|
||
profile, _ = Profile.objects.get_or_create(user=request.user)
|
||
if not profile.totp_enabled:
|
||
return Response({'detail': '2FA is not enabled.'}, status=400)
|
||
|
||
authenticated = False
|
||
if code.isdigit() and len(code) == 6:
|
||
authenticated = _verify_totp_with_replay_check(profile, code)
|
||
else:
|
||
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||
backup = BackupCode.objects.filter(user=request.user, code_hash=code_hash, used=False).first()
|
||
if backup:
|
||
backup.used = True
|
||
backup.save(update_fields=['used'])
|
||
authenticated = True
|
||
|
||
if not authenticated:
|
||
return Response({'detail': 'Invalid code.'}, status=400)
|
||
|
||
profile.totp_enabled = False
|
||
profile.totp_secret = ''
|
||
profile.totp_last_used_code = ''
|
||
profile.save(update_fields=['totp_enabled', 'totp_secret', 'totp_last_used_code'])
|
||
BackupCode.objects.filter(user=request.user).delete()
|
||
return Response({'detail': '2FA disabled.'})
|
||
|
||
|
||
# ── Recovery email helpers ────────────────────────────────────────────────────
|
||
|
||
def _mask_email(email: str) -> str:
|
||
if '@' not in email:
|
||
return '***'
|
||
local, domain = email.split('@', 1)
|
||
return f"{local[0]}{'*' * min(len(local) - 1, 18)}@{domain}"
|
||
|
||
|
||
def _generate_recovery_code() -> str:
|
||
"""Generate a human-readable 8-character code in XXXX-XXXX format."""
|
||
alphabet = 'ABCDEFGHJKLMNPQRSTUVWXYZ23456789' # no O/0, I/1 confusion
|
||
part = lambda: ''.join(secrets.choice(alphabet) for _ in range(4))
|
||
return f"{part()}-{part()}"
|
||
|
||
|
||
class TwoFactorRecoverRequestView(views.APIView):
|
||
"""Generate a recovery code, store its hash in Profile and email the plain code."""
|
||
permission_classes = [AllowAny]
|
||
throttle_classes = [AuthThrottle]
|
||
|
||
def post(self, request):
|
||
from django.utils import timezone
|
||
from datetime import timedelta
|
||
from .email import send_email
|
||
|
||
temp_token = request.data.get('temp_token', '')
|
||
user_id = _verify_2fa_token(temp_token)
|
||
if user_id is None:
|
||
return Response({'detail': 'ok'})
|
||
|
||
User = get_user_model()
|
||
user = User.objects.filter(pk=user_id).first()
|
||
if not user:
|
||
return Response({'detail': 'ok'})
|
||
|
||
profile = Profile.objects.filter(user=user).first()
|
||
if not profile or not profile.recovery_email:
|
||
return Response({'detail': 'ok'})
|
||
|
||
plain_code = _generate_recovery_code()
|
||
profile.recovery_code_hash = hashlib.sha256(plain_code.encode()).hexdigest()
|
||
profile.recovery_code_expires = timezone.now() + timedelta(minutes=15)
|
||
profile.save(update_fields=['recovery_code_hash', 'recovery_code_expires'])
|
||
|
||
sent = send_email(
|
||
template_name='2fa_recovery',
|
||
context={'code': plain_code},
|
||
subject='Armarium – 2FA-Wiederherstellung',
|
||
to=profile.recovery_email,
|
||
)
|
||
if not sent:
|
||
return Response({'detail': 'Failed to send recovery email.'}, status=500)
|
||
|
||
return Response({'detail': 'ok', 'masked_email': _mask_email(profile.recovery_email)})
|
||
|
||
|
||
class TwoFactorRecoverConfirmView(views.APIView):
|
||
"""Verify the recovery code, disable 2FA and return JWT tokens."""
|
||
permission_classes = [AllowAny]
|
||
throttle_classes = [AuthThrottle]
|
||
|
||
def post(self, request):
|
||
from django.utils import timezone
|
||
|
||
temp_token = request.data.get('temp_token', '')
|
||
user_id = _verify_2fa_token(temp_token)
|
||
if user_id is None:
|
||
return Response({'detail': 'Session expired. Please log in again.'}, status=401)
|
||
|
||
recovery_code = str(request.data.get('recovery_code', '')).strip().upper()
|
||
if not recovery_code:
|
||
return Response({'detail': 'Code required.'}, status=400)
|
||
|
||
code_hash = hashlib.sha256(recovery_code.encode()).hexdigest()
|
||
profile = Profile.objects.filter(
|
||
user_id=user_id,
|
||
recovery_code_hash=code_hash,
|
||
recovery_code_expires__gt=timezone.now(),
|
||
).first()
|
||
|
||
if not profile:
|
||
return Response({'detail': 'Invalid or expired recovery code.'}, status=400)
|
||
|
||
profile.totp_enabled = False
|
||
profile.totp_secret = ''
|
||
profile.totp_last_used_code = ''
|
||
profile.recovery_code_hash = ''
|
||
profile.recovery_code_expires = None
|
||
profile.save(update_fields=[
|
||
'totp_enabled', 'totp_secret', 'totp_last_used_code',
|
||
'recovery_code_hash', 'recovery_code_expires',
|
||
])
|
||
BackupCode.objects.filter(user=profile.user).delete()
|
||
|
||
refresh = RefreshToken.for_user(profile.user)
|
||
session_key = _create_session(profile.user, request, refresh)
|
||
return Response({'access': str(refresh.access_token), 'refresh': str(refresh), 'session_key': session_key})
|
||
|
||
|
||
# ── Session helpers ───────────────────────────────────────────────────────────
|
||
|
||
def _parse_device(ua: str) -> str:
|
||
ua = ua.lower()
|
||
if 'iphone' in ua: return 'iPhone'
|
||
if 'ipad' in ua: return 'iPad'
|
||
if 'android' in ua and 'mobile' in ua: return 'Android (Phone)'
|
||
if 'android' in ua: return 'Android (Tablet)'
|
||
if 'macintosh' in ua or 'mac os x' in ua: return 'Mac'
|
||
if 'windows nt' in ua: return 'Windows'
|
||
if 'linux' in ua: return 'Linux'
|
||
return 'Unbekanntes Gerät'
|
||
|
||
|
||
def _get_client_ip(request) -> str | None:
|
||
forwarded = request.META.get('HTTP_X_FORWARDED_FOR', '')
|
||
if forwarded:
|
||
return forwarded.split(',')[0].strip()
|
||
return request.META.get('REMOTE_ADDR') or None
|
||
|
||
|
||
def _create_session(user, request, refresh_token: RefreshToken) -> str:
|
||
session_key = secrets.token_urlsafe(32)
|
||
UserSession.objects.create(
|
||
user=user,
|
||
session_key=session_key,
|
||
refresh_jti=str(refresh_token.payload.get('jti', '')),
|
||
device_name=_parse_device(request.META.get('HTTP_USER_AGENT', '')),
|
||
ip_address=_get_client_ip(request),
|
||
)
|
||
return session_key
|
||
|
||
|
||
# ── Session views ─────────────────────────────────────────────────────────────
|
||
|
||
class SessionListView(views.APIView):
|
||
def get(self, request):
|
||
current_key = request.headers.get('X-Session-Key', '')
|
||
sessions = UserSession.objects.filter(user=request.user)
|
||
data = [
|
||
{
|
||
'session_key': s.session_key,
|
||
'device_name': s.device_name,
|
||
'ip_address': s.ip_address,
|
||
'created_at': s.created_at,
|
||
'last_active_at': s.last_active_at,
|
||
'is_current': s.session_key == current_key,
|
||
}
|
||
for s in sessions
|
||
]
|
||
return Response(data)
|
||
|
||
|
||
class SessionRevokeView(views.APIView):
|
||
def delete(self, request, session_key):
|
||
session = UserSession.objects.filter(user=request.user, session_key=session_key).first()
|
||
if not session:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
_blacklist_session(session)
|
||
return Response(status=204)
|
||
|
||
|
||
class SessionRevokeAllView(views.APIView):
|
||
def delete(self, request):
|
||
current_key = request.headers.get('X-Session-Key', '')
|
||
sessions = UserSession.objects.filter(user=request.user).exclude(session_key=current_key)
|
||
for session in sessions:
|
||
_blacklist_session(session)
|
||
return Response(status=204)
|
||
|
||
|
||
def _blacklist_session(session: UserSession) -> None:
|
||
if session.refresh_jti:
|
||
try:
|
||
from rest_framework_simplejwt.token_blacklist.models import OutstandingToken, BlacklistedToken
|
||
token = OutstandingToken.objects.get(jti=session.refresh_jti)
|
||
BlacklistedToken.objects.get_or_create(token=token)
|
||
except Exception:
|
||
pass
|
||
session.delete()
|
||
|
||
|
||
# ── Data export ───────────────────────────────────────────────────────────────
|
||
|
||
class DataExportView(views.APIView):
|
||
def get(self, request):
|
||
import io
|
||
import zipfile
|
||
from datetime import date
|
||
from fpdf import FPDF
|
||
|
||
user = request.user
|
||
profile = Profile.objects.filter(user=user).first()
|
||
today = date.today().strftime('%d.%m.%Y')
|
||
export_date = date.today().strftime('%Y-%m-%d')
|
||
|
||
VIOLET = (124, 58, 237)
|
||
HEADER_BG = (243, 240, 255)
|
||
ALT_ROW = (249, 249, 252)
|
||
TEXT_DARK = (30, 30, 40)
|
||
TEXT_GRAY = (120, 120, 135)
|
||
|
||
def safe(text: str) -> str:
|
||
return str(text).encode('latin-1', errors='replace').decode('latin-1')
|
||
|
||
class ArmPDF(FPDF):
|
||
def __init__(self, section_title):
|
||
super().__init__()
|
||
self.section_title = section_title
|
||
self.set_auto_page_break(auto=True, margin=18)
|
||
|
||
def header(self):
|
||
self.set_fill_color(*VIOLET)
|
||
self.rect(0, 0, 210, 10, 'F')
|
||
self.set_xy(14, 13)
|
||
self.set_font('Helvetica', 'B', 15)
|
||
self.set_text_color(*TEXT_DARK)
|
||
self.cell(0, 7, safe(f'Armarium - {self.section_title}'), ln=True)
|
||
self.set_font('Helvetica', '', 8)
|
||
self.set_text_color(*TEXT_GRAY)
|
||
self.set_x(14)
|
||
self.cell(0, 5, safe(f'Export vom {today}'), ln=True)
|
||
self.ln(3)
|
||
|
||
def footer(self):
|
||
self.set_y(-12)
|
||
self.set_font('Helvetica', '', 7)
|
||
self.set_text_color(*TEXT_GRAY)
|
||
self.cell(0, 5, safe(f'Armarium - {today} - Seite {self.page_no()}'), align='C')
|
||
|
||
def table_header(self, cols):
|
||
self.set_fill_color(*HEADER_BG)
|
||
self.set_font('Helvetica', 'B', 8)
|
||
self.set_text_color(*VIOLET)
|
||
for label, width in cols:
|
||
self.cell(width, 7, safe(label), border=0, fill=True, align='L')
|
||
self.ln()
|
||
self.set_draw_color(*VIOLET)
|
||
self.set_line_width(0.4)
|
||
x = self.get_x()
|
||
y = self.get_y()
|
||
self.line(14, y, 196, y)
|
||
|
||
def table_row(self, values, cols, fill=False):
|
||
if fill:
|
||
self.set_fill_color(*ALT_ROW)
|
||
self.set_font('Helvetica', '', 8)
|
||
self.set_text_color(*TEXT_DARK)
|
||
for (label, width), val in zip(cols, values):
|
||
self.cell(width, 6, safe(val), border=0, fill=fill)
|
||
self.ln()
|
||
|
||
def make_pdf(title, build_fn):
|
||
pdf = ArmPDF(title)
|
||
pdf.add_page()
|
||
build_fn(pdf)
|
||
return pdf.output()
|
||
|
||
# ── Profil ────────────────────────────────────────────────────────────
|
||
def build_profile(pdf):
|
||
name = f"{profile.first_name} {profile.last_name}".strip() if profile else ''
|
||
rows = [
|
||
('Name', name or '-'),
|
||
('E-Mail', user.email or '-'),
|
||
('Kanton', profile.canton if profile else '-'),
|
||
('Sprache', profile.language if profile else '-'),
|
||
('2FA', 'Aktiviert' if (profile and profile.totp_enabled) else 'Deaktiviert'),
|
||
]
|
||
cols = [('Feld', 60), ('Wert', 120)]
|
||
pdf.table_header(cols)
|
||
for i, (field, val) in enumerate(rows):
|
||
pdf.table_row([field, val], cols, fill=i % 2 == 1)
|
||
|
||
# ── Konten ────────────────────────────────────────────────────────────
|
||
def build_accounts(pdf):
|
||
cols = [('Name', 90), ('Typ', 60), ('Saldo (CHF)', 42)]
|
||
pdf.table_header(cols)
|
||
for i, acc in enumerate(Account.objects.filter(user=user)):
|
||
pdf.table_row([acc.name, acc.account_type, f'{acc.balance:,.2f}'], cols, fill=i % 2 == 1)
|
||
|
||
# ── Budgets ───────────────────────────────────────────────────────────
|
||
def build_budgets(pdf):
|
||
cols = [('Name', 80), ('Kategorie', 60), ('Betrag (CHF)', 42), ('Aktiv', 10)]
|
||
pdf.table_header(cols)
|
||
for i, b in enumerate(Budget.objects.filter(account__user=user).order_by('main_category', 'name')):
|
||
pdf.table_row([b.name, b.main_category, f'{b.amount:,.2f}', 'Ja' if b.active else 'Nein'], cols, fill=i % 2 == 1)
|
||
|
||
# ── Ausgaben ──────────────────────────────────────────────────────────
|
||
def build_expenses(pdf):
|
||
cols = [('Datum', 26), ('Name', 70), ('Kategorie', 46), ('Konto', 30), ('CHF', 20)]
|
||
pdf.table_header(cols)
|
||
for i, e in enumerate(Expense.objects.filter(account__user=user).order_by('-date')):
|
||
pdf.table_row([
|
||
e.date.strftime('%d.%m.%Y'), e.name, e.category,
|
||
e.account.name, f'{e.amount:,.2f}'
|
||
], cols, fill=i % 2 == 1)
|
||
|
||
# ── Transaktionen ─────────────────────────────────────────────────────
|
||
def build_transactions(pdf):
|
||
cols = [('Datum', 26), ('Beschreibung', 70), ('Von', 38), ('Nach', 38), ('CHF', 20)]
|
||
pdf.table_header(cols)
|
||
qs = Transaction.objects.filter(source_account__user=user).order_by('-date').select_related('source_account', 'destination_account')
|
||
for i, t in enumerate(qs):
|
||
pdf.table_row([
|
||
t.date.strftime('%d.%m.%Y'), t.description,
|
||
t.source_account.name, t.destination_account.name, f'{t.amount:,.2f}'
|
||
], cols, fill=i % 2 == 1)
|
||
|
||
# ── Termine ───────────────────────────────────────────────────────────
|
||
def build_deadlines(pdf):
|
||
cols = [('Datum', 30), ('Titel', 100), ('Typ', 42), ('Notizen', 20)]
|
||
pdf.table_header(cols)
|
||
for i, d in enumerate(Deadline.objects.filter(user=user).order_by('date')):
|
||
pdf.table_row([d.date.strftime('%d.%m.%Y'), d.title, d.type, d.notes[:20]], cols, fill=i % 2 == 1)
|
||
|
||
pdfs = [
|
||
('profil.pdf', 'Profil', build_profile),
|
||
('konten.pdf', 'Konten', build_accounts),
|
||
('budgets.pdf', 'Budgets', build_budgets),
|
||
('ausgaben.pdf', 'Ausgaben', build_expenses),
|
||
('transaktionen.pdf', 'Transaktionen', build_transactions),
|
||
('termine.pdf', 'Termine', build_deadlines),
|
||
]
|
||
|
||
zip_buffer = io.BytesIO()
|
||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
|
||
for filename, title, build_fn in pdfs:
|
||
zf.writestr(filename, bytes(make_pdf(title, build_fn)))
|
||
zip_buffer.seek(0)
|
||
|
||
response = HttpResponse(zip_buffer.read(), content_type='application/zip')
|
||
response['Content-Disposition'] = f'attachment; filename="armarium-export-{export_date}.zip"'
|
||
return response
|
||
|
||
|
||
# ── Notification preferences ──────────────────────────────────────────────────
|
||
|
||
class NotificationPrefsView(views.APIView):
|
||
def patch(self, request):
|
||
profile, _ = Profile.objects.get_or_create(user=request.user)
|
||
fields = ['notif_deadlines', 'notif_budget_alerts', 'notif_monthly_summary']
|
||
changed = []
|
||
for field in fields:
|
||
if field in request.data:
|
||
setattr(profile, field, bool(request.data[field]))
|
||
changed.append(field)
|
||
if changed:
|
||
profile.save(update_fields=changed)
|
||
return Response({
|
||
'notif_deadlines': profile.notif_deadlines,
|
||
'notif_budget_alerts': profile.notif_budget_alerts,
|
||
'notif_monthly_summary': profile.notif_monthly_summary,
|
||
})
|
||
|
||
|
||
class VerifyEmailView(views.APIView):
|
||
permission_classes = [AllowAny]
|
||
throttle_classes = [AuthThrottle]
|
||
|
||
def post(self, request):
|
||
from django.utils import timezone
|
||
|
||
token = request.data.get('token', '').strip()
|
||
if not token:
|
||
return Response({'detail': 'Token required.'}, status=400)
|
||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||
profile = Profile.objects.filter(
|
||
email_verify_token=token_hash,
|
||
email_verify_token_expires__gt=timezone.now(),
|
||
).first()
|
||
if not profile:
|
||
return Response({'detail': 'Invalid or expired token.'}, status=400)
|
||
profile.email_verified = True
|
||
profile.email_verify_token = ''
|
||
profile.email_verify_token_expires = None
|
||
profile.save(update_fields=['email_verified', 'email_verify_token', 'email_verify_token_expires'])
|
||
return Response({'detail': 'Email verified.'})
|
||
|
||
|
||
class PasswordResetRequestView(views.APIView):
|
||
permission_classes = [AllowAny]
|
||
throttle_classes = [AuthThrottle]
|
||
|
||
def post(self, request):
|
||
from django.utils import timezone
|
||
from datetime import timedelta
|
||
from .email import send_email
|
||
|
||
email = request.data.get('email', '').strip().lower()
|
||
User = get_user_model()
|
||
user = User.objects.filter(email=email).first()
|
||
if user:
|
||
token = secrets.token_urlsafe(32)
|
||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||
profile, _ = Profile.objects.get_or_create(user=user)
|
||
profile.password_reset_token_hash = token_hash
|
||
profile.password_reset_token_expires = timezone.now() + timedelta(minutes=15)
|
||
profile.save(update_fields=['password_reset_token_hash', 'password_reset_token_expires'])
|
||
link = f"{settings.FRONTEND_URL}/reset-password?token={token}"
|
||
send_email('password_reset', {'link': link}, 'Armarium – Passwort zurücksetzen', user.email)
|
||
return Response({'detail': 'Wenn ein Konto mit dieser E-Mail existiert, wurde eine E-Mail gesendet.'})
|
||
|
||
|
||
class PasswordResetConfirmView(views.APIView):
|
||
permission_classes = [AllowAny]
|
||
throttle_classes = [AuthThrottle]
|
||
|
||
def post(self, request):
|
||
from django.utils import timezone
|
||
|
||
token = request.data.get('token', '').strip()
|
||
password = request.data.get('password', '')
|
||
if not token:
|
||
return Response({'detail': 'Token required.'}, status=400)
|
||
if len(password) < 8:
|
||
return Response({'detail': 'Password must be at least 8 characters.'}, status=400)
|
||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||
profile = Profile.objects.filter(
|
||
password_reset_token_hash=token_hash,
|
||
password_reset_token_expires__gt=timezone.now(),
|
||
).first()
|
||
if not profile:
|
||
return Response({'detail': 'Invalid or expired token.'}, status=400)
|
||
user = profile.user
|
||
user.set_password(password)
|
||
user.save()
|
||
profile.password_reset_token_hash = ''
|
||
profile.password_reset_token_expires = None
|
||
profile.save(update_fields=['password_reset_token_hash', 'password_reset_token_expires'])
|
||
for session in UserSession.objects.filter(user=user):
|
||
_blacklist_session(session)
|
||
return Response({'detail': 'Password updated.'})
|
||
|
||
|
||
# ── FinancialYear Helpers ─────────────────────────────────────────────────────
|
||
|
||
def _get_user_financial_year(user, year):
|
||
"""Return the FinancialYear for a given year accessible to this user."""
|
||
# Personal year
|
||
fy = FinancialYear.objects.filter(user=user, year=year).first()
|
||
if fy:
|
||
return fy
|
||
# Household year where user is an active member for this year
|
||
memberships = HouseholdMembership.objects.filter(
|
||
user=user,
|
||
status='active',
|
||
effective_from_year__lte=year,
|
||
).filter(
|
||
models.Q(effective_until_year__isnull=True) | models.Q(effective_until_year__gt=year)
|
||
)
|
||
household_ids = memberships.values_list('household_id', flat=True)
|
||
return FinancialYear.objects.filter(household_id__in=household_ids, year=year).first()
|
||
|
||
|
||
def _all_user_financial_years(user):
|
||
"""Return all FinancialYears accessible to this user."""
|
||
personal = FinancialYear.objects.filter(user=user)
|
||
memberships = HouseholdMembership.objects.filter(user=user, status='active')
|
||
household_ids = memberships.values_list('household_id', flat=True)
|
||
household = FinancialYear.objects.filter(household_id__in=household_ids)
|
||
return (personal | household).distinct().order_by('-year')
|
||
|
||
|
||
def _max_year_for_user(user):
|
||
"""Return the highest year the user currently has access to."""
|
||
years = _all_user_financial_years(user).values_list('year', flat=True)
|
||
return max(years) if years else None
|
||
|
||
|
||
# ── FinancialYear Views ───────────────────────────────────────────────────────
|
||
|
||
class FinancialYearListCreateView(views.APIView):
|
||
def get(self, request):
|
||
qs = _all_user_financial_years(request.user)
|
||
return Response(FinancialYearSerializer(qs, many=True).data)
|
||
|
||
def post(self, request):
|
||
year = request.data.get('year')
|
||
if not year:
|
||
return Response({'year': 'This field is required.'}, status=400)
|
||
try:
|
||
year = int(year)
|
||
except (TypeError, ValueError):
|
||
return Response({'year': 'Must be an integer.'}, status=400)
|
||
|
||
current_year = datetime.date.today().year
|
||
if year > current_year + 1:
|
||
return Response(
|
||
{'year': f'You can only create years up to {current_year + 1}.'},
|
||
status=400,
|
||
)
|
||
|
||
max_year = _max_year_for_user(request.user)
|
||
if max_year is not None and year != max_year + 1:
|
||
return Response(
|
||
{'year': f'You can only create the next year ({max_year + 1}).'},
|
||
status=400,
|
||
)
|
||
|
||
household_id = request.data.get('household_id')
|
||
if household_id:
|
||
household = Household.objects.filter(
|
||
id=household_id,
|
||
memberships__user=request.user,
|
||
memberships__status='active',
|
||
).first()
|
||
if not household:
|
||
return Response({'detail': 'Household not found or not a member.'}, status=404)
|
||
if FinancialYear.objects.filter(household=household, year=year).exists():
|
||
return Response({'year': 'This year already exists for this household.'}, status=400)
|
||
FinancialYear.objects.filter(household=household, is_active=True).update(is_active=False)
|
||
fy = FinancialYear.objects.create(household=household, year=year, is_active=True)
|
||
else:
|
||
if FinancialYear.objects.filter(user=request.user, year=year).exists():
|
||
return Response({'year': 'This year already exists.'}, status=400)
|
||
FinancialYear.objects.filter(user=request.user, is_active=True).update(is_active=False)
|
||
fy = FinancialYear.objects.create(user=request.user, year=year, is_active=True)
|
||
|
||
return Response(FinancialYearSerializer(fy).data, status=201)
|
||
|
||
|
||
class FinancialYearDetailView(views.APIView):
|
||
def _get_or_404(self, request, year):
|
||
fy = _get_user_financial_year(request.user, year)
|
||
if not fy:
|
||
return None
|
||
return fy
|
||
|
||
def get(self, request, year):
|
||
fy = self._get_or_404(request, year)
|
||
if not fy:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
return Response(FinancialYearSerializer(fy).data)
|
||
|
||
def patch(self, request, year):
|
||
fy = self._get_or_404(request, year)
|
||
if not fy:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
serializer = FinancialYearSerializer(fy, data=request.data, partial=True)
|
||
if serializer.is_valid():
|
||
serializer.save()
|
||
return Response(serializer.data)
|
||
return Response(serializer.errors, status=400)
|
||
|
||
def delete(self, request, year):
|
||
fy = self._get_or_404(request, year)
|
||
if not fy:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
if not fy.is_active:
|
||
return Response({'detail': 'Archived years cannot be deleted.'}, status=400)
|
||
fy.delete()
|
||
return Response(status=204)
|
||
|
||
|
||
class FinancialYearCopyView(views.APIView):
|
||
def post(self, request, year, source_year):
|
||
source = _get_user_financial_year(request.user, source_year)
|
||
if not source:
|
||
return Response({'detail': f'Source year {source_year} not found.'}, status=404)
|
||
|
||
target = _get_user_financial_year(request.user, year)
|
||
if not target:
|
||
return Response({'detail': f'Target year {year} not found.'}, status=404)
|
||
|
||
if not target.is_active:
|
||
return Response({'detail': 'Target year is archived.'}, status=400)
|
||
|
||
with db_transaction.atomic():
|
||
incomes_copied = 0
|
||
for income in source.incomes.all():
|
||
YearlyIncome.objects.create(
|
||
financial_year=target,
|
||
member=income.member,
|
||
name=income.name,
|
||
amount=income.amount,
|
||
active=income.active,
|
||
notes=income.notes,
|
||
)
|
||
incomes_copied += 1
|
||
|
||
items_copied = 0
|
||
for item in source.budget_items.all():
|
||
YearlyBudgetItem.objects.create(
|
||
financial_year=target,
|
||
name=item.name,
|
||
amount=item.amount,
|
||
active=item.active,
|
||
notes=item.notes,
|
||
)
|
||
items_copied += 1
|
||
|
||
return Response({
|
||
'year': year,
|
||
'source_year': source_year,
|
||
'incomes_copied': incomes_copied,
|
||
'budget_items_copied': items_copied,
|
||
})
|
||
|
||
|
||
# ── YearlyIncome Views ────────────────────────────────────────────────────────
|
||
|
||
class YearlyIncomeListCreateView(views.APIView):
|
||
def _get_year_or_404(self, request, year):
|
||
fy = _get_user_financial_year(request.user, year)
|
||
return fy
|
||
|
||
def get(self, request, year):
|
||
fy = self._get_year_or_404(request, year)
|
||
if not fy:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
return Response(YearlyIncomeSerializer(fy.incomes.all(), many=True).data)
|
||
|
||
def post(self, request, year):
|
||
fy = self._get_year_or_404(request, year)
|
||
if not fy:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
if not fy.is_active:
|
||
return Response({'detail': 'Archived years are read-only.'}, status=403)
|
||
serializer = YearlyIncomeSerializer(data=request.data)
|
||
if serializer.is_valid():
|
||
serializer.save(financial_year=fy, member=request.user)
|
||
return Response(serializer.data, status=201)
|
||
return Response(serializer.errors, status=400)
|
||
|
||
|
||
class YearlyIncomeDetailView(views.APIView):
|
||
def _get_income_or_404(self, request, year, pk):
|
||
fy = _get_user_financial_year(request.user, year)
|
||
if not fy:
|
||
return None, None
|
||
income = fy.incomes.filter(pk=pk).first()
|
||
return fy, income
|
||
|
||
def patch(self, request, year, pk):
|
||
fy, income = self._get_income_or_404(request, year, pk)
|
||
if not income:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
serializer = YearlyIncomeSerializer(income, data=request.data, partial=True)
|
||
if serializer.is_valid():
|
||
serializer.save()
|
||
return Response(serializer.data)
|
||
return Response(serializer.errors, status=400)
|
||
|
||
def delete(self, request, year, pk):
|
||
fy, income = self._get_income_or_404(request, year, pk)
|
||
if not income:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
income.delete()
|
||
return Response(status=204)
|
||
|
||
|
||
# ── YearlyBudgetItem Views ────────────────────────────────────────────────────
|
||
|
||
class YearlyBudgetItemListCreateView(views.APIView):
|
||
def _get_year_or_404(self, request, year):
|
||
return _get_user_financial_year(request.user, year)
|
||
|
||
def get(self, request, year):
|
||
fy = self._get_year_or_404(request, year)
|
||
if not fy:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
return Response(YearlyBudgetItemSerializer(fy.budget_items.all(), many=True).data)
|
||
|
||
def post(self, request, year):
|
||
fy = self._get_year_or_404(request, year)
|
||
if not fy:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
serializer = YearlyBudgetItemSerializer(data=request.data)
|
||
if serializer.is_valid():
|
||
serializer.save(financial_year=fy)
|
||
return Response(serializer.data, status=201)
|
||
return Response(serializer.errors, status=400)
|
||
|
||
|
||
class YearlyBudgetItemDetailView(views.APIView):
|
||
def _get_item_or_404(self, request, year, pk):
|
||
fy = _get_user_financial_year(request.user, year)
|
||
if not fy:
|
||
return None, None
|
||
item = fy.budget_items.filter(pk=pk).first()
|
||
return fy, item
|
||
|
||
def patch(self, request, year, pk):
|
||
fy, item = self._get_item_or_404(request, year, pk)
|
||
if not item:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
serializer = YearlyBudgetItemSerializer(item, data=request.data, partial=True)
|
||
if serializer.is_valid():
|
||
serializer.save()
|
||
return Response(serializer.data)
|
||
return Response(serializer.errors, status=400)
|
||
|
||
def delete(self, request, year, pk):
|
||
fy, item = self._get_item_or_404(request, year, pk)
|
||
if not item:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
item.delete()
|
||
return Response(status=204)
|
||
|
||
|
||
# ── Household Views ───────────────────────────────────────────────────────────
|
||
|
||
class HouseholdListCreateView(views.APIView):
|
||
def get(self, request):
|
||
memberships = HouseholdMembership.objects.filter(user=request.user, status__in=['active', 'pending'])
|
||
household_ids = memberships.values_list('household_id', flat=True)
|
||
created = Household.objects.filter(created_by=request.user)
|
||
qs = (Household.objects.filter(id__in=household_ids) | created).distinct()
|
||
return Response(HouseholdSerializer(qs, many=True).data)
|
||
|
||
def post(self, request):
|
||
serializer = HouseholdSerializer(data=request.data)
|
||
if serializer.is_valid():
|
||
household = serializer.save(created_by=request.user)
|
||
# Creator is automatically an active member
|
||
next_year = (datetime.date.today().year + 1)
|
||
HouseholdMembership.objects.create(
|
||
household=household,
|
||
user=request.user,
|
||
invited_by=request.user,
|
||
status='active',
|
||
role='admin',
|
||
effective_from_year=next_year,
|
||
)
|
||
return Response(HouseholdSerializer(household).data, status=201)
|
||
return Response(serializer.errors, status=400)
|
||
|
||
|
||
class HouseholdInviteView(views.APIView):
|
||
def post(self, request, pk):
|
||
household = Household.objects.filter(pk=pk).first()
|
||
if not household:
|
||
return Response({'detail': 'Not found.'}, status=404)
|
||
# Only founder or active admins can invite
|
||
is_founder = household.created_by == request.user
|
||
is_admin = HouseholdMembership.objects.filter(
|
||
household=household, user=request.user, status='active', role='admin'
|
||
).exists()
|
||
if not (is_founder or is_admin):
|
||
return Response({'detail': 'Only admins can invite members.'}, status=403)
|
||
|
||
email = request.data.get('email', '').strip().lower()
|
||
User = get_user_model()
|
||
invitee = User.objects.filter(email__iexact=email).first()
|
||
|
||
from django.conf import settings
|
||
from .email import send_email
|
||
from .models import PendingHouseholdInvite
|
||
next_year = datetime.date.today().year + 1
|
||
inviter_name = request.user.get_full_name() or request.user.email
|
||
|
||
if not invitee:
|
||
if PendingHouseholdInvite.objects.filter(household=household, invited_email__iexact=email).exists():
|
||
return Response({'detail': 'Invitation already sent to this email.'}, status=400)
|
||
PendingHouseholdInvite.objects.create(
|
||
household=household,
|
||
invited_by=request.user,
|
||
invited_email=email,
|
||
effective_from_year=next_year,
|
||
)
|
||
register_url = f"{settings.FRONTEND_URL}/register"
|
||
send_email(
|
||
template_name='household_invite',
|
||
subject=f'Einladung zum Haushalt «{household.name}»',
|
||
context={
|
||
'invitee_name': email,
|
||
'inviter_name': inviter_name,
|
||
'household_name': household.name,
|
||
'accept_url': register_url,
|
||
'cta_label': 'Konto erstellen & beitreten',
|
||
},
|
||
to=email,
|
||
)
|
||
return Response({'detail': f'Registration invitation sent to {email}.'})
|
||
|
||
if invitee == request.user:
|
||
return Response({'detail': 'You cannot invite yourself.'}, status=400)
|
||
if HouseholdMembership.objects.filter(household=household, user=invitee, status__in=['pending', 'active']).exists():
|
||
return Response({'detail': 'User is already a member or has a pending invitation.'}, status=400)
|
||
|
||
HouseholdMembership.objects.create(
|
||
household=household,
|
||
user=invitee,
|
||
invited_by=request.user,
|
||
status='pending',
|
||
effective_from_year=next_year,
|
||
)
|
||
invitee_name = invitee.get_full_name() or invitee.email
|
||
send_email(
|
||
template_name='household_invite',
|
||
subject=f'Einladung zum Haushalt «{household.name}»',
|
||
context={
|
||
'invitee_name': invitee_name,
|
||
'inviter_name': inviter_name,
|
||
'household_name': household.name,
|
||
'accept_url': f"{settings.FRONTEND_URL}/financial-year",
|
||
'cta_label': 'Einladung annehmen',
|
||
},
|
||
to=invitee.email,
|
||
)
|
||
return Response({'detail': f'Invitation sent to {email} for year {next_year}.'})
|
||
|
||
|
||
class HouseholdAcceptView(views.APIView):
|
||
def post(self, request, pk):
|
||
membership = HouseholdMembership.objects.filter(
|
||
household_id=pk, user=request.user, status='pending'
|
||
).first()
|
||
if not membership:
|
||
return Response({'detail': 'No pending invitation found.'}, status=404)
|
||
membership.status = 'active'
|
||
membership.save(update_fields=['status'])
|
||
return Response({'detail': 'Invitation accepted.'})
|
||
|
||
|
||
class HouseholdLeaveView(views.APIView):
|
||
def post(self, request, pk):
|
||
membership = HouseholdMembership.objects.filter(
|
||
household_id=pk, user=request.user, status='active'
|
||
).first()
|
||
if not membership:
|
||
return Response({'detail': 'You are not an active member of this household.'}, status=404)
|
||
|
||
next_year = datetime.date.today().year + 1
|
||
membership.status = 'left'
|
||
membership.effective_until_year = next_year
|
||
membership.save(update_fields=['status', 'effective_until_year'])
|
||
return Response({'detail': f'You will leave this household at the end of {next_year - 1}.'})
|
||
|
||
|
||
class HouseholdSetRoleView(views.APIView):
|
||
def post(self, request, pk, membership_id):
|
||
# Only the founder can assign roles
|
||
household = Household.objects.filter(pk=pk, created_by=request.user).first()
|
||
if not household:
|
||
return Response({'detail': 'Not found or not owner.'}, status=404)
|
||
|
||
membership = HouseholdMembership.objects.filter(
|
||
pk=membership_id, household=household, status='active'
|
||
).first()
|
||
if not membership:
|
||
return Response({'detail': 'Active membership not found.'}, status=404)
|
||
|
||
if membership.user == request.user:
|
||
return Response({'detail': 'Cannot change your own role.'}, status=400)
|
||
|
||
role = request.data.get('role')
|
||
if role not in ['member', 'admin']:
|
||
return Response({'detail': 'Role must be "member" or "admin".'}, status=400)
|
||
|
||
membership.role = role
|
||
membership.save(update_fields=['role'])
|
||
return Response(HouseholdMembershipSerializer(membership).data)
|
||
|
||
|
||
class HouseholdRevenueAccountsView(views.APIView):
|
||
def get(self, request, pk):
|
||
membership = HouseholdMembership.objects.filter(
|
||
household_id=pk, user=request.user, status='active'
|
||
).first()
|
||
if not membership:
|
||
return Response({'detail': 'Not a member of this household.'}, status=403)
|
||
|
||
member_users = HouseholdMembership.objects.filter(
|
||
household_id=pk, status='active'
|
||
).values_list('user_id', flat=True)
|
||
|
||
accounts = Account.objects.filter(
|
||
user_id__in=member_users, account_type='revenue', active=True
|
||
).select_related('user')
|
||
|
||
data = [
|
||
{
|
||
'id': a.id,
|
||
'name': a.name,
|
||
'balance': str(a.balance),
|
||
'salary_months': a.salary_months,
|
||
'owner_email': a.user.email,
|
||
'is_mine': a.user_id == request.user.id,
|
||
}
|
||
for a in accounts
|
||
]
|
||
return Response(data)
|