Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions core/internal/server/network/captive_portal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package network

import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"

"github.com/AvengeMedia/DankMaterialShell/core/internal/log"
)

const (
portalProbeURL = "http://nmcheck.gnome.org/check_network_status.txt"
portalProbeExpect = "NetworkManager is online"
portalProbeTimeout = 5 * time.Second
portalProbeInterval = 30 * time.Second
portalProbeMaxBody = 4096
portalFullProbeTicks = 10 // re-probe a healthy connection every ~5min, not every tick
)

// portalProbe checks a known endpoint to spot a captive portal: a redirect or an
// unexpected 200 body means traffic is being intercepted.
type portalProbe struct {
mgr *Manager
client *http.Client
url string
trigger chan struct{}
stopChan chan struct{}
wg sync.WaitGroup
lastKey string
fullTicks int
}

func newPortalProbe(m *Manager) *portalProbe {
probeURL := portalProbeURL
if v := os.Getenv("DMS_CAPTIVE_PROBE_URL"); v != "" {
probeURL = v
}
return &portalProbe{
mgr: m,
url: probeURL,
trigger: make(chan struct{}, 1),
stopChan: make(chan struct{}),
client: &http.Client{
Timeout: portalProbeTimeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
},
}
}

func (p *portalProbe) start() {
p.wg.Add(1)
go p.run()
p.kick()
}

func (p *portalProbe) stop() {
close(p.stopChan)
p.wg.Wait()
}

func (p *portalProbe) kick() {
select {
case p.trigger <- struct{}{}:
default:
}
}

func (p *portalProbe) run() {
defer p.wg.Done()
ticker := time.NewTicker(portalProbeInterval)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of using a ticker here, better to rewrite polling to be event-driven. I'd re-work that as a timer thats conditionalled armed, cancelled when sign in is "full" . We can probably subscribe to NM PropertiesChanged, or something like that to catch portal changes.

defer ticker.Stop()
for {
select {
case <-p.stopChan:
return
case <-p.trigger:
p.probe(false)
case <-ticker.C:
p.probe(true)
}
}
}

func (p *portalProbe) probe(periodic bool) {
m := p.mgr
m.stateMutex.RLock()
connected := m.state.WiFiConnected || m.state.EthernetConnected
curConn := m.state.Connectivity
key := fmt.Sprintf("%v|%s|%s", connected, m.state.WiFiSSID, m.state.EthernetIP)
m.stateMutex.RUnlock()

if !connected {
p.lastKey = key
p.set(ConnectivityNone, "")
return
}

if periodic {
if curConn == ConnectivityFull {
p.fullTicks++
if p.fullTicks < portalFullProbeTicks {
return
}
p.fullTicks = 0
}
} else if key == p.lastKey {
return
}
p.lastKey = key

conn, loc := p.check()
p.set(conn, loc)
}

func (p *portalProbe) check() (Connectivity, string) {
req, err := http.NewRequest(http.MethodGet, p.url, nil)
if err != nil {
return ConnectivityUnknown, ""
}
req.Header.Set("User-Agent", "DankMaterialShell/captive-portal-check")

resp, err := p.client.Do(req)
if err != nil {
return ConnectivityNone, ""
}
defer resp.Body.Close()

if resp.StatusCode >= 300 && resp.StatusCode < 400 {
return ConnectivityPortal, p.resolveLocation(resp.Header.Get("Location"))
}
if resp.StatusCode == http.StatusNoContent {
return ConnectivityFull, ""
}
// server/infra errors are not a portal
if resp.StatusCode >= 400 {
return ConnectivityUnknown, ""
}

body, _ := io.ReadAll(io.LimitReader(resp.Body, portalProbeMaxBody))
if resp.StatusCode == http.StatusOK && strings.Contains(string(body), portalProbeExpect) {
return ConnectivityFull, ""
}

return ConnectivityPortal, p.url
}

// resolveLocation turns a possibly-relative redirect target into an absolute url.
func (p *portalProbe) resolveLocation(loc string) string {
if loc == "" {
return p.url
}
base, err := url.Parse(p.url)
if err != nil {
return loc
}
ref, err := url.Parse(loc)
if err != nil {
return loc
}
return base.ResolveReference(ref).String()
}

func (p *portalProbe) set(conn Connectivity, loc string) {
m := p.mgr
m.stateMutex.Lock()
changed := m.state.Connectivity != conn || m.state.PortalURL != loc
m.state.Connectivity = conn
m.state.PortalURL = loc
m.stateMutex.Unlock()

if !changed {
return
}
if conn == ConnectivityPortal {
log.Infof("[captive-portal] portal detected, login url: %s", loc)
}
m.notifySubscribers()
}
99 changes: 99 additions & 0 deletions core/internal/server/network/captive_portal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package network

import (
"net/http"
"net/http/httptest"
"testing"
)

func TestPortalProbeCheck(t *testing.T) {
tests := []struct {
name string
handler http.HandlerFunc
wantConn Connectivity
wantURL func(srv string) string
}{
{
name: "online when expected body returned",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(portalProbeExpect + "\n"))
},
wantConn: ConnectivityFull,
wantURL: func(string) string { return "" },
},
{
name: "portal on redirect, url from location",
handler: func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "http://portal.example/login", http.StatusFound)
},
wantConn: ConnectivityPortal,
wantURL: func(string) string { return "http://portal.example/login" },
},
{
name: "portal on unexpected 200 body",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("<html>please sign in</html>"))
},
wantConn: ConnectivityPortal,
wantURL: func(srv string) string { return srv },
},
{
name: "relative redirect location resolved to absolute",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", "/login")
w.WriteHeader(http.StatusFound)
},
wantConn: ConnectivityPortal,
wantURL: func(srv string) string { return srv + "/login" },
},
{
name: "204 no content means online",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
},
wantConn: ConnectivityFull,
wantURL: func(string) string { return "" },
},
{
name: "server error is not a portal",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
},
wantConn: ConnectivityUnknown,
wantURL: func(string) string { return "" },
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := httptest.NewServer(tt.handler)
defer srv.Close()

p := newPortalProbe(nil)
p.url = srv.URL

conn, url := p.check()
if conn != tt.wantConn {
t.Errorf("connectivity = %q, want %q", conn, tt.wantConn)
}
if want := tt.wantURL(srv.URL); url != want {
t.Errorf("url = %q, want %q", url, want)
}
})
}
}

func TestPortalProbeCheckUnreachable(t *testing.T) {
p := newPortalProbe(nil)
p.url = "http://127.0.0.1:1"

conn, url := p.check()
if conn != ConnectivityNone {
t.Errorf("connectivity = %q, want %q", conn, ConnectivityNone)
}
if url != "" {
t.Errorf("url = %q, want empty", url)
}
}
17 changes: 17 additions & 0 deletions core/internal/server/network/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func NewManager() (*Manager, error) {
backend: backend,
state: &NetworkState{
NetworkStatus: StatusDisconnected,
Connectivity: ConnectivityUnknown,
Preference: PreferenceAuto,
WiFiNetworks: []WiFiNetwork{},
SavedWiFiNetworks: []WiFiNetwork{},
Expand Down Expand Up @@ -96,6 +97,9 @@ func NewManager() (*Manager, error) {
return nil, fmt.Errorf("failed to start monitoring: %w", err)
}

m.portalProbe = newPortalProbe(m)
m.portalProbe.start()

return m, nil
}

Expand Down Expand Up @@ -139,6 +143,9 @@ func (m *Manager) onBackendStateChange() {
if err := m.syncStateFromBackend(); err != nil {
log.Errorf("failed to sync state from backend: %v", err)
}
if m.portalProbe != nil {
m.portalProbe.kick()
}
m.notifySubscribers()
}

Expand Down Expand Up @@ -171,6 +178,12 @@ func stateChangedMeaningfully(old, new *NetworkState) bool {
if old.NetworkStatus != new.NetworkStatus {
return true
}
if old.Connectivity != new.Connectivity {
return true
}
if old.PortalURL != new.PortalURL {
return true
}
if old.Preference != new.Preference {
return true
}
Expand Down Expand Up @@ -420,6 +433,10 @@ func (m *Manager) GetPromptBroker() PromptBroker {
}

func (m *Manager) Close() {
if m.portalProbe != nil {
m.portalProbe.stop()
}

close(m.stopChan)
m.notifierWg.Wait()

Expand Down
13 changes: 13 additions & 0 deletions core/internal/server/network/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ const (
PreferenceEthernet ConnectionPreference = "ethernet"
)

type Connectivity string

const (
ConnectivityUnknown Connectivity = "unknown"
ConnectivityNone Connectivity = "none"
ConnectivityPortal Connectivity = "portal"
ConnectivityLimited Connectivity = "limited"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesnt seem to do anything with the limited case, should we? I'd rather remove it or add a // TODO than leave the unused value

ConnectivityFull Connectivity = "full"
)

type WiFiNetwork struct {
SSID string `json:"ssid"`
BSSID string `json:"bssid"`
Expand Down Expand Up @@ -98,6 +108,8 @@ type VPNState struct {
type NetworkState struct {
Backend string `json:"backend"`
NetworkStatus NetworkStatus `json:"networkStatus"`
Connectivity Connectivity `json:"connectivity"`
PortalURL string `json:"portalURL"`
Preference ConnectionPreference `json:"preference"`
EthernetIP string `json:"ethernetIP"`
EthernetDevice string `json:"ethernetDevice"`
Expand Down Expand Up @@ -162,6 +174,7 @@ type Manager struct {
notifierWg sync.WaitGroup
lastNotifiedState *NetworkState
credentialSubscribers syncmap.Map[string, chan CredentialPrompt]
portalProbe *portalProbe
}

type EventType string
Expand Down
1 change: 1 addition & 0 deletions quickshell/Common/SettingsData.qml
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ Singleton {
property bool weatherEnabled: true

property string networkPreference: "auto"
property bool captivePortalAutoOpen: true

property string iconThemeDark: "System Default"
property string iconThemeLight: "System Default"
Expand Down
Loading
Loading