import numpy as np
import scipy.stats as stats
import scipy
import math
import matplotlib.pyplot as plt

import pystan

init    = np.array([1e-4])
n_times = 100
t_min   = 1e-2
t_max   = 30
ts      = np.linspace(t_min, t_max, n_times)

real_alpha = 2
real_gamma = 1

def ode_func(t, xs):
    return np.array([real_alpha - xs[0]*real_gamma])

ode_sol = scipy.integrate.solve_ivp(ode_func, (t_min, t_max),
                                    init, dense_output = True)

real_sigma = 0.05

x_hats = ode_sol.sol(ts)[0]
x_obs  = np.random.lognormal(np.log(x_hats), real_sigma)

model = pystan.StanModel('simplest_ode.stan')

simu_data ={"init": init,
            "n_times": n_times,
            "n_params": 2,
            "n_times": len(ts)-1,
            "n_vars": len(init),
            "t0": 0,
            "ts": ts[1:],
            "xs": x_obs[1:]}

fit = model.sampling(data=simu_data, chains=8)

print(fit.stansummary())

