#include <linux/module.h>
#include <linux/init.h>
#include <linux/in.h>
#include <net/sock.h>

static const unsigned short server_port = 5555;
static struct socket *serversocket = NULL;
static DECLARE_COMPLETION(threadcomplete);
static int echo_server_pid;

static int create_socket(void)
{
    struct sockaddr_in server;
    int servererror;

    if (sock_create_kern(PF_INET, SOCK_STREAM, IPPROTO_TCP, &serversocket) < 0) {
        printk(KERN_ERR "server: Error creating serversocket.\n");
        return -EIO;
    }
    server.sin_family      = AF_INET;
    server.sin_addr.s_addr = INADDR_ANY;
    server.sin_port        = htons((unsigned short)server_port);
    servererror            = kernel_bind(serversocket,
        (struct sockaddr *) &server, sizeof(server));
    if (servererror)
        goto release;
    servererror = kernel_listen(serversocket, 3);
    if (servererror)
        goto release;
    return 0;

release:
    sock_release(serversocket);
    printk(KERN_ERR "server: Error serversocket\n");
    return -EIO;
}

static struct socket *socket_accept(struct socket *server)
{
    struct sockaddr address;
    struct socket *clientsocket = NULL;
    int error;

    if (!server) return NULL;

    error = kernel_accept(server, &clientsocket, 0);
    if (error < 0)
        return NULL;

    return clientsocket;
}

static int server_send(struct socket *sock, unsigned char *buf, int len)
{
    struct msghdr msg;
    struct kvec iov;

    if (!sock->sk) 
        return 0;

    iov.iov_base = buf;
    iov.iov_len  = len;

    msg.msg_control    = NULL;
    msg.msg_controllen = 0;
    msg.msg_flags      = 0;

    len = kernel_sendmsg(sock, &msg, &iov, 1, len);
	
    return len;
}

static int server_receive(struct socket *sptr, unsigned char *buf, int len)
{
    struct msghdr msg;
    struct kvec iov;

    if (!sptr->sk) 
        return 0;
    iov.iov_base = buf;
    iov.iov_len  = len;
    msg.msg_control    = NULL;
    msg.msg_controllen = 0;

    len = kernel_recvmsg(sptr, &msg, &iov, 1, len, 0);

    return len;
}

static int echo_server(void *data)
{
    struct socket *clientsocket;
    unsigned char buffer[100];
    static int len;

    daemonize("echo");
    allow_signal(SIGTERM);
    while (!signal_pending(current)) {
        clientsocket = socket_accept(serversocket);
        printk("clientsocket(%p)\n", clientsocket);
        while (clientsocket) {
            len = server_receive(clientsocket, buffer, sizeof(buffer));
            if (len > 0) {
                server_send(clientsocket, buffer, len); // echo
            } else {
                sock_release(clientsocket);
                clientsocket = NULL;
            }
        }
    }
    complete(&threadcomplete);
    return 0;
}

static int __init server_init(void)
{
    if (create_socket() < 0)
        return -EIO;

    echo_server_pid = kernel_thread(echo_server, NULL, CLONE_KERNEL);
    printk("echo_server_pid: %d\n", echo_server_pid);
    if (echo_server_pid < 0) {
        printk(KERN_ERR "server: Error creating echo_server\n");
        sock_release(serversocket);
        return -EIO;
    }
    return 0;
}

static void __exit server_exit(void)
{
    printk("server_exit()\n");
    if (echo_server_pid) {
        kill_pid(find_pid_ns(echo_server_pid, &init_pid_ns), SIGTERM, 1);
        wait_for_completion(&threadcomplete);
    }
    if (serversocket)
        sock_release(serversocket);
}

module_init(server_init);
module_exit(server_exit);
MODULE_LICENSE("GPL");