URL
https://opencores.org/ocsvn/or1k/or1k/trunk
Subversion Repositories or1k
[/] [or1k/] [trunk/] [rc203soc/] [sw/] [uClinux/] [fs/] [smbfs/] [sock.c] - Rev 1765
Compare with Previous | Blame | View Log
/* * sock.c * * Copyright (C) 1995, 1996 by Paal-Kr. Engstad and Volker Lendecke * */ #include <linux/sched.h> #include <linux/smb_fs.h> #include <linux/errno.h> #include <linux/socket.h> #include <linux/fcntl.h> #include <linux/stat.h> #include <asm/segment.h> #include <linux/in.h> #include <linux/net.h> #include <linux/mm.h> #include <linux/netdevice.h> #include <net/ip.h> #include <linux/smb.h> #include <linux/smbno.h> #define _S(nr) (1<<((nr)-1)) static int _recvfrom(struct socket *sock, unsigned char *ubuf, int size, int noblock, unsigned flags, struct sockaddr_in *sa, int *addr_len) { struct iovec iov; struct msghdr msg; iov.iov_base = ubuf; iov.iov_len = size; msg.msg_name = (void *) sa; msg.msg_namelen = 0; if (addr_len) msg.msg_namelen = *addr_len; msg.msg_control = NULL; msg.msg_iov = &iov; msg.msg_iovlen = 1; return sock->ops->recvmsg(sock, &msg, size, noblock, flags, addr_len); } static int _send(struct socket *sock, const void *buff, int len, int nonblock, unsigned flags) { struct iovec iov; struct msghdr msg; iov.iov_base = (void *) buff; iov.iov_len = len; msg.msg_name = NULL; msg.msg_namelen = 0; msg.msg_control = NULL; msg.msg_iov = &iov; msg.msg_iovlen = 1; return sock->ops->sendmsg(sock, &msg, len, nonblock, flags); } static void smb_data_callback(struct sock *sk, int len) { struct socket *sock = sk->socket; if (!sk->dead) { unsigned char peek_buf[4]; int result; unsigned short fs; fs = get_fs(); set_fs(get_ds()); result = _recvfrom(sock, (void *) peek_buf, 1, 1, MSG_PEEK, NULL, NULL); while ((result != -EAGAIN) && (peek_buf[0] == 0x85)) { /* got SESSION KEEP ALIVE */ result = _recvfrom(sock, (void *) peek_buf, 4, 1, 0, NULL, NULL); DDPRINTK("smb_data_callback:" " got SESSION KEEP ALIVE\n"); if (result == -EAGAIN) { break; } result = _recvfrom(sock, (void *) peek_buf, 1, 1, MSG_PEEK, NULL, NULL); } set_fs(fs); if (result != -EAGAIN) { wake_up_interruptible(sk->sleep); } } } int smb_catch_keepalive(struct smb_server *server) { struct file *file; struct inode *inode; struct socket *sock; struct sock *sk; if ((server == NULL) || ((file = server->sock_file) == NULL) || ((inode = file->f_inode) == NULL) || (!S_ISSOCK(inode->i_mode))) { printk("smb_catch_keepalive: did not get valid server!\n"); server->data_ready = NULL; return -EINVAL; } sock = &(inode->u.socket_i); if (sock->type != SOCK_STREAM) { printk("smb_catch_keepalive: did not get SOCK_STREAM\n"); server->data_ready = NULL; return -EINVAL; } sk = (struct sock *) (sock->data); if (sk == NULL) { printk("smb_catch_keepalive: sk == NULL"); server->data_ready = NULL; return -EINVAL; } DDPRINTK("smb_catch_keepalive.: sk->d_r = %x, server->d_r = %x\n", (unsigned int) (sk->data_ready), (unsigned int) (server->data_ready)); if (sk->data_ready == smb_data_callback) { printk("smb_catch_keepalive: already done\n"); return -EINVAL; } server->data_ready = sk->data_ready; sk->data_ready = smb_data_callback; return 0; } int smb_dont_catch_keepalive(struct smb_server *server) { struct file *file; struct inode *inode; struct socket *sock; struct sock *sk; if ((server == NULL) || ((file = server->sock_file) == NULL) || ((inode = file->f_inode) == NULL) || (!S_ISSOCK(inode->i_mode))) { printk("smb_dont_catch_keepalive: " "did not get valid server!\n"); return -EINVAL; } sock = &(inode->u.socket_i); if (sock->type != SOCK_STREAM) { printk("smb_dont_catch_keepalive: did not get SOCK_STREAM\n"); return -EINVAL; } sk = (struct sock *) (sock->data); if (sk == NULL) { printk("smb_dont_catch_keepalive: sk == NULL"); return -EINVAL; } if (server->data_ready == NULL) { printk("smb_dont_catch_keepalive: " "server->data_ready == NULL\n"); return -EINVAL; } if (sk->data_ready != smb_data_callback) { printk("smb_dont_catch_keepalive: " "sk->data_callback != smb_data_callback\n"); return -EINVAL; } DDPRINTK("smb_dont_catch_keepalive: sk->d_r = %x, server->d_r = %x\n", (unsigned int) (sk->data_ready), (unsigned int) (server->data_ready)); sk->data_ready = server->data_ready; server->data_ready = NULL; return 0; } static int smb_send_raw(struct socket *sock, unsigned char *source, int length) { int result; int already_sent = 0; while (already_sent < length) { result = _send(sock, (void *) (source + already_sent), length - already_sent, 0, 0); if (result < 0) { DPRINTK("smb_send_raw: sendto error = %d\n", -result); return result; } already_sent += result; } return already_sent; } static int smb_receive_raw(struct socket *sock, unsigned char *target, int length) { int result; int already_read = 0; while (already_read < length) { result = _recvfrom(sock, (void *) (target + already_read), length - already_read, 0, 0, NULL, NULL); if (result == 0) { return -EIO; } if (result < 0) { DPRINTK("smb_receive_raw: recvfrom error = %d\n", -result); return result; } already_read += result; } return already_read; } static int smb_get_length(struct socket *sock, unsigned char *header) { int result; unsigned char peek_buf[4]; unsigned short fs; re_recv: fs = get_fs(); set_fs(get_ds()); result = smb_receive_raw(sock, peek_buf, 4); set_fs(fs); if (result < 0) { DPRINTK("smb_get_length: recv error = %d\n", -result); return result; } switch (peek_buf[0]) { case 0x00: case 0x82: break; case 0x85: DPRINTK("smb_get_length: Got SESSION KEEP ALIVE\n"); goto re_recv; default: printk("smb_get_length: Invalid NBT packet\n"); return -EIO; } if (header != NULL) { memcpy(header, peek_buf, 4); } /* The length in the RFC NB header is the raw data length */ return smb_len(peek_buf); } static struct socket * server_sock(struct smb_server *server) { struct file *file; struct inode *inode; if (server == NULL) return NULL; if ((file = server->sock_file) == NULL) return NULL; if ((inode = file->f_inode) == NULL) return NULL; return &(inode->u.socket_i); } /* * smb_receive * fs points to the correct segment */ static int smb_receive(struct smb_server *server) { struct socket *sock = server_sock(server); int len; int result; unsigned char peek_buf[4]; len = smb_get_length(sock, peek_buf); if (len < 0) { return len; } if (len + 4 > server->packet_size) { /* Some servers do not care about our max_xmit. They send larger packets */ DPRINTK("smb_receive: Increase packet size from %d to %d\n", server->packet_size, len + 4); smb_vfree(server->packet); server->packet = NULL; server->packet_size = 0; server->packet = smb_vmalloc(len + 4); if (server->packet == NULL) { return -ENOMEM; } server->packet_size = len + 4; } memcpy(server->packet, peek_buf, 4); result = smb_receive_raw(sock, server->packet + 4, len); if (result < 0) { printk("smb_receive: receive error: %d\n", result); return result; } server->rcls = BVAL(server->packet, 9); server->err = WVAL(server->packet, 11); if (server->rcls != 0) { DPRINTK("smb_receive: rcls=%d, err=%d\n", server->rcls, server->err); } return result; } static int smb_receive_trans2(struct smb_server *server, int *ldata, unsigned char **data, int *lparam, unsigned char **param) { int total_data = 0; int total_param = 0; int result; unsigned char *rcv_buf; int buf_len; int data_len = 0; int param_len = 0; if ((result = smb_receive(server)) < 0) { return result; } if (server->rcls != 0) { *param = *data = server->packet; *ldata = *lparam = 0; return 0; } total_data = WVAL(server->packet, smb_tdrcnt); total_param = WVAL(server->packet, smb_tprcnt); DDPRINTK("smb_receive_trans2: td=%d,tp=%d\n", total_data, total_param); if ((total_data > TRANS2_MAX_TRANSFER) || (total_param > TRANS2_MAX_TRANSFER)) { DPRINTK("smb_receive_trans2: data/param too long\n"); return -EIO; } buf_len = total_data + total_param; if (server->packet_size > buf_len) { buf_len = server->packet_size; } if ((rcv_buf = smb_vmalloc(buf_len)) == NULL) { DPRINTK("smb_receive_trans2: could not alloc data area\n"); return -ENOMEM; } *param = rcv_buf; *data = rcv_buf + total_param; while (1) { unsigned char *inbuf = server->packet; if (WVAL(inbuf, smb_prdisp) + WVAL(inbuf, smb_prcnt) > total_param) { DPRINTK("smb_receive_trans2: invalid parameters\n"); result = -EIO; goto fail; } memcpy(*param + WVAL(inbuf, smb_prdisp), smb_base(inbuf) + WVAL(inbuf, smb_proff), WVAL(inbuf, smb_prcnt)); param_len += WVAL(inbuf, smb_prcnt); if (WVAL(inbuf, smb_drdisp) + WVAL(inbuf, smb_drcnt) > total_data) { DPRINTK("smb_receive_trans2: invalid data block\n"); result = -EIO; goto fail; } DDPRINTK("target: %X\n", (unsigned int) *data + WVAL(inbuf, smb_drdisp)); DDPRINTK("source: %X\n", (unsigned int) smb_base(inbuf) + WVAL(inbuf, smb_droff)); DDPRINTK("disp: %d, off: %d, cnt: %d\n", WVAL(inbuf, smb_drdisp), WVAL(inbuf, smb_droff), WVAL(inbuf, smb_drcnt)); memcpy(*data + WVAL(inbuf, smb_drdisp), smb_base(inbuf) + WVAL(inbuf, smb_droff), WVAL(inbuf, smb_drcnt)); data_len += WVAL(inbuf, smb_drcnt); if ((WVAL(inbuf, smb_tdrcnt) > total_data) || (WVAL(inbuf, smb_tprcnt) > total_param)) { printk("smb_receive_trans2: data/params grew!\n"); result = -EIO; goto fail; } /* the total lengths might shrink! */ total_data = WVAL(inbuf, smb_tdrcnt); total_param = WVAL(inbuf, smb_tprcnt); if ((data_len >= total_data) && (param_len >= total_param)) { break; } if ((result = smb_receive(server)) < 0) { goto fail; } if (server->rcls != 0) { result = -EIO; goto fail; } } *ldata = data_len; *lparam = param_len; smb_vfree(server->packet); server->packet = rcv_buf; server->packet_size = buf_len; return 0; fail: smb_vfree(rcv_buf); return result; } int smb_release(struct smb_server *server) { struct socket *sock = server_sock(server); int result; if (sock == NULL) { return -EINVAL; } result = sock->ops->release(sock, NULL); DPRINTK("smb_release: sock->ops->release = %d\n", result); /* inet_release does not set sock->state. Maybe someone is confused about sock->state being SS_CONNECTED while there is nothing behind it, so I set it to SS_UNCONNECTED. */ sock->state = SS_UNCONNECTED; result = sock->ops->create(sock, 0); DPRINTK("smb_release: sock->ops->create = %d\n", result); return result; } int smb_connect(struct smb_server *server) { struct socket *sock = server_sock(server); if (sock == NULL) { return -EINVAL; } if (sock->state != SS_UNCONNECTED) { DPRINTK("smb_connect: socket is not unconnected: %d\n", sock->state); } return sock->ops->connect(sock, (struct sockaddr *) &(server->m.addr), sizeof(struct sockaddr_in), 0); } int smb_request(struct smb_server *server) { unsigned long old_mask; unsigned short fs; int len, result; unsigned char *buffer = (server == NULL) ? NULL : server->packet; if (buffer == NULL) { printk("smb_request: Bad server!\n"); return -EBADF; } if (server->state != CONN_VALID) { return -EIO; } if ((result = smb_dont_catch_keepalive(server)) != 0) { server->state = CONN_INVALID; smb_invalidate_all_inodes(server); return result; } len = smb_len(buffer) + 4; DPRINTK("smb_request: len = %d cmd = 0x%X\n", len, buffer[8]); old_mask = current->blocked; current->blocked |= ~(_S(SIGKILL) | _S(SIGSTOP)); fs = get_fs(); set_fs(get_ds()); result = smb_send_raw(server_sock(server), (void *) buffer, len); if (result > 0) { result = smb_receive(server); } /* read/write errors are handled by errno */ current->signal &= ~_S(SIGPIPE); current->blocked = old_mask; set_fs(fs); if (result >= 0) { int result2 = smb_catch_keepalive(server); if (result2 < 0) { result = result2; } } if (result < 0) { server->state = CONN_INVALID; smb_invalidate_all_inodes(server); } DDPRINTK("smb_request: result = %d\n", result); return result; } #define ROUND_UP(x) (((x)+3) & ~3) static int smb_send_trans2(struct smb_server *server, __u16 trans2_command, int ldata, unsigned char *data, int lparam, unsigned char *param) { struct socket *sock = server_sock(server); /* I know the following is very ugly, but I want to build the smb packet as efficiently as possible. */ const int smb_parameters = 15; const int oparam = ROUND_UP(SMB_HEADER_LEN + 2 * smb_parameters + 2 + 3); const int odata = ROUND_UP(oparam + lparam); const int bcc = odata + ldata - (SMB_HEADER_LEN + 2 * smb_parameters + 2); const int packet_length = SMB_HEADER_LEN + 2 * smb_parameters + bcc + 2; unsigned char padding[4] = {0,}; char *p; struct iovec iov[4]; struct msghdr msg; if ((bcc + oparam) > server->max_xmit) { return -ENOMEM; } p = smb_setup_header(server, SMBtrans2, smb_parameters, bcc); WSET(server->packet, smb_tpscnt, lparam); WSET(server->packet, smb_tdscnt, ldata); WSET(server->packet, smb_mprcnt, TRANS2_MAX_TRANSFER); WSET(server->packet, smb_mdrcnt, TRANS2_MAX_TRANSFER); WSET(server->packet, smb_msrcnt, 0); WSET(server->packet, smb_flags, 0); DSET(server->packet, smb_timeout, 0); WSET(server->packet, smb_pscnt, lparam); WSET(server->packet, smb_psoff, oparam - 4); WSET(server->packet, smb_dscnt, ldata); WSET(server->packet, smb_dsoff, odata - 4); WSET(server->packet, smb_suwcnt, 1); WSET(server->packet, smb_setup0, trans2_command); *p++ = 0; /* null smb_name for trans2 */ *p++ = 'D'; /* this was added because OS/2 does it */ *p++ = ' '; iov[0].iov_base = (void *) server->packet; iov[0].iov_len = oparam; iov[1].iov_base = (param == NULL) ? padding : param; iov[1].iov_len = lparam; iov[2].iov_base = padding; iov[2].iov_len = odata - oparam - lparam; iov[3].iov_base = (data == NULL) ? padding : data; iov[3].iov_len = ldata; msg.msg_name = NULL; msg.msg_namelen = 0; msg.msg_control = NULL; msg.msg_iov = iov; msg.msg_iovlen = 4; return sock->ops->sendmsg(sock, &msg, packet_length, 0, 0); } /* * This is not really a trans2 request, we assume that you only have * one packet to send. */ int smb_trans2_request(struct smb_server *server, __u16 trans2_command, int ldata, unsigned char *data, int lparam, unsigned char *param, int *lrdata, unsigned char **rdata, int *lrparam, unsigned char **rparam) { unsigned long old_mask; unsigned short fs; int result; DPRINTK("smb_trans2_request: com=%d, ld=%d, lp=%d\n", trans2_command, ldata, lparam); if (server->state != CONN_VALID) { return -EIO; } if ((result = smb_dont_catch_keepalive(server)) != 0) { server->state = CONN_INVALID; smb_invalidate_all_inodes(server); return result; } old_mask = current->blocked; current->blocked |= ~(_S(SIGKILL) | _S(SIGSTOP)); fs = get_fs(); set_fs(get_ds()); result = smb_send_trans2(server, trans2_command, ldata, data, lparam, param); if (result >= 0) { result = smb_receive_trans2(server, lrdata, rdata, lrparam, rparam); } /* read/write errors are handled by errno */ current->signal &= ~_S(SIGPIPE); current->blocked = old_mask; set_fs(fs); if (result >= 0) { int result2 = smb_catch_keepalive(server); if (result2 < 0) { result = result2; } } if (result < 0) { server->state = CONN_INVALID; smb_invalidate_all_inodes(server); } DDPRINTK("smb_trans2_request: result = %d\n", result); return result; }