Created
July 3, 2014 13:42
-
-
Save gdementen/8980d152f7627a6ec3bf to your computer and use it in GitHub Desktop.
Alexis Eidelman Default values Patch for LIAM2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src_liam/data.py b/src_liam/data.py | |
index 31fe79b..44f11b6 100644 | |
--- a/src_liam/data.py | |
+++ b/src_liam/data.py | |
@@ -35,7 +35,7 @@ def append_carray_to_table(array, table, numlines=None, buffersize=10 * MB): | |
class ColumnArray(object): | |
- def __init__(self, array=None): | |
+ def __init__(self, array=None, default_values=None): | |
columns = {} | |
if array is not None: | |
if isinstance(array, (np.ndarray, ColumnArray)): | |
@@ -43,18 +43,25 @@ class ColumnArray(object): | |
columns[name] = array[name].copy() | |
self.dtype = array.dtype | |
self.columns = columns | |
+ if isinstance(array, ColumnArray): | |
+ self.dval = array.dval | |
+ else: | |
+ self.dval = [] | |
elif isinstance(array, list): | |
for name, column in array: | |
columns[name] = column | |
self.dtype = np.dtype([(name, column.dtype) | |
for name, column in array]) | |
self.columns = columns | |
+ self.dval = [] | |
else: | |
#TODO: make a property instead? | |
self.dtype = None | |
self.columns = columns | |
+ self.dval = [] | |
else: | |
self.dtype = None | |
+ self.dval = [] | |
self.columns = columns | |
def __getitem__(self, key): | |
@@ -163,7 +170,7 @@ class ColumnArray(object): | |
return ca | |
@classmethod | |
- def from_table(cls, table, start=0, stop=None, buffersize=10 * 2 ** 20): | |
+ def from_table(cls, table, start=0, stop=None, default_values={}, buffersize=10 * 2 ** 20): | |
# reading a table one column at a time is very slow, this is why this | |
# function is even necessary | |
if stop is None: | |
@@ -172,6 +179,7 @@ class ColumnArray(object): | |
max_buffer_rows = buffersize // dtype.itemsize | |
numlines = stop - start | |
ca = cls.empty(numlines, dtype) | |
+ ca.dval = default_values | |
buffer_rows = min(numlines, max_buffer_rows) | |
# chunk = np.empty(buffer_rows, dtype=dtype) | |
array_start = 0 | |
@@ -220,10 +228,15 @@ class ColumnArray(object): | |
output_dtype = np.dtype(output_fields) | |
output_names = set(output_dtype.names) | |
input_names = set(self.dtype.names) | |
+ default_values = self.dval | |
length = len(self) | |
# add missing fields | |
for name in output_names - input_names: | |
- self[name] = get_missing_vector(length, output_dtype[name]) | |
+ if name in default_values: | |
+ self[name] = np.empty(length, dtype=output_dtype[name]) | |
+ self[name].fill(default_values[name]) | |
+ else: | |
+ self[name] = get_missing_vector(length, output_dtype[name]) | |
# delete extra fields | |
for name in input_names - output_names: | |
del self[name] | |
@@ -274,7 +287,7 @@ def assertValidType(array, wanted_type, allowed_missing=None, context=None): | |
wanted_type.__name__)) | |
-def add_and_drop_fields(array, output_fields, missing_fields={}, output_array=None): | |
+def add_and_drop_fields(array, output_fields, default_values={}, output_array=None): | |
output_dtype = np.dtype(output_fields) | |
output_names = set(output_dtype.names) | |
input_names = set(array.dtype.names) | |
@@ -283,8 +296,8 @@ def add_and_drop_fields(array, output_fields, missing_fields={}, output_array=No | |
if output_array is None: | |
output_array = np.empty(len(array), dtype=output_dtype) | |
for fname in all_missing_fields: | |
- if fname in missing_fields.keys(): | |
- output_array[fname] = missing_fields[fname] | |
+ if fname in default_values: | |
+ output_array[fname] = default_values[fname] | |
else: | |
output_array[fname] = get_missing_value(output_array[fname]) | |
else: | |
@@ -415,7 +428,7 @@ def appendTable(input_table, output_table, chunksize=10000, condition=None, | |
num_chunks += 1 | |
if output_fields is not None: | |
- expanded_data = np.empty(chunksize, dtype=np.dtype(output_fields)) | |
+ expanded_data = ColumnArray.empty(chunksize, dtype=np.dtype(output_fields)) | |
expanded_data[:] = get_missing_record(expanded_data) | |
def copyChunk(chunk_idx, chunk_num): | |
@@ -430,11 +443,11 @@ def appendTable(input_table, output_table, chunksize=10000, condition=None, | |
if output_fields is not None: | |
# use our pre-allocated buffer (except for the last chunk) | |
if len(input_data) == len(expanded_data): | |
- missing_fields = {} | |
+ default_values = {} | |
output_data = add_and_drop_fields(input_data, output_fields, | |
- missing_fields, expanded_data) | |
+ default_values, expanded_data) | |
else: | |
- output_data = add_and_drop_fields(input_data, output_fields) #, missing_fields | |
+ output_data = add_and_drop_fields(input_data, output_fields, default_values) | |
else: | |
output_data = input_data | |
@@ -472,7 +485,7 @@ def copyTable(input_table, output_node, output_fields=None, | |
# 1) all arrays have the same columns | |
# 2) we have id_to_rownum already computed for each array | |
def buildArrayForPeriod(input_table, output_fields, input_rows, | |
- input_index, start_period, missing_fields={}): | |
+ input_index, start_period, default_values={}): | |
periods_before = [p for p in input_rows.iterkeys() if p <= start_period] | |
if not periods_before: | |
id_to_rownum = np.empty(0, dtype=int) | |
@@ -495,7 +508,7 @@ def buildArrayForPeriod(input_table, output_fields, input_rows, | |
# if all individuals are present in the target period, we are done already! | |
if np.array_equal(present_in_period, is_present): | |
start, stop = input_rows[target_period] | |
- input_array = ColumnArray.from_table(input_table, start, stop) | |
+ input_array = ColumnArray.from_table(input_table, start, stop, default_values) | |
input_array.add_and_drop_fields(output_fields) | |
return input_array, period_id_to_rownum | |
@@ -807,9 +820,8 @@ class H5Data(DataSource): | |
# would be brought back to life. In conclusion, it should be | |
# optional. | |
entity.array, entity.id_to_rownum = \ | |
- buildArrayForPeriod(table.table, entity.fields, | |
- entity.input_rows, | |
- entity.input_index, start_period) | |
+ buildArrayForPeriod(table.table, entity.fields, entity.input_rows, | |
+ entity.input_index, start_period, entity.default_values) | |
assert isinstance(entity.array, ColumnArray) | |
entity.array_period = start_period | |
print("done (%s elapsed)." % time2str(time.time() - start_time)) | |
diff --git a/src_liam/entities.py b/src_liam/entities.py | |
index b705b0d..62b40d5 100644 | |
--- a/src_liam/entities.py | |
+++ b/src_liam/entities.py | |
@@ -10,7 +10,7 @@ import config | |
from context import EntityContext, context_length | |
from data import mergeArrays, get_fields, ColumnArray | |
from expr import (Variable, GlobalVariable, GlobalTable, GlobalArray, | |
- expr_eval, get_missing_value) | |
+ expr_eval, missing_values, get_missing_value) | |
from exprparser import parse | |
from process import Assignment, Compute, Process, ProcessGroup | |
from registry import entity_registry | |
@@ -32,7 +32,7 @@ class Entity(object): | |
''' | |
fields is a list of tuple (name, type, options) | |
''' | |
- def __init__(self, name, fields, missing_fields=None, links=None, | |
+ def __init__(self, name, fields, missing_fields=None, default_values={}, links=None, | |
macro_strings=None, process_strings=None, | |
on_align_overflow='carry'): | |
self.name = name | |
@@ -64,6 +64,7 @@ class Entity(object): | |
# another solution is to use a Field class | |
# seems like the better long term solution | |
self.missing_fields = missing_fields | |
+ self.default_values = default_values | |
self.period_individual_fnames = [name for name, _ in fields] | |
self.links = links | |
@@ -114,15 +115,29 @@ class Entity(object): | |
fields = [] | |
missing_fields = [] | |
+ default_values = {} | |
for name, fielddef in fields_def: | |
if isinstance(fielddef, dict): | |
strtype = fielddef['type'] | |
+ import pdb | |
if not fielddef.get('initialdata', True): | |
missing_fields.append(name) | |
+ | |
+ fieldtype = field_str_to_type(strtype, "field '%s'" % name) | |
+ dflt_type = missing_values[fieldtype] | |
+ default = fielddef.get('default', dflt_type) | |
+ if fieldtype != type(default): | |
+ raise Exception("The default value given to %s is %s" | |
+ " but %s was expected" %(name, type(default), strtype) ) | |
+ | |
else: | |
strtype = fielddef | |
- fields.append((name, | |
- field_str_to_type(strtype, "field '%s'" % name))) | |
+ fieldtype = field_str_to_type(strtype, "field '%s'" % name) | |
+ default = missing_values[fieldtype] | |
+ | |
+ fields.append((name, fieldtype)) | |
+ default_values[name] = default | |
+ | |
link_defs = entity_def.get('links', {}) | |
str2class = {'one2many': One2Many, 'many2one': Many2One} | |
@@ -131,13 +146,13 @@ class Entity(object): | |
for name, l in link_defs.iteritems()) | |
#TODO: add option for on_align_overflow | |
- return Entity(ent_name, fields, missing_fields, links, | |
+ return Entity(ent_name, fields, missing_fields, default_values, links, | |
entity_def.get('macros', {}), | |
entity_def.get('processes', {})) | |
@classmethod | |
def from_table(cls, table): | |
- return Entity(table.name, get_fields(table), missing_fields=[], | |
+ return Entity(table.name, get_fields(table), missing_fields=[], default_values={}, | |
links={}, macro_strings={}, process_strings={}) | |
@staticmethod | |
@@ -363,6 +378,8 @@ class Entity(object): | |
# but the usual case (in retro) is that self.array is a superset of | |
# input_array, in which case mergeArrays returns a ColumnArray | |
if not isinstance(self.array, ColumnArray): | |
+ import pdb | |
+ pdb.set_trace() | |
self.array = ColumnArray(self.array) | |
def store_period_data(self, period): | |
diff --git a/src_liam/expr.py b/src_liam/expr.py | |
index 9ec2fcf..816f398 100644 | |
--- a/src_liam/expr.py | |
+++ b/src_liam/expr.py | |
@@ -73,7 +73,10 @@ def get_missing_vector(num, dtype): | |
def get_missing_record(array): | |
row = np.empty(1, dtype=array.dtype) | |
for fname in array.dtype.names: | |
- row[fname] = get_missing_value(row[fname]) | |
+ if fname in array.dval: | |
+ row[fname] = array.dval[fname] | |
+ else: | |
+ row[fname] = get_missing_value(row[fname]) | |
return row | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment