Files
meezi/src/Meezi.API/Middleware/TenantMiddleware.cs
T

155 lines
5.5 KiB
C#
Raw Normal View History

2026-05-27 21:33:48 +03:30
using System.Text.Json;
using Microsoft.EntityFrameworkCore;
using Meezi.Core.Constants;
using Meezi.Core.Enums;
using Meezi.Core.Interfaces;
using Meezi.Infrastructure.Data;
using Meezi.Shared;
namespace Meezi.API.Middleware;
public class TenantMiddleware
{
private static readonly string[] PublicPrefixes =
[
"/api/auth",
"/api/public",
"/api/q/",
"/api/webhooks",
"/api/billing/verify",
"/hubs/guest-order",
"/health",
"/swagger",
"/hangfire"
];
private readonly RequestDelegate _next;
private readonly ILogger<TenantMiddleware> _logger;
public TenantMiddleware(RequestDelegate next, ILogger<TenantMiddleware> logger)
{
_next = next;
_logger = logger;
}
public async Task InvokeAsync(
HttpContext context,
ITenantContext tenant,
IBranchContext branchContext,
AppDbContext db)
{
if (IsPublicPath(context.Request.Path))
{
await _next(context);
return;
}
if (context.User.Identity?.IsAuthenticated != true)
{
await WriteUnauthorizedAsync(context, "UNAUTHORIZED", "Authentication required.");
return;
}
var actor = context.User.FindFirst(MeeziClaimTypes.Actor)?.Value;
var pathValue = context.Request.Path.Value ?? string.Empty;
if (actor == MeeziActorKinds.Consumer)
{
if (pathValue.StartsWith("/api/customers/me", StringComparison.OrdinalIgnoreCase))
{
await _next(context);
return;
}
await WriteForbiddenAsync(context, "FORBIDDEN", "Consumer access is limited to account endpoints.");
return;
}
if (tenant is TenantContext scopedTenant)
{
scopedTenant.UserId = context.User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value
?? context.User.FindFirst("sub")?.Value;
scopedTenant.Language = context.User.FindFirst(MeeziClaimTypes.Language)?.Value ?? "fa";
}
var cafeId = context.User.FindFirst(MeeziClaimTypes.CafeId)?.Value;
if (string.IsNullOrEmpty(cafeId))
{
_logger.LogWarning("Authenticated request missing cafeId claim for {Path}", context.Request.Path);
await WriteUnauthorizedAsync(context, "UNAUTHORIZED", "Cafe context is missing.");
return;
}
var cafeSuspended = await db.Cafes
.AsNoTracking()
.AnyAsync(c => c.Id == cafeId && c.IsSuspended, context.RequestAborted);
if (cafeSuspended)
{
await WriteForbiddenAsync(context, "CAFE_SUSPENDED", "This cafe account is suspended. Contact Meezi support.");
return;
}
if (tenant is TenantContext scopedMerchant)
{
scopedMerchant.CafeId = cafeId;
var roleClaim = context.User.FindFirst(MeeziClaimTypes.Role)?.Value;
if (Enum.TryParse<EmployeeRole>(roleClaim, ignoreCase: true, out var role))
scopedMerchant.Role = role;
var planClaim = context.User.FindFirst(MeeziClaimTypes.PlanTier)?.Value;
if (Enum.TryParse<PlanTier>(planClaim, ignoreCase: true, out var plan))
scopedMerchant.PlanTier = plan;
var branchIdClaim = context.User.FindFirst(MeeziClaimTypes.BranchId)?.Value;
if (!string.IsNullOrEmpty(branchIdClaim))
{
var branchValid = await db.Branches.AnyAsync(
b => b.Id == branchIdClaim && b.CafeId == cafeId && b.IsActive,
context.RequestAborted);
if (branchValid)
scopedMerchant.BranchId = branchIdClaim;
else
_logger.LogWarning("Ignoring invalid or inactive branchId claim for cafe {CafeId}", cafeId);
}
}
if (branchContext is BranchContext scopedBranch)
{
scopedBranch.CafeId = cafeId;
if (tenant is TenantContext scopedTenantBranch && !string.IsNullOrEmpty(scopedTenantBranch.BranchId))
scopedBranch.BranchId = scopedTenantBranch.BranchId;
}
await _next(context);
}
private static bool IsPublicPath(PathString path)
{
var value = path.Value ?? string.Empty;
return PublicPrefixes.Any(prefix =>
value.StartsWith(prefix, StringComparison.OrdinalIgnoreCase));
}
private static async Task WriteUnauthorizedAsync(HttpContext context, string code, string message)
{
context.Response.StatusCode = StatusCodes.Status401Unauthorized;
context.Response.ContentType = "application/json";
var payload = new ApiResponse<object>(false, null, new ApiError(code, message));
await context.Response.WriteAsync(JsonSerializer.Serialize(payload, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
}));
}
private static async Task WriteForbiddenAsync(HttpContext context, string code, string message)
{
context.Response.StatusCode = StatusCodes.Status403Forbidden;
context.Response.ContentType = "application/json";
var payload = new ApiResponse<object>(false, null, new ApiError(code, message));
await context.Response.WriteAsync(JsonSerializer.Serialize(payload, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
}));
}
}