Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/.env.docker
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ DATABASE_URL_DEDICATED=postgresql://dbusername:dbpassword@postgres:5432/httpsms
# Redis connection string
REDIS_URL=redis://@redis:6379

# Rate limiting (set to "true" to enable per-user API rate tracking)
RATE_LIMIT_ENABLED=false

# Google Cloud Storage bucket for MMS attachments. Leave empty to use in-memory storage.
GCS_BUCKET_NAME=

Expand Down
60 changes: 53 additions & 7 deletions api/pkg/di/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ type Container struct {
userRistrettoCache *ristretto.Cache[string, entities.AuthContext]
phoneRistrettoCache *ristretto.Cache[string, *entities.Phone]
inMemoryCache cache.Cache
rateLimitService *services.RateLimitService
redisClient *redis.Client
}

// NewLiteContainer creates a Container without any routes or listeners
Expand All @@ -105,6 +107,13 @@ func NewLiteContainer() (container *Container) {
}
}

// Close gracefully shuts down container resources
func (container *Container) Close() {
if container.rateLimitService != nil {
container.rateLimitService.Close()
}
}

// NewContainer creates a new dependency injection container
func NewContainer(projectID string, version string) (container *Container) {
container = &Container{
Expand Down Expand Up @@ -204,6 +213,20 @@ func (container *Container) App() (app *fiber.App) {
app.Use(middlewares.BearerAuth(container.Logger(), container.Tracer(), container.FirebaseAuthClient()))
app.Use(middlewares.APIKeyAuth(container.Logger(), container.Tracer(), container.UserRepository()))

if os.Getenv("RATE_LIMIT_ENABLED") == "true" {
app.Use(middlewares.RateLimit(
container.Tracer(),
container.Logger(),
container.RateLimitService(),
container.UserRepository(),
[]string{"/v1/events"},
))
app.Hooks().OnPreShutdown(func() error {
container.RateLimitService().Close()
return nil
})
}

container.app = app
return app
}
Expand Down Expand Up @@ -443,6 +466,16 @@ func (container *Container) InMemoryCache() cache.Cache {
// Cache creates a new instance of cache.Cache
func (container *Container) Cache() cache.Cache {
container.logger.Debug("creating cache.Cache")
return cache.NewRedisCache(container.Tracer(), container.RedisClient())
}

// RedisClient creates or returns the shared *redis.Client
func (container *Container) RedisClient() *redis.Client {
if container.redisClient != nil {
return container.redisClient
}

container.logger.Debug("creating *redis.Client")
opt, err := redis.ParseURL(os.Getenv("REDIS_URL"))
if err != nil {
container.logger.Fatal(stacktrace.Propagate(err, fmt.Sprintf("cannot parse redis url [%s]", os.Getenv("REDIS_URL"))))
Expand All @@ -453,19 +486,32 @@ func (container *Container) Cache() cache.Cache {
}
}

redisClient := redis.NewClient(opt)
container.redisClient = redis.NewClient(opt)

// Enable tracing instrumentation.
if err = redisotel.InstrumentTracing(redisClient); err != nil {
if err = redisotel.InstrumentTracing(container.redisClient); err != nil {
container.logger.Error(stacktrace.Propagate(err, "cannot instrument redis tracing"))
}

// Enable metrics instrumentation.
if err = redisotel.InstrumentMetrics(redisClient); err != nil {
if err = redisotel.InstrumentMetrics(container.redisClient); err != nil {
container.logger.Fatal(stacktrace.Propagate(err, "cannot instrument redis metrics"))
}

return cache.NewRedisCache(container.Tracer(), redisClient)
return container.redisClient
}

// RateLimitService creates or returns the shared *services.RateLimitService
func (container *Container) RateLimitService() *services.RateLimitService {
if container.rateLimitService != nil {
return container.rateLimitService
}

container.logger.Debug("creating services.RateLimitService")
container.rateLimitService = services.NewRateLimitService(
container.Tracer(),
container.Logger(),
container.RedisClient(),
container.EventDispatcher(),
)
return container.rateLimitService
}

// FirebaseAuthClient creates a new instance of auth.Client
Expand Down
5 changes: 5 additions & 0 deletions api/pkg/entities/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ func (subscription SubscriptionName) Limit() uint {
}
}

// RateLimit returns the daily API request rate limit for a subscription
func (subscription SubscriptionName) RateLimit() uint {
return subscription.Limit() * 2
}

// SubscriptionNameFree represents a free subscription
const SubscriptionNameFree = SubscriptionName("free")

Expand Down
16 changes: 16 additions & 0 deletions api/pkg/entities/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,19 @@ func TestUser_GetBillingAnchorDay_PaidUserDay31(t *testing.T) {
}
assert.Equal(t, 31, user.GetBillingAnchorDay())
}

func TestSubscriptionName_RateLimit_Free(t *testing.T) {
assert.Equal(t, uint(400), SubscriptionNameFree.RateLimit())
}

func TestSubscriptionName_RateLimit_Pro(t *testing.T) {
assert.Equal(t, uint(10000), SubscriptionNameProMonthly.RateLimit())
}

func TestSubscriptionName_RateLimit_Ultra(t *testing.T) {
assert.Equal(t, uint(20000), SubscriptionNameUltraMonthly.RateLimit())
}

func TestSubscriptionName_RateLimit_200K(t *testing.T) {
assert.Equal(t, uint(400000), SubscriptionName200KMonthly.RateLimit())
}
19 changes: 19 additions & 0 deletions api/pkg/events/rate_limit_exceeded_event.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package events

import (
"time"

"github.com/NdoleStudio/httpsms/pkg/entities"
)

// RateLimitExceeded is raised when a user exceeds their daily API rate limit.
const RateLimitExceeded = "rate.limit.exceeded"

// RateLimitExceededPayload stores the data for the RateLimitExceeded event
type RateLimitExceededPayload struct {
UserID entities.UserID `json:"user_id"`
Count int64 `json:"count"`
Limit uint `json:"limit"`
Plan string `json:"plan"`
Timestamp time.Time `json:"timestamp"`
}
62 changes: 62 additions & 0 deletions api/pkg/middlewares/rate_limit_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package middlewares

import (
"strconv"
"strings"

"github.com/NdoleStudio/httpsms/pkg/entities"
"github.com/NdoleStudio/httpsms/pkg/repositories"
"github.com/NdoleStudio/httpsms/pkg/services"
"github.com/NdoleStudio/httpsms/pkg/telemetry"
"github.com/gofiber/fiber/v3"
)

const rateLimitCostCap = 100

// RateLimit tracks per-user API request counts without blocking requests.
func RateLimit(
tracer telemetry.Tracer,
logger telemetry.Logger,
service *services.RateLimitService,
userRepository repositories.UserRepository,
excludePaths []string,
) fiber.Handler {
logger = logger.WithService("middlewares.RateLimit")

return func(c fiber.Ctx) error {
path := c.Path()
for _, excluded := range excludePaths {
if strings.HasPrefix(path, excluded) {
return c.Next()
}
}

ctx, span := tracer.StartFromFiberCtx(c, "middlewares.RateLimit")
defer span.End()

authUser, ok := c.Locals(ContextKeyAuthUserID).(entities.AuthContext)
if !ok || authUser.IsNoop() {
return c.Next()
}

cost := 1
if c.Method() == fiber.MethodGet {
if limitParam := c.Query("limit"); limitParam != "" {
if parsed, err := strconv.Atoi(limitParam); err == nil && parsed > 0 {
cost = min(parsed, rateLimitCostCap)
}
}
}

user, err := userRepository.Load(ctx, authUser.ID)
if err != nil {
ctxLogger := tracer.CtxLogger(logger, span)
ctxLogger.Error(err)
return c.Next()
}
Comment on lines +51 to +56

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Unconditional database query on every authenticated request

userRepository.Load is called on every request solely to retrieve user.SubscriptionName. With ~20M requests/month this adds an equivalent number of DB queries — one per API call — just for rate-limit plan lookups. The auth middlewares that run before this one (BearerAuth, APIKeyAuth) already load and validate the user; if the AuthContext stored in c.Locals already carries the subscription plan (or if there is a short-lived cache in the repository), this load could be avoided entirely. If AuthContext doesn't include the plan today, it would be worth adding it there rather than performing a separate load per request.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


_, _, _ = service.Increment(ctx, authUser.ID, user.SubscriptionName, cost)

return c.Next()
}
}
Loading
Loading