Skip to content

Module taskcat

taskcat python module

None

View Source
"""

taskcat python module

"""

from ._cfn.stack import Stack  # noqa: F401

from ._cfn.template import Template  # noqa: F401

from ._cli import main  # noqa: F401

from ._config import Config  # noqa: F401

__all__ = ["Stack", "Template", "Config", "main"]

Sub-modules

Functions

main

def main(
    cli_core_class=<class 'taskcat._cli_core.CliCore'>,
    exit_func=<function exit_with_code at 0x7fc54ff0ec10>
)
View Source
def main(cli_core_class=CliCore, exit_func=exit_with_code):

    signal.signal(signal.SIGINT, _sigint_handler)

    log_level = _setup_logging(sys.argv)

    args = sys.argv[1:]

    if not args:

        args.append("-h")

    try:

        _welcome()

        version = get_installed_version()

        cli = cli_core_class(NAME, _cli_modules, DESCRIPTION, version, GLOBAL_ARGS.ARGS)

        cli.parse(args)

        _default_profile = cli.parsed_args.__dict__.get("_profile")

        if _default_profile:

            GLOBAL_ARGS.profile = _default_profile

        cli.run()

    except TaskCatException as e:

        LOG.error(str(e), exc_info=_print_tracebacks(log_level))

        exit_func(1)

    except Exception as e:  # pylint: disable=broad-except

        LOG.error(

            "%s %s", e.__class__.__name__, str(e), exc_info=_print_tracebacks(log_level)

        )

        exit_func(1)

Classes

Config

class Config(
    sources: list,
    uid: uuid.UUID,
    project_root: pathlib.Path
)
View Source
class Config:

    def __init__(self, sources: list, uid: uuid.UUID, project_root: Path):

        self.config = BaseConfig.from_dict(DEFAULTS)

        self.config.set_source("TASKCAT_DEFAULT")

        self.project_root = project_root

        self.uid = uid

        for source in sources:

            config_dict: dict = source["config"]

            source_name: str = source["source"]

            source_config = BaseConfig.from_dict(config_dict)

            source_config.set_source(source_name)

            self.config = BaseConfig.merge(self.config, source_config)

    @classmethod

    # pylint: disable=too-many-locals

    def create(

        cls,

        template_file: Optional[Path] = None,

        args: Optional[dict] = None,

        global_config_path: Path = GENERAL,

        project_config_path: Path = PROJECT,

        overrides_path: Path = OVERRIDES,

        env_vars: Optional[dict] = None,

        project_root: Path = PROJECT_ROOT,

        uid: uuid.UUID = None,

    ) -> "Config":

        uid = uid if uid else uuid.uuid4()

        project_source = cls._get_project_source(

            cls, project_config_path, project_root, template_file

        )

        # general

        legacy_overrides(

            Path("~/.aws/taskcat_global_override.json").expanduser().resolve(),

            global_config_path,

            "global",

        )

        sources = [

            {

                "source": str(global_config_path),

                "config": cls._dict_from_file(global_config_path),

            }

        ]

        # project config file

        if project_source:

            sources.append(project_source)

        # template file

        if isinstance(template_file, Path):

            sources.append(

                {

                    "source": str(template_file),

                    "config": cls._dict_from_template(template_file),

                }

            )

        # override file

        legacy_overrides(

            project_root / "ci/taskcat_project_override.json", overrides_path, "project"

        )

        if overrides_path.is_file():

            overrides = BaseConfig().to_dict()

            with open(str(overrides_path), "r", encoding="utf-8") as file_handle:

                override_params = yaml.safe_load(file_handle)

            overrides["project"]["parameters"] = override_params

            sources.append({"source": str(overrides_path), "config": overrides})

        # environment variables

        sources.append(

            {

                "source": "EnvoronmentVariable",

                "config": cls._dict_from_env_vars(env_vars),

            }

        )

        # cli arguments

        if args:

            sources.append({"source": "CliArgument", "config": args})

        return cls(sources=sources, uid=uid, project_root=project_root)

    # pylint: disable=protected-access,inconsistent-return-statements

    @staticmethod

    def _get_project_source(base_cls, project_config_path, project_root, template_file):

        try:

            return {

                "source": str(project_config_path),

                "config": base_cls._dict_from_file(project_config_path, fail_ok=False),

            }

        except FileNotFoundError as e:

            error = e

            try:

                legacy_conf = parse_legacy_config(project_root)

                return {

                    "source": str(project_root / "ci/taskcat.yml"),

                    "config": legacy_conf.to_dict(),

                }

            except Exception as e:  # pylint: disable=broad-except

                LOG.debug(str(e), exc_info=True)

                if not template_file:

                    # pylint: disable=raise-missing-from

                    raise error

    @staticmethod

    def _dict_from_file(file_path: Path, fail_ok=True) -> dict:

        config_dict = BaseConfig().to_dict()

        if not file_path.is_file() and fail_ok:

            return config_dict

        try:

            with open(str(file_path), "r", encoding="utf-8") as file_handle:

                config_dict = yaml.safe_load(file_handle)

            return config_dict

        except Exception as e:  # pylint: disable=broad-except

            LOG.warning(f"failed to load config from {file_path}")

            LOG.debug(str(e), exc_info=True)

            if not fail_ok:

                raise e

        return config_dict

    @staticmethod

    def _dict_from_template(file_path: Path) -> dict:

        relative_path = str(file_path.relative_to(PROJECT_ROOT))

        config_dict = (

            BaseConfig()

            .from_dict(

                {"project": {"template": relative_path}, "tests": {"default": {}}}

            )

            .to_dict()

        )

        if not file_path.is_file():

            raise TaskCatException(f"invalid template path {file_path}")

        try:

            template = Template(

                str(file_path), template_cache=tcat_template_cache

            ).template

        except Exception as e:

            LOG.warning(f"failed to load template from {file_path}")

            LOG.debug(str(e), exc_info=True)

            raise e

        if not template.get("Metadata"):

            return config_dict

        if not template["Metadata"].get("taskcat"):

            return config_dict

        template_config_dict = template["Metadata"]["taskcat"]

        if not template_config_dict.get("project"):

            template_config_dict["project"] = {}

        template_config_dict["project"]["template"] = relative_path

        if not template_config_dict.get("tests"):

            template_config_dict["tests"] = {"default": {}}

        return template_config_dict

    # pylint: disable=protected-access

    @staticmethod

    def _dict_from_env_vars(

        env_vars: Optional[Union[os._Environ, Dict[str, str]]] = None

    ):

        if env_vars is None:

            env_vars = os.environ

        config_dict: Dict[str, Dict[str, Union[str, bool, int]]] = {}

        for key, value in env_vars.items():

            if key.startswith("TASKCAT_"):

                key = key[8:].lower()

                sub_key = None

                key_section = None

                for section in ["general", "project", "tests"]:

                    if key.startswith(section):

                        sub_key = key[len(section) + 1 :]

                        key_section = section

                if isinstance(sub_key, str) and isinstance(key_section, str):

                    if value.isnumeric():

                        value = int(value)

                    elif value.lower() in ["true", "false"]:

                        value = value.lower() == "true"

                    if not config_dict.get(key_section):

                        config_dict[key_section] = {}

                    config_dict[key_section][sub_key] = value

        return config_dict

    def _get_regions(self, region_parameter_name, test, boto3_cache: Boto3Cache = None):

        if boto3_cache is None:

            boto3_cache = Boto3Cache()

        region_object = {}

        for region in getattr(test, region_parameter_name, []):

            # TODO: comon_utils/determine_profile_for_region

            profile = (

                test.auth.get(region, test.auth.get("default", "default"))

                if test.auth

                else "default"

            )

            region_object[region] = RegionObj(

                name=region,

                account_id=boto3_cache.account_id(profile),

                partition=boto3_cache.partition(profile),

                profile=profile,

                _boto3_cache=boto3_cache,

                taskcat_id=self.uid,

                _role_name=test.role_name,

            )

        return region_object

    def get_regions(self, boto3_cache: Boto3Cache = None):

        region_objects: Dict[str, Dict[str, RegionObj]] = {}

        for test_name, test in self.config.tests.items():

            region_objects[test_name] = self._get_regions("regions", test, boto3_cache)

        return region_objects

    def get_artifact_regions(self, boto3_cache: Boto3Cache = None):

        region_objects: Dict[str, Dict[str, RegionObj]] = {}

        for test_name, test in self.config.tests.items():

            if test.artifact_regions is not None:

                region_objects[test_name] = self._get_regions(

                    "artifact_regions", test, boto3_cache

                )

            else:

                region_objects[test_name] = self._get_regions(

                    "regions", test, boto3_cache

                )

        return region_objects

    def get_buckets(self, boto3_cache: Boto3Cache = None):

        regions = self.get_artifact_regions(boto3_cache)

        bucket_objects: Dict[str, S3BucketObj] = {}

        bucket_mappings: Dict[str, Dict[str, S3BucketObj]] = {}

        for test_name, test in self.config.tests.items():

            bucket_mappings[test_name] = {}

            for region_name, region in regions[test_name].items():

                if test.s3_regional_buckets:

                    bucket_obj = self._create_regional_bucket_obj(

                        bucket_objects, region, test

                    )

                    bucket_objects[f"{region.account_id}{region.name}"] = bucket_obj

                else:

                    bucket_obj = self._create_legacy_bucket_obj(

                        bucket_objects, region, test

                    )

                    bucket_objects[region.account_id] = bucket_obj

                bucket_mappings[test_name][region_name] = bucket_obj

        return bucket_mappings

    def _create_legacy_bucket_obj(self, bucket_objects, region, test):

        new = False

        object_acl = (

            self.config.project.s3_object_acl

            if self.config.project.s3_object_acl

            else "private"

        )

        sigv4 = not self.config.project.s3_enable_sig_v2

        org_id = self.config.project.org_id

        if not test.s3_bucket and not bucket_objects.get(region.account_id):

            name = generate_bucket_name(self.config.project.name)

            auto_generated = True

            new = True

        elif bucket_objects.get(region.account_id):

            name = bucket_objects[region.account_id].name

            auto_generated = bucket_objects[region.account_id].auto_generated

        else:

            name = test.s3_bucket

            auto_generated = False

        bucket_region = self._get_bucket_region_for_partition(region.partition)

        bucket_obj = S3BucketObj(

            name=name,

            region=bucket_region,

            account_id=region.account_id,

            s3_client=region.session.client("s3", region_name=bucket_region),

            auto_generated=auto_generated,

            object_acl=object_acl,

            sigv4=sigv4,

            taskcat_id=self.uid,

            partition=region.partition,

            regional_buckets=test.s3_regional_buckets,

            org_id=org_id,

        )

        if new:

            bucket_obj.create()

        return bucket_obj

    def _create_regional_bucket_obj(self, bucket_objects, region, test):

        _bucket_obj_key = f"{region.account_id}{region.name}"

        new = False

        object_acl = (

            self.config.project.s3_object_acl

            if self.config.project.s3_object_acl

            else "private"

        )

        sigv4 = not self.config.project.s3_enable_sig_v2

        org_id = self.config.project.org_id

        if not test.s3_bucket and not bucket_objects.get(_bucket_obj_key):

            name = generate_regional_bucket_name(region)

            auto_generated = True

            new = True

        elif bucket_objects.get(_bucket_obj_key):

            name = bucket_objects[_bucket_obj_key].name

            auto_generated = bucket_objects[_bucket_obj_key].auto_generated

        else:

            name = f"{test.s3_bucket}-{region.name}"

            auto_generated = False

            try:

                region.client("s3").head_bucket(Bucket=name)

            except ClientError as e:

                if "(404)" in str(e):

                    new = True

                else:

                    raise

        bucket_obj = S3BucketObj(

            name=name,

            region=region.name,

            account_id=region.account_id,

            s3_client=region.session.client("s3", region_name=region.name),

            auto_generated=auto_generated,

            object_acl=object_acl,

            sigv4=sigv4,

            taskcat_id=self.uid,

            partition=region.partition,

            regional_buckets=test.s3_regional_buckets,

            org_id=org_id,

        )

        if new:

            bucket_obj.create()

        return bucket_obj

    @staticmethod

    def _get_bucket_region_for_partition(partition):

        region = "us-east-1"

        if partition == "aws-us-gov":

            region = "us-gov-east-1"

        elif partition == "aws-cn":

            region = "cn-north-1"

        return region

    def get_rendered_parameters(self, bucket_objects, region_objects, template_objects):

        parameters = {}

        template_params = self.get_params_from_templates(template_objects)

        for test_name, test in self.config.tests.items():

            parameters[test_name] = {}

            for region_name in test.regions:

                region_params = template_params[test_name].copy()

                for param_key, param_value in test.parameters.items():

                    if param_key in region_params:

                        region_params[param_key] = param_value

                region = region_objects[test_name][region_name]

                s3bucket = bucket_objects[test_name][region_name]

                parameters[test_name][region_name] = ParamGen(

                    region_params,

                    s3bucket.name,

                    region.name,

                    region.client,

                    self.config.project.name,

                    test_name,

                    test.az_blacklist,

                ).results

        return parameters

    @staticmethod

    def get_params_from_templates(template_objects):

        parameters = {}

        for test_name, template in template_objects.items():

            parameters[test_name] = template.parameters()

        return parameters

    def get_templates(self):

        templates = {}

        for test_name, test in self.config.tests.items():

            templates[test_name] = Template(

                template_path=self.project_root / test.template,

                project_root=self.project_root,

                s3_key_prefix=f"{self.config.project.name}/",

                template_cache=tcat_template_cache,

            )

        return templates

    def get_tests(self, templates, regions, buckets, parameters):

        tests = {}

        for test_name, test in self.config.tests.items():

            region_list = []

            artifact_region_list = []

            tag_list = []

            if test.tags:

                for tag_key, tag_value in test.tags.items():

                    tag_list.append(Tag({"Key": tag_key, "Value": tag_value}))

            for region_obj in regions[test_name].values():

                region_list.append(

                    TestRegion.from_region_obj(

                        region_obj,

                        buckets[test_name][region_obj.name],

                        parameters[test_name][region_obj.name],

                    )

                )

            tests[test_name] = TestObj(

                name=test_name,

                template_path=self.project_root / test.template,

                template=templates[test_name],

                project_root=self.project_root,

                regions=region_list,

                artifact_regions=artifact_region_list,

                tags=tag_list,

                uid=self.uid,

                _project_name=self.config.project.name,

                _shorten_stack_name=self.config.project.shorten_stack_name,

                _stack_name=test.stack_name,

                _stack_name_prefix=test.stack_name_prefix,

                _stack_name_suffix=test.stack_name_suffix,

            )

        return tests

Static methods

create

def create(
    template_file: Union[pathlib.Path, NoneType] = None,
    args: Union[dict, NoneType] = None,
    global_config_path: pathlib.Path = PosixPath('/home/trlindsa/.taskcat.yml'),
    project_config_path: pathlib.Path = PosixPath('/home/trlindsa/git/taskcat/.taskcat.yml'),
    overrides_path: pathlib.Path = PosixPath('/home/trlindsa/git/taskcat/.taskcat_overrides.yml'),
    env_vars: Union[dict, NoneType] = None,
    project_root: pathlib.Path = PosixPath('/home/trlindsa/git/taskcat'),
    uid: uuid.UUID = None
) -> 'Config'
View Source
    @classmethod

    # pylint: disable=too-many-locals

    def create(

        cls,

        template_file: Optional[Path] = None,

        args: Optional[dict] = None,

        global_config_path: Path = GENERAL,

        project_config_path: Path = PROJECT,

        overrides_path: Path = OVERRIDES,

        env_vars: Optional[dict] = None,

        project_root: Path = PROJECT_ROOT,

        uid: uuid.UUID = None,

    ) -> "Config":

        uid = uid if uid else uuid.uuid4()

        project_source = cls._get_project_source(

            cls, project_config_path, project_root, template_file

        )

        # general

        legacy_overrides(

            Path("~/.aws/taskcat_global_override.json").expanduser().resolve(),

            global_config_path,

            "global",

        )

        sources = [

            {

                "source": str(global_config_path),

                "config": cls._dict_from_file(global_config_path),

            }

        ]

        # project config file

        if project_source:

            sources.append(project_source)

        # template file

        if isinstance(template_file, Path):

            sources.append(

                {

                    "source": str(template_file),

                    "config": cls._dict_from_template(template_file),

                }

            )

        # override file

        legacy_overrides(

            project_root / "ci/taskcat_project_override.json", overrides_path, "project"

        )

        if overrides_path.is_file():

            overrides = BaseConfig().to_dict()

            with open(str(overrides_path), "r", encoding="utf-8") as file_handle:

                override_params = yaml.safe_load(file_handle)

            overrides["project"]["parameters"] = override_params

            sources.append({"source": str(overrides_path), "config": overrides})

        # environment variables

        sources.append(

            {

                "source": "EnvoronmentVariable",

                "config": cls._dict_from_env_vars(env_vars),

            }

        )

        # cli arguments

        if args:

            sources.append({"source": "CliArgument", "config": args})

        return cls(sources=sources, uid=uid, project_root=project_root)

get_params_from_templates

def get_params_from_templates(
    template_objects
)
View Source
    @staticmethod

    def get_params_from_templates(template_objects):

        parameters = {}

        for test_name, template in template_objects.items():

            parameters[test_name] = template.parameters()

        return parameters

Methods

get_artifact_regions

def get_artifact_regions(
    self,
    boto3_cache: taskcat._client_factory.Boto3Cache = None
)
View Source
    def get_artifact_regions(self, boto3_cache: Boto3Cache = None):

        region_objects: Dict[str, Dict[str, RegionObj]] = {}

        for test_name, test in self.config.tests.items():

            if test.artifact_regions is not None:

                region_objects[test_name] = self._get_regions(

                    "artifact_regions", test, boto3_cache

                )

            else:

                region_objects[test_name] = self._get_regions(

                    "regions", test, boto3_cache

                )

        return region_objects

get_buckets

def get_buckets(
    self,
    boto3_cache: taskcat._client_factory.Boto3Cache = None
)
View Source
    def get_buckets(self, boto3_cache: Boto3Cache = None):

        regions = self.get_artifact_regions(boto3_cache)

        bucket_objects: Dict[str, S3BucketObj] = {}

        bucket_mappings: Dict[str, Dict[str, S3BucketObj]] = {}

        for test_name, test in self.config.tests.items():

            bucket_mappings[test_name] = {}

            for region_name, region in regions[test_name].items():

                if test.s3_regional_buckets:

                    bucket_obj = self._create_regional_bucket_obj(

                        bucket_objects, region, test

                    )

                    bucket_objects[f"{region.account_id}{region.name}"] = bucket_obj

                else:

                    bucket_obj = self._create_legacy_bucket_obj(

                        bucket_objects, region, test

                    )

                    bucket_objects[region.account_id] = bucket_obj

                bucket_mappings[test_name][region_name] = bucket_obj

        return bucket_mappings

get_regions

def get_regions(
    self,
    boto3_cache: taskcat._client_factory.Boto3Cache = None
)
View Source
    def get_regions(self, boto3_cache: Boto3Cache = None):

        region_objects: Dict[str, Dict[str, RegionObj]] = {}

        for test_name, test in self.config.tests.items():

            region_objects[test_name] = self._get_regions("regions", test, boto3_cache)

        return region_objects

get_rendered_parameters

def get_rendered_parameters(
    self,
    bucket_objects,
    region_objects,
    template_objects
)
View Source
    def get_rendered_parameters(self, bucket_objects, region_objects, template_objects):

        parameters = {}

        template_params = self.get_params_from_templates(template_objects)

        for test_name, test in self.config.tests.items():

            parameters[test_name] = {}

            for region_name in test.regions:

                region_params = template_params[test_name].copy()

                for param_key, param_value in test.parameters.items():

                    if param_key in region_params:

                        region_params[param_key] = param_value

                region = region_objects[test_name][region_name]

                s3bucket = bucket_objects[test_name][region_name]

                parameters[test_name][region_name] = ParamGen(

                    region_params,

                    s3bucket.name,

                    region.name,

                    region.client,

                    self.config.project.name,

                    test_name,

                    test.az_blacklist,

                ).results

        return parameters

get_templates

def get_templates(
    self
)
View Source
    def get_templates(self):

        templates = {}

        for test_name, test in self.config.tests.items():

            templates[test_name] = Template(

                template_path=self.project_root / test.template,

                project_root=self.project_root,

                s3_key_prefix=f"{self.config.project.name}/",

                template_cache=tcat_template_cache,

            )

        return templates

get_tests

def get_tests(
    self,
    templates,
    regions,
    buckets,
    parameters
)
View Source
    def get_tests(self, templates, regions, buckets, parameters):

        tests = {}

        for test_name, test in self.config.tests.items():

            region_list = []

            artifact_region_list = []

            tag_list = []

            if test.tags:

                for tag_key, tag_value in test.tags.items():

                    tag_list.append(Tag({"Key": tag_key, "Value": tag_value}))

            for region_obj in regions[test_name].values():

                region_list.append(

                    TestRegion.from_region_obj(

                        region_obj,

                        buckets[test_name][region_obj.name],

                        parameters[test_name][region_obj.name],

                    )

                )

            tests[test_name] = TestObj(

                name=test_name,

                template_path=self.project_root / test.template,

                template=templates[test_name],

                project_root=self.project_root,

                regions=region_list,

                artifact_regions=artifact_region_list,

                tags=tag_list,

                uid=self.uid,

                _project_name=self.config.project.name,

                _shorten_stack_name=self.config.project.shorten_stack_name,

                _stack_name=test.stack_name,

                _stack_name_prefix=test.stack_name_prefix,

                _stack_name_suffix=test.stack_name_suffix,

            )

        return tests

Stack

class Stack(
    region: taskcat._dataclasses.TestRegion,
    stack_id: str,
    template: taskcat._cfn.template.Template,
    test_name,
    uuid: uuid.UUID = None
)
View Source
class Stack:  # pylint: disable=too-many-instance-attributes

    REMOTE_TEMPLATE_PATH = Path(".taskcat/.remote_templates")

    def __init__(

        self,

        region: TestRegion,

        stack_id: str,

        template: Template,

        test_name,

        uuid: UUID = None,

    ):

        uuid = uuid if uuid else uuid4()

        self.test_name: str = test_name

        self.uuid: UUID = uuid

        self.id: str = stack_id

        self.template: Template = template

        self.name: str = self._get_name()

        self.region: TestRegion = region

        self.region_name = region.name

        self.client: boto3.client = region.client("cloudformation")

        self.completion_time: timedelta = timedelta(0)

        self.role_arn = region.role_arn

        # properties from additional cfn api calls

        self._events: Events = Events()

        self._resources: Resources = Resources()

        self._children: Stacks = Stacks()

        # properties from describe_stacks response

        self.change_set_id: str = ""

        self.parameters: List[Parameter] = []

        self.creation_time: datetime = datetime.fromtimestamp(0)

        self.deletion_time: datetime = datetime.fromtimestamp(0)

        self._status: str = ""

        self.status_reason: str = ""

        self.disable_rollback: bool = False

        self.timeout_in_minutes: int = 0

        self.capabilities: List[str] = []

        self.outputs: List[Output] = []

        self.tags: List[Tag] = []

        self.parent_id: str = ""

        self.root_id: str = ""

        self._launch_succeeded: bool = False

        self._auto_refresh_interval: timedelta = timedelta(seconds=60)

        self._last_event_refresh: datetime = datetime.fromtimestamp(0)

        self._last_resource_refresh: datetime = datetime.fromtimestamp(0)

        self._last_child_refresh: datetime = datetime.fromtimestamp(0)

        self._timer = Timer(self._auto_refresh_interval.total_seconds(), self.refresh)

        self._timer.start()

    def __str__(self):

        return self.id

    def __repr__(self):

        return "<Stack object {} at {}>".format(self.name, hex(id(self)))

    def _get_region(self) -> str:

        return self.id.split(":")[3]

    def _get_name(self) -> str:

        return self.id.split(":")[5].split("/")[1]

    def _auto_refresh(self, last_refresh):

        if datetime.now() - last_refresh > self._auto_refresh_interval:

            return True

        return False

    @property

    def status(self):

        if self._status in StackStatus.COMPLETE:

            if not self.launch_succeeded:

                self._status = "OUT_OF_ORDER_EVENT"

                self.status_reason = (

                    "COMPLETE event not detected. "

                    + "Potential out-of-band action against the stack."

                )

        return self._status

    @status.setter

    def status(self, status):

        _complete = StackStatus.COMPLETE.copy()

        del _complete[_complete.index("DELETE_COMPLETE")]

        self._status = status

        if status in StackStatus.FAILED:

            self._launch_succeeded = False

            return

        if status in _complete:

            self._launch_succeeded = True

            return

        return

    @property

    def launch_succeeded(self):

        return self._launch_succeeded

    @classmethod

    def create(

        cls,

        region: TestRegion,

        stack_name: str,

        template: Template,

        tags: List[Tag] = None,

        disable_rollback: bool = True,

        test_name: str = "",

        uuid: UUID = None,

    ) -> "Stack":

        parameters = cls._cfn_format_parameters(region.parameters)

        uuid = uuid if uuid else uuid4()

        cfn_client = region.client("cloudformation")

        tags = [t.dump() for t in tags] if tags else []

        template = Template(

            template_path=template.template_path,

            project_root=template.project_root,

            s3_key_prefix=template.s3_key_prefix,

            url=s3_url_maker(

                region.s3_bucket.name,

                template.s3_key,

                region.client("s3"),

                region.s3_bucket.auto_generated,

            ),

            template_cache=tcat_template_cache,

        )

        create_options = {

            "StackName": stack_name,

            "TemplateURL": template.url,

            "Parameters": parameters,

            "DisableRollback": disable_rollback,

            "Tags": tags,

            "Capabilities": Capabilities.ALL,

        }

        if region.role_arn:

            create_options["RoleARN"] = region.role_arn

        stack_id = cfn_client.create_stack(**create_options)["StackId"]

        stack = cls(region, stack_id, template, test_name, uuid)

        # fetch property values from cfn

        stack.refresh()

        return stack

    @staticmethod

    def _cfn_format_parameters(parameters):

        return [{"ParameterKey": k, "ParameterValue": v} for k, v in parameters.items()]

    @classmethod

    def _import_child(  # pylint: disable=too-many-locals

        cls, stack_properties: dict, parent_stack: "Stack"

    ) -> Optional["Stack"]:

        try:

            url = ""

            for event in parent_stack.events():

                if (

                    event.physical_id == stack_properties["StackId"]

                    and event.properties

                ):

                    url = event.properties["TemplateURL"]

            if url.startswith(parent_stack.template.url_prefix()):

                # Template is part of the project, discovering path

                relative_path = url.replace(

                    parent_stack.template.url_prefix(), ""

                ).lstrip("/")

                absolute_path = parent_stack.template.project_root / relative_path

                if not absolute_path.is_file():

                    # try with the base folder stripped off

                    relative_path2 = Path(relative_path)

                    relative_path2 = relative_path2.relative_to(

                        *relative_path2.parts[:1]

                    )

                    absolute_path = parent_stack.template.project_root / relative_path2

                if not absolute_path.is_file():

                    LOG.warning(

                        f"Failed to find template for child stack "

                        f"{stack_properties['StackId']}. tried "

                        f"{parent_stack.template.project_root / relative_path}"

                        f" and {absolute_path}"

                    )

                    return None

            else:

                # Assuming template is remote to project and downloading it

                cfn_client = parent_stack.client

                tempate_body = cfn_client.get_template(

                    StackName=stack_properties["StackId"]

                )["TemplateBody"]

                path = parent_stack.template.project_root / Stack.REMOTE_TEMPLATE_PATH

                os.makedirs(path, exist_ok=True)

                fname = (

                    "".join(

                        random.choice(string.ascii_lowercase)  # nosec

                        for _ in range(16)

                    )

                    + ".template"

                )

                absolute_path = path / fname

                if not isinstance(tempate_body, str):

                    tempate_body = ordered_dump(tempate_body, dumper=yaml.SafeDumper)

                if not absolute_path.exists():

                    with open(absolute_path, "w", encoding="utf-8") as fh:

                        fh.write(tempate_body)

            template = Template(

                template_path=str(absolute_path),

                project_root=parent_stack.template.project_root,

                url=url,

                template_cache=tcat_template_cache,

            )

            stack = cls(

                parent_stack.region,

                stack_properties["StackId"],

                template,

                parent_stack.name,

                parent_stack.uuid,

            )

            stack.set_stack_properties(stack_properties)

        except Exception as e:  # pylint: disable=broad-except

            LOG.warning(f"Failed to import child stack: {str(e)}")

            LOG.debug("traceback:", exc_info=True)

            return None

        return stack

    @classmethod

    def import_existing(

        cls,

        stack_properties: dict,

        template: Template,

        region: TestRegion,

        test_name: str,

        uid: UUID,

    ) -> "Stack":

        stack = cls(region, stack_properties["StackId"], template, test_name, uid)

        stack.set_stack_properties(stack_properties)

        return stack

    def refresh(

        self,

        properties: bool = True,

        events: bool = False,

        resources: bool = False,

        children: bool = False,

    ) -> None:

        if properties:

            self.set_stack_properties()

        if events:

            self._fetch_stack_events()

            self._last_event_refresh = datetime.now()

        if resources:

            self._fetch_stack_resources()

            self._last_resource_refresh = datetime.now()

        if children:

            self._fetch_children()

            self._last_child_refresh = datetime.now()

    def set_stack_properties(self, stack_properties: Optional[dict] = None) -> None:

        # TODO: get time to complete for complete stacks and % complete

        props: dict = stack_properties if stack_properties else {}

        self._timer.cancel()

        if not props:

            describe_stacks = self.client.describe_stacks

            props = describe_stacks(StackName=self.id)["Stacks"][0]

        iterable_props: List[Tuple[str, Callable]] = [

            ("Parameters", Parameter),

            ("Outputs", Output),

            ("Tags", Tag),

        ]

        for prop_name, prop_class in iterable_props:

            for item in props.get(prop_name, []):

                item = prop_class(item)

                self._merge_props(getattr(self, prop_name.lower()), item)

        for key, value in props.items():

            if key in [p[0] for p in iterable_props]:  # noqa: C412

                continue

            key = pascal_to_snake(key).replace("stack_", "")

            setattr(self, key, value)

        if self.status in StackStatus.IN_PROGRESS:

            self._timer = Timer(

                self._auto_refresh_interval.total_seconds(), self.refresh

            )

            self._timer.start()

    @staticmethod

    def _merge_props(existing_props, new):

        added = False

        for existing_id, prop in enumerate(existing_props):

            if prop.key == new.key:

                existing_props[existing_id] = new

                added = True

        if not added:

            existing_props.append(new)

    def events(self, refresh: bool = False, include_generic: bool = True) -> Events:

        if refresh or not self._events or self._auto_refresh(self._last_event_refresh):

            self._fetch_stack_events()

        events = self._events

        if not include_generic:

            events = Events([event for event in events if not self._is_generic(event)])

        return events

    @staticmethod

    def _is_generic(event: Event) -> bool:

        generic = False

        for regex in GENERIC_ERROR_PATTERNS:

            if re.search(regex, event.status_reason):

                generic = True

        return generic

    def _fetch_stack_events(self) -> None:

        self._last_event_refresh = datetime.now()

        events = Events()

        for page in self.client.get_paginator("describe_stack_events").paginate(

            StackName=self.id

        ):

            for event in page["StackEvents"]:

                events.append(Event(event))

        self._events = events

    def resources(self, refresh: bool = False) -> Resources:

        if (

            refresh

            or not self._resources

            or self._auto_refresh(self._last_resource_refresh)

        ):

            self._fetch_stack_resources()

        return self._resources

    def _fetch_stack_resources(self) -> None:

        self._last_resource_refresh = datetime.now()

        resources = Resources()

        for page in self.client.get_paginator("list_stack_resources").paginate(

            StackName=self.id

        ):

            for resource in page["StackResourceSummaries"]:

                resources.append(Resource(self.id, resource, self.test_name, self.uuid))

        self._resources = resources

    @staticmethod

    def delete(client, stack_id) -> None:

        client.delete_stack(StackName=stack_id)

        LOG.info(f"Deleting stack: {stack_id}")

    def update(self, *args, **kwargs):

        raise NotImplementedError("Stack updates not implemented")

    def _fetch_children(self) -> None:

        self._last_child_refresh = datetime.now()

        for page in self.client.get_paginator("describe_stacks").paginate():

            for stack in page["Stacks"]:

                if self._children.filter(id=stack["StackId"]):

                    continue

                if "ParentId" in stack.keys():

                    if self.id == stack["ParentId"]:

                        stack_obj = Stack._import_child(stack, self)

                        if stack_obj:

                            self._children.append(stack_obj)

    def children(self, refresh=False) -> Stacks:

        if (

            refresh

            or not self._children

            or self._auto_refresh(self._last_child_refresh)

        ):

            self._fetch_children()

        return self._children

    def descendants(self, refresh=False) -> Stacks:

        if refresh or not self._children:

            self._fetch_children()

        def recurse(stack: Stack, descendants: Stacks = None) -> Stacks:

            descendants = descendants if descendants else Stacks()

            if stack.children(refresh=refresh):

                descendants += stack.children()

                for child in stack.children():

                    descendants = recurse(child, descendants)

            return descendants

        return recurse(self)

    def error_events(

        self, recurse: bool = True, include_generic: bool = False, refresh=False

    ) -> Events:

        errors = Events()

        stacks = Stacks([self])

        if recurse:

            stacks += self.descendants()

        for stack in stacks:

            for status in StackStatus.FAILED:

                errors += stack.events(

                    refresh=refresh, include_generic=include_generic

                ).filter({"status": status})

        return errors

Class variables

REMOTE_TEMPLATE_PATH

Static methods

create

def create(
    region: taskcat._dataclasses.TestRegion,
    stack_name: str,
    template: taskcat._cfn.template.Template,
    tags: List[taskcat._dataclasses.Tag] = None,
    disable_rollback: bool = True,
    test_name: str = '',
    uuid: uuid.UUID = None
) -> 'Stack'
View Source
    @classmethod

    def create(

        cls,

        region: TestRegion,

        stack_name: str,

        template: Template,

        tags: List[Tag] = None,

        disable_rollback: bool = True,

        test_name: str = "",

        uuid: UUID = None,

    ) -> "Stack":

        parameters = cls._cfn_format_parameters(region.parameters)

        uuid = uuid if uuid else uuid4()

        cfn_client = region.client("cloudformation")

        tags = [t.dump() for t in tags] if tags else []

        template = Template(

            template_path=template.template_path,

            project_root=template.project_root,

            s3_key_prefix=template.s3_key_prefix,

            url=s3_url_maker(

                region.s3_bucket.name,

                template.s3_key,

                region.client("s3"),

                region.s3_bucket.auto_generated,

            ),

            template_cache=tcat_template_cache,

        )

        create_options = {

            "StackName": stack_name,

            "TemplateURL": template.url,

            "Parameters": parameters,

            "DisableRollback": disable_rollback,

            "Tags": tags,

            "Capabilities": Capabilities.ALL,

        }

        if region.role_arn:

            create_options["RoleARN"] = region.role_arn

        stack_id = cfn_client.create_stack(**create_options)["StackId"]

        stack = cls(region, stack_id, template, test_name, uuid)

        # fetch property values from cfn

        stack.refresh()

        return stack

delete

def delete(
    client,
    stack_id
) -> None
View Source
    @staticmethod

    def delete(client, stack_id) -> None:

        client.delete_stack(StackName=stack_id)

        LOG.info(f"Deleting stack: {stack_id}")

import_existing

def import_existing(
    stack_properties: dict,
    template: taskcat._cfn.template.Template,
    region: taskcat._dataclasses.TestRegion,
    test_name: str,
    uid: uuid.UUID
) -> 'Stack'
View Source
    @classmethod

    def import_existing(

        cls,

        stack_properties: dict,

        template: Template,

        region: TestRegion,

        test_name: str,

        uid: UUID,

    ) -> "Stack":

        stack = cls(region, stack_properties["StackId"], template, test_name, uid)

        stack.set_stack_properties(stack_properties)

        return stack

Instance variables

launch_succeeded
status

Methods

children

def children(
    self,
    refresh=False
) -> taskcat._cfn.stack.Stacks
View Source
    def children(self, refresh=False) -> Stacks:

        if (

            refresh

            or not self._children

            or self._auto_refresh(self._last_child_refresh)

        ):

            self._fetch_children()

        return self._children

descendants

def descendants(
    self,
    refresh=False
) -> taskcat._cfn.stack.Stacks
View Source
    def descendants(self, refresh=False) -> Stacks:

        if refresh or not self._children:

            self._fetch_children()

        def recurse(stack: Stack, descendants: Stacks = None) -> Stacks:

            descendants = descendants if descendants else Stacks()

            if stack.children(refresh=refresh):

                descendants += stack.children()

                for child in stack.children():

                    descendants = recurse(child, descendants)

            return descendants

        return recurse(self)

error_events

def error_events(
    self,
    recurse: bool = True,
    include_generic: bool = False,
    refresh=False
) -> taskcat._cfn.stack.Events
View Source
    def error_events(

        self, recurse: bool = True, include_generic: bool = False, refresh=False

    ) -> Events:

        errors = Events()

        stacks = Stacks([self])

        if recurse:

            stacks += self.descendants()

        for stack in stacks:

            for status in StackStatus.FAILED:

                errors += stack.events(

                    refresh=refresh, include_generic=include_generic

                ).filter({"status": status})

        return errors

events

def events(
    self,
    refresh: bool = False,
    include_generic: bool = True
) -> taskcat._cfn.stack.Events
View Source
    def events(self, refresh: bool = False, include_generic: bool = True) -> Events:

        if refresh or not self._events or self._auto_refresh(self._last_event_refresh):

            self._fetch_stack_events()

        events = self._events

        if not include_generic:

            events = Events([event for event in events if not self._is_generic(event)])

        return events

refresh

def refresh(
    self,
    properties: bool = True,
    events: bool = False,
    resources: bool = False,
    children: bool = False
) -> None
View Source
    def refresh(

        self,

        properties: bool = True,

        events: bool = False,

        resources: bool = False,

        children: bool = False,

    ) -> None:

        if properties:

            self.set_stack_properties()

        if events:

            self._fetch_stack_events()

            self._last_event_refresh = datetime.now()

        if resources:

            self._fetch_stack_resources()

            self._last_resource_refresh = datetime.now()

        if children:

            self._fetch_children()

            self._last_child_refresh = datetime.now()

resources

def resources(
    self,
    refresh: bool = False
) -> taskcat._cfn.stack.Resources
View Source
    def resources(self, refresh: bool = False) -> Resources:

        if (

            refresh

            or not self._resources

            or self._auto_refresh(self._last_resource_refresh)

        ):

            self._fetch_stack_resources()

        return self._resources

set_stack_properties

def set_stack_properties(
    self,
    stack_properties: Union[dict, NoneType] = None
) -> None
View Source
    def set_stack_properties(self, stack_properties: Optional[dict] = None) -> None:

        # TODO: get time to complete for complete stacks and % complete

        props: dict = stack_properties if stack_properties else {}

        self._timer.cancel()

        if not props:

            describe_stacks = self.client.describe_stacks

            props = describe_stacks(StackName=self.id)["Stacks"][0]

        iterable_props: List[Tuple[str, Callable]] = [

            ("Parameters", Parameter),

            ("Outputs", Output),

            ("Tags", Tag),

        ]

        for prop_name, prop_class in iterable_props:

            for item in props.get(prop_name, []):

                item = prop_class(item)

                self._merge_props(getattr(self, prop_name.lower()), item)

        for key, value in props.items():

            if key in [p[0] for p in iterable_props]:  # noqa: C412

                continue

            key = pascal_to_snake(key).replace("stack_", "")

            setattr(self, key, value)

        if self.status in StackStatus.IN_PROGRESS:

            self._timer = Timer(

                self._auto_refresh_interval.total_seconds(), self.refresh

            )

            self._timer.start()

update

def update(
    self,
    *args,
    **kwargs
)
View Source
    def update(self, *args, **kwargs):

        raise NotImplementedError("Stack updates not implemented")

Template

class Template(
    template_path: Union[str, pathlib.Path],
    project_root: Union[str, pathlib.Path] = '',
    url: str = '',
    s3_key_prefix: str = '',
    template_cache: taskcat._cfn.template.TemplateCache = <taskcat._cfn.template.TemplateCache object at 0x7fc5662911c0>
)
View Source
class Template:

    def __init__(

        self,

        template_path: Union[str, Path],

        project_root: Union[str, Path] = "",

        url: str = "",

        s3_key_prefix: str = "",

        template_cache: TemplateCache = tcat_template_cache,

    ):

        self.template_cache = template_cache

        self.template_path: Path = Path(template_path).expanduser().resolve()

        self.template = self.template_cache.get(str(self.template_path))

        with open(template_path, "r", encoding="utf-8") as file_handle:

            self.raw_template = file_handle.read()

        project_root = (

            project_root if project_root else self.template_path.parent.parent

        )

        self.project_root = Path(project_root).expanduser().resolve()

        self.url = url

        self._s3_key_prefix = s3_key_prefix

        self.children: List[Template] = []

        self._find_children()

    def __str__(self):

        return str(self.template)

    def __repr__(self):

        return f"<Template {self.template_path} at {hex(id(self))}>"

    @property

    def s3_key(self):

        suffix = str(self.template_path.relative_to(self.project_root).as_posix())

        return self._s3_key_prefix + suffix

    @property

    def s3_key_prefix(self):

        return self._s3_key_prefix

    @property

    def linesplit(self):

        return self.raw_template.split("\n")

    def write(self):

        """writes raw_template back to file, and reloads decoded template, useful if

        the template has been modified"""

        with open(str(self.template_path), "w", encoding="utf-8") as file_handle:

            file_handle.write(self.raw_template)

        self.template = cfnlint.decode.cfn_yaml.load(self.template_path)

        self._find_children()

    def _template_url_to_path(

        self,

        template_url,

        template_mappings=None,

    ):

        try:

            helper = StackURLHelper(

                template_mappings=template_mappings,

                template_parameters=self.template.get("Parameters"),

            )

            urls = helper.template_url_to_path(

                current_template_path=self.template_path, template_url=template_url

            )

            if len(urls) > 0:

                return urls[0]

        except Exception as e:  # pylint: disable=broad-except

            LOG.debug("Traceback:", exc_info=True)

            LOG.error("TemplateURL parsing error: %s " % str(e))

        LOG.warning(

            "Failed to discover path for %s, path %s does not exist",

            template_url,

            None,

        )

        return ""

    def _get_relative_url(self, path: str) -> str:

        suffix = str(path).replace(str(self.project_root), "")

        url = self.url_prefix() + suffix

        return url

    def url_prefix(self) -> str:

        if not self.url:

            return ""

        regionless_url = re.sub(

            r"\.s3\.(.*)\.amazonaws\.com",

            ".s3.amazonaws.com",

            self.url,

        )

        suffix = str(self.template_path).replace(str(self.project_root), "")

        suffix_length = len(suffix.lstrip("/").split("/"))

        url_prefix = "/".join(regionless_url.split("/")[0:-suffix_length])

        return url_prefix

    def _find_children(self) -> None:  # noqa: C901

        children = set()

        if "Resources" not in self.template:

            raise TaskCatException(

                f"did not receive a valid template: {self.template_path} does not "

                f"have a Resources section"

            )

        for resource in self.template["Resources"].keys():

            resource = self.template["Resources"][resource]

            if resource["Type"] == "AWS::CloudFormation::Stack":

                child_name = self._template_url_to_path(

                    template_url=resource["Properties"]["TemplateURL"],

                )

                # print(child_name)

                if child_name:

                    # for child_url in child_name:

                    children.add(child_name)

        for child in children:

            child_template_instance = None

            for descendent in self.descendents:

                if str(descendent.template_path) == str(child):

                    child_template_instance = descendent

            if not child_template_instance:

                try:

                    child_template_instance = Template(

                        child,

                        self.project_root,

                        self._get_relative_url(child),

                        self._s3_key_prefix,

                        tcat_template_cache,

                    )

                except Exception:  # pylint: disable=broad-except

                    LOG.debug("Traceback:", exc_info=True)

                    LOG.error(f"Failed to add child template {child}")

            if isinstance(child_template_instance, Template):

                self.children.append(child_template_instance)

    @property

    def descendents(self) -> List["Template"]:

        desc_map = {}

        def recurse(template):

            for child in template.children:

                desc_map[str(child.template_path)] = child

                recurse(child)

        recurse(self)

        return list(desc_map.values())

    def parameters(

        self,

    ) -> Dict[str, Union[None, str, int, bool, List[Union[int, str]]]]:

        parameters = {}

        for param_key, param in self.template.get("Parameters", {}).items():

            parameters[param_key] = param.get("Default")

        return parameters

Instance variables

descendents
linesplit
s3_key
s3_key_prefix

Methods

parameters

def parameters(
    self
) -> Dict[str, Union[NoneType, str, int, bool, List[Union[str, int]]]]
View Source
    def parameters(

        self,

    ) -> Dict[str, Union[None, str, int, bool, List[Union[int, str]]]]:

        parameters = {}

        for param_key, param in self.template.get("Parameters", {}).items():

            parameters[param_key] = param.get("Default")

        return parameters

url_prefix

def url_prefix(
    self
) -> str
View Source
    def url_prefix(self) -> str:

        if not self.url:

            return ""

        regionless_url = re.sub(

            r"\.s3\.(.*)\.amazonaws\.com",

            ".s3.amazonaws.com",

            self.url,

        )

        suffix = str(self.template_path).replace(str(self.project_root), "")

        suffix_length = len(suffix.lstrip("/").split("/"))

        url_prefix = "/".join(regionless_url.split("/")[0:-suffix_length])

        return url_prefix

write

def write(
    self
)

writes raw_template back to file, and reloads decoded template, useful if

the template has been modified

View Source
    def write(self):

        """writes raw_template back to file, and reloads decoded template, useful if

        the template has been modified"""

        with open(str(self.template_path), "w", encoding="utf-8") as file_handle:

            file_handle.write(self.raw_template)

        self.template = cfnlint.decode.cfn_yaml.load(self.template_path)

        self._find_children()