📄 pmtp.c
字号:
#include <stdio.h>#include <stdlib.h>#include <string.h>#include <netinet/in.h>#include <unistd.h>#include <sys/stat.h>#include <sys/types.h>#include <fcntl.h>#ifdef DMALLOC#include "dmalloc.h"#endif#include <rs.h>#include <openssl/sha.h>#include "pmtp.h"#include "tigertree/tigertree.h"uint64_t get_tigertree (unsigned char *tt, unsigned char *filename, uint64_t offset, uint64_t length) { TT_CONTEXT ctx; unsigned char *buff; int fd; int i = 0, bytes_read = TIGERTREE_CHUNKSIZE, total_bytes = 0; if ((buff = (unsigned char *) malloc (TIGERTREE_CHUNKSIZE * sizeof (unsigned char))) == NULL) { fprintf (stderr, "Couldn't malloc %d bytes for tigertree chunk\n", TIGERTREE_CHUNKSIZE * sizeof (unsigned char)); exit (1); } if (offset == length) length = get_filesize (filename) - offset; tt_init (&ctx); if ((fd = open (filename, O_RDONLY)) < 0) { fprintf (stderr, "Error opening %s for reading\n", filename); exit (1); } if (lseek (fd, offset, SEEK_SET) == -1) { perror ("lseek"); exit (1); } while (i < length && bytes_read == TIGERTREE_CHUNKSIZE) { bytes_read = read (fd, buff, TIGERTREE_CHUNKSIZE); total_bytes += bytes_read; tt_update(&ctx,buff, bytes_read); i += TIGERTREE_CHUNKSIZE; } tt_digest (&ctx, tt); free(buff); close (fd); return total_bytes;}off_t get_filesize (const unsigned char *filename) { struct stat filestat; if (stat (filename, &filestat) != 0) { fprintf (stderr, "Couldn't stat %s\n", filename); exit (1); } return filestat.st_size;}unsigned char * get_printable_hex (unsigned char *dest, unsigned char *string, uint16_t length) { int count; for (count = 0; count < length; count++) sprintf(&dest[count * 2], "%X%X", (string[count] & 0xf0) >> 4, string[count] & 0x0f); string[length * 2] = '\0'; return dest;} unsigned char * get_printable_sha (unsigned char *dest, unsigned char *string, uint16_t length) { unsigned char sha_val[20]; SHA1(string, length, sha_val); return get_printable_hex (dest, sha_val, 20);}uint32_t init_segment (struct pmtp_segment *segment, void *codec) { uint32_t old_segment = segment->segmentid; segment->segmentid = 0; segment->packets_received = 0; segment->duplicate_packets = 0; segment->packets = NULL; segment->payload_size = 0; segment->codec = codec; memset (segment->received_packets, 0, sizeof (segment->received_packets)); return old_segment;}uint32_t free_segment (struct pmtp_segment *segment) { if (segment->packets != NULL) free (segment->packets); return init_segment (segment, NULL);} uint8_t * encode_segment (struct pmtp_segment *segment, uint8_t nroots) { uint16_t b, p; uint8_t *temp, block[NN]; /* TODO: optimize so that a temp buffer is not used. */ if ((temp = (uint8_t *) malloc (NN * segment->payload_size)) == NULL) { fprintf (stderr, "Couldn't malloc %d bytes for temp.\n", NN * segment->payload_size); exit (1); } memcpy (temp, segment->packets, (NN - nroots) * segment->payload_size); for (b = 0; b < segment->payload_size; b++) { memcpy (block, temp + b * (NN-nroots), NN-nroots); encode_rs_char (segment->codec, block, &block[NN-nroots]); for (p = 0; p < NN; p++) *(segment->packets + segment->payload_size * p + b) = block[p]; } free (temp); return segment->packets;}uint32_t decode_segment (struct pmtp_segment *segment, int fd, uint16_t nroots, uint32_t bytes_left) { unsigned char block[NN]; int *errlocs; int erasures = 0, erasure_pos; uint32_t p, b, e, bytes = 0; unsigned char sha_printable[60]; if ((errlocs = (int *) malloc (nroots * sizeof (int))) == NULL) { fprintf (stderr, "Couldn't malloc %d bytes for errlocs\n", nroots * sizeof (uint16_t)); exit (1); } for (p = 0; p < NN; p++) { if (segment->received_packets[p] == 0) { if (DEBUG & DEBUG_PACKETS_MISSED) fprintf (stderr, "Missed packet %d from segment %d\n", p, segment->segmentid); if (erasures < nroots) { errlocs[erasures++] = p; } } } if (DEBUG & DEBUG_SEGMENTS) { fprintf (stderr, "Processing segment %d with %d bytes to go and %d erasures", segment->segmentid, bytes_left, erasures); if (erasures > 0) { fprintf (stderr, " at "); for (e = 0; e < erasures; e++) { fprintf (stderr, "%d ", errlocs[e]); } } fprintf (stderr, "\n"); } if (erasures == nroots) { fprintf (stderr, "Too many erasures (%d) in segment %d.\n", erasures, segment->segmentid); exit (1); } if (DEBUG & DEBUG_SEGMENTS_SHA) fprintf (stderr, "Segment %d, Packets %d, Payload %d, SHA: 0x%s\n", segment->segmentid, segment->packets_received, segment->payload_size, get_printable_sha (sha_printable, segment->packets, NN * segment->payload_size)); for (b = 0; b < segment->payload_size; b++) { if (DEBUG & DEBUG_CODING) fprintf (stderr, "Decoding block %d with %d erasures\n", b, erasures); erasure_pos = 0; for (p = 0; p < NN; p++) { if (erasures > 0 && errlocs[erasure_pos] == p) { erasure_pos++; block[p] = 0; } else { block[p] = *(segment->packets + p * segment->payload_size + b); } if (DEBUG & DEBUG_CODING_DUMP) fprintf (stderr, "%X ", block[p]); } if (DEBUG & DEBUG_CODING_DUMP) fprintf (stderr, "\n"); /* The decode function writes the erasure positions it found to * the errlocs array, which we need later. So we use a temp copy */ if (erasures >= 0 && errlocs[0] < (NN - nroots)) { if ((erasures = decode_rs_char (segment->codec, block, errlocs, erasures)) == -1) { fprintf (stderr, "Decode failed! Erasures: %d, %d!\n", erasures, b); exit (1); } } bytes = bytes_left > (NN - nroots) ? (NN - nroots) : bytes_left; write (fd, block, bytes); bytes_left -= bytes; } free (errlocs); return bytes_left;}int32_t buffer_segment(struct pmtp_upload *upload,struct pmtp_segment *segment){ uint16_t p; uint8_t erasures = 0; uint8_t *packet; unsigned char sha_printable[60]; if ((packet = (uint8_t *) malloc (PMTP_HEADER_SIZE + segment->payload_size)) == NULL) { fprintf (stderr, "Couldn't malloc %d bytes for packet\n", PMTP_HEADER_SIZE + segment->payload_size); exit (1); } for (p = 0; p < NN; p++) { encode_header_16 (&packet[PMTP_HEADER_PACKETID], p & 0x0fff); encode_header_16 (&packet[PMTP_HEADER_SESSIONID], upload->session.sessionid); encode_header_32 (&packet[PMTP_HEADER_SEGMENTID], segment->segmentid); memcpy (packet + PMTP_HEADER_SIZE, segment->packets + p * segment->payload_size, segment->payload_size); if (DEBUG & DEBUG_PACKETS_SHA) fprintf(stderr, "Packet %d, segment %d, SHA 0x%s\n", p, segment->segmentid, get_printable_sha (sha_printable, packet, PMTP_HEADER_SIZE + segment->payload_size)); if (DEBUG && ((random() & 255) < PMTP_SIMULATE_ERRORS)) { erasures++; if (DEBUG & DEBUG_PACKETS_MISSED) fprintf (stderr, "On purpose losing packet: %d\n", p); } else { if (DEBUG & DEBUG_PACKETS_SENT) fprintf (stderr, "sending packet %d, segment %d, payload size %d\n", p, segment->segmentid, segment->payload_size); if (buffer_packet (&(upload->buffer), packet, PMTP_HEADER_SIZE + segment->payload_size) == -1) { fprintf (stderr, "buffer_packet"); exit (1); } } } if (erasures > NROOTS) { fprintf (stderr, "Exiting because there were %d packets lost\n", erasures); exit (1); } free (packet); return NN * segment->payload_size;}uint8_t decode_header_8 (unsigned char* header) { return (uint8_t) *header;}uint16_t decode_header_16 (unsigned char* header) { unsigned char temp[2]; uint16_t value; if (header == NULL || sizeof header < 2) { fprintf (stderr, "Invalid header in decode_header_16"); exit (1); } memcpy(temp, header, 2); value = *((uint16_t *) temp); return htons(value);}uint32_t decode_header_32 (unsigned char* header) { unsigned char temp[4]; uint32_t value; if (header == NULL || sizeof header < 4) { fprintf (stderr, "Invalid header in decode_header_32"); exit (1); } memcpy(temp, header, 4); value = *((uint32_t *) temp); return htonl(value);}uint8_t encode_header_8 (unsigned char* header, uint8_t value) { if (header == NULL) { fprintf (stderr, "Invalid header in encode_header_8\n"); exit (1); } return *(header) = (unsigned char) value;}uint16_t encode_header_16 (unsigned char* header, uint16_t value) { if (header == NULL || sizeof header < 2) { fprintf (stderr, "Invalid header in encode_header_16\n"); exit (1); } return *((uint16_t *) header) = ntohs(value);}uint32_t encode_header_32 (unsigned char* header, uint32_t value) { if (header == NULL || sizeof header < 4) { fprintf (stderr, "Invalid header in decode_header_32\n"); exit (1); } return *((uint32_t *) header) = ntohl(value);}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -