From dd17ab32f51a3c4bc2377765f8444bb7de2813ff Mon Sep 17 00:00:00 2001
From: Robin Goodall <rjg21@cam.ac.uk>
Date: Sat, 4 Mar 2023 14:25:14 +0000
Subject: [PATCH] Cache API results to file

---
 .gitignore                |  1 +
 README.md                 | 10 ++++++
 gsuitesync/__init__.py    | 15 ++++++--
 gsuitesync/sync/base.py   |  3 +-
 gsuitesync/sync/gapi.py   | 76 +++++++++++++++++++++++----------------
 gsuitesync/sync/lookup.py |  6 +++-
 gsuitesync/sync/main.py   |  6 ++--
 gsuitesync/sync/utils.py  | 39 +++++++++++++++++++-
 8 files changed, 117 insertions(+), 39 deletions(-)

diff --git a/.gitignore b/.gitignore
index b5e2785..e0638c0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -119,3 +119,4 @@ api_gateway_credentials.yml
 gsuitesync.yaml
 
 .vscode/
+cache/
diff --git a/README.md b/README.md
index ad62085..8f5198a 100644
--- a/README.md
+++ b/README.md
@@ -39,6 +39,16 @@ value:
 $ gsuitesync --timeout=600
 ```
 
+A folder for caching responses from Google and Lookup APIs can be specified.
+This is helpful when performing repeated dry-runs while testing code. If the
+appropriate cache file exists it will be used instead of making the API call.
+If not, the file is created for later runs.
+
+```
+# Cache results in API
+$ gsuitesync --cache-dir ./cache
+```
+
 See the output of ``gsuitesync --help`` for more information on valid
 command-line flags.
 
diff --git a/gsuitesync/__init__.py b/gsuitesync/__init__.py
index 51281b3..cf2b792 100644
--- a/gsuitesync/__init__.py
+++ b/gsuitesync/__init__.py
@@ -4,7 +4,7 @@ Synchronise users to GSuite
 Usage:
     gsuitesync (-h | --help)
     gsuitesync [--configuration=FILE] [--quiet] [--group-settings] [--just-users]
-               [--licensing] [--really-do-this] [--timeout=SECONDS]
+               [--licensing] [--really-do-this] [--timeout=SECONDS] [--cache-dir=DIR]
 
 Options:
     -h, --help                  Show a brief usage summary.
@@ -24,6 +24,8 @@ Options:
     --timeout=SECONDS           Integer timeout for socket when performing batch
                                 operations. [default: 300]
 
+    --cache-dir=DIR             Directory to cache API responses, if provided
+
 """
 import logging
 import os
@@ -48,9 +50,15 @@ def main():
     logging.getLogger('googleapiclient.discovery').setLevel(logging.WARN)
     logging.getLogger('googleapiclient.discovery_cache').setLevel(logging.ERROR)
 
-    # Convert to integer, any raised exceptions will propogate upwards
+    # Convert to integer, any raised exceptions will propagate upwards
     timeout_val = int(opts['--timeout'])
 
+    # Safety check - could write with out of date cached data
+    if opts['--really-do-this'] and opts['--cache-dir']:
+        LOG.error('Running in WRITE-MODE with a cache directory is blocked to avoid potentially '
+                  'writing out of date cached data.')
+        exit(1)
+
     LOG.info('Loading configuration')
     configuration = config.load_configuration(opts['--configuration'])
 
@@ -60,4 +68,5 @@ def main():
               timeout=timeout_val,
               group_settings=opts['--group-settings'],
               just_users=opts['--just-users'],
-              licensing=opts['--licensing'])
+              licensing=opts['--licensing'],
+              cache_dir=opts['--cache-dir'])
diff --git a/gsuitesync/sync/base.py b/gsuitesync/sync/base.py
index c7e39b8..b0607b1 100644
--- a/gsuitesync/sync/base.py
+++ b/gsuitesync/sync/base.py
@@ -7,9 +7,10 @@ Base classes for retrievers, comparator and updater classes that consume configu
 class ConfigurationStateConsumer:
     required_config = None
 
-    def __init__(self, configuration, state, read_only=True):
+    def __init__(self, configuration, state, read_only=True, cache_dir=None):
         # For convenience, create properties for required configuration
         for c in (self.required_config if self.required_config is not None else []):
             setattr(self, f'{c}_config', configuration.get(c, {}))
         self.state = state
         self.read_only = read_only
+        self.cache_dir = cache_dir
diff --git a/gsuitesync/sync/gapi.py b/gsuitesync/sync/gapi.py
index 1c8c184..1ac6dd0 100644
--- a/gsuitesync/sync/gapi.py
+++ b/gsuitesync/sync/gapi.py
@@ -4,6 +4,7 @@ Load current user, group and institution data from Google.
 """
 import logging
 import re
+import functools
 
 from google.oauth2 import service_account
 from googleapiclient import discovery
@@ -12,7 +13,7 @@ import socket
 from .base import ConfigurationStateConsumer
 from .. import gapiutil
 from .utils import (email_to_uid, email_to_gid, groupID_regex, instID_regex,
-                    date_months_ago, isodate_parse, customAction)
+                    date_months_ago, isodate_parse, customAction, cache_to_disk)
 
 
 LOG = logging.getLogger(__name__)
@@ -104,7 +105,7 @@ class GAPIRetriever(ConfigurationStateConsumer):
 
     def retrieve_users(self):
         # Retrieve information on all users excluding domain admins.
-        all_google_users = self._filtered_user_list()
+        all_google_users = self._filtered_user_list(show_deleted=False)
 
         # Form mappings from uid to Google user.
         all_google_users_by_uid = {
@@ -144,7 +145,7 @@ class GAPIRetriever(ConfigurationStateConsumer):
 
         # Get recently deleted users so we can restore them instead of recreating a duplicate
         # new account if needed
-        all_deleted_google_users = self._filtered_user_list(True)
+        all_deleted_google_users = self._filtered_user_list(show_deleted=True)
 
         # There are potentially multiple deletions of the same user. Assuming we want to restore
         # the last deleted, sort so that the dict references the last one.
@@ -175,6 +176,7 @@ class GAPIRetriever(ConfigurationStateConsumer):
             'all_deleted_google_users_by_uid': all_deleted_google_users_by_uid,
         })
 
+    @cache_to_disk('google_users_{show_deleted}')
     def _filtered_user_list(self, show_deleted=False):
         LOG.info(
             f"Getting information on {'deleted' if show_deleted else 'active'}"
@@ -233,21 +235,13 @@ class GAPIRetriever(ConfigurationStateConsumer):
             return
         # Retrieve all license assignments
         LOG.info('Getting information on licence assignment for Google domain users')
-        fields = ['userId', 'skuId']
-        all_licence_assignments = gapiutil.list_all(
-            self.state.licensing_service.licenseAssignments().listForProduct,
-            customerId=self.licensing_config.customer_id,
-            productId=self.licensing_config.product_id,
-            fields='nextPageToken,items(' + ','.join(fields) + ')',
-            retries=self.sync_config.http_retries, retry_delay=self.sync_config.http_retry_delay,
-        )
-        LOG.info('Total licence assignments: %s', len(all_licence_assignments))
+        LOG.info('Total licence assignments: %s', len(self.all_licence_assignments))
 
         # Build a map of uids to license SKU. We are only interested in those in our domain
         # and SKUs in our configuration
         skus_by_uid = {
             parts[1]: parts[0] for parts in
-            [[lic['skuId']] + lic['userId'].split('@', 1) for lic in all_licence_assignments]
+            [[lic['skuId']] + lic['userId'].split('@', 1) for lic in self.all_licence_assignments]
             if (
                 len(parts) == 3 and
                 parts[2] == self.gapi_domain_config.name and
@@ -278,18 +272,30 @@ class GAPIRetriever(ConfigurationStateConsumer):
             'google_available_by_sku': available_by_sku,
         })
 
+    @functools.cached_property
+    @cache_to_disk('google_licence_assignments')
+    def all_licence_assignments(self):
+        fields = ['userId', 'skuId']
+        return gapiutil.list_all(
+            self.state.licensing_service.licenseAssignments().listForProduct,
+            customerId=self.licensing_config.customer_id,
+            productId=self.licensing_config.product_id,
+            fields='nextPageToken,items(' + ','.join(fields) + ')',
+            retries=self.sync_config.http_retries, retry_delay=self.sync_config.http_retry_delay,
+        )
+
     def retrieve_groups(self):
         # Retrieve information on all Google groups that come from Lookup groups
         LOG.info('Getting information on Google domain groups')
         all_google_groups = [
-            g for g in self._fetch_groups(self.state.groups_domain)
+            g for g in self._fetch_groups(domain=self.state.groups_domain)
             if groupID_regex.match(g['email'].split('@')[0])
         ]
 
         # Append information on all Google groups that come from Lookup institutions
         LOG.info('Getting information on Google domain institutions')
         all_google_groups.extend([
-            g for g in self._fetch_groups(self.state.insts_domain)
+            g for g in self._fetch_groups(domain=self.state.insts_domain)
             if instID_regex.match(g['email'].split('@')[0].upper())
         ])
 
@@ -313,14 +319,7 @@ class GAPIRetriever(ConfigurationStateConsumer):
 
         # Retrieve all Google group memberships. This is a mapping from internal Google group ids
         # to lists of member resources, corresponding to both Lookup groups and institutions.
-        fields = ['id', 'email']
-        all_google_members = gapiutil.list_all_in_list(
-            self.state.directory_service, self.state.directory_service.members().list,
-            item_ids=[g['id'] for g in all_google_groups], id_key='groupKey',
-            batch_size=self.sync_config.batch_size, items_key='members',
-            fields='nextPageToken,members(' + ','.join(fields) + ')',
-            retries=self.sync_config.http_retries, retry_delay=self.sync_config.http_retry_delay,
-        )
+        all_google_members = self._fetch_group_members([g['id'] for g in all_google_groups])
 
         # Sanity check. We should have a group members list for each managed group.
         if len(all_google_members) != len(all_google_groups):
@@ -343,13 +342,8 @@ class GAPIRetriever(ConfigurationStateConsumer):
 
     def retrieve_group_settings(self):
         # Retrieve all Google group settings.
-        fields = ['email', *[k for k in self.sync_config.group_settings.keys()]]
-        all_google_group_settings = gapiutil.get_all_in_list(
-            self.state.groupssettings_service, self.state.groupssettings_service.groups().get,
-            item_ids=[g['email'] for g in self.state.all_google_groups], id_key='groupUniqueId',
-            batch_size=self.sync_config.batch_size, fields=','.join(fields),
-            retries=self.sync_config.http_retries, retry_delay=self.sync_config.http_retry_delay,
-        )
+        all_google_group_emails = [g['email'] for g in self.state.all_google_groups]
+        all_google_group_settings = self._fetch_group_settings(all_google_group_emails)
 
         # Form a mapping from gid to Google group settings.
         all_google_group_settings_by_gid = {
@@ -366,6 +360,7 @@ class GAPIRetriever(ConfigurationStateConsumer):
             'all_google_group_settings_by_gid': all_google_group_settings_by_gid,
         })
 
+    @cache_to_disk('google_groups_{domain}')
     def _fetch_groups(self, domain):
         """
         Function to fetch Google group information from the specified domain
@@ -377,3 +372,24 @@ class GAPIRetriever(ConfigurationStateConsumer):
             fields='nextPageToken,groups(' + ','.join(fields) + ')',
             retries=self.sync_config.http_retries, retry_delay=self.sync_config.http_retry_delay,
         )
+
+    @cache_to_disk('google_group_members')
+    def _fetch_group_members(self, ids):
+        fields = ['id', 'email']
+        return gapiutil.list_all_in_list(
+            self.state.directory_service, self.state.directory_service.members().list,
+            item_ids=ids, id_key='groupKey',
+            batch_size=self.sync_config.batch_size, items_key='members',
+            fields='nextPageToken,members(' + ','.join(fields) + ')',
+            retries=self.sync_config.http_retries, retry_delay=self.sync_config.http_retry_delay,
+        )
+
+    @cache_to_disk('google_group_settings')
+    def _fetch_group_settings(self, ids):
+        fields = ['email', *[k for k in self.sync_config.group_settings.keys()]]
+        return gapiutil.get_all_in_list(
+            self.state.groupssettings_service, self.state.groupssettings_service.groups().get,
+            item_ids=ids, id_key='groupUniqueId',
+            batch_size=self.sync_config.batch_size, fields=','.join(fields),
+            retries=self.sync_config.http_retries, retry_delay=self.sync_config.http_retry_delay,
+        )
diff --git a/gsuitesync/sync/lookup.py b/gsuitesync/sync/lookup.py
index 95c2656..edeae0e 100644
--- a/gsuitesync/sync/lookup.py
+++ b/gsuitesync/sync/lookup.py
@@ -13,7 +13,7 @@ from identitylib.lookup_client.api.person_api import PersonApi
 from identitylib.lookup_client.api.group_api import GroupApi
 from identitylib.lookup_client.api.institution_api import InstitutionApi
 
-from .utils import date_months_ago, isodate_parse
+from .utils import date_months_ago, isodate_parse, cache_to_disk
 
 from .base import ConfigurationStateConsumer
 
@@ -206,6 +206,7 @@ class LookupRetriever(ConfigurationStateConsumer):
     # Functions to perform Lookup API calls
     ###
     @functools.cached_property
+    @cache_to_disk('lookup_eligible_users')
     def eligible_users_by_uid(self):
         """
         Dictionary mapping CRSid to UserEntry instances. An entry exists in the dictionary for each
@@ -235,6 +236,7 @@ class LookupRetriever(ConfigurationStateConsumer):
         }
 
     @functools.cached_property
+    @cache_to_disk('lookup_cancelled_dates')
     def cancelled_dates_by_uid(self):
         """
         Return a dictionary mapping CRSid to cancelledDate for cancelled users. Limit to
@@ -269,6 +271,7 @@ class LookupRetriever(ConfigurationStateConsumer):
         return cancelled_users_to_cancelled_date
 
     @functools.cached_property
+    @cache_to_disk('lookup_eligible_groups')
     def eligible_groups_by_groupID(self):
         """
         Dictionary mapping groupID to GroupEntry instances. An entry exists in the dictionary for
@@ -299,6 +302,7 @@ class LookupRetriever(ConfigurationStateConsumer):
         return groups
 
     @functools.cached_property
+    @cache_to_disk('lookup_eligible_insts')
     def eligible_insts_by_instID(self):
         """
         Dictionary mapping instID to GroupEntry instances. An entry exists in the dictionary for
diff --git a/gsuitesync/sync/main.py b/gsuitesync/sync/main.py
index 18b8b37..64b3f63 100644
--- a/gsuitesync/sync/main.py
+++ b/gsuitesync/sync/main.py
@@ -15,7 +15,7 @@ LOG = logging.getLogger(__name__)
 
 
 def sync(configuration, *, read_only=True, timeout=300, group_settings=False, just_users=False,
-         licensing=False):
+         licensing=False, cache_dir=None):
     """Perform sync given configuration dictionary."""
     if read_only:
         LOG.info('Performing synchronisation in READ ONLY mode.')
@@ -30,13 +30,13 @@ def sync(configuration, *, read_only=True, timeout=300, group_settings=False, ju
     state = SyncState()
 
     # Get users and optionally groups from Lookup
-    lookup = LookupRetriever(configuration, state)
+    lookup = LookupRetriever(configuration, state, cache_dir=cache_dir)
     lookup.retrieve_users()
     if not just_users:
         lookup.retrieve_groups()
 
     # Get users and optionally groups from Google
-    gapi = GAPIRetriever(configuration, state, read_only=read_only)
+    gapi = GAPIRetriever(configuration, state, read_only=read_only, cache_dir=cache_dir)
     gapi.connect(timeout)
     gapi.retrieve_users()
     if licensing:
diff --git a/gsuitesync/sync/utils.py b/gsuitesync/sync/utils.py
index 2872f60..3c72b52 100644
--- a/gsuitesync/sync/utils.py
+++ b/gsuitesync/sync/utils.py
@@ -1,6 +1,9 @@
 import logging
 import re
+import yaml
 from datetime import date
+from os import path
+from functools import wraps
 from dateutil.relativedelta import relativedelta
 
 LOG = logging.getLogger(__name__)
@@ -59,4 +62,38 @@ def isodate_parse(str):
 
 
 def customAction(user, property):
-    return user.get('customSchemas', {}).get('UCam', {}).get(property)
\ No newline at end of file
+    return user.get('customSchemas', {}).get('UCam', {}).get(property)
+
+
+def cache_to_disk(file_name):
+    """
+    Decorator for instance method that (if self.cache_dir is set), caches the
+    result to a file specified in the decorator parameter in self.cache_dir.
+
+    file_name can contain keyword arguments that will get expanded into actual
+    cache file name. e.g. 'google_groups_{domain}'
+
+    """
+    def decorator(method):
+
+        @wraps(method)
+        def _impl(self, *args, **kwargs):
+            # Skip if file caching not enabled
+            if not self.cache_dir:
+                return method(self, *args, **kwargs)
+            # Check if cache file exists1
+            full_path = path.join(self.cache_dir, file_name.format(**kwargs)) + '.yaml'
+            if path.exists(full_path):
+                LOG.info(f'Reading from cache file: {full_path}')
+                # Read and return file result instead
+                with open(full_path) as fp:
+                    data = yaml.unsafe_load(fp)
+                return data
+            # Get the real data and cache it to file
+            data = method(self, *args, **kwargs)
+            LOG.info(f'Writing to cache file: {full_path}')
+            with open(full_path, 'w') as fp:
+                yaml.dump(data, fp)
+            return data
+        return _impl
+    return decorator
-- 
GitLab