;; Copyright (c) 2023 
;; SPDX-License-Identifier: MIT
#!r6rs

(library (conbot private pty)
  (export
    open-pty-process
    process?
    process-id
    process-exit-code
    process-exit-signal
    process-pty-fd
    process-pty-port
    process-stdin-port
    process-stdout-port
    process-args
    process-env
    process-kill
    process-close
    process-alive?
    process-io-error?
    process-resize
    process-wait
    current-directory)
  (import (rnrs)
          (pffi))

(define libpty (open-shared-object "libpty.so"))
(define pty_open (foreign-procedure libpty int pty_open
                   (pointer pointer int int pointer pointer pointer pointer pointer)))
(define pty_resize (foreign-procedure libpty int pty_resize
                   (int int int int)))
(define getcwd (foreign-procedure #f pointer getcwd (pointer int)))
(define process_wait (foreign-procedure libpty int process_wait (int int)))

(define EINTR 4)
(define EIO 5)
(define EAGAIN 11)

(define (pointer->string p)
  (let loop ((i 0) (l '()))
    (let ((c (pointer-ref-c-uint8 p i)))
      (if (zero? c)
	  (list->string (reverse l))
	  (loop (+ i 1) (cons (integer->char c) l))))))

(define sys-open (foreign-procedure #f int open (pointer int)))
(define sys-read (foreign-procedure #f int read (int pointer int)))
(define sys-write (foreign-procedure #f int write (int pointer int)))
(define sys-close (foreign-procedure #f int close (int)))
(define sys-kill (foreign-procedure #f int kill (int int)))

(define-foreign-variable #f int errno)

(define (handle-read-error fd fname err)
  (error 'pty-read "Error during reading from pty" fd fname err))

(define (read/bv proc fd fname bv start count)
  (let retry ()
    (let ([ret (sys-read fd (integer->pointer (fx+ (pointer->integer (bytevector->pointer bv)) start)) count)])
      (when (fx<=? ret -1)
        (cond ((eqv? errno EAGAIN)
               (retry))
              ((eqv? errno EINTR)
               (retry))
              ((eqv? errno EIO)
               (when proc
                 (process-io-error?-set! proc #t))
               (set! ret 0))
              (else
               (handle-read-error fd fname errno))))
    ret)))

(define (handle-write-error fd fname err)
  (error 'pty-write "Error during writing from pty" fd fname err))

(define (write/bv proc fd fname bv start count)
  (let retry ()
    (let ([ret (sys-write fd (integer->pointer (fx+ (pointer->integer (bytevector->pointer bv)) start)) count)])
      (when (fx<=? ret -1)
        (cond ((eqv? errno EAGAIN)
               (retry))
              ((eqv? errno EINTR)
               (retry))
              ((eqv? errno EIO)
               (set! ret 0))
              (else
               (handle-write-error fd fname errno))))
    ret)))

(define (make-binary-input-port-from-fd fd filename)
  (define (read! bv start count)
    (assert (fx<=? (fx+ start count) (bytevector-length bv)))
    (read/bv #f fd filename bv start count))
  (define get-position #f)
  (define set-position! #f)
  (define (close)
    (sys-close fd))
  (make-custom-binary-input-port
    filename read! get-position set-position! close)
)

(define (make-binary-output-port-from-fd fd filename)
  (define (write! bv start count)
    (assert (fx<=? (fx+ start count) (bytevector-length bv)))
    (write/bv #f fd filename bv start count))
  (define get-position #f)
  (define set-position! #f)
  (define (close)
    (sys-close fd))
  (make-custom-binary-output-port
    filename write!  get-position set-position! close)
)

(define (make-process-input/output-port proc)
  (let ([fd (process-pty-fd proc)]
        [filename (string-append "pty-" (number->string (process-id proc)))])
    (define (read! bv start count)
      (assert (fx<=? (fx+ start count) (bytevector-length bv)))
      (read/bv proc fd filename bv start count))
    (define (write! bv start count)
      (assert (fx<=? (fx+ start count) (bytevector-length bv)))
      (write/bv proc fd filename bv start count))
    (define get-position #f)
    (define set-position! #f)
    (define (close)
      (sys-close fd))
    (make-custom-binary-input/output-port
      filename read! write!  get-position set-position! close))
)

(define (string->cstr str)
  (if str
      (let* ([tmp (string->utf8 str)]
             [bv (make-bytevector (fx+ (bytevector-length tmp) 1) 0)])
        (bytevector-copy! tmp 0 bv 0 (bytevector-length tmp))
        (bytevector->pointer bv))
    (integer->pointer 0)))

(define (string-list->null-terminated-cstr-array ls)
  (let ([arr-ptr (bytevector->pointer
                   (make-bytevector (fx* (fx+ (length ls) 1) size-of-pointer)
                                    0))])
    (do ([i 0 (fx+ i 1)]) ((fx=? i (length ls)))
      (let ([offs (fx* i size-of-pointer)])
        (pointer-set-c-pointer! arr-ptr offs (string->cstr (list-ref ls i)))))
    arr-ptr))

(define (list/kv->list envs)
  (let loop ([envs envs] [ls '()])
    (if (null? envs)
        ls
      (let* ([kv (car envs)]
             [k (car kv)]
             [v (if (fx>? (length kv) 1) (cadr kv) #f)])
        (loop (cdr envs) (append ls (append (list k) (list v))))
      ))))

(define-record-type process
  (fields id
          (mutable exit-code)
          (mutable exit-signal)
          pty-fd
          (mutable pty-port)
          stdin-port
          stdout-port
          args
          env
          (mutable io-error?)))

(define open-pty-process
  (case-lambda
    [(prog)
     (open-pty-process prog #f #f #f)]

    [(prog args)
     (open-pty-process prog args #f)]

    [(prog args envs)
     (open-pty-process prog args envs #f #f)]

    [(prog args envs cwd)
     (open-pty-process prog args envs cwd #f)]

    [(prog args envs cwd redir?)
     (let* ([pty-ptr (bytevector->pointer (make-bytevector size-of-pointer))]
            [to-ptr (if redir?
                        (bytevector->pointer (make-bytevector size-of-pointer))
                        (integer->pointer 0))]
            [from-ptr (if redir?
                          (bytevector->pointer (make-bytevector size-of-pointer))
                          (integer->pointer 0))]
            [prog-ptr (string->cstr prog)]
            [cwd-ptr (string->cstr cwd)]
            [args-ptr (string-list->null-terminated-cstr-array (append (list prog) (or args '())))]
            [envs-ptr (string-list->null-terminated-cstr-array (list/kv->list (or envs '())))])
       (let* ([pid (pty_open pty-ptr prog-ptr 25 80 args-ptr cwd-ptr envs-ptr to-ptr from-ptr)]
              [pty-fd (pointer-ref-c-int pty-ptr 0)]
              [proc (make-process pid #f #f pty-fd #f
                           (and redir? (let ([name (string-append "stdin-" (number->string pid))]
                                             [fd (pointer-ref-c-int to-ptr 0)])
                                         (make-binary-output-port-from-fd fd name)))

                           (and redir? (let ([name (string-append "stdout-" (number->string pid))]
                                             [fd (pointer-ref-c-int from-ptr 0)])
                                         (make-binary-input-port-from-fd fd name)))
                           args envs #f)])
         (process-pty-port-set! proc (make-process-input/output-port proc))
         proc
       ))]
))

(define (process-alive? p)
  (and (not (process-io-error? p))
       (file-exists? (string-append "/proc/" (number->string (process-id p))))))

(define (process-kill p)
  (sys-kill (fx- (process-id p)) 9))

(define (process-close p)
  (close-port (process-pty-port p))
  (and (process-stdin-port p))
    (close-port (process-stdin-port p))
  (and (process-stdout-port p)
    (close-port (process-stdout-port p))))

(define (process-resize p rows cols)
  (pty_resize (process-id p) (process-pty-fd p) rows cols))

(define MAX-PATH 512)

(define (current-directory)
  (let* ([buf (make-bytevector (fx+ MAX-PATH 1))]
         [ret (getcwd (bytevector->pointer buf) (fx+ MAX-PATH 1))])
    (pointer->string ret)))

(define process-wait
  (case-lambda
    [(p)
     (process-wait p 30000) ]

    [(p timeout)
     (let* ([ret (process_wait (process-id p) timeout)]
            [val (fxarithmetic-shift-right ret 8)]
            [st (fxand val #xFF)])
       (cond
         ((fx=? val 0)
          (process-exit-code-set! p st)
          #t)
         ((fx=? val 1)
          (process-exit-signal-set! p st)
          #t)
         ((fx=? val 3)
          #t)
         (else #f)))
    ]))

)
