Source code for datastep.step

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import getpass
import inspect
import json
import logging
import os
import warnings
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Union

import botocore
import git
import pandas as pd
import prefect
import quilt3
from prefect import Flow, Task

from . import constants, exceptions, file_utils, get_module_version, quilt_utils

###############################################################################

log = logging.getLogger(__name__)

###############################################################################


# decorator for run that logs non default args and kwargs to file
[docs]def log_run_params(func): @wraps(func) def wrapper(self, *args, **kwargs): # Get the params for the function, not the wrapper params = inspect.signature(func).bind(self, *args, **kwargs).arguments params.pop("self") # In the case the operation is happening in a distributed fashion # Always make the local staging dir prior to run self.step_local_staging_dir.mkdir(parents=True, exist_ok=True) parameter_store = self.step_local_staging_dir / "run_parameters.json" # Dump run params with open(parameter_store, "w") as write_out: json.dump(params, write_out, default=str) log.debug(f"Stored params for run at: {parameter_store}") # Check if we want to clean the step local staging prior to run # If the user has defined clean in their run function it will be in # top level params, if they haven't it will be in the kwargs if "clean" in params and params["clean"]: file_utils._clean(self.step_local_staging_dir) log.info(f"Cleaned directory: {self.step_local_staging_dir}") elif "kwargs" in params: if "clean" in params["kwargs"] and params["kwargs"]["clean"]: file_utils._clean(self.step_local_staging_dir) log.info(f"Cleaned directory: {self.step_local_staging_dir}") return func(self, *args, **kwargs) return wrapper
[docs]class Step(Task): """ A class for creating "pure function" steps in a DAG. This object's sole purpose is to handle and enforce data logging tied to code using Quilt. It manages to do this data logging through heavy utilization of a local staging directory and supporting files such as initialization parameters, a manifest CSV / Parquet that you can use to store the files you will want to send to Quilt. However, as a part of the problem with stepwise workflows is their dependents on upstream data is hard to manage, the more you rely on this object the easier those upstream dependecies become. As if your upstream data dependecies are generated by other Step modules, then you can place them in the downstream Step as "direct_upstream_tasks" and use the `Step.pull` function to retrieve their data. Parameters ---------- step_name: Optional[str] A name for this step. Default: the lowercased version of the inheriting object name filepath_columns: List[str] In the final manifest CSV / Parquet you generate, which columns store filepaths. Default: ["filepath"] metadata_columns: List[str] In the final manifest CSV / Parquet you generate, which columns store metadata. Default: [] direct_upstream_tasks: List[Step] If you need data for this task to run, and that data was generated by another Step object you can place references to those objects here and during the pull method this Step will retrieve the required data. config: Optional[Union[str, Path, Dict[str, str]]] A path or dictionary detailing the entire workflow config. Refer to `datastep.constants` for details on workflow config defaults. """ def _unpack_config(self, config: Optional[Union[str, Path, Dict[str, str]]] = None): # If not provided, check for other places the config could live if config is None: # Check environment if constants.CONFIG_ENV_VAR_NAME in os.environ: config = os.environ[constants.CONFIG_ENV_VAR_NAME] # Check current working directory else: cwd = Path().resolve() cwd_files = [str(f.name) for f in cwd.iterdir()] # Attach config file name to cwd path if constants.CWD_CONFIG_FILE_NAME in cwd_files: config = cwd / constants.CWD_CONFIG_FILE_NAME # Config should now either be path to JSON, Dict, or None if isinstance(config, (str, Path)): # Resolve path config = file_utils.resolve_filepath(config) # Read config with open(config, "r") as read_in: config = json.load(read_in) # Config should now either have been provided as a dict, parsed, or None if isinstance(config, dict): # Get or default storage bucket config["quilt_storage_bucket"] = config.get( "quilt_storage_bucket", constants.DEFAULT_QUILT_STORAGE ) # Get or default package owner config["quilt_package_owner"] = config.get( "quilt_package_owner", constants.DEFAULT_QUILT_PACKAGE_OWNER ) # Get or default project local staging config["project_local_staging_dir"] = file_utils.resolve_directory( config.get( "project_local_staging_dir", constants.DEFAULT_PROJECT_LOCAL_STAGING_DIR.format(cwd="."), ), make=True, strict=False, ) # Get or default step local staging if self.step_name in config: config[self.step_name][ "step_local_staging_dir" ] = file_utils.resolve_directory( config[self.step_name].get( "step_local_staging_dir", f"{config['project_local_staging_dir'] / self.step_name}", ), make=True, strict=False, ) else: # Step name wasn't in the config, add it as a key to a further dict config[self.step_name] = {} config[self.step_name][ "step_local_staging_dir" ] = file_utils.resolve_directory( f"{config['project_local_staging_dir'] / self.step_name}", make=True, strict=False, ) # Get or default quilt package name config["quilt_package_name"] = file_utils._sanitize_name( config.get("quilt_package_name", self.__module__.split(".")[0]) ) log.debug(f"Unpacked config: {config}") else: # Log debug message indicating using defaults log.debug("Using default project and step configuration.") # Construct config dictionary object config = { "quilt_storage_bucket": constants.DEFAULT_QUILT_STORAGE, "quilt_package_owner": constants.DEFAULT_QUILT_PACKAGE_OWNER, "quilt_package_name": self.__module__.split(".")[0], "project_local_staging_dir": file_utils.resolve_directory( constants.DEFAULT_PROJECT_LOCAL_STAGING_DIR.format(cwd="."), make=True, strict=False, ), self.step_name: { "step_local_staging_dir": file_utils.resolve_directory( constants.DEFAULT_STEP_LOCAL_STAGING_DIR.format( cwd=".", module_name=self.step_name ), make=True, strict=False, ) }, } # Set object properties from config self._storage_bucket = config["quilt_storage_bucket"] self._quilt_package_owner = config["quilt_package_owner"] self._quilt_package_name = config["quilt_package_name"] self._project_local_staging_dir = config["project_local_staging_dir"] self._step_local_staging_dir = config[self.step_name]["step_local_staging_dir"] return config def __init__( self, step_name: Optional[str] = None, filepath_columns: List[str] = ["filepath"], metadata_columns: List[str] = [], direct_upstream_tasks: List["Step"] = [], config: Optional[Union[str, Path, Dict[str, str]]] = None, ): # Run super prefect Task init super().__init__() # Set step name as attributes if not None self._step_name = ( file_utils._sanitize_name(step_name) if step_name is not None else self.__class__.__name__.lower() ) # Set kwargs as attributes self._upstream_tasks = direct_upstream_tasks self.filepath_columns = filepath_columns self.metadata_columns = metadata_columns # Prepare locals to be stored for data logging params = locals() params["step_name"] = self._step_name params.pop("self") params.pop("__class__") # Unpack config into param log dict params["config"] = self._unpack_config(config) # Store current version of datastep in initialization parameters params["__version__"] = get_module_version() # Write out initialization params for data logging parameter_store = self.step_local_staging_dir / "init_parameters.json" with open(parameter_store, "w") as write_out: json.dump(params, write_out, default=str) log.debug(f"Stored params for run at: {parameter_store}") # Attempt to read a previously written manifest produced by this step m_path = Path(self.step_local_staging_dir) # Check if a prior manifest exists if (m_path / "manifest.parquet").is_file(): m_path = m_path / "manifest.parquet" self.manifest = pd.read_parquet(m_path) log.debug(f"Read previously produced manifest from file: {m_path}") elif (m_path / "manifest.csv").is_file(): m_path = m_path / "manifest.csv" self.manifest = pd.read_csv(m_path) log.debug(f"Read previously produced manifest from file: {m_path}") else: self.manifest = None log.debug(f"No previous manifest found. Checked path: {m_path}") # Set name for prefect task retrieval self.name = self.step_name # Prior to any operation log where we are operating log.info( f"{self.step_name} will use step local staging directory: " f"{self.step_local_staging_dir}" ) @property def step_name(self) -> str: """ Return the name of this step as a string. """ return self._step_name @property def upstream_tasks(self) -> List[str]: warnings.warn( "To enforce that there is no reliance on object state during run " "functions, the upstream_tasks property will be deprecated on the " "next datastep release.", PendingDeprecationWarning, ) return self._upstream_tasks @property def storage_bucket(self) -> str: warnings.warn( "To enforce that there is no reliance on object state during run " "functions, the storage_bucket property will be deprecated on the " "next datastep release.", PendingDeprecationWarning, ) return self._storage_bucket @property def project_local_staging_dir(self) -> Path: warnings.warn( "To enforce that there is no reliance on object state during run " "functions, the project_local_staging_dir property will be deprecated " "on the next datastep release.", PendingDeprecationWarning, ) return self._project_local_staging_dir @property def step_local_staging_dir(self) -> Path: """ A preconfigured directory for you to store output files in. Can be specifically set using a workflow_config.json file. """ return self._step_local_staging_dir @property def quilt_package_name(self) -> str: warnings.warn( "To enforce that there is no reliance on object state during run " "functions, the quilt_package_name property will be deprecated on the " "next datastep release.", PendingDeprecationWarning, ) return self._package_name @property def quilt_package_owner(self) -> str: warnings.warn( "To enforce that there is no reliance on object state during run " "functions, the quilt_package_owner property will be deprecated on the " "next datastep release.", PendingDeprecationWarning, ) return self._quilt_package_owner
[docs] def run( self, distributed_executor_address: Optional[str] = None, clean: bool = False, debug: bool = False, **kwargs, ) -> Any: """ Run a pure function. There are a few "protected" parameters that are the following: Parameters ---------- distributed_executor_address: Optional[str] An optional executor address to pass to some computation engine. clean: bool Should the local staging directory be cleaned prior to this run. Default: False (Do not clean) debug: bool A debug flag for the developer to use to manipulate how much data runs, how it is processed, etc. Default: False (Do not debug) Returns ------- result: Any A pickable object or value that is the result of any processing you do. """ # Your code here # # The `self.step_local_staging_dir` is exposed to save files in # # The user should set `self.manifest` to a dataframe of absolute paths that # point to the created files and each files metadata # # By default, `self.filepath_columns` is ["filepath"], but should be edited # if there are more than a single column of filepaths # # By default, `self.metadata_columns` is [], but should be edited to include # any columns that should be parsed for metadata and attached to objects # # The user should not rely on object state to retrieve results from prior steps. # I.E. do not call use the attribute self.upstream_tasks to retrieve data. # Pass the required path to a directory of files, the path to a prior manifest, # or in general, the exact parameters required for this function to run. return
[docs] def get_result(self, state: prefect.engine.state.State, flow: Flow) -> Any: """ Get the result of this step. Parameters ---------- state: prefect.engine.state.State The final state object of a prefect flow produced by running the flow. flow: prefect.core.flow.Flow The flow that ran this step. Returns ------- result: Any The resulting object from running this step in a flow. Notes ----- This will always return the first item that matches this step. What this means for the user is that if this step was used in a mapped task, you would only recieve the result of the first iteration of that map. Generally though, you shouldn't be using these steps in mapped tasks. (It's on our to-do list...) """ return state.result[flow.get_tasks(name=self.step_name)[0]].result
[docs] def pull(self, data_version: Optional[str] = None, bucket: Optional[str] = None): """ Pull all upstream data dependecies using the list of upstream steps. Parameters ---------- data_version: Optional[str] Request a specific version of the upstream data. Default: 'latest' for all upstreams bucket: Optional[str] Request data from a specific bucket different from the bucket defined by your workflow_config.json or the defaulted bucket. """ # Resolve None bucket if bucket is None: bucket = self._storage_bucket # Run checkout for each upstream for UpstreamTask in self._upstream_tasks: upstream_task = UpstreamTask() upstream_task.checkout(data_version=data_version, bucket=bucket)
@staticmethod def _get_current_git_branch() -> str: repo = git.Repo(Path(".").expanduser().resolve()) return repo.active_branch.name @staticmethod def _check_git_status_is_clean(push_target: str) -> Optional[Exception]: # This will throw an error if the current working directory is not a git repo repo = git.Repo(Path(".").expanduser().resolve()) current_branch = repo.active_branch.name # Check current git status if repo.is_dirty() or len(repo.untracked_files) > 0: dirty_files = [f.b_path for f in repo.index.diff(None)] all_changed_files = repo.untracked_files + dirty_files raise exceptions.InvalidGitStatus( f"Push to '{push_target}' was rejected because the current git " f"status of this branch ({current_branch}) is not clean. " f"Check files: {all_changed_files}." ) # Check that current hash is the same as remote head hash # Check that the current branch has even been pushed to origin origin_branches = [b.name for b in repo.remotes.origin.refs] if f"origin/{current_branch}" not in origin_branches: raise exceptions.InvalidGitStatus( f"Push to '{push_target}' was rejected because the current git " f"branch was not found on the origin." ) # Origin has current branch, check for matching commit hash else: # Find matching origin branch for origin_branch in repo.remotes.origin.refs: if origin_branch.name == f"origin/{current_branch}": matching_origin_branch = origin_branch break # Check git commit hash match if matching_origin_branch.commit.hexsha != repo.head.object.hexsha: raise exceptions.InvalidGitStatus( f"Push to '{push_target}' was rejected because the current git " f"commit has not been pushed to {matching_origin_branch.name}" ) @staticmethod def _create_data_commit_message() -> str: # This will throw an error if the current working directory is not a git repo repo = git.Repo(Path(".").expanduser().resolve()) current_branch = repo.active_branch.name return ( f"data created from code repo {repo.remotes.origin.url} on branch " f"{current_branch} at commit {repo.head.object.hexsha}" ) @staticmethod def _get_git_origin_url() -> str: # This will throw an error if the current working directory is not a git repo repo = git.Repo(Path(".").expanduser().resolve()) # Get origin info origin = repo.remotes.origin # If there is a @ character this was setup with ssh if "@" in origin.url: url = origin.url.split("@")[1].replace(":", "/").replace(".git", "") return f"https://{url}" else: return origin.url.replace(".git", "") @staticmethod def _get_current_git_commit_hash() -> str: # This will throw an error if the current working directory is not a git repo repo = git.Repo(Path(".").expanduser().resolve()) return repo.head.object.hexsha
[docs] def manifest_filepaths_rel2abs(self): """ Convert manifest filepaths to absolute paths. Useful for after you pull data from a remote bucket. """ self.manifest = file_utils.manifest_filepaths_rel2abs( self.manifest, self.filepath_columns, self.step_local_staging_dir )
[docs] def manifest_filepaths_abs2rel(self): """ Convert manifest filepaths to relative paths. Useful for when you are ready to upload to a remote bucket. """ self.manifest = file_utils.manifest_filepaths_abs2rel( self.manifest, self.filepath_columns, self.step_local_staging_dir )
[docs] def checkout( self, data_version: Optional[str] = None, bucket: Optional[str] = None ): """ Pull data previously generated by a run of this step. Parameters ---------- data_version: Optional[str] Request a specific version of the prior generated data. Default: 'latest' bucket: Optional[str] Request data from a specific bucket different from the bucket defined by your workflow_config.json or the defaulted bucket. """ # Resolve None bucket if bucket is None: bucket = self._storage_bucket # Get current git branch current_branch = self._get_current_git_branch() # Normalize branch name # This is to stop quilt from making extra directories from names like: # feature/some-feature current_branch = current_branch.replace("/", ".") # Checkout this step's output from quilt # Check for files on this branch and default to master # Browse top level project package quilt_loc = f"{self._quilt_package_owner}/{self._quilt_package_name}" p = quilt3.Package.browse(quilt_loc, bucket, top_hash=data_version) # Check to see if step data exists on this branch in quilt try: quilt_branch_step = f"{current_branch}/{self.step_name}" p[quilt_branch_step] # If not, use the version on master except KeyError: quilt_branch_step = f"master/{self.step_name}" p[quilt_branch_step] # Fetch the data and save it to the local staging dir p[quilt_branch_step].fetch(self.step_local_staging_dir)
[docs] def push(self, bucket: Optional[str] = None): """ Push the most recently generated data. Parameters ---------- bucket: Optional[str] Push data to a specific bucket different from the bucket defined by your workflow_config.json or the defaulted bucket. Notes ----- If your git status isn't clean, or you haven't commited and pushed to origin, any attempt to push data will be rejected. """ # Check if manifest is None if self.manifest is None: raise exceptions.PackagingError( "No manifest found to construct package with." ) # Resolve None bucket if bucket is None: bucket = self._storage_bucket # Get current git branch current_branch = self._get_current_git_branch() # Normalize branch name # This is to stop quilt from making extra directories from names like: # feature/some-feature current_branch = current_branch.replace("/", ".") # Resolve push target quilt_loc = f"{self._quilt_package_owner}/{self._quilt_package_name}" push_target = f"{quilt_loc}/{current_branch}/{self.step_name}" # Check git status is clean self._check_git_status_is_clean(push_target) # Construct the package step_pkg, relative_manifest = quilt_utils.create_package( manifest=self.manifest, step_pkg_root=self.step_local_staging_dir, filepath_columns=self.filepath_columns, metadata_columns=self.metadata_columns, ) # Add the relative manifest and generated README to the package with TemporaryDirectory() as tempdir: # Store the relative manifest in a temporary directory m_path = Path(tempdir) / "manifest.parquet" relative_manifest.to_parquet(m_path) step_pkg.set("manifest.parquet", m_path) # Add the params files to the package for param_file in ["run_parameters.json", "init_parameters.json"]: param_file_path = self.step_local_staging_dir / param_file step_pkg.set(param_file, param_file_path) # Generate README readme_path = Path(tempdir) / "README.md" with open(readme_path, "w") as write_readme: write_readme.write( constants.README_TEMPLATE.render( quilt_package_name=self._quilt_package_name, source_url=self._get_git_origin_url(), branch_name=self._get_current_git_branch(), commit_hash=self._get_current_git_commit_hash(), creator=getpass.getuser(), ) ) step_pkg.set("README.md", readme_path) # Browse top level project package and add / overwrite to it in step dir try: project_pkg = quilt3.Package.browse(quilt_loc, self._storage_bucket) except botocore.errorfactory.ClientError: log.info( f"Could not find existing package: {quilt_loc} " f"in bucket: {self._storage_bucket}. " f"Creating a new package." ) project_pkg = quilt3.Package() # Regardless of if we found a prior version of the package or starting from # a new package, we "merge" them together to place this steps data in the # correct location. # Remove the current step if it exists in the previous project package if current_branch in project_pkg.keys(): if self.step_name in project_pkg[current_branch].keys(): project_pkg = project_pkg.delete( f"{current_branch}/{self.step_name}" ) # Merge packages for (logical_key, pkg_entry) in step_pkg.walk(): project_pkg.set( f"{current_branch}/{self.step_name}/{logical_key}", pkg_entry ) # Push the data project_pkg.push( quilt_loc, registry=self._storage_bucket, message=self._create_data_commit_message(), )
[docs] def clean(self) -> str: """ Completely reset this steps local staging directory by removing all previously generated files. """ file_utils._clean(self.step_local_staging_dir)
def __str__(self): return ( f"<{self.step_name} [ " f"upstream_tasks: {self._upstream_tasks}, " f"storage_bucket: '{self._storage_bucket}', " f"project_local_staging_dir: '{self._project_local_staging_dir}', " f"step_local_staging_dir: '{self.step_local_staging_dir}' " f"]>" ) def __repr__(self): return str(self)