From e2abfaf381984b9c441555a3614336b64008c80d Mon Sep 17 00:00:00 2001 From: Igor Sysoev Date: Mon, 26 Aug 2019 18:29:00 +0300 Subject: Adding body handler to nxt_http_request_header_send(). --- src/nxt_h1proto.c | 19 +++++++++++++++++-- src/nxt_http.h | 6 ++++-- src/nxt_http_error.c | 6 ++---- src/nxt_http_request.c | 5 +++-- src/nxt_router.c | 7 +------ 5 files changed, 27 insertions(+), 16 deletions(-) (limited to 'src') diff --git a/src/nxt_h1proto.c b/src/nxt_h1proto.c index a60bdb36..b8f62a88 100644 --- a/src/nxt_h1proto.c +++ b/src/nxt_h1proto.c @@ -45,7 +45,7 @@ static void nxt_h1p_conn_request_body_read(nxt_task_t *task, void *obj, void *data); static void nxt_h1p_request_local_addr(nxt_task_t *task, nxt_http_request_t *r); static void nxt_h1p_request_header_send(nxt_task_t *task, - nxt_http_request_t *r); + nxt_http_request_t *r, nxt_work_handler_t body_handler); static void nxt_h1p_request_send(nxt_task_t *task, nxt_http_request_t *r, nxt_buf_t *out); static nxt_buf_t *nxt_h1p_chunk_create(nxt_task_t *task, nxt_http_request_t *r, @@ -996,7 +996,8 @@ static const nxt_str_t nxt_http_server_error[] = { #define UNKNOWN_STATUS_LENGTH nxt_length("HTTP/1.1 65536\r\n") static void -nxt_h1p_request_header_send(nxt_task_t *task, nxt_http_request_t *r) +nxt_h1p_request_header_send(nxt_task_t *task, nxt_http_request_t *r, + nxt_work_handler_t body_handler) { u_char *p; size_t size; @@ -1079,6 +1080,7 @@ nxt_h1p_request_header_send(nxt_task_t *task, nxt_http_request_t *r) if (http11) { if (n != NXT_HTTP_NOT_MODIFIED && n != NXT_HTTP_NO_CONTENT + && body_handler != NULL && !h1p->websocket) { h1p->chunked = 1; @@ -1165,6 +1167,19 @@ nxt_h1p_request_header_send(nxt_task_t *task, nxt_http_request_t *r) c->write = header; c->write_state = &nxt_h1p_request_send_state; + if (body_handler != NULL) { + /* + * The body handler will run before c->io->write() handler, + * because the latter was inqueued by nxt_conn_write() + * in engine->write_work_queue. + */ + nxt_work_queue_add(&task->thread->engine->fast_work_queue, + body_handler, task, r, NULL); + + } else { + header->next = nxt_http_buf_last(r); + } + nxt_conn_write(task->thread->engine, c); if (h1p->websocket) { diff --git a/src/nxt_http.h b/src/nxt_http.h index ac1eedcf..2ca3f536 100644 --- a/src/nxt_http.h +++ b/src/nxt_http.h @@ -175,7 +175,8 @@ struct nxt_http_pass_s { typedef struct { void (*body_read)(nxt_task_t *task, nxt_http_request_t *r); void (*local_addr)(nxt_task_t *task, nxt_http_request_t *r); - void (*header_send)(nxt_task_t *task, nxt_http_request_t *r); + void (*header_send)(nxt_task_t *task, nxt_http_request_t *r, + nxt_work_handler_t body_handler); void (*send)(nxt_task_t *task, nxt_http_request_t *r, nxt_buf_t *out); nxt_off_t (*body_bytes_sent)(nxt_task_t *task, nxt_http_proto_t proto); void (*discard)(nxt_task_t *task, nxt_http_request_t *r, nxt_buf_t *last); @@ -195,7 +196,8 @@ nxt_http_request_t *nxt_http_request_create(nxt_task_t *task); void nxt_http_request_error(nxt_task_t *task, nxt_http_request_t *r, nxt_http_status_t status); void nxt_http_request_read_body(nxt_task_t *task, nxt_http_request_t *r); -void nxt_http_request_header_send(nxt_task_t *task, nxt_http_request_t *r); +void nxt_http_request_header_send(nxt_task_t *task, nxt_http_request_t *r, + nxt_work_handler_t body_handler); void nxt_http_request_ws_frame_start(nxt_task_t *task, nxt_http_request_t *r, nxt_buf_t *ws_frame); void nxt_http_request_send(nxt_task_t *task, nxt_http_request_t *r, diff --git a/src/nxt_http_error.c b/src/nxt_http_error.c index c7c7e81a..1dcd8783 100644 --- a/src/nxt_http_error.c +++ b/src/nxt_http_error.c @@ -51,12 +51,10 @@ nxt_http_request_error(nxt_task_t *task, nxt_http_request_t *r, r->resp.content_length = NULL; r->resp.content_length_n = nxt_length(error); - nxt_http_request_header_send(task, r); - r->state = &nxt_http_request_send_error_body_state; - nxt_work_queue_add(&task->thread->engine->fast_work_queue, - nxt_http_request_send_error_body, task, r, NULL); + nxt_http_request_header_send(task, r, nxt_http_request_send_error_body); + return; fail: diff --git a/src/nxt_http_request.c b/src/nxt_http_request.c index 916004d2..8c38eeb2 100644 --- a/src/nxt_http_request.c +++ b/src/nxt_http_request.c @@ -369,7 +369,8 @@ nxt_http_request_read_body(nxt_task_t *task, nxt_http_request_t *r) void -nxt_http_request_header_send(nxt_task_t *task, nxt_http_request_t *r) +nxt_http_request_header_send(nxt_task_t *task, nxt_http_request_t *r, + nxt_work_handler_t body_handler) { u_char *p, *end; nxt_http_field_t *server, *date, *content_length; @@ -430,7 +431,7 @@ nxt_http_request_header_send(nxt_task_t *task, nxt_http_request_t *r) } if (nxt_fast_path(r->proto.any != NULL)) { - nxt_http_proto[r->protocol].header_send(task, r); + nxt_http_proto[r->protocol].header_send(task, r, body_handler); } return; diff --git a/src/nxt_router.c b/src/nxt_router.c index b87f588f..c3bcc51c 100644 --- a/src/nxt_router.c +++ b/src/nxt_router.c @@ -3531,7 +3531,7 @@ nxt_router_response_ready_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg, nxt_buf_chain_add(&r->out, b); } - nxt_http_request_header_send(task, r); + nxt_http_request_header_send(task, r, nxt_http_request_send_body); if (r->websocket_handshake && r->status == NXT_HTTP_SWITCHING_PROTOCOLS) @@ -3575,11 +3575,6 @@ nxt_router_response_ready_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg, } else { r->state = &nxt_http_request_send_state; } - - if (r->out) { - nxt_work_queue_add(&task->thread->engine->fast_work_queue, - nxt_http_request_send_body, task, r, NULL); - } } return; -- cgit From 7053a35a60bdacdfcf1277dd8f08f6c46c8028c7 Mon Sep 17 00:00:00 2001 From: Max Romanov Date: Fri, 30 Aug 2019 00:07:54 +0300 Subject: Fixed WebSocket implementation that was broken on some systems. The "nxt_http_websocket" request state, defined in "nxt_http_websocket.c", is used in "nxt_router.c" and must be linked with external symbol declared in "nxt_router.c". Due to the missing "extern" keyword, building Unit with some linkers (notably gold and LLD) caused WebSocket connections to get stuck or even crash the router process. --- src/nxt_router.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/nxt_router.c b/src/nxt_router.c index c3bcc51c..89ea5fe6 100644 --- a/src/nxt_router.c +++ b/src/nxt_router.c @@ -247,7 +247,7 @@ static nxt_int_t nxt_router_http_request_done(nxt_task_t *task, static void nxt_router_http_request_release(nxt_task_t *task, void *obj, void *data); -const nxt_http_request_state_t nxt_http_websocket; +extern const nxt_http_request_state_t nxt_http_websocket; static nxt_router_t *nxt_router; -- cgit From daadc2e00bc47498513929549bf0364070a1557c Mon Sep 17 00:00:00 2001 From: Max Romanov Date: Mon, 2 Sep 2019 18:27:08 +0300 Subject: Making request state handler calls more consistent. --- src/nxt_h1proto.c | 6 ++---- src/nxt_h1proto_websocket.c | 8 ++------ 2 files changed, 4 insertions(+), 10 deletions(-) (limited to 'src') diff --git a/src/nxt_h1proto.c b/src/nxt_h1proto.c index b8f62a88..39be4315 100644 --- a/src/nxt_h1proto.c +++ b/src/nxt_h1proto.c @@ -830,8 +830,7 @@ nxt_h1p_request_body_read(nxt_task_t *task, nxt_http_request_t *r) ready: - nxt_work_queue_add(&task->thread->engine->fast_work_queue, - r->state->ready_handler, task, r, NULL); + r->state->ready_handler(task, r, NULL); return; @@ -884,8 +883,7 @@ nxt_h1p_conn_request_body_read(nxt_task_t *task, void *obj, void *data) c->read = NULL; r = h1p->request; - nxt_work_queue_add(&engine->fast_work_queue, r->state->ready_handler, - task, r, NULL); + r->state->ready_handler(task, r, NULL); } } diff --git a/src/nxt_h1proto_websocket.c b/src/nxt_h1proto_websocket.c index dd9b6848..13754be0 100644 --- a/src/nxt_h1proto_websocket.c +++ b/src/nxt_h1proto_websocket.c @@ -417,16 +417,13 @@ nxt_h1p_conn_ws_frame_process(nxt_task_t *task, nxt_conn_t *c, uint8_t *p, *mask; uint16_t code; nxt_http_request_t *r; - nxt_event_engine_t *engine; - engine = task->thread->engine; r = h1p->request; c->read = NULL; if (nxt_slow_path(wsh->opcode == NXT_WEBSOCKET_OP_PING)) { - nxt_work_queue_add(&engine->fast_work_queue, nxt_h1p_conn_ws_pong, - task, r, NULL); + nxt_h1p_conn_ws_pong(task, r, NULL); return; } @@ -451,8 +448,7 @@ nxt_h1p_conn_ws_frame_process(nxt_task_t *task, nxt_conn_t *c, h1p->websocket_closed = 1; } - nxt_work_queue_add(&engine->fast_work_queue, r->state->ready_handler, - task, r, NULL); + r->state->ready_handler(task, r, NULL); } -- cgit From 2b8cab1e2478547398ad9c2fe68e025c180cac54 Mon Sep 17 00:00:00 2001 From: Max Romanov Date: Thu, 5 Sep 2019 15:27:32 +0300 Subject: Java: introducing websocket support. --- src/java/javax/websocket/ClientEndpoint.java | 34 + src/java/javax/websocket/ClientEndpointConfig.java | 138 +++ src/java/javax/websocket/CloseReason.java | 122 ++ src/java/javax/websocket/ContainerProvider.java | 63 + src/java/javax/websocket/DecodeException.java | 56 + src/java/javax/websocket/Decoder.java | 53 + .../websocket/DefaultClientEndpointConfig.java | 80 ++ src/java/javax/websocket/DeploymentException.java | 30 + src/java/javax/websocket/EncodeException.java | 38 + src/java/javax/websocket/Encoder.java | 51 + src/java/javax/websocket/Endpoint.java | 49 + src/java/javax/websocket/EndpointConfig.java | 29 + src/java/javax/websocket/Extension.java | 29 + src/java/javax/websocket/HandshakeResponse.java | 30 + src/java/javax/websocket/MessageHandler.java | 42 + src/java/javax/websocket/OnClose.java | 27 + src/java/javax/websocket/OnError.java | 27 + src/java/javax/websocket/OnMessage.java | 28 + src/java/javax/websocket/OnOpen.java | 27 + src/java/javax/websocket/PongMessage.java | 32 + src/java/javax/websocket/RemoteEndpoint.java | 229 ++++ src/java/javax/websocket/SendHandler.java | 22 + src/java/javax/websocket/SendResult.java | 39 + src/java/javax/websocket/Session.java | 193 +++ src/java/javax/websocket/SessionException.java | 35 + src/java/javax/websocket/WebSocketContainer.java | 131 +++ .../server/DefaultServerEndpointConfig.java | 95 ++ .../javax/websocket/server/HandshakeRequest.java | 53 + src/java/javax/websocket/server/PathParam.java | 33 + .../websocket/server/ServerApplicationConfig.java | 51 + .../javax/websocket/server/ServerContainer.java | 30 + .../javax/websocket/server/ServerEndpoint.java | 46 + .../websocket/server/ServerEndpointConfig.java | 218 ++++ src/java/nginx/unit/Context.java | 139 ++- src/java/nginx/unit/Request.java | 90 +- .../unit/websocket/AsyncChannelGroupUtil.java | 151 +++ .../nginx/unit/websocket/AsyncChannelWrapper.java | 47 + .../websocket/AsyncChannelWrapperNonSecure.java | 112 ++ .../unit/websocket/AsyncChannelWrapperSecure.java | 578 +++++++++ .../unit/websocket/AuthenticationException.java | 35 + src/java/nginx/unit/websocket/Authenticator.java | 71 ++ .../nginx/unit/websocket/AuthenticatorFactory.java | 68 ++ .../nginx/unit/websocket/BackgroundProcess.java | 26 + .../unit/websocket/BackgroundProcessManager.java | 149 +++ .../nginx/unit/websocket/BasicAuthenticator.java | 66 ++ src/java/nginx/unit/websocket/Constants.java | 158 +++ src/java/nginx/unit/websocket/DecoderEntry.java | 39 + .../nginx/unit/websocket/DigestAuthenticator.java | 150 +++ .../nginx/unit/websocket/FutureToSendHandler.java | 112 ++ .../nginx/unit/websocket/LocalStrings.properties | 147 +++ .../nginx/unit/websocket/MessageHandlerResult.java | 42 + .../unit/websocket/MessageHandlerResultType.java | 23 + src/java/nginx/unit/websocket/MessagePart.java | 83 ++ .../nginx/unit/websocket/PerMessageDeflate.java | 476 ++++++++ .../websocket/ReadBufferOverflowException.java | 34 + src/java/nginx/unit/websocket/Transformation.java | 111 ++ .../unit/websocket/TransformationFactory.java | 51 + .../nginx/unit/websocket/TransformationResult.java | 37 + src/java/nginx/unit/websocket/Util.java | 666 +++++++++++ .../unit/websocket/WrappedMessageHandler.java | 25 + .../nginx/unit/websocket/WsContainerProvider.java | 28 + src/java/nginx/unit/websocket/WsExtension.java | 46 + .../nginx/unit/websocket/WsExtensionParameter.java | 40 + src/java/nginx/unit/websocket/WsFrameBase.java | 1010 ++++++++++++++++ src/java/nginx/unit/websocket/WsFrameClient.java | 228 ++++ .../nginx/unit/websocket/WsHandshakeResponse.java | 56 + src/java/nginx/unit/websocket/WsIOException.java | 41 + src/java/nginx/unit/websocket/WsPongMessage.java | 39 + .../unit/websocket/WsRemoteEndpointAsync.java | 79 ++ .../nginx/unit/websocket/WsRemoteEndpointBase.java | 64 + .../unit/websocket/WsRemoteEndpointBasic.java | 76 ++ .../unit/websocket/WsRemoteEndpointImplBase.java | 1234 ++++++++++++++++++++ .../unit/websocket/WsRemoteEndpointImplClient.java | 75 ++ src/java/nginx/unit/websocket/WsSession.java | 1070 +++++++++++++++++ .../nginx/unit/websocket/WsWebSocketContainer.java | 1123 ++++++++++++++++++ src/java/nginx/unit/websocket/pojo/Constants.java | 32 + .../unit/websocket/pojo/LocalStrings.properties | 40 + .../unit/websocket/pojo/PojoEndpointBase.java | 156 +++ .../unit/websocket/pojo/PojoEndpointClient.java | 47 + .../unit/websocket/pojo/PojoEndpointServer.java | 66 ++ .../websocket/pojo/PojoMessageHandlerBase.java | 122 ++ .../pojo/PojoMessageHandlerPartialBase.java | 77 ++ .../pojo/PojoMessageHandlerPartialBinary.java | 36 + .../pojo/PojoMessageHandlerPartialText.java | 35 + .../pojo/PojoMessageHandlerWholeBase.java | 94 ++ .../pojo/PojoMessageHandlerWholeBinary.java | 131 +++ .../pojo/PojoMessageHandlerWholePong.java | 48 + .../pojo/PojoMessageHandlerWholeText.java | 136 +++ .../unit/websocket/pojo/PojoMethodMapping.java | 731 ++++++++++++ .../nginx/unit/websocket/pojo/PojoPathParam.java | 47 + .../nginx/unit/websocket/pojo/package-info.java | 21 + .../nginx/unit/websocket/server/Constants.java | 38 + .../server/DefaultServerEndpointConfigurator.java | 88 ++ .../unit/websocket/server/LocalStrings.properties | 43 + .../nginx/unit/websocket/server/UpgradeUtil.java | 285 +++++ .../nginx/unit/websocket/server/UriTemplate.java | 177 +++ .../unit/websocket/server/WsContextListener.java | 51 + src/java/nginx/unit/websocket/server/WsFilter.java | 81 ++ .../unit/websocket/server/WsHandshakeRequest.java | 196 ++++ .../websocket/server/WsHttpUpgradeHandler.java | 172 +++ .../unit/websocket/server/WsMappingResult.java | 44 + .../server/WsPerSessionServerEndpointConfig.java | 84 ++ .../server/WsRemoteEndpointImplServer.java | 158 +++ src/java/nginx/unit/websocket/server/WsSci.java | 145 +++ .../unit/websocket/server/WsServerContainer.java | 470 ++++++++ .../unit/websocket/server/WsSessionListener.java | 36 + .../unit/websocket/server/WsWriteTimeout.java | 128 ++ .../nginx/unit/websocket/server/package-info.java | 21 + src/java/nxt_jni_Request.c | 144 +++ src/java/nxt_jni_Request.h | 5 + src/nxt_java.c | 72 +- 111 files changed, 15254 insertions(+), 58 deletions(-) create mode 100644 src/java/javax/websocket/ClientEndpoint.java create mode 100644 src/java/javax/websocket/ClientEndpointConfig.java create mode 100644 src/java/javax/websocket/CloseReason.java create mode 100644 src/java/javax/websocket/ContainerProvider.java create mode 100644 src/java/javax/websocket/DecodeException.java create mode 100644 src/java/javax/websocket/Decoder.java create mode 100644 src/java/javax/websocket/DefaultClientEndpointConfig.java create mode 100644 src/java/javax/websocket/DeploymentException.java create mode 100644 src/java/javax/websocket/EncodeException.java create mode 100644 src/java/javax/websocket/Encoder.java create mode 100644 src/java/javax/websocket/Endpoint.java create mode 100644 src/java/javax/websocket/EndpointConfig.java create mode 100644 src/java/javax/websocket/Extension.java create mode 100644 src/java/javax/websocket/HandshakeResponse.java create mode 100644 src/java/javax/websocket/MessageHandler.java create mode 100644 src/java/javax/websocket/OnClose.java create mode 100644 src/java/javax/websocket/OnError.java create mode 100644 src/java/javax/websocket/OnMessage.java create mode 100644 src/java/javax/websocket/OnOpen.java create mode 100644 src/java/javax/websocket/PongMessage.java create mode 100644 src/java/javax/websocket/RemoteEndpoint.java create mode 100644 src/java/javax/websocket/SendHandler.java create mode 100644 src/java/javax/websocket/SendResult.java create mode 100644 src/java/javax/websocket/Session.java create mode 100644 src/java/javax/websocket/SessionException.java create mode 100644 src/java/javax/websocket/WebSocketContainer.java create mode 100644 src/java/javax/websocket/server/DefaultServerEndpointConfig.java create mode 100644 src/java/javax/websocket/server/HandshakeRequest.java create mode 100644 src/java/javax/websocket/server/PathParam.java create mode 100644 src/java/javax/websocket/server/ServerApplicationConfig.java create mode 100644 src/java/javax/websocket/server/ServerContainer.java create mode 100644 src/java/javax/websocket/server/ServerEndpoint.java create mode 100644 src/java/javax/websocket/server/ServerEndpointConfig.java create mode 100644 src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java create mode 100644 src/java/nginx/unit/websocket/AsyncChannelWrapper.java create mode 100644 src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java create mode 100644 src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java create mode 100644 src/java/nginx/unit/websocket/AuthenticationException.java create mode 100644 src/java/nginx/unit/websocket/Authenticator.java create mode 100644 src/java/nginx/unit/websocket/AuthenticatorFactory.java create mode 100644 src/java/nginx/unit/websocket/BackgroundProcess.java create mode 100644 src/java/nginx/unit/websocket/BackgroundProcessManager.java create mode 100644 src/java/nginx/unit/websocket/BasicAuthenticator.java create mode 100644 src/java/nginx/unit/websocket/Constants.java create mode 100644 src/java/nginx/unit/websocket/DecoderEntry.java create mode 100644 src/java/nginx/unit/websocket/DigestAuthenticator.java create mode 100644 src/java/nginx/unit/websocket/FutureToSendHandler.java create mode 100644 src/java/nginx/unit/websocket/LocalStrings.properties create mode 100644 src/java/nginx/unit/websocket/MessageHandlerResult.java create mode 100644 src/java/nginx/unit/websocket/MessageHandlerResultType.java create mode 100644 src/java/nginx/unit/websocket/MessagePart.java create mode 100644 src/java/nginx/unit/websocket/PerMessageDeflate.java create mode 100644 src/java/nginx/unit/websocket/ReadBufferOverflowException.java create mode 100644 src/java/nginx/unit/websocket/Transformation.java create mode 100644 src/java/nginx/unit/websocket/TransformationFactory.java create mode 100644 src/java/nginx/unit/websocket/TransformationResult.java create mode 100644 src/java/nginx/unit/websocket/Util.java create mode 100644 src/java/nginx/unit/websocket/WrappedMessageHandler.java create mode 100644 src/java/nginx/unit/websocket/WsContainerProvider.java create mode 100644 src/java/nginx/unit/websocket/WsExtension.java create mode 100644 src/java/nginx/unit/websocket/WsExtensionParameter.java create mode 100644 src/java/nginx/unit/websocket/WsFrameBase.java create mode 100644 src/java/nginx/unit/websocket/WsFrameClient.java create mode 100644 src/java/nginx/unit/websocket/WsHandshakeResponse.java create mode 100644 src/java/nginx/unit/websocket/WsIOException.java create mode 100644 src/java/nginx/unit/websocket/WsPongMessage.java create mode 100644 src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java create mode 100644 src/java/nginx/unit/websocket/WsRemoteEndpointBase.java create mode 100644 src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java create mode 100644 src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java create mode 100644 src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java create mode 100644 src/java/nginx/unit/websocket/WsSession.java create mode 100644 src/java/nginx/unit/websocket/WsWebSocketContainer.java create mode 100644 src/java/nginx/unit/websocket/pojo/Constants.java create mode 100644 src/java/nginx/unit/websocket/pojo/LocalStrings.properties create mode 100644 src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java create mode 100644 src/java/nginx/unit/websocket/pojo/PojoPathParam.java create mode 100644 src/java/nginx/unit/websocket/pojo/package-info.java create mode 100644 src/java/nginx/unit/websocket/server/Constants.java create mode 100644 src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java create mode 100644 src/java/nginx/unit/websocket/server/LocalStrings.properties create mode 100644 src/java/nginx/unit/websocket/server/UpgradeUtil.java create mode 100644 src/java/nginx/unit/websocket/server/UriTemplate.java create mode 100644 src/java/nginx/unit/websocket/server/WsContextListener.java create mode 100644 src/java/nginx/unit/websocket/server/WsFilter.java create mode 100644 src/java/nginx/unit/websocket/server/WsHandshakeRequest.java create mode 100644 src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java create mode 100644 src/java/nginx/unit/websocket/server/WsMappingResult.java create mode 100644 src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java create mode 100644 src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java create mode 100644 src/java/nginx/unit/websocket/server/WsSci.java create mode 100644 src/java/nginx/unit/websocket/server/WsServerContainer.java create mode 100644 src/java/nginx/unit/websocket/server/WsSessionListener.java create mode 100644 src/java/nginx/unit/websocket/server/WsWriteTimeout.java create mode 100644 src/java/nginx/unit/websocket/server/package-info.java (limited to 'src') diff --git a/src/java/javax/websocket/ClientEndpoint.java b/src/java/javax/websocket/ClientEndpoint.java new file mode 100644 index 00000000..ee984171 --- /dev/null +++ b/src/java/javax/websocket/ClientEndpoint.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import javax.websocket.ClientEndpointConfig.Configurator; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface ClientEndpoint { + String[] subprotocols() default {}; + Class[] decoders() default {}; + Class[] encoders() default {}; + public Class configurator() + default Configurator.class; +} diff --git a/src/java/javax/websocket/ClientEndpointConfig.java b/src/java/javax/websocket/ClientEndpointConfig.java new file mode 100644 index 00000000..13b6cba5 --- /dev/null +++ b/src/java/javax/websocket/ClientEndpointConfig.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public interface ClientEndpointConfig extends EndpointConfig { + + List getPreferredSubprotocols(); + + List getExtensions(); + + public Configurator getConfigurator(); + + public final class Builder { + + private static final Configurator DEFAULT_CONFIGURATOR = + new Configurator() {}; + + + public static Builder create() { + return new Builder(); + } + + + private Builder() { + // Hide default constructor + } + + private Configurator configurator = DEFAULT_CONFIGURATOR; + private List preferredSubprotocols = Collections.emptyList(); + private List extensions = Collections.emptyList(); + private List> encoders = + Collections.emptyList(); + private List> decoders = + Collections.emptyList(); + + + public ClientEndpointConfig build() { + return new DefaultClientEndpointConfig(preferredSubprotocols, + extensions, encoders, decoders, configurator); + } + + + public Builder configurator(Configurator configurator) { + if (configurator == null) { + this.configurator = DEFAULT_CONFIGURATOR; + } else { + this.configurator = configurator; + } + return this; + } + + + public Builder preferredSubprotocols( + List preferredSubprotocols) { + if (preferredSubprotocols == null || + preferredSubprotocols.size() == 0) { + this.preferredSubprotocols = Collections.emptyList(); + } else { + this.preferredSubprotocols = + Collections.unmodifiableList(preferredSubprotocols); + } + return this; + } + + + public Builder extensions( + List extensions) { + if (extensions == null || extensions.size() == 0) { + this.extensions = Collections.emptyList(); + } else { + this.extensions = Collections.unmodifiableList(extensions); + } + return this; + } + + + public Builder encoders(List> encoders) { + if (encoders == null || encoders.size() == 0) { + this.encoders = Collections.emptyList(); + } else { + this.encoders = Collections.unmodifiableList(encoders); + } + return this; + } + + + public Builder decoders(List> decoders) { + if (decoders == null || decoders.size() == 0) { + this.decoders = Collections.emptyList(); + } else { + this.decoders = Collections.unmodifiableList(decoders); + } + return this; + } + } + + + public class Configurator { + + /** + * Provides the client with a mechanism to inspect and/or modify the headers + * that are sent to the server to start the WebSocket handshake. + * + * @param headers The HTTP headers + */ + public void beforeRequest(Map> headers) { + // NO-OP + } + + /** + * Provides the client with a mechanism to inspect the handshake response + * that is returned from the server. + * + * @param handshakeResponse The response + */ + public void afterResponse(HandshakeResponse handshakeResponse) { + // NO-OP + } + } +} diff --git a/src/java/javax/websocket/CloseReason.java b/src/java/javax/websocket/CloseReason.java new file mode 100644 index 00000000..ef88d135 --- /dev/null +++ b/src/java/javax/websocket/CloseReason.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public class CloseReason { + + private final CloseCode closeCode; + private final String reasonPhrase; + + public CloseReason(CloseReason.CloseCode closeCode, String reasonPhrase) { + this.closeCode = closeCode; + this.reasonPhrase = reasonPhrase; + } + + public CloseCode getCloseCode() { + return closeCode; + } + + public String getReasonPhrase() { + return reasonPhrase; + } + + @Override + public String toString() { + return "CloseReason: code [" + closeCode.getCode() + + "], reason [" + reasonPhrase + "]"; + } + + public interface CloseCode { + int getCode(); + } + + public enum CloseCodes implements CloseReason.CloseCode { + + NORMAL_CLOSURE(1000), + GOING_AWAY(1001), + PROTOCOL_ERROR(1002), + CANNOT_ACCEPT(1003), + RESERVED(1004), + NO_STATUS_CODE(1005), + CLOSED_ABNORMALLY(1006), + NOT_CONSISTENT(1007), + VIOLATED_POLICY(1008), + TOO_BIG(1009), + NO_EXTENSION(1010), + UNEXPECTED_CONDITION(1011), + SERVICE_RESTART(1012), + TRY_AGAIN_LATER(1013), + TLS_HANDSHAKE_FAILURE(1015); + + private int code; + + CloseCodes(int code) { + this.code = code; + } + + public static CloseCode getCloseCode(final int code) { + if (code > 2999 && code < 5000) { + return new CloseCode() { + @Override + public int getCode() { + return code; + } + }; + } + switch (code) { + case 1000: + return CloseCodes.NORMAL_CLOSURE; + case 1001: + return CloseCodes.GOING_AWAY; + case 1002: + return CloseCodes.PROTOCOL_ERROR; + case 1003: + return CloseCodes.CANNOT_ACCEPT; + case 1004: + return CloseCodes.RESERVED; + case 1005: + return CloseCodes.NO_STATUS_CODE; + case 1006: + return CloseCodes.CLOSED_ABNORMALLY; + case 1007: + return CloseCodes.NOT_CONSISTENT; + case 1008: + return CloseCodes.VIOLATED_POLICY; + case 1009: + return CloseCodes.TOO_BIG; + case 1010: + return CloseCodes.NO_EXTENSION; + case 1011: + return CloseCodes.UNEXPECTED_CONDITION; + case 1012: + return CloseCodes.SERVICE_RESTART; + case 1013: + return CloseCodes.TRY_AGAIN_LATER; + case 1015: + return CloseCodes.TLS_HANDSHAKE_FAILURE; + default: + throw new IllegalArgumentException( + "Invalid close code: [" + code + "]"); + } + } + + @Override + public int getCode() { + return code; + } + } +} diff --git a/src/java/javax/websocket/ContainerProvider.java b/src/java/javax/websocket/ContainerProvider.java new file mode 100644 index 00000000..1727ca93 --- /dev/null +++ b/src/java/javax/websocket/ContainerProvider.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.Iterator; +import java.util.ServiceLoader; + +/** + * Use the {@link ServiceLoader} mechanism to provide instances of the WebSocket + * client container. + */ +public abstract class ContainerProvider { + + private static final String DEFAULT_PROVIDER_CLASS_NAME = + "nginx.unit.websocket.WsWebSocketContainer"; + + /** + * Create a new container used to create outgoing WebSocket connections. + * + * @return A newly created container. + */ + public static WebSocketContainer getWebSocketContainer() { + WebSocketContainer result = null; + + ServiceLoader serviceLoader = + ServiceLoader.load(ContainerProvider.class); + Iterator iter = serviceLoader.iterator(); + while (result == null && iter.hasNext()) { + result = iter.next().getContainer(); + } + + // Fall-back. Also used by unit tests + if (result == null) { + try { + @SuppressWarnings("unchecked") + Class clazz = + (Class) Class.forName( + DEFAULT_PROVIDER_CLASS_NAME); + result = clazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException | IllegalArgumentException | + SecurityException e) { + // No options left. Just return null. + } + } + return result; + } + + protected abstract WebSocketContainer getContainer(); +} diff --git a/src/java/javax/websocket/DecodeException.java b/src/java/javax/websocket/DecodeException.java new file mode 100644 index 00000000..771cfa58 --- /dev/null +++ b/src/java/javax/websocket/DecodeException.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.nio.ByteBuffer; + +public class DecodeException extends Exception { + + private static final long serialVersionUID = 1L; + + private ByteBuffer bb; + private String encodedString; + + public DecodeException(ByteBuffer bb, String message, Throwable cause) { + super(message, cause); + this.bb = bb; + } + + public DecodeException(String encodedString, String message, + Throwable cause) { + super(message, cause); + this.encodedString = encodedString; + } + + public DecodeException(ByteBuffer bb, String message) { + super(message); + this.bb = bb; + } + + public DecodeException(String encodedString, String message) { + super(message); + this.encodedString = encodedString; + } + + public ByteBuffer getBytes() { + return bb; + } + + public String getText() { + return encodedString; + } +} diff --git a/src/java/javax/websocket/Decoder.java b/src/java/javax/websocket/Decoder.java new file mode 100644 index 00000000..fad262e3 --- /dev/null +++ b/src/java/javax/websocket/Decoder.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; +import java.nio.ByteBuffer; + +public interface Decoder { + + abstract void init(EndpointConfig endpointConfig); + + abstract void destroy(); + + interface Binary extends Decoder { + + T decode(ByteBuffer bytes) throws DecodeException; + + boolean willDecode(ByteBuffer bytes); + } + + interface BinaryStream extends Decoder { + + T decode(InputStream is) throws DecodeException, IOException; + } + + interface Text extends Decoder { + + T decode(String s) throws DecodeException; + + boolean willDecode(String s); + } + + interface TextStream extends Decoder { + + T decode(Reader reader) throws DecodeException, IOException; + } +} diff --git a/src/java/javax/websocket/DefaultClientEndpointConfig.java b/src/java/javax/websocket/DefaultClientEndpointConfig.java new file mode 100644 index 00000000..ce28cb26 --- /dev/null +++ b/src/java/javax/websocket/DefaultClientEndpointConfig.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +final class DefaultClientEndpointConfig implements ClientEndpointConfig { + + private final List preferredSubprotocols; + private final List extensions; + private final List> encoders; + private final List> decoders; + private final Map userProperties = new ConcurrentHashMap<>(); + private final Configurator configurator; + + + DefaultClientEndpointConfig(List preferredSubprotocols, + List extensions, + List> encoders, + List> decoders, + Configurator configurator) { + this.preferredSubprotocols = preferredSubprotocols; + this.extensions = extensions; + this.decoders = decoders; + this.encoders = encoders; + this.configurator = configurator; + } + + + @Override + public List getPreferredSubprotocols() { + return preferredSubprotocols; + } + + + @Override + public List getExtensions() { + return extensions; + } + + + @Override + public List> getEncoders() { + return encoders; + } + + + @Override + public List> getDecoders() { + return decoders; + } + + + @Override + public final Map getUserProperties() { + return userProperties; + } + + + @Override + public Configurator getConfigurator() { + return configurator; + } +} diff --git a/src/java/javax/websocket/DeploymentException.java b/src/java/javax/websocket/DeploymentException.java new file mode 100644 index 00000000..1678fd09 --- /dev/null +++ b/src/java/javax/websocket/DeploymentException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public class DeploymentException extends Exception { + + private static final long serialVersionUID = 1L; + + public DeploymentException(String message) { + super(message); + } + + public DeploymentException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/java/javax/websocket/EncodeException.java b/src/java/javax/websocket/EncodeException.java new file mode 100644 index 00000000..fdb536ac --- /dev/null +++ b/src/java/javax/websocket/EncodeException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public class EncodeException extends Exception { + + private static final long serialVersionUID = 1L; + + private Object object; + + public EncodeException(Object object, String message) { + super(message); + this.object = object; + } + + public EncodeException(Object object, String message, Throwable cause) { + super(message, cause); + this.object = object; + } + + public Object getObject() { + return this.object; + } +} diff --git a/src/java/javax/websocket/Encoder.java b/src/java/javax/websocket/Encoder.java new file mode 100644 index 00000000..42a107f0 --- /dev/null +++ b/src/java/javax/websocket/Encoder.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.nio.ByteBuffer; + +public interface Encoder { + + abstract void init(EndpointConfig endpointConfig); + + abstract void destroy(); + + interface Text extends Encoder { + + String encode(T object) throws EncodeException; + } + + interface TextStream extends Encoder { + + void encode(T object, Writer writer) + throws EncodeException, IOException; + } + + interface Binary extends Encoder { + + ByteBuffer encode(T object) throws EncodeException; + } + + interface BinaryStream extends Encoder { + + void encode(T object, OutputStream os) + throws EncodeException, IOException; + } +} diff --git a/src/java/javax/websocket/Endpoint.java b/src/java/javax/websocket/Endpoint.java new file mode 100644 index 00000000..9dfdbcce --- /dev/null +++ b/src/java/javax/websocket/Endpoint.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public abstract class Endpoint { + + /** + * Event that is triggered when a new session starts. + * + * @param session The new session. + * @param config The configuration with which the Endpoint was + * configured. + */ + public abstract void onOpen(Session session, EndpointConfig config); + + /** + * Event that is triggered when a session has closed. + * + * @param session The session + * @param closeReason Why the session was closed + */ + public void onClose(Session session, CloseReason closeReason) { + // NO-OP by default + } + + /** + * Event that is triggered when a protocol error occurs. + * + * @param session The session. + * @param throwable The exception. + */ + public void onError(Session session, Throwable throwable) { + // NO-OP by default + } +} diff --git a/src/java/javax/websocket/EndpointConfig.java b/src/java/javax/websocket/EndpointConfig.java new file mode 100644 index 00000000..0b6c9681 --- /dev/null +++ b/src/java/javax/websocket/EndpointConfig.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.List; +import java.util.Map; + +public interface EndpointConfig { + + List> getEncoders(); + + List> getDecoders(); + + Map getUserProperties(); +} diff --git a/src/java/javax/websocket/Extension.java b/src/java/javax/websocket/Extension.java new file mode 100644 index 00000000..b95b27b8 --- /dev/null +++ b/src/java/javax/websocket/Extension.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.List; + +public interface Extension { + String getName(); + List getParameters(); + + interface Parameter { + String getName(); + String getValue(); + } +} diff --git a/src/java/javax/websocket/HandshakeResponse.java b/src/java/javax/websocket/HandshakeResponse.java new file mode 100644 index 00000000..807192e8 --- /dev/null +++ b/src/java/javax/websocket/HandshakeResponse.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.List; +import java.util.Map; + +public interface HandshakeResponse { + + /** + * Name of the WebSocket accept HTTP header. + */ + public static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept"; + + Map> getHeaders(); +} diff --git a/src/java/javax/websocket/MessageHandler.java b/src/java/javax/websocket/MessageHandler.java new file mode 100644 index 00000000..2c30d997 --- /dev/null +++ b/src/java/javax/websocket/MessageHandler.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public interface MessageHandler { + + interface Partial extends MessageHandler { + + /** + * Called when part of a message is available to be processed. + * + * @param messagePart The message part + * @param last true if this is the last part of + * this message, else false + */ + void onMessage(T messagePart, boolean last); + } + + interface Whole extends MessageHandler { + + /** + * Called when a whole message is available to be processed. + * + * @param message The message + */ + void onMessage(T message); + } +} diff --git a/src/java/javax/websocket/OnClose.java b/src/java/javax/websocket/OnClose.java new file mode 100644 index 00000000..6ee61d36 --- /dev/null +++ b/src/java/javax/websocket/OnClose.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface OnClose { +} diff --git a/src/java/javax/websocket/OnError.java b/src/java/javax/websocket/OnError.java new file mode 100644 index 00000000..ce431484 --- /dev/null +++ b/src/java/javax/websocket/OnError.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface OnError { +} diff --git a/src/java/javax/websocket/OnMessage.java b/src/java/javax/websocket/OnMessage.java new file mode 100644 index 00000000..564fa994 --- /dev/null +++ b/src/java/javax/websocket/OnMessage.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface OnMessage { + long maxMessageSize() default -1; +} diff --git a/src/java/javax/websocket/OnOpen.java b/src/java/javax/websocket/OnOpen.java new file mode 100644 index 00000000..9f0ea6e3 --- /dev/null +++ b/src/java/javax/websocket/OnOpen.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface OnOpen { +} diff --git a/src/java/javax/websocket/PongMessage.java b/src/java/javax/websocket/PongMessage.java new file mode 100644 index 00000000..7e9e3b6a --- /dev/null +++ b/src/java/javax/websocket/PongMessage.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.nio.ByteBuffer; + +/** + * Represents a WebSocket Pong message and used by message handlers to enable + * applications to process the response to any Pings they send. + */ +public interface PongMessage { + /** + * Get the payload of the Pong message. + * + * @return The payload of the Pong message. + */ + ByteBuffer getApplicationData(); +} diff --git a/src/java/javax/websocket/RemoteEndpoint.java b/src/java/javax/websocket/RemoteEndpoint.java new file mode 100644 index 00000000..19c7a100 --- /dev/null +++ b/src/java/javax/websocket/RemoteEndpoint.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.nio.ByteBuffer; +import java.util.concurrent.Future; + + +public interface RemoteEndpoint { + + interface Async extends RemoteEndpoint { + + /** + * Obtain the timeout (in milliseconds) for sending a message + * asynchronously. The default value is determined by + * {@link WebSocketContainer#getDefaultAsyncSendTimeout()}. + * @return The current send timeout in milliseconds. A non-positive + * value means an infinite timeout. + */ + long getSendTimeout(); + + /** + * Set the timeout (in milliseconds) for sending a message + * asynchronously. The default value is determined by + * {@link WebSocketContainer#getDefaultAsyncSendTimeout()}. + * @param timeout The new timeout for sending messages asynchronously + * in milliseconds. A non-positive value means an + * infinite timeout. + */ + void setSendTimeout(long timeout); + + /** + * Send the message asynchronously, using the SendHandler to signal to the + * client when the message has been sent. + * @param text The text message to send + * @param completion Used to signal to the client when the message has + * been sent + */ + void sendText(String text, SendHandler completion); + + /** + * Send the message asynchronously, using the Future to signal to the + * client when the message has been sent. + * @param text The text message to send + * @return A Future that signals when the message has been sent. + */ + Future sendText(String text); + + /** + * Send the message asynchronously, using the Future to signal to the client + * when the message has been sent. + * @param data The text message to send + * @return A Future that signals when the message has been sent. + * @throws IllegalArgumentException if {@code data} is {@code null}. + */ + Future sendBinary(ByteBuffer data); + + /** + * Send the message asynchronously, using the SendHandler to signal to the + * client when the message has been sent. + * @param data The text message to send + * @param completion Used to signal to the client when the message has + * been sent + * @throws IllegalArgumentException if {@code data} or {@code completion} + * is {@code null}. + */ + void sendBinary(ByteBuffer data, SendHandler completion); + + /** + * Encodes object as a message and sends it asynchronously, using the + * Future to signal to the client when the message has been sent. + * @param obj The object to be sent. + * @return A Future that signals when the message has been sent. + * @throws IllegalArgumentException if {@code obj} is {@code null}. + */ + Future sendObject(Object obj); + + /** + * Encodes object as a message and sends it asynchronously, using the + * SendHandler to signal to the client when the message has been sent. + * @param obj The object to be sent. + * @param completion Used to signal to the client when the message has + * been sent + * @throws IllegalArgumentException if {@code obj} or + * {@code completion} is {@code null}. + */ + void sendObject(Object obj, SendHandler completion); + + } + + interface Basic extends RemoteEndpoint { + + /** + * Send the message, blocking until the message is sent. + * @param text The text message to send. + * @throws IllegalArgumentException if {@code text} is {@code null}. + * @throws IOException if an I/O error occurs during the sending of the + * message. + */ + void sendText(String text) throws IOException; + + /** + * Send the message, blocking until the message is sent. + * @param data The binary message to send + * @throws IllegalArgumentException if {@code data} is {@code null}. + * @throws IOException if an I/O error occurs during the sending of the + * message. + */ + void sendBinary(ByteBuffer data) throws IOException; + + /** + * Sends part of a text message to the remote endpoint. Once the first part + * of a message has been sent, no other text or binary messages may be sent + * until all remaining parts of this message have been sent. + * + * @param fragment The partial message to send + * @param isLast true if this is the last part of the + * message, otherwise false + * @throws IllegalArgumentException if {@code fragment} is {@code null}. + * @throws IOException if an I/O error occurs during the sending of the + * message. + */ + void sendText(String fragment, boolean isLast) throws IOException; + + /** + * Sends part of a binary message to the remote endpoint. Once the first + * part of a message has been sent, no other text or binary messages may be + * sent until all remaining parts of this message have been sent. + * + * @param partialByte The partial message to send + * @param isLast true if this is the last part of the + * message, otherwise false + * @throws IllegalArgumentException if {@code partialByte} is + * {@code null}. + * @throws IOException if an I/O error occurs during the sending of the + * message. + */ + void sendBinary(ByteBuffer partialByte, boolean isLast) throws IOException; + + OutputStream getSendStream() throws IOException; + + Writer getSendWriter() throws IOException; + + /** + * Encodes object as a message and sends it to the remote endpoint. + * @param data The object to be sent. + * @throws EncodeException if there was a problem encoding the + * {@code data} object as a websocket message. + * @throws IllegalArgumentException if {@code data} is {@code null}. + * @throws IOException if an I/O error occurs during the sending of the + * message. + */ + void sendObject(Object data) throws IOException, EncodeException; + + } + /** + * Enable or disable the batching of outgoing messages for this endpoint. If + * batching is disabled when it was previously enabled then this method will + * block until any currently batched messages have been written. + * + * @param batchingAllowed New setting + * @throws IOException If changing the value resulted in a call to + * {@link #flushBatch()} and that call threw an + * {@link IOException}. + */ + void setBatchingAllowed(boolean batchingAllowed) throws IOException; + + /** + * Obtains the current batching status of the endpoint. + * + * @return true if batching is enabled, otherwise + * false. + */ + boolean getBatchingAllowed(); + + /** + * Flush any currently batched messages to the remote endpoint. This method + * will block until the flush completes. + * + * @throws IOException If an I/O error occurs while flushing + */ + void flushBatch() throws IOException; + + /** + * Send a ping message blocking until the message has been sent. Note that + * if a message is in the process of being sent asynchronously, this method + * will block until that message and this ping has been sent. + * + * @param applicationData The payload for the ping message + * + * @throws IOException If an I/O error occurs while sending the ping + * @throws IllegalArgumentException if the applicationData is too large for + * a control message (max 125 bytes) + */ + void sendPing(ByteBuffer applicationData) + throws IOException, IllegalArgumentException; + + /** + * Send a pong message blocking until the message has been sent. Note that + * if a message is in the process of being sent asynchronously, this method + * will block until that message and this pong has been sent. + * + * @param applicationData The payload for the pong message + * + * @throws IOException If an I/O error occurs while sending the pong + * @throws IllegalArgumentException if the applicationData is too large for + * a control message (max 125 bytes) + */ + void sendPong(ByteBuffer applicationData) + throws IOException, IllegalArgumentException; +} + diff --git a/src/java/javax/websocket/SendHandler.java b/src/java/javax/websocket/SendHandler.java new file mode 100644 index 00000000..65b9a19a --- /dev/null +++ b/src/java/javax/websocket/SendHandler.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public interface SendHandler { + + void onResult(SendResult result); +} diff --git a/src/java/javax/websocket/SendResult.java b/src/java/javax/websocket/SendResult.java new file mode 100644 index 00000000..a3797d5b --- /dev/null +++ b/src/java/javax/websocket/SendResult.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public final class SendResult { + private final Throwable exception; + private final boolean ok; + + public SendResult(Throwable exception) { + this.exception = exception; + this.ok = (exception == null); + } + + public SendResult() { + this (null); + } + + public Throwable getException() { + return exception; + } + + public boolean isOK() { + return ok; + } +} diff --git a/src/java/javax/websocket/Session.java b/src/java/javax/websocket/Session.java new file mode 100644 index 00000000..eea15e5b --- /dev/null +++ b/src/java/javax/websocket/Session.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.io.Closeable; +import java.io.IOException; +import java.net.URI; +import java.security.Principal; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public interface Session extends Closeable { + + /** + * Get the container that created this session. + * @return the container that created this session. + */ + WebSocketContainer getContainer(); + + /** + * Registers a {@link MessageHandler} for incoming messages. Only one + * {@link MessageHandler} may be registered for each message type (text, + * binary, pong). The message type will be derived at runtime from the + * provided {@link MessageHandler} instance. It is not always possible to do + * this so it is better to use + * {@link #addMessageHandler(Class, javax.websocket.MessageHandler.Partial)} + * or + * {@link #addMessageHandler(Class, javax.websocket.MessageHandler.Whole)}. + * + * @param handler The message handler for a incoming message + * + * @throws IllegalStateException If a message handler has already been + * registered for the associated message type + */ + void addMessageHandler(MessageHandler handler) throws IllegalStateException; + + Set getMessageHandlers(); + + void removeMessageHandler(MessageHandler listener); + + String getProtocolVersion(); + + String getNegotiatedSubprotocol(); + + List getNegotiatedExtensions(); + + boolean isSecure(); + + boolean isOpen(); + + /** + * Get the idle timeout for this session. + * @return The current idle timeout for this session in milliseconds. Zero + * or negative values indicate an infinite timeout. + */ + long getMaxIdleTimeout(); + + /** + * Set the idle timeout for this session. + * @param timeout The new idle timeout for this session in milliseconds. + * Zero or negative values indicate an infinite timeout. + */ + void setMaxIdleTimeout(long timeout); + + /** + * Set the current maximum buffer size for binary messages. + * @param max The new maximum buffer size in bytes + */ + void setMaxBinaryMessageBufferSize(int max); + + /** + * Get the current maximum buffer size for binary messages. + * @return The current maximum buffer size in bytes + */ + int getMaxBinaryMessageBufferSize(); + + /** + * Set the maximum buffer size for text messages. + * @param max The new maximum buffer size in characters. + */ + void setMaxTextMessageBufferSize(int max); + + /** + * Get the maximum buffer size for text messages. + * @return The maximum buffer size in characters. + */ + int getMaxTextMessageBufferSize(); + + RemoteEndpoint.Async getAsyncRemote(); + + RemoteEndpoint.Basic getBasicRemote(); + + /** + * Provides a unique identifier for the session. This identifier should not + * be relied upon to be generated from a secure random source. + * @return A unique identifier for the session. + */ + String getId(); + + /** + * Close the connection to the remote end point using the code + * {@link javax.websocket.CloseReason.CloseCodes#NORMAL_CLOSURE} and an + * empty reason phrase. + * + * @throws IOException if an I/O error occurs while the WebSocket session is + * being closed. + */ + @Override + void close() throws IOException; + + + /** + * Close the connection to the remote end point using the specified code + * and reason phrase. + * @param closeReason The reason the WebSocket session is being closed. + * + * @throws IOException if an I/O error occurs while the WebSocket session is + * being closed. + */ + void close(CloseReason closeReason) throws IOException; + + URI getRequestURI(); + + Map> getRequestParameterMap(); + + String getQueryString(); + + Map getPathParameters(); + + Map getUserProperties(); + + Principal getUserPrincipal(); + + /** + * Obtain the set of open sessions associated with the same local endpoint + * as this session. + * + * @return The set of currently open sessions for the local endpoint that + * this session is associated with. + */ + Set getOpenSessions(); + + /** + * Registers a {@link MessageHandler} for partial incoming messages. Only + * one {@link MessageHandler} may be registered for each message type (text + * or binary, pong messages are never presented as partial messages). + * + * @param The type of message that the given handler is intended + * for + * @param clazz The Class that implements T + * @param handler The message handler for a incoming message + * + * @throws IllegalStateException If a message handler has already been + * registered for the associated message type + * + * @since WebSocket 1.1 + */ + void addMessageHandler(Class clazz, MessageHandler.Partial handler) + throws IllegalStateException; + + /** + * Registers a {@link MessageHandler} for whole incoming messages. Only + * one {@link MessageHandler} may be registered for each message type (text, + * binary, pong). + * + * @param The type of message that the given handler is intended + * for + * @param clazz The Class that implements T + * @param handler The message handler for a incoming message + * + * @throws IllegalStateException If a message handler has already been + * registered for the associated message type + * + * @since WebSocket 1.1 + */ + void addMessageHandler(Class clazz, MessageHandler.Whole handler) + throws IllegalStateException; +} diff --git a/src/java/javax/websocket/SessionException.java b/src/java/javax/websocket/SessionException.java new file mode 100644 index 00000000..428b82ec --- /dev/null +++ b/src/java/javax/websocket/SessionException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public class SessionException extends Exception { + + private static final long serialVersionUID = 1L; + + private final Session session; + + + public SessionException(String message, Throwable cause, Session session) { + super(message, cause); + this.session = session; + } + + + public Session getSession() { + return session; + } +} diff --git a/src/java/javax/websocket/WebSocketContainer.java b/src/java/javax/websocket/WebSocketContainer.java new file mode 100644 index 00000000..f2da3e43 --- /dev/null +++ b/src/java/javax/websocket/WebSocketContainer.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.io.IOException; +import java.net.URI; +import java.util.Set; + +public interface WebSocketContainer { + + /** + * Get the default timeout for sending a message asynchronously. + * @return The current default timeout in milliseconds. A non-positive value + * means an infinite timeout. + */ + long getDefaultAsyncSendTimeout(); + + /** + * Set the default timeout for sending a message asynchronously. + * @param timeout The new default timeout in milliseconds. A non-positive + * value means an infinite timeout. + */ + void setAsyncSendTimeout(long timeout); + + Session connectToServer(Object endpoint, URI path) + throws DeploymentException, IOException; + + Session connectToServer(Class annotatedEndpointClass, URI path) + throws DeploymentException, IOException; + + /** + * Creates a new connection to the WebSocket. + * + * @param endpoint + * The endpoint instance that will handle responses from the + * server + * @param clientEndpointConfiguration + * Used to configure the new connection + * @param path + * The full URL of the WebSocket endpoint to connect to + * + * @return The WebSocket session for the connection + * + * @throws DeploymentException If the connection cannot be established + * @throws IOException If an I/O occurred while trying to establish the + * connection + */ + Session connectToServer(Endpoint endpoint, + ClientEndpointConfig clientEndpointConfiguration, URI path) + throws DeploymentException, IOException; + + /** + * Creates a new connection to the WebSocket. + * + * @param endpoint + * An instance of this class will be created to handle responses + * from the server + * @param clientEndpointConfiguration + * Used to configure the new connection + * @param path + * The full URL of the WebSocket endpoint to connect to + * + * @return The WebSocket session for the connection + * + * @throws DeploymentException If the connection cannot be established + * @throws IOException If an I/O occurred while trying to establish the + * connection + */ + Session connectToServer(Class endpoint, + ClientEndpointConfig clientEndpointConfiguration, URI path) + throws DeploymentException, IOException; + + /** + * Get the current default session idle timeout. + * @return The current default session idle timeout in milliseconds. Zero or + * negative values indicate an infinite timeout. + */ + long getDefaultMaxSessionIdleTimeout(); + + /** + * Set the default session idle timeout. + * @param timeout The new default session idle timeout in milliseconds. Zero + * or negative values indicate an infinite timeout. + */ + void setDefaultMaxSessionIdleTimeout(long timeout); + + /** + * Get the default maximum buffer size for binary messages. + * @return The current default maximum buffer size in bytes + */ + int getDefaultMaxBinaryMessageBufferSize(); + + /** + * Set the default maximum buffer size for binary messages. + * @param max The new default maximum buffer size in bytes + */ + void setDefaultMaxBinaryMessageBufferSize(int max); + + /** + * Get the default maximum buffer size for text messages. + * @return The current default maximum buffer size in characters + */ + int getDefaultMaxTextMessageBufferSize(); + + /** + * Set the default maximum buffer size for text messages. + * @param max The new default maximum buffer size in characters + */ + void setDefaultMaxTextMessageBufferSize(int max); + + /** + * Get the installed extensions. + * @return The set of extensions that are supported by this WebSocket + * implementation. + */ + Set getInstalledExtensions(); +} diff --git a/src/java/javax/websocket/server/DefaultServerEndpointConfig.java b/src/java/javax/websocket/server/DefaultServerEndpointConfig.java new file mode 100644 index 00000000..7c3b8d7d --- /dev/null +++ b/src/java/javax/websocket/server/DefaultServerEndpointConfig.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import javax.websocket.Decoder; +import javax.websocket.Encoder; +import javax.websocket.Extension; + +/** + * Provides the default configuration for WebSocket server endpoints. + */ +final class DefaultServerEndpointConfig implements ServerEndpointConfig { + + private final Class endpointClass; + private final String path; + private final List subprotocols; + private final List extensions; + private final List> encoders; + private final List> decoders; + private final Configurator serverEndpointConfigurator; + private final Map userProperties = new ConcurrentHashMap<>(); + + DefaultServerEndpointConfig( + Class endpointClass, String path, + List subprotocols, List extensions, + List> encoders, + List> decoders, + Configurator serverEndpointConfigurator) { + this.endpointClass = endpointClass; + this.path = path; + this.subprotocols = subprotocols; + this.extensions = extensions; + this.encoders = encoders; + this.decoders = decoders; + this.serverEndpointConfigurator = serverEndpointConfigurator; + } + + @Override + public Class getEndpointClass() { + return endpointClass; + } + + @Override + public List> getEncoders() { + return this.encoders; + } + + @Override + public List> getDecoders() { + return this.decoders; + } + + @Override + public String getPath() { + return path; + } + + @Override + public Configurator getConfigurator() { + return serverEndpointConfigurator; + } + + @Override + public final Map getUserProperties() { + return userProperties; + } + + @Override + public final List getSubprotocols() { + return subprotocols; + } + + @Override + public final List getExtensions() { + return extensions; + } +} diff --git a/src/java/javax/websocket/server/HandshakeRequest.java b/src/java/javax/websocket/server/HandshakeRequest.java new file mode 100644 index 00000000..f2e33273 --- /dev/null +++ b/src/java/javax/websocket/server/HandshakeRequest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.net.URI; +import java.security.Principal; +import java.util.List; +import java.util.Map; + +/** + * Represents the HTTP request that asked to be upgraded to WebSocket. + */ +public interface HandshakeRequest { + + static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; + static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; + static final String SEC_WEBSOCKET_VERSION = "Sec-WebSocket-Version"; + static final String SEC_WEBSOCKET_EXTENSIONS= "Sec-WebSocket-Extensions"; + + Map> getHeaders(); + + Principal getUserPrincipal(); + + URI getRequestURI(); + + boolean isUserInRole(String role); + + /** + * Get the HTTP Session object associated with this request. Object is used + * to avoid a direct dependency on the Servlet API. + * @return The javax.servlet.http.HttpSession object associated with this + * request, if any. + */ + Object getHttpSession(); + + Map> getParameterMap(); + + String getQueryString(); +} diff --git a/src/java/javax/websocket/server/PathParam.java b/src/java/javax/websocket/server/PathParam.java new file mode 100644 index 00000000..ff1d085e --- /dev/null +++ b/src/java/javax/websocket/server/PathParam.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Used to annotate method parameters on POJO endpoints the the {@link + * ServerEndpoint} has been defined with a {@link ServerEndpoint#value()} that + * uses a URI template. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.PARAMETER) +public @interface PathParam { + String value(); +} diff --git a/src/java/javax/websocket/server/ServerApplicationConfig.java b/src/java/javax/websocket/server/ServerApplicationConfig.java new file mode 100644 index 00000000..b91f1c43 --- /dev/null +++ b/src/java/javax/websocket/server/ServerApplicationConfig.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.util.Set; + +import javax.websocket.Endpoint; + +/** + * Applications may provide an implementation of this interface to filter the + * discovered WebSocket endpoints that are deployed. Implementations of this + * class will be discovered via an ServletContainerInitializer scan. + */ +public interface ServerApplicationConfig { + + /** + * Enables applications to filter the discovered implementations of + * {@link ServerEndpointConfig}. + * + * @param scanned The {@link Endpoint} implementations found in the + * application + * @return The set of configurations for the endpoint the application + * wishes to deploy + */ + Set getEndpointConfigs( + Set> scanned); + + /** + * Enables applications to filter the discovered classes annotated with + * {@link ServerEndpoint}. + * + * @param scanned The POJOs annotated with {@link ServerEndpoint} found in + * the application + * @return The set of POJOs the application wishes to deploy + */ + Set> getAnnotatedEndpointClasses(Set> scanned); +} diff --git a/src/java/javax/websocket/server/ServerContainer.java b/src/java/javax/websocket/server/ServerContainer.java new file mode 100644 index 00000000..3243a07c --- /dev/null +++ b/src/java/javax/websocket/server/ServerContainer.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import javax.websocket.DeploymentException; +import javax.websocket.WebSocketContainer; + +/** + * Provides the ability to deploy endpoints programmatically. + */ +public interface ServerContainer extends WebSocketContainer { + public abstract void addEndpoint(Class clazz) throws DeploymentException; + + public abstract void addEndpoint(ServerEndpointConfig sec) + throws DeploymentException; +} diff --git a/src/java/javax/websocket/server/ServerEndpoint.java b/src/java/javax/websocket/server/ServerEndpoint.java new file mode 100644 index 00000000..43b7dfa2 --- /dev/null +++ b/src/java/javax/websocket/server/ServerEndpoint.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import javax.websocket.Decoder; +import javax.websocket.Encoder; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface ServerEndpoint { + + /** + * URI or URI-template that the annotated class should be mapped to. + * @return The URI or URI-template that the annotated class should be mapped + * to. + */ + String value(); + + String[] subprotocols() default {}; + + Class[] decoders() default {}; + + Class[] encoders() default {}; + + public Class configurator() + default ServerEndpointConfig.Configurator.class; +} diff --git a/src/java/javax/websocket/server/ServerEndpointConfig.java b/src/java/javax/websocket/server/ServerEndpointConfig.java new file mode 100644 index 00000000..5afdf79c --- /dev/null +++ b/src/java/javax/websocket/server/ServerEndpointConfig.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.ServiceLoader; + +import javax.websocket.Decoder; +import javax.websocket.Encoder; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; + +/** + * Provides configuration information for WebSocket endpoints published to a + * server. Applications may provide their own implementation or use + * {@link Builder}. + */ +public interface ServerEndpointConfig extends EndpointConfig { + + Class getEndpointClass(); + + /** + * Returns the path at which this WebSocket server endpoint has been + * registered. It may be a path or a level 0 URI template. + * @return The registered path + */ + String getPath(); + + List getSubprotocols(); + + List getExtensions(); + + Configurator getConfigurator(); + + + public final class Builder { + + public static Builder create( + Class endpointClass, String path) { + return new Builder(endpointClass, path); + } + + + private final Class endpointClass; + private final String path; + private List> encoders = + Collections.emptyList(); + private List> decoders = + Collections.emptyList(); + private List subprotocols = Collections.emptyList(); + private List extensions = Collections.emptyList(); + private Configurator configurator = + Configurator.fetchContainerDefaultConfigurator(); + + + private Builder(Class endpointClass, + String path) { + this.endpointClass = endpointClass; + this.path = path; + } + + public ServerEndpointConfig build() { + return new DefaultServerEndpointConfig(endpointClass, path, + subprotocols, extensions, encoders, decoders, configurator); + } + + + public Builder encoders( + List> encoders) { + if (encoders == null || encoders.size() == 0) { + this.encoders = Collections.emptyList(); + } else { + this.encoders = Collections.unmodifiableList(encoders); + } + return this; + } + + + public Builder decoders( + List> decoders) { + if (decoders == null || decoders.size() == 0) { + this.decoders = Collections.emptyList(); + } else { + this.decoders = Collections.unmodifiableList(decoders); + } + return this; + } + + + public Builder subprotocols( + List subprotocols) { + if (subprotocols == null || subprotocols.size() == 0) { + this.subprotocols = Collections.emptyList(); + } else { + this.subprotocols = Collections.unmodifiableList(subprotocols); + } + return this; + } + + + public Builder extensions( + List extensions) { + if (extensions == null || extensions.size() == 0) { + this.extensions = Collections.emptyList(); + } else { + this.extensions = Collections.unmodifiableList(extensions); + } + return this; + } + + + public Builder configurator(Configurator serverEndpointConfigurator) { + if (serverEndpointConfigurator == null) { + this.configurator = Configurator.fetchContainerDefaultConfigurator(); + } else { + this.configurator = serverEndpointConfigurator; + } + return this; + } + } + + + public class Configurator { + + private static volatile Configurator defaultImpl = null; + private static final Object defaultImplLock = new Object(); + + private static final String DEFAULT_IMPL_CLASSNAME = + "nginx.unit.websocket.server.DefaultServerEndpointConfigurator"; + + public static void setDefault(Configurator def) { + synchronized (defaultImplLock) { + defaultImpl = def; + } + } + + static Configurator fetchContainerDefaultConfigurator() { + if (defaultImpl == null) { + synchronized (defaultImplLock) { + if (defaultImpl == null) { + defaultImpl = loadDefault(); + } + } + } + return defaultImpl; + } + + + private static Configurator loadDefault() { + Configurator result = null; + + ServiceLoader serviceLoader = + ServiceLoader.load(Configurator.class); + + Iterator iter = serviceLoader.iterator(); + while (result == null && iter.hasNext()) { + result = iter.next(); + } + + // Fall-back. Also used by unit tests + if (result == null) { + try { + @SuppressWarnings("unchecked") + Class clazz = + (Class) Class.forName( + DEFAULT_IMPL_CLASSNAME); + result = clazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException | IllegalArgumentException | + SecurityException e) { + // No options left. Just return null. + } + } + return result; + } + + public String getNegotiatedSubprotocol(List supported, + List requested) { + return fetchContainerDefaultConfigurator().getNegotiatedSubprotocol(supported, requested); + } + + public List getNegotiatedExtensions(List installed, + List requested) { + return fetchContainerDefaultConfigurator().getNegotiatedExtensions(installed, requested); + } + + public boolean checkOrigin(String originHeaderValue) { + return fetchContainerDefaultConfigurator().checkOrigin(originHeaderValue); + } + + public void modifyHandshake(ServerEndpointConfig sec, + HandshakeRequest request, HandshakeResponse response) { + fetchContainerDefaultConfigurator().modifyHandshake(sec, request, response); + } + + public T getEndpointInstance(Class clazz) + throws InstantiationException { + return fetchContainerDefaultConfigurator().getEndpointInstance( + clazz); + } + } +} diff --git a/src/java/nginx/unit/Context.java b/src/java/nginx/unit/Context.java index e1482903..6fcd6018 100644 --- a/src/java/nginx/unit/Context.java +++ b/src/java/nginx/unit/Context.java @@ -98,10 +98,14 @@ import javax.servlet.http.HttpSessionEvent; import javax.servlet.http.HttpSessionIdListener; import javax.servlet.http.HttpSessionListener; +import javax.websocket.server.ServerEndpoint; + import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.ParserConfigurationException; +import nginx.unit.websocket.WsSession; + import org.eclipse.jetty.http.MimeTypes; import org.w3c.dom.Document; @@ -421,6 +425,9 @@ public class Context implements ServletContext, InitParams loader_ = new AppClassLoader(urls, Context.class.getClassLoader().getParent()); + Class wsSession_class = WsSession.class; + trace("wsSession.test: " + WsSession.wsSession_test()); + ClassLoader old = Thread.currentThread().getContextClassLoader(); Thread.currentThread().setContextClassLoader(loader_); @@ -429,28 +436,30 @@ public class Context implements ServletContext, InitParams addListener(listener_classname); } - ScanResult scan_res = null; + ClassGraph classgraph = new ClassGraph() + //.verbose() + .overrideClassLoaders(loader_) + .ignoreParentClassLoaders() + .enableClassInfo() + .enableAnnotationInfo() + //.enableSystemPackages() + .whitelistModules("javax.*") + //.enableAllInfo() + ; - if (!metadata_complete_) { - ClassGraph classgraph = new ClassGraph() - //.verbose() - .overrideClassLoaders(loader_) - .ignoreParentClassLoaders() - .enableClassInfo() - .enableAnnotationInfo() - //.enableSystemPackages() - .whitelistModules("javax.*") - //.enableAllInfo() - ; - - String verbose = System.getProperty("nginx.unit.context.classgraph.verbose", "").trim(); - - if (verbose.equals("true")) { - classgraph.verbose(); - } + String verbose = System.getProperty("nginx.unit.context.classgraph.verbose", "").trim(); + + if (verbose.equals("true")) { + classgraph.verbose(); + } + + ScanResult scan_res = classgraph.scan(); + + javax.websocket.server.ServerEndpointConfig.Configurator.setDefault(new nginx.unit.websocket.server.DefaultServerEndpointConfigurator()); - scan_res = classgraph.scan(); + loadInitializer(new nginx.unit.websocket.server.WsSci(), scan_res); + if (!metadata_complete_) { loadInitializers(scan_res); } @@ -1471,54 +1480,61 @@ public class Context implements ServletContext, InitParams ServiceLoader.load(ServletContainerInitializer.class, loader_); for (ServletContainerInitializer sci : initializers) { + loadInitializer(sci, scan_res); + } + } - trace("loadInitializers: initializer: " + sci.getClass().getName()); + private void loadInitializer(ServletContainerInitializer sci, ScanResult scan_res) + { + trace("loadInitializer: initializer: " + sci.getClass().getName()); - HandlesTypes ann = sci.getClass().getAnnotation(HandlesTypes.class); - if (ann == null) { - trace("loadInitializers: no HandlesTypes annotation"); - continue; - } + HandlesTypes ann = sci.getClass().getAnnotation(HandlesTypes.class); + if (ann == null) { + trace("loadInitializer: no HandlesTypes annotation"); + return; + } - Class[] classes = ann.value(); - if (classes == null) { - trace("loadInitializers: no handles classes"); - continue; - } + Class[] classes = ann.value(); + if (classes == null) { + trace("loadInitializer: no handles classes"); + return; + } - Set> handles_classes = new HashSet<>(); + Set> handles_classes = new HashSet<>(); - for (Class c : classes) { - trace("loadInitializers: find handles: " + c.getName()); + for (Class c : classes) { + trace("loadInitializer: find handles: " + c.getName()); - ClassInfoList handles = c.isInterface() + ClassInfoList handles = + c.isAnnotation() + ? scan_res.getClassesWithAnnotation(c.getName()) + : c.isInterface() ? scan_res.getClassesImplementing(c.getName()) : scan_res.getSubclasses(c.getName()); - for (ClassInfo ci : handles) { - if (ci.isInterface() - || ci.isAnnotation() - || ci.isAbstract()) - { - continue; - } - - trace("loadInitializers: handles class: " + ci.getName()); - handles_classes.add(ci.loadClass()); + for (ClassInfo ci : handles) { + if (ci.isInterface() + || ci.isAnnotation() + || ci.isAbstract()) + { + return; } - } - if (handles_classes.isEmpty()) { - trace("loadInitializers: no handles implementations"); - continue; + trace("loadInitializer: handles class: " + ci.getName()); + handles_classes.add(ci.loadClass()); } + } - try { - sci.onStartup(handles_classes, this); - metadata_complete_ = true; - } catch(Exception e) { - System.err.println("loadInitializers: exception caught: " + e.toString()); - } + if (handles_classes.isEmpty()) { + trace("loadInitializer: no handles implementations"); + return; + } + + try { + sci.onStartup(handles_classes, this); + metadata_complete_ = true; + } catch(Exception e) { + System.err.println("loadInitializer: exception caught: " + e.toString()); } } @@ -1691,6 +1707,21 @@ public class Context implements ServletContext, InitParams listener_classnames_.add(ci.getName()); } + + + ClassInfoList endpoints = scan_res.getClassesWithAnnotation(ServerEndpoint.class.getName()); + + for (ClassInfo ci : endpoints) { + if (ci.isInterface() + || ci.isAnnotation() + || ci.isAbstract()) + { + trace("scanClasses: skip server end point: " + ci.getName()); + continue; + } + + trace("scanClasses: server end point: " + ci.getName()); + } } public void stop() throws IOException diff --git a/src/java/nginx/unit/Request.java b/src/java/nginx/unit/Request.java index 98584efe..335d7980 100644 --- a/src/java/nginx/unit/Request.java +++ b/src/java/nginx/unit/Request.java @@ -16,6 +16,7 @@ import java.lang.StringBuffer; import java.net.URI; import java.net.URISyntaxException; +import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; @@ -65,6 +66,9 @@ import org.eclipse.jetty.http.MultiPartFormInputStream; import org.eclipse.jetty.http.HttpFields; import org.eclipse.jetty.http.MimeTypes; +import nginx.unit.websocket.WsSession; +import nginx.unit.websocket.WsIOException; + public class Request implements HttpServletRequest, DynamicPathRequest { private final Context context; @@ -114,6 +118,9 @@ public class Request implements HttpServletRequest, DynamicPathRequest private boolean request_session_id_from_url = false; private Session session = null; + private WsSession wsSession = null; + private boolean skip_close_ws = false; + private final ServletRequestAttributeListener attr_listener; public static final String BARE = "nginx.unit.request.bare"; @@ -1203,11 +1210,30 @@ public class Request implements HttpServletRequest, DynamicPathRequest public T upgrade( Class httpUpgradeHandlerClass) throws java.io.IOException, ServletException { - log("upgrade: " + httpUpgradeHandlerClass.getName()); + trace("upgrade: " + httpUpgradeHandlerClass.getName()); - return null; + T handler; + + try { + handler = httpUpgradeHandlerClass.getConstructor().newInstance(); + } catch (Exception e) { + throw new ServletException(e); + } + + upgrade(req_info_ptr); + + return handler; + } + + private static native void upgrade(long req_info_ptr); + + public boolean isUpgrade() + { + return isUpgrade(req_info_ptr); } + private static native boolean isUpgrade(long req_info_ptr); + @Override public String changeSessionId() { @@ -1248,5 +1274,65 @@ public class Request implements HttpServletRequest, DynamicPathRequest public static native void trace(long req_info_ptr, String msg, int msg_len); private static native Response getResponse(long req_info_ptr); + + + public void setWsSession(WsSession s) + { + wsSession = s; + } + + private void processWsFrame(ByteBuffer buf, byte opCode, boolean last) + throws IOException + { + trace("processWsFrame: " + opCode + ", [" + buf.position() + ", " + buf.limit() + "]"); + try { + wsSession.processFrame(buf, opCode, last); + } catch (WsIOException e) { + wsSession.onClose(e.getCloseReason()); + } + } + + private void closeWsSession() + { + trace("closeWsSession"); + skip_close_ws = true; + + wsSession.onClose(); + } + + public void sendWsFrame(ByteBuffer payload, byte opCode, boolean last, + long timeoutExpiry) throws IOException + { + trace("sendWsFrame: " + opCode + ", [" + payload.position() + + ", " + payload.limit() + "]"); + + if (payload.isDirect()) { + sendWsFrame(req_info_ptr, payload, payload.position(), + payload.limit() - payload.position(), opCode, last); + } else { + sendWsFrame(req_info_ptr, payload.array(), payload.position(), + payload.limit() - payload.position(), opCode, last); + } + } + + private static native void sendWsFrame(long req_info_ptr, + ByteBuffer buf, int pos, int len, byte opCode, boolean last); + + private static native void sendWsFrame(long req_info_ptr, + byte[] arr, int pos, int len, byte opCode, boolean last); + + + public void closeWs() + { + if (skip_close_ws) { + return; + } + + trace("closeWs"); + + closeWs(req_info_ptr); + } + + private static native void closeWs(long req_info_ptr); } diff --git a/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java b/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java new file mode 100644 index 00000000..147112c1 --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.channels.AsynchronousChannelGroup; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.tomcat.util.res.StringManager; +import org.apache.tomcat.util.threads.ThreadPoolExecutor; + +/** + * This is a utility class that enables multiple {@link WsWebSocketContainer} + * instances to share a single {@link AsynchronousChannelGroup} while ensuring + * that the group is destroyed when no longer required. + */ +public class AsyncChannelGroupUtil { + + private static final StringManager sm = + StringManager.getManager(AsyncChannelGroupUtil.class); + + private static AsynchronousChannelGroup group = null; + private static int usageCount = 0; + private static final Object lock = new Object(); + + + private AsyncChannelGroupUtil() { + // Hide the default constructor + } + + + public static AsynchronousChannelGroup register() { + synchronized (lock) { + if (usageCount == 0) { + group = createAsynchronousChannelGroup(); + } + usageCount++; + return group; + } + } + + + public static void unregister() { + synchronized (lock) { + usageCount--; + if (usageCount == 0) { + group.shutdown(); + group = null; + } + } + } + + + private static AsynchronousChannelGroup createAsynchronousChannelGroup() { + // Need to do this with the right thread context class loader else the + // first web app to call this will trigger a leak + ClassLoader original = Thread.currentThread().getContextClassLoader(); + + try { + Thread.currentThread().setContextClassLoader( + AsyncIOThreadFactory.class.getClassLoader()); + + // These are the same settings as the default + // AsynchronousChannelGroup + int initialSize = Runtime.getRuntime().availableProcessors(); + ExecutorService executorService = new ThreadPoolExecutor( + 0, + Integer.MAX_VALUE, + Long.MAX_VALUE, TimeUnit.MILLISECONDS, + new SynchronousQueue(), + new AsyncIOThreadFactory()); + + try { + return AsynchronousChannelGroup.withCachedThreadPool( + executorService, initialSize); + } catch (IOException e) { + // No good reason for this to happen. + throw new IllegalStateException(sm.getString("asyncChannelGroup.createFail")); + } + } finally { + Thread.currentThread().setContextClassLoader(original); + } + } + + + private static class AsyncIOThreadFactory implements ThreadFactory { + + static { + // Load NewThreadPrivilegedAction since newThread() will not be able + // to if called from an InnocuousThread. + // See https://bz.apache.org/bugzilla/show_bug.cgi?id=57490 + NewThreadPrivilegedAction.load(); + } + + + @Override + public Thread newThread(final Runnable r) { + // Create the new Thread within a doPrivileged block to ensure that + // the thread inherits the current ProtectionDomain which is + // essential to be able to use this with a Java Applet. See + // https://bz.apache.org/bugzilla/show_bug.cgi?id=57091 + return AccessController.doPrivileged(new NewThreadPrivilegedAction(r)); + } + + // Non-anonymous class so that AsyncIOThreadFactory can load it + // explicitly + private static class NewThreadPrivilegedAction implements PrivilegedAction { + + private static AtomicInteger count = new AtomicInteger(0); + + private final Runnable r; + + public NewThreadPrivilegedAction(Runnable r) { + this.r = r; + } + + @Override + public Thread run() { + Thread t = new Thread(r); + t.setName("WebSocketClient-AsyncIO-" + count.incrementAndGet()); + t.setContextClassLoader(this.getClass().getClassLoader()); + t.setDaemon(true); + return t; + } + + private static void load() { + // NO-OP. Just provides a hook to enable the class to be loaded + } + } + } +} diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapper.java b/src/java/nginx/unit/websocket/AsyncChannelWrapper.java new file mode 100644 index 00000000..060ae9cb --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelWrapper.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; +import java.nio.channels.CompletionHandler; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.SSLException; + +/** + * This is a wrapper for a {@link java.nio.channels.AsynchronousSocketChannel} + * that limits the methods available thereby simplifying the process of + * implementing SSL/TLS support since there are fewer methods to intercept. + */ +public interface AsyncChannelWrapper { + + Future read(ByteBuffer dst); + + void read(ByteBuffer dst, A attachment, + CompletionHandler handler); + + Future write(ByteBuffer src); + + void write(ByteBuffer[] srcs, int offset, int length, + long timeout, TimeUnit unit, A attachment, + CompletionHandler handler); + + void close(); + + Future handshake() throws SSLException; +} diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java b/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java new file mode 100644 index 00000000..5b88bfe1 --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousSocketChannel; +import java.nio.channels.CompletionHandler; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * Generally, just passes calls straight to the wrapped + * {@link AsynchronousSocketChannel}. In some cases exceptions may be swallowed + * to save them being swallowed by the calling code. + */ +public class AsyncChannelWrapperNonSecure implements AsyncChannelWrapper { + + private static final Future NOOP_FUTURE = new NoOpFuture(); + + private final AsynchronousSocketChannel socketChannel; + + public AsyncChannelWrapperNonSecure( + AsynchronousSocketChannel socketChannel) { + this.socketChannel = socketChannel; + } + + @Override + public Future read(ByteBuffer dst) { + return socketChannel.read(dst); + } + + @Override + public void read(ByteBuffer dst, A attachment, + CompletionHandler handler) { + socketChannel.read(dst, attachment, handler); + } + + @Override + public Future write(ByteBuffer src) { + return socketChannel.write(src); + } + + @Override + public void write(ByteBuffer[] srcs, int offset, int length, + long timeout, TimeUnit unit, A attachment, + CompletionHandler handler) { + socketChannel.write( + srcs, offset, length, timeout, unit, attachment, handler); + } + + @Override + public void close() { + try { + socketChannel.close(); + } catch (IOException e) { + // Ignore + } + } + + @Override + public Future handshake() { + return NOOP_FUTURE; + } + + + private static final class NoOpFuture implements Future { + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return true; + } + + @Override + public Void get() throws InterruptedException, ExecutionException { + return null; + } + + @Override + public Void get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + return null; + } + } +} diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java b/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java new file mode 100644 index 00000000..21654487 --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java @@ -0,0 +1,578 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousSocketChannel; +import java.nio.channels.CompletionHandler; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLException; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; + +/** + * Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot + * more testing before it can be considered robust. + */ +public class AsyncChannelWrapperSecure implements AsyncChannelWrapper { + + private final Log log = + LogFactory.getLog(AsyncChannelWrapperSecure.class); + private static final StringManager sm = + StringManager.getManager(AsyncChannelWrapperSecure.class); + + private static final ByteBuffer DUMMY = ByteBuffer.allocate(16921); + private final AsynchronousSocketChannel socketChannel; + private final SSLEngine sslEngine; + private final ByteBuffer socketReadBuffer; + private final ByteBuffer socketWriteBuffer; + // One thread for read, one for write + private final ExecutorService executor = + Executors.newFixedThreadPool(2, new SecureIOThreadFactory()); + private AtomicBoolean writing = new AtomicBoolean(false); + private AtomicBoolean reading = new AtomicBoolean(false); + + public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, + SSLEngine sslEngine) { + this.socketChannel = socketChannel; + this.sslEngine = sslEngine; + + int socketBufferSize = sslEngine.getSession().getPacketBufferSize(); + socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize); + socketWriteBuffer = ByteBuffer.allocateDirect(socketBufferSize); + } + + @Override + public Future read(ByteBuffer dst) { + WrapperFuture future = new WrapperFuture<>(); + + if (!reading.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentRead")); + } + + ReadTask readTask = new ReadTask(dst, future); + + executor.execute(readTask); + + return future; + } + + @Override + public void read(ByteBuffer dst, A attachment, + CompletionHandler handler) { + + WrapperFuture future = + new WrapperFuture<>(handler, attachment); + + if (!reading.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentRead")); + } + + ReadTask readTask = new ReadTask(dst, future); + + executor.execute(readTask); + } + + @Override + public Future write(ByteBuffer src) { + + WrapperFuture inner = new WrapperFuture<>(); + + if (!writing.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentWrite")); + } + + WriteTask writeTask = + new WriteTask(new ByteBuffer[] {src}, 0, 1, inner); + + executor.execute(writeTask); + + Future future = new LongToIntegerFuture(inner); + return future; + } + + @Override + public void write(ByteBuffer[] srcs, int offset, int length, + long timeout, TimeUnit unit, A attachment, + CompletionHandler handler) { + + WrapperFuture future = + new WrapperFuture<>(handler, attachment); + + if (!writing.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentWrite")); + } + + WriteTask writeTask = new WriteTask(srcs, offset, length, future); + + executor.execute(writeTask); + } + + @Override + public void close() { + try { + socketChannel.close(); + } catch (IOException e) { + log.info(sm.getString("asyncChannelWrapperSecure.closeFail")); + } + executor.shutdownNow(); + } + + @Override + public Future handshake() throws SSLException { + + WrapperFuture wFuture = new WrapperFuture<>(); + + Thread t = new WebSocketSslHandshakeThread(wFuture); + t.start(); + + return wFuture; + } + + + private class WriteTask implements Runnable { + + private final ByteBuffer[] srcs; + private final int offset; + private final int length; + private final WrapperFuture future; + + public WriteTask(ByteBuffer[] srcs, int offset, int length, + WrapperFuture future) { + this.srcs = srcs; + this.future = future; + this.offset = offset; + this.length = length; + } + + @Override + public void run() { + long written = 0; + + try { + for (int i = offset; i < offset + length; i++) { + ByteBuffer src = srcs[i]; + while (src.hasRemaining()) { + socketWriteBuffer.clear(); + + // Encrypt the data + SSLEngineResult r = sslEngine.wrap(src, socketWriteBuffer); + written += r.bytesConsumed(); + Status s = r.getStatus(); + + if (s == Status.OK || s == Status.BUFFER_OVERFLOW) { + // Need to write out the bytes and may need to read from + // the source again to empty it + } else { + // Status.BUFFER_UNDERFLOW - only happens on unwrap + // Status.CLOSED - unexpected + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.statusWrap")); + } + + // Check for tasks + if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + Runnable runnable = sslEngine.getDelegatedTask(); + while (runnable != null) { + runnable.run(); + runnable = sslEngine.getDelegatedTask(); + } + } + + socketWriteBuffer.flip(); + + // Do the write + int toWrite = r.bytesProduced(); + while (toWrite > 0) { + Future f = + socketChannel.write(socketWriteBuffer); + Integer socketWrite = f.get(); + toWrite -= socketWrite.intValue(); + } + } + } + + + if (writing.compareAndSet(true, false)) { + future.complete(Long.valueOf(written)); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateWrite"))); + } + } catch (Exception e) { + writing.set(false); + future.fail(e); + } + } + } + + + private class ReadTask implements Runnable { + + private final ByteBuffer dest; + private final WrapperFuture future; + + public ReadTask(ByteBuffer dest, WrapperFuture future) { + this.dest = dest; + this.future = future; + } + + @Override + public void run() { + int read = 0; + + boolean forceRead = false; + + try { + while (read == 0) { + socketReadBuffer.compact(); + + if (forceRead) { + forceRead = false; + Future f = socketChannel.read(socketReadBuffer); + Integer socketRead = f.get(); + if (socketRead.intValue() == -1) { + throw new EOFException(sm.getString("asyncChannelWrapperSecure.eof")); + } + } + + socketReadBuffer.flip(); + + if (socketReadBuffer.hasRemaining()) { + // Decrypt the data in the buffer + SSLEngineResult r = sslEngine.unwrap(socketReadBuffer, dest); + read += r.bytesProduced(); + Status s = r.getStatus(); + + if (s == Status.OK) { + // Bytes available for reading and there may be + // sufficient data in the socketReadBuffer to + // support further reads without reading from the + // socket + } else if (s == Status.BUFFER_UNDERFLOW) { + // There is partial data in the socketReadBuffer + if (read == 0) { + // Need more data before the partial data can be + // processed and some output generated + forceRead = true; + } + // else return the data we have and deal with the + // partial data on the next read + } else if (s == Status.BUFFER_OVERFLOW) { + // Not enough space in the destination buffer to + // store all of the data. We could use a bytes read + // value of -bufferSizeRequired to signal the new + // buffer size required but an explicit exception is + // clearer. + if (reading.compareAndSet(true, false)) { + throw new ReadBufferOverflowException(sslEngine. + getSession().getApplicationBufferSize()); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateRead"))); + } + } else { + // Status.CLOSED - unexpected + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.statusUnwrap")); + } + + // Check for tasks + if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + Runnable runnable = sslEngine.getDelegatedTask(); + while (runnable != null) { + runnable.run(); + runnable = sslEngine.getDelegatedTask(); + } + } + } else { + forceRead = true; + } + } + + + if (reading.compareAndSet(true, false)) { + future.complete(Integer.valueOf(read)); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateRead"))); + } + } catch (RuntimeException | ReadBufferOverflowException | SSLException | EOFException | + ExecutionException | InterruptedException e) { + reading.set(false); + future.fail(e); + } + } + } + + + private class WebSocketSslHandshakeThread extends Thread { + + private final WrapperFuture hFuture; + + private HandshakeStatus handshakeStatus; + private Status resultStatus; + + public WebSocketSslHandshakeThread(WrapperFuture hFuture) { + this.hFuture = hFuture; + } + + @Override + public void run() { + try { + sslEngine.beginHandshake(); + // So the first compact does the right thing + socketReadBuffer.position(socketReadBuffer.limit()); + + handshakeStatus = sslEngine.getHandshakeStatus(); + resultStatus = Status.OK; + + boolean handshaking = true; + + while(handshaking) { + switch (handshakeStatus) { + case NEED_WRAP: { + socketWriteBuffer.clear(); + SSLEngineResult r = + sslEngine.wrap(DUMMY, socketWriteBuffer); + checkResult(r, true); + socketWriteBuffer.flip(); + Future fWrite = + socketChannel.write(socketWriteBuffer); + fWrite.get(); + break; + } + case NEED_UNWRAP: { + socketReadBuffer.compact(); + if (socketReadBuffer.position() == 0 || + resultStatus == Status.BUFFER_UNDERFLOW) { + Future fRead = + socketChannel.read(socketReadBuffer); + fRead.get(); + } + socketReadBuffer.flip(); + SSLEngineResult r = + sslEngine.unwrap(socketReadBuffer, DUMMY); + checkResult(r, false); + break; + } + case NEED_TASK: { + Runnable r = null; + while ((r = sslEngine.getDelegatedTask()) != null) { + r.run(); + } + handshakeStatus = sslEngine.getHandshakeStatus(); + break; + } + case FINISHED: { + handshaking = false; + break; + } + case NOT_HANDSHAKING: { + throw new SSLException( + sm.getString("asyncChannelWrapperSecure.notHandshaking")); + } + } + } + } catch (Exception e) { + hFuture.fail(e); + return; + } + + hFuture.complete(null); + } + + private void checkResult(SSLEngineResult result, boolean wrap) + throws SSLException { + + handshakeStatus = result.getHandshakeStatus(); + resultStatus = result.getStatus(); + + if (resultStatus != Status.OK && + (wrap || resultStatus != Status.BUFFER_UNDERFLOW)) { + throw new SSLException( + sm.getString("asyncChannelWrapperSecure.check.notOk", resultStatus)); + } + if (wrap && result.bytesConsumed() != 0) { + throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.wrap")); + } + if (!wrap && result.bytesProduced() != 0) { + throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.unwrap")); + } + } + } + + + private static class WrapperFuture implements Future { + + private final CompletionHandler handler; + private final A attachment; + + private volatile T result = null; + private volatile Throwable throwable = null; + private CountDownLatch completionLatch = new CountDownLatch(1); + + public WrapperFuture() { + this(null, null); + } + + public WrapperFuture(CompletionHandler handler, A attachment) { + this.handler = handler; + this.attachment = attachment; + } + + public void complete(T result) { + this.result = result; + completionLatch.countDown(); + if (handler != null) { + handler.completed(result, attachment); + } + } + + public void fail(Throwable t) { + throwable = t; + completionLatch.countDown(); + if (handler != null) { + handler.failed(throwable, attachment); + } + } + + @Override + public final boolean cancel(boolean mayInterruptIfRunning) { + // Could support cancellation by closing the connection + return false; + } + + @Override + public final boolean isCancelled() { + // Could support cancellation by closing the connection + return false; + } + + @Override + public final boolean isDone() { + return completionLatch.getCount() > 0; + } + + @Override + public T get() throws InterruptedException, ExecutionException { + completionLatch.await(); + if (throwable != null) { + throw new ExecutionException(throwable); + } + return result; + } + + @Override + public T get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + boolean latchResult = completionLatch.await(timeout, unit); + if (latchResult == false) { + throw new TimeoutException(); + } + if (throwable != null) { + throw new ExecutionException(throwable); + } + return result; + } + } + + private static final class LongToIntegerFuture implements Future { + + private final Future wrapped; + + public LongToIntegerFuture(Future wrapped) { + this.wrapped = wrapped; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return wrapped.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return wrapped.isCancelled(); + } + + @Override + public boolean isDone() { + return wrapped.isDone(); + } + + @Override + public Integer get() throws InterruptedException, ExecutionException { + Long result = wrapped.get(); + if (result.longValue() > Integer.MAX_VALUE) { + throw new ExecutionException(sm.getString( + "asyncChannelWrapperSecure.tooBig", result), null); + } + return Integer.valueOf(result.intValue()); + } + + @Override + public Integer get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + Long result = wrapped.get(timeout, unit); + if (result.longValue() > Integer.MAX_VALUE) { + throw new ExecutionException(sm.getString( + "asyncChannelWrapperSecure.tooBig", result), null); + } + return Integer.valueOf(result.intValue()); + } + } + + + private static class SecureIOThreadFactory implements ThreadFactory { + + private AtomicInteger count = new AtomicInteger(0); + + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r); + t.setName("WebSocketClient-SecureIO-" + count.incrementAndGet()); + // No need to set the context class loader. The threads will be + // cleaned up when the connection is closed. + t.setDaemon(true); + return t; + } + } +} diff --git a/src/java/nginx/unit/websocket/AuthenticationException.java b/src/java/nginx/unit/websocket/AuthenticationException.java new file mode 100644 index 00000000..001f1829 --- /dev/null +++ b/src/java/nginx/unit/websocket/AuthenticationException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +/** + * Exception thrown on authentication error connecting to a remote + * websocket endpoint. + */ +public class AuthenticationException extends Exception { + + private static final long serialVersionUID = 5709887412240096441L; + + /** + * Create authentication exception. + * @param message the error message + */ + public AuthenticationException(String message) { + super(message); + } + +} diff --git a/src/java/nginx/unit/websocket/Authenticator.java b/src/java/nginx/unit/websocket/Authenticator.java new file mode 100644 index 00000000..87b3ce6d --- /dev/null +++ b/src/java/nginx/unit/websocket/Authenticator.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Base class for the authentication methods used by the websocket client. + */ +public abstract class Authenticator { + private static final Pattern pattern = Pattern + .compile("(\\w+)\\s*=\\s*(\"([^\"]+)\"|([^,=\"]+))\\s*,?"); + + /** + * Generate the authentication header that will be sent to the server. + * @param requestUri The request URI + * @param WWWAuthenticate The server auth challenge + * @param UserProperties The user information + * @return The auth header + * @throws AuthenticationException When an error occurs + */ + public abstract String getAuthorization(String requestUri, String WWWAuthenticate, + Map UserProperties) throws AuthenticationException; + + /** + * Get the authentication method. + * @return the auth scheme + */ + public abstract String getSchemeName(); + + /** + * Utility method to parse the authentication header. + * @param WWWAuthenticate The server auth challenge + * @return the parsed header + */ + public Map parseWWWAuthenticateHeader(String WWWAuthenticate) { + + Matcher m = pattern.matcher(WWWAuthenticate); + Map challenge = new HashMap<>(); + + while (m.find()) { + String key = m.group(1); + String qtedValue = m.group(3); + String value = m.group(4); + + challenge.put(key, qtedValue != null ? qtedValue : value); + + } + + return challenge; + + } + +} diff --git a/src/java/nginx/unit/websocket/AuthenticatorFactory.java b/src/java/nginx/unit/websocket/AuthenticatorFactory.java new file mode 100644 index 00000000..7d46d7f9 --- /dev/null +++ b/src/java/nginx/unit/websocket/AuthenticatorFactory.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.Iterator; +import java.util.ServiceLoader; + +/** + * Utility method to return the appropriate authenticator according to + * the scheme that the server uses. + */ +public class AuthenticatorFactory { + + /** + * Return a new authenticator instance. + * @param authScheme The scheme used + * @return the authenticator + */ + public static Authenticator getAuthenticator(String authScheme) { + + Authenticator auth = null; + switch (authScheme.toLowerCase()) { + + case BasicAuthenticator.schemeName: + auth = new BasicAuthenticator(); + break; + + case DigestAuthenticator.schemeName: + auth = new DigestAuthenticator(); + break; + + default: + auth = loadAuthenticators(authScheme); + break; + } + + return auth; + + } + + private static Authenticator loadAuthenticators(String authScheme) { + ServiceLoader serviceLoader = ServiceLoader.load(Authenticator.class); + Iterator auths = serviceLoader.iterator(); + + while (auths.hasNext()) { + Authenticator auth = auths.next(); + if (auth.getSchemeName().equalsIgnoreCase(authScheme)) + return auth; + } + + return null; + } + +} diff --git a/src/java/nginx/unit/websocket/BackgroundProcess.java b/src/java/nginx/unit/websocket/BackgroundProcess.java new file mode 100644 index 00000000..0d2e1288 --- /dev/null +++ b/src/java/nginx/unit/websocket/BackgroundProcess.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +public interface BackgroundProcess { + + void backgroundProcess(); + + void setProcessPeriod(int period); + + int getProcessPeriod(); +} diff --git a/src/java/nginx/unit/websocket/BackgroundProcessManager.java b/src/java/nginx/unit/websocket/BackgroundProcessManager.java new file mode 100644 index 00000000..d8b1b950 --- /dev/null +++ b/src/java/nginx/unit/websocket/BackgroundProcessManager.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.HashSet; +import java.util.Set; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.res.StringManager; + +/** + * Provides a background processing mechanism that triggers roughly once a + * second. The class maintains a thread that only runs when there is at least + * one instance of {@link BackgroundProcess} registered. + */ +public class BackgroundProcessManager { + + private final Log log = + LogFactory.getLog(BackgroundProcessManager.class); + private static final StringManager sm = + StringManager.getManager(BackgroundProcessManager.class); + private static final BackgroundProcessManager instance; + + + static { + instance = new BackgroundProcessManager(); + } + + + public static BackgroundProcessManager getInstance() { + return instance; + } + + private final Set processes = new HashSet<>(); + private final Object processesLock = new Object(); + private WsBackgroundThread wsBackgroundThread = null; + + private BackgroundProcessManager() { + // Hide default constructor + } + + + public void register(BackgroundProcess process) { + synchronized (processesLock) { + if (processes.size() == 0) { + wsBackgroundThread = new WsBackgroundThread(this); + wsBackgroundThread.setContextClassLoader( + this.getClass().getClassLoader()); + wsBackgroundThread.setDaemon(true); + wsBackgroundThread.start(); + } + processes.add(process); + } + } + + + public void unregister(BackgroundProcess process) { + synchronized (processesLock) { + processes.remove(process); + if (wsBackgroundThread != null && processes.size() == 0) { + wsBackgroundThread.halt(); + wsBackgroundThread = null; + } + } + } + + + private void process() { + Set currentProcesses = new HashSet<>(); + synchronized (processesLock) { + currentProcesses.addAll(processes); + } + for (BackgroundProcess process : currentProcesses) { + try { + process.backgroundProcess(); + } catch (Throwable t) { + ExceptionUtils.handleThrowable(t); + log.error(sm.getString( + "backgroundProcessManager.processFailed"), t); + } + } + } + + + /* + * For unit testing. + */ + int getProcessCount() { + synchronized (processesLock) { + return processes.size(); + } + } + + + void shutdown() { + synchronized (processesLock) { + processes.clear(); + if (wsBackgroundThread != null) { + wsBackgroundThread.halt(); + wsBackgroundThread = null; + } + } + } + + + private static class WsBackgroundThread extends Thread { + + private final BackgroundProcessManager manager; + private volatile boolean running = true; + + public WsBackgroundThread(BackgroundProcessManager manager) { + setName("WebSocket background processing"); + this.manager = manager; + } + + @Override + public void run() { + while (running) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + // Ignore + } + manager.process(); + } + } + + public void halt() { + setName("WebSocket background processing - stopping"); + running = false; + } + } +} diff --git a/src/java/nginx/unit/websocket/BasicAuthenticator.java b/src/java/nginx/unit/websocket/BasicAuthenticator.java new file mode 100644 index 00000000..1b1a6b83 --- /dev/null +++ b/src/java/nginx/unit/websocket/BasicAuthenticator.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Map; + +/** + * Authenticator supporting the BASIC auth method. + */ +public class BasicAuthenticator extends Authenticator { + + public static final String schemeName = "basic"; + public static final String charsetparam = "charset"; + + @Override + public String getAuthorization(String requestUri, String WWWAuthenticate, + Map userProperties) throws AuthenticationException { + + String userName = (String) userProperties.get(Constants.WS_AUTHENTICATION_USER_NAME); + String password = (String) userProperties.get(Constants.WS_AUTHENTICATION_PASSWORD); + + if (userName == null || password == null) { + throw new AuthenticationException( + "Failed to perform Basic authentication due to missing user/password"); + } + + Map wwwAuthenticate = parseWWWAuthenticateHeader(WWWAuthenticate); + + String userPass = userName + ":" + password; + Charset charset; + + if (wwwAuthenticate.get(charsetparam) != null + && wwwAuthenticate.get(charsetparam).equalsIgnoreCase("UTF-8")) { + charset = StandardCharsets.UTF_8; + } else { + charset = StandardCharsets.ISO_8859_1; + } + + String base64 = Base64.getEncoder().encodeToString(userPass.getBytes(charset)); + + return " Basic " + base64; + } + + @Override + public String getSchemeName() { + return schemeName; + } + +} diff --git a/src/java/nginx/unit/websocket/Constants.java b/src/java/nginx/unit/websocket/Constants.java new file mode 100644 index 00000000..38b22fe0 --- /dev/null +++ b/src/java/nginx/unit/websocket/Constants.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import javax.websocket.Extension; + +/** + * Internal implementation constants. + */ +public class Constants { + + // OP Codes + public static final byte OPCODE_CONTINUATION = 0x00; + public static final byte OPCODE_TEXT = 0x01; + public static final byte OPCODE_BINARY = 0x02; + public static final byte OPCODE_CLOSE = 0x08; + public static final byte OPCODE_PING = 0x09; + public static final byte OPCODE_PONG = 0x0A; + + // Internal OP Codes + // RFC 6455 limits OP Codes to 4 bits so these should never clash + // Always set bit 4 so these will be treated as control codes + static final byte INTERNAL_OPCODE_FLUSH = 0x18; + + // Buffers + static final int DEFAULT_BUFFER_SIZE = Integer.getInteger( + "nginx.unit.websocket.DEFAULT_BUFFER_SIZE", 8 * 1024) + .intValue(); + + // Client connection + /** + * Property name to set to configure the value that is passed to + * {@link javax.net.ssl.SSLEngine#setEnabledProtocols(String[])}. The value + * should be a comma separated string. + */ + public static final String SSL_PROTOCOLS_PROPERTY = + "nginx.unit.websocket.SSL_PROTOCOLS"; + public static final String SSL_TRUSTSTORE_PROPERTY = + "nginx.unit.websocket.SSL_TRUSTSTORE"; + public static final String SSL_TRUSTSTORE_PWD_PROPERTY = + "nginx.unit.websocket.SSL_TRUSTSTORE_PWD"; + public static final String SSL_TRUSTSTORE_PWD_DEFAULT = "changeit"; + /** + * Property name to set to configure used SSLContext. The value should be an + * instance of SSLContext. If this property is present, the SSL_TRUSTSTORE* + * properties are ignored. + */ + public static final String SSL_CONTEXT_PROPERTY = + "nginx.unit.websocket.SSL_CONTEXT"; + /** + * Property name to set to configure the timeout (in milliseconds) when + * establishing a WebSocket connection to server. The default is + * {@link #IO_TIMEOUT_MS_DEFAULT}. + */ + public static final String IO_TIMEOUT_MS_PROPERTY = + "nginx.unit.websocket.IO_TIMEOUT_MS"; + public static final long IO_TIMEOUT_MS_DEFAULT = 5000; + + // RFC 2068 recommended a limit of 5 + // Most browsers have a default limit of 20 + public static final String MAX_REDIRECTIONS_PROPERTY = + "nginx.unit.websocket.MAX_REDIRECTIONS"; + public static final int MAX_REDIRECTIONS_DEFAULT = 20; + + // HTTP upgrade header names and values + public static final String HOST_HEADER_NAME = "Host"; + public static final String UPGRADE_HEADER_NAME = "Upgrade"; + public static final String UPGRADE_HEADER_VALUE = "websocket"; + public static final String ORIGIN_HEADER_NAME = "Origin"; + public static final String CONNECTION_HEADER_NAME = "Connection"; + public static final String CONNECTION_HEADER_VALUE = "upgrade"; + public static final String LOCATION_HEADER_NAME = "Location"; + public static final String AUTHORIZATION_HEADER_NAME = "Authorization"; + public static final String WWW_AUTHENTICATE_HEADER_NAME = "WWW-Authenticate"; + public static final String WS_VERSION_HEADER_NAME = "Sec-WebSocket-Version"; + public static final String WS_VERSION_HEADER_VALUE = "13"; + public static final String WS_KEY_HEADER_NAME = "Sec-WebSocket-Key"; + public static final String WS_PROTOCOL_HEADER_NAME = "Sec-WebSocket-Protocol"; + public static final String WS_EXTENSIONS_HEADER_NAME = "Sec-WebSocket-Extensions"; + + /// HTTP redirection status codes + public static final int MULTIPLE_CHOICES = 300; + public static final int MOVED_PERMANENTLY = 301; + public static final int FOUND = 302; + public static final int SEE_OTHER = 303; + public static final int USE_PROXY = 305; + public static final int TEMPORARY_REDIRECT = 307; + + // Configuration for Origin header in client + static final String DEFAULT_ORIGIN_HEADER_VALUE = + System.getProperty("nginx.unit.websocket.DEFAULT_ORIGIN_HEADER_VALUE"); + + // Configuration for blocking sends + public static final String BLOCKING_SEND_TIMEOUT_PROPERTY = + "nginx.unit.websocket.BLOCKING_SEND_TIMEOUT"; + // Milliseconds so this is 20 seconds + public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000; + + // Configuration for background processing checks intervals + static final int DEFAULT_PROCESS_PERIOD = Integer.getInteger( + "nginx.unit.websocket.DEFAULT_PROCESS_PERIOD", 10) + .intValue(); + + public static final String WS_AUTHENTICATION_USER_NAME = "nginx.unit.websocket.WS_AUTHENTICATION_USER_NAME"; + public static final String WS_AUTHENTICATION_PASSWORD = "nginx.unit.websocket.WS_AUTHENTICATION_PASSWORD"; + + /* Configuration for extensions + * Note: These options are primarily present to enable this implementation + * to pass compliance tests. They are expected to be removed once + * the WebSocket API includes a mechanism for adding custom extensions + * and disabling built-in extensions. + */ + static final boolean DISABLE_BUILTIN_EXTENSIONS = + Boolean.getBoolean("nginx.unit.websocket.DISABLE_BUILTIN_EXTENSIONS"); + static final boolean ALLOW_UNSUPPORTED_EXTENSIONS = + Boolean.getBoolean("nginx.unit.websocket.ALLOW_UNSUPPORTED_EXTENSIONS"); + + // Configuration for stream behavior + static final boolean STREAMS_DROP_EMPTY_MESSAGES = + Boolean.getBoolean("nginx.unit.websocket.STREAMS_DROP_EMPTY_MESSAGES"); + + public static final boolean STRICT_SPEC_COMPLIANCE = + Boolean.getBoolean("nginx.unit.websocket.STRICT_SPEC_COMPLIANCE"); + + public static final List INSTALLED_EXTENSIONS; + + static { + if (DISABLE_BUILTIN_EXTENSIONS) { + INSTALLED_EXTENSIONS = Collections.unmodifiableList(new ArrayList()); + } else { + List installed = new ArrayList<>(1); + installed.add(new WsExtension("permessage-deflate")); + INSTALLED_EXTENSIONS = Collections.unmodifiableList(installed); + } + } + + private Constants() { + // Hide default constructor + } +} diff --git a/src/java/nginx/unit/websocket/DecoderEntry.java b/src/java/nginx/unit/websocket/DecoderEntry.java new file mode 100644 index 00000000..36112ef4 --- /dev/null +++ b/src/java/nginx/unit/websocket/DecoderEntry.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.Decoder; + +public class DecoderEntry { + + private final Class clazz; + private final Class decoderClazz; + + public DecoderEntry(Class clazz, + Class decoderClazz) { + this.clazz = clazz; + this.decoderClazz = decoderClazz; + } + + public Class getClazz() { + return clazz; + } + + public Class getDecoderClazz() { + return decoderClazz; + } +} diff --git a/src/java/nginx/unit/websocket/DigestAuthenticator.java b/src/java/nginx/unit/websocket/DigestAuthenticator.java new file mode 100644 index 00000000..9530c303 --- /dev/null +++ b/src/java/nginx/unit/websocket/DigestAuthenticator.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.Map; + +import org.apache.tomcat.util.security.MD5Encoder; + +/** + * Authenticator supporting the DIGEST auth method. + */ +public class DigestAuthenticator extends Authenticator { + + public static final String schemeName = "digest"; + private SecureRandom cnonceGenerator; + private int nonceCount = 0; + private long cNonce; + + @Override + public String getAuthorization(String requestUri, String WWWAuthenticate, + Map userProperties) throws AuthenticationException { + + String userName = (String) userProperties.get(Constants.WS_AUTHENTICATION_USER_NAME); + String password = (String) userProperties.get(Constants.WS_AUTHENTICATION_PASSWORD); + + if (userName == null || password == null) { + throw new AuthenticationException( + "Failed to perform Digest authentication due to missing user/password"); + } + + Map wwwAuthenticate = parseWWWAuthenticateHeader(WWWAuthenticate); + + String realm = wwwAuthenticate.get("realm"); + String nonce = wwwAuthenticate.get("nonce"); + String messageQop = wwwAuthenticate.get("qop"); + String algorithm = wwwAuthenticate.get("algorithm") == null ? "MD5" + : wwwAuthenticate.get("algorithm"); + String opaque = wwwAuthenticate.get("opaque"); + + StringBuilder challenge = new StringBuilder(); + + if (!messageQop.isEmpty()) { + if (cnonceGenerator == null) { + cnonceGenerator = new SecureRandom(); + } + + cNonce = cnonceGenerator.nextLong(); + nonceCount++; + } + + challenge.append("Digest "); + challenge.append("username =\"" + userName + "\","); + challenge.append("realm=\"" + realm + "\","); + challenge.append("nonce=\"" + nonce + "\","); + challenge.append("uri=\"" + requestUri + "\","); + + try { + challenge.append("response=\"" + calculateRequestDigest(requestUri, userName, password, + realm, nonce, messageQop, algorithm) + "\","); + } + + catch (NoSuchAlgorithmException e) { + throw new AuthenticationException( + "Unable to generate request digest " + e.getMessage()); + } + + challenge.append("algorithm=" + algorithm + ","); + challenge.append("opaque=\"" + opaque + "\","); + + if (!messageQop.isEmpty()) { + challenge.append("qop=\"" + messageQop + "\""); + challenge.append(",cnonce=\"" + cNonce + "\","); + challenge.append("nc=" + String.format("%08X", Integer.valueOf(nonceCount))); + } + + return challenge.toString(); + + } + + private String calculateRequestDigest(String requestUri, String userName, String password, + String realm, String nonce, String qop, String algorithm) + throws NoSuchAlgorithmException { + + StringBuilder preDigest = new StringBuilder(); + String A1; + + if (algorithm.equalsIgnoreCase("MD5")) + A1 = userName + ":" + realm + ":" + password; + + else + A1 = encodeMD5(userName + ":" + realm + ":" + password) + ":" + nonce + ":" + cNonce; + + /* + * If the "qop" value is "auth-int", then A2 is: A2 = Method ":" + * digest-uri-value ":" H(entity-body) since we do not have an entity-body, A2 = + * Method ":" digest-uri-value for auth and auth_int + */ + String A2 = "GET:" + requestUri; + + preDigest.append(encodeMD5(A1)); + preDigest.append(":"); + preDigest.append(nonce); + + if (qop.toLowerCase().contains("auth")) { + preDigest.append(":"); + preDigest.append(String.format("%08X", Integer.valueOf(nonceCount))); + preDigest.append(":"); + preDigest.append(String.valueOf(cNonce)); + preDigest.append(":"); + preDigest.append(qop); + } + + preDigest.append(":"); + preDigest.append(encodeMD5(A2)); + + return encodeMD5(preDigest.toString()); + + } + + private String encodeMD5(String value) throws NoSuchAlgorithmException { + byte[] bytesOfMessage = value.getBytes(StandardCharsets.ISO_8859_1); + MessageDigest md = MessageDigest.getInstance("MD5"); + byte[] thedigest = md.digest(bytesOfMessage); + + return MD5Encoder.encode(thedigest); + } + + @Override + public String getSchemeName() { + return schemeName; + } +} diff --git a/src/java/nginx/unit/websocket/FutureToSendHandler.java b/src/java/nginx/unit/websocket/FutureToSendHandler.java new file mode 100644 index 00000000..4a0809cb --- /dev/null +++ b/src/java/nginx/unit/websocket/FutureToSendHandler.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +import org.apache.tomcat.util.res.StringManager; + + +/** + * Converts a Future to a SendHandler. + */ +class FutureToSendHandler implements Future, SendHandler { + + private static final StringManager sm = StringManager.getManager(FutureToSendHandler.class); + + private final CountDownLatch latch = new CountDownLatch(1); + private final WsSession wsSession; + private volatile AtomicReference result = new AtomicReference<>(null); + + public FutureToSendHandler(WsSession wsSession) { + this.wsSession = wsSession; + } + + + // --------------------------------------------------------- SendHandler + + @Override + public void onResult(SendResult result) { + this.result.compareAndSet(null, result); + latch.countDown(); + } + + + // -------------------------------------------------------------- Future + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + // Cancelling the task is not supported + return false; + } + + @Override + public boolean isCancelled() { + // Cancelling the task is not supported + return false; + } + + @Override + public boolean isDone() { + return latch.getCount() == 0; + } + + @Override + public Void get() throws InterruptedException, + ExecutionException { + try { + wsSession.registerFuture(this); + latch.await(); + } finally { + wsSession.unregisterFuture(this); + } + if (result.get().getException() != null) { + throw new ExecutionException(result.get().getException()); + } + return null; + } + + @Override + public Void get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + boolean retval = false; + try { + wsSession.registerFuture(this); + retval = latch.await(timeout, unit); + } finally { + wsSession.unregisterFuture(this); + + } + if (retval == false) { + throw new TimeoutException(sm.getString("futureToSendHandler.timeout", + Long.valueOf(timeout), unit.toString().toLowerCase())); + } + if (result.get().getException() != null) { + throw new ExecutionException(result.get().getException()); + } + return null; + } +} diff --git a/src/java/nginx/unit/websocket/LocalStrings.properties b/src/java/nginx/unit/websocket/LocalStrings.properties new file mode 100644 index 00000000..aeafe082 --- /dev/null +++ b/src/java/nginx/unit/websocket/LocalStrings.properties @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +asyncChannelGroup.createFail=Unable to create dedicated AsynchronousChannelGroup for WebSocket clients which is required to prevent memory leaks in complex class loader environments like JavaEE containers + +asyncChannelWrapperSecure.closeFail=Failed to close channel cleanly +asyncChannelWrapperSecure.check.notOk=TLS handshake returned an unexpected status [{0}] +asyncChannelWrapperSecure.check.unwrap=Bytes were written to the output during a read +asyncChannelWrapperSecure.check.wrap=Bytes were consumed from the input during a write +asyncChannelWrapperSecure.concurrentRead=Concurrent read operations are not permitted +asyncChannelWrapperSecure.concurrentWrite=Concurrent write operations are not permitted +asyncChannelWrapperSecure.eof=Unexpected end of stream +asyncChannelWrapperSecure.notHandshaking=Unexpected state [NOT_HANDSHAKING] during TLS handshake +asyncChannelWrapperSecure.readOverflow=Buffer overflow. [{0}] bytes to write into a [{1}] byte buffer that already contained [{2}] bytes. +asyncChannelWrapperSecure.statusUnwrap=Unexpected Status of SSLEngineResult after an unwrap() operation +asyncChannelWrapperSecure.statusWrap=Unexpected Status of SSLEngineResult after a wrap() operation +asyncChannelWrapperSecure.tooBig=The result [{0}] is too big to be expressed as an Integer +asyncChannelWrapperSecure.wrongStateRead=Flag that indicates a read is in progress was found to be false (it should have been true) when trying to complete a read operation +asyncChannelWrapperSecure.wrongStateWrite=Flag that indicates a write is in progress was found to be false (it should have been true) when trying to complete a write operation + +backgroundProcessManager.processFailed=A background process failed + +caseInsensitiveKeyMap.nullKey=Null keys are not permitted + +futureToSendHandler.timeout=Operation timed out after waiting [{0}] [{1}] to complete + +perMessageDeflate.deflateFailed=Failed to decompress a compressed WebSocket frame +perMessageDeflate.duplicateParameter=Duplicate definition of the [{0}] extension parameter +perMessageDeflate.invalidWindowSize=An invalid windows of [{1}] size was specified for [{0}]. Valid values are whole numbers from 8 to 15 inclusive. +perMessageDeflate.unknownParameter=An unknown extension parameter [{0}] was defined + +transformerFactory.unsupportedExtension=The extension [{0}] is not supported + +util.notToken=An illegal extension parameter was specified with name [{0}] and value [{1}] +util.invalidMessageHandler=The message handler provided does not have an onMessage(Object) method +util.invalidType=Unable to coerce value [{0}] to type [{1}]. That type is not supported. +util.unknownDecoderType=The Decoder type [{0}] is not recognized + +# Note the wsFrame.* messages are used as close reasons in WebSocket control +# frames and therefore must be 123 bytes (not characters) or less in length. +# Messages are encoded using UTF-8 where a single character may be encoded in +# as many as 4 bytes. +wsFrame.alreadyResumed=Message receiving has already been resumed. +wsFrame.alreadySuspended=Message receiving has already been suspended. +wsFrame.bufferTooSmall=No async message support and buffer too small. Buffer size: [{0}], Message size: [{1}] +wsFrame.byteToLongFail=Too many bytes ([{0}]) were provided to be converted into a long +wsFrame.closed=New frame received after a close control frame +wsFrame.controlFragmented=A fragmented control frame was received but control frames may not be fragmented +wsFrame.controlPayloadTooBig=A control frame was sent with a payload of size [{0}] which is larger than the maximum permitted of 125 bytes +wsFrame.controlNoFin=A control frame was sent that did not have the fin bit set. Control frames are not permitted to use continuation frames. +wsFrame.illegalReadState=Unexpected read state [{0}] +wsFrame.invalidOpCode= A WebSocket frame was sent with an unrecognised opCode of [{0}] +wsFrame.invalidUtf8=A WebSocket text frame was received that could not be decoded to UTF-8 because it contained invalid byte sequences +wsFrame.invalidUtf8Close=A WebSocket close frame was received with a close reason that contained invalid UTF-8 byte sequences +wsFrame.ioeTriggeredClose=An unrecoverable IOException occurred so the connection was closed +wsFrame.messageTooBig=The message was [{0}] bytes long but the MessageHandler has a limit of [{1}] bytes +wsFrame.noContinuation=A new message was started when a continuation frame was expected +wsFrame.notMasked=The client frame was not masked but all client frames must be masked +wsFrame.oneByteCloseCode=The client sent a close frame with a single byte payload which is not valid +wsFrame.partialHeaderComplete=WebSocket frame received. fin [{0}], rsv [{1}], OpCode [{2}], payload length [{3}] +wsFrame.sessionClosed=The client data cannot be processed because the session has already been closed +wsFrame.suspendRequested=Suspend of the message receiving has already been requested. +wsFrame.textMessageTooBig=The decoded text message was too big for the output buffer and the endpoint does not support partial messages +wsFrame.wrongRsv=The client frame set the reserved bits to [{0}] for a message with opCode [{1}] which was not supported by this endpoint + +wsFrameClient.ioe=Failure while reading data sent by server + +wsHandshakeRequest.invalidUri=The string [{0}] cannot be used to construct a valid URI +wsHandshakeRequest.unknownScheme=The scheme [{0}] in the request is not recognised + +wsRemoteEndpoint.acquireTimeout=The current message was not fully sent within the specified timeout +wsRemoteEndpoint.closed=Message will not be sent because the WebSocket session has been closed +wsRemoteEndpoint.closedDuringMessage=The remainder of the message will not be sent because the WebSocket session has been closed +wsRemoteEndpoint.closedOutputStream=This method may not be called as the OutputStream has been closed +wsRemoteEndpoint.closedWriter=This method may not be called as the Writer has been closed +wsRemoteEndpoint.changeType=When sending a fragmented message, all fragments must be of the same type +wsRemoteEndpoint.concurrentMessageSend=Messages may not be sent concurrently even when using the asynchronous send messages. The client must wait for the previous message to complete before sending the next. +wsRemoteEndpoint.flushOnCloseFailed=Batched messages still enabled after session has been closed. Unable to flush remaining batched message. +wsRemoteEndpoint.invalidEncoder=The specified encoder of type [{0}] could not be instantiated +wsRemoteEndpoint.noEncoder=No encoder specified for object of class [{0}] +wsRemoteEndpoint.nullData=Invalid null data argument +wsRemoteEndpoint.nullHandler=Invalid null handler argument +wsRemoteEndpoint.sendInterrupt=The current thread was interrupted while waiting for a blocking send to complete +wsRemoteEndpoint.tooMuchData=Ping or pong may not send more than 125 bytes +wsRemoteEndpoint.wrongState=The remote endpoint was in state [{0}] which is an invalid state for called method + +# Note the following message is used as a close reason in a WebSocket control +# frame and therefore must be 123 bytes (not characters) or less in length. +# Messages are encoded using UTF-8 where a single character may be encoded in +# as many as 4 bytes. +wsSession.timeout=The WebSocket session [{0}] timeout expired + +wsSession.closed=The WebSocket session [{0}] has been closed and no method (apart from close()) may be called on a closed session +wsSession.created=Created WebSocket session [{0}] +wsSession.doClose=Closing WebSocket session [{1}] +wsSession.duplicateHandlerBinary=A binary message handler has already been configured +wsSession.duplicateHandlerPong=A pong message handler has already been configured +wsSession.duplicateHandlerText=A text message handler has already been configured +wsSession.invalidHandlerTypePong=A pong message handler must implement MessageHandler.Whole +wsSession.flushFailOnClose=Failed to flush batched messages on session close +wsSession.messageFailed=Unable to write the complete message as the WebSocket connection has been closed +wsSession.sendCloseFail=Failed to send close message for session [{0}] to remote endpoint +wsSession.removeHandlerFailed=Unable to remove the handler [{0}] as it was not registered with this session +wsSession.unknownHandler=Unable to add the message handler [{0}] as it was for the unrecognised type [{1}] +wsSession.unknownHandlerType=Unable to add the message handler [{0}] as it was wrapped as the unrecognised type [{1}] +wsSession.instanceNew=Endpoint instance registration failed +wsSession.instanceDestroy=Endpoint instance unregistration failed + +# Note the following message is used as a close reason in a WebSocket control +# frame and therefore must be 123 bytes (not characters) or less in length. +# Messages are encoded using UTF-8 where a single character may be encoded in +# as many as 4 bytes. +wsWebSocketContainer.shutdown=The web application is stopping + +wsWebSocketContainer.defaultConfiguratorFail=Failed to create the default configurator +wsWebSocketContainer.endpointCreateFail=Failed to create a local endpoint of type [{0}] +wsWebSocketContainer.maxBuffer=This implementation limits the maximum size of a buffer to Integer.MAX_VALUE +wsWebSocketContainer.missingAnnotation=Cannot use POJO class [{0}] as it is not annotated with @ClientEndpoint +wsWebSocketContainer.sessionCloseFail=Session with ID [{0}] did not close cleanly + +wsWebSocketContainer.asynchronousSocketChannelFail=Unable to open a connection to the server +wsWebSocketContainer.httpRequestFailed=The HTTP request to initiate the WebSocket connection failed +wsWebSocketContainer.invalidExtensionParameters=The server responded with extension parameters the client is unable to support +wsWebSocketContainer.invalidHeader=Unable to parse HTTP header as no colon is present to delimit header name and header value in [{0}]. The header has been skipped. +wsWebSocketContainer.invalidStatus=The HTTP response from the server [{0}] did not permit the HTTP upgrade to WebSocket +wsWebSocketContainer.invalidSubProtocol=The WebSocket server returned multiple values for the Sec-WebSocket-Protocol header +wsWebSocketContainer.pathNoHost=No host was specified in URI +wsWebSocketContainer.pathWrongScheme=The scheme [{0}] is not supported. The supported schemes are ws and wss +wsWebSocketContainer.proxyConnectFail=Failed to connect to the configured Proxy [{0}]. The HTTP response code was [{1}] +wsWebSocketContainer.sslEngineFail=Unable to create SSLEngine to support SSL/TLS connections +wsWebSocketContainer.missingLocationHeader=Failed to handle HTTP response code [{0}]. Missing Location header in response +wsWebSocketContainer.redirectThreshold=Cyclic Location header [{0}] detected / reached max number of redirects [{1}] of max [{2}] +wsWebSocketContainer.unsupportedAuthScheme=Failed to handle HTTP response code [{0}]. Unsupported Authentication scheme [{1}] returned in response +wsWebSocketContainer.failedAuthentication=Failed to handle HTTP response code [{0}]. Authentication header was not accepted by server. +wsWebSocketContainer.missingWWWAuthenticateHeader=Failed to handle HTTP response code [{0}]. Missing WWW-Authenticate header in response diff --git a/src/java/nginx/unit/websocket/MessageHandlerResult.java b/src/java/nginx/unit/websocket/MessageHandlerResult.java new file mode 100644 index 00000000..8d532d1e --- /dev/null +++ b/src/java/nginx/unit/websocket/MessageHandlerResult.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.MessageHandler; + +public class MessageHandlerResult { + + private final MessageHandler handler; + private final MessageHandlerResultType type; + + + public MessageHandlerResult(MessageHandler handler, + MessageHandlerResultType type) { + this.handler = handler; + this.type = type; + } + + + public MessageHandler getHandler() { + return handler; + } + + + public MessageHandlerResultType getType() { + return type; + } +} diff --git a/src/java/nginx/unit/websocket/MessageHandlerResultType.java b/src/java/nginx/unit/websocket/MessageHandlerResultType.java new file mode 100644 index 00000000..1961bb4f --- /dev/null +++ b/src/java/nginx/unit/websocket/MessageHandlerResultType.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +public enum MessageHandlerResultType { + BINARY, + TEXT, + PONG +} diff --git a/src/java/nginx/unit/websocket/MessagePart.java b/src/java/nginx/unit/websocket/MessagePart.java new file mode 100644 index 00000000..b52c26f1 --- /dev/null +++ b/src/java/nginx/unit/websocket/MessagePart.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; + +import javax.websocket.SendHandler; + +class MessagePart { + private final boolean fin; + private final int rsv; + private final byte opCode; + private final ByteBuffer payload; + private final SendHandler intermediateHandler; + private volatile SendHandler endHandler; + private final long blockingWriteTimeoutExpiry; + + public MessagePart( boolean fin, int rsv, byte opCode, ByteBuffer payload, + SendHandler intermediateHandler, SendHandler endHandler, + long blockingWriteTimeoutExpiry) { + this.fin = fin; + this.rsv = rsv; + this.opCode = opCode; + this.payload = payload; + this.intermediateHandler = intermediateHandler; + this.endHandler = endHandler; + this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry; + } + + + public boolean isFin() { + return fin; + } + + + public int getRsv() { + return rsv; + } + + + public byte getOpCode() { + return opCode; + } + + + public ByteBuffer getPayload() { + return payload; + } + + + public SendHandler getIntermediateHandler() { + return intermediateHandler; + } + + + public SendHandler getEndHandler() { + return endHandler; + } + + public void setEndHandler(SendHandler endHandler) { + this.endHandler = endHandler; + } + + public long getBlockingWriteTimeoutExpiry() { + return blockingWriteTimeoutExpiry; + } +} + + diff --git a/src/java/nginx/unit/websocket/PerMessageDeflate.java b/src/java/nginx/unit/websocket/PerMessageDeflate.java new file mode 100644 index 00000000..88e0a0bc --- /dev/null +++ b/src/java/nginx/unit/websocket/PerMessageDeflate.java @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.zip.DataFormatException; +import java.util.zip.Deflater; +import java.util.zip.Inflater; + +import javax.websocket.Extension; +import javax.websocket.Extension.Parameter; +import javax.websocket.SendHandler; + +import org.apache.tomcat.util.res.StringManager; + +public class PerMessageDeflate implements Transformation { + + private static final StringManager sm = StringManager.getManager(PerMessageDeflate.class); + + private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover"; + private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover"; + private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits"; + private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits"; + + private static final int RSV_BITMASK = 0b100; + private static final byte[] EOM_BYTES = new byte[] {0, 0, -1, -1}; + + public static final String NAME = "permessage-deflate"; + + private final boolean serverContextTakeover; + private final int serverMaxWindowBits; + private final boolean clientContextTakeover; + private final int clientMaxWindowBits; + private final boolean isServer; + private final Inflater inflater = new Inflater(true); + private final ByteBuffer readBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true); + private final byte[] EOM_BUFFER = new byte[EOM_BYTES.length + 1]; + + private volatile Transformation next; + private volatile boolean skipDecompression = false; + private volatile ByteBuffer writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private volatile boolean firstCompressedFrameWritten = false; + // Flag to track if a message is completely empty + private volatile boolean emptyMessage = true; + + static PerMessageDeflate negotiate(List> preferences, boolean isServer) { + // Accept the first preference that the endpoint is able to support + for (List preference : preferences) { + boolean ok = true; + boolean serverContextTakeover = true; + int serverMaxWindowBits = -1; + boolean clientContextTakeover = true; + int clientMaxWindowBits = -1; + + for (Parameter param : preference) { + if (SERVER_NO_CONTEXT_TAKEOVER.equals(param.getName())) { + if (serverContextTakeover) { + serverContextTakeover = false; + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + SERVER_NO_CONTEXT_TAKEOVER )); + } + } else if (CLIENT_NO_CONTEXT_TAKEOVER.equals(param.getName())) { + if (clientContextTakeover) { + clientContextTakeover = false; + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + CLIENT_NO_CONTEXT_TAKEOVER )); + } + } else if (SERVER_MAX_WINDOW_BITS.equals(param.getName())) { + if (serverMaxWindowBits == -1) { + serverMaxWindowBits = Integer.parseInt(param.getValue()); + if (serverMaxWindowBits < 8 || serverMaxWindowBits > 15) { + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.invalidWindowSize", + SERVER_MAX_WINDOW_BITS, + Integer.valueOf(serverMaxWindowBits))); + } + // Java SE API (as of Java 8) does not expose the API to + // control the Window size. It is effectively hard-coded + // to 15 + if (isServer && serverMaxWindowBits != 15) { + ok = false; + break; + // Note server window size is not an issue for the + // client since the client will assume 15 and if the + // server uses a smaller window everything will + // still work + } + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + SERVER_MAX_WINDOW_BITS )); + } + } else if (CLIENT_MAX_WINDOW_BITS.equals(param.getName())) { + if (clientMaxWindowBits == -1) { + if (param.getValue() == null) { + // Hint to server that the client supports this + // option. Java SE API (as of Java 8) does not + // expose the API to control the Window size. It is + // effectively hard-coded to 15 + clientMaxWindowBits = 15; + } else { + clientMaxWindowBits = Integer.parseInt(param.getValue()); + if (clientMaxWindowBits < 8 || clientMaxWindowBits > 15) { + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.invalidWindowSize", + CLIENT_MAX_WINDOW_BITS, + Integer.valueOf(clientMaxWindowBits))); + } + } + // Java SE API (as of Java 8) does not expose the API to + // control the Window size. It is effectively hard-coded + // to 15 + if (!isServer && clientMaxWindowBits != 15) { + ok = false; + break; + // Note client window size is not an issue for the + // server since the server will assume 15 and if the + // client uses a smaller window everything will + // still work + } + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + CLIENT_MAX_WINDOW_BITS )); + } + } else { + // Unknown parameter + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.unknownParameter", param.getName())); + } + } + if (ok) { + return new PerMessageDeflate(serverContextTakeover, serverMaxWindowBits, + clientContextTakeover, clientMaxWindowBits, isServer); + } + } + // Failed to negotiate agreeable terms + return null; + } + + + private PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits, + boolean clientContextTakeover, int clientMaxWindowBits, boolean isServer) { + this.serverContextTakeover = serverContextTakeover; + this.serverMaxWindowBits = serverMaxWindowBits; + this.clientContextTakeover = clientContextTakeover; + this.clientMaxWindowBits = clientMaxWindowBits; + this.isServer = isServer; + } + + + @Override + public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest) + throws IOException { + // Control frames are never compressed and may appear in the middle of + // a WebSocket method. Pass them straight through. + if (Util.isControl(opCode)) { + return next.getMoreData(opCode, fin, rsv, dest); + } + + if (!Util.isContinuation(opCode)) { + // First frame in new message + skipDecompression = (rsv & RSV_BITMASK) == 0; + } + + // Pass uncompressed frames straight through. + if (skipDecompression) { + return next.getMoreData(opCode, fin, rsv, dest); + } + + int written; + boolean usedEomBytes = false; + + while (dest.remaining() > 0) { + // Space available in destination. Try and fill it. + try { + written = inflater.inflate( + dest.array(), dest.arrayOffset() + dest.position(), dest.remaining()); + } catch (DataFormatException e) { + throw new IOException(sm.getString("perMessageDeflate.deflateFailed"), e); + } + dest.position(dest.position() + written); + + if (inflater.needsInput() && !usedEomBytes ) { + if (dest.hasRemaining()) { + readBuffer.clear(); + TransformationResult nextResult = + next.getMoreData(opCode, fin, (rsv ^ RSV_BITMASK), readBuffer); + inflater.setInput( + readBuffer.array(), readBuffer.arrayOffset(), readBuffer.position()); + if (TransformationResult.UNDERFLOW.equals(nextResult)) { + return nextResult; + } else if (TransformationResult.END_OF_FRAME.equals(nextResult) && + readBuffer.position() == 0) { + if (fin) { + inflater.setInput(EOM_BYTES); + usedEomBytes = true; + } else { + return TransformationResult.END_OF_FRAME; + } + } + } + } else if (written == 0) { + if (fin && (isServer && !clientContextTakeover || + !isServer && !serverContextTakeover)) { + inflater.reset(); + } + return TransformationResult.END_OF_FRAME; + } + } + + return TransformationResult.OVERFLOW; + } + + + @Override + public boolean validateRsv(int rsv, byte opCode) { + if (Util.isControl(opCode)) { + if ((rsv & RSV_BITMASK) != 0) { + return false; + } else { + if (next == null) { + return true; + } else { + return next.validateRsv(rsv, opCode); + } + } + } else { + int rsvNext = rsv; + if ((rsv & RSV_BITMASK) != 0) { + rsvNext = rsv ^ RSV_BITMASK; + } + if (next == null) { + return true; + } else { + return next.validateRsv(rsvNext, opCode); + } + } + } + + + @Override + public Extension getExtensionResponse() { + Extension result = new WsExtension(NAME); + + List params = result.getParameters(); + + if (!serverContextTakeover) { + params.add(new WsExtensionParameter(SERVER_NO_CONTEXT_TAKEOVER, null)); + } + if (serverMaxWindowBits != -1) { + params.add(new WsExtensionParameter(SERVER_MAX_WINDOW_BITS, + Integer.toString(serverMaxWindowBits))); + } + if (!clientContextTakeover) { + params.add(new WsExtensionParameter(CLIENT_NO_CONTEXT_TAKEOVER, null)); + } + if (clientMaxWindowBits != -1) { + params.add(new WsExtensionParameter(CLIENT_MAX_WINDOW_BITS, + Integer.toString(clientMaxWindowBits))); + } + + return result; + } + + + @Override + public void setNext(Transformation t) { + if (next == null) { + this.next = t; + } else { + next.setNext(t); + } + } + + + @Override + public boolean validateRsvBits(int i) { + if ((i & RSV_BITMASK) != 0) { + return false; + } + if (next == null) { + return true; + } else { + return next.validateRsvBits(i | RSV_BITMASK); + } + } + + + @Override + public List sendMessagePart(List uncompressedParts) { + List allCompressedParts = new ArrayList<>(); + + for (MessagePart uncompressedPart : uncompressedParts) { + byte opCode = uncompressedPart.getOpCode(); + boolean emptyPart = uncompressedPart.getPayload().limit() == 0; + emptyMessage = emptyMessage && emptyPart; + if (Util.isControl(opCode)) { + // Control messages can appear in the middle of other messages + // and must not be compressed. Pass it straight through + allCompressedParts.add(uncompressedPart); + } else if (emptyMessage && uncompressedPart.isFin()) { + // Zero length messages can't be compressed so pass the + // final (empty) part straight through. + allCompressedParts.add(uncompressedPart); + } else { + List compressedParts = new ArrayList<>(); + ByteBuffer uncompressedPayload = uncompressedPart.getPayload(); + SendHandler uncompressedIntermediateHandler = + uncompressedPart.getIntermediateHandler(); + + deflater.setInput(uncompressedPayload.array(), + uncompressedPayload.arrayOffset() + uncompressedPayload.position(), + uncompressedPayload.remaining()); + + int flush = (uncompressedPart.isFin() ? Deflater.SYNC_FLUSH : Deflater.NO_FLUSH); + boolean deflateRequired = true; + + while (deflateRequired) { + ByteBuffer compressedPayload = writeBuffer; + + int written = deflater.deflate(compressedPayload.array(), + compressedPayload.arrayOffset() + compressedPayload.position(), + compressedPayload.remaining(), flush); + compressedPayload.position(compressedPayload.position() + written); + + if (!uncompressedPart.isFin() && compressedPayload.hasRemaining() && deflater.needsInput()) { + // This message part has been fully processed by the + // deflater. Fire the send handler for this message part + // and move on to the next message part. + break; + } + + // If this point is reached, a new compressed message part + // will be created... + MessagePart compressedPart; + + // .. and a new writeBuffer will be required. + writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + + // Flip the compressed payload ready for writing + compressedPayload.flip(); + + boolean fin = uncompressedPart.isFin(); + boolean full = compressedPayload.limit() == compressedPayload.capacity(); + boolean needsInput = deflater.needsInput(); + long blockingWriteTimeoutExpiry = uncompressedPart.getBlockingWriteTimeoutExpiry(); + + if (fin && !full && needsInput) { + // End of compressed message. Drop EOM bytes and output. + compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length); + compressedPart = new MessagePart(true, getRsv(uncompressedPart), + opCode, compressedPayload, uncompressedIntermediateHandler, + uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); + deflateRequired = false; + startNewMessage(); + } else if (full && !needsInput) { + // Write buffer full and input message not fully read. + // Output and start new compressed part. + compressedPart = new MessagePart(false, getRsv(uncompressedPart), + opCode, compressedPayload, uncompressedIntermediateHandler, + uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); + } else if (!fin && full && needsInput) { + // Write buffer full and input message not fully read. + // Output and get more data. + compressedPart = new MessagePart(false, getRsv(uncompressedPart), + opCode, compressedPayload, uncompressedIntermediateHandler, + uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); + deflateRequired = false; + } else if (fin && full && needsInput) { + // Write buffer full. Input fully read. Deflater may be + // in one of four states: + // - output complete (just happened to align with end of + // buffer + // - in middle of EOM bytes + // - about to write EOM bytes + // - more data to write + int eomBufferWritten = deflater.deflate(EOM_BUFFER, 0, EOM_BUFFER.length, Deflater.SYNC_FLUSH); + if (eomBufferWritten < EOM_BUFFER.length) { + // EOM has just been completed + compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length + eomBufferWritten); + compressedPart = new MessagePart(true, + getRsv(uncompressedPart), opCode, compressedPayload, + uncompressedIntermediateHandler, uncompressedIntermediateHandler, + blockingWriteTimeoutExpiry); + deflateRequired = false; + startNewMessage(); + } else { + // More data to write + // Copy bytes to new write buffer + writeBuffer.put(EOM_BUFFER, 0, eomBufferWritten); + compressedPart = new MessagePart(false, + getRsv(uncompressedPart), opCode, compressedPayload, + uncompressedIntermediateHandler, uncompressedIntermediateHandler, + blockingWriteTimeoutExpiry); + } + } else { + throw new IllegalStateException("Should never happen"); + } + + // Add the newly created compressed part to the set of parts + // to pass on to the next transformation. + compressedParts.add(compressedPart); + } + + SendHandler uncompressedEndHandler = uncompressedPart.getEndHandler(); + int size = compressedParts.size(); + if (size > 0) { + compressedParts.get(size - 1).setEndHandler(uncompressedEndHandler); + } + + allCompressedParts.addAll(compressedParts); + } + } + + if (next == null) { + return allCompressedParts; + } else { + return next.sendMessagePart(allCompressedParts); + } + } + + + private void startNewMessage() { + firstCompressedFrameWritten = false; + emptyMessage = true; + if (isServer && !serverContextTakeover || !isServer && !clientContextTakeover) { + deflater.reset(); + } + } + + + private int getRsv(MessagePart uncompressedMessagePart) { + int result = uncompressedMessagePart.getRsv(); + if (!firstCompressedFrameWritten) { + result += RSV_BITMASK; + firstCompressedFrameWritten = true; + } + return result; + } + + + @Override + public void close() { + // There will always be a next transformation + next.close(); + inflater.end(); + deflater.end(); + } +} diff --git a/src/java/nginx/unit/websocket/ReadBufferOverflowException.java b/src/java/nginx/unit/websocket/ReadBufferOverflowException.java new file mode 100644 index 00000000..9ce7ac27 --- /dev/null +++ b/src/java/nginx/unit/websocket/ReadBufferOverflowException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; + +public class ReadBufferOverflowException extends IOException { + + private static final long serialVersionUID = 1L; + + private final int minBufferSize; + + public ReadBufferOverflowException(int minBufferSize) { + this.minBufferSize = minBufferSize; + } + + public int getMinBufferSize() { + return minBufferSize; + } +} diff --git a/src/java/nginx/unit/websocket/Transformation.java b/src/java/nginx/unit/websocket/Transformation.java new file mode 100644 index 00000000..45474c7d --- /dev/null +++ b/src/java/nginx/unit/websocket/Transformation.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import javax.websocket.Extension; + +/** + * The internal representation of the transformation that a WebSocket extension + * performs on a message. + */ +public interface Transformation { + + /** + * Sets the next transformation in the pipeline. + * @param t The next transformation + */ + void setNext(Transformation t); + + /** + * Validate that the RSV bit(s) required by this transformation are not + * being used by another extension. The implementation is expected to set + * any bits it requires before passing the set of in-use bits to the next + * transformation. + * + * @param i The RSV bits marked as in use so far as an int in the + * range zero to seven with RSV1 as the MSB and RSV3 as the + * LSB + * + * @return true if the combination of RSV bits used by the + * transformations in the pipeline do not conflict otherwise + * false + */ + boolean validateRsvBits(int i); + + /** + * Obtain the extension that describes the information to be returned to the + * client. + * + * @return The extension information that describes the parameters that have + * been agreed for this transformation + */ + Extension getExtensionResponse(); + + /** + * Obtain more input data. + * + * @param opCode The opcode for the frame currently being processed + * @param fin Is this the final frame in this WebSocket message? + * @param rsv The reserved bits for the frame currently being + * processed + * @param dest The buffer in which the data is to be written + * + * @return The result of trying to read more data from the transform + * + * @throws IOException If an I/O error occurs while reading data from the + * transform + */ + TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest) throws IOException; + + /** + * Validates the RSV and opcode combination (assumed to have been extracted + * from a WebSocket Frame) for this extension. The implementation is + * expected to unset any RSV bits it has validated before passing the + * remaining RSV bits to the next transformation in the pipeline. + * + * @param rsv The RSV bits received as an int in the range zero to + * seven with RSV1 as the MSB and RSV3 as the LSB + * @param opCode The opCode received + * + * @return true if the RSV is valid otherwise + * false + */ + boolean validateRsv(int rsv, byte opCode); + + /** + * Takes the provided list of messages, transforms them, passes the + * transformed list on to the next transformation (if any) and then returns + * the resulting list of message parts after all of the transformations have + * been applied. + * + * @param messageParts The list of messages to be transformed + * + * @return The list of messages after this any any subsequent + * transformations have been applied. The size of the returned list + * may be bigger or smaller than the size of the input list + */ + List sendMessagePart(List messageParts); + + /** + * Clean-up any resources that were used by the transformation. + */ + void close(); +} diff --git a/src/java/nginx/unit/websocket/TransformationFactory.java b/src/java/nginx/unit/websocket/TransformationFactory.java new file mode 100644 index 00000000..fac04555 --- /dev/null +++ b/src/java/nginx/unit/websocket/TransformationFactory.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.List; + +import javax.websocket.Extension; + +import org.apache.tomcat.util.res.StringManager; + +public class TransformationFactory { + + private static final StringManager sm = StringManager.getManager(TransformationFactory.class); + + private static final TransformationFactory factory = new TransformationFactory(); + + private TransformationFactory() { + // Hide default constructor + } + + public static TransformationFactory getInstance() { + return factory; + } + + public Transformation create(String name, List> preferences, + boolean isServer) { + if (PerMessageDeflate.NAME.equals(name)) { + return PerMessageDeflate.negotiate(preferences, isServer); + } + if (Constants.ALLOW_UNSUPPORTED_EXTENSIONS) { + return null; + } else { + throw new IllegalArgumentException( + sm.getString("transformerFactory.unsupportedExtension", name)); + } + } +} diff --git a/src/java/nginx/unit/websocket/TransformationResult.java b/src/java/nginx/unit/websocket/TransformationResult.java new file mode 100644 index 00000000..0de35e55 --- /dev/null +++ b/src/java/nginx/unit/websocket/TransformationResult.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +public enum TransformationResult { + /** + * The end of the available data was reached before the WebSocket frame was + * completely read. + */ + UNDERFLOW, + + /** + * The provided destination buffer was filled before all of the available + * data from the WebSocket frame could be processed. + */ + OVERFLOW, + + /** + * The end of the WebSocket frame was reached and all the data from that + * frame processed into the provided destination buffer. + */ + END_OF_FRAME +} diff --git a/src/java/nginx/unit/websocket/Util.java b/src/java/nginx/unit/websocket/Util.java new file mode 100644 index 00000000..6acf3ade --- /dev/null +++ b/src/java/nginx/unit/websocket/Util.java @@ -0,0 +1,666 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.InputStream; +import java.io.Reader; +import java.lang.reflect.GenericArrayType; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.nio.ByteBuffer; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; + +import javax.websocket.CloseReason.CloseCode; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.Decoder; +import javax.websocket.Decoder.Binary; +import javax.websocket.Decoder.BinaryStream; +import javax.websocket.Decoder.Text; +import javax.websocket.Decoder.TextStream; +import javax.websocket.DeploymentException; +import javax.websocket.Encoder; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; +import javax.websocket.MessageHandler; +import javax.websocket.PongMessage; +import javax.websocket.Session; + +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.pojo.PojoMessageHandlerPartialBinary; +import nginx.unit.websocket.pojo.PojoMessageHandlerWholeBinary; +import nginx.unit.websocket.pojo.PojoMessageHandlerWholeText; + +/** + * Utility class for internal use only within the + * {@link nginx.unit.websocket} package. + */ +public class Util { + + private static final StringManager sm = StringManager.getManager(Util.class); + private static final Queue randoms = + new ConcurrentLinkedQueue<>(); + + private Util() { + // Hide default constructor + } + + + static boolean isControl(byte opCode) { + return (opCode & 0x08) != 0; + } + + + static boolean isText(byte opCode) { + return opCode == Constants.OPCODE_TEXT; + } + + + static boolean isContinuation(byte opCode) { + return opCode == Constants.OPCODE_CONTINUATION; + } + + + static CloseCode getCloseCode(int code) { + if (code > 2999 && code < 5000) { + return CloseCodes.getCloseCode(code); + } + switch (code) { + case 1000: + return CloseCodes.NORMAL_CLOSURE; + case 1001: + return CloseCodes.GOING_AWAY; + case 1002: + return CloseCodes.PROTOCOL_ERROR; + case 1003: + return CloseCodes.CANNOT_ACCEPT; + case 1004: + // Should not be used in a close frame + // return CloseCodes.RESERVED; + return CloseCodes.PROTOCOL_ERROR; + case 1005: + // Should not be used in a close frame + // return CloseCodes.NO_STATUS_CODE; + return CloseCodes.PROTOCOL_ERROR; + case 1006: + // Should not be used in a close frame + // return CloseCodes.CLOSED_ABNORMALLY; + return CloseCodes.PROTOCOL_ERROR; + case 1007: + return CloseCodes.NOT_CONSISTENT; + case 1008: + return CloseCodes.VIOLATED_POLICY; + case 1009: + return CloseCodes.TOO_BIG; + case 1010: + return CloseCodes.NO_EXTENSION; + case 1011: + return CloseCodes.UNEXPECTED_CONDITION; + case 1012: + // Not in RFC6455 + // return CloseCodes.SERVICE_RESTART; + return CloseCodes.PROTOCOL_ERROR; + case 1013: + // Not in RFC6455 + // return CloseCodes.TRY_AGAIN_LATER; + return CloseCodes.PROTOCOL_ERROR; + case 1015: + // Should not be used in a close frame + // return CloseCodes.TLS_HANDSHAKE_FAILURE; + return CloseCodes.PROTOCOL_ERROR; + default: + return CloseCodes.PROTOCOL_ERROR; + } + } + + + static byte[] generateMask() { + // SecureRandom is not thread-safe so need to make sure only one thread + // uses it at a time. In theory, the pool could grow to the same size + // as the number of request processing threads. In reality it will be + // a lot smaller. + + // Get a SecureRandom from the pool + SecureRandom sr = randoms.poll(); + + // If one isn't available, generate a new one + if (sr == null) { + try { + sr = SecureRandom.getInstance("SHA1PRNG"); + } catch (NoSuchAlgorithmException e) { + // Fall back to platform default + sr = new SecureRandom(); + } + } + + // Generate the mask + byte[] result = new byte[4]; + sr.nextBytes(result); + + // Put the SecureRandom back in the poll + randoms.add(sr); + + return result; + } + + + static Class getMessageType(MessageHandler listener) { + return Util.getGenericType(MessageHandler.class, + listener.getClass()).getClazz(); + } + + + private static Class getDecoderType(Class decoder) { + return Util.getGenericType(Decoder.class, decoder).getClazz(); + } + + + static Class getEncoderType(Class encoder) { + return Util.getGenericType(Encoder.class, encoder).getClazz(); + } + + + private static TypeResult getGenericType(Class type, + Class clazz) { + + // Look to see if this class implements the interface of interest + + // Get all the interfaces + Type[] interfaces = clazz.getGenericInterfaces(); + for (Type iface : interfaces) { + // Only need to check interfaces that use generics + if (iface instanceof ParameterizedType) { + ParameterizedType pi = (ParameterizedType) iface; + // Look for the interface of interest + if (pi.getRawType() instanceof Class) { + if (type.isAssignableFrom((Class) pi.getRawType())) { + return getTypeParameter( + clazz, pi.getActualTypeArguments()[0]); + } + } + } + } + + // Interface not found on this class. Look at the superclass. + @SuppressWarnings("unchecked") + Class superClazz = + (Class) clazz.getSuperclass(); + if (superClazz == null) { + // Finished looking up the class hierarchy without finding anything + return null; + } + + TypeResult superClassTypeResult = getGenericType(type, superClazz); + int dimension = superClassTypeResult.getDimension(); + if (superClassTypeResult.getIndex() == -1 && dimension == 0) { + // Superclass implements interface and defines explicit type for + // the interface of interest + return superClassTypeResult; + } + + if (superClassTypeResult.getIndex() > -1) { + // Superclass implements interface and defines unknown type for + // the interface of interest + // Map that unknown type to the generic types defined in this class + ParameterizedType superClassType = + (ParameterizedType) clazz.getGenericSuperclass(); + TypeResult result = getTypeParameter(clazz, + superClassType.getActualTypeArguments()[ + superClassTypeResult.getIndex()]); + result.incrementDimension(superClassTypeResult.getDimension()); + if (result.getClazz() != null && result.getDimension() > 0) { + superClassTypeResult = result; + } else { + return result; + } + } + + if (superClassTypeResult.getDimension() > 0) { + StringBuilder className = new StringBuilder(); + for (int i = 0; i < dimension; i++) { + className.append('['); + } + className.append('L'); + className.append(superClassTypeResult.getClazz().getCanonicalName()); + className.append(';'); + + Class arrayClazz; + try { + arrayClazz = Class.forName(className.toString()); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException(e); + } + + return new TypeResult(arrayClazz, -1, 0); + } + + // Error will be logged further up the call stack + return null; + } + + + /* + * For a generic parameter, return either the Class used or if the type + * is unknown, the index for the type in definition of the class + */ + private static TypeResult getTypeParameter(Class clazz, Type argType) { + if (argType instanceof Class) { + return new TypeResult((Class) argType, -1, 0); + } else if (argType instanceof ParameterizedType) { + return new TypeResult((Class)((ParameterizedType) argType).getRawType(), -1, 0); + } else if (argType instanceof GenericArrayType) { + Type arrayElementType = ((GenericArrayType) argType).getGenericComponentType(); + TypeResult result = getTypeParameter(clazz, arrayElementType); + result.incrementDimension(1); + return result; + } else { + TypeVariable[] tvs = clazz.getTypeParameters(); + for (int i = 0; i < tvs.length; i++) { + if (tvs[i].equals(argType)) { + return new TypeResult(null, i, 0); + } + } + return null; + } + } + + + public static boolean isPrimitive(Class clazz) { + if (clazz.isPrimitive()) { + return true; + } else if(clazz.equals(Boolean.class) || + clazz.equals(Byte.class) || + clazz.equals(Character.class) || + clazz.equals(Double.class) || + clazz.equals(Float.class) || + clazz.equals(Integer.class) || + clazz.equals(Long.class) || + clazz.equals(Short.class)) { + return true; + } + return false; + } + + + public static Object coerceToType(Class type, String value) { + if (type.equals(String.class)) { + return value; + } else if (type.equals(boolean.class) || type.equals(Boolean.class)) { + return Boolean.valueOf(value); + } else if (type.equals(byte.class) || type.equals(Byte.class)) { + return Byte.valueOf(value); + } else if (type.equals(char.class) || type.equals(Character.class)) { + return Character.valueOf(value.charAt(0)); + } else if (type.equals(double.class) || type.equals(Double.class)) { + return Double.valueOf(value); + } else if (type.equals(float.class) || type.equals(Float.class)) { + return Float.valueOf(value); + } else if (type.equals(int.class) || type.equals(Integer.class)) { + return Integer.valueOf(value); + } else if (type.equals(long.class) || type.equals(Long.class)) { + return Long.valueOf(value); + } else if (type.equals(short.class) || type.equals(Short.class)) { + return Short.valueOf(value); + } else { + throw new IllegalArgumentException(sm.getString( + "util.invalidType", value, type.getName())); + } + } + + + public static List getDecoders( + List> decoderClazzes) + throws DeploymentException{ + + List result = new ArrayList<>(); + if (decoderClazzes != null) { + for (Class decoderClazz : decoderClazzes) { + // Need to instantiate decoder to ensure it is valid and that + // deployment can be failed if it is not + @SuppressWarnings("unused") + Decoder instance; + try { + instance = decoderClazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException( + sm.getString("pojoMethodMapping.invalidDecoder", + decoderClazz.getName()), e); + } + DecoderEntry entry = new DecoderEntry( + Util.getDecoderType(decoderClazz), decoderClazz); + result.add(entry); + } + } + + return result; + } + + + static Set getMessageHandlers(Class target, + MessageHandler listener, EndpointConfig endpointConfig, + Session session) { + + // Will never be more than 2 types + Set results = new HashSet<>(2); + + // Simple cases - handlers already accepts one of the types expected by + // the frame handling code + if (String.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.TEXT); + results.add(result); + } else if (ByteBuffer.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.BINARY); + results.add(result); + } else if (PongMessage.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.PONG); + results.add(result); + // Handler needs wrapping and optional decoder to convert it to one of + // the types expected by the frame handling code + } else if (byte[].class.isAssignableFrom(target)) { + boolean whole = MessageHandler.Whole.class.isAssignableFrom(listener.getClass()); + MessageHandlerResult result = new MessageHandlerResult( + whole ? new PojoMessageHandlerWholeBinary(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, true), + new Object[1], 0, true, -1, false, -1) : + new PojoMessageHandlerPartialBinary(listener, + getOnMessagePartialMethod(listener), session, + new Object[2], 0, true, 1, -1, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } else if (InputStream.class.isAssignableFrom(target)) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeBinary(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, true), + new Object[1], 0, true, -1, true, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } else if (Reader.class.isAssignableFrom(target)) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeText(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, false), + new Object[1], 0, true, -1, -1), + MessageHandlerResultType.TEXT); + results.add(result); + } else { + // Handler needs wrapping and requires decoder to convert it to one + // of the types expected by the frame handling code + DecoderMatch decoderMatch = matchDecoders(target, endpointConfig); + Method m = getOnMessageMethod(listener); + if (decoderMatch.getBinaryDecoders().size() > 0) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeBinary(listener, m, session, + endpointConfig, + decoderMatch.getBinaryDecoders(), new Object[1], + 0, false, -1, false, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } + if (decoderMatch.getTextDecoders().size() > 0) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeText(listener, m, session, + endpointConfig, + decoderMatch.getTextDecoders(), new Object[1], + 0, false, -1, -1), + MessageHandlerResultType.TEXT); + results.add(result); + } + } + + if (results.size() == 0) { + throw new IllegalArgumentException( + sm.getString("wsSession.unknownHandler", listener, target)); + } + + return results; + } + + private static List> matchDecoders(Class target, + EndpointConfig endpointConfig, boolean binary) { + DecoderMatch decoderMatch = matchDecoders(target, endpointConfig); + if (binary) { + if (decoderMatch.getBinaryDecoders().size() > 0) { + return decoderMatch.getBinaryDecoders(); + } + } else if (decoderMatch.getTextDecoders().size() > 0) { + return decoderMatch.getTextDecoders(); + } + return null; + } + + private static DecoderMatch matchDecoders(Class target, + EndpointConfig endpointConfig) { + DecoderMatch decoderMatch; + try { + List> decoders = + endpointConfig.getDecoders(); + List decoderEntries = getDecoders(decoders); + decoderMatch = new DecoderMatch(target, decoderEntries); + } catch (DeploymentException e) { + throw new IllegalArgumentException(e); + } + return decoderMatch; + } + + public static void parseExtensionHeader(List extensions, + String header) { + // The relevant ABNF for the Sec-WebSocket-Extensions is as follows: + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ; When using the quoted-string syntax variant, the value + // ; after quoted-string unescaping MUST conform to the + // ; 'token' ABNF. + // + // The limiting of parameter values to tokens or "quoted tokens" makes + // the parsing of the header significantly simpler and allows a number + // of short-cuts to be taken. + + // Step one, split the header into individual extensions using ',' as a + // separator + String unparsedExtensions[] = header.split(","); + for (String unparsedExtension : unparsedExtensions) { + // Step two, split the extension into the registered name and + // parameter/value pairs using ';' as a separator + String unparsedParameters[] = unparsedExtension.split(";"); + WsExtension extension = new WsExtension(unparsedParameters[0].trim()); + + for (int i = 1; i < unparsedParameters.length; i++) { + int equalsPos = unparsedParameters[i].indexOf('='); + String name; + String value; + if (equalsPos == -1) { + name = unparsedParameters[i].trim(); + value = null; + } else { + name = unparsedParameters[i].substring(0, equalsPos).trim(); + value = unparsedParameters[i].substring(equalsPos + 1).trim(); + int len = value.length(); + if (len > 1) { + if (value.charAt(0) == '\"' && value.charAt(len - 1) == '\"') { + value = value.substring(1, value.length() - 1); + } + } + } + // Make sure value doesn't contain any of the delimiters since + // that would indicate something went wrong + if (containsDelims(name) || containsDelims(value)) { + throw new IllegalArgumentException(sm.getString( + "util.notToken", name, value)); + } + if (value != null && + (value.indexOf(',') > -1 || value.indexOf(';') > -1 || + value.indexOf('\"') > -1 || value.indexOf('=') > -1)) { + throw new IllegalArgumentException(sm.getString("", value)); + } + extension.addParameter(new WsExtensionParameter(name, value)); + } + extensions.add(extension); + } + } + + + private static boolean containsDelims(String input) { + if (input == null || input.length() == 0) { + return false; + } + for (char c : input.toCharArray()) { + switch (c) { + case ',': + case ';': + case '\"': + case '=': + return true; + default: + // NO_OP + } + + } + return false; + } + + private static Method getOnMessageMethod(MessageHandler listener) { + try { + return listener.getClass().getMethod("onMessage", Object.class); + } catch (NoSuchMethodException | SecurityException e) { + throw new IllegalArgumentException( + sm.getString("util.invalidMessageHandler"), e); + } + } + + private static Method getOnMessagePartialMethod(MessageHandler listener) { + try { + return listener.getClass().getMethod("onMessage", Object.class, Boolean.TYPE); + } catch (NoSuchMethodException | SecurityException e) { + throw new IllegalArgumentException( + sm.getString("util.invalidMessageHandler"), e); + } + } + + + public static class DecoderMatch { + + private final List> textDecoders = + new ArrayList<>(); + private final List> binaryDecoders = + new ArrayList<>(); + private final Class target; + + public DecoderMatch(Class target, List decoderEntries) { + this.target = target; + for (DecoderEntry decoderEntry : decoderEntries) { + if (decoderEntry.getClazz().isAssignableFrom(target)) { + if (Binary.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + binaryDecoders.add(decoderEntry.getDecoderClazz()); + // willDecode() method means this decoder may or may not + // decode a message so need to carry on checking for + // other matches + } else if (BinaryStream.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + binaryDecoders.add(decoderEntry.getDecoderClazz()); + // Stream decoders have to process the message so no + // more decoders can be matched + break; + } else if (Text.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + textDecoders.add(decoderEntry.getDecoderClazz()); + // willDecode() method means this decoder may or may not + // decode a message so need to carry on checking for + // other matches + } else if (TextStream.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + textDecoders.add(decoderEntry.getDecoderClazz()); + // Stream decoders have to process the message so no + // more decoders can be matched + break; + } else { + throw new IllegalArgumentException( + sm.getString("util.unknownDecoderType")); + } + } + } + } + + + public List> getTextDecoders() { + return textDecoders; + } + + + public List> getBinaryDecoders() { + return binaryDecoders; + } + + + public Class getTarget() { + return target; + } + + + public boolean hasMatches() { + return (textDecoders.size() > 0) || (binaryDecoders.size() > 0); + } + } + + + private static class TypeResult { + private final Class clazz; + private final int index; + private int dimension; + + public TypeResult(Class clazz, int index, int dimension) { + this.clazz= clazz; + this.index = index; + this.dimension = dimension; + } + + public Class getClazz() { + return clazz; + } + + public int getIndex() { + return index; + } + + public int getDimension() { + return dimension; + } + + public void incrementDimension(int inc) { + dimension += inc; + } + } +} diff --git a/src/java/nginx/unit/websocket/WrappedMessageHandler.java b/src/java/nginx/unit/websocket/WrappedMessageHandler.java new file mode 100644 index 00000000..2557a73e --- /dev/null +++ b/src/java/nginx/unit/websocket/WrappedMessageHandler.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.MessageHandler; + +public interface WrappedMessageHandler { + long getMaxMessageSize(); + + MessageHandler getWrappedHandler(); +} diff --git a/src/java/nginx/unit/websocket/WsContainerProvider.java b/src/java/nginx/unit/websocket/WsContainerProvider.java new file mode 100644 index 00000000..f8a404a1 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsContainerProvider.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.ContainerProvider; +import javax.websocket.WebSocketContainer; + +public class WsContainerProvider extends ContainerProvider { + + @Override + protected WebSocketContainer getContainer() { + return new WsWebSocketContainer(); + } +} diff --git a/src/java/nginx/unit/websocket/WsExtension.java b/src/java/nginx/unit/websocket/WsExtension.java new file mode 100644 index 00000000..3846feb1 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsExtension.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.ArrayList; +import java.util.List; + +import javax.websocket.Extension; + +public class WsExtension implements Extension { + + private final String name; + private final List parameters = new ArrayList<>(); + + WsExtension(String name) { + this.name = name; + } + + void addParameter(Parameter parameter) { + parameters.add(parameter); + } + + @Override + public String getName() { + return name; + } + + @Override + public List getParameters() { + return parameters; + } +} diff --git a/src/java/nginx/unit/websocket/WsExtensionParameter.java b/src/java/nginx/unit/websocket/WsExtensionParameter.java new file mode 100644 index 00000000..9b82f1c7 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsExtensionParameter.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.Extension.Parameter; + +public class WsExtensionParameter implements Parameter { + + private final String name; + private final String value; + + WsExtensionParameter(String name, String value) { + this.name = name; + this.value = value; + } + + @Override + public String getName() { + return name; + } + + @Override + public String getValue() { + return value; + } +} diff --git a/src/java/nginx/unit/websocket/WsFrameBase.java b/src/java/nginx/unit/websocket/WsFrameBase.java new file mode 100644 index 00000000..06d20bf4 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsFrameBase.java @@ -0,0 +1,1010 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; +import java.nio.charset.CodingErrorAction; +import java.util.List; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.Extension; +import javax.websocket.MessageHandler; +import javax.websocket.PongMessage; + +import org.apache.juli.logging.Log; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.buf.Utf8Decoder; +import org.apache.tomcat.util.res.StringManager; + +/** + * Takes the ServletInputStream, processes the WebSocket frames it contains and + * extracts the messages. WebSocket Pings received will be responded to + * automatically without any action required by the application. + */ +public abstract class WsFrameBase { + + private static final StringManager sm = StringManager.getManager(WsFrameBase.class); + + // Connection level attributes + protected final WsSession wsSession; + protected final ByteBuffer inputBuffer; + private final Transformation transformation; + + // Attributes for control messages + // Control messages can appear in the middle of other messages so need + // separate attributes + private final ByteBuffer controlBufferBinary = ByteBuffer.allocate(125); + private final CharBuffer controlBufferText = CharBuffer.allocate(125); + + // Attributes of the current message + private final CharsetDecoder utf8DecoderControl = new Utf8Decoder(). + onMalformedInput(CodingErrorAction.REPORT). + onUnmappableCharacter(CodingErrorAction.REPORT); + private final CharsetDecoder utf8DecoderMessage = new Utf8Decoder(). + onMalformedInput(CodingErrorAction.REPORT). + onUnmappableCharacter(CodingErrorAction.REPORT); + private boolean continuationExpected = false; + private boolean textMessage = false; + private ByteBuffer messageBufferBinary; + private CharBuffer messageBufferText; + // Cache the message handler in force when the message starts so it is used + // consistently for the entire message + private MessageHandler binaryMsgHandler = null; + private MessageHandler textMsgHandler = null; + + // Attributes of the current frame + private boolean fin = false; + private int rsv = 0; + private byte opCode = 0; + private final byte[] mask = new byte[4]; + private int maskIndex = 0; + private long payloadLength = 0; + private volatile long payloadWritten = 0; + + // Attributes tracking state + private volatile State state = State.NEW_FRAME; + private volatile boolean open = true; + + private static final AtomicReferenceFieldUpdater READ_STATE_UPDATER = + AtomicReferenceFieldUpdater.newUpdater(WsFrameBase.class, ReadState.class, "readState"); + private volatile ReadState readState = ReadState.WAITING; + + public WsFrameBase(WsSession wsSession, Transformation transformation) { + inputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + inputBuffer.position(0).limit(0); + messageBufferBinary = ByteBuffer.allocate(wsSession.getMaxBinaryMessageBufferSize()); + messageBufferText = CharBuffer.allocate(wsSession.getMaxTextMessageBufferSize()); + this.wsSession = wsSession; + Transformation finalTransformation; + if (isMasked()) { + finalTransformation = new UnmaskTransformation(); + } else { + finalTransformation = new NoopTransformation(); + } + if (transformation == null) { + this.transformation = finalTransformation; + } else { + transformation.setNext(finalTransformation); + this.transformation = transformation; + } + } + + + protected void processInputBuffer() throws IOException { + while (!isSuspended()) { + wsSession.updateLastActive(); + if (state == State.NEW_FRAME) { + if (!processInitialHeader()) { + break; + } + // If a close frame has been received, no further data should + // have seen + if (!open) { + throw new IOException(sm.getString("wsFrame.closed")); + } + } + if (state == State.PARTIAL_HEADER) { + if (!processRemainingHeader()) { + break; + } + } + if (state == State.DATA) { + if (!processData()) { + break; + } + } + } + } + + + /** + * @return true if sufficient data was present to process all + * of the initial header + */ + private boolean processInitialHeader() throws IOException { + // Need at least two bytes of data to do this + if (inputBuffer.remaining() < 2) { + return false; + } + int b = inputBuffer.get(); + fin = (b & 0x80) != 0; + rsv = (b & 0x70) >>> 4; + opCode = (byte) (b & 0x0F); + if (!transformation.validateRsv(rsv, opCode)) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.wrongRsv", Integer.valueOf(rsv), Integer.valueOf(opCode)))); + } + + if (Util.isControl(opCode)) { + if (!fin) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.controlFragmented"))); + } + if (opCode != Constants.OPCODE_PING && + opCode != Constants.OPCODE_PONG && + opCode != Constants.OPCODE_CLOSE) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode)))); + } + } else { + if (continuationExpected) { + if (!Util.isContinuation(opCode)) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.noContinuation"))); + } + } else { + try { + if (opCode == Constants.OPCODE_BINARY) { + // New binary message + textMessage = false; + int size = wsSession.getMaxBinaryMessageBufferSize(); + if (size != messageBufferBinary.capacity()) { + messageBufferBinary = ByteBuffer.allocate(size); + } + binaryMsgHandler = wsSession.getBinaryMessageHandler(); + textMsgHandler = null; + } else if (opCode == Constants.OPCODE_TEXT) { + // New text message + textMessage = true; + int size = wsSession.getMaxTextMessageBufferSize(); + if (size != messageBufferText.capacity()) { + messageBufferText = CharBuffer.allocate(size); + } + binaryMsgHandler = null; + textMsgHandler = wsSession.getTextMessageHandler(); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode)))); + } + } catch (IllegalStateException ise) { + // Thrown if the session is already closed + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.sessionClosed"))); + } + } + continuationExpected = !fin; + } + b = inputBuffer.get(); + // Client data must be masked + if ((b & 0x80) == 0 && isMasked()) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.notMasked"))); + } + payloadLength = b & 0x7F; + state = State.PARTIAL_HEADER; + if (getLog().isDebugEnabled()) { + getLog().debug(sm.getString("wsFrame.partialHeaderComplete", Boolean.toString(fin), + Integer.toString(rsv), Integer.toString(opCode), Long.toString(payloadLength))); + } + return true; + } + + + protected abstract boolean isMasked(); + protected abstract Log getLog(); + + + /** + * @return true if sufficient data was present to complete the + * processing of the header + */ + private boolean processRemainingHeader() throws IOException { + // Ignore the 2 bytes already read. 4 for the mask + int headerLength; + if (isMasked()) { + headerLength = 4; + } else { + headerLength = 0; + } + // Add additional bytes depending on length + if (payloadLength == 126) { + headerLength += 2; + } else if (payloadLength == 127) { + headerLength += 8; + } + if (inputBuffer.remaining() < headerLength) { + return false; + } + // Calculate new payload length if necessary + if (payloadLength == 126) { + payloadLength = byteArrayToLong(inputBuffer.array(), + inputBuffer.arrayOffset() + inputBuffer.position(), 2); + inputBuffer.position(inputBuffer.position() + 2); + } else if (payloadLength == 127) { + payloadLength = byteArrayToLong(inputBuffer.array(), + inputBuffer.arrayOffset() + inputBuffer.position(), 8); + inputBuffer.position(inputBuffer.position() + 8); + } + if (Util.isControl(opCode)) { + if (payloadLength > 125) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.controlPayloadTooBig", Long.valueOf(payloadLength)))); + } + if (!fin) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.controlNoFin"))); + } + } + if (isMasked()) { + inputBuffer.get(mask, 0, 4); + } + state = State.DATA; + return true; + } + + + private boolean processData() throws IOException { + boolean result; + if (Util.isControl(opCode)) { + result = processDataControl(); + } else if (textMessage) { + if (textMsgHandler == null) { + result = swallowInput(); + } else { + result = processDataText(); + } + } else { + if (binaryMsgHandler == null) { + result = swallowInput(); + } else { + result = processDataBinary(); + } + } + checkRoomPayload(); + return result; + } + + + private boolean processDataControl() throws IOException { + TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, controlBufferBinary); + if (TransformationResult.UNDERFLOW.equals(tr)) { + return false; + } + // Control messages have fixed message size so + // TransformationResult.OVERFLOW is not possible here + + controlBufferBinary.flip(); + if (opCode == Constants.OPCODE_CLOSE) { + open = false; + String reason = null; + int code = CloseCodes.NORMAL_CLOSURE.getCode(); + if (controlBufferBinary.remaining() == 1) { + controlBufferBinary.clear(); + // Payload must be zero or 2+ bytes long + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.oneByteCloseCode"))); + } + if (controlBufferBinary.remaining() > 1) { + code = controlBufferBinary.getShort(); + if (controlBufferBinary.remaining() > 0) { + CoderResult cr = utf8DecoderControl.decode(controlBufferBinary, + controlBufferText, true); + if (cr.isError()) { + controlBufferBinary.clear(); + controlBufferText.clear(); + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidUtf8Close"))); + } + // There will be no overflow as the output buffer is big + // enough. There will be no underflow as all the data is + // passed to the decoder in a single call. + controlBufferText.flip(); + reason = controlBufferText.toString(); + } + } + wsSession.onClose(new CloseReason(Util.getCloseCode(code), reason)); + } else if (opCode == Constants.OPCODE_PING) { + if (wsSession.isOpen()) { + wsSession.getBasicRemote().sendPong(controlBufferBinary); + } + } else if (opCode == Constants.OPCODE_PONG) { + MessageHandler.Whole mhPong = wsSession.getPongMessageHandler(); + if (mhPong != null) { + try { + mhPong.onMessage(new WsPongMessage(controlBufferBinary)); + } catch (Throwable t) { + handleThrowableOnSend(t); + } finally { + controlBufferBinary.clear(); + } + } + } else { + // Should have caught this earlier but just in case... + controlBufferBinary.clear(); + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode)))); + } + controlBufferBinary.clear(); + newFrame(); + return true; + } + + + @SuppressWarnings("unchecked") + protected void sendMessageText(boolean last) throws WsIOException { + if (textMsgHandler instanceof WrappedMessageHandler) { + long maxMessageSize = ((WrappedMessageHandler) textMsgHandler).getMaxMessageSize(); + if (maxMessageSize > -1 && messageBufferText.remaining() > maxMessageSize) { + throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG, + sm.getString("wsFrame.messageTooBig", + Long.valueOf(messageBufferText.remaining()), + Long.valueOf(maxMessageSize)))); + } + } + + try { + if (textMsgHandler instanceof MessageHandler.Partial) { + ((MessageHandler.Partial) textMsgHandler) + .onMessage(messageBufferText.toString(), last); + } else { + // Caller ensures last == true if this branch is used + ((MessageHandler.Whole) textMsgHandler) + .onMessage(messageBufferText.toString()); + } + } catch (Throwable t) { + handleThrowableOnSend(t); + } finally { + messageBufferText.clear(); + } + } + + + private boolean processDataText() throws IOException { + // Copy the available data to the buffer + TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + while (!TransformationResult.END_OF_FRAME.equals(tr)) { + // Frame not complete - we ran out of something + // Convert bytes to UTF-8 + messageBufferBinary.flip(); + while (true) { + CoderResult cr = utf8DecoderMessage.decode(messageBufferBinary, messageBufferText, + false); + if (cr.isError()) { + throw new WsIOException(new CloseReason( + CloseCodes.NOT_CONSISTENT, + sm.getString("wsFrame.invalidUtf8"))); + } else if (cr.isOverflow()) { + // Ran out of space in text buffer - flush it + if (usePartial()) { + messageBufferText.flip(); + sendMessageText(false); + messageBufferText.clear(); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + } else if (cr.isUnderflow()) { + // Compact what we have to create as much space as possible + messageBufferBinary.compact(); + + // Need more input + // What did we run out of? + if (TransformationResult.OVERFLOW.equals(tr)) { + // Ran out of message buffer - exit inner loop and + // refill + break; + } else { + // TransformationResult.UNDERFLOW + // Ran out of input data - get some more + return false; + } + } + } + // Read more input data + tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + } + + messageBufferBinary.flip(); + boolean last = false; + // Frame is fully received + // Convert bytes to UTF-8 + while (true) { + CoderResult cr = utf8DecoderMessage.decode(messageBufferBinary, messageBufferText, + last); + if (cr.isError()) { + throw new WsIOException(new CloseReason( + CloseCodes.NOT_CONSISTENT, + sm.getString("wsFrame.invalidUtf8"))); + } else if (cr.isOverflow()) { + // Ran out of space in text buffer - flush it + if (usePartial()) { + messageBufferText.flip(); + sendMessageText(false); + messageBufferText.clear(); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + } else if (cr.isUnderflow() && !last) { + // End of frame and possible message as well. + + if (continuationExpected) { + // If partial messages are supported, send what we have + // managed to decode + if (usePartial()) { + messageBufferText.flip(); + sendMessageText(false); + messageBufferText.clear(); + } + messageBufferBinary.compact(); + newFrame(); + // Process next frame + return true; + } else { + // Make sure coder has flushed all output + last = true; + } + } else { + // End of message + messageBufferText.flip(); + sendMessageText(true); + + newMessage(); + return true; + } + } + } + + + private boolean processDataBinary() throws IOException { + // Copy the available data to the buffer + TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + while (!TransformationResult.END_OF_FRAME.equals(tr)) { + // Frame not complete - what did we run out of? + if (TransformationResult.UNDERFLOW.equals(tr)) { + // Ran out of input data - get some more + return false; + } + + // Ran out of message buffer - flush it + if (!usePartial()) { + CloseReason cr = new CloseReason(CloseCodes.TOO_BIG, + sm.getString("wsFrame.bufferTooSmall", + Integer.valueOf(messageBufferBinary.capacity()), + Long.valueOf(payloadLength))); + throw new WsIOException(cr); + } + messageBufferBinary.flip(); + ByteBuffer copy = ByteBuffer.allocate(messageBufferBinary.limit()); + copy.put(messageBufferBinary); + copy.flip(); + sendMessageBinary(copy, false); + messageBufferBinary.clear(); + // Read more data + tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + } + + // Frame is fully received + // Send the message if either: + // - partial messages are supported + // - the message is complete + if (usePartial() || !continuationExpected) { + messageBufferBinary.flip(); + ByteBuffer copy = ByteBuffer.allocate(messageBufferBinary.limit()); + copy.put(messageBufferBinary); + copy.flip(); + sendMessageBinary(copy, !continuationExpected); + messageBufferBinary.clear(); + } + + if (continuationExpected) { + // More data for this message expected, start a new frame + newFrame(); + } else { + // Message is complete, start a new message + newMessage(); + } + + return true; + } + + + private void handleThrowableOnSend(Throwable t) throws WsIOException { + ExceptionUtils.handleThrowable(t); + wsSession.getLocal().onError(wsSession, t); + CloseReason cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY, + sm.getString("wsFrame.ioeTriggeredClose")); + throw new WsIOException(cr); + } + + + @SuppressWarnings("unchecked") + protected void sendMessageBinary(ByteBuffer msg, boolean last) throws WsIOException { + if (binaryMsgHandler instanceof WrappedMessageHandler) { + long maxMessageSize = ((WrappedMessageHandler) binaryMsgHandler).getMaxMessageSize(); + if (maxMessageSize > -1 && msg.remaining() > maxMessageSize) { + throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG, + sm.getString("wsFrame.messageTooBig", + Long.valueOf(msg.remaining()), + Long.valueOf(maxMessageSize)))); + } + } + try { + if (binaryMsgHandler instanceof MessageHandler.Partial) { + ((MessageHandler.Partial) binaryMsgHandler).onMessage(msg, last); + } else { + // Caller ensures last == true if this branch is used + ((MessageHandler.Whole) binaryMsgHandler).onMessage(msg); + } + } catch (Throwable t) { + handleThrowableOnSend(t); + } + } + + + private void newMessage() { + messageBufferBinary.clear(); + messageBufferText.clear(); + utf8DecoderMessage.reset(); + continuationExpected = false; + newFrame(); + } + + + private void newFrame() { + if (inputBuffer.remaining() == 0) { + inputBuffer.position(0).limit(0); + } + + maskIndex = 0; + payloadWritten = 0; + state = State.NEW_FRAME; + + // These get reset in processInitialHeader() + // fin, rsv, opCode, payloadLength, mask + + checkRoomHeaders(); + } + + + private void checkRoomHeaders() { + // Is the start of the current frame too near the end of the input + // buffer? + if (inputBuffer.capacity() - inputBuffer.position() < 131) { + // Limit based on a control frame with a full payload + makeRoom(); + } + } + + + private void checkRoomPayload() { + if (inputBuffer.capacity() - inputBuffer.position() - payloadLength + payloadWritten < 0) { + makeRoom(); + } + } + + + private void makeRoom() { + inputBuffer.compact(); + inputBuffer.flip(); + } + + + private boolean usePartial() { + if (Util.isControl(opCode)) { + return false; + } else if (textMessage) { + return textMsgHandler instanceof MessageHandler.Partial; + } else { + // Must be binary + return binaryMsgHandler instanceof MessageHandler.Partial; + } + } + + + private boolean swallowInput() { + long toSkip = Math.min(payloadLength - payloadWritten, inputBuffer.remaining()); + inputBuffer.position(inputBuffer.position() + (int) toSkip); + payloadWritten += toSkip; + if (payloadWritten == payloadLength) { + if (continuationExpected) { + newFrame(); + } else { + newMessage(); + } + return true; + } else { + return false; + } + } + + + protected static long byteArrayToLong(byte[] b, int start, int len) throws IOException { + if (len > 8) { + throw new IOException(sm.getString("wsFrame.byteToLongFail", Long.valueOf(len))); + } + int shift = 0; + long result = 0; + for (int i = start + len - 1; i >= start; i--) { + result = result + ((b[i] & 0xFF) << shift); + shift += 8; + } + return result; + } + + + protected boolean isOpen() { + return open; + } + + + protected Transformation getTransformation() { + return transformation; + } + + + private enum State { + NEW_FRAME, PARTIAL_HEADER, DATA + } + + + /** + * WAITING - not suspended + * Server case: waiting for a notification that data + * is ready to be read from the socket, the socket is + * registered to the poller + * Client case: data has been read from the socket and + * is waiting for data to be processed + * PROCESSING - not suspended + * Server case: reading from the socket and processing + * the data + * Client case: processing the data if such has + * already been read and more data will be read from + * the socket + * SUSPENDING_WAIT - suspended, a call to suspend() was made while in + * WAITING state. A call to resume() will do nothing + * and will transition to WAITING state + * SUSPENDING_PROCESS - suspended, a call to suspend() was made while in + * PROCESSING state. A call to resume() will do + * nothing and will transition to PROCESSING state + * SUSPENDED - suspended + * Server case: processing data finished + * (SUSPENDING_PROCESS) / a notification was received + * that data is ready to be read from the socket + * (SUSPENDING_WAIT), socket is not registered to the + * poller + * Client case: processing data finished + * (SUSPENDING_PROCESS) / data has been read from the + * socket and is available for processing + * (SUSPENDING_WAIT) + * A call to resume() will: + * Server case: register the socket to the poller + * Client case: resume data processing + * CLOSING - not suspended, a close will be send + * + *
+     *     resume           data to be        resume
+     *     no action        processed         no action
+     *  |---------------| |---------------| |----------|
+     *  |               v |               v v          |
+     *  |  |----------WAITING --------PROCESSING----|  |
+     *  |  |             ^   processing             |  |
+     *  |  |             |   finished               |  |
+     *  |  |             |                          |  |
+     *  | suspend        |                     suspend |
+     *  |  |             |                          |  |
+     *  |  |          resume                        |  |
+     *  |  |    register socket to poller (server)  |  |
+     *  |  |    resume data processing (client)     |  |
+     *  |  |             |                          |  |
+     *  |  v             |                          v  |
+     * SUSPENDING_WAIT   |                  SUSPENDING_PROCESS
+     *  |                |                             |
+     *  | data available |        processing finished  |
+     *  |------------- SUSPENDED ----------------------|
+     * 
+ */ + protected enum ReadState { + WAITING (false), + PROCESSING (false), + SUSPENDING_WAIT (true), + SUSPENDING_PROCESS(true), + SUSPENDED (true), + CLOSING (false); + + private final boolean isSuspended; + + ReadState(boolean isSuspended) { + this.isSuspended = isSuspended; + } + + public boolean isSuspended() { + return isSuspended; + } + } + + public void suspend() { + while (true) { + switch (readState) { + case WAITING: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.WAITING, + ReadState.SUSPENDING_WAIT)) { + continue; + } + return; + case PROCESSING: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.PROCESSING, + ReadState.SUSPENDING_PROCESS)) { + continue; + } + return; + case SUSPENDING_WAIT: + if (readState != ReadState.SUSPENDING_WAIT) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.suspendRequested")); + } + } + return; + case SUSPENDING_PROCESS: + if (readState != ReadState.SUSPENDING_PROCESS) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.suspendRequested")); + } + } + return; + case SUSPENDED: + if (readState != ReadState.SUSPENDED) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.alreadySuspended")); + } + } + return; + case CLOSING: + return; + default: + throw new IllegalStateException(sm.getString("wsFrame.illegalReadState", state)); + } + } + } + + public void resume() { + while (true) { + switch (readState) { + case WAITING: + if (readState != ReadState.WAITING) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.alreadyResumed")); + } + } + return; + case PROCESSING: + if (readState != ReadState.PROCESSING) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.alreadyResumed")); + } + } + return; + case SUSPENDING_WAIT: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDING_WAIT, + ReadState.WAITING)) { + continue; + } + return; + case SUSPENDING_PROCESS: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDING_PROCESS, + ReadState.PROCESSING)) { + continue; + } + return; + case SUSPENDED: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDED, + ReadState.WAITING)) { + continue; + } + resumeProcessing(); + return; + case CLOSING: + return; + default: + throw new IllegalStateException(sm.getString("wsFrame.illegalReadState", state)); + } + } + } + + protected boolean isSuspended() { + return readState.isSuspended(); + } + + protected ReadState getReadState() { + return readState; + } + + protected void changeReadState(ReadState newState) { + READ_STATE_UPDATER.set(this, newState); + } + + protected boolean changeReadState(ReadState oldState, ReadState newState) { + return READ_STATE_UPDATER.compareAndSet(this, oldState, newState); + } + + /** + * This method will be invoked when the read operation is resumed. + * As the suspend of the read operation can be invoked at any time, when + * implementing this method one should consider that there might still be + * data remaining into the internal buffers that needs to be processed + * before reading again from the socket. + */ + protected abstract void resumeProcessing(); + + + private abstract class TerminalTransformation implements Transformation { + + @Override + public boolean validateRsvBits(int i) { + // Terminal transformations don't use RSV bits and there is no next + // transformation so always return true. + return true; + } + + @Override + public Extension getExtensionResponse() { + // Return null since terminal transformations are not extensions + return null; + } + + @Override + public void setNext(Transformation t) { + // NO-OP since this is the terminal transformation + } + + /** + * {@inheritDoc} + *

+ * Anything other than a value of zero for rsv is invalid. + */ + @Override + public boolean validateRsv(int rsv, byte opCode) { + return rsv == 0; + } + + @Override + public void close() { + // NO-OP for the terminal transformations + } + } + + + /** + * For use by the client implementation that needs to obtain payload data + * without the need for unmasking. + */ + private final class NoopTransformation extends TerminalTransformation { + + @Override + public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, + ByteBuffer dest) { + // opCode is ignored as the transformation is the same for all + // opCodes + // rsv is ignored as it known to be zero at this point + long toWrite = Math.min(payloadLength - payloadWritten, inputBuffer.remaining()); + toWrite = Math.min(toWrite, dest.remaining()); + + int orgLimit = inputBuffer.limit(); + inputBuffer.limit(inputBuffer.position() + (int) toWrite); + dest.put(inputBuffer); + inputBuffer.limit(orgLimit); + payloadWritten += toWrite; + + if (payloadWritten == payloadLength) { + return TransformationResult.END_OF_FRAME; + } else if (inputBuffer.remaining() == 0) { + return TransformationResult.UNDERFLOW; + } else { + // !dest.hasRemaining() + return TransformationResult.OVERFLOW; + } + } + + + @Override + public List sendMessagePart(List messageParts) { + // TODO Masking should move to this method + // NO-OP send so simply return the message unchanged. + return messageParts; + } + } + + + /** + * For use by the server implementation that needs to obtain payload data + * and unmask it before any further processing. + */ + private final class UnmaskTransformation extends TerminalTransformation { + + @Override + public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, + ByteBuffer dest) { + // opCode is ignored as the transformation is the same for all + // opCodes + // rsv is ignored as it known to be zero at this point + while (payloadWritten < payloadLength && inputBuffer.remaining() > 0 && + dest.hasRemaining()) { + byte b = (byte) ((inputBuffer.get() ^ mask[maskIndex]) & 0xFF); + maskIndex++; + if (maskIndex == 4) { + maskIndex = 0; + } + payloadWritten++; + dest.put(b); + } + if (payloadWritten == payloadLength) { + return TransformationResult.END_OF_FRAME; + } else if (inputBuffer.remaining() == 0) { + return TransformationResult.UNDERFLOW; + } else { + // !dest.hasRemaining() + return TransformationResult.OVERFLOW; + } + } + + @Override + public List sendMessagePart(List messageParts) { + // NO-OP send so simply return the message unchanged. + return messageParts; + } + } +} diff --git a/src/java/nginx/unit/websocket/WsFrameClient.java b/src/java/nginx/unit/websocket/WsFrameClient.java new file mode 100644 index 00000000..3174c766 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsFrameClient.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.CompletionHandler; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; + +public class WsFrameClient extends WsFrameBase { + + private final Log log = LogFactory.getLog(WsFrameClient.class); // must not be static + private static final StringManager sm = StringManager.getManager(WsFrameClient.class); + + private final AsyncChannelWrapper channel; + private final CompletionHandler handler; + // Not final as it may need to be re-sized + private volatile ByteBuffer response; + + public WsFrameClient(ByteBuffer response, AsyncChannelWrapper channel, WsSession wsSession, + Transformation transformation) { + super(wsSession, transformation); + this.response = response; + this.channel = channel; + this.handler = new WsFrameClientCompletionHandler(); + } + + + void startInputProcessing() { + try { + processSocketRead(); + } catch (IOException e) { + close(e); + } + } + + + private void processSocketRead() throws IOException { + while (true) { + switch (getReadState()) { + case WAITING: + if (!changeReadState(ReadState.WAITING, ReadState.PROCESSING)) { + continue; + } + while (response.hasRemaining()) { + if (isSuspended()) { + if (!changeReadState(ReadState.SUSPENDING_PROCESS, ReadState.SUSPENDED)) { + continue; + } + // There is still data available in the response buffer + // Return here so that the response buffer will not be + // cleared and there will be no data read from the + // socket. Thus when the read operation is resumed first + // the data left in the response buffer will be consumed + // and then a new socket read will be performed + return; + } + inputBuffer.mark(); + inputBuffer.position(inputBuffer.limit()).limit(inputBuffer.capacity()); + + int toCopy = Math.min(response.remaining(), inputBuffer.remaining()); + + // Copy remaining bytes read in HTTP phase to input buffer used by + // frame processing + + int orgLimit = response.limit(); + response.limit(response.position() + toCopy); + inputBuffer.put(response); + response.limit(orgLimit); + + inputBuffer.limit(inputBuffer.position()).reset(); + + // Process the data we have + processInputBuffer(); + } + response.clear(); + + // Get some more data + if (isOpen()) { + channel.read(response, null, handler); + } else { + changeReadState(ReadState.CLOSING); + } + return; + case SUSPENDING_WAIT: + if (!changeReadState(ReadState.SUSPENDING_WAIT, ReadState.SUSPENDED)) { + continue; + } + return; + default: + throw new IllegalStateException( + sm.getString("wsFrameServer.illegalReadState", getReadState())); + } + } + } + + + private final void close(Throwable t) { + changeReadState(ReadState.CLOSING); + CloseReason cr; + if (t instanceof WsIOException) { + cr = ((WsIOException) t).getCloseReason(); + } else { + cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY, t.getMessage()); + } + + try { + wsSession.close(cr); + } catch (IOException ignore) { + // Ignore + } + } + + + @Override + protected boolean isMasked() { + // Data is from the server so it is not masked + return false; + } + + + @Override + protected Log getLog() { + return log; + } + + private class WsFrameClientCompletionHandler implements CompletionHandler { + + @Override + public void completed(Integer result, Void attachment) { + if (result.intValue() == -1) { + // BZ 57762. A dropped connection will get reported as EOF + // rather than as an error so handle it here. + if (isOpen()) { + // No close frame was received + close(new EOFException()); + } + // No data to process + return; + } + response.flip(); + doResumeProcessing(true); + } + + @Override + public void failed(Throwable exc, Void attachment) { + if (exc instanceof ReadBufferOverflowException) { + // response will be empty if this exception is thrown + response = ByteBuffer + .allocate(((ReadBufferOverflowException) exc).getMinBufferSize()); + response.flip(); + doResumeProcessing(false); + } else { + close(exc); + } + } + + private void doResumeProcessing(boolean checkOpenOnError) { + while (true) { + switch (getReadState()) { + case PROCESSING: + if (!changeReadState(ReadState.PROCESSING, ReadState.WAITING)) { + continue; + } + resumeProcessing(checkOpenOnError); + return; + case SUSPENDING_PROCESS: + if (!changeReadState(ReadState.SUSPENDING_PROCESS, ReadState.SUSPENDED)) { + continue; + } + return; + default: + throw new IllegalStateException( + sm.getString("wsFrame.illegalReadState", getReadState())); + } + } + } + } + + + @Override + protected void resumeProcessing() { + resumeProcessing(true); + } + + private void resumeProcessing(boolean checkOpenOnError) { + try { + processSocketRead(); + } catch (IOException e) { + if (checkOpenOnError) { + // Only send a close message on an IOException if the client + // has not yet received a close control message from the server + // as the IOException may be in response to the client + // continuing to send a message after the server sent a close + // control message. + if (isOpen()) { + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsFrameClient.ioe"), e); + } + close(e); + } + } else { + close(e); + } + } + } +} diff --git a/src/java/nginx/unit/websocket/WsHandshakeResponse.java b/src/java/nginx/unit/websocket/WsHandshakeResponse.java new file mode 100644 index 00000000..6e57ffd5 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsHandshakeResponse.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.websocket.HandshakeResponse; + +import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; + +/** + * Represents the response to a WebSocket handshake. + */ +public class WsHandshakeResponse implements HandshakeResponse { + + private final Map> headers = new CaseInsensitiveKeyMap<>(); + + + public WsHandshakeResponse() { + } + + + public WsHandshakeResponse(Map> headers) { + for (Entry> entry : headers.entrySet()) { + if (this.headers.containsKey(entry.getKey())) { + this.headers.get(entry.getKey()).addAll(entry.getValue()); + } else { + List values = new ArrayList<>(entry.getValue()); + this.headers.put(entry.getKey(), values); + } + } + } + + + @Override + public Map> getHeaders() { + return headers; + } +} diff --git a/src/java/nginx/unit/websocket/WsIOException.java b/src/java/nginx/unit/websocket/WsIOException.java new file mode 100644 index 00000000..0362dc1d --- /dev/null +++ b/src/java/nginx/unit/websocket/WsIOException.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; + +import javax.websocket.CloseReason; + +/** + * Allows the WebSocket implementation to throw an {@link IOException} that + * includes a {@link CloseReason} specific to the error that can be passed back + * to the client. + */ +public class WsIOException extends IOException { + + private static final long serialVersionUID = 1L; + + private final CloseReason closeReason; + + public WsIOException(CloseReason closeReason) { + this.closeReason = closeReason; + } + + public CloseReason getCloseReason() { + return closeReason; + } +} diff --git a/src/java/nginx/unit/websocket/WsPongMessage.java b/src/java/nginx/unit/websocket/WsPongMessage.java new file mode 100644 index 00000000..531bcda9 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsPongMessage.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; + +import javax.websocket.PongMessage; + +public class WsPongMessage implements PongMessage { + + private final ByteBuffer applicationData; + + + public WsPongMessage(ByteBuffer applicationData) { + byte[] dst = new byte[applicationData.limit()]; + applicationData.get(dst); + this.applicationData = ByteBuffer.wrap(dst); + } + + + @Override + public ByteBuffer getApplicationData() { + return applicationData; + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java b/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java new file mode 100644 index 00000000..0ea20795 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; +import java.util.concurrent.Future; + +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendHandler; + +public class WsRemoteEndpointAsync extends WsRemoteEndpointBase + implements RemoteEndpoint.Async { + + WsRemoteEndpointAsync(WsRemoteEndpointImplBase base) { + super(base); + } + + + @Override + public long getSendTimeout() { + return base.getSendTimeout(); + } + + + @Override + public void setSendTimeout(long timeout) { + base.setSendTimeout(timeout); + } + + + @Override + public void sendText(String text, SendHandler completion) { + base.sendStringByCompletion(text, completion); + } + + + @Override + public Future sendText(String text) { + return base.sendStringByFuture(text); + } + + + @Override + public Future sendBinary(ByteBuffer data) { + return base.sendBytesByFuture(data); + } + + + @Override + public void sendBinary(ByteBuffer data, SendHandler completion) { + base.sendBytesByCompletion(data, completion); + } + + + @Override + public Future sendObject(Object obj) { + return base.sendObjectByFuture(obj); + } + + + @Override + public void sendObject(Object obj, SendHandler completion) { + base.sendObjectByCompletion(obj, completion); + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java b/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java new file mode 100644 index 00000000..21cb2040 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import javax.websocket.RemoteEndpoint; + +public abstract class WsRemoteEndpointBase implements RemoteEndpoint { + + protected final WsRemoteEndpointImplBase base; + + + WsRemoteEndpointBase(WsRemoteEndpointImplBase base) { + this.base = base; + } + + + @Override + public final void setBatchingAllowed(boolean batchingAllowed) throws IOException { + base.setBatchingAllowed(batchingAllowed); + } + + + @Override + public final boolean getBatchingAllowed() { + return base.getBatchingAllowed(); + } + + + @Override + public final void flushBatch() throws IOException { + base.flushBatch(); + } + + + @Override + public final void sendPing(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + base.sendPing(applicationData); + } + + + @Override + public final void sendPong(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + base.sendPong(applicationData); + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java b/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java new file mode 100644 index 00000000..2a93cc7b --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.nio.ByteBuffer; + +import javax.websocket.EncodeException; +import javax.websocket.RemoteEndpoint; + +public class WsRemoteEndpointBasic extends WsRemoteEndpointBase + implements RemoteEndpoint.Basic { + + WsRemoteEndpointBasic(WsRemoteEndpointImplBase base) { + super(base); + } + + + @Override + public void sendText(String text) throws IOException { + base.sendString(text); + } + + + @Override + public void sendBinary(ByteBuffer data) throws IOException { + base.sendBytes(data); + } + + + @Override + public void sendText(String fragment, boolean isLast) throws IOException { + base.sendPartialString(fragment, isLast); + } + + + @Override + public void sendBinary(ByteBuffer partialByte, boolean isLast) + throws IOException { + base.sendPartialBytes(partialByte, isLast); + } + + + @Override + public OutputStream getSendStream() throws IOException { + return base.getSendStream(); + } + + + @Override + public Writer getSendWriter() throws IOException { + return base.getSendWriter(); + } + + + @Override + public void sendObject(Object o) throws IOException, EncodeException { + base.sendObject(o); + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java new file mode 100644 index 00000000..776124fd --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java @@ -0,0 +1,1234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.CoderResult; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.EncodeException; +import javax.websocket.Encoder; +import javax.websocket.EndpointConfig; +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.buf.Utf8Encoder; +import org.apache.tomcat.util.res.StringManager; + +import nginx.unit.Request; + +public abstract class WsRemoteEndpointImplBase implements RemoteEndpoint { + + private static final StringManager sm = + StringManager.getManager(WsRemoteEndpointImplBase.class); + + protected static final SendResult SENDRESULT_OK = new SendResult(); + + private final Log log = LogFactory.getLog(WsRemoteEndpointImplBase.class); // must not be static + + private final StateMachine stateMachine = new StateMachine(); + + private final IntermediateMessageHandler intermediateMessageHandler = + new IntermediateMessageHandler(this); + + private Transformation transformation = null; + private final Semaphore messagePartInProgress = new Semaphore(1); + private final Queue messagePartQueue = new ArrayDeque<>(); + private final Object messagePartLock = new Object(); + + // State + private volatile boolean closed = false; + private boolean fragmented = false; + private boolean nextFragmented = false; + private boolean text = false; + private boolean nextText = false; + + // Max size of WebSocket header is 14 bytes + private final ByteBuffer headerBuffer = ByteBuffer.allocate(14); + private final ByteBuffer outputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final CharsetEncoder encoder = new Utf8Encoder(); + private final ByteBuffer encoderBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final AtomicBoolean batchingAllowed = new AtomicBoolean(false); + private volatile long sendTimeout = -1; + private WsSession wsSession; + private List encoderEntries = new ArrayList<>(); + + private Request request; + + + protected void setTransformation(Transformation transformation) { + this.transformation = transformation; + } + + + public long getSendTimeout() { + return sendTimeout; + } + + + public void setSendTimeout(long timeout) { + this.sendTimeout = timeout; + } + + + @Override + public void setBatchingAllowed(boolean batchingAllowed) throws IOException { + boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed); + + if (oldValue && !batchingAllowed) { + flushBatch(); + } + } + + + @Override + public boolean getBatchingAllowed() { + return batchingAllowed.get(); + } + + + @Override + public void flushBatch() throws IOException { + sendMessageBlock(Constants.INTERNAL_OPCODE_FLUSH, null, true); + } + + + public void sendBytes(ByteBuffer data) throws IOException { + if (data == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.binaryStart(); + sendMessageBlock(Constants.OPCODE_BINARY, data, true); + stateMachine.complete(true); + } + + + public Future sendBytesByFuture(ByteBuffer data) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendBytesByCompletion(data, f2sh); + return f2sh; + } + + + public void sendBytesByCompletion(ByteBuffer data, SendHandler handler) { + if (data == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (handler == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + StateUpdateSendHandler sush = new StateUpdateSendHandler(handler, stateMachine); + stateMachine.binaryStart(); + startMessage(Constants.OPCODE_BINARY, data, true, sush); + } + + + public void sendPartialBytes(ByteBuffer partialByte, boolean last) + throws IOException { + if (partialByte == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.binaryPartialStart(); + sendMessageBlock(Constants.OPCODE_BINARY, partialByte, last); + stateMachine.complete(last); + } + + + @Override + public void sendPing(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + if (applicationData.remaining() > 125) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData")); + } + sendMessageBlock(Constants.OPCODE_PING, applicationData, true); + } + + + @Override + public void sendPong(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + if (applicationData.remaining() > 125) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData")); + } + sendMessageBlock(Constants.OPCODE_PONG, applicationData, true); + } + + + public void sendString(String text) throws IOException { + if (text == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.textStart(); + sendMessageBlock(CharBuffer.wrap(text), true); + } + + + public Future sendStringByFuture(String text) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendStringByCompletion(text, f2sh); + return f2sh; + } + + + public void sendStringByCompletion(String text, SendHandler handler) { + if (text == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (handler == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + stateMachine.textStart(); + TextMessageSendHandler tmsh = new TextMessageSendHandler(handler, + CharBuffer.wrap(text), true, encoder, encoderBuffer, this); + tmsh.write(); + // TextMessageSendHandler will update stateMachine when it completes + } + + + public void sendPartialString(String fragment, boolean isLast) + throws IOException { + if (fragment == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.textPartialStart(); + sendMessageBlock(CharBuffer.wrap(fragment), isLast); + } + + + public OutputStream getSendStream() { + stateMachine.streamStart(); + return new WsOutputStream(this); + } + + + public Writer getSendWriter() { + stateMachine.writeStart(); + return new WsWriter(this); + } + + + void sendMessageBlock(CharBuffer part, boolean last) throws IOException { + long timeoutExpiry = getTimeoutExpiry(); + boolean isDone = false; + while (!isDone) { + encoderBuffer.clear(); + CoderResult cr = encoder.encode(part, encoderBuffer, true); + if (cr.isError()) { + throw new IllegalArgumentException(cr.toString()); + } + isDone = !cr.isOverflow(); + encoderBuffer.flip(); + sendMessageBlock(Constants.OPCODE_TEXT, encoderBuffer, last && isDone, timeoutExpiry); + } + stateMachine.complete(last); + } + + + void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last) + throws IOException { + sendMessageBlock(opCode, payload, last, getTimeoutExpiry()); + } + + + private long getTimeoutExpiry() { + // Get the timeout before we send the message. The message may + // trigger a session close and depending on timing the client + // session may close before we can read the timeout. + long timeout = getBlockingSendTimeout(); + if (timeout < 0) { + return Long.MAX_VALUE; + } else { + return System.currentTimeMillis() + timeout; + } + } + + private byte currentOpCode = Constants.OPCODE_CONTINUATION; + + private void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last, + long timeoutExpiry) throws IOException { + wsSession.updateLastActive(); + + if (opCode == currentOpCode) { + opCode = Constants.OPCODE_CONTINUATION; + } + + request.sendWsFrame(payload, opCode, last, timeoutExpiry); + + if (!last && opCode != Constants.OPCODE_CONTINUATION) { + currentOpCode = opCode; + } + + if (last && opCode == Constants.OPCODE_CONTINUATION) { + currentOpCode = Constants.OPCODE_CONTINUATION; + } + } + + + void startMessage(byte opCode, ByteBuffer payload, boolean last, + SendHandler handler) { + + wsSession.updateLastActive(); + + List messageParts = new ArrayList<>(); + messageParts.add(new MessagePart(last, 0, opCode, payload, + intermediateMessageHandler, + new EndMessageHandler(this, handler), -1)); + + messageParts = transformation.sendMessagePart(messageParts); + + // Some extensions/transformations may buffer messages so it is possible + // that no message parts will be returned. If this is the case the + // trigger the supplied SendHandler + if (messageParts.size() == 0) { + handler.onResult(new SendResult()); + return; + } + + MessagePart mp = messageParts.remove(0); + + boolean doWrite = false; + synchronized (messagePartLock) { + if (Constants.OPCODE_CLOSE == mp.getOpCode() && getBatchingAllowed()) { + // Should not happen. To late to send batched messages now since + // the session has been closed. Complain loudly. + log.warn(sm.getString("wsRemoteEndpoint.flushOnCloseFailed")); + } + if (messagePartInProgress.tryAcquire()) { + doWrite = true; + } else { + // When a control message is sent while another message is being + // sent, the control message is queued. Chances are the + // subsequent data message part will end up queued while the + // control message is sent. The logic in this class (state + // machine, EndMessageHandler, TextMessageSendHandler) ensures + // that there will only ever be one data message part in the + // queue. There could be multiple control messages in the queue. + + // Add it to the queue + messagePartQueue.add(mp); + } + // Add any remaining messages to the queue + messagePartQueue.addAll(messageParts); + } + if (doWrite) { + // Actual write has to be outside sync block to avoid possible + // deadlock between messagePartLock and writeLock in + // o.a.coyote.http11.upgrade.AbstractServletOutputStream + writeMessagePart(mp); + } + } + + + void endMessage(SendHandler handler, SendResult result) { + boolean doWrite = false; + MessagePart mpNext = null; + synchronized (messagePartLock) { + + fragmented = nextFragmented; + text = nextText; + + mpNext = messagePartQueue.poll(); + if (mpNext == null) { + messagePartInProgress.release(); + } else if (!closed){ + // Session may have been closed unexpectedly in the middle of + // sending a fragmented message closing the endpoint. If this + // happens, clearly there is no point trying to send the rest of + // the message. + doWrite = true; + } + } + if (doWrite) { + // Actual write has to be outside sync block to avoid possible + // deadlock between messagePartLock and writeLock in + // o.a.coyote.http11.upgrade.AbstractServletOutputStream + writeMessagePart(mpNext); + } + + wsSession.updateLastActive(); + + // Some handlers, such as the IntermediateMessageHandler, do not have a + // nested handler so handler may be null. + if (handler != null) { + handler.onResult(result); + } + } + + + void writeMessagePart(MessagePart mp) { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closed")); + } + + if (Constants.INTERNAL_OPCODE_FLUSH == mp.getOpCode()) { + nextFragmented = fragmented; + nextText = text; + outputBuffer.flip(); + SendHandler flushHandler = new OutputBufferFlushSendHandler( + outputBuffer, mp.getEndHandler()); + doWrite(flushHandler, mp.getBlockingWriteTimeoutExpiry(), outputBuffer); + return; + } + + // Control messages may be sent in the middle of fragmented message + // so they have no effect on the fragmented or text flags + boolean first; + if (Util.isControl(mp.getOpCode())) { + nextFragmented = fragmented; + nextText = text; + if (mp.getOpCode() == Constants.OPCODE_CLOSE) { + closed = true; + } + first = true; + } else { + boolean isText = Util.isText(mp.getOpCode()); + + if (fragmented) { + // Currently fragmented + if (text != isText) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.changeType")); + } + nextText = text; + nextFragmented = !mp.isFin(); + first = false; + } else { + // Wasn't fragmented. Might be now + if (mp.isFin()) { + nextFragmented = false; + } else { + nextFragmented = true; + nextText = isText; + } + first = true; + } + } + + byte[] mask; + + if (isMasked()) { + mask = Util.generateMask(); + } else { + mask = null; + } + + headerBuffer.clear(); + writeHeader(headerBuffer, mp.isFin(), mp.getRsv(), mp.getOpCode(), + isMasked(), mp.getPayload(), mask, first); + headerBuffer.flip(); + + if (getBatchingAllowed() || isMasked()) { + // Need to write via output buffer + OutputBufferSendHandler obsh = new OutputBufferSendHandler( + mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(), + headerBuffer, mp.getPayload(), mask, + outputBuffer, !getBatchingAllowed(), this); + obsh.write(); + } else { + // Can write directly + doWrite(mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(), + headerBuffer, mp.getPayload()); + } + } + + + private long getBlockingSendTimeout() { + Object obj = wsSession.getUserProperties().get(Constants.BLOCKING_SEND_TIMEOUT_PROPERTY); + Long userTimeout = null; + if (obj instanceof Long) { + userTimeout = (Long) obj; + } + if (userTimeout == null) { + return Constants.DEFAULT_BLOCKING_SEND_TIMEOUT; + } else { + return userTimeout.longValue(); + } + } + + + /** + * Wraps the user provided handler so that the end point is notified when + * the message is complete. + */ + private static class EndMessageHandler implements SendHandler { + + private final WsRemoteEndpointImplBase endpoint; + private final SendHandler handler; + + public EndMessageHandler(WsRemoteEndpointImplBase endpoint, + SendHandler handler) { + this.endpoint = endpoint; + this.handler = handler; + } + + + @Override + public void onResult(SendResult result) { + endpoint.endMessage(handler, result); + } + } + + + /** + * If a transformation needs to split a {@link MessagePart} into multiple + * {@link MessagePart}s, it uses this handler as the end handler for each of + * the additional {@link MessagePart}s. This handler notifies this this + * class that the {@link MessagePart} has been processed and that the next + * {@link MessagePart} in the queue should be started. The final + * {@link MessagePart} will use the {@link EndMessageHandler} provided with + * the original {@link MessagePart}. + */ + private static class IntermediateMessageHandler implements SendHandler { + + private final WsRemoteEndpointImplBase endpoint; + + public IntermediateMessageHandler(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + + @Override + public void onResult(SendResult result) { + endpoint.endMessage(null, result); + } + } + + + @SuppressWarnings({"unchecked", "rawtypes"}) + public void sendObject(Object obj) throws IOException, EncodeException { + if (obj == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + /* + * Note that the implementation will convert primitives and their object + * equivalents by default but that users are free to specify their own + * encoders and decoders for this if they wish. + */ + Encoder encoder = findEncoder(obj); + if (encoder == null && Util.isPrimitive(obj.getClass())) { + String msg = obj.toString(); + sendString(msg); + return; + } + if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) { + ByteBuffer msg = ByteBuffer.wrap((byte[]) obj); + sendBytes(msg); + return; + } + + if (encoder instanceof Encoder.Text) { + String msg = ((Encoder.Text) encoder).encode(obj); + sendString(msg); + } else if (encoder instanceof Encoder.TextStream) { + try (Writer w = getSendWriter()) { + ((Encoder.TextStream) encoder).encode(obj, w); + } + } else if (encoder instanceof Encoder.Binary) { + ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj); + sendBytes(msg); + } else if (encoder instanceof Encoder.BinaryStream) { + try (OutputStream os = getSendStream()) { + ((Encoder.BinaryStream) encoder).encode(obj, os); + } + } else { + throw new EncodeException(obj, sm.getString( + "wsRemoteEndpoint.noEncoder", obj.getClass())); + } + } + + + public Future sendObjectByFuture(Object obj) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendObjectByCompletion(obj, f2sh); + return f2sh; + } + + + @SuppressWarnings({"unchecked", "rawtypes"}) + public void sendObjectByCompletion(Object obj, SendHandler completion) { + + if (obj == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (completion == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + + /* + * Note that the implementation will convert primitives and their object + * equivalents by default but that users are free to specify their own + * encoders and decoders for this if they wish. + */ + Encoder encoder = findEncoder(obj); + if (encoder == null && Util.isPrimitive(obj.getClass())) { + String msg = obj.toString(); + sendStringByCompletion(msg, completion); + return; + } + if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) { + ByteBuffer msg = ByteBuffer.wrap((byte[]) obj); + sendBytesByCompletion(msg, completion); + return; + } + + try { + if (encoder instanceof Encoder.Text) { + String msg = ((Encoder.Text) encoder).encode(obj); + sendStringByCompletion(msg, completion); + } else if (encoder instanceof Encoder.TextStream) { + try (Writer w = getSendWriter()) { + ((Encoder.TextStream) encoder).encode(obj, w); + } + completion.onResult(new SendResult()); + } else if (encoder instanceof Encoder.Binary) { + ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj); + sendBytesByCompletion(msg, completion); + } else if (encoder instanceof Encoder.BinaryStream) { + try (OutputStream os = getSendStream()) { + ((Encoder.BinaryStream) encoder).encode(obj, os); + } + completion.onResult(new SendResult()); + } else { + throw new EncodeException(obj, sm.getString( + "wsRemoteEndpoint.noEncoder", obj.getClass())); + } + } catch (Exception e) { + SendResult sr = new SendResult(e); + completion.onResult(sr); + } + } + + + protected void setSession(WsSession wsSession) { + this.wsSession = wsSession; + } + + + protected void setRequest(Request request) { + this.request = request; + } + + protected void setEncoders(EndpointConfig endpointConfig) + throws DeploymentException { + encoderEntries.clear(); + for (Class encoderClazz : + endpointConfig.getEncoders()) { + Encoder instance; + try { + instance = encoderClazz.getConstructor().newInstance(); + instance.init(endpointConfig); + } catch (ReflectiveOperationException e) { + throw new DeploymentException( + sm.getString("wsRemoteEndpoint.invalidEncoder", + encoderClazz.getName()), e); + } + EncoderEntry entry = new EncoderEntry( + Util.getEncoderType(encoderClazz), instance); + encoderEntries.add(entry); + } + } + + + private Encoder findEncoder(Object obj) { + for (EncoderEntry entry : encoderEntries) { + if (entry.getClazz().isAssignableFrom(obj.getClass())) { + return entry.getEncoder(); + } + } + return null; + } + + + public final void close() { + for (EncoderEntry entry : encoderEntries) { + entry.getEncoder().destroy(); + } + + request.closeWs(); + } + + + protected abstract void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, + ByteBuffer... data); + protected abstract boolean isMasked(); + protected abstract void doClose(); + + private static void writeHeader(ByteBuffer headerBuffer, boolean fin, + int rsv, byte opCode, boolean masked, ByteBuffer payload, + byte[] mask, boolean first) { + + byte b = 0; + + if (fin) { + // Set the fin bit + b -= 128; + } + + b += (rsv << 4); + + if (first) { + // This is the first fragment of this message + b += opCode; + } + // If not the first fragment, it is a continuation with opCode of zero + + headerBuffer.put(b); + + if (masked) { + b = (byte) 0x80; + } else { + b = 0; + } + + // Next write the mask && length length + if (payload.limit() < 126) { + headerBuffer.put((byte) (payload.limit() | b)); + } else if (payload.limit() < 65536) { + headerBuffer.put((byte) (126 | b)); + headerBuffer.put((byte) (payload.limit() >>> 8)); + headerBuffer.put((byte) (payload.limit() & 0xFF)); + } else { + // Will never be more than 2^31-1 + headerBuffer.put((byte) (127 | b)); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) (payload.limit() >>> 24)); + headerBuffer.put((byte) (payload.limit() >>> 16)); + headerBuffer.put((byte) (payload.limit() >>> 8)); + headerBuffer.put((byte) (payload.limit() & 0xFF)); + } + if (masked) { + headerBuffer.put(mask[0]); + headerBuffer.put(mask[1]); + headerBuffer.put(mask[2]); + headerBuffer.put(mask[3]); + } + } + + + private class TextMessageSendHandler implements SendHandler { + + private final SendHandler handler; + private final CharBuffer message; + private final boolean isLast; + private final CharsetEncoder encoder; + private final ByteBuffer buffer; + private final WsRemoteEndpointImplBase endpoint; + private volatile boolean isDone = false; + + public TextMessageSendHandler(SendHandler handler, CharBuffer message, + boolean isLast, CharsetEncoder encoder, + ByteBuffer encoderBuffer, WsRemoteEndpointImplBase endpoint) { + this.handler = handler; + this.message = message; + this.isLast = isLast; + this.encoder = encoder.reset(); + this.buffer = encoderBuffer; + this.endpoint = endpoint; + } + + public void write() { + buffer.clear(); + CoderResult cr = encoder.encode(message, buffer, true); + if (cr.isError()) { + throw new IllegalArgumentException(cr.toString()); + } + isDone = !cr.isOverflow(); + buffer.flip(); + endpoint.startMessage(Constants.OPCODE_TEXT, buffer, + isDone && isLast, this); + } + + @Override + public void onResult(SendResult result) { + if (isDone) { + endpoint.stateMachine.complete(isLast); + handler.onResult(result); + } else if(!result.isOK()) { + handler.onResult(result); + } else if (closed){ + SendResult sr = new SendResult(new IOException( + sm.getString("wsRemoteEndpoint.closedDuringMessage"))); + handler.onResult(sr); + } else { + write(); + } + } + } + + + /** + * Used to write data to the output buffer, flushing the buffer if it fills + * up. + */ + private static class OutputBufferSendHandler implements SendHandler { + + private final SendHandler handler; + private final long blockingWriteTimeoutExpiry; + private final ByteBuffer headerBuffer; + private final ByteBuffer payload; + private final byte[] mask; + private final ByteBuffer outputBuffer; + private final boolean flushRequired; + private final WsRemoteEndpointImplBase endpoint; + private int maskIndex = 0; + + public OutputBufferSendHandler(SendHandler completion, + long blockingWriteTimeoutExpiry, + ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask, + ByteBuffer outputBuffer, boolean flushRequired, + WsRemoteEndpointImplBase endpoint) { + this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry; + this.handler = completion; + this.headerBuffer = headerBuffer; + this.payload = payload; + this.mask = mask; + this.outputBuffer = outputBuffer; + this.flushRequired = flushRequired; + this.endpoint = endpoint; + } + + public void write() { + // Write the header + while (headerBuffer.hasRemaining() && outputBuffer.hasRemaining()) { + outputBuffer.put(headerBuffer.get()); + } + if (headerBuffer.hasRemaining()) { + // Still more headers to write, need to flush + outputBuffer.flip(); + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + return; + } + + // Write the payload + int payloadLeft = payload.remaining(); + int payloadLimit = payload.limit(); + int outputSpace = outputBuffer.remaining(); + int toWrite = payloadLeft; + + if (payloadLeft > outputSpace) { + toWrite = outputSpace; + // Temporarily reduce the limit + payload.limit(payload.position() + toWrite); + } + + if (mask == null) { + // Use a bulk copy + outputBuffer.put(payload); + } else { + for (int i = 0; i < toWrite; i++) { + outputBuffer.put( + (byte) (payload.get() ^ (mask[maskIndex++] & 0xFF))); + if (maskIndex > 3) { + maskIndex = 0; + } + } + } + + if (payloadLeft > outputSpace) { + // Restore the original limit + payload.limit(payloadLimit); + // Still more data to write, need to flush + outputBuffer.flip(); + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + return; + } + + if (flushRequired) { + outputBuffer.flip(); + if (outputBuffer.remaining() == 0) { + handler.onResult(SENDRESULT_OK); + } else { + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + } + } else { + handler.onResult(SENDRESULT_OK); + } + } + + // ------------------------------------------------- SendHandler methods + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + if (outputBuffer.hasRemaining()) { + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + } else { + outputBuffer.clear(); + write(); + } + } else { + handler.onResult(result); + } + } + } + + + /** + * Ensures that the output buffer is cleared after it has been flushed. + */ + private static class OutputBufferFlushSendHandler implements SendHandler { + + private final ByteBuffer outputBuffer; + private final SendHandler handler; + + public OutputBufferFlushSendHandler(ByteBuffer outputBuffer, SendHandler handler) { + this.outputBuffer = outputBuffer; + this.handler = handler; + } + + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + outputBuffer.clear(); + } + handler.onResult(result); + } + } + + + private static class WsOutputStream extends OutputStream { + + private final WsRemoteEndpointImplBase endpoint; + private final ByteBuffer buffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final Object closeLock = new Object(); + private volatile boolean closed = false; + private volatile boolean used = false; + + public WsOutputStream(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + @Override + public void write(int b) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + buffer.put((byte) b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + if (len == 0) { + return; + } + if ((off < 0) || (off > b.length) || (len < 0) || + ((off + len) > b.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException(); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + int remaining = buffer.remaining(); + int written = 0; + + while (remaining < len - written) { + buffer.put(b, off + written, remaining); + written += remaining; + flush(); + remaining = buffer.remaining(); + } + buffer.put(b, off + written, len - written); + } + + @Override + public void flush() throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + + // Optimisation. If there is no data to flush then do not send an + // empty message. + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) { + doWrite(false); + } + } + + @Override + public void close() throws IOException { + synchronized (closeLock) { + if (closed) { + return; + } + closed = true; + } + + doWrite(true); + } + + private void doWrite(boolean last) throws IOException { + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) { + buffer.flip(); + endpoint.sendMessageBlock(Constants.OPCODE_BINARY, buffer, last); + } + endpoint.stateMachine.complete(last); + buffer.clear(); + } + } + + + private static class WsWriter extends Writer { + + private final WsRemoteEndpointImplBase endpoint; + private final CharBuffer buffer = CharBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final Object closeLock = new Object(); + private volatile boolean closed = false; + private volatile boolean used = false; + + public WsWriter(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + @Override + public void write(char[] cbuf, int off, int len) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedWriter")); + } + if (len == 0) { + return; + } + if ((off < 0) || (off > cbuf.length) || (len < 0) || + ((off + len) > cbuf.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException(); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + int remaining = buffer.remaining(); + int written = 0; + + while (remaining < len - written) { + buffer.put(cbuf, off + written, remaining); + written += remaining; + flush(); + remaining = buffer.remaining(); + } + buffer.put(cbuf, off + written, len - written); + } + + @Override + public void flush() throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedWriter")); + } + + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) { + doWrite(false); + } + } + + @Override + public void close() throws IOException { + synchronized (closeLock) { + if (closed) { + return; + } + closed = true; + } + + doWrite(true); + } + + private void doWrite(boolean last) throws IOException { + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) { + buffer.flip(); + endpoint.sendMessageBlock(buffer, last); + buffer.clear(); + } else { + endpoint.stateMachine.complete(last); + } + } + } + + + private static class EncoderEntry { + + private final Class clazz; + private final Encoder encoder; + + public EncoderEntry(Class clazz, Encoder encoder) { + this.clazz = clazz; + this.encoder = encoder; + } + + public Class getClazz() { + return clazz; + } + + public Encoder getEncoder() { + return encoder; + } + } + + + private enum State { + OPEN, + STREAM_WRITING, + WRITER_WRITING, + BINARY_PARTIAL_WRITING, + BINARY_PARTIAL_READY, + BINARY_FULL_WRITING, + TEXT_PARTIAL_WRITING, + TEXT_PARTIAL_READY, + TEXT_FULL_WRITING + } + + + private static class StateMachine { + private State state = State.OPEN; + + public synchronized void streamStart() { + checkState(State.OPEN); + state = State.STREAM_WRITING; + } + + public synchronized void writeStart() { + checkState(State.OPEN); + state = State.WRITER_WRITING; + } + + public synchronized void binaryPartialStart() { + checkState(State.OPEN, State.BINARY_PARTIAL_READY); + state = State.BINARY_PARTIAL_WRITING; + } + + public synchronized void binaryStart() { + checkState(State.OPEN); + state = State.BINARY_FULL_WRITING; + } + + public synchronized void textPartialStart() { + checkState(State.OPEN, State.TEXT_PARTIAL_READY); + state = State.TEXT_PARTIAL_WRITING; + } + + public synchronized void textStart() { + checkState(State.OPEN); + state = State.TEXT_FULL_WRITING; + } + + public synchronized void complete(boolean last) { + if (last) { + checkState(State.TEXT_PARTIAL_WRITING, State.TEXT_FULL_WRITING, + State.BINARY_PARTIAL_WRITING, State.BINARY_FULL_WRITING, + State.STREAM_WRITING, State.WRITER_WRITING); + state = State.OPEN; + } else { + checkState(State.TEXT_PARTIAL_WRITING, State.BINARY_PARTIAL_WRITING, + State.STREAM_WRITING, State.WRITER_WRITING); + if (state == State.TEXT_PARTIAL_WRITING) { + state = State.TEXT_PARTIAL_READY; + } else if (state == State.BINARY_PARTIAL_WRITING){ + state = State.BINARY_PARTIAL_READY; + } else if (state == State.WRITER_WRITING) { + // NO-OP. Leave state as is. + } else if (state == State.STREAM_WRITING) { + // NO-OP. Leave state as is. + } else { + // Should never happen + // The if ... else ... blocks above should cover all states + // permitted by the preceding checkState() call + throw new IllegalStateException( + "BUG: This code should never be called"); + } + } + } + + private void checkState(State... required) { + for (State state : required) { + if (this.state == state) { + return; + } + } + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.wrongState", this.state)); + } + } + + + private static class StateUpdateSendHandler implements SendHandler { + + private final SendHandler handler; + private final StateMachine stateMachine; + + public StateUpdateSendHandler(SendHandler handler, StateMachine stateMachine) { + this.handler = handler; + this.stateMachine = stateMachine; + } + + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + stateMachine.complete(true); + } + handler.onResult(result); + } + } + + + private static class BlockingSendHandler implements SendHandler { + + private SendResult sendResult = null; + + @Override + public void onResult(SendResult result) { + sendResult = result; + } + + public SendResult getSendResult() { + return sendResult; + } + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java b/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java new file mode 100644 index 00000000..70b66789 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +public class WsRemoteEndpointImplClient extends WsRemoteEndpointImplBase { + + private final AsyncChannelWrapper channel; + + public WsRemoteEndpointImplClient(AsyncChannelWrapper channel) { + this.channel = channel; + } + + + @Override + protected boolean isMasked() { + return true; + } + + + @Override + protected void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, + ByteBuffer... data) { + long timeout; + for (ByteBuffer byteBuffer : data) { + if (blockingWriteTimeoutExpiry == -1) { + timeout = getSendTimeout(); + if (timeout < 1) { + timeout = Long.MAX_VALUE; + } + } else { + timeout = blockingWriteTimeoutExpiry - System.currentTimeMillis(); + if (timeout < 0) { + SendResult sr = new SendResult(new IOException("Blocking write timeout")); + handler.onResult(sr); + } + } + + try { + channel.write(byteBuffer).get(timeout, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + handler.onResult(new SendResult(e)); + return; + } + } + handler.onResult(SENDRESULT_OK); + } + + @Override + protected void doClose() { + channel.close(); + } +} diff --git a/src/java/nginx/unit/websocket/WsSession.java b/src/java/nginx/unit/websocket/WsSession.java new file mode 100644 index 00000000..b654eb37 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsSession.java @@ -0,0 +1,1070 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.CharBuffer; +import java.nio.channels.WritePendingException; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; +import java.nio.charset.CodingErrorAction; +import java.nio.charset.StandardCharsets; +import java.security.Principal; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCode; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; +import javax.websocket.MessageHandler; +import javax.websocket.MessageHandler.Partial; +import javax.websocket.MessageHandler.Whole; +import javax.websocket.PongMessage; +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendResult; +import javax.websocket.Session; +import javax.websocket.WebSocketContainer; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.InstanceManager; +import org.apache.tomcat.InstanceManagerBindings; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.buf.Utf8Decoder; +import org.apache.tomcat.util.res.StringManager; + +import nginx.unit.Request; + +public class WsSession implements Session { + + // An ellipsis is a single character that looks like three periods in a row + // and is used to indicate a continuation. + private static final byte[] ELLIPSIS_BYTES = "\u2026".getBytes(StandardCharsets.UTF_8); + // An ellipsis is three bytes in UTF-8 + private static final int ELLIPSIS_BYTES_LEN = ELLIPSIS_BYTES.length; + + private static final StringManager sm = StringManager.getManager(WsSession.class); + private static AtomicLong ids = new AtomicLong(0); + + private final Log log = LogFactory.getLog(WsSession.class); // must not be static + + private final CharsetDecoder utf8DecoderMessage = new Utf8Decoder(). + onMalformedInput(CodingErrorAction.REPORT). + onUnmappableCharacter(CodingErrorAction.REPORT); + + private final Endpoint localEndpoint; + private final WsRemoteEndpointImplBase wsRemoteEndpoint; + private final RemoteEndpoint.Async remoteEndpointAsync; + private final RemoteEndpoint.Basic remoteEndpointBasic; + private final ClassLoader applicationClassLoader; + private final WsWebSocketContainer webSocketContainer; + private final URI requestUri; + private final Map> requestParameterMap; + private final String queryString; + private final Principal userPrincipal; + private final EndpointConfig endpointConfig; + + private final List negotiatedExtensions; + private final String subProtocol; + private final Map pathParameters; + private final boolean secure; + private final String httpSessionId; + private final String id; + + // Expected to handle message types of only + private volatile MessageHandler textMessageHandler = null; + // Expected to handle message types of only + private volatile MessageHandler binaryMessageHandler = null; + private volatile MessageHandler.Whole pongMessageHandler = null; + private volatile State state = State.OPEN; + private final Object stateLock = new Object(); + private final Map userProperties = new ConcurrentHashMap<>(); + private volatile int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private volatile int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private volatile long maxIdleTimeout = 0; + private volatile long lastActive = System.currentTimeMillis(); + private Map futures = new ConcurrentHashMap<>(); + + private CharBuffer messageBufferText; + private ByteBuffer binaryBuffer; + private byte startOpCode = Constants.OPCODE_CONTINUATION; + + /** + * Creates a new WebSocket session for communication between the two + * provided end points. The result of {@link Thread#getContextClassLoader()} + * at the time this constructor is called will be used when calling + * {@link Endpoint#onClose(Session, CloseReason)}. + * + * @param localEndpoint The end point managed by this code + * @param wsRemoteEndpoint The other / remote endpoint + * @param wsWebSocketContainer The container that created this session + * @param requestUri The URI used to connect to this endpoint or + * null is this is a client session + * @param requestParameterMap The parameters associated with the request + * that initiated this session or + * null if this is a client session + * @param queryString The query string associated with the request + * that initiated this session or + * null if this is a client session + * @param userPrincipal The principal associated with the request + * that initiated this session or + * null if this is a client session + * @param httpSessionId The HTTP session ID associated with the + * request that initiated this session or + * null if this is a client session + * @param negotiatedExtensions The agreed extensions to use for this session + * @param subProtocol The agreed subprotocol to use for this + * session + * @param pathParameters The path parameters associated with the + * request that initiated this session or + * null if this is a client session + * @param secure Was this session initiated over a secure + * connection? + * @param endpointConfig The configuration information for the + * endpoint + * @throws DeploymentException if an invalid encode is specified + */ + public WsSession(Endpoint localEndpoint, + WsRemoteEndpointImplBase wsRemoteEndpoint, + WsWebSocketContainer wsWebSocketContainer, + URI requestUri, Map> requestParameterMap, + String queryString, Principal userPrincipal, String httpSessionId, + List negotiatedExtensions, String subProtocol, Map pathParameters, + boolean secure, EndpointConfig endpointConfig, + Request request) throws DeploymentException { + this.localEndpoint = localEndpoint; + this.wsRemoteEndpoint = wsRemoteEndpoint; + this.wsRemoteEndpoint.setSession(this); + this.wsRemoteEndpoint.setRequest(request); + + request.setWsSession(this); + + this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint); + this.remoteEndpointBasic = new WsRemoteEndpointBasic(wsRemoteEndpoint); + this.webSocketContainer = wsWebSocketContainer; + applicationClassLoader = Thread.currentThread().getContextClassLoader(); + wsRemoteEndpoint.setSendTimeout(wsWebSocketContainer.getDefaultAsyncSendTimeout()); + this.maxBinaryMessageBufferSize = webSocketContainer.getDefaultMaxBinaryMessageBufferSize(); + this.maxTextMessageBufferSize = webSocketContainer.getDefaultMaxTextMessageBufferSize(); + this.maxIdleTimeout = webSocketContainer.getDefaultMaxSessionIdleTimeout(); + this.requestUri = requestUri; + if (requestParameterMap == null) { + this.requestParameterMap = Collections.emptyMap(); + } else { + this.requestParameterMap = requestParameterMap; + } + this.queryString = queryString; + this.userPrincipal = userPrincipal; + this.httpSessionId = httpSessionId; + this.negotiatedExtensions = negotiatedExtensions; + if (subProtocol == null) { + this.subProtocol = ""; + } else { + this.subProtocol = subProtocol; + } + this.pathParameters = pathParameters; + this.secure = secure; + this.wsRemoteEndpoint.setEncoders(endpointConfig); + this.endpointConfig = endpointConfig; + + this.userProperties.putAll(endpointConfig.getUserProperties()); + this.id = Long.toHexString(ids.getAndIncrement()); + + InstanceManager instanceManager = webSocketContainer.getInstanceManager(); + if (instanceManager == null) { + instanceManager = InstanceManagerBindings.get(applicationClassLoader); + } + if (instanceManager != null) { + try { + instanceManager.newInstance(localEndpoint); + } catch (Exception e) { + throw new DeploymentException(sm.getString("wsSession.instanceNew"), e); + } + } + + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsSession.created", id)); + } + + messageBufferText = CharBuffer.allocate(maxTextMessageBufferSize); + } + + public static String wsSession_test() { + return sm.getString("wsSession.instanceNew"); + } + + + @Override + public WebSocketContainer getContainer() { + checkState(); + return webSocketContainer; + } + + + @Override + public void addMessageHandler(MessageHandler listener) { + Class target = Util.getMessageType(listener); + doAddMessageHandler(target, listener); + } + + + @Override + public void addMessageHandler(Class clazz, Partial handler) + throws IllegalStateException { + doAddMessageHandler(clazz, handler); + } + + + @Override + public void addMessageHandler(Class clazz, Whole handler) + throws IllegalStateException { + doAddMessageHandler(clazz, handler); + } + + + @SuppressWarnings("unchecked") + private void doAddMessageHandler(Class target, MessageHandler listener) { + checkState(); + + // Message handlers that require decoders may map to text messages, + // binary messages, both or neither. + + // The frame processing code expects binary message handlers to + // accept ByteBuffer + + // Use the POJO message handler wrappers as they are designed to wrap + // arbitrary objects with MessageHandlers and can wrap MessageHandlers + // just as easily. + + Set mhResults = Util.getMessageHandlers(target, listener, + endpointConfig, this); + + for (MessageHandlerResult mhResult : mhResults) { + switch (mhResult.getType()) { + case TEXT: { + if (textMessageHandler != null) { + throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerText")); + } + textMessageHandler = mhResult.getHandler(); + break; + } + case BINARY: { + if (binaryMessageHandler != null) { + throw new IllegalStateException( + sm.getString("wsSession.duplicateHandlerBinary")); + } + binaryMessageHandler = mhResult.getHandler(); + break; + } + case PONG: { + if (pongMessageHandler != null) { + throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerPong")); + } + MessageHandler handler = mhResult.getHandler(); + if (handler instanceof MessageHandler.Whole) { + pongMessageHandler = (MessageHandler.Whole) handler; + } else { + throw new IllegalStateException( + sm.getString("wsSession.invalidHandlerTypePong")); + } + + break; + } + default: { + throw new IllegalArgumentException( + sm.getString("wsSession.unknownHandlerType", listener, mhResult.getType())); + } + } + } + } + + + @Override + public Set getMessageHandlers() { + checkState(); + Set result = new HashSet<>(); + if (binaryMessageHandler != null) { + result.add(binaryMessageHandler); + } + if (textMessageHandler != null) { + result.add(textMessageHandler); + } + if (pongMessageHandler != null) { + result.add(pongMessageHandler); + } + return result; + } + + + @Override + public void removeMessageHandler(MessageHandler listener) { + checkState(); + if (listener == null) { + return; + } + + MessageHandler wrapped = null; + + if (listener instanceof WrappedMessageHandler) { + wrapped = ((WrappedMessageHandler) listener).getWrappedHandler(); + } + + if (wrapped == null) { + wrapped = listener; + } + + boolean removed = false; + if (wrapped.equals(textMessageHandler) || listener.equals(textMessageHandler)) { + textMessageHandler = null; + removed = true; + } + + if (wrapped.equals(binaryMessageHandler) || listener.equals(binaryMessageHandler)) { + binaryMessageHandler = null; + removed = true; + } + + if (wrapped.equals(pongMessageHandler) || listener.equals(pongMessageHandler)) { + pongMessageHandler = null; + removed = true; + } + + if (!removed) { + // ISE for now. Could swallow this silently / log this if the ISE + // becomes a problem + throw new IllegalStateException( + sm.getString("wsSession.removeHandlerFailed", listener)); + } + } + + + @Override + public String getProtocolVersion() { + checkState(); + return Constants.WS_VERSION_HEADER_VALUE; + } + + + @Override + public String getNegotiatedSubprotocol() { + checkState(); + return subProtocol; + } + + + @Override + public List getNegotiatedExtensions() { + checkState(); + return negotiatedExtensions; + } + + + @Override + public boolean isSecure() { + checkState(); + return secure; + } + + + @Override + public boolean isOpen() { + return state == State.OPEN; + } + + + @Override + public long getMaxIdleTimeout() { + checkState(); + return maxIdleTimeout; + } + + + @Override + public void setMaxIdleTimeout(long timeout) { + checkState(); + this.maxIdleTimeout = timeout; + } + + + @Override + public void setMaxBinaryMessageBufferSize(int max) { + checkState(); + this.maxBinaryMessageBufferSize = max; + } + + + @Override + public int getMaxBinaryMessageBufferSize() { + checkState(); + return maxBinaryMessageBufferSize; + } + + + @Override + public void setMaxTextMessageBufferSize(int max) { + checkState(); + this.maxTextMessageBufferSize = max; + } + + + @Override + public int getMaxTextMessageBufferSize() { + checkState(); + return maxTextMessageBufferSize; + } + + + @Override + public Set getOpenSessions() { + checkState(); + return webSocketContainer.getOpenSessions(localEndpoint); + } + + + @Override + public RemoteEndpoint.Async getAsyncRemote() { + checkState(); + return remoteEndpointAsync; + } + + + @Override + public RemoteEndpoint.Basic getBasicRemote() { + checkState(); + return remoteEndpointBasic; + } + + + @Override + public void close() throws IOException { + close(new CloseReason(CloseCodes.NORMAL_CLOSURE, "")); + } + + + @Override + public void close(CloseReason closeReason) throws IOException { + doClose(closeReason, closeReason); + } + + + /** + * WebSocket 1.0. Section 2.1.5. + * Need internal close method as spec requires that the local endpoint + * receives a 1006 on timeout. + * + * @param closeReasonMessage The close reason to pass to the remote endpoint + * @param closeReasonLocal The close reason to pass to the local endpoint + */ + public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal) { + // Double-checked locking. OK because state is volatile + if (state != State.OPEN) { + return; + } + + synchronized (stateLock) { + if (state != State.OPEN) { + return; + } + + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsSession.doClose", id)); + } + try { + wsRemoteEndpoint.setBatchingAllowed(false); + } catch (IOException e) { + log.warn(sm.getString("wsSession.flushFailOnClose"), e); + fireEndpointOnError(e); + } + + state = State.OUTPUT_CLOSED; + + sendCloseMessage(closeReasonMessage); + fireEndpointOnClose(closeReasonLocal); + } + + IOException ioe = new IOException(sm.getString("wsSession.messageFailed")); + SendResult sr = new SendResult(ioe); + for (FutureToSendHandler f2sh : futures.keySet()) { + f2sh.onResult(sr); + } + } + + + /** + * Called when a close message is received. Should only ever happen once. + * Also called after a protocol error when the ProtocolHandler needs to + * force the closing of the connection. + * + * @param closeReason The reason contained within the received close + * message. + */ + public void onClose(CloseReason closeReason) { + + synchronized (stateLock) { + if (state != State.CLOSED) { + try { + wsRemoteEndpoint.setBatchingAllowed(false); + } catch (IOException e) { + log.warn(sm.getString("wsSession.flushFailOnClose"), e); + fireEndpointOnError(e); + } + if (state == State.OPEN) { + state = State.OUTPUT_CLOSED; + sendCloseMessage(closeReason); + fireEndpointOnClose(closeReason); + } + state = State.CLOSED; + + // Close the socket + wsRemoteEndpoint.close(); + } + } + } + + + public void onClose() { + + synchronized (stateLock) { + if (state != State.CLOSED) { + try { + wsRemoteEndpoint.setBatchingAllowed(false); + } catch (IOException e) { + log.warn(sm.getString("wsSession.flushFailOnClose"), e); + fireEndpointOnError(e); + } + if (state == State.OPEN) { + state = State.OUTPUT_CLOSED; + fireEndpointOnClose(new CloseReason( + CloseReason.CloseCodes.NORMAL_CLOSURE, "")); + } + state = State.CLOSED; + + // Close the socket + wsRemoteEndpoint.close(); + } + } + } + + + private void fireEndpointOnClose(CloseReason closeReason) { + + // Fire the onClose event + Throwable throwable = null; + InstanceManager instanceManager = webSocketContainer.getInstanceManager(); + if (instanceManager == null) { + instanceManager = InstanceManagerBindings.get(applicationClassLoader); + } + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); + try { + localEndpoint.onClose(this, closeReason); + } catch (Throwable t1) { + ExceptionUtils.handleThrowable(t1); + throwable = t1; + } finally { + if (instanceManager != null) { + try { + instanceManager.destroyInstance(localEndpoint); + } catch (Throwable t2) { + ExceptionUtils.handleThrowable(t2); + if (throwable == null) { + throwable = t2; + } + } + } + t.setContextClassLoader(cl); + } + + if (throwable != null) { + fireEndpointOnError(throwable); + } + } + + + private void fireEndpointOnError(Throwable throwable) { + + // Fire the onError event + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); + try { + localEndpoint.onError(this, throwable); + } finally { + t.setContextClassLoader(cl); + } + } + + + private void sendCloseMessage(CloseReason closeReason) { + // 125 is maximum size for the payload of a control message + ByteBuffer msg = ByteBuffer.allocate(125); + CloseCode closeCode = closeReason.getCloseCode(); + // CLOSED_ABNORMALLY should not be put on the wire + if (closeCode == CloseCodes.CLOSED_ABNORMALLY) { + // PROTOCOL_ERROR is probably better than GOING_AWAY here + msg.putShort((short) CloseCodes.PROTOCOL_ERROR.getCode()); + } else { + msg.putShort((short) closeCode.getCode()); + } + + String reason = closeReason.getReasonPhrase(); + if (reason != null && reason.length() > 0) { + appendCloseReasonWithTruncation(msg, reason); + } + msg.flip(); + try { + wsRemoteEndpoint.sendMessageBlock(Constants.OPCODE_CLOSE, msg, true); + } catch (IOException | WritePendingException e) { + // Failed to send close message. Close the socket and let the caller + // deal with the Exception + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsSession.sendCloseFail", id), e); + } + wsRemoteEndpoint.close(); + // Failure to send a close message is not unexpected in the case of + // an abnormal closure (usually triggered by a failure to read/write + // from/to the client. In this case do not trigger the endpoint's + // error handling + if (closeCode != CloseCodes.CLOSED_ABNORMALLY) { + localEndpoint.onError(this, e); + } + } finally { + webSocketContainer.unregisterSession(localEndpoint, this); + } + } + + + /** + * Use protected so unit tests can access this method directly. + * @param msg The message + * @param reason The reason + */ + protected static void appendCloseReasonWithTruncation(ByteBuffer msg, String reason) { + // Once the close code has been added there are a maximum of 123 bytes + // left for the reason phrase. If it is truncated then care needs to be + // taken to ensure the bytes are not truncated in the middle of a + // multi-byte UTF-8 character. + byte[] reasonBytes = reason.getBytes(StandardCharsets.UTF_8); + + if (reasonBytes.length <= 123) { + // No need to truncate + msg.put(reasonBytes); + } else { + // Need to truncate + int remaining = 123 - ELLIPSIS_BYTES_LEN; + int pos = 0; + byte[] bytesNext = reason.substring(pos, pos + 1).getBytes(StandardCharsets.UTF_8); + while (remaining >= bytesNext.length) { + msg.put(bytesNext); + remaining -= bytesNext.length; + pos++; + bytesNext = reason.substring(pos, pos + 1).getBytes(StandardCharsets.UTF_8); + } + msg.put(ELLIPSIS_BYTES); + } + } + + + /** + * Make the session aware of a {@link FutureToSendHandler} that will need to + * be forcibly closed if the session closes before the + * {@link FutureToSendHandler} completes. + * @param f2sh The handler + */ + protected void registerFuture(FutureToSendHandler f2sh) { + // Ideally, this code should sync on stateLock so that the correct + // action is taken based on the current state of the connection. + // However, a sync on stateLock can't be used here as it will create the + // possibility of a dead-lock. See BZ 61183. + // Therefore, a slightly less efficient approach is used. + + // Always register the future. + futures.put(f2sh, f2sh); + + if (state == State.OPEN) { + // The session is open. The future has been registered with the open + // session. Normal processing continues. + return; + } + + // The session is closed. The future may or may not have been registered + // in time for it to be processed during session closure. + + if (f2sh.isDone()) { + // The future has completed. It is not known if the future was + // completed normally by the I/O layer or in error by doClose(). It + // doesn't matter which. There is nothing more to do here. + return; + } + + // The session is closed. The Future had not completed when last checked. + // There is a small timing window that means the Future may have been + // completed since the last check. There is also the possibility that + // the Future was not registered in time to be cleaned up during session + // close. + // Attempt to complete the Future with an error result as this ensures + // that the Future completes and any client code waiting on it does not + // hang. It is slightly inefficient since the Future may have been + // completed in another thread or another thread may be about to + // complete the Future but knowing if this is the case requires the sync + // on stateLock (see above). + // Note: If multiple attempts are made to complete the Future, the + // second and subsequent attempts are ignored. + + IOException ioe = new IOException(sm.getString("wsSession.messageFailed")); + SendResult sr = new SendResult(ioe); + f2sh.onResult(sr); + } + + + /** + * Remove a {@link FutureToSendHandler} from the set of tracked instances. + * @param f2sh The handler + */ + protected void unregisterFuture(FutureToSendHandler f2sh) { + futures.remove(f2sh); + } + + + @Override + public URI getRequestURI() { + checkState(); + return requestUri; + } + + + @Override + public Map> getRequestParameterMap() { + checkState(); + return requestParameterMap; + } + + + @Override + public String getQueryString() { + checkState(); + return queryString; + } + + + @Override + public Principal getUserPrincipal() { + checkState(); + return userPrincipal; + } + + + @Override + public Map getPathParameters() { + checkState(); + return pathParameters; + } + + + @Override + public String getId() { + return id; + } + + + @Override + public Map getUserProperties() { + checkState(); + return userProperties; + } + + + public Endpoint getLocal() { + return localEndpoint; + } + + + public String getHttpSessionId() { + return httpSessionId; + } + + private ByteBuffer rawFragments; + + public void processFrame(ByteBuffer buf, byte opCode, boolean last) + throws IOException + { + if (state == State.CLOSED) { + return; + } + + if (opCode == Constants.OPCODE_CONTINUATION) { + opCode = startOpCode; + + if (rawFragments != null && rawFragments.position() > 0) { + rawFragments.put(buf); + rawFragments.flip(); + buf = rawFragments; + } + } else { + if (!last && (opCode == Constants.OPCODE_BINARY || + opCode == Constants.OPCODE_TEXT)) { + startOpCode = opCode; + + if (rawFragments != null) { + rawFragments.clear(); + } + } + } + + if (last) { + startOpCode = Constants.OPCODE_CONTINUATION; + } + + if (opCode == Constants.OPCODE_PONG) { + if (pongMessageHandler != null) { + final ByteBuffer b = buf; + + PongMessage pongMessage = new PongMessage() { + @Override + public ByteBuffer getApplicationData() { + return b; + } + }; + + pongMessageHandler.onMessage(pongMessage); + } + } + + if (opCode == Constants.OPCODE_CLOSE) { + CloseReason closeReason; + + if (buf.remaining() >= 2) { + short closeCode = buf.order(ByteOrder.BIG_ENDIAN).getShort(); + + closeReason = new CloseReason( + CloseReason.CloseCodes.getCloseCode(closeCode), + buf.asCharBuffer().toString()); + } else { + closeReason = new CloseReason( + CloseReason.CloseCodes.NORMAL_CLOSURE, ""); + } + + onClose(closeReason); + } + + if (opCode == Constants.OPCODE_BINARY) { + onMessage(buf, last); + } + + if (opCode == Constants.OPCODE_TEXT) { + if (messageBufferText.position() == 0 && maxTextMessageBufferSize != messageBufferText.capacity()) { + messageBufferText = CharBuffer.allocate(maxTextMessageBufferSize); + } + + CoderResult cr = utf8DecoderMessage.decode(buf, messageBufferText, last); + if (cr.isError()) { + throw new WsIOException(new CloseReason( + CloseCodes.NOT_CONSISTENT, + sm.getString("wsFrame.invalidUtf8"))); + } else if (cr.isOverflow()) { + // Ran out of space in text buffer - flush it + if (hasTextPartial()) { + do { + onMessage(messageBufferText, false); + + cr = utf8DecoderMessage.decode(buf, messageBufferText, last); + } while (cr.isOverflow()); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + } else if (cr.isUnderflow() && !last) { + updateRawFragments(buf, last); + + if (hasTextPartial()) { + onMessage(messageBufferText, false); + } + + return; + } + + if (last) { + utf8DecoderMessage.reset(); + } + + updateRawFragments(buf, last); + + onMessage(messageBufferText, last); + } + } + + + private boolean hasTextPartial() { + return textMessageHandler instanceof MessageHandler.Partial; + } + + + private void onMessage(CharBuffer buf, boolean last) throws IOException { + buf.flip(); + try { + onMessage(buf.toString(), last); + } catch (Throwable t) { + handleThrowableOnSend(t); + } finally { + buf.clear(); + } + } + + + private void updateRawFragments(ByteBuffer buf, boolean last) { + if (!last && buf.remaining() > 0) { + if (buf == rawFragments) { + buf.compact(); + } else { + if (rawFragments == null || (rawFragments.position() == 0 && maxTextMessageBufferSize != rawFragments.capacity())) { + rawFragments = ByteBuffer.allocateDirect(maxTextMessageBufferSize); + } + rawFragments.put(buf); + } + } else { + if (rawFragments != null) { + rawFragments.clear(); + } + } + } + + + @SuppressWarnings("unchecked") + public void onMessage(String text, boolean last) { + if (hasTextPartial()) { + ((MessageHandler.Partial) textMessageHandler).onMessage(text, last); + } else { + // Caller ensures last == true if this branch is used + ((MessageHandler.Whole) textMessageHandler).onMessage(text); + } + } + + + @SuppressWarnings("unchecked") + public void onMessage(ByteBuffer buf, boolean last) + throws IOException + { + if (binaryMessageHandler instanceof MessageHandler.Partial) { + ((MessageHandler.Partial) binaryMessageHandler).onMessage(buf, last); + } else { + if (last && (binaryBuffer == null || binaryBuffer.position() == 0)) { + ((MessageHandler.Whole) binaryMessageHandler).onMessage(buf); + return; + } + + if (binaryBuffer == null || + (binaryBuffer.position() == 0 && binaryBuffer.capacity() != maxBinaryMessageBufferSize)) + { + binaryBuffer = ByteBuffer.allocateDirect(maxBinaryMessageBufferSize); + } + + if (binaryBuffer.remaining() < buf.remaining()) { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + + binaryBuffer.put(buf); + + if (last) { + binaryBuffer.flip(); + try { + ((MessageHandler.Whole) binaryMessageHandler).onMessage(binaryBuffer); + } finally { + binaryBuffer.clear(); + } + } + } + } + + + private void handleThrowableOnSend(Throwable t) throws WsIOException { + ExceptionUtils.handleThrowable(t); + getLocal().onError(this, t); + CloseReason cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY, + sm.getString("wsFrame.ioeTriggeredClose")); + throw new WsIOException(cr); + } + + + protected MessageHandler getTextMessageHandler() { + return textMessageHandler; + } + + + protected MessageHandler getBinaryMessageHandler() { + return binaryMessageHandler; + } + + + protected MessageHandler.Whole getPongMessageHandler() { + return pongMessageHandler; + } + + + protected void updateLastActive() { + lastActive = System.currentTimeMillis(); + } + + + protected void checkExpiration() { + long timeout = maxIdleTimeout; + if (timeout < 1) { + return; + } + + if (System.currentTimeMillis() - lastActive > timeout) { + String msg = sm.getString("wsSession.timeout", getId()); + if (log.isDebugEnabled()) { + log.debug(msg); + } + doClose(new CloseReason(CloseCodes.GOING_AWAY, msg), + new CloseReason(CloseCodes.CLOSED_ABNORMALLY, msg)); + } + } + + + private void checkState() { + if (state == State.CLOSED) { + /* + * As per RFC 6455, a WebSocket connection is considered to be + * closed once a peer has sent and received a WebSocket close frame. + */ + throw new IllegalStateException(sm.getString("wsSession.closed", id)); + } + } + + private enum State { + OPEN, + OUTPUT_CLOSED, + CLOSED + } +} diff --git a/src/java/nginx/unit/websocket/WsWebSocketContainer.java b/src/java/nginx/unit/websocket/WsWebSocketContainer.java new file mode 100644 index 00000000..282665ef --- /dev/null +++ b/src/java/nginx/unit/websocket/WsWebSocketContainer.java @@ -0,0 +1,1123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.EOFException; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.net.ProxySelector; +import java.net.SocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousChannelGroup; +import java.nio.channels.AsynchronousSocketChannel; +import java.nio.charset.StandardCharsets; +import java.security.KeyStore; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.TrustManagerFactory; +import javax.websocket.ClientEndpoint; +import javax.websocket.ClientEndpointConfig; +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; +import javax.websocket.Session; +import javax.websocket.WebSocketContainer; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.InstanceManager; +import org.apache.tomcat.util.buf.StringUtils; +import org.apache.tomcat.util.codec.binary.Base64; +import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.pojo.PojoEndpointClient; + +public class WsWebSocketContainer implements WebSocketContainer, BackgroundProcess { + + private static final StringManager sm = StringManager.getManager(WsWebSocketContainer.class); + private static final Random RANDOM = new Random(); + private static final byte[] CRLF = new byte[] { 13, 10 }; + + private static final byte[] GET_BYTES = "GET ".getBytes(StandardCharsets.ISO_8859_1); + private static final byte[] ROOT_URI_BYTES = "/".getBytes(StandardCharsets.ISO_8859_1); + private static final byte[] HTTP_VERSION_BYTES = + " HTTP/1.1\r\n".getBytes(StandardCharsets.ISO_8859_1); + + private volatile AsynchronousChannelGroup asynchronousChannelGroup = null; + private final Object asynchronousChannelGroupLock = new Object(); + + private final Log log = LogFactory.getLog(WsWebSocketContainer.class); // must not be static + private final Map> endpointSessionMap = + new HashMap<>(); + private final Map sessions = new ConcurrentHashMap<>(); + private final Object endPointSessionMapLock = new Object(); + + private long defaultAsyncTimeout = -1; + private int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private volatile long defaultMaxSessionIdleTimeout = 0; + private int backgroundProcessCount = 0; + private int processPeriod = Constants.DEFAULT_PROCESS_PERIOD; + + private InstanceManager instanceManager; + + InstanceManager getInstanceManager() { + return instanceManager; + } + + protected void setInstanceManager(InstanceManager instanceManager) { + this.instanceManager = instanceManager; + } + + @Override + public Session connectToServer(Object pojo, URI path) + throws DeploymentException { + + ClientEndpoint annotation = + pojo.getClass().getAnnotation(ClientEndpoint.class); + if (annotation == null) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.missingAnnotation", + pojo.getClass().getName())); + } + + Endpoint ep = new PojoEndpointClient(pojo, Arrays.asList(annotation.decoders())); + + Class configuratorClazz = + annotation.configurator(); + + ClientEndpointConfig.Configurator configurator = null; + if (!ClientEndpointConfig.Configurator.class.equals( + configuratorClazz)) { + try { + configurator = configuratorClazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.defaultConfiguratorFail"), e); + } + } + + ClientEndpointConfig.Builder builder = ClientEndpointConfig.Builder.create(); + // Avoid NPE when using RI API JAR - see BZ 56343 + if (configurator != null) { + builder.configurator(configurator); + } + ClientEndpointConfig config = builder. + decoders(Arrays.asList(annotation.decoders())). + encoders(Arrays.asList(annotation.encoders())). + preferredSubprotocols(Arrays.asList(annotation.subprotocols())). + build(); + return connectToServer(ep, config, path); + } + + + @Override + public Session connectToServer(Class annotatedEndpointClass, URI path) + throws DeploymentException { + + Object pojo; + try { + pojo = annotatedEndpointClass.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.endpointCreateFail", + annotatedEndpointClass.getName()), e); + } + + return connectToServer(pojo, path); + } + + + @Override + public Session connectToServer(Class clazz, + ClientEndpointConfig clientEndpointConfiguration, URI path) + throws DeploymentException { + + Endpoint endpoint; + try { + endpoint = clazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.endpointCreateFail", clazz.getName()), + e); + } + + return connectToServer(endpoint, clientEndpointConfiguration, path); + } + + + @Override + public Session connectToServer(Endpoint endpoint, + ClientEndpointConfig clientEndpointConfiguration, URI path) + throws DeploymentException { + return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, new HashSet<>()); + } + + private Session connectToServerRecursive(Endpoint endpoint, + ClientEndpointConfig clientEndpointConfiguration, URI path, + Set redirectSet) + throws DeploymentException { + + boolean secure = false; + ByteBuffer proxyConnect = null; + URI proxyPath; + + // Validate scheme (and build proxyPath) + String scheme = path.getScheme(); + if ("ws".equalsIgnoreCase(scheme)) { + proxyPath = URI.create("http" + path.toString().substring(2)); + } else if ("wss".equalsIgnoreCase(scheme)) { + proxyPath = URI.create("https" + path.toString().substring(3)); + secure = true; + } else { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.pathWrongScheme", scheme)); + } + + // Validate host + String host = path.getHost(); + if (host == null) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.pathNoHost")); + } + int port = path.getPort(); + + SocketAddress sa = null; + + // Check to see if a proxy is configured. Javadoc indicates return value + // will never be null + List proxies = ProxySelector.getDefault().select(proxyPath); + Proxy selectedProxy = null; + for (Proxy proxy : proxies) { + if (proxy.type().equals(Proxy.Type.HTTP)) { + sa = proxy.address(); + if (sa instanceof InetSocketAddress) { + InetSocketAddress inet = (InetSocketAddress) sa; + if (inet.isUnresolved()) { + sa = new InetSocketAddress(inet.getHostName(), inet.getPort()); + } + } + selectedProxy = proxy; + break; + } + } + + // If the port is not explicitly specified, compute it based on the + // scheme + if (port == -1) { + if ("ws".equalsIgnoreCase(scheme)) { + port = 80; + } else { + // Must be wss due to scheme validation above + port = 443; + } + } + + // If sa is null, no proxy is configured so need to create sa + if (sa == null) { + sa = new InetSocketAddress(host, port); + } else { + proxyConnect = createProxyRequest(host, port); + } + + // Create the initial HTTP request to open the WebSocket connection + Map> reqHeaders = createRequestHeaders(host, port, + clientEndpointConfiguration); + clientEndpointConfiguration.getConfigurator().beforeRequest(reqHeaders); + if (Constants.DEFAULT_ORIGIN_HEADER_VALUE != null + && !reqHeaders.containsKey(Constants.ORIGIN_HEADER_NAME)) { + List originValues = new ArrayList<>(1); + originValues.add(Constants.DEFAULT_ORIGIN_HEADER_VALUE); + reqHeaders.put(Constants.ORIGIN_HEADER_NAME, originValues); + } + ByteBuffer request = createRequest(path, reqHeaders); + + AsynchronousSocketChannel socketChannel; + try { + socketChannel = AsynchronousSocketChannel.open(getAsynchronousChannelGroup()); + } catch (IOException ioe) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.asynchronousSocketChannelFail"), ioe); + } + + Map userProperties = clientEndpointConfiguration.getUserProperties(); + + // Get the connection timeout + long timeout = Constants.IO_TIMEOUT_MS_DEFAULT; + String timeoutValue = (String) userProperties.get(Constants.IO_TIMEOUT_MS_PROPERTY); + if (timeoutValue != null) { + timeout = Long.valueOf(timeoutValue).intValue(); + } + + // Set-up + // Same size as the WsFrame input buffer + ByteBuffer response = ByteBuffer.allocate(getDefaultMaxBinaryMessageBufferSize()); + String subProtocol; + boolean success = false; + List extensionsAgreed = new ArrayList<>(); + Transformation transformation = null; + + // Open the connection + Future fConnect = socketChannel.connect(sa); + AsyncChannelWrapper channel = null; + + if (proxyConnect != null) { + try { + fConnect.get(timeout, TimeUnit.MILLISECONDS); + // Proxy CONNECT is clear text + channel = new AsyncChannelWrapperNonSecure(socketChannel); + writeRequest(channel, proxyConnect, timeout); + HttpResponse httpResponse = processResponse(response, channel, timeout); + if (httpResponse.getStatus() != 200) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.proxyConnectFail", selectedProxy, + Integer.toString(httpResponse.getStatus()))); + } + } catch (TimeoutException | InterruptedException | ExecutionException | + EOFException e) { + if (channel != null) { + channel.close(); + } + throw new DeploymentException( + sm.getString("wsWebSocketContainer.httpRequestFailed"), e); + } + } + + if (secure) { + // Regardless of whether a non-secure wrapper was created for a + // proxy CONNECT, need to use TLS from this point on so wrap the + // original AsynchronousSocketChannel + SSLEngine sslEngine = createSSLEngine(userProperties, host, port); + channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine); + } else if (channel == null) { + // Only need to wrap as this point if it wasn't wrapped to process a + // proxy CONNECT + channel = new AsyncChannelWrapperNonSecure(socketChannel); + } + + try { + fConnect.get(timeout, TimeUnit.MILLISECONDS); + + Future fHandshake = channel.handshake(); + fHandshake.get(timeout, TimeUnit.MILLISECONDS); + + writeRequest(channel, request, timeout); + + HttpResponse httpResponse = processResponse(response, channel, timeout); + + // Check maximum permitted redirects + int maxRedirects = Constants.MAX_REDIRECTIONS_DEFAULT; + String maxRedirectsValue = + (String) userProperties.get(Constants.MAX_REDIRECTIONS_PROPERTY); + if (maxRedirectsValue != null) { + maxRedirects = Integer.parseInt(maxRedirectsValue); + } + + if (httpResponse.status != 101) { + if(isRedirectStatus(httpResponse.status)){ + List locationHeader = + httpResponse.getHandshakeResponse().getHeaders().get( + Constants.LOCATION_HEADER_NAME); + + if (locationHeader == null || locationHeader.isEmpty() || + locationHeader.get(0) == null || locationHeader.get(0).isEmpty()) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.missingLocationHeader", + Integer.toString(httpResponse.status))); + } + + URI redirectLocation = URI.create(locationHeader.get(0)).normalize(); + + if (!redirectLocation.isAbsolute()) { + redirectLocation = path.resolve(redirectLocation); + } + + String redirectScheme = redirectLocation.getScheme().toLowerCase(); + + if (redirectScheme.startsWith("http")) { + redirectLocation = new URI(redirectScheme.replace("http", "ws"), + redirectLocation.getUserInfo(), redirectLocation.getHost(), + redirectLocation.getPort(), redirectLocation.getPath(), + redirectLocation.getQuery(), redirectLocation.getFragment()); + } + + if (!redirectSet.add(redirectLocation) || redirectSet.size() > maxRedirects) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.redirectThreshold", redirectLocation, + Integer.toString(redirectSet.size()), + Integer.toString(maxRedirects))); + } + + return connectToServerRecursive(endpoint, clientEndpointConfiguration, redirectLocation, redirectSet); + + } + + else if (httpResponse.status == 401) { + + if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.failedAuthentication", + Integer.valueOf(httpResponse.status))); + } + + List wwwAuthenticateHeaders = httpResponse.getHandshakeResponse() + .getHeaders().get(Constants.WWW_AUTHENTICATE_HEADER_NAME); + + if (wwwAuthenticateHeaders == null || wwwAuthenticateHeaders.isEmpty() || + wwwAuthenticateHeaders.get(0) == null || wwwAuthenticateHeaders.get(0).isEmpty()) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.missingWWWAuthenticateHeader", + Integer.toString(httpResponse.status))); + } + + String authScheme = wwwAuthenticateHeaders.get(0).split("\\s+", 2)[0]; + String requestUri = new String(request.array(), StandardCharsets.ISO_8859_1) + .split("\\s", 3)[1]; + + Authenticator auth = AuthenticatorFactory.getAuthenticator(authScheme); + + if (auth == null) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.unsupportedAuthScheme", + Integer.valueOf(httpResponse.status), authScheme)); + } + + userProperties.put(Constants.AUTHORIZATION_HEADER_NAME, auth.getAuthorization( + requestUri, wwwAuthenticateHeaders.get(0), userProperties)); + + return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, redirectSet); + + } + + else { + throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus", + Integer.toString(httpResponse.status))); + } + } + HandshakeResponse handshakeResponse = httpResponse.getHandshakeResponse(); + clientEndpointConfiguration.getConfigurator().afterResponse(handshakeResponse); + + // Sub-protocol + List protocolHeaders = handshakeResponse.getHeaders().get( + Constants.WS_PROTOCOL_HEADER_NAME); + if (protocolHeaders == null || protocolHeaders.size() == 0) { + subProtocol = null; + } else if (protocolHeaders.size() == 1) { + subProtocol = protocolHeaders.get(0); + } else { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.invalidSubProtocol")); + } + + // Extensions + // Should normally only be one header but handle the case of + // multiple headers + List extHeaders = handshakeResponse.getHeaders().get( + Constants.WS_EXTENSIONS_HEADER_NAME); + if (extHeaders != null) { + for (String extHeader : extHeaders) { + Util.parseExtensionHeader(extensionsAgreed, extHeader); + } + } + + // Build the transformations + TransformationFactory factory = TransformationFactory.getInstance(); + for (Extension extension : extensionsAgreed) { + List> wrapper = new ArrayList<>(1); + wrapper.add(extension.getParameters()); + Transformation t = factory.create(extension.getName(), wrapper, false); + if (t == null) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.invalidExtensionParameters")); + } + if (transformation == null) { + transformation = t; + } else { + transformation.setNext(t); + } + } + + success = true; + } catch (ExecutionException | InterruptedException | SSLException | + EOFException | TimeoutException | URISyntaxException | AuthenticationException e) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.httpRequestFailed"), e); + } finally { + if (!success) { + channel.close(); + } + } + + // Switch to WebSocket + WsRemoteEndpointImplClient wsRemoteEndpointClient = new WsRemoteEndpointImplClient(channel); + + WsSession wsSession = new WsSession(endpoint, wsRemoteEndpointClient, + this, null, null, null, null, null, extensionsAgreed, + subProtocol, Collections.emptyMap(), secure, + clientEndpointConfiguration, null); + + WsFrameClient wsFrameClient = new WsFrameClient(response, channel, + wsSession, transformation); + // WsFrame adds the necessary final transformations. Copy the + // completed transformation chain to the remote end point. + wsRemoteEndpointClient.setTransformation(wsFrameClient.getTransformation()); + + endpoint.onOpen(wsSession, clientEndpointConfiguration); + registerSession(endpoint, wsSession); + + /* It is possible that the server sent one or more messages as soon as + * the WebSocket connection was established. Depending on the exact + * timing of when those messages were sent they could be sat in the + * input buffer waiting to be read and will not trigger a "data + * available to read" event. Therefore, it is necessary to process the + * input buffer here. Note that this happens on the current thread which + * means that this thread will be used for any onMessage notifications. + * This is a special case. Subsequent "data available to read" events + * will be handled by threads from the AsyncChannelGroup's executor. + */ + wsFrameClient.startInputProcessing(); + + return wsSession; + } + + + private static void writeRequest(AsyncChannelWrapper channel, ByteBuffer request, + long timeout) throws TimeoutException, InterruptedException, ExecutionException { + int toWrite = request.limit(); + + Future fWrite = channel.write(request); + Integer thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS); + toWrite -= thisWrite.intValue(); + + while (toWrite > 0) { + fWrite = channel.write(request); + thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS); + toWrite -= thisWrite.intValue(); + } + } + + + private static boolean isRedirectStatus(int httpResponseCode) { + + boolean isRedirect = false; + + switch (httpResponseCode) { + case Constants.MULTIPLE_CHOICES: + case Constants.MOVED_PERMANENTLY: + case Constants.FOUND: + case Constants.SEE_OTHER: + case Constants.USE_PROXY: + case Constants.TEMPORARY_REDIRECT: + isRedirect = true; + break; + default: + break; + } + + return isRedirect; + } + + + private static ByteBuffer createProxyRequest(String host, int port) { + StringBuilder request = new StringBuilder(); + request.append("CONNECT "); + request.append(host); + request.append(':'); + request.append(port); + + request.append(" HTTP/1.1\r\nProxy-Connection: keep-alive\r\nConnection: keepalive\r\nHost: "); + request.append(host); + request.append(':'); + request.append(port); + + request.append("\r\n\r\n"); + + byte[] bytes = request.toString().getBytes(StandardCharsets.ISO_8859_1); + return ByteBuffer.wrap(bytes); + } + + protected void registerSession(Endpoint endpoint, WsSession wsSession) { + + if (!wsSession.isOpen()) { + // The session was closed during onOpen. No need to register it. + return; + } + synchronized (endPointSessionMapLock) { + if (endpointSessionMap.size() == 0) { + BackgroundProcessManager.getInstance().register(this); + } + Set wsSessions = endpointSessionMap.get(endpoint); + if (wsSessions == null) { + wsSessions = new HashSet<>(); + endpointSessionMap.put(endpoint, wsSessions); + } + wsSessions.add(wsSession); + } + sessions.put(wsSession, wsSession); + } + + + protected void unregisterSession(Endpoint endpoint, WsSession wsSession) { + + synchronized (endPointSessionMapLock) { + Set wsSessions = endpointSessionMap.get(endpoint); + if (wsSessions != null) { + wsSessions.remove(wsSession); + if (wsSessions.size() == 0) { + endpointSessionMap.remove(endpoint); + } + } + if (endpointSessionMap.size() == 0) { + BackgroundProcessManager.getInstance().unregister(this); + } + } + sessions.remove(wsSession); + } + + + Set getOpenSessions(Endpoint endpoint) { + HashSet result = new HashSet<>(); + synchronized (endPointSessionMapLock) { + Set sessions = endpointSessionMap.get(endpoint); + if (sessions != null) { + result.addAll(sessions); + } + } + return result; + } + + private static Map> createRequestHeaders(String host, int port, + ClientEndpointConfig clientEndpointConfiguration) { + + Map> headers = new HashMap<>(); + List extensions = clientEndpointConfiguration.getExtensions(); + List subProtocols = clientEndpointConfiguration.getPreferredSubprotocols(); + Map userProperties = clientEndpointConfiguration.getUserProperties(); + + if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) { + List authValues = new ArrayList<>(1); + authValues.add((String) userProperties.get(Constants.AUTHORIZATION_HEADER_NAME)); + headers.put(Constants.AUTHORIZATION_HEADER_NAME, authValues); + } + + // Host header + List hostValues = new ArrayList<>(1); + if (port == -1) { + hostValues.add(host); + } else { + hostValues.add(host + ':' + port); + } + + headers.put(Constants.HOST_HEADER_NAME, hostValues); + + // Upgrade header + List upgradeValues = new ArrayList<>(1); + upgradeValues.add(Constants.UPGRADE_HEADER_VALUE); + headers.put(Constants.UPGRADE_HEADER_NAME, upgradeValues); + + // Connection header + List connectionValues = new ArrayList<>(1); + connectionValues.add(Constants.CONNECTION_HEADER_VALUE); + headers.put(Constants.CONNECTION_HEADER_NAME, connectionValues); + + // WebSocket version header + List wsVersionValues = new ArrayList<>(1); + wsVersionValues.add(Constants.WS_VERSION_HEADER_VALUE); + headers.put(Constants.WS_VERSION_HEADER_NAME, wsVersionValues); + + // WebSocket key + List wsKeyValues = new ArrayList<>(1); + wsKeyValues.add(generateWsKeyValue()); + headers.put(Constants.WS_KEY_HEADER_NAME, wsKeyValues); + + // WebSocket sub-protocols + if (subProtocols != null && subProtocols.size() > 0) { + headers.put(Constants.WS_PROTOCOL_HEADER_NAME, subProtocols); + } + + // WebSocket extensions + if (extensions != null && extensions.size() > 0) { + headers.put(Constants.WS_EXTENSIONS_HEADER_NAME, + generateExtensionHeaders(extensions)); + } + + return headers; + } + + + private static List generateExtensionHeaders(List extensions) { + List result = new ArrayList<>(extensions.size()); + for (Extension extension : extensions) { + StringBuilder header = new StringBuilder(); + header.append(extension.getName()); + for (Extension.Parameter param : extension.getParameters()) { + header.append(';'); + header.append(param.getName()); + String value = param.getValue(); + if (value != null && value.length() > 0) { + header.append('='); + header.append(value); + } + } + result.add(header.toString()); + } + return result; + } + + + private static String generateWsKeyValue() { + byte[] keyBytes = new byte[16]; + RANDOM.nextBytes(keyBytes); + return Base64.encodeBase64String(keyBytes); + } + + + private static ByteBuffer createRequest(URI uri, Map> reqHeaders) { + ByteBuffer result = ByteBuffer.allocate(4 * 1024); + + // Request line + result.put(GET_BYTES); + if (null == uri.getPath() || "".equals(uri.getPath())) { + result.put(ROOT_URI_BYTES); + } else { + result.put(uri.getRawPath().getBytes(StandardCharsets.ISO_8859_1)); + } + String query = uri.getRawQuery(); + if (query != null) { + result.put((byte) '?'); + result.put(query.getBytes(StandardCharsets.ISO_8859_1)); + } + result.put(HTTP_VERSION_BYTES); + + // Headers + for (Entry> entry : reqHeaders.entrySet()) { + result = addHeader(result, entry.getKey(), entry.getValue()); + } + + // Terminating CRLF + result.put(CRLF); + + result.flip(); + + return result; + } + + + private static ByteBuffer addHeader(ByteBuffer result, String key, List values) { + if (values.isEmpty()) { + return result; + } + + result = putWithExpand(result, key.getBytes(StandardCharsets.ISO_8859_1)); + result = putWithExpand(result, ": ".getBytes(StandardCharsets.ISO_8859_1)); + result = putWithExpand(result, StringUtils.join(values).getBytes(StandardCharsets.ISO_8859_1)); + result = putWithExpand(result, CRLF); + + return result; + } + + + private static ByteBuffer putWithExpand(ByteBuffer input, byte[] bytes) { + if (bytes.length > input.remaining()) { + int newSize; + if (bytes.length > input.capacity()) { + newSize = 2 * bytes.length; + } else { + newSize = input.capacity() * 2; + } + ByteBuffer expanded = ByteBuffer.allocate(newSize); + input.flip(); + expanded.put(input); + input = expanded; + } + return input.put(bytes); + } + + + /** + * Process response, blocking until HTTP response has been fully received. + * @throws ExecutionException + * @throws InterruptedException + * @throws DeploymentException + * @throws TimeoutException + */ + private HttpResponse processResponse(ByteBuffer response, + AsyncChannelWrapper channel, long timeout) throws InterruptedException, + ExecutionException, DeploymentException, EOFException, + TimeoutException { + + Map> headers = new CaseInsensitiveKeyMap<>(); + + int status = 0; + boolean readStatus = false; + boolean readHeaders = false; + String line = null; + while (!readHeaders) { + // On entering loop buffer will be empty and at the start of a new + // loop the buffer will have been fully read. + response.clear(); + // Blocking read + Future read = channel.read(response); + Integer bytesRead = read.get(timeout, TimeUnit.MILLISECONDS); + if (bytesRead.intValue() == -1) { + throw new EOFException(); + } + response.flip(); + while (response.hasRemaining() && !readHeaders) { + if (line == null) { + line = readLine(response); + } else { + line += readLine(response); + } + if ("\r\n".equals(line)) { + readHeaders = true; + } else if (line.endsWith("\r\n")) { + if (readStatus) { + parseHeaders(line, headers); + } else { + status = parseStatus(line); + readStatus = true; + } + line = null; + } + } + } + + return new HttpResponse(status, new WsHandshakeResponse(headers)); + } + + + private int parseStatus(String line) throws DeploymentException { + // This client only understands HTTP 1. + // RFC2616 is case specific + String[] parts = line.trim().split(" "); + // CONNECT for proxy may return a 1.0 response + if (parts.length < 2 || !("HTTP/1.0".equals(parts[0]) || "HTTP/1.1".equals(parts[0]))) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.invalidStatus", line)); + } + try { + return Integer.parseInt(parts[1]); + } catch (NumberFormatException nfe) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.invalidStatus", line)); + } + } + + + private void parseHeaders(String line, Map> headers) { + // Treat headers as single values by default. + + int index = line.indexOf(':'); + if (index == -1) { + log.warn(sm.getString("wsWebSocketContainer.invalidHeader", line)); + return; + } + // Header names are case insensitive so always use lower case + String headerName = line.substring(0, index).trim().toLowerCase(Locale.ENGLISH); + // Multi-value headers are stored as a single header and the client is + // expected to handle splitting into individual values + String headerValue = line.substring(index + 1).trim(); + + List values = headers.get(headerName); + if (values == null) { + values = new ArrayList<>(1); + headers.put(headerName, values); + } + values.add(headerValue); + } + + private String readLine(ByteBuffer response) { + // All ISO-8859-1 + StringBuilder sb = new StringBuilder(); + + char c = 0; + while (response.hasRemaining()) { + c = (char) response.get(); + sb.append(c); + if (c == 10) { + break; + } + } + + return sb.toString(); + } + + + private SSLEngine createSSLEngine(Map userProperties, String host, int port) + throws DeploymentException { + + try { + // See if a custom SSLContext has been provided + SSLContext sslContext = + (SSLContext) userProperties.get(Constants.SSL_CONTEXT_PROPERTY); + + if (sslContext == null) { + // Create the SSL Context + sslContext = SSLContext.getInstance("TLS"); + + // Trust store + String sslTrustStoreValue = + (String) userProperties.get(Constants.SSL_TRUSTSTORE_PROPERTY); + if (sslTrustStoreValue != null) { + String sslTrustStorePwdValue = (String) userProperties.get( + Constants.SSL_TRUSTSTORE_PWD_PROPERTY); + if (sslTrustStorePwdValue == null) { + sslTrustStorePwdValue = Constants.SSL_TRUSTSTORE_PWD_DEFAULT; + } + + File keyStoreFile = new File(sslTrustStoreValue); + KeyStore ks = KeyStore.getInstance("JKS"); + try (InputStream is = new FileInputStream(keyStoreFile)) { + ks.load(is, sslTrustStorePwdValue.toCharArray()); + } + + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(ks); + + sslContext.init(null, tmf.getTrustManagers(), null); + } else { + sslContext.init(null, null, null); + } + } + + SSLEngine engine = sslContext.createSSLEngine(host, port); + + String sslProtocolsValue = + (String) userProperties.get(Constants.SSL_PROTOCOLS_PROPERTY); + if (sslProtocolsValue != null) { + engine.setEnabledProtocols(sslProtocolsValue.split(",")); + } + + engine.setUseClientMode(true); + + // Enable host verification + // Start with current settings (returns a copy) + SSLParameters sslParams = engine.getSSLParameters(); + // Use HTTPS since WebSocket starts over HTTP(S) + sslParams.setEndpointIdentificationAlgorithm("HTTPS"); + // Write the parameters back + engine.setSSLParameters(sslParams); + + return engine; + } catch (Exception e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.sslEngineFail"), e); + } + } + + + @Override + public long getDefaultMaxSessionIdleTimeout() { + return defaultMaxSessionIdleTimeout; + } + + + @Override + public void setDefaultMaxSessionIdleTimeout(long timeout) { + this.defaultMaxSessionIdleTimeout = timeout; + } + + + @Override + public int getDefaultMaxBinaryMessageBufferSize() { + return maxBinaryMessageBufferSize; + } + + + @Override + public void setDefaultMaxBinaryMessageBufferSize(int max) { + maxBinaryMessageBufferSize = max; + } + + + @Override + public int getDefaultMaxTextMessageBufferSize() { + return maxTextMessageBufferSize; + } + + + @Override + public void setDefaultMaxTextMessageBufferSize(int max) { + maxTextMessageBufferSize = max; + } + + + /** + * {@inheritDoc} + * + * Currently, this implementation does not support any extensions. + */ + @Override + public Set getInstalledExtensions() { + return Collections.emptySet(); + } + + + /** + * {@inheritDoc} + * + * The default value for this implementation is -1. + */ + @Override + public long getDefaultAsyncSendTimeout() { + return defaultAsyncTimeout; + } + + + /** + * {@inheritDoc} + * + * The default value for this implementation is -1. + */ + @Override + public void setAsyncSendTimeout(long timeout) { + this.defaultAsyncTimeout = timeout; + } + + + /** + * Cleans up the resources still in use by WebSocket sessions created from + * this container. This includes closing sessions and cancelling + * {@link Future}s associated with blocking read/writes. + */ + public void destroy() { + CloseReason cr = new CloseReason( + CloseCodes.GOING_AWAY, sm.getString("wsWebSocketContainer.shutdown")); + + for (WsSession session : sessions.keySet()) { + try { + session.close(cr); + } catch (IOException ioe) { + log.debug(sm.getString( + "wsWebSocketContainer.sessionCloseFail", session.getId()), ioe); + } + } + + // Only unregister with AsyncChannelGroupUtil if this instance + // registered with it + if (asynchronousChannelGroup != null) { + synchronized (asynchronousChannelGroupLock) { + if (asynchronousChannelGroup != null) { + AsyncChannelGroupUtil.unregister(); + asynchronousChannelGroup = null; + } + } + } + } + + + private AsynchronousChannelGroup getAsynchronousChannelGroup() { + // Use AsyncChannelGroupUtil to share a common group amongst all + // WebSocket clients + AsynchronousChannelGroup result = asynchronousChannelGroup; + if (result == null) { + synchronized (asynchronousChannelGroupLock) { + if (asynchronousChannelGroup == null) { + asynchronousChannelGroup = AsyncChannelGroupUtil.register(); + } + result = asynchronousChannelGroup; + } + } + return result; + } + + + // ----------------------------------------------- BackgroundProcess methods + + @Override + public void backgroundProcess() { + // This method gets called once a second. + backgroundProcessCount ++; + if (backgroundProcessCount >= processPeriod) { + backgroundProcessCount = 0; + + for (WsSession wsSession : sessions.keySet()) { + wsSession.checkExpiration(); + } + } + + } + + + @Override + public void setProcessPeriod(int period) { + this.processPeriod = period; + } + + + /** + * {@inheritDoc} + * + * The default value is 10 which means session expirations are processed + * every 10 seconds. + */ + @Override + public int getProcessPeriod() { + return processPeriod; + } + + + private static class HttpResponse { + private final int status; + private final HandshakeResponse handshakeResponse; + + public HttpResponse(int status, HandshakeResponse handshakeResponse) { + this.status = status; + this.handshakeResponse = handshakeResponse; + } + + + public int getStatus() { + return status; + } + + + public HandshakeResponse getHandshakeResponse() { + return handshakeResponse; + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/Constants.java b/src/java/nginx/unit/websocket/pojo/Constants.java new file mode 100644 index 00000000..93cdecc7 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/Constants.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +/** + * Internal implementation constants. + */ +public class Constants { + + public static final String POJO_PATH_PARAM_KEY = + "nginx.unit.websocket.pojo.PojoEndpoint.pathParams"; + public static final String POJO_METHOD_MAPPING_KEY = + "nginx.unit.websocket.pojo.PojoEndpoint.methodMapping"; + + private Constants() { + // Hide default constructor + } +} diff --git a/src/java/nginx/unit/websocket/pojo/LocalStrings.properties b/src/java/nginx/unit/websocket/pojo/LocalStrings.properties new file mode 100644 index 00000000..00ab7e6b --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/LocalStrings.properties @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pojoEndpointBase.closeSessionFail=Failed to close WebSocket session during error handling +pojoEndpointBase.onCloseFail=Failed to call onClose method of POJO end point for POJO of type [{0}] +pojoEndpointBase.onError=No error handling configured for [{0}] and the following error occurred +pojoEndpointBase.onErrorFail=Failed to call onError method of POJO end point for POJO of type [{0}] +pojoEndpointBase.onOpenFail=Failed to call onOpen method of POJO end point for POJO of type [{0}] +pojoEndpointServer.getPojoInstanceFail=Failed to create instance of POJO of type [{0}] +pojoMethodMapping.decodePathParamFail=Failed to decode path parameter value [{0}] to expected type [{1}] +pojoMethodMapping.duplicateAnnotation=Duplicate annotations [{0}] present on class [{1}] +pojoMethodMapping.duplicateLastParam=Multiple boolean (last) parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.duplicateMessageParam=Multiple message parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.duplicatePongMessageParam=Multiple PongMessage parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.duplicateSessionParam=Multiple session parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.invalidDecoder=The specified decoder of type [{0}] could not be instantiated +pojoMethodMapping.invalidPathParamType=Parameters annotated with @PathParam may only be Strings, Java primitives or a boxed version thereof +pojoMethodMapping.methodNotPublic=The annotated method [{0}] is not public +pojoMethodMapping.noPayload=No payload parameter present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.onErrorNoThrowable=No Throwable parameter was present on the method [{0}] of class [{1}] that was annotated with OnError +pojoMethodMapping.paramWithoutAnnotation=A parameter of type [{0}] was found on method[{1}] of class [{2}] that did not have a @PathParam annotation +pojoMethodMapping.partialInputStream=Invalid InputStream and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.partialObject=Invalid Object and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.partialPong=Invalid PongMessage and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.partialReader=Invalid Reader and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.pongWithPayload=Invalid PongMessage and Message parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMessageHandlerWhole.decodeIoFail=IO error while decoding message +pojoMessageHandlerWhole.maxBufferSize=The maximum supported message size for this implementation is Integer.MAX_VALUE diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java new file mode 100644 index 00000000..be679a35 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.util.Map; +import java.util.Set; + +import javax.websocket.CloseReason; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.MessageHandler; +import javax.websocket.Session; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.res.StringManager; + +/** + * Base implementation (client and server have different concrete + * implementations) of the wrapper that converts a POJO instance into a + * WebSocket endpoint instance. + */ +public abstract class PojoEndpointBase extends Endpoint { + + private final Log log = LogFactory.getLog(PojoEndpointBase.class); // must not be static + private static final StringManager sm = StringManager.getManager(PojoEndpointBase.class); + + private Object pojo; + private Map pathParameters; + private PojoMethodMapping methodMapping; + + + protected final void doOnOpen(Session session, EndpointConfig config) { + PojoMethodMapping methodMapping = getMethodMapping(); + Object pojo = getPojo(); + Map pathParameters = getPathParameters(); + + // Add message handlers before calling onOpen since that may trigger a + // message which in turn could trigger a response and/or close the + // session + for (MessageHandler mh : methodMapping.getMessageHandlers(pojo, + pathParameters, session, config)) { + session.addMessageHandler(mh); + } + + if (methodMapping.getOnOpen() != null) { + try { + methodMapping.getOnOpen().invoke(pojo, + methodMapping.getOnOpenArgs( + pathParameters, session, config)); + + } catch (IllegalAccessException e) { + // Reflection related problems + log.error(sm.getString( + "pojoEndpointBase.onOpenFail", + pojo.getClass().getName()), e); + handleOnOpenOrCloseError(session, e); + } catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + handleOnOpenOrCloseError(session, cause); + } catch (Throwable t) { + handleOnOpenOrCloseError(session, t); + } + } + } + + + private void handleOnOpenOrCloseError(Session session, Throwable t) { + // If really fatal - re-throw + ExceptionUtils.handleThrowable(t); + + // Trigger the error handler and close the session + onError(session, t); + try { + session.close(); + } catch (IOException ioe) { + log.warn(sm.getString("pojoEndpointBase.closeSessionFail"), ioe); + } + } + + @Override + public final void onClose(Session session, CloseReason closeReason) { + + if (methodMapping.getOnClose() != null) { + try { + methodMapping.getOnClose().invoke(pojo, + methodMapping.getOnCloseArgs(pathParameters, session, closeReason)); + } catch (Throwable t) { + log.error(sm.getString("pojoEndpointBase.onCloseFail", + pojo.getClass().getName()), t); + handleOnOpenOrCloseError(session, t); + } + } + + // Trigger the destroy method for any associated decoders + Set messageHandlers = session.getMessageHandlers(); + for (MessageHandler messageHandler : messageHandlers) { + if (messageHandler instanceof PojoMessageHandlerWholeBase) { + ((PojoMessageHandlerWholeBase) messageHandler).onClose(); + } + } + } + + + @Override + public final void onError(Session session, Throwable throwable) { + + if (methodMapping.getOnError() == null) { + log.error(sm.getString("pojoEndpointBase.onError", + pojo.getClass().getName()), throwable); + } else { + try { + methodMapping.getOnError().invoke( + pojo, + methodMapping.getOnErrorArgs(pathParameters, session, + throwable)); + } catch (Throwable t) { + ExceptionUtils.handleThrowable(t); + log.error(sm.getString("pojoEndpointBase.onErrorFail", + pojo.getClass().getName()), t); + } + } + } + + protected Object getPojo() { return pojo; } + protected void setPojo(Object pojo) { this.pojo = pojo; } + + + protected Map getPathParameters() { return pathParameters; } + protected void setPathParameters(Map pathParameters) { + this.pathParameters = pathParameters; + } + + + protected PojoMethodMapping getMethodMapping() { return methodMapping; } + protected void setMethodMapping(PojoMethodMapping methodMapping) { + this.methodMapping = methodMapping; + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java new file mode 100644 index 00000000..6e569487 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.util.Collections; +import java.util.List; + +import javax.websocket.Decoder; +import javax.websocket.DeploymentException; +import javax.websocket.EndpointConfig; +import javax.websocket.Session; + + +/** + * Wrapper class for instances of POJOs annotated with + * {@link javax.websocket.ClientEndpoint} so they appear as standard + * {@link javax.websocket.Endpoint} instances. + */ +public class PojoEndpointClient extends PojoEndpointBase { + + public PojoEndpointClient(Object pojo, + List> decoders) throws DeploymentException { + setPojo(pojo); + setMethodMapping( + new PojoMethodMapping(pojo.getClass(), decoders, null)); + setPathParameters(Collections.emptyMap()); + } + + @Override + public void onOpen(Session session, EndpointConfig config) { + doOnOpen(session, config); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java new file mode 100644 index 00000000..499f8274 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.util.Map; + +import javax.websocket.EndpointConfig; +import javax.websocket.Session; +import javax.websocket.server.ServerEndpointConfig; + +import org.apache.tomcat.util.res.StringManager; + +/** + * Wrapper class for instances of POJOs annotated with + * {@link javax.websocket.server.ServerEndpoint} so they appear as standard + * {@link javax.websocket.Endpoint} instances. + */ +public class PojoEndpointServer extends PojoEndpointBase { + + private static final StringManager sm = + StringManager.getManager(PojoEndpointServer.class); + + @Override + public void onOpen(Session session, EndpointConfig endpointConfig) { + + ServerEndpointConfig sec = (ServerEndpointConfig) endpointConfig; + + Object pojo; + try { + pojo = sec.getConfigurator().getEndpointInstance( + sec.getEndpointClass()); + } catch (InstantiationException e) { + throw new IllegalArgumentException(sm.getString( + "pojoEndpointServer.getPojoInstanceFail", + sec.getEndpointClass().getName()), e); + } + setPojo(pojo); + + @SuppressWarnings("unchecked") + Map pathParameters = + (Map) sec.getUserProperties().get( + Constants.POJO_PATH_PARAM_KEY); + setPathParameters(pathParameters); + + PojoMethodMapping methodMapping = + (PojoMethodMapping) sec.getUserProperties().get( + Constants.POJO_METHOD_MAPPING_KEY); + setMethodMapping(methodMapping); + + doOnOpen(session, endpointConfig); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java new file mode 100644 index 00000000..b72d719a --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +import javax.websocket.EncodeException; +import javax.websocket.MessageHandler; +import javax.websocket.RemoteEndpoint; +import javax.websocket.Session; + +import org.apache.tomcat.util.ExceptionUtils; +import nginx.unit.websocket.WrappedMessageHandler; + +/** + * Common implementation code for the POJO message handlers. + * + * @param The type of message to handle + */ +public abstract class PojoMessageHandlerBase + implements WrappedMessageHandler { + + protected final Object pojo; + protected final Method method; + protected final Session session; + protected final Object[] params; + protected final int indexPayload; + protected final boolean convert; + protected final int indexSession; + protected final long maxMessageSize; + + public PojoMessageHandlerBase(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexSession, long maxMessageSize) { + this.pojo = pojo; + this.method = method; + // TODO: The method should already be accessible here but the following + // code seems to be necessary in some as yet not fully understood cases. + try { + this.method.setAccessible(true); + } catch (Exception e) { + // It is better to make sure the method is accessible, but + // ignore exceptions and hope for the best + } + this.session = session; + this.params = params; + this.indexPayload = indexPayload; + this.convert = convert; + this.indexSession = indexSession; + this.maxMessageSize = maxMessageSize; + } + + + protected final void processResult(Object result) { + if (result == null) { + return; + } + + RemoteEndpoint.Basic remoteEndpoint = session.getBasicRemote(); + try { + if (result instanceof String) { + remoteEndpoint.sendText((String) result); + } else if (result instanceof ByteBuffer) { + remoteEndpoint.sendBinary((ByteBuffer) result); + } else if (result instanceof byte[]) { + remoteEndpoint.sendBinary(ByteBuffer.wrap((byte[]) result)); + } else { + remoteEndpoint.sendObject(result); + } + } catch (IOException | EncodeException ioe) { + throw new IllegalStateException(ioe); + } + } + + + /** + * Expose the POJO if it is a message handler so the Session is able to + * match requests to remove handlers if the original handler has been + * wrapped. + */ + @Override + public final MessageHandler getWrappedHandler() { + if (pojo instanceof MessageHandler) { + return (MessageHandler) pojo; + } else { + return null; + } + } + + + @Override + public final long getMaxMessageSize() { + return maxMessageSize; + } + + + protected final void handlePojoMethodException(Throwable t) { + t = ExceptionUtils.unwrapInvocationTargetException(t); + ExceptionUtils.handleThrowable(t); + if (t instanceof RuntimeException) { + throw (RuntimeException) t; + } else { + throw new RuntimeException(t.getMessage(), t); + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java new file mode 100644 index 00000000..d6f37724 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +import javax.websocket.DecodeException; +import javax.websocket.MessageHandler; +import javax.websocket.Session; + +import nginx.unit.websocket.WsSession; + +/** + * Common implementation code for the POJO partial message handlers. All + * the real work is done in this class and in the superclass. + * + * @param The type of message to handle + */ +public abstract class PojoMessageHandlerPartialBase + extends PojoMessageHandlerBase implements MessageHandler.Partial { + + private final int indexBoolean; + + public PojoMessageHandlerPartialBase(Object pojo, Method method, + Session session, Object[] params, int indexPayload, + boolean convert, int indexBoolean, int indexSession, + long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + this.indexBoolean = indexBoolean; + } + + + @Override + public final void onMessage(T message, boolean last) { + if (params.length == 1 && params[0] instanceof DecodeException) { + ((WsSession) session).getLocal().onError(session, + (DecodeException) params[0]); + return; + } + Object[] parameters = params.clone(); + if (indexBoolean != -1) { + parameters[indexBoolean] = Boolean.valueOf(last); + } + if (indexSession != -1) { + parameters[indexSession] = session; + } + if (convert) { + parameters[indexPayload] = ((ByteBuffer) message).array(); + } else { + parameters[indexPayload] = message; + } + Object result = null; + try { + result = method.invoke(pojo, parameters); + } catch (IllegalAccessException | InvocationTargetException e) { + handlePojoMethodException(e); + } + processResult(result); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java new file mode 100644 index 00000000..1d334017 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +import javax.websocket.Session; + +/** + * ByteBuffer specific concrete implementation for handling partial messages. + */ +public class PojoMessageHandlerPartialBinary + extends PojoMessageHandlerPartialBase { + + public PojoMessageHandlerPartialBinary(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexBoolean, int indexSession, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, indexBoolean, + indexSession, maxMessageSize); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java new file mode 100644 index 00000000..8f7c1a0d --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.Method; + +import javax.websocket.Session; + +/** + * Text specific concrete implementation for handling partial messages. + */ +public class PojoMessageHandlerPartialText + extends PojoMessageHandlerPartialBase { + + public PojoMessageHandlerPartialText(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexBoolean, int indexSession, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, indexBoolean, + indexSession, maxMessageSize); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java new file mode 100644 index 00000000..23333eb7 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +import javax.websocket.DecodeException; +import javax.websocket.MessageHandler; +import javax.websocket.Session; + +import nginx.unit.websocket.WsSession; + +/** + * Common implementation code for the POJO whole message handlers. All the real + * work is done in this class and in the superclass. + * + * @param The type of message to handle + */ +public abstract class PojoMessageHandlerWholeBase + extends PojoMessageHandlerBase implements MessageHandler.Whole { + + public PojoMessageHandlerWholeBase(Object pojo, Method method, + Session session, Object[] params, int indexPayload, + boolean convert, int indexSession, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + } + + + @Override + public final void onMessage(T message) { + + if (params.length == 1 && params[0] instanceof DecodeException) { + ((WsSession) session).getLocal().onError(session, + (DecodeException) params[0]); + return; + } + + // Can this message be decoded? + Object payload; + try { + payload = decode(message); + } catch (DecodeException de) { + ((WsSession) session).getLocal().onError(session, de); + return; + } + + if (payload == null) { + // Not decoded. Convert if required. + if (convert) { + payload = convert(message); + } else { + payload = message; + } + } + + Object[] parameters = params.clone(); + if (indexSession != -1) { + parameters[indexSession] = session; + } + parameters[indexPayload] = payload; + + Object result = null; + try { + result = method.invoke(pojo, parameters); + } catch (IllegalAccessException | InvocationTargetException e) { + handlePojoMethodException(e); + } + processResult(result); + } + + protected Object convert(T message) { + return message; + } + + + protected abstract Object decode(T message) throws DecodeException; + protected abstract void onClose(); +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java new file mode 100644 index 00000000..07ff0648 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import javax.websocket.DecodeException; +import javax.websocket.Decoder; +import javax.websocket.Decoder.Binary; +import javax.websocket.Decoder.BinaryStream; +import javax.websocket.EndpointConfig; +import javax.websocket.Session; + +import org.apache.tomcat.util.res.StringManager; + +/** + * ByteBuffer specific concrete implementation for handling whole messages. + */ +public class PojoMessageHandlerWholeBinary + extends PojoMessageHandlerWholeBase { + + private static final StringManager sm = + StringManager.getManager(PojoMessageHandlerWholeBinary.class); + + private final List decoders = new ArrayList<>(); + + private final boolean isForInputStream; + + public PojoMessageHandlerWholeBinary(Object pojo, Method method, + Session session, EndpointConfig config, + List> decoderClazzes, Object[] params, + int indexPayload, boolean convert, int indexSession, + boolean isForInputStream, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + + // Update binary text size handled by session + if (maxMessageSize > -1 && maxMessageSize > session.getMaxBinaryMessageBufferSize()) { + if (maxMessageSize > Integer.MAX_VALUE) { + throw new IllegalArgumentException(sm.getString( + "pojoMessageHandlerWhole.maxBufferSize")); + } + session.setMaxBinaryMessageBufferSize((int) maxMessageSize); + } + + try { + if (decoderClazzes != null) { + for (Class decoderClazz : decoderClazzes) { + if (Binary.class.isAssignableFrom(decoderClazz)) { + Binary decoder = (Binary) decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else if (BinaryStream.class.isAssignableFrom( + decoderClazz)) { + BinaryStream decoder = (BinaryStream) + decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else { + // Text decoder - ignore it + } + } + } + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException(e); + } + this.isForInputStream = isForInputStream; + } + + + @Override + protected Object decode(ByteBuffer message) throws DecodeException { + for (Decoder decoder : decoders) { + if (decoder instanceof Binary) { + if (((Binary) decoder).willDecode(message)) { + return ((Binary) decoder).decode(message); + } + } else { + byte[] array = new byte[message.limit() - message.position()]; + message.get(array); + ByteArrayInputStream bais = new ByteArrayInputStream(array); + try { + return ((BinaryStream) decoder).decode(bais); + } catch (IOException ioe) { + throw new DecodeException(message, sm.getString( + "pojoMessageHandlerWhole.decodeIoFail"), ioe); + } + } + } + return null; + } + + + @Override + protected Object convert(ByteBuffer message) { + byte[] array = new byte[message.remaining()]; + message.get(array); + if (isForInputStream) { + return new ByteArrayInputStream(array); + } else { + return array; + } + } + + + @Override + protected void onClose() { + for (Decoder decoder : decoders) { + decoder.destroy(); + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java new file mode 100644 index 00000000..bdedd7de --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.Method; + +import javax.websocket.PongMessage; +import javax.websocket.Session; + +/** + * PongMessage specific concrete implementation for handling whole messages. + */ +public class PojoMessageHandlerWholePong + extends PojoMessageHandlerWholeBase { + + public PojoMessageHandlerWholePong(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexSession) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, -1); + } + + @Override + protected Object decode(PongMessage message) { + // Never decoded + return null; + } + + + @Override + protected void onClose() { + // NO-OP + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java new file mode 100644 index 00000000..59007349 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.IOException; +import java.io.StringReader; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; + +import javax.websocket.DecodeException; +import javax.websocket.Decoder; +import javax.websocket.Decoder.Text; +import javax.websocket.Decoder.TextStream; +import javax.websocket.EndpointConfig; +import javax.websocket.Session; + +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.Util; + + +/** + * Text specific concrete implementation for handling whole messages. + */ +public class PojoMessageHandlerWholeText + extends PojoMessageHandlerWholeBase { + + private static final StringManager sm = + StringManager.getManager(PojoMessageHandlerWholeText.class); + + private final List decoders = new ArrayList<>(); + private final Class primitiveType; + + public PojoMessageHandlerWholeText(Object pojo, Method method, + Session session, EndpointConfig config, + List> decoderClazzes, Object[] params, + int indexPayload, boolean convert, int indexSession, + long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + + // Update max text size handled by session + if (maxMessageSize > -1 && maxMessageSize > session.getMaxTextMessageBufferSize()) { + if (maxMessageSize > Integer.MAX_VALUE) { + throw new IllegalArgumentException(sm.getString( + "pojoMessageHandlerWhole.maxBufferSize")); + } + session.setMaxTextMessageBufferSize((int) maxMessageSize); + } + + // Check for primitives + Class type = method.getParameterTypes()[indexPayload]; + if (Util.isPrimitive(type)) { + primitiveType = type; + return; + } else { + primitiveType = null; + } + + try { + if (decoderClazzes != null) { + for (Class decoderClazz : decoderClazzes) { + if (Text.class.isAssignableFrom(decoderClazz)) { + Text decoder = (Text) decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else if (TextStream.class.isAssignableFrom( + decoderClazz)) { + TextStream decoder = + (TextStream) decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else { + // Binary decoder - ignore it + } + } + } + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException(e); + } + } + + + @Override + protected Object decode(String message) throws DecodeException { + // Handle primitives + if (primitiveType != null) { + return Util.coerceToType(primitiveType, message); + } + // Handle full decoders + for (Decoder decoder : decoders) { + if (decoder instanceof Text) { + if (((Text) decoder).willDecode(message)) { + return ((Text) decoder).decode(message); + } + } else { + StringReader r = new StringReader(message); + try { + return ((TextStream) decoder).decode(r); + } catch (IOException ioe) { + throw new DecodeException(message, sm.getString( + "pojoMessageHandlerWhole.decodeIoFail"), ioe); + } + } + } + return null; + } + + + @Override + protected Object convert(String message) { + return new StringReader(message); + } + + + @Override + protected void onClose() { + for (Decoder decoder : decoders) { + decoder.destroy(); + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java b/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java new file mode 100644 index 00000000..2385b5c7 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java @@ -0,0 +1,731 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.InputStream; +import java.io.Reader; +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.websocket.CloseReason; +import javax.websocket.DecodeException; +import javax.websocket.Decoder; +import javax.websocket.DeploymentException; +import javax.websocket.EndpointConfig; +import javax.websocket.MessageHandler; +import javax.websocket.OnClose; +import javax.websocket.OnError; +import javax.websocket.OnMessage; +import javax.websocket.OnOpen; +import javax.websocket.PongMessage; +import javax.websocket.Session; +import javax.websocket.server.PathParam; + +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.DecoderEntry; +import nginx.unit.websocket.Util; +import nginx.unit.websocket.Util.DecoderMatch; + +/** + * For a POJO class annotated with + * {@link javax.websocket.server.ServerEndpoint}, an instance of this class + * creates and caches the method handler, method information and parameter + * information for the onXXX calls. + */ +public class PojoMethodMapping { + + private static final StringManager sm = + StringManager.getManager(PojoMethodMapping.class); + + private final Method onOpen; + private final Method onClose; + private final Method onError; + private final PojoPathParam[] onOpenParams; + private final PojoPathParam[] onCloseParams; + private final PojoPathParam[] onErrorParams; + private final List onMessage = new ArrayList<>(); + private final String wsPath; + + + public PojoMethodMapping(Class clazzPojo, + List> decoderClazzes, String wsPath) + throws DeploymentException { + + this.wsPath = wsPath; + + List decoders = Util.getDecoders(decoderClazzes); + Method open = null; + Method close = null; + Method error = null; + Method[] clazzPojoMethods = null; + Class currentClazz = clazzPojo; + while (!currentClazz.equals(Object.class)) { + Method[] currentClazzMethods = currentClazz.getDeclaredMethods(); + if (currentClazz == clazzPojo) { + clazzPojoMethods = currentClazzMethods; + } + for (Method method : currentClazzMethods) { + if (method.getAnnotation(OnOpen.class) != null) { + checkPublic(method); + if (open == null) { + open = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(open, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnOpen.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnClose.class) != null) { + checkPublic(method); + if (close == null) { + close = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(close, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnClose.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnError.class) != null) { + checkPublic(method); + if (error == null) { + error = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(error, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnError.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnMessage.class) != null) { + checkPublic(method); + MessageHandlerInfo messageHandler = new MessageHandlerInfo(method, decoders); + boolean found = false; + for (MessageHandlerInfo otherMessageHandler : onMessage) { + if (messageHandler.targetsSameWebSocketMessageType(otherMessageHandler)) { + found = true; + if (currentClazz == clazzPojo || + !isMethodOverride(messageHandler.m, otherMessageHandler.m)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnMessage.class, currentClazz)); + } + } + } + if (!found) { + onMessage.add(messageHandler); + } + } else { + // Method not annotated + } + } + currentClazz = currentClazz.getSuperclass(); + } + // If the methods are not on clazzPojo and they are overridden + // by a non annotated method in clazzPojo, they should be ignored + if (open != null && open.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, open, OnOpen.class)) { + open = null; + } + } + if (close != null && close.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, close, OnClose.class)) { + close = null; + } + } + if (error != null && error.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, error, OnError.class)) { + error = null; + } + } + List overriddenOnMessage = new ArrayList<>(); + for (MessageHandlerInfo messageHandler : onMessage) { + if (messageHandler.m.getDeclaringClass() != clazzPojo + && isOverridenWithoutAnnotation(clazzPojoMethods, messageHandler.m, OnMessage.class)) { + overriddenOnMessage.add(messageHandler); + } + } + for (MessageHandlerInfo messageHandler : overriddenOnMessage) { + onMessage.remove(messageHandler); + } + this.onOpen = open; + this.onClose = close; + this.onError = error; + onOpenParams = getPathParams(onOpen, MethodType.ON_OPEN); + onCloseParams = getPathParams(onClose, MethodType.ON_CLOSE); + onErrorParams = getPathParams(onError, MethodType.ON_ERROR); + } + + + private void checkPublic(Method m) throws DeploymentException { + if (!Modifier.isPublic(m.getModifiers())) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.methodNotPublic", m.getName())); + } + } + + + private boolean isMethodOverride(Method method1, Method method2) { + return method1.getName().equals(method2.getName()) + && method1.getReturnType().equals(method2.getReturnType()) + && Arrays.equals(method1.getParameterTypes(), method2.getParameterTypes()); + } + + + private boolean isOverridenWithoutAnnotation(Method[] methods, + Method superclazzMethod, Class annotation) { + for (Method method : methods) { + if (isMethodOverride(method, superclazzMethod) + && (method.getAnnotation(annotation) == null)) { + return true; + } + } + return false; + } + + + public String getWsPath() { + return wsPath; + } + + + public Method getOnOpen() { + return onOpen; + } + + + public Object[] getOnOpenArgs(Map pathParameters, + Session session, EndpointConfig config) throws DecodeException { + return buildArgs(onOpenParams, pathParameters, session, config, null, + null); + } + + + public Method getOnClose() { + return onClose; + } + + + public Object[] getOnCloseArgs(Map pathParameters, + Session session, CloseReason closeReason) throws DecodeException { + return buildArgs(onCloseParams, pathParameters, session, null, null, + closeReason); + } + + + public Method getOnError() { + return onError; + } + + + public Object[] getOnErrorArgs(Map pathParameters, + Session session, Throwable throwable) throws DecodeException { + return buildArgs(onErrorParams, pathParameters, session, null, + throwable, null); + } + + + public boolean hasMessageHandlers() { + return !onMessage.isEmpty(); + } + + + public Set getMessageHandlers(Object pojo, + Map pathParameters, Session session, + EndpointConfig config) { + Set result = new HashSet<>(); + for (MessageHandlerInfo messageMethod : onMessage) { + result.addAll(messageMethod.getMessageHandlers(pojo, pathParameters, + session, config)); + } + return result; + } + + + private static PojoPathParam[] getPathParams(Method m, + MethodType methodType) throws DeploymentException { + if (m == null) { + return new PojoPathParam[0]; + } + boolean foundThrowable = false; + Class[] types = m.getParameterTypes(); + Annotation[][] paramsAnnotations = m.getParameterAnnotations(); + PojoPathParam[] result = new PojoPathParam[types.length]; + for (int i = 0; i < types.length; i++) { + Class type = types[i]; + if (type.equals(Session.class)) { + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_OPEN && + type.equals(EndpointConfig.class)) { + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_ERROR + && type.equals(Throwable.class)) { + foundThrowable = true; + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_CLOSE && + type.equals(CloseReason.class)) { + result[i] = new PojoPathParam(type, null); + } else { + Annotation[] paramAnnotations = paramsAnnotations[i]; + for (Annotation paramAnnotation : paramAnnotations) { + if (paramAnnotation.annotationType().equals( + PathParam.class)) { + // Check that the type is valid. "0" coerces to every + // valid type + try { + Util.coerceToType(type, "0"); + } catch (IllegalArgumentException iae) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.invalidPathParamType"), + iae); + } + result[i] = new PojoPathParam(type, + ((PathParam) paramAnnotation).value()); + break; + } + } + // Parameters without annotations are not permitted + if (result[i] == null) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.paramWithoutAnnotation", + type, m.getName(), m.getClass().getName())); + } + } + } + if (methodType == MethodType.ON_ERROR && !foundThrowable) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.onErrorNoThrowable", + m.getName(), m.getDeclaringClass().getName())); + } + return result; + } + + + private static Object[] buildArgs(PojoPathParam[] pathParams, + Map pathParameters, Session session, + EndpointConfig config, Throwable throwable, CloseReason closeReason) + throws DecodeException { + Object[] result = new Object[pathParams.length]; + for (int i = 0; i < pathParams.length; i++) { + Class type = pathParams[i].getType(); + if (type.equals(Session.class)) { + result[i] = session; + } else if (type.equals(EndpointConfig.class)) { + result[i] = config; + } else if (type.equals(Throwable.class)) { + result[i] = throwable; + } else if (type.equals(CloseReason.class)) { + result[i] = closeReason; + } else { + String name = pathParams[i].getName(); + String value = pathParameters.get(name); + try { + result[i] = Util.coerceToType(type, value); + } catch (Exception e) { + throw new DecodeException(value, sm.getString( + "pojoMethodMapping.decodePathParamFail", + value, type), e); + } + } + } + return result; + } + + + private static class MessageHandlerInfo { + + private final Method m; + private int indexString = -1; + private int indexByteArray = -1; + private int indexByteBuffer = -1; + private int indexPong = -1; + private int indexBoolean = -1; + private int indexSession = -1; + private int indexInputStream = -1; + private int indexReader = -1; + private int indexPrimitive = -1; + private Class primitiveType = null; + private Map indexPathParams = new HashMap<>(); + private int indexPayload = -1; + private DecoderMatch decoderMatch = null; + private long maxMessageSize = -1; + + public MessageHandlerInfo(Method m, List decoderEntries) { + this.m = m; + + Class[] types = m.getParameterTypes(); + Annotation[][] paramsAnnotations = m.getParameterAnnotations(); + + for (int i = 0; i < types.length; i++) { + boolean paramFound = false; + Annotation[] paramAnnotations = paramsAnnotations[i]; + for (Annotation paramAnnotation : paramAnnotations) { + if (paramAnnotation.annotationType().equals( + PathParam.class)) { + indexPathParams.put( + Integer.valueOf(i), new PojoPathParam(types[i], + ((PathParam) paramAnnotation).value())); + paramFound = true; + break; + } + } + if (paramFound) { + continue; + } + if (String.class.isAssignableFrom(types[i])) { + if (indexString == -1) { + indexString = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Reader.class.isAssignableFrom(types[i])) { + if (indexReader == -1) { + indexReader = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (boolean.class == types[i]) { + if (indexBoolean == -1) { + indexBoolean = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateLastParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (ByteBuffer.class.isAssignableFrom(types[i])) { + if (indexByteBuffer == -1) { + indexByteBuffer = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (byte[].class == types[i]) { + if (indexByteArray == -1) { + indexByteArray = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (InputStream.class.isAssignableFrom(types[i])) { + if (indexInputStream == -1) { + indexInputStream = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Util.isPrimitive(types[i])) { + if (indexPrimitive == -1) { + indexPrimitive = i; + primitiveType = types[i]; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Session.class.isAssignableFrom(types[i])) { + if (indexSession == -1) { + indexSession = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateSessionParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (PongMessage.class.isAssignableFrom(types[i])) { + if (indexPong == -1) { + indexPong = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicatePongMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else { + if (decoderMatch != null && decoderMatch.hasMatches()) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + decoderMatch = new DecoderMatch(types[i], decoderEntries); + + if (decoderMatch.hasMatches()) { + indexPayload = i; + } + } + } + + // Additional checks required + if (indexString != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexString; + } + } + if (indexReader != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexReader; + } + } + if (indexByteArray != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexByteArray; + } + } + if (indexByteBuffer != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexByteBuffer; + } + } + if (indexInputStream != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexInputStream; + } + } + if (indexPrimitive != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexPrimitive; + } + } + if (indexPong != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.pongWithPayload", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexPong; + } + } + if (indexPayload == -1 && indexPrimitive == -1 && + indexBoolean != -1) { + // The boolean we found is a payload, not a last flag + indexPayload = indexBoolean; + indexPrimitive = indexBoolean; + primitiveType = Boolean.TYPE; + indexBoolean = -1; + } + if (indexPayload == -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.noPayload", + m.getName(), m.getDeclaringClass().getName())); + } + if (indexPong != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialPong", + m.getName(), m.getDeclaringClass().getName())); + } + if(indexReader != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialReader", + m.getName(), m.getDeclaringClass().getName())); + } + if(indexInputStream != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialInputStream", + m.getName(), m.getDeclaringClass().getName())); + } + if (decoderMatch != null && decoderMatch.hasMatches() && + indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialObject", + m.getName(), m.getDeclaringClass().getName())); + } + + maxMessageSize = m.getAnnotation(OnMessage.class).maxMessageSize(); + } + + + public boolean targetsSameWebSocketMessageType(MessageHandlerInfo otherHandler) { + if (otherHandler == null) { + return false; + } + if (indexByteArray >= 0 && otherHandler.indexByteArray >= 0) { + return true; + } + if (indexByteBuffer >= 0 && otherHandler.indexByteBuffer >= 0) { + return true; + } + if (indexInputStream >= 0 && otherHandler.indexInputStream >= 0) { + return true; + } + if (indexPong >= 0 && otherHandler.indexPong >= 0) { + return true; + } + if (indexPrimitive >= 0 && otherHandler.indexPrimitive >= 0 + && primitiveType == otherHandler.primitiveType) { + return true; + } + if (indexReader >= 0 && otherHandler.indexReader >= 0) { + return true; + } + if (indexString >= 0 && otherHandler.indexString >= 0) { + return true; + } + if (decoderMatch != null && otherHandler.decoderMatch != null + && decoderMatch.getTarget().equals(otherHandler.decoderMatch.getTarget())) { + return true; + } + return false; + } + + + public Set getMessageHandlers(Object pojo, + Map pathParameters, Session session, + EndpointConfig config) { + Object[] params = new Object[m.getParameterTypes().length]; + + for (Map.Entry entry : + indexPathParams.entrySet()) { + PojoPathParam pathParam = entry.getValue(); + String valueString = pathParameters.get(pathParam.getName()); + Object value = null; + try { + value = Util.coerceToType(pathParam.getType(), valueString); + } catch (Exception e) { + DecodeException de = new DecodeException(valueString, + sm.getString( + "pojoMethodMapping.decodePathParamFail", + valueString, pathParam.getType()), e); + params = new Object[] { de }; + break; + } + params[entry.getKey().intValue()] = value; + } + + Set results = new HashSet<>(2); + if (indexBoolean == -1) { + // Basic + if (indexString != -1 || indexPrimitive != -1) { + MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m, + session, config, null, params, indexPayload, false, + indexSession, maxMessageSize); + results.add(mh); + } else if (indexReader != -1) { + MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m, + session, config, null, params, indexReader, true, + indexSession, maxMessageSize); + results.add(mh); + } else if (indexByteArray != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexByteArray, + true, indexSession, false, maxMessageSize); + results.add(mh); + } else if (indexByteBuffer != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexByteBuffer, + false, indexSession, false, maxMessageSize); + results.add(mh); + } else if (indexInputStream != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexInputStream, + true, indexSession, true, maxMessageSize); + results.add(mh); + } else if (decoderMatch != null && decoderMatch.hasMatches()) { + if (decoderMatch.getBinaryDecoders().size() > 0) { + MessageHandler mh = new PojoMessageHandlerWholeBinary( + pojo, m, session, config, + decoderMatch.getBinaryDecoders(), params, + indexPayload, true, indexSession, true, + maxMessageSize); + results.add(mh); + } + if (decoderMatch.getTextDecoders().size() > 0) { + MessageHandler mh = new PojoMessageHandlerWholeText( + pojo, m, session, config, + decoderMatch.getTextDecoders(), params, + indexPayload, true, indexSession, maxMessageSize); + results.add(mh); + } + } else { + MessageHandler mh = new PojoMessageHandlerWholePong(pojo, m, + session, params, indexPong, false, indexSession); + results.add(mh); + } + } else { + // ASync + if (indexString != -1) { + MessageHandler mh = new PojoMessageHandlerPartialText(pojo, + m, session, params, indexString, false, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } else if (indexByteArray != -1) { + MessageHandler mh = new PojoMessageHandlerPartialBinary( + pojo, m, session, params, indexByteArray, true, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } else { + MessageHandler mh = new PojoMessageHandlerPartialBinary( + pojo, m, session, params, indexByteBuffer, false, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } + } + return results; + } + } + + + private enum MethodType { + ON_OPEN, + ON_CLOSE, + ON_ERROR + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoPathParam.java b/src/java/nginx/unit/websocket/pojo/PojoPathParam.java new file mode 100644 index 00000000..859b6d68 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoPathParam.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +/** + * Stores the parameter type and name for a parameter that needs to be passed to + * an onXxx method of {@link javax.websocket.Endpoint}. The name is only present + * for parameters annotated with + * {@link javax.websocket.server.PathParam}. For the + * {@link javax.websocket.Session} and {@link java.lang.Throwable} parameters, + * {@link #getName()} will always return null. + */ +public class PojoPathParam { + + private final Class type; + private final String name; + + + public PojoPathParam(Class type, String name) { + this.type = type; + this.name = name; + } + + + public Class getType() { + return type; + } + + + public String getName() { + return name; + } +} diff --git a/src/java/nginx/unit/websocket/pojo/package-info.java b/src/java/nginx/unit/websocket/pojo/package-info.java new file mode 100644 index 00000000..39cf80c8 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * This package provides the necessary plumbing to convert an annotated POJO + * into a WebSocket {@link javax.websocket.Endpoint}. + */ +package nginx.unit.websocket.pojo; diff --git a/src/java/nginx/unit/websocket/server/Constants.java b/src/java/nginx/unit/websocket/server/Constants.java new file mode 100644 index 00000000..5210c4ba --- /dev/null +++ b/src/java/nginx/unit/websocket/server/Constants.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +/** + * Internal implementation constants. + */ +public class Constants { + + public static final String BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM = + "nginx.unit.websocket.binaryBufferSize"; + public static final String TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM = + "nginx.unit.websocket.textBufferSize"; + public static final String ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM = + "nginx.unit.websocket.noAddAfterHandshake"; + + public static final String SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE = + "javax.websocket.server.ServerContainer"; + + + private Constants() { + // Hide default constructor + } +} diff --git a/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java b/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java new file mode 100644 index 00000000..43ffe2bc --- /dev/null +++ b/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; +import javax.websocket.server.HandshakeRequest; +import javax.websocket.server.ServerEndpointConfig; + +public class DefaultServerEndpointConfigurator + extends ServerEndpointConfig.Configurator { + + @Override + public T getEndpointInstance(Class clazz) + throws InstantiationException { + try { + return clazz.getConstructor().newInstance(); + } catch (InstantiationException e) { + throw e; + } catch (ReflectiveOperationException e) { + InstantiationException ie = new InstantiationException(); + ie.initCause(e); + throw ie; + } + } + + + @Override + public String getNegotiatedSubprotocol(List supported, + List requested) { + + for (String request : requested) { + if (supported.contains(request)) { + return request; + } + } + return ""; + } + + + @Override + public List getNegotiatedExtensions(List installed, + List requested) { + Set installedNames = new HashSet<>(); + for (Extension e : installed) { + installedNames.add(e.getName()); + } + List result = new ArrayList<>(); + for (Extension request : requested) { + if (installedNames.contains(request.getName())) { + result.add(request); + } + } + return result; + } + + + @Override + public boolean checkOrigin(String originHeaderValue) { + return true; + } + + @Override + public void modifyHandshake(ServerEndpointConfig sec, + HandshakeRequest request, HandshakeResponse response) { + // NO-OP + } + +} diff --git a/src/java/nginx/unit/websocket/server/LocalStrings.properties b/src/java/nginx/unit/websocket/server/LocalStrings.properties new file mode 100644 index 00000000..5bc12501 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/LocalStrings.properties @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +serverContainer.addNotAllowed=No further Endpoints may be registered once an attempt has been made to use one of the previously registered endpoints +serverContainer.configuratorFail=Failed to create configurator of type [{0}] for POJO of type [{1}] +serverContainer.duplicatePaths=Multiple Endpoints may not be deployed to the same path [{0}] : existing endpoint was [{1}] and new endpoint is [{2}] +serverContainer.encoderFail=Unable to create encoder of type [{0}] +serverContainer.endpointDeploy=Endpoint class [{0}] deploying to path [{1}] in ServletContext [{2}] +serverContainer.missingAnnotation=Cannot deploy POJO class [{0}] as it is not annotated with @ServerEndpoint +serverContainer.missingEndpoint=An Endpoint instance has been request for path [{0}] but no matching Endpoint class was found +serverContainer.pojoDeploy=POJO class [{0}] deploying to path [{1}] in ServletContext [{2}] +serverContainer.servletContextMismatch=Attempted to register a POJO annotated for WebSocket at path [{0}] in the ServletContext with context path [{1}] when the WebSocket ServerContainer is allocated to the ServletContext with context path [{2}] +serverContainer.servletContextMissing=No ServletContext was specified + +upgradeUtil.incompatibleRsv=Extensions were specified that have incompatible RSV bit usage + +uriTemplate.duplicateParameter=The parameter [{0}] appears more than once in the path which is not permitted +uriTemplate.emptySegment=The path [{0}] contains one or more empty segments which are is not permitted +uriTemplate.invalidPath=The path [{0}] is not valid. +uriTemplate.invalidSegment=The segment [{0}] is not valid in the provided path [{1}] + +wsFrameServer.bytesRead=Read [{0}] bytes into input buffer ready for processing +wsFrameServer.illegalReadState=Unexpected read state [{0}] +wsFrameServer.onDataAvailable=Method entry + +wsHttpUpgradeHandler.closeOnError=Closing WebSocket connection due to an error +wsHttpUpgradeHandler.destroyFailed=Failed to close WebConnection while destroying the WebSocket HttpUpgradeHandler +wsHttpUpgradeHandler.noPreInit=The preInit() method must be called to configure the WebSocket HttpUpgradeHandler before the container calls init(). Usually, this means the Servlet that created the WsHttpUpgradeHandler instance should also call preInit() +wsHttpUpgradeHandler.serverStop=The server is stopping + +wsRemoteEndpointServer.closeFailed=Failed to close the ServletOutputStream connection cleanly diff --git a/src/java/nginx/unit/websocket/server/UpgradeUtil.java b/src/java/nginx/unit/websocket/server/UpgradeUtil.java new file mode 100644 index 00000000..162f01c7 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/UpgradeUtil.java @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; +import javax.websocket.server.ServerEndpointConfig; + +import nginx.unit.Request; + +import org.apache.tomcat.util.codec.binary.Base64; +import org.apache.tomcat.util.res.StringManager; +import org.apache.tomcat.util.security.ConcurrentMessageDigest; +import nginx.unit.websocket.Constants; +import nginx.unit.websocket.Transformation; +import nginx.unit.websocket.TransformationFactory; +import nginx.unit.websocket.Util; +import nginx.unit.websocket.WsHandshakeResponse; +import nginx.unit.websocket.pojo.PojoEndpointServer; + +public class UpgradeUtil { + + private static final StringManager sm = + StringManager.getManager(UpgradeUtil.class.getPackage().getName()); + private static final byte[] WS_ACCEPT = + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes( + StandardCharsets.ISO_8859_1); + + private UpgradeUtil() { + // Utility class. Hide default constructor. + } + + /** + * Checks to see if this is an HTTP request that includes a valid upgrade + * request to web socket. + *

+ * Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java + * WebSocket spec 1.0, section 8.2 implies such a limitation and RFC + * 6455 section 4.1 requires that a WebSocket Upgrade uses GET. + * @param request The request to check if it is an HTTP upgrade request for + * a WebSocket connection + * @param response The response associated with the request + * @return true if the request includes a HTTP Upgrade request + * for the WebSocket protocol, otherwise false + */ + public static boolean isWebSocketUpgradeRequest(ServletRequest request, + ServletResponse response) { + + Request r = (Request) request.getAttribute(Request.BARE); + + return ((request instanceof HttpServletRequest) && + (response instanceof HttpServletResponse) && + (r != null) && + (r.isUpgrade())); + } + + + public static void doUpgrade(WsServerContainer sc, HttpServletRequest req, + HttpServletResponse resp, ServerEndpointConfig sec, + Map pathParams) + throws ServletException, IOException { + + + // Origin check + String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME); + + if (!sec.getConfigurator().checkOrigin(origin)) { + resp.sendError(HttpServletResponse.SC_FORBIDDEN); + return; + } + // Sub-protocols + List subProtocols = getTokensFromHeader(req, + Constants.WS_PROTOCOL_HEADER_NAME); + String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol( + sec.getSubprotocols(), subProtocols); + + // Extensions + // Should normally only be one header but handle the case of multiple + // headers + List extensionsRequested = new ArrayList<>(); + Enumeration extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME); + while (extHeaders.hasMoreElements()) { + Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement()); + } + + // Negotiation phase 1. By default this simply filters out the + // extensions that the server does not support but applications could + // use a custom configurator to do more than this. + List installedExtensions = null; + if (sec.getExtensions().size() == 0) { + installedExtensions = Constants.INSTALLED_EXTENSIONS; + } else { + installedExtensions = new ArrayList<>(); + installedExtensions.addAll(sec.getExtensions()); + installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS); + } + List negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions( + installedExtensions, extensionsRequested); + + // Negotiation phase 2. Create the Transformations that will be applied + // to this connection. Note than an extension may be dropped at this + // point if the client has requested a configuration that the server is + // unable to support. + List transformations = createTransformations(negotiatedExtensionsPhase1); + + List negotiatedExtensionsPhase2; + if (transformations.isEmpty()) { + negotiatedExtensionsPhase2 = Collections.emptyList(); + } else { + negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size()); + for (Transformation t : transformations) { + negotiatedExtensionsPhase2.add(t.getExtensionResponse()); + } + } + + WsHttpUpgradeHandler wsHandler = + req.upgrade(WsHttpUpgradeHandler.class); + + WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams); + WsHandshakeResponse wsResponse = new WsHandshakeResponse(); + WsPerSessionServerEndpointConfig perSessionServerEndpointConfig = + new WsPerSessionServerEndpointConfig(sec); + sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig, + wsRequest, wsResponse); + //wsRequest.finished(); + + // Add any additional headers + for (Entry> entry : + wsResponse.getHeaders().entrySet()) { + for (String headerValue: entry.getValue()) { + resp.addHeader(entry.getKey(), headerValue); + } + } + + Endpoint ep; + try { + Class clazz = sec.getEndpointClass(); + if (Endpoint.class.isAssignableFrom(clazz)) { + ep = (Endpoint) sec.getConfigurator().getEndpointInstance( + clazz); + } else { + ep = new PojoEndpointServer(); + // Need to make path params available to POJO + perSessionServerEndpointConfig.getUserProperties().put( + nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams); + } + } catch (InstantiationException e) { + throw new ServletException(e); + } + + wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest, + negotiatedExtensionsPhase2, subProtocol, null, pathParams, + req.isSecure()); + + wsHandler.init(null); + } + + + private static List createTransformations( + List negotiatedExtensions) { + + TransformationFactory factory = TransformationFactory.getInstance(); + + LinkedHashMap>> extensionPreferences = + new LinkedHashMap<>(); + + // Result will likely be smaller than this + List result = new ArrayList<>(negotiatedExtensions.size()); + + for (Extension extension : negotiatedExtensions) { + List> preferences = + extensionPreferences.get(extension.getName()); + + if (preferences == null) { + preferences = new ArrayList<>(); + extensionPreferences.put(extension.getName(), preferences); + } + + preferences.add(extension.getParameters()); + } + + for (Map.Entry>> entry : + extensionPreferences.entrySet()) { + Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true); + if (transformation != null) { + result.add(transformation); + } + } + return result; + } + + + private static void append(StringBuilder sb, Extension extension) { + if (extension == null || extension.getName() == null || extension.getName().length() == 0) { + return; + } + + sb.append(extension.getName()); + + for (Extension.Parameter p : extension.getParameters()) { + sb.append(';'); + sb.append(p.getName()); + if (p.getValue() != null) { + sb.append('='); + sb.append(p.getValue()); + } + } + } + + + /* + * This only works for tokens. Quoted strings need more sophisticated + * parsing. + */ + private static boolean headerContainsToken(HttpServletRequest req, + String headerName, String target) { + Enumeration headers = req.getHeaders(headerName); + while (headers.hasMoreElements()) { + String header = headers.nextElement(); + String[] tokens = header.split(","); + for (String token : tokens) { + if (target.equalsIgnoreCase(token.trim())) { + return true; + } + } + } + return false; + } + + + /* + * This only works for tokens. Quoted strings need more sophisticated + * parsing. + */ + private static List getTokensFromHeader(HttpServletRequest req, + String headerName) { + List result = new ArrayList<>(); + Enumeration headers = req.getHeaders(headerName); + while (headers.hasMoreElements()) { + String header = headers.nextElement(); + String[] tokens = header.split(","); + for (String token : tokens) { + result.add(token.trim()); + } + } + return result; + } + + + private static String getWebSocketAccept(String key) { + byte[] digest = ConcurrentMessageDigest.digestSHA1( + key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT); + return Base64.encodeBase64String(digest); + } +} diff --git a/src/java/nginx/unit/websocket/server/UriTemplate.java b/src/java/nginx/unit/websocket/server/UriTemplate.java new file mode 100644 index 00000000..7877fac9 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/UriTemplate.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.websocket.DeploymentException; + +import org.apache.tomcat.util.res.StringManager; + +/** + * Extracts path parameters from URIs used to create web socket connections + * using the URI template defined for the associated Endpoint. + */ +public class UriTemplate { + + private static final StringManager sm = StringManager.getManager(UriTemplate.class); + + private final String normalized; + private final List segments = new ArrayList<>(); + private final boolean hasParameters; + + + public UriTemplate(String path) throws DeploymentException { + + if (path == null || path.length() ==0 || !path.startsWith("/")) { + throw new DeploymentException( + sm.getString("uriTemplate.invalidPath", path)); + } + + StringBuilder normalized = new StringBuilder(path.length()); + Set paramNames = new HashSet<>(); + + // Include empty segments. + String[] segments = path.split("/", -1); + int paramCount = 0; + int segmentCount = 0; + + for (int i = 0; i < segments.length; i++) { + String segment = segments[i]; + if (segment.length() == 0) { + if (i == 0 || (i == segments.length - 1 && paramCount == 0)) { + // Ignore the first empty segment as the path must always + // start with '/' + // Ending with a '/' is also OK for instances used for + // matches but not for parameterised templates. + continue; + } else { + // As per EG discussion, all other empty segments are + // invalid + throw new IllegalArgumentException(sm.getString( + "uriTemplate.emptySegment", path)); + } + } + normalized.append('/'); + int index = -1; + if (segment.startsWith("{") && segment.endsWith("}")) { + index = segmentCount; + segment = segment.substring(1, segment.length() - 1); + normalized.append('{'); + normalized.append(paramCount++); + normalized.append('}'); + if (!paramNames.add(segment)) { + throw new IllegalArgumentException(sm.getString( + "uriTemplate.duplicateParameter", segment)); + } + } else { + if (segment.contains("{") || segment.contains("}")) { + throw new IllegalArgumentException(sm.getString( + "uriTemplate.invalidSegment", segment, path)); + } + normalized.append(segment); + } + this.segments.add(new Segment(index, segment)); + segmentCount++; + } + + this.normalized = normalized.toString(); + this.hasParameters = paramCount > 0; + } + + + public Map match(UriTemplate candidate) { + + Map result = new HashMap<>(); + + // Should not happen but for safety + if (candidate.getSegmentCount() != getSegmentCount()) { + return null; + } + + Iterator candidateSegments = + candidate.getSegments().iterator(); + Iterator targetSegments = segments.iterator(); + + while (candidateSegments.hasNext()) { + Segment candidateSegment = candidateSegments.next(); + Segment targetSegment = targetSegments.next(); + + if (targetSegment.getParameterIndex() == -1) { + // Not a parameter - values must match + if (!targetSegment.getValue().equals( + candidateSegment.getValue())) { + // Not a match. Stop here + return null; + } + } else { + // Parameter + result.put(targetSegment.getValue(), + candidateSegment.getValue()); + } + } + + return result; + } + + + public boolean hasParameters() { + return hasParameters; + } + + + public int getSegmentCount() { + return segments.size(); + } + + + public String getNormalizedPath() { + return normalized; + } + + + private List getSegments() { + return segments; + } + + + private static class Segment { + private final int parameterIndex; + private final String value; + + public Segment(int parameterIndex, String value) { + this.parameterIndex = parameterIndex; + this.value = value; + } + + + public int getParameterIndex() { + return parameterIndex; + } + + + public String getValue() { + return value; + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsContextListener.java b/src/java/nginx/unit/websocket/server/WsContextListener.java new file mode 100644 index 00000000..07137856 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsContextListener.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import javax.servlet.ServletContext; +import javax.servlet.ServletContextEvent; +import javax.servlet.ServletContextListener; + +/** + * In normal usage, this {@link ServletContextListener} does not need to be + * explicitly configured as the {@link WsSci} performs all the necessary + * bootstrap and installs this listener in the {@link ServletContext}. If the + * {@link WsSci} is disabled, this listener must be added manually to every + * {@link ServletContext} that uses WebSocket to bootstrap the + * {@link WsServerContainer} correctly. + */ +public class WsContextListener implements ServletContextListener { + + @Override + public void contextInitialized(ServletContextEvent sce) { + ServletContext sc = sce.getServletContext(); + // Don't trigger WebSocket initialization if a WebSocket Server + // Container is already present + if (sc.getAttribute(Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE) == null) { + WsSci.init(sce.getServletContext(), false); + } + } + + @Override + public void contextDestroyed(ServletContextEvent sce) { + ServletContext sc = sce.getServletContext(); + Object obj = sc.getAttribute(Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + if (obj instanceof WsServerContainer) { + ((WsServerContainer) obj).destroy(); + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsFilter.java b/src/java/nginx/unit/websocket/server/WsFilter.java new file mode 100644 index 00000000..abea71fc --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsFilter.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.GenericFilter; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * Handles the initial HTTP connection for WebSocket connections. + */ +public class WsFilter extends GenericFilter { + + private static final long serialVersionUID = 1L; + + private transient WsServerContainer sc; + + + @Override + public void init() throws ServletException { + sc = (WsServerContainer) getServletContext().getAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + } + + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + + // This filter only needs to handle WebSocket upgrade requests + if (!sc.areEndpointsRegistered() || + !UpgradeUtil.isWebSocketUpgradeRequest(request, response)) { + chain.doFilter(request, response); + return; + } + + // HTTP request with an upgrade header for WebSocket present + HttpServletRequest req = (HttpServletRequest) request; + HttpServletResponse resp = (HttpServletResponse) response; + + // Check to see if this WebSocket implementation has a matching mapping + String path; + String pathInfo = req.getPathInfo(); + if (pathInfo == null) { + path = req.getServletPath(); + } else { + path = req.getServletPath() + pathInfo; + } + WsMappingResult mappingResult = sc.findMapping(path); + + if (mappingResult == null) { + // No endpoint registered for the requested path. Let the + // application handle it (it might redirect or forward for example) + chain.doFilter(request, response); + return; + } + + UpgradeUtil.doUpgrade(sc, req, resp, mappingResult.getConfig(), + mappingResult.getPathParams()); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java b/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java new file mode 100644 index 00000000..fa774302 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.net.URI; +import java.net.URISyntaxException; +import java.security.Principal; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.servlet.http.HttpServletRequest; +import javax.websocket.server.HandshakeRequest; + +import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; +import org.apache.tomcat.util.res.StringManager; + +/** + * Represents the request that this session was opened under. + */ +public class WsHandshakeRequest implements HandshakeRequest { + + private static final StringManager sm = StringManager.getManager(WsHandshakeRequest.class); + + private final URI requestUri; + private final Map> parameterMap; + private final String queryString; + private final Principal userPrincipal; + private final Map> headers; + private final Object httpSession; + + private volatile HttpServletRequest request; + + + public WsHandshakeRequest(HttpServletRequest request, Map pathParams) { + + this.request = request; + + queryString = request.getQueryString(); + userPrincipal = request.getUserPrincipal(); + httpSession = request.getSession(false); + requestUri = buildRequestUri(request); + + // ParameterMap + Map originalParameters = request.getParameterMap(); + Map> newParameters = + new HashMap<>(originalParameters.size()); + for (Entry entry : originalParameters.entrySet()) { + newParameters.put(entry.getKey(), + Collections.unmodifiableList( + Arrays.asList(entry.getValue()))); + } + for (Entry entry : pathParams.entrySet()) { + newParameters.put(entry.getKey(), + Collections.unmodifiableList( + Collections.singletonList(entry.getValue()))); + } + parameterMap = Collections.unmodifiableMap(newParameters); + + // Headers + Map> newHeaders = new CaseInsensitiveKeyMap<>(); + + Enumeration headerNames = request.getHeaderNames(); + while (headerNames.hasMoreElements()) { + String headerName = headerNames.nextElement(); + + newHeaders.put(headerName, Collections.unmodifiableList( + Collections.list(request.getHeaders(headerName)))); + } + + headers = Collections.unmodifiableMap(newHeaders); + } + + @Override + public URI getRequestURI() { + return requestUri; + } + + @Override + public Map> getParameterMap() { + return parameterMap; + } + + @Override + public String getQueryString() { + return queryString; + } + + @Override + public Principal getUserPrincipal() { + return userPrincipal; + } + + @Override + public Map> getHeaders() { + return headers; + } + + @Override + public boolean isUserInRole(String role) { + if (request == null) { + throw new IllegalStateException(); + } + + return request.isUserInRole(role); + } + + @Override + public Object getHttpSession() { + return httpSession; + } + + /** + * Called when the HandshakeRequest is no longer required. Since an instance + * of this class retains a reference to the current HttpServletRequest that + * reference needs to be cleared as the HttpServletRequest may be reused. + * + * There is no reason for instances of this class to be accessed once the + * handshake has been completed. + */ + void finished() { + request = null; + } + + + /* + * See RequestUtil.getRequestURL() + */ + private static URI buildRequestUri(HttpServletRequest req) { + + StringBuffer uri = new StringBuffer(); + String scheme = req.getScheme(); + int port = req.getServerPort(); + if (port < 0) { + // Work around java.net.URL bug + port = 80; + } + + if ("http".equals(scheme)) { + uri.append("ws"); + } else if ("https".equals(scheme)) { + uri.append("wss"); + } else { + // Should never happen + throw new IllegalArgumentException( + sm.getString("wsHandshakeRequest.unknownScheme", scheme)); + } + + uri.append("://"); + uri.append(req.getServerName()); + + if ((scheme.equals("http") && (port != 80)) + || (scheme.equals("https") && (port != 443))) { + uri.append(':'); + uri.append(port); + } + + uri.append(req.getRequestURI()); + + if (req.getQueryString() != null) { + uri.append("?"); + uri.append(req.getQueryString()); + } + + try { + return new URI(uri.toString()); + } catch (URISyntaxException e) { + // Should never happen + throw new IllegalArgumentException( + sm.getString("wsHandshakeRequest.invalidUri", uri.toString()), e); + } + } + + public Object getAttribute(String name) + { + return request != null ? request.getAttribute(name) : null; + } +} diff --git a/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java b/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java new file mode 100644 index 00000000..cc39ab73 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import javax.servlet.http.HttpSession; +import javax.servlet.http.HttpUpgradeHandler; +import javax.servlet.http.WebConnection; +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; + +import nginx.unit.websocket.Transformation; +import nginx.unit.websocket.WsIOException; +import nginx.unit.websocket.WsSession; + +import nginx.unit.Request; + +/** + * Servlet 3.1 HTTP upgrade handler for WebSocket connections. + */ +public class WsHttpUpgradeHandler implements HttpUpgradeHandler { + + private final Log log = LogFactory.getLog(WsHttpUpgradeHandler.class); // must not be static + private static final StringManager sm = StringManager.getManager(WsHttpUpgradeHandler.class); + + private final ClassLoader applicationClassLoader; + + private Endpoint ep; + private EndpointConfig endpointConfig; + private WsServerContainer webSocketContainer; + private WsHandshakeRequest handshakeRequest; + private List negotiatedExtensions; + private String subProtocol; + private Transformation transformation; + private Map pathParameters; + private boolean secure; + private WebConnection connection; + private WsRemoteEndpointImplServer wsRemoteEndpointServer; + private WsSession wsSession; + + + public WsHttpUpgradeHandler() { + applicationClassLoader = Thread.currentThread().getContextClassLoader(); + } + + public void preInit(Endpoint ep, EndpointConfig endpointConfig, + WsServerContainer wsc, WsHandshakeRequest handshakeRequest, + List negotiatedExtensionsPhase2, String subProtocol, + Transformation transformation, Map pathParameters, + boolean secure) { + this.ep = ep; + this.endpointConfig = endpointConfig; + this.webSocketContainer = wsc; + this.handshakeRequest = handshakeRequest; + this.negotiatedExtensions = negotiatedExtensionsPhase2; + this.subProtocol = subProtocol; + this.transformation = transformation; + this.pathParameters = pathParameters; + this.secure = secure; + } + + + @Override + public void init(WebConnection connection) { + if (ep == null) { + throw new IllegalStateException( + sm.getString("wsHttpUpgradeHandler.noPreInit")); + } + + String httpSessionId = null; + Object session = handshakeRequest.getHttpSession(); + if (session != null ) { + httpSessionId = ((HttpSession) session).getId(); + } + + nginx.unit.Context.trace("UpgradeHandler.init(" + connection + ")"); + +/* + // Need to call onOpen using the web application's class loader + // Create the frame using the application's class loader so it can pick + // up application specific config from the ServerContainerImpl + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); +*/ + try { + Request r = (Request) handshakeRequest.getAttribute(Request.BARE); + + wsRemoteEndpointServer = new WsRemoteEndpointImplServer(webSocketContainer); + wsSession = new WsSession(ep, wsRemoteEndpointServer, + webSocketContainer, handshakeRequest.getRequestURI(), + handshakeRequest.getParameterMap(), + handshakeRequest.getQueryString(), + handshakeRequest.getUserPrincipal(), httpSessionId, + negotiatedExtensions, subProtocol, pathParameters, secure, + endpointConfig, r); + + ep.onOpen(wsSession, endpointConfig); + webSocketContainer.registerSession(ep, wsSession); + } catch (DeploymentException e) { + throw new IllegalArgumentException(e); +/* + } finally { + t.setContextClassLoader(cl); +*/ + } + } + + + + @Override + public void destroy() { + if (connection != null) { + try { + connection.close(); + } catch (Exception e) { + log.error(sm.getString("wsHttpUpgradeHandler.destroyFailed"), e); + } + } + } + + + private void onError(Throwable throwable) { + // Need to call onError using the web application's class loader + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); + try { + ep.onError(wsSession, throwable); + } finally { + t.setContextClassLoader(cl); + } + } + + + private void close(CloseReason cr) { + /* + * Any call to this method is a result of a problem reading from the + * client. At this point that state of the connection is unknown. + * Attempt to send a close frame to the client and then close the socket + * immediately. There is no point in waiting for a close frame from the + * client because there is no guarantee that we can recover from + * whatever messed up state the client put the connection into. + */ + wsSession.onClose(cr); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsMappingResult.java b/src/java/nginx/unit/websocket/server/WsMappingResult.java new file mode 100644 index 00000000..a7a4c022 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsMappingResult.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.Map; + +import javax.websocket.server.ServerEndpointConfig; + +class WsMappingResult { + + private final ServerEndpointConfig config; + private final Map pathParams; + + + WsMappingResult(ServerEndpointConfig config, + Map pathParams) { + this.config = config; + this.pathParams = pathParams; + } + + + ServerEndpointConfig getConfig() { + return config; + } + + + Map getPathParams() { + return pathParams; + } +} diff --git a/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java b/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java new file mode 100644 index 00000000..2be050cb --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import javax.websocket.Decoder; +import javax.websocket.Encoder; +import javax.websocket.Extension; +import javax.websocket.server.ServerEndpointConfig; + +/** + * Wraps the provided {@link ServerEndpointConfig} and provides a per session + * view - the difference being that the map returned by {@link + * #getUserProperties()} is unique to this instance rather than shared with the + * wrapped {@link ServerEndpointConfig}. + */ +class WsPerSessionServerEndpointConfig implements ServerEndpointConfig { + + private final ServerEndpointConfig perEndpointConfig; + private final Map perSessionUserProperties = + new ConcurrentHashMap<>(); + + WsPerSessionServerEndpointConfig(ServerEndpointConfig perEndpointConfig) { + this.perEndpointConfig = perEndpointConfig; + perSessionUserProperties.putAll(perEndpointConfig.getUserProperties()); + } + + @Override + public List> getEncoders() { + return perEndpointConfig.getEncoders(); + } + + @Override + public List> getDecoders() { + return perEndpointConfig.getDecoders(); + } + + @Override + public Map getUserProperties() { + return perSessionUserProperties; + } + + @Override + public Class getEndpointClass() { + return perEndpointConfig.getEndpointClass(); + } + + @Override + public String getPath() { + return perEndpointConfig.getPath(); + } + + @Override + public List getSubprotocols() { + return perEndpointConfig.getSubprotocols(); + } + + @Override + public List getExtensions() { + return perEndpointConfig.getExtensions(); + } + + @Override + public Configurator getConfigurator() { + return perEndpointConfig.getConfigurator(); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java b/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java new file mode 100644 index 00000000..6d10a3be --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.EOFException; +import java.io.IOException; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.CompletionHandler; +import java.nio.channels.InterruptedByTimeoutException; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; + +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.Transformation; +import nginx.unit.websocket.WsRemoteEndpointImplBase; + +/** + * This is the server side {@link javax.websocket.RemoteEndpoint} implementation + * - i.e. what the server uses to send data to the client. + */ +public class WsRemoteEndpointImplServer extends WsRemoteEndpointImplBase { + + private static final StringManager sm = + StringManager.getManager(WsRemoteEndpointImplServer.class); + private final Log log = LogFactory.getLog(WsRemoteEndpointImplServer.class); // must not be static + + private volatile SendHandler handler = null; + private volatile ByteBuffer[] buffers = null; + + private volatile long timeoutExpiry = -1; + private volatile boolean close; + + public WsRemoteEndpointImplServer( + WsServerContainer serverContainer) { + } + + + @Override + protected final boolean isMasked() { + return false; + } + + @Override + protected void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, + ByteBuffer... buffers) { + } + + @Override + protected void doClose() { + if (handler != null) { + // close() can be triggered by a wide range of scenarios. It is far + // simpler just to always use a dispatch than it is to try and track + // whether or not this method was called by the same thread that + // triggered the write + clearHandler(new EOFException(), true); + } + } + + + protected long getTimeoutExpiry() { + return timeoutExpiry; + } + + + /* + * Currently this is only called from the background thread so we could just + * call clearHandler() with useDispatch == false but the method parameter + * was added in case other callers started to use this method to make sure + * that those callers think through what the correct value of useDispatch is + * for them. + */ + protected void onTimeout(boolean useDispatch) { + if (handler != null) { + clearHandler(new SocketTimeoutException(), useDispatch); + } + close(); + } + + + @Override + protected void setTransformation(Transformation transformation) { + // Overridden purely so it is visible to other classes in this package + super.setTransformation(transformation); + } + + + /** + * + * @param t The throwable associated with any error that + * occurred + * @param useDispatch Should {@link SendHandler#onResult(SendResult)} be + * called from a new thread, keeping in mind the + * requirements of + * {@link javax.websocket.RemoteEndpoint.Async} + */ + private void clearHandler(Throwable t, boolean useDispatch) { + // Setting the result marks this (partial) message as + // complete which means the next one may be sent which + // could update the value of the handler. Therefore, keep a + // local copy before signalling the end of the (partial) + // message. + SendHandler sh = handler; + handler = null; + buffers = null; + if (sh != null) { + if (useDispatch) { + OnResultRunnable r = new OnResultRunnable(sh, t); + } else { + if (t == null) { + sh.onResult(new SendResult()); + } else { + sh.onResult(new SendResult(t)); + } + } + } + } + + + private static class OnResultRunnable implements Runnable { + + private final SendHandler sh; + private final Throwable t; + + private OnResultRunnable(SendHandler sh, Throwable t) { + this.sh = sh; + this.t = t; + } + + @Override + public void run() { + if (t == null) { + sh.onResult(new SendResult()); + } else { + sh.onResult(new SendResult(t)); + } + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsSci.java b/src/java/nginx/unit/websocket/server/WsSci.java new file mode 100644 index 00000000..cdecce27 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsSci.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.lang.reflect.Modifier; +import java.util.HashSet; +import java.util.Set; + +import javax.servlet.ServletContainerInitializer; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.annotation.HandlesTypes; +import javax.websocket.ContainerProvider; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.server.ServerApplicationConfig; +import javax.websocket.server.ServerEndpoint; +import javax.websocket.server.ServerEndpointConfig; + +/** + * Registers an interest in any class that is annotated with + * {@link ServerEndpoint} so that Endpoint can be published via the WebSocket + * server. + */ +@HandlesTypes({ServerEndpoint.class, ServerApplicationConfig.class, + Endpoint.class}) +public class WsSci implements ServletContainerInitializer { + + @Override + public void onStartup(Set> clazzes, ServletContext ctx) + throws ServletException { + + WsServerContainer sc = init(ctx, true); + + if (clazzes == null || clazzes.size() == 0) { + return; + } + + // Group the discovered classes by type + Set serverApplicationConfigs = new HashSet<>(); + Set> scannedEndpointClazzes = new HashSet<>(); + Set> scannedPojoEndpoints = new HashSet<>(); + + try { + // wsPackage is "javax.websocket." + String wsPackage = ContainerProvider.class.getName(); + wsPackage = wsPackage.substring(0, wsPackage.lastIndexOf('.') + 1); + for (Class clazz : clazzes) { + int modifiers = clazz.getModifiers(); + if (!Modifier.isPublic(modifiers) || + Modifier.isAbstract(modifiers)) { + // Non-public or abstract - skip it. + continue; + } + // Protect against scanning the WebSocket API JARs + if (clazz.getName().startsWith(wsPackage)) { + continue; + } + if (ServerApplicationConfig.class.isAssignableFrom(clazz)) { + serverApplicationConfigs.add( + (ServerApplicationConfig) clazz.getConstructor().newInstance()); + } + if (Endpoint.class.isAssignableFrom(clazz)) { + @SuppressWarnings("unchecked") + Class endpoint = + (Class) clazz; + scannedEndpointClazzes.add(endpoint); + } + if (clazz.isAnnotationPresent(ServerEndpoint.class)) { + scannedPojoEndpoints.add(clazz); + } + } + } catch (ReflectiveOperationException e) { + throw new ServletException(e); + } + + // Filter the results + Set filteredEndpointConfigs = new HashSet<>(); + Set> filteredPojoEndpoints = new HashSet<>(); + + if (serverApplicationConfigs.isEmpty()) { + filteredPojoEndpoints.addAll(scannedPojoEndpoints); + } else { + for (ServerApplicationConfig config : serverApplicationConfigs) { + Set configFilteredEndpoints = + config.getEndpointConfigs(scannedEndpointClazzes); + if (configFilteredEndpoints != null) { + filteredEndpointConfigs.addAll(configFilteredEndpoints); + } + Set> configFilteredPojos = + config.getAnnotatedEndpointClasses( + scannedPojoEndpoints); + if (configFilteredPojos != null) { + filteredPojoEndpoints.addAll(configFilteredPojos); + } + } + } + + try { + // Deploy endpoints + for (ServerEndpointConfig config : filteredEndpointConfigs) { + sc.addEndpoint(config); + } + // Deploy POJOs + for (Class clazz : filteredPojoEndpoints) { + sc.addEndpoint(clazz); + } + } catch (DeploymentException e) { + throw new ServletException(e); + } + } + + + static WsServerContainer init(ServletContext servletContext, + boolean initBySciMechanism) { + + WsServerContainer sc = new WsServerContainer(servletContext); + + servletContext.setAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE, sc); + + servletContext.addListener(new WsSessionListener(sc)); + // Can't register the ContextListener again if the ContextListener is + // calling this method + if (initBySciMechanism) { + servletContext.addListener(new WsContextListener()); + } + + return sc; + } +} diff --git a/src/java/nginx/unit/websocket/server/WsServerContainer.java b/src/java/nginx/unit/websocket/server/WsServerContainer.java new file mode 100644 index 00000000..069fc54f --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsServerContainer.java @@ -0,0 +1,470 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.EnumSet; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; +import java.util.concurrent.ConcurrentHashMap; + +import javax.servlet.DispatcherType; +import javax.servlet.FilterRegistration; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Encoder; +import javax.websocket.Endpoint; +import javax.websocket.server.ServerContainer; +import javax.websocket.server.ServerEndpoint; +import javax.websocket.server.ServerEndpointConfig; +import javax.websocket.server.ServerEndpointConfig.Configurator; + +import org.apache.tomcat.InstanceManager; +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.WsSession; +import nginx.unit.websocket.WsWebSocketContainer; +import nginx.unit.websocket.pojo.PojoMethodMapping; + +/** + * Provides a per class loader (i.e. per web application) instance of a + * ServerContainer. Web application wide defaults may be configured by setting + * the following servlet context initialisation parameters to the desired + * values. + *

    + *
  • {@link Constants#BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}
  • + *
  • {@link Constants#TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}
  • + *
+ */ +public class WsServerContainer extends WsWebSocketContainer + implements ServerContainer { + + private static final StringManager sm = StringManager.getManager(WsServerContainer.class); + + private static final CloseReason AUTHENTICATED_HTTP_SESSION_CLOSED = + new CloseReason(CloseCodes.VIOLATED_POLICY, + "This connection was established under an authenticated " + + "HTTP session that has ended."); + + private final ServletContext servletContext; + private final Map configExactMatchMap = + new ConcurrentHashMap<>(); + private final Map> configTemplateMatchMap = + new ConcurrentHashMap<>(); + private volatile boolean enforceNoAddAfterHandshake = + nginx.unit.websocket.Constants.STRICT_SPEC_COMPLIANCE; + private volatile boolean addAllowed = true; + private final Map> authenticatedSessions = new ConcurrentHashMap<>(); + private volatile boolean endpointsRegistered = false; + + WsServerContainer(ServletContext servletContext) { + + this.servletContext = servletContext; + setInstanceManager((InstanceManager) servletContext.getAttribute(InstanceManager.class.getName())); + + // Configure servlet context wide defaults + String value = servletContext.getInitParameter( + Constants.BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM); + if (value != null) { + setDefaultMaxBinaryMessageBufferSize(Integer.parseInt(value)); + } + + value = servletContext.getInitParameter( + Constants.TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM); + if (value != null) { + setDefaultMaxTextMessageBufferSize(Integer.parseInt(value)); + } + + value = servletContext.getInitParameter( + Constants.ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM); + if (value != null) { + setEnforceNoAddAfterHandshake(Boolean.parseBoolean(value)); + } + + FilterRegistration.Dynamic fr = servletContext.addFilter( + "Tomcat WebSocket (JSR356) Filter", new WsFilter()); + fr.setAsyncSupported(true); + + EnumSet types = EnumSet.of(DispatcherType.REQUEST, + DispatcherType.FORWARD); + + fr.addMappingForUrlPatterns(types, true, "/*"); + } + + + /** + * Published the provided endpoint implementation at the specified path with + * the specified configuration. {@link #WsServerContainer(ServletContext)} + * must be called before calling this method. + * + * @param sec The configuration to use when creating endpoint instances + * @throws DeploymentException if the endpoint cannot be published as + * requested + */ + @Override + public void addEndpoint(ServerEndpointConfig sec) + throws DeploymentException { + + if (enforceNoAddAfterHandshake && !addAllowed) { + throw new DeploymentException( + sm.getString("serverContainer.addNotAllowed")); + } + + if (servletContext == null) { + throw new DeploymentException( + sm.getString("serverContainer.servletContextMissing")); + } + String path = sec.getPath(); + + // Add method mapping to user properties + PojoMethodMapping methodMapping = new PojoMethodMapping(sec.getEndpointClass(), + sec.getDecoders(), path); + if (methodMapping.getOnClose() != null || methodMapping.getOnOpen() != null + || methodMapping.getOnError() != null || methodMapping.hasMessageHandlers()) { + sec.getUserProperties().put(nginx.unit.websocket.pojo.Constants.POJO_METHOD_MAPPING_KEY, + methodMapping); + } + + UriTemplate uriTemplate = new UriTemplate(path); + if (uriTemplate.hasParameters()) { + Integer key = Integer.valueOf(uriTemplate.getSegmentCount()); + SortedSet templateMatches = + configTemplateMatchMap.get(key); + if (templateMatches == null) { + // Ensure that if concurrent threads execute this block they + // both end up using the same TreeSet instance + templateMatches = new TreeSet<>( + TemplatePathMatchComparator.getInstance()); + configTemplateMatchMap.putIfAbsent(key, templateMatches); + templateMatches = configTemplateMatchMap.get(key); + } + if (!templateMatches.add(new TemplatePathMatch(sec, uriTemplate))) { + // Duplicate uriTemplate; + throw new DeploymentException( + sm.getString("serverContainer.duplicatePaths", path, + sec.getEndpointClass(), + sec.getEndpointClass())); + } + } else { + // Exact match + ServerEndpointConfig old = configExactMatchMap.put(path, sec); + if (old != null) { + // Duplicate path mappings + throw new DeploymentException( + sm.getString("serverContainer.duplicatePaths", path, + old.getEndpointClass(), + sec.getEndpointClass())); + } + } + + endpointsRegistered = true; + } + + + /** + * Provides the equivalent of {@link #addEndpoint(ServerEndpointConfig)} + * for publishing plain old java objects (POJOs) that have been annotated as + * WebSocket endpoints. + * + * @param pojo The annotated POJO + */ + @Override + public void addEndpoint(Class pojo) throws DeploymentException { + + ServerEndpoint annotation = pojo.getAnnotation(ServerEndpoint.class); + if (annotation == null) { + throw new DeploymentException( + sm.getString("serverContainer.missingAnnotation", + pojo.getName())); + } + String path = annotation.value(); + + // Validate encoders + validateEncoders(annotation.encoders()); + + // ServerEndpointConfig + ServerEndpointConfig sec; + Class configuratorClazz = + annotation.configurator(); + Configurator configurator = null; + if (!configuratorClazz.equals(Configurator.class)) { + try { + configurator = annotation.configurator().getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "serverContainer.configuratorFail", + annotation.configurator().getName(), + pojo.getClass().getName()), e); + } + } + if (configurator == null) { + configurator = new nginx.unit.websocket.server.DefaultServerEndpointConfigurator(); + } + sec = ServerEndpointConfig.Builder.create(pojo, path). + decoders(Arrays.asList(annotation.decoders())). + encoders(Arrays.asList(annotation.encoders())). + subprotocols(Arrays.asList(annotation.subprotocols())). + configurator(configurator). + build(); + + addEndpoint(sec); + } + + + boolean areEndpointsRegistered() { + return endpointsRegistered; + } + + + /** + * Until the WebSocket specification provides such a mechanism, this Tomcat + * proprietary method is provided to enable applications to programmatically + * determine whether or not to upgrade an individual request to WebSocket. + *

+ * Note: This method is not used by Tomcat but is used directly by + * third-party code and must not be removed. + * + * @param request The request object to be upgraded + * @param response The response object to be populated with the result of + * the upgrade + * @param sec The server endpoint to use to process the upgrade request + * @param pathParams The path parameters associated with the upgrade request + * + * @throws ServletException If a configuration error prevents the upgrade + * from taking place + * @throws IOException If an I/O error occurs during the upgrade process + */ + public void doUpgrade(HttpServletRequest request, + HttpServletResponse response, ServerEndpointConfig sec, + Map pathParams) + throws ServletException, IOException { + UpgradeUtil.doUpgrade(this, request, response, sec, pathParams); + } + + + public WsMappingResult findMapping(String path) { + + // Prevent registering additional endpoints once the first attempt has + // been made to use one + if (addAllowed) { + addAllowed = false; + } + + // Check an exact match. Simple case as there are no templates. + ServerEndpointConfig sec = configExactMatchMap.get(path); + if (sec != null) { + return new WsMappingResult(sec, Collections.emptyMap()); + } + + // No exact match. Need to look for template matches. + UriTemplate pathUriTemplate = null; + try { + pathUriTemplate = new UriTemplate(path); + } catch (DeploymentException e) { + // Path is not valid so can't be matched to a WebSocketEndpoint + return null; + } + + // Number of segments has to match + Integer key = Integer.valueOf(pathUriTemplate.getSegmentCount()); + SortedSet templateMatches = + configTemplateMatchMap.get(key); + + if (templateMatches == null) { + // No templates with an equal number of segments so there will be + // no matches + return null; + } + + // List is in alphabetical order of normalised templates. + // Correct match is the first one that matches. + Map pathParams = null; + for (TemplatePathMatch templateMatch : templateMatches) { + pathParams = templateMatch.getUriTemplate().match(pathUriTemplate); + if (pathParams != null) { + sec = templateMatch.getConfig(); + break; + } + } + + if (sec == null) { + // No match + return null; + } + + return new WsMappingResult(sec, pathParams); + } + + + + public boolean isEnforceNoAddAfterHandshake() { + return enforceNoAddAfterHandshake; + } + + + public void setEnforceNoAddAfterHandshake( + boolean enforceNoAddAfterHandshake) { + this.enforceNoAddAfterHandshake = enforceNoAddAfterHandshake; + } + + + /** + * {@inheritDoc} + * + * Overridden to make it visible to other classes in this package. + */ + @Override + protected void registerSession(Endpoint endpoint, WsSession wsSession) { + super.registerSession(endpoint, wsSession); + if (wsSession.isOpen() && + wsSession.getUserPrincipal() != null && + wsSession.getHttpSessionId() != null) { + registerAuthenticatedSession(wsSession, + wsSession.getHttpSessionId()); + } + } + + + /** + * {@inheritDoc} + * + * Overridden to make it visible to other classes in this package. + */ + @Override + protected void unregisterSession(Endpoint endpoint, WsSession wsSession) { + if (wsSession.getUserPrincipal() != null && + wsSession.getHttpSessionId() != null) { + unregisterAuthenticatedSession(wsSession, + wsSession.getHttpSessionId()); + } + super.unregisterSession(endpoint, wsSession); + } + + + private void registerAuthenticatedSession(WsSession wsSession, + String httpSessionId) { + Set wsSessions = authenticatedSessions.get(httpSessionId); + if (wsSessions == null) { + wsSessions = Collections.newSetFromMap( + new ConcurrentHashMap()); + authenticatedSessions.putIfAbsent(httpSessionId, wsSessions); + wsSessions = authenticatedSessions.get(httpSessionId); + } + wsSessions.add(wsSession); + } + + + private void unregisterAuthenticatedSession(WsSession wsSession, + String httpSessionId) { + Set wsSessions = authenticatedSessions.get(httpSessionId); + // wsSessions will be null if the HTTP session has ended + if (wsSessions != null) { + wsSessions.remove(wsSession); + } + } + + + public void closeAuthenticatedSession(String httpSessionId) { + Set wsSessions = authenticatedSessions.remove(httpSessionId); + + if (wsSessions != null && !wsSessions.isEmpty()) { + for (WsSession wsSession : wsSessions) { + try { + wsSession.close(AUTHENTICATED_HTTP_SESSION_CLOSED); + } catch (IOException e) { + // Any IOExceptions during close will have been caught and the + // onError method called. + } + } + } + } + + + private static void validateEncoders(Class[] encoders) + throws DeploymentException { + + for (Class encoder : encoders) { + // Need to instantiate decoder to ensure it is valid and that + // deployment can be failed if it is not + @SuppressWarnings("unused") + Encoder instance; + try { + encoder.getConstructor().newInstance(); + } catch(ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "serverContainer.encoderFail", encoder.getName()), e); + } + } + } + + + private static class TemplatePathMatch { + private final ServerEndpointConfig config; + private final UriTemplate uriTemplate; + + public TemplatePathMatch(ServerEndpointConfig config, + UriTemplate uriTemplate) { + this.config = config; + this.uriTemplate = uriTemplate; + } + + + public ServerEndpointConfig getConfig() { + return config; + } + + + public UriTemplate getUriTemplate() { + return uriTemplate; + } + } + + + /** + * This Comparator implementation is thread-safe so only create a single + * instance. + */ + private static class TemplatePathMatchComparator + implements Comparator { + + private static final TemplatePathMatchComparator INSTANCE = + new TemplatePathMatchComparator(); + + public static TemplatePathMatchComparator getInstance() { + return INSTANCE; + } + + private TemplatePathMatchComparator() { + // Hide default constructor + } + + @Override + public int compare(TemplatePathMatch tpm1, TemplatePathMatch tpm2) { + return tpm1.getUriTemplate().getNormalizedPath().compareTo( + tpm2.getUriTemplate().getNormalizedPath()); + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsSessionListener.java b/src/java/nginx/unit/websocket/server/WsSessionListener.java new file mode 100644 index 00000000..fc2bc9c5 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsSessionListener.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import javax.servlet.http.HttpSessionEvent; +import javax.servlet.http.HttpSessionListener; + +public class WsSessionListener implements HttpSessionListener{ + + private final WsServerContainer wsServerContainer; + + + public WsSessionListener(WsServerContainer wsServerContainer) { + this.wsServerContainer = wsServerContainer; + } + + + @Override + public void sessionDestroyed(HttpSessionEvent se) { + wsServerContainer.closeAuthenticatedSession(se.getSession().getId()); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsWriteTimeout.java b/src/java/nginx/unit/websocket/server/WsWriteTimeout.java new file mode 100644 index 00000000..2dfc4ab2 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsWriteTimeout.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.Comparator; +import java.util.Set; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.atomic.AtomicInteger; + +import nginx.unit.websocket.BackgroundProcess; +import nginx.unit.websocket.BackgroundProcessManager; + +/** + * Provides timeouts for asynchronous web socket writes. On the server side we + * only have access to {@link javax.servlet.ServletOutputStream} and + * {@link javax.servlet.ServletInputStream} so there is no way to set a timeout + * for writes to the client. + */ +public class WsWriteTimeout implements BackgroundProcess { + + private final Set endpoints = + new ConcurrentSkipListSet<>(new EndpointComparator()); + private final AtomicInteger count = new AtomicInteger(0); + private int backgroundProcessCount = 0; + private volatile int processPeriod = 1; + + @Override + public void backgroundProcess() { + // This method gets called once a second. + backgroundProcessCount ++; + + if (backgroundProcessCount >= processPeriod) { + backgroundProcessCount = 0; + + long now = System.currentTimeMillis(); + for (WsRemoteEndpointImplServer endpoint : endpoints) { + if (endpoint.getTimeoutExpiry() < now) { + // Background thread, not the thread that triggered the + // write so no need to use a dispatch + endpoint.onTimeout(false); + } else { + // Endpoints are ordered by timeout expiry so if this point + // is reached there is no need to check the remaining + // endpoints + break; + } + } + } + } + + + @Override + public void setProcessPeriod(int period) { + this.processPeriod = period; + } + + + /** + * {@inheritDoc} + * + * The default value is 1 which means asynchronous write timeouts are + * processed every 1 second. + */ + @Override + public int getProcessPeriod() { + return processPeriod; + } + + + public void register(WsRemoteEndpointImplServer endpoint) { + boolean result = endpoints.add(endpoint); + if (result) { + int newCount = count.incrementAndGet(); + if (newCount == 1) { + BackgroundProcessManager.getInstance().register(this); + } + } + } + + + public void unregister(WsRemoteEndpointImplServer endpoint) { + boolean result = endpoints.remove(endpoint); + if (result) { + int newCount = count.decrementAndGet(); + if (newCount == 0) { + BackgroundProcessManager.getInstance().unregister(this); + } + } + } + + + /** + * Note: this comparator imposes orderings that are inconsistent with equals + */ + private static class EndpointComparator implements + Comparator { + + @Override + public int compare(WsRemoteEndpointImplServer o1, + WsRemoteEndpointImplServer o2) { + + long t1 = o1.getTimeoutExpiry(); + long t2 = o2.getTimeoutExpiry(); + + if (t1 < t2) { + return -1; + } else if (t1 == t2) { + return 0; + } else { + return 1; + } + } + } +} diff --git a/src/java/nginx/unit/websocket/server/package-info.java b/src/java/nginx/unit/websocket/server/package-info.java new file mode 100644 index 00000000..87bc85a3 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Server-side specific implementation classes. These are in a separate package + * to make packaging a pure client JAR simpler. + */ +package nginx.unit.websocket.server; diff --git a/src/java/nxt_jni_Request.c b/src/java/nxt_jni_Request.c index 733290dd..2e9dce67 100644 --- a/src/java/nxt_jni_Request.c +++ b/src/java/nxt_jni_Request.c @@ -58,16 +58,28 @@ static jint JNICALL nxt_java_Request_getServerPort(JNIEnv *env, jclass cls, jlong req_ptr); static jboolean JNICALL nxt_java_Request_isSecure(JNIEnv *env, jclass cls, jlong req_ptr); +static void JNICALL nxt_java_Request_upgrade(JNIEnv *env, jclass cls, + jlong req_info_ptr); +static jboolean JNICALL nxt_java_Request_isUpgrade(JNIEnv *env, jclass cls, + jlong req_info_ptr); static void JNICALL nxt_java_Request_log(JNIEnv *env, jclass cls, jlong req_info_ptr, jstring msg, jint msg_len); static void JNICALL nxt_java_Request_trace(JNIEnv *env, jclass cls, jlong req_info_ptr, jstring msg, jint msg_len); static jobject JNICALL nxt_java_Request_getResponse(JNIEnv *env, jclass cls, jlong req_info_ptr); +static void JNICALL nxt_java_Request_sendWsFrameBuf(JNIEnv *env, jclass cls, + jlong req_info_ptr, jobject buf, jint pos, jint len, jbyte opCode, jboolean last); +static void JNICALL nxt_java_Request_sendWsFrameArr(JNIEnv *env, jclass cls, + jlong req_info_ptr, jarray arr, jint pos, jint len, jbyte opCode, jboolean last); +static void JNICALL nxt_java_Request_closeWs(JNIEnv *env, jclass cls, + jlong req_info_ptr); static jclass nxt_java_Request_class; static jmethodID nxt_java_Request_ctor; +static jmethodID nxt_java_Request_processWsFrame; +static jmethodID nxt_java_Request_closeWsSession; int @@ -91,6 +103,18 @@ nxt_java_initRequest(JNIEnv *env, jobject cl) return NXT_UNIT_ERROR; } + nxt_java_Request_processWsFrame = (*env)->GetMethodID(env, cls, "processWsFrame", "(Ljava/nio/ByteBuffer;BZ)V"); + if (nxt_java_Request_processWsFrame == NULL) { + (*env)->DeleteGlobalRef(env, cls); + return NXT_UNIT_ERROR; + } + + nxt_java_Request_closeWsSession = (*env)->GetMethodID(env, cls, "closeWsSession", "()V"); + if (nxt_java_Request_closeWsSession == NULL) { + (*env)->DeleteGlobalRef(env, cls); + return NXT_UNIT_ERROR; + } + JNINativeMethod request_methods[] = { { (char *) "getHeader", (char *) "(JLjava/lang/String;I)Ljava/lang/String;", @@ -172,6 +196,14 @@ nxt_java_initRequest(JNIEnv *env, jobject cl) (char *) "(J)Z", nxt_java_Request_isSecure }, + { (char *) "upgrade", + (char *) "(J)V", + nxt_java_Request_upgrade }, + + { (char *) "isUpgrade", + (char *) "(J)Z", + nxt_java_Request_isUpgrade }, + { (char *) "log", (char *) "(JLjava/lang/String;I)V", nxt_java_Request_log }, @@ -184,6 +216,18 @@ nxt_java_initRequest(JNIEnv *env, jobject cl) (char *) "(J)Lnginx/unit/Response;", nxt_java_Request_getResponse }, + { (char *) "sendWsFrame", + (char *) "(JLjava/nio/ByteBuffer;IIBZ)V", + nxt_java_Request_sendWsFrameBuf }, + + { (char *) "sendWsFrame", + (char *) "(J[BIIBZ)V", + nxt_java_Request_sendWsFrameArr }, + + { (char *) "closeWs", + (char *) "(J)V", + nxt_java_Request_closeWs }, + }; res = (*env)->RegisterNatives(env, nxt_java_Request_class, @@ -624,6 +668,32 @@ nxt_java_Request_isSecure(JNIEnv *env, jclass cls, jlong req_ptr) } +static void JNICALL +nxt_java_Request_upgrade(JNIEnv *env, jclass cls, jlong req_info_ptr) +{ + nxt_unit_request_info_t *req; + + req = nxt_jlong2ptr(req_info_ptr); + + if (!nxt_unit_response_is_init(req)) { + nxt_unit_response_init(req, 101, 0, 0); + } + + (void) nxt_unit_response_upgrade(req); +} + + +static jboolean JNICALL +nxt_java_Request_isUpgrade(JNIEnv *env, jclass cls, jlong req_info_ptr) +{ + nxt_unit_request_info_t *req; + + req = nxt_jlong2ptr(req_info_ptr); + + return nxt_unit_request_is_websocket_handshake(req); +} + + static void JNICALL nxt_java_Request_log(JNIEnv *env, jclass cls, jlong req_info_ptr, jstring msg, jint msg_len) @@ -677,3 +747,77 @@ nxt_java_Request_getResponse(JNIEnv *env, jclass cls, jlong req_info_ptr) return data->jresp; } + + +static void JNICALL +nxt_java_Request_sendWsFrameBuf(JNIEnv *env, jclass cls, + jlong req_info_ptr, jobject buf, jint pos, jint len, jbyte opCode, jboolean last) +{ + nxt_unit_request_info_t *req; + + req = nxt_jlong2ptr(req_info_ptr); + uint8_t *b = (*env)->GetDirectBufferAddress(env, buf); + + if (b != NULL) { + nxt_unit_websocket_send(req, opCode, last, b + pos, len); + + } else { + nxt_unit_req_debug(req, "sendWsFrameBuf: b == NULL"); + } +} + + +static void JNICALL +nxt_java_Request_sendWsFrameArr(JNIEnv *env, jclass cls, + jlong req_info_ptr, jarray arr, jint pos, jint len, jbyte opCode, jboolean last) +{ + nxt_unit_request_info_t *req; + + req = nxt_jlong2ptr(req_info_ptr); + uint8_t *b = (*env)->GetPrimitiveArrayCritical(env, arr, NULL); + + if (b != NULL) { + if (!nxt_unit_response_is_sent(req)) { + nxt_unit_response_send(req); + } + + nxt_unit_websocket_send(req, opCode, last, b + pos, len); + + (*env)->ReleasePrimitiveArrayCritical(env, arr, b, 0); + + } else { + nxt_unit_req_debug(req, "sendWsFrameArr: b == NULL"); + } +} + + +static void JNICALL +nxt_java_Request_closeWs(JNIEnv *env, jclass cls, jlong req_info_ptr) +{ + nxt_unit_request_info_t *req; + nxt_java_request_data_t *data; + + req = nxt_jlong2ptr(req_info_ptr); + + data = req->data; + + (*env)->DeleteGlobalRef(env, data->jresp); + (*env)->DeleteGlobalRef(env, data->jreq); + + nxt_unit_request_done(req, NXT_UNIT_OK); +} + + +void +nxt_java_Request_websocket(JNIEnv *env, jobject jreq, jobject jbuf, + uint8_t opcode, uint8_t fin) +{ + (*env)->CallVoidMethod(env, jreq, nxt_java_Request_processWsFrame, jbuf, opcode, fin); +} + + +void +nxt_java_Request_close(JNIEnv *env, jobject jreq) +{ + (*env)->CallVoidMethod(env, jreq, nxt_java_Request_closeWsSession); +} diff --git a/src/java/nxt_jni_Request.h b/src/java/nxt_jni_Request.h index 1c9c1428..9187d878 100644 --- a/src/java/nxt_jni_Request.h +++ b/src/java/nxt_jni_Request.h @@ -15,4 +15,9 @@ int nxt_java_initRequest(JNIEnv *env, jobject cl); jobject nxt_java_newRequest(JNIEnv *env, jobject ctx, nxt_unit_request_info_t *req); +void nxt_java_Request_websocket(JNIEnv *env, jobject jreq, jobject jbuf, + uint8_t opcode, uint8_t fin); + +void nxt_java_Request_close(JNIEnv *env, jobject jreq); + #endif /* _NXT_JAVA_REQUEST_H_INCLUDED_ */ diff --git a/src/nxt_java.c b/src/nxt_java.c index 3421d825..08e24595 100644 --- a/src/nxt_java.c +++ b/src/nxt_java.c @@ -13,6 +13,7 @@ #include #include #include +#include #include @@ -30,6 +31,8 @@ static nxt_int_t nxt_java_pre_init(nxt_task_t *task, nxt_common_app_conf_t *conf); static nxt_int_t nxt_java_init(nxt_task_t *task, nxt_common_app_conf_t *conf); static void nxt_java_request_handler(nxt_unit_request_info_t *req); +static void nxt_java_websocket_handler(nxt_unit_websocket_frame_t *ws); +static void nxt_java_close_handler(nxt_unit_request_info_t *req); static uint32_t compat[] = { NXT_VERNUM, NXT_DEBUG, @@ -347,6 +350,8 @@ nxt_java_init(nxt_task_t *task, nxt_common_app_conf_t *conf) nxt_unit_default_init(task, &java_init); java_init.callbacks.request_handler = nxt_java_request_handler; + java_init.callbacks.websocket_handler = nxt_java_websocket_handler; + java_init.callbacks.close_handler = nxt_java_close_handler; java_init.request_data_size = sizeof(nxt_java_request_data_t); java_init.data = &data; @@ -361,14 +366,14 @@ nxt_java_init(nxt_task_t *task, nxt_common_app_conf_t *conf) /* TODO report error */ } - nxt_unit_done(ctx); - nxt_java_stopContext(env, data.ctx); if ((*env)->ExceptionCheck(env)) { (*env)->ExceptionDescribe(env); } + nxt_unit_done(ctx); + (*jvm)->DestroyJavaVM(jvm); exit(0); @@ -454,8 +459,71 @@ nxt_java_request_handler(nxt_unit_request_info_t *req) data->buf = NULL; } + if (nxt_unit_response_is_websocket(req)) { + data->jreq = (*env)->NewGlobalRef(env, jreq); + data->jresp = (*env)->NewGlobalRef(env, jresp); + + } else { + nxt_unit_request_done(req, NXT_UNIT_OK); + } + (*env)->DeleteLocalRef(env, jresp); (*env)->DeleteLocalRef(env, jreq); +} + + +static void +nxt_java_websocket_handler(nxt_unit_websocket_frame_t *ws) +{ + void *b; + JNIEnv *env; + jobject jbuf; + nxt_java_data_t *java_data; + nxt_java_request_data_t *data; + + java_data = ws->req->unit->data; + env = java_data->env; + data = ws->req->data; + + b = malloc(ws->payload_len); + if (b != NULL) { + nxt_unit_websocket_read(ws, b, ws->payload_len); + + jbuf = (*env)->NewDirectByteBuffer(env, b, ws->payload_len); + if (jbuf != NULL) { + nxt_java_Request_websocket(env, data->jreq, jbuf, + ws->header->opcode, ws->header->fin); + + if ((*env)->ExceptionCheck(env)) { + (*env)->ExceptionDescribe(env); + (*env)->ExceptionClear(env); + } + + (*env)->DeleteLocalRef(env, jbuf); + } + + free(b); + } + + nxt_unit_websocket_done(ws); +} + + +static void +nxt_java_close_handler(nxt_unit_request_info_t *req) +{ + JNIEnv *env; + nxt_java_data_t *java_data; + nxt_java_request_data_t *data; + + java_data = req->unit->data; + env = java_data->env; + data = req->data; + + nxt_java_Request_close(env, data->jreq); + + (*env)->DeleteGlobalRef(env, data->jresp); + (*env)->DeleteGlobalRef(env, data->jreq); nxt_unit_request_done(req, NXT_UNIT_OK); } -- cgit From 2a6af1d21407cffbb687fffa8b6a1e7c8a273ddd Mon Sep 17 00:00:00 2001 From: Max Romanov Date: Mon, 9 Sep 2019 17:39:24 +0300 Subject: Added "extern" to nxt_http_fields_hash_proto to avoid link issues. --- src/nxt_http_parse.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/nxt_http_parse.h b/src/nxt_http_parse.h index c5b11bf3..42189182 100644 --- a/src/nxt_http_parse.h +++ b/src/nxt_http_parse.h @@ -113,7 +113,7 @@ nxt_int_t nxt_http_fields_process(nxt_list_t *fields, nxt_lvlhsh_t *hash, void *ctx); -const nxt_lvlhsh_proto_t nxt_http_fields_hash_proto; +extern const nxt_lvlhsh_proto_t nxt_http_fields_hash_proto; nxt_inline nxt_int_t nxt_http_field_process(nxt_http_field_t *field, nxt_lvlhsh_t *hash, void *ctx) -- cgit From 64be8717bdc2f0f8f11cbb8d18a0f96d2c24c6d3 Mon Sep 17 00:00:00 2001 From: Valentin Bartenev Date: Mon, 16 Sep 2019 20:17:42 +0300 Subject: Configuration: added ability to access object members with slashes. Now URI encoding can be used to escape "/" in the request path: GET /config/listeners/unix:%2Fpath%2Fto%2Fsocket/ --- src/nxt_conf.c | 15 ++++++++++++++ src/nxt_controller.c | 2 ++ src/nxt_http_parse.c | 30 ++++++++++++++++++++++++---- src/nxt_http_parse.h | 3 +++ src/nxt_string.c | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/nxt_string.h | 2 ++ 6 files changed, 104 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/nxt_conf.c b/src/nxt_conf.c index 57870838..ff82e1d2 100644 --- a/src/nxt_conf.c +++ b/src/nxt_conf.c @@ -416,10 +416,13 @@ static void nxt_conf_path_next_token(nxt_conf_path_parse_t *parse, nxt_conf_value_t * nxt_conf_get_path(nxt_conf_value_t *value, nxt_str_t *path) { + u_char *end; nxt_str_t token; nxt_int_t index; nxt_conf_path_parse_t parse; + u_char buf[256]; + parse.start = path->start; parse.end = path->start + path->length; parse.last = 0; @@ -436,6 +439,18 @@ nxt_conf_get_path(nxt_conf_value_t *value, nxt_str_t *path) return NULL; } + if (nxt_slow_path(token.length > 256)) { + return NULL; + } + + end = nxt_decode_uri(buf, token.start, token.length); + if (nxt_slow_path(end == NULL)) { + return NULL; + } + + token.length = end - buf; + token.start = buf; + switch (value->type) { case NXT_CONF_VALUE_OBJECT: diff --git a/src/nxt_controller.c b/src/nxt_controller.c index 49afbe46..86ba1246 100644 --- a/src/nxt_controller.c +++ b/src/nxt_controller.c @@ -478,6 +478,8 @@ nxt_controller_conn_init(nxt_task_t *task, void *obj, void *data) return; } + r->parser.encoded_slashes = 1; + b = nxt_buf_mem_alloc(c->mem_pool, 1024, 0); if (nxt_slow_path(b == NULL)) { nxt_controller_conn_free(task, c, NULL); diff --git a/src/nxt_http_parse.c b/src/nxt_http_parse.c index 8b4bf47c..945f4e93 100644 --- a/src/nxt_http_parse.c +++ b/src/nxt_http_parse.c @@ -1046,12 +1046,25 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) if (ch >= '0' && ch <= '9') { ch = (u_char) ((high << 4) + ch - '0'); - if (ch == '%' || ch == '#') { + if (ch == '%') { state = sw_normal; - *u++ = ch; + *u++ = '%'; + + if (rp->encoded_slashes) { + *u++ = '2'; + *u++ = '5'; + } + + continue; + } + + if (ch == '#') { + state = sw_normal; + *u++ = '#'; continue; + } - } else if (ch == '\0') { + if (ch == '\0') { return NXT_HTTP_PARSE_INVALID; } @@ -1067,8 +1080,17 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) state = sw_normal; *u++ = ch; continue; + } + + if (ch == '/' && rp->encoded_slashes) { + state = sw_normal; + *u++ = '%'; + *u++ = '2'; + *u++ = p[-1]; /* 'f' or 'F' */ + continue; + } - } else if (ch == '+') { + if (ch == '+') { rp->plus_in_target = 1; } diff --git a/src/nxt_http_parse.h b/src/nxt_http_parse.h index 42189182..e42ec9a4 100644 --- a/src/nxt_http_parse.h +++ b/src/nxt_http_parse.h @@ -68,6 +68,9 @@ struct nxt_http_request_parse_s { unsigned space_in_target:1; /* target with "+" */ unsigned plus_in_target:1; + + /* Preserve encoded '/' (%2F) and '%' (%25). */ + unsigned encoded_slashes:1; }; diff --git a/src/nxt_string.c b/src/nxt_string.c index 7d8c1ce3..4d3b3954 100644 --- a/src/nxt_string.c +++ b/src/nxt_string.c @@ -451,3 +451,59 @@ nxt_strvers_match(u_char *version, u_char *prefix, size_t length) return 0; } + + +u_char * +nxt_decode_uri(u_char *dst, u_char *src, size_t length) +{ + u_char *end, ch; + uint8_t d0, d1; + + static const uint8_t hex[256] + nxt_aligned(32) = + { + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 16, 16, 16, 16, 16, + 16, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + }; + + nxt_prefetch(&hex['0']); + + end = src + length; + + while (src < end) { + ch = *src++; + + if (ch == '%') { + if (nxt_slow_path(end - src < 2)) { + return NULL; + } + + d0 = hex[*src++]; + d1 = hex[*src++]; + + if (nxt_slow_path((d0 | d1) >= 16)) { + return NULL; + } + + ch = (d0 << 4) + d1; + } + + *dst++ = ch; + } + + return dst; +} diff --git a/src/nxt_string.h b/src/nxt_string.h index 22a63a17..56d7316b 100644 --- a/src/nxt_string.h +++ b/src/nxt_string.h @@ -168,5 +168,7 @@ NXT_EXPORT nxt_int_t nxt_strverscmp(const u_char *s1, const u_char *s2); NXT_EXPORT nxt_bool_t nxt_strvers_match(u_char *version, u_char *prefix, size_t length); +NXT_EXPORT u_char *nxt_decode_uri(u_char *dst, u_char *src, size_t length); + #endif /* _NXT_STRING_H_INCLUDED_ */ -- cgit From 2fb7a1bfb90c4c8d484e4b0c4bd5ae2a594fd6e1 Mon Sep 17 00:00:00 2001 From: Valentin Bartenev Date: Mon, 16 Sep 2019 20:17:42 +0300 Subject: HTTP parser: removed unused "exten_start" and "args_start" fields. --- src/nxt_http_parse.c | 58 +++++++++++++++++++++--------------------- src/nxt_http_parse.h | 2 -- src/test/nxt_http_parse_test.c | 24 ++++++----------- 3 files changed, 37 insertions(+), 47 deletions(-) (limited to 'src') diff --git a/src/nxt_http_parse.c b/src/nxt_http_parse.c index 945f4e93..9d591b86 100644 --- a/src/nxt_http_parse.c +++ b/src/nxt_http_parse.c @@ -163,7 +163,7 @@ static nxt_int_t nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, u_char *end) { - u_char *p, ch, *after_slash; + u_char *p, ch, *after_slash, *exten, *args; nxt_int_t rc; nxt_http_ver_t ver; nxt_http_target_traps_e trap; @@ -255,6 +255,8 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, rp->target_start = p; after_slash = p + 1; + exten = NULL; + args = NULL; for ( ;; ) { p++; @@ -269,8 +271,7 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, } after_slash = p + 1; - - rp->exten_start = NULL; + exten = NULL; continue; case NXT_HTTP_TARGET_DOT: @@ -279,11 +280,11 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, goto rest_of_target; } - rp->exten_start = p + 1; + exten = p + 1; continue; case NXT_HTTP_TARGET_ARGS_MARK: - rp->args_start = p + 1; + args = p + 1; goto rest_of_target; case NXT_HTTP_TARGET_SPACE: @@ -437,20 +438,19 @@ space_after_target: rp->path.start = rp->target_start; - if (rp->args_start != NULL) { - rp->path.length = rp->args_start - rp->target_start - 1; + if (args != NULL) { + rp->path.length = args - rp->target_start - 1; - rp->args.start = rp->args_start; - rp->args.length = rp->target_end - rp->args_start; + rp->args.length = rp->target_end - args; + rp->args.start = args; } else { rp->path.length = rp->target_end - rp->target_start; } - if (rp->exten_start) { - rp->exten.length = rp->path.start + rp->path.length - - rp->exten_start; - rp->exten.start = rp->exten_start; + if (exten != NULL) { + rp->exten.length = (rp->path.start + rp->path.length) - exten; + rp->exten.start = exten; } return nxt_http_parse_field_name(rp, pos, end); @@ -835,7 +835,8 @@ static const uint8_t nxt_http_normal[32] nxt_aligned(32) = { static nxt_int_t nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) { - u_char *p, *u, c, ch, high; + u_char *p, *u, c, ch, high, *exten, *args; + enum { sw_normal = 0, sw_slash, @@ -852,7 +853,6 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) p = rp->target_start; u = nxt_mp_alloc(rp->mem_pool, rp->target_end - p + 1); - if (nxt_slow_path(u == NULL)) { return NXT_ERROR; } @@ -861,8 +861,8 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) rp->path.start = u; high = '\0'; - rp->exten_start = NULL; - rp->args_start = NULL; + exten = NULL; + args = NULL; while (p < rp->target_end) { @@ -881,7 +881,7 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) switch (ch) { case '/': - rp->exten_start = NULL; + exten = NULL; state = sw_slash; *u++ = ch; continue; @@ -890,12 +890,12 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) state = sw_quoted; continue; case '?': - rp->args_start = p; + args = p; goto args; case '#': goto done; case '.': - rp->exten_start = u + 1; + exten = u + 1; *u++ = ch; continue; case '+': @@ -928,7 +928,7 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) state = sw_quoted; continue; case '?': - rp->args_start = p; + args = p; goto args; case '#': goto done; @@ -965,7 +965,7 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) state = sw_quoted; continue; case '?': - rp->args_start = p; + args = p; goto args; case '#': goto done; @@ -1009,7 +1009,7 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) state = sw_quoted; continue; case '?': - rp->args_start = p; + args = p; goto args; case '#': goto done; @@ -1114,18 +1114,18 @@ args: } } - if (rp->args_start != NULL) { - rp->args.length = p - rp->args_start; - rp->args.start = rp->args_start; + if (args != NULL) { + rp->args.length = p - args; + rp->args.start = args; } done: rp->path.length = u - rp->path.start; - if (rp->exten_start) { - rp->exten.length = u - rp->exten_start; - rp->exten.start = rp->exten_start; + if (exten) { + rp->exten.length = u - exten; + rp->exten.start = exten; } return NXT_OK; diff --git a/src/nxt_http_parse.h b/src/nxt_http_parse.h index e42ec9a4..29e3688e 100644 --- a/src/nxt_http_parse.h +++ b/src/nxt_http_parse.h @@ -43,8 +43,6 @@ struct nxt_http_request_parse_s { u_char *target_start; u_char *target_end; - u_char *exten_start; - u_char *args_start; nxt_str_t path; nxt_str_t args; diff --git a/src/test/nxt_http_parse_test.c b/src/test/nxt_http_parse_test.c index 572e91b2..3e86f6ef 100644 --- a/src/test/nxt_http_parse_test.c +++ b/src/test/nxt_http_parse_test.c @@ -80,7 +80,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { { .request_line = { nxt_string("XXX-METHOD"), nxt_string("/d.ir/fi+le.ext?key=val"), - nxt_string("ext?key=val"), + nxt_string("ext"), nxt_string("key=val"), "HTTP/1.2", 0, 0, 0, 1 @@ -163,7 +163,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_string("GET"), nxt_string("/?#"), nxt_null_string, - nxt_string("#"), + nxt_string(""), "HTTP/1.0", 1, 0, 0, 0 }} @@ -729,31 +729,23 @@ nxt_http_parse_test_request_line(nxt_http_request_parse_t *rp, return NXT_ERROR; } - str.length = (rp->exten_start != NULL) ? rp->target_end - rp->exten_start - : 0; - str.start = rp->exten_start; - - if (str.start != test->exten.start - && !nxt_strstr_eq(&str, &test->exten)) + if (rp->exten.start != test->exten.start + && !nxt_strstr_eq(&rp->exten, &test->exten)) { nxt_log_alert(log, "http parse test case failed:\n" " - request:\n\"%V\"\n" " - exten: \"%V\" (expected: \"%V\")", - request, &str, &test->exten); + request, &rp->exten, &test->exten); return NXT_ERROR; } - str.length = (rp->args_start != NULL) ? rp->target_end - rp->args_start - : 0; - str.start = rp->args_start; - - if (str.start != test->args.start - && !nxt_strstr_eq(&str, &test->args)) + if (rp->args.start != test->args.start + && !nxt_strstr_eq(&rp->args, &test->args)) { nxt_log_alert(log, "http parse test case failed:\n" " - request:\n\"%V\"\n" " - args: \"%V\" (expected: \"%V\")", - request, &str, &test->args); + request, &rp->args, &test->args); return NXT_ERROR; } -- cgit From 56f4085b9d1d00ecdd359e3deb9c06a3b8ba0874 Mon Sep 17 00:00:00 2001 From: Valentin Bartenev Date: Mon, 16 Sep 2019 20:17:42 +0300 Subject: HTTP parser: removed unused "offset" field. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Thanks to 洪志道 (Hong Zhi Dao). --- src/nxt_http_parse.h | 2 -- 1 file changed, 2 deletions(-) (limited to 'src') diff --git a/src/nxt_http_parse.h b/src/nxt_http_parse.h index 29e3688e..e66475a1 100644 --- a/src/nxt_http_parse.h +++ b/src/nxt_http_parse.h @@ -37,8 +37,6 @@ struct nxt_http_request_parse_s { nxt_int_t (*handler)(nxt_http_request_parse_t *rp, u_char **pos, u_char *end); - size_t offset; - nxt_str_t method; u_char *target_start; -- cgit From 3b77e402a903d9f7a6eeb32f7930d8979f8e0c9e Mon Sep 17 00:00:00 2001 From: Valentin Bartenev Date: Mon, 16 Sep 2019 20:17:42 +0300 Subject: HTTP parser: removed unused "plus_in_target" flag. --- src/nxt_http_parse.c | 23 +---------------------- src/nxt_http_parse.h | 10 ++++------ src/test/nxt_http_parse_test.c | 30 ++++++++++-------------------- 3 files changed, 15 insertions(+), 48 deletions(-) (limited to 'src') diff --git a/src/nxt_http_parse.c b/src/nxt_http_parse.c index 9d591b86..4d806c61 100644 --- a/src/nxt_http_parse.c +++ b/src/nxt_http_parse.c @@ -47,7 +47,6 @@ typedef enum { NXT_HTTP_TARGET_DOT, /* . */ NXT_HTTP_TARGET_ARGS_MARK, /* ? */ NXT_HTTP_TARGET_QUOTE_MARK, /* % */ - NXT_HTTP_TARGET_PLUS, /* + */ } nxt_http_target_traps_e; @@ -57,7 +56,7 @@ static const uint8_t nxt_http_target_chars[256] nxt_aligned(64) = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* \s ! " # $ % & ' ( ) * + , - . / */ - 1, 0, 0, 2, 0, 8, 0, 0, 0, 0, 0, 9, 0, 0, 6, 5, + 1, 0, 0, 2, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 6, 5, /* 0 1 2 3 4 5 6 7 8 9 : ; < = > ? */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, @@ -295,10 +294,6 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, rp->quoted_target = 1; goto rest_of_target; - case NXT_HTTP_TARGET_PLUS: - rp->plus_in_target = 1; - continue; - case NXT_HTTP_TARGET_HASH: rp->complex_target = 1; goto rest_of_target; @@ -898,9 +893,6 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) exten = u + 1; *u++ = ch; continue; - case '+': - rp->plus_in_target = 1; - /* Fall through. */ default: *u++ = ch; continue; @@ -932,9 +924,6 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) goto args; case '#': goto done; - case '+': - rp->plus_in_target = 1; - /* Fall through. */ default: state = sw_normal; *u++ = ch; @@ -969,9 +958,6 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) goto args; case '#': goto done; - case '+': - rp->plus_in_target = 1; - /* Fall through. */ default: state = sw_normal; *u++ = ch; @@ -1013,9 +999,6 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) goto args; case '#': goto done; - case '+': - rp->plus_in_target = 1; - /* Fall through. */ default: state = sw_normal; *u++ = ch; @@ -1090,10 +1073,6 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) continue; } - if (ch == '+') { - rp->plus_in_target = 1; - } - state = saved_state; goto again; } diff --git a/src/nxt_http_parse.h b/src/nxt_http_parse.h index e66475a1..a307ea73 100644 --- a/src/nxt_http_parse.h +++ b/src/nxt_http_parse.h @@ -57,16 +57,14 @@ struct nxt_http_request_parse_s { uint32_t field_hash; /* target with "/." */ - unsigned complex_target:1; + uint8_t complex_target; /* 1 bit */ /* target with "%" */ - unsigned quoted_target:1; + uint8_t quoted_target; /* 1 bit */ /* target with " " */ - unsigned space_in_target:1; - /* target with "+" */ - unsigned plus_in_target:1; + uint8_t space_in_target; /* 1 bit */ /* Preserve encoded '/' (%2F) and '%' (%25). */ - unsigned encoded_slashes:1; + uint8_t encoded_slashes; /* 1 bit */ }; diff --git a/src/test/nxt_http_parse_test.c b/src/test/nxt_http_parse_test.c index 3e86f6ef..6fbda25a 100644 --- a/src/test/nxt_http_parse_test.c +++ b/src/test/nxt_http_parse_test.c @@ -21,8 +21,6 @@ typedef struct { unsigned quoted_target:1; /* target with " " */ unsigned space_in_target:1; - /* target with "+" */ - unsigned plus_in_target:1; } nxt_http_parse_test_request_line_t; @@ -70,7 +68,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 0, 0, 0, 0 + 0, 0, 0 }} }, { @@ -83,7 +81,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_string("ext"), nxt_string("key=val"), "HTTP/1.2", - 0, 0, 0, 1 + 0, 0, 0 }} }, { @@ -96,7 +94,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_string(""), "HTTP/1.0", - 0, 0, 0, 0 + 0, 0, 0 }} }, { @@ -139,7 +137,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 1, 0, 0, 0 + 1, 0, 0 }} }, { @@ -152,7 +150,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 1, 0, 0, 0 + 1, 0, 0 }} }, { @@ -165,7 +163,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_string(""), "HTTP/1.0", - 1, 0, 0, 0 + 1, 0, 0 }} }, { @@ -178,7 +176,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 1, 0, 0, 0 + 1, 0, 0 }} }, { @@ -191,7 +189,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 0, 1, 0, 0 + 0, 1, 0 }} }, { @@ -204,7 +202,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 0, 0, 1, 0 + 0, 0, 1 }} }, { @@ -217,7 +215,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.1", - 0, 0, 1, 0 + 0, 0, 1 }} }, { @@ -782,14 +780,6 @@ nxt_http_parse_test_request_line(nxt_http_request_parse_t *rp, return NXT_ERROR; } - if (rp->plus_in_target != test->plus_in_target) { - nxt_log_alert(log, "http parse test case failed:\n" - " - request:\n\"%V\"\n" - " - plus_in_target: %d (expected: %d)", - request, rp->plus_in_target, test->plus_in_target); - return NXT_ERROR; - } - return NXT_OK; } -- cgit From 6352c21a58d66db99f8f981c37e6d57e62fc24a2 Mon Sep 17 00:00:00 2001 From: Valentin Bartenev Date: Tue, 17 Sep 2019 18:40:21 +0300 Subject: HTTP parser: fixed parsing of target after literal space character. In theory, all space characters in request target must be encoded; however, some clients may violate the specification. For the sake of interoperability, Unit supports unencoded space characters. Previously, if there was a space character before the extension or arguments parts, those parts weren't recognized. Also, quoted symbols and complex target weren't detected after a space character. --- src/nxt_http_parse.c | 20 ++++++++++++++++++-- src/test/nxt_http_parse_test.c | 15 ++++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/nxt_http_parse.c b/src/nxt_http_parse.c index 4d806c61..5b009d96 100644 --- a/src/nxt_http_parse.c +++ b/src/nxt_http_parse.c @@ -164,6 +164,7 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, { u_char *p, ch, *after_slash, *exten, *args; nxt_int_t rc; + nxt_bool_t rest; nxt_http_ver_t ver; nxt_http_target_traps_e trap; @@ -256,6 +257,9 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, after_slash = p + 1; exten = NULL; args = NULL; + rest = 0; + +continue_target: for ( ;; ) { p++; @@ -312,6 +316,8 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, rest_of_target: + rest = 1; + for ( ;; ) { p++; @@ -378,7 +384,12 @@ space_after_target: } rp->space_in_target = 1; - goto rest_of_target; + + if (rest) { + goto rest_of_target; + } + + goto continue_target; } /* " HTTP/1.1\r\n" or " HTTP/1.1\n" */ @@ -392,7 +403,12 @@ space_after_target: } rp->space_in_target = 1; - goto rest_of_target; + + if (rest) { + goto rest_of_target; + } + + goto continue_target; } nxt_memcpy(ver.str, &p[1], 8); diff --git a/src/test/nxt_http_parse_test.c b/src/test/nxt_http_parse_test.c index 6fbda25a..5498cb1f 100644 --- a/src/test/nxt_http_parse_test.c +++ b/src/test/nxt_http_parse_test.c @@ -205,6 +205,19 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { 0, 0, 1 }} }, + { + nxt_string("GET /na %20me.ext?args HTTP/1.0\r\n\r\n"), + NXT_DONE, + &nxt_http_parse_test_request_line, + { .request_line = { + nxt_string("GET"), + nxt_string("/na %20me.ext?args"), + nxt_string("ext"), + nxt_string("args"), + "HTTP/1.0", + 0, 1, 1 + }} + }, { nxt_string("GET / HTTP/1.0 HTTP/1.1\r\n\r\n"), NXT_DONE, @@ -212,7 +225,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { { .request_line = { nxt_string("GET"), nxt_string("/ HTTP/1.0"), - nxt_null_string, + nxt_string("0"), nxt_null_string, "HTTP/1.1", 0, 0, 1 -- cgit From ca01845d89968ef8b726f8e4a683a06c518755e4 Mon Sep 17 00:00:00 2001 From: Valentin Bartenev Date: Wed, 18 Sep 2019 16:03:28 +0300 Subject: Configuration: added ability to modify object members with slashes. Example: PUT/POST/DELETE /config/listeners/unix:%2Fpath%2Fto%2Fsocket This follows a49ee872e83d. --- src/nxt_conf.c | 88 ++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 49 insertions(+), 39 deletions(-) (limited to 'src') diff --git a/src/nxt_conf.c b/src/nxt_conf.c index ff82e1d2..59eddd77 100644 --- a/src/nxt_conf.c +++ b/src/nxt_conf.c @@ -16,6 +16,8 @@ #define NXT_CONF_MAX_SHORT_STRING 14 #define NXT_CONF_MAX_STRING NXT_INT32_T_MAX +#define NXT_CONF_MAX_TOKEN_LEN 256 + typedef enum { NXT_CONF_VALUE_NULL = 0, @@ -90,6 +92,17 @@ struct nxt_conf_op_s { }; +typedef struct { + u_char *start; + u_char *end; + nxt_bool_t last; + u_char buf[NXT_CONF_MAX_TOKEN_LEN]; +} nxt_conf_path_parse_t; + + +static nxt_int_t nxt_conf_path_next_token(nxt_conf_path_parse_t *parse, + nxt_str_t *token); + static u_char *nxt_conf_json_skip_space(u_char *start, u_char *end); static u_char *nxt_conf_json_parse_value(nxt_mp_t *mp, nxt_conf_value_t *value, u_char *start, u_char *end, nxt_conf_json_error_t *error); @@ -402,33 +415,22 @@ nxt_conf_type(nxt_conf_value_t *value) } -typedef struct { - u_char *start; - u_char *end; - nxt_bool_t last; -} nxt_conf_path_parse_t; - - -static void nxt_conf_path_next_token(nxt_conf_path_parse_t *parse, - nxt_str_t *token); - - nxt_conf_value_t * nxt_conf_get_path(nxt_conf_value_t *value, nxt_str_t *path) { - u_char *end; nxt_str_t token; - nxt_int_t index; + nxt_int_t ret, index; nxt_conf_path_parse_t parse; - u_char buf[256]; - parse.start = path->start; parse.end = path->start + path->length; parse.last = 0; do { - nxt_conf_path_next_token(&parse, &token); + ret = nxt_conf_path_next_token(&parse, &token); + if (nxt_slow_path(ret != NXT_OK)) { + return NULL; + } if (token.length == 0) { @@ -439,18 +441,6 @@ nxt_conf_get_path(nxt_conf_value_t *value, nxt_str_t *path) return NULL; } - if (nxt_slow_path(token.length > 256)) { - return NULL; - } - - end = nxt_decode_uri(buf, token.start, token.length); - if (nxt_slow_path(end == NULL)) { - return NULL; - } - - token.length = end - buf; - token.start = buf; - switch (value->type) { case NXT_CONF_VALUE_OBJECT: @@ -481,24 +471,38 @@ nxt_conf_get_path(nxt_conf_value_t *value, nxt_str_t *path) } -static void +static nxt_int_t nxt_conf_path_next_token(nxt_conf_path_parse_t *parse, nxt_str_t *token) { - u_char *p, *end; + u_char *p, *start, *end; + size_t length; - end = parse->end; - p = parse->start + 1; + start = parse->start + 1; - token->start = p; + p = start; - while (p < end && *p != '/') { + while (p < parse->end && *p != '/') { p++; } parse->start = p; - parse->last = (p >= end); + parse->last = (p >= parse->end); + + length = p - start; + + if (nxt_slow_path(length > NXT_CONF_MAX_TOKEN_LEN)) { + return NXT_ERROR; + } + + end = nxt_decode_uri(parse->buf, start, length); + if (nxt_slow_path(end == NULL)) { + return NXT_ERROR; + } + + token->length = end - parse->buf; + token->start = parse->buf; - token->length = p - token->start; + return NXT_OK; } @@ -757,7 +761,7 @@ nxt_conf_op_compile(nxt_mp_t *mp, nxt_conf_op_t **ops, nxt_conf_value_t *root, nxt_str_t *path, nxt_conf_value_t *value, nxt_bool_t add) { nxt_str_t token; - nxt_int_t index; + nxt_int_t ret, index; nxt_conf_op_t *op, **parent; nxt_conf_value_t *node; nxt_conf_path_parse_t parse; @@ -778,7 +782,10 @@ nxt_conf_op_compile(nxt_mp_t *mp, nxt_conf_op_t **ops, nxt_conf_value_t *root, *parent = op; parent = (nxt_conf_op_t **) &op->ctx; - nxt_conf_path_next_token(&parse, &token); + ret = nxt_conf_path_next_token(&parse, &token); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_CONF_OP_ERROR; + } switch (root->type) { @@ -872,7 +879,10 @@ nxt_conf_op_compile(nxt_mp_t *mp, nxt_conf_op_t **ops, nxt_conf_value_t *root, return NXT_CONF_OP_ERROR; } - nxt_conf_set_string(&member->name, &token); + ret = nxt_conf_set_string_dup(&member->name, mp, &token); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_CONF_OP_ERROR; + } member->value = *value; -- cgit From 9bacd21405de021fa846bac95b7e3fb796763a80 Mon Sep 17 00:00:00 2001 From: Max Romanov Date: Wed, 18 Sep 2019 18:30:58 +0300 Subject: Protecting context structures with mutex. By design, Unit context is created for the thread which reads messages from the router. However, Go request handlers are called in a separate goroutine that may be executed in a different thread. To avoid a racing condition, access to lists of free structures in the context should be serialized. This patch should fix random crashes in Go applications under high load. This is related to #253 and #309 issues on GitHub. --- src/nxt_unit.c | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/nxt_unit.c b/src/nxt_unit.c index 28a0de20..3f6a945f 100644 --- a/src/nxt_unit.c +++ b/src/nxt_unit.c @@ -31,7 +31,7 @@ typedef struct nxt_unit_request_info_impl_s nxt_unit_request_info_impl_t; typedef struct nxt_unit_websocket_frame_impl_s nxt_unit_websocket_frame_impl_t; static nxt_unit_impl_t *nxt_unit_create(nxt_unit_init_t *init); -static void nxt_unit_ctx_init(nxt_unit_impl_t *lib, +static int nxt_unit_ctx_init(nxt_unit_impl_t *lib, nxt_unit_ctx_impl_t *ctx_impl, void *data); nxt_inline void nxt_unit_mmap_buf_insert(nxt_unit_mmap_buf_t **head, nxt_unit_mmap_buf_t *mmap_buf); @@ -204,6 +204,8 @@ struct nxt_unit_websocket_frame_impl_s { struct nxt_unit_ctx_impl_s { nxt_unit_ctx_t ctx; + pthread_mutex_t mutex; + nxt_unit_port_id_t read_port_id; int read_port_fd; @@ -402,7 +404,10 @@ nxt_unit_create(nxt_unit_init_t *init) nxt_queue_init(&lib->contexts); - nxt_unit_ctx_init(lib, &lib->main_ctx, init->ctx_data); + rc = nxt_unit_ctx_init(lib, &lib->main_ctx, init->ctx_data); + if (nxt_slow_path(rc != NXT_UNIT_OK)) { + goto fail; + } cb = &lib->callbacks; @@ -446,15 +451,24 @@ fail: } -static void +static int nxt_unit_ctx_init(nxt_unit_impl_t *lib, nxt_unit_ctx_impl_t *ctx_impl, void *data) { + int rc; + ctx_impl->ctx.data = data; ctx_impl->ctx.unit = &lib->unit; nxt_queue_insert_tail(&lib->contexts, &ctx_impl->link); + rc = pthread_mutex_init(&ctx_impl->mutex, NULL); + if (nxt_slow_path(rc != 0)) { + nxt_unit_alert(NULL, "failed to initialize mutex (%d)", rc); + + return NXT_UNIT_ERROR; + } + nxt_queue_init(&ctx_impl->free_req); nxt_queue_init(&ctx_impl->free_ws); nxt_queue_init(&ctx_impl->active_req); @@ -470,6 +484,8 @@ nxt_unit_ctx_init(nxt_unit_impl_t *lib, nxt_unit_ctx_impl_t *ctx_impl, ctx_impl->read_port_fd = -1; ctx_impl->requests.slot = 0; + + return NXT_UNIT_OK; } @@ -1029,7 +1045,11 @@ nxt_unit_request_info_get(nxt_unit_ctx_t *ctx) lib = nxt_container_of(ctx->unit, nxt_unit_impl_t, unit); + pthread_mutex_lock(&ctx_impl->mutex); + if (nxt_queue_is_empty(&ctx_impl->free_req)) { + pthread_mutex_unlock(&ctx_impl->mutex); + req_impl = malloc(sizeof(nxt_unit_request_info_impl_t) + lib->request_data_size); if (nxt_slow_path(req_impl == NULL)) { @@ -1041,6 +1061,8 @@ nxt_unit_request_info_get(nxt_unit_ctx_t *ctx) req_impl->req.unit = ctx->unit; req_impl->req.ctx = ctx; + pthread_mutex_lock(&ctx_impl->mutex); + } else { lnk = nxt_queue_first(&ctx_impl->free_req); nxt_queue_remove(lnk); @@ -1050,6 +1072,8 @@ nxt_unit_request_info_get(nxt_unit_ctx_t *ctx) nxt_queue_insert_tail(&ctx_impl->active_req, &req_impl->link); + pthread_mutex_unlock(&ctx_impl->mutex); + req_impl->req.data = lib->request_data_size ? req_impl->extra_data : NULL; return req_impl; @@ -1088,10 +1112,14 @@ nxt_unit_request_info_release(nxt_unit_request_info_t *req) nxt_unit_mmap_buf_free(req_impl->incoming_buf); } + pthread_mutex_lock(&ctx_impl->mutex); + nxt_queue_remove(&req_impl->link); nxt_queue_insert_tail(&ctx_impl->free_req, &req_impl->link); + pthread_mutex_unlock(&ctx_impl->mutex); + req_impl->state = NXT_UNIT_RS_RELEASED; } @@ -1120,7 +1148,11 @@ nxt_unit_websocket_frame_get(nxt_unit_ctx_t *ctx) ctx_impl = nxt_container_of(ctx, nxt_unit_ctx_impl_t, ctx); + pthread_mutex_lock(&ctx_impl->mutex); + if (nxt_queue_is_empty(&ctx_impl->free_ws)) { + pthread_mutex_unlock(&ctx_impl->mutex); + ws_impl = malloc(sizeof(nxt_unit_websocket_frame_impl_t)); if (nxt_slow_path(ws_impl == NULL)) { nxt_unit_warn(ctx, "websocket frame allocation failed"); @@ -1132,6 +1164,8 @@ nxt_unit_websocket_frame_get(nxt_unit_ctx_t *ctx) lnk = nxt_queue_first(&ctx_impl->free_ws); nxt_queue_remove(lnk); + pthread_mutex_unlock(&ctx_impl->mutex); + ws_impl = nxt_container_of(lnk, nxt_unit_websocket_frame_impl_t, link); } @@ -1160,7 +1194,11 @@ nxt_unit_websocket_frame_release(nxt_unit_websocket_frame_t *ws) ws_impl->retain_buf = NULL; } + pthread_mutex_lock(&ws_impl->ctx_impl->mutex); + nxt_queue_insert_tail(&ws_impl->ctx_impl->free_ws, &ws_impl->link); + + pthread_mutex_unlock(&ws_impl->ctx_impl->mutex); } @@ -1688,16 +1726,24 @@ nxt_unit_mmap_buf_get(nxt_unit_ctx_t *ctx) ctx_impl = nxt_container_of(ctx, nxt_unit_ctx_impl_t, ctx); + pthread_mutex_lock(&ctx_impl->mutex); + if (ctx_impl->free_buf == NULL) { + pthread_mutex_unlock(&ctx_impl->mutex); + mmap_buf = malloc(sizeof(nxt_unit_mmap_buf_t)); if (nxt_slow_path(mmap_buf == NULL)) { nxt_unit_warn(ctx, "failed to allocate buf"); + + return NULL; } } else { mmap_buf = ctx_impl->free_buf; nxt_unit_mmap_buf_remove(mmap_buf); + + pthread_mutex_unlock(&ctx_impl->mutex); } mmap_buf->ctx_impl = ctx_impl; @@ -1711,7 +1757,11 @@ nxt_unit_mmap_buf_release(nxt_unit_mmap_buf_t *mmap_buf) { nxt_unit_mmap_buf_remove(mmap_buf); + pthread_mutex_lock(&mmap_buf->ctx_impl->mutex); + nxt_unit_mmap_buf_insert(&mmap_buf->ctx_impl->free_buf, mmap_buf); + + pthread_mutex_unlock(&mmap_buf->ctx_impl->mutex); } @@ -3298,7 +3348,14 @@ nxt_unit_ctx_alloc(nxt_unit_ctx_t *ctx, void *data) close(fd); - nxt_unit_ctx_init(lib, new_ctx, data); + rc = nxt_unit_ctx_init(lib, new_ctx, data); + if (nxt_slow_path(rc != NXT_UNIT_OK)) { + lib->callbacks.remove_port(ctx, &new_port_id); + + free(new_ctx); + + return NULL; + } new_ctx->read_port_id = new_port_id; @@ -3350,6 +3407,8 @@ nxt_unit_ctx_free(nxt_unit_ctx_t *ctx) } nxt_queue_loop; + pthread_mutex_destroy(&ctx_impl->mutex); + nxt_queue_remove(&ctx_impl->link); if (ctx_impl != &lib->main_ctx) { -- cgit From 4ea9ed309e958ad087ddcfd358688e2a2f105e39 Mon Sep 17 00:00:00 2001 From: Max Romanov Date: Wed, 18 Sep 2019 18:31:06 +0300 Subject: Reducing number of warning messages. One alert per failed allocation is enough. --- src/nxt_unit.c | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/nxt_unit.c b/src/nxt_unit.c index 3f6a945f..18a07f9e 100644 --- a/src/nxt_unit.c +++ b/src/nxt_unit.c @@ -978,6 +978,9 @@ nxt_unit_process_websocket(nxt_unit_ctx_t *ctx, nxt_unit_recv_msg_t *recv_msg) } else { b = nxt_unit_mmap_buf_get(ctx); if (nxt_slow_path(b == NULL)) { + nxt_unit_alert(ctx, "#%"PRIu32": failed to allocate buf", + req_impl->stream); + return NXT_UNIT_ERROR; } @@ -1053,8 +1056,6 @@ nxt_unit_request_info_get(nxt_unit_ctx_t *ctx) req_impl = malloc(sizeof(nxt_unit_request_info_impl_t) + lib->request_data_size); if (nxt_slow_path(req_impl == NULL)) { - nxt_unit_warn(ctx, "request info allocation failed"); - return NULL; } @@ -1155,8 +1156,6 @@ nxt_unit_websocket_frame_get(nxt_unit_ctx_t *ctx) ws_impl = malloc(sizeof(nxt_unit_websocket_frame_impl_t)); if (nxt_slow_path(ws_impl == NULL)) { - nxt_unit_warn(ctx, "websocket frame allocation failed"); - return NULL; } @@ -1673,6 +1672,8 @@ nxt_unit_response_buf_alloc(nxt_unit_request_info_t *req, uint32_t size) mmap_buf = nxt_unit_mmap_buf_get(req->ctx); if (nxt_slow_path(mmap_buf == NULL)) { + nxt_unit_req_alert(req, "response_buf_alloc: failed to allocate buf"); + return NULL; } @@ -1733,8 +1734,6 @@ nxt_unit_mmap_buf_get(nxt_unit_ctx_t *ctx) mmap_buf = malloc(sizeof(nxt_unit_mmap_buf_t)); if (nxt_slow_path(mmap_buf == NULL)) { - nxt_unit_warn(ctx, "failed to allocate buf"); - return NULL; } -- cgit From 84e4a6437f69b14b1cafbe7915447fdcf1d87f68 Mon Sep 17 00:00:00 2001 From: Max Romanov Date: Wed, 18 Sep 2019 18:31:14 +0300 Subject: Go: do not store pointer to Go object. To pass Go object references to C and back we use hack with casting to unsafe and then to uintptr. However, we should not store such references because Go not guaratnee it will be available by the same address. Introducing map with integer key helps to avoid dereference stored address. This closes #253 and #309 issues on GitHub. --- src/go/unit/request.go | 2 +- src/go/unit/unit.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/go/unit/request.go b/src/go/unit/request.go index ad56cabb..1d8c6702 100644 --- a/src/go/unit/request.go +++ b/src/go/unit/request.go @@ -135,7 +135,7 @@ func nxt_go_request_set_tls(go_req uintptr) { //export nxt_go_request_handler func nxt_go_request_handler(go_req uintptr, h uintptr) { r := get_request(go_req) - handler := *(*http.Handler)(unsafe.Pointer(h)) + handler := get_handler(h) go func(r *request) { handler.ServeHTTP(r.response(), &r.req) diff --git a/src/go/unit/unit.go b/src/go/unit/unit.go index 06257768..1534479e 100644 --- a/src/go/unit/unit.go +++ b/src/go/unit/unit.go @@ -13,6 +13,7 @@ import "C" import ( "fmt" "net/http" + "sync" "unsafe" ) @@ -87,12 +88,58 @@ func nxt_go_warn(format string, args ...interface{}) { C.nxt_cgo_warn(str_ref(str), C.uint32_t(len(str))) } +type handler_registry struct { + sync.RWMutex + next uintptr + m map[uintptr]*http.Handler +} + +var handler_registry_ handler_registry + +func set_handler(handler *http.Handler) uintptr { + + handler_registry_.Lock() + if handler_registry_.m == nil { + handler_registry_.m = make(map[uintptr]*http.Handler) + handler_registry_.next = 1 + } + + h := handler_registry_.next + handler_registry_.next += 1 + handler_registry_.m[h] = handler + + handler_registry_.Unlock() + + return h +} + +func get_handler(h uintptr) http.Handler { + handler_registry_.RLock() + defer handler_registry_.RUnlock() + + return *handler_registry_.m[h] +} + +func reset_handler(h uintptr) { + + handler_registry_.Lock() + if handler_registry_.m != nil { + delete(handler_registry_.m, h) + } + + handler_registry_.Unlock() +} + func ListenAndServe(addr string, handler http.Handler) error { if handler == nil { handler = http.DefaultServeMux } - rc := C.nxt_cgo_run(C.uintptr_t(uintptr(unsafe.Pointer(&handler)))) + h := set_handler(&handler) + + rc := C.nxt_cgo_run(C.uintptr_t(h)) + + reset_handler(h) if rc != 0 { return http.ListenAndServe(addr, handler) -- cgit From a216b15e991d6e7cdc7b61fe018d58dd387937bd Mon Sep 17 00:00:00 2001 From: Max Romanov Date: Wed, 18 Sep 2019 18:31:22 +0300 Subject: Fixing request release order to avoid crashes on exit. Each request references the router process structure that owns all memory maps. The process structure has a reference counter; each request increases the counter to lock the structure in memory until request processing ends. Incoming and outgoing buffers reference memory maps that the process owns, so the process structure should be released only when all buffers are released to avoid invalid memory access and a crash. This describes the libunit library mechanism used for application processes. The background of this issue is as follows: The issue was found on buildbot when the router crashed during Java websocket tests. The Java application receives a notification from the master process; when the notification is processed, libunit deletes the process structure from its process hash and decrements the use counter; however, active websocket connections maintain their use counts on the process structure. After that, when the master process is stopping the application, libunit releases active websocket connections. At this point, it's important to release the connections' memory buffers before the corresponding process structure and all shared memory segments are released. --- src/nxt_unit.c | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/nxt_unit.c b/src/nxt_unit.c index 18a07f9e..9696e9cd 100644 --- a/src/nxt_unit.c +++ b/src/nxt_unit.c @@ -1093,12 +1093,6 @@ nxt_unit_request_info_release(nxt_unit_request_info_t *req) req->response = NULL; req->response_buf = NULL; - if (req_impl->process != NULL) { - nxt_unit_process_use(req->ctx, req_impl->process, -1); - - req_impl->process = NULL; - } - if (req_impl->websocket) { nxt_unit_request_hash_find(&ctx_impl->requests, req_impl->stream, 1); @@ -1113,6 +1107,16 @@ nxt_unit_request_info_release(nxt_unit_request_info_t *req) nxt_unit_mmap_buf_free(req_impl->incoming_buf); } + /* + * Process release should go after buffers release to guarantee mmap + * existence. + */ + if (req_impl->process != NULL) { + nxt_unit_process_use(req->ctx, req_impl->process, -1); + + req_impl->process = NULL; + } + pthread_mutex_lock(&ctx_impl->mutex); nxt_queue_remove(&req_impl->link); -- cgit From ab40ce3757e6f9d50c9c1a8aa4ea01cb5278eafa Mon Sep 17 00:00:00 2001 From: Max Romanov Date: Wed, 18 Sep 2019 18:31:29 +0300 Subject: Go: removing nxt_main.h usage. One small step to Go modules support. --- src/go/unit/nxt_cgo_lib.c | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/go/unit/nxt_cgo_lib.c b/src/go/unit/nxt_cgo_lib.c index cc1228f5..5cb31b5a 100644 --- a/src/go/unit/nxt_cgo_lib.c +++ b/src/go/unit/nxt_cgo_lib.c @@ -6,7 +6,6 @@ #include "_cgo_export.h" -#include #include #include @@ -39,7 +38,7 @@ nxt_cgo_run(uintptr_t handler) init.data = (void *) handler; ctx = nxt_unit_init(&init); - if (nxt_slow_path(ctx == NULL)) { + if (ctx == NULL) { return NXT_UNIT_ERROR; } @@ -171,7 +170,7 @@ nxt_cgo_response_write(uintptr_t req, uintptr_t start, uint32_t len) rc = nxt_unit_response_write((nxt_unit_request_info_t *) req, (void *) start, len); - if (nxt_slow_path(rc != NXT_UNIT_OK)) { + if (rc != NXT_UNIT_OK) { return -1; } -- cgit From 1fac43eebe4136a2f57c56f23fc90a33783e63f2 Mon Sep 17 00:00:00 2001 From: Max Romanov Date: Wed, 18 Sep 2019 22:45:30 +0300 Subject: Fixing master process crash after failed fork. This closes #312 issue on GitHub. --- src/nxt_main_process.c | 9 +++++++-- src/nxt_port.c | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/nxt_main_process.c b/src/nxt_main_process.c index 40682eb9..83c6d188 100644 --- a/src/nxt_main_process.c +++ b/src/nxt_main_process.c @@ -690,15 +690,18 @@ nxt_main_create_worker_process(nxt_task_t *task, nxt_runtime_t *rt, pid = nxt_process_create(task, process); - nxt_port_use(task, port, -1); - switch (pid) { case -1: + nxt_port_close(task, port); + nxt_port_use(task, port, -1); + return NXT_ERROR; case 0: /* A worker process, return to the event engine work queue loop. */ + nxt_port_use(task, port, -1); + return NXT_AGAIN; default: @@ -707,6 +710,8 @@ nxt_main_create_worker_process(nxt_task_t *task, nxt_runtime_t *rt, nxt_port_read_close(port); nxt_port_write_enable(task, port); + nxt_port_use(task, port, -1); + return NXT_OK; } } diff --git a/src/nxt_port.c b/src/nxt_port.c index cef65cab..9029353a 100644 --- a/src/nxt_port.c +++ b/src/nxt_port.c @@ -546,7 +546,7 @@ nxt_port_use(nxt_task_t *task, nxt_port_t *port, int i) if (i < 0 && c == -i) { - if (task->thread->engine == port->engine) { + if (port->engine == NULL || task->thread->engine == port->engine) { nxt_port_release(task, port); return; -- cgit From 6346e641eef4aacf92e81e0f1ea4f42ed1e62834 Mon Sep 17 00:00:00 2001 From: Max Romanov Date: Thu, 19 Sep 2019 16:28:03 +0300 Subject: Releasing WebSocket frame in case of buffer allocation failure. Found by Coverity (CID 349456). --- src/nxt_unit.c | 2 ++ 1 file changed, 2 insertions(+) (limited to 'src') diff --git a/src/nxt_unit.c b/src/nxt_unit.c index 9696e9cd..4497d09d 100644 --- a/src/nxt_unit.c +++ b/src/nxt_unit.c @@ -981,6 +981,8 @@ nxt_unit_process_websocket(nxt_unit_ctx_t *ctx, nxt_unit_recv_msg_t *recv_msg) nxt_unit_alert(ctx, "#%"PRIu32": failed to allocate buf", req_impl->stream); + nxt_unit_websocket_frame_release(&ws_impl->ws); + return NXT_UNIT_ERROR; } -- cgit From c554941b4f826d83d92d5ca8d7713bea4167896e Mon Sep 17 00:00:00 2001 From: Tiago de Bem Natel de Moura Date: Thu, 19 Sep 2019 15:25:23 +0300 Subject: Initial applications isolation support using Linux namespaces. --- src/nxt_application.h | 2 + src/nxt_capability.c | 104 ++++++++++++++++ src/nxt_capability.h | 17 +++ src/nxt_clone.c | 263 ++++++++++++++++++++++++++++++++++++++++ src/nxt_clone.h | 17 +++ src/nxt_conf_validation.c | 303 ++++++++++++++++++++++++++++++++++++---------- src/nxt_main.h | 1 + src/nxt_main_process.c | 207 +++++++++++++++++++++++++++---- src/nxt_process.c | 251 ++++++++++++++++++++++++++++---------- src/nxt_process.h | 31 +++-- src/nxt_runtime.c | 14 ++- src/nxt_runtime.h | 1 + src/nxt_unit.c | 2 +- 13 files changed, 1048 insertions(+), 165 deletions(-) create mode 100644 src/nxt_capability.c create mode 100644 src/nxt_capability.h create mode 100644 src/nxt_clone.c create mode 100644 src/nxt_clone.h (limited to 'src') diff --git a/src/nxt_application.h b/src/nxt_application.h index 7ff4bb11..2a1fa39e 100644 --- a/src/nxt_application.h +++ b/src/nxt_application.h @@ -88,6 +88,8 @@ struct nxt_common_app_conf_s { char *working_directory; nxt_conf_value_t *environment; + nxt_conf_value_t *isolation; + union { nxt_external_app_conf_t external; nxt_python_app_conf_t python; diff --git a/src/nxt_capability.c b/src/nxt_capability.c new file mode 100644 index 00000000..805faff6 --- /dev/null +++ b/src/nxt_capability.c @@ -0,0 +1,104 @@ +/* + * Copyright (C) Igor Sysoev + * Copyright (C) NGINX, Inc. + */ + +#include + +#if (NXT_HAVE_LINUX_CAPABILITY) + +#include +#include + +#define nxt_capget(hdrp, datap) \ + syscall(SYS_capget, hdrp, datap) +#define nxt_capset(hdrp, datap) \ + syscall(SYS_capset, hdrp, datap) + +#endif /* NXT_HAVE_LINUX_CAPABILITY */ + + +static nxt_int_t nxt_capability_specific_set(nxt_task_t *task, + nxt_capabilities_t *cap); + + +nxt_int_t +nxt_capability_set(nxt_task_t *task, nxt_capabilities_t *cap) +{ + nxt_assert(cap->setid == 0); + + if (geteuid() == 0) { + cap->setid = 1; + return NXT_OK; + } + + return nxt_capability_specific_set(task, cap); +} + + +#if (NXT_HAVE_LINUX_CAPABILITY) + +static uint32_t +nxt_capability_linux_get_version() +{ + struct __user_cap_header_struct hdr; + + hdr.version = _LINUX_CAPABILITY_VERSION; + hdr.pid = nxt_pid; + + nxt_capget(&hdr, NULL); + return hdr.version; +} + + +static nxt_int_t +nxt_capability_specific_set(nxt_task_t *task, nxt_capabilities_t *cap) +{ + struct __user_cap_data_struct *val, data[2]; + struct __user_cap_header_struct hdr; + + /* + * Linux capability v1 fills an u32 struct. + * Linux capability v2 and v3 fills an u64 struct. + * We allocate data[2] for compatibility, we waste 4 bytes on v1. + * + * This is safe as we only need to check CAP_SETUID and CAP_SETGID + * that resides in the first 32-bit chunk. + */ + + val = &data[0]; + + /* + * Ask the kernel the preferred capability version + * instead of using _LINUX_CAPABILITY_VERSION from header. + * This is safer when distributing a pre-compiled Unit binary. + */ + hdr.version = nxt_capability_linux_get_version(); + hdr.pid = nxt_pid; + + if (nxt_slow_path(nxt_capget(&hdr, val) == -1)) { + nxt_alert(task, "failed to get process capabilities: %E", nxt_errno); + return NXT_ERROR; + } + + if ((val->effective & (1 << CAP_SETUID)) == 0) { + return NXT_OK; + } + + if ((val->effective & (1 << CAP_SETGID)) == 0) { + return NXT_OK; + } + + cap->setid = 1; + return NXT_OK; +} + +#else + +static nxt_int_t +nxt_capability_specific_set(nxt_task_t *task, nxt_capabilities_t *cap) +{ + return NXT_OK; +} + +#endif diff --git a/src/nxt_capability.h b/src/nxt_capability.h new file mode 100644 index 00000000..60bbd5f8 --- /dev/null +++ b/src/nxt_capability.h @@ -0,0 +1,17 @@ +/* + * Copyright (C) Igor Sysoev + * Copyright (C) NGINX, Inc. + */ + +#ifndef _NXT_CAPABILITY_INCLUDED_ +#define _NXT_CAPABILITY_INCLUDED_ + +typedef struct { + uint8_t setid; /* 1 bit */ +} nxt_capabilities_t; + + +NXT_EXPORT nxt_int_t nxt_capability_set(nxt_task_t *task, + nxt_capabilities_t *cap); + +#endif /* _NXT_CAPABILITY_INCLUDED_ */ diff --git a/src/nxt_clone.c b/src/nxt_clone.c new file mode 100644 index 00000000..0fddd6c7 --- /dev/null +++ b/src/nxt_clone.c @@ -0,0 +1,263 @@ +/* + * Copyright (C) Igor Sysoev + * Copyright (C) NGINX, Inc. + */ + +#include +#include +#include +#include + +#if (NXT_HAVE_CLONE) + +pid_t +nxt_clone(nxt_int_t flags) +{ +#if defined(__s390x__) || defined(__s390__) || defined(__CRIS__) + return syscall(__NR_clone, NULL, flags); +#else + return syscall(__NR_clone, flags, NULL); +#endif +} + +#endif + + +#if (NXT_HAVE_CLONE_NEWUSER) + +/* map uid 65534 to unit pid */ +#define NXT_DEFAULT_UNPRIV_MAP "65534 %d 1" + +nxt_int_t nxt_clone_proc_setgroups(nxt_task_t *task, pid_t child_pid, + const char *str); +nxt_int_t nxt_clone_proc_map_set(nxt_task_t *task, const char* mapfile, + pid_t pid, nxt_int_t defval, nxt_conf_value_t *mapobj); +nxt_int_t nxt_clone_proc_map_write(nxt_task_t *task, const char *mapfile, + pid_t pid, u_char *mapinfo); + + +typedef struct { + nxt_int_t container; + nxt_int_t host; + nxt_int_t size; +} nxt_clone_procmap_t; + + +nxt_int_t +nxt_clone_proc_setgroups(nxt_task_t *task, pid_t child_pid, const char *str) +{ + int fd, n; + u_char *p, *end; + u_char path[PATH_MAX]; + + end = path + PATH_MAX; + p = nxt_sprintf(path, end, "/proc/%d/setgroups", child_pid); + *p = '\0'; + + if (nxt_slow_path(p == end)) { + nxt_alert(task, "error write past the buffer: %s", path); + return NXT_ERROR; + } + + fd = open((char *)path, O_RDWR); + + if (fd == -1) { + /* + * If the /proc/pid/setgroups doesn't exists, we are + * safe to set uid/gid maps. But if the error is anything + * other than ENOENT, then we should abort and let user know. + */ + + if (errno != ENOENT) { + nxt_alert(task, "open(%s): %E", path, nxt_errno); + return NXT_ERROR; + } + + return NXT_OK; + } + + n = write(fd, str, strlen(str)); + close(fd); + + if (nxt_slow_path(n == -1)) { + nxt_alert(task, "write(%s): %E", path, nxt_errno); + return NXT_ERROR; + } + + return NXT_OK; +} + + +nxt_int_t +nxt_clone_proc_map_write(nxt_task_t *task, const char *mapfile, pid_t pid, + u_char *mapinfo) +{ + int len, mapfd; + u_char *p, *end; + ssize_t n; + u_char buf[256]; + + end = buf + sizeof(buf); + + p = nxt_sprintf(buf, end, "/proc/%d/%s", pid, mapfile); + if (nxt_slow_path(p == end)) { + nxt_alert(task, "writing past the buffer"); + return NXT_ERROR; + } + + *p = '\0'; + + mapfd = open((char*)buf, O_RDWR); + if (nxt_slow_path(mapfd == -1)) { + nxt_alert(task, "failed to open proc map (%s) %E", buf, nxt_errno); + return NXT_ERROR; + } + + len = nxt_strlen(mapinfo); + + n = write(mapfd, (char *)mapinfo, len); + if (nxt_slow_path(n != len)) { + + if (n == -1 && nxt_errno == EINVAL) { + nxt_alert(task, "failed to write %s: Check kernel maximum " \ + "allowed lines %E", buf, nxt_errno); + + } else { + nxt_alert(task, "failed to write proc map (%s) %E", buf, + nxt_errno); + } + + return NXT_ERROR; + } + + return NXT_OK; +} + + +nxt_int_t +nxt_clone_proc_map_set(nxt_task_t *task, const char* mapfile, pid_t pid, + nxt_int_t defval, nxt_conf_value_t *mapobj) +{ + u_char *p, *end, *mapinfo; + nxt_int_t container, host, size; + nxt_int_t ret, len, count, i; + nxt_conf_value_t *obj, *value; + + static nxt_str_t str_cont = nxt_string("container"); + static nxt_str_t str_host = nxt_string("host"); + static nxt_str_t str_size = nxt_string("size"); + + /* + * uid_map one-entry size: + * alloc space for 3 numbers (32bit) plus 2 spaces and \n. + */ + len = sizeof(u_char) * (10 + 10 + 10 + 2 + 1); + + if (mapobj != NULL) { + count = nxt_conf_array_elements_count(mapobj); + + if (count == 0) { + goto default_map; + } + + len = len * count + 1; + + mapinfo = nxt_malloc(len); + if (nxt_slow_path(mapinfo == NULL)) { + nxt_alert(task, "failed to allocate uid_map buffer"); + return NXT_ERROR; + } + + p = mapinfo; + end = mapinfo + len; + + for (i = 0; i < count; i++) { + obj = nxt_conf_get_array_element(mapobj, i); + + value = nxt_conf_get_object_member(obj, &str_cont, NULL); + container = nxt_conf_get_integer(value); + + value = nxt_conf_get_object_member(obj, &str_host, NULL); + host = nxt_conf_get_integer(value); + + value = nxt_conf_get_object_member(obj, &str_size, NULL); + size = nxt_conf_get_integer(value); + + p = nxt_sprintf(p, end, "%d %d %d", container, host, size); + if (nxt_slow_path(p == end)) { + nxt_alert(task, "write past the uid_map buffer"); + nxt_free(mapinfo); + return NXT_ERROR; + } + + if (i+1 < count) { + *p++ = '\n'; + + } else { + *p = '\0'; + } + } + + } else { + +default_map: + + mapinfo = nxt_malloc(len); + if (nxt_slow_path(mapinfo == NULL)) { + nxt_alert(task, "failed to allocate uid_map buffer"); + return NXT_ERROR; + } + + end = mapinfo + len; + p = nxt_sprintf(mapinfo, end, NXT_DEFAULT_UNPRIV_MAP, defval); + *p = '\0'; + + if (nxt_slow_path(p == end)) { + nxt_alert(task, "write past the %s buffer", mapfile); + nxt_free(mapinfo); + return NXT_ERROR; + } + } + + ret = nxt_clone_proc_map_write(task, mapfile, pid, mapinfo); + + nxt_free(mapinfo); + + return ret; +} + + +nxt_int_t +nxt_clone_proc_map(nxt_task_t *task, pid_t pid, nxt_process_clone_t *clone) +{ + nxt_int_t ret; + nxt_int_t uid, gid; + const char *rule; + nxt_runtime_t *rt; + + rt = task->thread->runtime; + uid = geteuid(); + gid = getegid(); + + rule = rt->capabilities.setid ? "allow" : "deny"; + + ret = nxt_clone_proc_map_set(task, "uid_map", pid, uid, clone->uidmap); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + + ret = nxt_clone_proc_setgroups(task, pid, rule); + if (nxt_slow_path(ret != NXT_OK)) { + nxt_alert(task, "failed to write /proc/%d/setgroups", pid); + return NXT_ERROR; + } + + ret = nxt_clone_proc_map_set(task, "gid_map", pid, gid, clone->gidmap); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + + return NXT_OK; +} + +#endif diff --git a/src/nxt_clone.h b/src/nxt_clone.h new file mode 100644 index 00000000..50dec0b4 --- /dev/null +++ b/src/nxt_clone.h @@ -0,0 +1,17 @@ +/* + * Copyright (C) Igor Sysoev + * Copyright (C) NGINX, Inc. + */ + +#ifndef _NXT_CLONE_INCLUDED_ +#define _NXT_CLONE_INCLUDED_ + + +pid_t nxt_clone(nxt_int_t flags); + +#if (NXT_HAVE_CLONE_NEWUSER) +nxt_int_t nxt_clone_proc_map(nxt_task_t *task, pid_t pid, + nxt_process_clone_t *clone); +#endif + +#endif /* _NXT_CLONE_INCLUDED_ */ diff --git a/src/nxt_conf_validation.c b/src/nxt_conf_validation.c index ca8ec62e..078ddd17 100644 --- a/src/nxt_conf_validation.c +++ b/src/nxt_conf_validation.c @@ -39,9 +39,6 @@ typedef nxt_int_t (*nxt_conf_vldt_member_t)(nxt_conf_validation_t *vldt, nxt_conf_value_t *value); typedef nxt_int_t (*nxt_conf_vldt_element_t)(nxt_conf_validation_t *vldt, nxt_conf_value_t *value); -typedef nxt_int_t (*nxt_conf_vldt_system_t)(nxt_conf_validation_t *vldt, - char *name); - static nxt_int_t nxt_conf_vldt_type(nxt_conf_validation_t *vldt, nxt_str_t *name, nxt_conf_value_t *value, nxt_conf_vldt_type_t type); @@ -86,10 +83,6 @@ static nxt_int_t nxt_conf_vldt_object_iterator(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, void *data); static nxt_int_t nxt_conf_vldt_array_iterator(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, void *data); -static nxt_int_t nxt_conf_vldt_system(nxt_conf_validation_t *vldt, - nxt_conf_value_t *value, void *data); -static nxt_int_t nxt_conf_vldt_user(nxt_conf_validation_t *vldt, char *name); -static nxt_int_t nxt_conf_vldt_group(nxt_conf_validation_t *vldt, char *name); static nxt_int_t nxt_conf_vldt_environment(nxt_conf_validation_t *vldt, nxt_str_t *name, nxt_conf_value_t *value); static nxt_int_t nxt_conf_vldt_argument(nxt_conf_validation_t *vldt, @@ -101,6 +94,21 @@ static nxt_int_t nxt_conf_vldt_java_classpath(nxt_conf_validation_t *vldt, static nxt_int_t nxt_conf_vldt_java_option(nxt_conf_validation_t *vldt, nxt_conf_value_t *value); +static nxt_int_t +nxt_conf_vldt_isolation(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, + void *data); +static nxt_int_t +nxt_conf_vldt_clone_namespaces(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value, void *data); + +#if (NXT_HAVE_CLONE_NEWUSER) +static nxt_int_t nxt_conf_vldt_clone_procmap(nxt_conf_validation_t *vldt, + const char* mapfile, nxt_conf_value_t *value); +static nxt_int_t nxt_conf_vldt_clone_uidmap(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value); +static nxt_int_t nxt_conf_vldt_clone_gidmap(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value); +#endif static nxt_conf_vldt_object_t nxt_conf_vldt_websocket_members[] = { { nxt_string("read_timeout"), @@ -340,6 +348,100 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_app_processes_members[] = { }; +static nxt_conf_vldt_object_t nxt_conf_vldt_app_namespaces_members[] = { + +#if (NXT_HAVE_CLONE_NEWUSER) + { nxt_string("credential"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + +#if (NXT_HAVE_CLONE_NEWPID) + { nxt_string("pid"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + +#if (NXT_HAVE_CLONE_NEWNET) + { nxt_string("network"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + +#if (NXT_HAVE_CLONE_NEWNS) + { nxt_string("mount"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + +#if (NXT_HAVE_CLONE_NEWUTS) + { nxt_string("uname"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + +#if (NXT_HAVE_CLONE_NEWCGROUP) + { nxt_string("cgroup"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + + NXT_CONF_VLDT_END +}; + + +#if (NXT_HAVE_CLONE_NEWUSER) + +static nxt_conf_vldt_object_t nxt_conf_vldt_app_procmap_members[] = { + { nxt_string("container"), + NXT_CONF_VLDT_INTEGER, + NULL, + NULL }, + + { nxt_string("host"), + NXT_CONF_VLDT_INTEGER, + NULL, + NULL }, + + { nxt_string("size"), + NXT_CONF_VLDT_INTEGER, + NULL, + NULL }, +}; + +#endif + + +static nxt_conf_vldt_object_t nxt_conf_vldt_app_isolation_members[] = { + { nxt_string("namespaces"), + NXT_CONF_VLDT_OBJECT, + &nxt_conf_vldt_clone_namespaces, + (void *) &nxt_conf_vldt_app_namespaces_members }, + +#if (NXT_HAVE_CLONE_NEWUSER) + + { nxt_string("uidmap"), + NXT_CONF_VLDT_ARRAY, + &nxt_conf_vldt_array_iterator, + (void *) &nxt_conf_vldt_clone_uidmap }, + + { nxt_string("gidmap"), + NXT_CONF_VLDT_ARRAY, + &nxt_conf_vldt_array_iterator, + (void *) &nxt_conf_vldt_clone_gidmap }, + +#endif + + NXT_CONF_VLDT_END +}; + + static nxt_conf_vldt_object_t nxt_conf_vldt_common_members[] = { { nxt_string("type"), NXT_CONF_VLDT_STRING, @@ -358,13 +460,13 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_common_members[] = { { nxt_string("user"), NXT_CONF_VLDT_STRING, - nxt_conf_vldt_system, - (void *) &nxt_conf_vldt_user }, + NULL, + NULL }, { nxt_string("group"), NXT_CONF_VLDT_STRING, - nxt_conf_vldt_system, - (void *) &nxt_conf_vldt_group }, + NULL, + NULL }, { nxt_string("working_directory"), NXT_CONF_VLDT_STRING, @@ -376,6 +478,11 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_common_members[] = { &nxt_conf_vldt_object_iterator, (void *) &nxt_conf_vldt_environment }, + { nxt_string("isolation"), + NXT_CONF_VLDT_OBJECT, + &nxt_conf_vldt_isolation, + (void *) &nxt_conf_vldt_app_isolation_members }, + NXT_CONF_VLDT_END }; @@ -1252,106 +1359,168 @@ nxt_conf_vldt_array_iterator(nxt_conf_validation_t *vldt, static nxt_int_t -nxt_conf_vldt_system(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, - void *data) +nxt_conf_vldt_environment(nxt_conf_validation_t *vldt, nxt_str_t *name, + nxt_conf_value_t *value) { - size_t length; - nxt_str_t name; - nxt_conf_vldt_system_t validator; - char string[32]; + nxt_str_t str; + + if (name->length == 0) { + return nxt_conf_vldt_error(vldt, + "The environment name must not be empty."); + } - /* The cast is required by Sun C. */ - validator = (nxt_conf_vldt_system_t) data; + if (nxt_memchr(name->start, '\0', name->length) != NULL) { + return nxt_conf_vldt_error(vldt, "The environment name must not " + "contain null character."); + } - nxt_conf_get_string(value, &name); + if (nxt_memchr(name->start, '=', name->length) != NULL) { + return nxt_conf_vldt_error(vldt, "The environment name must not " + "contain '=' character."); + } + + if (nxt_conf_type(value) != NXT_CONF_STRING) { + return nxt_conf_vldt_error(vldt, "The \"%V\" environment value must be " + "a string.", name); + } - length = name.length + 1; - length = nxt_min(length, sizeof(string)); + nxt_conf_get_string(value, &str); - nxt_cpystrn((u_char *) string, name.start, length); + if (nxt_memchr(str.start, '\0', str.length) != NULL) { + return nxt_conf_vldt_error(vldt, "The \"%V\" environment value must " + "not contain null character.", name); + } - return validator(vldt, string); + return NXT_OK; } static nxt_int_t -nxt_conf_vldt_user(nxt_conf_validation_t *vldt, char *user) +nxt_conf_vldt_clone_namespaces(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, + void *data) { - struct passwd *pwd; + return nxt_conf_vldt_object(vldt, value, data); +} - nxt_errno = 0; - pwd = getpwnam(user); +static nxt_int_t +nxt_conf_vldt_isolation(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, + void *data) +{ + return nxt_conf_vldt_object(vldt, value, data); +} - if (pwd != NULL) { - return NXT_OK; - } - if (nxt_errno == 0) { - return nxt_conf_vldt_error(vldt, "User \"%s\" is not found.", user); - } +#if (NXT_HAVE_CLONE_NEWUSER) - return NXT_ERROR; -} +typedef struct { + nxt_int_t container; + nxt_int_t host; + nxt_int_t size; +} nxt_conf_vldt_clone_procmap_conf_t; + + +static nxt_conf_map_t nxt_conf_vldt_clone_procmap_conf_map[] = { + { + nxt_string("container"), + NXT_CONF_MAP_INT32, + offsetof(nxt_conf_vldt_clone_procmap_conf_t, container), + }, + + { + nxt_string("host"), + NXT_CONF_MAP_INT32, + offsetof(nxt_conf_vldt_clone_procmap_conf_t, host), + }, + + { + nxt_string("size"), + NXT_CONF_MAP_INT32, + offsetof(nxt_conf_vldt_clone_procmap_conf_t, size), + }, + +}; static nxt_int_t -nxt_conf_vldt_group(nxt_conf_validation_t *vldt, char *group) +nxt_conf_vldt_clone_procmap(nxt_conf_validation_t *vldt, const char *mapfile, + nxt_conf_value_t *value) { - struct group *grp; + nxt_int_t ret; + nxt_conf_vldt_clone_procmap_conf_t procmap; - nxt_errno = 0; + procmap.container = -1; + procmap.host = -1; + procmap.size = -1; - grp = getgrnam(group); + ret = nxt_conf_map_object(vldt->pool, value, + nxt_conf_vldt_clone_procmap_conf_map, + nxt_nitems(nxt_conf_vldt_clone_procmap_conf_map), + &procmap); + if (ret != NXT_OK) { + return ret; + } - if (grp != NULL) { - return NXT_OK; + if (procmap.container == -1) { + return nxt_conf_vldt_error(vldt, "The %s requires the " + "\"container\" field set.", mapfile); } - if (nxt_errno == 0) { - return nxt_conf_vldt_error(vldt, "Group \"%s\" is not found.", group); + if (procmap.host == -1) { + return nxt_conf_vldt_error(vldt, "The %s requires the " + "\"host\" field set.", mapfile); } - return NXT_ERROR; + if (procmap.size == -1) { + return nxt_conf_vldt_error(vldt, "The %s requires the " + "\"size\" field set.", mapfile); + } + + return NXT_OK; } static nxt_int_t -nxt_conf_vldt_environment(nxt_conf_validation_t *vldt, nxt_str_t *name, - nxt_conf_value_t *value) +nxt_conf_vldt_clone_uidmap(nxt_conf_validation_t *vldt, nxt_conf_value_t *value) { - nxt_str_t str; + nxt_int_t ret; - if (name->length == 0) { - return nxt_conf_vldt_error(vldt, - "The environment name must not be empty."); + if (nxt_conf_type(value) != NXT_CONF_OBJECT) { + return nxt_conf_vldt_error(vldt, "The \"uidmap\" array " + "must contain only object values."); } - if (nxt_memchr(name->start, '\0', name->length) != NULL) { - return nxt_conf_vldt_error(vldt, "The environment name must not " - "contain null character."); + ret = nxt_conf_vldt_object(vldt, value, + (void *) nxt_conf_vldt_app_procmap_members); + if (nxt_slow_path(ret != NXT_OK)) { + return ret; } - if (nxt_memchr(name->start, '=', name->length) != NULL) { - return nxt_conf_vldt_error(vldt, "The environment name must not " - "contain '=' character."); - } + return nxt_conf_vldt_clone_procmap(vldt, "uid_map", value); +} - if (nxt_conf_type(value) != NXT_CONF_STRING) { - return nxt_conf_vldt_error(vldt, "The \"%V\" environment value must be " - "a string.", name); - } - nxt_conf_get_string(value, &str); +static nxt_int_t +nxt_conf_vldt_clone_gidmap(nxt_conf_validation_t *vldt, nxt_conf_value_t *value) +{ + nxt_int_t ret; - if (nxt_memchr(str.start, '\0', str.length) != NULL) { - return nxt_conf_vldt_error(vldt, "The \"%V\" environment value must " - "not contain null character.", name); + if (nxt_conf_type(value) != NXT_CONF_OBJECT) { + return nxt_conf_vldt_error(vldt, "The \"gidmap\" array " + "must contain only object values."); } - return NXT_OK; + ret = nxt_conf_vldt_object(vldt, value, + (void *) nxt_conf_vldt_app_procmap_members); + if (nxt_slow_path(ret != NXT_OK)) { + return ret; + } + + return nxt_conf_vldt_clone_procmap(vldt, "gid_map", value); } +#endif + static nxt_int_t nxt_conf_vldt_argument(nxt_conf_validation_t *vldt, nxt_conf_value_t *value) diff --git a/src/nxt_main.h b/src/nxt_main.h index 23c55002..0afebb96 100644 --- a/src/nxt_main.h +++ b/src/nxt_main.h @@ -57,6 +57,7 @@ typedef uint16_t nxt_port_id_t; #include #include #include +#include #include #include #include diff --git a/src/nxt_main_process.c b/src/nxt_main_process.c index 83c6d188..44deb272 100644 --- a/src/nxt_main_process.c +++ b/src/nxt_main_process.c @@ -14,6 +14,10 @@ #include #endif +#ifdef NXT_LINUX +#include +#endif + typedef struct { nxt_socket_t socket; @@ -68,6 +72,10 @@ static void nxt_main_port_conf_store_handler(nxt_task_t *task, static void nxt_main_port_access_log_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg); +static nxt_int_t nxt_init_set_isolation(nxt_task_t *task, + nxt_process_init_t *init, nxt_conf_value_t *isolation); +static nxt_int_t nxt_init_set_ns(nxt_task_t *task, + nxt_process_init_t *init, nxt_conf_value_t *ns); const nxt_sig_event_t nxt_main_process_signals[] = { nxt_event_signal(SIGHUP, nxt_main_process_signal_handler), @@ -134,6 +142,12 @@ static nxt_conf_map_t nxt_common_app_conf[] = { NXT_CONF_MAP_PTR, offsetof(nxt_common_app_conf_t, environment), }, + + { + nxt_string("isolation"), + NXT_CONF_MAP_PTR, + offsetof(nxt_common_app_conf_t, isolation), + } }; @@ -271,12 +285,11 @@ nxt_port_main_start_worker_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg) nxt_int_t ret; nxt_buf_t *b; nxt_port_t *port; + nxt_runtime_t *rt; nxt_app_type_t idx; nxt_conf_value_t *conf; nxt_common_app_conf_t app_conf; - static nxt_str_t nobody = nxt_string("nobody"); - ret = NXT_ERROR; mp = nxt_mp_create(1024, 128, 256, 32); @@ -311,7 +324,10 @@ nxt_port_main_start_worker_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg) goto failed; } - app_conf.user = nobody; + rt = task->thread->runtime; + + app_conf.user.start = (u_char*)rt->user_cred.user; + app_conf.user.length = nxt_strlen(rt->user_cred.user); ret = nxt_conf_map_object(mp, conf, nxt_common_app_conf, nxt_nitems(nxt_common_app_conf), &app_conf); @@ -458,6 +474,8 @@ nxt_main_start_controller_process(nxt_task_t *task, nxt_runtime_t *rt) return NXT_ERROR; } + nxt_memzero(init, sizeof(nxt_process_init_t)); + init->start = nxt_controller_start; init->name = "controller"; init->user_cred = &rt->user_cred; @@ -552,6 +570,8 @@ nxt_main_start_discovery_process(nxt_task_t *task, nxt_runtime_t *rt) return NXT_ERROR; } + nxt_memzero(init, sizeof(nxt_process_init_t)); + init->start = nxt_discovery_start; init->name = "discovery"; init->user_cred = &rt->user_cred; @@ -576,6 +596,8 @@ nxt_main_start_router_process(nxt_task_t *task, nxt_runtime_t *rt) return NXT_ERROR; } + nxt_memzero(init, sizeof(nxt_process_init_t)); + init->start = nxt_router_start; init->name = "router"; init->user_cred = &rt->user_cred; @@ -589,7 +611,6 @@ nxt_main_start_router_process(nxt_task_t *task, nxt_runtime_t *rt) return nxt_main_create_worker_process(task, rt, init); } - static nxt_int_t nxt_main_start_worker_process(nxt_task_t *task, nxt_runtime_t *rt, nxt_common_app_conf_t *app_conf, uint32_t stream) @@ -597,41 +618,72 @@ nxt_main_start_worker_process(nxt_task_t *task, nxt_runtime_t *rt, char *user, *group; u_char *title, *last, *end; size_t size; + nxt_int_t ret; nxt_process_init_t *init; size = sizeof(nxt_process_init_t) - + sizeof(nxt_user_cred_t) - + app_conf->user.length + 1 - + app_conf->group.length + 1 - + app_conf->name.length + sizeof("\"\" application"); + + app_conf->name.length + + sizeof("\"\" application"); + + if (rt->capabilities.setid) { + size += sizeof(nxt_user_cred_t) + + app_conf->user.length + 1 + + app_conf->group.length + 1; + } init = nxt_malloc(size); if (nxt_slow_path(init == NULL)) { return NXT_ERROR; } - init->user_cred = nxt_pointer_to(init, sizeof(nxt_process_init_t)); - user = nxt_pointer_to(init->user_cred, sizeof(nxt_user_cred_t)); + nxt_memzero(init, sizeof(nxt_process_init_t)); - nxt_memcpy(user, app_conf->user.start, app_conf->user.length); - last = nxt_pointer_to(user, app_conf->user.length); - *last++ = '\0'; + if (rt->capabilities.setid) { + init->user_cred = nxt_pointer_to(init, sizeof(nxt_process_init_t)); + user = nxt_pointer_to(init->user_cred, sizeof(nxt_user_cred_t)); - init->user_cred->user = user; + nxt_memcpy(user, app_conf->user.start, app_conf->user.length); + last = nxt_pointer_to(user, app_conf->user.length); + *last++ = '\0'; - if (app_conf->group.start != NULL) { - group = (char *) last; + init->user_cred->user = user; - nxt_memcpy(group, app_conf->group.start, app_conf->group.length); - last = nxt_pointer_to(group, app_conf->group.length); - *last++ = '\0'; + if (app_conf->group.start != NULL) { + group = (char *) last; + + nxt_memcpy(group, app_conf->group.start, app_conf->group.length); + last = nxt_pointer_to(group, app_conf->group.length); + *last++ = '\0'; + + } else { + group = NULL; + } + + ret = nxt_user_cred_get(task, init->user_cred, group); + if (ret != NXT_OK) { + return NXT_ERROR; + } } else { - group = NULL; - } + if (!nxt_str_eq(&app_conf->user, (u_char *) rt->user_cred.user, + nxt_strlen(rt->user_cred.user))) + { + nxt_alert(task, "cannot set user \"%V\" for app \"%V\": " + "missing capabilities", &app_conf->user, &app_conf->name); + return NXT_ERROR; + } - if (nxt_user_cred_get(task, init->user_cred, group) != NXT_OK) { - return NXT_ERROR; + if (app_conf->group.length > 0 + && !nxt_str_eq(&app_conf->group, (u_char *) rt->group, + nxt_strlen(rt->group))) + { + nxt_alert(task, "cannot set group \"%V\" for app \"%V\": " + "missing capabilities", &app_conf->group, + &app_conf->name); + return NXT_ERROR; + } + + last = nxt_pointer_to(init, sizeof(nxt_process_init_t)); } title = last; @@ -648,6 +700,11 @@ nxt_main_start_worker_process(nxt_task_t *task, nxt_runtime_t *rt, init->stream = stream; init->restart = NULL; + ret = nxt_init_set_isolation(task, init, app_conf->isolation); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + return nxt_main_create_worker_process(task, rt, init); } @@ -1246,7 +1303,7 @@ nxt_main_port_modules_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg) nxt_conf_value_t *conf, *root, *value; nxt_app_lang_module_t *lang; - static nxt_str_t root_path = nxt_string("/"); + static nxt_str_t root_path = nxt_string("/"); rt = task->thread->runtime; @@ -1438,3 +1495,105 @@ nxt_main_port_access_log_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg) msg->port_msg.stream, 0, NULL); } } + + +static nxt_int_t +nxt_init_set_isolation(nxt_task_t *task, nxt_process_init_t *init, + nxt_conf_value_t *isolation) +{ + nxt_int_t ret; + nxt_conf_value_t *object; + + static nxt_str_t nsname = nxt_string("namespaces"); + static nxt_str_t uidname = nxt_string("uidmap"); + static nxt_str_t gidname = nxt_string("gidmap"); + + if (isolation == NULL) { + return NXT_OK; + } + + object = nxt_conf_get_object_member(isolation, &nsname, NULL); + if (object != NULL) { + ret = nxt_init_set_ns(task, init, object); + if (ret != NXT_OK) { + return ret; + } + } + + object = nxt_conf_get_object_member(isolation, &uidname, NULL); + if (object != NULL) { + init->isolation.clone.uidmap = object; + } + + object = nxt_conf_get_object_member(isolation, &gidname, NULL); + if (object != NULL) { + init->isolation.clone.gidmap = object; + } + + return NXT_OK; +} + + +static nxt_int_t +nxt_init_set_ns(nxt_task_t *task, nxt_process_init_t *init, nxt_conf_value_t *namespaces) +{ + uint32_t index; + nxt_str_t name; + nxt_int_t flag; + nxt_conf_value_t *value; + + index = 0; + + while ((value = nxt_conf_next_object_member(namespaces, &name, &index)) != NULL) { + flag = 0; + +#if (NXT_HAVE_CLONE_NEWUSER) + if (nxt_str_eq(&name, "credential", 10)) { + flag = CLONE_NEWUSER; + } +#endif + +#if (NXT_HAVE_CLONE_NEWPID) + if (nxt_str_eq(&name, "pid", 3)) { + flag = CLONE_NEWPID; + } +#endif + +#if (NXT_HAVE_CLONE_NEWNET) + if (nxt_str_eq(&name, "network", 7)) { + flag = CLONE_NEWNET; + } +#endif + +#if (NXT_HAVE_CLONE_NEWUTS) + if (nxt_str_eq(&name, "uname", 5)) { + flag = CLONE_NEWUTS; + } +#endif + +#if (NXT_HAVE_CLONE_NEWNS) + if (nxt_str_eq(&name, "mount", 5)) { + flag = CLONE_NEWNS; + } +#endif + +#if (NXT_HAVE_CLONE_NEWCGROUP) + if (nxt_str_eq(&name, "cgroup", 6)) { + flag = CLONE_NEWCGROUP; + } +#endif + + if (!flag) { + nxt_alert(task, "unknown namespace flag: \"%V\"", &name); + return NXT_ERROR; + } + + if (nxt_conf_get_integer(value) == 0) { + continue; /* process shares everything by default */ + } + + init->isolation.clone.flags |= flag; + } + + return NXT_OK; +} diff --git a/src/nxt_process.c b/src/nxt_process.c index c4aef21c..638765a4 100644 --- a/src/nxt_process.c +++ b/src/nxt_process.c @@ -7,10 +7,16 @@ #include #include +#if (NXT_HAVE_CLONE) +#include +#endif + +#include static void nxt_process_start(nxt_task_t *task, nxt_process_t *process); static nxt_int_t nxt_user_groups_get(nxt_task_t *task, nxt_user_cred_t *uc); - +static nxt_int_t nxt_process_worker_setup(nxt_task_t *task, + nxt_process_t *process, int parentfd); /* A cached process pid. */ nxt_pid_t nxt_pid; @@ -34,84 +40,217 @@ nxt_bool_t nxt_proc_remove_notify_matrix[NXT_PROCESS_MAX][NXT_PROCESS_MAX] = { { 0, 0, 0, 1, 0 }, }; -nxt_pid_t -nxt_process_create(nxt_task_t *task, nxt_process_t *process) -{ - nxt_pid_t pid; + +static nxt_int_t +nxt_process_worker_setup(nxt_task_t *task, nxt_process_t *process, int parentfd) { + pid_t rpid, pid; + ssize_t n; + nxt_int_t parent_status; nxt_process_t *p; nxt_runtime_t *rt; + nxt_process_init_t *init; nxt_process_type_t ptype; - rt = task->thread->runtime; + pid = getpid(); + rpid = 0; + rt = task->thread->runtime; + init = process->init; - pid = fork(); + /* Setup the worker process. */ - switch (pid) { + n = read(parentfd, &rpid, sizeof(rpid)); + if (nxt_slow_path(n == -1 || n != sizeof(rpid))) { + nxt_alert(task, "failed to read real pid"); + return NXT_ERROR; + } - case -1: - nxt_alert(task, "fork() failed while creating \"%s\" %E", - process->init->name, nxt_errno); - break; + if (nxt_slow_path(rpid == 0)) { + nxt_alert(task, "failed to get real pid from parent"); + return NXT_ERROR; + } - case 0: - /* A child. */ - nxt_pid = getpid(); + nxt_pid = rpid; + + /* Clean inherited cached thread tid. */ + task->thread->tid = 0; + + process->pid = nxt_pid; + + if (nxt_pid != pid) { + nxt_debug(task, "app \"%s\" real pid %d", init->name, nxt_pid); + nxt_debug(task, "app \"%s\" isolated pid: %d", init->name, pid); + } - /* Clean inherited cached thread tid. */ - task->thread->tid = 0; + n = read(parentfd, &parent_status, sizeof(parent_status)); + if (nxt_slow_path(n == -1 || n != sizeof(parent_status))) { + nxt_alert(task, "failed to read parent status"); + return NXT_ERROR; + } - process->pid = nxt_pid; + if (nxt_slow_path(close(parentfd) == -1)) { + nxt_alert(task, "failed to close reader pipe fd"); + return NXT_ERROR; + } - ptype = process->init->type; + if (nxt_slow_path(parent_status != NXT_OK)) { + return parent_status; + } - nxt_port_reset_next_id(); + ptype = init->type; - nxt_event_engine_thread_adopt(task->thread->engine); + nxt_port_reset_next_id(); - /* Remove not ready processes */ - nxt_runtime_process_each(rt, p) { + nxt_event_engine_thread_adopt(task->thread->engine); - if (nxt_proc_conn_matrix[ptype][nxt_process_type(p)] == 0) { - nxt_debug(task, "remove not required process %PI", p->pid); + /* Remove not ready processes. */ + nxt_runtime_process_each(rt, p) { - nxt_process_close_ports(task, p); + if (nxt_proc_conn_matrix[ptype][nxt_process_type(p)] == 0) { + nxt_debug(task, "remove not required process %PI", p->pid); - continue; - } + nxt_process_close_ports(task, p); - if (!p->ready) { - nxt_debug(task, "remove not ready process %PI", p->pid); + continue; + } - nxt_process_close_ports(task, p); + if (!p->ready) { + nxt_debug(task, "remove not ready process %PI", p->pid); - continue; - } + nxt_process_close_ports(task, p); - nxt_port_mmaps_destroy(&p->incoming, 0); - nxt_port_mmaps_destroy(&p->outgoing, 0); + continue; + } - } nxt_runtime_process_loop; + nxt_port_mmaps_destroy(&p->incoming, 0); + nxt_port_mmaps_destroy(&p->outgoing, 0); - nxt_runtime_process_add(task, process); + } nxt_runtime_process_loop; - nxt_process_start(task, process); + nxt_runtime_process_add(task, process); - process->ready = 1; + nxt_process_start(task, process); - break; + process->ready = 1; - default: - /* A parent. */ - nxt_debug(task, "fork(\"%s\"): %PI", process->init->name, pid); + return NXT_OK; +} - process->pid = pid; - nxt_runtime_process_add(task, process); +nxt_pid_t +nxt_process_create(nxt_task_t *task, nxt_process_t *process) +{ + int pipefd[2]; + nxt_int_t ret; + nxt_pid_t pid; + nxt_process_init_t *init; - break; + if (nxt_slow_path(pipe(pipefd) == -1)) { + nxt_alert(task, "failed to create process pipe for passing rpid"); + return -1; + } + + init = process->init; + +#if (NXT_HAVE_CLONE) + pid = nxt_clone(SIGCHLD|init->isolation.clone.flags); +#else + pid = fork(); +#endif + + if (nxt_slow_path(pid < 0)) { +#if (NXT_HAVE_CLONE) + nxt_alert(task, "clone() failed while creating \"%s\" %E", + init->name, nxt_errno); +#else + nxt_alert(task, "fork() failed while creating \"%s\" %E", + init->name, nxt_errno); +#endif + + return pid; + } + + if (pid == 0) { + /* Child. */ + + if (nxt_slow_path(close(pipefd[1]) == -1)) { + nxt_alert(task, "failed to close writer pipe fd"); + return NXT_ERROR; + } + + ret = nxt_process_worker_setup(task, process, pipefd[0]); + if (nxt_slow_path(ret != NXT_OK)) { + exit(1); + } + + /* + * Explicitly return 0 to notice the caller function this is the child. + * The caller must return to the event engine work queue loop. + */ + return 0; + } + + /* Parent. */ + + if (nxt_slow_path(close(pipefd[0]) != 0)) { + nxt_alert(task, "failed to close pipe: %E", nxt_errno); + } + + /* + * At this point, the child process is blocked reading the + * pipe fd to get its real pid (rpid). + * + * If anything goes wrong now, we need to terminate the child + * process by sending a NXT_ERROR in the pipe. + */ + +#if (NXT_HAVE_CLONE) + nxt_debug(task, "clone(\"%s\"): %PI", init->name, pid); +#else + nxt_debug(task, "fork(\"%s\"): %PI", init->name, pid); +#endif + + if (nxt_slow_path(write(pipefd[1], &pid, sizeof(pid)) == -1)) { + nxt_alert(task, "failed to write real pid"); + goto fail_cleanup; + } + +#if (NXT_HAVE_CLONE_NEWUSER) + if ((init->isolation.clone.flags & CLONE_NEWUSER) == CLONE_NEWUSER) { + ret = nxt_clone_proc_map(task, pid, &init->isolation.clone); + if (nxt_slow_path(ret != NXT_OK)) { + goto fail_cleanup; + } + } +#endif + + ret = NXT_OK; + + if (nxt_slow_path(write(pipefd[1], &ret, sizeof(ret)) == -1)) { + nxt_alert(task, "failed to write status"); + goto fail_cleanup; } + process->pid = pid; + + nxt_runtime_process_add(task, process); + return pid; + +fail_cleanup: + + ret = NXT_ERROR; + + if (nxt_slow_path(write(pipefd[1], &ret, sizeof(ret)) == -1)) { + nxt_alert(task, "failed to write status"); + } + + if (nxt_slow_path(close(pipefd[1]) != 0)) { + nxt_alert(task, "failed to close pipe: %E", nxt_errno); + } + + waitpid(pid, NULL, 0); + + return -1; } @@ -133,22 +272,17 @@ nxt_process_start(nxt_task_t *task, nxt_process_t *process) nxt_process_title(task, "unit: %s", init->name); thread = task->thread; + rt = thread->runtime; nxt_random_init(&thread->random); - if (init->user_cred != NULL) { - /* - * Changing user credentials requires either root privileges - * or CAP_SETUID and CAP_SETGID capabilities on Linux. - */ + if (rt->capabilities.setid && init->user_cred != NULL) { ret = nxt_user_cred_set(task, init->user_cred); if (ret != NXT_OK) { goto fail; } } - rt = thread->runtime; - rt->type = init->type; engine = thread->engine; @@ -592,15 +726,8 @@ nxt_user_cred_set(nxt_task_t *task, nxt_user_cred_t *uc) uc->user, (uint64_t) uc->uid, (uint64_t) uc->base_gid); if (setgid(uc->base_gid) != 0) { - if (nxt_errno == NXT_EPERM) { - nxt_log(task, NXT_LOG_NOTICE, "setgid(%d) failed %E, ignored", - uc->base_gid, nxt_errno); - return NXT_OK; - - } else { - nxt_alert(task, "setgid(%d) failed %E", uc->base_gid, nxt_errno); - return NXT_ERROR; - } + nxt_alert(task, "setgid(%d) failed %E", uc->base_gid, nxt_errno); + return NXT_ERROR; } if (uc->gids != NULL) { diff --git a/src/nxt_process.h b/src/nxt_process.h index c6e19f97..df9ca038 100644 --- a/src/nxt_process.h +++ b/src/nxt_process.h @@ -7,6 +7,8 @@ #ifndef _NXT_PROCESS_H_INCLUDED_ #define _NXT_PROCESS_H_INCLUDED_ +#include + typedef pid_t nxt_pid_t; typedef uid_t nxt_uid_t; @@ -21,26 +23,35 @@ typedef struct { nxt_gid_t *gids; } nxt_user_cred_t; +typedef struct { + nxt_int_t flags; + nxt_conf_value_t *uidmap; + nxt_conf_value_t *gidmap; +} nxt_process_clone_t; + typedef struct nxt_process_init_s nxt_process_init_t; typedef nxt_int_t (*nxt_process_start_t)(nxt_task_t *task, void *data); typedef nxt_int_t (*nxt_process_restart_t)(nxt_task_t *task, nxt_runtime_t *rt, nxt_process_init_t *init); - struct nxt_process_init_s { - nxt_process_start_t start; - const char *name; - nxt_user_cred_t *user_cred; + nxt_process_start_t start; + const char *name; + nxt_user_cred_t *user_cred; + + nxt_port_handlers_t *port_handlers; + const nxt_sig_event_t *signals; - nxt_port_handlers_t *port_handlers; - const nxt_sig_event_t *signals; + nxt_process_type_t type; - nxt_process_type_t type; + void *data; + uint32_t stream; - void *data; - uint32_t stream; + nxt_process_restart_t restart; - nxt_process_restart_t restart; + union { + nxt_process_clone_t clone; + } isolation; }; diff --git a/src/nxt_runtime.c b/src/nxt_runtime.c index 06478f72..de41ba4d 100644 --- a/src/nxt_runtime.c +++ b/src/nxt_runtime.c @@ -692,14 +692,26 @@ nxt_runtime_conf_init(nxt_task_t *task, nxt_runtime_t *rt) rt->state = NXT_STATE; rt->control = NXT_CONTROL_SOCK; + nxt_memzero(&rt->capabilities, sizeof(nxt_capabilities_t)); + if (nxt_runtime_conf_read_cmd(task, rt) != NXT_OK) { return NXT_ERROR; } - if (nxt_user_cred_get(task, &rt->user_cred, rt->group) != NXT_OK) { + if (nxt_capability_set(task, &rt->capabilities) != NXT_OK) { return NXT_ERROR; } + if (rt->capabilities.setid) { + if (nxt_user_cred_get(task, &rt->user_cred, rt->group) != NXT_OK) { + return NXT_ERROR; + } + + } else { + nxt_log(task, NXT_LOG_WARN, "Unit is running unprivileged, then it " + "cannot use arbitrary user and group."); + } + /* An engine's parameters. */ interface = nxt_service_get(rt->services, "engine", rt->engine); diff --git a/src/nxt_runtime.h b/src/nxt_runtime.h index 496ae478..0791f8e7 100644 --- a/src/nxt_runtime.h +++ b/src/nxt_runtime.h @@ -59,6 +59,7 @@ struct nxt_runtime_s { uint32_t engine_connections; uint32_t auxiliary_threads; nxt_user_cred_t user_cred; + nxt_capabilities_t capabilities; const char *group; const char *pid; const char *log; diff --git a/src/nxt_unit.c b/src/nxt_unit.c index 4497d09d..9ccd1fd9 100644 --- a/src/nxt_unit.c +++ b/src/nxt_unit.c @@ -333,6 +333,7 @@ nxt_unit_init(nxt_unit_init_t *init) } } + lib->pid = read_port.id.pid; ctx = &lib->main_ctx.ctx; rc = lib->callbacks.add_port(ctx, &ready_port); @@ -398,7 +399,6 @@ nxt_unit_create(nxt_unit_init_t *init) lib->processes.slot = NULL; lib->ports.slot = NULL; - lib->pid = getpid(); lib->log_fd = STDERR_FILENO; lib->online = 1; -- cgit From 08a8d1510d5f73d91112ead9e6ac075fb7d2bac0 Mon Sep 17 00:00:00 2001 From: Valentin Bartenev Date: Thu, 19 Sep 2019 02:47:09 +0300 Subject: Basic support for serving static files. --- src/nxt_conf.h | 1 + src/nxt_conf_validation.c | 129 ++++++++++ src/nxt_http.h | 29 +++ src/nxt_http_request.c | 25 +- src/nxt_http_route.c | 28 ++- src/nxt_http_static.c | 599 ++++++++++++++++++++++++++++++++++++++++++++++ src/nxt_router.c | 93 ++++++- src/nxt_router.h | 6 +- src/nxt_sprintf.c | 4 +- src/nxt_string.c | 66 +++++ src/nxt_string.h | 1 + 11 files changed, 952 insertions(+), 29 deletions(-) create mode 100644 src/nxt_http_static.c (limited to 'src') diff --git a/src/nxt_conf.h b/src/nxt_conf.h index 2435b0e2..725a6c95 100644 --- a/src/nxt_conf.h +++ b/src/nxt_conf.h @@ -71,6 +71,7 @@ typedef struct { nxt_conf_value_t *conf; nxt_mp_t *pool; nxt_str_t error; + void *ctx; } nxt_conf_validation_t; diff --git a/src/nxt_conf_validation.c b/src/nxt_conf_validation.c index 078ddd17..c934b10b 100644 --- a/src/nxt_conf_validation.c +++ b/src/nxt_conf_validation.c @@ -8,6 +8,7 @@ #include #include #include +#include typedef enum { @@ -45,6 +46,12 @@ static nxt_int_t nxt_conf_vldt_type(nxt_conf_validation_t *vldt, static nxt_int_t nxt_conf_vldt_error(nxt_conf_validation_t *vldt, const char *fmt, ...); +static nxt_int_t nxt_conf_vldt_mtypes(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value, void *data); +static nxt_int_t nxt_conf_vldt_mtypes_type(nxt_conf_validation_t *vldt, + nxt_str_t *name, nxt_conf_value_t *value); +static nxt_int_t nxt_conf_vldt_mtypes_extension(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value); static nxt_int_t nxt_conf_vldt_listener(nxt_conf_validation_t *vldt, nxt_str_t *name, nxt_conf_value_t *value); #if (NXT_TLS) @@ -130,6 +137,16 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_websocket_members[] = { }; +static nxt_conf_vldt_object_t nxt_conf_vldt_static_members[] = { + { nxt_string("mime_types"), + NXT_CONF_VLDT_OBJECT, + &nxt_conf_vldt_mtypes, + NULL }, + + NXT_CONF_VLDT_END +}; + + static nxt_conf_vldt_object_t nxt_conf_vldt_http_members[] = { { nxt_string("header_read_timeout"), NXT_CONF_VLDT_INTEGER, @@ -161,6 +178,11 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_http_members[] = { &nxt_conf_vldt_object, (void *) &nxt_conf_vldt_websocket_members }, + { nxt_string("static"), + NXT_CONF_VLDT_OBJECT, + &nxt_conf_vldt_object, + (void *) &nxt_conf_vldt_static_members }, + NXT_CONF_VLDT_END }; @@ -289,6 +311,11 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_action_members[] = { &nxt_conf_vldt_pass, NULL }, + { nxt_string("share"), + NXT_CONF_VLDT_STRING, + NULL, + NULL }, + NXT_CONF_VLDT_END }; @@ -735,6 +762,108 @@ nxt_conf_vldt_error(nxt_conf_validation_t *vldt, const char *fmt, ...) } +typedef struct { + nxt_mp_t *pool; + nxt_str_t *type; + nxt_lvlhsh_t hash; +} nxt_conf_vldt_mtypes_ctx_t; + + +static nxt_int_t +nxt_conf_vldt_mtypes(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, + void *data) +{ + nxt_int_t ret; + nxt_conf_vldt_mtypes_ctx_t ctx; + + ctx.pool = nxt_mp_create(1024, 128, 256, 32); + if (nxt_slow_path(ctx.pool == NULL)) { + return NXT_ERROR; + } + + nxt_lvlhsh_init(&ctx.hash); + + vldt->ctx = &ctx; + + ret = nxt_conf_vldt_object_iterator(vldt, value, + &nxt_conf_vldt_mtypes_type); + + vldt->ctx = NULL; + + nxt_mp_destroy(ctx.pool); + + return ret; +} + + +static nxt_int_t +nxt_conf_vldt_mtypes_type(nxt_conf_validation_t *vldt, nxt_str_t *name, + nxt_conf_value_t *value) +{ + nxt_int_t ret; + nxt_conf_vldt_mtypes_ctx_t *ctx; + + ret = nxt_conf_vldt_type(vldt, name, value, + NXT_CONF_VLDT_STRING|NXT_CONF_VLDT_ARRAY); + if (ret != NXT_OK) { + return ret; + } + + ctx = vldt->ctx; + + ctx->type = nxt_mp_get(ctx->pool, sizeof(nxt_str_t)); + if (nxt_slow_path(ctx->type == NULL)) { + return NXT_ERROR; + } + + *ctx->type = *name; + + if (nxt_conf_type(value) == NXT_CONF_ARRAY) { + return nxt_conf_vldt_array_iterator(vldt, value, + &nxt_conf_vldt_mtypes_extension); + } + + /* NXT_CONF_STRING */ + + return nxt_conf_vldt_mtypes_extension(vldt, value); +} + + +static nxt_int_t +nxt_conf_vldt_mtypes_extension(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value) +{ + nxt_str_t ext, *dup_type; + nxt_conf_vldt_mtypes_ctx_t *ctx; + + ctx = vldt->ctx; + + if (nxt_conf_type(value) != NXT_CONF_STRING) { + return nxt_conf_vldt_error(vldt, "The \"%V\" MIME type array must " + "contain only strings.", ctx->type); + } + + nxt_conf_get_string(value, &ext); + + if (ext.length == 0) { + return nxt_conf_vldt_error(vldt, "An empty file extension for " + "the \"%V\" MIME type.", ctx->type); + } + + dup_type = nxt_http_static_mtypes_hash_find(&ctx->hash, &ext); + + if (dup_type != NULL) { + return nxt_conf_vldt_error(vldt, "The \"%V\" file extension has been " + "declared for \"%V\" and \"%V\" " + "MIME types at the same time.", + &ext, dup_type, ctx->type); + } + + return nxt_http_static_mtypes_hash_add(ctx->pool, &ctx->hash, + &ext, ctx->type); +} + + static nxt_int_t nxt_conf_vldt_listener(nxt_conf_validation_t *vldt, nxt_str_t *name, nxt_conf_value_t *value) diff --git a/src/nxt_http.h b/src/nxt_http.h index 2ca3f536..560b7310 100644 --- a/src/nxt_http.h +++ b/src/nxt_http.h @@ -24,7 +24,9 @@ typedef enum { NXT_HTTP_NOT_MODIFIED = 304, NXT_HTTP_BAD_REQUEST = 400, + NXT_HTTP_FORBIDDEN = 403, NXT_HTTP_NOT_FOUND = 404, + NXT_HTTP_METHOD_NOT_ALLOWED = 405, NXT_HTTP_REQUEST_TIMEOUT = 408, NXT_HTTP_LENGTH_REQUIRED = 411, NXT_HTTP_PAYLOAD_TOO_LARGE = 413, @@ -187,6 +189,25 @@ typedef struct { } nxt_http_proto_table_t; +#define NXT_HTTP_DATE_LEN nxt_length("Wed, 31 Dec 1986 16:40:00 GMT") + +nxt_inline u_char * +nxt_http_date(u_char *buf, struct tm *tm) +{ + static const char *week[] = { "Sun", "Mon", "Tue", "Wed", "Thu", "Fri", + "Sat" }; + + static const char *month[] = { "Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec" }; + + return nxt_sprintf(buf, buf + NXT_HTTP_DATE_LEN, + "%s, %02d %s %4d %02d:%02d:%02d GMT", + week[tm->tm_wday], tm->tm_mday, + month[tm->tm_mon], tm->tm_year + 1900, + tm->tm_hour, tm->tm_min, tm->tm_sec); +} + + nxt_int_t nxt_http_init(nxt_task_t *task, nxt_runtime_t *rt); nxt_int_t nxt_h1p_init(nxt_task_t *task, nxt_runtime_t *rt); nxt_int_t nxt_http_response_hash_init(nxt_task_t *task, nxt_runtime_t *rt); @@ -225,6 +246,14 @@ nxt_http_pass_t *nxt_http_pass_application(nxt_task_t *task, void nxt_http_routes_cleanup(nxt_task_t *task, nxt_http_routes_t *routes); void nxt_http_pass_cleanup(nxt_task_t *task, nxt_http_pass_t *pass); +nxt_http_pass_t *nxt_http_static_handler(nxt_task_t *task, + nxt_http_request_t *r, nxt_http_pass_t *pass); +nxt_int_t nxt_http_static_mtypes_init(nxt_mp_t *mp, nxt_lvlhsh_t *hash); +nxt_int_t nxt_http_static_mtypes_hash_add(nxt_mp_t *mp, nxt_lvlhsh_t *hash, + nxt_str_t *extension, nxt_str_t *type); +nxt_str_t *nxt_http_static_mtypes_hash_find(nxt_lvlhsh_t *hash, + nxt_str_t *extension); + nxt_http_pass_t *nxt_http_request_application(nxt_task_t *task, nxt_http_request_t *r, nxt_http_pass_t *pass); diff --git a/src/nxt_http_request.c b/src/nxt_http_request.c index 8c38eeb2..a18a02e7 100644 --- a/src/nxt_http_request.c +++ b/src/nxt_http_request.c @@ -17,8 +17,8 @@ static void nxt_http_request_mem_buf_completion(nxt_task_t *task, void *obj, void *data); static void nxt_http_request_done(nxt_task_t *task, void *obj, void *data); -static u_char *nxt_http_date(u_char *buf, nxt_realtime_t *now, struct tm *tm, - size_t size, const char *format); +static u_char *nxt_http_date_cache_handler(u_char *buf, nxt_realtime_t *now, + struct tm *tm, size_t size, const char *format); static const nxt_http_request_state_t nxt_http_request_init_state; @@ -27,9 +27,9 @@ static const nxt_http_request_state_t nxt_http_request_body_state; nxt_time_string_t nxt_http_date_cache = { (nxt_atomic_uint_t) -1, - nxt_http_date, - "%s, %02d %s %4d %02d:%02d:%02d GMT", - nxt_length("Wed, 31 Dec 1986 16:40:00 GMT"), + nxt_http_date_cache_handler, + NULL, + NXT_HTTP_DATE_LEN, NXT_THREAD_TIME_GMT, NXT_THREAD_TIME_SEC, }; @@ -578,17 +578,8 @@ nxt_http_request_close_handler(nxt_task_t *task, void *obj, void *data) static u_char * -nxt_http_date(u_char *buf, nxt_realtime_t *now, struct tm *tm, size_t size, - const char *format) +nxt_http_date_cache_handler(u_char *buf, nxt_realtime_t *now, struct tm *tm, + size_t size, const char *format) { - static const char *week[] = { "Sun", "Mon", "Tue", "Wed", "Thu", "Fri", - "Sat" }; - - static const char *month[] = { "Jan", "Feb", "Mar", "Apr", "May", "Jun", - "Jul", "Aug", "Sep", "Oct", "Nov", "Dec" }; - - return nxt_sprintf(buf, buf + size, format, - week[tm->tm_wday], tm->tm_mday, - month[tm->tm_mon], tm->tm_year + 1900, - tm->tm_hour, tm->tm_min, tm->tm_sec); + return nxt_http_date(buf, tm); } diff --git a/src/nxt_http_route.c b/src/nxt_http_route.c index 0b665573..c3c11faa 100644 --- a/src/nxt_http_route.c +++ b/src/nxt_http_route.c @@ -376,15 +376,9 @@ nxt_http_route_match_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, nxt_http_route_match_conf_t mtcf; static nxt_str_t pass_path = nxt_string("/action/pass"); + static nxt_str_t share_path = nxt_string("/action/share"); static nxt_str_t match_path = nxt_string("/match"); - pass_conf = nxt_conf_get_path(cv, &pass_path); - if (nxt_slow_path(pass_conf == NULL)) { - return NULL; - } - - nxt_conf_get_string(pass_conf, &pass); - match_conf = nxt_conf_get_path(cv, &match_path); n = (match_conf != NULL) ? nxt_conf_object_members_count(match_conf) : 0; @@ -401,6 +395,19 @@ nxt_http_route_match_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, match->pass.handler = NULL; match->items = n; + pass_conf = nxt_conf_get_path(cv, &pass_path); + + if (pass_conf == NULL) { + pass_conf = nxt_conf_get_path(cv, &share_path); + if (nxt_slow_path(pass_conf == NULL)) { + return NULL; + } + + match->pass.handler = nxt_http_static_handler; + } + + nxt_conf_get_string(pass_conf, &pass); + string = nxt_str_dup(mp, &match->pass.name, &pass); if (nxt_slow_path(string == NULL)) { return NULL; @@ -870,13 +877,18 @@ static void nxt_http_route_resolve(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, nxt_http_route_t *route) { + nxt_http_pass_t *pass; nxt_http_route_match_t **match, **end; match = &route->match[0]; end = match + route->items; while (match < end) { - nxt_http_pass_resolve(task, tmcf, &(*match)->pass); + pass = &(*match)->pass; + + if (pass->handler == NULL) { + nxt_http_pass_resolve(task, tmcf, &(*match)->pass); + } match++; } diff --git a/src/nxt_http_static.c b/src/nxt_http_static.c new file mode 100644 index 00000000..44b85389 --- /dev/null +++ b/src/nxt_http_static.c @@ -0,0 +1,599 @@ + +/* + * Copyright (C) NGINX, Inc. + */ + +#include +#include + + +#define NXT_HTTP_STATIC_BUF_COUNT 2 +#define NXT_HTTP_STATIC_BUF_SIZE (128 * 1024) + + +static void nxt_http_static_extract_extension(nxt_str_t *path, + nxt_str_t *extension); +static void nxt_http_static_body_handler(nxt_task_t *task, void *obj, + void *data); +static void nxt_http_static_buf_completion(nxt_task_t *task, void *obj, + void *data); + +static nxt_int_t nxt_http_static_mtypes_hash_test(nxt_lvlhsh_query_t *lhq, + void *data); +static void *nxt_http_static_mtypes_hash_alloc(void *data, size_t size); +static void nxt_http_static_mtypes_hash_free(void *data, void *p); + + +static const nxt_http_request_state_t nxt_http_static_send_state; + + +nxt_http_pass_t * +nxt_http_static_handler(nxt_task_t *task, nxt_http_request_t *r, + nxt_http_pass_t *pass) +{ + size_t alloc, encode; + u_char *p; + struct tm tm; + nxt_buf_t *fb; + nxt_int_t ret; + nxt_str_t index, extension, *mtype; + nxt_uint_t level; + nxt_bool_t need_body; + nxt_file_t *f; + nxt_file_info_t fi; + nxt_http_field_t *field; + nxt_http_status_t status; + nxt_router_conf_t *rtcf; + nxt_work_handler_t body_handler; + + if (nxt_slow_path(!nxt_str_eq(r->method, "GET", 3))) { + + if (!nxt_str_eq(r->method, "HEAD", 4)) { + nxt_http_request_error(task, r, NXT_HTTP_METHOD_NOT_ALLOWED); + return NULL; + } + + need_body = 0; + + } else { + need_body = 1; + } + + f = nxt_mp_zget(r->mem_pool, sizeof(nxt_file_t)); + if (nxt_slow_path(f == NULL)) { + goto fail; + } + + f->fd = NXT_FILE_INVALID; + + if (r->path->start[r->path->length - 1] == '/') { + /* TODO: dynamic index setting. */ + nxt_str_set(&index, "index.html"); + nxt_str_set(&extension, ".html"); + + } else { + nxt_str_null(&index); + nxt_str_null(&extension); + } + + alloc = pass->name.length + r->path->length + index.length + 1; + + f->name = nxt_mp_nget(r->mem_pool, alloc); + if (nxt_slow_path(f->name == NULL)) { + goto fail; + } + + p = f->name; + p = nxt_cpymem(p, pass->name.start, pass->name.length); + p = nxt_cpymem(p, r->path->start, r->path->length); + p = nxt_cpymem(p, index.start, index.length); + *p = '\0'; + + ret = nxt_file_open(task, f, NXT_FILE_RDONLY, NXT_FILE_OPEN, 0); + + if (nxt_slow_path(ret != NXT_OK)) { + switch (f->error) { + + case NXT_ENOENT: + case NXT_ENOTDIR: + case NXT_ENAMETOOLONG: + level = NXT_LOG_ERR; + status = NXT_HTTP_NOT_FOUND; + break; + + case NXT_EACCES: + level = NXT_LOG_ERR; + status = NXT_HTTP_FORBIDDEN; + break; + + default: + level = NXT_LOG_ALERT; + status = NXT_HTTP_INTERNAL_SERVER_ERROR; + break; + } + + if (status != NXT_HTTP_NOT_FOUND) { + nxt_log(task, level, "open(\"%FN\") failed %E", f->name, f->error); + } + + nxt_http_request_error(task, r, status); + return NULL; + } + + ret = nxt_file_info(f, &fi); + if (nxt_slow_path(ret != NXT_OK)) { + goto fail; + } + + if (nxt_fast_path(nxt_is_file(&fi))) { + r->status = NXT_HTTP_OK; + r->resp.content_length_n = nxt_file_size(&fi); + + field = nxt_list_zero_add(r->resp.fields); + if (nxt_slow_path(field == NULL)) { + goto fail; + } + + nxt_http_field_name_set(field, "Last-Modified"); + + p = nxt_mp_nget(r->mem_pool, NXT_HTTP_DATE_LEN); + if (nxt_slow_path(p == NULL)) { + goto fail; + } + + nxt_localtime(nxt_file_mtime(&fi), &tm); + + field->value = p; + field->value_length = nxt_http_date(p, &tm) - p; + + field = nxt_list_zero_add(r->resp.fields); + if (nxt_slow_path(field == NULL)) { + goto fail; + } + + nxt_http_field_name_set(field, "ETag"); + + alloc = NXT_TIME_T_HEXLEN + NXT_OFF_T_HEXLEN + 3; + + p = nxt_mp_nget(r->mem_pool, alloc); + if (nxt_slow_path(p == NULL)) { + goto fail; + } + + field->value = p; + field->value_length = nxt_sprintf(p, p + alloc, "\"%xT-%xO\"", + nxt_file_mtime(&fi), + nxt_file_size(&fi)) + - p; + + if (extension.start == NULL) { + nxt_http_static_extract_extension(r->path, &extension); + } + + rtcf = r->conf->socket_conf->router_conf; + + mtype = nxt_http_static_mtypes_hash_find(&rtcf->mtypes_hash, + &extension); + + if (mtype != NULL) { + field = nxt_list_zero_add(r->resp.fields); + if (nxt_slow_path(field == NULL)) { + goto fail; + } + + nxt_http_field_name_set(field, "Content-Type"); + + field->value = mtype->start; + field->value_length = mtype->length; + } + + if (need_body && nxt_file_size(&fi) > 0) { + fb = nxt_mp_zget(r->mem_pool, NXT_BUF_FILE_SIZE); + if (nxt_slow_path(fb == NULL)) { + goto fail; + } + + fb->file = f; + fb->file_end = nxt_file_size(&fi); + + r->out = fb; + + body_handler = &nxt_http_static_body_handler; + + } else { + nxt_file_close(task, f); + body_handler = NULL; + } + + } else { + /* Not a file. */ + + nxt_file_close(task, f); + f = NULL; + + if (nxt_slow_path(!nxt_is_dir(&fi))) { + nxt_log(task, NXT_LOG_ERR, "\"%FN\" is not a regular file", + f->name); + nxt_http_request_error(task, r, NXT_HTTP_NOT_FOUND); + return NULL; + } + + r->status = NXT_HTTP_MOVED_PERMANENTLY; + r->resp.content_length_n = 0; + + field = nxt_list_zero_add(r->resp.fields); + if (nxt_slow_path(field == NULL)) { + goto fail; + } + + nxt_http_field_name_set(field, "Location"); + + encode = nxt_encode_uri(NULL, r->path->start, r->path->length); + alloc = r->path->length + encode * 2 + 1; + + if (r->args->length > 0) { + alloc += 1 + r->args->length; + } + + p = nxt_mp_nget(r->mem_pool, alloc); + if (nxt_slow_path(p == NULL)) { + goto fail; + } + + field->value = p; + field->value_length = alloc; + + if (encode > 0) { + p = (u_char *) nxt_encode_uri(p, r->path->start, r->path->length); + + } else { + p = nxt_cpymem(p, r->path->start, r->path->length); + } + + *p++ = '/'; + + if (r->args->length > 0) { + *p++ = '?'; + nxt_memcpy(p, r->args->start, r->args->length); + } + + body_handler = NULL; + } + + nxt_http_request_header_send(task, r, body_handler); + + r->state = &nxt_http_static_send_state; + return NULL; + +fail: + + nxt_http_request_error(task, r, NXT_HTTP_INTERNAL_SERVER_ERROR); + + if (f != NULL && f->fd != NXT_FILE_INVALID) { + nxt_file_close(task, f); + } + + return NULL; +} + + +static void +nxt_http_static_extract_extension(nxt_str_t *path, nxt_str_t *extension) +{ + u_char ch, *p, *end; + + end = path->start + path->length; + p = end; + + for ( ;; ) { + /* There's always '/' in the beginning of the request path. */ + + p--; + ch = *p; + + switch (ch) { + case '/': + p++; + /* Fall through. */ + case '.': + extension->length = end - p; + extension->start = p; + return; + } + } +} + + +static void +nxt_http_static_body_handler(nxt_task_t *task, void *obj, void *data) +{ + size_t alloc; + nxt_buf_t *fb, *b, **next, *out; + nxt_off_t rest; + nxt_int_t n; + nxt_work_queue_t *wq; + nxt_http_request_t *r; + + r = obj; + fb = r->out; + + rest = fb->file_end - fb->file_pos; + out = NULL; + next = &out; + n = 0; + + do { + alloc = nxt_min(rest, NXT_HTTP_STATIC_BUF_SIZE); + + b = nxt_buf_mem_alloc(r->mem_pool, alloc, 0); + if (nxt_slow_path(b == NULL)) { + goto fail; + } + + b->completion_handler = nxt_http_static_buf_completion; + b->parent = r; + + nxt_mp_retain(r->mem_pool); + + *next = b; + next = &b->next; + + rest -= alloc; + + } while (rest > 0 && ++n < NXT_HTTP_STATIC_BUF_COUNT); + + wq = &task->thread->engine->fast_work_queue; + + nxt_sendbuf_drain(task, wq, out); + return; + +fail: + + while (out != NULL) { + b = out; + out = b->next; + + nxt_mp_free(r->mem_pool, b); + nxt_mp_release(r->mem_pool); + } +} + + +static const nxt_http_request_state_t nxt_http_static_send_state + nxt_aligned(64) = +{ + .error_handler = nxt_http_request_error_handler, +}; + + +static void +nxt_http_static_buf_completion(nxt_task_t *task, void *obj, void *data) +{ + ssize_t n, size; + nxt_buf_t *b, *fb; + nxt_off_t rest; + nxt_http_request_t *r; + + b = obj; + r = data; + fb = r->out; + + if (nxt_slow_path(fb == NULL || r->error)) { + goto clean; + } + + rest = fb->file_end - fb->file_pos; + size = nxt_buf_mem_size(&b->mem); + + size = nxt_min(rest, (nxt_off_t) size); + + n = nxt_file_read(fb->file, b->mem.start, size, fb->file_pos); + + if (n != size) { + if (n >= 0) { + nxt_log(task, NXT_LOG_ERR, "file \"%FN\" has changed " + "while sending response to a client", fb->file->name); + } + + nxt_http_request_error_handler(task, r, r->proto.any); + goto clean; + } + + if (n == rest) { + nxt_file_close(task, fb->file); + r->out = NULL; + + b->next = nxt_http_buf_last(r); + + } else { + fb->file_pos += n; + b->next = NULL; + } + + b->mem.pos = b->mem.start; + b->mem.free = b->mem.pos + n; + + nxt_http_request_send(task, r, b); + return; + +clean: + + nxt_mp_free(r->mem_pool, b); + nxt_mp_release(r->mem_pool); + + if (fb != NULL) { + nxt_file_close(task, fb->file); + r->out = NULL; + } +} + + +nxt_int_t +nxt_http_static_mtypes_init(nxt_mp_t *mp, nxt_lvlhsh_t *hash) +{ + nxt_str_t *type, extension; + nxt_int_t ret; + nxt_uint_t i; + + static const struct { + nxt_str_t type; + const char *extension; + } default_types[] = { + + { nxt_string("text/html"), ".html" }, + { nxt_string("text/html"), ".htm" }, + { nxt_string("text/css"), ".css" }, + + { nxt_string("image/svg+xml"), ".svg" }, + { nxt_string("image/svg+xml"), ".svg" }, + { nxt_string("image/webp"), ".webp" }, + { nxt_string("image/png"), ".png" }, + { nxt_string("image/jpeg"), ".jpeg" }, + { nxt_string("image/jpeg"), ".jpg" }, + { nxt_string("image/gif"), ".gif" }, + { nxt_string("image/x-icon"), ".ico" }, + + { nxt_string("font/woff"), ".woff" }, + { nxt_string("font/woff2"), ".woff2" }, + { nxt_string("font/otf"), ".otf" }, + { nxt_string("font/ttf"), ".ttf" }, + + { nxt_string("text/plain"), ".txt" }, + { nxt_string("text/markdown"), ".md" }, + { nxt_string("text/x-rst"), ".rst" }, + + { nxt_string("application/javascript"), ".js" }, + { nxt_string("application/json"), ".json" }, + { nxt_string("application/xml"), ".xml" }, + { nxt_string("application/rss+xml"), ".rss" }, + { nxt_string("application/atom+xml"), ".atom" }, + { nxt_string("application/pdf"), ".pdf" }, + + { nxt_string("application/zip"), ".zip" }, + + { nxt_string("audio/mpeg"), ".mp3" }, + { nxt_string("audio/ogg"), ".ogg" }, + { nxt_string("audio/midi"), ".midi" }, + { nxt_string("audio/midi"), ".mid" }, + { nxt_string("audio/flac"), ".flac" }, + { nxt_string("audio/aac"), ".aac" }, + { nxt_string("audio/wav"), ".wav" }, + + { nxt_string("video/mpeg"), ".mpeg" }, + { nxt_string("video/mpeg"), ".mpg" }, + { nxt_string("video/mp4"), ".mp4" }, + { nxt_string("video/webm"), ".webm" }, + { nxt_string("video/x-msvideo"), ".avi" }, + + { nxt_string("application/octet-stream"), ".exe" }, + { nxt_string("application/octet-stream"), ".bin" }, + { nxt_string("application/octet-stream"), ".dll" }, + { nxt_string("application/octet-stream"), ".iso" }, + { nxt_string("application/octet-stream"), ".img" }, + { nxt_string("application/octet-stream"), ".msi" }, + + { nxt_string("application/octet-stream"), ".deb" }, + { nxt_string("application/octet-stream"), ".rpm" }, + }; + + for (i = 0; i < nxt_nitems(default_types); i++) { + type = (nxt_str_t *) &default_types[i].type; + + extension.start = (u_char *) default_types[i].extension; + extension.length = nxt_strlen(extension.start); + + ret = nxt_http_static_mtypes_hash_add(mp, hash, &extension, type); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + } + + return NXT_OK; +} + + +static const nxt_lvlhsh_proto_t nxt_http_static_mtypes_hash_proto + nxt_aligned(64) = +{ + NXT_LVLHSH_DEFAULT, + nxt_http_static_mtypes_hash_test, + nxt_http_static_mtypes_hash_alloc, + nxt_http_static_mtypes_hash_free, +}; + + +typedef struct { + nxt_str_t extension; + nxt_str_t *type; +} nxt_http_static_mtype_t; + + +nxt_int_t +nxt_http_static_mtypes_hash_add(nxt_mp_t *mp, nxt_lvlhsh_t *hash, + nxt_str_t *extension, nxt_str_t *type) +{ + nxt_lvlhsh_query_t lhq; + nxt_http_static_mtype_t *mtype; + + mtype = nxt_mp_get(mp, sizeof(nxt_http_static_mtype_t)); + if (nxt_slow_path(mtype == NULL)) { + return NXT_ERROR; + } + + mtype->extension = *extension; + mtype->type = type; + + lhq.key = *extension; + lhq.key_hash = nxt_djb_hash_lowcase(lhq.key.start, lhq.key.length); + lhq.replace = 1; + lhq.value = mtype; + lhq.proto = &nxt_http_static_mtypes_hash_proto; + lhq.pool = mp; + + return nxt_lvlhsh_insert(hash, &lhq); +} + + +nxt_str_t * +nxt_http_static_mtypes_hash_find(nxt_lvlhsh_t *hash, nxt_str_t *extension) +{ + nxt_lvlhsh_query_t lhq; + nxt_http_static_mtype_t *mtype; + + lhq.key = *extension; + lhq.key_hash = nxt_djb_hash_lowcase(lhq.key.start, lhq.key.length); + lhq.proto = &nxt_http_static_mtypes_hash_proto; + + if (nxt_lvlhsh_find(hash, &lhq) == NXT_OK) { + mtype = lhq.value; + return mtype->type; + } + + return NULL; +} + + +static nxt_int_t +nxt_http_static_mtypes_hash_test(nxt_lvlhsh_query_t *lhq, void *data) +{ + nxt_http_static_mtype_t *mtype; + + mtype = data; + + return nxt_strcasestr_eq(&lhq->key, &mtype->extension) ? NXT_OK + : NXT_DECLINED; +} + + +static void * +nxt_http_static_mtypes_hash_alloc(void *data, size_t size) +{ + return nxt_mp_align(data, size, size); +} + + +static void +nxt_http_static_mtypes_hash_free(void *data, void *p) +{ + nxt_mp_free(data, p); +} diff --git a/src/nxt_router.c b/src/nxt_router.c index 89ea5fe6..28781600 100644 --- a/src/nxt_router.c +++ b/src/nxt_router.c @@ -122,6 +122,8 @@ static void nxt_router_conf_send(nxt_task_t *task, static nxt_int_t nxt_router_conf_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, u_char *start, u_char *end); +static nxt_int_t nxt_router_conf_process_static(nxt_task_t *task, + nxt_router_conf_t *rtcf, nxt_conf_value_t *conf); static nxt_app_t *nxt_router_app_find(nxt_queue_t *queue, nxt_str_t *name); static void nxt_router_listen_socket_rpc_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, nxt_socket_conf_t *skcf); @@ -1399,7 +1401,7 @@ nxt_router_conf_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, nxt_conf_value_t *conf, *http, *value, *websocket; nxt_conf_value_t *applications, *application; nxt_conf_value_t *listeners, *listener; - nxt_conf_value_t *routes_conf; + nxt_conf_value_t *routes_conf, *static_conf; nxt_socket_conf_t *skcf; nxt_http_routes_t *routes; nxt_event_engine_t *engine; @@ -1419,6 +1421,7 @@ nxt_router_conf_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, #if (NXT_TLS) static nxt_str_t certificate_path = nxt_string("/tls/certificate"); #endif + static nxt_str_t static_path = nxt_string("/settings/http/static"); static nxt_str_t websocket_path = nxt_string("/settings/http/websocket"); conf = nxt_conf_json_parse(tmcf->mem_pool, start, end, NULL); @@ -1440,6 +1443,13 @@ nxt_router_conf_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, tmcf->router_conf->threads = nxt_ncpu; } + static_conf = nxt_conf_get_path(conf, &static_path); + + ret = nxt_router_conf_process_static(task, tmcf->router_conf, static_conf); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + router = tmcf->router_conf->router; applications = nxt_conf_get_path(conf, &applications_path); @@ -1788,6 +1798,87 @@ fail: } +static nxt_int_t +nxt_router_conf_process_static(nxt_task_t *task, nxt_router_conf_t *rtcf, + nxt_conf_value_t *conf) +{ + uint32_t next, i; + nxt_mp_t *mp; + nxt_str_t *type, extension, str; + nxt_int_t ret; + nxt_uint_t exts; + nxt_conf_value_t *mtypes_conf, *ext_conf, *value; + + static nxt_str_t mtypes_path = nxt_string("/mime_types"); + + mp = rtcf->mem_pool; + + ret = nxt_http_static_mtypes_init(mp, &rtcf->mtypes_hash); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + + if (conf == NULL) { + return NXT_OK; + } + + mtypes_conf = nxt_conf_get_path(conf, &mtypes_path); + + if (mtypes_conf != NULL) { + next = 0; + + for ( ;; ) { + ext_conf = nxt_conf_next_object_member(mtypes_conf, &str, &next); + + if (ext_conf == NULL) { + break; + } + + type = nxt_str_dup(mp, NULL, &str); + if (nxt_slow_path(type == NULL)) { + return NXT_ERROR; + } + + if (nxt_conf_type(ext_conf) == NXT_CONF_STRING) { + nxt_conf_get_string(ext_conf, &str); + + if (nxt_slow_path(nxt_str_dup(mp, &extension, &str) == NULL)) { + return NXT_ERROR; + } + + ret = nxt_http_static_mtypes_hash_add(mp, &rtcf->mtypes_hash, + &extension, type); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + + continue; + } + + exts = nxt_conf_array_elements_count(ext_conf); + + for (i = 0; i < exts; i++) { + value = nxt_conf_get_array_element(ext_conf, i); + + nxt_conf_get_string(value, &str); + + if (nxt_slow_path(nxt_str_dup(mp, &extension, &str) == NULL)) { + return NXT_ERROR; + } + + ret = nxt_http_static_mtypes_hash_add(mp, &rtcf->mtypes_hash, + &extension, type); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + } + } + } + + return NXT_OK; +} + + static nxt_app_t * nxt_router_app_find(nxt_queue_t *queue, nxt_str_t *name) { diff --git a/src/nxt_router.h b/src/nxt_router.h index b55a4de3..ec18ff48 100644 --- a/src/nxt_router.h +++ b/src/nxt_router.h @@ -38,9 +38,13 @@ typedef struct { typedef struct { uint32_t count; uint32_t threads; + + nxt_mp_t *mem_pool; + nxt_router_t *router; nxt_http_routes_t *routes; - nxt_mp_t *mem_pool; + + nxt_lvlhsh_t mtypes_hash; nxt_router_access_log_t *access_log; } nxt_router_conf_t; diff --git a/src/nxt_sprintf.c b/src/nxt_sprintf.c index 240f47ef..9557b327 100644 --- a/src/nxt_sprintf.c +++ b/src/nxt_sprintf.c @@ -12,8 +12,8 @@ /* * Supported formats: * - * %[0][width][x][X]O nxt_off_t - * %[0][width]T nxt_time_t + * %[0][width][x|X]O nxt_off_t + * %[0][width][x|X]T nxt_time_t * %[0][width][u][x|X]z ssize_t/size_t * %[0][width][u][x|X]d int/u_int * %[0][width][u][x|X]l long diff --git a/src/nxt_string.c b/src/nxt_string.c index 4d3b3954..b89e9555 100644 --- a/src/nxt_string.c +++ b/src/nxt_string.c @@ -507,3 +507,69 @@ nxt_decode_uri(u_char *dst, u_char *src, size_t length) return dst; } + + +uintptr_t +nxt_encode_uri(u_char *dst, u_char *src, size_t length) +{ + u_char *end; + nxt_uint_t n; + + static const u_char hex[16] = "0123456789ABCDEF"; + + /* " ", "#", "%", "?", %00-%1F, %7F-%FF */ + + static const uint32_t escape[] = { + 0xffffffff, /* 1111 1111 1111 1111 1111 1111 1111 1111 */ + + /* ?>=< ;:98 7654 3210 /.-, +*)( '&%$ #"! */ + 0x80000029, /* 1000 0000 0000 0000 0000 0000 0010 1001 */ + + /* _^]\ [ZYX WVUT SRQP ONML KJIH GFED CBA@ */ + 0x00000000, /* 0000 0000 0000 0000 0000 0000 0000 0000 */ + + /* ~}| {zyx wvut srqp onml kjih gfed cba` */ + 0x80000000, /* 1000 0000 0000 0000 0000 0000 0000 0000 */ + + 0xffffffff, /* 1111 1111 1111 1111 1111 1111 1111 1111 */ + 0xffffffff, /* 1111 1111 1111 1111 1111 1111 1111 1111 */ + 0xffffffff, /* 1111 1111 1111 1111 1111 1111 1111 1111 */ + 0xffffffff /* 1111 1111 1111 1111 1111 1111 1111 1111 */ + }; + + end = src + length; + + if (dst == NULL) { + + /* Find the number of the characters to be escaped. */ + + n = 0; + + while (src < end) { + + if (escape[*src >> 5] & (1U << (*src & 0x1f))) { + n++; + } + + src++; + } + + return (uintptr_t) n; + } + + while (src < end) { + + if (escape[*src >> 5] & (1U << (*src & 0x1f))) { + *dst++ = '%'; + *dst++ = hex[*src >> 4]; + *dst++ = hex[*src & 0xf]; + + } else { + *dst++ = *src; + } + + src++; + } + + return (uintptr_t) dst; +} diff --git a/src/nxt_string.h b/src/nxt_string.h index 56d7316b..8d7b3b73 100644 --- a/src/nxt_string.h +++ b/src/nxt_string.h @@ -169,6 +169,7 @@ NXT_EXPORT nxt_bool_t nxt_strvers_match(u_char *version, u_char *prefix, size_t length); NXT_EXPORT u_char *nxt_decode_uri(u_char *dst, u_char *src, size_t length); +NXT_EXPORT uintptr_t nxt_encode_uri(u_char *dst, u_char *src, size_t length); #endif /* _NXT_STRING_H_INCLUDED_ */ -- cgit