OpenCores
URL https://opencores.org/ocsvn/or1k/or1k/trunk

Subversion Repositories or1k

[/] [or1k/] [trunk/] [rc203soc/] [sw/] [uClinux/] [fs/] [smbfs/] [sock.c] - Rev 1771

Go to most recent revision | Compare with Previous | Blame | View Log

/*
 *  sock.c
 *
 *  Copyright (C) 1995, 1996 by Paal-Kr. Engstad and Volker Lendecke
 *
 */
 
#include <linux/sched.h>
#include <linux/smb_fs.h>
#include <linux/errno.h>
#include <linux/socket.h>
#include <linux/fcntl.h>
#include <linux/stat.h>
#include <asm/segment.h>
#include <linux/in.h>
#include <linux/net.h>
#include <linux/mm.h>
#include <linux/netdevice.h>
#include <net/ip.h>
 
#include <linux/smb.h>
#include <linux/smbno.h>
 
 
#define _S(nr) (1<<((nr)-1))
 
static int
_recvfrom(struct socket *sock, unsigned char *ubuf, int size,
	  int noblock, unsigned flags, struct sockaddr_in *sa, int *addr_len)
{
	struct iovec iov;
	struct msghdr msg;
 
	iov.iov_base = ubuf;
	iov.iov_len = size;
 
	msg.msg_name = (void *) sa;
	msg.msg_namelen = 0;
	if (addr_len)
		msg.msg_namelen = *addr_len;
	msg.msg_control = NULL;
	msg.msg_iov = &iov;
	msg.msg_iovlen = 1;
 
	return sock->ops->recvmsg(sock, &msg, size, noblock, flags, addr_len);
}
 
static int
_send(struct socket *sock, const void *buff, int len,
      int nonblock, unsigned flags)
{
	struct iovec iov;
	struct msghdr msg;
 
	iov.iov_base = (void *) buff;
	iov.iov_len = len;
 
	msg.msg_name = NULL;
	msg.msg_namelen = 0;
	msg.msg_control = NULL;
	msg.msg_iov = &iov;
	msg.msg_iovlen = 1;
 
	return sock->ops->sendmsg(sock, &msg, len, nonblock, flags);
}
 
static void
smb_data_callback(struct sock *sk, int len)
{
	struct socket *sock = sk->socket;
 
	if (!sk->dead)
	{
		unsigned char peek_buf[4];
		int result;
		unsigned short fs;
 
		fs = get_fs();
		set_fs(get_ds());
 
		result = _recvfrom(sock, (void *) peek_buf, 1, 1,
				   MSG_PEEK, NULL, NULL);
 
		while ((result != -EAGAIN) && (peek_buf[0] == 0x85))
		{
			/* got SESSION KEEP ALIVE */
			result = _recvfrom(sock, (void *) peek_buf,
					   4, 1, 0, NULL, NULL);
 
			DDPRINTK("smb_data_callback:"
				 " got SESSION KEEP ALIVE\n");
 
			if (result == -EAGAIN)
			{
				break;
			}
			result = _recvfrom(sock, (void *) peek_buf,
					   1, 1, MSG_PEEK,
					   NULL, NULL);
		}
		set_fs(fs);
 
		if (result != -EAGAIN)
		{
			wake_up_interruptible(sk->sleep);
		}
	}
}
 
int
smb_catch_keepalive(struct smb_server *server)
{
	struct file *file;
	struct inode *inode;
	struct socket *sock;
	struct sock *sk;
 
	if ((server == NULL)
	    || ((file = server->sock_file) == NULL)
	    || ((inode = file->f_inode) == NULL)
	    || (!S_ISSOCK(inode->i_mode)))
	{
		printk("smb_catch_keepalive: did not get valid server!\n");
		server->data_ready = NULL;
		return -EINVAL;
	}
	sock = &(inode->u.socket_i);
 
	if (sock->type != SOCK_STREAM)
	{
		printk("smb_catch_keepalive: did not get SOCK_STREAM\n");
		server->data_ready = NULL;
		return -EINVAL;
	}
	sk = (struct sock *) (sock->data);
 
	if (sk == NULL)
	{
		printk("smb_catch_keepalive: sk == NULL");
		server->data_ready = NULL;
		return -EINVAL;
	}
	DDPRINTK("smb_catch_keepalive.: sk->d_r = %x, server->d_r = %x\n",
		 (unsigned int) (sk->data_ready),
		 (unsigned int) (server->data_ready));
 
	if (sk->data_ready == smb_data_callback)
	{
		printk("smb_catch_keepalive: already done\n");
		return -EINVAL;
	}
	server->data_ready = sk->data_ready;
	sk->data_ready = smb_data_callback;
	return 0;
}
 
int
smb_dont_catch_keepalive(struct smb_server *server)
{
	struct file *file;
	struct inode *inode;
	struct socket *sock;
	struct sock *sk;
 
	if ((server == NULL)
	    || ((file = server->sock_file) == NULL)
	    || ((inode = file->f_inode) == NULL)
	    || (!S_ISSOCK(inode->i_mode)))
	{
		printk("smb_dont_catch_keepalive: "
		       "did not get valid server!\n");
		return -EINVAL;
	}
	sock = &(inode->u.socket_i);
 
	if (sock->type != SOCK_STREAM)
	{
		printk("smb_dont_catch_keepalive: did not get SOCK_STREAM\n");
		return -EINVAL;
	}
	sk = (struct sock *) (sock->data);
 
	if (sk == NULL)
	{
		printk("smb_dont_catch_keepalive: sk == NULL");
		return -EINVAL;
	}
	if (server->data_ready == NULL)
	{
		printk("smb_dont_catch_keepalive: "
		       "server->data_ready == NULL\n");
		return -EINVAL;
	}
	if (sk->data_ready != smb_data_callback)
	{
		printk("smb_dont_catch_keepalive: "
		       "sk->data_callback != smb_data_callback\n");
		return -EINVAL;
	}
	DDPRINTK("smb_dont_catch_keepalive: sk->d_r = %x, server->d_r = %x\n",
		 (unsigned int) (sk->data_ready),
		 (unsigned int) (server->data_ready));
 
	sk->data_ready = server->data_ready;
	server->data_ready = NULL;
	return 0;
}
 
static int
smb_send_raw(struct socket *sock, unsigned char *source, int length)
{
	int result;
	int already_sent = 0;
 
	while (already_sent < length)
	{
		result = _send(sock,
			       (void *) (source + already_sent),
			       length - already_sent, 0, 0);
 
		if (result < 0)
		{
			DPRINTK("smb_send_raw: sendto error = %d\n",
				-result);
			return result;
		}
		already_sent += result;
	}
	return already_sent;
}
 
static int
smb_receive_raw(struct socket *sock, unsigned char *target, int length)
{
	int result;
	int already_read = 0;
 
	while (already_read < length)
	{
		result = _recvfrom(sock,
				   (void *) (target + already_read),
				   length - already_read, 0, 0,
				   NULL, NULL);
 
		if (result == 0)
		{
			return -EIO;
		}
		if (result < 0)
		{
			DPRINTK("smb_receive_raw: recvfrom error = %d\n",
				-result);
			return result;
		}
		already_read += result;
	}
	return already_read;
}
 
static int
smb_get_length(struct socket *sock, unsigned char *header)
{
	int result;
	unsigned char peek_buf[4];
	unsigned short fs;
 
      re_recv:
	fs = get_fs();
	set_fs(get_ds());
	result = smb_receive_raw(sock, peek_buf, 4);
	set_fs(fs);
 
	if (result < 0)
	{
		DPRINTK("smb_get_length: recv error = %d\n", -result);
		return result;
	}
	switch (peek_buf[0])
	{
	case 0x00:
	case 0x82:
		break;
 
	case 0x85:
		DPRINTK("smb_get_length: Got SESSION KEEP ALIVE\n");
		goto re_recv;
 
	default:
		printk("smb_get_length: Invalid NBT packet\n");
		return -EIO;
	}
 
	if (header != NULL)
	{
		memcpy(header, peek_buf, 4);
	}
	/* The length in the RFC NB header is the raw data length */
	return smb_len(peek_buf);
}
 
static struct socket *
server_sock(struct smb_server *server)
{
	struct file *file;
	struct inode *inode;
 
	if (server == NULL)
		return NULL;
	if ((file = server->sock_file) == NULL)
		return NULL;
	if ((inode = file->f_inode) == NULL)
		return NULL;
	return &(inode->u.socket_i);
}
 
/*
 * smb_receive
 * fs points to the correct segment
 */
static int
smb_receive(struct smb_server *server)
{
	struct socket *sock = server_sock(server);
	int len;
	int result;
	unsigned char peek_buf[4];
 
	len = smb_get_length(sock, peek_buf);
 
	if (len < 0)
	{
		return len;
	}
	if (len + 4 > server->packet_size)
	{
		/* Some servers do not care about our max_xmit. They
		   send larger packets */
		DPRINTK("smb_receive: Increase packet size from %d to %d\n",
			server->packet_size, len + 4);
		smb_vfree(server->packet);
		server->packet = NULL;
 
		server->packet_size = 0;
		server->packet = smb_vmalloc(len + 4);
		if (server->packet == NULL)
		{
			return -ENOMEM;
		}
		server->packet_size = len + 4;
	}
	memcpy(server->packet, peek_buf, 4);
	result = smb_receive_raw(sock, server->packet + 4, len);
 
	if (result < 0)
	{
		printk("smb_receive: receive error: %d\n", result);
		return result;
	}
	server->rcls = BVAL(server->packet, 9);
	server->err = WVAL(server->packet, 11);
 
	if (server->rcls != 0)
	{
		DPRINTK("smb_receive: rcls=%d, err=%d\n",
			server->rcls, server->err);
	}
	return result;
}
 
static int
smb_receive_trans2(struct smb_server *server,
		   int *ldata, unsigned char **data,
		   int *lparam, unsigned char **param)
{
	int total_data = 0;
	int total_param = 0;
	int result;
	unsigned char *rcv_buf;
	int buf_len;
	int data_len = 0;
	int param_len = 0;
 
	if ((result = smb_receive(server)) < 0)
	{
		return result;
	}
	if (server->rcls != 0)
	{
		*param = *data = server->packet;
		*ldata = *lparam = 0;
		return 0;
	}
	total_data = WVAL(server->packet, smb_tdrcnt);
	total_param = WVAL(server->packet, smb_tprcnt);
 
	DDPRINTK("smb_receive_trans2: td=%d,tp=%d\n", total_data, total_param);
 
	if ((total_data > TRANS2_MAX_TRANSFER)
	    || (total_param > TRANS2_MAX_TRANSFER))
	{
		DPRINTK("smb_receive_trans2: data/param too long\n");
		return -EIO;
	}
	buf_len = total_data + total_param;
	if (server->packet_size > buf_len)
	{
		buf_len = server->packet_size;
	}
	if ((rcv_buf = smb_vmalloc(buf_len)) == NULL)
	{
		DPRINTK("smb_receive_trans2: could not alloc data area\n");
		return -ENOMEM;
	}
	*param = rcv_buf;
	*data = rcv_buf + total_param;
 
	while (1)
	{
		unsigned char *inbuf = server->packet;
 
		if (WVAL(inbuf, smb_prdisp) + WVAL(inbuf, smb_prcnt)
		    > total_param)
		{
			DPRINTK("smb_receive_trans2: invalid parameters\n");
			result = -EIO;
			goto fail;
		}
		memcpy(*param + WVAL(inbuf, smb_prdisp),
		       smb_base(inbuf) + WVAL(inbuf, smb_proff),
		       WVAL(inbuf, smb_prcnt));
		param_len += WVAL(inbuf, smb_prcnt);
 
		if (WVAL(inbuf, smb_drdisp) + WVAL(inbuf, smb_drcnt)
		    > total_data)
		{
			DPRINTK("smb_receive_trans2: invalid data block\n");
			result = -EIO;
			goto fail;
		}
		DDPRINTK("target: %X\n",
			 (unsigned int) *data + WVAL(inbuf, smb_drdisp));
		DDPRINTK("source: %X\n",
			 (unsigned int)
			 smb_base(inbuf) + WVAL(inbuf, smb_droff));
		DDPRINTK("disp: %d, off: %d, cnt: %d\n",
			 WVAL(inbuf, smb_drdisp), WVAL(inbuf, smb_droff),
			 WVAL(inbuf, smb_drcnt));
 
		memcpy(*data + WVAL(inbuf, smb_drdisp),
		       smb_base(inbuf) + WVAL(inbuf, smb_droff),
		       WVAL(inbuf, smb_drcnt));
		data_len += WVAL(inbuf, smb_drcnt);
 
		if ((WVAL(inbuf, smb_tdrcnt) > total_data)
		    || (WVAL(inbuf, smb_tprcnt) > total_param))
		{
			printk("smb_receive_trans2: data/params grew!\n");
			result = -EIO;
			goto fail;
		}
		/* the total lengths might shrink! */
		total_data = WVAL(inbuf, smb_tdrcnt);
		total_param = WVAL(inbuf, smb_tprcnt);
 
		if ((data_len >= total_data) && (param_len >= total_param))
		{
			break;
		}
		if ((result = smb_receive(server)) < 0)
		{
			goto fail;
		}
		if (server->rcls != 0)
		{
			result = -EIO;
			goto fail;
		}
	}
	*ldata = data_len;
	*lparam = param_len;
 
	smb_vfree(server->packet);
	server->packet = rcv_buf;
	server->packet_size = buf_len;
	return 0;
 
      fail:
	smb_vfree(rcv_buf);
	return result;
}
 
int
smb_release(struct smb_server *server)
{
	struct socket *sock = server_sock(server);
	int result;
 
	if (sock == NULL)
	{
		return -EINVAL;
	}
	result = sock->ops->release(sock, NULL);
	DPRINTK("smb_release: sock->ops->release = %d\n", result);
 
	/* inet_release does not set sock->state.  Maybe someone is
	   confused about sock->state being SS_CONNECTED while there
	   is nothing behind it, so I set it to SS_UNCONNECTED. */
	sock->state = SS_UNCONNECTED;
 
	result = sock->ops->create(sock, 0);
	DPRINTK("smb_release: sock->ops->create = %d\n", result);
	return result;
}
 
int
smb_connect(struct smb_server *server)
{
	struct socket *sock = server_sock(server);
	if (sock == NULL)
	{
		return -EINVAL;
	}
	if (sock->state != SS_UNCONNECTED)
	{
		DPRINTK("smb_connect: socket is not unconnected: %d\n",
			sock->state);
	}
	return sock->ops->connect(sock, (struct sockaddr *) &(server->m.addr),
				  sizeof(struct sockaddr_in), 0);
}
 
int
smb_request(struct smb_server *server)
{
	unsigned long old_mask;
	unsigned short fs;
	int len, result;
 
	unsigned char *buffer = (server == NULL) ? NULL : server->packet;
 
	if (buffer == NULL)
	{
		printk("smb_request: Bad server!\n");
		return -EBADF;
	}
	if (server->state != CONN_VALID)
	{
		return -EIO;
	}
	if ((result = smb_dont_catch_keepalive(server)) != 0)
	{
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
		return result;
	}
	len = smb_len(buffer) + 4;
 
	DPRINTK("smb_request: len = %d cmd = 0x%X\n", len, buffer[8]);
 
	old_mask = current->blocked;
	current->blocked |= ~(_S(SIGKILL) | _S(SIGSTOP));
	fs = get_fs();
	set_fs(get_ds());
 
	result = smb_send_raw(server_sock(server), (void *) buffer, len);
	if (result > 0)
	{
		result = smb_receive(server);
	}
	/* read/write errors are handled by errno */
	current->signal &= ~_S(SIGPIPE);
	current->blocked = old_mask;
	set_fs(fs);
 
	if (result >= 0)
	{
		int result2 = smb_catch_keepalive(server);
		if (result2 < 0)
		{
			result = result2;
		}
	}
	if (result < 0)
	{
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
	}
	DDPRINTK("smb_request: result = %d\n", result);
 
	return result;
}
 
#define ROUND_UP(x) (((x)+3) & ~3)
static int
smb_send_trans2(struct smb_server *server, __u16 trans2_command,
		int ldata, unsigned char *data,
		int lparam, unsigned char *param)
{
	struct socket *sock = server_sock(server);
 
	/* I know the following is very ugly, but I want to build the
	   smb packet as efficiently as possible. */
 
	const int smb_parameters = 15;
	const int oparam =
	ROUND_UP(SMB_HEADER_LEN + 2 * smb_parameters + 2 + 3);
	const int odata =
	ROUND_UP(oparam + lparam);
	const int bcc =
	odata + ldata - (SMB_HEADER_LEN + 2 * smb_parameters + 2);
	const int packet_length =
	SMB_HEADER_LEN + 2 * smb_parameters + bcc + 2;
 
	unsigned char padding[4] =
	{0,};
	char *p;
 
	struct iovec iov[4];
	struct msghdr msg;
 
	if ((bcc + oparam) > server->max_xmit)
	{
		return -ENOMEM;
	}
	p = smb_setup_header(server, SMBtrans2, smb_parameters, bcc);
 
	WSET(server->packet, smb_tpscnt, lparam);
	WSET(server->packet, smb_tdscnt, ldata);
	WSET(server->packet, smb_mprcnt, TRANS2_MAX_TRANSFER);
	WSET(server->packet, smb_mdrcnt, TRANS2_MAX_TRANSFER);
	WSET(server->packet, smb_msrcnt, 0);
	WSET(server->packet, smb_flags, 0);
	DSET(server->packet, smb_timeout, 0);
	WSET(server->packet, smb_pscnt, lparam);
	WSET(server->packet, smb_psoff, oparam - 4);
	WSET(server->packet, smb_dscnt, ldata);
	WSET(server->packet, smb_dsoff, odata - 4);
	WSET(server->packet, smb_suwcnt, 1);
	WSET(server->packet, smb_setup0, trans2_command);
	*p++ = 0;		/* null smb_name for trans2 */
	*p++ = 'D';		/* this was added because OS/2 does it */
	*p++ = ' ';
 
	iov[0].iov_base = (void *) server->packet;
	iov[0].iov_len = oparam;
	iov[1].iov_base = (param == NULL) ? padding : param;
	iov[1].iov_len = lparam;
	iov[2].iov_base = padding;
	iov[2].iov_len = odata - oparam - lparam;
	iov[3].iov_base = (data == NULL) ? padding : data;
	iov[3].iov_len = ldata;
 
	msg.msg_name = NULL;
	msg.msg_namelen = 0;
	msg.msg_control = NULL;
	msg.msg_iov = iov;
	msg.msg_iovlen = 4;
 
	return sock->ops->sendmsg(sock, &msg, packet_length, 0, 0);
}
 
/*
 * This is not really a trans2 request, we assume that you only have
 * one packet to send.
 */
int
smb_trans2_request(struct smb_server *server, __u16 trans2_command,
		   int ldata, unsigned char *data,
		   int lparam, unsigned char *param,
		   int *lrdata, unsigned char **rdata,
		   int *lrparam, unsigned char **rparam)
{
	unsigned long old_mask;
	unsigned short fs;
	int result;
 
	DPRINTK("smb_trans2_request: com=%d, ld=%d, lp=%d\n",
		trans2_command, ldata, lparam);
 
	if (server->state != CONN_VALID)
	{
		return -EIO;
	}
	if ((result = smb_dont_catch_keepalive(server)) != 0)
	{
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
		return result;
	}
	old_mask = current->blocked;
	current->blocked |= ~(_S(SIGKILL) | _S(SIGSTOP));
	fs = get_fs();
	set_fs(get_ds());
 
	result = smb_send_trans2(server, trans2_command,
				 ldata, data, lparam, param);
	if (result >= 0)
	{
		result = smb_receive_trans2(server,
					    lrdata, rdata, lrparam, rparam);
	}
	/* read/write errors are handled by errno */
	current->signal &= ~_S(SIGPIPE);
	current->blocked = old_mask;
	set_fs(fs);
 
	if (result >= 0)
	{
		int result2 = smb_catch_keepalive(server);
		if (result2 < 0)
		{
			result = result2;
		}
	}
	if (result < 0)
	{
		server->state = CONN_INVALID;
		smb_invalidate_all_inodes(server);
	}
	DDPRINTK("smb_trans2_request: result = %d\n", result);
 
	return result;
}
 

Go to most recent revision | Compare with Previous | Blame | View Log

powered by: WebSVN 2.1.0

© copyright 1999-2025 OpenCores.org, equivalent to Oliscience, all rights reserved. OpenCores®, registered trademark.