00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef _PASSENGER_MESSAGE_SERVER_H_
00026 #define _PASSENGER_MESSAGE_SERVER_H_
00027
00028 #include <string>
00029 #include <vector>
00030
00031 #include <boost/shared_ptr.hpp>
00032 #include <boost/thread.hpp>
00033 #include <oxt/system_calls.hpp>
00034 #include <oxt/dynamic_thread_group.hpp>
00035
00036 #include <sys/types.h>
00037 #include <sys/stat.h>
00038 #include <sys/un.h>
00039 #include <unistd.h>
00040 #include <cerrno>
00041 #include <cassert>
00042
00043 #include "Account.h"
00044 #include "AccountsDatabase.h"
00045 #include "Constants.h"
00046 #include "FileDescriptor.h"
00047 #include "MessageChannel.h"
00048 #include "Logging.h"
00049 #include "Exceptions.h"
00050 #include "Utils/StrIntUtils.h"
00051 #include "Utils/IOUtils.h"
00052
00053 namespace Passenger {
00054
00055 using namespace std;
00056 using namespace boost;
00057 using namespace oxt;
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158 class MessageServer {
00159 public:
00160 static const unsigned int CLIENT_THREAD_STACK_SIZE = 64 * 1024;
00161
00162
00163 class ClientContext {
00164 public:
00165 virtual ~ClientContext() { }
00166 };
00167
00168 typedef shared_ptr<ClientContext> ClientContextPtr;
00169
00170
00171
00172
00173
00174 class CommonClientContext: public ClientContext {
00175 public:
00176
00177 FileDescriptor fd;
00178
00179
00180 MessageChannel channel;
00181
00182
00183 AccountPtr account;
00184
00185
00186 CommonClientContext(FileDescriptor &theFd, AccountPtr &theAccount)
00187 : fd(theFd), channel(theFd), account(theAccount)
00188 { }
00189
00190
00191 string name() {
00192 return toString(channel.filenum());
00193 }
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204 void requireRights(Account::Rights rights) {
00205 if (!account->hasRights(rights)) {
00206 P_TRACE(2, "Security error: insufficient rights to execute this command.");
00207 channel.write("SecurityException", "Insufficient rights to execute this command.", NULL);
00208 throw SecurityException("Insufficient rights to execute this command.");
00209 } else {
00210 channel.write("Passed security", NULL);
00211 }
00212 }
00213 };
00214
00215
00216
00217
00218
00219
00220
00221
00222 class Handler {
00223 public:
00224 virtual ~Handler() { }
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235 virtual ClientContextPtr newClient(CommonClientContext &context) {
00236 return ClientContextPtr();
00237 }
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249
00250 virtual void clientDisconnected(MessageServer::CommonClientContext &context,
00251 MessageServer::ClientContextPtr &handlerSpecificContext)
00252 { }
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266 virtual bool processMessage(CommonClientContext &commonContext,
00267 ClientContextPtr &handlerSpecificContext,
00268 const vector<string> &args) = 0;
00269 };
00270
00271 typedef shared_ptr<Handler> HandlerPtr;
00272
00273 protected:
00274
00275 string socketFilename;
00276
00277
00278 AccountsDatabasePtr accountsDatabase;
00279
00280
00281 vector<HandlerPtr> handlers;
00282
00283
00284
00285
00286
00287
00288 unsigned long long loginTimeout;
00289
00290
00291 dynamic_thread_group threadGroup;
00292
00293
00294
00295
00296 int serverFd;
00297
00298
00299
00300 struct DisconnectEventBroadcastGuard {
00301 vector<HandlerPtr> &handlers;
00302 CommonClientContext &commonContext;
00303 vector<ClientContextPtr> &handlerSpecificContexts;
00304
00305 DisconnectEventBroadcastGuard(vector<HandlerPtr> &_handlers,
00306 CommonClientContext &_commonContext,
00307 vector<ClientContextPtr> &_handlerSpecificContexts)
00308 : handlers(_handlers),
00309 commonContext(_commonContext),
00310 handlerSpecificContexts(_handlerSpecificContexts)
00311 { }
00312
00313 ~DisconnectEventBroadcastGuard() {
00314 vector<HandlerPtr>::iterator handler_iter;
00315 vector<ClientContextPtr>::iterator context_iter;
00316
00317 for (handler_iter = handlers.begin(), context_iter = handlerSpecificContexts.begin();
00318 handler_iter != handlers.end();
00319 handler_iter++, context_iter++) {
00320 (*handler_iter)->clientDisconnected(commonContext, *context_iter);
00321 }
00322 }
00323 };
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334 void startListening() {
00335 TRACE_POINT();
00336 int ret;
00337
00338 serverFd = createUnixServer(socketFilename.c_str());
00339 do {
00340 ret = chmod(socketFilename.c_str(),
00341 S_ISVTX |
00342 S_IRUSR | S_IWUSR | S_IXUSR |
00343 S_IRGRP | S_IWGRP | S_IXGRP |
00344 S_IROTH | S_IWOTH | S_IXOTH);
00345 } while (ret == -1 && errno == EINTR);
00346 }
00347
00348
00349
00350
00351
00352
00353 AccountPtr authenticate(FileDescriptor &client) {
00354 MessageChannel channel(client);
00355 string username, password;
00356 MemZeroGuard passwordGuard(password);
00357 unsigned long long timeout = loginTimeout;
00358
00359 try {
00360 channel.write("version", "1", NULL);
00361
00362 try {
00363 if (!channel.readScalar(username, MESSAGE_SERVER_MAX_USERNAME_SIZE, &timeout)) {
00364 return AccountPtr();
00365 }
00366 } catch (const SecurityException &) {
00367 channel.write("The supplied username is too long.", NULL);
00368 return AccountPtr();
00369 }
00370
00371 try {
00372 if (!channel.readScalar(password, MESSAGE_SERVER_MAX_PASSWORD_SIZE, &timeout)) {
00373 return AccountPtr();
00374 }
00375 } catch (const SecurityException &) {
00376 channel.write("The supplied password is too long.", NULL);
00377 return AccountPtr();
00378 }
00379
00380 AccountPtr account = accountsDatabase->authenticate(username, password);
00381 passwordGuard.zeroNow();
00382 if (account == NULL) {
00383 channel.write("Invalid username or password.", NULL);
00384 return AccountPtr();
00385 } else {
00386 channel.write("ok", NULL);
00387 return account;
00388 }
00389 } catch (const SystemException &) {
00390 return AccountPtr();
00391 } catch (const TimeoutException &) {
00392 return AccountPtr();
00393 }
00394 }
00395
00396 void broadcastNewClientEvent(CommonClientContext &context,
00397 vector<ClientContextPtr> &handlerSpecificContexts) {
00398 vector<HandlerPtr>::iterator it;
00399
00400 for (it = handlers.begin(); it != handlers.end(); it++) {
00401 handlerSpecificContexts.push_back((*it)->newClient(context));
00402 }
00403 }
00404
00405 bool processMessage(CommonClientContext &commonContext,
00406 vector<ClientContextPtr> &handlerSpecificContexts,
00407 const vector<string> &args) {
00408 vector<HandlerPtr>::iterator handler_iter;
00409 vector<ClientContextPtr>::iterator context_iter;
00410
00411 for (handler_iter = handlers.begin(), context_iter = handlerSpecificContexts.begin();
00412 handler_iter != handlers.end();
00413 handler_iter++, context_iter++) {
00414 if ((*handler_iter)->processMessage(commonContext, *context_iter, args)) {
00415 return true;
00416 }
00417 }
00418 return false;
00419 }
00420
00421 void processUnknownMessage(CommonClientContext &commonContext, const vector<string> &args) {
00422 TRACE_POINT();
00423 string name;
00424 if (args.empty()) {
00425 name = "(null)";
00426 } else {
00427 name = args[0];
00428 }
00429 P_TRACE(2, "A MessageServer client sent an invalid command: "
00430 << name << " (" << args.size() << " elements)");
00431 }
00432
00433
00434
00435
00436 void clientHandlingMainLoop(FileDescriptor &client) {
00437 TRACE_POINT();
00438 vector<string> args;
00439
00440 P_TRACE(4, "MessageServer client thread " << (int) client << " started.");
00441
00442 try {
00443 AccountPtr account(authenticate(client));
00444 if (account == NULL) {
00445 P_TRACE(4, "MessageServer client thread " << (int) client << " exited.");
00446 return;
00447 }
00448
00449 CommonClientContext commonContext(client, account);
00450 vector<ClientContextPtr> handlerSpecificContexts;
00451 broadcastNewClientEvent(commonContext, handlerSpecificContexts);
00452 DisconnectEventBroadcastGuard dguard(handlers, commonContext, handlerSpecificContexts);
00453
00454 while (!this_thread::interruption_requested()) {
00455 UPDATE_TRACE_POINT();
00456 if (!commonContext.channel.read(args)) {
00457
00458 break;
00459 }
00460
00461 P_TRACE(4, "MessageServer client " << commonContext.name() <<
00462 ": received message: " << toString(args));
00463
00464 UPDATE_TRACE_POINT();
00465 if (!processMessage(commonContext, handlerSpecificContexts, args)) {
00466 processUnknownMessage(commonContext, args);
00467 break;
00468 }
00469 args.clear();
00470 }
00471
00472 P_TRACE(4, "MessageServer client thread " << (int) client << " exited.");
00473 client.close();
00474 } catch (const boost::thread_interrupted &) {
00475 P_TRACE(2, "MessageServer client thread " << (int) client << " interrupted.");
00476 } catch (const tracable_exception &e) {
00477 P_TRACE(2, "An error occurred in a MessageServer client thread " << (int) client << ":\n"
00478 << " message: " << toString(args) << "\n"
00479 << " exception: " << e.what() << "\n"
00480 << " backtrace:\n" << e.backtrace());
00481 }
00482 }
00483
00484 public:
00485
00486
00487
00488
00489
00490
00491
00492
00493
00494
00495
00496
00497 MessageServer(const string &socketFilename, AccountsDatabasePtr accountsDatabase) {
00498 this->socketFilename = socketFilename;
00499 this->accountsDatabase = accountsDatabase;
00500 loginTimeout = 2000;
00501 startListening();
00502 }
00503
00504 ~MessageServer() {
00505 this_thread::disable_syscall_interruption dsi;
00506 syscalls::close(serverFd);
00507 syscalls::unlink(socketFilename.c_str());
00508 }
00509
00510 string getSocketFilename() const {
00511 return socketFilename;
00512 }
00513
00514
00515
00516
00517
00518
00519
00520
00521
00522
00523 void mainLoop() {
00524 TRACE_POINT();
00525 while (true) {
00526 this_thread::interruption_point();
00527 sockaddr_un addr;
00528 socklen_t len = sizeof(addr);
00529 FileDescriptor fd;
00530
00531 UPDATE_TRACE_POINT();
00532 fd = syscalls::accept(serverFd, (struct sockaddr *) &addr, &len);
00533 if (fd == -1) {
00534 throw SystemException("Unable to accept a new client", errno);
00535 }
00536
00537 UPDATE_TRACE_POINT();
00538 this_thread::disable_interruption di;
00539 this_thread::disable_syscall_interruption dsi;
00540
00541 function<void ()> func(boost::bind(&MessageServer::clientHandlingMainLoop,
00542 this, fd));
00543 string name = "MessageServer client thread ";
00544 name.append(toString(fd));
00545 threadGroup.create_thread(func, name, CLIENT_THREAD_STACK_SIZE);
00546 }
00547 }
00548
00549
00550
00551
00552
00553
00554 void addHandler(HandlerPtr handler) {
00555 handlers.push_back(handler);
00556 }
00557
00558
00559
00560
00561
00562
00563
00564
00565 void setLoginTimeout(unsigned long long timeout) {
00566 assert(timeout != 0);
00567 loginTimeout = timeout;
00568 }
00569 };
00570
00571 typedef shared_ptr<MessageServer> MessageServerPtr;
00572
00573 }
00574
00575 #endif