Package npsgd :: Module model_manager
[hide private]
[frames] | no frames]

Source Code for Module npsgd.model_manager

  1  # Author: Thomas Dimson [tdimson@gmail.com] 
  2  # Date:   January 2011 
  3  # For distribution details, see LICENSE 
  4  """Model (plug-in) loader with versioning support.""" 
  5  import os 
  6  import sys 
  7  import imp 
  8  import glob 
  9  import hashlib 
 10  import inspect 
 11  import logging 
 12  import threading 
 13  from npsgd.config import config 
 14  from model_task import ModelTask 
 15   
16 -class InvalidModelError(RuntimeError): pass
17 -class ModelManager(object):
18 """Object for keeping track of all models available to the NPSGD daemons. 19 20 This essentially takes the form of hash from (modelName, modelVersion) to 21 the actual model classes (from modules). This class is thread safe. 22 """ 23
24 - def __init__(self):
25 self.modelLock = threading.RLock() 26 self.models = {} 27 self.latestVersions = {}
28
29 - def modelNames(self):
30 with self.modelLock: 31 return list(n for (n,v) in self.models.keys())
32
33 - def modelVersions(self):
34 with self.modelLock: 35 return list(self.models.keys())
36
37 - def getLatestModel(self, name):
38 with self.modelLock: 39 return self.latestVersions[name]
40
41 - def getModel(self, name, version):
42 with self.modelLock: 43 return self.models[(name, version)]
44
45 - def getModelFromTaskDict(self, taskDict):
46 name = taskDict["modelName"] 47 version = taskDict["modelVersion"] 48 with self.modelLock: 49 if (name, version) not in self.models: 50 raise InvalidModelError("Invalid model-version combination %s-%s" % (name, version)) 51 model = self.models[(name, version)] 52 53 return model.fromDict(taskDict)
54
55 - def hasModel(self, name, version):
56 with self.modelLock: 57 return (name, version) in self.models
58
59 - def addModel(self, cls, version):
60 """Add a model to th hash, provided it is well formed.""" 61 #Ignore abstract models 62 if not hasattr(cls, 'abstractModel') or cls.abstractModel == cls.__name__: 63 return 64 65 if not hasattr(cls, 'short_name'): 66 raise InvalidModelError("Model '%s' lacks a short_name" % cls.__name__) 67 68 if not hasattr(cls, 'full_name'): 69 raise InvalidModelError("Model '%s' lacks a full_name" % cls.__name__) 70 71 if not hasattr(cls, 'parameters'): 72 raise InvalidModelError("Model '%s' has no parameters" % cls.__name__) 73 74 if self.hasModel(cls.short_name, version): 75 return 76 77 cls.version = version 78 with self.modelLock: 79 self.models[(cls.short_name, version)] = cls 80 self.latestVersions[cls.short_name] = cls 81 logging.info("Found and loaded model '%s', version '%s'", cls.short_name, cls.version)
82
83 - def getModelVersion(self, cls):
84 sourceCode = inspect.getsource(inspect.getmodule(cls)) 85 m = hashlib.md5() 86 m.update(sourceCode) 87 return m.hexdigest()
88 89
90 -def loadMembers(mod, version):
91 """Steps through all classes in a given module and loads those that are NPSGD models.""" 92 global modelManager 93 for name, obj in inspect.getmembers(mod): 94 if inspect.isclass(obj) and obj.__module__ == mod.__name__ and issubclass(obj, ModelTask): 95 modelManager.addModel(obj, version)
96
97 -def setupModels():
98 """Attempts to do the initial load of all models. Must be called on script startup. 99 100 This method scans the the model directory and finds all python scripts available. 101 It computes a hash of the scripts (i.e. a 'version') then attempts to load all 102 NPSGD models held within, using the version previously configured. 103 """ 104 if config.modelDirectory not in sys.path: 105 sys.path.append(config.modelDirectory) 106 t = 0 107 try: 108 sys.dont_write_bytecode = True 109 for pyfile in glob.glob("%s/*.py" % config.modelDirectory): 110 importName = os.path.basename(pyfile).rsplit(".", 1)[0] 111 t += 1 112 try: 113 m = hashlib.md5() 114 with open(pyfile) as f: 115 m.update(f.read()) 116 117 version = m.hexdigest() 118 module = imp.load_source(importName, pyfile) 119 loadMembers(module, version) 120 except Exception: 121 logging.exception("Unable to load model from '%s'" % importName) 122 continue 123 finally: 124 sys.dont_write_bytecode = False
125
126 -class ModelScannerThread(threading.Thread):
127 """Thread for periodically loading new versions of models."""
128 - def __init__(self):
129 threading.Thread.__init__(self) 130 self.done = threading.Event() 131 self.daemon = True
132
133 - def run(self):
134 while True: 135 self.done.wait(config.modelScanInterval) 136 if self.done.isSet(): 137 break 138 logging.debug("Model scanner thread scanning for models") 139 setupModels()
140 141 modelScannerThread = None
142 -def startScannerThread():
143 """Start the dynamic model loader, loading models as they are modified.""" 144 global modelScannerThread 145 modelScannerThread = ModelScannerThread() 146 modelScannerThread.start()
147 148 modelManager = ModelManager() 149