Process.py 14 KB
Newer Older
1

2
# Copyright [1999-2015] Wellcome Trust Sanger Institute and the EMBL-European Bioinformatics Institute
Brandon Walts's avatar
Brandon Walts committed
3
# Copyright [2016-2019] EMBL-European Bioinformatics Institute
4 5 6 7 8 9 10 11 12 13 14 15 16
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#      http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

17
import eHive.Params
18 19 20 21 22

import os
import sys
import json
import numbers
23
import unittest
24 25 26
import warnings
import traceback

27
__version__ = "4.0"
Matthieu Muffato's avatar
Matthieu Muffato committed
28 29

__doc__ = """
30
This module mainly implements python's counterpart of GuestProcess. Read
Matthieu Muffato's avatar
Matthieu Muffato committed
31 32 33
the later for more information about the JSON protocol used to communicate.
"""

34
class Job(object):
Matthieu Muffato's avatar
Matthieu Muffato committed
35
    """Dummy class to hold job-related information"""
36 37
    pass

Matthieu Muffato's avatar
cleanup  
Matthieu Muffato committed
38
class CompleteEarlyException(Exception):
39 40 41 42
    """Can be raised by a derived class of BaseRunnable to indicate an early successful termination"""
    pass
class JobFailedException(Exception):
    """Can be raised by a derived class of BaseRunnable to indicate an early unsuccessful termination"""
43 44
    pass
class HiveJSONMessageException(Exception):
45
    """Raised when we could not parse the JSON message coming from GuestProcess"""
46
    pass
47 48 49
class LostHiveConnectionException(Exception):
    """Raised when the process has lost the communication pipe with the Perl side"""
    pass
50 51


52
class BaseRunnable(object):
Matthieu Muffato's avatar
Matthieu Muffato committed
53
    """This is the counterpart of GuestProcess. Note that most of the methods
Matthieu Muffato's avatar
Matthieu Muffato committed
54 55 56 57 58 59 60
    are private to be hidden in the derived classes.

    This class can be used as a base-class for people to redefine fetch_input(),
    run() and/or write_output() (and/or pre_cleanup(), post_cleanup()).
    Jobs are supposed to raise CompleteEarlyException in case they complete before
    reaching. They can also raise JobFailedException to indicate a general failure
    """
61

62 63
    # Private BaseRunnable interface
    #################################
64

65
    def __init__(self, read_fileno, write_fileno, debug):
66
        # We need the binary mode to disable the buffering
67 68 69
        self.__read_pipe = os.fdopen(read_fileno, mode='rb', buffering=0)
        self.__write_pipe = os.fdopen(write_fileno, mode='wb', buffering=0)
        self.__pid = os.getpid()
70
        self.debug = debug
71
        self.__process_life_cycle()
72

73
    def __print_debug(self, *args):
Matthieu Muffato's avatar
Matthieu Muffato committed
74
        if self.debug > 1:
75
            print("PYTHON {0}".format(self.__pid), *args, file=sys.stderr)
76

77 78
    # FIXME: we can probably merge __send_message and __send_response

79
    def __send_message(self, event, content):
Matthieu Muffato's avatar
Matthieu Muffato committed
80
        """seralizes the message in JSON and send it to the parent process"""
81
        def default_json_encoder(o):
82
            self.__print_debug("Cannot serialize {0} (type {1}) in JSON".format(o, type(o)))
83 84
            return 'UNSERIALIZABLE OBJECT'
        j = json.dumps({'event': event, 'content': content}, indent=None, default=default_json_encoder)
85
        self.__print_debug('__send_message:', j)
Matthieu Muffato's avatar
Matthieu Muffato committed
86
        # UTF8 encoding has never been tested. Just hope it works :)
87 88 89 90
        try:
            self.__write_pipe.write(bytes(j+"\n", 'utf-8'))
        except BrokenPipeError as e:
            raise LostHiveConnectionException("__write_pipe") from None
91

92 93 94 95
    def __send_response(self, response):
        """Sends a response message to the parent process"""
        self.__print_debug('__send_response:', response)
        # Like above, UTF8 encoding has never been tested. Just hope it works :)
96 97 98 99
        try:
            self.__write_pipe.write(bytes('{"response": "' + str(response) + '"}\n', 'utf-8'))
        except BrokenPipeError as e:
            raise LostHiveConnectionException("__write_pipe") from None
100

101
    def __read_message(self):
Matthieu Muffato's avatar
Matthieu Muffato committed
102
        """Read a message from the parent and parse it"""
103
        try:
104
            self.__print_debug("__read_message ...")
105
            l = self.__read_pipe.readline()
106
            self.__print_debug(" ... -> ", l[:-1].decode())
107
            return json.loads(l.decode())
108 109
        except BrokenPipeError as e:
            raise LostHiveConnectionException("__read_pipe") from None
110
        except ValueError as e:
Matthieu Muffato's avatar
Matthieu Muffato committed
111
            # HiveJSONMessageException is a more meaningful name than ValueError
112
            raise HiveJSONMessageException from e
113 114

    def __send_message_and_wait_for_OK(self, event, content):
Matthieu Muffato's avatar
Matthieu Muffato committed
115
        """Send a message and expects a response to be 'OK'"""
116 117 118
        self.__send_message(event, content)
        response = self.__read_message()
        if response['response'] != 'OK':
119
            raise HiveJSONMessageException("Received '{0}' instead of OK".format(response))
120

121
    def __process_life_cycle(self):
Matthieu Muffato's avatar
Matthieu Muffato committed
122
        """Simple loop: wait for job parameters, do the job's life-cycle"""
123
        self.__send_message_and_wait_for_OK('VERSION', __version__)
124
        self.__send_message_and_wait_for_OK('PARAM_DEFAULTS', self.param_defaults())
125
        self.__created_worker_temp_directory = None
126 127 128 129
        while True:
            self.__print_debug("waiting for instructions")
            config = self.__read_message()
            if 'input_job' not in config:
Matthieu Muffato's avatar
Matthieu Muffato committed
130
                self.__print_debug("no params, this is the end of the wrapper")
131 132 133 134
                return
            self.__job_life_cycle(config)

    def __job_life_cycle(self, config):
135
        """Job's life-cycle. See GuestProcess for a description of the protocol to communicate with the parent"""
136
        self.__print_debug("__life_cycle")
137

138
        # Params
139
        self.__params = eHive.Params.ParamContainer(config['input_job']['parameters'], self.debug > 1)
140 141 142 143 144 145

        # Job attributes
        self.input_job = Job()
        for x in ['dbID', 'input_id', 'retry_count']:
            setattr(self.input_job, x, config['input_job'][x])
        self.input_job.autoflow = True
146 147
        self.input_job.lethal_for_worker = False
        self.input_job.transient_error = True
148 149

        # Worker attributes
150
        self.debug = config['debug']
151 152 153 154 155 156 157

        # Which methods should be run
        steps = [ 'fetch_input', 'run' ]
        if self.input_job.retry_count > 0:
            steps.insert(0, 'pre_cleanup')
        if config['execute_writes']:
            steps.append('write_output')
158
            steps.append('post_healthcheck')
159
        self.__print_debug("steps to run:", steps)
160
        self.__send_response('OK')
161 162 163 164 165 166 167 168

        # The actual life-cycle
        died_somewhere = False
        try:
            for s in steps:
                self.__run_method_if_exists(s)
        except CompleteEarlyException as e:
            self.warning(e.args[0] if len(e.args) else repr(e), False)
169 170 171
        except LostHiveConnectionException as e:
            # Mothing we can do, let's just exit
            raise
172 173
        except:
            died_somewhere = True
174
            self.warning( self.__traceback(2), True)
175 176 177

        try:
            self.__run_method_if_exists('post_cleanup')
178 179 180
        except LostHiveConnectionException as e:
            # Mothing we can do, let's just exit
            raise
181 182
        except:
            died_somewhere = True
183
            self.warning( self.__traceback(2), True)
184

185
        job_end_structure = {'complete' : not died_somewhere, 'job': {}, 'params': {'substituted': self.__params.param_hash, 'unsubstituted': self.__params.unsubstituted_param_hash}}
186 187
        for x in [ 'autoflow', 'lethal_for_worker', 'transient_error' ]:
            job_end_structure['job'][x] = getattr(self.input_job, x)
188
        self.__send_message_and_wait_for_OK('JOB_END', job_end_structure)
189 190

    def __run_method_if_exists(self, method):
Matthieu Muffato's avatar
Matthieu Muffato committed
191 192
        """method is one of "pre_cleanup", "fetch_input", "run", "write_output", "post_cleanup".
        We only the call the method if it exists to save a trip to the database."""
193 194 195 196
        if hasattr(self, method):
            self.__send_message_and_wait_for_OK('JOB_STATUS_UPDATE', method)
            getattr(self, method)()

197
    def __traceback(self, skipped_traces):
Matthieu Muffato's avatar
Matthieu Muffato committed
198
        """Remove "skipped_traces" lines from the stack trace (the eHive part)"""
199 200 201 202 203 204
        (etype, value, tb) = sys.exc_info()
        s1 = traceback.format_exception_only(etype, value)
        l = traceback.extract_tb(tb)[skipped_traces:]
        s2 = traceback.format_list(l)
        return "".join(s1+s2)

205

206 207
    # Public BaseRunnable interface
    ################################
208 209

    def warning(self, message, is_error = False):
Matthieu Muffato's avatar
Matthieu Muffato committed
210
        """Store a message in the log_message table with is_error indicating whether the warning is actually an error or not"""
211 212
        self.__send_message_and_wait_for_OK('WARNING', {'message': message, 'is_error': is_error})

213
    def dataflow(self, output_ids, branch_name_or_code = 1):
Matthieu Muffato's avatar
Matthieu Muffato committed
214
        """Dataflows the output_id(s) on a given branch (default 1). Returns whatever the Perl side returns"""
215 216
        if branch_name_or_code == 1:
            self.autoflow = False
217
        self.__send_message('DATAFLOW', {'output_ids': output_ids, 'branch_name_or_code': branch_name_or_code, 'params': {'substituted': self.__params.param_hash, 'unsubstituted': self.__params.unsubstituted_param_hash}})
218
        return self.__read_message()['response']
219 220

    def worker_temp_directory(self):
Matthieu Muffato's avatar
Matthieu Muffato committed
221 222
        """Returns the full path of the temporary directory created by the worker.
        """
223
        if self.__created_worker_temp_directory is None:
224
            self.__send_message('WORKER_TEMP_DIRECTORY', None)
225
            self.__created_worker_temp_directory = self.__read_message()['response']
226
        return self.__created_worker_temp_directory
227 228 229 230 231

    # Param interface
    ##################

    def param_defaults(self):
Matthieu Muffato's avatar
Matthieu Muffato committed
232
        """Returns the defaults parameters for this runnable"""
233 234 235
        return {}

    def param_required(self, param_name):
Matthieu Muffato's avatar
Matthieu Muffato committed
236
        """Returns the value of the parameter "param_name" or raises an exception
237 238
        if anything wrong happens or the value is None. The exception is
        marked as non-transient."""
239 240
        t = self.input_job.transient_error
        self.input_job.transient_error = False
241
        v = self.__params.get_param(param_name)
242 243
        if v is None:
            raise eHive.Params.NullParamException(param_name)
244 245
        self.input_job.transient_error = t
        return v
246 247

    def param(self, param_name, *args):
Matthieu Muffato's avatar
Matthieu Muffato committed
248 249
        """When called as a setter: sets the value of the parameter "param_name".
        When called as a getter: returns the value of the parameter "param_name".
Matthieu Muffato's avatar
Matthieu Muffato committed
250 251
        It does not raise an exception if the parameter (or another one in the
        substitution stack) is undefined"""
252 253
        # As a setter
        if len(args):
254
            return self.__params.set_param(param_name, args[0])
255 256 257

        # As a getter
        try:
258
            return self.__params.get_param(param_name)
259
        except KeyError as e:
260
            warnings.warn("parameter '{0}' cannot be initialized because {1} is missing !".format(param_name, e), eHive.Params.ParamWarning, 2)
261 262 263
            return None

    def param_exists(self, param_name):
264 265 266 267 268 269 270 271 272
        """Returns True if the parameter exists and can be successfully
        substituted, None if the substitution fails, False if it is missing"""
        if not self.__params.has_param(param_name):
            return False
        try:
            self.__params.get_param(param_name)
            return True
        except KeyError:
            return None
273 274

    def param_is_defined(self, param_name):
275 276 277 278 279 280 281
        """Returns True if the parameter exists and can be successfully
        substituted to a defined value, None if the substitution fails,
        False if it is missing or evaluates as None"""
        e = self.param_exists(param_name)
        if not e:
            # False or None
            return e
282
        try:
283
            return self.__params.get_param(param_name) is not None
284 285 286
        except KeyError:
            return False

287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
class RunnableTest(unittest.TestCase):
    def test_job_param(self):
        class FakeRunnableWithParams(BaseRunnable):
            def __init__(self, d):
                self._BaseRunnable__params = eHive.Params.ParamContainer(d)
                self.input_job = Job()
                self.input_job.transient_error = True
        j = FakeRunnableWithParams({
            'a': 3,
            'b': None,
            'c': '#other#',
            'e': '#e#'
        })

        # param_exists
        self.assertIs( j.param_exists('a'), True, '"a" exists' )
        self.assertIs( j.param_exists('b'), True, '"b" exists' )
        self.assertIs( j.param_exists('c'), None, '"c"\'s existence is unclear' )
        self.assertIs( j.param_exists('d'), False, '"d" doesn\'t exist' )
        with self.assertRaises(eHive.Params.ParamInfiniteLoopException):
            j.param_exists('e')

        # param_is_defined
        self.assertIs( j.param_is_defined('a'), True, '"a" is defined' )
        self.assertIs( j.param_is_defined('b'), False, '"b" is not defined' )
        self.assertIs( j.param_is_defined('c'), None, '"c"\'s defined-ness is unclear' )
        self.assertIs( j.param_is_defined('d'), False, '"d" is not defined (it doesn\'t exist)' )
        with self.assertRaises(eHive.Params.ParamInfiniteLoopException):
            j.param_is_defined('e')

        # param
        self.assertIs( j.param('a'), 3, '"a" is 3' )
        self.assertIs( j.param('b'), None, '"b" is None' )
        with self.assertWarns(eHive.Params.ParamWarning):
            self.assertIs( j.param('c'), None, '"c"\'s value is unclear' )
        with self.assertWarns(eHive.Params.ParamWarning):
            self.assertIs( j.param('d'), None, '"d" is not defined (it doesn\'t exist)' )
        with self.assertRaises(eHive.Params.ParamInfiniteLoopException):
            j.param('e')

        # param_required
        self.assertIs( j.param_required('a'), 3, '"a" is 3' )
        with self.assertRaises(eHive.Params.NullParamException):
            j.param_required('b')
        with self.assertRaises(KeyError):
            j.param_required('c')
        with self.assertRaises(KeyError):
            j.param_required('d')
        with self.assertRaises(eHive.Params.ParamInfiniteLoopException):
            j.param_required('e')