From 7c287638c316b285b3dbcdaf56c650164828a42c Mon Sep 17 00:00:00 2001 From: Joe Anderson Date: Tue, 1 Nov 2011 12:05:08 -0500 Subject: solver almost done, needs more testing --- throw.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/throw.py b/throw.py index b144eaa..fecc93d 100644 --- a/throw.py +++ b/throw.py @@ -1,6 +1,6 @@ from time import time as current_time -import sys +import sys, itertools import numpy from scipy.interpolate import interp1d from scipy.integrate import odeint @@ -155,21 +155,19 @@ def solver(func, init_vars, timespace=(0,600,10**5), max_tension=max_tension, break getslice = lambda x: numpy.ndarray.__getslice__(x, 0, max_arg+1) - getmax = lambda x: numpy.ndarray.__getitem__( x, max_arg) + getmax = lambda x: numpy.ndarray.__getitem__( x, max_arg ) slices = map(getslice, (time,) + variables) maxes = map(getmax, (time,) + variables) out = [] for item in returns: out.append(vars()[item]) - return tuple(out) -def plot(time, soln, graphs=['all'], i=7): - phi, phi1, v_tan, tension, moment, momentum, max_arg = \ - interpret(soln) - (phi_max, phi1_max, v_tan_max, tension_max, moment_max, momentum_max, - time_max) = map(lambda x: numpy.ndarray.__getitem__(x, max_arg), - [phi, phi1, v_tan, tension, moment, momentum, time]) +def plot(solved, graphs=['all'], i=7): + time, variables, slices, maxes = solved + phi, phi1, v_tan, tension, moment, momentum = variables + (phi_max, phi1_max, v_tan_max, + tension_max, moment_max, momentum_max) = maxes def test(*args): b = 0 @@ -259,9 +257,8 @@ def opt_fun(guesses, f, func, init_vars): map(lambda x: numpy.ndarray.__getslice__(x, 0, max_arg+1), [phi, phi1, v_tan, tension, moment, momentum, time]) - (time, soln, - (phi, phi1, v_tan, tension, moment, momentum, max_arg), - ( + time, (phi, phi1, v_tan, tension, moment, momentum, max_arg) = \ + solver(func, init_vars, returns=['time', 'slices']) # Parameter values. f.write("%s, " % r_roll) @@ -296,8 +293,8 @@ def optimize(): finally: fil.close() def solve(): - time, soln = solve_fun(throw, inits) - plot(time, soln, ['all', '!gravity']) + solved = solver(throw, inits) + plot(solved, ['all', '!gravity']) pyplot.show() if len(sys.argv) == 2: -- cgit v1.2.3-54-g00ecf