OpenCores
URL https://opencores.org/ocsvn/or1k_old/or1k_old/trunk

Subversion Repositories or1k_old

[/] [or1k_old/] [trunk/] [rc203soc/] [sw/] [uClinux/] [fs/] [smbfs/] [sock.c] - Diff between revs 1628 and 1765

Go to most recent revision | Only display areas with differences | Details | Blame | View Log

Rev 1628 Rev 1765
/*
/*
 *  sock.c
 *  sock.c
 *
 *
 *  Copyright (C) 1995, 1996 by Paal-Kr. Engstad and Volker Lendecke
 *  Copyright (C) 1995, 1996 by Paal-Kr. Engstad and Volker Lendecke
 *
 *
 */
 */
 
 
#include <linux/sched.h>
#include <linux/sched.h>
#include <linux/smb_fs.h>
#include <linux/smb_fs.h>
#include <linux/errno.h>
#include <linux/errno.h>
#include <linux/socket.h>
#include <linux/socket.h>
#include <linux/fcntl.h>
#include <linux/fcntl.h>
#include <linux/stat.h>
#include <linux/stat.h>
#include <asm/segment.h>
#include <asm/segment.h>
#include <linux/in.h>
#include <linux/in.h>
#include <linux/net.h>
#include <linux/net.h>
#include <linux/mm.h>
#include <linux/mm.h>
#include <linux/netdevice.h>
#include <linux/netdevice.h>
#include <net/ip.h>
#include <net/ip.h>
 
 
#include <linux/smb.h>
#include <linux/smb.h>
#include <linux/smbno.h>
#include <linux/smbno.h>
 
 
 
 
#define _S(nr) (1<<((nr)-1))
#define _S(nr) (1<<((nr)-1))
 
 
static int
static int
_recvfrom(struct socket *sock, unsigned char *ubuf, int size,
_recvfrom(struct socket *sock, unsigned char *ubuf, int size,
          int noblock, unsigned flags, struct sockaddr_in *sa, int *addr_len)
          int noblock, unsigned flags, struct sockaddr_in *sa, int *addr_len)
{
{
        struct iovec iov;
        struct iovec iov;
        struct msghdr msg;
        struct msghdr msg;
 
 
        iov.iov_base = ubuf;
        iov.iov_base = ubuf;
        iov.iov_len = size;
        iov.iov_len = size;
 
 
        msg.msg_name = (void *) sa;
        msg.msg_name = (void *) sa;
        msg.msg_namelen = 0;
        msg.msg_namelen = 0;
        if (addr_len)
        if (addr_len)
                msg.msg_namelen = *addr_len;
                msg.msg_namelen = *addr_len;
        msg.msg_control = NULL;
        msg.msg_control = NULL;
        msg.msg_iov = &iov;
        msg.msg_iov = &iov;
        msg.msg_iovlen = 1;
        msg.msg_iovlen = 1;
 
 
        return sock->ops->recvmsg(sock, &msg, size, noblock, flags, addr_len);
        return sock->ops->recvmsg(sock, &msg, size, noblock, flags, addr_len);
}
}
 
 
static int
static int
_send(struct socket *sock, const void *buff, int len,
_send(struct socket *sock, const void *buff, int len,
      int nonblock, unsigned flags)
      int nonblock, unsigned flags)
{
{
        struct iovec iov;
        struct iovec iov;
        struct msghdr msg;
        struct msghdr msg;
 
 
        iov.iov_base = (void *) buff;
        iov.iov_base = (void *) buff;
        iov.iov_len = len;
        iov.iov_len = len;
 
 
        msg.msg_name = NULL;
        msg.msg_name = NULL;
        msg.msg_namelen = 0;
        msg.msg_namelen = 0;
        msg.msg_control = NULL;
        msg.msg_control = NULL;
        msg.msg_iov = &iov;
        msg.msg_iov = &iov;
        msg.msg_iovlen = 1;
        msg.msg_iovlen = 1;
 
 
        return sock->ops->sendmsg(sock, &msg, len, nonblock, flags);
        return sock->ops->sendmsg(sock, &msg, len, nonblock, flags);
}
}
 
 
static void
static void
smb_data_callback(struct sock *sk, int len)
smb_data_callback(struct sock *sk, int len)
{
{
        struct socket *sock = sk->socket;
        struct socket *sock = sk->socket;
 
 
        if (!sk->dead)
        if (!sk->dead)
        {
        {
                unsigned char peek_buf[4];
                unsigned char peek_buf[4];
                int result;
                int result;
                unsigned short fs;
                unsigned short fs;
 
 
                fs = get_fs();
                fs = get_fs();
                set_fs(get_ds());
                set_fs(get_ds());
 
 
                result = _recvfrom(sock, (void *) peek_buf, 1, 1,
                result = _recvfrom(sock, (void *) peek_buf, 1, 1,
                                   MSG_PEEK, NULL, NULL);
                                   MSG_PEEK, NULL, NULL);
 
 
                while ((result != -EAGAIN) && (peek_buf[0] == 0x85))
                while ((result != -EAGAIN) && (peek_buf[0] == 0x85))
                {
                {
                        /* got SESSION KEEP ALIVE */
                        /* got SESSION KEEP ALIVE */
                        result = _recvfrom(sock, (void *) peek_buf,
                        result = _recvfrom(sock, (void *) peek_buf,
                                           4, 1, 0, NULL, NULL);
                                           4, 1, 0, NULL, NULL);
 
 
                        DDPRINTK("smb_data_callback:"
                        DDPRINTK("smb_data_callback:"
                                 " got SESSION KEEP ALIVE\n");
                                 " got SESSION KEEP ALIVE\n");
 
 
                        if (result == -EAGAIN)
                        if (result == -EAGAIN)
                        {
                        {
                                break;
                                break;
                        }
                        }
                        result = _recvfrom(sock, (void *) peek_buf,
                        result = _recvfrom(sock, (void *) peek_buf,
                                           1, 1, MSG_PEEK,
                                           1, 1, MSG_PEEK,
                                           NULL, NULL);
                                           NULL, NULL);
                }
                }
                set_fs(fs);
                set_fs(fs);
 
 
                if (result != -EAGAIN)
                if (result != -EAGAIN)
                {
                {
                        wake_up_interruptible(sk->sleep);
                        wake_up_interruptible(sk->sleep);
                }
                }
        }
        }
}
}
 
 
int
int
smb_catch_keepalive(struct smb_server *server)
smb_catch_keepalive(struct smb_server *server)
{
{
        struct file *file;
        struct file *file;
        struct inode *inode;
        struct inode *inode;
        struct socket *sock;
        struct socket *sock;
        struct sock *sk;
        struct sock *sk;
 
 
        if ((server == NULL)
        if ((server == NULL)
            || ((file = server->sock_file) == NULL)
            || ((file = server->sock_file) == NULL)
            || ((inode = file->f_inode) == NULL)
            || ((inode = file->f_inode) == NULL)
            || (!S_ISSOCK(inode->i_mode)))
            || (!S_ISSOCK(inode->i_mode)))
        {
        {
                printk("smb_catch_keepalive: did not get valid server!\n");
                printk("smb_catch_keepalive: did not get valid server!\n");
                server->data_ready = NULL;
                server->data_ready = NULL;
                return -EINVAL;
                return -EINVAL;
        }
        }
        sock = &(inode->u.socket_i);
        sock = &(inode->u.socket_i);
 
 
        if (sock->type != SOCK_STREAM)
        if (sock->type != SOCK_STREAM)
        {
        {
                printk("smb_catch_keepalive: did not get SOCK_STREAM\n");
                printk("smb_catch_keepalive: did not get SOCK_STREAM\n");
                server->data_ready = NULL;
                server->data_ready = NULL;
                return -EINVAL;
                return -EINVAL;
        }
        }
        sk = (struct sock *) (sock->data);
        sk = (struct sock *) (sock->data);
 
 
        if (sk == NULL)
        if (sk == NULL)
        {
        {
                printk("smb_catch_keepalive: sk == NULL");
                printk("smb_catch_keepalive: sk == NULL");
                server->data_ready = NULL;
                server->data_ready = NULL;
                return -EINVAL;
                return -EINVAL;
        }
        }
        DDPRINTK("smb_catch_keepalive.: sk->d_r = %x, server->d_r = %x\n",
        DDPRINTK("smb_catch_keepalive.: sk->d_r = %x, server->d_r = %x\n",
                 (unsigned int) (sk->data_ready),
                 (unsigned int) (sk->data_ready),
                 (unsigned int) (server->data_ready));
                 (unsigned int) (server->data_ready));
 
 
        if (sk->data_ready == smb_data_callback)
        if (sk->data_ready == smb_data_callback)
        {
        {
                printk("smb_catch_keepalive: already done\n");
                printk("smb_catch_keepalive: already done\n");
                return -EINVAL;
                return -EINVAL;
        }
        }
        server->data_ready = sk->data_ready;
        server->data_ready = sk->data_ready;
        sk->data_ready = smb_data_callback;
        sk->data_ready = smb_data_callback;
        return 0;
        return 0;
}
}
 
 
int
int
smb_dont_catch_keepalive(struct smb_server *server)
smb_dont_catch_keepalive(struct smb_server *server)
{
{
        struct file *file;
        struct file *file;
        struct inode *inode;
        struct inode *inode;
        struct socket *sock;
        struct socket *sock;
        struct sock *sk;
        struct sock *sk;
 
 
        if ((server == NULL)
        if ((server == NULL)
            || ((file = server->sock_file) == NULL)
            || ((file = server->sock_file) == NULL)
            || ((inode = file->f_inode) == NULL)
            || ((inode = file->f_inode) == NULL)
            || (!S_ISSOCK(inode->i_mode)))
            || (!S_ISSOCK(inode->i_mode)))
        {
        {
                printk("smb_dont_catch_keepalive: "
                printk("smb_dont_catch_keepalive: "
                       "did not get valid server!\n");
                       "did not get valid server!\n");
                return -EINVAL;
                return -EINVAL;
        }
        }
        sock = &(inode->u.socket_i);
        sock = &(inode->u.socket_i);
 
 
        if (sock->type != SOCK_STREAM)
        if (sock->type != SOCK_STREAM)
        {
        {
                printk("smb_dont_catch_keepalive: did not get SOCK_STREAM\n");
                printk("smb_dont_catch_keepalive: did not get SOCK_STREAM\n");
                return -EINVAL;
                return -EINVAL;
        }
        }
        sk = (struct sock *) (sock->data);
        sk = (struct sock *) (sock->data);
 
 
        if (sk == NULL)
        if (sk == NULL)
        {
        {
                printk("smb_dont_catch_keepalive: sk == NULL");
                printk("smb_dont_catch_keepalive: sk == NULL");
                return -EINVAL;
                return -EINVAL;
        }
        }
        if (server->data_ready == NULL)
        if (server->data_ready == NULL)
        {
        {
                printk("smb_dont_catch_keepalive: "
                printk("smb_dont_catch_keepalive: "
                       "server->data_ready == NULL\n");
                       "server->data_ready == NULL\n");
                return -EINVAL;
                return -EINVAL;
        }
        }
        if (sk->data_ready != smb_data_callback)
        if (sk->data_ready != smb_data_callback)
        {
        {
                printk("smb_dont_catch_keepalive: "
                printk("smb_dont_catch_keepalive: "
                       "sk->data_callback != smb_data_callback\n");
                       "sk->data_callback != smb_data_callback\n");
                return -EINVAL;
                return -EINVAL;
        }
        }
        DDPRINTK("smb_dont_catch_keepalive: sk->d_r = %x, server->d_r = %x\n",
        DDPRINTK("smb_dont_catch_keepalive: sk->d_r = %x, server->d_r = %x\n",
                 (unsigned int) (sk->data_ready),
                 (unsigned int) (sk->data_ready),
                 (unsigned int) (server->data_ready));
                 (unsigned int) (server->data_ready));
 
 
        sk->data_ready = server->data_ready;
        sk->data_ready = server->data_ready;
        server->data_ready = NULL;
        server->data_ready = NULL;
        return 0;
        return 0;
}
}
 
 
static int
static int
smb_send_raw(struct socket *sock, unsigned char *source, int length)
smb_send_raw(struct socket *sock, unsigned char *source, int length)
{
{
        int result;
        int result;
        int already_sent = 0;
        int already_sent = 0;
 
 
        while (already_sent < length)
        while (already_sent < length)
        {
        {
                result = _send(sock,
                result = _send(sock,
                               (void *) (source + already_sent),
                               (void *) (source + already_sent),
                               length - already_sent, 0, 0);
                               length - already_sent, 0, 0);
 
 
                if (result < 0)
                if (result < 0)
                {
                {
                        DPRINTK("smb_send_raw: sendto error = %d\n",
                        DPRINTK("smb_send_raw: sendto error = %d\n",
                                -result);
                                -result);
                        return result;
                        return result;
                }
                }
                already_sent += result;
                already_sent += result;
        }
        }
        return already_sent;
        return already_sent;
}
}
 
 
static int
static int
smb_receive_raw(struct socket *sock, unsigned char *target, int length)
smb_receive_raw(struct socket *sock, unsigned char *target, int length)
{
{
        int result;
        int result;
        int already_read = 0;
        int already_read = 0;
 
 
        while (already_read < length)
        while (already_read < length)
        {
        {
                result = _recvfrom(sock,
                result = _recvfrom(sock,
                                   (void *) (target + already_read),
                                   (void *) (target + already_read),
                                   length - already_read, 0, 0,
                                   length - already_read, 0, 0,
                                   NULL, NULL);
                                   NULL, NULL);
 
 
                if (result == 0)
                if (result == 0)
                {
                {
                        return -EIO;
                        return -EIO;
                }
                }
                if (result < 0)
                if (result < 0)
                {
                {
                        DPRINTK("smb_receive_raw: recvfrom error = %d\n",
                        DPRINTK("smb_receive_raw: recvfrom error = %d\n",
                                -result);
                                -result);
                        return result;
                        return result;
                }
                }
                already_read += result;
                already_read += result;
        }
        }
        return already_read;
        return already_read;
}
}
 
 
static int
static int
smb_get_length(struct socket *sock, unsigned char *header)
smb_get_length(struct socket *sock, unsigned char *header)
{
{
        int result;
        int result;
        unsigned char peek_buf[4];
        unsigned char peek_buf[4];
        unsigned short fs;
        unsigned short fs;
 
 
      re_recv:
      re_recv:
        fs = get_fs();
        fs = get_fs();
        set_fs(get_ds());
        set_fs(get_ds());
        result = smb_receive_raw(sock, peek_buf, 4);
        result = smb_receive_raw(sock, peek_buf, 4);
        set_fs(fs);
        set_fs(fs);
 
 
        if (result < 0)
        if (result < 0)
        {
        {
                DPRINTK("smb_get_length: recv error = %d\n", -result);
                DPRINTK("smb_get_length: recv error = %d\n", -result);
                return result;
                return result;
        }
        }
        switch (peek_buf[0])
        switch (peek_buf[0])
        {
        {
        case 0x00:
        case 0x00:
        case 0x82:
        case 0x82:
                break;
                break;
 
 
        case 0x85:
        case 0x85:
                DPRINTK("smb_get_length: Got SESSION KEEP ALIVE\n");
                DPRINTK("smb_get_length: Got SESSION KEEP ALIVE\n");
                goto re_recv;
                goto re_recv;
 
 
        default:
        default:
                printk("smb_get_length: Invalid NBT packet\n");
                printk("smb_get_length: Invalid NBT packet\n");
                return -EIO;
                return -EIO;
        }
        }
 
 
        if (header != NULL)
        if (header != NULL)
        {
        {
                memcpy(header, peek_buf, 4);
                memcpy(header, peek_buf, 4);
        }
        }
        /* The length in the RFC NB header is the raw data length */
        /* The length in the RFC NB header is the raw data length */
        return smb_len(peek_buf);
        return smb_len(peek_buf);
}
}
 
 
static struct socket *
static struct socket *
server_sock(struct smb_server *server)
server_sock(struct smb_server *server)
{
{
        struct file *file;
        struct file *file;
        struct inode *inode;
        struct inode *inode;
 
 
        if (server == NULL)
        if (server == NULL)
                return NULL;
                return NULL;
        if ((file = server->sock_file) == NULL)
        if ((file = server->sock_file) == NULL)
                return NULL;
                return NULL;
        if ((inode = file->f_inode) == NULL)
        if ((inode = file->f_inode) == NULL)
                return NULL;
                return NULL;
        return &(inode->u.socket_i);
        return &(inode->u.socket_i);
}
}
 
 
/*
/*
 * smb_receive
 * smb_receive
 * fs points to the correct segment
 * fs points to the correct segment
 */
 */
static int
static int
smb_receive(struct smb_server *server)
smb_receive(struct smb_server *server)
{
{
        struct socket *sock = server_sock(server);
        struct socket *sock = server_sock(server);
        int len;
        int len;
        int result;
        int result;
        unsigned char peek_buf[4];
        unsigned char peek_buf[4];
 
 
        len = smb_get_length(sock, peek_buf);
        len = smb_get_length(sock, peek_buf);
 
 
        if (len < 0)
        if (len < 0)
        {
        {
                return len;
                return len;
        }
        }
        if (len + 4 > server->packet_size)
        if (len + 4 > server->packet_size)
        {
        {
                /* Some servers do not care about our max_xmit. They
                /* Some servers do not care about our max_xmit. They
                   send larger packets */
                   send larger packets */
                DPRINTK("smb_receive: Increase packet size from %d to %d\n",
                DPRINTK("smb_receive: Increase packet size from %d to %d\n",
                        server->packet_size, len + 4);
                        server->packet_size, len + 4);
                smb_vfree(server->packet);
                smb_vfree(server->packet);
                server->packet = NULL;
                server->packet = NULL;
 
 
                server->packet_size = 0;
                server->packet_size = 0;
                server->packet = smb_vmalloc(len + 4);
                server->packet = smb_vmalloc(len + 4);
                if (server->packet == NULL)
                if (server->packet == NULL)
                {
                {
                        return -ENOMEM;
                        return -ENOMEM;
                }
                }
                server->packet_size = len + 4;
                server->packet_size = len + 4;
        }
        }
        memcpy(server->packet, peek_buf, 4);
        memcpy(server->packet, peek_buf, 4);
        result = smb_receive_raw(sock, server->packet + 4, len);
        result = smb_receive_raw(sock, server->packet + 4, len);
 
 
        if (result < 0)
        if (result < 0)
        {
        {
                printk("smb_receive: receive error: %d\n", result);
                printk("smb_receive: receive error: %d\n", result);
                return result;
                return result;
        }
        }
        server->rcls = BVAL(server->packet, 9);
        server->rcls = BVAL(server->packet, 9);
        server->err = WVAL(server->packet, 11);
        server->err = WVAL(server->packet, 11);
 
 
        if (server->rcls != 0)
        if (server->rcls != 0)
        {
        {
                DPRINTK("smb_receive: rcls=%d, err=%d\n",
                DPRINTK("smb_receive: rcls=%d, err=%d\n",
                        server->rcls, server->err);
                        server->rcls, server->err);
        }
        }
        return result;
        return result;
}
}
 
 
static int
static int
smb_receive_trans2(struct smb_server *server,
smb_receive_trans2(struct smb_server *server,
                   int *ldata, unsigned char **data,
                   int *ldata, unsigned char **data,
                   int *lparam, unsigned char **param)
                   int *lparam, unsigned char **param)
{
{
        int total_data = 0;
        int total_data = 0;
        int total_param = 0;
        int total_param = 0;
        int result;
        int result;
        unsigned char *rcv_buf;
        unsigned char *rcv_buf;
        int buf_len;
        int buf_len;
        int data_len = 0;
        int data_len = 0;
        int param_len = 0;
        int param_len = 0;
 
 
        if ((result = smb_receive(server)) < 0)
        if ((result = smb_receive(server)) < 0)
        {
        {
                return result;
                return result;
        }
        }
        if (server->rcls != 0)
        if (server->rcls != 0)
        {
        {
                *param = *data = server->packet;
                *param = *data = server->packet;
                *ldata = *lparam = 0;
                *ldata = *lparam = 0;
                return 0;
                return 0;
        }
        }
        total_data = WVAL(server->packet, smb_tdrcnt);
        total_data = WVAL(server->packet, smb_tdrcnt);
        total_param = WVAL(server->packet, smb_tprcnt);
        total_param = WVAL(server->packet, smb_tprcnt);
 
 
        DDPRINTK("smb_receive_trans2: td=%d,tp=%d\n", total_data, total_param);
        DDPRINTK("smb_receive_trans2: td=%d,tp=%d\n", total_data, total_param);
 
 
        if ((total_data > TRANS2_MAX_TRANSFER)
        if ((total_data > TRANS2_MAX_TRANSFER)
            || (total_param > TRANS2_MAX_TRANSFER))
            || (total_param > TRANS2_MAX_TRANSFER))
        {
        {
                DPRINTK("smb_receive_trans2: data/param too long\n");
                DPRINTK("smb_receive_trans2: data/param too long\n");
                return -EIO;
                return -EIO;
        }
        }
        buf_len = total_data + total_param;
        buf_len = total_data + total_param;
        if (server->packet_size > buf_len)
        if (server->packet_size > buf_len)
        {
        {
                buf_len = server->packet_size;
                buf_len = server->packet_size;
        }
        }
        if ((rcv_buf = smb_vmalloc(buf_len)) == NULL)
        if ((rcv_buf = smb_vmalloc(buf_len)) == NULL)
        {
        {
                DPRINTK("smb_receive_trans2: could not alloc data area\n");
                DPRINTK("smb_receive_trans2: could not alloc data area\n");
                return -ENOMEM;
                return -ENOMEM;
        }
        }
        *param = rcv_buf;
        *param = rcv_buf;
        *data = rcv_buf + total_param;
        *data = rcv_buf + total_param;
 
 
        while (1)
        while (1)
        {
        {
                unsigned char *inbuf = server->packet;
                unsigned char *inbuf = server->packet;
 
 
                if (WVAL(inbuf, smb_prdisp) + WVAL(inbuf, smb_prcnt)
                if (WVAL(inbuf, smb_prdisp) + WVAL(inbuf, smb_prcnt)
                    > total_param)
                    > total_param)
                {
                {
                        DPRINTK("smb_receive_trans2: invalid parameters\n");
                        DPRINTK("smb_receive_trans2: invalid parameters\n");
                        result = -EIO;
                        result = -EIO;
                        goto fail;
                        goto fail;
                }
                }
                memcpy(*param + WVAL(inbuf, smb_prdisp),
                memcpy(*param + WVAL(inbuf, smb_prdisp),
                       smb_base(inbuf) + WVAL(inbuf, smb_proff),
                       smb_base(inbuf) + WVAL(inbuf, smb_proff),
                       WVAL(inbuf, smb_prcnt));
                       WVAL(inbuf, smb_prcnt));
                param_len += WVAL(inbuf, smb_prcnt);
                param_len += WVAL(inbuf, smb_prcnt);
 
 
                if (WVAL(inbuf, smb_drdisp) + WVAL(inbuf, smb_drcnt)
                if (WVAL(inbuf, smb_drdisp) + WVAL(inbuf, smb_drcnt)
                    > total_data)
                    > total_data)
                {
                {
                        DPRINTK("smb_receive_trans2: invalid data block\n");
                        DPRINTK("smb_receive_trans2: invalid data block\n");
                        result = -EIO;
                        result = -EIO;
                        goto fail;
                        goto fail;
                }
                }
                DDPRINTK("target: %X\n",
                DDPRINTK("target: %X\n",
                         (unsigned int) *data + WVAL(inbuf, smb_drdisp));
                         (unsigned int) *data + WVAL(inbuf, smb_drdisp));
                DDPRINTK("source: %X\n",
                DDPRINTK("source: %X\n",
                         (unsigned int)
                         (unsigned int)
                         smb_base(inbuf) + WVAL(inbuf, smb_droff));
                         smb_base(inbuf) + WVAL(inbuf, smb_droff));
                DDPRINTK("disp: %d, off: %d, cnt: %d\n",
                DDPRINTK("disp: %d, off: %d, cnt: %d\n",
                         WVAL(inbuf, smb_drdisp), WVAL(inbuf, smb_droff),
                         WVAL(inbuf, smb_drdisp), WVAL(inbuf, smb_droff),
                         WVAL(inbuf, smb_drcnt));
                         WVAL(inbuf, smb_drcnt));
 
 
                memcpy(*data + WVAL(inbuf, smb_drdisp),
                memcpy(*data + WVAL(inbuf, smb_drdisp),
                       smb_base(inbuf) + WVAL(inbuf, smb_droff),
                       smb_base(inbuf) + WVAL(inbuf, smb_droff),
                       WVAL(inbuf, smb_drcnt));
                       WVAL(inbuf, smb_drcnt));
                data_len += WVAL(inbuf, smb_drcnt);
                data_len += WVAL(inbuf, smb_drcnt);
 
 
                if ((WVAL(inbuf, smb_tdrcnt) > total_data)
                if ((WVAL(inbuf, smb_tdrcnt) > total_data)
                    || (WVAL(inbuf, smb_tprcnt) > total_param))
                    || (WVAL(inbuf, smb_tprcnt) > total_param))
                {
                {
                        printk("smb_receive_trans2: data/params grew!\n");
                        printk("smb_receive_trans2: data/params grew!\n");
                        result = -EIO;
                        result = -EIO;
                        goto fail;
                        goto fail;
                }
                }
                /* the total lengths might shrink! */
                /* the total lengths might shrink! */
                total_data = WVAL(inbuf, smb_tdrcnt);
                total_data = WVAL(inbuf, smb_tdrcnt);
                total_param = WVAL(inbuf, smb_tprcnt);
                total_param = WVAL(inbuf, smb_tprcnt);
 
 
                if ((data_len >= total_data) && (param_len >= total_param))
                if ((data_len >= total_data) && (param_len >= total_param))
                {
                {
                        break;
                        break;
                }
                }
                if ((result = smb_receive(server)) < 0)
                if ((result = smb_receive(server)) < 0)
                {
                {
                        goto fail;
                        goto fail;
                }
                }
                if (server->rcls != 0)
                if (server->rcls != 0)
                {
                {
                        result = -EIO;
                        result = -EIO;
                        goto fail;
                        goto fail;
                }
                }
        }
        }
        *ldata = data_len;
        *ldata = data_len;
        *lparam = param_len;
        *lparam = param_len;
 
 
        smb_vfree(server->packet);
        smb_vfree(server->packet);
        server->packet = rcv_buf;
        server->packet = rcv_buf;
        server->packet_size = buf_len;
        server->packet_size = buf_len;
        return 0;
        return 0;
 
 
      fail:
      fail:
        smb_vfree(rcv_buf);
        smb_vfree(rcv_buf);
        return result;
        return result;
}
}
 
 
int
int
smb_release(struct smb_server *server)
smb_release(struct smb_server *server)
{
{
        struct socket *sock = server_sock(server);
        struct socket *sock = server_sock(server);
        int result;
        int result;
 
 
        if (sock == NULL)
        if (sock == NULL)
        {
        {
                return -EINVAL;
                return -EINVAL;
        }
        }
        result = sock->ops->release(sock, NULL);
        result = sock->ops->release(sock, NULL);
        DPRINTK("smb_release: sock->ops->release = %d\n", result);
        DPRINTK("smb_release: sock->ops->release = %d\n", result);
 
 
        /* inet_release does not set sock->state.  Maybe someone is
        /* inet_release does not set sock->state.  Maybe someone is
           confused about sock->state being SS_CONNECTED while there
           confused about sock->state being SS_CONNECTED while there
           is nothing behind it, so I set it to SS_UNCONNECTED. */
           is nothing behind it, so I set it to SS_UNCONNECTED. */
        sock->state = SS_UNCONNECTED;
        sock->state = SS_UNCONNECTED;
 
 
        result = sock->ops->create(sock, 0);
        result = sock->ops->create(sock, 0);
        DPRINTK("smb_release: sock->ops->create = %d\n", result);
        DPRINTK("smb_release: sock->ops->create = %d\n", result);
        return result;
        return result;
}
}
 
 
int
int
smb_connect(struct smb_server *server)
smb_connect(struct smb_server *server)
{
{
        struct socket *sock = server_sock(server);
        struct socket *sock = server_sock(server);
        if (sock == NULL)
        if (sock == NULL)
        {
        {
                return -EINVAL;
                return -EINVAL;
        }
        }
        if (sock->state != SS_UNCONNECTED)
        if (sock->state != SS_UNCONNECTED)
        {
        {
                DPRINTK("smb_connect: socket is not unconnected: %d\n",
                DPRINTK("smb_connect: socket is not unconnected: %d\n",
                        sock->state);
                        sock->state);
        }
        }
        return sock->ops->connect(sock, (struct sockaddr *) &(server->m.addr),
        return sock->ops->connect(sock, (struct sockaddr *) &(server->m.addr),
                                  sizeof(struct sockaddr_in), 0);
                                  sizeof(struct sockaddr_in), 0);
}
}
 
 
int
int
smb_request(struct smb_server *server)
smb_request(struct smb_server *server)
{
{
        unsigned long old_mask;
        unsigned long old_mask;
        unsigned short fs;
        unsigned short fs;
        int len, result;
        int len, result;
 
 
        unsigned char *buffer = (server == NULL) ? NULL : server->packet;
        unsigned char *buffer = (server == NULL) ? NULL : server->packet;
 
 
        if (buffer == NULL)
        if (buffer == NULL)
        {
        {
                printk("smb_request: Bad server!\n");
                printk("smb_request: Bad server!\n");
                return -EBADF;
                return -EBADF;
        }
        }
        if (server->state != CONN_VALID)
        if (server->state != CONN_VALID)
        {
        {
                return -EIO;
                return -EIO;
        }
        }
        if ((result = smb_dont_catch_keepalive(server)) != 0)
        if ((result = smb_dont_catch_keepalive(server)) != 0)
        {
        {
                server->state = CONN_INVALID;
                server->state = CONN_INVALID;
                smb_invalidate_all_inodes(server);
                smb_invalidate_all_inodes(server);
                return result;
                return result;
        }
        }
        len = smb_len(buffer) + 4;
        len = smb_len(buffer) + 4;
 
 
        DPRINTK("smb_request: len = %d cmd = 0x%X\n", len, buffer[8]);
        DPRINTK("smb_request: len = %d cmd = 0x%X\n", len, buffer[8]);
 
 
        old_mask = current->blocked;
        old_mask = current->blocked;
        current->blocked |= ~(_S(SIGKILL) | _S(SIGSTOP));
        current->blocked |= ~(_S(SIGKILL) | _S(SIGSTOP));
        fs = get_fs();
        fs = get_fs();
        set_fs(get_ds());
        set_fs(get_ds());
 
 
        result = smb_send_raw(server_sock(server), (void *) buffer, len);
        result = smb_send_raw(server_sock(server), (void *) buffer, len);
        if (result > 0)
        if (result > 0)
        {
        {
                result = smb_receive(server);
                result = smb_receive(server);
        }
        }
        /* read/write errors are handled by errno */
        /* read/write errors are handled by errno */
        current->signal &= ~_S(SIGPIPE);
        current->signal &= ~_S(SIGPIPE);
        current->blocked = old_mask;
        current->blocked = old_mask;
        set_fs(fs);
        set_fs(fs);
 
 
        if (result >= 0)
        if (result >= 0)
        {
        {
                int result2 = smb_catch_keepalive(server);
                int result2 = smb_catch_keepalive(server);
                if (result2 < 0)
                if (result2 < 0)
                {
                {
                        result = result2;
                        result = result2;
                }
                }
        }
        }
        if (result < 0)
        if (result < 0)
        {
        {
                server->state = CONN_INVALID;
                server->state = CONN_INVALID;
                smb_invalidate_all_inodes(server);
                smb_invalidate_all_inodes(server);
        }
        }
        DDPRINTK("smb_request: result = %d\n", result);
        DDPRINTK("smb_request: result = %d\n", result);
 
 
        return result;
        return result;
}
}
 
 
#define ROUND_UP(x) (((x)+3) & ~3)
#define ROUND_UP(x) (((x)+3) & ~3)
static int
static int
smb_send_trans2(struct smb_server *server, __u16 trans2_command,
smb_send_trans2(struct smb_server *server, __u16 trans2_command,
                int ldata, unsigned char *data,
                int ldata, unsigned char *data,
                int lparam, unsigned char *param)
                int lparam, unsigned char *param)
{
{
        struct socket *sock = server_sock(server);
        struct socket *sock = server_sock(server);
 
 
        /* I know the following is very ugly, but I want to build the
        /* I know the following is very ugly, but I want to build the
           smb packet as efficiently as possible. */
           smb packet as efficiently as possible. */
 
 
        const int smb_parameters = 15;
        const int smb_parameters = 15;
        const int oparam =
        const int oparam =
        ROUND_UP(SMB_HEADER_LEN + 2 * smb_parameters + 2 + 3);
        ROUND_UP(SMB_HEADER_LEN + 2 * smb_parameters + 2 + 3);
        const int odata =
        const int odata =
        ROUND_UP(oparam + lparam);
        ROUND_UP(oparam + lparam);
        const int bcc =
        const int bcc =
        odata + ldata - (SMB_HEADER_LEN + 2 * smb_parameters + 2);
        odata + ldata - (SMB_HEADER_LEN + 2 * smb_parameters + 2);
        const int packet_length =
        const int packet_length =
        SMB_HEADER_LEN + 2 * smb_parameters + bcc + 2;
        SMB_HEADER_LEN + 2 * smb_parameters + bcc + 2;
 
 
        unsigned char padding[4] =
        unsigned char padding[4] =
        {0,};
        {0,};
        char *p;
        char *p;
 
 
        struct iovec iov[4];
        struct iovec iov[4];
        struct msghdr msg;
        struct msghdr msg;
 
 
        if ((bcc + oparam) > server->max_xmit)
        if ((bcc + oparam) > server->max_xmit)
        {
        {
                return -ENOMEM;
                return -ENOMEM;
        }
        }
        p = smb_setup_header(server, SMBtrans2, smb_parameters, bcc);
        p = smb_setup_header(server, SMBtrans2, smb_parameters, bcc);
 
 
        WSET(server->packet, smb_tpscnt, lparam);
        WSET(server->packet, smb_tpscnt, lparam);
        WSET(server->packet, smb_tdscnt, ldata);
        WSET(server->packet, smb_tdscnt, ldata);
        WSET(server->packet, smb_mprcnt, TRANS2_MAX_TRANSFER);
        WSET(server->packet, smb_mprcnt, TRANS2_MAX_TRANSFER);
        WSET(server->packet, smb_mdrcnt, TRANS2_MAX_TRANSFER);
        WSET(server->packet, smb_mdrcnt, TRANS2_MAX_TRANSFER);
        WSET(server->packet, smb_msrcnt, 0);
        WSET(server->packet, smb_msrcnt, 0);
        WSET(server->packet, smb_flags, 0);
        WSET(server->packet, smb_flags, 0);
        DSET(server->packet, smb_timeout, 0);
        DSET(server->packet, smb_timeout, 0);
        WSET(server->packet, smb_pscnt, lparam);
        WSET(server->packet, smb_pscnt, lparam);
        WSET(server->packet, smb_psoff, oparam - 4);
        WSET(server->packet, smb_psoff, oparam - 4);
        WSET(server->packet, smb_dscnt, ldata);
        WSET(server->packet, smb_dscnt, ldata);
        WSET(server->packet, smb_dsoff, odata - 4);
        WSET(server->packet, smb_dsoff, odata - 4);
        WSET(server->packet, smb_suwcnt, 1);
        WSET(server->packet, smb_suwcnt, 1);
        WSET(server->packet, smb_setup0, trans2_command);
        WSET(server->packet, smb_setup0, trans2_command);
        *p++ = 0;                /* null smb_name for trans2 */
        *p++ = 0;                /* null smb_name for trans2 */
        *p++ = 'D';             /* this was added because OS/2 does it */
        *p++ = 'D';             /* this was added because OS/2 does it */
        *p++ = ' ';
        *p++ = ' ';
 
 
        iov[0].iov_base = (void *) server->packet;
        iov[0].iov_base = (void *) server->packet;
        iov[0].iov_len = oparam;
        iov[0].iov_len = oparam;
        iov[1].iov_base = (param == NULL) ? padding : param;
        iov[1].iov_base = (param == NULL) ? padding : param;
        iov[1].iov_len = lparam;
        iov[1].iov_len = lparam;
        iov[2].iov_base = padding;
        iov[2].iov_base = padding;
        iov[2].iov_len = odata - oparam - lparam;
        iov[2].iov_len = odata - oparam - lparam;
        iov[3].iov_base = (data == NULL) ? padding : data;
        iov[3].iov_base = (data == NULL) ? padding : data;
        iov[3].iov_len = ldata;
        iov[3].iov_len = ldata;
 
 
        msg.msg_name = NULL;
        msg.msg_name = NULL;
        msg.msg_namelen = 0;
        msg.msg_namelen = 0;
        msg.msg_control = NULL;
        msg.msg_control = NULL;
        msg.msg_iov = iov;
        msg.msg_iov = iov;
        msg.msg_iovlen = 4;
        msg.msg_iovlen = 4;
 
 
        return sock->ops->sendmsg(sock, &msg, packet_length, 0, 0);
        return sock->ops->sendmsg(sock, &msg, packet_length, 0, 0);
}
}
 
 
/*
/*
 * This is not really a trans2 request, we assume that you only have
 * This is not really a trans2 request, we assume that you only have
 * one packet to send.
 * one packet to send.
 */
 */
int
int
smb_trans2_request(struct smb_server *server, __u16 trans2_command,
smb_trans2_request(struct smb_server *server, __u16 trans2_command,
                   int ldata, unsigned char *data,
                   int ldata, unsigned char *data,
                   int lparam, unsigned char *param,
                   int lparam, unsigned char *param,
                   int *lrdata, unsigned char **rdata,
                   int *lrdata, unsigned char **rdata,
                   int *lrparam, unsigned char **rparam)
                   int *lrparam, unsigned char **rparam)
{
{
        unsigned long old_mask;
        unsigned long old_mask;
        unsigned short fs;
        unsigned short fs;
        int result;
        int result;
 
 
        DPRINTK("smb_trans2_request: com=%d, ld=%d, lp=%d\n",
        DPRINTK("smb_trans2_request: com=%d, ld=%d, lp=%d\n",
                trans2_command, ldata, lparam);
                trans2_command, ldata, lparam);
 
 
        if (server->state != CONN_VALID)
        if (server->state != CONN_VALID)
        {
        {
                return -EIO;
                return -EIO;
        }
        }
        if ((result = smb_dont_catch_keepalive(server)) != 0)
        if ((result = smb_dont_catch_keepalive(server)) != 0)
        {
        {
                server->state = CONN_INVALID;
                server->state = CONN_INVALID;
                smb_invalidate_all_inodes(server);
                smb_invalidate_all_inodes(server);
                return result;
                return result;
        }
        }
        old_mask = current->blocked;
        old_mask = current->blocked;
        current->blocked |= ~(_S(SIGKILL) | _S(SIGSTOP));
        current->blocked |= ~(_S(SIGKILL) | _S(SIGSTOP));
        fs = get_fs();
        fs = get_fs();
        set_fs(get_ds());
        set_fs(get_ds());
 
 
        result = smb_send_trans2(server, trans2_command,
        result = smb_send_trans2(server, trans2_command,
                                 ldata, data, lparam, param);
                                 ldata, data, lparam, param);
        if (result >= 0)
        if (result >= 0)
        {
        {
                result = smb_receive_trans2(server,
                result = smb_receive_trans2(server,
                                            lrdata, rdata, lrparam, rparam);
                                            lrdata, rdata, lrparam, rparam);
        }
        }
        /* read/write errors are handled by errno */
        /* read/write errors are handled by errno */
        current->signal &= ~_S(SIGPIPE);
        current->signal &= ~_S(SIGPIPE);
        current->blocked = old_mask;
        current->blocked = old_mask;
        set_fs(fs);
        set_fs(fs);
 
 
        if (result >= 0)
        if (result >= 0)
        {
        {
                int result2 = smb_catch_keepalive(server);
                int result2 = smb_catch_keepalive(server);
                if (result2 < 0)
                if (result2 < 0)
                {
                {
                        result = result2;
                        result = result2;
                }
                }
        }
        }
        if (result < 0)
        if (result < 0)
        {
        {
                server->state = CONN_INVALID;
                server->state = CONN_INVALID;
                smb_invalidate_all_inodes(server);
                smb_invalidate_all_inodes(server);
        }
        }
        DDPRINTK("smb_trans2_request: result = %d\n", result);
        DDPRINTK("smb_trans2_request: result = %d\n", result);
 
 
        return result;
        return result;
}
}
 
 

powered by: WebSVN 2.1.0

© copyright 1999-2024 OpenCores.org, equivalent to Oliscience, all rights reserved. OpenCores®, registered trademark.