Copied atexit() code from fakesmtp.c to fakepop.c so that its
[mmh] / test / fakesmtp.c
1 /*
2  * fakesmtp - A fake SMTP server used by the nmh test suite
3  *
4  * This code is Copyright (c) 2012, by the authors of nmh.  See the
5  * COPYRIGHT file in the root directory of the nmh distribution for
6  * complete copyright information.
7  */
8
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <unistd.h>
13 #include <netdb.h>
14 #include <errno.h>
15 #include <sys/socket.h>
16 #include <sys/types.h>
17 #include <sys/select.h>
18 #include <sys/stat.h>
19 #include <sys/uio.h>
20 #include <signal.h>
21
22 #define PIDFILE "/tmp/fakesmtp.pid"
23
24 #define LINESIZE 1024
25
26 static void killpidfile(void);
27 static void handleterm(int);
28 static void putsmtp(int, char *);
29 static int getsmtp(int, char *);
30
31 int
32 main(int argc, char *argv[])
33 {
34         struct addrinfo hints, *res;
35         int rc, l, conn, on, datamode;
36         FILE *f, *pid;
37         pid_t child;
38         fd_set readfd;
39         struct stat st;
40         struct timeval tv;
41
42         if (argc != 3) {
43                 fprintf(stderr, "Usage: %s output-filename port\n", argv[0]);
44                 exit(1);
45         }
46
47         if (!(f = fopen(argv[1], "w"))) {
48                 fprintf(stderr, "Unable to open output file \"%s\": %s\n",
49                         argv[1], strerror(errno));
50                 exit(1);
51         }
52
53         /*
54          * If there is a pid file already around, kill the previously running
55          * fakesmtp process.  Hopefully this will reduce the race conditions
56          * that crop up when running the test suite.
57          */
58
59         if (stat(PIDFILE, &st) == 0) {
60                 long oldpid;
61
62                 if (!(pid = fopen(PIDFILE, "r"))) {
63                         fprintf(stderr, "Cannot open " PIDFILE
64                                 " (%s), continuing ...\n", strerror(errno));
65                 } else {
66                         rc = fscanf(pid, "%ld", &oldpid);
67                         fclose(pid);
68
69                         if (rc != 1) {
70                                 fprintf(stderr, "Unable to parse pid in "
71                                         PIDFILE ", continuing ...\n");
72                         } else {
73                                 kill((pid_t) oldpid, SIGTERM);
74                         }
75                 }
76
77                 unlink(PIDFILE);
78         }
79
80         memset(&hints, 0, sizeof(hints));
81
82         hints.ai_family = PF_INET;
83         hints.ai_socktype = SOCK_STREAM;
84         hints.ai_protocol = IPPROTO_TCP;
85         hints.ai_flags = AI_PASSIVE;
86
87         rc = getaddrinfo("127.0.0.1", argv[2], &hints, &res);
88
89         if (rc) {
90                 fprintf(stderr, "Unable to resolve localhost/%s: %s\n",
91                         argv[2], gai_strerror(rc));
92                 exit(1);
93         }
94
95         l = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
96
97         if (l == -1) {
98                 fprintf(stderr, "Unable to create listening socket: %s\n",
99                         strerror(errno));
100                 exit(1);
101         }
102
103         on = 1;
104
105         if (setsockopt(l, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) == -1) {
106                 fprintf(stderr, "Unable to set SO_REUSEADDR: %s\n",
107                         strerror(errno));
108                 exit(1);
109         }
110
111         if (bind(l, res->ai_addr, res->ai_addrlen) == -1) {
112                 fprintf(stderr, "Unable to bind socket: %s\n", strerror(errno));
113                 exit(1);
114         }
115
116         if (listen(l, 1) == -1) {
117                 fprintf(stderr, "Unable to listen on socket: %s\n",
118                         strerror(errno));
119                 exit(1);
120         }
121
122         /*
123          * Now we fork() and print out the process ID of our child
124          * for scripts to use.  Once we do that, then exit.
125          */
126
127         child = fork();
128
129         switch (child) {
130         case -1:
131                 fprintf(stderr, "Unable to fork child: %s\n", strerror(errno));
132                 exit(1);
133                 break;
134         case 0:
135                 /*
136                  * Close stdin & stdout, otherwise people can
137                  * think we're still doing stuff.  For now leave stderr
138                  * open.
139                  */
140                 fclose(stdin);
141                 fclose(stdout);
142                 break;
143         default:
144                 printf("%ld\n", (long) child);
145                 exit(0);
146         }
147
148         /*
149          * Now that our socket & files are set up, wait 30 seconds for
150          * a connection.  If there isn't one, then exit.
151          */
152
153         if (!(pid = fopen(PIDFILE, "w"))) {
154                 fprintf(stderr, "Cannot open " PIDFILE ": %s\n",
155                         strerror(errno));
156                 exit(1);
157         }
158
159         fprintf(pid, "%ld\n", (long) getpid());
160         fclose(pid);
161
162         signal(SIGTERM, handleterm);
163         atexit(killpidfile);
164
165         FD_ZERO(&readfd);
166         FD_SET(l, &readfd);
167         tv.tv_sec = 30;
168         tv.tv_usec = 0;
169
170         rc = select(l + 1, &readfd, NULL, NULL, &tv);
171
172         if (rc < 0) {
173                 fprintf(stderr, "select() failed: %s\n", strerror(errno));
174                 exit(1);
175         }
176
177         /*
178          * I think if we get a timeout, we should just exit quietly.
179          */
180
181         if (rc == 0) {
182                 exit(1);
183         }
184
185         /*
186          * Alright, got a connection!  Accept it.
187          */
188
189         if ((conn = accept(l, NULL, NULL)) == -1) {
190                 fprintf(stderr, "Unable to accept connection: %s\n",
191                         strerror(errno));
192                 exit(1);
193         }
194
195         close(l);
196         
197         /*
198          * Pretend to be an SMTP server.
199          */
200
201         putsmtp(conn, "220 Not really an ESMTP server");
202         datamode = 0;
203
204         for (;;) {
205                 char line[LINESIZE];
206
207                 rc = getsmtp(conn, line);
208
209                 if (rc == -1)
210                         break;  /* EOF */
211
212                 fprintf(f, "%s\n", line);
213
214                 /*
215                  * If we're in DATA mode, then check to see if we've got
216                  * a "."; otherwise, continue
217                  */
218
219                 if (datamode) {
220                         if (strcmp(line, ".") == 0) {
221                                 datamode = 0;
222                                 putsmtp(conn, "250 Thanks for the info!");
223                         }
224                         continue;
225                 }
226
227                 /*
228                  * Most commands we ignore and send the same response to.
229                  */
230
231                 if (strcmp(line, "QUIT") == 0) {
232                         fclose(f);
233                         f = NULL;
234                         putsmtp(conn, "221 Later alligator!");
235                         close(conn);
236                         break;
237                 } else if (strcmp(line, "DATA") == 0) {
238                         putsmtp(conn, "354 Go ahead");
239                         datamode = 1;
240                 } else {
241                         putsmtp(conn, "250 I'll buy that for a dollar!");
242                 }
243         }
244
245         if (f)
246                 fclose(f);
247
248         exit(0);
249 }
250
251 /*
252  * Write a line to the SMTP client on the other end
253  */
254
255 static void
256 putsmtp(int socket, char *data)
257 {
258         struct iovec iov[2];
259
260         iov[0].iov_base = data;
261         iov[0].iov_len = strlen(data);
262         iov[1].iov_base = "\r\n";
263         iov[1].iov_len = 2;
264
265         writev(socket, iov, 2);
266 }
267
268 /*
269  * Read a line (up to the \r\n)
270  */
271
272 static int
273 getsmtp(int socket, char *data)
274 {
275         int cc;
276         static unsigned int bytesinbuf = 0;
277         static char buffer[LINESIZE * 2], *p;
278
279         for (;;) {
280                 /*
281                  * Find our \r\n
282                  */
283
284                 if (bytesinbuf > 0 && (p = strchr(buffer, '\r')) &&
285                                                         *(p + 1) == '\n') {
286                         *p = '\0';
287                         strncpy(data, buffer, LINESIZE);
288                         data[LINESIZE - 1] = '\0';
289                         cc = strlen(buffer);
290
291                         /*
292                          * Shuffle leftover bytes back to the beginning
293                          */
294
295                         bytesinbuf -= cc + 2;   /* Don't forget \r\n */
296                         if (bytesinbuf > 0) {
297                                 memmove(buffer, buffer + cc + 2, bytesinbuf);
298                         }
299                         return cc;
300                 }
301
302                 if (bytesinbuf >= sizeof(buffer)) {
303                         fprintf(stderr, "Buffer overflow in getsmtp()!\n");
304                         exit(1);
305                 }
306
307                 memset(buffer + bytesinbuf, 0, sizeof(buffer) - bytesinbuf);
308                 cc = read(socket, buffer + bytesinbuf,
309                           sizeof(buffer) - bytesinbuf);
310
311                 if (cc < 0) {
312                         fprintf(stderr, "Read failed: %s\n", strerror(errno));
313                         exit(1);
314                 }
315
316                 if (cc == 0)
317                         return -1;
318
319                 bytesinbuf += cc;
320         }
321 }
322
323 /*
324  * Handle a SIGTERM
325  */
326
327 static void
328 handleterm(int signal)
329 {
330         (void) signal;
331
332         killpidfile();
333         fflush(NULL);
334         _exit(1);
335 }
336
337 /*
338  * Get rid of our pid file
339  */
340
341 static void
342 killpidfile(void)
343 {
344         unlink(PIDFILE);
345 }