
# $+HEADER$
# 
# Copyright 2017-2018 Christoph Lueders
# 
# This file is part of the PTCUT project: <http://wrogn.com/ptcut>
# 
# PTCUT is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or 
# (at your option) any later version.
# 
# PTCUT is distributed in the hope that it will be useful, 
# but WITHOUT ANY WARRANTY; without even the implied warranty of 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
# GNU Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public License
# along with PTCUT.  If not, see <http://www.gnu.org/licenses/>.
# 
# $-HEADER$

from __future__ import print_function, division
from sage.all_cmdline import *   # import sage library

import time
from util import mytime, prt

svnrev = int("$Rev: 6627 $".split()[1])


def handle_params(ps):
   """
   Convert parameters from Params.txt file into dictionary.

   >>> sorted(handle_params(["k1 = 1.2121", "k2 = 0.13889", "k3 = 1e-12", "k4 = 19000000", "k5 = 25", "k6 = 0.813", "k7 = 0.557", "k8 = 400000", "k9 = 100000"]).items())
   [('k1', 1.2121), ('k2', 0.13889), ('k3', 1e-12), ('k4', 19000000.0), ('k5', 25.0), ('k6', 0.813), ('k7', 0.557), ('k8', 400000.0), ('k9', 100000.0)]
   """
   d = {}
   for p in ps:
      if not p.strip():
         continue
      ar = p.split("=")
      assert len(ar) == 2
      d[ar[0].strip()] = float(ar[1])
   return d


def logep(x, ep, scale=1):
   """
   >>> logep(3000,0.2)
   -5
   >>> logep(0.0001,0.2)
   6
   >>> logep(0.0001,1/5)
   6
   >>> logep(0.0001,Integer(1)/Integer(5))
   6
   >>> logep(100,10)
   2
   >>> logep(99,10,10)
   20
   >>> logep(80,10,10)
   19
   """
   return int(round(log(x,ep) * scale))


def mysgn(x):
   """
   Simple self-made signum function.

   >>> mysgn(12)
   1
   >>> mysgn(-1009912)
   -1
   >>> mysgn(0.001232)
   1
   >>> mysgn(0)
   0
   """
   return -1 if x < 0 else 1 if x > 0 else 0


def handle_polysys(str_polys, params, ep, sumup, verbose=0, keep_coeff=False, paramwise=True, quiet=False, scale=1):
   """
   ps is the polynomial system, one poly per line as string

   >>> d = handle_params(["k1 = 1.2121", "k2 = 0.13889", "k3 = 1e-12", "k4 = 19000000", "k5 = 25", "k6 = 0.813", "k7 = 0.557", "k8 = 400000", "k9 = 100000"])
   >>> s = ["-k4*x2*(-k8+x2+k9)+k5*(k8-x2-x4)", "k6*(k8-x2-x4)-k7*x4"]
   >>> [sorted(i.items()) for i in handle_polysys(s, d, 1.0/5, True, 0, True, False, True)]
   [[((-18, 1, 0), 1), ((-10, 0, 0), 1), ((-10, 2, 0), -1), ((-2, 0, 1), -1)], [((-8, 0, 0), 1), ((0, 0, 1), -1), ((0, 1, 0), -1)]]
   """
   if quiet:
      verbose = False
   if paramwise:
      assert scale == 1
      # Satya does parameter-wise rescaling.
      # To make this easy, round parameters according to ep before they are used.
      p2 = {}
      for k,v in params.items():
         p2[k] = ep ** logep(v, ep) if v > 0 else v
      params = p2
   polys = []
   vars = set()
   from sympy import sympify
   for p in str_polys:
      p = p.split(";")[0]        # remove comments
      if not p.split():          # ignore empty lines
         continue
      if not quiet:
         prt(end="[{}]".format(len(polys)), flush=True)
      p = p.strip()
      if verbose:
         prt("\ninput: {}".format(p), flush=True)
      f = sympify(p)             # convert string to sympy object
      if verbose:
         prt("sympify into: {}".format(f), flush=True)
      f = f.expand()             # expand representation into terms
      if verbose:
         prt("expand into: {}".format(f), flush=True)
      # substitute "k" parameters
      if sumup:
         f = f.subs(params)
      else:
         f1 = ""
         for t in f.as_ordered_terms():         # list of terms
            if not keep_coeff:
               cm = t.as_coeff_mul()
               if abs(cm[0]) != 1:
                  t = sympify(str(t) + "/" + str(abs(cm[0])))          # remove constant
            t = t.subs(params)                  # substitute "k" parameters
            f1 += (" + " if f1 else "") + str(t)
         f = sympify(f1, evaluate=False)
      if verbose:
         prt("subs into: {}".format(f))
      polys.append(f)
      if f.free_symbols:
         vars |= f.free_symbols
      if verbose:
         prt(flush=True)
   # sort variables alphabetically
   if not all([i[0] == "x" and i[1:].isdigit() for i in [str(j) for j in list(vars)]]):
      prt("all variables should be called xNNN, where NNN are digits")
   vars = sorted(list(vars), key=lambda x: int(str(x)[1:]))     # all variables still left
   if verbose:
      prt("variables: {}".format(vars))

   trop = []
   for p in polys:
      d = {}
      for t in p.as_ordered_terms():         # list of terms
         tv = [0] * (len(vars) + 1)
         coeff, vprod = t.as_coeff_Mul()     # split off the coeffictient
         #prt("{}*{}".format(coeff, vprod))
         # "vprod" are the vars, like x1*x2**2
         # coeff is the multiplicative factor
         tv[0] = logep(abs(coeff), ep, scale)
         sign = mysgn(coeff)
         for v, exp in vprod.as_powers_dict().items():
            #prt("  {}**{}".format(v, exp))
            if v.is_constant():
               assert v == 1                 # constant was consumed as coeff and put into tv[0]
            else:
               i = vars.index(v) + 1         # index 0 is the contant
               tv[i] = int(exp)
         k = tuple(tv)                       # must convert to immutable tuple and int to remove sympy.core.numbers.One
         if k in d:
            if d[k] != sign:
               # if we have two tropical points with different signs, save the sign as zero.
               # this will disable opposite sign checking for this point.
               d[k] = 0
         else:
            d[k] = sign
      trop.append(d)
      if verbose:
         prt("trop: {}".format(d), flush=True)
   return trop


epr5 = Integer(1) / Integer(5)
def tropicalize_system(mod, mod_dir, ep=epr5, sumup=True, verbose=0, keep_coeff=False, paramwise=True, param_override={}, scale=1):
   # read parameters from file
   with open(mod_dir + "Params.txt") as f:
      str_params = f.readlines()
      params = handle_params(str_params)
   # apply overrides, if any given.  those must be str/float pairs
   params.update(param_override.items())
   # read polynomial system from file
   with open(mod_dir + "Polynomial_system.txt") as f:
      str_polysys = f.readlines()
      start = mytime()
      r = handle_polysys(str_polysys, params, ep, sumup, verbose, keep_coeff, paramwise, scale=scale)
      total = mytime() - start
      prt("\nTropicalization time: {} sec".format(total))
   return r


def read_grid_data(ss):
   """
   >>> read_grid_data("k1:100:200:10")
   [('k1', 100.0, 200.0, 10.0, False)]
   >>> read_grid_data("k10:1.5:2.5:0.1,k11:5:9:1")
   [('k10', 1.5, 2.5, 0.1, False), ('k11', 5.0, 9.0, 1.0, False)]
   >>> read_grid_data("")
   []
   >>> read_grid_data("k1:100:200:*10")
   [('k1', 100.0, 200.0, 10.0, True)]
   >>> read_grid_data("k1:100")
   [('k1', 100.0, 100.0, 1.0, False)]
   """
   r = []
   for s in ss.split(","):
      if s:
         a = s.split(":")
         a2 = a[2] if len(a) > 2 else a[1]
         if len(a) > 3:
            if a[3][0] == "*":
               mult = True
               a3 = a[3][1:]
            else:
               mult = False
               a3 = a[3]
         else:
            mult = False
            a3 = 1
         r.append((a[0], float(a[1]), float(a2), float(a3), mult))
   return r


def sample_grid(grid):
   """
   >>> list(sample_grid([("k1",1,5,1,False)]))
   [[('k1', 1)], [('k1', 2)], [('k1', 3)], [('k1', 4)], [('k1', 5)]]
   >>> list(sample_grid([]))
   [{}]
   >>> list(sample_grid([("k1",1,3,1,False), ("k2",6,7,1,False)]))
   [[('k1', 1), ('k2', 6)], [('k1', 2), ('k2', 6)], [('k1', 3), ('k2', 6)], [('k1', 1), ('k2', 7)], [('k1', 2), ('k2', 7)], [('k1', 3), ('k2', 7)]]
   >>> list(sample_grid([("k1",1,8,2,True)]))
   [[('k1', 1)], [('k1', 2)], [('k1', 4)], [('k1', 8)]]
   >>> list(sample_grid([("k1",1,8,2,True), ("k2",1,10,10,True)]))
   [[('k1', 1), ('k2', 1)], [('k1', 2), ('k2', 1)], [('k1', 4), ('k2', 1)], [('k1', 8), ('k2', 1)], [('k1', 1), ('k2', 10)], [('k1', 2), ('k2', 10)], [('k1', 4), ('k2', 10)], [('k1', 8), ('k2', 10)]]
   """
   if not grid:
      yield {}
   else:
      vars = [g[0] for g in grid]
      cnt = [g[1] for g in grid]            # start values
      while True:
         yield zip(vars, cnt)
         for i in range(len(cnt)):
            if grid[i][4]:
               assert grid[i][3] != 1
               cnt[i] *= grid[i][3]         # multiply by step size
            else:
               assert grid[i][3] != 0
               cnt[i] += grid[i][3]         # increment by step size
            if cnt[i] > grid[i][2]:         # over upper limit?  (inclusive!)
               cnt[i] = grid[i][1]
            else:
               break                         # values ok, exit loop
         else:
            break                            # all positions covered, exit outer loop


# less than 0.5 sec
biomd_nosol = [
   "BIOMD0000000108",
   "BIOMD0000000220p",
   "BIOMD0000000342",
]

biomd_fast = sorted(biomd_nosol + [
   "BIOMD0000000005c_modified",
   "BIOMD0000000027_transfo", "BIOMD0000000029_transfo", "BIOMD0000000031_transfo",
   "BIOMD0000000035", "BIOMD0000000040", "BIOMD0000000072", "BIOMD0000000077", "BIOMD0000000101", "BIOMD0000000125", "BIOMD0000000150",
   "BIOMD0000000156", "BIOMD0000000159", "BIOMD0000000193", "BIOMD0000000194", "BIOMD0000000198", "BIOMD0000000199",
   "BIOMD0000000233", "BIOMD0000000257", "BIOMD0000000257c",
   "BIOMD0000000289", "BIOMD0000000361", "BIOMD0000000459", "BIOMD0000000460" ])

# more than 0.5 sec
biomd_slow = [ "BIOMD0000000001", "BIOMD0000000002", "BIOMD0000000009", "BIOMD0000000009p", "BIOMD0000000026", "BIOMD0000000026c",
   "BIOMD0000000028", "BIOMD0000000030", "BIOMD0000000038", "BIOMD0000000046",
   "BIOMD0000000080", "BIOMD0000000082", "BIOMD0000000102", "BIOMD0000000122", "BIOMD0000000123",
   "BIOMD0000000163", "BIOMD0000000226", "BIOMD0000000270",
   "BIOMD0000000287" ]

biomd_slowhull = [ "BIOMD0000000001", "BIOMD0000000002", "BIOMD0000000026", "BIOMD0000000028", "BIOMD0000000030", "BIOMD0000000038", "BIOMD0000000046",
   "BIOMD0000000080", "BIOMD0000000082", "BIOMD0000000123", "BIOMD0000000163", "BIOMD0000000270" ]

biomd_hard = [ "BIOMD0000000146_numer", "bluthgen0", "bluthgen1", "bluthgen2" ]
biomd_toohard = [ "BIOMD0000000019", "BIOMD0000000255" ]

biomd_simple = sorted(biomd_fast + biomd_slow)
biomd_all = sorted(biomd_simple + biomd_hard)
biomd_easy = sorted(set(biomd_simple) - set(["BIOMD0000000102"]))


def load_satyas_solution(mod):
   try:
      from _biomdsoldb import biomd_sol_db
   except ImportError:
      return None
   if mod not in biomd_sol_db:
      return None
   return biomd_sol_db[mod]()


if __name__ == "__main__":
   import doctest
   doctest.testmod()
