#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This script will run all tasks in a prefect Flow.
When you add steps to you step workflow be sure to add them to the step list
and configure their IO in the `run` function.
"""
import logging
from datetime import datetime
from pathlib import Path
from typing import Optional
from dask_jobqueue import SLURMCluster
from distributed import LocalCluster
from prefect import Flow
from prefect.engine.executors import DaskExecutor, LocalExecutor
from actk import steps
###############################################################################
log = logging.getLogger(__name__)
###############################################################################
[docs]class All:
def __init__(self):
"""
Set all of your available steps here.
This is only used for data logging operations, not computation purposes.
"""
self.step_list = [
steps.StandardizeFOVArray(),
steps.SingleCellFeatures(),
steps.SingleCellImages(),
steps.DiagnosticSheets(),
]
[docs] def run(
self,
dataset: str,
include_raw: bool = False,
batch_size: Optional[int] = None,
distributed: bool = False,
n_workers: int = 10,
worker_cpu: int = 8,
worker_mem: str = "120GB",
overwrite: bool = False,
debug: bool = False,
**kwargs,
):
"""
Run a flow with your steps.
Parameters
----------
dataset: str
The dataset to use for the pipeline.
include_raw: bool
A boolean option to determine if the raw data should be included in the
Quilt package.
Default: False (Do not include the raw data)
batch_size: Optional[int]
An optional batch size to provide to each step for processing their items.
Default: None (auto batch size depending on CPU / threads available)
distributed: bool
A boolean option to determine if the jobs should be distributed to a SLURM
cluster when possible.
Default: False (Do not distribute)
n_workers: int
Number of workers to request (when distributed is enabled).
Default: 10
worker_cpu: int
Number of cores to provide per worker (when distributed is enabled).
Default: 8
worker_mem: str
Amount of memory to provide per worker (when distributed is enabled).
Default: 120GB
overwrite: bool
If this pipeline has already partially or completely run, should it
overwrite the previous files or not.
Default: False (Do not overwrite or regenerate files)
debug: bool
A debug flag for the developer to use to manipulate how much data runs,
how it is processed, etc. Additionally, if debug is True, any mapped
operation will run on threads instead of processes.
Default: False (Do not debug)
"""
# Initalize steps
raw = steps.Raw()
standardize_fov_array = steps.StandardizeFOVArray()
single_cell_features = steps.SingleCellFeatures()
single_cell_images = steps.SingleCellImages()
diagnostic_sheets = steps.DiagnosticSheets()
# Cluster / distributed defaults
distributed_executor_address = None
# Choose executor
if debug:
exe = LocalExecutor()
log.info("Debug flagged. Will use threads instead of Dask.")
else:
if distributed:
# Create or get log dir
# Do not include ms
log_dir_name = datetime.now().isoformat().split(".")[0]
log_dir = Path(f".dask_logs/{log_dir_name}").expanduser()
# Log dir settings
log_dir.mkdir(parents=True, exist_ok=True)
# Create cluster
log.info("Creating SLURMCluster")
cluster = SLURMCluster(
cores=worker_cpu,
memory=worker_mem,
queue="aics_cpu_general",
walltime="9-23:00:00",
local_directory=str(log_dir),
log_directory=str(log_dir),
)
# Spawn workers
cluster.scale(jobs=n_workers)
log.info("Created SLURMCluster")
# Use the port from the created connector to set executor address
distributed_executor_address = cluster.scheduler_address
# Only auto batch size if it is not None
if batch_size is None:
# Batch size is n_workers * worker_cpu * 0.75
# We could just do n_workers * worker_cpu but 3/4 of that is safer
batch_size = int(n_workers * worker_cpu * 0.75)
# Log dashboard URI
log.info(f"Dask dashboard available at: {cluster.dashboard_link}")
else:
# Create local cluster
log.info("Creating LocalCluster")
cluster = LocalCluster()
log.info("Created LocalCluster")
# Set distributed_executor_address
distributed_executor_address = cluster.scheduler_address
# Log dashboard URI
log.info(f"Dask dashboard available at: {cluster.dashboard_link}")
# Use dask cluster
exe = DaskExecutor(distributed_executor_address)
# Configure your flow
with Flow("actk") as flow:
if include_raw:
dataset = raw(dataset, **kwargs)
standardized_fov_paths_dataset = standardize_fov_array(
dataset=dataset,
distributed_executor_address=distributed_executor_address,
batch_size=batch_size,
overwrite=overwrite,
debug=debug,
# Allows us to pass `--desired_pixel_sizes [{float},{float},{float}]`
**kwargs,
)
single_cell_features_dataset = single_cell_features(
dataset=standardized_fov_paths_dataset,
distributed_executor_address=distributed_executor_address,
batch_size=batch_size,
overwrite=overwrite,
debug=debug,
# Allows us to pass `--cell_ceiling_adjustment {int}`
**kwargs,
)
single_cell_images_dataset = single_cell_images(
dataset=single_cell_features_dataset,
distributed_executor_address=distributed_executor_address,
batch_size=batch_size,
overwrite=overwrite,
debug=debug,
# Allows us to pass `--cell_ceiling_adjustment {int}`
**kwargs,
)
diagnostic_sheets(
dataset=single_cell_images_dataset,
distributed_executor_address=distributed_executor_address,
overwrite=overwrite,
# Allows us to pass `--metadata {str}`,
# `--feature {str}'`
**kwargs,
)
# Run flow and get ending state, log duration
start = datetime.now()
state = flow.run(executor=exe)
duration = datetime.now() - start
log.info(
f"Total duration of pipeline: "
f"{duration.seconds // 60 // 60}:"
f"{duration.seconds // 60}:"
f"{duration.seconds % 60}"
)
# Get and display any outputs you want to see on your local terminal
log.info(single_cell_images_dataset.get_result(state, flow))
[docs] def pull(self):
"""
Pull all steps.
"""
for step in self.step_list:
step.pull()
[docs] def checkout(self):
"""
Checkout all steps.
"""
for step in self.step_list:
step.checkout()
[docs] def push(self):
"""
Push all steps.
"""
for step in self.step_list:
step.push()
[docs] def clean(self):
"""
Clean all steps.
"""
for step in self.step_list:
step.clean()