Process.py 6.32 KB
Newer Older
1

2
import eHive.Params
3
4
5
6
7
8
9
10
11
12
13

import os
import sys
import json
import numbers
import warnings
import traceback

class Job(object):
    pass

Matthieu Muffato's avatar
cleanup  
Matthieu Muffato committed
14
class CompleteEarlyException(Exception):
15
16
17
18
19
    pass
class HiveJSONMessageException(Exception):
    pass


20
class BaseRunnable(object):
21

22
23
    # Private BaseRunnable interface
    #################################
24
25
26
27
28

    def __init__(self, read_fileno, write_fileno):
        # We need the binary mode to disable the buffering
        self.read_pipe = os.fdopen(read_fileno, mode='rb', buffering=0)
        self.write_pipe = os.fdopen(write_fileno, mode='wb', buffering=0)
29
        self.pid = os.getpid()
30
        self.__process_life_cycle()
31

32
33
34
    def __print_debug(self, *args):
        print("PYTHON {0}".format(self.pid), *args, file=sys.stderr)

35
    def __send_message(self, event, content):
36
        def default_json_encoder(o):
37
            self.__print_debug("Cannot serialize {0} (type {1}) in JSON".format(o, type(o)))
38
39
            return 'UNSERIALIZABLE OBJECT'
        j = json.dumps({'event': event, 'content': content}, indent=None, default=default_json_encoder)
40
        self.__print_debug('__send_message:', j)
41
42
43
44
        self.write_pipe.write(bytes(j+"\n", 'utf-8'))

    def __read_message(self):
        try:
45
            self.__print_debug("__read_message ...")
46
            l = self.read_pipe.readline()
47
            self.__print_debug(" ... -> ", l[:-1].decode())
48
49
            return json.loads(l.decode())
        except ValueError as e:
50
            raise HiveJSONMessageException from e
51
52
53
54
55

    def __send_message_and_wait_for_OK(self, event, content):
        self.__send_message(event, content)
        response = self.__read_message()
        if response['response'] != 'OK':
56
            raise HiveJSONMessageException("Received '{0}' instead of OK".format(response))
57

58
    def __process_life_cycle(self):
59
        self.__send_message_and_wait_for_OK('PARAM_DEFAULTS', self.param_defaults())
60
61
62
63
64
65
66
67
68
        while True:
            self.__print_debug("waiting for instructions")
            config = self.__read_message()
            if 'input_job' not in config:
                self.__print_debug("no params")
                return
            self.__job_life_cycle(config)

    def __job_life_cycle(self, config):
69

70
        self.__print_debug("__life_cycle")
71

72
        # Params
73
        self.p = eHive.Params.ParamContainer(config['input_job']['parameters'])
74
75
76
77
78
79

        # Job attributes
        self.input_job = Job()
        for x in ['dbID', 'input_id', 'retry_count']:
            setattr(self.input_job, x, config['input_job'][x])
        self.input_job.autoflow = True
80
81
        self.input_job.lethal_for_worker = False
        self.input_job.transient_error = True
82
83

        # Worker attributes
84
        self.debug = config['debug']
85
86
87
88
89
90
91

        # Which methods should be run
        steps = [ 'fetch_input', 'run' ]
        if self.input_job.retry_count > 0:
            steps.insert(0, 'pre_cleanup')
        if config['execute_writes']:
            steps.append('write_output')
92
        self.__print_debug("steps to run:", steps)
93
94
95
96
97
98
99
100
101
102

        # The actual life-cycle
        died_somewhere = False
        try:
            for s in steps:
                self.__run_method_if_exists(s)
        except CompleteEarlyException as e:
            self.warning(e.args[0] if len(e.args) else repr(e), False)
        except:
            died_somewhere = True
103
            self.warning( self.__traceback(2), True)
104
105
106
107
108

        try:
            self.__run_method_if_exists('post_cleanup')
        except:
            died_somewhere = True
109
            self.warning( self.__traceback(2), True)
110

111
112
113
        job_end_structure = {'complete' : not died_somewhere, 'job': {}, 'params': {'substituted': self.p._param_hash, 'unsubstituted': self.p._unsubstituted_param_hash}}
        for x in [ 'autoflow', 'lethal_for_worker', 'transient_error' ]:
            job_end_structure['job'][x] = getattr(self.input_job, x)
114
        self.__send_message_and_wait_for_OK('JOB_END', job_end_structure)
115
116
117
118
119
120

    def __run_method_if_exists(self, method):
        if hasattr(self, method):
            self.__send_message_and_wait_for_OK('JOB_STATUS_UPDATE', method)
            getattr(self, method)()

121
122
123
124
125
126
127
    def __traceback(self, skipped_traces):
        (etype, value, tb) = sys.exc_info()
        s1 = traceback.format_exception_only(etype, value)
        l = traceback.extract_tb(tb)[skipped_traces:]
        s2 = traceback.format_list(l)
        return "".join(s1+s2)

128

129
130
    # Public BaseRunnable interface
    ################################
131
132
133
134

    def warning(self, message, is_error = False):
        self.__send_message_and_wait_for_OK('WARNING', {'message': message, 'is_error': is_error})

135
136
    def dataflow(self, output_ids, branch_name_or_code = 1):
        self.__send_message('DATAFLOW', {'output_ids': output_ids, 'branch_name_or_code': branch_name_or_code, 'params': {'substituted': self.p._param_hash, 'unsubstituted': self.p._unsubstituted_param_hash}})
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        return self.__read_message()

    def worker_temp_directory(self):
        if not hasattr(self, '_created_worker_temp_directory'):
            template_name = self.worker_temp_directory_name() if hasattr(self, 'worker_temp_directory_name') else None
            self.__send_message('WORKER_TEMP_DIRECTORY', template_name)
            self._created_worker_temp_directory = self.__read_message()
        return self._created_worker_temp_directory

    # Param interface
    ##################

    def param_defaults(self):
        return {}

    def param_required(self, param_name):
153
154
        t = self.input_job.transient_error
        self.input_job.transient_error = False
Matthieu Muffato's avatar
fixup  
Matthieu Muffato committed
155
        v = self.p.get_param(param_name)
156
157
        self.input_job.transient_error = t
        return v
158
159
160
161
162
163
164
165
166
167

    def param(self, param_name, *args):
        # As a setter
        if len(args):
            return self.p.set_param(param_name, args[0])

        # As a getter
        try:
            return self.p.get_param(param_name)
        except KeyError as e:
168
            warnings.warn("parameter '{0}' cannot be initialized because {1} is not defined !\n".format(param_name, e), Params.ParamWarning, 2)
169
170
171
172
173
174
            return None

    def param_exists(self, param_name):
        return self.p.has_param(param_name)

    def param_is_defined(self, param_name):
175
176
        if not self.param_exists(param_name):
            return False
177
178
179
180
181
        try:
            return self.p.get_param(param_name) is not None
        except KeyError:
            return False