Created
January 13, 2022 14:44
-
-
Save mducle/ebeae8584755008fdd4df12e3d5de58a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/+euphonic/CoherentCrystal.m b/+euphonic/CoherentCrystal.m | |
| index aebc44d..d39d186 100644 | |
| --- a/+euphonic/CoherentCrystal.m | |
| +++ b/+euphonic/CoherentCrystal.m | |
| @@ -11,6 +11,9 @@ classdef CoherentCrystal < light_python_wrapper.light_python_wrapper | |
| is_initialised = euphonic_on(); | |
| is_redirected = light_python_wrapper.light_python_wrapper.redirect_python_warnings(); | |
| end | |
| + properties (Constant) | |
| + horace_disp = horace_disp_private(); | |
| + end | |
| methods | |
| % Constructor | |
| function obj = CoherentCrystal(varargin) | |
| @@ -23,58 +26,65 @@ classdef CoherentCrystal < light_python_wrapper.light_python_wrapper | |
| else | |
| obj.pyobj = py.getattr(eu, 'CoherentCrystal'); | |
| end | |
| - obj.overrides = {'horace_disp'}; | |
| + obj.overrides = {{'horace_disp', @horace_disp_internal, 'obj'}}; | |
| end | |
| - function out = horace_disp(self, qh, qk, ql, pars, varargin) | |
| - % Overrides Python function to do chunking in Matlab to print messages | |
| + end | |
| +end | |
| - args = {}; | |
| - kwargs = pyargs(); | |
| - if ~isempty(varargin) | |
| - all_args = [pars, varargin]; | |
| - % Find first occurence of str/char - assume everything before | |
| - % is positional args, everything after is kwargs | |
| - is_str = cellfun(@isstring, all_args); | |
| - if ~any(is_str) | |
| - is_str = cellfun(@ischar, all_args); | |
| - end | |
| - str_idx = find(is_str==1); | |
| - if isempty(str_idx) | |
| - % No strings - all positional | |
| - args = all_args; | |
| - else | |
| - args = all_args(1:str_idx(1) - 1); | |
| - kwargs = pyargs(all_args{str_idx(1):end}); | |
| - end | |
| - else | |
| - % If no varargin, assume all positional arguments | |
| - args = num2cell(pars); | |
| - end | |
| +function out = horace_disp_private() | |
| + persistent eu_mod | |
| + if isempty(eu_mod), eu_mod = py.importlib.import_module('euphonic_sqw_models'); end | |
| + out = light_python_wrapper.generic_python_wrapper(eu_mod.CoherentCrystal.horace_disp, ... | |
| + 'overrides', {{'constructor__', @horace_disp_internal, 'obj'}}); | |
| +end | |
| - horace_disp = py.getattr(self.pyobj, 'horace_disp'); | |
| - chunk_size = double(self.pyobj.chunk); | |
| - lqh = numel(qh); | |
| - if self.pyobj.verbose && chunk_size > 0 | |
| - self.pyobj.chunk = 0; | |
| - nchunk = ceil(lqh / chunk_size); | |
| - pyout = {}; | |
| - for ii = 1:nchunk | |
| - qi = (ii-1)*chunk_size + 1; | |
| - qf = min([ii*chunk_size lqh]); | |
| - fprintf('Using Euphonic to interpolate for q-points %d:%d out of %d\n', qi, qf, lqh); | |
| - pyout = cat(1, pyout, light_python_wrapper.p2m(horace_disp(qh(qi:qf), qk(qi:qf), ql(qi:qf), args{:}, kwargs))); | |
| - end | |
| - self.pyobj.chunk = chunk_size; | |
| - for jj = 1:2 | |
| - tmp = cat(1, pyout{:,jj}); | |
| - for ii = 1:size(tmp,2) | |
| - out{jj}{ii} = cell2mat(tmp(:,ii)'); %#ok<AGROW> | |
| - end | |
| - end | |
| - else | |
| - out = light_python_wrapper.p2m(horace_disp(qh, qh, qk, args{:}, kwargs)); | |
| +function out = horace_disp_internal(self, qh, qk, ql, pars, varargin) | |
| + % Overrides Python function to do chunking in Matlab to print messages | |
| + | |
| + args = {}; | |
| + kwargs = pyargs(); | |
| + if ~isempty(varargin) | |
| + all_args = [pars, varargin]; | |
| + % Find first occurence of str/char - assume everything before | |
| + % is positional args, everything after is kwargs | |
| + is_str = cellfun(@isstring, all_args); | |
| + if ~any(is_str) | |
| + is_str = cellfun(@ischar, all_args); | |
| + end | |
| + str_idx = find(is_str==1); | |
| + if isempty(str_idx) | |
| + % No strings - all positional | |
| + args = all_args; | |
| + else | |
| + args = all_args(1:str_idx(1) - 1); | |
| + kwargs = pyargs(all_args{str_idx(1):end}); | |
| + end | |
| + else | |
| + % If no varargin, assume all positional arguments | |
| + args = num2cell(pars); | |
| + end | |
| + | |
| + horace_disp = py.getattr(self.pyobj, 'horace_disp'); | |
| + chunk_size = double(self.pyobj.chunk); | |
| + lqh = numel(qh); | |
| + if self.pyobj.verbose && chunk_size > 0 | |
| + self.pyobj.chunk = 0; | |
| + nchunk = ceil(lqh / chunk_size); | |
| + pyout = {}; | |
| + for ii = 1:nchunk | |
| + qi = (ii-1)*chunk_size + 1; | |
| + qf = min([ii*chunk_size lqh]); | |
| + fprintf('Using Euphonic to interpolate for q-points %d:%d out of %d\n', qi, qf, lqh); | |
| + pyout = cat(1, pyout, light_python_wrapper.p2m(horace_disp(qh(qi:qf), qk(qi:qf), ql(qi:qf), args{:}, kwargs))); | |
| + end | |
| + self.pyobj.chunk = chunk_size; | |
| + for jj = 1:2 | |
| + tmp = cat(1, pyout{:,jj}); | |
| + for ii = 1:size(tmp,2) | |
| + out{jj}{ii} = cell2mat(tmp(:,ii)'); %#ok<AGROW> | |
| end | |
| end | |
| + else | |
| + out = light_python_wrapper.p2m(horace_disp(qh, qh, qk, args{:}, kwargs)); | |
| end | |
| end | |
| - | |
| diff --git a/light_python_wrapper b/light_python_wrapper | |
| --- a/light_python_wrapper | |
| +++ b/light_python_wrapper | |
| @@ -1 +1 @@ | |
| -Subproject commit b5a0c516085002fdcc0b7d4b3e8aec04bcdb8d71 | |
| +Subproject commit b5a0c516085002fdcc0b7d4b3e8aec04bcdb8d71-dirty | |
| diff --git a/+light_python_wrapper/generic_python_wrapper.m b/+light_python_wrapper/generic_python_wrapper.m | |
| index 9c93601..1fdfc48 100644 | |
| --- a/+light_python_wrapper/generic_python_wrapper.m | |
| +++ b/+light_python_wrapper/generic_python_wrapper.m | |
| @@ -9,13 +9,19 @@ classdef generic_python_wrapper < light_python_wrapper.light_python_wrapper | |
| end | |
| methods | |
| % Constructor | |
| - function obj = generic_python_wrapper(pyobj) | |
| + function obj = generic_python_wrapper(pyobj, varargin) | |
| if strncmp(class(pyobj), 'py.', 3) | |
| obj.pyobj = pyobj; | |
| obj.populate_props(); | |
| else | |
| error('This class only wraps Python objects'); | |
| end | |
| + if nargin > 1 | |
| + assert(mod(numel(varargin), 2) == 0, 'generic_python_wrapper: Expected keyword,value pairs'); | |
| + for ii = 1:2:numel(varargin) | |
| + obj.(varargin{ii}) = varargin{ii+1}; | |
| + end | |
| + end | |
| end | |
| end | |
| end | |
| diff --git a/+light_python_wrapper/light_python_wrapper.m b/+light_python_wrapper/light_python_wrapper.m | |
| index fef91f6..fbc4527 100644 | |
| --- a/+light_python_wrapper/light_python_wrapper.m | |
| +++ b/+light_python_wrapper/light_python_wrapper.m | |
| @@ -88,6 +88,8 @@ classdef light_python_wrapper < dynamicprops | |
| end | |
| function varargout = subsref(obj, s) | |
| % Overloads Matlab indexing to allow users to get at Python properties directly using dot notation. | |
| + persistent override_idx | |
| + if isempty(override_idx), override_idx = struct(); end | |
| switch s(1).type | |
| case '{}' | |
| % Overload to allow access to hidden Python properties (starts with _ - not allowed by Matlab) | |
| @@ -105,8 +107,15 @@ classdef light_python_wrapper < dynamicprops | |
| varargout = python_redirection(varargout, s((ii+1):end)); | |
| end | |
| case '.' | |
| - if any(cellfun(@(c) strcmp(s(1).subs, c), obj.overrides)) | |
| - varargout = get_matlab(obj, s); | |
| + if ~isfield(override_idx, s(1).subs) | |
| + override_idx.(s(1).subs) = find(cellfun(@(c) any(strcmp(s(1).subs, c)), obj.overrides), 1); | |
| + end | |
| + if ~isempty(override_idx.(s(1).subs)) | |
| + if iscell(obj.overrides{override_idx.(s(1).subs)}) | |
| + varargout = get_override(obj.overrides{override_idx.(s(1).subs)}, s, obj); | |
| + else | |
| + varargout = get_matlab(obj, s); | |
| + end | |
| else | |
| try | |
| varargout = python_redirection(obj.pyobj, s); | |
| @@ -128,8 +137,19 @@ classdef light_python_wrapper < dynamicprops | |
| end | |
| end | |
| case '()' | |
| - args = light_python_wrapper.light_python_wrapper.parse_args(s(1).subs, obj.pyobj); | |
| - varargout = light_python_wrapper.p2m(obj.pyobj(args{:})); | |
| + if ~isfield(override_idx, 'constructor__') | |
| + override_idx.constructor__ = find(cellfun(@(c) any(strcmp('constructor__', c)), obj.overrides), 1); | |
| + end | |
| + if ~isempty(override_idx.constructor__) | |
| + if iscell(obj.overrides{override_idx.constructor__}) | |
| + varargout = get_override(obj.overrides{override_idx.constructor__}, s, obj); | |
| + else | |
| + varargout = get_matlab(obj, s); | |
| + end | |
| + else | |
| + args = light_python_wrapper.light_python_wrapper.parse_args(s(1).subs, obj.pyobj); | |
| + varargout = light_python_wrapper.p2m(obj.pyobj(args{:})); | |
| + end | |
| end | |
| if ~iscell(varargout) | |
| varargout = {varargout}; | |
| @@ -332,10 +352,24 @@ function varargout = python_redirection(first_obj, s) | |
| end | |
| function out = get_matlab(obj, s) | |
| + if ~isa(obj, 'light_python_wrapper.light_python_wrapper') | |
| + fn = obj; | |
| + if numel(s) == 1, s = [s, s]; end | |
| + else | |
| + fn = obj.(s(1).subs); | |
| + end | |
| if numel(s) > 1 && strcmp(s(2).type, '()') | |
| - out = obj.(s(1).subs)(s(2).subs{:}); | |
| + out = fn(s(2).subs{:}); | |
| + else | |
| + out = fn; | |
| + end | |
| +end | |
| + | |
| +function out = get_override(override, s, obj) | |
| + if numel(override) > 2 | |
| + out = get_matlab(@(varargin) override{2}(obj, varargin{:}), s); | |
| else | |
| - out = obj.(s(1).subs); | |
| + out = get_matlab(override{2}, s); | |
| end | |
| end | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment