/* projgrad.c ************************************************
   Steven Dirkse
   Computer Science Department, UW-Madison
   Source for a (simple) projected-gradient solver for MCP
   *******************************************************/

#include <stdio.h>
#include <malloc.h>
#include <math.h>

#ifdef GAMSLINK
#include "c_cplib.h"
#elif defined(AMPLLINK)
#include "mcp.h"
#else
bad_compile_defines;
#endif


#define MEMALLOC(type,num) ((type *) Mymalloc(sizeof(type)*(num)))
#define PROJECT(l,z,u) ( (z) < (l) ? (l) : ((z) > (u) ? (u) : (z)) )

void projected_gradient (INT n);
DREAL nonsmooth_norm (INT n, DREAL z[], DREAL F[], DREAL lb[], DREAL ub[]);
void *Mymalloc(long Len);

INT iterlimit, iosta, iolog, screen;
DREAL *z, *bl, *bu, *F;
CHAR msgbuf[256];

#ifdef GAMSLINK
#ifdef POSTUC
void corerq_ (void)
#else
void corerq (void)
#endif
{
  c_cpputi ("nwucor", 1);	/* ask for one "word", it won't be used */
  c_cpputi ("istype", 2);	/* istype == 2 means we solve general MCP's */
  return;
}

#ifdef POSTUC
void solver_ (DREAL *work, INT *nwucor)
#else
void solver (DREAL *work, INT *nwucor)
#endif
{
  INT n;

  /* get n, unit numbers, iterations limit */
  n = c_cpgeti ("n");
  iosta = c_cpgeti ("iosta");
  iolog = c_cpgeti ("iolog");
  screen = c_cpgeti ("screen");
  iterlimit = c_cpgeti ("iterlim");

  c_cpputi ("startc", 0);
  sprintf (msgbuf, "Sample solver programmed by Steve Dirkse");
  if (screen != iolog)
    c_print_msg (screen, msgbuf);
  c_print_msg (iolog, msgbuf);
  c_print_msg (iosta, msgbuf);
  c_cpputi ("stopc", 0);

  projected_gradient (n);
  return;
}
#endif

/* Not a complete implementation, just a model */
void projected_gradient (INT n)
{
  int iteration = 0,
  i;
  CHAR *s;
  DREAL metric, obj,
  stepsize = 0.5;		/* should be set by user, but . .  */

  bl = MEMALLOC (DREAL,n);
  bu = MEMALLOC (DREAL,n);
  z = MEMALLOC (DREAL,n);
  F = MEMALLOC (DREAL,n);

  /* get initial iterate and lower, upper bounds */
  /* then evaluate F */
#ifdef GAMSLINK
#ifdef POSTUC
  cpbnds_ (z, bl, bu, &n);
  cpfunf_ (z, F, &n);
#else
  cpbnds (z, bl, bu, &n);
  cpfunf (z, F, &n);
#endif
#elif defined(AMPLLINK)
  mcp_bounds (n, bl, bu);
  mcp_init_z (n, z);
  mcp_F (n, z, F);
#else
  bad_compile_defines;
#endif

  metric = nonsmooth_norm (n, z, F, bl, bu);
#ifdef GAMSLINK
  s = msgbuf;
  sprintf (s, "\n iterate\t  residual norm\t     CPLIB norm\n");
  while (*s) s++;
  sprintf (s,   " -------\t  -------------\t     ----------");
  c_print_msg (iolog, msgbuf);
  obj = c_cpgetd ("obj");
  sprintf (msgbuf, "%8d\t%15.7f\t%15.7f", iteration, metric, obj);
  c_print_msg (iolog, msgbuf);
#else 
  fprintf (stdout, "\n iterate\t  residual norm\n");
  fprintf (stdout, " -------\t  -------------\n");
  fprintf (stdout, "%8d\t%15.7f", iteration, metric);
#endif

  while (iteration < iterlimit) {
    if (metric < 1e-6) {	/* convergence! */
#ifdef GAMSLINK
#ifdef POSTUC
      cpsoln_ (z, &n);
#else
      cpsoln (z, &n);
#endif
      c_cpputi ("modsta", MODEL_SOLVED);
      c_cpputi ("solsta", SOLU_NORMAL);
#else /* AMPL link */
      mcp_report_soln ("Solution found", n, z, F);
#endif
      return;
    }
    for (i = 0;  i < n;  i++) {
      z[i] -= stepsize * F[i];
      z[i] = PROJECT(bl[i], z[i], bu[i]);
    }
    iteration++;
#ifdef GAMSLINK
#ifdef POSTUC
    cpfunf_ (z, F, &n);
#else
    cpfunf (z, F, &n);
#endif
#else  /* AMPL link */
    mcp_F (n, z, F);
#endif
    metric = nonsmooth_norm (n, z, F, bl, bu);
#ifdef GAMSLINK
    obj = c_cpgetd ("obj");
    sprintf (msgbuf, "%8d\t%15.7f\t%15.7f", iteration, metric, obj);
    c_print_msg (iolog, msgbuf);
#else  /* AMPL link */
    fprintf (stdout, "%8d\t%15.7f", iteration, metric);
#endif
  }
#ifdef GAMSLINK
  c_cpputi ("modsta", MODEL_NOT_SOLVED);
  c_cpputi ("solsta", SOLU_ITERATION);
#else  /* AMPL link */
  mcp_report_soln ("Solution not found: iteration limit", n, NULL, NULL);
#endif
  return;
}

#define MAX(a,b) ((a) > (b) ? (a) : (b))
#define MIN(a,b) ((a) < (b) ? (a) : (b))
DREAL
nonsmooth_norm (INT n, DREAL z[], DREAL F[], DREAL lb[], DREAL ub[])
{
  int i;
  double norm,
  dt;

  norm = 0;
  for (i = 0;  i < n;  i++) {
    if (z[i] <= lb[i]) {
      dt = MIN(0,F[i]);
      norm += dt*dt;
    }
    else if (z[i] >= ub[i]) {
      dt = MAX(0,F[i]);
      norm += dt*dt;
    }
    else
      norm += F[i]*F[i];
  }
  return (sqrt(norm));
}

void *Mymalloc(long Len)
{
  void *rv;
  size_t len = (size_t) Len;

  rv = NULL;
  if (sizeof(Len) != sizeof(len) && Len != (long)len) {
    fprintf (stdout, "%s(%lu) failure: %s.\n",
             "Mymalloc", Len, "Too large a request");
#ifdef GAMSLINK
#ifdef POSTUC
    cppunt_ ();
#else
    cppunt ();
#endif /* ! defined(POSTUC) */

#elif defined(AMPLLINK)
    exit(-1);
#elif defined(STANDALONE)
    exit(-1);
#endif

  }
  rv = malloc(len);
  if (!rv) {
    fprintf (stdout, "%s(%lu) failure: %s.\n",
             "Mymalloc", Len, "Ran out of memory");
#ifdef GAMSLINK
#ifdef POSTUC
    cppunt_ ();
#else
    cppunt ();
#endif /* ! defined(POSTUC) */

#elif defined(AMPLLINK)
    exit(-1);
#elif defined(STANDALONE)
    exit(-1);
#endif
  }
  return rv;
}
