blob: 0c6238076810c7729c61434be157d9f3ffe39a67 [file] [log] [blame]
#region Copyright notice and license
// Copyright 2015-2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#endregion
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using CommandLine;
using CommandLine.Text;
using Grpc.Core;
using Grpc.Core.Logging;
using Grpc.Core.Utils;
using Grpc.Testing;
namespace Grpc.IntegrationTesting
{
public class StressTestClient
{
static readonly ILogger Logger = GrpcEnvironment.Logger.ForType<StressTestClient>();
const double SecondsToNanos = 1e9;
private class ClientOptions
{
[Option("server_addresses", Default = "localhost:8080")]
public string ServerAddresses { get; set; }
[Option("test_cases", Default = "large_unary:100")]
public string TestCases { get; set; }
[Option("test_duration_secs", Default = -1)]
public int TestDurationSecs { get; set; }
[Option("num_channels_per_server", Default = 1)]
public int NumChannelsPerServer { get; set; }
[Option("num_stubs_per_channel", Default = 1)]
public int NumStubsPerChannel { get; set; }
[Option("metrics_port", Default = 8081)]
public int MetricsPort { get; set; }
}
ClientOptions options;
List<string> serverAddresses;
Dictionary<string, int> weightedTestCases;
WeightedRandomGenerator testCaseGenerator;
// cancellation will be emitted once test_duration_secs has elapsed.
CancellationTokenSource finishedTokenSource = new CancellationTokenSource();
Histogram histogram = new Histogram(0.01, 60 * SecondsToNanos);
private StressTestClient(ClientOptions options, List<string> serverAddresses, Dictionary<string, int> weightedTestCases)
{
this.options = options;
this.serverAddresses = serverAddresses;
this.weightedTestCases = weightedTestCases;
this.testCaseGenerator = new WeightedRandomGenerator(this.weightedTestCases);
}
public static void Run(string[] args)
{
GrpcEnvironment.SetLogger(new ConsoleLogger());
var parserResult = Parser.Default.ParseArguments<ClientOptions>(args)
.WithNotParsed((x) => Environment.Exit(1))
.WithParsed(options => {
GrpcPreconditions.CheckArgument(options.NumChannelsPerServer > 0);
GrpcPreconditions.CheckArgument(options.NumStubsPerChannel > 0);
var serverAddresses = options.ServerAddresses.Split(',');
GrpcPreconditions.CheckArgument(serverAddresses.Length > 0, "You need to provide at least one server address");
var testCases = ParseWeightedTestCases(options.TestCases);
GrpcPreconditions.CheckArgument(testCases.Count > 0, "You need to provide at least one test case");
var interopClient = new StressTestClient(options, serverAddresses.ToList(), testCases);
interopClient.Run().Wait();
});
}
async Task Run()
{
var metricsServer = new Server()
{
Services = { MetricsService.BindService(new MetricsServiceImpl(histogram)) },
Ports = { { "[::]", options.MetricsPort, ServerCredentials.Insecure } }
};
metricsServer.Start();
if (options.TestDurationSecs >= 0)
{
finishedTokenSource.CancelAfter(TimeSpan.FromSeconds(options.TestDurationSecs));
}
var tasks = new List<Task>();
var channels = new List<Channel>();
foreach (var serverAddress in serverAddresses)
{
for (int i = 0; i < options.NumChannelsPerServer; i++)
{
var channel = new Channel(serverAddress, ChannelCredentials.Insecure);
channels.Add(channel);
for (int j = 0; j < options.NumStubsPerChannel; j++)
{
var client = new TestService.TestServiceClient(channel);
var task = Task.Factory.StartNew(() => RunBodyAsync(client).GetAwaiter().GetResult(),
TaskCreationOptions.LongRunning);
tasks.Add(task);
}
}
}
await Task.WhenAll(tasks);
foreach (var channel in channels)
{
await channel.ShutdownAsync();
}
await metricsServer.ShutdownAsync();
}
async Task RunBodyAsync(TestService.TestServiceClient client)
{
Logger.Info("Starting stress test client thread.");
while (!finishedTokenSource.Token.IsCancellationRequested)
{
var testCase = testCaseGenerator.GetNext();
var stopwatch = Stopwatch.StartNew();
await RunTestCaseAsync(client, testCase);
stopwatch.Stop();
histogram.AddObservation(stopwatch.Elapsed.TotalSeconds * SecondsToNanos);
}
Logger.Info("Stress test client thread finished.");
}
async Task RunTestCaseAsync(TestService.TestServiceClient client, string testCase)
{
switch (testCase)
{
case "empty_unary":
InteropClient.RunEmptyUnary(client);
break;
case "large_unary":
InteropClient.RunLargeUnary(client);
break;
case "client_streaming":
await InteropClient.RunClientStreamingAsync(client);
break;
case "server_streaming":
await InteropClient.RunServerStreamingAsync(client);
break;
case "ping_pong":
await InteropClient.RunPingPongAsync(client);
break;
case "empty_stream":
await InteropClient.RunEmptyStreamAsync(client);
break;
case "cancel_after_begin":
await InteropClient.RunCancelAfterBeginAsync(client);
break;
case "cancel_after_first_response":
await InteropClient.RunCancelAfterFirstResponseAsync(client);
break;
case "timeout_on_sleeping_server":
await InteropClient.RunTimeoutOnSleepingServerAsync(client);
break;
case "custom_metadata":
await InteropClient.RunCustomMetadataAsync(client);
break;
case "status_code_and_message":
await InteropClient.RunStatusCodeAndMessageAsync(client);
break;
default:
throw new ArgumentException("Unsupported test case " + testCase);
}
}
static Dictionary<string, int> ParseWeightedTestCases(string weightedTestCases)
{
var result = new Dictionary<string, int>();
foreach (var weightedTestCase in weightedTestCases.Split(','))
{
var parts = weightedTestCase.Split(new char[] {':'}, 2);
GrpcPreconditions.CheckArgument(parts.Length == 2, "Malformed test_cases option.");
result.Add(parts[0], int.Parse(parts[1]));
}
return result;
}
class WeightedRandomGenerator
{
readonly Random random = new Random();
readonly List<Tuple<int, string>> cumulativeSums;
readonly int weightSum;
public WeightedRandomGenerator(Dictionary<string, int> weightedItems)
{
cumulativeSums = new List<Tuple<int, string>>();
weightSum = 0;
foreach (var entry in weightedItems)
{
weightSum += entry.Value;
cumulativeSums.Add(Tuple.Create(weightSum, entry.Key));
}
}
public string GetNext()
{
int rand = random.Next(weightSum);
foreach (var entry in cumulativeSums)
{
if (rand < entry.Item1)
{
return entry.Item2;
}
}
throw new InvalidOperationException("GetNext() failed.");
}
}
class MetricsServiceImpl : MetricsService.MetricsServiceBase
{
const string GaugeName = "csharp_overall_qps";
readonly Histogram histogram;
readonly TimeStats timeStats = new TimeStats();
public MetricsServiceImpl(Histogram histogram)
{
this.histogram = histogram;
}
public override Task<GaugeResponse> GetGauge(GaugeRequest request, ServerCallContext context)
{
if (request.Name == GaugeName)
{
long qps = GetQpsAndReset();
return Task.FromResult(new GaugeResponse
{
Name = GaugeName,
LongValue = qps
});
}
throw new RpcException(new Status(StatusCode.InvalidArgument, "Gauge does not exist"));
}
public override async Task GetAllGauges(EmptyMessage request, IServerStreamWriter<GaugeResponse> responseStream, ServerCallContext context)
{
long qps = GetQpsAndReset();
var response = new GaugeResponse
{
Name = GaugeName,
LongValue = qps
};
await responseStream.WriteAsync(response);
}
long GetQpsAndReset()
{
var snapshot = histogram.GetSnapshot(true);
var timeSnapshot = timeStats.GetSnapshot(true);
return (long) (snapshot.Count / timeSnapshot.WallClockTime.TotalSeconds);
}
}
}
}