From 8caef960346f090f6e14d8b86725f03db0359544 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 9 Jun 2026 23:27:48 +0200 Subject: [PATCH 1/2] Support gateways with multiple replicas MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A gateway can now have multiple replicas for improved availability. ```yaml type: gateway name: example-gateway backend: aws region: eu-west-1 domain: example.com certificate: null replicas: 2 ``` To balance requests between gateway replicas, add DNS records for each replica or set up a load balancer outside of `dstack`. Replica hostnames are displayed in `dstack` CLI and UI. ```shell $ dstack gateway list NAME BACKEND HOSTNAME DOMAIN DEFAULT STATUS example-gateway example.com ✓ running replica=0 aws (eu-west-1) 34.244.128.46 replica=1 aws (eu-west-1) 18.201.201.174 ``` Limitations: - Changing the number of replicas or redeploying replicas is not supported. - HTTPS is not supported. Use an external load balancer for TLS termination. - An unavailable gateway replica prevents any new services or service replicas from being added. - All replicas are bound to the same backend and region. Implementation notes: - `GatewayComputeModel` now represents a gateway replica. - In this version, the terms "compute" and "replica" are used interchangeably. The plan is to switch to using exclusively "replica" later. - In this version, replica provisioning and termination are still done in the gateway pipeline, for all replicas at once. The plan is to introduce gateway replica pipelines later to allow for independent replica processing. --- frontend/src/locale/en.json | 1 + .../Table/hooks/useColumnsDefinitions.tsx | 16 +- frontend/src/types/gateway.d.ts | 8 + mkdocs/docs/concepts/gateways.md | 43 +++ .../cli/services/configurators/gateway.py | 4 + src/dstack/_internal/cli/utils/gateway.py | 32 +- .../_internal/core/compatibility/gateways.py | 2 + src/dstack/_internal/core/models/gateways.py | 30 +- .../background/pipeline_tasks/gateways.py | 212 +++++++---- .../background/pipeline_tasks/jobs_running.py | 63 ++-- .../pipeline_tasks/jobs_terminating.py | 47 ++- .../pipeline_tasks/runs/__init__.py | 12 +- .../pipeline_tasks/runs/terminating.py | 43 ++- .../server/compatibility/gateways.py | 15 + ...ea4d_add_gatewaycomputemodel_gateway_id.py | 52 +++ src/dstack/_internal/server/models.py | 43 ++- .../_internal/server/routers/gateways.py | 53 +-- .../server/services/gateways/__init__.py | 161 ++++++-- .../server/services/services/__init__.py | 92 ++--- src/dstack/_internal/server/testing/common.py | 4 +- .../pipeline_tasks/test_gateways.py | 351 +++++++++++++++++- .../pipeline_tasks/test_running_jobs.py | 20 +- .../server/compatibility/__init__.py | 0 .../server/compatibility/test_gateways.py | 88 +++++ .../_internal/server/routers/test_gateways.py | 341 +++++++++++++---- .../_internal/server/routers/test_runs.py | 33 +- .../server/services/gateways/__init__.py | 0 .../server/services/gateways/test_gateways.py | 88 +++++ 28 files changed, 1468 insertions(+), 386 deletions(-) create mode 100644 src/dstack/_internal/server/compatibility/gateways.py create mode 100644 src/dstack/_internal/server/migrations/versions/2026/06_01_1911_b7609b94ea4d_add_gatewaycomputemodel_gateway_id.py create mode 100644 src/tests/_internal/server/compatibility/__init__.py create mode 100644 src/tests/_internal/server/compatibility/test_gateways.py create mode 100644 src/tests/_internal/server/services/gateways/__init__.py create mode 100644 src/tests/_internal/server/services/gateways/test_gateways.py diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index 7d8c9f2465..804134f3d1 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -144,6 +144,7 @@ "region_description": "Select a region", "default": "Default", "default_checkbox": "Turn on default", + "hostname": "Hostname", "external_ip": "External IP", "wildcard_domain": "Wildcard domain", "wildcard_domain_description": "Specify the wildcard domain mapped to the external IP.", diff --git a/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx b/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx index 79b4aac22c..f63a77fa14 100644 --- a/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx +++ b/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx @@ -30,13 +30,15 @@ export const useColumnsDefinitions = ({ loading, projectName, onDeleteClick, onE { id: 'type', header: t('gateway.edit.backend'), - cell: (gateway: IGateway) => gateway.backend, + cell: (gateway: IGateway) => + gateway.replicas.length > 0 ? gateway.replicas.map((r, i) =>
{r.backend}
) : null, }, { id: 'region', header: t('gateway.edit.region'), - cell: (gateway: IGateway) => gateway.region, + cell: (gateway: IGateway) => + gateway.replicas.length > 0 ? gateway.replicas.map((r, i) =>
{r.region}
) : null, }, { @@ -46,9 +48,13 @@ export const useColumnsDefinitions = ({ loading, projectName, onDeleteClick, onE }, { - id: 'external_ip', - header: t('gateway.edit.external_ip'), - cell: (gateway: IGateway) => gateway.ip_address, + id: 'hostname', + header: t('gateway.edit.hostname'), + cell: (gateway: IGateway) => { + if (gateway.hostname) return gateway.hostname; + if (gateway.replicas.length > 0) return gateway.replicas.map((r, i) =>
{r.hostname}
); + return null; + }, }, { diff --git a/frontend/src/types/gateway.d.ts b/frontend/src/types/gateway.d.ts index 4ef2eeeb54..1442cf4d62 100644 --- a/frontend/src/types/gateway.d.ts +++ b/frontend/src/types/gateway.d.ts @@ -1,3 +1,9 @@ +declare interface IGatewayReplica { + hostname: string, + backend: string, + region: string, +} + declare interface IGateway { backend: string, name: string, @@ -5,8 +11,10 @@ declare interface IGateway { ip_address: string, instance_id: string, region:string + hostname?: string, wildcard_domain?: string default: boolean + replicas: IGatewayReplica[], created_at?: number, } diff --git a/mkdocs/docs/concepts/gateways.md b/mkdocs/docs/concepts/gateways.md index bd71187964..26374a6751 100644 --- a/mkdocs/docs/concepts/gateways.md +++ b/mkdocs/docs/concepts/gateways.md @@ -182,6 +182,49 @@ domain: example.com +### Replicas + +A gateway can have multiple replicas for improved availability. + +
+ +```yaml +type: gateway +name: example-gateway + +backend: aws +region: eu-west-1 + +domain: example.com + +certificate: null +replicas: 2 +``` + +
+ +To balance requests between gateway replicas, add DNS records for each replica or set up a load balancer outside of `dstack`. Replica hostnames are displayed in `dstack` CLI and UI. + +
+ +```shell +$ dstack gateway list + NAME BACKEND HOSTNAME DOMAIN DEFAULT STATUS + example-gateway example.com ✓ running + replica=0 aws (eu-west-1) 34.244.128.46 + replica=1 aws (eu-west-1) 18.201.201.174 +``` + +
+ +!!! warning "Experimental" + Replicated gateways are an experimental feature and currently have limitations: + + - Changing the number of replicas or redeploying replicas is not supported. + - HTTPS is not supported. Use an external load balancer for TLS termination. + - An unavailable gateway replica prevents any new services or service replicas from being added. + - All replicas are bound to the same backend and region. + !!! info "Reference" For all gateway configuration options, refer to the [reference](../reference/dstack.yml/gateway.md). diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py index 0b6993e18b..9f8e6cd0d1 100644 --- a/src/dstack/_internal/cli/services/configurators/gateway.py +++ b/src/dstack/_internal/cli/services/configurators/gateway.py @@ -236,6 +236,10 @@ def th(s: str) -> str: configuration_table.add_row(th("Region"), plan.spec.configuration.region) configuration_table.add_row(th("Domain"), domain) + if plan.spec.configuration.replicas is not None: + assert isinstance(plan.spec.configuration.replicas, int) + configuration_table.add_row(th("Replicas"), str(plan.spec.configuration.replicas)) + console.print(configuration_table) console.print() diff --git a/src/dstack/_internal/cli/utils/gateway.py b/src/dstack/_internal/cli/utils/gateway.py index 5605326211..0d873a9a5d 100644 --- a/src/dstack/_internal/cli/utils/gateway.py +++ b/src/dstack/_internal/cli/utils/gateway.py @@ -95,15 +95,39 @@ def get_gateways_table( # Ignore errors in case future server versions introduce more interpolation variables exception_type=None, ) - row = { + + gateway_row = { "NAME": name, - "BACKEND": format_backend(gateway.configuration.backend, gateway.configuration.region), - "HOSTNAME": gateway.hostname, "DOMAIN": domain, "DEFAULT": "✓" if gateway.default else "", "STATUS": gateway.status, "CREATED": format_date(gateway.created_at), "ERROR": gateway.status_message, } - add_row_from_dict(table, row) + if gateway.hostname is not None: + gateway_row["HOSTNAME"] = gateway.hostname + if len(gateway.replicas) == 0: + # replicas not yet created, or it's a pre-0.20.25 server without replica support + gateway_row["BACKEND"] = format_backend( + gateway.configuration.backend, gateway.configuration.region + ) + gateway_row["HOSTNAME"] = gateway_row.get("HOSTNAME", gateway.ip_address) + if len(gateway.replicas) == 1: + # compact display for single-replica gateway + gateway_row["BACKEND"] = format_backend( + gateway.replicas[0].backend, gateway.replicas[0].region + ) + gateway_row["HOSTNAME"] = gateway_row.get("HOSTNAME", gateway.replicas[0].hostname) + add_row_from_dict(table, gateway_row) + + if len(gateway.replicas) > 1: + for replica in gateway.replicas: + replica_row = { + "NAME": f" replica={replica.replica_num}", + "BACKEND": format_backend(replica.backend, replica.region), + "HOSTNAME": replica.hostname, + "CREATED": format_date(replica.created_at), + } + add_row_from_dict(table, replica_row, style="secondary") + return table diff --git a/src/dstack/_internal/core/compatibility/gateways.py b/src/dstack/_internal/core/compatibility/gateways.py index a2fc6101e6..0a89e86113 100644 --- a/src/dstack/_internal/core/compatibility/gateways.py +++ b/src/dstack/_internal/core/compatibility/gateways.py @@ -41,5 +41,7 @@ def _get_gateway_configuration_excludes( if configuration.router is None: configuration_excludes["router"] = True + if configuration.replicas is None: + configuration_excludes["replicas"] = True return configuration_excludes diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 6f92b449b7..74b3f4e835 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -11,6 +11,8 @@ from dstack._internal.core.models.routers import AnyGatewayRouterConfig from dstack._internal.utils.tags import tags_validator +GATEWAY_REPLICAS_DEFAULT = 1 + class GatewayStatus(str, Enum): SUBMITTED = "submitted" @@ -90,6 +92,13 @@ class GatewayConfiguration(CoreModel): " Set to `null` to disable. Defaults to `type: lets-encrypt`" ), ] = LetsEncryptGatewayCertificate() + replicas: Annotated[ + Optional[int], + Field( + description=f"The number of gateway replicas. Defaults to `{GATEWAY_REPLICAS_DEFAULT}`", + ge=1, + ), + ] = None tags: Annotated[ Optional[Dict[str, str]], Field( @@ -109,6 +118,14 @@ class GatewaySpec(CoreModel): configuration_path: Optional[str] = None +class GatewayReplica(CoreModel): + hostname: str + replica_num: int + backend: BackendType + region: str + created_at: datetime.datetime + + class Gateway(CoreModel): # TODO(0.21): Make `id` required. id: Optional[uuid.UUID] = None @@ -121,14 +138,13 @@ class Gateway(CoreModel): status: GatewayStatus status_message: Optional[str] hostname: Optional[str] - """`hostname` is the IP address or hostname the user should set up the domain for. - Could be the same as `ip_address` but also different, for example a gateway behind ALB. + """Hostname of the load balancer. + Unset if there is no load balancer, in which case users are expected to point the gateway's + wildcard domain name to `replicas[i].hostname`. """ - ip_address: Optional[str] - """`ip_address` is the IP address of the gateway instance.""" - instance_id: Optional[str] wildcard_domain: Optional[str] default: bool + replicas: list[GatewayReplica] = [] backend: Optional[BackendType] = None """`backend` duplicates a configuration field on the top level for backward compatibility with 0.19.x clients that expect it to be required. @@ -139,6 +155,10 @@ class Gateway(CoreModel): with 0.19.x clients that expect it to be required. Remove after 0.21. """ + ip_address: Optional[str] = None + """Deprecated in favor of `replicas[i].hostname`, only set for pre-0.20.25 clients.""" + instance_id: Optional[str] = None + """Deprecated, unused, kept for pre-0.20.25 clients.""" class GatewayPlan(CoreModel): diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index 1f8f0c64f5..05393eb1ae 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -5,12 +5,12 @@ from typing import Optional, Sequence, TypedDict from sqlalchemy import delete, or_, select, update -from sqlalchemy.orm import joinedload, load_only +from sqlalchemy.orm import joinedload, load_only, selectinload from dstack._internal.core.backends.base.compute import ComputeWithGatewaySupport from dstack._internal.core.errors import BackendError, BackendNotAvailable from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.gateways import GatewayStatus +from dstack._internal.core.models.gateways import GATEWAY_REPLICAS_DEFAULT, GatewayStatus from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, @@ -34,7 +34,10 @@ from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services import events from dstack._internal.server.services import gateways as gateways_services -from dstack._internal.server.services.gateways import emit_gateway_status_change_event +from dstack._internal.server.services.gateways import ( + emit_gateway_status_change_event, + get_gateway_compute_models, +) from dstack._internal.server.services.gateways.pool import gateway_connections_pool from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt @@ -239,11 +242,8 @@ async def _process_submitted_item(item: GatewayPipelineItem): set_processed_update_map_fields(update_map) set_unlock_update_map_fields(update_map) async with get_session_ctx() as session: - gateway_compute_model = result.gateway_compute_model - if gateway_compute_model is not None: + for gateway_compute_model in result.gateway_compute_models: session.add(gateway_compute_model) - await session.flush() - update_map["gateway_compute_id"] = gateway_compute_model.id now = get_current_datetime() resolve_now_placeholders(update_map, now=now) res = await session.execute( @@ -258,7 +258,7 @@ async def _process_submitted_item(item: GatewayPipelineItem): updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: log_lock_token_changed_after_processing(logger, item) - # TODO: Clean up gateway_compute_model. + # TODO: Clean up gateway_compute_models. return emit_gateway_status_change_event( session=session, @@ -272,7 +272,6 @@ async def _process_submitted_item(item: GatewayPipelineItem): class _GatewayUpdateMap(ItemUpdateMap, total=False): status: GatewayStatus status_message: str - gateway_compute_id: uuid.UUID class _GatewayComputeUpdateMap(TypedDict, total=False): @@ -283,7 +282,7 @@ class _GatewayComputeUpdateMap(TypedDict, total=False): @dataclass class _SubmittedResult: update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap) - gateway_compute_model: Optional[GatewayComputeModel] = None + gateway_compute_models: list[GatewayComputeModel] = field(default_factory=list) async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedResult: @@ -303,16 +302,28 @@ async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedR "status_message": "Backend not available", } ) + replicas = ( + configuration.replicas if configuration.replicas is not None else GATEWAY_REPLICAS_DEFAULT + ) + gateway_compute_models = [] try: - gateway_compute_model = await gateways_services.create_gateway_compute( - backend_compute=backend.compute(), - project_name=gateway_model.project.name, - configuration=configuration, - backend_id=backend_model.id, - ) + for replica_num in range(replicas): + logger.debug( + "%s replica %d: creating gateway compute", fmt(gateway_model), replica_num + ) + gateway_compute_model = await gateways_services.create_gateway_compute( + backend_compute=backend.compute(), + project_name=gateway_model.project.name, + configuration=configuration, + replica_num=replica_num, + gateway_id=gateway_model.id, + backend_id=backend_model.id, + ) + logger.info("%s replica %d: gateway compute created", fmt(gateway_model), replica_num) + gateway_compute_models.append(gateway_compute_model) return _SubmittedResult( update_map={"status": GatewayStatus.PROVISIONING}, - gateway_compute_model=gateway_compute_model, + gateway_compute_models=gateway_compute_models, ) except BackendError as e: status_message = f"Backend error: {repr(e)}" @@ -322,7 +333,8 @@ async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedR update_map={ "status": GatewayStatus.FAILED, "status_message": status_message, - } + }, + gateway_compute_models=gateway_compute_models, ) except Exception as e: logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model)) @@ -330,7 +342,8 @@ async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedR update_map={ "status": GatewayStatus.FAILED, "status_message": f"Unexpected error: {repr(e)}", - } + }, + gateway_compute_models=gateway_compute_models, ) @@ -343,6 +356,7 @@ async def _process_provisioning_item(item: GatewayPipelineItem): GatewayModel.lock_token == item.lock_token, ) .options(joinedload(GatewayModel.gateway_compute)) + .options(selectinload(GatewayModel.gateway_computes)) ) gateway_model = res.unique().scalar_one_or_none() if gateway_model is None: @@ -377,34 +391,39 @@ async def _process_provisioning_item(item: GatewayPipelineItem): new_status=gateway_update_map.get("status", gateway_model.status), status_message=gateway_update_map.get("status_message", gateway_model.status_message), ) - if result.gateway_compute_update_map: + if result.all_computes_update_map: res = await session.execute( update(GatewayComputeModel) - .where(GatewayComputeModel.id == gateway_model.gateway_compute_id) - .values(**result.gateway_compute_update_map) + .where( + or_( + GatewayComputeModel.gateway_id == gateway_model.id, + GatewayComputeModel.id == gateway_model.gateway_compute_id, + ) + ) + .values(**result.all_computes_update_map) .returning(GatewayComputeModel.id) ) updated_ids = list(res.scalars().all()) - if len(updated_ids) == 0: + if len(updated_ids) < len(get_gateway_compute_models(gateway_model)): logger.error( - "Failed to update compute model %s for gateway %s." + "Failed to update compute models for gateway %s." " This is unexpected and may happen only if the compute model was manually deleted.", gateway_model.id, - item.id, ) @dataclass class _ProvisioningResult: gateway_update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap) - gateway_compute_update_map: _GatewayComputeUpdateMap = field( + all_computes_update_map: _GatewayComputeUpdateMap = field( default_factory=_GatewayComputeUpdateMap ) async def _process_provisioning_gateway(gateway_model: GatewayModel) -> _ProvisioningResult: + gateway_computes = get_gateway_compute_models(gateway_model) # Provisioning gateways must have compute. - assert gateway_model.gateway_compute is not None + assert len(gateway_computes) > 0 # FIXME: problems caused by blocking on connect_to_gateway_with_retry and configure_gateway: # - cannot delete the gateway before it is provisioned because the DB model is locked @@ -413,32 +432,58 @@ async def _process_provisioning_gateway(gateway_model: GatewayModel) -> _Provisi # Easy to fix by doing only one connection/configuration attempt per processing iteration. The # main challenge is applying the same provisioning model to the dstack Sky gateway to avoid # maintaining a different model for Sky. - connection = await gateways_services.connect_to_gateway_with_retry( - gateway_model.gateway_compute + + errors = await asyncio.gather( + *(_connect_and_configure_gateway_replica(gateway_model, gc) for gc in gateway_computes) ) - if connection is None: + if any(errors): return _ProvisioningResult( gateway_update_map={ "status": GatewayStatus.FAILED, - "status_message": "Failed to connect to gateway", + "status_message": next(e for e in errors if e), }, - gateway_compute_update_map={"active": False}, + all_computes_update_map={"active": False}, + ) + + return _ProvisioningResult( + gateway_update_map={"status": GatewayStatus.RUNNING}, + ) + + +async def _connect_and_configure_gateway_replica( + gateway_model: GatewayModel, + gateway_compute: GatewayComputeModel, +) -> Optional[str]: + """Returns an error message on failure, None on success.""" + logger.debug( + "%s replica %d: connecting to gateway compute", + fmt(gateway_model), + gateway_compute.replica_num, + ) + connection = await gateways_services.connect_to_gateway_with_retry(gateway_compute) + if connection is None: + logger.warning( + "%s replica %d: failed to connect to gateway compute", + fmt(gateway_model), + gateway_compute.replica_num, ) + return "Failed to connect to gateway" try: await gateways_services.configure_gateway(connection) except Exception: - logger.exception("%s: failed to configure gateway", fmt(gateway_model)) - await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) - return _ProvisioningResult( - gateway_update_map={ - "status": GatewayStatus.FAILED, - "status_message": "Failed to configure gateway", - }, - gateway_compute_update_map={"active": False}, + logger.exception( + "%s replica %d: failed to configure gateway", + fmt(gateway_model), + gateway_compute.replica_num, ) - return _ProvisioningResult( - gateway_update_map={"status": GatewayStatus.RUNNING}, + await gateway_connections_pool.remove(gateway_compute.ip_address) + return "Failed to configure gateway" + logger.info( + "%s replica %d: gateway compute connected and configured", + fmt(gateway_model), + gateway_compute.replica_num, ) + return None async def _process_to_be_deleted_item(item: GatewayPipelineItem): @@ -451,6 +496,7 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): ) .options(joinedload(GatewayModel.project).joinedload(ProjectModel.backends)) .options(joinedload(GatewayModel.gateway_compute)) + .options(selectinload(GatewayModel.gateway_computes)) .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) ) gateway_model = res.unique().scalar_one_or_none() @@ -460,6 +506,27 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): result = await _process_to_be_deleted_gateway(gateway_model) async with get_session_ctx() as session: + if result.all_computes_update_map: + res = await session.execute( + update(GatewayComputeModel) + .where( + or_( + GatewayComputeModel.gateway_id == gateway_model.id, + GatewayComputeModel.id == gateway_model.gateway_compute_id, + ) + ) + .values(**result.all_computes_update_map) + .returning(GatewayComputeModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) < len(get_gateway_compute_models(gateway_model)): + logger.error( + "Failed to update compute models for gateway %s." + " This is unexpected and may happen only if the compute model was manually deleted.", + gateway_model.id, + ) + return + if result.delete_gateway: res = await session.execute( delete(GatewayModel) @@ -503,28 +570,11 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): log_lock_token_changed_after_processing(logger, item) return - if result.gateway_compute_update_map: - res = await session.execute( - update(GatewayComputeModel) - .where(GatewayComputeModel.id == gateway_model.gateway_compute_id) - .values(**result.gateway_compute_update_map) - .returning(GatewayComputeModel.id) - ) - updated_ids = list(res.scalars().all()) - if len(updated_ids) == 0: - logger.error( - "Failed to update compute model %s for gateway %s." - " This is unexpected and may happen only if the compute model was manually deleted.", - gateway_model.id, - item.id, - ) - return - @dataclass class _ProcessToBeDeletedResult: delete_gateway: bool - gateway_compute_update_map: _GatewayComputeUpdateMap = field( + all_computes_update_map: _GatewayComputeUpdateMap = field( default_factory=_GatewayComputeUpdateMap ) @@ -536,27 +586,39 @@ async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _Proces ) compute = backend.compute() assert isinstance(compute, ComputeWithGatewaySupport) - gateway_compute_configuration = gateways_services.get_gateway_compute_configuration( - gateway_model - ) - if gateway_model.gateway_compute is not None and gateway_compute_configuration is not None: - logger.info("Deleting gateway compute for %s...", gateway_model.name) + + for gateway_compute in get_gateway_compute_models(gateway_model): + gateway_compute_configuration = gateways_services.get_gateway_compute_configuration( + gateway_compute=gateway_compute, + gateway_model=gateway_model, + ) + logger.debug( + "%s replica %d: terminating gateway compute", + fmt(gateway_model), + gateway_compute.replica_num, + ) try: await run_async( compute.terminate_gateway, - gateway_model.gateway_compute.instance_id, + gateway_compute.instance_id, gateway_compute_configuration, - gateway_model.gateway_compute.backend_data, + gateway_compute.backend_data, ) except Exception: logger.exception( - "Error when deleting gateway compute for %s", - gateway_model.name, + "%s replica %d: error when terminating gateway compute", + fmt(gateway_model), + gateway_compute.replica_num, ) return _ProcessToBeDeletedResult(delete_gateway=False) - logger.info("Deleted gateway compute for %s", gateway_model.name) - result = _ProcessToBeDeletedResult(delete_gateway=True) - if gateway_model.gateway_compute is not None: - await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) - result.gateway_compute_update_map = {"active": False, "deleted": True} - return result + logger.info( + "%s replica %d: gateway compute terminated", + fmt(gateway_model), + gateway_compute.replica_num, + ) + await gateway_connections_pool.remove(gateway_compute.ip_address) + + return _ProcessToBeDeletedResult( + delete_gateway=True, + all_computes_update_map={"active": False, "deleted": True}, + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 014f84c604..61599172b5 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -76,7 +76,7 @@ get_instance_specific_mounts, resolve_provisioning_image, ) -from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.gateways import get_or_add_gateway_connections from dstack._internal.server.services.instances import ( get_instance_remote_connection_info, get_instance_ssh_private_keys, @@ -1185,7 +1185,7 @@ async def _register_service_replica( return None async with get_session_ctx() as session: - gateway_model, conn = await get_or_add_gateway_connection( + gateway_model, connections = await get_or_add_gateway_connections( session, context.run_model.gateway_id ) gateway_target = events.Target.from_model(gateway_model) @@ -1197,35 +1197,40 @@ async def _register_service_replica( # so we must update job_submission with the result value. job_submission = context.job_submission.copy(deep=True) job_submission.job_runtime_data = _get_result_job_runtime_data(context.job_model, result) - try: - logger.debug( - "%s: registering replica for service %s", fmt(context.job_model), context.run.id.hex - ) - async with conn.client() as gateway_client: - await gateway_client.register_replica( - run=context.run, - job_spec=job_spec, - job_submission=job_submission, - instance_project_ssh_private_key=instance_project_ssh_private_key, - ssh_head_proxy=ssh_head_proxy, - ssh_head_proxy_private_key=ssh_head_proxy_private_key, - ) - except (httpx.RequestError, SSHError) as e: - logger.debug("Gateway request failed", exc_info=True) - raise GatewayError(repr(e)) - except GatewayError as e: - if "already exists in service" in e.msg: - logger.warning( - ( - "%s: could not register replica in gateway: %s." - " NOTE: if you just updated dstack from pre-0.19.25 to 0.19.25+," - " expect to see this warning once for every running service replica" - ), + for conn in connections: + try: + logger.debug( + "%s: registering replica for service %s on gateway replica %s", fmt(context.job_model), - e.msg, + context.run.id.hex, + conn.ip_address, ) - else: - raise + async with conn.client() as gateway_client: + await gateway_client.register_replica( + run=context.run, + job_spec=job_spec, + job_submission=job_submission, + instance_project_ssh_private_key=instance_project_ssh_private_key, + ssh_head_proxy=ssh_head_proxy, + ssh_head_proxy_private_key=ssh_head_proxy_private_key, + ) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(repr(e)) + except GatewayError as e: + if "already exists in service" in e.msg: + logger.warning( + ( + "%s: could not register replica in gateway %s: %s." + " NOTE: if you just updated dstack from pre-0.19.25 to 0.19.25+," + " expect to see this warning once for every running service replica" + ), + fmt(context.job_model), + conn.ip_address, + e.msg, + ) + else: + raise return gateway_target diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index 3ae30c3ef2..adedf9bb4b 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -51,7 +51,7 @@ ) from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services import events -from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.gateways import get_or_add_gateway_connections from dstack._internal.server.services.instances import ( emit_instance_status_change_event, get_instance_ssh_private_keys, @@ -795,25 +795,36 @@ async def _unregister_replica( run_model = job_model.run if run_model.gateway_id is not None: async with get_session_ctx() as session: - gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) - gateway_target = events.Target.from_model(gateway) - try: - logger.debug( - "%s: unregistering replica from service %s", fmt(job_model), job_model.run_id.hex + gateway, connections = await get_or_add_gateway_connections( + session, run_model.gateway_id ) - async with conn.client() as client: - await client.unregister_replica( - project=run_model.project.name, - run_name=run_model.run_name, - job_id=job_model.id, + gateway_target = events.Target.from_model(gateway) + for conn in connections: + try: + logger.debug( + "%s: unregistering replica from service %s on gateway replica %s", + fmt(job_model), + job_model.run_id.hex, + conn.ip_address, + ) + async with conn.client() as client: + await client.unregister_replica( + project=run_model.project.name, + run_name=run_model.run_name, + job_id=job_model.id, + ) + except GatewayError as e: + logger.warning( + "%s: unregistering replica from service on gateway replica %s: %s", + fmt(job_model), + conn.ip_address, + e, ) - except GatewayError as e: - logger.warning("%s: unregistering replica from service: %s", fmt(job_model), e) - except (httpx.RequestError, SSHError) as e: - logger.debug("Gateway request failed", exc_info=True) - # FIXME: Unhandled exception raised. - # Handle and retry unregister with timeout. - raise GatewayError(repr(e)) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + # FIXME: Unhandled exception raised. + # Handle and retry unregister with timeout. + raise GatewayError(repr(e)) return gateway_target diff --git a/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py index c26df9d4d3..071af9fbd3 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py @@ -28,7 +28,7 @@ from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import InstanceModel, JobModel, ProjectModel, RunModel from dstack._internal.server.services import events -from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.gateways import get_combined_gateway_stats from dstack._internal.server.services.jobs import emit_job_status_change_event from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.pipelines import PipelineHinterProtocol @@ -313,8 +313,9 @@ async def _load_pending_context( gateway_stats = None if run_spec.configuration.type == "service" and run_model.gateway_id is not None: - _, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) - gateway_stats = await conn.get_stats(run_model.project.name, run_model.run_name) + gateway_stats = await get_combined_gateway_stats( + session, run_model.gateway_id, run_model.project.name, run_model.run_name + ) return pending.PendingContext( run_model=run_model, @@ -494,8 +495,9 @@ async def _load_active_context( gateway_stats = None if run_spec.configuration.type == "service" and run_model.gateway_id is not None: - _, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) - gateway_stats = await conn.get_stats(run_model.project.name, run_model.run_name) + gateway_stats = await get_combined_gateway_stats( + session, run_model.gateway_id, run_model.project.name, run_model.run_name + ) return active.ActiveContext( run_model=run_model, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py index eece7dfa7c..c9a75e3c71 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py @@ -16,7 +16,7 @@ from dstack._internal.server.background.pipeline_tasks.base import ItemUpdateMap from dstack._internal.server.db import get_session_ctx from dstack._internal.server.services import events -from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.gateways import get_or_add_gateway_connections from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.runs import _get_next_triggered_at, get_run_spec from dstack._internal.utils.common import get_or_error @@ -148,24 +148,37 @@ async def _unregister_service(run_model: models.RunModel) -> Optional[ServiceUnr return None async with get_session_ctx() as session: - gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + gateway, connections = await get_or_add_gateway_connections(session, run_model.gateway_id) gateway_target = events.Target.from_model(gateway) - try: - logger.debug("%s: unregistering service", fmt(run_model)) - async with conn.client() as client: - await client.unregister_service( - project=run_model.project.name, - run_name=run_model.run_name, + gateway_errors = [] + for conn in connections: + try: + logger.debug( + "%s: unregistering service on gateway replica %s", fmt(run_model), conn.ip_address + ) + async with conn.client() as client: + await client.unregister_service( + project=run_model.project.name, + run_name=run_model.run_name, + ) + except GatewayError as e: + # Ignore if the service is not registered on this replica. + logger.warning( + "%s: unregistering service on gateway replica %s: %s", + fmt(run_model), + conn.ip_address, + e, ) + gateway_errors.append(str(e)) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(repr(e)) + + if gateway_errors: + event_message = f"Gateway error when unregistering service: {'; '.join(gateway_errors)}" + else: event_message = "Service unregistered from gateway" - except GatewayError as e: - # Ignore if the service is not registered. - logger.warning("%s: unregistering service: %s", fmt(run_model), e) - event_message = f"Gateway error when unregistering service: {e}" - except (httpx.RequestError, SSHError) as e: - logger.debug("Gateway request failed", exc_info=True) - raise GatewayError(repr(e)) return ServiceUnregistration( event_message=event_message, gateway_target=gateway_target, diff --git a/src/dstack/_internal/server/compatibility/gateways.py b/src/dstack/_internal/server/compatibility/gateways.py new file mode 100644 index 0000000000..3e410b5a9c --- /dev/null +++ b/src/dstack/_internal/server/compatibility/gateways.py @@ -0,0 +1,15 @@ +from typing import Optional + +from packaging.version import Version + +from dstack._internal.core.models.gateways import Gateway + + +def patch_gateway(gateway: Gateway, client_version: Optional[Version]) -> None: + if client_version is None: + return + if client_version < Version("0.20.25") and len(gateway.replicas) < 2: + gateway.instance_id = "" + gateway.ip_address = gateway.replicas[0].hostname if gateway.replicas else "" + if gateway.hostname is None: + gateway.hostname = gateway.ip_address diff --git a/src/dstack/_internal/server/migrations/versions/2026/06_01_1911_b7609b94ea4d_add_gatewaycomputemodel_gateway_id.py b/src/dstack/_internal/server/migrations/versions/2026/06_01_1911_b7609b94ea4d_add_gatewaycomputemodel_gateway_id.py new file mode 100644 index 0000000000..2729699af1 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/06_01_1911_b7609b94ea4d_add_gatewaycomputemodel_gateway_id.py @@ -0,0 +1,52 @@ +"""Add GatewayComputeModel.gateway_id + +Revision ID: b7609b94ea4d +Revises: 201cb7ccd0d3 +Create Date: 2026-06-01 19:11:30.641417+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "b7609b94ea4d" +down_revision = "201cb7ccd0d3" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("gateway_computes", schema=None) as batch_op: + batch_op.add_column( + sa.Column("replica_num", sa.Integer(), server_default="0", nullable=False) + ) + batch_op.add_column( + sa.Column( + "gateway_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.create_foreign_key( + batch_op.f("fk_gateway_computes_gateway_id_gateways"), + "gateways", + ["gateway_id"], + ["id"], + ondelete="SET NULL", + use_alter=True, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("gateway_computes", schema=None) as batch_op: + batch_op.drop_constraint( + batch_op.f("fk_gateway_computes_gateway_id_gateways"), type_="foreignkey" + ) + batch_op.drop_column("gateway_id") + batch_op.drop_column("replica_num") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index d433244ea3..8d6f3c512c 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -629,7 +629,21 @@ class GatewayModel(PipelineModelMixin, BaseModel): gateway_compute_id: Mapped[Optional[uuid.UUID]] = mapped_column( ForeignKey("gateway_computes.id", ondelete="CASCADE") ) - gateway_compute: Mapped[Optional["GatewayComputeModel"]] = relationship() + gateway_compute: Mapped[Optional["GatewayComputeModel"]] = relationship( + foreign_keys=[gateway_compute_id] + ) + """ + Relationship with gateway computes for pre-0.20.25 gateways. + Use `get_gateway_compute_models()` for version-agnostic gateway compute retrieval. + """ + gateway_computes: Mapped[List["GatewayComputeModel"]] = relationship( + back_populates="gateway", + foreign_keys="GatewayComputeModel.gateway_id", + ) + """ + Relationship with gateway computes for 0.20.25+ gateways. + Use `get_gateway_compute_models()` for version-agnostic gateway compute retrieval. + """ runs: Mapped[List["RunModel"]] = relationship(back_populates="gateway") @@ -639,15 +653,26 @@ class GatewayModel(PipelineModelMixin, BaseModel): class GatewayComputeModel(BaseModel): + """A single gateway replica. + **TODO**: consider renaming to `GatewayReplicaModel`. + """ + __tablename__ = "gateway_computes" id: Mapped[uuid.UUID] = mapped_column( UUIDType(binary=False), primary_key=True, default=uuid.uuid4 ) created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + replica_num: Mapped[int] = mapped_column(Integer, server_default="0") instance_id: Mapped[str] = mapped_column(String(100)) ip_address: Mapped[str] = mapped_column(String(100)) + """Gateway replica IP address or domain name (e.g., k8s can use domain names). + **TODO**: rename. + """ hostname: Mapped[Optional[str]] = mapped_column(String(100)) + """Hostname of the gateway's load balancer. + **TODO**: move to `GatewayModel`. + """ configuration: Mapped[Optional[str]] = mapped_column(Text) """`configuration` is optional for compatibility with pre-0.18.2 gateways. Use `get_gateway_compute_configuration` to construct `configuration` for old gateways. @@ -655,6 +680,22 @@ class GatewayComputeModel(BaseModel): backend_data: Mapped[Optional[str]] = mapped_column(Text) region: Mapped[str] = mapped_column(String(100)) + gateway_id: Mapped[Optional[uuid.UUID]] = mapped_column( + ForeignKey( + "gateways.id", + ondelete="SET NULL", + use_alter=True, + ) + ) + gateway: Mapped[Optional["GatewayModel"]] = relationship( + back_populates="gateway_computes", + foreign_keys=[gateway_id], + ) + """ + Gateway. Can be None for pre-0.20.25 gateways, which use GatewayModel.gateway_compute_id to + establish the relationship. + """ + backend_id: Mapped[Optional[uuid.UUID]] = mapped_column( ForeignKey("backends.id", ondelete="CASCADE") ) diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index eee99077c1..6b9a6718dd 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -1,6 +1,7 @@ from typing import List, Optional, Tuple from fastapi import APIRouter, Depends +from packaging.version import Version from sqlalchemy.ext.asyncio import AsyncSession import dstack._internal.core.models.gateways as models @@ -8,6 +9,7 @@ import dstack._internal.server.services.gateways as gateways from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.common import EntityReference +from dstack._internal.server.compatibility.gateways import patch_gateway from dstack._internal.server.db import get_session from dstack._internal.server.deps import Project from dstack._internal.server.models import ProjectModel, UserModel @@ -21,6 +23,7 @@ from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, + get_client_version, ) router = APIRouter( @@ -35,17 +38,19 @@ async def list_gateways( body: Optional[schemas.ListGatewaysRequest] = None, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()), + client_version: Optional[Version] = Depends(get_client_version), ): _, project = user_project if body is None: body = schemas.ListGatewaysRequest() - return CustomORJSONResponse( - await gateways.list_project_gateways( - session=session, - project=project, - include_imported=body.include_imported, - ) + gateway_list = await gateways.list_project_gateways( + session=session, + project=project, + include_imported=body.include_imported, ) + for gateway in gateway_list: + patch_gateway(gateway, client_version) + return CustomORJSONResponse(gateway_list) @router.post("/get", summary="Get gateway", response_model=models.Gateway) @@ -54,6 +59,7 @@ async def get_gateway( session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), project: ProjectModel = Depends(Project()), + client_version: Optional[Version] = Depends(get_client_version), ): await check_can_access_gateway( session=session, user=user, gateway_project=project, gateway_name=body.name @@ -61,6 +67,7 @@ async def get_gateway( gateway = await gateways.get_gateway_by_name(session=session, project=project, name=body.name) if gateway is None: raise ResourceNotExistsError() + patch_gateway(gateway, client_version) return CustomORJSONResponse(gateway) @@ -70,17 +77,18 @@ async def create_gateway( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter), + client_version: Optional[Version] = Depends(get_client_version), ): user, project = user_project - return CustomORJSONResponse( - await gateways.create_gateway( - session=session, - user=user, - project=project, - configuration=body.configuration, - pipeline_hinter=pipeline_hinter, - ) + gateway = await gateways.create_gateway( + session=session, + user=user, + project=project, + configuration=body.configuration, + pipeline_hinter=pipeline_hinter, ) + patch_gateway(gateway, client_version) + return CustomORJSONResponse(gateway) @router.post("/delete", summary="Delete gateways") @@ -118,14 +126,15 @@ async def set_gateway_wildcard_domain( body: schemas.SetWildcardDomainRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), + client_version: Optional[Version] = Depends(get_client_version), ): user, project = user_project - return CustomORJSONResponse( - await gateways.set_gateway_wildcard_domain( - session=session, - project=project, - name=body.name, - wildcard_domain=body.wildcard_domain, - user=user, - ) + gateway = await gateways.set_gateway_wildcard_domain( + session=session, + project=project, + name=body.name, + wildcard_domain=body.wildcard_domain, + user=user, ) + patch_gateway(gateway, client_version) + return CustomORJSONResponse(gateway) diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 287117e2ed..d82d658127 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -10,7 +10,7 @@ import httpx from sqlalchemy import exists, func, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload, selectinload import dstack._internal.utils.random_names as random_names from dstack._internal.core.backends.base.compute import ( @@ -32,15 +32,19 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.gateways import ( + GATEWAY_REPLICAS_DEFAULT, AnyGatewayRouterConfig, Gateway, GatewayComputeConfiguration, GatewayConfiguration, + GatewayReplica, GatewaySpec, GatewayStatus, LetsEncryptGatewayCertificate, ) from dstack._internal.core.services import validate_dstack_resource_name +from dstack._internal.proxy.gateway.const import SERVICE_SCALING_WINDOWS +from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats, Stat from dstack._internal.server import settings from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( @@ -169,6 +173,8 @@ async def create_gateway_compute( project_name: str, backend_compute: Compute, configuration: GatewayConfiguration, + replica_num: int, + gateway_id: Optional[uuid.UUID] = None, backend_id: Optional[uuid.UUID] = None, ) -> GatewayComputeModel: assert isinstance(backend_compute, ComputeWithGatewaySupport) @@ -180,7 +186,7 @@ async def create_gateway_compute( compute_configuration = GatewayComputeConfiguration( project_name=project_name, - instance_name=configuration.name, + instance_name=f"{configuration.name}-{replica_num}", backend=configuration.backend, region=configuration.region, instance_type=configuration.instance_type, @@ -197,7 +203,9 @@ async def create_gateway_compute( ) return GatewayComputeModel( + gateway_id=gateway_id, backend_id=backend_id, + replica_num=replica_num, region=gpd.region, ip_address=gpd.ip_address, instance_id=gpd.instance_id, @@ -467,6 +475,7 @@ async def list_project_gateway_models( stmt = stmt.where(GatewayModel.project_id == project.id) if load_gateway_compute: stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) + stmt = stmt.options(selectinload(GatewayModel.gateway_computes)) if load_backend_type: stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) res = await session.execute(stmt) @@ -495,6 +504,7 @@ async def get_project_gateway_model_by_reference( ) if load_gateway_compute: stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) + stmt = stmt.options(selectinload(GatewayModel.gateway_computes)) if load_backend_type: stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) res = await session.execute(stmt) @@ -529,6 +539,7 @@ async def get_project_gateway_model_by_name_for_update( select(GatewayModel) .where(GatewayModel.id.in_([gateway_id]), *filters) .options(joinedload(GatewayModel.gateway_compute)) + .options(selectinload(GatewayModel.gateway_computes)) .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) .with_for_update(key_share=True, of=GatewayModel) ) @@ -555,6 +566,7 @@ async def get_project_default_gateway_model( ) if load_gateway_compute: stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) + stmt = stmt.options(selectinload(GatewayModel.gateway_computes)) if load_backend_type: stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) res = await session.execute(stmt) @@ -571,30 +583,71 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) -> # TODO: Connect to gateway outside session -async def get_or_add_gateway_connection( +async def get_or_add_gateway_connections( session: AsyncSession, gateway_id: uuid.UUID -) -> tuple[GatewayModel, GatewayConnection]: - gateway = await session.get( - GatewayModel, - gateway_id, - options=[joinedload(GatewayModel.gateway_compute)], - populate_existing=True, +) -> tuple[GatewayModel, List[GatewayConnection]]: + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id == gateway_id) + .options(joinedload(GatewayModel.gateway_compute)) + .options(selectinload(GatewayModel.gateway_computes)) ) + gateway = res.scalar_one_or_none() if gateway is None: raise GatewayError("Gateway not found") - if gateway.gateway_compute is None: + computes = get_gateway_compute_models(gateway) + if not computes: raise GatewayError("Gateway compute not found") + connections: List[GatewayConnection] = [] + for compute in computes: + try: + conn = await gateway_connections_pool.get_or_add( + hostname=compute.ip_address, + id_rsa=compute.ssh_private_key, + ) + connections.append(conn) + except Exception as e: + logger.warning("Failed to connect to gateway %s: %s", compute.ip_address, e) + raise GatewayError("Failed to connect to gateway") + return gateway, connections + + +async def get_combined_gateway_stats( + session: AsyncSession, + gateway_id: uuid.UUID, + project_name: str, + run_name: str, +) -> Optional[PerWindowStats]: + """ + Return stats for *run_name* aggregated across all replicas of *gateway_id*. + """ try: - conn = await gateway_connections_pool.get_or_add( - hostname=gateway.gateway_compute.ip_address, - id_rsa=gateway.gateway_compute.ssh_private_key, - ) - except Exception as e: - logger.warning( - "Failed to connect to gateway %s: %s", gateway.gateway_compute.ip_address, e + _, connections = await get_or_add_gateway_connections(session, gateway_id) + except GatewayError: + return None + per_replica: list[PerWindowStats] = [] + for conn in connections: + stats = await conn.get_stats(project_name, run_name) + if stats is None: # Stats not fetched yet + return None + per_replica.append(stats) + return _merge_per_window_stats(per_replica) if per_replica else None + + +def _merge_per_window_stats(stats_per_gateway_replica: list[PerWindowStats]) -> PerWindowStats: + merged: PerWindowStats = {} + for window in SERVICE_SCALING_WINDOWS: + total_requests = 0 + total_time_of_all_requests = 0.0 + for gateway_replica_stats in stats_per_gateway_replica: + stat = gateway_replica_stats[window] + total_requests += stat.requests + total_time_of_all_requests += stat.requests * stat.request_time + merged[window] = Stat( + requests=total_requests, + request_time=(total_time_of_all_requests / total_requests if total_requests else 0.0), ) - raise GatewayError("Failed to connect to gateway") - return gateway, conn + return merged async def init_gateways(session: AsyncSession): @@ -732,6 +785,14 @@ async def configure_gateway( logger.info("Gateway %s configured", connection.ip_address) +def get_gateway_compute_models(gateway_model: GatewayModel) -> List[GatewayComputeModel]: + if gateway_model.gateway_computes: # 0.20.25+ gateway + return list(gateway_model.gateway_computes) + if gateway_model.gateway_compute is not None: # pre-0.20.25 gateway + return [gateway_model.gateway_compute] + return [] + + def get_gateway_configuration(gateway_model: GatewayModel) -> GatewayConfiguration: if gateway_model.configuration is not None: return GatewayConfiguration.__response__.parse_raw(gateway_model.configuration) @@ -746,22 +807,19 @@ def get_gateway_configuration(gateway_model: GatewayModel) -> GatewayConfigurati def get_gateway_compute_configuration( + gateway_compute: GatewayComputeModel, gateway_model: GatewayModel, -) -> Optional[GatewayComputeConfiguration]: - if gateway_model.gateway_compute is None: - return None - if gateway_model.gateway_compute.configuration is not None: - return GatewayComputeConfiguration.__response__.parse_raw( - gateway_model.gateway_compute.configuration - ) +) -> GatewayComputeConfiguration: + if gateway_compute.configuration is not None: + return GatewayComputeConfiguration.__response__.parse_raw(gateway_compute.configuration) # Handle gateways created before GatewayComputeConfiguration was introduced return GatewayComputeConfiguration( project_name=gateway_model.project.name, - instance_name=gateway_model.gateway_compute.instance_id, + instance_name=gateway_compute.instance_id, backend=gateway_model.backend.type, - region=gateway_model.gateway_compute.region, + region=gateway_compute.region, public_ip=True, - ssh_key_pub=gateway_model.gateway_compute.ssh_public_key, + ssh_key_pub=gateway_compute.ssh_public_key, certificate=LetsEncryptGatewayCertificate(), ) @@ -775,28 +833,34 @@ def gateway_model_to_gateway( default_gateway_id: ID of the default gateway in the project where `gateway_model` is being viewed. Can be different from `gateway_model.project` if the gateway is imported. """ - ip_address = "" - instance_id = "" - hostname = "" - if gateway_model.gateway_compute is not None: - ip_address = gateway_model.gateway_compute.ip_address - instance_id = gateway_model.gateway_compute.instance_id - hostname = gateway_model.gateway_compute.hostname - if hostname is None: - hostname = ip_address backend_type = gateway_model.backend.type if gateway_model.backend.type == BackendType.DSTACK: backend_type = BackendType.AWS is_default = default_gateway_id == gateway_model.id configuration = get_gateway_configuration(gateway_model) configuration.default = is_default + + compute_models = sorted(get_gateway_compute_models(gateway_model), key=lambda c: c.replica_num) + gateway_hostname = None + replicas = [] + for compute in compute_models: + compute_configuration = get_gateway_compute_configuration(compute, gateway_model) + replicas.append( + GatewayReplica( + hostname=compute.ip_address, + replica_num=compute.replica_num, + backend=compute_configuration.backend, + region=compute_configuration.region, + created_at=compute.created_at, + ) + ) + gateway_hostname = compute.hostname + return Gateway( id=gateway_model.id, name=gateway_model.name, project_name=gateway_model.project.name, - ip_address=ip_address, - instance_id=instance_id, - hostname=hostname, + hostname=gateway_hostname, backend=backend_type, region=gateway_model.region, wildcard_domain=gateway_model.wildcard_domain, @@ -805,6 +869,7 @@ def gateway_model_to_gateway( status=gateway_model.status, status_message=gateway_model.status_message, configuration=configuration, + replicas=replicas, ) @@ -838,6 +903,10 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration): f" {[b.value for b in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT]}." ) + replicas = ( + configuration.replicas if configuration.replicas is not None else GATEWAY_REPLICAS_DEFAULT + ) + if configuration.certificate is not None: if configuration.certificate.type == "lets-encrypt" and not configuration.public_ip: raise ServerClientError( @@ -845,3 +914,13 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration): ) if configuration.certificate.type == "acm" and configuration.backend != BackendType.AWS: raise ServerClientError("acm certificate type is supported for aws backend only") + if replicas > 1: + raise ServerClientError( + "Replicated gateways do not support certificates." + " Set either `certificate: null` or `replicas: 1` in the gateway configuration" + ) + + if configuration.router is not None and replicas > 1: + raise ServerClientError( + "The deprecated `router` property is not supported for multi-replica gateways" + ) diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 273054e74f..5f936ef4ae 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -32,8 +32,9 @@ from dstack._internal.server.models import GatewayModel, RunModel from dstack._internal.server.services import events from dstack._internal.server.services.gateways import ( + get_gateway_compute_models, get_gateway_configuration, - get_or_add_gateway_connection, + get_or_add_gateway_connections, get_project_default_gateway_model, get_project_gateway_model_by_reference, ) @@ -100,7 +101,7 @@ async def _register_service_in_gateway( ) -> ServiceSpec: assert run_spec.configuration.type == "service" - if gateway.gateway_compute is None: + if not get_gateway_compute_models(gateway): raise ServerClientError("Gateway has no instance associated with it") if gateway.status != GatewayStatus.RUNNING: @@ -178,50 +179,51 @@ async def _register_service_in_gateway( domain = service_spec.get_domain() assert domain is not None - _, conn = await get_or_add_gateway_connection(session, gateway.id) - try: - logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url) - async with conn.client() as client: - do_register = partial( - client.register_service, - project=run_model.project.name, - run_name=run_model.run_name, - domain=domain, - service_https=configure_service_https, - gateway_https=gateway_https, - auth=run_spec.configuration.auth, - client_max_body_size=settings.DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE, - options=service_spec.options, - rate_limits=run_spec.configuration.rate_limits, - ssh_private_key=run_model.project.ssh_private_key, - has_router_replica=has_replica_group_router, - router=router, - ) - try: - await do_register() - except GatewayError as e: - if e.msg == SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format( - ref=f"{run_model.project.name}/{run_model.run_name}" - ): - # Happens if there was a communication issue with the gateway when last unregistering - logger.warning( - "Service %s/%s is dangling on gateway %s, unregistering and re-registering", - run_model.project.name, - run_model.run_name, - gateway.name, - ) - await client.unregister_service( - project=run_model.project.name, - run_name=run_model.run_name, - ) + _, connections = await get_or_add_gateway_connections(session, gateway.id) + for conn in connections: + try: + logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url) + async with conn.client() as client: + do_register = partial( + client.register_service, + project=run_model.project.name, + run_name=run_model.run_name, + domain=domain, + service_https=configure_service_https, + gateway_https=gateway_https, + auth=run_spec.configuration.auth, + client_max_body_size=settings.DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE, + options=service_spec.options, + rate_limits=run_spec.configuration.rate_limits, + ssh_private_key=run_model.project.ssh_private_key, + has_router_replica=has_replica_group_router, + router=router, + ) + try: await do_register() - else: - raise - except SSHError: - raise ServerClientError("Gateway tunnel is not working") - except httpx.RequestError as e: - logger.debug("Gateway request failed", exc_info=True) - raise GatewayError(f"Gateway is not working: {e!r}") + except GatewayError as e: + if e.msg == SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format( + ref=f"{run_model.project.name}/{run_model.run_name}" + ): + # Happens if there was a communication issue with the gateway when last unregistering + logger.warning( + "Service %s/%s is dangling on gateway replica %s, unregistering and re-registering", + run_model.project.name, + run_model.run_name, + conn.ip_address, + ) + await client.unregister_service( + project=run_model.project.name, + run_name=run_model.run_name, + ) + await do_register() + else: + raise + except SSHError: + raise ServerClientError("Gateway tunnel is not working") + except httpx.RequestError as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(f"Gateway is not working: {e!r}") events.emit( session, diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 6c8b7233f6..2c0a66be5a 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -638,7 +638,6 @@ async def create_gateway( name: str = "test_gateway", region: str = "us", wildcard_domain: Optional[str] = None, - gateway_compute_id: Optional[UUID] = None, status: Optional[GatewayStatus] = GatewayStatus.SUBMITTED, last_processed_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), forbid_new_services: bool = False, @@ -649,7 +648,6 @@ async def create_gateway( name=name, region=region, wildcard_domain=wildcard_domain, - gateway_compute_id=gateway_compute_id, status=status, last_processed_at=last_processed_at, forbid_new_services=forbid_new_services, @@ -661,6 +659,7 @@ async def create_gateway( async def create_gateway_compute( session: AsyncSession, + gateway_id: Optional[UUID] = None, backend_id: Optional[UUID] = None, ip_address: Optional[str] = "1.1.1.1", region: str = "us", @@ -669,6 +668,7 @@ async def create_gateway_compute( ssh_public_key: str = "", ) -> GatewayComputeModel: gateway_compute = GatewayComputeModel( + gateway_id=gateway_id, backend_id=backend_id, ip_address=ip_address, region=region, diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py index d1113c90d1..2759d8a236 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py @@ -6,10 +6,15 @@ import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import selectinload from dstack._internal.core.errors import BackendError -from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.gateways import ( + GatewayConfiguration, + GatewayProvisioningData, + GatewayStatus, +) from dstack._internal.server.background.pipeline_tasks.gateways import ( GatewayFetcher, GatewayPipeline, @@ -257,12 +262,12 @@ async def test_submitted_to_provisioning( res = await session.execute( select(GatewayModel) .where(GatewayModel.id == gateway.id) - .options(joinedload(GatewayModel.gateway_compute)) + .options(selectinload(GatewayModel.gateway_computes)) ) gateway = res.unique().scalar_one() assert gateway.status == GatewayStatus.PROVISIONING - assert gateway.gateway_compute is not None - assert gateway.gateway_compute.ip_address == "2.2.2.2" + assert len(gateway.gateway_computes) > 0 + assert gateway.gateway_computes[0].ip_address == "2.2.2.2" events = await list_events(session) assert len(events) == 1 assert events[0].message == "Gateway status changed SUBMITTED -> PROVISIONING" @@ -300,23 +305,129 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors( assert len(events) == 1 assert events[0].message == "Gateway status changed SUBMITTED -> FAILED (Some error)" + async def test_submitted_creates_multiple_computes_for_multi_replica( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.SUBMITTED, + ) + config = GatewayConfiguration( + name=gateway.name, + backend=BackendType.AWS, + region=gateway.region, + replicas=2, + ) + gateway.configuration = config.json() + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" + ) as m: + aws = Mock() + m.return_value = (backend, aws) + aws.compute.return_value = Mock(spec=ComputeMockSpec) + aws.compute.return_value.create_gateway.side_effect = [ + GatewayProvisioningData(instance_id="i-aaa", ip_address="2.2.2.2", region="us"), + GatewayProvisioningData(instance_id="i-bbb", ip_address="3.3.3.3", region="us"), + ] + await worker.process(_gateway_to_pipeline_item(gateway)) + assert aws.compute.return_value.create_gateway.call_count == 2 + + await session.refresh(gateway) + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id == gateway.id) + .options(selectinload(GatewayModel.gateway_computes)) + ) + gateway = res.unique().scalar_one() + assert gateway.status == GatewayStatus.PROVISIONING + computes = sorted(gateway.gateway_computes, key=lambda c: c.replica_num) + assert len(computes) == 2 + assert computes[0].ip_address == "2.2.2.2" + assert computes[0].replica_num == 0 + assert computes[1].ip_address == "3.3.3.3" + assert computes[1].replica_num == 1 + + async def test_marks_gateway_as_failed_if_second_replica_creation_errors( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.SUBMITTED, + ) + config = GatewayConfiguration( + name=gateway.name, + backend=BackendType.AWS, + region=gateway.region, + replicas=2, + ) + gateway.configuration = config.json() + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" + ) as m: + aws = Mock() + m.return_value = (backend, aws) + aws.compute.return_value = Mock(spec=ComputeMockSpec) + aws.compute.return_value.create_gateway.side_effect = [ + GatewayProvisioningData(instance_id="i-aaa", ip_address="2.2.2.2", region="us"), + BackendError("Some error"), + ] + await worker.process(_gateway_to_pipeline_item(gateway)) + assert aws.compute.return_value.create_gateway.call_count == 2 + + await session.refresh(gateway) + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id == gateway.id) + .options(selectinload(GatewayModel.gateway_computes)) + ) + gateway = res.unique().scalar_one() + assert gateway.status == GatewayStatus.FAILED + assert gateway.status_message == "Some error" + # The first replica's compute is saved even though the second failed + assert len(gateway.gateway_computes) == 1 + assert gateway.gateway_computes[0].ip_address == "2.2.2.2" + assert gateway.gateway_computes[0].replica_num == 0 + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed SUBMITTED -> FAILED (Some error)" + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestGatewayWorkerProvisioning: + @pytest.mark.parametrize("legacy_compute", [False, True]) async def test_provisioning_to_running( - self, test_db, session: AsyncSession, worker: GatewayWorker + self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.PROVISIONING, ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + await create_gateway_compute(session, gateway_id=gateway.id) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) await session.commit() @@ -335,19 +446,57 @@ async def test_provisioning_to_running( assert len(events) == 1 assert events[0].message == "Gateway status changed PROVISIONING -> RUNNING" - async def test_marks_gateway_as_failed_if_fails_to_connect( + async def test_provisioning_to_running_with_multiple_replicas( self, test_db, session: AsyncSession, worker: GatewayWorker ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.PROVISIONING, ) + await create_gateway_compute(session, gateway_id=gateway.id, ip_address="1.1.1.1") + compute1 = await create_gateway_compute( + session, gateway_id=gateway.id, ip_address="2.2.2.2" + ) + compute1.replica_num = 1 + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" + ) as pool_add: + pool_add.return_value = MagicMock() + pool_add.return_value.client.return_value = MagicMock(AsyncContextManager()) + await worker.process(_gateway_to_pipeline_item(gateway)) + assert pool_add.call_count == 2 + + await session.refresh(gateway) + assert gateway.status == GatewayStatus.RUNNING + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed PROVISIONING -> RUNNING" + + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_marks_gateway_as_failed_if_fails_to_connect( + self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + gateway_compute = await create_gateway_compute(session, gateway_id=gateway.id) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) await session.commit() @@ -360,8 +509,55 @@ async def test_marks_gateway_as_failed_if_fails_to_connect( connect_to_gateway_with_retry_mock.assert_called_once() await session.refresh(gateway) + await session.refresh(gateway_compute) assert gateway.status == GatewayStatus.FAILED assert gateway.status_message == "Failed to connect to gateway" + assert gateway_compute.active is False + events = await list_events(session) + assert len(events) == 1 + assert ( + events[0].message + == "Gateway status changed PROVISIONING -> FAILED (Failed to connect to gateway)" + ) + + async def test_marks_gateway_as_failed_if_any_replica_fails_to_connect( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + compute0 = await create_gateway_compute( + session, gateway_id=gateway.id, ip_address="1.1.1.1" + ) + compute1 = await create_gateway_compute( + session, gateway_id=gateway.id, ip_address="2.2.2.2" + ) + compute1.replica_num = 1 + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.gateways.connect_to_gateway_with_retry" + ) as connect_mock: + connect_mock.return_value = None + await worker.process(_gateway_to_pipeline_item(gateway)) + assert connect_mock.call_count == 2 + + await session.refresh(gateway) + assert gateway.status == GatewayStatus.FAILED + assert gateway.status_message == "Failed to connect to gateway" + + await session.refresh(compute0) + await session.refresh(compute1) + assert compute0.active is False + assert compute1.active is False + events = await list_events(session) assert len(events) == 1 assert ( @@ -373,19 +569,25 @@ async def test_marks_gateway_as_failed_if_fails_to_connect( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestGatewayWorkerDeleted: + @pytest.mark.parametrize("legacy_compute", [False, True]) async def test_deletes_gateway_and_marks_compute_deleted( - self, test_db, session: AsyncSession, worker: GatewayWorker + self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + gateway_compute = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id + ) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) gateway.to_be_deleted = True @@ -418,19 +620,25 @@ async def test_deletes_gateway_and_marks_compute_deleted( assert len(events) == 1 assert events[0].message == "Gateway deleted" + @pytest.mark.parametrize("legacy_compute", [False, True]) async def test_keeps_gateway_if_terminate_fails( - self, test_db, session: AsyncSession, worker: GatewayWorker + self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + gateway_compute = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id + ) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) gateway.lock_owner = "GatewayPipeline" @@ -470,3 +678,112 @@ async def test_keeps_gateway_if_terminate_fails( assert gateway_compute.deleted is False events = await list_events(session) assert len(events) == 0 + + async def test_deletes_gateway_with_multiple_replicas( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.RUNNING, + ) + compute0 = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="1.1.1.1" + ) + compute1 = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="2.2.2.2" + ) + compute1.replica_num = 1 + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + gateway.to_be_deleted = True + await session.commit() + + with ( + patch( + "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" + ) as get_backend_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove" + ) as remove_connection_mock, + ): + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + get_backend_mock.return_value = backend_mock + + await worker.process(_gateway_to_pipeline_item(gateway)) + + assert backend_mock.compute.return_value.terminate_gateway.call_count == 2 + assert remove_connection_mock.call_count == 2 + + await session.refresh(compute0) + await session.refresh(compute1) + res = await session.execute(select(GatewayModel.id).where(GatewayModel.id == gateway.id)) + assert res.scalar_one_or_none() is None + assert compute0.active is False + assert compute0.deleted is True + assert compute1.active is False + assert compute1.deleted is True + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway deleted" + + async def test_keeps_gateway_if_second_replica_terminate_fails( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.RUNNING, + ) + compute0 = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="1.1.1.1" + ) + compute1 = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="2.2.2.2" + ) + compute1.replica_num = 1 + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + gateway.lock_owner = "GatewayPipeline" + gateway.to_be_deleted = True + original_last_processed_at = gateway.last_processed_at + await session.commit() + + with ( + patch( + "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" + ) as get_backend_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove" + ) as remove_connection_mock, + ): + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.terminate_gateway.side_effect = [ + None, + BackendError("Terminate failed"), + ] + get_backend_mock.return_value = backend_mock + + await worker.process(_gateway_to_pipeline_item(gateway)) + + assert backend_mock.compute.return_value.terminate_gateway.call_count == 2 + remove_connection_mock.assert_called_once_with(compute0.ip_address) + + await session.refresh(gateway) + await session.refresh(compute0) + await session.refresh(compute1) + assert gateway.to_be_deleted is True + assert gateway.last_processed_at > original_last_processed_at + assert gateway.lock_token is None + assert gateway.lock_expires_at is None + assert gateway.lock_owner is None + assert compute0.deleted is False + assert compute1.deleted is False diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index e308b89ce8..85aa00e0b6 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -1892,19 +1892,19 @@ async def test_registers_service_replica_in_gateway( project = await create_project(session=session, owner=user) repo = await create_repo(session=session, project_id=project.id) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="test-gateway", wildcard_domain="example.com", ) + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ) run = await create_run( session=session, project=project, @@ -1985,19 +1985,19 @@ async def test_registers_service_replica_in_gateway_when_running_on_imported_ins ) repo = await create_repo(session=session, project_id=importer_project.id) backend = await create_backend(session=session, project_id=importer_project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=importer_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="test-gateway", wildcard_domain="example.com", ) + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ) run = await create_run( session=session, project=importer_project, diff --git a/src/tests/_internal/server/compatibility/__init__.py b/src/tests/_internal/server/compatibility/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/server/compatibility/test_gateways.py b/src/tests/_internal/server/compatibility/test_gateways.py new file mode 100644 index 0000000000..4bbd2a80d1 --- /dev/null +++ b/src/tests/_internal/server/compatibility/test_gateways.py @@ -0,0 +1,88 @@ +import uuid +from datetime import datetime, timezone + +from packaging.version import Version + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.gateways import ( + Gateway, + GatewayConfiguration, + GatewayReplica, + GatewayStatus, +) +from dstack._internal.server.compatibility.gateways import patch_gateway +from dstack._internal.utils.common import get_current_datetime + +_CREATED_AT = datetime(2025, 1, 1, tzinfo=timezone.utc) +_CONFIG = GatewayConfiguration(name="gw", backend=BackendType.AWS, region="us") + + +def _make_gateway_replica(hostname: str = "1.2.3.4") -> GatewayReplica: + return GatewayReplica( + hostname=hostname, + replica_num=0, + backend=BackendType.AWS, + region="us", + created_at=get_current_datetime(), + ) + + +def _make_gateway(replicas=None, hostname=None) -> Gateway: + return Gateway( + id=uuid.uuid4(), + name="test", + project_name="proj", + backend=BackendType.AWS, + region="us", + created_at=_CREATED_AT, + status=GatewayStatus.RUNNING, + status_message=None, + hostname=hostname, + wildcard_domain=None, + default=False, + replicas=replicas or [], + configuration=_CONFIG, + ) + + +class TestPatchGateway: + def test_none_version_is_noop(self): + replica = _make_gateway_replica("1.2.3.4") + gw = _make_gateway(replicas=[replica]) + patch_gateway(gw, None) + assert gw.ip_address is None + assert gw.instance_id is None + assert gw.hostname is None + + def test_new_version_is_noop(self): + replica = _make_gateway_replica("1.2.3.4") + gw = _make_gateway(replicas=[replica]) + patch_gateway(gw, Version("0.20.25")) + assert gw.ip_address is None + assert gw.instance_id is None + + def test_old_version_fills_hostname_from_replica(self): + replica = _make_gateway_replica("1.2.3.4") + gw = _make_gateway(replicas=[replica], hostname=None) + patch_gateway(gw, Version("0.20.24")) + assert gw.hostname == "1.2.3.4" + + def test_old_version_keeps_existing_hostname(self): + replica = _make_gateway_replica("1.2.3.4") + gw = _make_gateway(replicas=[replica], hostname="lb.example.com") + patch_gateway(gw, Version("0.20.24")) + assert gw.hostname == "lb.example.com" + + def test_old_version_no_replicas_sets_empty_strings(self): + gw = _make_gateway(replicas=[]) + patch_gateway(gw, Version("0.20.24")) + assert gw.ip_address == "" + assert gw.instance_id == "" + assert gw.hostname == "" + + def test_old_version_multi_replica_is_noop(self): + replicas = [_make_gateway_replica("1.2.3.4"), _make_gateway_replica("5.6.7.8")] + gw = _make_gateway(replicas=replicas) + patch_gateway(gw, Version("0.20.24")) + assert gw.ip_address is None + assert gw.instance_id is None diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 5cc6bdd715..cdc69b5ff9 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import patch import pytest @@ -29,23 +30,29 @@ async def test_returns_40x_if_not_authenticated(self, client: AsyncClient): @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_list( + self, test_db, session: AsyncSession, client: AsyncClient, legacy_compute: bool + ): user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session) await add_project_member( session=session, project=project, user=user, project_role=ProjectRole.USER ) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + gateway_compute = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id + ) + await session.commit() response = await client.post( f"/api/project/{project.name}/gateways/list", headers=get_auth_headers(user.token), @@ -60,9 +67,18 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): "default": False, "status": "submitted", "status_message": None, - "instance_id": gateway_compute.instance_id, - "ip_address": gateway_compute.ip_address, - "hostname": gateway_compute.ip_address, + "replicas": [ + { + "hostname": gateway_compute.ip_address, + "replica_num": 0, + "backend": backend.type.value, + "region": "us", + "created_at": response.json()[0]["replicas"][0]["created_at"], + } + ], + "instance_id": None, + "ip_address": None, + "hostname": None, "name": gateway.name, "region": gateway.region, "wildcard_domain": gateway.wildcard_domain, @@ -78,29 +94,36 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } ] @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_get( + self, test_db, session: AsyncSession, client: AsyncClient, legacy_compute: bool + ): user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session) await add_project_member( session=session, project=project, user=user, project_role=ProjectRole.USER ) backend = await create_backend(session, project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + gateway_compute = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id + ) + await session.commit() response = await client.post( f"/api/project/{project.name}/gateways/get", json={"name": gateway.name}, @@ -115,9 +138,18 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): "default": False, "status": "submitted", "status_message": None, - "instance_id": gateway_compute.instance_id, - "ip_address": gateway_compute.ip_address, - "hostname": gateway_compute.ip_address, + "replicas": [ + { + "hostname": gateway_compute.ip_address, + "replica_num": 0, + "backend": backend.type.value, + "region": "us", + "created_at": response.json()["replicas"][0]["created_at"], + } + ], + "instance_id": None, + "ip_address": None, + "hostname": None, "name": gateway.name, "region": gateway.region, "wildcard_domain": gateway.wildcard_domain, @@ -133,26 +165,60 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_list_non_member_public_project( + async def test_list_legacy_client_populates_compat_fields( self, test_db, session: AsyncSession, client: AsyncClient ): + """Old clients (< 0.20.25) get ip_address/instance_id/hostname back-filled.""" user = await create_user(session, global_role=GlobalRole.USER) - project = await create_project(session, is_public=True) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + ) gateway_compute = await create_gateway_compute( session=session, backend_id=backend.id, + gateway_id=gateway.id, + ) + response = await client.post( + f"/api/project/{project.name}/gateways/list", + headers={**get_auth_headers(user.token), "x-api-version": "0.20.24"}, ) + assert response.status_code == 200 + assert len(response.json()) == 1 + gw = response.json()[0] + assert gw["ip_address"] == gateway_compute.ip_address + assert gw["instance_id"] == "" + assert gw["hostname"] == gateway_compute.ip_address + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_list_non_member_public_project( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session, is_public=True) + backend = await create_backend(session=session, project_id=project.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, + ) + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, ) response = await client.post( f"/api/project/{project.name}/gateways/list", @@ -170,15 +236,15 @@ async def test_get_non_member_public_project( user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session, is_public=True) backend = await create_backend(session, project.id) - gateway_compute = await create_gateway_compute( + gateway = await create_gateway( session=session, + project_id=project.id, backend_id=backend.id, ) - gateway = await create_gateway( + await create_gateway_compute( session=session, - project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, + gateway_id=gateway.id, ) response = await client.post( f"/api/project/{project.name}/gateways/get", @@ -222,14 +288,13 @@ async def test_list_returns_imported_gateway_with_include_imported( project_role=ProjectRole.ADMIN, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -266,14 +331,13 @@ async def test_list_not_returns_imported_gateway_without_include_imported( project_role=ProjectRole.ADMIN, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -308,14 +372,13 @@ async def test_get_returns_imported_gateway( project_role=ProjectRole.ADMIN, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -356,14 +419,13 @@ async def test_get_returns_403_on_foreign_gateway_if_not_imported( project_role=ProjectRole.USER, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -426,9 +488,10 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn "region": "us", "status": "submitted", "status_message": None, - "instance_id": "", - "ip_address": "", - "hostname": "", + "replicas": [], + "instance_id": None, + "ip_address": None, + "hostname": None, "wildcard_domain": None, "default": True, "created_at": response.json()["created_at"], @@ -444,11 +507,43 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } events = await list_events(session) assert events[0].message == "Gateway created. Status: SUBMITTED" + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_create_multi_replica_gateway( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + await create_backend(session, project.id, backend_type=BackendType.AWS) + response = await client.post( + f"/api/project/{project.name}/gateways/create", + json={ + "configuration": { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "replicas": 2, + "certificate": None, + }, + }, + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + assert response.json()["configuration"]["replicas"] == 2 + assert response.json()["replicas"] == [] # populated later by pipelines + events = await list_events(session) + assert events[0].message == "Gateway created. Status: SUBMITTED" + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_create_gateway_without_name( @@ -484,9 +579,10 @@ async def test_create_gateway_without_name( "region": "us", "status": "submitted", "status_message": None, - "instance_id": "", - "ip_address": "", - "hostname": "", + "replicas": [], + "instance_id": None, + "ip_address": None, + "hostname": None, "wildcard_domain": None, "default": True, "created_at": response.json()["created_at"], @@ -502,6 +598,7 @@ async def test_create_gateway_without_name( "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } events = await list_events(session) @@ -583,6 +680,88 @@ async def test_create_gateway_with_invalid_domain_interpolation( ) assert response.status_code == 400 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "configuration, expected_error", + [ + pytest.param( + { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "domain": "${{ run.unknown_variable }}.example.com", + }, + "Cannot interpolate gateway domain name: Failed to interpolate due to missing vars: ['run.unknown_variable']", + id="invalid-domain-interpolation", + ), + pytest.param( + { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "certificate": { + "type": "acm", + "arn": "arn:aws:acm:us-east-1:123456789:certificate/abc", + }, + "replicas": 2, + }, + "Replicated gateways do not support certificates." + " Set either `certificate: null` or `replicas: 1` in the gateway configuration", + id="multi-replica-with-acm-cert", + ), + pytest.param( + { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "certificate": {"type": "lets-encrypt"}, + "replicas": 2, + }, + "Replicated gateways do not support certificates." + " Set either `certificate: null` or `replicas: 1` in the gateway configuration", + id="multi-replica-with-letsencrypt-cert", + ), + pytest.param( + { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "certificate": None, + "router": {"type": "sglang"}, + "replicas": 2, + }, + "The deprecated `router` property is not supported for multi-replica gateways", + id="multi-replica-with-router", + ), + ], + ) + async def test_invalid_configuration_rejected( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + configuration: dict[str, Any], + expected_error: str, + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + await create_backend(session, project.id, backend_type=BackendType.AWS) + response = await client.post( + f"/api/project/{project.name}/gateways/create", + json={"configuration": configuration}, + headers=get_auth_headers(user.token), + ) + assert response.status_code == 400 + assert response.json()["detail"][0]["msg"] == expected_error + class TestDefaultGateway: @pytest.mark.asyncio @@ -613,17 +792,17 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: session=session, project=project, user=user, project_role=ProjectRole.ADMIN ) backend = await create_backend(session, project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="first_gateway", ) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ) response = await client.post( f"/api/project/{project.name}/gateways/set_default", json={"name": gateway.name}, @@ -645,9 +824,18 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: "default": True, "status": "submitted", "status_message": None, - "instance_id": gateway_compute.instance_id, - "ip_address": gateway_compute.ip_address, - "hostname": gateway_compute.ip_address, + "replicas": [ + { + "hostname": gateway_compute.ip_address, + "replica_num": 0, + "backend": backend.type.value, + "region": "us", + "created_at": response.json()["replicas"][0]["created_at"], + } + ], + "instance_id": None, + "ip_address": None, + "hostname": None, "name": gateway.name, "region": gateway.region, "wildcard_domain": gateway.wildcard_domain, @@ -663,23 +851,24 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } events = await list_events(session) assert len(events) == 1 assert events[0].message == "Gateway set as project default" - second_gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) second_gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=second_gateway_compute.id, name="second_gateway", ) + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=second_gateway.id, + ) await clear_events(session) response = await client.post( f"/api/project/{project.name}/gateways/set_default", @@ -775,14 +964,13 @@ async def test_set_imported_gateway_as_default( project_role=ProjectRole.ADMIN, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -820,14 +1008,13 @@ async def test_cannot_set_non_imported_foreign_gateway_as_default( project_role=ProjectRole.ADMIN, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -872,27 +1059,27 @@ async def test_marks_gateways_to_be_deleted( ) backend_aws = await create_backend(session, project.id) backend_gcp = await create_backend(session, project.id, backend_type=BackendType.GCP) - gateway_compute_aws = await create_gateway_compute( - session=session, - backend_id=backend_aws.id, - ) gateway_aws = await create_gateway( session=session, project_id=project.id, backend_id=backend_aws.id, name="gateway-aws", - gateway_compute_id=gateway_compute_aws.id, ) - gateway_compute_gcp = await create_gateway_compute( + gateway_compute_aws = await create_gateway_compute( session=session, - backend_id=backend_gcp.id, + backend_id=backend_aws.id, + gateway_id=gateway_aws.id, ) gateway_gcp = await create_gateway( session=session, project_id=project.id, backend_id=backend_gcp.id, name="gateway-gcp", - gateway_compute_id=gateway_compute_gcp.id, + ) + gateway_compute_gcp = await create_gateway_compute( + session=session, + backend_id=backend_gcp.id, + gateway_id=gateway_gcp.id, ) response = await client.post( f"/api/project/{project.name}/gateways/delete", @@ -991,17 +1178,17 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: session=session, project=project, user=user, project_role=ProjectRole.ADMIN ) backend = await create_backend(session, project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, wildcard_domain="old.example", ) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ) response = await client.post( f"/api/project/{project.name}/gateways/set_wildcard_domain", json={"name": gateway.name, "wildcard_domain": "new.example"}, @@ -1016,9 +1203,18 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "status": "submitted", "status_message": None, "default": False, - "instance_id": gateway_compute.instance_id, - "ip_address": gateway_compute.ip_address, - "hostname": gateway_compute.ip_address, + "replicas": [ + { + "hostname": gateway_compute.ip_address, + "replica_num": 0, + "backend": backend.type.value, + "region": "us", + "created_at": response.json()["replicas"][0]["created_at"], + } + ], + "instance_id": None, + "ip_address": None, + "hostname": None, "name": gateway.name, "region": gateway.region, "wildcard_domain": "new.example", @@ -1034,6 +1230,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } events = await list_events(session) diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 93414c9f42..01bc19f5eb 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -3748,19 +3748,19 @@ async def test_submit_to_correct_proxy( repo = await create_repo(session=session, project_id=project.id) backend = await create_backend(session=session, project_id=project.id) for gateway_name, is_default in existing_gateways: - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name=gateway_name, wildcard_domain=f"{gateway_name}.example", ) + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ) if is_default: project.default_gateway_id = gateway.id await session.commit() @@ -3844,16 +3844,15 @@ async def test_submit_to_foreign_gateway_only_if_imported( session=session, owner=exporter_user, name="exporter-project" ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="exported-gateway", wildcard_domain="exported-gateway.example", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) importer_user = await create_user( session=session, global_role=GlobalRole.USER, name="importer_user" @@ -3929,14 +3928,13 @@ async def test_not_submits_to_default_gateway_if_not_imported( user = await create_user(session=session, global_role=GlobalRole.USER) gateway_project = await create_project(session=session, owner=user, name="gateway-project") backend = await create_backend(session=session, project_id=gateway_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=gateway_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) service_project = await create_project(session=session, owner=user, name="service-project") # The project's default_gateway_id may point to the gateway (e.g., if the gateway was @@ -3982,16 +3980,15 @@ async def test_interpolates_project_name_in_imported_gateway_domain( session=session, owner=exporter_user, name="exporter-project" ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="exported-gateway", wildcard_domain="${{ run.project_name }}.example.com", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) importer_user = await create_user( session=session, global_role=GlobalRole.USER, name="importer_user" @@ -4041,16 +4038,15 @@ async def test_returns_error_if_imported_gateway_domain_has_unknown_variable( session=session, owner=exporter_user, name="exporter-project" ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="exported-gateway", wildcard_domain="${{ run.unknown_variable }}.example.com", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) importer_user = await create_user( session=session, global_role=GlobalRole.USER, name="importer_user" @@ -4108,15 +4104,14 @@ async def test_unregister_dangling_service( ) repo = await create_repo(session=session, project_id=project.id) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, wildcard_domain="example.com", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) project.default_gateway_id = gateway.id await session.commit() @@ -4158,16 +4153,15 @@ async def test_return_error_if_default_gateway_forbids_new_services( ) repo = await create_repo(session=session, project_id=project.id) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, wildcard_domain="example.com", forbid_new_services=True, ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) project.default_gateway_id = gateway.id await session.commit() @@ -4196,17 +4190,16 @@ async def test_return_error_if_explicitly_specified_gateway_forbids_new_services ) repo = await create_repo(session=session, project_id=project.id) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) - await create_gateway( + gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="restricted-gateway", wildcard_domain="example.com", forbid_new_services=True, ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) response = await client.post( "/api/project/test-project/runs/submit", diff --git a/src/tests/_internal/server/services/gateways/__init__.py b/src/tests/_internal/server/services/gateways/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/server/services/gateways/test_gateways.py b/src/tests/_internal/server/services/gateways/test_gateways.py new file mode 100644 index 0000000000..aaf8fe6d52 --- /dev/null +++ b/src/tests/_internal/server/services/gateways/test_gateways.py @@ -0,0 +1,88 @@ +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.proxy.gateway.const import SERVICE_SCALING_WINDOWS +from dstack._internal.proxy.gateway.schemas.stats import Stat +from dstack._internal.server.services.gateways import ( + _merge_per_window_stats, + get_gateway_compute_models, +) +from dstack._internal.server.testing.common import ( + create_backend, + create_gateway, + create_gateway_compute, + create_project, +) + + +class TestMergePerWindowStats: + def test_empty_returns_zero_stats(self): + result = _merge_per_window_stats([]) + for window in SERVICE_SCALING_WINDOWS: + assert result[window].requests == 0 + assert result[window].request_time == 0.0 + + def test_single_replica_returns_same_values(self): + stats = {w: Stat(requests=10, request_time=0.5) for w in SERVICE_SCALING_WINDOWS} + result = _merge_per_window_stats([stats]) + for window in SERVICE_SCALING_WINDOWS: + assert result[window].requests == 10 + assert result[window].request_time == pytest.approx(0.5) + + def test_multiple_replicas_sums_requests_and_averages_time(self): + stats_a = {w: Stat(requests=10, request_time=1.0) for w in SERVICE_SCALING_WINDOWS} + stats_b = {w: Stat(requests=30, request_time=3.0) for w in SERVICE_SCALING_WINDOWS} + result = _merge_per_window_stats([stats_a, stats_b]) + for window in SERVICE_SCALING_WINDOWS: + assert result[window].requests == 40 + assert result[window].request_time == pytest.approx(2.5) # (10*1 + 30*3) / 40 + + def test_zero_requests_across_all_replicas_returns_zero_time(self): + stats_a = {w: Stat(requests=0, request_time=0.0) for w in SERVICE_SCALING_WINDOWS} + stats_b = {w: Stat(requests=0, request_time=0.0) for w in SERVICE_SCALING_WINDOWS} + result = _merge_per_window_stats([stats_a, stats_b]) + for window in SERVICE_SCALING_WINDOWS: + assert result[window].requests == 0 + assert result[window].request_time == 0.0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGetGatewayComputeModels: + async def test_new_style_returns_gateway_computes(self, test_db, session: AsyncSession): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, project_id=project.id, backend_id=backend.id + ) + compute = await create_gateway_compute( + session=session, gateway_id=gateway.id, backend_id=backend.id + ) + await session.refresh(gateway, ["gateway_computes", "gateway_compute"]) + result = get_gateway_compute_models(gateway) + assert len(result) == 1 + assert result[0].id == compute.id + + async def test_old_style_returns_single_compute(self, test_db, session: AsyncSession): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, project_id=project.id, backend_id=backend.id + ) + gateway.gateway_compute_id = compute.id + await session.commit() + await session.refresh(gateway, ["gateway_computes", "gateway_compute"]) + result = get_gateway_compute_models(gateway) + assert len(result) == 1 + assert result[0].id == compute.id + + async def test_no_computes_returns_empty(self, test_db, session: AsyncSession): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, project_id=project.id, backend_id=backend.id + ) + await session.refresh(gateway, ["gateway_computes", "gateway_compute"]) + result = get_gateway_compute_models(gateway) + assert result == [] From ccfec5611f4cbff00cddeaa63c970ed9bbfc68da Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 16 Jun 2026 02:06:01 +0200 Subject: [PATCH 2/2] Review fixes --- mkdocs/docs/concepts/gateways.md | 1 + .../_internal/server/services/gateways/__init__.py | 11 +++++++++++ .../_internal/server/services/services/__init__.py | 2 +- src/tests/_internal/server/routers/test_gateways.py | 12 ++++++++++++ 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/mkdocs/docs/concepts/gateways.md b/mkdocs/docs/concepts/gateways.md index 26374a6751..b71a23d7b6 100644 --- a/mkdocs/docs/concepts/gateways.md +++ b/mkdocs/docs/concepts/gateways.md @@ -224,6 +224,7 @@ $ dstack gateway list - HTTPS is not supported. Use an external load balancer for TLS termination. - An unavailable gateway replica prevents any new services or service replicas from being added. - All replicas are bound to the same backend and region. + - At most 3 replicas are allowed per gateway. !!! info "Reference" For all gateway configuration options, refer to the [reference](../reference/dstack.yml/gateway.md). diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index d82d658127..bfd05cecf9 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -134,6 +134,10 @@ def get_gateway_status_change_message( GATEWAY_CONNECT_DELAY = 10 GATEWAY_CONFIGURE_ATTEMPTS = 50 GATEWAY_CONFIGURE_DELAY = 3 +# Artificial limit to avoid doing too many per-replica operations (gateway replica provisioning, +# service registration, etc) in a single pipeline tick. Can be lifted once the implementation is +# more mature. +GATEWAY_MAX_REPLICAS = 3 # documented in gateways.md, keep in sync async def list_project_gateways( @@ -629,6 +633,8 @@ async def get_combined_gateway_stats( for conn in connections: stats = await conn.get_stats(project_name, run_name) if stats is None: # Stats not fetched yet + # TODO: find a way to make service scaling decisions even if some gateway replicas are + # unavailable for fetching stats. return None per_replica.append(stats) return _merge_per_window_stats(per_replica) if per_replica else None @@ -907,6 +913,11 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration): configuration.replicas if configuration.replicas is not None else GATEWAY_REPLICAS_DEFAULT ) + if replicas > GATEWAY_MAX_REPLICAS: + raise ServerClientError( + f"Cannot provision {replicas} gateway replicas. This server allows at most {GATEWAY_MAX_REPLICAS}" + ) + if configuration.certificate is not None: if configuration.certificate.type == "lets-encrypt" and not configuration.public_ip: raise ServerClientError( diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 5f936ef4ae..b637683af7 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -205,7 +205,7 @@ async def _register_service_in_gateway( if e.msg == SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format( ref=f"{run_model.project.name}/{run_model.run_name}" ): - # Happens if there was a communication issue with the gateway when last unregistering + # Happens if there was a communication issue with the gateway when last (un)registering logger.warning( "Service %s/%s is dangling on gateway replica %s, unregistering and re-registering", run_model.project.name, diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index cdc69b5ff9..075d1f6d4a 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -738,6 +738,18 @@ async def test_create_gateway_with_invalid_domain_interpolation( "The deprecated `router` property is not supported for multi-replica gateways", id="multi-replica-with-router", ), + pytest.param( + { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "certificate": None, + "replicas": 4, + }, + "Cannot provision 4 gateway replicas. This server allows at most 3", + id="replicas-exceed-max", + ), ], ) async def test_invalid_configuration_rejected(