Home > analyzePRF > utilities > fitnonlinearmodel_helper.m

fitnonlinearmodel_helper

PURPOSE ^

This is a helper function for fitnonlinearmodel.m. Not for external use!

SYNOPSIS ^

function results = fitnonlinearmodel_helper(opt,stimulus,tmatrix,smatrix,trainfun,testfun)

DESCRIPTION ^

 This is a helper function for fitnonlinearmodel.m.  Not for external use!

 Notes:
 - opt.data is always a cell vector and contains only one voxel
 - in the nonlinear case, the seed to use has been hacked into model{1}{1} and may have multiple rows

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function results = fitnonlinearmodel_helper(opt,stimulus,tmatrix,smatrix,trainfun,testfun)
0002 
0003 % This is a helper function for fitnonlinearmodel.m.  Not for external use!
0004 %
0005 % Notes:
0006 % - opt.data is always a cell vector and contains only one voxel
0007 % - in the nonlinear case, the seed to use has been hacked into model{1}{1} and may have multiple rows
0008 
0009 % calc
0010 islinear = isa(opt.model,'function_handle');
0011 if ~islinear
0012   ismultipleseeds = size(opt.model{1}{1},1) > 1;
0013   ismultiplemodels = length(opt.model) > 1;
0014 end
0015 
0016 % calc
0017 wantmodelfit = ~(ismember('modelfit',opt.dontsave) && ~ismember('modelfit',opt.dosave));
0018 if islinear
0019   numparams = size(stimulus{1},2);
0020 else
0021   numparams = size(opt.model{end}{2},2);
0022 end
0023 
0024 % init
0025 results = struct;
0026 results.params = zeros(length(trainfun),numparams);
0027 results.testdata =  cell(1,length(trainfun));  % but converted to a matrix at the end
0028 results.modelpred = cell(1,length(trainfun));  % but converted to a matrix at the end
0029 results.modelfit =  cell(1,length(trainfun));  % but converted to a matrix at the end
0030 results.trainperformance = zeros(1,length(trainfun));
0031 results.testperformance =  zeros(1,length(trainfun));
0032 results.aggregatedtestperformance = [];
0033 if islinear
0034   results.numiters = [];
0035   results.resnorms = [];
0036 else
0037   results.numiters = zeros(length(trainfun),size(opt.model{1}{1},1),length(opt.model));
0038   results.resnorms = zeros(length(trainfun),size(opt.model{1}{1},1));
0039 end
0040 
0041 % loop over resampling cases
0042 for rr=1:length(trainfun)
0043   fprintf('  starting resampling case %d of %d.\n',rr,length(trainfun));
0044 
0045   % deal with resampling
0046   trainstim = feval(trainfun{rr},stimulus);
0047   traindata = feval(trainfun{rr},opt.data);  % result is a column vector
0048   trainT =    projectionmatrix(feval(trainfun{rr},tmatrix));   % NOTE: potentially slow step. make sparse? [or CACHE]
0049   trainS =    projectionmatrix(feval(trainfun{rr},smatrix));   % NOTE: potentially slow step. make sparse? [or CACHE]
0050   teststim =  feval(testfun{rr},stimulus);
0051   testdata =  feval(testfun{rr},opt.data);   % result is a column vector
0052   testT =     projectionmatrix(feval(testfun{rr},tmatrix));    % NOTE: potentially slow step. make sparse? [or CACHE]
0053   testS =     projectionmatrix(feval(testfun{rr},smatrix));    % NOTE: potentially slow step. make sparse? [or CACHE]
0054   if wantmodelfit  % save on memory if user doesn't even want 'modelfit'
0055     allstim = catcell(1,stimulus);
0056   end
0057 
0058   % precompute
0059   traindataT = trainT*traindata;  % remove regressors from data (fitting)
0060   
0061   % deal with last-minute data division
0062   if ~islinear
0063     datastd = std(traindataT);
0064     if datastd == 0
0065       datastd = 1;
0066     end
0067     traindataT = traindataT / datastd;
0068   end
0069 
0070   % deal with options
0071   if ~islinear
0072     options = opt.optimoptions;
0073     if ~isempty(opt.outputfcn)
0074       if nargin(opt.outputfcn) == 4
0075         options.OutputFcn = @(a,b,c) feval(opt.outputfcn,a,b,c,traindataT);
0076       else
0077         options.OutputFcn = opt.outputfcn;
0078       end
0079     end
0080   end
0081 
0082   % ok, deal with linear case
0083   if islinear
0084 
0085     % do the fitting.  note that we take the mean across the third dimension
0086     % to deal with the case where the stimulus consists of multiple frames.
0087     finalparams = feval(opt.model,trainT*mean(trainstim,3),traindataT);
0088 
0089     % report
0090     fprintf('      the estimated parameters are ['); ...
0091       fprintf('%.3f ',finalparams); fprintf('].\n');
0092     
0093   % ok, deal with nonlinear case
0094   else
0095 
0096     % loop over seeds
0097     params = [];
0098     for ss=1:size(opt.model{1}{1},1)
0099       if ismultipleseeds
0100         fprintf('    trying seed %d of %d.\n',ss,size(opt.model{1}{1},1));
0101       end
0102   
0103       % loop through models
0104       for mm=1:length(opt.model)
0105     
0106         % which parameters are we actually fitting?
0107         ix = ~isnan(opt.model{mm}{2}(1,:));
0108 
0109         % calculate seed, model, and transform
0110         if mm==1
0111           seed = opt.model{mm}{1}(ss,:);
0112           model = opt.model{mm}{3};
0113           transform = opt.model{mm}{4};
0114         else
0115           seed = feval(opt.model{mm}{1},params0);
0116           model = feval(opt.model{mm}{3},params0);
0117           transform = feval(opt.model{mm}{4},params0);
0118         end
0119 
0120         % in the special case that the stimulus consists of multiple frames,
0121         % then we have to modify model so that it averages across the
0122         % predicted response associated with each frame.  this is magical voodoo here.
0123         if size(trainstim,3) > 1
0124           nums = repmat(size(trainstim,3),[1 size(trainstim,1)]);
0125           model = @(pp,dd) chunkfun(feval(model,pp,squish(permute(dd,[3 1 2]),2)),nums,@(x) mean(x,1)).';
0126         end
0127 
0128         % figure out bounds to use
0129         if isequal(options.Algorithm,'levenberg-marquardt')
0130           lb = [];
0131           ub = [];
0132         else
0133           lb = opt.model{mm}{2}(1,ix);
0134           ub = opt.model{mm}{2}(2,ix);
0135         end
0136       
0137         % precompute
0138         trainstimTRANSFORM = feval(transform,trainstim);
0139 
0140         % define the final model function
0141         fun = @(pp) trainT*feval(model,copymatrix(seed,ix,pp),trainstimTRANSFORM);
0142 
0143         % report
0144         if ismultiplemodels
0145           fprintf('      for model %d of %d, the seed is [', ...
0146                   mm,length(opt.model)); fprintf('%.3f ',seed); fprintf('].\n');
0147         else
0148           fprintf('      the seed is ['); fprintf('%.3f ',seed); fprintf('].\n');
0149         end
0150 
0151         % perform the fit (NOTICE THE DIVISION BY DATASTD, THE NAN PROTECTION, THE CONVERSION TO DOUBLE)
0152         if ~any(ix)
0153           params0 = seed;   % if no parameters are to be optimized, just return the seed
0154           resnorm = NaN;
0155           output = [];
0156           output.iterations = NaN;
0157         else
0158           [params0,resnorm,residual,exitflag,output] = ...
0159             lsqcurvefit(@(x,y) double(nanreplace(feval(fun,x) / datastd,0,2)),seed(ix),[],double(traindataT),lb,ub,options);
0160           params0 = copymatrix(seed,ix,params0);
0161         end
0162 
0163         % report
0164         fprintf('      the estimated parameters are ['); ...
0165           fprintf('%.3f ',params0); fprintf('].\n');
0166       
0167         % record
0168         results.numiters(rr,ss,mm) = output.iterations;
0169 
0170       end
0171     
0172       % record
0173       results.resnorms(rr,ss) = resnorm;  % final resnorm
0174       params(ss,:) = params0;  % final parameters
0175 
0176     end
0177   
0178     % which seed produced the best results?
0179     [d,mnix] = min(results.resnorms(rr,:));
0180     finalparams = params(mnix,:);
0181 
0182   end
0183   
0184   % record the results
0185   results.params(rr,:) = finalparams;
0186 
0187   % report
0188   if ~islinear && ismultipleseeds
0189     fprintf('    seed %d was best. final estimated parameters are [',mnix); ...
0190       fprintf('%.3f ',finalparams); fprintf('].\n');
0191   end
0192 
0193   % prepare data and model fits
0194   % [NOTE: in the nonlinear case, this inherits model, transform, and trainstimTRANSFORM from above!!]
0195   traindatatemp = trainS*traindata;
0196   if islinear
0197     modelfittemp = trainS*(trainstim*finalparams');
0198   else
0199     modelfittemp = nanreplace(trainS*feval(model,finalparams,trainstimTRANSFORM),0,2);
0200   end
0201   if isempty(testdata)  % handle this case explicitly, just to avoid problems
0202     results.testdata{rr} = [];
0203     results.modelpred{rr} = [];
0204   else
0205     results.testdata{rr} = testS*testdata;
0206     if islinear
0207       results.modelpred{rr} = testS*(teststim*finalparams');
0208     else
0209       results.modelpred{rr} = nanreplace(testS*feval(model,finalparams,feval(transform,teststim)),0,2);
0210     end
0211   end
0212   
0213   % prepare modelfit
0214   if wantmodelfit
0215     if islinear
0216       results.modelfit{rr} = (allstim*finalparams')';
0217     else
0218       results.modelfit{rr} = nanreplace(feval(model,finalparams,feval(transform,allstim)),0,2)';
0219     end
0220   else
0221     results.modelfit{rr} = [];  % if not wanted by user, don't bother computing
0222   end
0223   
0224   % compute metrics
0225   results.trainperformance(rr) = feval(opt.metric,modelfittemp,traindatatemp);
0226   if isempty(results.testdata{rr})  % handle this case explicitly, just to avoid problems
0227     results.testperformance(rr) = NaN;
0228   else
0229     results.testperformance(rr) = feval(opt.metric,results.modelpred{rr},results.testdata{rr});
0230   end
0231   
0232   % report
0233   fprintf('    trainperformance is %.2f. testperformance is %.2f.\n', ...
0234     results.trainperformance(rr),results.testperformance(rr));
0235 
0236 end
0237 
0238 % compute aggregated metrics
0239 results.testdata = catcell(1,results.testdata);
0240 results.modelpred = catcell(1,results.modelpred);
0241 results.modelfit = catcell(1,results.modelfit);
0242 if isempty(results.testdata)
0243   results.aggregatedtestperformance = NaN;
0244 else
0245   results.aggregatedtestperformance = feval(opt.metric,results.modelpred,results.testdata);
0246 end
0247 
0248 % report
0249 fprintf('  aggregatedtestperformance is %.2f.\n',results.aggregatedtestperformance);

Generated on Wed 18-Jun-2014 21:47:41 by m2html © 2005