From f99ad4ed9e5611ced9196e5e9cb875667e03bdc1 Mon Sep 17 00:00:00 2001 From: Joe Anderson Date: Sun, 13 Nov 2011 15:28:39 -0600 Subject: in the middle of fixing up opt_fun --- diffeq.py | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/diffeq.py b/diffeq.py index c9447ac..1a756e1 100644 --- a/diffeq.py +++ b/diffeq.py @@ -58,7 +58,7 @@ class DiffEq(object): if key[0] == "_": continue if type(value) != float: continue self.opt_vars[key] = value - def time(self): return time.time() - self.time + def _gettime(self): return time.time() - self.time def __getattribute__(self, attr): """Provides a few getattr psuedonyms: @@ -93,30 +93,27 @@ class DiffEq(object): vars(self).remove(key) ssetattr(attr) - def check_in(self, variables): - """Usage is as follows. + def solve_diff_eq(self): + """wrapper for odeint and self.diff_eq - def method(self, var0=None, var1=None, ...): - self.check_in(locals()) - return do_stuff_with_vars() + requires the following: + self.diff_eq + self.init_vars + self.timespace """ - for key, value in variables.items(): - if value == None: - variables[key] = getattr(self, key) - - def solve_diff_eq(self, diff_eq=None, init_vars=None, timespace=None): - """wrapper for odeint and self.diff_eq""" - self.check_in(locals()) - time0 = self.time() - time = numpy.linspace(*timespace) - soln = odeint(diff_eq, init_vars, time) - print " :: took %s seconds" % (self.time() - time0) - return time, soln + time0 = self.gettime() + self.time = numpy.linspace(*self.timespace) + self.soln = odeint(self.diff_eq, self.init_vars, self.time) + print " :: took %s seconds" % (self.gettime() - time0) #in optimize(): vars().update(self.opt_vars) - def opt_fun(guesses, f, func, init_vars): + def opt_fun(self, guesses, f, func, init_vars): """Optimize the free variables for the diff eq.""" - global r0, r_roll, length, m, Mass + for key, val in zip(opt_vars.iterkeys(), guesses): + setattr(self, key, val) + #map(setattr, self, opt_vars.iterkeys(), guesses) + self.__dict__.update(dict(zip(self.opt_vars.iterkeys(), guesses))) + outfile = self.__dict__.get('outfile') r0, r_roll, length, m, Mass = guesses time, soln = solve_fun(func, init_vars) -- cgit v1.2.3-54-g00ecf