Skip to content

Instantly share code, notes, and snippets.

@mducle
Created January 13, 2022 14:44
Show Gist options
  • Select an option

  • Save mducle/ebeae8584755008fdd4df12e3d5de58a to your computer and use it in GitHub Desktop.

Select an option

Save mducle/ebeae8584755008fdd4df12e3d5de58a to your computer and use it in GitHub Desktop.
diff --git a/+euphonic/CoherentCrystal.m b/+euphonic/CoherentCrystal.m
index aebc44d..d39d186 100644
--- a/+euphonic/CoherentCrystal.m
+++ b/+euphonic/CoherentCrystal.m
@@ -11,6 +11,9 @@ classdef CoherentCrystal < light_python_wrapper.light_python_wrapper
is_initialised = euphonic_on();
is_redirected = light_python_wrapper.light_python_wrapper.redirect_python_warnings();
end
+ properties (Constant)
+ horace_disp = horace_disp_private();
+ end
methods
% Constructor
function obj = CoherentCrystal(varargin)
@@ -23,58 +26,65 @@ classdef CoherentCrystal < light_python_wrapper.light_python_wrapper
else
obj.pyobj = py.getattr(eu, 'CoherentCrystal');
end
- obj.overrides = {'horace_disp'};
+ obj.overrides = {{'horace_disp', @horace_disp_internal, 'obj'}};
end
- function out = horace_disp(self, qh, qk, ql, pars, varargin)
- % Overrides Python function to do chunking in Matlab to print messages
+ end
+end
- args = {};
- kwargs = pyargs();
- if ~isempty(varargin)
- all_args = [pars, varargin];
- % Find first occurence of str/char - assume everything before
- % is positional args, everything after is kwargs
- is_str = cellfun(@isstring, all_args);
- if ~any(is_str)
- is_str = cellfun(@ischar, all_args);
- end
- str_idx = find(is_str==1);
- if isempty(str_idx)
- % No strings - all positional
- args = all_args;
- else
- args = all_args(1:str_idx(1) - 1);
- kwargs = pyargs(all_args{str_idx(1):end});
- end
- else
- % If no varargin, assume all positional arguments
- args = num2cell(pars);
- end
+function out = horace_disp_private()
+ persistent eu_mod
+ if isempty(eu_mod), eu_mod = py.importlib.import_module('euphonic_sqw_models'); end
+ out = light_python_wrapper.generic_python_wrapper(eu_mod.CoherentCrystal.horace_disp, ...
+ 'overrides', {{'constructor__', @horace_disp_internal, 'obj'}});
+end
- horace_disp = py.getattr(self.pyobj, 'horace_disp');
- chunk_size = double(self.pyobj.chunk);
- lqh = numel(qh);
- if self.pyobj.verbose && chunk_size > 0
- self.pyobj.chunk = 0;
- nchunk = ceil(lqh / chunk_size);
- pyout = {};
- for ii = 1:nchunk
- qi = (ii-1)*chunk_size + 1;
- qf = min([ii*chunk_size lqh]);
- fprintf('Using Euphonic to interpolate for q-points %d:%d out of %d\n', qi, qf, lqh);
- pyout = cat(1, pyout, light_python_wrapper.p2m(horace_disp(qh(qi:qf), qk(qi:qf), ql(qi:qf), args{:}, kwargs)));
- end
- self.pyobj.chunk = chunk_size;
- for jj = 1:2
- tmp = cat(1, pyout{:,jj});
- for ii = 1:size(tmp,2)
- out{jj}{ii} = cell2mat(tmp(:,ii)'); %#ok<AGROW>
- end
- end
- else
- out = light_python_wrapper.p2m(horace_disp(qh, qh, qk, args{:}, kwargs));
+function out = horace_disp_internal(self, qh, qk, ql, pars, varargin)
+ % Overrides Python function to do chunking in Matlab to print messages
+
+ args = {};
+ kwargs = pyargs();
+ if ~isempty(varargin)
+ all_args = [pars, varargin];
+ % Find first occurence of str/char - assume everything before
+ % is positional args, everything after is kwargs
+ is_str = cellfun(@isstring, all_args);
+ if ~any(is_str)
+ is_str = cellfun(@ischar, all_args);
+ end
+ str_idx = find(is_str==1);
+ if isempty(str_idx)
+ % No strings - all positional
+ args = all_args;
+ else
+ args = all_args(1:str_idx(1) - 1);
+ kwargs = pyargs(all_args{str_idx(1):end});
+ end
+ else
+ % If no varargin, assume all positional arguments
+ args = num2cell(pars);
+ end
+
+ horace_disp = py.getattr(self.pyobj, 'horace_disp');
+ chunk_size = double(self.pyobj.chunk);
+ lqh = numel(qh);
+ if self.pyobj.verbose && chunk_size > 0
+ self.pyobj.chunk = 0;
+ nchunk = ceil(lqh / chunk_size);
+ pyout = {};
+ for ii = 1:nchunk
+ qi = (ii-1)*chunk_size + 1;
+ qf = min([ii*chunk_size lqh]);
+ fprintf('Using Euphonic to interpolate for q-points %d:%d out of %d\n', qi, qf, lqh);
+ pyout = cat(1, pyout, light_python_wrapper.p2m(horace_disp(qh(qi:qf), qk(qi:qf), ql(qi:qf), args{:}, kwargs)));
+ end
+ self.pyobj.chunk = chunk_size;
+ for jj = 1:2
+ tmp = cat(1, pyout{:,jj});
+ for ii = 1:size(tmp,2)
+ out{jj}{ii} = cell2mat(tmp(:,ii)'); %#ok<AGROW>
end
end
+ else
+ out = light_python_wrapper.p2m(horace_disp(qh, qh, qk, args{:}, kwargs));
end
end
-
diff --git a/light_python_wrapper b/light_python_wrapper
--- a/light_python_wrapper
+++ b/light_python_wrapper
@@ -1 +1 @@
-Subproject commit b5a0c516085002fdcc0b7d4b3e8aec04bcdb8d71
+Subproject commit b5a0c516085002fdcc0b7d4b3e8aec04bcdb8d71-dirty
diff --git a/+light_python_wrapper/generic_python_wrapper.m b/+light_python_wrapper/generic_python_wrapper.m
index 9c93601..1fdfc48 100644
--- a/+light_python_wrapper/generic_python_wrapper.m
+++ b/+light_python_wrapper/generic_python_wrapper.m
@@ -9,13 +9,19 @@ classdef generic_python_wrapper < light_python_wrapper.light_python_wrapper
end
methods
% Constructor
- function obj = generic_python_wrapper(pyobj)
+ function obj = generic_python_wrapper(pyobj, varargin)
if strncmp(class(pyobj), 'py.', 3)
obj.pyobj = pyobj;
obj.populate_props();
else
error('This class only wraps Python objects');
end
+ if nargin > 1
+ assert(mod(numel(varargin), 2) == 0, 'generic_python_wrapper: Expected keyword,value pairs');
+ for ii = 1:2:numel(varargin)
+ obj.(varargin{ii}) = varargin{ii+1};
+ end
+ end
end
end
end
diff --git a/+light_python_wrapper/light_python_wrapper.m b/+light_python_wrapper/light_python_wrapper.m
index fef91f6..fbc4527 100644
--- a/+light_python_wrapper/light_python_wrapper.m
+++ b/+light_python_wrapper/light_python_wrapper.m
@@ -88,6 +88,8 @@ classdef light_python_wrapper < dynamicprops
end
function varargout = subsref(obj, s)
% Overloads Matlab indexing to allow users to get at Python properties directly using dot notation.
+ persistent override_idx
+ if isempty(override_idx), override_idx = struct(); end
switch s(1).type
case '{}'
% Overload to allow access to hidden Python properties (starts with _ - not allowed by Matlab)
@@ -105,8 +107,15 @@ classdef light_python_wrapper < dynamicprops
varargout = python_redirection(varargout, s((ii+1):end));
end
case '.'
- if any(cellfun(@(c) strcmp(s(1).subs, c), obj.overrides))
- varargout = get_matlab(obj, s);
+ if ~isfield(override_idx, s(1).subs)
+ override_idx.(s(1).subs) = find(cellfun(@(c) any(strcmp(s(1).subs, c)), obj.overrides), 1);
+ end
+ if ~isempty(override_idx.(s(1).subs))
+ if iscell(obj.overrides{override_idx.(s(1).subs)})
+ varargout = get_override(obj.overrides{override_idx.(s(1).subs)}, s, obj);
+ else
+ varargout = get_matlab(obj, s);
+ end
else
try
varargout = python_redirection(obj.pyobj, s);
@@ -128,8 +137,19 @@ classdef light_python_wrapper < dynamicprops
end
end
case '()'
- args = light_python_wrapper.light_python_wrapper.parse_args(s(1).subs, obj.pyobj);
- varargout = light_python_wrapper.p2m(obj.pyobj(args{:}));
+ if ~isfield(override_idx, 'constructor__')
+ override_idx.constructor__ = find(cellfun(@(c) any(strcmp('constructor__', c)), obj.overrides), 1);
+ end
+ if ~isempty(override_idx.constructor__)
+ if iscell(obj.overrides{override_idx.constructor__})
+ varargout = get_override(obj.overrides{override_idx.constructor__}, s, obj);
+ else
+ varargout = get_matlab(obj, s);
+ end
+ else
+ args = light_python_wrapper.light_python_wrapper.parse_args(s(1).subs, obj.pyobj);
+ varargout = light_python_wrapper.p2m(obj.pyobj(args{:}));
+ end
end
if ~iscell(varargout)
varargout = {varargout};
@@ -332,10 +352,24 @@ function varargout = python_redirection(first_obj, s)
end
function out = get_matlab(obj, s)
+ if ~isa(obj, 'light_python_wrapper.light_python_wrapper')
+ fn = obj;
+ if numel(s) == 1, s = [s, s]; end
+ else
+ fn = obj.(s(1).subs);
+ end
if numel(s) > 1 && strcmp(s(2).type, '()')
- out = obj.(s(1).subs)(s(2).subs{:});
+ out = fn(s(2).subs{:});
+ else
+ out = fn;
+ end
+end
+
+function out = get_override(override, s, obj)
+ if numel(override) > 2
+ out = get_matlab(@(varargin) override{2}(obj, varargin{:}), s);
else
- out = obj.(s(1).subs);
+ out = get_matlab(override{2}, s);
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment