Skip to content
Snippets Groups Projects
Commit 9dc8ebaa authored by Benjamin Wingfield's avatar Benjamin Wingfield
Browse files

fix bucket handling

parent 29854282
No related branches found
No related tags found
No related merge requests found
......@@ -105,7 +105,7 @@ class PolygenicScoreJob(AsyncMachine):
},
]
def __init__(self, id):
def __init__(self, intp_id):
states = [
# a dummy initial state: /launch got POSTed
{"name": States.REQUESTED},
......@@ -120,8 +120,8 @@ class PolygenicScoreJob(AsyncMachine):
{"name": States.FAILED},
]
self.id = id
self.handler = GoogleResourceHandler()
self.intp_id = intp_id
self.handler = GoogleResourceHandler(intp_id=intp_id)
# set up the state machine
super().__init__(
......@@ -135,11 +135,16 @@ class PolygenicScoreJob(AsyncMachine):
async def create_resources(self, event: EventData):
"""Create resources required to start the job"""
print("creating resources")
await self.handler.create_resources(job_model=event.kwargs["job_model"])
try:
await self.handler.create_resources(job_model=event.kwargs["job_model"])
except Exception as e:
logger.warning(f"Something went wrong, {self.intp_id} entering error state")
await self.error()
raise Exception from e
async def destroy_resources(self, event: EventData):
"""Delete all resources associated with this job"""
print(f"deleting all resources: {self.id}")
print(f"deleting all resources: {self.intp_id}")
await self.handler.destroy_resources()
async def notify(self, event):
......@@ -147,4 +152,4 @@ class PolygenicScoreJob(AsyncMachine):
print(f"sending state notification: {self.state}")
def __repr__(self):
return f"{self.__class__.__name__}(id={self.id!r})"
return f"{self.__class__.__name__}(id={self.intp_id!r})"
import asyncio
import pathlib
import shutil
import tempfile
from contextlib import asynccontextmanager
import datetime
import httpx
from fastapi import FastAPI, BackgroundTasks
from fastapi import FastAPI, BackgroundTasks, HTTPException
import logging
import shelve
......@@ -18,20 +19,18 @@ from .logmodels import LogMessage, LogEvent, MonitorMessage, SummaryTrace
logger = logging.getLogger()
logger.setLevel(logging.INFO)
SHELF_PATH = ""
SHELF_LOCK = asyncio.Lock()
TIMEOUT_SECONDS = 1
CLIENT = httpx.AsyncClient()
@asynccontextmanager
async def lifespan(app: FastAPI):
_, SHELF_PATH = tempfile.mkstemp()
logger.info(f"Creating temporary shelf file {SHELF_PATH}")
SHELF_PATH = pathlib.Path(SHELF_PATH)
tempdir = tempfile.mkdtemp()
Config.SHELF_PATH = pathlib.Path(tempdir) / "shelve.dat"
logger.info(f"Created temporary shelf file {Config.SHELF_PATH}")
yield
SHELF_PATH.unlink()
logger.info(f"Cleaned up {SHELF_PATH}")
shutil.rmtree(tempdir)
logger.info(f"Cleaned up {Config.SHELF_PATH}")
# close the connection pool
await CLIENT.aclose()
logger.info("Closed httpx thread pool")
......@@ -43,11 +42,12 @@ app = FastAPI(lifespan=lifespan)
async def launch_job(job_model: JobModel):
"""Background task to create a job, trigger create, and store the job on the shelf"""
id: str = job_model.pipeline_param.id
job_instance: PolygenicScoreJob = PolygenicScoreJob(id=id)
job_instance: PolygenicScoreJob = PolygenicScoreJob(intp_id=id)
await job_instance.create(job_model=job_model, client=CLIENT)
async with SHELF_LOCK:
with shelve.open(SHELF_PATH) as db:
with shelve.open(Config.SHELF_PATH) as db:
db[id] = job_instance
......@@ -55,11 +55,11 @@ async def timeout_job(job_id: str):
"""Background task to check if a job is still on the shelf after a timeout.
If it is, trigger the error state, which will force a cleanup and notify the backend"""
logger.info(f"Async timeout for {TIMEOUT_SECONDS}s started for {job_id}")
await asyncio.sleep(TIMEOUT_SECONDS)
logger.info(f"Async timeout for {Config.TIMEOUT_SECONDS}s started for {job_id}")
await asyncio.sleep(Config.TIMEOUT_SECONDS)
async with SHELF_LOCK:
with shelve.open(SHELF_PATH) as db:
with shelve.open(Config.SHELF_PATH) as db:
job_instance: PolygenicScoreJob = db.get(job_id, None)
if job_instance is not None:
......@@ -74,6 +74,13 @@ async def timeout_job(job_id: str):
@app.post("/launch", status_code=status.HTTP_201_CREATED)
async def launch(job: JobModel, background_tasks: BackgroundTasks):
with shelve.open(Config.SHELF_PATH) as db:
if job.pipeline_param.id in db:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Job {job.pipeline_param.id} already exists",
)
background_tasks.add_task(launch_job, job)
background_tasks.add_task(timeout_job, job.pipeline_param.id)
return {"id": job.pipeline_param.id}
......@@ -111,15 +118,20 @@ async def monitor(message: LogMessage):
async def update_job_state(state, message: MonitorMessage, delete=False):
with shelve.open(SHELF_PATH) as db:
with shelve.open(Config.SHELF_PATH) as db:
job_instance: PolygenicScoreJob = db[message.run_name]
logger.info(f"Triggering state {state}")
await job_instance.trigger(state, client=CLIENT, message=message)
async with SHELF_LOCK:
with shelve.open(SHELF_PATH) as db:
with shelve.open(Config.SHELF_PATH) as db:
if not delete:
db[message.run_name] = job_instance
else:
db.pop(message.run_name)
class Config:
SHELF_PATH = None
TIMEOUT_SECONDS = 60 * 60 * 24
......@@ -13,6 +13,10 @@ logger = logging.getLogger(__name__)
class ResourceHandler(abc.ABC):
@abc.abstractmethod
def __init__(self, intp_id: str):
self.intp_id = intp_id
@abc.abstractmethod
def create_resources(self, job_model: JobModel):
"""Create the compute resources needed to run a job
......@@ -35,6 +39,19 @@ class ResourceHandler(abc.ABC):
class GoogleResourceHandler(ResourceHandler):
dry_run = False
def __init__(
self,
intp_id,
project_id="prj-ext-dev-intervene-413412",
location="europe-west2",
):
super().__init__(intp_id=intp_id.lower())
self.project_id = project_id
self._work_bucket = f"{self.intp_id}-work"
self._results_bucket = f"{self.intp_id}-results"
self._location = location
self._work_bucket_existed_on_create = False
async def create_resources(self, job_model: JobModel):
"""Create some resources to run the job, including:
......@@ -42,9 +59,41 @@ class GoogleResourceHandler(ResourceHandler):
- Render a helm chart
- Run helm install
"""
storage_client = storage.Client(project="prj-ext-dev-intervene-413412")
bucket: storage.bucket.Bucket = storage_client.bucket("intervene-test-bucket")
bucket.storage.bucket.SoftDeletePolicy(bucket, retention_duration_seconds=0)
self.make_buckets(job_model=job_model)
await helm_install(job_model=job_model)
async def destroy_resources(self):
# TODO: if the bucket exists already, we shouldn't destroy it in the error state
await helm_uninstall(
namespace="intervene-dev", release_name="helmvatti-1712756412"
)
self._delete_work_bucket()
def make_buckets(self, job_model: JobModel):
"""Create the buckets needed to run the job"""
self._make_work_bucket(job_model)
self._make_results_bucket(job_model)
def _make_work_bucket(self, job_model: JobModel):
"""Unfortunately google cloud storage doesn't support async
The work bucket has much stricter lifecycle policies than the results bucket
"""
client = storage.Client(project=self.project_id)
bucket: storage.bucket.Bucket = client.bucket(self._work_bucket)
if bucket.exists():
logger.critical(f"Bucket {self._work_bucket} exists!")
logger.critical(
"This bucket won't get cleaned up automatically by the error state"
)
self._work_bucket_existed_on_create = True
raise FileExistsError
bucket.add_lifecycle_abort_incomplete_multipart_upload_rule(age=1)
# these file suffixes are guaranteed to contain sensitive data
bucket.add_lifecycle_delete_rule(
age=1,
matches_suffix=[
......@@ -59,27 +108,76 @@ class GoogleResourceHandler(ResourceHandler):
".gz",
],
)
bucket.create(location="europe-west2")
storage.bucket.SoftDeletePolicy(bucket, retention_duration_seconds=0)
storage_client.create_bucket(bucket, location="europe-west2")
await helm_install(job_model=job_model)
async def destroy_resources(self):
await helm_uninstall(
namespace="intervene-dev", release_name="helmvatti-1712756412"
)
# this is so dumb!
# if you init the SoftDeletePolicy with retention_duration_seconds then it never patches the bucket soft_delete_policy property
# the soft_delete_policy property has no setter
# instead init a minimal SoftDeletePolicy, then use the retention_duration_seconds property to make sure the setter is called and patches the bucket config
# took me way too long to figure this out
soft_policy = storage.bucket.SoftDeletePolicy(bucket)
soft_policy.retention_duration_seconds = 0
iam = storage.bucket.IAMConfiguration(bucket=bucket)
iam.public_access_prevention = "enforced"
bucket.create(location=self._location)
def _make_results_bucket(self, job_model: JobModel):
"""Unfortunately the google storage library doesn't support async"""
client = storage.Client(project=self.project_id)
bucket: storage.bucket.Bucket = client.bucket(self._results_bucket)
if bucket.exists():
logger.critical(f"Bucket {self._results_bucket} exists!")
raise FileExistsError
# results stay live for 7 days
bucket.add_lifecycle_delete_rule(age=7)
bucket.add_lifecycle_abort_incomplete_multipart_upload_rule(age=1)
# don't soft delete, it's annoying
soft_policy = storage.bucket.SoftDeletePolicy(bucket)
soft_policy.retention_duration_seconds = 0
iam = storage.bucket.IAMConfiguration(bucket=bucket)
iam.public_access_prevention = "enforced"
bucket.create(location=self._location)
def _delete_work_bucket(self):
# TODO: what if this is slow? it's not async!
if self._work_bucket_existed_on_create:
# don't delete a bucket that existed before the job was created
# otherwise a bad job will interfere with an existing good job
logger.warning(
"Work bucket existed during creation, so not deleting it to avoid modifying existing jobs"
)
return
client = storage.Client(project=self.project_id)
bucket = client.get_bucket(self._work_bucket)
if not bucket.exists():
logger.info("work bucket not found, so not deleting")
return
blobs = list(bucket.list_blobs())
if len(blobs) > 256:
logger.warning(f"Deleting a very big bucket: {len(blobs)} items")
for blob in blobs:
blob.delete()
logger.info(f"Deleting {bucket}")
bucket.delete(force=True)
async def helm_install(job_model: JobModel):
if GoogleResourceHandler.dry_run:
logger.info("dry run enabled")
dry_run = "--dry-run"
logger.info("{dry_run} enabled")
else:
dry_run = ""
# TODO: add chart path and values file
cmd = "helm" # install -n intervene-dev {dry_run}"
cmd = f"helm # install -n intervene-dev {dry_run}"
proc = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
......@@ -100,10 +198,11 @@ async def helm_render(job_model: JobModel):
async def helm_uninstall(release_name: str, namespace: str):
if GoogleResourceHandler.dry_run:
dry_run = "--dry-run"
logger.info(f"{dry_run} enabled")
else:
dry_run = ""
cmd = f"helm uninstall --namespace {namespace} {dry_run} {release_name}"
cmd = f"helm # uninstall --namespace {namespace} {dry_run} {release_name}"
proc = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment