Skip to content

Commit 7537a57

Browse files
committed
refactor to add documentation, clarify variable names, add test cases, and better encapsulate behaviors (among other things to simplify testing)
1 parent 4cfb11f commit 7537a57

File tree

5 files changed

+301
-90
lines changed

5 files changed

+301
-90
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
build
44
dist
55
*.egg-info
6+
.idea

ssm-diff

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,53 @@
11
#!/usr/bin/env python
22
from __future__ import print_function
3-
from states import *
4-
import states.helpers as helpers
3+
54
import argparse
65
import os
76

7+
from states import states
8+
from states.helpers import DiffResolver
9+
10+
11+
def configure_endpoints(args):
12+
# pre-configure resolver, but still accept remote and local at runtime
13+
diff_resolver = DiffResolver.configure(force=args.force)
14+
return states.RemoteState(args.profile, diff_resolver, paths=args.path), states.LocalState(args.filename, paths=args.path)
15+
816

917
def init(args):
10-
r, l = RemoteState(args.profile), LocalState(args.filename)
11-
l.save(r.get(flat=False, paths=args.path))
18+
"""Create a local YAML file from the SSM Parameter Store (per configs in args)"""
19+
remote, local = configure_endpoints(args)
20+
local.save(remote.clone())
1221

1322

1423
def pull(args):
15-
dictfilter = lambda x, y: dict([ (i,x[i]) for i in x if i in set(y) ])
16-
r, l = RemoteState(args.profile), LocalState(args.filename)
17-
diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path))
18-
if args.force:
19-
ref_set = diff.changed().union(diff.removed()).union(diff.unchanged())
20-
target_set = diff.added()
21-
else:
22-
ref_set = diff.unchanged().union(diff.removed())
23-
target_set = diff.added().union(diff.changed())
24-
state = dictfilter(diff.ref, ref_set)
25-
state.update(dictfilter(diff.target, target_set))
26-
l.save(helpers.unflatten(state))
24+
"""Update local YAML file with changes in the SSM Parameter Store (per configs in args)"""
25+
remote, local = configure_endpoints(args)
26+
local.save(remote.pull(local.get()))
2727

2828

2929
def apply(args):
30-
r, _, diff = plan(args)
31-
30+
"""Apply local changes to the SSM Parameter Store"""
31+
remote, local = configure_endpoints(args)
3232
print("\nApplying changes...")
3333
try:
34-
r.apply(diff)
34+
remote.push(local.get())
3535
except Exception as e:
3636
print("Failed to apply changes to remote:", e)
3737
print("Done.")
3838

3939

4040
def plan(args):
41-
r, l = RemoteState(args.profile), LocalState(args.filename)
42-
diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path))
41+
"""Print a representation of the changes that would be applied to SSM Parameter Store if applied (per config in args)"""
42+
remote, local = configure_endpoints(args)
43+
diff = remote.dry_run(local.get())
4344

4445
if diff.differ:
45-
diff.print_state()
46+
print(diff.describe_diff())
4647
else:
4748
print("Remote state is up to date.")
4849

49-
return r, l, diff
50+
return remote, local, diff
5051

5152

5253
if __name__ == "__main__":

states/helpers.py

Lines changed: 76 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,101 @@
1-
from termcolor import colored
2-
from copy import deepcopy
31
import collections
2+
from copy import deepcopy
3+
from functools import partial
44

5+
from termcolor import colored
56

6-
class FlatDictDiffer(object):
7-
def __init__(self, ref, target):
8-
self.ref, self.target = ref, target
9-
self.ref_set, self.target_set = set(ref.keys()), set(target.keys())
10-
self.isect = self.ref_set.intersection(self.target_set)
7+
8+
class DiffResolver(object):
9+
"""Determines diffs between two dicts, where the remote copy is considered the baseline"""
10+
def __init__(self, remote, local, force=False):
11+
self.remote_flat, self.local_flat = self._flatten(remote), self._flatten(local)
12+
self.remote_set, self.local_set = set(self.remote_flat.keys()), set(self.local_flat.keys())
13+
self.intersection = self.remote_set.intersection(self.local_set)
14+
self.force = force
1115

1216
if self.added() or self.removed() or self.changed():
1317
self.differ = True
1418
else:
1519
self.differ = False
1620

21+
@classmethod
22+
def configure(cls, *args, **kwargs):
23+
return partial(cls, *args, **kwargs)
24+
1725
def added(self):
18-
return self.target_set - self.isect
26+
"""Returns a (flattened) dict of added leaves i.e. {"full/path": value, ...}"""
27+
return self.local_set - self.intersection
1928

2029
def removed(self):
21-
return self.ref_set - self.isect
30+
"""Returns a (flattened) dict of removed leaves i.e. {"full/path": value, ...}"""
31+
return self.remote_set - self.intersection
2232

2333
def changed(self):
24-
return set(k for k in self.isect if self.ref[k] != self.target[k])
34+
"""Returns a (flattened) dict of changed leaves i.e. {"full/path": value, ...}"""
35+
return set(k for k in self.intersection if self.remote_flat[k] != self.local_flat[k])
2536

2637
def unchanged(self):
27-
return set(k for k in self.isect if self.ref[k] == self.target[k])
38+
"""Returns a (flattened) dict of unchanged leaves i.e. {"full/path": value, ...}"""
39+
return set(k for k in self.intersection if self.remote_flat[k] == self.local_flat[k])
2840

29-
def print_state(self):
41+
def describe_diff(self):
42+
"""Return a (multi-line) string describing all differences"""
43+
description = ""
3044
for k in self.added():
31-
print(colored("+", 'green'), "{} = {}".format(k, self.target[k]))
45+
description += colored("+", 'green'), "{} = {}".format(k, self.local_flat[k]) + '\n'
3246

3347
for k in self.removed():
34-
print(colored("-", 'red'), k)
48+
description += colored("-", 'red'), k + '\n'
3549

3650
for k in self.changed():
37-
print(colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, self.ref[k], self.target[k]))
38-
39-
40-
def flatten(d, pkey='', sep='/'):
41-
items = []
42-
for k in d:
43-
new = pkey + sep + k if pkey else k
44-
if isinstance(d[k], collections.MutableMapping):
45-
items.extend(flatten(d[k], new, sep=sep).items())
51+
description += colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, self.remote_flat[k], self.local_flat[k]) + '\n'
52+
53+
return description
54+
55+
def _flatten(self, d, current_path='', sep='/'):
56+
"""Convert a nested dict structure into a "flattened" dict i.e. {"full/path": "value", ...}"""
57+
items = []
58+
for k in d:
59+
new = current_path + sep + k if current_path else k
60+
if isinstance(d[k], collections.MutableMapping):
61+
items.extend(self._flatten(d[k], new, sep=sep).items())
62+
else:
63+
items.append((sep + new, d[k]))
64+
return dict(items)
65+
66+
def _unflatten(self, d, sep='/'):
67+
"""Converts a "flattened" dict i.e. {"full/path": "value", ...} into a nested dict structure"""
68+
output = {}
69+
for k in d:
70+
add(
71+
obj=output,
72+
path=k,
73+
value=d[k],
74+
sep=sep,
75+
)
76+
return output
77+
78+
def merge(self):
79+
"""Generate a merge of the local and remote dicts, following configurations set during __init__"""
80+
dictfilter = lambda original, keep_keys: dict([(i, original[i]) for i in original if i in set(keep_keys)])
81+
if self.force:
82+
# Overwrite local changes (i.e. only preserve added keys)
83+
# NOTE: Currently the system cannot tell the difference between a remote delete and a local add
84+
prior_set = self.changed().union(self.removed()).union(self.unchanged())
85+
current_set = self.added()
4686
else:
47-
items.append((sep + new, d[k]))
48-
return dict(items)
49-
50-
51-
def add(obj, path, value):
52-
parts = path.strip("/").split("/")
87+
# Preserve added keys and changed keys
88+
# NOTE: Currently the system cannot tell the difference between a remote delete and a local add
89+
prior_set = self.unchanged().union(self.removed())
90+
current_set = self.added().union(self.changed())
91+
state = dictfilter(original=self.remote_flat, keep_keys=prior_set)
92+
state.update(dictfilter(original=self.local_flat, keep_keys=current_set))
93+
return self._unflatten(state)
94+
95+
96+
def add(obj, path, value, sep='/'):
97+
"""Add value to the `obj` dict at the specified path"""
98+
parts = path.strip(sep).split(sep)
5399
last = len(parts) - 1
54100
for index, part in enumerate(parts):
55101
if index == last:
@@ -61,7 +107,7 @@ def add(obj, path, value):
61107
def search(state, path):
62108
result = state
63109
for p in path.strip("/").split("/"):
64-
if result.get(p):
110+
if result.clone(p):
65111
result = result[p]
66112
else:
67113
result = {}
@@ -71,16 +117,6 @@ def search(state, path):
71117
return output
72118

73119

74-
def unflatten(d):
75-
output = {}
76-
for k in d:
77-
add(
78-
obj=output,
79-
path=k,
80-
value=d[k])
81-
return output
82-
83-
84120
def merge(a, b):
85121
if not isinstance(b, dict):
86122
return b

states/states.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import print_function
2-
from botocore.exceptions import ClientError, NoCredentialsError
3-
from .helpers import flatten, merge, add, search
2+
43
import sys
5-
import os
6-
import yaml
4+
75
import boto3
86
import termcolor
7+
import yaml
8+
from botocore.exceptions import ClientError, NoCredentialsError
9+
10+
from .helpers import merge, add, search
11+
912

1013
def str_presenter(dumper, data):
1114
if len(data.splitlines()) == 1 and data[-1] == '\n':
@@ -17,8 +20,10 @@ def str_presenter(dumper, data):
1720
return dumper.represent_scalar(
1821
'tag:yaml.org,2002:str', data.strip())
1922

23+
2024
yaml.SafeDumper.add_representer(str, str_presenter)
2125

26+
2227
class SecureTag(yaml.YAMLObject):
2328
yaml_tag = u'!secure'
2429

@@ -38,7 +43,7 @@ def __hash__(self):
3843
return hash(self.secure)
3944

4045
def __ne__(self, other):
41-
return (not self.__eq__(other))
46+
return not self.__eq__(other)
4247

4348
@classmethod
4449
def from_yaml(cls, loader, node):
@@ -50,25 +55,28 @@ def to_yaml(cls, dumper, data):
5055
return dumper.represent_scalar(cls.yaml_tag, data.secure, style='|')
5156
return dumper.represent_scalar(cls.yaml_tag, data.secure)
5257

58+
5359
yaml.SafeLoader.add_constructor('!secure', SecureTag.from_yaml)
5460
yaml.SafeDumper.add_multi_representer(SecureTag, SecureTag.to_yaml)
5561

5662

5763
class LocalState(object):
58-
def __init__(self, filename):
64+
"""Encodes/decodes a dictionary to/from a YAML file"""
65+
def __init__(self, filename, paths=('/',)):
5966
self.filename = filename
67+
self.paths = paths
6068

61-
def get(self, paths, flat=True):
69+
def get(self):
6270
try:
6371
output = {}
64-
with open(self.filename,'rb') as f:
65-
l = yaml.safe_load(f.read())
66-
for path in paths:
72+
with open(self.filename, 'rb') as f:
73+
local = yaml.safe_load(f.read())
74+
for path in self.paths:
6775
if path.strip('/'):
68-
output = merge(output, search(l, path))
76+
output = merge(output, search(local, path))
6977
else:
70-
return flatten(l) if flat else l
71-
return flatten(output) if flat else output
78+
return local
79+
return output
7280
except IOError as e:
7381
print(e, file=sys.stderr)
7482
if e.errno == 2:
@@ -90,53 +98,68 @@ def save(self, state):
9098

9199

92100
class RemoteState(object):
93-
def __init__(self, profile):
101+
"""Encodes/decodes a dict to/from the SSM Parameter Store"""
102+
def __init__(self, profile, diff_class, paths=('/',)):
94103
if profile:
95104
boto3.setup_default_session(profile_name=profile)
96105
self.ssm = boto3.client('ssm')
106+
self.diff_class = diff_class
107+
self.paths = paths
97108

98-
def get(self, paths=['/'], flat=True):
109+
def clone(self):
99110
p = self.ssm.get_paginator('get_parameters_by_path')
100111
output = {}
101-
for path in paths:
112+
for path in self.paths:
102113
try:
103114
for page in p.paginate(
104-
Path=path,
105-
Recursive=True,
106-
WithDecryption=True):
115+
Path=path,
116+
Recursive=True,
117+
WithDecryption=True):
107118
for param in page['Parameters']:
108119
add(obj=output,
109120
path=param['Name'],
110121
value=self._read_param(param['Value'], param['Type']))
111122
except (ClientError, NoCredentialsError) as e:
112123
print("Failed to fetch parameters from SSM!", e, file=sys.stderr)
113124

114-
return flatten(output) if flat else output
125+
return output
115126

127+
# noinspection PyMethodMayBeStatic
116128
def _read_param(self, value, ssm_type='String'):
117129
return SecureTag(value) if ssm_type == 'SecureString' else str(value)
118130

119-
def apply(self, diff):
131+
def pull(self, local):
132+
diff = self.diff_class(
133+
remote=self.clone(),
134+
local=local,
135+
)
136+
return diff.merge()
120137

138+
def dry_run(self, local):
139+
return self.diff_class(self.clone(), local)
140+
141+
def push(self, local):
142+
diff = self.dry_run(local)
143+
144+
# diff.added|removed|changed return a "flattened" dict i.e. {"full/path": "value", ...}
121145
for k in diff.added():
122146
ssm_type = 'String'
123-
if isinstance(diff.target[k], list):
147+
if isinstance(diff.local[k], list):
124148
ssm_type = 'StringList'
125-
if isinstance(diff.target[k], SecureTag):
149+
if isinstance(diff.local[k], SecureTag):
126150
ssm_type = 'SecureString'
127151
self.ssm.put_parameter(
128152
Name=k,
129-
Value=repr(diff.target[k]) if type(diff.target[k]) == SecureTag else str(diff.target[k]),
153+
Value=repr(diff.local[k]) if type(diff.local[k]) == SecureTag else str(diff.local[k]),
130154
Type=ssm_type)
131155

132156
for k in diff.removed():
133157
self.ssm.delete_parameter(Name=k)
134158

135159
for k in diff.changed():
136-
ssm_type = 'SecureString' if isinstance(diff.target[k], SecureTag) else 'String'
137-
160+
ssm_type = 'SecureString' if isinstance(diff.local[k], SecureTag) else 'String'
138161
self.ssm.put_parameter(
139162
Name=k,
140-
Value=repr(diff.target[k]) if type(diff.target[k]) == SecureTag else str(diff.target[k]),
163+
Value=repr(diff.local[k]) if type(diff.local[k]) == SecureTag else str(diff.local[k]),
141164
Overwrite=True,
142165
Type=ssm_type)

0 commit comments

Comments
 (0)