0001 function results = fitnonlinearmodel_helper(opt,stimulus,tmatrix,smatrix,trainfun,testfun)
0002
0003
0004
0005
0006
0007
0008
0009
0010 islinear = isa(opt.model,'function_handle');
0011 if ~islinear
0012 ismultipleseeds = size(opt.model{1}{1},1) > 1;
0013 ismultiplemodels = length(opt.model) > 1;
0014 end
0015
0016
0017 wantmodelfit = ~(ismember('modelfit',opt.dontsave) && ~ismember('modelfit',opt.dosave));
0018 if islinear
0019 numparams = size(stimulus{1},2);
0020 else
0021 numparams = size(opt.model{end}{2},2);
0022 end
0023
0024
0025 results = struct;
0026 results.params = zeros(length(trainfun),numparams);
0027 results.testdata = cell(1,length(trainfun));
0028 results.modelpred = cell(1,length(trainfun));
0029 results.modelfit = cell(1,length(trainfun));
0030 results.trainperformance = zeros(1,length(trainfun));
0031 results.testperformance = zeros(1,length(trainfun));
0032 results.aggregatedtestperformance = [];
0033 if islinear
0034 results.numiters = [];
0035 results.resnorms = [];
0036 else
0037 results.numiters = zeros(length(trainfun),size(opt.model{1}{1},1),length(opt.model));
0038 results.resnorms = zeros(length(trainfun),size(opt.model{1}{1},1));
0039 end
0040
0041
0042 for rr=1:length(trainfun)
0043 fprintf(' starting resampling case %d of %d.\n',rr,length(trainfun));
0044
0045
0046 trainstim = feval(trainfun{rr},stimulus);
0047 traindata = feval(trainfun{rr},opt.data);
0048 trainT = projectionmatrix(feval(trainfun{rr},tmatrix));
0049 trainS = projectionmatrix(feval(trainfun{rr},smatrix));
0050 teststim = feval(testfun{rr},stimulus);
0051 testdata = feval(testfun{rr},opt.data);
0052 testT = projectionmatrix(feval(testfun{rr},tmatrix));
0053 testS = projectionmatrix(feval(testfun{rr},smatrix));
0054 if wantmodelfit
0055 allstim = catcell(1,stimulus);
0056 end
0057
0058
0059 traindataT = trainT*traindata;
0060
0061
0062 if ~islinear
0063 datastd = std(traindataT);
0064 if datastd == 0
0065 datastd = 1;
0066 end
0067 traindataT = traindataT / datastd;
0068 end
0069
0070
0071 if ~islinear
0072 options = opt.optimoptions;
0073 if ~isempty(opt.outputfcn)
0074 if nargin(opt.outputfcn) == 4
0075 options.OutputFcn = @(a,b,c) feval(opt.outputfcn,a,b,c,traindataT);
0076 else
0077 options.OutputFcn = opt.outputfcn;
0078 end
0079 end
0080 end
0081
0082
0083 if islinear
0084
0085
0086
0087 finalparams = feval(opt.model,trainT*mean(trainstim,3),traindataT);
0088
0089
0090 fprintf(' the estimated parameters are ['); ...
0091 fprintf('%.3f ',finalparams); fprintf('].\n');
0092
0093
0094 else
0095
0096
0097 params = [];
0098 for ss=1:size(opt.model{1}{1},1)
0099 if ismultipleseeds
0100 fprintf(' trying seed %d of %d.\n',ss,size(opt.model{1}{1},1));
0101 end
0102
0103
0104 for mm=1:length(opt.model)
0105
0106
0107 ix = ~isnan(opt.model{mm}{2}(1,:));
0108
0109
0110 if mm==1
0111 seed = opt.model{mm}{1}(ss,:);
0112 model = opt.model{mm}{3};
0113 transform = opt.model{mm}{4};
0114 else
0115 seed = feval(opt.model{mm}{1},params0);
0116 model = feval(opt.model{mm}{3},params0);
0117 transform = feval(opt.model{mm}{4},params0);
0118 end
0119
0120
0121
0122
0123 if size(trainstim,3) > 1
0124 nums = repmat(size(trainstim,3),[1 size(trainstim,1)]);
0125 model = @(pp,dd) chunkfun(feval(model,pp,squish(permute(dd,[3 1 2]),2)),nums,@(x) mean(x,1)).';
0126 end
0127
0128
0129 if isequal(options.Algorithm,'levenberg-marquardt')
0130 lb = [];
0131 ub = [];
0132 else
0133 lb = opt.model{mm}{2}(1,ix);
0134 ub = opt.model{mm}{2}(2,ix);
0135 end
0136
0137
0138 trainstimTRANSFORM = feval(transform,trainstim);
0139
0140
0141 fun = @(pp) trainT*feval(model,copymatrix(seed,ix,pp),trainstimTRANSFORM);
0142
0143
0144 if ismultiplemodels
0145 fprintf(' for model %d of %d, the seed is [', ...
0146 mm,length(opt.model)); fprintf('%.3f ',seed); fprintf('].\n');
0147 else
0148 fprintf(' the seed is ['); fprintf('%.3f ',seed); fprintf('].\n');
0149 end
0150
0151
0152 if ~any(ix)
0153 params0 = seed;
0154 resnorm = NaN;
0155 output = [];
0156 output.iterations = NaN;
0157 else
0158 [params0,resnorm,residual,exitflag,output] = ...
0159 lsqcurvefit(@(x,y) double(nanreplace(feval(fun,x) / datastd,0,2)),seed(ix),[],double(traindataT),lb,ub,options);
0160 params0 = copymatrix(seed,ix,params0);
0161 end
0162
0163
0164 fprintf(' the estimated parameters are ['); ...
0165 fprintf('%.3f ',params0); fprintf('].\n');
0166
0167
0168 results.numiters(rr,ss,mm) = output.iterations;
0169
0170 end
0171
0172
0173 results.resnorms(rr,ss) = resnorm;
0174 params(ss,:) = params0;
0175
0176 end
0177
0178
0179 [d,mnix] = min(results.resnorms(rr,:));
0180 finalparams = params(mnix,:);
0181
0182 end
0183
0184
0185 results.params(rr,:) = finalparams;
0186
0187
0188 if ~islinear && ismultipleseeds
0189 fprintf(' seed %d was best. final estimated parameters are [',mnix); ...
0190 fprintf('%.3f ',finalparams); fprintf('].\n');
0191 end
0192
0193
0194
0195 traindatatemp = trainS*traindata;
0196 if islinear
0197 modelfittemp = trainS*(trainstim*finalparams');
0198 else
0199 modelfittemp = nanreplace(trainS*feval(model,finalparams,trainstimTRANSFORM),0,2);
0200 end
0201 if isempty(testdata)
0202 results.testdata{rr} = [];
0203 results.modelpred{rr} = [];
0204 else
0205 results.testdata{rr} = testS*testdata;
0206 if islinear
0207 results.modelpred{rr} = testS*(teststim*finalparams');
0208 else
0209 results.modelpred{rr} = nanreplace(testS*feval(model,finalparams,feval(transform,teststim)),0,2);
0210 end
0211 end
0212
0213
0214 if wantmodelfit
0215 if islinear
0216 results.modelfit{rr} = (allstim*finalparams')';
0217 else
0218 results.modelfit{rr} = nanreplace(feval(model,finalparams,feval(transform,allstim)),0,2)';
0219 end
0220 else
0221 results.modelfit{rr} = [];
0222 end
0223
0224
0225 results.trainperformance(rr) = feval(opt.metric,modelfittemp,traindatatemp);
0226 if isempty(results.testdata{rr})
0227 results.testperformance(rr) = NaN;
0228 else
0229 results.testperformance(rr) = feval(opt.metric,results.modelpred{rr},results.testdata{rr});
0230 end
0231
0232
0233 fprintf(' trainperformance is %.2f. testperformance is %.2f.\n', ...
0234 results.trainperformance(rr),results.testperformance(rr));
0235
0236 end
0237
0238
0239 results.testdata = catcell(1,results.testdata);
0240 results.modelpred = catcell(1,results.modelpred);
0241 results.modelfit = catcell(1,results.modelfit);
0242 if isempty(results.testdata)
0243 results.aggregatedtestperformance = NaN;
0244 else
0245 results.aggregatedtestperformance = feval(opt.metric,results.modelpred,results.testdata);
0246 end
0247
0248
0249 fprintf(' aggregatedtestperformance is %.2f.\n',results.aggregatedtestperformance);